|
|
@ -32,7 +32,7 @@ if __name__ == '__main__': |
|
|
|
towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array) |
|
|
|
|
|
|
|
op = ResnetImageEmbedding('resnet50', num_classes=10) |
|
|
|
# op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data") |
|
|
|
|
|
|
|
old_out = op(towhee_img) |
|
|
|
# print(old_out.feature_vector[0][:10]) |
|
|
|
# print(old_out.feature_vector[:10]) |
|
|
@ -57,11 +57,14 @@ if __name__ == '__main__': |
|
|
|
# train_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=True) |
|
|
|
# eval_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=False) |
|
|
|
# training_config.output_dir = 'mnist_output' |
|
|
|
# op.model_card = ModelCard(datasets='mnist dataset') |
|
|
|
|
|
|
|
fake_transform = transforms.Compose([transforms.ToTensor(), |
|
|
|
RandomResizedCrop(224),]) |
|
|
|
train_data = dataset('fake', size=100, transform=fake_transform) |
|
|
|
RandomResizedCrop(224)]) |
|
|
|
train_data = dataset('fake', size=20, transform=fake_transform) |
|
|
|
eval_data = dataset('fake', size=10, transform=fake_transform) |
|
|
|
training_config.output_dir = 'fake_output' |
|
|
|
op.model_card = ModelCard(datasets='fake dataset') |
|
|
|
|
|
|
|
# trainer = op.setup_trainer() |
|
|
|
# print(op.get_model()) |
|
|
@ -83,11 +86,13 @@ if __name__ == '__main__': |
|
|
|
|
|
|
|
freezer = LayerFreezer(op.get_model()) |
|
|
|
freezer.set_slice(-1) |
|
|
|
op.train(training_config, train_dataset=train_data, eval_dataset=eval_data) |
|
|
|
# op.train(training_config, train_dataset=train_data, eval_dataset=eval_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2') |
|
|
|
|
|
|
|
# op.train(training_config, train_dataset=train_data, eval_dataset=eval_data) |
|
|
|
op.train(training_config, train_dataset=train_data, eval_dataset=eval_data, |
|
|
|
resume_checkpoint_path=training_config.output_dir + '/epoch_2') |
|
|
|
|
|
|
|
# op.save('./test_save') |
|
|
|
# op.load('./test_save')\ |
|
|
|
|
|
|
|
new_out = op(towhee_img) |
|
|
|
assert (new_out.feature_vector==old_out.feature_vector).all() |
|
|
|
assert (new_out.feature_vector == old_out.feature_vector).all() |
|
|
|