towhee
/
            
				 1 changed files with 4 additions and 19 deletions
			
			
		@ -1,23 +1,8 @@ | 
			
		|||||
from PIL import Image | 
				 | 
			
		||||
from torchvision import transforms | 
				 | 
			
		||||
from image_embedding_operator_template import ImageEmbeddingOperatorTemplate | 
				from image_embedding_operator_template import ImageEmbeddingOperatorTemplate | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
def get_transformered_img(img_path): | 
				 | 
			
		||||
    img = Image.open(img_path) | 
				 | 
			
		||||
    tfms = transforms.Compose( | 
				 | 
			
		||||
        [ | 
				 | 
			
		||||
            transforms.Resize(256), | 
				 | 
			
		||||
            transforms.CenterCrop(224), | 
				 | 
			
		||||
            transforms.ToTensor(), | 
				 | 
			
		||||
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | 
				 | 
			
		||||
        ] | 
				 | 
			
		||||
    ) | 
				 | 
			
		||||
    img_tensor = tfms(img).unsqueeze(0) | 
				 | 
			
		||||
    return img_tensor | 
				 | 
			
		||||
 | 
				 | 
			
		||||
def test_image_embedding_operator(model_name='resnet50', img_path='./test_data/test.jpg', dimension=2048): | 
				def test_image_embedding_operator(model_name='resnet50', img_path='./test_data/test.jpg', dimension=2048): | 
			
		||||
    op = ImageEmbeddingOperatorTemplate(model_name) | 
				 | 
			
		||||
    img_tensor = get_transformered_img(img_path) | 
				 | 
			
		||||
    embedding = op(img_tensor) | 
				 | 
			
		||||
    assert (1, dimension)==embedding[0].shape | 
				 | 
			
		||||
 | 
				    op = ResnetImageEmbedding(model_name) | 
			
		||||
 | 
				    embedding = op(img_path) | 
			
		||||
 | 
				    # get the output shape in README | 
			
		||||
 | 
				    assert (dimension,)==embedding[0].shape | 
			
		||||
 | 
			
		|||||
					Loading…
					
					
				
		Reference in new issue