diff --git "a/competition/\347\247\221\345\244\247\350\256\257\351\243\236AI\345\274\200\345\217\221\350\200\205\345\244\247\350\265\2332024/\346\277\222\345\215\261\345\244\247\345\236\213\345\212\250\347\211\251\347\247\215\347\261\273\350\257\206\345\210\253\346\214\221\346\210\230\350\265\233_baseline.ipynb" "b/competition/\347\247\221\345\244\247\350\256\257\351\243\236AI\345\274\200\345\217\221\350\200\205\345\244\247\350\265\2332024/\346\277\222\345\215\261\345\244\247\345\236\213\345\212\250\347\211\251\347\247\215\347\261\273\350\257\206\345\210\253\346\214\221\346\210\230\350\265\233_baseline.ipynb" new file mode 100644 index 0000000..c630fd4 --- /dev/null +++ "b/competition/\347\247\221\345\244\247\350\256\257\351\243\236AI\345\274\200\345\217\221\350\200\205\345\244\247\350\265\2332024/\346\277\222\345\215\261\345\244\247\345\236\213\345\212\250\347\211\251\347\247\215\347\261\273\350\257\206\345\210\253\346\214\221\346\210\230\350\265\233_baseline.ipynb" @@ -0,0 +1,754 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e8896660-62a3-4bbc-b80b-40ffddf38cba", + "metadata": {}, + "source": [ + "

濒危大型动物种类识别挑战赛

