logo
Browse Source

Update test.py

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
training
Jael Gu 3 years ago
parent
commit
4ae024fdeb
  1. 13
      test.py

13
test.py

@ -52,14 +52,17 @@ if __name__ == '__main__':
transforms.Normalize(mean=[0.5], std=[0.5])]) transforms.Normalize(mean=[0.5], std=[0.5])])
train_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data') train_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data')
fake_transform = transforms.Compose([transforms.ToTensor(),
RandomResizedCrop(224),])
# fake_transform = transforms.Compose([transforms.ToTensor(),
# RandomResizedCrop(224),])
# train_data = get_dataset('fake', size=20, transform=fake_transform) # train_data = get_dataset('fake', size=20, transform=fake_transform)
op.change_before_train(10) op.change_before_train(10)
op.train(training_config, train_dataset=train_data) op.train(training_config, train_dataset=train_data)
# e.save('./test_save')
# e.load('./test_save')
# new_out = e(img)
training_config.num_epoch = 3
op.train(training_config, train_dataset=train_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2')
# op.save('./test_save')
# op.load('./test_save')
# new_out = op(towhee_img)
# assert (new_out[0]!=old_out[0]).all() # assert (new_out[0]!=old_out[0]).all()

Loading…
Cancel
Save