logo
Browse Source

fix embedding_extractor.py memory leak problem, and add a `num_classes` param in the construct function of operator.

training
zhang chen 3 years ago
parent
commit
086c225df0
  1. 66
      pytorch/embedding_extractor.py
  2. 52
      pytorch/model.py
  3. 5
      resnet_image_embedding.py
  4. 53
      resnet_training_yaml.yaml
  5. 77
      test.py

66
pytorch/embedding_extractor.py

@ -1,66 +0,0 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pprint
class EmbeddingOutput:
"""
Container for embedding extractor.
"""
def __init__(self):
self.embeddings = []
def __call__(self, module, module_in, module_out):
self.embeddings.append(module_out)
def clear(self):
"""
clear list
"""
self.embeddings = []
class EmbeddingExtractor:
"""
Embedding extractor from a layer
Args:
model (`nn.Module`):
Model used for inference.
"""
def __init__(self, model):
# self.modules = model.modules()
# self.modules_list = list(model.named_modules(remove_duplicate=False))
self.modules_dict = dict(model.named_modules(remove_duplicate=False))
self.emb_out = EmbeddingOutput()
def disp_modules(self, full=False):
"""
Display the the modules of the model.
"""
if not full:
pprint.pprint(list(self.modules_dict.keys()))
else:
pprint.pprint(self.modules_dict)
def register(self, layer_name: str):
"""
Registration for embedding extraction.
Args:
layer_name (`str`):
Name of the layer from which the embedding is extracted.
"""
if layer_name in self.modules_dict:
layer = self.modules_dict[layer_name]
layer.register_forward_hook(self.emb_out)
else:
raise ValueError('layer_name not in modules')

52
pytorch/model.py

