-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnet.py
69 lines (54 loc) · 2.03 KB
/
net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable as V
class ModifiableModule(nn.Module):
def params(self):
return [p for _, p in self.named_params()]
def named_leaves(self):
return []
def named_submodules(self):
return []
def named_params(self):
subparams = []
for name, mod in self.named_submodules():
for subname, param in mod.named_params():
subparams.append((name + '.' + subname, param))
return self.named_leaves() + subparams
def set_param(self, name, param):
if '.' in name:
n = name.split('.')
module_name = n[0]
rest = '.'.join(n[1:])
for name, mod in self.named_submodules():
if module_name == name:
mod.set_param(rest, param)
break
else:
setattr(self, name, param)
def copy(self, other, same_var=False):
for name, param in other.named_params():
if not same_var:
param = V(param.data.clone(), requires_grad=True)
self.set_param(name, param)
class GradLinear(ModifiableModule):
def __init__(self, *args, **kwargs):
super().__init__()
ignore = nn.Linear(*args, **kwargs)
self.weights = V(ignore.weight.data, requires_grad=True)
self.bias = V(ignore.bias.data, requires_grad=True)
def forward(self, x):
return F.linear(x, self.weights, self.bias)
def named_leaves(self):
return [('weights', self.weights), ('bias', self.bias)]
class SineModel(ModifiableModule):
def __init__(self):
super().__init__()
self.hidden1 = GradLinear(1, 40)
self.hidden2 = GradLinear(40, 40)
self.out = GradLinear(40, 1)
def forward(self, x):
x = F.relu(self.hidden1(x))
x = F.relu(self.hidden2(x))
return self.out(x)
def named_submodules(self):
return [('hidden1', self.hidden1), ('hidden2', self.hidden2), ('out', self.out)]