-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathioueval.py
83 lines (67 loc) · 2.84 KB
/
ioueval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#!/usr/bin/env python3
# This file is covered by the LICENSE file in the root of this project.
import numpy as np
import torch
class iouEval:
def __init__(self, n_classes, device, ignore=None):
self.n_classes = n_classes
self.device = device
# if ignore is larger than n_classes, consider no ignoreIndex
self.ignore = torch.tensor(ignore).long()
self.include = torch.tensor(
[n for n in range(self.n_classes) if n not in self.ignore]).long()
# print("[IOU EVAL] IGNORE: ", self.ignore)
# print("[IOU EVAL] INCLUDE: ", self.include)
self.reset()
def num_classes(self):
return self.n_classes
def reset(self):
self.conf_matrix = torch.zeros(
(self.n_classes, self.n_classes), device=self.device).long()
self.ones = None
self.last_scan_size = None # for when variable scan size is used
def addBatch(self, x, y): # x=preds, y=targets
# if numpy, pass to pytorch
# to tensor
if isinstance(x, np.ndarray):
x = torch.from_numpy(np.array(x)).long().to(self.device)
if isinstance(y, np.ndarray):
y = torch.from_numpy(np.array(y)).long().to(self.device)
# sizes should be "batch_size x H x W"
x_row = x.reshape(-1) # de-batchify
y_row = y.reshape(-1) # de-batchify
# idxs are labels and predictions
idxs = torch.stack([x_row, y_row], dim=0)
# ones is what I want to add to conf when I
if self.ones is None or self.last_scan_size != idxs.shape[-1]:
self.ones = torch.ones((idxs.shape[-1]), device=self.device).long()
self.last_scan_size = idxs.shape[-1]
# make confusion matrix (cols = gt, rows = pred)
self.conf_matrix = self.conf_matrix.index_put_(
tuple(idxs), self.ones, accumulate=True)
# print(self.tp.shape)
# print(self.fp.shape)
# print(self.fn.shape)
def getStats(self):
# remove fp and fn from confusion on the ignore classes cols and rows
conf = self.conf_matrix.clone().double()
conf[self.ignore] = 0
conf[:, self.ignore] = 0
# get the clean stats
tp = conf.diag()
fp = conf.sum(dim=1) - tp
fn = conf.sum(dim=0) - tp
return tp, fp, fn
def getIoU(self):
tp, fp, fn = self.getStats()
intersection = tp
union = tp + fp + fn + 1e-15
iou = intersection / union
iou_mean = (intersection[self.include] / union[self.include]).mean()
return iou_mean, iou # returns "iou mean", "iou per class" ALL CLASSES
def getacc(self):
tp, fp, fn = self.getStats()
total_tp = tp.sum()
total = tp[self.include].sum() + fp[self.include].sum() + 1e-15
acc_mean = total_tp / total
return acc_mean # returns "acc mean"