From b32cd3376643f0a701a263d3ed63a72102d5b4e2 Mon Sep 17 00:00:00 2001 From: zhang chen Date: Fri, 24 Dec 2021 15:46:21 +0800 Subject: [PATCH] fix import problem --- resnet_image_embedding.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/resnet_image_embedding.py b/resnet_image_embedding.py index b0161b9..ba18e2d 100644 --- a/resnet_image_embedding.py +++ b/resnet_image_embedding.py @@ -19,7 +19,7 @@ from PIL import Image from torchvision import transforms from pathlib import Path from typing import NamedTuple - +import os from torchvision.transforms import InterpolationMode from towhee.operator import Operator @@ -30,10 +30,14 @@ class ResnetImageEmbedding(Operator): """ def __init__(self, model_name: str, framework: str = 'pytorch') -> None: super().__init__() - sys.path.append(str(Path(__file__).parent)) if framework == 'pytorch': - from pytorch.model import Model - self.model = Model(model_name) + import importlib.util + path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py') + opname = os.path.basename(str(Path(__file__))).split('.')[0] + spec = importlib.util.spec_from_file_location(opname, path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + self.model = module.Model(model_name) self.tfms = transforms.Compose([transforms.Resize(235, interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(224), transforms.ToTensor(),