{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Quick start to train an operator on a toy fake dataset."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "from resnet_image_embedding import ResnetImageEmbedding\n",
    "from towhee.trainer.training_config import TrainingConfig\n",
    "from torchvision import transforms\n",
    "from towhee import dataset\n",
    "\n",
    "# build an resnet op:\n",
    "op = ResnetImageEmbedding('resnet18', num_classes=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "outputs": [],
   "source": [
    "# build a training config:\n",
    "training_config = TrainingConfig()\n",
    "training_config.batch_size = 2\n",
    "training_config.epoch_num = 2\n",
    "training_config.tensorboard = None\n",
    "training_config.output_dir = 'quick_start_output'"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "outputs": [],
   "source": [
    "# prepare the dataset\n",
    "fake_transform = transforms.Compose([transforms.ToTensor()])\n",
    "train_data = dataset('fake', size=20, transform=fake_transform)\n",
    "eval_data = dataset('fake', size=10, transform=fake_transform)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-03-02 15:09:06,334 - 8665081344 - trainer.py-trainer:390 - WARNING: TrainingConfig(output_dir='quick_start_output', overwrite_output_dir=True, eval_strategy='epoch', eval_steps=None, batch_size=2, val_batch_size=-1, seed=42, epoch_num=2, dataloader_pin_memory=True, dataloader_drop_last=True, dataloader_num_workers=-1, lr=5e-05, metric='Accuracy', print_steps=None, load_best_model_at_end=False, early_stopping={'monitor': 'eval_epoch_metric', 'patience': 4, 'mode': 'max'}, model_checkpoint={'every_n_epoch': 1}, tensorboard=None, loss='CrossEntropyLoss', optimizer='Adam', lr_scheduler_type='linear', warmup_ratio=0.0, warmup_steps=0, device_str=None, n_gpu=-1, sync_bn=False, freeze_bn=False)\n",
      "[epoch 1/2] loss=2.402, metric=0.0, eval_loss=2.254, eval_metric=0.0: 100%|██████████| 10/10 [00:32<00:00,  3.25s/step]\n",
      "[epoch 2/2] loss=1.88, metric=0.1, eval_loss=1.855, eval_metric=0.1: 100%|██████████| 10/10 [00:22<00:00,  1.14step/s]  "
     ]
    }
   ],
   "source": [
    "# start training, it will take about 2 minute on a cpu machine.\n",
    "op.train(training_config, train_dataset=train_data, eval_dataset=eval_data)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### If you see the two epochs progress bar finish its schedule and a `quick_start_output` folder result, it means you succeeded."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}