logo
Browse Source

Update

Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
main
shiyu22 3 years ago
parent
commit
377679bb92
  1. 5
      pytorch/__init__.py
  2. 7
      vit_image_embedding.py

5
pytorch/__init__.py

@ -18,4 +18,7 @@ import os
try: try:
import timm import timm
except ModuleNotFoundError: except ModuleNotFoundError:
os.system('pip install timm')
os.system('pip install timm')
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

7
vit_image_embedding.py

@ -20,6 +20,7 @@ from PIL import Image
import torch import torch
import numpy import numpy
from towhee.operator import Operator from towhee.operator import Operator
@ -39,11 +40,9 @@ class VitImageEmbedding(Operator):
sys.path.append(str(Path(__file__).parent)) sys.path.append(str(Path(__file__).parent))
if framework == 'pytorch': if framework == 'pytorch':
from pytorch.model import Model from pytorch.model import Model
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
self.model = Model(model_name, weights_path) self.model = Model(model_name, weights_path)
config = resolve_data_config({}, model=self.model._model)
self.tfms = create_transform(**config)
config = pytorch.resolve_data_config({}, model=self.model._model)
self.tfms = pytorch.create_transform(**config)
def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]):
Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)])

Loading…
Cancel
Save