diff --git a/vit_image_embedding.py b/vit_image_embedding.py index 9df2437..8812503 100644 --- a/vit_image_embedding.py +++ b/vit_image_embedding.py @@ -39,6 +39,7 @@ class VitImageEmbedding(Operator): super().__init__() sys.path.append(str(Path(__file__).parent)) if framework == 'pytorch': + import pytorch from pytorch.model import Model self.model = Model(model_name, weights_path) config = pytorch.resolve_data_config({}, model=self.model._model)