logo
Browse Source

fix import problem

main
zhang chen 3 years ago
parent
commit
6e4ab40e06
  1. 20
      vit_image_embedding.py

20
vit_image_embedding.py

@ -19,10 +19,11 @@ from pathlib import Path
from PIL import Image from PIL import Image
import torch import torch
import numpy import numpy
import os
from towhee.operator import Operator from towhee.operator import Operator
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
class VitImageEmbedding(Operator): class VitImageEmbedding(Operator):
""" """
@ -37,13 +38,16 @@ class VitImageEmbedding(Operator):
def __init__(self, model_name: str = 'vit_large_patch16_224', def __init__(self, model_name: str = 'vit_large_patch16_224',
framework: str = 'pytorch', weights_path: str = None) -> None: framework: str = 'pytorch', weights_path: str = None) -> None:
super().__init__() super().__init__()
sys.path.append(str(Path(__file__).parent))
if framework == 'pytorch': if framework == 'pytorch':
import pytorch
from pytorch.model import Model
self.model = Model(model_name, weights_path)
config = pytorch.resolve_data_config({}, model=self.model._model)
self.tfms = pytorch.create_transform(**config)
import importlib.util
path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py')
opname = os.path.basename(str(Path(__file__))).split('.')[0]
spec = importlib.util.spec_from_file_location(opname, path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
self.model = module.Model(model_name, weights_path)
config = resolve_data_config({}, model=self.model._model)
self.tfms = create_transform(**config)
def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]):
Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)])

Loading…
Cancel
Save