logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

32 lines
985 B

import unittest
from PIL import Image
from torchvision import transforms
from image_embedding_operator_template import ImageEmbeddingOperatorTemplate
from config import DIMENSION, MODEL_NAME, TEST_IMG
class TestImageEmbeddingOperatorTemplate(unittest.TestCase):
"""
Simple operator test
"""
img = Image.open(TEST_IMG)
tfms = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
img_tensor = tfms(img).unsqueeze(0)
def test_image_embedding(self):
self.model_name = MODEL_NAME
self.dimension = DIMENSION
op = ImageEmbeddingOperatorTemplate(self.model_name)
print("The output shape of operator:", op(self.img_tensor)[0].shape)
assert (1, self.dimension)==op(self.img_tensor)[0].shape
if __name__ == '__main__':
unittest.main()