logo
Browse Source

Merge remote-tracking branch 'origin/main'

training
zhang chen 3 years ago
parent
commit
499ca03321
  1. 4
      requirements.txt
  2. 2
      resnet_image_embedding.py

4
requirements.txt

@ -0,0 +1,4 @@
torch==1.9.0
torchvision==0.10.0
pillow==8.3.1
numpy==1.19.5

2
resnet_image_embedding.py

@ -46,7 +46,7 @@ class ResnetImageEmbedding(Operator):
def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]):
img = self.tfms(Image.open(img_path)).unsqueeze(0)
img = self.tfms(Image.open(img_path).convert('RGB')).unsqueeze(0)
embedding = self.model(img)
Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)])
return Outputs(embedding)

Loading…
Cancel
Save