towhee
/
            
				 2 changed files with 10 additions and 17 deletions
			
			
		@ -1 +0,0 @@ | 
				
			|||
TEST_IMG = './test_data/test.jpg' | 
				
			|||
@ -1,28 +1,22 @@ | 
				
			|||
import unittest | 
				
			|||
from PIL import Image | 
				
			|||
from torchvision import transforms | 
				
			|||
from transform_image_operator_template import TransformImageOperatorTemplate | 
				
			|||
from config import TEST_IMG, SIZE | 
				
			|||
 | 
				
			|||
 | 
				
			|||
class TestTransformImageOperatorTemplate(unittest.TestCase): | 
				
			|||
def get_transformed_img(size, img_path) | 
				
			|||
    tfms = transforms.Compose( | 
				
			|||
        [ | 
				
			|||
            transforms.Resize(SIZE), | 
				
			|||
            transforms.Resize(size), | 
				
			|||
            transforms.CenterCrop(224), | 
				
			|||
            transforms.ToTensor(), | 
				
			|||
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | 
				
			|||
        ] | 
				
			|||
    ) | 
				
			|||
    img1 = tfms(test_img).unsqueeze(0) | 
				
			|||
    img_transformed = tfms(img_path).unsqueeze(0) | 
				
			|||
    return img_transformed | 
				
			|||
 | 
				
			|||
    def test_transform_image(self): | 
				
			|||
        op = TransformImageOperatorTemplate(SIZE) | 
				
			|||
        outputs = op(TEST_IMG) | 
				
			|||
        print("The output tyep of operator:", type(outputs.img_transformed)) | 
				
			|||
        c = (self.img1.numpy() == outputs.img_transformed.numpy()) | 
				
			|||
        self.assertEqual(c.all(), True) | 
				
			|||
 | 
				
			|||
 | 
				
			|||
if __name__ == '__main__': | 
				
			|||
    unittest.main() | 
				
			|||
def test_transform_image(size=256, img_path='./test_data/test.jpg'): | 
				
			|||
    op = TransformImageOperatorTemplate(size | 
				
			|||
    outputs = op(img_path) | 
				
			|||
    img_transformed = get_transformed_img(size, img_path) | 
				
			|||
    c = (img_transformed.numpy() == outputs.img_transformed.numpy()) | 
				
			|||
    assert c.all() | 
				
			|||
 | 
				
			|||
					Loading…
					
					
				
		Reference in new issue