-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathinfer_cam.py
107 lines (81 loc) · 4.06 KB
/
infer_cam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import argparse
import os
from datetime import datetime
TIMESTAMP = "{0:%Y-%m-%d-%H-%M-%S/}".format(datetime.now())
parser = argparse.ArgumentParser()
parser.add_argument("--gpu", default='0,1,2', type=str, help="gpu")
parser.add_argument("--config", default='configs/voc.yaml', type=str, help="config")
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from torch import multiprocessing
from tqdm import tqdm
import numpy as np
from dataset import voc
from network import resnet_cam
from collections import OrderedDict
from utils import imutils
def makedirs(path):
if os.path.exists(path) is False:
os.makedirs(path)
return True
def _infer_cam(pid, model=None, dataset=None, config=None):
data_loader = torch.utils.data.DataLoader(dataset[pid], batch_size=1, shuffle=False, num_workers=2, pin_memory=False)
model.eval()
cam_dir = os.path.join(config.exp.backbone, config.exp.cam_dir)
makedirs(cam_dir)
with torch.no_grad(), torch.cuda.device(pid):
model.cuda()
for _, data in tqdm(enumerate(data_loader), total=len(data_loader), ncols=100,):
img_name, input_list, labels, size_ = data['name'], data['img'], data['label'], data['size']
#inputs = inputs.to()
#labels = labels.to(inputs.device)
img_size = input_list[0].shape[-2:]
strided_size = imutils.get_strided_size(size_, 4)
cam_list = []
for inputs in input_list:
inputs = inputs[0].cuda(non_blocking=True)
_, cams = model(inputs, return_cam=True)
cams_ = torch.max(cams[0], cams[1].flip(-1))
cam_list.append(cams_)
#labels = labels.to(outputs.device)
strided_cam = torch.sum(torch.stack([F.interpolate(torch.unsqueeze(cam, 0), strided_size, mode='bilinear', align_corners=False)[0] for cam in cam_list]), 0)
resized_cam_list = [F.interpolate(cam.unsqueeze(0), img_size, mode='bilinear', align_corners=False)[0] for cam in cam_list]
out_cam = torch.sum(torch.stack(resized_cam_list, dim=0), dim=0)
valid_label = torch.nonzero(labels[0])[:,0]
strided_cam = strided_cam[valid_label]
strided_cam /= F.adaptive_max_pool2d(strided_cam, (1, 1)) + 1e-5
high_res_cam = out_cam[valid_label,:,:]
high_res_cam /= F.adaptive_max_pool2d(high_res_cam, (1, 1)) + 1e-5
#loss = F.multilabel_soft_margin_loss(outputs, labels)
np.save(os.path.join(cam_dir, img_name[0] + '.npy'), {"keys": valid_label.cpu().numpy(), "cam": strided_cam.cpu().numpy(), "high_res": high_res_cam.cpu().numpy()})
return None
def main(config=None):
infer_dataset = voc.VOC12ClassificationDatasetMSF(voc12_root=config.dataset.root_dir, img_name_list_path=config.cam.split, scales=config.cam.scales)
n_gpus = torch.cuda.device_count()
split_dataset = [torch.utils.data.Subset(infer_dataset, np.arange(i, len(infer_dataset), n_gpus)) for i in range (n_gpus)]
# device
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# build and initialize model
model = resnet_cam.Net(n_classes=config.dataset.n_classes, backbone=config.exp.backbone)
model_path = os.path.join(config.exp.backbone, config.exp.checkpoint_dir, config.exp.final_weights)
#model = nn.DataParallel(model)
state_dict = torch.load(model_path)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
k = k.replace('module.', '')
new_state_dict[k] = v
model.load_state_dict(state_dict=new_state_dict, strict=True)
model.eval()
#_infer_cam(model=)
print('Inferring...')
makedirs(os.path.join(config.exp.backbone, config.exp.cam_dir))
multiprocessing.spawn(_infer_cam, nprocs=n_gpus, args=(model, split_dataset, config), join=True)
torch.cuda.empty_cache()
return True
if __name__=="__main__":
config = OmegaConf.load(args.config)
main(config)