import numpy as np from torch.optim import AdamW from torchvision import transforms from torchvision.transforms import RandomResizedCrop, Lambda from towhee.data.dataset.dataset import dataset from towhee.trainer.modelcard import ModelCard from towhee.trainer.training_config import TrainingConfig # from towhee.trainer.dataset import get_dataset from towhee.trainer.utils.layer_freezer import LayerFreezer from resnet_image_embedding import ResnetImageEmbedding from towhee.types import Image from towhee.trainer.training_config import dump_default_yaml from PIL import Image as PILImage from timm.models.resnet import ResNet from torch import nn if __name__ == '__main__': dump_default_yaml(yaml_path='default_config.yaml') # img = torch.rand([1, 3, 224, 224]) img_path = './ILSVRC2012_val_00049771.JPEG' # logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png') img = PILImage.open(img_path) img_bytes = img.tobytes() img_width = img.width img_height = img.height img_channel = len(img.split()) img_mode = img.mode img_array = np.array(img) array_size = np.array(img).shape towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array) op = ResnetImageEmbedding('resnet50', num_classes=10) old_out = op(towhee_img) # print(old_out.feature_vector[0][:10]) # print(old_out.feature_vector[:10]) # print(old_out.feature_vector.shape) training_config = TrainingConfig() yaml_path = 'resnet_training_yaml.yaml' # dump_default_yaml(yaml_path=yaml_path) training_config.load_from_yaml(yaml_path) # training_config.overwrite_output_dir=True # training_config.epoch_num=3 # training_config.batch_size=256 # training_config.device_str='cpu' # training_config.n_gpu=-1 # training_config.save_to_yaml(yaml_path) # mnist_transform = transforms.Compose([transforms.ToTensor(), # RandomResizedCrop(224), # Lambda(lambda x: x.repeat(3, 1, 1)), # transforms.Normalize(mean=[0.5], std=[0.5])]) # train_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=True) # eval_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=False) # training_config.output_dir = 'mnist_output' # op.model_card = ModelCard(datasets='mnist dataset') fake_transform = transforms.Compose([transforms.ToTensor(), RandomResizedCrop(224)]) train_data = dataset('fake', size=20, transform=fake_transform) eval_data = dataset('fake', size=10, transform=fake_transform) training_config.output_dir = 'fake_output' op.model_card = ModelCard(datasets='fake dataset') # trainer = op.setup_trainer() # print(op.get_model()) # my_optimimzer = AdamW(op.get_model().parameters(), lr=0.002, betas=(0.91, 0.98), eps=1e-08, weight_decay=0.01, amsgrad=False) # op.setup_trainer() # # trainer.add_callback() # trainer.set_optimizer() # # op.trainer.set_optimizer(my_optimimzer) # trainer.configs.save_to_yaml('changed_optimizer_yaml.yaml') # # my_loss = nn.BCELoss() # trainer.set_loss(my_loss, 'my_loss111') # trainer.configs.save_to_yaml('chaned_loss_yaml.yaml') # op.trainer._create_optimizer() # op.trainer.set_optimizer() # trainer = op.setup_trainer(training_config, train_dataset=train_data, eval_dataset=eval_data) freezer = LayerFreezer(op.get_model()) freezer.set_slice(-1) # op.train(training_config, train_dataset=train_data, eval_dataset=eval_data) op.train(training_config, train_dataset=train_data, eval_dataset=eval_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2') # op.save('./test_save') # op.load('./test_save')\ new_out = op(towhee_img) assert (new_out.feature_vector == old_out.feature_vector).all()