-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathLRP_linear_layer.py
52 lines (43 loc) · 2.33 KB
/
LRP_linear_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
'''
@author: Leila Arras
@maintainer: Leila Arras
@date: 21.06.2017
@version: 1.0+
@copyright: Copyright (c) 2017, Leila Arras, Gregoire Montavon, Klaus-Robert Mueller, Wojciech Samek
@license: see LICENSE file in repository root
'''
import numpy as np
from numpy import newaxis as na
def lrp_linear(hin, w, b, hout, Rout, bias_nb_units, eps, bias_factor=0.0, debug=False):
"""
LRP for a linear layer with input dim D and output dim M.
Args:
- hin: forward pass input, of shape (D,)
- w: connection weights, of shape (D, M)
- b: biases, of shape (M,)
- hout: forward pass output, of shape (M,) (unequal to np.dot(w.T,hin)+b if more than one incoming layer!)
- Rout: relevance at layer output, of shape (M,)
- bias_nb_units: total number of connected lower-layer units (onto which the bias/stabilizer contribution is redistributed for sanity check)
- eps: stabilizer (small positive number)
- bias_factor: set to 1.0 to check global relevance conservation, otherwise use 0.0 to ignore bias/stabilizer redistribution (recommended)
Returns:
- Rin: relevance at layer input, of shape (D,)
"""
sign_out = np.where(hout[na, :] >= 0, 1., -1.) # shape (1, M)
# print(sign_out)
# 13 分子
numer = (w * hin[:, na]) + (bias_factor * (b[na, :] * 1. + eps * sign_out * 1.) / bias_nb_units) # shape (D, M)
# Note: here we multiply the bias_factor with both the bias b and the stabilizer eps since in fact
# using the term (b[na,:]*1. + eps*sign_out*1.) / bias_nb_units in the numerator is only useful for sanity check
# (in the initial paper version we were using (bias_factor*b[na,:]*1. + eps*sign_out*1.) / bias_nb_units instead)
# 13 分母
denom = hout[na, :] + (eps * sign_out * 1.) # shape (1, M)
message = (numer / denom) * Rout[na, :] # shape (D, M)
Rin = message.sum(axis=1) # shape (D,)
if debug:
print("local diff: ", Rout.sum() - Rin.sum())
# Note:
# - local layer relevance conservation if bias_factor==1.0 and bias_nb_units==D (i.e. when only one incoming layer)
# - global network relevance conservation if bias_factor==1.0 and bias_nb_units set accordingly to the total number of lower-layer connections
# -> can be used for sanity check
return Rin