-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_simple_lfads.py
executable file
·200 lines (133 loc) · 5.45 KB
/
run_simple_lfads.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import tensorflow as tf
import numpy as np
import sys
import time
# utils defined by CP
#from helper_funcs import linear, init_linear_transform, makeInitialState
#from helper_funcs import DiagonalGaussianFromInput, DiagonalGaussian, DiagonalGaussianFromExisting
#from helper_funcs import BidirectionalDynamicRNN, DynamicRNN, LinearTimeVarying
#from helper_funcs import KLCost_GaussianGaussian, Poisson
#from complexcell import ComplexCell
from helper_funcs import ListOfRandomBatches, kind_dict
from plot_funcs import plot_data, close_all_plots
from data_funcs import read_datasets
from simpleModelNoInputs import SimpleModel
#data_dir = '/tmp/rnn_synth_data_v1.0/'
#data_fn = 'chaotic_rnn_inputs_g1p5_dataset_N50_S50'
#data_fn = 'chaotic_rnn_inputs_g1p5_dataset'
data_dir = '/tmp/lorenz/' # lorenz data
#data_fn = 'generated_data'
data_fn = 'generated'
datasets = read_datasets(data_dir, data_fn)
dkey = datasets.keys()[0]
#print datasets[dkey].keys()
#sys.exit()
train_data = datasets[dkey]['train_data']
valid_data = datasets[dkey]['valid_data']
train_data = train_data[0:1000,:,:]
valid_data = valid_data[0:250,:,:]
#train_truth = datasets[dkey]['train_truth']
#valid_truth = datasets[dkey]['valid_truth']
# train_data = train_data[0::5,:,:]
# valid_data = valid_data[0::5,:,:]
# train_truth = train_truth[0::5,:,:]
# valid_truth = valid_truth[0::5,:,:]
print train_data.shape
print valid_data.shape
hps = {}
hps['num_steps'] = train_data.shape[1]
hps['dataset_dims'] = {}
hps['dataset_dims'][dkey] = train_data.shape[2]
hps['batch_size'] = 50
hps['sequence_lengths'] = [hps['num_steps'] for i in range(hps['batch_size'])]
# hardcode some HPs for now
#networks
hps['ic_enc_dim'] = 64
hps['gen_dim'] = 64
#ics, cis, factors
hps['factors_dim'] = 20
hps['ic_dim'] = 10
# hps
hps['keep_prob']=0.95
hps['ic_var_min']=0.1
hps['kind'] = 'train'
hps['max_grad_norm'] = 200.0
hps['learning_rate_init'] = 0.05
hps['learning_rate_decay_factor'] = 0.97
hps['learning_rate_n_to_compare'] = 6
hps['learning_rate_stop'] = 0.00001
## haven't implemented these in lfadslite
# hps['feedback_factors_or_rates'] = 'factors'
# hps['cell_clip_value'] = 5.0
# initialize the model with these hyperparams
model = SimpleModel(hps)
# define an epoch
epoch_batches = []
for npochs in range(200):
epoch_batches += ListOfRandomBatches(train_data.shape[0], hps['batch_size'])
valid_batches = ListOfRandomBatches(valid_data.shape[0], hps['batch_size'])
steps = range(0,len(epoch_batches))
steps_per_epoch = np.floor(train_data.shape[0] / hps['batch_size'])
#print_trainable_vars(trainable_vars)
#sys.exit()
# setup tf configuration
config = tf.ConfigProto(allow_soft_placement=True,
log_device_placement=False)
config.gpu_options.allow_growth = True
session = tf.Session(config=config)
writer = tf.summary.FileWriter('logs', session.graph)
train_costs = []
nepoch = 0
with session.as_default():
tf.global_variables_initializer().run()
for nstep in steps:
if nstep % steps_per_epoch == 0:
# new epoch
nepoch+=1
tc_e = []; rc_e = []; kc_e = []
# train the model on this batch
X = train_data[epoch_batches[nstep],:,:].astype(np.float32)
feed_dict = {'input_data': X,
'keep_prob': hps['keep_prob']}
evald = model.train_batch(feed_dict)
[_, tc, rc, kc, rts, lr] = evald
# print out the step cost
print 'step ' + str(nstep) +': [ '+ str(tc) +', ' + str(rc) + '], \r',
sys.stdout.flush()
#store costs for the epoch
tc_e.append(tc); rc_e.append(rc); kc_e.append(kc)
writer.close()
# should we decrement learning rate
n_lr = hps['learning_rate_n_to_compare']
if len(train_costs) > n_lr and tc > np.max(train_costs[-n_lr:]):
print("Decreasing learning rate")
model.run_learning_rate_decay_opt()
train_costs.append(tc)
new_lr = model.get_learning_rate()
# should we stop?
if new_lr < hps['learning_rate_stop']:
print("Learning rate criteria met")
break
#run the validation set once per epoch
if nstep % steps_per_epoch == 0:
# returns the total cost, reconstruction cost, and kl cost
tcv_all = []; rcv_all = []; kcv_all = []
for nvbatch in range(len(valid_batches)):
Xv = valid_data[valid_batches[nvbatch],:,:].astype(np.float32)
feed_dict = {'input_data': Xv}
[tcv_b, rcv_b, kcv_b] = model.validation_batch(feed_dict)
tcv_all.append(tcv_b); rcv_all.append(rcv_b); kcv_all.append(kcv_b)
# take the mean of all validation batches
tcv = np.mean(tcv_all); rcv = np.mean(rcv_all); kcv = np.mean(kcv_all)
# take the mean of all training batches
tct = np.mean(tc_e); rct = np.mean(rc_e); kct = np.mean(kc_e)
print
print "epoch: %i, step %i: total: (%f, %f), rec: (%f, %f), kl: (%f, %f), lr: %f" % (nepoch, nstep, tct, tcv, rct, rcv, kct, kcv, lr)
if nepoch % 1 == 0:
# plot the current results
plt = plot_data(X[0,:,:], rts[0,:,:])
if nepoch % 20 == 0:
close_all_plots()
# if nstep % 10 == 0:
# print "step %i: total: (%f), rec: (%f), kl: (%f), lr: %f" % (nstep, tc, rc, kc, lr)
print("Done training")