diff --git a/ILSVRC2012_val_00049771.JPEG b/ILSVRC2012_val_00049771.JPEG new file mode 100644 index 0000000..5e9e980 Binary files /dev/null and b/ILSVRC2012_val_00049771.JPEG differ diff --git a/pytorch/embedding_extractor.py b/pytorch/embedding_extractor.py new file mode 100644 index 0000000..873b1d8 --- /dev/null +++ b/pytorch/embedding_extractor.py @@ -0,0 +1,66 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pprint + +class EmbeddingOutput: + """ + Container for embedding extractor. + """ + def __init__(self): + self.embeddings = [] + + def __call__(self, module, module_in, module_out): + self.embeddings.append(module_out) + + def clear(self): + """ + clear list + """ + self.embeddings = [] + + +class EmbeddingExtractor: + """ + Embedding extractor from a layer + Args: + model (`nn.Module`): + Model used for inference. + """ + def __init__(self, model): + # self.modules = model.modules() + # self.modules_list = list(model.named_modules(remove_duplicate=False)) + self.modules_dict = dict(model.named_modules(remove_duplicate=False)) + self.emb_out = EmbeddingOutput() + + def disp_modules(self, full=False): + """ + Display the the modules of the model. + """ + if not full: + pprint.pprint(list(self.modules_dict.keys())) + else: + pprint.pprint(self.modules_dict) + + def register(self, layer_name: str): + """ + Registration for embedding extraction. + Args: + layer_name (`str`): + Name of the layer from which the embedding is extracted. + """ + if layer_name in self.modules_dict: + layer = self.modules_dict[layer_name] + layer.register_forward_hook(self.emb_out) + else: + raise ValueError('layer_name not in modules') \ No newline at end of file diff --git a/pytorch/model.py b/pytorch/model.py index 3a32524..183d217 100644 --- a/pytorch/model.py +++ b/pytorch/model.py @@ -18,6 +18,13 @@ from typing import NamedTuple import numpy import torch import torchvision +from torch.nn import Linear +from timm.models.resnet import ResNet + +# ResNet. +from pytorch.embedding_extractor import EmbeddingExtractor +#todo:后面改成用towhee.models.embedding.下面的EmbeddingExtractor,这个现在在origin main分支上可用,但在train分支上不可用 + class Model(): @@ -38,14 +45,25 @@ class Model(): if state_dict: self._model.load_state_dict(state_dict) - self._model.fc = torch.nn.Identity() + # self._model.fc = torch.nn.Identity() self._model.eval() + self.ex = EmbeddingExtractor(self._model) + # self.ex.disp_modules(full=True) + self.ex.register('avgpool') def __call__(self, img_tensor: torch.Tensor): - return self._model(img_tensor).flatten().detach().numpy() + self.ex.emb_out.clear() + self._model(img_tensor) + # return self.fc_input[0] + return self.ex.emb_out.embeddings[0] + # return self._model(img_tensor).flatten().detach().numpy() #todo + + def create_classifier(self, num_classes): + self._model.fc = Linear(self._model.fc.in_features, num_classes, bias=True) + # self._model.classifier.register_forward_hook(self._forward_hook) - def train(self): - """ - For training model - """ - pass \ No newline at end of file + # def train(self): + # """ + # For training model + # """ + # pass \ No newline at end of file diff --git a/resnet_image_embedding.py b/resnet_image_embedding.py index b62b9a4..c4aafa1 100644 --- a/resnet_image_embedding.py +++ b/resnet_image_embedding.py @@ -21,12 +21,12 @@ from pathlib import Path from typing import NamedTuple import os from torchvision.transforms import InterpolationMode -from towhee.operator import Operator +from towhee.operator import NNOperator from towhee.utils.pil_utils import to_pil import warnings warnings.filterwarnings("ignore") -class ResnetImageEmbedding(Operator): +class ResnetImageEmbedding(NNOperator): """ PyTorch model for image embedding. """ @@ -50,3 +50,13 @@ class ResnetImageEmbedding(Operator): embedding = self.model(img) Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) return Outputs(embedding) + + def get_model(self): + return self.model._model + + # def test(self): + # return self.framework + + def change_before_train(self, num_classes: int = 0): + if num_classes > 0: + self.model.create_classifier(num_classes) \ No newline at end of file diff --git a/resnet_training_yaml.yaml b/resnet_training_yaml.yaml new file mode 100644 index 0000000..c003c2f --- /dev/null +++ b/resnet_training_yaml.yaml @@ -0,0 +1,22 @@ +device: + device_str: null + n_gpu: -1 + sync_bn: true +metrics: + metric: Accuracy +train: + batch_size: 16 +learning: + optimizer: + name_: SGD + lr: 0.03 + momentum: 0.001 + nesterov: 111 + loss: + name_: CrossEntropyLoss + label_smoothing: 0.1 +#learning: +# optimizer: +# name_: Adam +# lr: 0.02 +# eps: 0.001 diff --git a/test.py b/test.py new file mode 100644 index 0000000..04ad646 --- /dev/null +++ b/test.py @@ -0,0 +1,65 @@ +import numpy as np +from torchvision import transforms +from torchvision.transforms import RandomResizedCrop, Lambda +from towhee.trainer.modelcard import ModelCard + +from towhee.trainer.training_config import TrainingConfig +from towhee.trainer.dataset import get_dataset +from resnet_image_embedding import ResnetImageEmbedding +from towhee.types import Image +from towhee.trainer.training_config import dump_default_yaml +from PIL import Image as PILImage +from timm.models.resnet import ResNet + +if __name__ == '__main__': + # img = torch.rand([1, 3, 224, 224]) + img_path = './ILSVRC2012_val_00049771.JPEG' + # # logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png') + img = PILImage.open(img_path) + img_bytes = img.tobytes() + img_width = img.width + img_height = img.height + img_channel = len(img.split()) + img_mode = img.mode + img_array = np.array(img) + array_size = np.array(img).shape + towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array) + + op = ResnetImageEmbedding('resnet34') + op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data") + # old_out = op(towhee_img) + + # print(old_out.feature_vector[0]) + + training_config = TrainingConfig() + yaml_path = 'resnet_training_yaml.yaml' + # dump_default_yaml(yaml_path=yaml_path) + training_config.load_from_yaml(yaml_path) + # output_dir='./temp_output', + # overwrite_output_dir=True, + # epoch_num=2, + # per_gpu_train_batch_size=16, + # prediction_loss_only=True, + # metric='Accuracy' + # # device_str='cuda', + # # n_gpu=4 + # ) + + + mnist_transform = transforms.Compose([transforms.ToTensor(), + RandomResizedCrop(224), + Lambda(lambda x: x.repeat(3, 1, 1)), + transforms.Normalize(mean=[0.5], std=[0.5])]) + train_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data') + + fake_transform = transforms.Compose([transforms.ToTensor(), + RandomResizedCrop(224),]) + # train_data = get_dataset('fake', size=20, transform=fake_transform) + + op.change_before_train(10) + op.train(training_config, train_dataset=train_data) + # e.save('./test_save') + # e.load('./test_save') + # new_out = e(img) + + # assert (new_out[0]!=old_out[0]).all()