diff --git a/test_transform_image_operator_template.py b/test_transform_image_operator_template.py index 44bd3a5..e981748 100644 --- a/test_transform_image_operator_template.py +++ b/test_transform_image_operator_template.py @@ -1,8 +1,10 @@ from torchvision import transforms from transform_image_operator_template import TransformImageOperatorTemplate +from PIL import Image -def get_transformed_img(size, img_path) +def get_transformed_img(size, img_path): + img = Image.open(img_path) tfms = transforms.Compose( [ transforms.Resize(size), @@ -11,11 +13,11 @@ def get_transformed_img(size, img_path) transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) - img_transformed = tfms(img_path).unsqueeze(0) + img_transformed = tfms(img).unsqueeze(0) return img_transformed def test_transform_image(size=256, img_path='./test_data/test.jpg'): - op = TransformImageOperatorTemplate(size + op = TransformImageOperatorTemplate(size) outputs = op(img_path) img_transformed = get_transformed_img(size, img_path) c = (img_transformed.numpy() == outputs.img_transformed.numpy())