-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
Copy pathgenerate.py
54 lines (49 loc) · 1.61 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import numpy as np
import pickle
import tqdm
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LSTM, Dropout, Activation
import os
sequence_length = 100
# dataset file path
FILE_PATH = "data/wonderland.txt"
# FILE_PATH = "data/python_code.py"
BASENAME = os.path.basename(FILE_PATH)
# load vocab dictionaries
char2int = pickle.load(open(f"{BASENAME}-char2int.pickle", "rb"))
int2char = pickle.load(open(f"{BASENAME}-int2char.pickle", "rb"))
sequence_length = 100
vocab_size = len(char2int)
# building the model
model = Sequential([
LSTM(256, input_shape=(sequence_length, vocab_size), return_sequences=True),
Dropout(0.3),
LSTM(256),
Dense(vocab_size, activation="softmax"),
])
# load the optimal weights
model.load_weights(f"results/{BASENAME}-{sequence_length}.h5")
# specify the feed to first characters to generate
seed = "alice is pretty"
s = seed
n_chars = 400
# generate 400 characters
generated = ""
for i in tqdm.tqdm(range(n_chars), "Generating text"):
# make the input sequence
X = np.zeros((1, sequence_length, vocab_size))
for t, char in enumerate(seed):
X[0, (sequence_length - len(seed)) + t, char2int[char]] = 1
# predict the next character
predicted = model.predict(X, verbose=0)[0]
# converting the vector to an integer
next_index = np.argmax(predicted)
# converting the integer to a character
next_char = int2char[next_index]
# add the character to results
generated += next_char
# shift seed and the predicted character
seed = seed[1:] + next_char
print("Seed:", s)
print("Generated text:")
print(generated)