-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmlp_backprop.py
93 lines (71 loc) · 1.98 KB
/
mlp_backprop.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
from torch import nn
def sigma(x):
return torch.sigmoid(x)
def sigma_prime(x):
return sigma(x)*(1-sigma(x))
torch.manual_seed(0)
L = 6
X_data = torch.rand(4, 1)
Y_data = torch.rand(1, 1)
A_list,b_list = [],[]
for _ in range(L-1):
A_list.append(torch.rand(4, 4))
b_list.append(torch.rand(4, 1))
A_list.append(torch.rand(1, 4))
b_list.append(torch.rand(1, 1))
'''
# Option 1: directly use PyTorch's autograd feature
for A in A_list:
A.requires_grad = True
for b in b_list:
b.requires_grad = True
y = X_data
for ell in range(L):
S = sigma if ell<L-1 else lambda x: x
y = S(A_list[ell]@y+b_list[ell])
# backward pass in pytorch
loss=torch.square(y-Y_data)/2
loss.backward()
print(A_list[0].grad)
'''
'''
# Option 2: construct a NN model and use backprop
class MLP(nn.Module) :
def __init__(self) :
super().__init__()
self.linear = nn.ModuleList([nn.Linear(4,4) for _ in range(L-1)])
self.linear.append(nn.Linear(4,1))
for ell in range(L):
self.linear[ell].weight.data = A_list[ell]
self.linear[ell].bias.data = b_list[ell].squeeze()
def forward(self, x) :
x = x.squeeze()
for ell in range(L-1):
x = sigma(self.linear[ell](x))
x = self.linear[-1](x)
return x
model = MLP()
loss = torch.square(model(X_data)-Y_data)/2
loss.backward()
print(model.linear[0].weight.grad)
'''
# Option 3: implement backprop yourself
y_list = [X_data]
y = X_data
for ell in range(L):
S = sigma if ell<L-1 else lambda x: x
y = S(A_list[ell]@y+b_list[ell])
y_list.append(y)
dA_list = []
db_list = []
dy = y-Y_data # dloss/dy_L
for ell in reversed(range(L)):
S = sigma_prime if ell<L-1 else lambda x: torch.ones(x.shape)
A, b, y = A_list[ell], b_list[ell], y_list[ell]
db = ... # dloss/db_ell
dA = ... # dloss/dA_ell
dy = ... # dloss/dy_{ell-1}
dA_list.insert(0, dA)
db_list.insert(0, db)
print(dA_list[0])