logo
Browse Source

fix import problem

training
zhang chen 3 years ago
parent
commit
b32cd33766
  1. 12
      resnet_image_embedding.py

12
resnet_image_embedding.py

@ -19,7 +19,7 @@ from PIL import Image
from torchvision import transforms from torchvision import transforms
from pathlib import Path from pathlib import Path
from typing import NamedTuple from typing import NamedTuple
import os
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
from towhee.operator import Operator from towhee.operator import Operator
@ -30,10 +30,14 @@ class ResnetImageEmbedding(Operator):
""" """
def __init__(self, model_name: str, framework: str = 'pytorch') -> None: def __init__(self, model_name: str, framework: str = 'pytorch') -> None:
super().__init__() super().__init__()
sys.path.append(str(Path(__file__).parent))
if framework == 'pytorch': if framework == 'pytorch':
from pytorch.model import Model
self.model = Model(model_name)
import importlib.util
path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py')
opname = os.path.basename(str(Path(__file__))).split('.')[0]
spec = importlib.util.spec_from_file_location(opname, path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
self.model = module.Model(model_name)
self.tfms = transforms.Compose([transforms.Resize(235, interpolation=InterpolationMode.BICUBIC), self.tfms = transforms.Compose([transforms.Resize(235, interpolation=InterpolationMode.BICUBIC),
transforms.CenterCrop(224), transforms.CenterCrop(224),
transforms.ToTensor(), transforms.ToTensor(),

Loading…
Cancel
Save