diff --git a/pytorch/__init__.py b/pytorch/__init__.py index 86316bc..b661573 100644 --- a/pytorch/__init__.py +++ b/pytorch/__init__.py @@ -18,4 +18,7 @@ import os try: import timm except ModuleNotFoundError: - os.system('pip install timm') \ No newline at end of file + os.system('pip install timm') + +from timm.data import resolve_data_config +from timm.data.transforms_factory import create_transform \ No newline at end of file diff --git a/vit_image_embedding.py b/vit_image_embedding.py index ed88e2b..9df2437 100644 --- a/vit_image_embedding.py +++ b/vit_image_embedding.py @@ -20,6 +20,7 @@ from PIL import Image import torch import numpy + from towhee.operator import Operator @@ -39,11 +40,9 @@ class VitImageEmbedding(Operator): sys.path.append(str(Path(__file__).parent)) if framework == 'pytorch': from pytorch.model import Model - from timm.data import resolve_data_config - from timm.data.transforms_factory import create_transform self.model = Model(model_name, weights_path) - config = resolve_data_config({}, model=self.model._model) - self.tfms = create_transform(**config) + config = pytorch.resolve_data_config({}, model=self.model._model) + self.tfms = pytorch.create_transform(**config) def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)])