towhee
/
resnet-image-embedding
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
32 lines
967 B
32 lines
967 B
3 years ago
|
import os
|
||
|
import unittest
|
||
|
from PIL import Image
|
||
|
from torchvision import transforms
|
||
|
from resnet50_image_embedding import Resnet50ImageEmbedding
|
||
|
|
||
|
|
||
|
class TestResnet50ImageEmbedding(unittest.TestCase):
|
||
|
"""
|
||
|
Simple operator test
|
||
|
"""
|
||
|
def test_image_embedding(self):
|
||
|
test_img = './test_data/test.jpg'
|
||
|
img = Image.open(test_img)
|
||
|
tfms = transforms.Compose(
|
||
|
[
|
||
|
transforms.Resize(256),
|
||
|
transforms.CenterCrop(224),
|
||
|
transforms.ToTensor(),
|
||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||
|
]
|
||
|
)
|
||
|
img_tensor = tfms(img).unsqueeze(0)
|
||
|
model_name = 'resnet50'
|
||
|
dimension = 1000
|
||
|
op = Resnet50ImageEmbedding(model_name)
|
||
|
print("The output shape of operator:", op(img_tensor)[0].shape)
|
||
|
self.assertEqual((1, dimension), op(img_tensor)[0].shape)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
unittest.main()
|