logo
Browse Source

Update model

Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
training
shiyu22 3 years ago
parent
commit
4adfae8f02
  1. 30
      pytorch/model.py
  2. 13
      resnet50_image_embedding.py

30
pytorch/resnet50.py → pytorch/model.py

@ -13,28 +13,28 @@
# limitations under the License.
from typing import NamedTuple
import numpy
import torch
import torchvision
class Resnet50():
class Model():
"""
PyTorch model for image embedding.
PyTorch model class
"""
def __init__(self, model_name: str):
self.model_name = model_name
def __init__(self, model_name):
super().__init__()
model_func = getattr(torchvision.models, model_name)
self._model = model_func(pretrained=True)
self._model.eval()
def load_model(self):
"""
For loading model
"""
model_func = getattr(torchvision.models, self.model_name)
self.model = model_func(pretrained=True)
self.model.eval()
return self.model
def train_model(self):
def __call__(self, img_tensor: torch.Tensor):
return self._model(img_tensor).detach().numpy()
def train(self):
"""
For training model
"""
pass
pass

13
resnet50_image_embedding.py

@ -26,13 +26,16 @@ class Resnet50ImageEmbedding(Operator):
"""
PyTorch model for image embedding.
"""
def __init__(self, model_name: str) -> None:
def __init__(self, model_name: str, framework: str = 'pytorch') -> None:
super().__init__()
sys.path.append(str(Path(__file__).parent))
from pytorch.resnet50 import Resnet50
resnet50_image_embedding = Resnet50(model_name)
self._model = resnet50_image_embedding.load_model()
if framework == 'pytorch':
from pytorch.model import Model
if framework == 'tensorflow':
from tensorflow.model import Model
self.model = Model(model_name)
def __call__(self, img_tensor: torch.Tensor) -> NamedTuple('Outputs', [('cnn', numpy.ndarray)]):
embedding = self.model(img_tensor)
Outputs = NamedTuple('Outputs', [('cnn', numpy.ndarray)])
return Outputs(self._model(img_tensor).detach().numpy())
return Outputs(embedding)

Loading…
Cancel
Save