diff --git a/test_image_embedding_operator_template.py b/test_image_embedding_operator_template.py index f4cd205..78b18dc 100644 --- a/test_image_embedding_operator_template.py +++ b/test_image_embedding_operator_template.py @@ -1,23 +1,8 @@ -from PIL import Image -from torchvision import transforms from image_embedding_operator_template import ImageEmbeddingOperatorTemplate -def get_transformered_img(img_path): - img = Image.open(img_path) - tfms = transforms.Compose( - [ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - ] - ) - img_tensor = tfms(img).unsqueeze(0) - return img_tensor - def test_image_embedding_operator(model_name='resnet50', img_path='./test_data/test.jpg', dimension=2048): - op = ImageEmbeddingOperatorTemplate(model_name) - img_tensor = get_transformered_img(img_path) - embedding = op(img_tensor) - assert (1, dimension)==embedding[0].shape + op = ResnetImageEmbedding(model_name) + embedding = op(img_path) + # get the output shape in README + assert (dimension,)==embedding[0].shape