diff --git a/config.py b/config.py deleted file mode 100644 index 3c33d66..0000000 --- a/config.py +++ /dev/null @@ -1 +0,0 @@ -TEST_IMG = './test_data/test.jpg' diff --git a/test_transform_image_operator_template.py b/test_transform_image_operator_template.py index 83385e8..44bd3a5 100644 --- a/test_transform_image_operator_template.py +++ b/test_transform_image_operator_template.py @@ -1,28 +1,22 @@ -import unittest -from PIL import Image from torchvision import transforms from transform_image_operator_template import TransformImageOperatorTemplate -from config import TEST_IMG, SIZE -class TestTransformImageOperatorTemplate(unittest.TestCase): +def get_transformed_img(size, img_path) tfms = transforms.Compose( [ - transforms.Resize(SIZE), + transforms.Resize(size), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) - img1 = tfms(test_img).unsqueeze(0) + img_transformed = tfms(img_path).unsqueeze(0) + return img_transformed - def test_transform_image(self): - op = TransformImageOperatorTemplate(SIZE) - outputs = op(TEST_IMG) - print("The output tyep of operator:", type(outputs.img_transformed)) - c = (self.img1.numpy() == outputs.img_transformed.numpy()) - self.assertEqual(c.all(), True) - - -if __name__ == '__main__': - unittest.main() \ No newline at end of file +def test_transform_image(size=256, img_path='./test_data/test.jpg'): + op = TransformImageOperatorTemplate(size + outputs = op(img_path) + img_transformed = get_transformed_img(size, img_path) + c = (img_transformed.numpy() == outputs.img_transformed.numpy()) + assert c.all()