-
-
Notifications
You must be signed in to change notification settings - Fork 70
/
Copy pathmain.py
246 lines (195 loc) · 11.4 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
import os
import sys
import time
import random
from tqdm import tqdm
from argparse import ArgumentParser
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
if torch.__version__ == 'parrots':
from pavi import SummaryWriter
else:
from torch.utils.tensorboard import SummaryWriter
from data import PlanningDataset, SequencePlanningDataset, Comma2k19SequenceDataset
from model import PlaningNetwork, MultipleTrajectoryPredictionLoss, SequencePlanningNetwork
from utils import draw_trajectory_on_ax, get_val_metric, get_val_metric_keys
def get_hyperparameters(parser: ArgumentParser):
parser.add_argument('--batch_size', type=int, default=6)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--n_workers', type=int, default=4)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--log_per_n_step', type=int, default=20)
parser.add_argument('--val_per_n_epoch', type=int, default=1)
parser.add_argument('--resume', type=str, default='')
parser.add_argument('--M', type=int, default=5)
parser.add_argument('--num_pts', type=int, default=33)
parser.add_argument('--mtp_alpha', type=float, default=1.0)
parser.add_argument('--optimizer', type=str, default='sgd')
parser.add_argument('--sync_bn', type=bool, default=True)
parser.add_argument('--tqdm', type=bool, default=False)
parser.add_argument('--optimize_per_n_step', type=int, default=40)
try:
exp_name = os.environ["SLURM_JOB_ID"]
except KeyError:
exp_name = str(time.time())
parser.add_argument('--exp_name', type=str, default=exp_name)
return parser
def setup(rank, world_size):
torch.cuda.set_device(rank)
dist.init_process_group('nccl', init_method='tcp://localhost:%s' % os.environ['PORT'], rank=rank, world_size=world_size)
print('[%.2f]' % time.time(), 'DDP Initialized at %s:%s' % ('localhost', os.environ['PORT']), rank, 'of', world_size, flush=True)
def get_dataloader(rank, world_size, batch_size, pin_memory=False, num_workers=0):
train = Comma2k19SequenceDataset('data/comma2k19_train_non_overlap.txt', 'data/comma2k19/','train', use_memcache=False)
val = Comma2k19SequenceDataset('data/comma2k19_val_non_overlap.txt', 'data/comma2k19/','demo', use_memcache=False)
if torch.__version__ == 'parrots':
dist_sampler_params = dict(num_replicas=world_size, rank=rank, shuffle=True)
else:
dist_sampler_params = dict(num_replicas=world_size, rank=rank, shuffle=True, drop_last=True)
train_sampler = DistributedSampler(train, **dist_sampler_params)
val_sampler = DistributedSampler(val, **dist_sampler_params)
loader_args = dict(num_workers=num_workers, persistent_workers=True if num_workers > 0 else False, prefetch_factor=2, pin_memory=pin_memory)
train_loader = DataLoader(train, batch_size, sampler=train_sampler, **loader_args)
val_loader = DataLoader(val, batch_size=1, sampler=val_sampler, **loader_args)
return train_loader, val_loader
def cleanup():
dist.destroy_process_group()
class SequenceBaselineV1(nn.Module):
def __init__(self, M, num_pts, mtp_alpha, lr, optimizer, optimize_per_n_step=40) -> None:
super().__init__()
self.M = M
self.num_pts = num_pts
self.mtp_alpha = mtp_alpha
self.lr = lr
self.optimizer = optimizer
self.net = SequencePlanningNetwork(M, num_pts)
self.optimize_per_n_step = optimize_per_n_step # for the gru module
@staticmethod
def configure_optimizers(args, model):
if args.optimizer == 'sgd':
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.01)
elif args.optimizer == 'adam':
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.01)
elif args.optimizer == 'adamw':
optimizer = optim.AdamW(model.parameters(), lr=args.lr, )
else:
raise NotImplementedError
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 20, 0.9)
return optimizer, lr_scheduler
def forward(self, x, hidden=None):
if hidden is None:
hidden = torch.zeros((2, x.size(0), 512)).to(self.device)
return self.net(x, hidden)
def main(rank, world_size, args):
if rank == 0:
writer = SummaryWriter()
train_dataloader, val_dataloader = get_dataloader(rank, world_size, args.batch_size, False, args.n_workers)
model = SequenceBaselineV1(args.M, args.num_pts, args.mtp_alpha, args.lr, args.optimizer, args.optimize_per_n_step)
use_sync_bn = args.sync_bn
if use_sync_bn:
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = model.cuda()
optimizer, lr_scheduler = model.configure_optimizers(args, model)
model: SequenceBaselineV1
if args.resume and rank == 0:
print('Loading weights from', args.resume)
model.load_state_dict(torch.load(args.resume), strict=True)
dist.barrier()
model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], find_unused_parameters=True, broadcast_buffers=False)
loss = MultipleTrajectoryPredictionLoss(args.mtp_alpha, args.M, args.num_pts, distance_type='angle')
num_steps = 0
disable_tqdm = (not args.tqdm) or (rank != 0)
for epoch in tqdm(range(args.epochs), disable=disable_tqdm, position=0):
train_dataloader.sampler.set_epoch(epoch)
for batch_idx, data in enumerate(tqdm(train_dataloader, leave=False, disable=disable_tqdm, position=1)):
seq_inputs, seq_labels = data['seq_input_img'].cuda(), data['seq_future_poses'].cuda()
bs = seq_labels.size(0)
seq_length = seq_labels.size(1)
hidden = torch.zeros((2, bs, 512)).cuda()
total_loss = 0
for t in tqdm(range(seq_length), leave=False, disable=disable_tqdm, position=2):
num_steps += 1
inputs, labels = seq_inputs[:, t, :, :, :], seq_labels[:, t, :, :]
pred_cls, pred_trajectory, hidden = model(inputs, hidden)
cls_loss, reg_loss = loss(pred_cls, pred_trajectory, labels)
total_loss += (cls_loss + args.mtp_alpha * reg_loss.mean()) / model.module.optimize_per_n_step
if rank == 0 and (num_steps + 1) % args.log_per_n_step == 0:
# TODO: add a customized log function
writer.add_scalar('train/epoch', epoch, num_steps)
writer.add_scalar('loss/cls', cls_loss, num_steps)
writer.add_scalar('loss/reg', reg_loss.mean(), num_steps)
writer.add_scalar('loss/reg_x', reg_loss[0], num_steps)
writer.add_scalar('loss/reg_y', reg_loss[1], num_steps)
writer.add_scalar('loss/reg_z', reg_loss[2], num_steps)
writer.add_scalar('param/lr', optimizer.param_groups[0]['lr'], num_steps)
if (t + 1) % model.module.optimize_per_n_step == 0:
hidden = hidden.clone().detach()
optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # TODO: move to args
optimizer.step()
if rank == 0:
writer.add_scalar('loss/total', total_loss, num_steps)
total_loss = 0
if not isinstance(total_loss, int):
optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # TODO: move to args
optimizer.step()
if rank == 0:
writer.add_scalar('loss/total', total_loss, num_steps)
lr_scheduler.step()
if (epoch + 1) % args.val_per_n_epoch == 0:
if rank == 0:
# save model
ckpt_path = os.path.join(writer.log_dir, 'epoch_%d.pth' % epoch)
torch.save(model.module.state_dict(), ckpt_path)
print('[Epoch %d] checkpoint saved at %s' % (epoch, ckpt_path))
model.eval()
with torch.no_grad():
saved_metric_epoch = get_val_metric_keys()
for batch_idx, data in enumerate(tqdm(val_dataloader, leave=False, disable=disable_tqdm, position=1)):
seq_inputs, seq_labels = data['seq_input_img'].cuda(), data['seq_future_poses'].cuda()
bs = seq_labels.size(0)
seq_length = seq_labels.size(1)
hidden = torch.zeros((2, bs, 512), device=seq_inputs.device)
for t in tqdm(range(seq_length), leave=False, disable=True, position=2):
inputs, labels = seq_inputs[:, t, :, :, :], seq_labels[:, t, :, :]
pred_cls, pred_trajectory, hidden = model(inputs, hidden)
metrics = get_val_metric(pred_cls, pred_trajectory.view(-1, args.M, args.num_pts, 3), labels)
for k, v in metrics.items():
saved_metric_epoch[k].append(v.float().mean().item())
dist.barrier() # Wait for all processes
# sync
metric_single = torch.zeros((len(saved_metric_epoch), ), dtype=torch.float32, device='cuda')
counter_single = torch.zeros((len(saved_metric_epoch), ), dtype=torch.int32, device='cuda')
# From Python 3.6 onwards, the standard dict type maintains insertion order by default.
# But, programmers should not rely on it.
for i, k in enumerate(sorted(saved_metric_epoch.keys())):
metric_single[i] = np.mean(saved_metric_epoch[k])
counter_single[i] = len(saved_metric_epoch[k])
metric_gather = [torch.zeros((len(saved_metric_epoch), ), dtype=torch.float32, device='cuda')[None] for _ in range(world_size)]
counter_gather = [torch.zeros((len(saved_metric_epoch), ), dtype=torch.int32, device='cuda')[None] for _ in range(world_size)]
dist.all_gather(metric_gather, metric_single[None])
dist.all_gather(counter_gather, counter_single[None])
if rank == 0:
metric_gather = torch.cat(metric_gather, dim=0) # [world_size, num_metric_keys]
counter_gather = torch.cat(counter_gather, dim=0) # [world_size, num_metric_keys]
metric_gather_weighted_mean = (metric_gather * counter_gather).sum(0) / counter_gather.sum(0)
for i, k in enumerate(sorted(saved_metric_epoch.keys())):
writer.add_scalar(k, metric_gather_weighted_mean[i], num_steps)
dist.barrier()
model.train()
cleanup()
if __name__ == "__main__":
print('[%.2f]' % time.time(), 'starting job...', os.environ['SLURM_PROCID'], 'of', os.environ['SLURM_NTASKS'], flush=True)
parser = ArgumentParser()
parser = get_hyperparameters(parser)
args = parser.parse_args()
setup(rank=int(os.environ['SLURM_PROCID']), world_size=int(os.environ['SLURM_NTASKS']))
main(rank=int(os.environ['SLURM_PROCID']), world_size=int(os.environ['SLURM_NTASKS']), args=args)