diff --git a/resnet_image_embedding.py b/resnet_image_embedding.py index 5b7cdc9..d10f224 100644 --- a/resnet_image_embedding.py +++ b/resnet_image_embedding.py @@ -26,12 +26,12 @@ class ResnetImageEmbedding(Operator): """ PyTorch model for image embedding. """ - def __init__(self, model_name: str, framework: str = 'pytorch', weights_path: str = None) -> None: + def __init__(self, model_name: str, framework: str = 'pytorch') -> None: super().__init__() sys.path.append(str(Path(__file__).parent)) if framework == 'pytorch': from pytorch.model import Model - self.model = Model(model_name, weights_path) + self.model = Model(model_name) def __call__(self, img_tensor: torch.Tensor) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): embedding = self.model(img_tensor)