From 4adfae8f02d7ffad111da8529262977639ffbdc2 Mon Sep 17 00:00:00 2001 From: shiyu22 Date: Thu, 9 Dec 2021 14:03:48 +0800 Subject: [PATCH] Update model Signed-off-by: shiyu22 --- pytorch/{resnet50.py => model.py} | 30 +++++++++++++++--------------- resnet50_image_embedding.py | 13 ++++++++----- 2 files changed, 23 insertions(+), 20 deletions(-) rename pytorch/{resnet50.py => model.py} (62%) diff --git a/pytorch/resnet50.py b/pytorch/model.py similarity index 62% rename from pytorch/resnet50.py rename to pytorch/model.py index 7044957..ea0e408 100644 --- a/pytorch/resnet50.py +++ b/pytorch/model.py @@ -13,28 +13,28 @@ # limitations under the License. +from typing import NamedTuple + +import numpy import torch import torchvision -class Resnet50(): +class Model(): """ - PyTorch model for image embedding. + PyTorch model class """ - def __init__(self, model_name: str): - self.model_name = model_name + def __init__(self, model_name): + super().__init__() + model_func = getattr(torchvision.models, model_name) + self._model = model_func(pretrained=True) + self._model.eval() - def load_model(self): - """ - For loading model - """ - model_func = getattr(torchvision.models, self.model_name) - self.model = model_func(pretrained=True) - self.model.eval() - return self.model - - def train_model(self): + def __call__(self, img_tensor: torch.Tensor): + return self._model(img_tensor).detach().numpy() + + def train(self): """ For training model """ - pass + pass \ No newline at end of file diff --git a/resnet50_image_embedding.py b/resnet50_image_embedding.py index 6ce366d..66f3bdb 100644 --- a/resnet50_image_embedding.py +++ b/resnet50_image_embedding.py @@ -26,13 +26,16 @@ class Resnet50ImageEmbedding(Operator): """ PyTorch model for image embedding. """ - def __init__(self, model_name: str) -> None: + def __init__(self, model_name: str, framework: str = 'pytorch') -> None: super().__init__() sys.path.append(str(Path(__file__).parent)) - from pytorch.resnet50 import Resnet50 - resnet50_image_embedding = Resnet50(model_name) - self._model = resnet50_image_embedding.load_model() + if framework == 'pytorch': + from pytorch.model import Model + if framework == 'tensorflow': + from tensorflow.model import Model + self.model = Model(model_name) def __call__(self, img_tensor: torch.Tensor) -> NamedTuple('Outputs', [('cnn', numpy.ndarray)]): + embedding = self.model(img_tensor) Outputs = NamedTuple('Outputs', [('cnn', numpy.ndarray)]) - return Outputs(self._model(img_tensor).detach().numpy()) + return Outputs(embedding)