logo
Browse Source

update

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
training
Jael Gu 3 years ago
parent
commit
5341779d16
  1. 3
      resnet_training_yaml.yaml
  2. 17
      test.py

3
resnet_training_yaml.yaml

@ -4,7 +4,7 @@ callback:
monitor: eval_epoch_metric
patience: 2
model_checkpoint:
every_n_epoch: 1
every_n_epoch: 2
tensorboard:
comment: ''
log_dir: null
@ -40,4 +40,3 @@ train:
resume_from_checkpoint: null
seed: 42
val_batch_size: -1
freeze_bn: true

17
test.py

@ -32,7 +32,7 @@ if __name__ == '__main__':
towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array)
op = ResnetImageEmbedding('resnet50', num_classes=10)
# op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data")
old_out = op(towhee_img)
# print(old_out.feature_vector[0][:10])
# print(old_out.feature_vector[:10])
@ -57,11 +57,14 @@ if __name__ == '__main__':
# train_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=True)
# eval_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=False)
# training_config.output_dir = 'mnist_output'
# op.model_card = ModelCard(datasets='mnist dataset')
fake_transform = transforms.Compose([transforms.ToTensor(),
RandomResizedCrop(224),])
train_data = dataset('fake', size=100, transform=fake_transform)
RandomResizedCrop(224)])
train_data = dataset('fake', size=20, transform=fake_transform)
eval_data = dataset('fake', size=10, transform=fake_transform)
training_config.output_dir = 'fake_output'
op.model_card = ModelCard(datasets='fake dataset')
# trainer = op.setup_trainer()
# print(op.get_model())
@ -83,11 +86,13 @@ if __name__ == '__main__':
freezer = LayerFreezer(op.get_model())
freezer.set_slice(-1)
op.train(training_config, train_dataset=train_data, eval_dataset=eval_data)
# op.train(training_config, train_dataset=train_data, eval_dataset=eval_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2')
# op.train(training_config, train_dataset=train_data, eval_dataset=eval_data)
op.train(training_config, train_dataset=train_data, eval_dataset=eval_data,
resume_checkpoint_path=training_config.output_dir + '/epoch_2')
# op.save('./test_save')
# op.load('./test_save')\
new_out = op(towhee_img)
assert (new_out.feature_vector==old_out.feature_vector).all()
assert (new_out.feature_vector == old_out.feature_vector).all()

Loading…
Cancel
Save