\n", + "\n", + "> 报名链接:https://challenge.xfyun.cn/topic/info?type=species-recognition&option=ssgy&ch=dw24_AtTCK9\n", + "\n", + "\n", + "# 一、赛事背景\n", + "\n", + "全球生物多样性持续下降,尤其濒危大型动物的数量锐减,成为全球环保的紧急议题。濒危动物如大象、犀牛及大型猫科动物面临众多威胁,如栖息地丧失、非法狩猎等,这不仅威胁到生物多样性,也影响生态平衡。为此,精确及时的监测和保护措施显得至关重要。然而,传统监测方法耗时且效率低,难以满足迫切需求。借助人工智能技术,开发自动化的濒危大型动物种类识别系统,通过分析图像数据自动识别特定动物,将极大提升保护效率和准确性。因此,举办一场濒危大型动物种类识别挑战赛,具有重要的研究意义,旨在推动AI技术在野生动物保护中的应用,提高公众保护意识,促进全球生态保护工作。\n", + "\n", + "# 二、赛事任务\n", + "\n", + "本次濒危大型动物种类识别挑战赛旨在应用人工智能技术,提高对濒危大型动物的监测与保护效率。赛事提供了包括照相陷阱、无人机等多源采集的动物图像数据及其所在保护区的地理位置信息。参赛者需基于这些数据,开发出能够准确识别濒危大型动物种类的模型。挑战的关键在于处理和识别在不同光照、角度和背景条件下相似物种。\n", + "\n", + "# 三、评审规则\n", + "\n", + "## 1.数据说明\n", + "\n", + "本次濒危大型动物种类识别挑战赛为参赛选手提供了一个包含9种不同濒危动物的庞大图像数据库,这些图像通过照相陷阱、无人机等多样化的采集方法,在不同的时间、光照条件和多样化的自然背景下收集,旨在模拟实际野外监测的多种情形。\n", + "\n", + "(1)数据集细分说明:\n", + "\n", + "训练集用于模型训练,提供动物图像及其对应的类别标签;测试集用于评估模型性能,仅包含图像数据,参赛者需要预测图像对应的动物种类。\n", + "\n", + "(2)种类识别数据集具体分类:\n", + "\n", + "| Label |\n", + "| ----------- |\n", + "| Badger |\n", + "| BlackBear |\n", + "| Cheetah |\n", + "| Hare |\n", + "| LeopardCat |\n", + "| MuskDeer |\n", + "| AmurLeopard |\n", + "| Tiger |\n", + "| RedFox |\n", + "\n", + "## 2.评估指标\n", + "\n", + "本模型依据提交的结果文件,采用macro-F1分数进行评价。\n", + "\n", + "## 3.评测及排行\n", + "\n", + "(1)初赛和复赛均提供下载数据,选手在本地进行算法调试,在比赛页面提交结果。\n", + "\n", + "(2)比赛采用AB榜,A榜成绩供参赛队伍比赛中查看,最终比赛排名采用B榜最佳成绩。\n", + "\n", + "# 四、作品提交要求\n", + "\n", + "1、文件格式:按照csv格式提交\n", + "\n", + "2、文件大小:无要求\n", + "\n", + "3、提交次数限制:每支队伍每天最多3次\n", + "\n", + "4、文件详细说明:编码为UTF-8,第一行为表头,提交格式见样例\n", + "\n", + "5、不需要上传其他文件\n", + "\n", + "# 五、赛程规则\n", + "\n", + "本赛题实行一轮赛制\n", + "\n", + "## 【赛程周期】\n", + "\n", + "7月17日-8月29日\n", + "\n", + "1、7月17日10:00发布训练集、开发集、即开启比赛榜单,8月15日发布B榜测试\n", + "\n", + "2、比赛作品提交截止日期为8月29日17:00,公布名次日期为9月9日10:00\n", + "\n", + "## 【现场答辩】\n", + "\n", + "1、最终前三名团队将受邀参加科大讯飞全球1024开发者节并于现场进行答辩\n", + "\n", + "2、答辩以(10mins陈述+5mins问答)的形式进行\n", + "\n", + "3、根据作品成绩和答辩成绩综合评分(作品成绩占比70%,现场答辩分数占比30%)\n", + "\n", + "# 六、奖项设置\n", + "\n", + "本赛题设立一、二、三等奖共三名,具体详情如下:\n", + "\n", + "## 【奖项激励】\n", + "\n", + "1. TOP3团队颁发获奖证书\n", + "2. 赛道奖金,第一名5000元、第二名3000元、第三名2000元\n", + "\n", + "## 【资源激励】\n", + "\n", + "1. 讯飞开放平台优质AI能力个人资源包\n", + "2. 讯飞AI全链创业扶持资源\n", + "3. 讯飞绿色实习/就业通道\n", + "\n", + "注:\n", + "\n", + "1. 鼓励选手分享参赛心得、参赛技术攻略、大赛相关技术或产品使用体验等文章至组委会邮箱(AICompetition@iflytek.com),有机会获得大赛周边;\n", + "2. 赛事规则及奖金发放解释权归科大讯飞所有;以上全部奖金均为税前金额,将由主办方代扣代缴个人所得税。" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "69cfe08a-db6c-49f2-adc2-d52cb1387e99", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Looking in indexes: http://mirrors.aliyun.com/pypi/simple\n", + "Requirement already satisfied: timm in ./miniconda3/lib/python3.8/site-packages (1.0.7)\n", + "Collecting pandas\n", + " Downloading http://mirrors.aliyun.com/pypi/packages/f8/7f/5b047effafbdd34e52c9e2d7e44f729a0655efafb22198c45cf692cdc157/pandas-2.0.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.4 MB)\n", + "\u001b[K |████████████████████████████████| 12.4 MB 561 kB/s eta 0:00:01\n", + "\u001b[?25hRequirement already satisfied: pyyaml in ./miniconda3/lib/python3.8/site-packages (from timm) (6.0.1)\n", + "Requirement already satisfied: huggingface_hub in ./miniconda3/lib/python3.8/site-packages (from timm) (0.23.5)\n", + "Requirement already satisfied: torch in ./miniconda3/lib/python3.8/site-packages (from timm) (1.10.0+cu113)\n", + "Requirement already satisfied: safetensors in ./miniconda3/lib/python3.8/site-packages (from timm) (0.4.3)\n", + "Requirement already satisfied: torchvision in ./miniconda3/lib/python3.8/site-packages (from timm) (0.11.1+cu113)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in ./miniconda3/lib/python3.8/site-packages (from pandas) (2.8.2)\n", + "Collecting tzdata>=2022.1\n", + " Downloading http://mirrors.aliyun.com/pypi/packages/65/58/f9c9e6be752e9fcb8b6a0ee9fb87e6e7a1f6bcab2cdc73f02bb7ba91ada0/tzdata-2024.1-py2.py3-none-any.whl (345 kB)\n", + "\u001b[K |████████████████████████████████| 345 kB 550 kB/s eta 0:00:01\n", + "\u001b[?25hRequirement already satisfied: numpy>=1.20.3 in ./miniconda3/lib/python3.8/site-packages (from pandas) (1.21.4)\n", + "Requirement already satisfied: pytz>=2020.1 in ./miniconda3/lib/python3.8/site-packages (from pandas) (2021.3)\n", + "Requirement already satisfied: six>=1.5 in ./miniconda3/lib/python3.8/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n", + "Requirement already satisfied: requests in ./miniconda3/lib/python3.8/site-packages (from huggingface_hub->timm) (2.25.1)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in ./miniconda3/lib/python3.8/site-packages (from huggingface_hub->timm) (4.0.0)\n", + "Requirement already satisfied: fsspec>=2023.5.0 in ./miniconda3/lib/python3.8/site-packages (from huggingface_hub->timm) (2024.6.1)\n", + "Requirement already satisfied: tqdm>=4.42.1 in ./miniconda3/lib/python3.8/site-packages (from huggingface_hub->timm) (4.61.2)\n", + "Requirement already satisfied: packaging>=20.9 in ./miniconda3/lib/python3.8/site-packages (from huggingface_hub->timm) (21.3)\n", + "Requirement already satisfied: filelock in ./miniconda3/lib/python3.8/site-packages (from huggingface_hub->timm) (3.15.4)\n", + "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in ./miniconda3/lib/python3.8/site-packages (from packaging>=20.9->huggingface_hub->timm) (3.0.6)\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in ./miniconda3/lib/python3.8/site-packages (from requests->huggingface_hub->timm) (1.26.6)\n", + "Requirement already satisfied: certifi>=2017.4.17 in ./miniconda3/lib/python3.8/site-packages (from requests->huggingface_hub->timm) (2021.5.30)\n", + "Requirement already satisfied: idna<3,>=2.5 in ./miniconda3/lib/python3.8/site-packages (from requests->huggingface_hub->timm) (2.10)\n", + "Requirement already satisfied: chardet<5,>=3.0.2 in ./miniconda3/lib/python3.8/site-packages (from requests->huggingface_hub->timm) (4.0.0)\n", + "Requirement already satisfied: pillow!=8.3.0,>=5.3.0 in ./miniconda3/lib/python3.8/site-packages (from torchvision->timm) (8.4.0)\n", + "Installing collected packages: tzdata, pandas\n", + "Successfully installed pandas-2.0.3 tzdata-2024.1\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n" + ] + } + ], + "source": [ + "!pip install timm pandas" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c64d3134-4dec-44ff-9f5e-55b033c48edb", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "torch.manual_seed(0)\n", + "torch.backends.cudnn.deterministic = False\n", + "torch.backends.cudnn.benchmark = True\n", + "\n", + "import torchvision.models as models\n", + "import torchvision.transforms as transforms\n", + "import torchvision.datasets as datasets\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "from torch.autograd import Variable\n", + "from torch.utils.data.dataset import Dataset\n", + "import time\n", + "import glob\n", + "\n", + "import pandas as pd\n", + "import numpy as np\n", + "from PIL import Image\n", + "from tqdm import tqdm_notebook" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "01acdbde-f012-45db-ae94-beff97baa8f7", + "metadata": {}, + "outputs": [], + "source": [ + "label_list = ['Badger', 'BlackBear', 'Cheetah', 'Hare', 'LeopardCat', 'MuskDeer', 'AmurLeopard', 'Tiger', 'RedFox']\n", + "\n", + "train_path = glob.glob('./train/*/*.jpg')\n", + "np.random.shuffle(train_path)\n", + "train_label = [label_list.index(x.split('/')[-2]) for x in train_path]\n", + "\n", + "test_path = glob.glob('./testA/*.jpg')" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ab48904c-3f33-4299-a6ab-c44ee9a40670", + "metadata": {}, + "outputs": [], + "source": [ + "class AverageMeter(object):\n", + " \"\"\"Computes and stores the average and current value\"\"\"\n", + " def __init__(self, name, fmt=':f'):\n", + " self.name = name\n", + " self.fmt = fmt\n", + " self.reset()\n", + "\n", + " def reset(self):\n", + " self.val = 0\n", + " self.avg = 0\n", + " self.sum = 0\n", + " self.count = 0\n", + "\n", + " def update(self, val, n=1):\n", + " self.val = val\n", + " self.sum += val * n\n", + " self.count += n\n", + " self.avg = self.sum / self.count\n", + "\n", + " def __str__(self):\n", + " fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'\n", + " return fmtstr.format(**self.__dict__)\n", + "\n", + "class ProgressMeter(object):\n", + " def __init__(self, num_batches, *meters):\n", + " self.batch_fmtstr = self._get_batch_fmtstr(num_batches)\n", + " self.meters = meters\n", + " self.prefix = \"\"\n", + "\n", + "\n", + " def pr2int(self, batch):\n", + " entries = [self.prefix + self.batch_fmtstr.format(batch)]\n", + " entries += [str(meter) for meter in self.meters]\n", + " print('\\t'.join(entries))\n", + "\n", + " def _get_batch_fmtstr(self, num_batches):\n", + " num_digits = len(str(num_batches // 1))\n", + " fmt = '{:' + str(num_digits) + 'd}'\n", + " return '[' + fmt + '/' + fmt.format(num_batches) + ']'\n", + "def validate(val_loader, model, criterion):\n", + " batch_time = AverageMeter('Time', ':6.3f')\n", + " losses = AverageMeter('Loss', ':.4e')\n", + " top1 = AverageMeter('Acc@1', ':6.2f')\n", + " progress = ProgressMeter(len(val_loader), batch_time, losses, top1)\n", + "\n", + " # switch to evaluate mode\n", + " model.eval()\n", + "\n", + " with torch.no_grad():\n", + " end = time.time()\n", + " for i, (input, target) in tqdm_notebook(enumerate(val_loader), total=len(val_loader)):\n", + " input = input.cuda()\n", + " target = target.cuda()\n", + "\n", + " # compute output\n", + " output = model(input)\n", + " loss = criterion(output, target)\n", + "\n", + " # measure accuracy and record loss\n", + " acc = (output.argmax(1).view(-1) == target.float().view(-1)).float().mean() * 100\n", + " losses.update(loss.item(), input.size(0))\n", + " top1.update(acc, input.size(0))\n", + " # measure elapsed time\n", + " batch_time.update(time.time() - end)\n", + " end = time.time()\n", + "\n", + " # TODO: this should also be done with the ProgressMeter\n", + " print(' * Acc@1 {top1.avg:.3f}'\n", + " .format(top1=top1))\n", + " return top1\n", + "\n", + "def predict(test_loader, model, tta=10):\n", + " # switch to evaluate mode\n", + " model.eval()\n", + " \n", + " test_pred_tta = None\n", + " for _ in range(tta):\n", + " test_pred = []\n", + " with torch.no_grad():\n", + " end = time.time()\n", + " for i, (input, target) in tqdm_notebook(enumerate(test_loader), total=len(test_loader)):\n", + " input = input.cuda()\n", + " target = target.cuda()\n", + "\n", + " # compute output\n", + " output = model(input)\n", + " output = F.softmax(output, dim=1)\n", + " output = output.data.cpu().numpy()\n", + "\n", + " test_pred.append(output)\n", + " test_pred = np.vstack(test_pred)\n", + " \n", + " if test_pred_tta is None:\n", + " test_pred_tta = test_pred\n", + " else:\n", + " test_pred_tta += test_pred\n", + " \n", + " return test_pred_tta\n", + "\n", + "def train(train_loader, model, criterion, optimizer, epoch):\n", + " batch_time = AverageMeter('Time', ':6.3f')\n", + " losses = AverageMeter('Loss', ':.4e')\n", + " top1 = AverageMeter('Acc@1', ':6.2f')\n", + " progress = ProgressMeter(len(train_loader), batch_time, losses, top1)\n", + "\n", + " # switch to train mode\n", + " model.train()\n", + "\n", + " end = time.time()\n", + " for i, (input, target) in enumerate(train_loader):\n", + " input = input.cuda(non_blocking=True)\n", + " target = target.cuda(non_blocking=True)\n", + "\n", + " # compute output\n", + " output = model(input)\n", + " loss = criterion(output, target)\n", + "\n", + " # measure accuracy and record loss\n", + " losses.update(loss.item(), input.size(0))\n", + "\n", + " acc = (output.argmax(1).view(-1) == target.float().view(-1)).float().mean() * 100\n", + " top1.update(acc, input.size(0))\n", + "\n", + " # compute gradient and do SGD step\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # measure elapsed time\n", + " batch_time.update(time.time() - end)\n", + " end = time.time()\n", + "\n", + " if i % 100 == 0:\n", + " progress.pr2int(i)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "dde84a60-5939-4b69-b0a9-bb43773186b9", + "metadata": {}, + "outputs": [], + "source": [ + "class XFDataset(Dataset):\n", + " def __init__(self, img_path, img_label, transform=None):\n", + " self.img_path = img_path\n", + " self.img_label = img_label\n", + " \n", + " if transform is not None:\n", + " self.transform = transform\n", + " else:\n", + " self.transform = None\n", + " \n", + " def __getitem__(self, index):\n", + " img = Image.open(self.img_path[index]).convert('RGB')\n", + " \n", + " if self.transform is not None:\n", + " img = self.transform(img)\n", + " \n", + " return img, torch.from_numpy(np.array(self.img_label[index]))\n", + " \n", + " def __len__(self):\n", + " return len(self.img_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "aad4267f-b007-48a4-8e06-cf89df9b9f2f", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\n", + "\n", + "import timm\n", + "model = timm.create_model('efficientnet_b1', pretrained=False, num_classes=9)\n", + "model = model.cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "abc0c12c-5cdd-42fd-bfd6-9027431e46e2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/root/miniconda3/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:129: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate\n", + " warnings.warn(\"Detected call of `lr_scheduler.step()` before `optimizer.step()`. \"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ 0/96]\tTime 2.672 ( 2.672)\tLoss 4.4935e+00 (4.4935e+00)\tAcc@1 10.00 ( 10.00)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_1641/2676678360.py:51: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n", + "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n", + " for i, (input, target) in tqdm_notebook(enumerate(val_loader), total=len(val_loader)):\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3fac89d56fcf4143ae10b4d22c1aaf45", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/8 [00:00\u001b[0;34m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Epoch: '\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 33\u001b[0m \u001b[0mval_acc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalidate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/tmp/ipykernel_1641/2676678360.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(train_loader, model, criterion, optimizer, epoch)\u001b[0m\n\u001b[1;32m 108\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0mend\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 110\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 111\u001b[0m \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnon_blocking\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0mtarget\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnon_blocking\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 519\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sampler_iter\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 520\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_reset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 521\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 522\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 523\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[0;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1184\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1185\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_shutdown\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_tasks_outstanding\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1186\u001b[0;31m \u001b[0midx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1187\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_tasks_outstanding\u001b[0m \u001b[0;34m-=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1188\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIterable\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_get_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1140\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1141\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pin_memory_thread\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_alive\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1142\u001b[0;31m \u001b[0msuccess\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_try_get_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1143\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0msuccess\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1144\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_try_get_data\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 988\u001b[0m \u001b[0;31m# (bool: whether successfully get data, any: data if successful else None)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 989\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 990\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_data_queue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 991\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 992\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/lib/python3.8/queue.py\u001b[0m in \u001b[0;36mget\u001b[0;34m(self, block, timeout)\u001b[0m\n\u001b[1;32m 177\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mremaining\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0;36m0.0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mEmpty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 179\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnot_empty\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwait\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mremaining\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 180\u001b[0m \u001b[0mitem\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 181\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnot_full\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnotify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/lib/python3.8/threading.py\u001b[0m in \u001b[0;36mwait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 304\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 305\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtimeout\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 306\u001b[0;31m \u001b[0mgotit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mwaiter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0macquire\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 307\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 308\u001b[0m \u001b[0mgotit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mwaiter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0macquire\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "train_loader = torch.utils.data.DataLoader(\n", + " XFDataset(train_path[:-500], train_label[:-500], \n", + " transforms.Compose([\n", + " transforms.Resize((256, 256)),\n", + " transforms.RandomHorizontalFlip(),\n", + " transforms.RandomVerticalFlip(),\n", + " transforms.ColorJitter(brightness=.5, hue=.3),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", + " ])\n", + " ), batch_size=70, shuffle=True, num_workers=4, pin_memory=True\n", + ")\n", + "\n", + "val_loader = torch.utils.data.DataLoader(\n", + " XFDataset(train_path[-500:], train_label[-500:], \n", + " transforms.Compose([\n", + " transforms.Resize((256, 256)),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", + " ])\n", + " ), batch_size=70, shuffle=False, num_workers=4, pin_memory=True\n", + ")\n", + "\n", + "criterion = nn.CrossEntropyLoss().cuda()\n", + "optimizer = torch.optim.Adam(model.parameters(), 0.007)\n", + "scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.85)\n", + "best_acc = 0.0\n", + "for epoch in range(20):\n", + " scheduler.step()\n", + " print('Epoch: ', epoch)\n", + "\n", + " train(train_loader, model, criterion, optimizer, epoch)\n", + " val_acc = validate(val_loader, model, criterion)\n", + " \n", + " if val_acc.avg.item() > best_acc:\n", + " best_acc = round(val_acc.avg.item(), 2)\n", + " torch.save(model.state_dict(), f'./model_{best_acc}.pt')" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "41e86d06-e364-4e3f-8631-c2f3c0722e35", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_1641/2676678360.py:81: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n", + "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n", + " for i, (input, target) in tqdm_notebook(enumerate(test_loader), total=len(test_loader)):\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b8b22c4bf5ac4b598ed6bac9b8186218", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/60 [00:00