logo
Browse Source

fix embedding_extractor.py memory leak problem, and add a `num_classes` param in the construct function of operator.

training
zhang chen 3 years ago
parent
commit
086c225df0
  1. 66
      pytorch/embedding_extractor.py
  2. 52
      pytorch/model.py
  3. 5
      resnet_image_embedding.py
  4. 53
      resnet_training_yaml.yaml
  5. 77
      test.py

66
pytorch/embedding_extractor.py

@ -1,66 +0,0 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pprint
class EmbeddingOutput:
"""
Container for embedding extractor.
"""
def __init__(self):
self.embeddings = []
def __call__(self, module, module_in, module_out):
self.embeddings.append(module_out)
def clear(self):
"""
clear list
"""
self.embeddings = []
class EmbeddingExtractor:
"""
Embedding extractor from a layer
Args:
model (`nn.Module`):
Model used for inference.
"""
def __init__(self, model):
# self.modules = model.modules()
# self.modules_list = list(model.named_modules(remove_duplicate=False))
self.modules_dict = dict(model.named_modules(remove_duplicate=False))
self.emb_out = EmbeddingOutput()
def disp_modules(self, full=False):
"""
Display the the modules of the model.
"""
if not full:
pprint.pprint(list(self.modules_dict.keys()))
else:
pprint.pprint(self.modules_dict)
def register(self, layer_name: str):
"""
Registration for embedding extraction.
Args:
layer_name (`str`):
Name of the layer from which the embedding is extracted.
"""
if layer_name in self.modules_dict:
layer = self.modules_dict[layer_name]
layer.register_forward_hook(self.emb_out)
else:
raise ValueError('layer_name not in modules')

52
pytorch/model.py

