import unittest from PIL import Image from torchvision import transforms from image_embedding_operator_template import ImageEmbeddingOperatorTemplate def get_transformered_img(img_path): img = Image.open(img_path) tfms = transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) img_tensor = tfms(img).unsqueeze(0) return img_tensor def test_image_embedding_operator(model_name='resnet50', img_path='./test_data/test.jpg', dimension=2048): op = ImageEmbeddingOperatorTemplate(model_name) img_tensor = get_transformered_img(img_path) embedding = op(img_tensor) assert (1, dimension)==embedding[0].shape