import unittest from PIL import Image from torchvision import transforms from transform_image import TransformImage class TestTransformImage(unittest.TestCase): def test_transform_image(self): img_src = './test_data/test.jpg' test_img = Image.open(img_src) tfms = transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) img1 = tfms(test_img).unsqueeze(0) op = TransformImage(256) outputs = op(test_img) print("The output tyep of operator:", type(outputs.img_transformed)) c = (img1.numpy() == outputs.img_transformed.numpy()) self.assertTrue(c.all()) if __name__ == '__main__': unittest.main()