Browse Source
Update test.py
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
training
1 changed files with
8 additions and
5 deletions
-
test.py
|
|
@ -52,14 +52,17 @@ if __name__ == '__main__': |
|
|
|
transforms.Normalize(mean=[0.5], std=[0.5])]) |
|
|
|
train_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data') |
|
|
|
|
|
|
|
fake_transform = transforms.Compose([transforms.ToTensor(), |
|
|
|
RandomResizedCrop(224),]) |
|
|
|
# fake_transform = transforms.Compose([transforms.ToTensor(), |
|
|
|
# RandomResizedCrop(224),]) |
|
|
|
# train_data = get_dataset('fake', size=20, transform=fake_transform) |
|
|
|
|
|
|
|
op.change_before_train(10) |
|
|
|
op.train(training_config, train_dataset=train_data) |
|
|
|
# e.save('./test_save') |
|
|
|
# e.load('./test_save') |
|
|
|
# new_out = e(img) |
|
|
|
training_config.num_epoch = 3 |
|
|
|
op.train(training_config, train_dataset=train_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2') |
|
|
|
|
|
|
|
# op.save('./test_save') |
|
|
|
# op.load('./test_save') |
|
|
|
# new_out = op(towhee_img) |
|
|
|
|
|
|
|
# assert (new_out[0]!=old_out[0]).all() |
|
|
|