from torchvision import transforms from transform_image_operator_template import TransformImageOperatorTemplate from PIL import Image def get_transformed_img(size, img_path): img = Image.open(img_path) tfms = transforms.Compose( [ transforms.Resize(size), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) img_transformed = tfms(img).unsqueeze(0) return img_transformed def test_transform_image(size=256, img_path='./test_data/test.jpg'): op = TransformImageOperatorTemplate(size) outputs = op(img_path) img_transformed = get_transformed_img(size, img_path) c = (img_transformed.numpy() == outputs.img_transformed.numpy()) assert c.all()