towhee
/
            
				 2 changed files with 10 additions and 17 deletions
			
			
		@ -1 +0,0 @@ | 
			
		|||||
TEST_IMG = './test_data/test.jpg' | 
				 | 
			
		||||
@ -1,28 +1,22 @@ | 
			
		|||||
import unittest | 
				 | 
			
		||||
from PIL import Image | 
				 | 
			
		||||
from torchvision import transforms | 
				from torchvision import transforms | 
			
		||||
from transform_image_operator_template import TransformImageOperatorTemplate | 
				from transform_image_operator_template import TransformImageOperatorTemplate | 
			
		||||
from config import TEST_IMG, SIZE | 
				 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
class TestTransformImageOperatorTemplate(unittest.TestCase): | 
				 | 
			
		||||
 | 
				def get_transformed_img(size, img_path) | 
			
		||||
    tfms = transforms.Compose( | 
				    tfms = transforms.Compose( | 
			
		||||
        [ | 
				        [ | 
			
		||||
            transforms.Resize(SIZE), | 
				 | 
			
		||||
 | 
				            transforms.Resize(size), | 
			
		||||
            transforms.CenterCrop(224), | 
				            transforms.CenterCrop(224), | 
			
		||||
            transforms.ToTensor(), | 
				            transforms.ToTensor(), | 
			
		||||
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | 
				            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | 
			
		||||
        ] | 
				        ] | 
			
		||||
    ) | 
				    ) | 
			
		||||
    img1 = tfms(test_img).unsqueeze(0) | 
				 | 
			
		||||
 | 
				    img_transformed = tfms(img_path).unsqueeze(0) | 
			
		||||
 | 
				    return img_transformed | 
			
		||||
 | 
				
 | 
			
		||||
    def test_transform_image(self): | 
				 | 
			
		||||
        op = TransformImageOperatorTemplate(SIZE) | 
				 | 
			
		||||
        outputs = op(TEST_IMG) | 
				 | 
			
		||||
        print("The output tyep of operator:", type(outputs.img_transformed)) | 
				 | 
			
		||||
        c = (self.img1.numpy() == outputs.img_transformed.numpy()) | 
				 | 
			
		||||
        self.assertEqual(c.all(), True) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
 | 
				 | 
			
		||||
if __name__ == '__main__': | 
				 | 
			
		||||
    unittest.main() | 
				 | 
			
		||||
 | 
				def test_transform_image(size=256, img_path='./test_data/test.jpg'): | 
			
		||||
 | 
				    op = TransformImageOperatorTemplate(size | 
			
		||||
 | 
				    outputs = op(img_path) | 
			
		||||
 | 
				    img_transformed = get_transformed_img(size, img_path) | 
			
		||||
 | 
				    c = (img_transformed.numpy() == outputs.img_transformed.numpy()) | 
			
		||||
 | 
				    assert c.all() | 
			
		||||
 | 
			
		|||||
					Loading…
					
					
				
		Reference in new issue