towhee
/
resnet-image-embedding
copied
7 changed files with 604 additions and 33 deletions
@ -0,0 +1,39 @@ |
|||||
|
callback: |
||||
|
early_stopping: |
||||
|
mode: max |
||||
|
monitor: eval_epoch_metric |
||||
|
patience: 4 |
||||
|
model_checkpoint: |
||||
|
every_n_epoch: 1 |
||||
|
tensorboard: |
||||
|
comment: '' |
||||
|
log_dir: null |
||||
|
device: |
||||
|
device_str: null |
||||
|
n_gpu: -1 |
||||
|
sync_bn: false |
||||
|
learning: |
||||
|
loss: CrossEntropyLoss |
||||
|
lr: 5.0e-05 |
||||
|
lr_scheduler_type: linear |
||||
|
optimizer: Adam |
||||
|
warmup_ratio: 0.0 |
||||
|
warmup_steps: 0 |
||||
|
logging: |
||||
|
print_steps: null |
||||
|
metrics: |
||||
|
metric: Accuracy |
||||
|
train: |
||||
|
batch_size: 8 |
||||
|
dataloader_drop_last: true |
||||
|
dataloader_num_workers: -1 |
||||
|
dataloader_pin_memory: true |
||||
|
epoch_num: 2 |
||||
|
eval_steps: null |
||||
|
eval_strategy: epoch |
||||
|
freeze_bn: false |
||||
|
load_best_model_at_end: false |
||||
|
output_dir: ./output_dir |
||||
|
overwrite_output_dir: true |
||||
|
seed: 42 |
||||
|
val_batch_size: -1 |
@ -0,0 +1,136 @@ |
|||||
|
{ |
||||
|
"cells": [ |
||||
|
{ |
||||
|
"cell_type": "markdown", |
||||
|
"source": [ |
||||
|
"# Quick start to train an operator on a toy fake dataset." |
||||
|
], |
||||
|
"metadata": { |
||||
|
"collapsed": false |
||||
|
} |
||||
|
}, |
||||
|
{ |
||||
|
"cell_type": "code", |
||||
|
"execution_count": 6, |
||||
|
"metadata": { |
||||
|
"collapsed": true |
||||
|
}, |
||||
|
"outputs": [], |
||||
|
"source": [ |
||||
|
"from resnet_image_embedding import ResnetImageEmbedding\n", |
||||
|
"from towhee.trainer.training_config import TrainingConfig\n", |
||||
|
"from torchvision import transforms\n", |
||||
|
"from towhee import dataset\n", |
||||
|
"\n", |
||||
|
"# build an resnet op:\n", |
||||
|
"op = ResnetImageEmbedding('resnet18', num_classes=10)" |
||||
|
] |
||||
|
}, |
||||
|
{ |
||||
|
"cell_type": "code", |
||||
|
"execution_count": 10, |
||||
|
"outputs": [], |
||||
|
"source": [ |
||||
|
"# build a training config:\n", |
||||
|
"training_config = TrainingConfig()\n", |
||||
|
"training_config.batch_size = 2\n", |
||||
|
"training_config.epoch_num = 2\n", |
||||
|
"training_config.tensorboard = None\n", |
||||
|
"training_config.output_dir = 'quick_start_output'" |
||||
|
], |
||||
|
"metadata": { |
||||
|
"collapsed": false, |
||||
|
"pycharm": { |
||||
|
"name": "#%%\n" |
||||
|
} |
||||
|
} |
||||
|
}, |
||||
|
{ |
||||
|
"cell_type": "code", |
||||
|
"execution_count": 11, |
||||
|
"outputs": [], |
||||
|
"source": [ |
||||
|
"# prepare the dataset\n", |
||||
|
"fake_transform = transforms.Compose([transforms.ToTensor()])\n", |
||||
|
"train_data = dataset('fake', size=20, transform=fake_transform)\n", |
||||
|
"eval_data = dataset('fake', size=10, transform=fake_transform)" |
||||
|
], |
||||
|
"metadata": { |
||||
|
"collapsed": false, |
||||
|
"pycharm": { |
||||
|
"name": "#%%\n" |
||||
|
} |
||||
|
} |
||||
|
}, |
||||
|
{ |
||||
|
"cell_type": "code", |
||||
|
"execution_count": 9, |
||||
|
"outputs": [ |
||||
|
{ |
||||
|
"name": "stderr", |
||||
|
"output_type": "stream", |
||||
|
"text": [ |
||||
|
"2022-03-02 15:09:06,334 - 8665081344 - trainer.py-trainer:390 - WARNING: TrainingConfig(output_dir='quick_start_output', overwrite_output_dir=True, eval_strategy='epoch', eval_steps=None, batch_size=2, val_batch_size=-1, seed=42, epoch_num=2, dataloader_pin_memory=True, dataloader_drop_last=True, dataloader_num_workers=-1, lr=5e-05, metric='Accuracy', print_steps=None, load_best_model_at_end=False, early_stopping={'monitor': 'eval_epoch_metric', 'patience': 4, 'mode': 'max'}, model_checkpoint={'every_n_epoch': 1}, tensorboard=None, loss='CrossEntropyLoss', optimizer='Adam', lr_scheduler_type='linear', warmup_ratio=0.0, warmup_steps=0, device_str=None, n_gpu=-1, sync_bn=False, freeze_bn=False)\n", |
||||
|
"[epoch 1/2] loss=2.402, metric=0.0, eval_loss=2.254, eval_metric=0.0: 100%|██████████| 10/10 [00:32<00:00, 3.25s/step]\n", |
||||
|
"[epoch 2/2] loss=1.88, metric=0.1, eval_loss=1.855, eval_metric=0.1: 100%|██████████| 10/10 [00:22<00:00, 1.14step/s] " |
||||
|
] |
||||
|
} |
||||
|
], |
||||
|
"source": [ |
||||
|
"# start training, it will take about 2 minute on a cpu machine.\n", |
||||
|
"op.train(training_config, train_dataset=train_data, eval_dataset=eval_data)" |
||||
|
], |
||||
|
"metadata": { |
||||
|
"collapsed": false, |
||||
|
"pycharm": { |
||||
|
"name": "#%%\n" |
||||
|
} |
||||
|
} |
||||
|
}, |
||||
|
{ |
||||
|
"cell_type": "markdown", |
||||
|
"source": [ |
||||
|
"### If you see the two epochs progress bar finish its schedule and a `quick_start_output` folder result, it means you succeeded." |
||||
|
], |
||||
|
"metadata": { |
||||
|
"collapsed": false, |
||||
|
"pycharm": { |
||||
|
"name": "#%% md\n" |
||||
|
} |
||||
|
} |
||||
|
}, |
||||
|
{ |
||||
|
"cell_type": "code", |
||||
|
"execution_count": null, |
||||
|
"outputs": [], |
||||
|
"source": [], |
||||
|
"metadata": { |
||||
|
"collapsed": false, |
||||
|
"pycharm": { |
||||
|
"name": "#%%\n" |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
], |
||||
|
"metadata": { |
||||
|
"kernelspec": { |
||||
|
"display_name": "Python 3", |
||||
|
"language": "python", |
||||
|
"name": "python3" |
||||
|
}, |
||||
|
"language_info": { |
||||
|
"codemirror_mode": { |
||||
|
"name": "ipython", |
||||
|
"version": 2 |
||||
|
}, |
||||
|
"file_extension": ".py", |
||||
|
"mimetype": "text/x-python", |
||||
|
"name": "python", |
||||
|
"nbconvert_exporter": "python", |
||||
|
"pygments_lexer": "ipython2", |
||||
|
"version": "2.7.6" |
||||
|
} |
||||
|
}, |
||||
|
"nbformat": 4, |
||||
|
"nbformat_minor": 0 |
||||
|
} |
@ -0,0 +1,163 @@ |
|||||
|
{ |
||||
|
"cells": [ |
||||
|
{ |
||||
|
"cell_type": "markdown", |
||||
|
"source": [ |
||||
|
"# Train resnet18 operator on mnist dataset." |
||||
|
], |
||||
|
"metadata": { |
||||
|
"collapsed": false |
||||
|
} |
||||
|
}, |
||||
|
{ |
||||
|
"cell_type": "code", |
||||
|
"execution_count": 1, |
||||
|
"metadata": { |
||||
|
"collapsed": true |
||||
|
}, |
||||
|
"outputs": [], |
||||
|
"source": [ |
||||
|
"from resnet_image_embedding import ResnetImageEmbedding\n", |
||||
|
"from towhee.trainer.training_config import TrainingConfig\n", |
||||
|
"from torchvision import transforms\n", |
||||
|
"from towhee import dataset\n", |
||||
|
"from torchvision.transforms import Lambda\n", |
||||
|
"\n", |
||||
|
"# build a resnet op with 10 classes output, because the mnist has 10 classes:\n", |
||||
|
"op = ResnetImageEmbedding('resnet18', num_classes=10)" |
||||
|
] |
||||
|
}, |
||||
|
{ |
||||
|
"cell_type": "code", |
||||
|
"execution_count": 2, |
||||
|
"outputs": [], |
||||
|
"source": [ |
||||
|
"# build a training config:\n", |
||||
|
"training_config = TrainingConfig()\n", |
||||
|
"training_config.batch_size = 64\n", |
||||
|
"training_config.epoch_num = 5\n", |
||||
|
"training_config.tensorboard = None\n", |
||||
|
"training_config.output_dir = 'mnist_output'\n", |
||||
|
"training_config.dataloader_num_workers = 0" |
||||
|
], |
||||
|
"metadata": { |
||||
|
"collapsed": false, |
||||
|
"pycharm": { |
||||
|
"name": "#%%\n" |
||||
|
} |
||||
|
} |
||||
|
}, |
||||
|
{ |
||||
|
"cell_type": "code", |
||||
|
"execution_count": 3, |
||||
|
"outputs": [], |
||||
|
"source": [ |
||||
|
"# prepare the mnist data\n", |
||||
|
"mnist_transform = transforms.Compose([transforms.ToTensor(),\n", |
||||
|
" Lambda(lambda x: x.repeat(3, 1, 1)),\n", |
||||
|
" transforms.Normalize(mean=[0.1307,0.1307,0.1307], std=[0.3081,0.3081,0.3081])])\n", |
||||
|
"train_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=True)\n", |
||||
|
"eval_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=False)" |
||||
|
], |
||||
|
"metadata": { |
||||
|
"collapsed": false, |
||||
|
"pycharm": { |
||||
|
"name": "#%%\n" |
||||
|
} |
||||
|
} |
||||
|
}, |
||||
|
{ |
||||
|
"cell_type": "code", |
||||
|
"execution_count": 4, |
||||
|
"outputs": [ |
||||
|
{ |
||||
|
"name": "stderr", |
||||
|
"output_type": "stream", |
||||
|
"text": [ |
||||
|
"2022-03-02 15:24:19,702 - 8669324800 - trainer.py-trainer:390 - WARNING: TrainingConfig(output_dir='mnist_output', overwrite_output_dir=True, eval_strategy='epoch', eval_steps=None, batch_size=64, val_batch_size=-1, seed=42, epoch_num=5, dataloader_pin_memory=True, dataloader_drop_last=True, dataloader_num_workers=0, lr=5e-05, metric='Accuracy', print_steps=None, load_best_model_at_end=False, early_stopping={'monitor': 'eval_epoch_metric', 'patience': 4, 'mode': 'max'}, model_checkpoint={'every_n_epoch': 1}, tensorboard=None, loss='CrossEntropyLoss', optimizer='Adam', lr_scheduler_type='linear', warmup_ratio=0.0, warmup_steps=0, device_str=None, n_gpu=-1, sync_bn=False, freeze_bn=False)\n", |
||||
|
"[epoch 1/5] loss=0.388, metric=0.884: 41%|████ | 383/937 [07:38<11:00, 1.19s/step]" |
||||
|
] |
||||
|
}, |
||||
|
{ |
||||
|
"ename": "KeyboardInterrupt", |
||||
|
"evalue": "", |
||||
|
"output_type": "error", |
||||
|
"traceback": [ |
||||
|
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", |
||||
|
"\u001B[0;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)", |
||||
|
"\u001B[0;32m/var/folders/wn/4wflyq8x0f9bhkwryvss30880000gn/T/ipykernel_5732/1544844912.py\u001B[0m in \u001B[0;36m<module>\u001B[0;34m\u001B[0m\n\u001B[0;32m----> 1\u001B[0;31m \u001B[0mop\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtrain\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mtraining_config\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mtrain_dataset\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0mtrain_data\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0meval_dataset\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0meval_data\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m", |
||||
|
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/towhee-0.4.1.dev110-py3.9.egg/towhee/operator/base.py\u001B[0m in \u001B[0;36mtrain\u001B[0;34m(self, training_config, train_dataset, eval_dataset, resume_checkpoint_path, **kwargs)\u001B[0m\n\u001B[1;32m 136\u001B[0m \"\"\"\n\u001B[1;32m 137\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0msetup_trainer\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mtraining_config\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mtrain_dataset\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0meval_dataset\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 138\u001B[0;31m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtrainer\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtrain\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mresume_checkpoint_path\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 139\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 140\u001B[0m \u001B[0;34m@\u001B[0m\u001B[0mproperty\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", |
||||
|
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/towhee-0.4.1.dev110-py3.9.egg/towhee/trainer/trainer.py\u001B[0m in \u001B[0;36mtrain\u001B[0;34m(self, resume_checkpoint_path)\u001B[0m\n\u001B[1;32m 281\u001B[0m \u001B[0;32melse\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 282\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mdistributed\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;32mFalse\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 283\u001B[0;31m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mrun_train\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mresume_checkpoint_path\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 284\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 285\u001B[0m \u001B[0;32mdef\u001B[0m \u001B[0m_spawn_train_process\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mresume_checkpoint_path\u001B[0m\u001B[0;34m:\u001B[0m \u001B[0mOptional\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0mstr\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", |
||||
|
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/towhee-0.4.1.dev110-py3.9.egg/towhee/trainer/trainer.py\u001B[0m in \u001B[0;36mrun_train\u001B[0;34m(self, resume_checkpoint_path, rank, world_size)\u001B[0m\n\u001B[1;32m 407\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mcallbacks\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mon_train_batch_begin\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0minputs\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mlogs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 408\u001B[0m \u001B[0minputs\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mprepare_inputs\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0minputs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 409\u001B[0;31m \u001B[0mstep_logs\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtrain_step\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mmodel\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0minputs\u001B[0m\u001B[0;34m)\u001B[0m \u001B[0;31m# , train_dataloader)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 410\u001B[0m \u001B[0mlogs\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m\"lr\"\u001B[0m\u001B[0;34m]\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mlr_scheduler\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mget_lr\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;36m0\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 411\u001B[0m \u001B[0mlogs\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m\"global_step\"\u001B[0m\u001B[0;34m]\u001B[0m \u001B[0;34m+=\u001B[0m \u001B[0;36m1\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", |
||||
|
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/towhee-0.4.1.dev110-py3.9.egg/towhee/trainer/trainer.py\u001B[0m in \u001B[0;36mtrain_step\u001B[0;34m(self, model, inputs)\u001B[0m\n\u001B[1;32m 615\u001B[0m \u001B[0mStep\u001B[0m \u001B[0mlogs\u001B[0m \u001B[0mwhich\u001B[0m \u001B[0mcontains\u001B[0m \u001B[0mthe\u001B[0m \u001B[0mstep\u001B[0m \u001B[0mloss\u001B[0m \u001B[0;32mand\u001B[0m \u001B[0mmetric\u001B[0m \u001B[0minfos\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 616\u001B[0m \"\"\"\n\u001B[0;32m--> 617\u001B[0;31m \u001B[0mstep_loss\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mcompute_loss\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mmodel\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0minputs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 618\u001B[0m \u001B[0mstep_loss\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mreduce_value\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mstep_loss\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0maverage\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mTrue\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 619\u001B[0m \u001B[0mstep_loss\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mbackward\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", |
||||
|
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/towhee-0.4.1.dev110-py3.9.egg/towhee/trainer/trainer.py\u001B[0m in \u001B[0;36mcompute_loss\u001B[0;34m(self, model, inputs)\u001B[0m\n\u001B[1;32m 653\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mset_train_mode\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mmodel\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 654\u001B[0m \u001B[0mlabels\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0minputs\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;36m1\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 655\u001B[0;31m \u001B[0moutputs\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mmodel\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0minputs\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;36m0\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 656\u001B[0m \u001B[0mloss\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mloss\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0moutputs\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mlabels\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 657\u001B[0m \u001B[0;32mreturn\u001B[0m \u001B[0mloss\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", |
||||
|
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/torch/nn/modules/module.py\u001B[0m in \u001B[0;36m_call_impl\u001B[0;34m(self, *input, **kwargs)\u001B[0m\n\u001B[1;32m 1049\u001B[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001B[1;32m 1050\u001B[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001B[0;32m-> 1051\u001B[0;31m \u001B[0;32mreturn\u001B[0m \u001B[0mforward_call\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m*\u001B[0m\u001B[0minput\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 1052\u001B[0m \u001B[0;31m# Do not call functions when jit is used\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 1053\u001B[0m \u001B[0mfull_backward_hooks\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mnon_full_backward_hooks\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", |
||||
|
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/timm/models/resnet.py\u001B[0m in \u001B[0;36mforward\u001B[0;34m(self, x)\u001B[0m\n\u001B[1;32m 683\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 684\u001B[0m \u001B[0;32mdef\u001B[0m \u001B[0mforward\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 685\u001B[0;31m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mforward_features\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 686\u001B[0m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mglobal_pool\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 687\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mdrop_rate\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", |
||||
|
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/timm/models/resnet.py\u001B[0m in \u001B[0;36mforward_features\u001B[0;34m(self, x)\u001B[0m\n\u001B[1;32m 676\u001B[0m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mmaxpool\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 677\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 678\u001B[0;31m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mlayer1\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 679\u001B[0m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mlayer2\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 680\u001B[0m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mlayer3\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", |
||||
|
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/torch/nn/modules/module.py\u001B[0m in \u001B[0;36m_call_impl\u001B[0;34m(self, *input, **kwargs)\u001B[0m\n\u001B[1;32m 1049\u001B[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001B[1;32m 1050\u001B[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001B[0;32m-> 1051\u001B[0;31m \u001B[0;32mreturn\u001B[0m \u001B[0mforward_call\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m*\u001B[0m\u001B[0minput\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 1052\u001B[0m \u001B[0;31m# Do not call functions when jit is used\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 1053\u001B[0m \u001B[0mfull_backward_hooks\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mnon_full_backward_hooks\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", |
||||
|
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/torch/nn/modules/container.py\u001B[0m in \u001B[0;36mforward\u001B[0;34m(self, input)\u001B[0m\n\u001B[1;32m 137\u001B[0m \u001B[0;32mdef\u001B[0m \u001B[0mforward\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0minput\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 138\u001B[0m \u001B[0;32mfor\u001B[0m \u001B[0mmodule\u001B[0m \u001B[0;32min\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 139\u001B[0;31m \u001B[0minput\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mmodule\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0minput\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 140\u001B[0m \u001B[0;32mreturn\u001B[0m \u001B[0minput\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 141\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n", |
||||
|
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/torch/nn/modules/module.py\u001B[0m in \u001B[0;36m_call_impl\u001B[0;34m(self, *input, **kwargs)\u001B[0m\n\u001B[1;32m 1049\u001B[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001B[1;32m 1050\u001B[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001B[0;32m-> 1051\u001B[0;31m \u001B[0;32mreturn\u001B[0m \u001B[0mforward_call\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m*\u001B[0m\u001B[0minput\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 1052\u001B[0m \u001B[0;31m# Do not call functions when jit is used\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 1053\u001B[0m \u001B[0mfull_backward_hooks\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mnon_full_backward_hooks\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", |
||||
|
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/timm/models/resnet.py\u001B[0m in \u001B[0;36mforward\u001B[0;34m(self, x)\u001B[0m\n\u001B[1;32m 339\u001B[0m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0maa\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 340\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 341\u001B[0;31m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mconv2\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 342\u001B[0m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mbn2\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 343\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mdrop_block\u001B[0m \u001B[0;32mis\u001B[0m \u001B[0;32mnot\u001B[0m \u001B[0;32mNone\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", |
||||
|
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/torch/nn/modules/module.py\u001B[0m in \u001B[0;36m_call_impl\u001B[0;34m(self, *input, **kwargs)\u001B[0m\n\u001B[1;32m 1049\u001B[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001B[1;32m 1050\u001B[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001B[0;32m-> 1051\u001B[0;31m \u001B[0;32mreturn\u001B[0m \u001B[0mforward_call\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m*\u001B[0m\u001B[0minput\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 1052\u001B[0m \u001B[0;31m# Do not call functions when jit is used\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 1053\u001B[0m \u001B[0mfull_backward_hooks\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mnon_full_backward_hooks\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", |
||||
|
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/torch/nn/modules/conv.py\u001B[0m in \u001B[0;36mforward\u001B[0;34m(self, input)\u001B[0m\n\u001B[1;32m 441\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 442\u001B[0m \u001B[0;32mdef\u001B[0m \u001B[0mforward\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0minput\u001B[0m\u001B[0;34m:\u001B[0m \u001B[0mTensor\u001B[0m\u001B[0;34m)\u001B[0m \u001B[0;34m->\u001B[0m \u001B[0mTensor\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 443\u001B[0;31m \u001B[0;32mreturn\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_conv_forward\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0minput\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mweight\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mbias\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 444\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 445\u001B[0m \u001B[0;32mclass\u001B[0m \u001B[0mConv3d\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0m_ConvNd\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", |
||||
|
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/torch/nn/modules/conv.py\u001B[0m in \u001B[0;36m_conv_forward\u001B[0;34m(self, input, weight, bias)\u001B[0m\n\u001B[1;32m 437\u001B[0m \u001B[0mweight\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mbias\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mstride\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 438\u001B[0m _pair(0), self.dilation, self.groups)\n\u001B[0;32m--> 439\u001B[0;31m return F.conv2d(input, weight, bias, self.stride,\n\u001B[0m\u001B[1;32m 440\u001B[0m self.padding, self.dilation, self.groups)\n\u001B[1;32m 441\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n", |
||||
|
"\u001B[0;31mKeyboardInterrupt\u001B[0m: " |
||||
|
] |
||||
|
} |
||||
|
], |
||||
|
"source": [ |
||||
|
"# start to train mnist, it will take about 30-100 minutes on a cpu machine.\n", |
||||
|
"# if you train on a gpu machine, it will be much faster.\n", |
||||
|
"op.train(training_config, train_dataset=train_data, eval_dataset=eval_data)" |
||||
|
], |
||||
|
"metadata": { |
||||
|
"collapsed": false, |
||||
|
"pycharm": { |
||||
|
"name": "#%%\n" |
||||
|
} |
||||
|
} |
||||
|
}, |
||||
|
{ |
||||
|
"cell_type": "markdown", |
||||
|
"source": [ |
||||
|
"### if you observe loss is decreasing and metric is increasing, it means you are training the model correctly." |
||||
|
], |
||||
|
"metadata": { |
||||
|
"collapsed": false |
||||
|
} |
||||
|
}, |
||||
|
{ |
||||
|
"cell_type": "code", |
||||
|
"execution_count": null, |
||||
|
"outputs": [], |
||||
|
"source": [], |
||||
|
"metadata": { |
||||
|
"collapsed": false, |
||||
|
"pycharm": { |
||||
|
"name": "#%%\n" |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
], |
||||
|
"metadata": { |
||||
|
"kernelspec": { |
||||
|
"display_name": "Python 3", |
||||
|
"language": "python", |
||||
|
"name": "python3" |
||||
|
}, |
||||
|
"language_info": { |
||||
|
"codemirror_mode": { |
||||
|
"name": "ipython", |
||||
|
"version": 2 |
||||
|
}, |
||||
|
"file_extension": ".py", |
||||
|
"mimetype": "text/x-python", |
||||
|
"name": "python", |
||||
|
"nbconvert_exporter": "python", |
||||
|
"pygments_lexer": "ipython2", |
||||
|
"version": "2.7.6" |
||||
|
} |
||||
|
}, |
||||
|
"nbformat": 4, |
||||
|
"nbformat_minor": 0 |
||||
|
} |
@ -0,0 +1,193 @@ |
|||||
|
{ |
||||
|
"cells": [ |
||||
|
{ |
||||
|
"cell_type": "markdown", |
||||
|
"source": [ |
||||
|
"# Read the configs from a yaml file." |
||||
|
], |
||||
|
"metadata": { |
||||
|
"collapsed": false |
||||
|
} |
||||
|
}, |
||||
|
{ |
||||
|
"cell_type": "code", |
||||
|
"execution_count": 1, |
||||
|
"metadata": { |
||||
|
"collapsed": true |
||||
|
}, |
||||
|
"outputs": [], |
||||
|
"source": [ |
||||
|
"from resnet_image_embedding import ResnetImageEmbedding\n", |
||||
|
"from torchvision import transforms\n", |
||||
|
"from towhee import dataset\n", |
||||
|
"\n", |
||||
|
"# build an resnet op:\n", |
||||
|
"op = ResnetImageEmbedding('resnet18', num_classes=10)" |
||||
|
] |
||||
|
}, |
||||
|
{ |
||||
|
"cell_type": "code", |
||||
|
"execution_count": 2, |
||||
|
"outputs": [], |
||||
|
"source": [ |
||||
|
"from towhee.trainer.training_config import dump_default_yaml\n", |
||||
|
"\n", |
||||
|
"# If you want to see the default setting yaml, run dump_default_yaml()\n", |
||||
|
"dump_default_yaml('default_setting.yaml')" |
||||
|
], |
||||
|
"metadata": { |
||||
|
"collapsed": false, |
||||
|
"pycharm": { |
||||
|
"name": "#%%\n" |
||||
|
} |
||||
|
} |
||||
|
}, |
||||
|
{ |
||||
|
"cell_type": "markdown", |
||||
|
"source": [ |
||||
|
"##### Then you can open `default_setting.yaml` to observe the yaml structure.\n", |
||||
|
"##### Change `batch_size` to 5, `epoch_num` to 3, `tensorboard` to `null`, `output_dir` to `my_output`, `print_steps` to 1, and save it as `my_setting.yaml`" |
||||
|
], |
||||
|
"metadata": { |
||||
|
"collapsed": false |
||||
|
} |
||||
|
}, |
||||
|
{ |
||||
|
"cell_type": "code", |
||||
|
"execution_count": 3, |
||||
|
"outputs": [ |
||||
|
{ |
||||
|
"data": { |
||||
|
"text/plain": "TrainingConfig(output_dir='my_output', overwrite_output_dir=True, eval_strategy='epoch', eval_steps=None, batch_size=5, val_batch_size=-1, seed=42, epoch_num=3, dataloader_pin_memory=True, dataloader_drop_last=True, dataloader_num_workers=-1, lr=5e-05, metric='Accuracy', print_steps=1, load_best_model_at_end=False, early_stopping={'mode': 'max', 'monitor': 'eval_epoch_metric', 'patience': 4}, model_checkpoint={'every_n_epoch': 1}, tensorboard=None, loss='CrossEntropyLoss', optimizer='Adam', lr_scheduler_type='linear', warmup_ratio=0.0, warmup_steps=0, device_str=None, n_gpu=-1, sync_bn=False, freeze_bn=False)" |
||||
|
}, |
||||
|
"execution_count": 3, |
||||
|
"metadata": {}, |
||||
|
"output_type": "execute_result" |
||||
|
} |
||||
|
], |
||||
|
"source": [ |
||||
|
"from towhee.trainer.training_config import TrainingConfig\n", |
||||
|
"\n", |
||||
|
"# now, read from your custom yaml.\n", |
||||
|
"training_config = TrainingConfig()\n", |
||||
|
"training_config.load_from_yaml('my_setting.yaml')\n", |
||||
|
"training_config\n" |
||||
|
], |
||||
|
"metadata": { |
||||
|
"collapsed": false, |
||||
|
"pycharm": { |
||||
|
"name": "#%%\n" |
||||
|
} |
||||
|
} |
||||
|
}, |
||||
|
{ |
||||
|
"cell_type": "code", |
||||
|
"execution_count": 4, |
||||
|
"outputs": [], |
||||
|
"source": [ |
||||
|
"# prepare the fake dataset\n", |
||||
|
"fake_transform = transforms.Compose([transforms.ToTensor()])\n", |
||||
|
"train_data = dataset('fake', size=20, transform=fake_transform)\n", |
||||
|
"eval_data = dataset('fake', size=10, transform=fake_transform)" |
||||
|
], |
||||
|
"metadata": { |
||||
|
"collapsed": false, |
||||
|
"pycharm": { |
||||
|
"name": "#%%\n" |
||||
|
} |
||||
|
} |
||||
|
}, |
||||
|
{ |
||||
|
"cell_type": "code", |
||||
|
"execution_count": 5, |
||||
|
"outputs": [ |
||||
|
{ |
||||
|
"name": "stderr", |
||||
|
"output_type": "stream", |
||||
|
"text": [ |
||||
|
"2022-03-02 16:00:26,226 - 8666785280 - trainer.py-trainer:390 - WARNING: TrainingConfig(output_dir='my_output', overwrite_output_dir=True, eval_strategy='epoch', eval_steps=None, batch_size=5, val_batch_size=-1, seed=42, epoch_num=3, dataloader_pin_memory=True, dataloader_drop_last=True, dataloader_num_workers=-1, lr=5e-05, metric='Accuracy', print_steps=1, load_best_model_at_end=False, early_stopping={'mode': 'max', 'monitor': 'eval_epoch_metric', 'patience': 4}, model_checkpoint={'every_n_epoch': 1}, tensorboard=None, loss='CrossEntropyLoss', optimizer='Adam', lr_scheduler_type='linear', warmup_ratio=0.0, warmup_steps=0, device_str=None, n_gpu=-1, sync_bn=False, freeze_bn=False)\n" |
||||
|
] |
||||
|
}, |
||||
|
{ |
||||
|
"name": "stdout", |
||||
|
"output_type": "stream", |
||||
|
"text": [ |
||||
|
"epoch=1/3, global_step=1, epoch_loss=2.5702719688415527, epoch_metric=0.0\n", |
||||
|
"epoch=1/3, global_step=2, epoch_loss=2.572024345397949, epoch_metric=0.0\n", |
||||
|
"epoch=1/3, global_step=3, epoch_loss=2.558194160461426, epoch_metric=0.0\n", |
||||
|
"epoch=1/3, global_step=4, epoch_loss=2.558873176574707, epoch_metric=0.15000000596046448\n", |
||||
|
"epoch=1/3, eval_global_step=0, eval_epoch_loss=2.370976686477661, eval_epoch_metric=0.20000000298023224\n", |
||||
|
"epoch=1/3, eval_global_step=1, eval_epoch_loss=2.2873291969299316, eval_epoch_metric=0.20000000298023224\n", |
||||
|
"epoch=2/3, global_step=5, epoch_loss=1.3134113550186157, epoch_metric=0.20000000298023224\n", |
||||
|
"epoch=2/3, global_step=6, epoch_loss=1.3073358535766602, epoch_metric=0.10000000149011612\n", |
||||
|
"epoch=2/3, global_step=7, epoch_loss=1.41914701461792, epoch_metric=0.13333334028720856\n", |
||||
|
"epoch=2/3, global_step=8, epoch_loss=1.3628838062286377, epoch_metric=0.15000000596046448\n", |
||||
|
"epoch=2/3, eval_global_step=2, eval_epoch_loss=1.3158948421478271, eval_epoch_metric=0.20000000298023224\n", |
||||
|
"epoch=2/3, eval_global_step=3, eval_epoch_loss=1.3246530294418335, eval_epoch_metric=0.20000000298023224\n", |
||||
|
"epoch=3/3, global_step=9, epoch_loss=1.4589173793792725, epoch_metric=0.0\n", |
||||
|
"epoch=3/3, global_step=10, epoch_loss=1.4343616962432861, epoch_metric=0.0\n", |
||||
|
"epoch=3/3, global_step=11, epoch_loss=1.3701648712158203, epoch_metric=0.06666667014360428\n", |
||||
|
"epoch=3/3, global_step=12, epoch_loss=1.1501117944717407, epoch_metric=0.10000000149011612\n", |
||||
|
"epoch=3/3, eval_global_step=4, eval_epoch_loss=1.1129425764083862, eval_epoch_metric=0.0\n", |
||||
|
"epoch=3/3, eval_global_step=5, eval_epoch_loss=1.1257113218307495, eval_epoch_metric=0.0\n" |
||||
|
] |
||||
|
} |
||||
|
], |
||||
|
"source": [ |
||||
|
"# start training,\n", |
||||
|
"op.train(training_config, train_dataset=train_data, eval_dataset=eval_data)" |
||||
|
], |
||||
|
"metadata": { |
||||
|
"collapsed": false, |
||||
|
"pycharm": { |
||||
|
"name": "#%%\n" |
||||
|
} |
||||
|
} |
||||
|
}, |
||||
|
{ |
||||
|
"cell_type": "markdown", |
||||
|
"source": [ |
||||
|
"### Because you have set the `print_steps` to 1, you will not see the progress bar, instead, you will see the every batch steps result printed on the screen. You can check whether other configs ares work correctly." |
||||
|
], |
||||
|
"metadata": { |
||||
|
"collapsed": false, |
||||
|
"pycharm": { |
||||
|
"name": "#%% md\n" |
||||
|
} |
||||
|
} |
||||
|
}, |
||||
|
{ |
||||
|
"cell_type": "code", |
||||
|
"execution_count": null, |
||||
|
"outputs": [], |
||||
|
"source": [], |
||||
|
"metadata": { |
||||
|
"collapsed": false, |
||||
|
"pycharm": { |
||||
|
"name": "#%%\n" |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
], |
||||
|
"metadata": { |
||||
|
"kernelspec": { |
||||
|
"display_name": "Python 3", |
||||
|
"language": "python", |
||||
|
"name": "python3" |
||||
|
}, |
||||
|
"language_info": { |
||||
|
"codemirror_mode": { |
||||
|
"name": "ipython", |
||||
|
"version": 2 |
||||
|
}, |
||||
|
"file_extension": ".py", |
||||
|
"mimetype": "text/x-python", |
||||
|
"name": "python", |
||||
|
"nbconvert_exporter": "python", |
||||
|
"pygments_lexer": "ipython2", |
||||
|
"version": "2.7.6" |
||||
|
} |
||||
|
}, |
||||
|
"nbformat": 4, |
||||
|
"nbformat_minor": 0 |
||||
|
} |
@ -0,0 +1,37 @@ |
|||||
|
callback: |
||||
|
early_stopping: |
||||
|
mode: max |
||||
|
monitor: eval_epoch_metric |
||||
|
patience: 4 |
||||
|
model_checkpoint: |
||||
|
every_n_epoch: 1 |
||||
|
tensorboard: null |
||||
|
device: |
||||
|
device_str: null |
||||
|
n_gpu: -1 |
||||
|
sync_bn: false |
||||
|
learning: |
||||
|
loss: CrossEntropyLoss |
||||
|
lr: 5.0e-05 |
||||
|
lr_scheduler_type: linear |
||||
|
optimizer: Adam |
||||
|
warmup_ratio: 0.0 |
||||
|
warmup_steps: 0 |
||||
|
logging: |
||||
|
print_steps: 1 |
||||
|
metrics: |
||||
|
metric: Accuracy |
||||
|
train: |
||||
|
batch_size: 5 |
||||
|
dataloader_drop_last: true |
||||
|
dataloader_num_workers: -1 |
||||
|
dataloader_pin_memory: true |
||||
|
epoch_num: 3 |
||||
|
eval_steps: null |
||||
|
eval_strategy: epoch |
||||
|
freeze_bn: false |
||||
|
load_best_model_at_end: false |
||||
|
output_dir: my_output |
||||
|
overwrite_output_dir: true |
||||
|
seed: 42 |
||||
|
val_batch_size: -1 |
Loading…
Reference in new issue