From db97b78a2c25bd5cddbbd69a2bca0cc1c449d767 Mon Sep 17 00:00:00 2001 From: Aiden Nibali Date: Sun, 30 Aug 2015 14:47:16 +1000 Subject: [PATCH] Calculate GRU input sums together --- model/GRU.lua | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/model/GRU.lua b/model/GRU.lua index 11ae34e3..74f1f51e 100644 --- a/model/GRU.lua +++ b/model/GRU.lua @@ -6,7 +6,8 @@ Creates one timestep of one GRU Paper reference: http://arxiv.org/pdf/1412.3555v1.pdf ]]-- function GRU.gru(input_size, rnn_size, n, dropout) - dropout = dropout or 0 + dropout = dropout or 0 + -- there are n+1 inputs (hiddens on each layer and x) local inputs = {} table.insert(inputs, nn.Identity()()) -- x @@ -14,30 +15,30 @@ function GRU.gru(input_size, rnn_size, n, dropout) table.insert(inputs, nn.Identity()()) -- prev_h[L] end - function new_input_sum(insize, xv, hv) - local i2h = nn.Linear(insize, rnn_size)(xv) - local h2h = nn.Linear(rnn_size, rnn_size)(hv) - return nn.CAddTable()({i2h, h2h}) - end - local x, input_size_L local outputs = {} for L = 1,n do local prev_h = inputs[L+1] -- the input to this layer - if L == 1 then + if L == 1 then x = OneHot(input_size)(inputs[1]) input_size_L = input_size - else - x = outputs[(L-1)] + else + x = outputs[(L-1)] if dropout > 0 then x = nn.Dropout(dropout)(x) end -- apply dropout, if any input_size_L = rnn_size end -- GRU tick + -- evaluate the input sums at once for efficiency + local i2h = nn.Linear(input_size_L, 2 * rnn_size)(x) + local h2h = nn.Linear(rnn_size, 2 * rnn_size)(prev_h) + local all_input_sums = nn.CAddTable()({i2h, h2h}) + local reshaped = nn.Reshape(2, rnn_size)(all_input_sums) + local n1, n2 = nn.SplitTable(2)(reshaped):split(2) -- forward the update and reset gates - local update_gate = nn.Sigmoid()(new_input_sum(input_size_L, x, prev_h)) - local reset_gate = nn.Sigmoid()(new_input_sum(input_size_L, x, prev_h)) + local update_gate = nn.Sigmoid()(n1) + local reset_gate = nn.Sigmoid()(n2) -- compute candidate hidden state local gated_hidden = nn.CMulTable()({reset_gate, prev_h}) local p2 = nn.Linear(rnn_size, rnn_size)(gated_hidden) @@ -50,7 +51,8 @@ function GRU.gru(input_size, rnn_size, n, dropout) table.insert(outputs, next_h) end --- set up the decoder + + -- set up the decoder local top_h = outputs[#outputs] if dropout > 0 then top_h = nn.Dropout(dropout)(top_h) end local proj = nn.Linear(rnn_size, input_size)(top_h)