logo
Browse Source

add some training examples.

main
zhang chen 3 years ago
parent
commit
7c78b94de1
  1. 39
      default_config.yaml
  2. 136
      examples/1_quick_start.ipynb
  3. 163
      examples/2_train_on_mnist.ipynb
  4. 193
      examples/3_read_configs_from_yaml.ipynb
  5. 37
      examples/my_setting.yaml
  6. 17
      resnet_training_yaml.yaml
  7. 52
      test.py

39
default_config.yaml

@ -0,0 +1,39 @@
callback:
early_stopping:
mode: max
monitor: eval_epoch_metric
patience: 4
model_checkpoint:
every_n_epoch: 1
tensorboard:
comment: ''
log_dir: null
device:
device_str: null
n_gpu: -1
sync_bn: false
learning:
loss: CrossEntropyLoss
lr: 5.0e-05
lr_scheduler_type: linear
optimizer: Adam
warmup_ratio: 0.0
warmup_steps: 0
logging:
print_steps: null
metrics:
metric: Accuracy
train:
batch_size: 8
dataloader_drop_last: true
dataloader_num_workers: -1
dataloader_pin_memory: true
epoch_num: 2
eval_steps: null
eval_strategy: epoch
freeze_bn: false
load_best_model_at_end: false
output_dir: ./output_dir
overwrite_output_dir: true
seed: 42
val_batch_size: -1

136
examples/1_quick_start.ipynb