@ -13,57 +13,39 @@
# limitations under the License. # limitations under the License.
from typing import NamedTuple
import numpy
import torch import torch
import torchvision
from torch.nn import Linear from torch.nn import Linear
from timm.models.resnet import ResNet
# ResNet.
from pytorch.embedding_extractor import EmbeddingExtractor
#todo:后面改成用towhee.models.embedding.下面的EmbeddingExtractor,这个现在在origin main分支上可用,但在train分支上不可用
from torch import nn
import timm
class Model(): class Model():
""" """
PyTorch model class PyTorch model class
""" """
def __init__(self, model_name):
def __init__(self, model_name, num_classes=1000):
super().__init__() super().__init__()
model_func = getattr(torchvision.models, model_name)
self._model = model_func(pretrained=True)
state_dict = None
self._model = timm.create_model(model_name, pretrained=True)
pretrained_dict = None
if model_name == 'resnet101': if model_name == 'resnet101':
state_dict = torch.hub.load_state_dict_from_url(
pretrained_dict = torch.hub.load_state_dict_from_url(
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1h-36d3f2aa.pth') 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1h-36d3f2aa.pth')
if model_name == 'resnet50': if model_name == 'resnet50':
state_dict = torch.hub.load_state_dict_from_url(
pretrained_dict = torch.hub.load_state_dict_from_url(
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth') 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth')
if state_dict:
self._model.load_state_dict(state_dict)
# self._model.fc = torch.nn.Identity()
if pretrained_dict:
self._model.load_state_dict(pretrained_dict, strict=False)
if num_classes != 1000:
self.create_classifier(num_classes=num_classes)
self._model.eval() self._model.eval()
self.ex = EmbeddingExtractor(self._model)
# self.ex.disp_modules(full=True)
self.ex.register('avgpool')
def __call__(self, img_tensor: torch.Tensor): def __call__(self, img_tensor: torch.Tensor):
self.ex.emb_out.clear()
self._model(img_tensor)
# return self.fc_input[0]
return self.ex.emb_out.embeddings[0]
# return self._model(img_tensor).flatten().detach().numpy() #todo
features = self._model.forward_features(img_tensor)
if features.dim() == 4: # if the shape of feature map is [N, C, H, W], where H > 1 and W > 1
global_pool = nn.AdaptiveAvgPool2d(1)
features = global_pool(features)
return features.flatten().detach().numpy()
def create_classifier(self, num_classes): def create_classifier(self, num_classes):
self._model.fc = Linear(self._model.fc.in_features, num_classes, bias=True) self._model.fc = Linear(self._model.fc.in_features, num_classes, bias=True)
# self._model.classifier.register_forward_hook(self._forward_hook)
# def train(self):
# """
# For training model
# """
# pass

5
resnet_image_embedding.py

@ -31,7 +31,7 @@ class ResnetImageEmbedding(NNOperator):
""" """
PyTorch model for image embedding. PyTorch model for image embedding.
""" """
def __init__(self, model_name: str, framework: str = 'pytorch') -> None:
def __init__(self, model_name: str, num_classes: int = 1000, framework: str = 'pytorch') -> None:
super().__init__(framework=framework) super().__init__(framework=framework)
if framework == 'pytorch': if framework == 'pytorch':
import importlib.util import importlib.util
@ -40,8 +40,7 @@ class ResnetImageEmbedding(NNOperator):
spec = importlib.util.spec_from_file_location(opname, path) spec = importlib.util.spec_from_file_location(opname, path)
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) spec.loader.exec_module(module)
self.model = module.Model(model_name)
self.model = module.Model(model_name, num_classes=num_classes)
self.tfms = transforms.Compose([transforms.Resize(235, interpolation=InterpolationMode.BICUBIC), self.tfms = transforms.Compose([transforms.Resize(235, interpolation=InterpolationMode.BICUBIC),
transforms.CenterCrop(224), transforms.CenterCrop(224),
transforms.ToTensor(), transforms.ToTensor(),

53
resnet_training_yaml.yaml

@ -1,25 +1,42 @@
callback:
early_stopping:
mode: max
monitor: eval_epoch_metric
patience: 2
model_checkpoint:
every_n_epoch: 2
tensorboard:
comment: ''
log_dir: null
device: device:
device_str: null device_str: null
n_gpu: -1 n_gpu: -1
sync_bn: true
sync_bn: false
learning:
loss: CrossEntropyLoss
lr: 5.0e-05
lr_scheduler_type: linear
optimizer: Adam
warmup_ratio: 0.0
warmup_steps: 0
logging:
logging_dir: null
logging_strategy: steps
print_steps: null
save_strategy: steps
metrics: metrics:
metric: Accuracy metric: Accuracy
train: train:
batch_size: 32
batch_size: 16
dataloader_drop_last: false
dataloader_num_workers: 0
epoch_num: 16
eval_steps: null
eval_strategy: epoch
load_best_model_at_end: false
max_steps: -1
output_dir: ./output_dir
overwrite_output_dir: true overwrite_output_dir: true
epoch_num: 2
learning:
optimizer:
name_: SGD
lr: 0.04
momentum: 0.001
loss:
name_: CrossEntropyLoss
ignore_index: -1
logging:
print_steps: 2
#learning:
# optimizer:
# name_: Adam
# lr: 0.02
# eps: 0.001
resume_from_checkpoint: null
seed: 42
val_batch_size: -1

77
test.py

@ -2,10 +2,13 @@ import numpy as np
from torch.optim import AdamW from torch.optim import AdamW
from torchvision import transforms from torchvision import transforms
from torchvision.transforms import RandomResizedCrop, Lambda from torchvision.transforms import RandomResizedCrop, Lambda
from towhee.data.dataset.dataset import dataset
from towhee.trainer.modelcard import ModelCard from towhee.trainer.modelcard import ModelCard
from towhee.trainer.training_config import TrainingConfig from towhee.trainer.training_config import TrainingConfig
from towhee.trainer.dataset import get_dataset
# from towhee.trainer.dataset import get_dataset
from towhee.trainer.utils.layer_freezer import LayerFreezer
from resnet_image_embedding import ResnetImageEmbedding from resnet_image_embedding import ResnetImageEmbedding
from towhee.types import Image from towhee.types import Image
from towhee.trainer.training_config import dump_default_yaml from towhee.trainer.training_config import dump_default_yaml
@ -17,7 +20,7 @@ if __name__ == '__main__':
dump_default_yaml(yaml_path='default_config.yaml') dump_default_yaml(yaml_path='default_config.yaml')
# img = torch.rand([1, 3, 224, 224]) # img = torch.rand([1, 3, 224, 224])
img_path = './ILSVRC2012_val_00049771.JPEG' img_path = './ILSVRC2012_val_00049771.JPEG'
# # logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png')
# logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png')
img = PILImage.open(img_path) img = PILImage.open(img_path)
img_bytes = img.tobytes() img_bytes = img.tobytes()
img_width = img.width img_width = img.width
@ -28,10 +31,12 @@ if __name__ == '__main__':
array_size = np.array(img).shape array_size = np.array(img).shape
towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array) towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array)
op = ResnetImageEmbedding('resnet34')
op = ResnetImageEmbedding('resnet50', num_classes=10)
# op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data") # op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data")
# old_out = op(towhee_img)
# print(old_out.feature_vector[0])
old_out = op(towhee_img)
# print(old_out.feature_vector[0][:10])
print(old_out.feature_vector[:10])
# print(old_out.feature_vector.shape)
training_config = TrainingConfig() training_config = TrainingConfig()
yaml_path = 'resnet_training_yaml.yaml' yaml_path = 'resnet_training_yaml.yaml'
@ -46,39 +51,45 @@ if __name__ == '__main__':
# # device_str='cuda', # # device_str='cuda',
# # n_gpu=4 # # n_gpu=4
# ) # )
#
mnist_transform = transforms.Compose([transforms.ToTensor(), mnist_transform = transforms.Compose([transforms.ToTensor(),
RandomResizedCrop(224), RandomResizedCrop(224),
Lambda(lambda x: x.repeat(3, 1, 1)), Lambda(lambda x: x.repeat(3, 1, 1)),
transforms.Normalize(mean=[0.5], std=[0.5])]) transforms.Normalize(mean=[0.5], std=[0.5])])
train_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data', train=True)
eval_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data', train=False)
train_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=True)
eval_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=False)
# fake_transform = transforms.Compose([transforms.ToTensor(), # fake_transform = transforms.Compose([transforms.ToTensor(),
# RandomResizedCrop(224),]) # RandomResizedCrop(224),])
# train_data = get_dataset('fake', size=20, transform=fake_transform) # train_data = get_dataset('fake', size=20, transform=fake_transform)
op.change_before_train(10)
trainer = op.setup_trainer()
# my_optimimzer = AdamW(op.get_model().parameters(), lr=0.002, betas=(0.91, 0.98), eps=1e-08, weight_decay=0.01, amsgrad=False)
# op.setup_trainer()
# trainer.add_callback()
# trainer.set_optimizer()
# op.trainer.set_optimizer(my_optimimzer)
# trainer.configs.save_to_yaml('changed_optimizer_yaml.yaml')
# my_loss = nn.BCELoss()
# trainer.set_loss(my_loss, 'my_loss111')
# trainer.configs.save_to_yaml('chaned_loss_yaml.yaml')
# op.trainer._create_optimizer()
# op.trainer.set_optimizer()
#
# op.change_before_train(num_classes=10)
# # trainer = op.setup_trainer()
# print(op.get_model())
# # my_optimimzer = AdamW(op.get_model().parameters(), lr=0.002, betas=(0.91, 0.98), eps=1e-08, weight_decay=0.01, amsgrad=False)
# # op.setup_trainer()
#
# # trainer.add_callback()
# # trainer.set_optimizer()
#
# # op.trainer.set_optimizer(my_optimimzer)
# # trainer.configs.save_to_yaml('changed_optimizer_yaml.yaml')
#
# # my_loss = nn.BCELoss()
# # trainer.set_loss(my_loss, 'my_loss111')
# # trainer.configs.save_to_yaml('chaned_loss_yaml.yaml')
# # op.trainer._create_optimizer()
# # op.trainer.set_optimizer()
# # trainer = op.setup_trainer(training_config, train_dataset=train_data, eval_dataset=eval_data)
#
# # freezer = LayerFreezer(op.get_model())
# # freezer.by_idx([-1])
op.train(training_config, train_dataset=train_data, eval_dataset=eval_data) op.train(training_config, train_dataset=train_data, eval_dataset=eval_data)
# training_config.num_epoch = 3
# op.train(training_config, train_dataset=train_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2')
# op.save('./test_save')
# op.load('./test_save')
# new_out = op(towhee_img)
# assert (new_out[0]!=old_out[0]).all()
# # op.trainer.run_train()
# # training_config.num_epoch = 3
# # op.train(training_config, train_dataset=train_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2')
#
# # op.save('./test_save')
# # op.load('./test_save')
# # new_out = op(towhee_img)
#
# # assert (new_out[0]!=old_out[0]).all()

Loading…
Cancel
Save