logo
Browse Source

add bird prediction example.

main
zhang chen 3 years ago
parent
commit
1e18ecb529
  1. 2
      default_config.yaml
  2. 2
      examples/1_quick_start.ipynb
  3. 4
      examples/2_train_on_mnist.ipynb
  4. 76
      examples/3_read_configs_from_yaml.ipynb
  5. 223
      examples/4_fine_tune_from_image_net.ipynb

2
default_config.yaml

@ -26,7 +26,7 @@ metrics:
train:
batch_size: 8
dataloader_drop_last: true
dataloader_num_workers: -1
dataloader_num_workers: 0
dataloader_pin_memory: true
epoch_num: 2
eval_steps: null

2
examples/1_quick_start.ipynb

@ -17,6 +17,8 @@
},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('..')\n",
"from resnet_image_embedding import ResnetImageEmbedding\n",
"from towhee.trainer.training_config import TrainingConfig\n",
"from torchvision import transforms\n",

4
examples/2_train_on_mnist.ipynb

@ -17,6 +17,8 @@
},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('..')\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"import random\n",
@ -136,7 +138,7 @@
"pil_img = img * std + mean\n",
"plt.imshow(pil_img)\n",
"plt.show()\n",
"test_img = eval_data.dataset[img_index][0].unsqueeze(0)\n",
"test_img = eval_data.dataset[img_index][0].unsqueeze(0).to(op.trainer.configs.device)\n",
"out = op.trainer.predict(test_img)\n",
"predict_num = torch.argmax(torch.softmax(out, dim=-1)).item()\n",
"print('this picture is number {}'.format(predict_num))"

76
examples/3_read_configs_from_yaml.ipynb

