diff --git a/resnet_image_embedding.py b/resnet_image_embedding.py index b0161b9..ba18e2d 100644 --- a/resnet_image_embedding.py +++ b/resnet_image_embedding.py @@ -19,7 +19,7 @@ from PIL import Image from torchvision import transforms from pathlib import Path from typing import NamedTuple - +import os from torchvision.transforms import InterpolationMode from towhee.operator import Operator @@ -30,10 +30,14 @@ class ResnetImageEmbedding(Operator): """ def __init__(self, model_name: str, framework: str = 'pytorch') -> None: super().__init__() - sys.path.append(str(Path(__file__).parent)) if framework == 'pytorch': - from pytorch.model import Model - self.model = Model(model_name) + import importlib.util + path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py') + opname = os.path.basename(str(Path(__file__))).split('.')[0] + spec = importlib.util.spec_from_file_location(opname, path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + self.model = module.Model(model_name) self.tfms = transforms.Compose([transforms.Resize(235, interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(224), transforms.ToTensor(),