-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_transformer_mnist.py
146 lines (123 loc) · 7.33 KB
/
train_transformer_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# 对mnist数据集进行训练时需要把vision_tansformer.py中的class PatchEmbed(nn.Module):中的in_c和
# class VisionTransformer(nn.Module):中的in_c都改为1,因为mnist为单通道
# 项目讲解链接:https://www.bilibili.com/video/BV1px4y1W7zs/?spm_id_from=333.999.0.0&vd_source=6528929ff3772e61e9f5baf9b8ab1e64
import os
import argparse
import math
import shutil
import random
import numpy as np
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, datasets
import torch.optim.lr_scheduler as lr_scheduler
import classic_models
from utils.lr_methods import warmup
from utils.train_engin import train_one_epoch, evaluate
parser = argparse.ArgumentParser()
parser.add_argument('--num_classes', type=int, default=10, help='the number of classes')
parser.add_argument('--epochs', type=int, default=50, help='the number of training epoch')
parser.add_argument('--batch_size', type=int, default=32, help='batch_size for training')
parser.add_argument('--lr', type=float, default=0.0002, help='star learning rate')
parser.add_argument('--lrf', type=float, default=0.0001, help='end learning rate')
parser.add_argument('--seed', default=False, action='store_true', help='fix the initialization of parameters')
parser.add_argument('--tensorboard', default=True, action='store_true', help=' use tensorboard for visualization')
parser.add_argument('--use_amp', default=False, action='store_true', help=' training with mixed precision')
parser.add_argument('--data_path', type=str, default=r"./mnist")
parser.add_argument('--model', type=str, default="vision_transformer1", help=' select a model for training')
parser.add_argument('--device', default='cuda', help='device id (i.e. 0 or 0,1 or cpu)')
# parser.add_argument('--weights', type=str, default=r'model_pth\vit_base_patch16_224.pth', help='initial weights path')
opt = parser.parse_args()
if opt.seed:
def seed_torch(seed=7):
random.seed(seed) # Python random module.
os.environ['PYTHONHASHSEED'] = str(seed) # 为了禁止hash随机化,使得实验可复现
np.random.seed(seed) # Numpy module.
torch.manual_seed(seed) # 为CPU设置随机种子
torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
# 设置cuDNN:cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True
# 实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。
print('random seed has been fixed')
seed_torch()
def main(args):
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(args)
if opt.tensorboard:
# 这是存放你要使用tensorboard显示的数据的绝对路径
log_path = os.path.join('./results/tensorboard' , args.model)
print('Start Tensorboard with "tensorboard --logdir={}"'.format(log_path))
if os.path.exists(log_path) is False:
os.makedirs(log_path)
print("tensorboard log save in {}".format(log_path))
else:
shutil.rmtree(log_path) #当log文件存在时删除文件夹。记得在代码最开始import shutil
# 实例化一个tensorboard
tb_writer = SummaryWriter(log_path)
data_transform = {
"train": transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]),
"val": transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
}
# 使用 PyTorch 内置的 MNIST 数据集
train_dataset = datasets.MNIST(root=args.data_path, train=True, download=True, transform=data_transform["train"])
val_dataset = datasets.MNIST(root=args.data_path, train=False, download=True, transform=data_transform["val"])
if args.num_classes != 10:
raise ValueError("MNIST dataset has 10 classes, but input {}".format(args.num_classes))
nw = min([os.cpu_count(), args.batch_size if args.batch_size > 1 else 0, 8]) # number of workers
print('Using {} dataloader workers every process'.format(nw))
# 使用 DataLoader 将加载的数据集处理成批量(batch)加载模式
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=nw)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=nw)
# create model
model = classic_models.find_model_using_name(opt.model, num_classes=opt.num_classes).to(device)
# if args.weights != "":
# assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
# weights_dict = torch.load(args.weights, map_location=device)
# # 删除不需要的权重
# del_keys = ['head.weight', 'head.bias'] # 可以根据需要修改
# for k in del_keys:
# del weights_dict[k]
# model.load_state_dict(weights_dict, strict=False)
pg = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.Adam(pg, lr=args.lr)
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
best_acc = 0.
# save parameters path
save_path = os.path.join(os.getcwd(), 'results/weights', args.model)
if os.path.exists(save_path) is False:
os.makedirs(save_path)
for epoch in range(args.epochs):
# train
mean_loss, train_acc = train_one_epoch(model=model, optimizer=optimizer, data_loader=train_loader, device=device, epoch=epoch, use_amp=args.use_amp, lr_method= warmup)
scheduler.step()
# validate
val_acc = evaluate(model=model, data_loader=val_loader, device=device)
print('[epoch %d] train_loss: %.3f train_acc: %.3f val_accuracy: %.3f' % (epoch + 1, mean_loss, train_acc, val_acc))
with open(os.path.join(save_path, "vision_transformer_mnist_log.txt"), 'a') as f:
f.writelines('[epoch %d] train_loss: %.3f train_acc: %.3f val_accuracy: %.3f' % (epoch + 1, mean_loss, train_acc, val_acc) + '\n')
if opt.tensorboard:
tags = ["train_loss", "train_acc", "val_accuracy", "learning_rate"]
tb_writer.add_scalar(tags[0], mean_loss, epoch)
tb_writer.add_scalar(tags[1], train_acc, epoch)
tb_writer.add_scalar(tags[2], val_acc, epoch)
tb_writer.add_scalar(tags[3], optimizer.param_groups[0]["lr"], epoch)
# 判断当前验证集的准确率是否是最大的,如果是,则更新之前保存的权重
if val_acc > best_acc:
best_acc = val_acc
torch.save(model.state_dict(), os.path.join(save_path, "vision_transformer_mnist.pth"))
if __name__ == '__main__':
main(opt)