{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Train resnet18 operator on mnist dataset."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "from resnet_image_embedding import ResnetImageEmbedding\n",
    "from towhee.trainer.training_config import TrainingConfig\n",
    "from torchvision import transforms\n",
    "from towhee import dataset\n",
    "from torchvision.transforms import Lambda\n",
    "\n",
    "# build a resnet op with 10 classes output, because the mnist has 10 classes:\n",
    "op = ResnetImageEmbedding('resnet18', num_classes=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "# build a training config:\n",
    "training_config = TrainingConfig()\n",
    "training_config.batch_size = 64\n",
    "training_config.epoch_num = 2\n",
    "training_config.tensorboard = None\n",
    "training_config.output_dir = 'mnist_output'"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [],
   "source": [
    "# prepare the mnist data\n",
    "mean = 0.1307\n",
    "std = 0.3081\n",
    "mnist_transform = transforms.Compose([transforms.ToTensor(),\n",
    "                                          Lambda(lambda x: x.repeat(3, 1, 1)),\n",
    "                                          transforms.Normalize(mean=[mean] * 3, std=[std] * 3)])\n",
    "train_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=True)\n",
    "eval_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=False)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-03-02 18:43:32,380 - 8612169216 - trainer.py-trainer:390 - WARNING: TrainingConfig(output_dir='mnist_output', overwrite_output_dir=True, eval_strategy='epoch', eval_steps=None, batch_size=64, val_batch_size=-1, seed=42, epoch_num=2, dataloader_pin_memory=True, dataloader_drop_last=True, dataloader_num_workers=0, lr=5e-05, metric='Accuracy', print_steps=None, load_best_model_at_end=False, early_stopping={'monitor': 'eval_epoch_metric', 'patience': 4, 'mode': 'max'}, model_checkpoint={'every_n_epoch': 1}, tensorboard=None, loss='CrossEntropyLoss', optimizer='Adam', lr_scheduler_type='linear', warmup_ratio=0.0, warmup_steps=0, device_str=None, n_gpu=-1, sync_bn=False, freeze_bn=False)\n",
      "[epoch 1/2] loss=0.203, metric=0.945, eval_loss=0.182, eval_metric=0.988: 100%|██████████| 937/937 [19:04<00:00,  1.22s/step]\n",
      "[epoch 2/2] loss=0.044, metric=0.994, eval_loss=0.045, eval_metric=0.989: 100%|██████████| 937/937 [19:08<00:00,  1.20s/step]"
     ]
    }
   ],
   "source": [
    "# start to train mnist, it will take about 30-100 minutes on a cpu machine.\n",
    "# if you train on a gpu machine, it will be much faster.\n",
    "op.train(training_config, train_dataset=train_data, eval_dataset=eval_data)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### If you observe loss is decreasing and metric is increasing, it means you are training the model correctly.\n",
    "### After finishing training, we can use this model to do predict."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "outputs": [
    {
     "data": {
      "text/plain": "<Figure size 432x288 with 1 Axes>",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAANxklEQVR4nO3db6hc9Z3H8c9no2I0VRIlGq1Yt6i4LsGsIQhZVpdicX2gUag0YFWUvX2gG8Woe1U0Rp+Ejd26ChZSK02lq4a0oT6otSEEpQrFG7mrsbE1W932JiHZGLH6QPPvuw/uyXKNd35znXPmT/J9vyDMzPnOmfNluJ+cM/M7Z36OCAE4+v1VvxsA0BuEHUiCsANJEHYgCcIOJHFMLzdmm6/+gS6LCE+2vNae3fYVtn9ve6vt4TqvBaC73Ok4u+1pkv4g6XJJY5Jel7Q4In5XWIc9O9Bl3dizL5C0NSL+GBF7JT0n6eoarwegi+qE/UxJf57weKxa9jm2h2yP2B6psS0ANdX5gm6yQ4UvHKZHxCpJqyQO44F+qrNnH5N01oTHX5W0vV47ALqlTthfl3Su7XNsHyfp25JeaKYtAE3r+DA+Ivbbvk3SS5KmSXo6It5urDMAjep46K2jjfGZHei6rpxUA+DIQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IouP52SXJ9vuSPpZ0QNL+iJjfRFMAmlcr7JV/jIjdDbwOgC7iMB5Iom7YQ9KvbW+yPTTZE2wP2R6xPVJzWwBqcER0vrJ9RkRstz1b0npJ/xIRrxSe3/nGAExJRHiy5bX27BGxvbrdJWmdpAV1Xg9A93Qcdtsn2v7KofuSvilpc1ONAWhWnW/jT5O0zvah1/nPiPhVI10BaFytz+xfemN8Zge6riuf2QEcOQg7kARhB5Ig7EAShB1IookLYTDAZsyYUazfddddtda/8847i/UPP/ywZW358uXFdZ988sliff/+/cU6Po89O5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kwVVvR4Djjz++WB8eHm5ZW7p0aXHdE044oVivLmFuqZt/P+3G2duN8e/bt6/Jdo4YXPUGJEfYgSQIO5AEYQeSIOxAEoQdSIKwA0kwzt4D7cbJL7vssmL97rvvrrV+HXv27KlVP/bYY1vWzj777I56OuTFF18s1l9++eWWtccee6y47pE8Rs84O5AcYQeSIOxAEoQdSIKwA0kQdiAJwg4kwTh7A6ZPn16sP/7448X6zTff3GQ7n7N58+ZifcWKFcX66Ohosb5ly5ZivfS78y+99FJx3UsuuaRYr+P8888v1rdu3dq1bXdbx+Pstp+2vcv25gnLZtleb/vd6nZmk80CaN5UDuN/LOmKw5YNS9oQEedK2lA9BjDA2oY9Il6RdPg5kVdLWl3dXy1pUbNtAWhap3O9nRYROyQpInbYnt3qibaHJA11uB0ADen6xI4RsUrSKuno/YIOOBJ0OvS20/YcSapudzXXEoBu6DTsL0i6sbp/o6RfNNMOgG5pexhv+1lJl0k61faYpGWSVkhaY/sWSX+S9K1uNjnoLr/88mK97jj67t27i/Xnn3++Za3d/Ot79+7tqKepOuOMM1rWPvvss65uG5/XNuwRsbhF6RsN9wKgizhdFkiCsANJEHYgCcIOJEHYgSS6fgbd0aI0tXG74a26nnrqqWL9/vvv79q2jzmm/CdyzTXXFOuly3tnz255lnUjNm7c2LI2NjbW1W0PIvbsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE4+xT9MADD7SsLVy4sNZrtxtHf+SRR2q9fskFF1xQrC9ZsqRYHxoa3F8cW7lyZcvap59+2sNOBgN7diAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2KermtdfPPPNMsd5uTPi8885rWbv99tuL61533XXF+imnnFKs93LK78M98cQTxXrpevaM2LMDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKMs0/RyMhIy9pNN91U67XXrVtXrO/bt69Ynz59esvaSSed1FFPh7Sb0vmGG24o1u+9996Wtblz53bU0yFr164t1rs9HfWRpu2e3fbTtnfZ3jxh2UO2t9kerf5d2d02AdQ1lcP4H0u6YpLl34+Ii6p/v2y2LQBNaxv2iHhF0p4e9AKgi+p8QXeb7Terw/yZrZ5ke8j2iO3WH3oBdF2nYf+BpK9LukjSDknfa/XEiFgVEfMjYn6H2wLQgI7CHhE7I+JARByU9ENJC5ptC0DTOgq77TkTHl4jaXOr5wIYDG53PbLtZyVdJulUSTslLaseXyQpJL0v6bsRsaPtxuz+Xfxck+2WtXbjvYsWLWq4m+a89tprxfrDDz9crLcbx1+zZs2X7umQdr1deumlxfrBgwc73vaRLCIm/WNte1JNRCyeZPGPancEoKc4XRZIgrADSRB2IAnCDiRB2IEkuMR1ikpDlLfeemtx3Z07dxbr119/fbH+zjvvFOulS2Tb/dzyJ598Uqwfd9xxxfqrr75arJeGLNsNja1fv75Yzzq01in27EAShB1IgrADSRB2IAnCDiRB2IEkCDuQRNtLXBvd2BF8iWtWp59+erG+bdu2jl97dHS0WL/44os7fu3MWl3iyp4dSIKwA0kQdiAJwg4kQdiBJAg7kARhB5LgenYULVu2rNb6pemmn3vuuVqvjS+HPTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJMH17Mlde+21xXq76ajb/f2sXLmyZW14eLi4LjrT8fXsts+yvdH2Fttv2769Wj7L9nrb71a3M5tuGkBzpnIYv1/S0oi4QNIlkm61/TeShiVtiIhzJW2oHgMYUG3DHhE7IuKN6v7HkrZIOlPS1ZJWV09bLWlRl3oE0IAvdW687a9Jmifpt5JOi4gd0vh/CLZnt1hnSNJQzT4B1DTlsNueIelnku6IiL+UJuybKCJWSVpVvQZf0AF9MqWhN9vHajzoP42In1eLd9qeU9XnSNrVnRYBNKHt0JvHd+GrJe2JiDsmLF8p6YOIWGF7WNKsiLinzWuxZx8w7X7Oee7cucX6Bx98UKzPmzevZW1sbKy4LjrTauhtKofxCyV9R9JbtkerZfdJWiFpje1bJP1J0rca6BNAl7QNe0T8RlKrD+jfaLYdAN3C6bJAEoQdSIKwA0kQdiAJwg4kwU9JH+Xuuad46oMuvPDCYv3AgQPF+oMPPlisM5Y+ONizA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAS/JT0UeCcc85pWdu0aVNx3ZNPPrlYb7f+ggULinX0Xsc/JQ3g6EDYgSQIO5AEYQeSIOxAEoQdSIKwA0lwPftRYMmSJS1r7cbR21m+fHmt9TE42LMDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBJTmZ/9LEk/kXS6pIOSVkXEf9h+SNI/S/rf6qn3RcQv27wW17N34KqrrirW165d27I2bdq0Wtuuuz56r8787PslLY2IN2x/RdIm2+ur2vcj4tGmmgTQPVOZn32HpB3V/Y9tb5F0ZrcbA9CsL/WZ3fbXJM2T9Ntq0W2237T9tO2ZLdYZsj1ie6ReqwDqmHLYbc+Q9DNJd0TEXyT9QNLXJV2k8T3/9yZbLyJWRcT8iJhfv10AnZpS2G0fq/Gg/zQifi5JEbEzIg5ExEFJP5TELw8CA6xt2G1b0o8kbYmIf5+wfM6Ep10jaXPz7QFoylS+jV8o6TuS3rI9Wi27T9Ji2xdJCknvS/puF/qDpPfee69Y/+ijj1rWZs2aVVz30UcZTMliKt/G/0bSZON2xTF1AIOFM+iAJAg7kARhB5Ig7EAShB1IgrADSTBlM3CUYcpmIDnCDiRB2IEkCDuQBGEHkiDsQBKEHUii11M275b0PxMen1otG0SD2tug9iXRW6ea7O3sVoWenlTzhY3bI4P623SD2tug9iXRW6d61RuH8UAShB1Iot9hX9Xn7ZcMam+D2pdEb53qSW99/cwOoHf6vWcH0COEHUiiL2G3fYXt39veanu4Hz20Yvt922/ZHu33/HTVHHq7bG+esGyW7fW2361uJ51jr0+9PWR7W/Xejdq+sk+9nWV7o+0ttt+2fXu1vK/vXaGvnrxvPf/MbnuapD9IulzSmKTXJS2OiN/1tJEWbL8vaX5E9P0EDNv/IOkTST+JiL+tlv2bpD0RsaL6j3JmRPzrgPT2kKRP+j2NdzVb0ZyJ04xLWiTpJvXxvSv0dZ168L71Y8++QNLWiPhjROyV9Jykq/vQx8CLiFck7Tls8dWSVlf3V2v8j6XnWvQ2ECJiR0S8Ud3/WNKhacb7+t4V+uqJfoT9TEl/nvB4TIM133tI+rXtTbaH+t3MJE6LiB3S+B+PpNl97udwbafx7qXDphkfmPeuk+nP6+pH2Cf7faxBGv9bGBF/J+mfJN1aHa5iaqY0jXevTDLN+EDodPrzuvoR9jFJZ014/FVJ2/vQx6QiYnt1u0vSOg3eVNQ7D82gW93u6nM//2+QpvGebJpxDcB718/pz/sR9tclnWv7HNvHSfq2pBf60McX2D6x+uJEtk+U9E0N3lTUL0i6sbp/o6Rf9LGXzxmUabxbTTOuPr93fZ/+PCJ6/k/SlRr/Rv6/Jd3fjx5a9PXXkv6r+vd2v3uT9KzGD+v2afyI6BZJp0jaIOnd6nbWAPX2jKS3JL2p8WDN6VNvf6/xj4ZvShqt/l3Z7/eu0FdP3jdOlwWS4Aw6IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUji/wD4PWdK1kkK0AAAAABJRU5ErkJggg==\n"
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "this picture is number 9\n"
     ]
    }
   ],
   "source": [
    "# get random picture and predict it.\n",
    "img_index = random.randint(0, len(eval_data))\n",
    "img = eval_data.dataset[img_index][0]\n",
    "img = img.numpy().transpose(1, 2, 0)  # (C, H, W) -> (H, W, C)\n",
    "pil_img = img * std + mean\n",
    "plt.imshow(pil_img)\n",
    "plt.show()\n",
    "test_img = eval_data.dataset[img_index][0].unsqueeze(0)\n",
    "out = op.trainer.predict(test_img)\n",
    "predict_num = torch.argmax(torch.softmax(out, dim=-1)).item()\n",
    "print('this picture is number {}'.format(predict_num))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### You can repeat running this prediction code cell multiple times and check if the prediction result is right."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}