logo
Browse Source

change the input format.

training
wxywb 3 years ago
parent
commit
951ae888e7
  1. 9
      README.md
  2. 6
      resnet_image_embedding.py

9
README.md

@ -24,14 +24,15 @@ __init__(self, model_name: str, framework: str = 'pytorch')
- supported types: `str`, default is 'pytorch'
```python
__call__(self, img_tensor: torch.Tensor)
__call__(self, image: 'towhee.types.Image')
```
**Args:**
- img_tensor:
- the input image tensor
- supported types: `torch.Tensor`
image:
- the input image
- supported types: `towhee.types.Image`
**Returns:**

6
resnet_image_embedding.py

@ -22,6 +22,7 @@ from typing import NamedTuple
import os
from torchvision.transforms import InterpolationMode
from towhee.operator import Operator
from towhee.utils.pil_utils import to_pil
import warnings
warnings.filterwarnings("ignore")
@ -44,9 +45,8 @@ class ResnetImageEmbedding(Operator):
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]):
img = self.tfms(Image.open(img_path).convert('RGB')).unsqueeze(0)
def __call__(self, image: 'towhee.types.Image') -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]):
img = self.tfms(to_pil(image)).unsqueeze(0)
embedding = self.model(img)
Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)])
return Outputs(embedding)

Loading…
Cancel
Save