logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

138 lines
3.8 KiB

{
"cells": [
{
"cell_type": "markdown",
"source": [
"# Quick start to train an operator on a toy fake dataset."
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('..')\n",
"from resnet_image_embedding import ResnetImageEmbedding\n",
"from towhee.trainer.training_config import TrainingConfig\n",
"from torchvision import transforms\n",
"from towhee import dataset\n",
"\n",
"# build an resnet op:\n",
"op = ResnetImageEmbedding('resnet18', num_classes=10)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"outputs": [],
"source": [
"# build a training config:\n",
"training_config = TrainingConfig()\n",
"training_config.batch_size = 2\n",
"training_config.epoch_num = 2\n",
"training_config.tensorboard = None\n",
"training_config.output_dir = 'quick_start_output'"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 11,
"outputs": [],
"source": [
"# prepare the dataset\n",
"fake_transform = transforms.Compose([transforms.ToTensor()])\n",
"train_data = dataset('fake', size=20, transform=fake_transform)\n",
"eval_data = dataset('fake', size=10, transform=fake_transform)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 9,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-03-02 15:09:06,334 - 8665081344 - trainer.py-trainer:390 - WARNING: TrainingConfig(output_dir='quick_start_output', overwrite_output_dir=True, eval_strategy='epoch', eval_steps=None, batch_size=2, val_batch_size=-1, seed=42, epoch_num=2, dataloader_pin_memory=True, dataloader_drop_last=True, dataloader_num_workers=-1, lr=5e-05, metric='Accuracy', print_steps=None, load_best_model_at_end=False, early_stopping={'monitor': 'eval_epoch_metric', 'patience': 4, 'mode': 'max'}, model_checkpoint={'every_n_epoch': 1}, tensorboard=None, loss='CrossEntropyLoss', optimizer='Adam', lr_scheduler_type='linear', warmup_ratio=0.0, warmup_steps=0, device_str=None, n_gpu=-1, sync_bn=False, freeze_bn=False)\n",
"[epoch 1/2] loss=2.402, metric=0.0, eval_loss=2.254, eval_metric=0.0: 100%|██████████| 10/10 [00:32<00:00, 3.25s/step]\n",
"[epoch 2/2] loss=1.88, metric=0.1, eval_loss=1.855, eval_metric=0.1: 100%|██████████| 10/10 [00:22<00:00, 1.14step/s] "
]
}
],
"source": [
"# start training, it will take about 2 minute on a cpu machine.\n",
"op.train(training_config, train_dataset=train_data, eval_dataset=eval_data)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"### If you see the two epochs progress bar finish its schedule and a `quick_start_output` folder result, it means you succeeded."
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}