From 4ae024fdeb4455e38b59796ec82821c8e2588e09 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Tue, 15 Feb 2022 17:45:05 +0800 Subject: [PATCH] Update test.py Signed-off-by: Jael Gu --- test.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/test.py b/test.py index 04ad646..95eead9 100644 --- a/test.py +++ b/test.py @@ -52,14 +52,17 @@ if __name__ == '__main__': transforms.Normalize(mean=[0.5], std=[0.5])]) train_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data') - fake_transform = transforms.Compose([transforms.ToTensor(), - RandomResizedCrop(224),]) + # fake_transform = transforms.Compose([transforms.ToTensor(), + # RandomResizedCrop(224),]) # train_data = get_dataset('fake', size=20, transform=fake_transform) op.change_before_train(10) op.train(training_config, train_dataset=train_data) - # e.save('./test_save') - # e.load('./test_save') - # new_out = e(img) + training_config.num_epoch = 3 + op.train(training_config, train_dataset=train_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2') + + # op.save('./test_save') + # op.load('./test_save') + # new_out = op(towhee_img) # assert (new_out[0]!=old_out[0]).all()