diff --git a/pytorch/__init__.py b/pytorch/__init__.py deleted file mode 100644 index b661573..0000000 --- a/pytorch/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright 2021 Zilliz. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -# For requirements. -try: - import timm -except ModuleNotFoundError: - os.system('pip install timm') - -from timm.data import resolve_data_config -from timm.data.transforms_factory import create_transform \ No newline at end of file diff --git a/pytorch/model.py b/pytorch/model.py deleted file mode 100644 index 96b71ee..0000000 --- a/pytorch/model.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2021 Zilliz. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -from torch.nn import Linear -from torch import nn -import timm - - -class Model(): - """ - PyTorch model class - """ - def __init__(self, model_name: str, weights_path: str, num_classes=1000): - super().__init__() - if weights_path: - self._model = timm.create_model(model_name, checkpoint_path=weights_path, num_classes=num_classes) - else: - self._model = timm.create_model(model_name, pretrained=True, num_classes=num_classes) - self._model.eval() - - def __call__(self, img_tensor: torch.Tensor): - self._model.eval() - features = self._model.forward_features(img_tensor) - if features.dim() == 4: # if the shape of feature map is [N, C, H, W], where H > 1 and W > 1 - global_pool = nn.AdaptiveAvgPool2d(1) - features = global_pool(features) - return features.flatten().detach().numpy() - - diff --git a/vit_image_embedding.py b/vit_image_embedding.py index 9314cea..a8ca545 100644 --- a/vit_image_embedding.py +++ b/vit_image_embedding.py @@ -13,23 +13,21 @@ # limitations under the License. -import sys -from typing import NamedTuple -from pathlib import Path -from PIL import Image +import timm import torch -from torch import nn as nn import numpy -import os - -from towhee.operator import Operator, NNOperator +from torch import nn as nn +from typing import NamedTuple +from towhee.operator import NNOperator from timm.data import resolve_data_config from timm.data.transforms_factory import create_transform from towhee.utils.pil_utils import to_pil import warnings + warnings.filterwarnings("ignore") + class VitImageEmbedding(NNOperator): """ Embedding extractor using ViT. @@ -42,23 +40,25 @@ class VitImageEmbedding(NNOperator): def __init__(self, model_name: str = 'vit_large_patch16_224', num_classes: int = 1000, framework: str = 'pytorch', weights_path: str = None) -> None: - super().__init__() - if framework == 'pytorch': - import importlib.util - path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py') - opname = os.path.basename(str(Path(__file__))).split('.')[0] - spec = importlib.util.spec_from_file_location(opname, path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - self.model = module.Model(model_name, weights_path, num_classes=num_classes) - config = resolve_data_config({}, model=self.model._model) + super().__init__(framework=framework) + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + if weights_path: + self.model = timm.create_model(model_name, checkpoint_path=weights_path, num_classes=num_classes) + else: + self.model = timm.create_model(model_name, pretrained=True, num_classes=num_classes) + self.model.eval() + config = resolve_data_config({}, model=self.model) self.tfms = create_transform(**config) def __call__(self, image: 'towhee.types.Image') -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): - img = self.tfms(to_pil(image)).unsqueeze(0) + img_tensor = self.tfms(to_pil(image)).unsqueeze(0) Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) - features = self.model(img) + self.model.to(self.device) + self.model.eval() + features = self.model.forward_features(img_tensor) + if features.dim() == 4: # if the shape of feature map is [N, C, H, W], where H > 1 and W > 1 + global_pool = nn.AdaptiveAvgPool2d(1) + features = global_pool(features) + features = features.to('cpu') + features = features.flatten().detach().numpy() return Outputs(features) - - def get_model(self) -> nn.Module: - return self.model._model \ No newline at end of file