|
|
@ -2,14 +2,10 @@ import unittest |
|
|
|
from PIL import Image |
|
|
|
from torchvision import transforms |
|
|
|
from image_embedding_operator_template import ImageEmbeddingOperatorTemplate |
|
|
|
from config import DIMENSION, MODEL_NAME, TEST_IMG |
|
|
|
|
|
|
|
|
|
|
|
class TestImageEmbeddingOperatorTemplate(unittest.TestCase): |
|
|
|
""" |
|
|
|
Simple operator test |
|
|
|
""" |
|
|
|
img = Image.open(TEST_IMG) |
|
|
|
def get_transformered_img(img_path): |
|
|
|
img = Image.open(img_path) |
|
|
|
tfms = transforms.Compose( |
|
|
|
[ |
|
|
|
transforms.Resize(256), |
|
|
@ -19,14 +15,10 @@ class TestImageEmbeddingOperatorTemplate(unittest.TestCase): |
|
|
|
] |
|
|
|
) |
|
|
|
img_tensor = tfms(img).unsqueeze(0) |
|
|
|
return img_tensor |
|
|
|
|
|
|
|
def test_image_embedding(self): |
|
|
|
self.model_name = MODEL_NAME |
|
|
|
self.dimension = DIMENSION |
|
|
|
op = ImageEmbeddingOperatorTemplate(self.model_name) |
|
|
|
print("The output shape of operator:", op(self.img_tensor)[0].shape) |
|
|
|
assert (1, self.dimension)==op(self.img_tensor)[0].shape |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
unittest.main() |
|
|
|
def test_image_embedding_operator(model_name='resnet50', img_path='./test_data/test.jpg', dimension=2048): |
|
|
|
op = ImageEmbeddingOperatorTemplate(model_name) |
|
|
|
img_tensor = get_transformered_img(img_path) |
|
|
|
embedding = op(img_tensor) |
|
|
|
assert (1, dimension)==embedding[0].shape |
|
|
|