From ddf0974d27a1aa561d7d9e1070475a3cbbad7039 Mon Sep 17 00:00:00 2001 From: zhang chen Date: Wed, 29 Dec 2021 20:58:54 +0800 Subject: [PATCH] change input type --- README.md | 8 ++++---- vit_image_embedding.py | 5 +++-- vit_image_embedding.yaml | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index d27701b..f33705e 100644 --- a/README.md +++ b/README.md @@ -26,14 +26,14 @@ __init__(self, model_name: str = 'vit_large_patch16_224', - supported types: `str`, default is None, using pretrained weights ```python -__call__(self, img_path: str) +__call__(self, image: 'towhee.types.Image') ``` **Args:** -- img_path: - - the input image path - - supported types: `str` +- img_tensor: + - the input image tensor + - supported types: `torch.Tensor` **Returns:** diff --git a/vit_image_embedding.py b/vit_image_embedding.py index 310d435..cc4a5eb 100644 --- a/vit_image_embedding.py +++ b/vit_image_embedding.py @@ -24,6 +24,7 @@ import os from towhee.operator import Operator from timm.data import resolve_data_config from timm.data.transforms_factory import create_transform +from towhee.utils.pil_utils import to_pil import warnings warnings.filterwarnings("ignore") @@ -52,8 +53,8 @@ class VitImageEmbedding(Operator): config = resolve_data_config({}, model=self.model._model) self.tfms = create_transform(**config) - def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): + def __call__(self, image: 'towhee.types.Image') -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): + img = self.tfms(to_pil(image)).unsqueeze(0) Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) - img = self.tfms(Image.open(img_path)).unsqueeze(0) features = self.model(img) return Outputs(features.flatten().detach().numpy()) diff --git a/vit_image_embedding.yaml b/vit_image_embedding.yaml index fa903c1..8011a86 100644 --- a/vit_image_embedding.yaml +++ b/vit_image_embedding.yaml @@ -8,6 +8,6 @@ init: model_name: str call: input: - img_path: str + image: towhee.types.Image output: feature_vector: numpy.ndarray