logo
Browse Source

resnet op train from yaml config.

training
zhang chen 3 years ago
parent
commit
28cdc0f73a
  1. BIN
      ILSVRC2012_val_00049771.JPEG
  2. 66
      pytorch/embedding_extractor.py
  3. 32
      pytorch/model.py
  4. 14
      resnet_image_embedding.py
  5. 22
      resnet_training_yaml.yaml
  6. 65
      test.py

BIN
ILSVRC2012_val_00049771.JPEG

Binary file not shown.

After

Width:  |  Height:  |  Size: 92 KiB

66
pytorch/embedding_extractor.py

@ -0,0 +1,66 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pprint
class EmbeddingOutput:
"""
Container for embedding extractor.
"""
def __init__(self):
self.embeddings = []
def __call__(self, module, module_in, module_out):
self.embeddings.append(module_out)
def clear(self):
"""
clear list
"""
self.embeddings = []
class EmbeddingExtractor:
"""
Embedding extractor from a layer
Args:
model (`nn.Module`):
Model used for inference.
"""
def __init__(self, model):
# self.modules = model.modules()
# self.modules_list = list(model.named_modules(remove_duplicate=False))
self.modules_dict = dict(model.named_modules(remove_duplicate=False))
self.emb_out = EmbeddingOutput()
def disp_modules(self, full=False):
"""
Display the the modules of the model.
"""
if not full:
pprint.pprint(list(self.modules_dict.keys()))
else:
pprint.pprint(self.modules_dict)
def register(self, layer_name: str):
"""
Registration for embedding extraction.
Args:
layer_name (`str`):
Name of the layer from which the embedding is extracted.
"""
if layer_name in self.modules_dict:
layer = self.modules_dict[layer_name]
layer.register_forward_hook(self.emb_out)
else:
raise ValueError('layer_name not in modules')

32
pytorch/model.py

@ -18,6 +18,13 @@ from typing import NamedTuple
import numpy import numpy
import torch import torch
import torchvision import torchvision
from torch.nn import Linear
from timm.models.resnet import ResNet
# ResNet.
from pytorch.embedding_extractor import EmbeddingExtractor
#todo:后面改成用towhee.models.embedding.下面的EmbeddingExtractor,这个现在在origin main分支上可用,但在train分支上不可用
class Model(): class Model():
@ -38,14 +45,25 @@ class Model():
if state_dict: if state_dict:
self._model.load_state_dict(state_dict) self._model.load_state_dict(state_dict)
self._model.fc = torch.nn.Identity()
# self._model.fc = torch.nn.Identity()
self._model.eval() self._model.eval()
self.ex = EmbeddingExtractor(self._model)
# self.ex.disp_modules(full=True)
self.ex.register('avgpool')
def __call__(self, img_tensor: torch.Tensor): def __call__(self, img_tensor: torch.Tensor):
return self._model(img_tensor).flatten().detach().numpy()
self.ex.emb_out.clear()
self._model(img_tensor)
# return self.fc_input[0]
return self.ex.emb_out.embeddings[0]
# return self._model(img_tensor).flatten().detach().numpy() #todo
def train(self):
"""
For training model
"""
pass
def create_classifier(self, num_classes):
self._model.fc = Linear(self._model.fc.in_features, num_classes, bias=True)
# self._model.classifier.register_forward_hook(self._forward_hook)
# def train(self):
# """
# For training model
# """
# pass

14
resnet_image_embedding.py

@ -21,12 +21,12 @@ from pathlib import Path
from typing import NamedTuple from typing import NamedTuple
import os import os
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
from towhee.operator import Operator
from towhee.operator import NNOperator
from towhee.utils.pil_utils import to_pil from towhee.utils.pil_utils import to_pil
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
class ResnetImageEmbedding(Operator):
class ResnetImageEmbedding(NNOperator):
""" """
PyTorch model for image embedding. PyTorch model for image embedding.
""" """
@ -50,3 +50,13 @@ class ResnetImageEmbedding(Operator):
embedding = self.model(img) embedding = self.model(img)
Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)])
return Outputs(embedding) return Outputs(embedding)
def get_model(self):
return self.model._model
# def test(self):
# return self.framework
def change_before_train(self, num_classes: int = 0):
if num_classes > 0:
self.model.create_classifier(num_classes)

22
resnet_training_yaml.yaml

@ -0,0 +1,22 @@
device:
device_str: null
n_gpu: -1
sync_bn: true
metrics:
metric: Accuracy
train:
batch_size: 16
learning:
optimizer:
name_: SGD
lr: 0.03
momentum: 0.001
nesterov: 111
loss:
name_: CrossEntropyLoss
label_smoothing: 0.1
#learning:
# optimizer:
# name_: Adam
# lr: 0.02
# eps: 0.001

65
test.py

@ -0,0 +1,65 @@
import numpy as np
from torchvision import transforms
from torchvision.transforms import RandomResizedCrop, Lambda
from towhee.trainer.modelcard import ModelCard
from towhee.trainer.training_config import TrainingConfig
from towhee.trainer.dataset import get_dataset
from resnet_image_embedding import ResnetImageEmbedding
from towhee.types import Image
from towhee.trainer.training_config import dump_default_yaml
from PIL import Image as PILImage
from timm.models.resnet import ResNet
if __name__ == '__main__':
# img = torch.rand([1, 3, 224, 224])
img_path = './ILSVRC2012_val_00049771.JPEG'
# # logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png')
img = PILImage.open(img_path)
img_bytes = img.tobytes()
img_width = img.width
img_height = img.height
img_channel = len(img.split())
img_mode = img.mode
img_array = np.array(img)
array_size = np.array(img).shape
towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array)
op = ResnetImageEmbedding('resnet34')
op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data")
# old_out = op(towhee_img)
# print(old_out.feature_vector[0])
training_config = TrainingConfig()
yaml_path = 'resnet_training_yaml.yaml'
# dump_default_yaml(yaml_path=yaml_path)
training_config.load_from_yaml(yaml_path)
# output_dir='./temp_output',
# overwrite_output_dir=True,
# epoch_num=2,
# per_gpu_train_batch_size=16,
# prediction_loss_only=True,
# metric='Accuracy'
# # device_str='cuda',
# # n_gpu=4
# )
mnist_transform = transforms.Compose([transforms.ToTensor(),
RandomResizedCrop(224),
Lambda(lambda x: x.repeat(3, 1, 1)),
transforms.Normalize(mean=[0.5], std=[0.5])])
train_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data')
fake_transform = transforms.Compose([transforms.ToTensor(),
RandomResizedCrop(224),])
# train_data = get_dataset('fake', size=20, transform=fake_transform)
op.change_before_train(10)
op.train(training_config, train_dataset=train_data)
# e.save('./test_save')
# e.load('./test_save')
# new_out = e(img)
# assert (new_out[0]!=old_out[0]).all()
Loading…
Cancel
Save