logo
Browse Source

Update the test script

Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
main
shiyu22 3 years ago
parent
commit
d484f2eb18
  1. 23
      test_image_embedding_operator_template.py

23
test_image_embedding_operator_template.py

@ -1,23 +1,8 @@
from PIL import Image
from torchvision import transforms
from image_embedding_operator_template import ImageEmbeddingOperatorTemplate from image_embedding_operator_template import ImageEmbeddingOperatorTemplate
def get_transformered_img(img_path):
img = Image.open(img_path)
tfms = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
img_tensor = tfms(img).unsqueeze(0)
return img_tensor
def test_image_embedding_operator(model_name='resnet50', img_path='./test_data/test.jpg', dimension=2048): def test_image_embedding_operator(model_name='resnet50', img_path='./test_data/test.jpg', dimension=2048):
op = ImageEmbeddingOperatorTemplate(model_name)
img_tensor = get_transformered_img(img_path)
embedding = op(img_tensor)
assert (1, dimension)==embedding[0].shape
op = ResnetImageEmbedding(model_name)
embedding = op(img_path)
# get the output shape in README
assert (dimension,)==embedding[0].shape

Loading…
Cancel
Save