From f1346e4f45cbd37780a745e93c2ed63e051f0c38 Mon Sep 17 00:00:00 2001 From: zhang chen Date: Thu, 17 Feb 2022 11:37:04 +0800 Subject: [PATCH] add eval ability --- resnet_training_yaml.yaml | 4 +++- test.py | 20 +++++++++----------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/resnet_training_yaml.yaml b/resnet_training_yaml.yaml index 1ae4559..34129b5 100644 --- a/resnet_training_yaml.yaml +++ b/resnet_training_yaml.yaml @@ -7,7 +7,7 @@ metrics: train: batch_size: 32 overwrite_output_dir: true - epoch_num: 1 + epoch_num: 2 learning: optimizer: name_: SGD @@ -16,6 +16,8 @@ learning: loss: name_: CrossEntropyLoss ignore_index: -1 +logging: + print_steps: 2 #learning: # optimizer: # name_: Adam diff --git a/test.py b/test.py index 1650e5e..23d52ca 100644 --- a/test.py +++ b/test.py @@ -29,7 +29,7 @@ if __name__ == '__main__': towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array) op = ResnetImageEmbedding('resnet34') - op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data") + # op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data") # old_out = op(towhee_img) # print(old_out.feature_vector[0]) @@ -51,33 +51,31 @@ if __name__ == '__main__': RandomResizedCrop(224), Lambda(lambda x: x.repeat(3, 1, 1)), transforms.Normalize(mean=[0.5], std=[0.5])]) - train_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data') - + train_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data', train=True) + eval_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data', train=False) # fake_transform = transforms.Compose([transforms.ToTensor(), # RandomResizedCrop(224),]) # train_data = get_dataset('fake', size=20, transform=fake_transform) op.change_before_train(10) trainer = op.setup_trainer() - - - my_optimimzer = AdamW(op.get_model().parameters(), lr=0.002, betas=(0.91, 0.98), eps=1e-08, weight_decay=0.01, amsgrad=False) + # my_optimimzer = AdamW(op.get_model().parameters(), lr=0.002, betas=(0.91, 0.98), eps=1e-08, weight_decay=0.01, amsgrad=False) # op.setup_trainer() # trainer.add_callback() # trainer.set_optimizer() - op.trainer.set_optimizer(my_optimimzer) - trainer.configs.save_to_yaml('changed_optimizer_yaml.yaml') + # op.trainer.set_optimizer(my_optimimzer) + # trainer.configs.save_to_yaml('changed_optimizer_yaml.yaml') # my_loss = nn.BCELoss() # trainer.set_loss(my_loss, 'my_loss111') # trainer.configs.save_to_yaml('chaned_loss_yaml.yaml') # op.trainer._create_optimizer() # op.trainer.set_optimizer() - op.train(training_config, train_dataset=train_data) - training_config.num_epoch = 3 - op.train(training_config, train_dataset=train_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2') + op.train(training_config, train_dataset=train_data, eval_dataset=eval_data) + # training_config.num_epoch = 3 + # op.train(training_config, train_dataset=train_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2') # op.save('./test_save') # op.load('./test_save')