-
Notifications
You must be signed in to change notification settings - Fork 349
/
Copy pathBeam.py
89 lines (61 loc) · 2.88 KB
/
Beam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
from Batch import nopeak_mask
import torch.nn.functional as F
import math
def init_vars(src, model, SRC, TRG, opt):
init_tok = TRG.vocab.stoi['<sos>']
src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2)
e_output = model.encoder(src, src_mask)
outputs = torch.LongTensor([[init_tok]], device=opt.device)
trg_mask = nopeak_mask(1, opt)
out = model.out(model.decoder(outputs,
e_output, src_mask, trg_mask))
out = F.softmax(out, dim=-1)
probs, ix = out[:, -1].data.topk(opt.k)
log_scores = torch.Tensor([math.log(prob) for prob in probs.data[0]]).unsqueeze(0)
outputs = torch.zeros(opt.k, opt.max_len, device=opt.device).long()
outputs[:, 0] = init_tok
outputs[:, 1] = ix[0]
e_outputs = torch.zeros(opt.k, e_output.size(-2), e_output.size(-1), device=opt.device)
e_outputs[:, :] = e_output[0]
return outputs, e_outputs, log_scores
def k_best_outputs(outputs, out, log_scores, i, k):
probs, ix = out[:, -1].data.topk(k)
log_probs = torch.Tensor([math.log(p) for p in probs.data.view(-1)]).view(k, -1) + log_scores.transpose(0,1)
k_probs, k_ix = log_probs.view(-1).topk(k)
row = k_ix // k
col = k_ix % k
outputs[:, :i] = outputs[row, :i]
outputs[:, i] = ix[row, col]
log_scores = k_probs.unsqueeze(0)
return outputs, log_scores
def beam_search(src, model, SRC, TRG, opt):
outputs, e_outputs, log_scores = init_vars(src, model, SRC, TRG, opt)
eos_tok = TRG.vocab.stoi['<eos>']
src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2)
ind = None
for i in range(2, opt.max_len):
trg_mask = nopeak_mask(i, opt)
out = model.out(model.decoder(outputs[:,:i],
e_outputs, src_mask, trg_mask))
out = F.softmax(out, dim=-1)
outputs, log_scores = k_best_outputs(outputs, out, log_scores, i, opt.k)
ones = (outputs==eos_tok).nonzero() # Occurrences of end symbols for all input sentences.
sentence_lengths = torch.zeros(len(outputs), dtype=torch.long).cuda()
for vec in ones:
i = vec[0]
if sentence_lengths[i]==0: # First end symbol has not been found yet
sentence_lengths[i] = vec[1] # Position of first end symbol
num_finished_sentences = len([s for s in sentence_lengths if s > 0])
if num_finished_sentences == opt.k:
alpha = 0.7
div = 1/(sentence_lengths.type_as(log_scores)**alpha)
_, ind = torch.max(log_scores * div, 1)
ind = ind.data[0]
break
if ind is None:
length = (outputs[0]==eos_tok).nonzero()[0]
return ' '.join([TRG.vocab.itos[tok] for tok in outputs[0][1:length]])
else:
length = (outputs[ind]==eos_tok).nonzero()[0]
return ' '.join([TRG.vocab.itos[tok] for tok in outputs[ind][1:length]])