logo
Browse Source

add custom optimizer and loss case

training
zhang chen 3 years ago
parent
commit
51a17a4bb0
  1. 9
      resnet_image_embedding.py
  2. 9
      resnet_training_yaml.yaml
  3. 22
      test.py

9
resnet_image_embedding.py

@ -16,6 +16,7 @@ import numpy
import torch import torch
import torchvision import torchvision
from PIL import Image from PIL import Image
from torch import nn as nn
from torchvision import transforms from torchvision import transforms
from pathlib import Path from pathlib import Path
from typing import NamedTuple from typing import NamedTuple
@ -31,7 +32,7 @@ class ResnetImageEmbedding(NNOperator):
PyTorch model for image embedding. PyTorch model for image embedding.
""" """
def __init__(self, model_name: str, framework: str = 'pytorch') -> None: def __init__(self, model_name: str, framework: str = 'pytorch') -> None:
super().__init__()
super().__init__(framework=framework)
if framework == 'pytorch': if framework == 'pytorch':
import importlib.util import importlib.util
path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py') path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py')
@ -40,6 +41,7 @@ class ResnetImageEmbedding(NNOperator):
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) spec.loader.exec_module(module)
self.model = module.Model(model_name) self.model = module.Model(model_name)
self.tfms = transforms.Compose([transforms.Resize(235, interpolation=InterpolationMode.BICUBIC), self.tfms = transforms.Compose([transforms.Resize(235, interpolation=InterpolationMode.BICUBIC),
transforms.CenterCrop(224), transforms.CenterCrop(224),
transforms.ToTensor(), transforms.ToTensor(),
@ -51,12 +53,9 @@ class ResnetImageEmbedding(NNOperator):
Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)])
return Outputs(embedding) return Outputs(embedding)
def get_model(self):
def get_model(self) -> nn.Module:
return self.model._model return self.model._model
# def test(self):
# return self.framework
def change_before_train(self, num_classes: int = 0): def change_before_train(self, num_classes: int = 0):
if num_classes > 0: if num_classes > 0:
self.model.create_classifier(num_classes) self.model.create_classifier(num_classes)

9
resnet_training_yaml.yaml

@ -5,16 +5,17 @@ device:
metrics: metrics:
metric: Accuracy metric: Accuracy
train: train:
batch_size: 16
batch_size: 32
overwrite_output_dir: true
epoch_num: 1
learning: learning:
optimizer: optimizer:
name_: SGD name_: SGD
lr: 0.03
lr: 0.04
momentum: 0.001 momentum: 0.001
nesterov: 111
loss: loss:
name_: CrossEntropyLoss name_: CrossEntropyLoss
label_smoothing: 0.1
ignore_index: -1
#learning: #learning:
# optimizer: # optimizer:
# name_: Adam # name_: Adam

22
test.py

@ -1,4 +1,5 @@
import numpy as np import numpy as np
from torch.optim import AdamW
from torchvision import transforms from torchvision import transforms
from torchvision.transforms import RandomResizedCrop, Lambda from torchvision.transforms import RandomResizedCrop, Lambda
from towhee.trainer.modelcard import ModelCard from towhee.trainer.modelcard import ModelCard
@ -10,8 +11,10 @@ from towhee.types import Image
from towhee.trainer.training_config import dump_default_yaml from towhee.trainer.training_config import dump_default_yaml
from PIL import Image as PILImage from PIL import Image as PILImage
from timm.models.resnet import ResNet from timm.models.resnet import ResNet
from torch import nn
if __name__ == '__main__': if __name__ == '__main__':
dump_default_yaml(yaml_path='default_config.yaml')
# img = torch.rand([1, 3, 224, 224]) # img = torch.rand([1, 3, 224, 224])
img_path = './ILSVRC2012_val_00049771.JPEG' img_path = './ILSVRC2012_val_00049771.JPEG'
# # logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png') # # logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png')
@ -28,7 +31,6 @@ if __name__ == '__main__':
op = ResnetImageEmbedding('resnet34') op = ResnetImageEmbedding('resnet34')
op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data") op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data")
# old_out = op(towhee_img) # old_out = op(towhee_img)
# print(old_out.feature_vector[0]) # print(old_out.feature_vector[0])
training_config = TrainingConfig() training_config = TrainingConfig()
@ -45,7 +47,6 @@ if __name__ == '__main__':
# # n_gpu=4 # # n_gpu=4
# ) # )
mnist_transform = transforms.Compose([transforms.ToTensor(), mnist_transform = transforms.Compose([transforms.ToTensor(),
RandomResizedCrop(224), RandomResizedCrop(224),
Lambda(lambda x: x.repeat(3, 1, 1)), Lambda(lambda x: x.repeat(3, 1, 1)),
@ -57,6 +58,23 @@ if __name__ == '__main__':
# train_data = get_dataset('fake', size=20, transform=fake_transform) # train_data = get_dataset('fake', size=20, transform=fake_transform)
op.change_before_train(10) op.change_before_train(10)
trainer = op.setup_trainer()
my_optimimzer = AdamW(op.get_model().parameters(), lr=0.002, betas=(0.91, 0.98), eps=1e-08, weight_decay=0.01, amsgrad=False)
# op.setup_trainer()
# trainer.add_callback()
# trainer.set_optimizer()
op.trainer.set_optimizer(my_optimimzer)
trainer.configs.save_to_yaml('changed_optimizer_yaml.yaml')
# my_loss = nn.BCELoss()
# trainer.set_loss(my_loss, 'my_loss111')
# trainer.configs.save_to_yaml('chaned_loss_yaml.yaml')
# op.trainer._create_optimizer()
# op.trainer.set_optimizer()
op.train(training_config, train_dataset=train_data) op.train(training_config, train_dataset=train_data)
training_config.num_epoch = 3 training_config.num_epoch = 3
op.train(training_config, train_dataset=train_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2') op.train(training_config, train_dataset=train_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2')

Loading…
Cancel
Save