-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathVQAperformance.py
42 lines (36 loc) · 1.21 KB
/
VQAperformance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from ignite.metrics.metric import Metric
import numpy as np
from scipy import stats
class VQAPerformance(Metric):
"""
Evaluation of VQA methods using SROCC, KROCC, PLCC, RMSE.
`update` must receive output of the form (y_pred, y).
"""
def reset(self):
self._rq = []
self._mq = []
self._aq = []
self._y = []
def update(self, output):
y_pred, y = output
self._y.append(y[0].item())
self._rq.append(y_pred[0][0].item())
self._mq.append(y_pred[1][0].item())
self._aq.append(y_pred[2][0].item())
def compute(self):
sq = np.reshape(np.asarray(self._y), (-1,))
rq = np.reshape(np.asarray(self._rq), (-1,))
mq = np.reshape(np.asarray(self._mq), (-1,))
aq = np.reshape(np.asarray(self._aq), (-1,))
SROCC = stats.spearmanr(sq, rq)[0]
KROCC = stats.stats.kendalltau(sq, rq)[0]
PLCC = stats.pearsonr(sq, mq)[0]
RMSE = np.sqrt(np.power(sq-aq, 2).mean())
return {'SROCC': SROCC,
'KROCC': KROCC,
'PLCC': PLCC,
'RMSE': RMSE,
'sq': sq,
'rq': rq,
'mq': mq,
'aq': aq}