@ -11,12 +11,14 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('..')\n",
"from resnet_image_embedding import ResnetImageEmbedding\n",
"from torchvision import transforms\n",
"from towhee import dataset\n",
@ -27,7 +29,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"outputs": [],
"source": [
"from towhee.trainer.training_config import dump_default_yaml\n",
@ -54,13 +56,13 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"outputs": [
{
"data": {
"text/plain": "TrainingConfig(output_dir='my_output', overwrite_output_dir=True, eval_strategy='epoch', eval_steps=None, batch_size=5, val_batch_size=-1, seed=42, epoch_num=3, dataloader_pin_memory=True, dataloader_drop_last=True, dataloader_num_workers=-1, lr=5e-05, metric='Accuracy', print_steps=1, load_best_model_at_end=False, early_stopping={'mode': 'max', 'monitor': 'eval_epoch_metric', 'patience': 4}, model_checkpoint={'every_n_epoch': 1}, tensorboard=None, loss='CrossEntropyLoss', optimizer='Adam', lr_scheduler_type='linear', warmup_ratio=0.0, warmup_steps=0, device_str=None, n_gpu=-1, sync_bn=False, freeze_bn=False)"
},
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@ -82,7 +84,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"outputs": [],
"source": [
"# prepare the fake dataset\n",
@ -99,37 +101,37 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-03-02 16:00:26,226 - 8666785280 - trainer.py-trainer:390 - WARNING: TrainingConfig(output_dir='my_output', overwrite_output_dir=True, eval_strategy='epoch', eval_steps=None, batch_size=5, val_batch_size=-1, seed=42, epoch_num=3, dataloader_pin_memory=True, dataloader_drop_last=True, dataloader_num_workers=-1, lr=5e-05, metric='Accuracy', print_steps=1, load_best_model_at_end=False, early_stopping={'mode': 'max', 'monitor': 'eval_epoch_metric', 'patience': 4}, model_checkpoint={'every_n_epoch': 1}, tensorboard=None, loss='CrossEntropyLoss', optimizer='Adam', lr_scheduler_type='linear', warmup_ratio=0.0, warmup_steps=0, device_str=None, n_gpu=-1, sync_bn=False, freeze_bn=False)\n"
"2022-03-03 16:59:41,635 - 4310336896 - trainer.py-trainer:390 - WARNING: TrainingConfig(output_dir='my_output', overwrite_output_dir=True, eval_strategy='epoch', eval_steps=None, batch_size=5, val_batch_size=-1, seed=42, epoch_num=3, dataloader_pin_memory=True, dataloader_drop_last=True, dataloader_num_workers=-1, lr=5e-05, metric='Accuracy', print_steps=1, load_best_model_at_end=False, early_stopping={'mode': 'max', 'monitor': 'eval_epoch_metric', 'patience': 4}, model_checkpoint={'every_n_epoch': 1}, tensorboard=None, loss='CrossEntropyLoss', optimizer='Adam', lr_scheduler_type='linear', warmup_ratio=0.0, warmup_steps=0, device_str=None, n_gpu=-1, sync_bn=False, freeze_bn=False)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch=1/3, global_step=1, epoch_loss=2.5702719688415527, epoch_metric=0.0\n",
"epoch=1/3, global_step=2, epoch_loss=2.572024345397949, epoch_metric=0.0\n",
"epoch=1/3, global_step=3, epoch_loss=2.558194160461426, epoch_metric=0.0\n",
"epoch=1/3, global_step=4, epoch_loss=2.558873176574707, epoch_metric=0.15000000596046448\n",
"epoch=1/3, eval_global_step=0, eval_epoch_loss=2.370976686477661, eval_epoch_metric=0.20000000298023224\n",
"epoch=1/3, eval_global_step=1, eval_epoch_loss=2.2873291969299316, eval_epoch_metric=0.20000000298023224\n",
"epoch=2/3, global_step=5, epoch_loss=1.3134113550186157, epoch_metric=0.20000000298023224\n",
"epoch=2/3, global_step=6, epoch_loss=1.3073358535766602, epoch_metric=0.10000000149011612\n",
"epoch=2/3, global_step=7, epoch_loss=1.41914701461792, epoch_metric=0.13333334028720856\n",
"epoch=2/3, global_step=8, epoch_loss=1.3628838062286377, epoch_metric=0.15000000596046448\n",
"epoch=2/3, eval_global_step=2, eval_epoch_loss=1.3158948421478271, eval_epoch_metric=0.20000000298023224\n",
"epoch=2/3, eval_global_step=3, eval_epoch_loss=1.3246530294418335, eval_epoch_metric=0.20000000298023224\n",
"epoch=3/3, global_step=9, epoch_loss=1.4589173793792725, epoch_metric=0.0\n",
"epoch=3/3, global_step=10, epoch_loss=1.4343616962432861, epoch_metric=0.0\n",
"epoch=3/3, global_step=11, epoch_loss=1.3701648712158203, epoch_metric=0.06666667014360428\n",
"epoch=3/3, global_step=12, epoch_loss=1.1501117944717407, epoch_metric=0.10000000149011612\n",
"epoch=3/3, eval_global_step=4, eval_epoch_loss=1.1129425764083862, eval_epoch_metric=0.0\n",
"epoch=3/3, eval_global_step=5, eval_epoch_loss=1.1257113218307495, eval_epoch_metric=0.0\n"
"epoch=1/3, global_step=1, epoch_loss=2.469155788421631, epoch_metric=0.20000000298023224\n",
"epoch=1/3, global_step=2, epoch_loss=2.486016273498535, epoch_metric=0.20000000298023224\n",
"epoch=1/3, global_step=3, epoch_loss=2.519146203994751, epoch_metric=0.20000000298023224\n",
"epoch=1/3, global_step=4, epoch_loss=2.451723098754883, epoch_metric=0.20000000298023224\n",
"epoch=1/3, eval_global_step=0, eval_epoch_loss=2.263216495513916, eval_epoch_metric=0.20000000298023224\n",
"epoch=1/3, eval_global_step=1, eval_epoch_loss=2.1709983348846436, eval_epoch_metric=0.20000000298023224\n",
"epoch=2/3, global_step=5, epoch_loss=1.2240798473358154, epoch_metric=0.20000000298023224\n",
"epoch=2/3, global_step=6, epoch_loss=1.1725499629974365, epoch_metric=0.20000000298023224\n",
"epoch=2/3, global_step=7, epoch_loss=1.2648464441299438, epoch_metric=0.20000000298023224\n",
"epoch=2/3, global_step=8, epoch_loss=1.30061936378479, epoch_metric=0.15000000596046448\n",
"epoch=2/3, eval_global_step=2, eval_epoch_loss=1.2398303747177124, eval_epoch_metric=0.0\n",
"epoch=2/3, eval_global_step=3, eval_epoch_loss=1.2246357202529907, eval_epoch_metric=0.10000000149011612\n",
"epoch=3/3, global_step=9, epoch_loss=1.501572847366333, epoch_metric=0.20000000298023224\n",
"epoch=3/3, global_step=10, epoch_loss=1.365707516670227, epoch_metric=0.20000000298023224\n",
"epoch=3/3, global_step=11, epoch_loss=1.2403526306152344, epoch_metric=0.13333334028720856\n",
"epoch=3/3, global_step=12, epoch_loss=1.0921388864517212, epoch_metric=0.10000000149011612\n",
"epoch=3/3, eval_global_step=4, eval_epoch_loss=1.0393352508544922, eval_epoch_metric=0.0\n",
"epoch=3/3, eval_global_step=5, eval_epoch_loss=1.0277410745620728, eval_epoch_metric=0.10000000149011612\n"
]
}
],
@ -156,6 +158,30 @@
}
}
},
{
"cell_type": "markdown",
"source": [
"### By the way, you can change the config in your python code and save the config into a yaml file. So it's easy to convert between the python config instance and yaml file."
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 7,
"outputs": [],
"source": [
"training_config.batch_size = 2\n",
"training_config.save_to_yaml('another_setting.yaml')\n"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,

223
examples/4_fine_tune_from_image_net.ipynb

File diff suppressed because one or more lines are too long
Loading…
Cancel
Save