@ -13,57 +13,39 @@
# limitations under the License.
from typing import NamedTuple
import numpy
import torch
import torchvision
from torch.nn import Linear
from timm.models.resnet import ResNet
# ResNet.
from pytorch.embedding_extractor import EmbeddingExtractor
#todo:后面改成用towhee.models.embedding.下面的EmbeddingExtractor,这个现在在origin main分支上可用,但在train分支上不可用
from torch import nn
import timm
class Model():
"""
PyTorch model class
"""
def __init__(self, model_name):
def __init__(self, model_name, num_classes=1000):
super().__init__()
model_func = getattr(torchvision.models, model_name)
self._model = model_func(pretrained=True)
state_dict = None
self._model = timm.create_model(model_name, pretrained=True)
pretrained_dict = None
if model_name == 'resnet101':
state_dict = torch.hub.load_state_dict_from_url(
pretrained_dict = torch.hub.load_state_dict_from_url(
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1h-36d3f2aa.pth')
if model_name == 'resnet50':
state_dict = torch.hub.load_state_dict_from_url(
pretrained_dict = torch.hub.load_state_dict_from_url(
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth')
if state_dict:
self._model.load_state_dict(state_dict)
# self._model.fc = torch.nn.Identity()
if pretrained_dict:
self._model.load_state_dict(pretrained_dict, strict=False)
if num_classes != 1000:
self.create_classifier(num_classes=num_classes)
self._model.eval()
self.ex = EmbeddingExtractor(self._model)
# self.ex.disp_modules(full=True)
self.ex.register('avgpool')
def __call__(self, img_tensor: torch.Tensor):
self.ex.emb_out.clear()
self._model(img_tensor)
# return self.fc_input[0]
return self.ex.emb_out.embeddings[0]
# return self._model(img_tensor).flatten().detach().numpy() #todo
features = self._model.forward_features(img_tensor)
if features.dim() == 4: # if the shape of feature map is [N, C, H, W], where H > 1 and W > 1
global_pool = nn.AdaptiveAvgPool2d(1)
features = global_pool(features)
return features.flatten().detach().numpy()
def create_classifier(self, num_classes):
self._model.fc = Linear(self._model.fc.in_features, num_classes, bias=True)
# self._model.classifier.register_forward_hook(self._forward_hook)
# def train(self):
# """
# For training model
# """
# pass

5
resnet_image_embedding.py

@ -31,7 +31,7 @@ class ResnetImageEmbedding(NNOperator):
"""
PyTorch model for image embedding.
"""
def __init__(self, model_name: str, framework: str = 'pytorch') -> None:
def __init__(self, model_name: str, num_classes: int = 1000, framework: str = 'pytorch') -> None:
super().__init__(framework=framework)
if framework == 'pytorch':
import importlib.util
@ -40,8 +40,7 @@ class ResnetImageEmbedding(NNOperator):
spec = importlib.util.spec_from_file_location(opname, path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
self.model = module.Model(model_name)
self.model = module.Model(model_name, num_classes=num_classes)
self.tfms = transforms.Compose([transforms.Resize(235, interpolation=InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.ToTensor(),

53
resnet_training_yaml.yaml

@ -1,25 +1,42 @@
callback:
early_stopping:
mode: max
monitor: eval_epoch_metric
patience: 2
model_checkpoint:
every_n_epoch: 2
tensorboard:
comment: ''
log_dir: null
device:
device_str: null
n_gpu: -1
sync_bn: true
sync_bn: false
learning:
loss: CrossEntropyLoss
lr: 5.0e-05
lr_scheduler_type: linear
optimizer: Adam
warmup_ratio: 0.0
warmup_steps: 0
logging:
logging_dir: null
logging_strategy: steps
print_steps: null
save_strategy: steps
metrics:
metric: Accuracy
train:
batch_size: 32
batch_size: 16
dataloader_drop_last: false
dataloader_num_workers: 0
epoch_num: 16
eval_steps: null
eval_strategy: epoch
load_best_model_at_end: false
max_steps: -1
output_dir: ./output_dir
overwrite_output_dir: true
epoch_num: 2
learning:
optimizer:
name_: SGD
lr: 0.04
momentum: 0.001
loss:
name_: CrossEntropyLoss
ignore_index: -1
logging:
print_steps: 2
#learning:
# optimizer:
# name_: Adam
# lr: 0.02
# eps: 0.001
resume_from_checkpoint: null
seed: 42
val_batch_size: -1

77
test.py

@ -2,10 +2,13 @@ import numpy as np
from torch.optim import AdamW
from torchvision import transforms
from torchvision.transforms import RandomResizedCrop, Lambda
from towhee.data.dataset.dataset import dataset
from towhee.trainer.modelcard import ModelCard
from towhee.trainer.training_config import TrainingConfig
from towhee.trainer.dataset import get_dataset
# from towhee.trainer.dataset import get_dataset
from towhee.trainer.utils.layer_freezer import LayerFreezer
from resnet_image_embedding import ResnetImageEmbedding
from towhee.types import Image
from towhee.trainer.training_config import dump_default_yaml
@ -17,7 +20,7 @@ if __name__ == '__main__':
dump_default_yaml(yaml_path='default_config.yaml')
# img = torch.rand([1, 3, 224, 224])
img_path = './ILSVRC2012_val_00049771.JPEG'
# # logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png')
# logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png')
img = PILImage.open(img_path)
img_bytes = img.tobytes()
img_width = img.width
@ -28,10 +31,12 @@ if __name__ == '__main__':
array_size = np.array(img).shape
towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array)
op = ResnetImageEmbedding('resnet34')
op = ResnetImageEmbedding('resnet50', num_classes=10)
# op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data")
# old_out = op(towhee_img)
# print(old_out.feature_vector[0])
old_out = op(towhee_img)
# print(old_out.feature_vector[0][:10])
print(old_out.feature_vector[:10])
# print(old_out.feature_vector.shape)
training_config = TrainingConfig()
yaml_path = 'resnet_training_yaml.yaml'
@ -46,39 +51,45 @@ if __name__ == '__main__':
# # device_str='cuda',
# # n_gpu=4
# )
#
mnist_transform = transforms.Compose([transforms.ToTensor(),
RandomResizedCrop(224),
Lambda(lambda x: x.repeat(3, 1, 1)),
transforms.Normalize(mean=[0.5], std=[0.5])])
train_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data', train=True)
eval_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data', train=False)
train_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=True)
eval_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=False)
# fake_transform = transforms.Compose([transforms.ToTensor(),
# RandomResizedCrop(224),])
# train_data = get_dataset('fake', size=20, transform=fake_transform)
op.change_before_train(10)
trainer = op.setup_trainer()
# my_optimimzer = AdamW(op.get_model().parameters(), lr=0.002, betas=(0.91, 0.98), eps=1e-08, weight_decay=0.01, amsgrad=False)
# op.setup_trainer()
# trainer.add_callback()
# trainer.set_optimizer()
# op.trainer.set_optimizer(my_optimimzer)
# trainer.configs.save_to_yaml('changed_optimizer_yaml.yaml')
# my_loss = nn.BCELoss()
# trainer.set_loss(my_loss, 'my_loss111')
# trainer.configs.save_to_yaml('chaned_loss_yaml.yaml')
# op.trainer._create_optimizer()
# op.trainer.set_optimizer()
#
# op.change_before_train(num_classes=10)
# # trainer = op.setup_trainer()
# print(op.get_model())
# # my_optimimzer = AdamW(op.get_model().parameters(), lr=0.002, betas=(0.91, 0.98), eps=1e-08, weight_decay=0.01, amsgrad=False)
# # op.setup_trainer()
#
# # trainer.add_callback()
# # trainer.set_optimizer()
#
# # op.trainer.set_optimizer(my_optimimzer)
# # trainer.configs.save_to_yaml('changed_optimizer_yaml.yaml')
#
# # my_loss = nn.BCELoss()
# # trainer.set_loss(my_loss, 'my_loss111')
# # trainer.configs.save_to_yaml('chaned_loss_yaml.yaml')
# # op.trainer._create_optimizer()
# # op.trainer.set_optimizer()
# # trainer = op.setup_trainer(training_config, train_dataset=train_data, eval_dataset=eval_data)
#
# # freezer = LayerFreezer(op.get_model())
# # freezer.by_idx([-1])
op.train(training_config, train_dataset=train_data, eval_dataset=eval_data)
# training_config.num_epoch = 3
# op.train(training_config, train_dataset=train_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2')
# op.save('./test_save')
# op.load('./test_save')
# new_out = op(towhee_img)
# assert (new_out[0]!=old_out[0]).all()
# # op.trainer.run_train()
# # training_config.num_epoch = 3
# # op.train(training_config, train_dataset=train_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2')
#
# # op.save('./test_save')
# # op.load('./test_save')
# # new_out = op(towhee_img)
#
# # assert (new_out[0]!=old_out[0]).all()

Loading…
Cancel
Save