logo
Browse Source

Add freeze_bn config

Signed-off-by: Your Name <you@example.com>
training
Your Name 3 years ago
parent
commit
0b1571476a
  1. 11
      resnet_training_yaml.yaml
  2. 24
      test.py

11
resnet_training_yaml.yaml

@ -4,7 +4,7 @@ callback:
monitor: eval_epoch_metric monitor: eval_epoch_metric
patience: 2 patience: 2
model_checkpoint: model_checkpoint:
every_n_epoch: 2
every_n_epoch: 1
tensorboard: tensorboard:
comment: '' comment: ''
log_dir: null log_dir: null
@ -23,20 +23,21 @@ logging:
logging_dir: null logging_dir: null
logging_strategy: steps logging_strategy: steps
print_steps: null print_steps: null
save_strategy: steps
metrics: metrics:
metric: Accuracy metric: Accuracy
train: train:
batch_size: 16
batch_size: 256
dataloader_drop_last: false dataloader_drop_last: false
dataloader_num_workers: 0 dataloader_num_workers: 0
epoch_num: 16
dataloader_pin_memory: true
epoch_num: 3
eval_steps: null eval_steps: null
eval_strategy: epoch eval_strategy: epoch
freeze_bn: true
load_best_model_at_end: false load_best_model_at_end: false
max_steps: -1
output_dir: ./output_dir output_dir: ./output_dir
overwrite_output_dir: true overwrite_output_dir: true
resume_from_checkpoint: null resume_from_checkpoint: null
seed: 42 seed: 42
val_batch_size: -1 val_batch_size: -1
freeze_bn: true

24
test.py

@ -43,12 +43,12 @@ if __name__ == '__main__':
# dump_default_yaml(yaml_path=yaml_path) # dump_default_yaml(yaml_path=yaml_path)
training_config.load_from_yaml(yaml_path) training_config.load_from_yaml(yaml_path)
training_config.overwrite_output_dir=True
training_config.epoch_num=3
training_config.batch_size=256
training_config.device_str='cpu'
training_config.n_gpu=-1
training_config.save_to_yaml(yaml_path)
# training_config.overwrite_output_dir=True
# training_config.epoch_num=3
# training_config.batch_size=256
# training_config.device_str='cpu'
# training_config.n_gpu=-1
# training_config.save_to_yaml(yaml_path)
# mnist_transform = transforms.Compose([transforms.ToTensor(), # mnist_transform = transforms.Compose([transforms.ToTensor(),
# RandomResizedCrop(224), # RandomResizedCrop(224),
@ -59,8 +59,8 @@ if __name__ == '__main__':
# training_config.output_dir = 'mnist_output' # training_config.output_dir = 'mnist_output'
fake_transform = transforms.Compose([transforms.ToTensor(), fake_transform = transforms.Compose([transforms.ToTensor(),
RandomResizedCrop(224),]) RandomResizedCrop(224),])
train_data = dataset('fake', size=1000, transform=fake_transform)
eval_data = dataset('fake', size=500, transform=fake_transform)
train_data = dataset('fake', size=100, transform=fake_transform)
eval_data = dataset('fake', size=10, transform=fake_transform)
training_config.output_dir = 'fake_output' training_config.output_dir = 'fake_output'
# trainer = op.setup_trainer() # trainer = op.setup_trainer()
@ -87,7 +87,7 @@ if __name__ == '__main__':
# op.train(training_config, train_dataset=train_data, eval_dataset=eval_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2') # op.train(training_config, train_dataset=train_data, eval_dataset=eval_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2')
# op.save('./test_save') # op.save('./test_save')
# op.load('./test_save')
# new_out = op(towhee_img)
#
# assert (new_out[0]!=old_out[0]).all()
# op.load('./test_save')\
new_out = op(towhee_img)
assert (new_out.feature_vector==old_out.feature_vector).all()

Loading…
Cancel
Save