Browse Source
Update model
Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
training
2 changed files with
23 additions and
20 deletions
-
pytorch/model.py
-
resnet50_image_embedding.py
|
|
@ -13,28 +13,28 @@ |
|
|
|
# limitations under the License. |
|
|
|
|
|
|
|
|
|
|
|
from typing import NamedTuple |
|
|
|
|
|
|
|
import numpy |
|
|
|
import torch |
|
|
|
import torchvision |
|
|
|
|
|
|
|
|
|
|
|
class Resnet50(): |
|
|
|
class Model(): |
|
|
|
""" |
|
|
|
PyTorch model for image embedding. |
|
|
|
PyTorch model class |
|
|
|
""" |
|
|
|
def __init__(self, model_name: str): |
|
|
|
self.model_name = model_name |
|
|
|
def __init__(self, model_name): |
|
|
|
super().__init__() |
|
|
|
model_func = getattr(torchvision.models, model_name) |
|
|
|
self._model = model_func(pretrained=True) |
|
|
|
self._model.eval() |
|
|
|
|
|
|
|
def load_model(self): |
|
|
|
""" |
|
|
|
For loading model |
|
|
|
""" |
|
|
|
model_func = getattr(torchvision.models, self.model_name) |
|
|
|
self.model = model_func(pretrained=True) |
|
|
|
self.model.eval() |
|
|
|
return self.model |
|
|
|
|
|
|
|
def train_model(self): |
|
|
|
def __call__(self, img_tensor: torch.Tensor): |
|
|
|
return self._model(img_tensor).detach().numpy() |
|
|
|
|
|
|
|
def train(self): |
|
|
|
""" |
|
|
|
For training model |
|
|
|
""" |
|
|
|
pass |
|
|
|
pass |
|
|
@ -26,13 +26,16 @@ class Resnet50ImageEmbedding(Operator): |
|
|
|
""" |
|
|
|
PyTorch model for image embedding. |
|
|
|
""" |
|
|
|
def __init__(self, model_name: str) -> None: |
|
|
|
def __init__(self, model_name: str, framework: str = 'pytorch') -> None: |
|
|
|
super().__init__() |
|
|
|
sys.path.append(str(Path(__file__).parent)) |
|
|
|
from pytorch.resnet50 import Resnet50 |
|
|
|
resnet50_image_embedding = Resnet50(model_name) |
|
|
|
self._model = resnet50_image_embedding.load_model() |
|
|
|
if framework == 'pytorch': |
|
|
|
from pytorch.model import Model |
|
|
|
if framework == 'tensorflow': |
|
|
|
from tensorflow.model import Model |
|
|
|
self.model = Model(model_name) |
|
|
|
|
|
|
|
def __call__(self, img_tensor: torch.Tensor) -> NamedTuple('Outputs', [('cnn', numpy.ndarray)]): |
|
|
|
embedding = self.model(img_tensor) |
|
|
|
Outputs = NamedTuple('Outputs', [('cnn', numpy.ndarray)]) |
|
|
|
return Outputs(self._model(img_tensor).detach().numpy()) |
|
|
|
return Outputs(embedding) |
|
|
|