logo
Browse Source

change input type

main
zhang chen 3 years ago
parent
commit
ddf0974d27
  1. 8
      README.md
  2. 5
      vit_image_embedding.py
  3. 2
      vit_image_embedding.yaml

8
README.md

@ -26,14 +26,14 @@ __init__(self, model_name: str = 'vit_large_patch16_224',
- supported types: `str`, default is None, using pretrained weights
```python
__call__(self, img_path: str)
__call__(self, image: 'towhee.types.Image')
```
**Args:**
- img_path:
- the input image path
- supported types: `str`
- img_tensor:
- the input image tensor
- supported types: `torch.Tensor`
**Returns:**

5
vit_image_embedding.py

@ -24,6 +24,7 @@ import os
from towhee.operator import Operator
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from towhee.utils.pil_utils import to_pil
import warnings
warnings.filterwarnings("ignore")
@ -52,8 +53,8 @@ class VitImageEmbedding(Operator):
config = resolve_data_config({}, model=self.model._model)
self.tfms = create_transform(**config)
def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]):
def __call__(self, image: 'towhee.types.Image') -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]):
img = self.tfms(to_pil(image)).unsqueeze(0)
Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)])
img = self.tfms(Image.open(img_path)).unsqueeze(0)
features = self.model(img)
return Outputs(features.flatten().detach().numpy())

2
vit_image_embedding.yaml

@ -8,6 +8,6 @@ init:
model_name: str
call:
input:
img_path: str
image: towhee.types.Image
output:
feature_vector: numpy.ndarray

Loading…
Cancel
Save