logo
Browse Source

resnet op train from yaml config.

training
zhang chen 3 years ago
parent
commit
28cdc0f73a
  1. BIN
      ILSVRC2012_val_00049771.JPEG
  2. 66
      pytorch/embedding_extractor.py
  3. 32
      pytorch/model.py
  4. 14
      resnet_image_embedding.py
  5. 22
      resnet_training_yaml.yaml
  6. 65
      test.py

BIN
ILSVRC2012_val_00049771.JPEG

Binary file not shown.

After

Width:  |  Height:  |  Size: 92 KiB

66
pytorch/embedding_extractor.py

@ -0,0 +1,66 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pprint
class EmbeddingOutput:
"""
Container for embedding extractor.
"""
def __init__(self):
self.embeddings = []
def __call__(self, module, module_in, module_out):
self.embeddings.append(module_out)
def clear(self):
"""
clear list
"""
self.embeddings = []
class EmbeddingExtractor:
"""
Embedding extractor from a layer
Args:
model (`nn.Module`):
Model used for inference.
"""
def __init__(self, model):
# self.modules = model.modules()
# self.modules_list = list(model.named_modules(remove_duplicate=False))
self.modules_dict = dict(model.named_modules(remove_duplicate=False))
self.emb_out = EmbeddingOutput()
def disp_modules(self, full=False):
"""
Display the the modules of the model.
"""
if not full:
pprint.pprint(list(self.modules_dict.keys()))
else:
pprint.pprint(self.modules_dict)
def register(self, layer_name: str):
"""
Registration for embedding extraction.
Args:
layer_name (`str`):
Name of the layer from which the embedding is extracted.
"""
if layer_name in self.modules_dict:
layer = self.modules_dict[layer_name]
layer.register_forward_hook(self.emb_out)
else:
raise ValueError('layer_name not in modules')

32
pytorch/model.py

@ -18,6 +18,13 @@ from typing import NamedTuple
import numpy
import torch
import torchvision
from torch.nn import Linear
from timm.models.resnet import ResNet
# ResNet.
from pytorch.embedding_extractor import EmbeddingExtractor
#todo:后面改成用towhee.models.embedding.下面的EmbeddingExtractor,这个现在在origin main分支上可用,但在train分支上不可用
class Model():
@ -38,14 +45,25 @@ class Model():
if state_dict:
self._model.load_state_dict(state_dict)
self._model.fc = torch.nn.Identity()
# self._model.fc = torch.nn.Identity()
self._model.eval()
self.ex = EmbeddingExtractor(self._model)
# self.ex.disp_modules(full=True)
self.ex.register('avgpool')
def __call__(self, img_tensor: torch.Tensor):
return self._model(img_tensor).flatten().detach().numpy()
self.ex.emb_out.clear()
self._model(img_tensor)
# return self.fc_input[0]
return self.ex.emb_out.embeddings[0]
# return self._model(img_tensor).flatten().detach().numpy() #todo
def create_classifier(self, num_classes):
self._model.fc = Linear(self._model.fc.in_features, num_classes, bias=True)
# self._model.classifier.register_forward_hook(self._forward_hook)
def train(self):
"""
For training model
"""
pass
# def train(self):
# """
# For training model
# """
# pass

14
resnet_image_embedding.py

@ -21,12 +21,12 @@ from pathlib import Path
from typing import NamedTuple
import os
from torchvision.transforms import InterpolationMode
from towhee.operator import Operator
from towhee.operator import NNOperator
from towhee.utils.pil_utils import to_pil
import warnings
warnings.filterwarnings("ignore")
class ResnetImageEmbedding(Operator):
class ResnetImageEmbedding(NNOperator):
"""
PyTorch model for image embedding.
"""
@ -50,3 +50,13 @@ class ResnetImageEmbedding(Operator):
embedding = self.model(img)
Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)])
return Outputs(embedding)
def get_model(self):
return self.model._model
# def test(self):
# return self.framework
def change_before_train(self, num_classes: int = 0):
if num_classes > 0:
self.model.create_classifier(num_classes)

22
resnet_training_yaml.yaml

@ -0,0 +1,22 @@
device:
device_str: null
n_gpu: -1
sync_bn: true
metrics:
metric: Accuracy
train:
batch_size: 16
learning:
optimizer:
name_: SGD
lr: 0.03
momentum: 0.001
nesterov: 111
loss:
name_: CrossEntropyLoss
label_smoothing: 0.1
#learning:
# optimizer:
# name_: Adam
# lr: 0.02
# eps: 0.001

65
test.py

@ -0,0 +1,65 @@
import numpy as np
from torchvision import transforms
from torchvision.transforms import RandomResizedCrop, Lambda
from towhee.trainer.modelcard import ModelCard
from towhee.trainer.training_config import TrainingConfig
from towhee.trainer.dataset import get_dataset
from resnet_image_embedding import ResnetImageEmbedding
from towhee.types import Image
from towhee.trainer.training_config import dump_default_yaml
from PIL import Image as PILImage
from timm.models.resnet import ResNet
if __name__ == '__main__':
# img = torch.rand([1, 3, 224, 224])
img_path = './ILSVRC2012_val_00049771.JPEG'
# # logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png')
img = PILImage.open(img_path)
img_bytes = img.tobytes()
img_width = img.width
img_height = img.height
img_channel = len(img.split())
img_mode = img.mode
img_array = np.array(img)
array_size = np.array(img).shape
towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array)
op = ResnetImageEmbedding('resnet34')
op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data")
# old_out = op(towhee_img)
# print(old_out.feature_vector[0])
training_config = TrainingConfig()
yaml_path = 'resnet_training_yaml.yaml'
# dump_default_yaml(yaml_path=yaml_path)
training_config.load_from_yaml(yaml_path)
# output_dir='./temp_output',
# overwrite_output_dir=True,
# epoch_num=2,
# per_gpu_train_batch_size=16,
# prediction_loss_only=True,
# metric='Accuracy'
# # device_str='cuda',
# # n_gpu=4
# )
mnist_transform = transforms.Compose([transforms.ToTensor(),
RandomResizedCrop(224),
Lambda(lambda x: x.repeat(3, 1, 1)),
transforms.Normalize(mean=[0.5], std=[0.5])])
train_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data')
fake_transform = transforms.Compose([transforms.ToTensor(),
RandomResizedCrop(224),])
# train_data = get_dataset('fake', size=20, transform=fake_transform)
op.change_before_train(10)
op.train(training_config, train_dataset=train_data)
# e.save('./test_save')
# e.load('./test_save')
# new_out = e(img)
# assert (new_out[0]!=old_out[0]).all()
Loading…
Cancel
Save