Browse Source
Merge remote-tracking branch 'origin/main'
training
zhang chen
3 years ago
2 changed files with
5 additions and
1 deletions
-
requirements.txt
-
resnet_image_embedding.py
|
@ -0,0 +1,4 @@ |
|
|
|
|
|
torch==1.9.0 |
|
|
|
|
|
torchvision==0.10.0 |
|
|
|
|
|
pillow==8.3.1 |
|
|
|
|
|
numpy==1.19.5 |
|
@ -46,7 +46,7 @@ class ResnetImageEmbedding(Operator): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): |
|
|
def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): |
|
|
img = self.tfms(Image.open(img_path)).unsqueeze(0) |
|
|
|
|
|
|
|
|
img = self.tfms(Image.open(img_path).convert('RGB')).unsqueeze(0) |
|
|
embedding = self.model(img) |
|
|
embedding = self.model(img) |
|
|
Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) |
|
|
Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) |
|
|
return Outputs(embedding) |
|
|
return Outputs(embedding) |
|
|