From d484f2eb1863e02fd37dd23005d5b7b207b02d9e Mon Sep 17 00:00:00 2001 From: shiyu22 Date: Wed, 22 Dec 2021 19:04:52 +0800 Subject: [PATCH] Update the test script Signed-off-by: shiyu22 --- test_image_embedding_operator_template.py | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/test_image_embedding_operator_template.py b/test_image_embedding_operator_template.py index f4cd205..78b18dc 100644 --- a/test_image_embedding_operator_template.py +++ b/test_image_embedding_operator_template.py @@ -1,23 +1,8 @@ -from PIL import Image -from torchvision import transforms from image_embedding_operator_template import ImageEmbeddingOperatorTemplate -def get_transformered_img(img_path): - img = Image.open(img_path) - tfms = transforms.Compose( - [ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - ] - ) - img_tensor = tfms(img).unsqueeze(0) - return img_tensor - def test_image_embedding_operator(model_name='resnet50', img_path='./test_data/test.jpg', dimension=2048): - op = ImageEmbeddingOperatorTemplate(model_name) - img_tensor = get_transformered_img(img_path) - embedding = op(img_tensor) - assert (1, dimension)==embedding[0].shape + op = ResnetImageEmbedding(model_name) + embedding = op(img_path) + # get the output shape in README + assert (dimension,)==embedding[0].shape