diff --git a/.gitignore b/.gitignore index dd84837dd..c9dd3db88 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,16 @@ # Compiled python modules. *.pyc +# Byte-compiled +_pycache__/ + # Python egg metadata, regenerated from source files by setuptools. /*.egg-info -# PyPI distribution artificats +# PyPI distribution artifacts. build/ dist/ + +# Sublime project files +*.sublime-project +*.sublime-workspace diff --git a/tensor2tensor/data_generators/text_encoder.py b/tensor2tensor/data_generators/text_encoder.py index 1bf7539d3..61078b3f4 100644 --- a/tensor2tensor/data_generators/text_encoder.py +++ b/tensor2tensor/data_generators/text_encoder.py @@ -310,6 +310,7 @@ def build_to_target_size(cls, tf.logging.info("Alphabet contains %d characters" % len(alphabet_set)) def bisect(min_val, max_val): + """Bisection to find the right size.""" present_count = (max_val + min_val) // 2 tf.logging.info("Trying min_count %d" % present_count) subtokenizer = cls() @@ -317,14 +318,16 @@ def bisect(min_val, max_val): present_count, num_iterations) if min_val >= max_val or subtokenizer.vocab_size == target_size: return subtokenizer + if subtokenizer.vocab_size > target_size: other_subtokenizer = bisect(present_count + 1, max_val) else: other_subtokenizer = bisect(min_val, present_count - 1) - if (abs(other_subtokenizer.vocab_size - target_size) < - abs(subtokenizer.vocab_size - target_size)): - return other_subtokenizer - return subtokenizer + + if (abs(other_subtokenizer.vocab_size - target_size) < + abs(subtokenizer.vocab_size - target_size)): + return other_subtokenizer + return subtokenizer return bisect(min_val, max_val) diff --git a/tensor2tensor/models/bluenet.py b/tensor2tensor/models/bluenet.py index 19bed2032..8f4c89eac 100644 --- a/tensor2tensor/models/bluenet.py +++ b/tensor2tensor/models/bluenet.py @@ -77,7 +77,8 @@ def run_binary_modules(modules, cur1, cur2, hparams): """Run binary modules.""" selection_var = tf.get_variable("selection", [len(modules)], initializer=tf.zeros_initializer()) - inv_t = 100.0 * common_layers.inverse_exp_decay(100000, min_value=0.01) + inv_t = 100.0 * common_layers.inverse_exp_decay( + hparams.anneal_until, min_value=0.01) selected_weights = tf.nn.softmax(selection_var * inv_t) all_res = [modules[n](cur1, cur2, hparams) for n in xrange(len(modules))] all_res = tf.concat([tf.expand_dims(r, axis=0) for r in all_res], axis=0) @@ -89,7 +90,8 @@ def run_unary_modules_basic(modules, cur, hparams): """Run unary modules.""" selection_var = tf.get_variable("selection", [len(modules)], initializer=tf.zeros_initializer()) - inv_t = 100.0 * common_layers.inverse_exp_decay(100000, min_value=0.01) + inv_t = 100.0 * common_layers.inverse_exp_decay( + hparams.anneal_until, min_value=0.01) selected_weights = tf.nn.softmax(selection_var * inv_t) all_res = [modules[n](cur, hparams) for n in xrange(len(modules))] all_res = tf.concat([tf.expand_dims(r, axis=0) for r in all_res], axis=0) @@ -109,7 +111,8 @@ def run_unary_modules_sample(modules, cur, hparams, k): lambda: tf.zeros_like(cur), lambda i=n: modules[i](cur, hparams)) for n in xrange(len(modules))] - inv_t = 100.0 * common_layers.inverse_exp_decay(100000, min_value=0.01) + inv_t = 100.0 * common_layers.inverse_exp_decay( + hparams.anneal_until, min_value=0.01) selected_weights = tf.nn.softmax(selection_var * inv_t - 1e9 * (1.0 - to_run)) all_res = tf.concat([tf.expand_dims(r, axis=0) for r in all_res], axis=0) res = all_res * tf.reshape(selected_weights, [-1, 1, 1, 1, 1]) @@ -122,6 +125,14 @@ def run_unary_modules(modules, cur, hparams): return run_unary_modules_sample(modules, cur, hparams, 4) +def batch_deviation(x): + """Average deviation of the batch.""" + x_mean = tf.reduce_mean(x, axis=[0], keep_dims=True) + x_variance = tf.reduce_mean( + tf.square(x - x_mean), axis=[0], keep_dims=True) + return tf.reduce_mean(tf.sqrt(x_variance)) + + @registry.register_model class BlueNet(t2t_model.T2TModel): @@ -153,14 +164,15 @@ def run_unary(x, name): with tf.variable_scope("conv"): x = run_unary_modules(conv_modules, x, hparams) x.set_shape(x_shape) - return x + return tf.nn.dropout(x, 1.0 - hparams.dropout), batch_deviation(x) - cur1, cur2 = inputs, inputs + cur1, cur2, extra_loss = inputs, inputs, 0.0 cur_shape = inputs.get_shape() for i in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % i): - cur1 = run_unary(cur1, "unary1") - cur2 = run_unary(cur2, "unary2") + cur1, loss1 = run_unary(cur1, "unary1") + cur2, loss2 = run_unary(cur2, "unary2") + extra_loss += (loss1 + loss2) / float(hparams.num_hidden_layers) with tf.variable_scope("binary1"): next1 = run_binary_modules(binary_modules, cur1, cur2, hparams) next1.set_shape(cur_shape) @@ -169,7 +181,9 @@ def run_unary(x, name): next2.set_shape(cur_shape) cur1, cur2 = next1, next2 - return cur1 + anneal = common_layers.inverse_exp_decay(hparams.anneal_until) + extra_loss *= hparams.batch_deviation_loss_factor * anneal + return cur1, extra_loss @registry.register_hparams @@ -185,7 +199,7 @@ def bluenet_base(): hparams.num_hidden_layers = 8 hparams.kernel_height = 3 hparams.kernel_width = 3 - hparams.learning_rate_decay_scheme = "exp50k" + hparams.learning_rate_decay_scheme = "exp10k" hparams.learning_rate = 0.05 hparams.learning_rate_warmup_steps = 3000 hparams.initializer_gain = 1.0 @@ -196,6 +210,8 @@ def bluenet_base(): hparams.optimizer_adam_beta1 = 0.85 hparams.optimizer_adam_beta2 = 0.997 hparams.add_hparam("imagenet_use_2d", True) + hparams.add_hparam("anneal_until", 40000) + hparams.add_hparam("batch_deviation_loss_factor", 0.001) return hparams