logo
Browse Source

update test

Signed-off-by: Your Name <you@example.com>
training
Your Name 3 years ago
parent
commit
9b0f12749e
  1. 90
      test.py

90
test.py

@ -35,61 +35,59 @@ if __name__ == '__main__':
# op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data")
old_out = op(towhee_img)
# print(old_out.feature_vector[0][:10])
print(old_out.feature_vector[:10])
# print(old_out.feature_vector[:10])
# print(old_out.feature_vector.shape)
training_config = TrainingConfig()
yaml_path = 'resnet_training_yaml.yaml'
# dump_default_yaml(yaml_path=yaml_path)
training_config.load_from_yaml(yaml_path)
# output_dir='./temp_output',
# overwrite_output_dir=True,
# epoch_num=2,
# per_gpu_train_batch_size=16,
# prediction_loss_only=True,
# metric='Accuracy'
# # device_str='cuda',
# # n_gpu=4
# )
#
mnist_transform = transforms.Compose([transforms.ToTensor(),
RandomResizedCrop(224),
Lambda(lambda x: x.repeat(3, 1, 1)),
transforms.Normalize(mean=[0.5], std=[0.5])])
train_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=True)
eval_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=False)
# fake_transform = transforms.Compose([transforms.ToTensor(),
# RandomResizedCrop(224),])
# train_data = get_dataset('fake', size=20, transform=fake_transform)
#
# op.change_before_train(num_classes=10)
# # trainer = op.setup_trainer()
training_config.overwrite_output_dir=True
training_config.epoch_num=3
training_config.batch_size=256
training_config.device_str='cpu'
training_config.n_gpu=-1
training_config.save_to_yaml(yaml_path)
# mnist_transform = transforms.Compose([transforms.ToTensor(),
# RandomResizedCrop(224),
# Lambda(lambda x: x.repeat(3, 1, 1)),
# transforms.Normalize(mean=[0.5], std=[0.5])])
# train_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=True)
# eval_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=False)
# training_config.output_dir = 'mnist_output'
fake_transform = transforms.Compose([transforms.ToTensor(),
RandomResizedCrop(224),])
train_data = dataset('fake', size=1000, transform=fake_transform)
eval_data = dataset('fake', size=500, transform=fake_transform)
training_config.output_dir = 'fake_output'
# trainer = op.setup_trainer()
# print(op.get_model())
# # my_optimimzer = AdamW(op.get_model().parameters(), lr=0.002, betas=(0.91, 0.98), eps=1e-08, weight_decay=0.01, amsgrad=False)
# # op.setup_trainer()
# my_optimimzer = AdamW(op.get_model().parameters(), lr=0.002, betas=(0.91, 0.98), eps=1e-08, weight_decay=0.01, amsgrad=False)
# op.setup_trainer()
#
# # trainer.add_callback()
# # trainer.set_optimizer()
# trainer.add_callback()
# trainer.set_optimizer()
#
# # op.trainer.set_optimizer(my_optimimzer)
# # trainer.configs.save_to_yaml('changed_optimizer_yaml.yaml')
# op.trainer.set_optimizer(my_optimimzer)
# trainer.configs.save_to_yaml('changed_optimizer_yaml.yaml')
#
# # my_loss = nn.BCELoss()
# # trainer.set_loss(my_loss, 'my_loss111')
# # trainer.configs.save_to_yaml('chaned_loss_yaml.yaml')
# # op.trainer._create_optimizer()
# # op.trainer.set_optimizer()
# # trainer = op.setup_trainer(training_config, train_dataset=train_data, eval_dataset=eval_data)
#
# # freezer = LayerFreezer(op.get_model())
# # freezer.by_idx([-1])
# my_loss = nn.BCELoss()
# trainer.set_loss(my_loss, 'my_loss111')
# trainer.configs.save_to_yaml('chaned_loss_yaml.yaml')
# op.trainer._create_optimizer()
# op.trainer.set_optimizer()
# trainer = op.setup_trainer(training_config, train_dataset=train_data, eval_dataset=eval_data)
freezer = LayerFreezer(op.get_model())
freezer.set_slice(-1)
op.train(training_config, train_dataset=train_data, eval_dataset=eval_data)
# # op.trainer.run_train()
# # training_config.num_epoch = 3
# # op.train(training_config, train_dataset=train_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2')
#
# # op.save('./test_save')
# # op.load('./test_save')
# # new_out = op(towhee_img)
# op.train(training_config, train_dataset=train_data, eval_dataset=eval_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2')
# op.save('./test_save')
# op.load('./test_save')
# new_out = op(towhee_img)
#
# # assert (new_out[0]!=old_out[0]).all()
# assert (new_out[0]!=old_out[0]).all()

Loading…
Cancel
Save