-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
Copy pathtrain.py
132 lines (111 loc) · 4.57 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import tensorflow as tf
import numpy as np
import os
import pickle
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LSTM, Dropout
from tensorflow.keras.callbacks import ModelCheckpoint
from string import punctuation
sequence_length = 100
BATCH_SIZE = 128
EPOCHS = 3
# dataset file path
FILE_PATH = "data/wonderland.txt"
# FILE_PATH = "data/python_code.py"
BASENAME = os.path.basename(FILE_PATH)
# commented because already downloaded
# import requests
# content = requests.get("http://www.gutenberg.org/cache/epub/11/pg11.txt").text
# open("data/wonderland.txt", "w", encoding="utf-8").write(content)
# read the data
text = open(FILE_PATH, encoding="utf-8").read()
# remove caps, comment this code if you want uppercase characters as well
text = text.lower()
# remove punctuation
text = text.translate(str.maketrans("", "", punctuation))
# print some stats
n_chars = len(text)
vocab = ''.join(sorted(set(text)))
print("unique_chars:", vocab)
n_unique_chars = len(vocab)
print("Number of characters:", n_chars)
print("Number of unique characters:", n_unique_chars)
# dictionary that converts characters to integers
char2int = {c: i for i, c in enumerate(vocab)}
# dictionary that converts integers to characters
int2char = {i: c for i, c in enumerate(vocab)}
# save these dictionaries for later generation
pickle.dump(char2int, open(f"{BASENAME}-char2int.pickle", "wb"))
pickle.dump(int2char, open(f"{BASENAME}-int2char.pickle", "wb"))
# convert all text into integers
encoded_text = np.array([char2int[c] for c in text])
# construct tf.data.Dataset object
char_dataset = tf.data.Dataset.from_tensor_slices(encoded_text)
# print first 5 characters
for char in char_dataset.take(8):
print(char.numpy(), int2char[char.numpy()])
# build sequences by batching
sequences = char_dataset.batch(2*sequence_length + 1, drop_remainder=True)
# print sequences
for sequence in sequences.take(2):
print(''.join([int2char[i] for i in sequence.numpy()]))
def split_sample(sample):
# example :
# sequence_length is 10
# sample is "python is a great pro" (21 length)
# ds will equal to ('python is ', 'a') encoded as integers
ds = tf.data.Dataset.from_tensors((sample[:sequence_length], sample[sequence_length]))
for i in range(1, (len(sample)-1) // 2):
# first (input_, target) will be ('ython is a', ' ')
# second (input_, target) will be ('thon is a ', 'g')
# third (input_, target) will be ('hon is a g', 'r')
# and so on
input_ = sample[i: i+sequence_length]
target = sample[i+sequence_length]
# extend the dataset with these samples by concatenate() method
other_ds = tf.data.Dataset.from_tensors((input_, target))
ds = ds.concatenate(other_ds)
return ds
# prepare inputs and targets
dataset = sequences.flat_map(split_sample)
def one_hot_samples(input_, target):
# onehot encode the inputs and the targets
# Example:
# if character 'd' is encoded as 3 and n_unique_chars = 5
# result should be the vector: [0, 0, 0, 1, 0], since 'd' is the 4th character
return tf.one_hot(input_, n_unique_chars), tf.one_hot(target, n_unique_chars)
dataset = dataset.map(one_hot_samples)
# print first 2 samples
for element in dataset.take(2):
print("Input:", ''.join([int2char[np.argmax(char_vector)] for char_vector in element[0].numpy()]))
print("Target:", int2char[np.argmax(element[1].numpy())])
print("Input shape:", element[0].shape)
print("Target shape:", element[1].shape)
print("="*50, "\n")
# repeat, shuffle and batch the dataset
ds = dataset.repeat().shuffle(1024).batch(BATCH_SIZE, drop_remainder=True)
# building the model
# model = Sequential([
# LSTM(128, input_shape=(sequence_length, n_unique_chars)),
# Dense(n_unique_chars, activation="softmax"),
# ])
# a better model (slower to train obviously)
model = Sequential([
LSTM(256, input_shape=(sequence_length, n_unique_chars), return_sequences=True),
Dropout(0.3),
LSTM(256),
Dense(n_unique_chars, activation="softmax"),
])
# define the model path
model_weights_path = f"results/{BASENAME}-{sequence_length}.h5"
# if os.path.isfile(model_weights_path):
# model.load_weights(model_weights_path)
model.summary()
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
if not os.path.isdir("results"):
os.mkdir("results")
# checkpoint = ModelCheckpoint("results/{}-{loss:.2f}.h5".format(BASENAME), verbose=1)
# train the model
model.fit(ds, steps_per_epoch=(len(encoded_text) - sequence_length) // BATCH_SIZE, epochs=EPOCHS)
# save the model
model.save(model_weights_path)