logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

223 lines
134 KiB

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Fine-tune from ImageNet pretrained model to train a bird classification model."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('..')\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"import random\n",
"import os\n",
"from resnet_image_embedding import ResnetImageEmbedding\n",
"from towhee.trainer.training_config import TrainingConfig\n",
"from torchvision import transforms\n",
"from torchvision.datasets import ImageFolder\n",
"\n",
"# build a resnet32 op with 400 classes output\n",
"op = ResnetImageEmbedding('resnet34', num_classes=400)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### download [BIRDS 400](https://www.kaggle.com/gpiosenka/100-bird-species) from kaggle dataset.\n",
"### replace this the `bird_400_path` with your download path."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# replace with your own dataset path.\n",
"bird_400_path = '/home/zhangchen/zhangchen_workspace/dataset/bird_400/'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### build a training config:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"training_config = TrainingConfig()\n",
"training_config.batch_size = 32\n",
"training_config.epoch_num = 4\n",
"training_config.output_dir = 'bird_output'"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"std = (0.229, 0.224, 0.229)\n",
"mean = (0.485, 0.456, 0.406)\n",
"transform = transforms.Compose([transforms.ToTensor(),\n",
" transforms.Normalize(mean=mean, std=std),\n",
" transforms.RandomHorizontalFlip(p=0.5)\n",
" ])\n",
"train_data = ImageFolder(os.path.join(bird_400_path, 'train'), transform=transform)\n",
"eval_data = ImageFolder(os.path.join(bird_400_path, 'valid'), transform=transform)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### start to train."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-03-04 17:41:49,336 - 139967684245312 - trainer.py-trainer:390 - WARNING: TrainingConfig(output_dir='bird_output', overwrite_output_dir=True, eval_strategy='epoch', eval_steps=None, batch_size=32, val_batch_size=-1, seed=42, epoch_num=4, dataloader_pin_memory=True, dataloader_drop_last=True, dataloader_num_workers=0, lr=5e-05, metric='Accuracy', print_steps=None, load_best_model_at_end=False, early_stopping={'monitor': 'eval_epoch_metric', 'patience': 4, 'mode': 'max'}, model_checkpoint={'every_n_epoch': 1}, tensorboard={'log_dir': None, 'comment': ''}, loss='CrossEntropyLoss', optimizer='Adam', lr_scheduler_type='linear', warmup_ratio=0.0, warmup_steps=0, device_str=None, freeze_bn=False)\n",
"[epoch 1/4] loss=2.244, metric=0.668, eval_loss=2.222, eval_metric=0.935: 100%|███████████████████████████████████████████| 1824/1824 [03:39<00:00, 8.31step/s]\n",
"[epoch 2/4] loss=0.403, metric=0.939, eval_loss=0.426, eval_metric=0.964: 100%|███████████████████████████████████████████| 1824/1824 [03:38<00:00, 8.35step/s]\n",
"[epoch 3/4] loss=0.195, metric=0.972, eval_loss=0.219, eval_metric=0.978: 100%|███████████████████████████████████████████| 1824/1824 [03:39<00:00, 8.31step/s]\n",
"[epoch 4/4] loss=0.115, metric=0.985, eval_loss=0.14, eval_metric=0.979: 100%|████████████████████████████████████████████| 1824/1824 [03:37<00:00, 8.61step/s]"
]
}
],
"source": [
"op.train(training_config, train_dataset=train_data, eval_dataset=eval_data)"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-03-04 17:25:15,019 - 140613268227904 - image.py-image:725 - WARNING: Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9ya8k25beif12b2bufpqI270u82WSLJGsKkAJ1H+siQYaqQaSphpIgFBAEagiq6gima+/TUSczt2t2a0Ga5vHTUJM4l0NxMFz4CHui3PiHHcz23uv9a2vUa01/vL6y+svr//yXvr/32/gL6+/vP7y+v/++svi/MvrL6//Ql9/WZx/ef3l9V/o6y+L8y+vv7z+C339ZXH+5fWX13+hL/uPffE//P7v2/fffeJ6ydRimMLAFq/81V99QSNyvbyQU6EVTdoyOSWcB1RCkaBWavaUrNG6YOyGUjNxW7m8Rf4f//f/F0+fnvhn//Rv+Jf/8p/y7R9+z2/+/o98880v+eabE8e7EzknSql8/+ETL28vKK359ttv+V//t3/Lx0/P2DAwTieU8oChFFDaovxIrY1SMlAxxqC1prVKjIlSEqiGMRqtFcuykHOmlIK2AecD3jug0Vql1Qat0Rq0VlFKk+JGjBsAWmu0NigFSsG2bczzjFIK7z3jeEApua5aGYZhYl1XPn36RAiBEAaMMQyTobYsPwtNTpnWFMuysa0r3gcOw8QwBrR1tFZZtxljLKUUKBljLN772/sxGrz3WKNptQCFnCI5Z2gKUCgUKEBXrLXo/mattVhrMcaQcyGXyhozy7YRt0zTCm3c7f1qpXHOA5qSG8YYANZ1JZdIo6G0/L1SCpQGGlobilIopdBa3661XFf5e6Wg1kylYa1BGdjixrIuuFaxxqCUnDe1VlqT92OM3Eel5P4750kp0UqhNSglk8hYazlNBxqV1iolZXJOHKYDIXhSyqQUUejbffXeY5RGAfT3W0ql1krwI94P/T0ZLueZ0+mEtRalYF1nrvOZ/8P/8f+k/uzF+dvf/paaDd6foDlo4JwDNNu2sa4btYLBgoJSC6SG0hUj11xuaCrUGmlcgCs5R4wZOBwOfPj+e85vr8zXK8M48v79F3z95de8ezdhjOEtvnF+u/L999+BVnz11ddYa1lTRpk/8OHTEzFmFFpuvNIY1R+2fpEVUEqhtXZbII0mFxRFra0v3EZrjVILumRSAqU0tSZaqfJ1YBpHci6gFM55ai3U2vrPl58qCz2jlCbnRM4J55z8vFbJJVFr7Q+/wTlZBFpDybK4W2nye5CLOU0Hhv5AgKKkSG5FPoVSsvFsK8ZYam395ykwFhUUxhpyzP3zK2pplFJQSmGtR2uFsVY2pSZf21/ysH8eu7XWqDRouv99o9VKplJrA7Q8G0UWZ22V1v+dVgpjDNa5fq0KtVVQpt8RQGmUkoddKUUpiZwzrVW00RirUQ2s1ow+4LSC1m4bbKv9/mvzD+6L1hrnHLVW0JpW5ZmorV8vY2QDrpBbo9ZKqYVS5FnSyvTrKn8abWhFvk+epQo0nHMcDkeGYWTbEgrDMFSGYcAYTa0FHxy5hP/k+vtHF+cP3z/xePeO0zTg7HDbaUupxFVODG8s1gSSymzLQkoZ50FpTUmFLW2UpCh5o5IoNbJcz5T0itKVcQqUkriez1jr+MUvfsHpeMTZwPV65enpzPfff8933/3AMA4EP3KZr1wvM9N4wJozrWpKq7TaH1IawWVazshRqmhAKkUeVvab1fqDrfHDQKWhcgZlbzdTTkLTv1MW9LbJyauU6jfa9MVfqTX3k1VuUq2FnBWl5H4K7A8cGKMIIeC9xzmHtZaGLJ6SMtu2kbNsKs4N+MmjVAMqpWRijORW++Jy/WuNmDZKrfjqe8VgqAVqhYY8kKU0cpWfbZQsYmcdxum+UKGWjKxHOcVQyEmp9W3RVJpsTEoqC60MTclCra1ScsUYjbEKioa+4Iy1GGPlgdcNqgajUajbSQmgUWjV76uCXCu1FGpWYDW1yu9HKdkkmkYrRdN9g1BS6egqJ7AxcpJqHfvG20ArbG1449BKU2Vv6BsesonlSikVrQ3BD1gtp59Ck0r80Sa3L9wg176pvohd3/zlZ1hn8eHYF/NPWJzrdUPfe4IfOE73QCWXDa0WlNJY4/DWMYQDOVbSlohpppSMMQq0ue10uSjWNRG3lW1d+dMf/kDL8O7dPcfDiNKGuCZ+9fNvesmY+O7bD/z+D7/j+x9+4PXtjaenTzw/PzOvC0+vb0zHO7z3NCxUTcoNtMYZQ22FlCO1VpxzaKVQfePXRqO0IedMyrJorLJ4N2BNRVvXyygp+VptFFVuCzAlOfW899DLMKUatcoCSDnjnCeEyrqutNvp1mgUbjusAucN1mm0kYcxxijbNv10quW2QK0CM460Vogxsm0bVUk1Mwxe3q8xlCIthla2l3SKmAulVll0tZHr/n6lzFTa4ELAOdsfMg3OkVIk5dg3Mi3XQ8l1VtagqwIl5WTJGRsc1hhK3k+T/tA6g7GySSotG2gpmVQLSlkwoJqUulppUJVWK0pJNWSdxXlLrYUYN3LK6P25rpWqNFpJBYJWUp2kSuv32DlLawqr9o3F9tI6y4adDWYv91q7lftaa6wxt1Lf28DgBwBS2jdkWcAAwWnGMBDCyOA9WhlSytAqx8OR2ira0Mt3xf393U9bnMfpyHE6kWPhdXsBVVE6gdqgVFST8mtdEnFNXC4XUAXrFCkrUtpItfFw/8DPTj8np4Xz2w+c3z4whoAzmpYbz5+e+e1vf8dxPOGd43e/+R3LsjAdJ/7u7/6O1/MrrWU+vXzij9/+iX/zv/xrYpz56y9+zRdffs2Hjy9clw3rDaCwxmB0I/V+0Wiw/aFrCukJspRpKSZQkFLC+4C1ltL6bkyTddIaOVVKKbcdspRKSplSFLUkct8IrNOUWmS3dLaXdYqmKsooQBNzZI0rqj9QNUZyKXIKF+mPrXVMk6aUwrqufVOBXKJsJMESRkdDkbP0OLlvNEpJ/2ucw/T+L6dK6S2eRhZmbopaQTVoSklvpI3UlcqgKL1vL1I5KMilEXvpqJRC9Y1uX6Dee4yx0to4WQDGKLQ2pBJpVTacVKS6yLnhPVIeA7qfWHJM9/ZDgTUaYzStaWiZthVakd7VuYDxXja0Kie6oUHLpJTJOd9aCKUU67oS44a19kcnl5ThsuCKbObK4GzA2oB3AWUb1nqMcXJ610zJhZoqWisOhwPv3r1jHMfb/VjWBVA8v3zicDgwTBMhTP2+Lr3a+QmL0+rAOm98vDyzLpFhcDw+HBgmKQnneUahMNrJbl0SSlW0MSjlmA4nTqeA0Zbnl1e29UqrBeMC0+lAWWder68sy7X3hZV//7/9u14aW8bg8d5xnAaMB2zhcOd5fH/i2x9+YIuZ59czpUbCEMhzpFRFsJ5lfuM6z+SUyXkiDIOcmEqRcibXIg+mNrfyrLRGLZVaKkor6WOrLFBjHMa4W/8yz3JRjbEoVVmWRM4b3t+xpoVSCsZoQgi3Eq211k+uyrouUt71/lCh8D6gFZTyGR+QMr1CKZQiJ9peElelSFvqi3fspyS9RPYMw9h/vmwwGoWq0ivmIg9fGDzOOlCGmAsV1XuiRqmJlBOlyXptVFIpxBRJpQLSo7GX/VqRapU+uEq5bazGGEOtssBzLQLGNHnQlQaUplbp57S2OGtRFDIVmrQSWvffnyO1NjkhAYXBWYe2gVIrqSQB9bTpLUe7gT6mn4Axl1tl9LmnlQ3OKC0bhQKjdQespCIyyqCUkVO1aYwyoGGYAsF7pmliCKNs+indTtPWZBMZx5HgBaMQPKIxjuNPW5zeWNK2ktYNakE1QRJLgVI3QVJTASLbGrleLhirUCrggsO4wGE8ss4b87KQU2GaRmniB83z00qpUjLVqljmhe+++5bD4cB0uCcMhppXXt8+kvKCcZq70wFjH0l15cOnV2pbcd7SUFgHZYvkqig5Yag0BYpeXgGtKUp/wBoKYyzWOemRc5GHMaeOUFpqkX9jLb2P0FJ69p0dQGvps3ItKKM
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"It is black baza.\n",
"probability = 0.26655182242393494\n"
]
}
],
"source": [
"# get random picture and predict it.\n",
"img_index = random.randint(0, len(eval_data))\n",
"img = eval_data[img_index][0]\n",
"img = img.numpy().transpose(1, 2, 0) # (C, H, W) -> (H, W, C)\n",
"pil_img = img * std + mean\n",
"plt.axis('off')\n",
"plt.imshow(pil_img)\n",
"plt.show()\n",
"test_img = eval_data[img_index][0].unsqueeze(0).to(op.trainer.configs.device)\n",
"logits = op.trainer.predict(test_img)\n",
"out = torch.softmax(logits, dim=-1)\n",
"probability = torch.max(out).item()\n",
"predict_num = torch.argmax(out).item()\n",
"print('It is {}.'.format(eval_data.classes[predict_num].lower()))\n",
"print('probability = {}'.format(probability))"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"### You can re-run this predicting code cell to make sure whether this bird prediction bird class is right.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 1
}