towhee
/
1 changed files with 4 additions and 19 deletions
@ -1,23 +1,8 @@ |
|||
from PIL import Image |
|||
from torchvision import transforms |
|||
from image_embedding_operator_template import ImageEmbeddingOperatorTemplate |
|||
|
|||
|
|||
def get_transformered_img(img_path): |
|||
img = Image.open(img_path) |
|||
tfms = transforms.Compose( |
|||
[ |
|||
transforms.Resize(256), |
|||
transforms.CenterCrop(224), |
|||
transforms.ToTensor(), |
|||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
|||
] |
|||
) |
|||
img_tensor = tfms(img).unsqueeze(0) |
|||
return img_tensor |
|||
|
|||
def test_image_embedding_operator(model_name='resnet50', img_path='./test_data/test.jpg', dimension=2048): |
|||
op = ImageEmbeddingOperatorTemplate(model_name) |
|||
img_tensor = get_transformered_img(img_path) |
|||
embedding = op(img_tensor) |
|||
assert (1, dimension)==embedding[0].shape |
|||
op = ResnetImageEmbedding(model_name) |
|||
embedding = op(img_path) |
|||
# get the output shape in README |
|||
assert (dimension,)==embedding[0].shape |
|||
|
Loading…
Reference in new issue