import os import unittest from PIL import Image from torchvision import transforms from resnet50_image_embedding import Resnet50ImageEmbedding class TestResnet50ImageEmbedding(unittest.TestCase): """ Simple operator test """ def test_image_embedding(self): test_img = './test_data/test.jpg' img = Image.open(test_img) tfms = transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) img_tensor = tfms(img).unsqueeze(0) model_name = 'resnet50' dimension = 1000 op = Resnet50ImageEmbedding(model_name) print("The output shape of operator:", op(img_tensor)[0].shape) self.assertEqual((1, dimension), op(img_tensor)[0].shape) if __name__ == '__main__': unittest.main()