logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

163 lines
24 KiB

{
"cells": [
{
"cell_type": "markdown",
"source": [
"# Train resnet18 operator on mnist dataset."
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from resnet_image_embedding import ResnetImageEmbedding\n",
"from towhee.trainer.training_config import TrainingConfig\n",
"from torchvision import transforms\n",
"from towhee import dataset\n",
"from torchvision.transforms import Lambda\n",
"\n",
"# build a resnet op with 10 classes output, because the mnist has 10 classes:\n",
"op = ResnetImageEmbedding('resnet18', num_classes=10)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [],
"source": [
"# build a training config:\n",
"training_config = TrainingConfig()\n",
"training_config.batch_size = 64\n",
"training_config.epoch_num = 5\n",
"training_config.tensorboard = None\n",
"training_config.output_dir = 'mnist_output'\n",
"training_config.dataloader_num_workers = 0"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [],
"source": [
"# prepare the mnist data\n",
"mnist_transform = transforms.Compose([transforms.ToTensor(),\n",
" Lambda(lambda x: x.repeat(3, 1, 1)),\n",
" transforms.Normalize(mean=[0.1307,0.1307,0.1307], std=[0.3081,0.3081,0.3081])])\n",
"train_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=True)\n",
"eval_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=False)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-03-02 15:24:19,702 - 8669324800 - trainer.py-trainer:390 - WARNING: TrainingConfig(output_dir='mnist_output', overwrite_output_dir=True, eval_strategy='epoch', eval_steps=None, batch_size=64, val_batch_size=-1, seed=42, epoch_num=5, dataloader_pin_memory=True, dataloader_drop_last=True, dataloader_num_workers=0, lr=5e-05, metric='Accuracy', print_steps=None, load_best_model_at_end=False, early_stopping={'monitor': 'eval_epoch_metric', 'patience': 4, 'mode': 'max'}, model_checkpoint={'every_n_epoch': 1}, tensorboard=None, loss='CrossEntropyLoss', optimizer='Adam', lr_scheduler_type='linear', warmup_ratio=0.0, warmup_steps=0, device_str=None, n_gpu=-1, sync_bn=False, freeze_bn=False)\n",
"[epoch 1/5] loss=0.388, metric=0.884: 41%|████ | 383/937 [07:38<11:00, 1.19s/step]"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[0;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)",
"\u001B[0;32m/var/folders/wn/4wflyq8x0f9bhkwryvss30880000gn/T/ipykernel_5732/1544844912.py\u001B[0m in \u001B[0;36m<module>\u001B[0;34m\u001B[0m\n\u001B[0;32m----> 1\u001B[0;31m \u001B[0mop\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtrain\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mtraining_config\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mtrain_dataset\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0mtrain_data\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0meval_dataset\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0meval_data\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/towhee-0.4.1.dev110-py3.9.egg/towhee/operator/base.py\u001B[0m in \u001B[0;36mtrain\u001B[0;34m(self, training_config, train_dataset, eval_dataset, resume_checkpoint_path, **kwargs)\u001B[0m\n\u001B[1;32m 136\u001B[0m \"\"\"\n\u001B[1;32m 137\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0msetup_trainer\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mtraining_config\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mtrain_dataset\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0meval_dataset\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 138\u001B[0;31m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtrainer\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtrain\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mresume_checkpoint_path\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 139\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 140\u001B[0m \u001B[0;34m@\u001B[0m\u001B[0mproperty\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/towhee-0.4.1.dev110-py3.9.egg/towhee/trainer/trainer.py\u001B[0m in \u001B[0;36mtrain\u001B[0;34m(self, resume_checkpoint_path)\u001B[0m\n\u001B[1;32m 281\u001B[0m \u001B[0;32melse\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 282\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mdistributed\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;32mFalse\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 283\u001B[0;31m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mrun_train\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mresume_checkpoint_path\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 284\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 285\u001B[0m \u001B[0;32mdef\u001B[0m \u001B[0m_spawn_train_process\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mresume_checkpoint_path\u001B[0m\u001B[0;34m:\u001B[0m \u001B[0mOptional\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0mstr\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/towhee-0.4.1.dev110-py3.9.egg/towhee/trainer/trainer.py\u001B[0m in \u001B[0;36mrun_train\u001B[0;34m(self, resume_checkpoint_path, rank, world_size)\u001B[0m\n\u001B[1;32m 407\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mcallbacks\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mon_train_batch_begin\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0minputs\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mlogs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 408\u001B[0m \u001B[0minputs\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mprepare_inputs\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0minputs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 409\u001B[0;31m \u001B[0mstep_logs\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtrain_step\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mmodel\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0minputs\u001B[0m\u001B[0;34m)\u001B[0m \u001B[0;31m# , train_dataloader)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 410\u001B[0m \u001B[0mlogs\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m\"lr\"\u001B[0m\u001B[0;34m]\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mlr_scheduler\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mget_lr\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;36m0\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 411\u001B[0m \u001B[0mlogs\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m\"global_step\"\u001B[0m\u001B[0;34m]\u001B[0m \u001B[0;34m+=\u001B[0m \u001B[0;36m1\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/towhee-0.4.1.dev110-py3.9.egg/towhee/trainer/trainer.py\u001B[0m in \u001B[0;36mtrain_step\u001B[0;34m(self, model, inputs)\u001B[0m\n\u001B[1;32m 615\u001B[0m \u001B[0mStep\u001B[0m \u001B[0mlogs\u001B[0m \u001B[0mwhich\u001B[0m \u001B[0mcontains\u001B[0m \u001B[0mthe\u001B[0m \u001B[0mstep\u001B[0m \u001B[0mloss\u001B[0m \u001B[0;32mand\u001B[0m \u001B[0mmetric\u001B[0m \u001B[0minfos\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 616\u001B[0m \"\"\"\n\u001B[0;32m--> 617\u001B[0;31m \u001B[0mstep_loss\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mcompute_loss\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mmodel\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0minputs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 618\u001B[0m \u001B[0mstep_loss\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mreduce_value\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mstep_loss\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0maverage\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mTrue\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 619\u001B[0m \u001B[0mstep_loss\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mbackward\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/towhee-0.4.1.dev110-py3.9.egg/towhee/trainer/trainer.py\u001B[0m in \u001B[0;36mcompute_loss\u001B[0;34m(self, model, inputs)\u001B[0m\n\u001B[1;32m 653\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mset_train_mode\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mmodel\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 654\u001B[0m \u001B[0mlabels\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0minputs\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;36m1\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 655\u001B[0;31m \u001B[0moutputs\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mmodel\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0minputs\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;36m0\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 656\u001B[0m \u001B[0mloss\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mloss\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0moutputs\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mlabels\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 657\u001B[0m \u001B[0;32mreturn\u001B[0m \u001B[0mloss\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/torch/nn/modules/module.py\u001B[0m in \u001B[0;36m_call_impl\u001B[0;34m(self, *input, **kwargs)\u001B[0m\n\u001B[1;32m 1049\u001B[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001B[1;32m 1050\u001B[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001B[0;32m-> 1051\u001B[0;31m \u001B[0;32mreturn\u001B[0m \u001B[0mforward_call\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m*\u001B[0m\u001B[0minput\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 1052\u001B[0m \u001B[0;31m# Do not call functions when jit is used\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 1053\u001B[0m \u001B[0mfull_backward_hooks\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mnon_full_backward_hooks\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/timm/models/resnet.py\u001B[0m in \u001B[0;36mforward\u001B[0;34m(self, x)\u001B[0m\n\u001B[1;32m 683\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 684\u001B[0m \u001B[0;32mdef\u001B[0m \u001B[0mforward\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 685\u001B[0;31m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mforward_features\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 686\u001B[0m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mglobal_pool\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 687\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mdrop_rate\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/timm/models/resnet.py\u001B[0m in \u001B[0;36mforward_features\u001B[0;34m(self, x)\u001B[0m\n\u001B[1;32m 676\u001B[0m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mmaxpool\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 677\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 678\u001B[0;31m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mlayer1\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 679\u001B[0m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mlayer2\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 680\u001B[0m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mlayer3\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/torch/nn/modules/module.py\u001B[0m in \u001B[0;36m_call_impl\u001B[0;34m(self, *input, **kwargs)\u001B[0m\n\u001B[1;32m 1049\u001B[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001B[1;32m 1050\u001B[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001B[0;32m-> 1051\u001B[0;31m \u001B[0;32mreturn\u001B[0m \u001B[0mforward_call\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m*\u001B[0m\u001B[0minput\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 1052\u001B[0m \u001B[0;31m# Do not call functions when jit is used\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 1053\u001B[0m \u001B[0mfull_backward_hooks\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mnon_full_backward_hooks\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/torch/nn/modules/container.py\u001B[0m in \u001B[0;36mforward\u001B[0;34m(self, input)\u001B[0m\n\u001B[1;32m 137\u001B[0m \u001B[0;32mdef\u001B[0m \u001B[0mforward\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0minput\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 138\u001B[0m \u001B[0;32mfor\u001B[0m \u001B[0mmodule\u001B[0m \u001B[0;32min\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 139\u001B[0;31m \u001B[0minput\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mmodule\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0minput\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 140\u001B[0m \u001B[0;32mreturn\u001B[0m \u001B[0minput\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 141\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/torch/nn/modules/module.py\u001B[0m in \u001B[0;36m_call_impl\u001B[0;34m(self, *input, **kwargs)\u001B[0m\n\u001B[1;32m 1049\u001B[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001B[1;32m 1050\u001B[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001B[0;32m-> 1051\u001B[0;31m \u001B[0;32mreturn\u001B[0m \u001B[0mforward_call\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m*\u001B[0m\u001B[0minput\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 1052\u001B[0m \u001B[0;31m# Do not call functions when jit is used\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 1053\u001B[0m \u001B[0mfull_backward_hooks\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mnon_full_backward_hooks\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/timm/models/resnet.py\u001B[0m in \u001B[0;36mforward\u001B[0;34m(self, x)\u001B[0m\n\u001B[1;32m 339\u001B[0m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0maa\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 340\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 341\u001B[0;31m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mconv2\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 342\u001B[0m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mbn2\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 343\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mdrop_block\u001B[0m \u001B[0;32mis\u001B[0m \u001B[0;32mnot\u001B[0m \u001B[0;32mNone\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/torch/nn/modules/module.py\u001B[0m in \u001B[0;36m_call_impl\u001B[0;34m(self, *input, **kwargs)\u001B[0m\n\u001B[1;32m 1049\u001B[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001B[1;32m 1050\u001B[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001B[0;32m-> 1051\u001B[0;31m \u001B[0;32mreturn\u001B[0m \u001B[0mforward_call\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m*\u001B[0m\u001B[0minput\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 1052\u001B[0m \u001B[0;31m# Do not call functions when jit is used\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 1053\u001B[0m \u001B[0mfull_backward_hooks\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mnon_full_backward_hooks\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/torch/nn/modules/conv.py\u001B[0m in \u001B[0;36mforward\u001B[0;34m(self, input)\u001B[0m\n\u001B[1;32m 441\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 442\u001B[0m \u001B[0;32mdef\u001B[0m \u001B[0mforward\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0minput\u001B[0m\u001B[0;34m:\u001B[0m \u001B[0mTensor\u001B[0m\u001B[0;34m)\u001B[0m \u001B[0;34m->\u001B[0m \u001B[0mTensor\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 443\u001B[0;31m \u001B[0;32mreturn\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_conv_forward\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0minput\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mweight\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mbias\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 444\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 445\u001B[0m \u001B[0;32mclass\u001B[0m \u001B[0mConv3d\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0m_ConvNd\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m/opt/homebrew/anaconda3/envs/conda_towhee/lib/python3.9/site-packages/torch/nn/modules/conv.py\u001B[0m in \u001B[0;36m_conv_forward\u001B[0;34m(self, input, weight, bias)\u001B[0m\n\u001B[1;32m 437\u001B[0m \u001B[0mweight\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mbias\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mstride\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 438\u001B[0m _pair(0), self.dilation, self.groups)\n\u001B[0;32m--> 439\u001B[0;31m return F.conv2d(input, weight, bias, self.stride,\n\u001B[0m\u001B[1;32m 440\u001B[0m self.padding, self.dilation, self.groups)\n\u001B[1;32m 441\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;31mKeyboardInterrupt\u001B[0m: "
]
}
],
"source": [
"# start to train mnist, it will take about 30-100 minutes on a cpu machine.\n",
"# if you train on a gpu machine, it will be much faster.\n",
"op.train(training_config, train_dataset=train_data, eval_dataset=eval_data)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"### if you observe loss is decreasing and metric is increasing, it means you are training the model correctly."
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}