From 086c225df00029da7e9884ad61bba94399dcf8a4 Mon Sep 17 00:00:00 2001 From: zhang chen Date: Mon, 21 Feb 2022 17:49:11 +0800 Subject: [PATCH] fix embedding_extractor.py memory leak problem, and add a `num_classes` param in the construct function of operator. --- pytorch/embedding_extractor.py | 66 ----------------------------- pytorch/model.py | 52 ++++++++--------------- resnet_image_embedding.py | 5 +-- resnet_training_yaml.yaml | 53 +++++++++++++++-------- test.py | 77 +++++++++++++++++++--------------- 5 files changed, 98 insertions(+), 155 deletions(-) delete mode 100644 pytorch/embedding_extractor.py diff --git a/pytorch/embedding_extractor.py b/pytorch/embedding_extractor.py deleted file mode 100644 index 873b1d8..0000000 --- a/pytorch/embedding_extractor.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2021 Zilliz. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import pprint - -class EmbeddingOutput: - """ - Container for embedding extractor. - """ - def __init__(self): - self.embeddings = [] - - def __call__(self, module, module_in, module_out): - self.embeddings.append(module_out) - - def clear(self): - """ - clear list - """ - self.embeddings = [] - - -class EmbeddingExtractor: - """ - Embedding extractor from a layer - Args: - model (`nn.Module`): - Model used for inference. - """ - def __init__(self, model): - # self.modules = model.modules() - # self.modules_list = list(model.named_modules(remove_duplicate=False)) - self.modules_dict = dict(model.named_modules(remove_duplicate=False)) - self.emb_out = EmbeddingOutput() - - def disp_modules(self, full=False): - """ - Display the the modules of the model. - """ - if not full: - pprint.pprint(list(self.modules_dict.keys())) - else: - pprint.pprint(self.modules_dict) - - def register(self, layer_name: str): - """ - Registration for embedding extraction. - Args: - layer_name (`str`): - Name of the layer from which the embedding is extracted. - """ - if layer_name in self.modules_dict: - layer = self.modules_dict[layer_name] - layer.register_forward_hook(self.emb_out) - else: - raise ValueError('layer_name not in modules') \ No newline at end of file diff --git a/pytorch/model.py b/pytorch/model.py index 183d217..2d5d36a 100644 --- a/pytorch/model.py +++ b/pytorch/model.py @@ -13,57 +13,39 @@ # limitations under the License. -from typing import NamedTuple - -import numpy import torch -import torchvision from torch.nn import Linear -from timm.models.resnet import ResNet - -# ResNet. -from pytorch.embedding_extractor import EmbeddingExtractor -#todo:后面改成用towhee.models.embedding.下面的EmbeddingExtractor,这个现在在origin main分支上可用,但在train分支上不可用 - +from torch import nn +import timm class Model(): """ PyTorch model class """ - def __init__(self, model_name): + + def __init__(self, model_name, num_classes=1000): super().__init__() - model_func = getattr(torchvision.models, model_name) - self._model = model_func(pretrained=True) - state_dict = None + self._model = timm.create_model(model_name, pretrained=True) + pretrained_dict = None if model_name == 'resnet101': - state_dict = torch.hub.load_state_dict_from_url( + pretrained_dict = torch.hub.load_state_dict_from_url( 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1h-36d3f2aa.pth') if model_name == 'resnet50': - state_dict = torch.hub.load_state_dict_from_url( + pretrained_dict = torch.hub.load_state_dict_from_url( 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth') - if state_dict: - self._model.load_state_dict(state_dict) - - # self._model.fc = torch.nn.Identity() + if pretrained_dict: + self._model.load_state_dict(pretrained_dict, strict=False) + if num_classes != 1000: + self.create_classifier(num_classes=num_classes) self._model.eval() - self.ex = EmbeddingExtractor(self._model) - # self.ex.disp_modules(full=True) - self.ex.register('avgpool') def __call__(self, img_tensor: torch.Tensor): - self.ex.emb_out.clear() - self._model(img_tensor) - # return self.fc_input[0] - return self.ex.emb_out.embeddings[0] - # return self._model(img_tensor).flatten().detach().numpy() #todo + features = self._model.forward_features(img_tensor) + if features.dim() == 4: # if the shape of feature map is [N, C, H, W], where H > 1 and W > 1 + global_pool = nn.AdaptiveAvgPool2d(1) + features = global_pool(features) + return features.flatten().detach().numpy() def create_classifier(self, num_classes): self._model.fc = Linear(self._model.fc.in_features, num_classes, bias=True) - # self._model.classifier.register_forward_hook(self._forward_hook) - - # def train(self): - # """ - # For training model - # """ - # pass \ No newline at end of file diff --git a/resnet_image_embedding.py b/resnet_image_embedding.py index 0cd54f3..f4de412 100644 --- a/resnet_image_embedding.py +++ b/resnet_image_embedding.py @@ -31,7 +31,7 @@ class ResnetImageEmbedding(NNOperator): """ PyTorch model for image embedding. """ - def __init__(self, model_name: str, framework: str = 'pytorch') -> None: + def __init__(self, model_name: str, num_classes: int = 1000, framework: str = 'pytorch') -> None: super().__init__(framework=framework) if framework == 'pytorch': import importlib.util @@ -40,8 +40,7 @@ class ResnetImageEmbedding(NNOperator): spec = importlib.util.spec_from_file_location(opname, path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - self.model = module.Model(model_name) - + self.model = module.Model(model_name, num_classes=num_classes) self.tfms = transforms.Compose([transforms.Resize(235, interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(224), transforms.ToTensor(), diff --git a/resnet_training_yaml.yaml b/resnet_training_yaml.yaml index 34129b5..65c222a 100644 --- a/resnet_training_yaml.yaml +++ b/resnet_training_yaml.yaml @@ -1,25 +1,42 @@ +callback: + early_stopping: + mode: max + monitor: eval_epoch_metric + patience: 2 + model_checkpoint: + every_n_epoch: 2 + tensorboard: + comment: '' + log_dir: null device: device_str: null n_gpu: -1 - sync_bn: true + sync_bn: false +learning: + loss: CrossEntropyLoss + lr: 5.0e-05 + lr_scheduler_type: linear + optimizer: Adam + warmup_ratio: 0.0 + warmup_steps: 0 +logging: + logging_dir: null + logging_strategy: steps + print_steps: null + save_strategy: steps metrics: metric: Accuracy train: - batch_size: 32 + batch_size: 16 + dataloader_drop_last: false + dataloader_num_workers: 0 + epoch_num: 16 + eval_steps: null + eval_strategy: epoch + load_best_model_at_end: false + max_steps: -1 + output_dir: ./output_dir overwrite_output_dir: true - epoch_num: 2 -learning: - optimizer: - name_: SGD - lr: 0.04 - momentum: 0.001 - loss: - name_: CrossEntropyLoss - ignore_index: -1 -logging: - print_steps: 2 -#learning: -# optimizer: -# name_: Adam -# lr: 0.02 -# eps: 0.001 + resume_from_checkpoint: null + seed: 42 + val_batch_size: -1 diff --git a/test.py b/test.py index 23d52ca..75b75ce 100644 --- a/test.py +++ b/test.py @@ -2,10 +2,13 @@ import numpy as np from torch.optim import AdamW from torchvision import transforms from torchvision.transforms import RandomResizedCrop, Lambda +from towhee.data.dataset.dataset import dataset from towhee.trainer.modelcard import ModelCard from towhee.trainer.training_config import TrainingConfig -from towhee.trainer.dataset import get_dataset +# from towhee.trainer.dataset import get_dataset +from towhee.trainer.utils.layer_freezer import LayerFreezer + from resnet_image_embedding import ResnetImageEmbedding from towhee.types import Image from towhee.trainer.training_config import dump_default_yaml @@ -17,7 +20,7 @@ if __name__ == '__main__': dump_default_yaml(yaml_path='default_config.yaml') # img = torch.rand([1, 3, 224, 224]) img_path = './ILSVRC2012_val_00049771.JPEG' - # # logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png') + # logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png') img = PILImage.open(img_path) img_bytes = img.tobytes() img_width = img.width @@ -28,10 +31,12 @@ if __name__ == '__main__': array_size = np.array(img).shape towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array) - op = ResnetImageEmbedding('resnet34') + op = ResnetImageEmbedding('resnet50', num_classes=10) # op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data") - # old_out = op(towhee_img) - # print(old_out.feature_vector[0]) + old_out = op(towhee_img) + # print(old_out.feature_vector[0][:10]) + print(old_out.feature_vector[:10]) + # print(old_out.feature_vector.shape) training_config = TrainingConfig() yaml_path = 'resnet_training_yaml.yaml' @@ -46,39 +51,45 @@ if __name__ == '__main__': # # device_str='cuda', # # n_gpu=4 # ) - + # mnist_transform = transforms.Compose([transforms.ToTensor(), RandomResizedCrop(224), Lambda(lambda x: x.repeat(3, 1, 1)), transforms.Normalize(mean=[0.5], std=[0.5])]) - train_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data', train=True) - eval_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data', train=False) + train_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=True) + eval_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=False) # fake_transform = transforms.Compose([transforms.ToTensor(), # RandomResizedCrop(224),]) # train_data = get_dataset('fake', size=20, transform=fake_transform) - - op.change_before_train(10) - trainer = op.setup_trainer() - # my_optimimzer = AdamW(op.get_model().parameters(), lr=0.002, betas=(0.91, 0.98), eps=1e-08, weight_decay=0.01, amsgrad=False) - # op.setup_trainer() - - # trainer.add_callback() - # trainer.set_optimizer() - - # op.trainer.set_optimizer(my_optimimzer) - # trainer.configs.save_to_yaml('changed_optimizer_yaml.yaml') - - # my_loss = nn.BCELoss() - # trainer.set_loss(my_loss, 'my_loss111') - # trainer.configs.save_to_yaml('chaned_loss_yaml.yaml') - # op.trainer._create_optimizer() - # op.trainer.set_optimizer() + # + # op.change_before_train(num_classes=10) + # # trainer = op.setup_trainer() + # print(op.get_model()) + # # my_optimimzer = AdamW(op.get_model().parameters(), lr=0.002, betas=(0.91, 0.98), eps=1e-08, weight_decay=0.01, amsgrad=False) + # # op.setup_trainer() + # + # # trainer.add_callback() + # # trainer.set_optimizer() + # + # # op.trainer.set_optimizer(my_optimimzer) + # # trainer.configs.save_to_yaml('changed_optimizer_yaml.yaml') + # + # # my_loss = nn.BCELoss() + # # trainer.set_loss(my_loss, 'my_loss111') + # # trainer.configs.save_to_yaml('chaned_loss_yaml.yaml') + # # op.trainer._create_optimizer() + # # op.trainer.set_optimizer() + # # trainer = op.setup_trainer(training_config, train_dataset=train_data, eval_dataset=eval_data) + # + # # freezer = LayerFreezer(op.get_model()) + # # freezer.by_idx([-1]) op.train(training_config, train_dataset=train_data, eval_dataset=eval_data) - # training_config.num_epoch = 3 - # op.train(training_config, train_dataset=train_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2') - - # op.save('./test_save') - # op.load('./test_save') - # new_out = op(towhee_img) - - # assert (new_out[0]!=old_out[0]).all() + # # op.trainer.run_train() + # # training_config.num_epoch = 3 + # # op.train(training_config, train_dataset=train_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2') + # + # # op.save('./test_save') + # # op.load('./test_save') + # # new_out = op(towhee_img) + # + # # assert (new_out[0]!=old_out[0]).all()