@ -0,0 +1,136 @@
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# Quick start to train an operator on a toy fake dataset."
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from resnet_image_embedding import ResnetImageEmbedding\n",
"from towhee.trainer.training_config import TrainingConfig\n",
"from torchvision import transforms\n",
"from towhee import dataset\n",
"\n",
"# build an resnet op:\n",
"op = ResnetImageEmbedding('resnet18', num_classes=10)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"outputs": [],
"source": [
"# build a training config:\n",
"training_config = TrainingConfig()\n",
"training_config.batch_size = 2\n",
"training_config.epoch_num = 2\n",
"training_config.tensorboard = None\n",
"training_config.output_dir = 'quick_start_output'"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 11,
"outputs": [],
"source": [
"# prepare the dataset\n",
"fake_transform = transforms.Compose([transforms.ToTensor()])\n",
"train_data = dataset('fake', size=20, transform=fake_transform)\n",
"eval_data = dataset('fake', size=10, transform=fake_transform)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 9,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-03-02 15:09:06,334 - 8665081344 - trainer.py-trainer:390 - WARNING: TrainingConfig(output_dir='quick_start_output', overwrite_output_dir=True, eval_strategy='epoch', eval_steps=None, batch_size=2, val_batch_size=-1, seed=42, epoch_num=2, dataloader_pin_memory=True, dataloader_drop_last=True, dataloader_num_workers=-1, lr=5e-05, metric='Accuracy', print_steps=None, load_best_model_at_end=False, early_stopping={'monitor': 'eval_epoch_metric', 'patience': 4, 'mode': 'max'}, model_checkpoint={'every_n_epoch': 1}, tensorboard=None, loss='CrossEntropyLoss', optimizer='Adam', lr_scheduler_type='linear', warmup_ratio=0.0, warmup_steps=0, device_str=None, n_gpu=-1, sync_bn=False, freeze_bn=False)\n",
"[epoch 1/2] loss=2.402, metric=0.0, eval_loss=2.254, eval_metric=0.0: 100%|██████████| 10/10 [00:32<00:00, 3.25s/step]\n",
"[epoch 2/2] loss=1.88, metric=0.1, eval_loss=1.855, eval_metric=0.1: 100%|██████████| 10/10 [00:22<00:00, 1.14step/s] "
]
}
],
"source": [
"# start training, it will take about 2 minute on a cpu machine.\n",
"op.train(training_config, train_dataset=train_data, eval_dataset=eval_data)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"### If you see the two epochs progress bar finish its schedule and a `quick_start_output` folder result, it means you succeeded."
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

163
examples/2_train_on_mnist.ipynb

@ -0,0 +1,163 @@
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# Train resnet18 operator on mnist dataset."
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from resnet_image_embedding import ResnetImageEmbedding\n",
"from towhee.trainer.training_config import TrainingConfig\n",
"from torchvision import transforms\n",
"from towhee import dataset\n",
"from torchvision.transforms import Lambda\n",
"\n",
"# build a resnet op with 10 classes output, because the mnist has 10 classes:\n",
"op = ResnetImageEmbedding('resnet18', num_classes=10)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [],
"source": [
"# build a training config:\n",
"training_config = TrainingConfig()\n",
"training_config.batch_size = 64\n",
"training_config.epoch_num = 5\n",
"training_config.tensorboard = None\n",
"training_config.output_dir = 'mnist_output'\n",
"training_config.dataloader_num_workers = 0"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [],
"source": [
"# prepare the mnist data\n",
"mnist_transform = transforms.Compose([transforms.ToTensor(),\n",
" Lambda(lambda x: x.repeat(3, 1, 1)),\n",
" transforms.Normalize(mean=[0.1307,0.1307,0.1307], std=[0.3081,0.3081,0.3081])])\n",
"train_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=True)\n",
"eval_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=False)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-03-02 15:24:19,702 - 8669324800 - trainer.py-trainer:390 - WARNING: TrainingConfig(output_dir='mnist_output', overwrite_output_dir=True, eval_strategy='epoch', eval_steps=None, batch_size=64, val_batch_size=-1, seed=42, epoch_num=5, dataloader_pin_memory=True, dataloader_drop_last=True, dataloader_num_workers=0, lr=5e-05, metric='Accuracy', print_steps=None, load_best_model_at_end=False, early_stopping={'monitor': 'eval_epoch_metric', 'patience': 4, 'mode': 'max'}, model_checkpoint={'every_n_epoch': 1}, tensorboard=None, loss='CrossEntropyLoss', optimizer='Adam', lr_scheduler_type='linear', warmup_ratio=0.0, warmup_steps=0, device_str=None, n_gpu=-1, sync_bn=False, freeze_bn=False)\n",
"[epoch 1/5] loss=0.388, metric=0.884: 41%|████ | 383/937 [07:38<11:00, 1.19s/step]"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[0;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)",
"\u001B[0;32m/var/folders/wn/4wflyq8x0f9bhkwryvss30880000gn/T/ipykernel_5732/1544844912.py\u001B[0m in \u001B[0;36m<module>\u001B[0;34m\u001B[0m\n\u001B[0;32m----> 1\u001B[0;31m \u001B[0mop\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtrain\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mtraining_config\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mtrain_dataset\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0mtrain_data\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0meval_dataset\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0meval_data\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/towhee-0.4.1.dev110-py3.9.egg/towhee/operator/base.py\u001B[0m in \u001B[0;36mtrain\u001B[0;34m(self, training_config, train_dataset, eval_dataset, resume_checkpoint_path, **kwargs)\u001B[0m\n\u001B[1;32m 136\u001B[0m \"\"\"\n\u001B[1;32m 137\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0msetup_trainer\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mtraining_config\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mtrain_dataset\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0meval_dataset\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 138\u001B[0;31m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtrainer\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtrain\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mresume_checkpoint_path\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 139\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 140\u001B[0m \u001B[0;34m@\u001B[0m\u001B[0mproperty\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/towhee-0.4.1.dev110-py3.9.egg/towhee/trainer/trainer.py\u001B[0m in \u001B[0;36mtrain\u001B[0;34m(self, resume_checkpoint_path)\u001B[0m\n\u001B[1;32m 281\u001B[0m \u001B[0;32melse\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 282\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mdistributed\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;32mFalse\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 283\u001B[0;31m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mrun_train\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mresume_checkpoint_path\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 284\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 285\u001B[0m \u001B[0;32mdef\u001B[0m \u001B[0m_spawn_train_process\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mresume_checkpoint_path\u001B[0m\u001B[0;34m:\u001B[0m \u001B[0mOptional\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0mstr\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/towhee-0.4.1.dev110-py3.9.egg/towhee/trainer/trainer.py\u001B[0m in \u001B[0;36mrun_train\u001B[0;34m(self, resume_checkpoint_path, rank, world_size)\u001B[0m\n\u001B[1;32m 407\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mcallbacks\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mon_train_batch_begin\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0minputs\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mlogs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 408\u001B[0m \u001B[0minputs\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mprepare_inputs\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0minputs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 409\u001B[0;31m \u001B[0mstep_logs\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtrain_step\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mmodel\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0minputs\u001B[0m\u001B[0;34m)\u001B[0m \u001B[0;31m# , train_dataloader)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 410\u001B[0m \u001B[0mlogs\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m\"lr\"\u001B[0m\u001B[0;34m]\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mlr_scheduler\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mget_lr\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;36m0\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 411\u001B[0m \u001B[0mlogs\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m\"global_step\"\u001B[0m\u001B[0;34m]\u001B[0m \u001B[0;34m+=\u001B[0m \u001B[0;36m1\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/towhee-0.4.1.dev110-py3.9.egg/towhee/trainer/trainer.py\u001B[0m in \u001B[0;36mtrain_step\u001B[0;34m(self, model, inputs)\u001B[0m\n\u001B[1;32m 615\u001B[0m \u001B[0mStep\u001B[0m \u001B[0mlogs\u001B[0m \u001B[0mwhich\u001B[0m \u001B[0mcontains\u001B[0m \u001B[0mthe\u001B[0m \u001B[0mstep\u001B[0m \u001B[0mloss\u001B[0m \u001B[0;32mand\u001B[0m \u001B[0mmetric\u001B[0m \u001B[0minfos\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 616\u001B[0m \"\"\"\n\u001B[0;32m--> 617\u001B[0;31m \u001B[0mstep_loss\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mcompute_loss\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mmodel\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0minputs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 618\u001B[0m \u001B[0mstep_loss\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mreduce_value\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mstep_loss\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0maverage\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mTrue\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 619\u001B[0m \u001B[0mstep_loss\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mbackward\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/towhee-0.4.1.dev110-py3.9.egg/towhee/trainer/trainer.py\u001B[0m in \u001B[0;36mcompute_loss\u001B[0;34m(self, model, inputs)\u001B[0m\n\u001B[1;32m 653\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mset_train_mode\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mmodel\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 654\u001B[0m \u001B[0mlabels\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0minputs\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;36m1\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 655\u001B[0;31m \u001B[0moutputs\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mmodel\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0minputs\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;36m0\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 656\u001B[0m \u001B[0mloss\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mloss\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0moutputs\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mlabels\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 657\u001B[0m \u001B[0;32mreturn\u001B[0m \u001B[0mloss\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/torch/nn/modules/module.py\u001B[0m in \u001B[0;36m_call_impl\u001B[0;34m(self, *input, **kwargs)\u001B[0m\n\u001B[1;32m 1049\u001B[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001B[1;32m 1050\u001B[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001B[0;32m-> 1051\u001B[0;31m \u001B[0;32mreturn\u001B[0m \u001B[0mforward_call\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m*\u001B[0m\u001B[0minput\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 1052\u001B[0m \u001B[0;31m# Do not call functions when jit is used\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 1053\u001B[0m \u001B[0mfull_backward_hooks\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mnon_full_backward_hooks\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/timm/models/resnet.py\u001B[0m in \u001B[0;36mforward\u001B[0;34m(self, x)\u001B[0m\n\u001B[1;32m 683\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 684\u001B[0m \u001B[0;32mdef\u001B[0m \u001B[0mforward\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 685\u001B[0;31m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mforward_features\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 686\u001B[0m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mglobal_pool\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 687\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mdrop_rate\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/timm/models/resnet.py\u001B[0m in \u001B[0;36mforward_features\u001B[0;34m(self, x)\u001B[0m\n\u001B[1;32m 676\u001B[0m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mmaxpool\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 677\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 678\u001B[0;31m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mlayer1\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 679\u001B[0m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mlayer2\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 680\u001B[0m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mlayer3\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/torch/nn/modules/module.py\u001B[0m in \u001B[0;36m_call_impl\u001B[0;34m(self, *input, **kwargs)\u001B[0m\n\u001B[1;32m 1049\u001B[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001B[1;32m 1050\u001B[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001B[0;32m-> 1051\u001B[0;31m \u001B[0;32mreturn\u001B[0m \u001B[0mforward_call\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m*\u001B[0m\u001B[0minput\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 1052\u001B[0m \u001B[0;31m# Do not call functions when jit is used\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 1053\u001B[0m \u001B[0mfull_backward_hooks\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mnon_full_backward_hooks\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/torch/nn/modules/container.py\u001B[0m in \u001B[0;36mforward\u001B[0;34m(self, input)\u001B[0m\n\u001B[1;32m 137\u001B[0m \u001B[0;32mdef\u001B[0m \u001B[0mforward\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0minput\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 138\u001B[0m \u001B[0;32mfor\u001B[0m \u001B[0mmodule\u001B[0m \u001B[0;32min\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 139\u001B[0;31m \u001B[0minput\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mmodule\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0minput\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 140\u001B[0m \u001B[0;32mreturn\u001B[0m \u001B[0minput\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 141\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/torch/nn/modules/module.py\u001B[0m in \u001B[0;36m_call_impl\u001B[0;34m(self, *input, **kwargs)\u001B[0m\n\u001B[1;32m 1049\u001B[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001B[1;32m 1050\u001B[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001B[0;32m-> 1051\u001B[0;31m \u001B[0;32mreturn\u001B[0m \u001B[0mforward_call\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m*\u001B[0m\u001B[0minput\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 1052\u001B[0m \u001B[0;31m# Do not call functions when jit is used\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 1053\u001B[0m \u001B[0mfull_backward_hooks\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mnon_full_backward_hooks\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/timm/models/resnet.py\u001B[0m in \u001B[0;36mforward\u001B[0;34m(self, x)\u001B[0m\n\u001B[1;32m 339\u001B[0m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0maa\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 340\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 341\u001B[0;31m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mconv2\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 342\u001B[0m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mbn2\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 343\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mdrop_block\u001B[0m \u001B[0;32mis\u001B[0m \u001B[0;32mnot\u001B[0m \u001B[0;32mNone\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/torch/nn/modules/module.py\u001B[0m in \u001B[0;36m_call_impl\u001B[0;34m(self, *input, **kwargs)\u001B[0m\n\u001B[1;32m 1049\u001B[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001B[1;32m 1050\u001B[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001B[0;32m-> 1051\u001B[0;31m \u001B[0;32mreturn\u001B[0m \u001B[0mforward_call\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m*\u001B[0m\u001B[0minput\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 1052\u001B[0m \u001B[0;31m# Do not call functions when jit is used\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 1053\u001B[0m \u001B[0mfull_backward_hooks\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mnon_full_backward_hooks\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/torch/nn/modules/conv.py\u001B[0m in \u001B[0;36mforward\u001B[0;34m(self, input)\u001B[0m\n\u001B[1;32m 441\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 442\u001B[0m \u001B[0;32mdef\u001B[0m \u001B[0mforward\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0minput\u001B[0m\u001B[0;34m:\u001B[0m \u001B[0mTensor\u001B[0m\u001B[0;34m)\u001B[0m \u001B[0;34m->\u001B[0m \u001B[0mTensor\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 443\u001B[0;31m \u001B[0;32mreturn\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_conv_forward\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0minput\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mweight\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mbias\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 444\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 445\u001B[0m \u001B[0;32mclass\u001B[0m \u001B[0mConv3d\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0m_ConvNd\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/torch/nn/modules/conv.py\u001B[0m in \u001B[0;36m_conv_forward\u001B[0;34m(self, input, weight, bias)\u001B[0m\n\u001B[1;32m 437\u001B[0m \u001B[0mweight\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mbias\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mstride\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 438\u001B[0m _pair(0), self.dilation, self.groups)\n\u001B[0;32m--> 439\u001B[0;31m return F.conv2d(input, weight, bias, self.stride,\n\u001B[0m\u001B[1;32m 440\u001B[0m self.padding, self.dilation, self.groups)\n\u001B[1;32m 441\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;31mKeyboardInterrupt\u001B[0m: "
]
}
],
"source": [
"# start to train mnist, it will take about 30-100 minutes on a cpu machine.\n",
"# if you train on a gpu machine, it will be much faster.\n",
"op.train(training_config, train_dataset=train_data, eval_dataset=eval_data)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"### if you observe loss is decreasing and metric is increasing, it means you are training the model correctly."
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

193
examples/3_read_configs_from_yaml.ipynb

@ -0,0 +1,193 @@
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# Read the configs from a yaml file."
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from resnet_image_embedding import ResnetImageEmbedding\n",
"from torchvision import transforms\n",
"from towhee import dataset\n",
"\n",
"# build an resnet op:\n",
"op = ResnetImageEmbedding('resnet18', num_classes=10)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [],
"source": [
"from towhee.trainer.training_config import dump_default_yaml\n",
"\n",
"# If you want to see the default setting yaml, run dump_default_yaml()\n",
"dump_default_yaml('default_setting.yaml')"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"##### Then you can open `default_setting.yaml` to observe the yaml structure.\n",
"##### Change `batch_size` to 5, `epoch_num` to 3, `tensorboard` to `null`, `output_dir` to `my_output`, `print_steps` to 1, and save it as `my_setting.yaml`"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [
{
"data": {
"text/plain": "TrainingConfig(output_dir='my_output', overwrite_output_dir=True, eval_strategy='epoch', eval_steps=None, batch_size=5, val_batch_size=-1, seed=42, epoch_num=3, dataloader_pin_memory=True, dataloader_drop_last=True, dataloader_num_workers=-1, lr=5e-05, metric='Accuracy', print_steps=1, load_best_model_at_end=False, early_stopping={'mode': 'max', 'monitor': 'eval_epoch_metric', 'patience': 4}, model_checkpoint={'every_n_epoch': 1}, tensorboard=None, loss='CrossEntropyLoss', optimizer='Adam', lr_scheduler_type='linear', warmup_ratio=0.0, warmup_steps=0, device_str=None, n_gpu=-1, sync_bn=False, freeze_bn=False)"
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from towhee.trainer.training_config import TrainingConfig\n",
"\n",
"# now, read from your custom yaml.\n",
"training_config = TrainingConfig()\n",
"training_config.load_from_yaml('my_setting.yaml')\n",
"training_config\n"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [],
"source": [
"# prepare the fake dataset\n",
"fake_transform = transforms.Compose([transforms.ToTensor()])\n",
"train_data = dataset('fake', size=20, transform=fake_transform)\n",
"eval_data = dataset('fake', size=10, transform=fake_transform)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 5,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-03-02 16:00:26,226 - 8666785280 - trainer.py-trainer:390 - WARNING: TrainingConfig(output_dir='my_output', overwrite_output_dir=True, eval_strategy='epoch', eval_steps=None, batch_size=5, val_batch_size=-1, seed=42, epoch_num=3, dataloader_pin_memory=True, dataloader_drop_last=True, dataloader_num_workers=-1, lr=5e-05, metric='Accuracy', print_steps=1, load_best_model_at_end=False, early_stopping={'mode': 'max', 'monitor': 'eval_epoch_metric', 'patience': 4}, model_checkpoint={'every_n_epoch': 1}, tensorboard=None, loss='CrossEntropyLoss', optimizer='Adam', lr_scheduler_type='linear', warmup_ratio=0.0, warmup_steps=0, device_str=None, n_gpu=-1, sync_bn=False, freeze_bn=False)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch=1/3, global_step=1, epoch_loss=2.5702719688415527, epoch_metric=0.0\n",
"epoch=1/3, global_step=2, epoch_loss=2.572024345397949, epoch_metric=0.0\n",
"epoch=1/3, global_step=3, epoch_loss=2.558194160461426, epoch_metric=0.0\n",
"epoch=1/3, global_step=4, epoch_loss=2.558873176574707, epoch_metric=0.15000000596046448\n",
"epoch=1/3, eval_global_step=0, eval_epoch_loss=2.370976686477661, eval_epoch_metric=0.20000000298023224\n",
"epoch=1/3, eval_global_step=1, eval_epoch_loss=2.2873291969299316, eval_epoch_metric=0.20000000298023224\n",
"epoch=2/3, global_step=5, epoch_loss=1.3134113550186157, epoch_metric=0.20000000298023224\n",
"epoch=2/3, global_step=6, epoch_loss=1.3073358535766602, epoch_metric=0.10000000149011612\n",
"epoch=2/3, global_step=7, epoch_loss=1.41914701461792, epoch_metric=0.13333334028720856\n",
"epoch=2/3, global_step=8, epoch_loss=1.3628838062286377, epoch_metric=0.15000000596046448\n",
"epoch=2/3, eval_global_step=2, eval_epoch_loss=1.3158948421478271, eval_epoch_metric=0.20000000298023224\n",
"epoch=2/3, eval_global_step=3, eval_epoch_loss=1.3246530294418335, eval_epoch_metric=0.20000000298023224\n",
"epoch=3/3, global_step=9, epoch_loss=1.4589173793792725, epoch_metric=0.0\n",
"epoch=3/3, global_step=10, epoch_loss=1.4343616962432861, epoch_metric=0.0\n",
"epoch=3/3, global_step=11, epoch_loss=1.3701648712158203, epoch_metric=0.06666667014360428\n",
"epoch=3/3, global_step=12, epoch_loss=1.1501117944717407, epoch_metric=0.10000000149011612\n",
"epoch=3/3, eval_global_step=4, eval_epoch_loss=1.1129425764083862, eval_epoch_metric=0.0\n",
"epoch=3/3, eval_global_step=5, eval_epoch_loss=1.1257113218307495, eval_epoch_metric=0.0\n"
]
}
],
"source": [
"# start training,\n",
"op.train(training_config, train_dataset=train_data, eval_dataset=eval_data)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"### Because you have set the `print_steps` to 1, you will not see the progress bar, instead, you will see the every batch steps result printed on the screen. You can check whether other configs ares work correctly."
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

37
examples/my_setting.yaml

@ -0,0 +1,37 @@
callback:
early_stopping:
mode: max
monitor: eval_epoch_metric
patience: 4
model_checkpoint:
every_n_epoch: 1
tensorboard: null
device:
device_str: null
n_gpu: -1
sync_bn: false
learning:
loss: CrossEntropyLoss
lr: 5.0e-05
lr_scheduler_type: linear
optimizer: Adam
warmup_ratio: 0.0
warmup_steps: 0
logging:
print_steps: 1
metrics:
metric: Accuracy
train:
batch_size: 5
dataloader_drop_last: true
dataloader_num_workers: -1
dataloader_pin_memory: true
epoch_num: 3
eval_steps: null
eval_strategy: epoch
freeze_bn: false
load_best_model_at_end: false
output_dir: my_output
overwrite_output_dir: true
seed: 42
val_batch_size: -1

17
resnet_training_yaml.yaml

@ -2,9 +2,9 @@ callback:
early_stopping: early_stopping:
mode: max mode: max
monitor: eval_epoch_metric monitor: eval_epoch_metric
patience: 2
patience: 4
model_checkpoint: model_checkpoint:
every_n_epoch: 2
every_n_epoch: 1
tensorboard: tensorboard:
comment: '' comment: ''
log_dir: null log_dir: null
@ -20,23 +20,20 @@ learning:
warmup_ratio: 0.0 warmup_ratio: 0.0
warmup_steps: 0 warmup_steps: 0
logging: logging:
logging_dir: null
logging_strategy: steps
print_steps: null print_steps: null
metrics: metrics:
metric: Accuracy metric: Accuracy
train: train:
batch_size: 16
dataloader_drop_last: false
dataloader_num_workers: 0
batch_size: 8
dataloader_drop_last: true
dataloader_num_workers: -1
dataloader_pin_memory: true dataloader_pin_memory: true
epoch_num: 3
epoch_num: 2
eval_steps: null eval_steps: null
eval_strategy: epoch eval_strategy: epoch
freeze_bn: true
freeze_bn: false
load_best_model_at_end: false load_best_model_at_end: false
output_dir: ./output_dir output_dir: ./output_dir
overwrite_output_dir: true overwrite_output_dir: true
resume_from_checkpoint: null
seed: 42 seed: 42
val_batch_size: -1 val_batch_size: -1

52
test.py

@ -2,7 +2,7 @@ import numpy as np
from torch.optim import AdamW from torch.optim import AdamW
from torchvision import transforms from torchvision import transforms
from torchvision.transforms import RandomResizedCrop, Lambda from torchvision.transforms import RandomResizedCrop, Lambda
from towhee.data.dataset.dataset import dataset
from towhee import dataset
from towhee.trainer.modelcard import ModelCard from towhee.trainer.modelcard import ModelCard
from towhee.trainer.training_config import TrainingConfig from towhee.trainer.training_config import TrainingConfig
@ -31,8 +31,8 @@ if __name__ == '__main__':
array_size = np.array(img).shape array_size = np.array(img).shape
towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array) towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array)
op = ResnetImageEmbedding('resnet50', num_classes=10)
op = ResnetImageEmbedding('resnet18', num_classes=10)
# op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data")
old_out = op(towhee_img) old_out = op(towhee_img)
# print(old_out.feature_vector[0][:10]) # print(old_out.feature_vector[0][:10])
# print(old_out.feature_vector[:10]) # print(old_out.feature_vector[:10])
@ -46,26 +46,34 @@ if __name__ == '__main__':
# training_config.overwrite_output_dir=True # training_config.overwrite_output_dir=True
# training_config.epoch_num=3 # training_config.epoch_num=3
# training_config.batch_size=256 # training_config.batch_size=256
# training_config.device_str='cpu'
training_config.device_str='cpu'
training_config.tensorboard = None
training_config.batch_size = 2
training_config.epoch_num = 2
# training_config.n_gpu=-1 # training_config.n_gpu=-1
# training_config.save_to_yaml(yaml_path) # training_config.save_to_yaml(yaml_path)
# mnist_transform = transforms.Compose([transforms.ToTensor(),
# RandomResizedCrop(224),
# Lambda(lambda x: x.repeat(3, 1, 1)),
# transforms.Normalize(mean=[0.5], std=[0.5])])
# train_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=True)
# eval_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=False)
mnist_transform = transforms.Compose([transforms.ToTensor(),
Lambda(lambda x: x.repeat(3, 1, 1)),
transforms.Normalize(mean=[0.1307,0.1307,0.1307], std=[0.3081,0.3081,0.3081])])
train_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=True)
eval_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=False)
# training_config.output_dir = 'mnist_output' # training_config.output_dir = 'mnist_output'
# fake_transform = transforms.Compose([transforms.ToTensor(),
# RandomResizedCrop(224),])
# train_data = dataset('fake', size=100, transform=fake_transform)
# eval_data = dataset('fake', size=10, transform=fake_transform)
# training_config.output_dir = 'mnist_0228_5'
# op.model_card = ModelCard(datasets='mnist dataset') # op.model_card = ModelCard(datasets='mnist dataset')
fake_transform = transforms.Compose([transforms.ToTensor(),
RandomResizedCrop(224)])
train_data = dataset('fake', size=20, transform=fake_transform)
eval_data = dataset('fake', size=10, transform=fake_transform)
training_config.output_dir = 'fake_output'
op.model_card = ModelCard(datasets='fake dataset')
# fake_transform = transforms.Compose([transforms.ToTensor()])
# # RandomResizedCrop(224)])
# train_data = dataset('fake', size=20, transform=fake_transform)
# eval_data = dataset('fake', size=10, transform=fake_transform)
# training_config.output_dir = 'fake_output'
# op.model_card = ModelCard(datasets='fake dataset')
# op.trainer
# trainer = op.setup_trainer() # trainer = op.setup_trainer()
# print(op.get_model()) # print(op.get_model())
# my_optimimzer = AdamW(op.get_model().parameters(), lr=0.002, betas=(0.91, 0.98), eps=1e-08, weight_decay=0.01, amsgrad=False) # my_optimimzer = AdamW(op.get_model().parameters(), lr=0.002, betas=(0.91, 0.98), eps=1e-08, weight_decay=0.01, amsgrad=False)
@ -84,15 +92,13 @@ if __name__ == '__main__':
# op.trainer.set_optimizer() # op.trainer.set_optimizer()
# trainer = op.setup_trainer(training_config, train_dataset=train_data, eval_dataset=eval_data) # trainer = op.setup_trainer(training_config, train_dataset=train_data, eval_dataset=eval_data)
freezer = LayerFreezer(op.get_model())
freezer.set_slice(-1)
# freezer = LayerFreezer(op.get_model())
# freezer.set_slice(-1)
op.train(training_config, train_dataset=train_data, eval_dataset=eval_data) op.train(training_config, train_dataset=train_data, eval_dataset=eval_data)
# op.train(training_config, train_dataset=train_data, eval_dataset=eval_data,
# resume_checkpoint_path=training_config.output_dir + '/epoch_2')
# op.train(training_config, train_dataset=train_data, eval_dataset=eval_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2')
# op.save('./test_save') # op.save('./test_save')
# op.load('./test_save')\ # op.load('./test_save')\
new_out = op(towhee_img)
assert (new_out.feature_vector == old_out.feature_vector).all()
# new_out = op(towhee_img)
# assert (new_out.feature_vector==old_out.feature_vector).all()

Loading…
Cancel
Save