-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtranslate.py
executable file
·161 lines (129 loc) · 7.17 KB
/
translate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import os
import io
import sys
import argparse
import torch
from XLM.src.utils import AttrDict
from XLM.src.utils import bool_flag, initialize_exp
from XLM.src.data.dictionary import Dictionary
from XLM.src.model.transformer import TransformerModel
class Param():
def __init__(self, dump_path, exp_name, exp_id, batch_size, model_path, tgt_lang, src_lang, max_vocab=-1, min_count=0):
self.dump_path = dump_path # Experiment dump path
self.exp_name = exp_name # Experiment name
self.exp_id = exp_id # Experiment ID
self.batch_size = batch_size # Number of sentences per batch
self.model_path = model_path # Model path
self.src_lang = src_lang # Source language
self.tgt_lang = tgt_lang # Target language
self.max_vocab = max_vocab # Maximum vocabulary size (-1 to disable)
self.min_count = min_count, # Minimum vocabulary count
class Translate():
def __init__(self, model_path, tgt_lang, src_lang,dump_path = "./dumped/", exp_name="translate", exp_id="test", batch_size=32):
params = Param(dump_path, exp_name, exp_id, batch_size, model_path, tgt_lang, src_lang)
assert params.src_lang != '' and params.tgt_lang != '' and params.src_lang != params.tgt_lang
# initialize the experiment
logger = initialize_exp(params)
# On a pas de GPU
#reloaded = torch.load(params.model_path)
reloaded = torch.load(params.model_path, map_location=torch.device('cpu'))
model_params = AttrDict(reloaded['params'])
self.supported_languages = model_params.lang2id.keys()
logger.info("Supported languages: %s" % ", ".join(self.supported_languages))
# update dictionary parameters
for name in ['n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index', 'mask_index']:
setattr(params, name, getattr(model_params, name))
print("params = ", params)
# build dictionary / build encoder / build decoder / reload weights
self.dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
#self.encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=True).cuda().eval()
self.encoder = TransformerModel(model_params, self.dico, is_encoder=True, with_output=True).eval()
#self.decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval()
self.decoder = TransformerModel(model_params, self.dico, is_encoder=False, with_output=True).eval()
self.encoder.load_state_dict(reloaded['encoder'])
self.decoder.load_state_dict(reloaded['decoder'])
params.src_id = model_params.lang2id[params.src_lang]
params.tgt_id = model_params.lang2id[params.tgt_lang]
self.model_params = model_params
self.params = params
def translate(self, src_sent=[]):
flag = False
if type(src_sent) == str :
src_sent = [src_sent]
flag = True
tgt_sent = []
for i in range(0, len(src_sent), self.params.batch_size):
# prepare batch
word_ids = [torch.LongTensor([self.dico.index(w) for w in s.strip().split()])
for s in src_sent[i:i + self.params.batch_size]]
lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
batch = torch.LongTensor(lengths.max().item(), lengths.size(0)).fill_(self.params.pad_index)
batch[0] = self.params.eos_index
for j, s in enumerate(word_ids):
if lengths[j] > 2: # if sentence not empty
batch[1:lengths[j] - 1, j].copy_(s)
batch[lengths[j] - 1, j] = self.params.eos_index
langs = batch.clone().fill_(self.params.src_id)
# encode source batch and translate it
#encoded = self.encoder('fwd', x=batch.cuda(), lengths=lengths.cuda(), langs=langs.cuda(), causal=False)
encoded = self.encoder('fwd', x=batch, lengths=lengths, langs=langs, causal=False)
encoded = encoded.transpose(0, 1)
#decoded, dec_lengths = self.decoder.generate(encoded, lengths.cuda(), self.params.tgt_id, max_len=int(1.5 * lengths.max().item() + 10))
decoded, dec_lengths = self.decoder.generate(encoded, lengths, self.params.tgt_id, max_len=int(1.5 * lengths.max().item() + 10))
# convert sentences to words
for j in range(decoded.size(1)):
# remove delimiters
sent = decoded[:, j]
delimiters = (sent == self.params.eos_index).nonzero().view(-1)
assert len(delimiters) >= 1 and delimiters[0].item() == 0
sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]
# output translation
source = src_sent[i + j].strip()
target = " ".join([self.dico[sent[k].item()] for k in range(len(sent))])
sys.stderr.write("%i / %i: %s -> %s\n" % (i + j, len(src_sent), source, target))
tgt_sent.append(target )
if flag :
return tgt_sent[0]
return tgt_sent
# Ici on permet d'appeler le modele en faisant :
#python translate_our.py --exp_name translate \
# --src_lang en --tgt_lang fr \
# --model_path trained_model.pth --output_path output
def get_parser():
"""
Generate a parameters parser.
"""
# parse parameters
parser = argparse.ArgumentParser(description="Translate sentences")
# main parameters
parser.add_argument("--dump_path", type=str, default="./dumped/", help="Experiment dump path")
parser.add_argument("--exp_name", type=str, default="", help="Experiment name")
parser.add_argument("--exp_id", type=str, default="", help="Experiment ID")
parser.add_argument("--batch_size", type=int, default=32, help="Number of sentences per batch")
# model / output paths
parser.add_argument("--model_path", type=str, default="", help="Model path")
parser.add_argument("--output_path", type=str, default="", help="Output path")
# parser.add_argument("--max_vocab", type=int, default=-1, help="Maximum vocabulary size (-1 to disable)")
# parser.add_argument("--min_count", type=int, default=0, help="Minimum vocabulary count")
# source language / target language
parser.add_argument("--src_lang", type=str, default="", help="Source language")
parser.add_argument("--tgt_lang", type=str, default="", help="Target language")
return parser
def main(params):
# initialize the experiment
logger = initialize_exp(params)
# generate parser / parse parameters
parser = get_parser()
params = parser.parse_args()
trainer = Translate(model_path = params.model_path , tgt_lang = params.tgt_lang, src_lang = params.src_lang,
dump_path = params.dump_path, exp_name=params.exp_name,
exp_id=params.exp_id, batch_size=params.batch_size)
print(trainer.translate(["I eat something", "Good morning my frend"]))
print(trainer.translate("Good morning my frend, I eat something"))
if __name__ == '__main__':
# generate parser / parse parameters
parser = get_parser()
params = parser.parse_args()
# translate
with torch.no_grad():
main(params)