towhee
/
1 changed files with 4 additions and 19 deletions
@ -1,23 +1,8 @@ |
|||||
from PIL import Image |
|
||||
from torchvision import transforms |
|
||||
from image_embedding_operator_template import ImageEmbeddingOperatorTemplate |
from image_embedding_operator_template import ImageEmbeddingOperatorTemplate |
||||
|
|
||||
|
|
||||
def get_transformered_img(img_path): |
|
||||
img = Image.open(img_path) |
|
||||
tfms = transforms.Compose( |
|
||||
[ |
|
||||
transforms.Resize(256), |
|
||||
transforms.CenterCrop(224), |
|
||||
transforms.ToTensor(), |
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
|
||||
] |
|
||||
) |
|
||||
img_tensor = tfms(img).unsqueeze(0) |
|
||||
return img_tensor |
|
||||
|
|
||||
def test_image_embedding_operator(model_name='resnet50', img_path='./test_data/test.jpg', dimension=2048): |
def test_image_embedding_operator(model_name='resnet50', img_path='./test_data/test.jpg', dimension=2048): |
||||
op = ImageEmbeddingOperatorTemplate(model_name) |
|
||||
img_tensor = get_transformered_img(img_path) |
|
||||
embedding = op(img_tensor) |
|
||||
assert (1, dimension)==embedding[0].shape |
|
||||
|
op = ResnetImageEmbedding(model_name) |
||||
|
embedding = op(img_path) |
||||
|
# get the output shape in README |
||||
|
assert (dimension,)==embedding[0].shape |
||||
|
Loading…
Reference in new issue