logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

196 lines
10 KiB

{
"cells": [
{
"cell_type": "markdown",
"source": [
"# Train resnet18 operator on mnist dataset."
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('..')\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"import random\n",
"from resnet_image_embedding import ResnetImageEmbedding\n",
"from towhee.trainer.training_config import TrainingConfig\n",
"from torchvision import transforms\n",
"from towhee import dataset\n",
"from torchvision.transforms import Lambda\n",
"\n",
"# build a resnet op with 10 classes output, because the mnist has 10 classes:\n",
"op = ResnetImageEmbedding('resnet18', num_classes=10)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [],
"source": [
"# build a training config:\n",
"training_config = TrainingConfig()\n",
"training_config.batch_size = 64\n",
"training_config.epoch_num = 2\n",
"training_config.tensorboard = None\n",
"training_config.output_dir = 'mnist_output'"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [],
"source": [
"# prepare the mnist data\n",
"mean = 0.1307\n",
"std = 0.3081\n",
"mnist_transform = transforms.Compose([transforms.ToTensor(),\n",
" Lambda(lambda x: x.repeat(3, 1, 1)),\n",
" transforms.Normalize(mean=[mean] * 3, std=[std] * 3)])\n",
"train_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=True)\n",
"eval_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=False)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-03-02 18:43:32,380 - 8612169216 - trainer.py-trainer:390 - WARNING: TrainingConfig(output_dir='mnist_output', overwrite_output_dir=True, eval_strategy='epoch', eval_steps=None, batch_size=64, val_batch_size=-1, seed=42, epoch_num=2, dataloader_pin_memory=True, dataloader_drop_last=True, dataloader_num_workers=0, lr=5e-05, metric='Accuracy', print_steps=None, load_best_model_at_end=False, early_stopping={'monitor': 'eval_epoch_metric', 'patience': 4, 'mode': 'max'}, model_checkpoint={'every_n_epoch': 1}, tensorboard=None, loss='CrossEntropyLoss', optimizer='Adam', lr_scheduler_type='linear', warmup_ratio=0.0, warmup_steps=0, device_str=None, n_gpu=-1, sync_bn=False, freeze_bn=False)\n",
"[epoch 1/2] loss=0.203, metric=0.945, eval_loss=0.182, eval_metric=0.988: 100%|██████████| 937/937 [19:04<00:00, 1.22s/step]\n",
"[epoch 2/2] loss=0.044, metric=0.994, eval_loss=0.045, eval_metric=0.989: 100%|██████████| 937/937 [19:08<00:00, 1.20s/step]"
]
}
],
"source": [
"# start to train mnist, it will take about 30-100 minutes on a cpu machine.\n",
"# if you train on a gpu machine, it will be much faster.\n",
"op.train(training_config, train_dataset=train_data, eval_dataset=eval_data)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"### If you observe loss is decreasing and metric is increasing, it means you are training the model correctly.\n",
"### After finishing training, we can use this model to do predict."
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 12,
"outputs": [
{
"data": {
"text/plain": "<Figure size 432x288 with 1 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAANxklEQVR4nO3db6hc9Z3H8c9no2I0VRIlGq1Yt6i4LsGsIQhZVpdicX2gUag0YFWUvX2gG8Woe1U0Rp+Ejd26ChZSK02lq4a0oT6otSEEpQrFG7mrsbE1W932JiHZGLH6QPPvuw/uyXKNd35znXPmT/J9vyDMzPnOmfNluJ+cM/M7Z36OCAE4+v1VvxsA0BuEHUiCsANJEHYgCcIOJHFMLzdmm6/+gS6LCE+2vNae3fYVtn9ve6vt4TqvBaC73Ok4u+1pkv4g6XJJY5Jel7Q4In5XWIc9O9Bl3dizL5C0NSL+GBF7JT0n6eoarwegi+qE/UxJf57weKxa9jm2h2yP2B6psS0ANdX5gm6yQ4UvHKZHxCpJqyQO44F+qrNnH5N01oTHX5W0vV47ALqlTthfl3Su7XNsHyfp25JeaKYtAE3r+DA+Ivbbvk3SS5KmSXo6It5urDMAjep46K2jjfGZHei6rpxUA+DIQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IouP52SXJ9vuSPpZ0QNL+iJjfRFMAmlcr7JV/jIjdDbwOgC7iMB5Iom7YQ9KvbW+yPTTZE2wP2R6xPVJzWwBqcER0vrJ9RkRstz1b0npJ/xIRrxSe3/nGAExJRHiy5bX27BGxvbrdJWmdpAV1Xg9A93Qcdtsn2v7KofuSvilpc1ONAWhWnW/jT5O0zvah1/nPiPhVI10BaFytz+xfemN8Zge6riuf2QEcOQg7kARhB5Ig7EAShB1IookLYTDAZsyYUazfddddtda/8847i/UPP/ywZW358uXFdZ988sliff/+/cU6Po89O5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kwVVvR4Djjz++WB8eHm5ZW7p0aXHdE044oVivLmFuqZt/P+3G2duN8e/bt6/Jdo4YXPUGJEfYgSQIO5AEYQeSIOxAEoQdSIKwA0kwzt4D7cbJL7vssmL97rvvrrV+HXv27KlVP/bYY1vWzj777I56OuTFF18s1l9++eWWtccee6y47pE8Rs84O5AcYQeSIOxAEoQdSIKwA0kQdiAJwg4kwTh7A6ZPn16sP/7448X6zTff3GQ7n7N58+ZifcWKFcX66Ohosb5ly5ZivfS78y+99FJx3UsuuaRYr+P8888v1rdu3dq1bXdbx+Pstp+2vcv25gnLZtleb/vd6nZmk80CaN5UDuN/LOmKw5YNS9oQEedK2lA9BjDA2oY9Il6RdPg5kVdLWl3dXy1pUbNtAWhap3O9nRYROyQpInbYnt3qibaHJA11uB0ADen6xI4RsUrSKuno/YIOOBJ0OvS20/YcSapudzXXEoBu6DTsL0i6sbp/o6RfNNMOgG5pexhv+1lJl0k61faYpGWSVkhaY/sWSX+S9K1uNjnoLr/88mK97jj67t27i/Xnn3++Za3d/Ot79+7tqKepOuOMM1rWPvvss65uG5/XNuwRsbhF6RsN9wKgizhdFkiCsANJEHYgCcIOJEHYgSS6fgbd0aI0tXG74a26nnrqqWL9/vvv79q2jzmm/CdyzTXXFOuly3tnz255lnUjNm7c2LI2NjbW1W0PIvbsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE4+xT9MADD7SsLVy4sNZrtxtHf+SRR2q9fskFF1xQrC9ZsqRYHxoa3F8cW7lyZcvap59+2sNOBgN7diAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2KermtdfPPPNMsd5uTPi8885rWbv99tuL61533XXF+imnnFKs93LK78M98cQTxXrpevaM2LMDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKMs0/RyMhIy9pNN91U67XXrVtXrO/bt69Ynz59esvaSSed1FFPh7Sb0vmGG24o1u+9996Wtblz53bU0yFr164t1rs9HfWRpu2e3fbTtnfZ3jxh2UO2t9kerf5d2d02AdQ1lcP4H0u6YpLl34+Ii6p/v2y2LQBNaxv2iHhF0p4e9AKgi+p8QXeb7Terw/yZrZ5ke8j2iO3WH3oBdF2nYf+BpK9LukjSDknfa/XEiFgVEfMjYn6H2wLQgI7CHhE7I+JARByU9ENJC5ptC0DTOgq77TkTHl4jaXOr5wIYDG53PbLtZyVdJulUSTslLaseXyQpJL0v6bsRsaPtxuz+Xfxck+2WtXbjvYsWLWq4m+a89tprxfrDDz9crLcbx1+zZs2X7umQdr1deumlxfrBgwc73vaRLCIm/WNte1JNRCyeZPGPancEoKc4XRZIgrADSRB2IAnCDiRB2IEkuMR1ikpDlLfeemtx3Z07dxbr119/fbH+zjvvFOulS2Tb/dzyJ598Uqwfd9xxxfqrr75arJeGLNsNja1fv75Yzzq01in27EAShB1IgrADSRB2IAnCDiRB2IEkCDuQRNtLXBvd2BF8iWtWp59+erG+bdu2jl97dHS0WL/44os7fu3MWl3iyp4dSIKwA0kQdiAJwg4kQdiBJAg7kARhB5LgenYULVu2rNb6pemmn3vuuVqvjS+HPTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJMH17Mlde+21xXq76ajb/f2sXLmyZW14eLi4LjrT8fXsts+yvdH2Fttv2769Wj7L9nrb71a3M5tuGkBzpnIYv1/S0oi4QNIlkm61/TeShiVtiIhzJW2oHgMYUG3DHhE7IuKN6v7HkrZIOlPS1ZJWV09bLWlRl3oE0IAvdW687a9Jmifpt5JOi4gd0vh/CLZnt1hnSNJQzT4B1DTlsNueIelnku6IiL+UJuybKCJWSVpVvQZf0AF9MqWhN9vHajzoP42In1eLd9qeU9XnSNrVnRYBNKHt0JvHd+GrJe2JiDsmLF8p6YOIWGF7WNKsiLinzWuxZx8w7X7Oee7cucX6Bx98UKzPmzevZW1sbKy4LjrTauhtKofxCyV9R9JbtkerZfdJWiFpje1bJP1J0rca6BNAl7QNe0T8RlKrD+jfaLYdAN3C6bJAEoQdSIKwA0kQdiAJwg4kwU9JH+Xuuad46oMuvPDCYv3AgQPF+oMPPlisM5Y+ONizA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAS/JT0UeCcc85pWdu0aVNx3ZNPPrlYb7f+ggULinX0Xsc/JQ3g6EDYgSQIO5AEYQeSIOxAEoQdSIKwA0lwPftRYMmSJS1r7cbR21m+fHmt9TE42LMDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBJTmZ/9LEk/kXS6pIOSVkXEf9h+SNI/S/rf6qn3RcQv27wW17N34KqrrirW165d27I2bdq0Wtuuuz56r8787PslLY2IN2x/RdIm2+ur2vcj4tGmmgTQPVOZn32HpB3V/Y9tb5F0ZrcbA9CsL/WZ3fbXJM2T9Ntq0W2237T9tO2ZLdYZsj1ie6ReqwDqmHLYbc+Q9DNJd0TEXyT9QNLXJV2k8T3/9yZbLyJWRcT8iJhfv10AnZpS2G0fq/Gg/zQifi5JEbEzIg5ExEFJP5TELw8CA6xt2G1b0o8kbYmIf5+wfM6Ep10jaXPz7QFoylS+jV8o6TuS3rI9Wi27T9Ji2xdJCknvS/puF/qDpPfee69Y/+ijj1rWZs2aVVz30UcZTMliKt/G/0bSZON2xTF1AIOFM+iAJAg7kARhB5Ig7EAShB1IgrADSTBlM3CUYcpmIDnCDiRB2IEkCDuQBGEHkiDsQBKEHUii11M275b0PxMen1otG0SD2tug9iXRW6ea7O3sVoWenlTzhY3bI4P623SD2tug9iXRW6d61RuH8UAShB1Iot9hX9Xn7ZcMam+D2pdEb53qSW99/cwOoHf6vWcH0COEHUiiL2G3fYXt39veanu4Hz20Yvt922/ZHu33/HTVHHq7bG+esGyW7fW2361uJ51jr0+9PWR7W/Xejdq+sk+9nWV7o+0ttt+2fXu1vK/vXaGvnrxvPf/MbnuapD9IulzSmKTXJS2OiN/1tJEWbL8vaX5E9P0EDNv/IOkTST+JiL+tlv2bpD0RsaL6j3JmRPzrgPT2kKRP+j2NdzVb0ZyJ04xLWiTpJvXxvSv0dZ168L71Y8++QNLWiPhjROyV9Jykq/vQx8CLiFck7Tls8dWSVlf3V2v8j6XnWvQ2ECJiR0S8Ud3/WNKhacb7+t4V+uqJfoT9TEl/nvB4TIM133tI+rXtTbaH+t3MJE6LiB3S+B+PpNl97udwbafx7qXDphkfmPeuk+nP6+pH2Cf7faxBGv9bGBF/J+mfJN1aHa5iaqY0jXevTDLN+EDodPrzuvoR9jFJZ014/FVJ2/vQx6QiYnt1u0vSOg3eVNQ7D82gW93u6nM//2+QpvGebJpxDcB718/pz/sR9tclnWv7HNvHSfq2pBf60McX2D6x+uJEtk+U9E0N3lTUL0i6sbp/o6Rf9LGXzxmUabxbTTOuPr93fZ/+PCJ6/k/SlRr/Rv6/Jd3fjx5a9PXXkv6r+vd2v3uT9KzGD+v2afyI6BZJp0jaIOnd6nbWAPX2jKS3JL2p8WDN6VNvf6/xj4ZvShqt/l3Z7/eu0FdP3jdOlwWS4Aw6IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUji/wD4PWdK1kkK0AAAAABJRU5ErkJggg==\n"
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"this picture is number 9\n"
]
}
],
"source": [
"# get random picture and predict it.\n",
"img_index = random.randint(0, len(eval_data))\n",
"img = eval_data.dataset[img_index][0]\n",
"img = img.numpy().transpose(1, 2, 0) # (C, H, W) -> (H, W, C)\n",
"pil_img = img * std + mean\n",
"plt.imshow(pil_img)\n",
"plt.show()\n",
"test_img = eval_data.dataset[img_index][0].unsqueeze(0).to(op.trainer.configs.device)\n",
"out = op.trainer.predict(test_img)\n",
"predict_num = torch.argmax(torch.softmax(out, dim=-1)).item()\n",
"print('this picture is number {}'.format(predict_num))"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"### You can repeat running this prediction code cell multiple times and check if the prediction result is right."
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}