forked from jhetherly/EnglishSpeechUpsampler
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
513 lines (440 loc) · 20.6 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
import tensorflow as tf
# lib_loc = 'src/shuffle_op.so'
# try:
# custom_shuffle_module = tf.load_op_library(lib_loc)
# shuffle = custom_shuffle_module.shuffle
# except NotFoundError:
# exit("Something is wrong with '{}', it probably has not been (re)compiled yet.\n".format(lib_loc) +
# "Recompile using 'COMPILE_FROM_BINARY.sh' in the src folder, or with the following commands in bash:\n\n"
# "TF_CFLAGS=( $(python3 -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') )\n"
# "TF_LFLAGS=( $(python3 -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') )\n"
# "g++ -std=c++11 -shared shuffle_op.cc -o test.so -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2")
# ###################
# TENSORBOARD HELPERS
# ###################
def comprehensive_variable_summaries(var):
"""
Attach a lot of summaries to a Tensor (for TensorBoard visualization).
"""
with tf.name_scope('summaries'):
mean = tf.reduce_mean(var)
tf.summary.scalar('mean', mean)
with tf.name_scope('stddev'):
stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
tf.summary.scalar('stddev', stddev)
tf.summary.scalar('max', tf.reduce_max(var))
tf.summary.scalar('min', tf.reduce_min(var))
tf.summary.histogram('histogram', var)
def histogram_variable_summaries(var):
"""
Attach a histogram summary to a Tensor (for TensorBoard visualization).
"""
with tf.name_scope('summaries'):
tf.summary.histogram('histogram', var)
# ###################
# ###################
# ######################
# LAYER HELPER FUNCTIONS
# ######################
def subpixel_reshuffle_1D_impl(X, m):
"""
performs a 1-D subpixel reshuffle of the input 2-D tensor
assumes the last dimension of X is the filter dimension
ref: https://github.com/Tetrachrome/subpixel
"""
return tf.transpose(tf.stack([tf.reshape(x, (-1,)) for x
in tf.split(X, m, axis=1)]))
def subpixel_reshuffle_1D(X, m, name=None):
"""
maps over the batch dimension
"""
return tf.map_fn(lambda x: subpixel_reshuffle_1D_impl(x, m), X, name=name)
def subpixel_restack_impl(X, n_prime, m_prime, name=None):
"""
performs a subpixel restacking such that it restacks columns of a 2-D
tensor onto the rows
"""
bsize = tf.shape(X)[0]
r_n = n_prime - X.get_shape().as_list()[1]
total_new_space = r_n * m_prime
to_stack = tf.slice(X, [0, 0, m_prime], [-1, -1, -1])
to_stack = tf.slice(tf.reshape(to_stack, (bsize, -1)),
[0, 0], [-1, total_new_space])
to_stack = tf.reshape(to_stack, (bsize, -1, m_prime))
to_stack = tf.slice(to_stack, [0, 0, 0], [-1, r_n, -1])
return tf.concat((tf.slice(X, [0, 0, 0], [-1, -1, m_prime]), to_stack),
axis=1, name=name)
def subpixel_restack(X, n_prime, m_prime=None, name=None):
n = X.get_shape().as_list()[1]
m = X.get_shape().as_list()[2]
r_n = n_prime - n
if m_prime is None:
for i in range(1, m):
r_m = i
m_prime = m - r_m
if r_m * n >= m_prime * r_n:
break
return subpixel_restack_impl(X, n_prime, m_prime, name=name)
def batch_norm(T, is_training, scope):
# tf.cond takes nullary functions as its first and second arguments
return tf.cond(is_training,
lambda: tf.contrib.layers.batch_norm(T,
decay=0.99,
# zero_debias_moving_mean=True,
is_training=is_training,
center=True, scale=True,
updates_collections=None,
scope=scope,
reuse=False),
lambda: tf.contrib.layers.batch_norm(T,
decay=0.99,
is_training=is_training,
center=True, scale=True,
updates_collections=None,
scope=scope,
reuse=True))
def weight_variable(shape, name=None):
initial = tf.truncated_normal(shape, mean=0.0, stddev=0.1)
return tf.Variable(initial, name=name)
def bias_variable(shape, name=None):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial, name=name)
def conv1d(x, W, stride=1, padding='SAME', name=None):
return tf.nn.conv1d(x, W, stride=stride, padding=padding, name=name)
def build_1d_conv_layer(prev_tensor, prev_conv_depth,
conv_window, conv_depth,
act, layer_number,
stride=1,
padding='SAME',
tensorboard_output=False,
name=None):
with tf.name_scope('{}_layer_weights'.format(layer_number)):
W = weight_variable([conv_window,
prev_conv_depth,
conv_depth])
if tensorboard_output:
histogram_variable_summaries(W)
with tf.name_scope('{}_layer_biases'.format(layer_number)):
b = bias_variable([conv_depth])
if tensorboard_output:
histogram_variable_summaries(b)
with tf.name_scope('{}_layer_conv_preactivation'.format(layer_number)):
conv = conv1d(prev_tensor, W, stride=stride, padding=padding) + b
if tensorboard_output:
histogram_variable_summaries(conv)
with tf.name_scope('{}_layer_conv_activation'.format(layer_number)):
h = act(conv, name=name)
if tensorboard_output:
histogram_variable_summaries(h)
return h
def build_1d_conv_layer_with_res(prev_tensor, prev_conv_depth,
conv_window, conv_depth,
res, act, layer_number,
tensorboard_output=False,
name=None):
with tf.name_scope('{}_layer_weights'.format(layer_number)):
W = weight_variable([conv_window,
prev_conv_depth,
conv_depth])
if tensorboard_output:
histogram_variable_summaries(W)
with tf.name_scope('{}_layer_biases'.format(layer_number)):
b = bias_variable([conv_depth])
if tensorboard_output:
histogram_variable_summaries(b)
with tf.name_scope('{}_layer_conv_preactivation'.format(layer_number)):
conv = conv1d(prev_tensor, W) + b
if tensorboard_output:
histogram_variable_summaries(conv)
with tf.name_scope('{}_layer_conv_activation'.format(layer_number)):
h = act(tf.add(conv, res), name=name)
if tensorboard_output:
histogram_variable_summaries(h)
return h
def build_downsampling_block(input_tensor,
filter_size, stride,
layer_number,
act=tf.nn.relu,
is_training=True,
depth=None,
padding='VALID',
tensorboard_output=False,
name=None):
# assume this layer is twice the depth of the previous layer if no depth
# information is given
if depth is None:
depth = 2 * input_tensor.get_shape().as_list()[-1]
with tf.name_scope('{}_layer_weights'.format(layer_number)):
W = weight_variable([filter_size,
input_tensor.get_shape().as_list()[-1],
depth])
if tensorboard_output:
histogram_variable_summaries(W)
with tf.name_scope('{}_layer_biases'.format(layer_number)):
b = bias_variable([depth])
if tensorboard_output:
histogram_variable_summaries(b)
with tf.name_scope('{}_layer_conv_preactivation'.format(layer_number)):
l = tf.nn.conv1d(input_tensor, W, stride=stride,
padding=padding, name=name) + b
if tensorboard_output:
histogram_variable_summaries(l)
with tf.name_scope('{}_layer_batch_norm'.format(layer_number)) as scope:
# l = tf.nn.dropout(l, keep_prob=0.25)
l = batch_norm(l, is_training, scope)
with tf.name_scope('{}_layer_conv_activation'.format(layer_number)):
l = act(l, name=name)
if tensorboard_output:
histogram_variable_summaries(l)
return l
def build_upsampling_block(input_tensor, residual_tensor,
filter_size,
layer_number,
act=tf.nn.relu,
is_training=True,
depth=None,
padding='VALID',
tensorboard_output=False,
name=None):
# assume this layer is half the depth of the previous layer if no depth
# information is given
if depth is None:
depth = int(input_tensor.get_shape().as_list()[-1] / 2)
with tf.name_scope('{}_layer_weights'.format(layer_number)):
W = weight_variable([filter_size,
input_tensor.get_shape().as_list()[-1],
depth])
if tensorboard_output:
histogram_variable_summaries(W)
with tf.name_scope('{}_layer_biases'.format(layer_number)):
b = bias_variable([depth])
if tensorboard_output:
histogram_variable_summaries(b)
with tf.name_scope('{}_layer_conv_preactivation'.format(layer_number)):
l = tf.nn.conv1d(input_tensor, W, stride=1,
padding=padding, name=name) + b
if tensorboard_output:
histogram_variable_summaries(l)
with tf.name_scope('{}_layer_batch_norm'.format(layer_number)) as scope:
# l = tf.nn.dropout(l, keep_prob=0.25)
l = batch_norm(l, is_training, scope)
# l = tf.nn.l2_normalize(l, dim=2)
with tf.name_scope('{}_layer_conv_activation'.format(layer_number)):
l = act(l, name=name)
if tensorboard_output:
histogram_variable_summaries(l)
with tf.name_scope('{}_layer_subpixel_reshuffle'.format(layer_number)):
l = subpixel_reshuffle_1D(l,
residual_tensor.get_shape().as_list()[-1],
name=name)
if tensorboard_output:
histogram_variable_summaries(l)
with tf.name_scope('{}_layer_stacking'.format(layer_number)):
sliced = tf.slice(residual_tensor,
begin=[0, 0, 0],
size=[-1, l.get_shape().as_list()[1], -1])
l = tf.concat((l, sliced), axis=2, name=name)
if tensorboard_output:
histogram_variable_summaries(l)
return l
# ######################
# ######################
# #################
# MODEL DEFINITIONS
# #################
def single_fully_connected_model(input_type, input_shape,
n_inputs, n_weights,
tensorboard_output=True,
scope_name='single_fully_connected_layer'):
with tf.name_scope(scope_name):
# input of the model (examples)
s = [None]
shape_prod = 1
for i in input_shape:
s.append(i)
shape_prod *= i
x = tf.placeholder(input_type, shape=s)
x_ = tf.reshape(x, [-1, shape_prod])
# first conv layer
with tf.name_scope('first_layer_weights'):
s = []
s.append(shape_prod)
s.append(n_weights)
W = weight_variable(s)
if tensorboard_output:
histogram_variable_summaries(W)
with tf.name_scope('first_layer_biases'):
b = bias_variable([n_weights])
if tensorboard_output:
histogram_variable_summaries(b)
with tf.name_scope('first_layer_preactivation'):
preact = tf.matmul(x_, W) + b
if tensorboard_output:
histogram_variable_summaries(preact)
with tf.name_scope('first_layer_activation'):
y = tf.identity(preact, name=scope_name)
if tensorboard_output:
histogram_variable_summaries(y)
return x, y
def three_layer_conv_model(input_type, input_shape,
first_conv_window=30, first_conv_depth=128,
second_conv_window=10, second_conv_depth=64,
third_conv_window=15,
tensorboard_output=False,
scope_name='3-layer_conv'):
with tf.name_scope(scope_name):
# input of the model (examples)
s = [None]
for i in input_shape:
s.append(i)
x = tf.placeholder(input_type, shape=s)
# first conv layer
h1 = build_1d_conv_layer(x, 1,
first_conv_window, first_conv_depth,
tf.nn.elu, 1,
tensorboard_output)
# second conv layer
h2 = build_1d_conv_layer(h1, first_conv_depth,
second_conv_window, second_conv_depth,
tf.nn.elu, 2,
tensorboard_output)
# third (last) conv layer
y = build_1d_conv_layer(h2, second_conv_depth,
third_conv_window, 1,
tf.identity, 3,
tensorboard_output,
scope_name)
return x, y
def five_layer_conv_model(input_type, input_shape,
first_conv_window=30, first_conv_depth=256,
second_conv_window=20, second_conv_depth=128,
third_conv_window=10, third_conv_depth=64,
fourth_conv_window=5, fourth_conv_depth=32,
fifth_conv_window=5,
tensorboard_output=False,
scope_name='5-layer_conv'):
with tf.name_scope(scope_name):
# input of the model (examples)
s = [None]
for i in input_shape:
s.append(i)
x = tf.placeholder(input_type, shape=s)
# first conv layer
h1 = build_1d_conv_layer(x, 1,
first_conv_window, first_conv_depth,
tf.nn.elu, 1,
tensorboard_output)
# second conv layer
h2 = build_1d_conv_layer(h1, first_conv_depth,
second_conv_window, second_conv_depth,
tf.nn.elu, 2,
tensorboard_output)
# third conv layer
h3 = build_1d_conv_layer(h2, second_conv_depth,
third_conv_window, third_conv_depth,
tf.nn.elu, 3,
tensorboard_output)
# fourth conv layer
h4 = build_1d_conv_layer(h3, third_conv_depth,
fourth_conv_window, fourth_conv_depth,
tf.nn.elu, 4,
tensorboard_output)
# fifth (last) conv layer
y = build_1d_conv_layer(h4, fourth_conv_depth,
fifth_conv_window, 1,
tf.identity, 5,
tensorboard_output,
scope_name)
return x, y
def deep_residual_network(input_type, input_shape,
number_of_downsample_layers=8,
channel_multiple=8,
initial_filter_window=5,
initial_stride=2,
downsample_filter_window=3,
downsample_stride=2,
bottleneck_filter_window=4,
bottleneck_stride=2,
upsample_filter_window=3,
tensorboard_output=False,
scope_name='deep_residual'):
print('layer summary for {} network'.format(scope_name))
downsample_layers = []
upsample_layers = []
with tf.name_scope(scope_name):
# training flag
train_flag = tf.placeholder(tf.bool)
# input of the model (examples)
s = [None]
for i in input_shape:
s.append(i)
x = tf.placeholder(input_type, shape=s)
input_size = s[-2]
num_of_channels = s[-1]
print('input: {}'.format(x.get_shape().as_list()[1:]))
d1 = build_downsampling_block(x,
filter_size=initial_filter_window,
stride=initial_stride,
tensorboard_output=tensorboard_output,
depth=channel_multiple * num_of_channels,
is_training=train_flag,
layer_number=1)
print('downsample layer: {}'.format(d1.get_shape().as_list()[1:]))
downsample_layers.append(d1)
layer_count = 2
for i in range(number_of_downsample_layers - 1):
d = build_downsampling_block(
downsample_layers[-1],
filter_size=downsample_filter_window,
stride=downsample_stride,
tensorboard_output=tensorboard_output,
is_training=train_flag,
layer_number=layer_count)
print('downsample layer: {}'.format(d.get_shape().as_list()[1:]))
downsample_layers.append(d)
layer_count += 1
bn = build_downsampling_block(downsample_layers[-1],
filter_size=bottleneck_filter_window,
stride=bottleneck_stride,
tensorboard_output=tensorboard_output,
is_training=train_flag,
layer_number=layer_count)
print('bottleneck layer: {}'.format(bn.get_shape().as_list()[1:]))
layer_count += 1
u1 = build_upsampling_block(bn, downsample_layers[-1],
depth=bn.get_shape().as_list()[-1],
filter_size=upsample_filter_window,
tensorboard_output=tensorboard_output,
is_training=train_flag,
layer_number=layer_count)
print('upsample layer: {}'.format(u1.get_shape().as_list()[1:]))
upsample_layers.append(u1)
layer_count += 1
for i in range(number_of_downsample_layers - 2, -1, -1):
u = build_upsampling_block(upsample_layers[-1],
downsample_layers[i],
filter_size=upsample_filter_window,
tensorboard_output=tensorboard_output,
is_training=train_flag,
layer_number=layer_count)
print('upsample layer: {}'.format(u.get_shape().as_list()[1:]))
upsample_layers.append(u)
layer_count += 1
target_size = int(input_size / initial_stride)
restack = subpixel_restack(upsample_layers[-1],
target_size + (upsample_filter_window - 1))
print('restack layer: {}'.format(restack.get_shape().as_list()[1:]))
conv = build_1d_conv_layer(restack, restack.get_shape().as_list()[-1],
upsample_filter_window, initial_stride,
tf.nn.elu, layer_count,
padding='VALID',
tensorboard_output=tensorboard_output)
print('final conv layer: {}'.format(conv.get_shape().as_list()[1:]))
# NOTE this effectively is a linear activation on the last conv layer
y = subpixel_reshuffle_1D(conv,
num_of_channels)
y = tf.add(y, x, name=scope_name)
print('output: {}'.format(y.get_shape().as_list()[1:]))
return train_flag, x, y
# #################
# #################