towhee
/
            
				 1 changed files with 4 additions and 19 deletions
			
			
		@ -1,23 +1,8 @@ | 
				
			|||
from PIL import Image | 
				
			|||
from torchvision import transforms | 
				
			|||
from image_embedding_operator_template import ImageEmbeddingOperatorTemplate | 
				
			|||
 | 
				
			|||
 | 
				
			|||
def get_transformered_img(img_path): | 
				
			|||
    img = Image.open(img_path) | 
				
			|||
    tfms = transforms.Compose( | 
				
			|||
        [ | 
				
			|||
            transforms.Resize(256), | 
				
			|||
            transforms.CenterCrop(224), | 
				
			|||
            transforms.ToTensor(), | 
				
			|||
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | 
				
			|||
        ] | 
				
			|||
    ) | 
				
			|||
    img_tensor = tfms(img).unsqueeze(0) | 
				
			|||
    return img_tensor | 
				
			|||
 | 
				
			|||
def test_image_embedding_operator(model_name='resnet50', img_path='./test_data/test.jpg', dimension=2048): | 
				
			|||
    op = ImageEmbeddingOperatorTemplate(model_name) | 
				
			|||
    img_tensor = get_transformered_img(img_path) | 
				
			|||
    embedding = op(img_tensor) | 
				
			|||
    assert (1, dimension)==embedding[0].shape | 
				
			|||
    op = ResnetImageEmbedding(model_name) | 
				
			|||
    embedding = op(img_path) | 
				
			|||
    # get the output shape in README | 
				
			|||
    assert (dimension,)==embedding[0].shape | 
				
			|||
 | 
				
			|||
					Loading…
					
					
				
		Reference in new issue