-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathinit.lua
36 lines (33 loc) · 1.16 KB
/
init.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
require 'nn'
require 'nngraph'
lstm = {}
lstm.utils = {}
-- An important function to make recursive/recurrent neural network work.
-- Copy (same pointer) all parameters from sharing unit to target unit
function lstm.utils.share_parameters(sharing, target)
-- if sharing.parameters then
-- local s_params, s_grad_params = sharing:parameters()
-- local t_params, t_grad_params = target:parameters()
-- for i = 1, #s_params do
-- t_params[i]:set(s_params[i])
-- t_grad_params[i]:set(s_grad_params[i])
-- end
-- else
-- error('no parameters to share')
-- end
if torch.isTypeOf(target, 'nn.gModule') then
for i = 1, #target.forwardnodes do
local node = target.forwardnodes[i]
if node.data.module then
lstm.utils.share_parameters(
sharing.forwardnodes[i].data.module,
node.data.module)
end
end
elseif torch.isTypeOf(target, 'nn.Module') then
target:share(sharing, 'weight', 'bias', 'gradWeight', 'gradBias')
else
error('cannot share parameters of the argument type')
end
end
include('LSTM.lua')