forked from hrlinlp/entail_sum
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutil.py
65 lines (53 loc) · 2.15 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from collections import defaultdict
import numpy as np
def map_ent_label(labels):
ent_label_dict = {'entailment': 0, 'neutral': 1, 'contradiction': 2}
new_labels = [[ent_label_dict[label[0]]] for label in labels]
return new_labels
def read_corpus(file_path, source):
data = []
for line in open(file_path):
sent = line.strip().split(' ')
# only append <s> and </s> to the target sentence
if source == 'tgt':
sent = ['<s>'] + sent + ['</s>']
data.append(sent)
return data
def batch_slice(data, batch_size, sort=True, ent=False):
if len(data[0]) == 3:
ent = True
batch_num = int(np.ceil(len(data) / float(batch_size)))
for i in xrange(batch_num):
cur_batch_size = batch_size if i < batch_num - 1 else len(data) - batch_size * i
src_sents = [data[i * batch_size + b][0] for b in range(cur_batch_size)]
tgt_sents = [data[i * batch_size + b][1] for b in range(cur_batch_size)]
if ent:
labels = [data[i * batch_size + b][2] for b in range(cur_batch_size)]
if sort:
src_ids = sorted(range(cur_batch_size), key=lambda src_id: len(src_sents[src_id]), reverse=True)
src_sents = [src_sents[src_id] for src_id in src_ids]
tgt_sents = [tgt_sents[src_id] for src_id in src_ids]
if ent:
labels = [labels[src_id] for src_id in src_ids]
if ent:
yield src_sents, tgt_sents, labels
else:
yield src_sents, tgt_sents
def data_iter(data, batch_size, shuffle=True):
"""
randomly permute data, then sort by source length, and partition into batches
ensure that the length of source sentences in each batch is decreasing
"""
buckets = defaultdict(list)
for pair in data:
src_sent = pair[0]
buckets[len(src_sent)].append(pair)
batched_data = []
for src_len in buckets:
tuples = buckets[src_len]
if shuffle: np.random.shuffle(tuples)
batched_data.extend(list(batch_slice(tuples, batch_size)))
if shuffle:
np.random.shuffle(batched_data)
for batch in batched_data:
yield batch