|
|
|
import numpy as np
|
|
|
|
from torch.optim import AdamW
|
|
|
|
from torchvision import transforms
|
|
|
|
from torchvision.transforms import RandomResizedCrop, Lambda
|
|
|
|
from towhee.data.dataset.dataset import dataset
|
|
|
|
from towhee.trainer.modelcard import ModelCard
|
|
|
|
|
|
|
|
from towhee.trainer.training_config import TrainingConfig
|
|
|
|
# from towhee.trainer.dataset import get_dataset
|
|
|
|
from towhee.trainer.utils.layer_freezer import LayerFreezer
|
|
|
|
|
|
|
|
from resnet_image_embedding import ResnetImageEmbedding
|
|
|
|
from towhee.types import Image
|
|
|
|
from towhee.trainer.training_config import dump_default_yaml
|
|
|
|
from PIL import Image as PILImage
|
|
|
|
from timm.models.resnet import ResNet
|
|
|
|
from torch import nn
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
dump_default_yaml(yaml_path='default_config.yaml')
|
|
|
|
# img = torch.rand([1, 3, 224, 224])
|
|
|
|
img_path = './ILSVRC2012_val_00049771.JPEG'
|
|
|
|
# logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png')
|
|
|
|
img = PILImage.open(img_path)
|
|
|
|
img_bytes = img.tobytes()
|
|
|
|
img_width = img.width
|
|
|
|
img_height = img.height
|
|
|
|
img_channel = len(img.split())
|
|
|
|
img_mode = img.mode
|
|
|
|
img_array = np.array(img)
|
|
|
|
array_size = np.array(img).shape
|
|
|
|
towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array)
|
|
|
|
|
|
|
|
op = ResnetImageEmbedding('resnet50', num_classes=10)
|
|
|
|
|
|
|
|
old_out = op(towhee_img)
|
|
|
|
# print(old_out.feature_vector[0][:10])
|
|
|
|
# print(old_out.feature_vector[:10])
|
|
|
|
# print(old_out.feature_vector.shape)
|
|
|
|
|
|
|
|
training_config = TrainingConfig()
|
|
|
|
yaml_path = 'resnet_training_yaml.yaml'
|
|
|
|
# dump_default_yaml(yaml_path=yaml_path)
|
|
|
|
training_config.load_from_yaml(yaml_path)
|
|
|
|
|
|
|
|
# training_config.overwrite_output_dir=True
|
|
|
|
# training_config.epoch_num=3
|
|
|
|
# training_config.batch_size=256
|
|
|
|
# training_config.device_str='cpu'
|
|
|
|
# training_config.n_gpu=-1
|
|
|
|
# training_config.save_to_yaml(yaml_path)
|
|
|
|
|
|
|
|
# mnist_transform = transforms.Compose([transforms.ToTensor(),
|
|
|
|
# RandomResizedCrop(224),
|
|
|
|
# Lambda(lambda x: x.repeat(3, 1, 1)),
|
|
|
|
# transforms.Normalize(mean=[0.5], std=[0.5])])
|
|
|
|
# train_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=True)
|
|
|
|
# eval_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=False)
|
|
|
|
# training_config.output_dir = 'mnist_output'
|
|
|
|
# op.model_card = ModelCard(datasets='mnist dataset')
|
|
|
|
|
|
|
|
fake_transform = transforms.Compose([transforms.ToTensor(),
|
|
|
|
RandomResizedCrop(224)])
|
|
|
|
train_data = dataset('fake', size=20, transform=fake_transform)
|
|
|
|
eval_data = dataset('fake', size=10, transform=fake_transform)
|
|
|
|
training_config.output_dir = 'fake_output'
|
|
|
|
op.model_card = ModelCard(datasets='fake dataset')
|
|
|
|
|
|
|
|
# trainer = op.setup_trainer()
|
|
|
|
# print(op.get_model())
|
|
|
|
# my_optimimzer = AdamW(op.get_model().parameters(), lr=0.002, betas=(0.91, 0.98), eps=1e-08, weight_decay=0.01, amsgrad=False)
|
|
|
|
# op.setup_trainer()
|
|
|
|
#
|
|
|
|
# trainer.add_callback()
|
|
|
|
# trainer.set_optimizer()
|
|
|
|
#
|
|
|
|
# op.trainer.set_optimizer(my_optimimzer)
|
|
|
|
# trainer.configs.save_to_yaml('changed_optimizer_yaml.yaml')
|
|
|
|
#
|
|
|
|
# my_loss = nn.BCELoss()
|
|
|
|
# trainer.set_loss(my_loss, 'my_loss111')
|
|
|
|
# trainer.configs.save_to_yaml('chaned_loss_yaml.yaml')
|
|
|
|
# op.trainer._create_optimizer()
|
|
|
|
# op.trainer.set_optimizer()
|
|
|
|
# trainer = op.setup_trainer(training_config, train_dataset=train_data, eval_dataset=eval_data)
|
|
|
|
|
|
|
|
freezer = LayerFreezer(op.get_model())
|
|
|
|
freezer.set_slice(-1)
|
|
|
|
|
|
|
|
# op.train(training_config, train_dataset=train_data, eval_dataset=eval_data)
|
|
|
|
op.train(training_config, train_dataset=train_data, eval_dataset=eval_data,
|
|
|
|
resume_checkpoint_path=training_config.output_dir + '/epoch_2')
|
|
|
|
|
|
|
|
# op.save('./test_save')
|
|
|
|
# op.load('./test_save')\
|
|
|
|
|
|
|
|
new_out = op(towhee_img)
|
|
|
|
assert (new_out.feature_vector == old_out.feature_vector).all()
|