towhee
/
resnet-image-embedding
copied
5 changed files with 18 additions and 109 deletions
@ -1,2 +0,0 @@ |
|||
torch>=1.2.0 |
|||
torchvision>=0.4.0 |
Before Width: | Height: | Size: 262 KiB After Width: | Height: | Size: 262 KiB |
Before Width: | Height: | Size: 178 KiB |
@ -1,32 +0,0 @@ |
|||
import os |
|||
import unittest |
|||
from PIL import Image |
|||
from torchvision import transforms |
|||
from resnet50_image_embedding import Resnet50ImageEmbedding |
|||
|
|||
|
|||
class TestResnet50ImageEmbedding(unittest.TestCase): |
|||
""" |
|||
Simple operator test |
|||
""" |
|||
def test_image_embedding(self): |
|||
test_img = './test_data/test.jpg' |
|||
img = Image.open(test_img) |
|||
tfms = transforms.Compose( |
|||
[ |
|||
transforms.Resize(256), |
|||
transforms.CenterCrop(224), |
|||
transforms.ToTensor(), |
|||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
|||
] |
|||
) |
|||
img_tensor = tfms(img).unsqueeze(0) |
|||
model_name = 'resnet50' |
|||
dimension = 1000 |
|||
op = Resnet50ImageEmbedding(model_name) |
|||
print("The output shape of operator:", op(img_tensor)[0].shape) |
|||
self.assertEqual((1, dimension), op(img_tensor)[0].shape) |
|||
|
|||
|
|||
if __name__ == '__main__': |
|||
unittest.main() |
Loading…
Reference in new issue