-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathUnsupTrainer.lua
171 lines (140 loc) · 4.62 KB
/
UnsupTrainer.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
local UnsupTrainer = torch.class('unsup.UnsupTrainer')
function UnsupTrainer:__init(module,data)
local x,dx,ddx = module:getParameters()
self.parameters = {x,dx,ddx}
if not self.parameters or #self.parameters == 0 then
error(' I could not get parameters from module...')
end
self.module = module
self.data = data
end
function UnsupTrainer:train(params)
-- essential stuff
local data = self.data
local eta = params.eta
local etadecay = params.etadecay or 0
local maxiter = params.maxiter
local statinterval = params.statinterval or math.ceil(maxiter/100)
local etadecayinterval = params.etadecayinterval or statinterval
-- optional hessian stuff
local dohessian = params.hessian or false
local hessianinterval = params.hessianinterval or statinterval
if not dohessian then self.parameters[3] = nil end
local age = 1
local err = 0
while age <= maxiter do
-- HESSIAN
if dohessian and (age-1) % hessianinterval == 0 then
print('Computing Hessian')
params.di = age
self:computeDiagHessian(params)
print('done')
end
-- DATA
local ex = data[age]
-- SGD UPDATE
local sres = self:trainSample(ex,eta)
local serr = sres[1]
err = err + serr
-- HOOK SAMPLE
if self.hookSample then self.hookSample(self,age,ex,sres) end
if age % statinterval == 0 then
-- HOOK EPOCH
if self.hookEpoch then self.hookEpoch(self,age/statinterval) end
print('# iter= ' .. age .. ' eta= ' .. eta .. ' current error= ' .. err)
-- ETA DECAY
eta = params.eta/(1+(age/etadecayinterval)*etadecay)
err = 0
end
age = age + 1
end
end
function UnsupTrainer:computeDiagHessian(params)
local hessiansamples = params.hessiansamples or 500
local minhessian = params.minhessian or 0.02
local maxhessian = params.maxhessian or 1/minhessian
local di = params.di
print('Min Hessian=' .. minhessian .. ' Max Hessian=' .. maxhessian)
local parameters = self.parameters
local data = self.data
local module = self.module
local x = parameters[1]
local dx = parameters[2]
local ddx = parameters[3]
local knew = 1/hessiansamples
local kold = 1
self.ddeltax = self.ddeltax or ddx.new():resizeAs(ddx)
local ddeltax = self.ddeltax
ddeltax:zero()
for i=1,hessiansamples do
local ex = data[di+i]
local input = ex[1]
local target = ex[2]
module:updateOutput(input, target)
-- gradient
dx:zero()
module:updateGradInput(input, target)
module:accGradParameters(input, target)
-- hessian
ddx:zero()
module:updateDiagHessianInput(input, target)
module:accDiagHessianParameters(input, target)
if ddx:min() < 0 then
error('Negative ddx')
end
ddeltax:mul(kold)
ddeltax:add(knew,ddx)
end
print('ddeltax : min/max = ' .. ddeltax:min() .. '/' .. ddeltax:max())
ddeltax[torch.lt(ddeltax,minhessian)] = minhessian
ddeltax[torch.gt(ddeltax,maxhessian)] = maxhessian
print('ddeltax : min/max = ' .. ddeltax:min() .. '/' .. ddeltax:max())
--ddeltax:add(minhessian)
ddx:copy(ddeltax)
end
function UnsupTrainer:trainSample(ex, eta)
local module = self.module
local parameters = self.parameters
local input = ex[1]
local target = ex[2]
local x = parameters[1]
local dx = parameters[2]
local ddx = parameters[3]
local res = {module:updateOutput(input, target)}
-- clear derivatives
dx:zero()
module:updateGradInput(input, target)
module:accGradParameters(input, target)
if dx:max() > 100 or dx:min() < -100 then
print('oops large dx ' .. dx:max() .. ' ' .. dx:min())
end
if torch.ne(dx,dx):sum() > 0 then
print('oops nan dx')
--torch.save('error.bin',module)
error('oops nan dx')
end
--print('k min/max (before) =',module.decoder.D.weight:min(),module.decoder.D.weight:max())
-- do update
if not ddx then
-- regular sgd
x:add(-eta,dx)
else
-- diag hessian
x:addcdiv(-eta,dx,ddx)
end
if torch.ne(x,x):sum() > 0 then
print('oops nan x')
--torch.save('error.bin',module)
error('oops nan x')
end
module:normalize()
-- print('k min/max (after) =',module.decoder.D.weight:min(),module.decoder.D.weight:max())
-- print('k norm=',module.decoder.D.weight[1]:norm())
-- print('code min/max (after) =',module.decoder.code:min(),module.decoder.code:max())
if torch.ne(x,x):sum() > 0 then
print('oops nan x norm')
--torch.save('error.bin',module)
error('oops nan x norm')
end
return res
end