You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
|
|
|
from PIL import Image
|
|
|
|
from torchvision import transforms
|
|
|
|
from image_embedding_operator_template import ImageEmbeddingOperatorTemplate
|
|
|
|
|
|
|
|
|
|
|
|
def get_transformered_img(img_path):
|
|
|
|
img = Image.open(img_path)
|
|
|
|
tfms = transforms.Compose(
|
|
|
|
[
|
|
|
|
transforms.Resize(256),
|
|
|
|
transforms.CenterCrop(224),
|
|
|
|
transforms.ToTensor(),
|
|
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
|
|
|
]
|
|
|
|
)
|
|
|
|
img_tensor = tfms(img).unsqueeze(0)
|
|
|
|
return img_tensor
|
|
|
|
|
|
|
|
def test_image_embedding_operator(model_name='resnet50', img_path='./test_data/test.jpg', dimension=2048):
|
|
|
|
op = ImageEmbeddingOperatorTemplate(model_name)
|
|
|
|
img_tensor = get_transformered_img(img_path)
|
|
|
|
embedding = op(img_tensor)
|
|
|
|
assert (1, dimension)==embedding[0].shape
|