diff --git a/pytorch/model.py b/pytorch/model.py index 007b645..96b71ee 100644 --- a/pytorch/model.py +++ b/pytorch/model.py @@ -14,6 +14,8 @@ import torch +from torch.nn import Linear +from torch import nn import timm @@ -21,19 +23,20 @@ class Model(): """ PyTorch model class """ - def __init__(self, model_name: str, weights_path: str): + def __init__(self, model_name: str, weights_path: str, num_classes=1000): super().__init__() if weights_path: - self._model = timm.create_model(model_name, checkpoint_path=weights_path, num_classes=0) + self._model = timm.create_model(model_name, checkpoint_path=weights_path, num_classes=num_classes) else: - self._model = timm.create_model(model_name, pretrained=True, num_classes=0) + self._model = timm.create_model(model_name, pretrained=True, num_classes=num_classes) self._model.eval() def __call__(self, img_tensor: torch.Tensor): - return self._model(img_tensor) + self._model.eval() + features = self._model.forward_features(img_tensor) + if features.dim() == 4: # if the shape of feature map is [N, C, H, W], where H > 1 and W > 1 + global_pool = nn.AdaptiveAvgPool2d(1) + features = global_pool(features) + return features.flatten().detach().numpy() + - def train(self): - """ - For training model - """ - pass diff --git a/vit_image_embedding.py b/vit_image_embedding.py index cc4a5eb..9314cea 100644 --- a/vit_image_embedding.py +++ b/vit_image_embedding.py @@ -18,10 +18,11 @@ from typing import NamedTuple from pathlib import Path from PIL import Image import torch +from torch import nn as nn import numpy import os -from towhee.operator import Operator +from towhee.operator import Operator, NNOperator from timm.data import resolve_data_config from timm.data.transforms_factory import create_transform from towhee.utils.pil_utils import to_pil @@ -29,7 +30,7 @@ from towhee.utils.pil_utils import to_pil import warnings warnings.filterwarnings("ignore") -class VitImageEmbedding(Operator): +class VitImageEmbedding(NNOperator): """ Embedding extractor using ViT. Args: @@ -39,7 +40,7 @@ class VitImageEmbedding(Operator): Path to local weights. """ - def __init__(self, model_name: str = 'vit_large_patch16_224', + def __init__(self, model_name: str = 'vit_large_patch16_224', num_classes: int = 1000, framework: str = 'pytorch', weights_path: str = None) -> None: super().__init__() if framework == 'pytorch': @@ -49,7 +50,7 @@ class VitImageEmbedding(Operator): spec = importlib.util.spec_from_file_location(opname, path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - self.model = module.Model(model_name, weights_path) + self.model = module.Model(model_name, weights_path, num_classes=num_classes) config = resolve_data_config({}, model=self.model._model) self.tfms = create_transform(**config) @@ -57,4 +58,7 @@ class VitImageEmbedding(Operator): img = self.tfms(to_pil(image)).unsqueeze(0) Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) features = self.model(img) - return Outputs(features.flatten().detach().numpy()) + return Outputs(features) + + def get_model(self) -> nn.Module: + return self.model._model \ No newline at end of file