-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLanguageModel.lua
224 lines (185 loc) · 5.8 KB
/
LanguageModel.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
require 'torch'
require 'nn'
-- require 'VanillaRNN'
require 'GRU'
require 'GRIDGRU'
require 'LSTM'
local utils = require 'util.utils'
local LM, parent = torch.class('nn.LanguageModel', 'nn.Module')
function LM:__init(kwargs)
self.idx_to_token = utils.get_kwarg(kwargs, 'idx_to_token')
self.token_to_idx = {}
self.vocab_size = 0
for idx, token in pairs(self.idx_to_token) do
self.token_to_idx[token] = idx
self.vocab_size = self.vocab_size + 1
end
self.model_type = utils.get_kwarg(kwargs, 'model_type')
self.wordvec_dim = utils.get_kwarg(kwargs, 'wordvec_size')
self.rnn_size = utils.get_kwarg(kwargs, 'rnn_size')
self.num_layers = utils.get_kwarg(kwargs, 'num_layers')
self.dropout = utils.get_kwarg(kwargs, 'dropout')
self.batchnorm = utils.get_kwarg(kwargs, 'batchnorm')
local V, D, H = self.vocab_size, self.wordvec_dim, self.rnn_size
self.net = nn.Sequential()
self.rnns = {}
self.bn_view_in = {}
self.bn_view_out = {}
self.net:add(nn.LookupTable(V, D))
for i = 1, self.num_layers do
local prev_dim = H
if i == 1 then prev_dim = D end
local rnn
if self.model_type == 'rnn' then
rnn = nn.VanillaRNN(prev_dim, H)
elseif self.model_type == 'lstm' then
rnn = nn.LSTM(prev_dim, H)
elseif self.model_type == 'gru' then
rnn = nn.GRU(prev_dim, H)
elseif self.model_type == 'gridgru' then
rnn = nn.GRIDGRU(D, H)
end
rnn.remember_states = true
table.insert(self.rnns, rnn)
self.net:add(rnn)
if self.batchnorm == 1 then
local view_in = nn.View(1, 1, -1):setNumInputDims(3)
table.insert(self.bn_view_in, view_in)
self.net:add(view_in)
self.net:add(nn.BatchNormalization(H))
local view_out = nn.View(1, -1):setNumInputDims(2)
table.insert(self.bn_view_out, view_out)
self.net:add(view_out)
end
if self.dropout > 0 then
self.net:add(nn.Dropout(self.dropout))
end
end
-- After all the RNNs run, we will have a tensor of shape (N, T, H);
-- we want to apply a 1D temporal convolution to predict scores for each
-- vocab element, giving a tensor of shape (N, T, V). Unfortunately
-- nn.TemporalConvolution is SUPER slow, so instead we will use a pair of
-- views (N, T, H) -> (NT, H) and (NT, V) -> (N, T, V) with a nn.Linear in
-- between. Unfortunately N and T can change on every minibatch, so we need
-- to set them in the forward pass.
self.view1 = nn.View(1, 1, -1):setNumInputDims(3)
self.view2 = nn.View(1, -1):setNumInputDims(2)
self.net:add(self.view1)
self.net:add(nn.Linear(H, V))
self.net:add(self.view2)
end
function LM:updateOutput(input)
local N, T = input:size(1), input:size(2)
self.view1:resetSize(N * T, -1)
self.view2:resetSize(N, T, -1)
for _, view_in in ipairs(self.bn_view_in) do
view_in:resetSize(N * T, -1)
end
for _, view_out in ipairs(self.bn_view_out) do
view_out:resetSize(N, T, -1)
end
return self.net:forward(input)
end
function LM:backward(input, gradOutput, scale)
return self.net:backward(input, gradOutput, scale)
end
function LM:parameters()
return self.net:parameters()
end
function LM:training()
self.net:training()
parent.training(self)
end
function LM:evaluate()
self.net:evaluate()
parent.evaluate(self)
end
function LM:resetStates()
for i, rnn in ipairs(self.rnns) do
rnn:resetStates()
end
end
function LM:encode_string(s)
local encoded = torch.LongTensor(#s)
for i = 1, #s do
local token = s:sub(i, i)
local idx = self.token_to_idx[token]
assert(idx ~= nil, 'Got invalid idx')
encoded[i] = idx
end
return encoded
end
function LM:decode_char(encoded)
return self.idx_to_token[encoded[1][1]]
end
--[[
Sample from the language model. Note that this will reset the states of the
underlying RNNs.
Inputs:
- init: String of length T0
- max_length: Number of characters to sample
Returns:
- sampled: (1, max_length) array of integers, where the first part is init.
--]]
function LM:sample(kwargs)
local T = utils.get_kwarg(kwargs, 'length', 100)
local start_text = utils.get_kwarg(kwargs, 'start_text', '')
local verbose = utils.get_kwarg(kwargs, 'verbose', 0)
local sample = utils.get_kwarg(kwargs, 'sample', 1)
local temperature = utils.get_kwarg(kwargs, 'temperature', 1)
local write = io.write
local scores, first_t
if #start_text > 0 then
if verbose > 0 then
print('Seeding with: "' .. start_text .. '"')
end
local x = self:encode_string(start_text):view(1, -1)
local T0 = x:size(2)
scores = self:forward(x)[{{}, {T0, T0}}]
first_t = T0 + 1
else
if verbose > 0 then
print('Seeding with uniform probabilities')
end
local w = self.net:get(1).weight
scores = w.new(1, 1, self.vocab_size):fill(1)
first_t = 1
end
local empty = 1
local _, next_char = nil, nil
local answer = {}
for t = first_t, T do
if sample == 0 then
_, next_char = scores:max(3)
next_char = next_char[{{}, {}, 1}]
else
local probs = torch.div(scores, temperature):double():exp():squeeze()
probs:div(torch.sum(probs))
next_char = torch.multinomial(probs, 1):view(1, 1)
end
local current_char = self:decode_char(next_char)
-- Is the string empty?
if empty then
-- Ignore punctuation, spaces and new-lines as the first char
if string.find(current_char, "[a-zA-Z0-9]") then
-- If the first letter is small, prepend with "... "
if string.find(current_char,"[a-z]") then
table.insert(answer,"... ")
end
table.insert(answer,current_char)
-- The string is not empty now
empty = nil
end
-- New line means end of the response
elseif current_char == '\n' then
return table.concat(answer,"")
else
table.insert(answer,current_char)
end
scores = self:forward(next_char)
end
return table.concat(answer,"")
end
function LM:clearState()
self.net:clearState()
end