-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathdataloader.py
118 lines (93 loc) · 3.51 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from __future__ import print_function
import torch
import pandas as pd
import numpy as np
import util as ut
class TextClassDataLoader(object):
def __init__(self, path_file, word_to_index, batch_size=32):
"""
Args:
path_file:
word_to_index:
batch_size:
"""
self.batch_size = batch_size
self.word_to_index = word_to_index
# read file
df = pd.read_csv(path_file, delimiter='\t')
df['body'] = df['body'].apply(ut._tokenize)
df['body'] = df['body'].apply(self.generate_indexifyer())
self.samples = df.values.tolist()
# for batch
self.n_samples = len(self.samples)
self.n_batches = int(self.n_samples / self.batch_size)
self.max_length = self._get_max_length()
self._shuffle_indices()
self.report()
def _shuffle_indices(self):
self.indices = np.random.permutation(self.n_samples)
self.index = 0
self.batch_index = 0
def _get_max_length(self):
length = 0
for sample in self.samples:
length = max(length, len(sample[1]))
return length
def generate_indexifyer(self):
def indexify(lst_text):
indices = []
for word in lst_text:
if word in self.word_to_index:
indices.append(self.word_to_index[word])
else:
indices.append(self.word_to_index['__UNK__'])
return indices
return indexify
@staticmethod
def _padding(batch_x):
batch_s = sorted(batch_x, key=lambda x: len(x))
size = len(batch_s[-1])
for i, x in enumerate(batch_x):
missing = size - len(x)
batch_x[i] = batch_x[i] + [0 for _ in range(missing)]
return batch_x
def _create_batch(self):
batch = []
n = 0
while n < self.batch_size:
_index = self.indices[self.index]
batch.append(self.samples[_index])
self.index += 1
n += 1
self.batch_index += 1
label, string = tuple(zip(*batch))
# get the length of each seq in your batch
seq_lengths = torch.LongTensor(map(len, string))
# dump padding everywhere, and place seqs on the left.
# NOTE: you only need a tensor as big as your longest sequence
seq_tensor = torch.zeros((len(string), seq_lengths.max())).long()
for idx, (seq, seqlen) in enumerate(zip(string, seq_lengths)):
seq_tensor[idx, :seqlen] = torch.LongTensor(seq)
# SORT YOUR TENSORS BY LENGTH!
seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
seq_tensor = seq_tensor[perm_idx]
# seq_tensor = seq_tensor.transpose(0, 1)
label = torch.LongTensor(label)
label = label[perm_idx]
return seq_tensor, label, seq_lengths
def __len__(self):
return self.n_batches
def __iter__(self):
self._shuffle_indices()
for i in range(self.n_batches):
if self.batch_index == self.n_batches:
raise StopIteration()
yield self._create_batch()
def show_samples(self, n=10):
for sample in self.samples[:n]:
print(sample)
def report(self):
print('# samples: {}'.format(len(self.samples)))
print('max len: {}'.format(self.max_length))
print('# vocab: {}'.format(len(self.word_to_index)))
print('# batches: {} (batch_size = {})'.format(self.n_batches, self.batch_size))