forked from uber-research/LaneGCN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
118 lines (102 loc) · 3.8 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# ---------------------------------------------------------------------------
# Learning Lane Graph Representations for Motion Forecasting
#
# Copyright (c) 2020 Uber Technologies, Inc.
#
# Licensed under the Uber Non-Commercial License (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at the root directory of this project.
#
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Written by Ming Liang, Yun Chen
# ---------------------------------------------------------------------------
import argparse
import os
os.umask(0)
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
import pickle
import sys
from importlib import import_module
import torch
from torch.utils.data import DataLoader, Sampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from data import ArgoTestDataset
from utils import Logger, load_pretrain
root_path = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, root_path)
# define parser
parser = argparse.ArgumentParser(description="Argoverse Motion Forecasting in Pytorch")
parser.add_argument(
"-m", "--model", default="angle90", type=str, metavar="MODEL", help="model name"
)
parser.add_argument("--eval", action="store_true", default=True)
parser.add_argument(
"--split", type=str, default="val", help='data split, "val" or "test"'
)
parser.add_argument(
"--weight", default="", type=str, metavar="WEIGHT", help="checkpoint path"
)
def main():
# Import all settings for experiment.
args = parser.parse_args()
model = import_module(args.model)
config, _, collate_fn, net, loss, post_process, opt = model.get_model()
# load pretrain model
ckpt_path = args.weight
if not os.path.isabs(ckpt_path):
ckpt_path = os.path.join(config["save_dir"], ckpt_path)
ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
load_pretrain(net, ckpt["state_dict"])
net.eval()
# Data loader for evaluation
dataset = ArgoTestDataset(args.split, config, train=False)
data_loader = DataLoader(
dataset,
batch_size=config["val_batch_size"],
num_workers=config["val_workers"],
collate_fn=collate_fn,
shuffle=True,
pin_memory=True,
)
# begin inference
preds = {}
gts = {}
cities = {}
for ii, data in tqdm(enumerate(data_loader)):
data = dict(data)
with torch.no_grad():
output = net(data)
results = [x[0:1].detach().cpu().numpy() for x in output["reg"]]
for i, (argo_idx, pred_traj) in enumerate(zip(data["argo_id"], results)):
preds[argo_idx] = pred_traj.squeeze()
cities[argo_idx] = data["city"][i]
gts[argo_idx] = data["gt_preds"][i][0] if "gt_preds" in data else None
# save for further visualization
res = dict(
preds = preds,
gts = gts,
cities = cities,
)
# torch.save(res,f"{config['save_dir']}/results.pkl")
# evaluate or submit
if args.split == "val":
# for val set: compute metric
from argoverse.evaluation.eval_forecasting import (
compute_forecasting_metrics,
)
# Max #guesses (K): 6
_ = compute_forecasting_metrics(preds, gts, cities, 6, 30, 2)
# Max #guesses (K): 1
_ = compute_forecasting_metrics(preds, gts, cities, 1, 30, 2)
else:
# for test set: save as h5 for submission in evaluation server
from argoverse.evaluation.competition_util import generate_forecasting_h5
generate_forecasting_h5(preds, f"{config['save_dir']}/submit.h5") # this might take awhile
import ipdb;ipdb.set_trace()
if __name__ == "__main__":
main()