logo
Browse Source

Update

Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
training
shiyu22 3 years ago
parent
commit
f6b8ce3ba7
  1. 4
      resnet_image_embedding.py

4
resnet_image_embedding.py

@ -26,12 +26,12 @@ class ResnetImageEmbedding(Operator):
""" """
PyTorch model for image embedding. PyTorch model for image embedding.
""" """
def __init__(self, model_name: str, framework: str = 'pytorch', weights_path: str = None) -> None:
def __init__(self, model_name: str, framework: str = 'pytorch') -> None:
super().__init__() super().__init__()
sys.path.append(str(Path(__file__).parent)) sys.path.append(str(Path(__file__).parent))
if framework == 'pytorch': if framework == 'pytorch':
from pytorch.model import Model from pytorch.model import Model
self.model = Model(model_name, weights_path)
self.model = Model(model_name)
def __call__(self, img_tensor: torch.Tensor) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): def __call__(self, img_tensor: torch.Tensor) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]):
embedding = self.model(img_tensor) embedding = self.model(img_tensor)

Loading…
Cancel
Save