| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -2,14 +2,10 @@ import unittest | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from PIL import Image | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from torchvision import transforms | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from image_embedding_operator_template import ImageEmbeddingOperatorTemplate | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from config import DIMENSION, MODEL_NAME, TEST_IMG | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					class TestImageEmbeddingOperatorTemplate(unittest.TestCase): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    Simple operator test | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    img = Image.open(TEST_IMG) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def get_transformered_img(img_path): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    img = Image.open(img_path) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    tfms = transforms.Compose( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        [ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            transforms.Resize(256), | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -19,14 +15,10 @@ class TestImageEmbeddingOperatorTemplate(unittest.TestCase): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    img_tensor = tfms(img).unsqueeze(0) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    return img_tensor | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def test_image_embedding(self): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.model_name = MODEL_NAME | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.dimension = DIMENSION | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        op = ImageEmbeddingOperatorTemplate(self.model_name) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        print("The output shape of operator:", op(self.img_tensor)[0].shape) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        assert (1, self.dimension)==op(self.img_tensor)[0].shape | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					if __name__ == '__main__': | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    unittest.main() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def test_image_embedding_operator(model_name='resnet50', img_path='./test_data/test.jpg', dimension=2048): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    op = ImageEmbeddingOperatorTemplate(model_name) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    img_tensor = get_transformered_img(img_path) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    embedding = op(img_tensor) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    assert (1, dimension)==embedding[0].shape | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
					 | 
				
				 | 
				
					
  |