-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_data.py
185 lines (152 loc) · 5.57 KB
/
generate_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import sys
import dataset
import nltk
from nltk import word_tokenize
from nltk.util import ngrams
import time
from collections import defaultdict
# nltk.download('shakespeare')
# nltk.download('brown')
# nltk.download('punkt')
from config import DATABASE_URI
db = dataset.connect(DATABASE_URI)
#create_table = '''CREATE TABLE {} (text TEXT);'''.format(table_name)
#truncate_table = ''' DELETE FROM lines '''
def insert_brown_phrases(categories=None, cutoff=10000):
from nltk.corpus import brown
if not categories:
categories=brown.categories()
start_time = time.time()
brown_table = db['brown']
words = brown.words(categories=categories)
print("importing brown words from categories {},"
" w maximum word count of {}".format(categories, cutoff))
# words have ',' separating each phrase
import string
cutoff = min(len(words), cutoff)
insertions = 0
db.begin()
try:
sentence = []
for idx, w in enumerate(words[:cutoff]):
if idx % 1000 == 0:
print('processed {} words'.format(idx))
if w[0] not in string.ascii_letters:
sentence = ' '.join(sentence)
brown_table.insert({'chunk': sentence})
# print('skipped {}'.format(w))
# print('dumped sentence {}'.format(sentence))
sentence = []
insertions += 1
continue
sentence.append(w)
db.commit()
except:
db.rollback()
print("error inserting data")
end_time = time.time() - start_time
print("Done, w {} phrase insertions, took {}s".format(insertions, end_time))
def insert_shakespeare_lines():
table = db['speare_lines'] # will create table
from nltk.corpus import shakespeare
plays = [ shakespeare.xml(i) for i in shakespeare.fileids()]
start_time = time.time()
print(start_time)
for p in plays[0:2]:
lines = p.findall('*/*/*/LINE')
for line in lines:
line = line.text
if line:
line = line.replace("'","''")
table.insert({'text':line})
# db.query("INSERT INTO lines VALUES ('{}')".format(line))
print("inserted {} values into db".format(len(lines)))
# t.commit()
print('took{}'.format((time.time() - start_time)))
def get_lines_from_play(amount=None):
# table = db['quadgrams']
from nltk.corpus import shakespeare
plays = [ shakespeare.xml(i) for i in shakespeare.fileids()]
all_lines = []
for p in plays[0:2]:
lines = p.findall('*/*/*/LINE')
for l in lines:
if amount and amount <= len(all_lines):
return all_lines
if len(all_lines) % 100 == 0:
print("processed {} lines".format(len(all_lines)))
all_lines.append(l.text)
return all_lines
def get_word_quadgram_from_lines(lines):
fails = []
s = ''
quad_dict = defaultdict(list)
# store 5 quadgrams per key
QUADGRAMS_PER_KEY = 5
for line in lines:
try:
token=nltk.word_tokenize(line)
quadgrams = ngrams(token,4)
for quad in quadgrams:
if len(quad_dict[quad[0].lower()]) > QUADGRAMS_PER_KEY:
continue
quad_dict[quad[0].lower()].append(quad)
except:
fails.append(line)
print("failed to quadgram {} lines".format(len(fails)))
return quad_dict
def set_db_quadgrams(qd, amount=None):
print('inserting keys ({})'.format(amount))
table = db['quadgram'] # will create table
db.begin()
try:
for idx, key in enumerate(qd.keys()):
if idx % 10 == 0:
print('set {} values in db'.format(idx))
if amount and amount <= idx:
break
for q in qd[key]:
table.insert({'one': key,
'two':q[1],
'three':q[2],
'four':q[3]})
print('commited')
db.commit()
except:
print("rolling back qd insertion")
db.rollback()
def get_db_quadgrams(amount=None):
table = db['quadgram']
qd = defaultdict(list)
if amount:
quads = db.query('''SELECT root, one, two, three
FROM quadgram ORDER BY id
LIMIT {}'''.format(amount))
else:
quads = table.all()
for quad in quads:
qd[quad['one']].append([quad['one'], quad['two'], quad['three'], quad['four']])
return qd
if __name__ == '__main__':
if sys.argv[1] == 'setup':
nltk.download('shakespeare')
nltk.download('brown')
nltk.download('punkt')
if sys.argv[1] == 'gen-brown':
# ['adventure', 'belles_lettres', 'editorial', 'fiction',
# 'government', 'hobbies', 'humor', 'learned', 'lore',
# 'mystery', 'news', 'religion', 'reviews', 'romance',
# 'science_fiction']
cutoff = int(sys.argv[2]) if len(sys.argv) == 3 else 10000
insert_brown_phrases(categories=['adventure', 'lore',
'belles_lettres',
'romance','science_fiction'],
cutoff=cutoff)
if sys.argv[1] == 'gen-speare':
lines = get_lines_from_play(amount=1000)
qd = get_word_quadgram_from_lines(lines)
print('generated quads: {}'.format(len(qd.keys())))
set_db_quadgrams(qd)
if sys.argv[1] == 'stat':
qd2 = get_db_quadgrams()
print('fetched quads: {}'.format(len(qd2.keys())))