diff --git a/.travis.yml b/.travis.yml index 2cdcd85bf..5e0aff001 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,8 +3,11 @@ python: - "2.7" - "3.6" before_install: + - echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list + - curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add - - sudo apt-get update -qq - sudo apt-get install -qq libhdf5-dev + - sudo apt-get install -qq tensorflow-model-server install: - pip install -q .[tensorflow] - pip install -q .[tests] @@ -21,7 +24,7 @@ script: - python -c "from tensor2tensor.models import transformer; print(transformer.Transformer.__name__)" # Run tests - - pytest --ignore=tensor2tensor/utils/registry_test.py --ignore=tensor2tensor/problems_test.py --ignore=tensor2tensor/utils/trainer_lib_test.py --ignore=tensor2tensor/data_generators/algorithmic_math_test.py + - pytest --ignore=tensor2tensor/utils/registry_test.py --ignore=tensor2tensor/problems_test.py --ignore=tensor2tensor/utils/trainer_lib_test.py --ignore=tensor2tensor/data_generators/algorithmic_math_test.py --ignore=tensor2tensor/bin/t2t_trainer_test.py - pytest tensor2tensor/utils/registry_test.py - pytest tensor2tensor/utils/trainer_lib_test.py @@ -36,5 +39,14 @@ script: - t2t-datagen --problem=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR - t2t-trainer --problems=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --train_steps=5 --eval_steps=5 --output_dir=$T2T_TRAIN_DIR - t2t-decoder --problems=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --output_dir=$T2T_TRAIN_DIR --decode_hparams='num_samples=10' + + # Export and query (on Python 2 only) + - t2t-exporter --problems=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --output_dir=$T2T_TRAIN_DIR + - if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then + pip install tensorflow-serving-api; + tensorflow_model_server --port=9000 --model_name=my_model --model_base_path=$T2T_TRAIN_DIR/export/Servo & + sleep 10; + t2t-query-server --problem=$T2T_PROBLEM --server=localhost:9000 --servable_name=my_model --data_dir=$T2T_DATA_DIR --inputs_once='1 0 1 0 1 0'; + fi git: depth: 3 diff --git a/docs/cloud_mlengine.md b/docs/cloud_mlengine.md new file mode 100644 index 000000000..b257fab25 --- /dev/null +++ b/docs/cloud_mlengine.md @@ -0,0 +1,80 @@ +# Running on Cloud ML Engine + +Google Cloud Platform offers a managed training environment for TensorFlow +models called [Cloud ML Engine](https://cloud.google.com/ml-engine/) and +you can easily launch Tensor2Tensor on it, including for hyperparameter tuning. + +# Launch + +It's the same `t2t-trainer` you know and love with the addition of the +`--cloud_mlengine` flag, which by default will launch on a 1-GPU machine. + +``` +# Note that both the data dir and output dir have to be on GCS +DATA_DIR=gs://my-bucket/data +OUTPUT_DIR=gs://my-bucket/train +t2t-trainer \ + --problems=translate_ende_wmt32k \ + --model=transformer \ + --hparams_set=transformer_base \ + --data_dir=$DATA_DIR \ + --output_dir=$OUTPUT_DIR \ + --cloud_mlengine +``` + +By passing `--worker_gpu=4` or `--worker_gpu=8` it will automatically launch on +machines with 4 or 8 GPUs. + +You can additionally pass the `--cloud_mlengine_master_type` to select another +kind of machine (see the [docs for +`masterType`](https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#traininginput) +for your options). If you provide this flag yourself, make sure you pass the +correct value for `--worker_gpu`. + +**Note**: `t2t-trainer` only currently supports launching with single machines, +possibly with multiple GPUs. Multi-machine setups are not yet supported out of +the box with the `--cloud_mlengine` flag, though multi-machine should in +principle work just fine. Contributions/testers welcome. + +## `--t2t_usr_dir` + +Launching on Cloud ML Engine works with `--t2t_usr_dir` as well as long as the +directory is fully self-contained (i.e. the imports only refer to other modules +in the directory). If there are additional PyPI dependencies that you need, you +can include a `setup.py` file in your directory (ensure that it uses +`setuptools.find_packages`). + +# Hyperparameter Tuning + +Hyperparameter tuning with `t2t-trainer` and Cloud ML Engine is also a breeze +with `--hparams_range` and the `--autotune_*` flags: + +``` +t2t-trainer \ + --problems=translate_ende_wmt32k \ + --model=transformer \ + --hparams_set=transformer_base \ + --data_dir=$DATA_DIR \ + --output_dir=$OUTPUT_DIR \ + --cloud_mlengine \ + --hparams_range=transformer_base_range \ + --autotune_objective='metrics-translate_ende_wmt32k/neg_log_perplexity' \ + --autotune_maximize \ + --autotune_max_trials=100 \ + --autotune_parallel_trials=3 +``` + +The `--hparams_range` specifies the search space and should be registered with +`@register_ranged_hparams`. It defines a `RangedHParams` object that sets +search ranges and scales for various parameters. See `transformer_base_range` +in +[`transformer.py`](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py) +for an example. + +The metric name passed as `--autotune_objective` should be exactly what you'd +see in TensorBoard. To minimize a metric, set `--autotune_maximize=False`. + +You control how many total trials to run with `--autotune_max_trials` and the +number of jobs to launch in parallel with `--autotune_parallel_trials`. + +Happy tuning! diff --git a/setup.py b/setup.py index ee0eb0d09..1153dbba8 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='tensor2tensor', - version='1.4.3', + version='1.4.4', description='Tensor2Tensor', author='Google Inc.', author_email='no-reply@google.com', @@ -35,9 +35,9 @@ 'flask', 'future', 'gevent', + 'google-api-python-client', 'gunicorn', 'gym<=0.9.5', # gym in version 0.9.6 has some temporary issues. - 'munch', 'numpy', 'requests', 'scipy', diff --git a/tensor2tensor/bin/t2t-rl-trainer b/tensor2tensor/bin/t2t-rl-trainer deleted file mode 100644 index 06c97d2d5..000000000 --- a/tensor2tensor/bin/t2t-rl-trainer +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env python -"""t2t-rl-trainer.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensor2tensor.bin import t2t_rl_trainer - -import tensorflow as tf - -def main(argv): - t2t_rl_trainer.main(argv) - - -if __name__ == "__main__": - tf.app.run() diff --git a/tensor2tensor/bin/t2t_trainer.py b/tensor2tensor/bin/t2t_trainer.py index 8f1f0dfdc..469734883 100644 --- a/tensor2tensor/bin/t2t_trainer.py +++ b/tensor2tensor/bin/t2t_trainer.py @@ -26,7 +26,8 @@ from tensor2tensor import models # pylint: disable=unused-import from tensor2tensor import problems as problems_lib # pylint: disable=unused-import -from tensor2tensor.utils import cloud +from tensor2tensor.utils import cloud_mlengine +from tensor2tensor.utils import cloud_tpu from tensor2tensor.utils import decoding from tensor2tensor.utils import flags as t2t_flags # pylint: disable=unused-import from tensor2tensor.utils import registry @@ -81,6 +82,34 @@ flags.DEFINE_bool("cloud_delete_on_done", False, "Whether to delete the VM and TPU instance when done.") +# Google Cloud ML Engine +flags.DEFINE_bool("cloud_mlengine", False, + "Whether to launch on Cloud ML Engine.") +flags.DEFINE_string("cloud_mlengine_master_type", None, + "Machine type for master on Cloud ML Engine. " + "If provided, overrides default selections based on " + "--worker_gpu. User is responsible for ensuring " + "type is valid and that --worker_gpu matches number of " + "GPUs on machine type. See documentation: " + "https://cloud.google.com/ml-engine/reference/rest/v1/" + "projects.jobs#traininginput") +# Hyperparameter tuning on Cloud ML Engine +# Pass an --hparams_range to enable +flags.DEFINE_string("autotune_objective", None, + "TensorBoard metric name to optimize.") +flags.DEFINE_bool("autotune_maximize", True, + "Whether to maximize (vs. minimize) autotune_objective.") +flags.DEFINE_integer("autotune_max_trials", 10, + "Maximum number of tuning experiments to run.") +flags.DEFINE_integer("autotune_parallel_trials", 1, + "How many trials to run in parallel (will spin up this " + "many jobs.") +# Note than in open-source TensorFlow, the dash gets converted to an underscore, +# so access is FLAGS.job_dir. +flags.DEFINE_string("job-dir", None, + "DO NOT USE. Exists only for Cloud ML Engine to pass in " + "during hyperparameter tuning. Overrides --output_dir.") + def get_problem_name(): problems = FLAGS.problems.split("-") @@ -88,6 +117,33 @@ def get_problem_name(): return problems[0] +def set_hparams_from_args(args): + """Set hparams overrides from unparsed args list.""" + if not args: + return + + hp_prefix = "--hp_" + tf.logging.info("Found unparsed command-line arguments. Checking if any " + "start with %s and interpreting those as hparams " + "settings.", hp_prefix) + + pairs = [] + i = 0 + while i < len(args): + arg = args[i] + if arg.startswith(hp_prefix): + pairs.append((arg.lstrip(hp_prefix), args[i+1])) + i += 2 + else: + tf.logging.warn("Found unknown flag: %s", arg) + i += 1 + + as_hparams = ",".join(["%s=%s" % (key, val) for key, val in pairs]) + if FLAGS.hparams: + as_hparams = "," + as_hparams + FLAGS.hparams += as_hparams + + def create_hparams(): if (FLAGS.cloud_tpu or FLAGS.use_tpu) and "tpu" not in FLAGS.hparams_set: tf.logging.warn("Not all hyperparameter sets work on TPU. " @@ -244,7 +300,7 @@ def maybe_cloud_tpu(): "be gs:// paths, i.e. on Google Cloud Storage.") FLAGS.use_tpu = True - with cloud.cloud_tpu( + with cloud_tpu.cloud_tpu( FLAGS.cloud_vm_name, FLAGS.cloud_tpu_name, delete_on_done=FLAGS.cloud_delete_on_done) as tpu_master: @@ -252,15 +308,23 @@ def maybe_cloud_tpu(): yield -def main(_): +def main(argv): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) log_registry() + if FLAGS.cloud_mlengine: + return cloud_mlengine.launch() + if FLAGS.generate_data: generate_data() + if hasattr(FLAGS, "job_dir") and FLAGS.job_dir: + FLAGS.output_dir = FLAGS.job_dir + + if argv: + set_hparams_from_args(argv[1:]) hparams = create_hparams() if is_chief(): save_metadata(hparams) diff --git a/tensor2tensor/bin/t2t_trainer_test.py b/tensor2tensor/bin/t2t_trainer_test.py new file mode 100644 index 000000000..b1f38cec5 --- /dev/null +++ b/tensor2tensor/bin/t2t_trainer_test.py @@ -0,0 +1,50 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for t2t_trainer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + +from tensor2tensor.bin import t2t_trainer +from tensor2tensor.utils import trainer_lib_test + +import tensorflow as tf + +FLAGS = tf.flags.FLAGS + + +class TrainerTest(tf.test.TestCase): + + @classmethod + def setUpClass(cls): + trainer_lib_test.TrainerLibTest.setUpClass() + + def testTrain(self): + FLAGS.problems = "tiny_algo" + FLAGS.model = "transformer" + FLAGS.hparams_set = "transformer_tiny" + FLAGS.train_steps = 1 + FLAGS.eval_steps = 1 + FLAGS.output_dir = tf.test.get_temp_dir() + FLAGS.data_dir = tf.test.get_temp_dir() + t2t_trainer.main(None) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensor2tensor/data_generators/librispeech.py b/tensor2tensor/data_generators/librispeech.py index ad8e931d8..0c59824c1 100644 --- a/tensor2tensor/data_generators/librispeech.py +++ b/tensor2tensor/data_generators/librispeech.py @@ -66,8 +66,7 @@ def _collect_data(directory, input_ext, transcription_ext): transcript_path = os.path.join(root, transcript) with open(transcript_path, "r") as transcript_file: for transcript_line in transcript_file: - line_contents = transcript_line.split(" ", 1) - assert len(line_contents) == 2 + line_contents = transcript_line.strip().split(" ", 1) media_base, label = line_contents key = os.path.join(root, media_base) assert key not in data_files diff --git a/tensor2tensor/data_generators/problem.py b/tensor2tensor/data_generators/problem.py index 890271dbe..a2c330c2d 100644 --- a/tensor2tensor/data_generators/problem.py +++ b/tensor2tensor/data_generators/problem.py @@ -517,16 +517,26 @@ def _maybe_reverse_and_copy(example): if shuffle_files: random.shuffle(data_files) dataset = tf.data.Dataset.from_tensor_slices(tf.constant(data_files)) - dataset = dataset.apply( - tf.contrib.data.parallel_interleave( - _load_records, sloppy=is_training, cycle_length=8)) + + if hasattr(tf.contrib.data, "parallel_interleave"): + dataset = dataset.apply( + tf.contrib.data.parallel_interleave( + _load_records, sloppy=is_training, cycle_length=8)) + else: + dataset = dataset.interleave(_load_records, cycle_length=8, + block_length=16) + if repeat: dataset = dataset.repeat() dataset = dataset.map(self.decode_example, num_parallel_calls=num_threads) if preprocess: - dataset = dataset.apply( - tf.contrib.data.parallel_interleave( - _preprocess, sloppy=is_training, cycle_length=8)) + if hasattr(tf.contrib.data, "parallel_interleave"): + dataset = dataset.apply( + tf.contrib.data.parallel_interleave( + _preprocess, sloppy=is_training, cycle_length=8)) + else: + dataset = dataset.interleave(_preprocess, cycle_length=8, + block_length=16) dataset = dataset.map( _maybe_reverse_and_copy, num_parallel_calls=num_threads) @@ -633,6 +643,8 @@ def _dataset_partition(self, mode, config): num_partitions: an integer """ if mode != tf.estimator.ModeKeys.TRAIN or not hasattr(config, "tpu_config"): + # Reset in the case when using TPU but alternating TRAIN and EVAL. + self._next_partition_id = 0 return 0, 1 if config.tpu_config.per_host_input_for_training: num_partitions = max(config.tpu_config.num_shards // 8, 1) @@ -670,7 +682,7 @@ def input_fn(self, partition_id, num_partitions = self._dataset_partition(mode, config) is_training = mode == tf.estimator.ModeKeys.TRAIN - if config.use_tpu: + if config and config.use_tpu: num_threads = 64 else: num_threads = 4 if is_training else 1 diff --git a/tensor2tensor/data_generators/speech_recognition.py b/tensor2tensor/data_generators/speech_recognition.py index 01a3db564..e17e4de85 100644 --- a/tensor2tensor/data_generators/speech_recognition.py +++ b/tensor2tensor/data_generators/speech_recognition.py @@ -32,6 +32,7 @@ from tensor2tensor.data_generators import problem from tensor2tensor.data_generators import text_encoder +from tensor2tensor.layers import common_attention from tensor2tensor.layers import common_layers from tensor2tensor.utils import metrics from tensor2tensor.utils import modality @@ -76,7 +77,7 @@ def compute_mel_filterbank_features( frame_length=25, frame_step=10, fft_length=None, window_fn=functools.partial(tf.contrib.signal.hann_window, periodic=True), lower_edge_hertz=80.0, upper_edge_hertz=7600.0, num_mel_bins=80, - log_noise_floor=1e-3): + log_noise_floor=1e-3, apply_mask=True): """Implement mel-filterbank extraction using tf ops. Args: @@ -93,6 +94,7 @@ def compute_mel_filterbank_features( upper_edge_hertz: highest frequency of the filterbank num_mel_bins: filterbank size log_noise_floor: clip small values to prevent numeric overflow in log + apply_mask: When working on a batch of samples, set padding frames to zero Returns: filterbanks: a float32 tensor with shape [batch_size, len, num_bins, 1] """ @@ -100,14 +102,24 @@ def compute_mel_filterbank_features( # Transform of each signal in `signals`. Its shape is # [batch_size, ?, fft_unique_bins] # where fft_unique_bins = fft_length // 2 + 1 + + # Find the wave length: the largest index for which the value is !=0 + # note that waveforms samples that are exactly 0.0 are quite common, so + # simply doing sum(waveforms != 0, axis=-1) will not work correctly. + wav_lens = tf.reduce_max( + tf.expand_dims(tf.range(tf.shape(waveforms)[1]), 0) * + tf.to_int32(tf.not_equal(waveforms, 0.0)), + axis=-1) + 1 if dither > 0: waveforms += tf.random_normal(tf.shape(waveforms), stddev=dither) if preemphasis > 0: waveforms = waveforms[:, 1:] - preemphasis * waveforms[:, :-1] + wav_lens -= 1 frame_length = int(frame_length * sample_rate / 1e3) frame_step = int(frame_step * sample_rate / 1e3) if fft_length is None: fft_length = int(2**(np.ceil(np.log2(frame_length)))) + stfts = tf.contrib.signal.stft( waveforms, frame_length=frame_length, @@ -116,6 +128,11 @@ def compute_mel_filterbank_features( window_fn=window_fn, pad_end=True) + stft_lens = (wav_lens + (frame_step - 1)) // frame_step + masks = tf.to_float(tf.less_equal( + tf.expand_dims(tf.range(tf.shape(stfts)[1]), 0), + tf.expand_dims(stft_lens, 1))) + # An energy spectrogram is the magnitude of the complex-valued STFT. # A float32 Tensor of shape [batch_size, ?, 257]. magnitude_spectrograms = tf.abs(stfts) @@ -134,7 +151,10 @@ def compute_mel_filterbank_features( log_mel_sgram = tf.log(tf.maximum(log_noise_floor, mel_spectrograms)) - return tf.expand_dims(log_mel_sgram, -1) + if apply_mask: + log_mel_sgram *= tf.expand_dims(tf.to_float(masks), -1) + + return tf.expand_dims(log_mel_sgram, -1, name="mel_sgrams") # @@ -207,12 +227,21 @@ def vocab_size(self): return 256 +class ByteTextEncoderWithEos(text_encoder.ByteTextEncoder): + """Encodes each byte to an id and appends the EOS token.""" + + def encode(self, s): + return super(ByteTextEncoderWithEos, self).encode(s) + [text_encoder.EOS_ID] + + class SpeechRecognitionProblem(problem.Problem): """Base class for speech recognition problems.""" def hparams(self, defaults, model_hparams): p = model_hparams # Filterbank extraction + # Filterbank extraction in bottom instead of preprocess_example is faster. + p.add_hparam("audio_preproc_in_bottom", False) # The trainer seems to reserve memory for all members of the input dict p.add_hparam("audio_keep_example_waveforms", False) p.add_hparam("audio_sample_rate", 16000) @@ -248,7 +277,7 @@ def feature_encoders(self, _): # decoding.py doesn't try to convert the floats # into text... "waveforms": AudioEncoder(), - "targets": text_encoder.ByteTextEncoder(), + "targets": ByteTextEncoderWithEos(), } def example_reading_spec(self): @@ -263,25 +292,30 @@ def example_reading_spec(self): def preprocess_example(self, example, mode, hparams): p = hparams - waveforms = tf.expand_dims(example["waveforms"], 0) - mel_fbanks = compute_mel_filterbank_features( - waveforms, - sample_rate=p.audio_sample_rate, - dither=p.audio_dither, - preemphasis=p.audio_preemphasis, - frame_length=p.audio_frame_length, - frame_step=p.audio_frame_step, - lower_edge_hertz=p.audio_lower_edge_hertz, - upper_edge_hertz=p.audio_upper_edge_hertz, - num_mel_bins=p.audio_num_mel_bins) - if p.audio_add_delta_deltas: - mel_fbanks = add_delta_deltas(mel_fbanks) - fbank_size = common_layers.shape_list(mel_fbanks) - assert fbank_size[0] == 1 - # Later models like to flatten the two spatial dims. Instead, we add a - # unit spatial dim and flatten the frequencies and channels. - example["inputs"] = tf.reshape( - mel_fbanks, [fbank_size[1], 1, fbank_size[2] * fbank_size[3]]) + if p.audio_preproc_in_bottom: + example["inputs"] = tf.expand_dims( + tf.expand_dims(example["waveforms"], -1), -1) + else: + waveforms = tf.expand_dims(example["waveforms"], 0) + mel_fbanks = compute_mel_filterbank_features( + waveforms, + sample_rate=p.audio_sample_rate, + dither=p.audio_dither, + preemphasis=p.audio_preemphasis, + frame_length=p.audio_frame_length, + frame_step=p.audio_frame_step, + lower_edge_hertz=p.audio_lower_edge_hertz, + upper_edge_hertz=p.audio_upper_edge_hertz, + num_mel_bins=p.audio_num_mel_bins, + apply_mask=False) + if p.audio_add_delta_deltas: + mel_fbanks = add_delta_deltas(mel_fbanks) + fbank_size = common_layers.shape_list(mel_fbanks) + assert fbank_size[0] == 1 + # Later models like to flatten the two spatial dims. Instead, we add a + # unit spatial dim and flatten the frequencies and channels. + example["inputs"] = tf.reshape( + mel_fbanks, [fbank_size[1], 1, fbank_size[2] * fbank_size[3]]) if not p.audio_keep_example_waveforms: del example["waveforms"] return super(SpeechRecognitionProblem, self @@ -306,37 +340,72 @@ def bottom(self, inputs): float32 tensor with shape [batch_size, shorter_len, 1, hidden_size] """ p = self._model_hparams - training = p.mode == tf.estimator.ModeKeys.TRAIN + + num_mel_bins = p.audio_num_mel_bins + num_channels = 3 if p.audio_add_delta_deltas else 1 with tf.variable_scope(self.name): - x = inputs - num_mel_bins = p.audio_num_mel_bins - num_channels = 3 if p.audio_add_delta_deltas else 1 + if p.audio_preproc_in_bottom: + # Compute filterbanks + with tf.variable_scope("fbanks"): + waveforms = tf.squeeze(inputs, [2, 3]) + mel_fbanks = compute_mel_filterbank_features( + waveforms, + sample_rate=p.audio_sample_rate, + dither=p.audio_dither, + preemphasis=p.audio_preemphasis, + frame_length=p.audio_frame_length, + frame_step=p.audio_frame_step, + lower_edge_hertz=p.audio_lower_edge_hertz, + upper_edge_hertz=p.audio_upper_edge_hertz, + num_mel_bins=p.audio_num_mel_bins, + apply_mask=True) + if p.audio_add_delta_deltas: + mel_fbanks = add_delta_deltas(mel_fbanks) + x = tf.reshape(mel_fbanks, + common_layers.shape_list(mel_fbanks)[:2] + + [1, num_mel_bins * num_channels]) + else: + x = inputs + # The convention is that the models are flattened along the spatial, # dimensions, thus the speech preprocessor treats frequencies and # channels as image colors (last axis) x.set_shape([None, None, 1, num_mel_bins * num_channels]) + xshape = common_layers.shape_list(x) + + nonpadding_mask = 1. - common_attention.embedding_to_padding(x) + num_of_nonpadding_elements = tf.reduce_sum( + nonpadding_mask) * num_mel_bins * num_channels + # This replaces CMVN estimation on data - x = tf.layers.batch_normalization( - x, axis=3, center=False, scale=False, training=training) + mean = tf.reduce_sum( + x, axis=[1, 2], keepdims=True) / num_of_nonpadding_elements + variance = (num_of_nonpadding_elements * mean**2. - + 2. * mean * tf.reduce_sum(x, axis=[1, 2], keepdims=True) + + tf.reduce_sum(x**2, axis=[1, 2], keepdims=True) + ) / num_of_nonpadding_elements + x = (x - mean) / variance * tf.expand_dims(nonpadding_mask, -1) - xshape = common_layers.shape_list(x) # restore batch_size x time x frequency x channel layout x = tf.reshape(x, [xshape[0], xshape[1], num_mel_bins, num_channels]) # TODO(chorowski): how to specify bottom's hparams and avoid hardcoding? for _ in range(2): + x = tf.pad(x, [[0, 0], [0, 2], [0, 0], [0, 0]]) x = tf.layers.conv2d( x, 128, (3, 3), (2, 2), use_bias=False) - x = tf.layers.batch_normalization(x, axis=3, training=training) + x = common_layers.layer_norm(x) x = tf.nn.relu(x) xshape = common_layers.shape_list(x) # apply a conv that will remove all frequencies and at the same time # project the output into desired hidden_size + x = tf.pad(x, [[0, 0], [0, 2], [0, 0], [0, 0]]) x = tf.layers.conv2d(x, p.hidden_size, (3, xshape[2]), use_bias=False) + assert common_layers.shape_list(x)[2] == 1 - x = tf.layers.batch_normalization(x, axis=3, training=training) + x = common_layers.layer_norm(x) x = tf.nn.relu(x) return x diff --git a/tensor2tensor/layers/common_attention.py b/tensor2tensor/layers/common_attention.py index e82e6d471..63bf8d6cd 100644 --- a/tensor2tensor/layers/common_attention.py +++ b/tensor2tensor/layers/common_attention.py @@ -192,9 +192,9 @@ def memeff_attention_fn(*args, **kwargs): attention_type="local_mask_right", ) - # === Memory-compressed multihead self attention layer === + # === Masked memory-compressed multihead self attention layer === # Only works for self attention. Always mask the future. - compressed_attention_fn = register_layer( + compressed_attention_masked_fn = register_layer( multihead_self_attention_reduced, default_kwargs=dict( factor=hparams.attention_red_factor, @@ -209,6 +209,13 @@ def memeff_attention_fn(*args, **kwargs): ), ) + # === Unmasked memory-compressed multihead self attention layer === + # Only works for self attention. Never mask the future. Bias never added + compressed_attention_fn = partial( + compressed_attention_masked_fn, + add_mask=False, + ) + # Feed-forwards layers: # === Mixture of expert layer === @@ -259,14 +266,17 @@ def memeff_attention_fn(*args, **kwargs): # Define all available layers layers = dict( + # Attention layers: a=multihead_attention_fn, # Multihead full attention loc=local_attention_fn, # Local attention - locm=local_attention_masked_fn, # Local masked attention + locm=local_attention_masked_fn, # Local attention (masked) red=compressed_attention_fn, # Memory-compressed attention + redm=compressed_attention_masked_fn, # Memory-compressed att (masked) mem=memeff_attention_fn, # Memory efficient - fc=conv_hidden_relu, - sep=sep_conv_relu, # Fully connected - sepm=sep_conv_relu_masked, # masked separable convolution + # Feed-forward layers: + fc=conv_hidden_relu, # Fully connected + sep=sep_conv_relu, # Separable convolution (unmasked) + sepm=sep_conv_relu_masked, # Separable convolution (masked) moe=distributed_moe, # Mixture of expert layer ) return layers @@ -317,6 +327,34 @@ def add_standard_attention_hparams(hparams): return hparams +def encoder_decoder_attention_loss(expected_attention, actual_attentions): + """Computes encdec attention loss between expected and actual attentions. + + Args: + expected_attention: Tensor storing the expected encoder-decoder attention + weights with shape [batch_size, target_length, input_length]. + actual_attentions: Dictionary with actual attention weights for different + attention types and hidden layers. + + Returns: + MSE loss between the actual and expected attention weights. + """ + # For each hidden layer, we have an attention weight tensor with shape + # [batch_size, num_heads, target_length, input_length]. + actual_encdec_attention_weights = [ + t for layer_key, t in actual_attentions.items() + if "encdec_attention" in layer_key + ] + # Stack all hidden layer attention weight tensors to get a tensor with shape + # [num_hidden_layers, batch_size, num_heads, target_length, input_length]. + actual_attention_weights = tf.stack(actual_encdec_attention_weights) + # Reduce mean across all layers (axis=0) and all heads (axis=2) to get a + # tensor with shape [batch_size, target_length, input_length]. + actual_attention_weights = tf.reduce_mean(actual_attention_weights, [0, 2]) + return tf.losses.mean_squared_error(expected_attention, + actual_attention_weights) + + @expert_utils.add_name_scope() def get_timing_signal_1d(length, channels, @@ -3392,7 +3430,7 @@ def pad_and_reshape(x): block_length, block_length, # Restore the block length dimension ]) - weights = tf.reduce_sum(weights, axis=3, keepdims=True) # Compress block + weights = tf.reduce_sum(weights, axis=3, keep_dims=True) # Compress block v_out = tf.matmul(weights, v) # [1, block_length] @ [block_length, depth] v_out = tf.squeeze(v_out, axis=3) return v_out @@ -3415,6 +3453,7 @@ def multihead_self_attention_reduced( multihead_params=None, nonlinearity="none", reduction_type="conv", + add_mask=True, ): """Reduce the length dimension by compressing with conv. @@ -3426,6 +3465,7 @@ def multihead_self_attention_reduced( multihead_params (dict): parameters for multihead attention nonlinearity (str): Add some non-linearity after the memory block reduction_type (str): type of compression + add_mask (bool): If True, add the bias to prevent attention to the future Returns: (tf.Tensor): float32 of shape [batch, length, depth] @@ -3475,18 +3515,21 @@ def construct_bias_vectors(t, axis): # [1, length_k] or [length_q, 1] return length_coordinates - bias = tf.to_float( - tf.greater( - # Because we add the first elem to the memory block and it can be - # attended by anyone,we don't need to add +1 anymore to prevent self - # attention Use * factor to make sure the last tokens of a block - # cannot attend the block - construct_bias_vectors(memory_x, 0) * factor, - # +epsilon to avoid float equality - construct_bias_vectors(x, 1) + 1e-3, - )) * -1e9 - bias = tf.expand_dims(bias, axis=0) - bias = tf.expand_dims(bias, axis=0) # [1, 1, length_k, length_q] + if add_mask: # Create mask to prevent attention to the future + bias = tf.to_float( + tf.greater( + # Because we add the first elem to the memory block and it can be + # attended by anyone,we don't need to add +1 anymore to prevent self + # attention Use * factor to make sure the last tokens of a block + # cannot attend the block + construct_bias_vectors(memory_x, 0) * factor, + # +epsilon to avoid float equality + construct_bias_vectors(x, 1) + 1e-3, + )) * -1e9 + bias = tf.expand_dims(bias, axis=0) + bias = tf.expand_dims(bias, axis=0) # [1, 1, length_k, length_q] + else: + bias = None return multihead_attention( query_antecedent=x, diff --git a/tensor2tensor/layers/common_hparams.py b/tensor2tensor/layers/common_hparams.py index 02a5df2f3..8fbd88bd2 100644 --- a/tensor2tensor/layers/common_hparams.py +++ b/tensor2tensor/layers/common_hparams.py @@ -21,7 +21,6 @@ # Dependency imports -import six from six.moves import zip # pylint: disable=redefined-builtin from tensor2tensor.utils import registry @@ -64,6 +63,11 @@ def basic_params1(): optimizer_momentum_nesterov=False, weight_decay=1e-6, weight_noise=0.0, + learning_rate_schedule="warmup_and_decay", + # If learning_rate_schedule=="warmup_and_decay", then this specifies + # the decay part of the schedule. + # The warmup is always exponential. + # TODO(noam): add a hyperparameter to control the warmup. learning_rate_decay_scheme="none", # decay_steps and decay_staircase for learning_rate_decay_scheme=="exp" learning_rate_decay_steps=5000, @@ -224,10 +228,15 @@ class RangedHParams(object): LOG_SCALE = 2 REVERSE_LOG_SCALE = 3 + SCALES_STR = { + LINEAR_SCALE: "UNIT_LINEAR_SCALE", + LOG_SCALE: "UNIT_LOG_SCALE", + REVERSE_LOG_SCALE: "UNIT_REVERSE_LOG_SCALE", + } + def __init__(self): self._categorical_params = {} self._discrete_params = {} - self._discrete_float_params = {} self._float_params = {} self._int_params = {} @@ -237,10 +246,12 @@ def _check_reset_and_type_change(self, name, orig_ctr): if name in orig_ctr: tf.logging.warning("Overwriting hparam %s", name) - ctr_names = [(self._categorical_params, - "categorical"), (self._discrete_params, "discrete"), - (self._float_params, "float"), (self._int_params, "int"), - (self._discrete_float_params, "discrete_float")] + ctr_names = [ + (self._categorical_params, "categorical"), + (self._discrete_params, "discrete"), + (self._float_params, "float"), + (self._int_params, "int"), + ] ctrs, names = list(zip(*ctr_names)) orig_name = names[ctrs.index(orig_ctr)] @@ -263,23 +274,17 @@ def set_discrete(self, name, feasible_points, scale=None, length=None): self._discrete_params[name] = (name, feasible_points, scale, length) def set_float(self, name, min_val, max_val, scale=None, length=None): - if name in self._discrete_float_params: - del self._discrete_float_params[name] self._check_reset_and_type_change(name, self._float_params) self._float_params[name] = (name, min_val, max_val, scale, length) - def set_discrete_float(self, name, val): - self._check_reset_and_type_change(name, self._discrete_float_params) - self._discrete_float_params[name] = (name, [val]) - def set_int(self, name, min_val, max_val, scale=None, length=None): self._check_reset_and_type_change(name, self._int_params) self._int_params[name] = (name, min_val, max_val, scale, length) def fix_select_params(self, hp): ctrs = [ - self._categorical_params, self._discrete_params, - self._discrete_float_params, self._float_params, self._int_params + self._categorical_params, self._discrete_params, self._float_params, + self._int_params ] for key, val in hp.values().iteritems(): for ctr in ctrs: @@ -287,52 +292,56 @@ def fix_select_params(self, hp): del ctr[key] self.set_discrete(key, [val]) + def to_parameter_specs(self, name_prefix=""): + """To list of dicts suitable for Cloud ML Engine hyperparameter tuning.""" + specs = [] + for name, categories, _ in self._categorical_params.values(): + spec = { + "parameterName": name_prefix + name, + "type": "CATEGORICAL", + "categoricalValues": categories, + } + specs.append(spec) -def fill_ranged_hparams_from_hparams(hparams, ranged_hparams): - """Fill ranged_hparams with singleton values from hparams. + for name, feasible_points, scale, _ in self._discrete_params.values(): + spec = { + "parameterName": name_prefix + name, + "type": "DISCRETE", + "discreteValues": feasible_points, + } + if scale: + spec["scaleType"] = self.SCALES_STR[scale] + specs.append(spec) - HParams are placed in RangedHParams with the following functions, according to - type: - * int: set_discrete - * bool: set_discrete - * float: set_discrete_float - * str: set_categorical + for name, min_val, max_val, scale, _ in self._float_params.values(): + spec = { + "parameterName": name_prefix + name, + "type": "DOUBLE", + "minValue": min_val, + "maxValue": max_val, + } + if scale: + spec["scaleType"] = self.SCALES_STR[scale] + specs.append(spec) - Args: - hparams: tf.contrib.training.HParams; contains the hyperparameters to copy - over to ranged_hparams. - ranged_hparams: RangedHParams; will have hparams values copied to it. + for name, min_val, max_val, scale, _ in self._int_params.values(): + spec = { + "parameterName": name_prefix + name, + "type": "INTEGER", + "minValue": min_val, + "maxValue": max_val, + } + if scale: + spec["scaleType"] = self.SCALES_STR[scale] + specs.append(spec) - Raises: - ValueError: if hparams contains a hyperparameter not of type - {int, float, str, bool}. - """ - for name, (hp_type, is_multivalent) in six.iteritems(hparams._hparam_types): # pylint: disable=protected-access - - if is_multivalent: - raise ValueError("Multivalent hparams not supported in RangedHParams. " - "Hyperparameter %s is multivalent." % name) - val = getattr(hparams, name) - if hp_type == int: - ranged_hparams.set_discrete(name, [val]) - elif hp_type == bool: - ranged_hparams.set_discrete(name, [int(val)]) - elif hp_type == float: - ranged_hparams.set_discrete_float(name, val) - elif hp_type == str: - ranged_hparams.set_categorical(name, [val]) - else: - raise ValueError("Unsupported type %s for param %s" % (hp_type, name)) + return specs @registry.register_ranged_hparams("basic1") def basic_range1(ranged_hparams): """A basic range of hyperparameters.""" rhp = ranged_hparams - - hparams = basic_params1() - fill_ranged_hparams_from_hparams(hparams, rhp) - rhp.set_discrete("batch_size", [1024, 2048, 4096]) rhp.set_discrete("num_hidden_layers", [1, 2, 3, 4, 5, 6]) rhp.set_discrete("hidden_size", [32, 64, 128, 256, 512], scale=rhp.LOG_SCALE) diff --git a/tensor2tensor/layers/common_image_attention.py b/tensor2tensor/layers/common_image_attention.py new file mode 100644 index 000000000..88fc8ed93 --- /dev/null +++ b/tensor2tensor/layers/common_image_attention.py @@ -0,0 +1,544 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils for attention mechanism for images.""" +from tensor2tensor.layers import common_attention +from tensor2tensor.layers import common_layers +from tensor2tensor.utils import expert_utils + +import tensorflow as tf + + +class AttentionType(object): + LOCAL_1D = "local_1d" + LOCAL_2D = "local_2d" + GLOBAL = "global" + GLOCAL = "global_local" + MOE_LOCAL_1D = "moe_local1d" + + @staticmethod + def get_choices(): + return [ + AttentionType.GLOBAL, + AttentionType.GLOCAL, + AttentionType.MOE_LOCAL_1D, + AttentionType.LOCAL_1D, + AttentionType.LOCAL_2D, + ] + + +def maybe_reshape_4d_to_3d(x, hparams): + """Reshape input from 4D to 3D if necessary.""" + x_shape = common_layers.shape_list(x) + is_4d = False + if len(x_shape) == 4: + x = tf.reshape(x, [x_shape[0], x_shape[1]*x_shape[2], x_shape[3]]) + is_4d = True + x.set_shape([None, None, hparams.hidden_size]) + return x, x_shape, is_4d + + +def local_attention_2d(x, hparams, attention_type="local_attention_2d"): + """Local 2d, self attention layer.""" + # self-attention + with tf.variable_scope("local_2d_self_att"): + y = common_attention.multihead_attention_2d( + x, + None, + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + attention_type=attention_type, + query_shape=hparams.query_shape, + memory_flange=hparams.memory_flange, + name="self_attention") + return y + + +def local_attention_1d(x, + self_attention_bias, + hparams, + attention_type="local_unmasked", + q_padding="VALID", + kv_padding="VALID"): + """Local 1d self attention.""" + # self-attention + x, x_shape, is_4d = maybe_reshape_4d_to_3d(x, hparams) + with tf.variable_scope("local_1d_self_att"): + y = common_attention.multihead_attention( + x, + None, + self_attention_bias, + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout, + attention_type=attention_type, + block_width=hparams.block_width, + block_length=hparams.block_length, + q_padding=q_padding, + kv_padding=kv_padding, + q_filter_width=hparams.q_filter_width, + kv_filter_width=hparams.kv_filter_width, + name="self_attention") + if is_4d: + y = tf.reshape(y, x_shape) + y.set_shape([None, None, None, hparams.hidden_size]) + return y + + +def local_global_attention(x, + self_attention_bias, + hparams, + q_padding="LEFT", + kv_padding="LEFT"): + """Local and global 1d self attention.""" + with tf.variable_scope("self_local_global_att"): + [x_global, x_local] = tf.split(x, 2, axis=-1) + split_hidden_size = int(hparams.hidden_size / 2) + split_heads = int(hparams.num_heads / 2) + y_global = common_attention.multihead_attention( + x_global, + None, + self_attention_bias, + hparams.attention_key_channels or split_hidden_size, + hparams.attention_value_channels or split_hidden_size, + split_hidden_size, + split_heads, + hparams.attention_dropout, + q_filter_width=hparams.q_filter_width, + kv_filter_width=hparams.kv_filter_width, + q_padding=q_padding, + kv_padding=kv_padding, + name="global_self_att") + y_local = common_attention.multihead_attention( + x_local, + None, + self_attention_bias, + hparams.attention_key_channels or split_hidden_size, + hparams.attention_value_channels or split_hidden_size, + split_hidden_size, + split_heads, + hparams.attention_dropout, + attention_type="local_masked", + block_length=hparams.block_length, + block_width=hparams.block_width, + q_filter_width=hparams.q_filter_width, + kv_filter_width=hparams.kv_filter_width, + q_padding=q_padding, + kv_padding=kv_padding, + name="local_self_att") + y = tf.concat([y_global, y_local], axis=-1) + return y + + +def full_self_attention(x, + self_attention_bias, + hparams, + q_padding="LEFT", + kv_padding="LEFT"): + """Full self-attention layer.""" + x, x_shape, is_4d = maybe_reshape_4d_to_3d(x, hparams) + with tf.variable_scope("self_att"): + y = common_attention.multihead_attention( + x, + None, + self_attention_bias, + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout, + q_filter_width=hparams.q_filter_width, + kv_filter_width=hparams.kv_filter_width, + q_padding=q_padding, + kv_padding=kv_padding, + name="self_att") + if is_4d: + y = tf.reshape(y, [x_shape[0], x_shape[1], x_shape[2], x_shape[3]]) + y.set_shape([None, None, None, hparams.hidden_size]) + return y + + +def encdec_attention_1d(x, + encoder_output, + hparams): + """Local 1d self attention.""" + x, x_shape, is_4d = maybe_reshape_4d_to_3d(x, hparams) + encoder_output, _, _ = maybe_reshape_4d_to_3d(encoder_output, hparams) + with tf.variable_scope("encdec_attention"): + # Encoder Decoder attention + y = common_attention.multihead_attention( + x, + encoder_output, + None, + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout, + name="encdec_attention") + if is_4d: + y = tf.reshape(y, x_shape) + y.set_shape([None, None, None, hparams.hidden_size]) + return y + + +def transformer_decoder_layers(inputs, + encoder_output, + bias, + num_layers, + hparams, + attention_type=AttentionType.LOCAL_2D, + name="transformer"): + """Multi layer transformer.""" + x = inputs + x = tf.nn.dropout(x, 1.0 - hparams.layer_prepostprocess_dropout) + for layer in xrange(num_layers): + with tf.variable_scope("%s_layer_%d" % (name, layer)): + # self-attention + skip connections + if attention_type == AttentionType.LOCAL_2D: + y = local_attention_2d(common_layers.layer_preprocess(x, hparams), + hparams, + attention_type="masked_local_attention_2d") + elif attention_type == AttentionType.LOCAL_1D: + y = local_attention_1d(common_layers.layer_preprocess(x, hparams), + bias, hparams, + attention_type="local_mask_right", + q_padding="LEFT", kv_padding="LEFT") + elif attention_type == AttentionType.GLOCAL: + y = local_global_attention(common_layers.layer_preprocess(x, hparams), + bias, hparams, + q_padding="LEFT", kv_padding="LEFT") + elif attention_type == AttentionType.GLOBAL: + y = full_self_attention(common_layers.layer_preprocess(x, hparams), + bias, hparams, + q_padding="LEFT", kv_padding="LEFT") + # TODO(nikip): Add support for dilated attention. + x = common_layers.layer_postprocess(x, y, hparams) + # enc-dec attention + skip connections + if encoder_output is not None: + y = encdec_attention_1d(common_layers.layer_preprocess(x, hparams), + encoder_output, hparams) + x = common_layers.layer_postprocess(x, y, hparams) + # feed-fwd layers + skip connections + y = ffn_layer(common_layers.layer_preprocess(x, hparams), hparams) + x = common_layers.layer_postprocess(x, y, hparams) + return common_layers.layer_preprocess(x, hparams) + + +def transformer_encoder_layers(inputs, + num_layers, + hparams, + attention_type=AttentionType.GLOBAL, + self_attention_bias=None, + q_padding="VALID", + kv_padding="VALID", + name="transformer"): + """Multi layer transformer encoder.""" + x = inputs + x = tf.nn.dropout(x, 1.0 - hparams.layer_prepostprocess_dropout) + + for layer in xrange(num_layers): + # attention layers + skip connections + with tf.variable_scope("%s_layer_%d" % (name, layer)): + if attention_type == AttentionType.LOCAL_2D: + y = local_attention_2d(common_layers.layer_preprocess(x, hparams), + hparams, + attention_type="local_attention_2d") + elif attention_type == AttentionType.LOCAL_1D: + y = local_attention_1d(common_layers.layer_preprocess(x, hparams), + self_attention_bias, hparams, + attention_type="local_unmasked", + q_padding=q_padding, kv_padding=kv_padding) + elif attention_type == AttentionType.GLOBAL: + y = full_self_attention(common_layers.layer_preprocess(x, hparams), + self_attention_bias, hparams, + q_padding=q_padding, kv_padding=kv_padding) + x = common_layers.layer_postprocess(x, y, hparams) + # feed-fwd layer + skip connections + y = ffn_layer(common_layers.layer_preprocess(x, hparams), hparams) + x = common_layers.layer_postprocess(x, y, hparams) + return common_layers.layer_preprocess(x, hparams) + + +def ffn_layer(x, hparams): + """ffn layer transformer.""" + with tf.variable_scope("ffn"): + if hparams.ffn_layer == "none": + return x + if hparams.ffn_layer == "conv_hidden_relu": + y = common_layers.dense_relu_dense( + x, + hparams.filter_size, + hparams.hidden_size, + dropout=hparams.relu_dropout) + elif hparams.ffn_layer == "normed_conv_hidden_relu": + y = common_layers.normed_conv_hidden_relu( + x, + hparams.norm_type, + hparams.layer_norm_epsilon, + hparams.filter_size, + hparams.hidden_size, + dropout=hparams.relu_dropout, + norm_name="convnorm") + elif hparams.ffn_layer == "self_attention_ffn": + x_shape = tf.shape(x) + x = tf.reshape(x, [x_shape[0], -1, hparams.hidden_size]) + y = common_attention.ffn_self_attention_layer( + x, hparams.filter_size, hparams.hidden_size, hparams.num_parts, + hparams.attention_dropout, hparams.share_kv) + y = tf.reshape(y, x_shape) + else: + assert hparams.ffn_layer == "glu_ffn" + y = common_layers.gated_linear_unit_layer(x) + return y + + +def transformer_layers_sharded(dp, + ps_devices, + inputs, + num_layers, + hparams, + self_attention_bias=None, + enc_output=None, + attention_type=AttentionType.GLOBAL, + name="transformer"): + """Multi layer transformer, sharded by the data parallelism dp.""" + x = inputs + extra_loss = tf.constant(0.0) + moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")] + expert_fn = expert_utils.ffn_expert_fn( + hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) + x = dp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout) + for layer in xrange(num_layers): + with tf.variable_scope("%s_layer_%d" % (name, layer)): + # self-attention + if attention_type == AttentionType.LOCAL_2D: + y = dp(local_attention_2d(common_layers.layer_preprocess(x, hparams), + hparams, + attention_type="masked_local_attention_2d")) + elif attention_type == AttentionType.LOCAL_1D: + y = dp(local_attention_1d(common_layers.layer_preprocess(x, hparams), + self_attention_bias, hparams, + attention_type="local_mask_right", + q_padding="LEFT", kv_padding="LEFT")) + elif attention_type == AttentionType.GLOCAL: + y = dp(local_global_attention( + common_layers.layer_preprocess(x, hparams), self_attention_bias, + hparams, q_padding="LEFT", kv_padding="LEFT")) + elif attention_type == AttentionType.GLOBAL: + y = dp(full_self_attention(common_layers.layer_preprocess(x, hparams), + self_attention_bias, hparams, + q_padding="LEFT", kv_padding="LEFT")) + x = common_layers.layer_postprocess(x, y, hparams) + if enc_output is not None: + y = dp(encdec_attention_1d(common_layers.layer_preprocess(x, hparams), + enc_output, hparams)) + x = dp(common_layers.layer_postprocess, x, y, hparams) + with tf.variable_scope("ffn"): + if str(layer) in hparams.moe_layers_decoder.split(","): + y, loss = expert_utils.distributed_moe( + dp, + ps_devices, + common_layers.layer_preprocess(x, hparams), + hparams.mode == tf.estimator.ModeKeys.TRAIN, + input_size=hparams.hidden_size, + expert_fn=expert_fn, + num_experts=hparams.moe_num_experts, + k=hparams.moe_k, + loss_coef=hparams.moe_loss_coef) + extra_loss += loss + x = dp(common_layers.layer_postprocess, x, y, hparams) + else: + y = dp(ffn_layer, common_layers.layer_preprocess(x, hparams), hparams) + x = dp(common_layers.layer_postprocess, x, y, hparams) + return dp(common_layers.layer_preprocess, x, hparams), extra_loss + + +def postprocess_image(x, rows, cols, hparams): + """Postprocessing after decoding.""" + batch = common_layers.shape_list(x)[0] + channels = 256 + x = tf.reshape(x, [batch, rows, cols, hparams.hidden_size]) + # targets = common_layers.conv(x, 256, (1, 1), name="output_conv") + targets = tf.layers.dense(x, 256, use_bias=True, activation=None, + name="output_conv") + if hparams.mode == tf.contrib.learn.ModeKeys.INFER: + y = targets + y = tf.reshape(y, [batch, -1, hparams.img_len*3, channels]) + yshape = common_layers.shape_list(y) + block_length = hparams.query_shape[0] + block_width = hparams.query_shape[1] + + # Break into block row wise. + y = tf.reshape(y, + [batch, yshape[1] // block_length, + block_length, + yshape[2], channels]) + yshape = common_layers.shape_list(y) + # Break into blocks width wise. + y_blocks = tf.reshape(y, + [batch, yshape[1], yshape[2], + yshape[3] // block_width, + block_width, channels]) + + # Reshape targets as [batch_size, num_blocks_rows, num_block_cols, + # block_length, block_width, channels] + targets = tf.transpose(y_blocks, [0, 1, 3, 2, 4, 5]) + + return targets + + +def prepare_encoder(inputs, hparams, attention_type="local_1d"): + """Prepare encoder for images.""" + x = prepare_image(inputs, hparams, name="enc_channels") + # Add position signals. + x = add_pos_signals(x, hparams, "enc_pos") + x_shape = common_layers.shape_list(x) + if attention_type == "local_1d": + x = tf.reshape(x, [x_shape[0], x_shape[1]*x_shape[2], hparams.hidden_size]) + x.set_shape([None, None, hparams.hidden_size]) + elif attention_type == "local_2d": + x.set_shape([None, None, None, hparams.hidden_size]) + return x + + +def prepare_decoder(targets, hparams): + """Prepare decoder for images.""" + targets_shape = common_layers.shape_list(targets) + channels = hparams.num_channels + curr_infer_length = None + + # during training, images are [batch, IMG_LEN, IMG_LEN, 3]. + # At inference, they are [batch, curr_infer_length, 1, 1] + if hparams.mode == tf.contrib.learn.ModeKeys.INFER: + curr_infer_length = targets_shape[1] + if hparams.block_rastor_scan: + assert hparams.img_len*channels % hparams.query_shape[1] == 0 + assert hparams.img_len % hparams.query_shape[0] == 0 + total_block_width = hparams.img_len*channels + # Decoding is in block rastor scan order. We divide the image into + # hparams.query_shape blocks and then decode each block in rastor scan. + # To make that compatible with our inference pipeline, pad the target so + # that rows is a multiple of query_shape and columns is a multiple of + # hparams.img_len*channels + curr_infer_length = targets_shape[1] + block_padding_factor = total_block_width * hparams.query_shape[0] + targets = tf.pad(targets, [ + [0, 0], [0, -curr_infer_length % block_padding_factor], + [0, 0], [0, 0]]) + + num_blocks = total_block_width // hparams.query_shape[1] + # Reshape the image to represent blocks + target_blocks = tf.reshape( + targets, [targets_shape[0], -1, num_blocks, hparams.query_shape[0], + hparams.query_shape[1]]) + # Transpose to read the image in 2D fashion. + targets = tf.transpose(target_blocks, [0, 1, 3, 2, 4]) + else: + # add padding to make sure the size of targets is a multiple of img_height + # times number of channels. This is needed for positional encodings and + # for doing the RGB lookup. + padding_factor = channels * hparams.img_len + targets = tf.pad(targets, [ + [0, 0], [0, -curr_infer_length % padding_factor], [0, 0], [0, 0]]) + targets = tf.reshape(targets, + [targets_shape[0], -1, hparams.img_len, channels]) + # Preprocess image + x = prepare_image(targets, hparams, name="dec_channels") + x_shape = common_layers.shape_list(x) + # mask out upper triangle to avoid looking into the future. + bias = common_attention.attention_bias_lower_triangle(x_shape[1]*x_shape[2]) + if hparams.dec_attention_type == AttentionType.LOCAL_2D: + x = common_attention.right_shift_blockwise(x, hparams.query_shape) + x = add_pos_signals(x, hparams, "dec_pos") + else: + # Add position signals + x = tf.reshape(x, [-1, x_shape[1]*x_shape[2], hparams.hidden_size]) + x = common_layers.shift_right_3d(x) + x = tf.reshape(x, [-1, x_shape[1], x_shape[2], hparams.hidden_size]) + x = add_pos_signals(x, hparams, "dec_pos") + x.set_shape([None, None, None, hparams.hidden_size]) + return x, x_shape[1], x_shape[2], bias + + +def prepare_image(inputs, hparams, name=None): + """Prepare image.""" + inputs_shape = common_layers.shape_list(inputs) + batch = inputs_shape[0] + orig_rows = inputs_shape[1] + orig_cols = inputs_shape[2] + channels = hparams.num_channels + + hidden_size = hparams.hidden_size + # Only do lookup if the embeddings haven't been looked up already. + # if the last dimension is number of channels, then this is very likely the + # channel ids tensor. We have to make sure. + if inputs_shape[-1] == hparams.num_channels: + inputs = tf.to_int32(inputs) + x = get_channel_embeddings(channels, inputs, hidden_size, name=name) + else: + x = inputs + x = tf.reshape(x, [batch, orig_rows, orig_cols * channels, hidden_size]) + + return x + + +def create_output(decoder_output, rows, cols, targets, hparams): + """Create output from decoder output and vars.""" + decoded_image = postprocess_image(decoder_output, rows, cols, hparams) + targets_shape = common_layers.shape_list(targets) + if hparams.mode == tf.estimator.ModeKeys.PREDICT: + # Hardcoding that the number of intensity values is 256. + y = tf.reshape(decoded_image, [targets_shape[0], -1, 1, 1, 256]) + output = y[:, :targets_shape[1], :, :, :] + else: + output = tf.reshape(decoded_image, [ + targets_shape[0], targets_shape[1], targets_shape[2], + targets_shape[3], 256 + ]) + return output + + +def get_channel_embeddings(io_depth, targets, hidden_size, name="channel"): + """Get separate embedding for each of the channels.""" + targets_split = tf.split(targets, io_depth, axis=3) + rgb_embedding_var = tf.get_variable("rgb_target_emb_%s" % name, + [256 * io_depth, hidden_size]) + rgb_embedding_var = tf.identity(rgb_embedding_var) + rgb_embedding_var *= float(hidden_size)**0.5 + channel_target_embs = [] + for i in xrange(io_depth): + # Adding the channel offsets to get the right embedding since the + # embedding tensor has shape 256 * io_depth, hidden_size + target_ids = tf.squeeze(targets_split[i], axis=3) + i * 256 + target_embs = common_layers.gather(rgb_embedding_var, target_ids) + channel_target_embs.append(target_embs) + + return tf.concat(channel_target_embs, axis=-1) + + +def add_pos_signals(x, hparams, name="pos_emb"): + with tf.variable_scope(name, reuse=False): + if hparams.pos == "timing": + x = common_attention.add_timing_signal_nd(x) + else: + assert hparams.pos == "emb" + x = common_attention.add_positional_embedding_nd( + x, hparams.max_length, name=name) + return x diff --git a/tensor2tensor/layers/common_layers.py b/tensor2tensor/layers/common_layers.py index c8d54fb99..7b22dc44b 100644 --- a/tensor2tensor/layers/common_layers.py +++ b/tensor2tensor/layers/common_layers.py @@ -509,8 +509,8 @@ def layer_norm_vars(filters): def layer_norm_compute_python(x, epsilon, scale, bias): """Layer norm raw computation.""" - mean = tf.reduce_mean(x, axis=[-1], keepdims=True) - variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keepdims=True) + mean = tf.reduce_mean(x, axis=[-1], keep_dims=True) + variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keep_dims=True) norm_x = (x - mean) * tf.rsqrt(variance + epsilon) return norm_x * scale + bias @@ -1171,7 +1171,7 @@ def mask_from_embedding(emb): Returns: a 0.0/1.0 Tensor with shape [batch, width, height, 1]. """ - return weights_nonzero(tf.reduce_sum(tf.abs(emb), axis=3, keepdims=True)) + return weights_nonzero(tf.reduce_sum(tf.abs(emb), axis=3, keep_dims=True)) def mask_leq(target_length, source_length): @@ -1703,7 +1703,7 @@ def smoothing_cross_entropy(logits, depth=vocab_size, on_value=confidence, off_value=low_confidence) - xentropy = tf.nn.softmax_cross_entropy_with_logits_v2( + xentropy = tf.nn.softmax_cross_entropy_with_logits( logits=logits, labels=soft_targets) return xentropy - normalizing @@ -1737,7 +1737,7 @@ def global_pool_1d(inputs, pooling_type="MAX", mask=None): if mask is not None: # Some elems are dummy elems so we can't just reduce the average. output = tf.reduce_sum(inputs, axis=1) - num_elems = tf.reduce_sum(mask, axis=1, keepdims=True) + num_elems = tf.reduce_sum(mask, axis=1, keep_dims=True) output = tf.div(output, tf.maximum(num_elems, 1)) else: output = tf.reduce_mean(inputs, axis=1) diff --git a/tensor2tensor/layers/modalities.py b/tensor2tensor/layers/modalities.py index 26063388b..f5788701c 100644 --- a/tensor2tensor/layers/modalities.py +++ b/tensor2tensor/layers/modalities.py @@ -298,6 +298,46 @@ def top(self, body_output, _): return x +@registry.register_image_modality("channel_embeddings_bottom") +class ImageChannelEmbeddingsBottom(modality.Modality): + """Modality for images using channel compression for generation.""" + + def get_channel_embeddings(self, io_depth, targets, hidden_size, + name="channel"): + """Get separate embedding for each of the channels.""" + targets_split = tf.split(targets, io_depth, axis=3) + rgb_embedding_var = tf.get_variable("rgb_target_emb_%s" % name, + [256 * io_depth, hidden_size]) + rgb_embedding_var = tf.identity(rgb_embedding_var) + rgb_embedding_var *= float(hidden_size)**0.5 + channel_target_embs = [] + for i in xrange(io_depth): + # Adding the channel offsets to get the right embedding since the + # embedding tensor has shape 256 * io_depth, hidden_size + target_ids = tf.squeeze(targets_split[i], axis=3) + i * 256 + target_embs = common_layers.gather(rgb_embedding_var, target_ids) + channel_target_embs.append(target_embs) + + return tf.concat(channel_target_embs, axis=-1) + + def targets_bottom(self, inputs): + io_depth = self._model_hparams.num_channels + hidden_size = self._model_hparams.hidden_size + return self.get_channel_embeddings(io_depth, inputs, hidden_size, + "input_bottom") + + def top(self, body_output, _): + with tf.variable_scope(self.name): + img_len = self._model_hparams.img_len + channels = self._model_hparams.num_channels + x = tf.layers.dense(body_output, 256, + use_bias=True, activation=None, + name="output_conv") + x = tf.reshape(x, + [-1, img_len, img_len, channels, self.top_dimensionality]) + return x + + @registry.register_audio_modality("default") class AudioModality(modality.Modality): """Performs strided conv compressions for audio data.""" @@ -421,7 +461,7 @@ def top(self, body_output, _): """ with tf.variable_scope(self.name): x = body_output - x = tf.reduce_mean(x, axis=[1, 2], keepdims=True) + x = tf.reduce_mean(x, axis=[1, 2], keep_dims=True) res = tf.layers.dense(x, self._vocab_size) return tf.expand_dims(res, 3) diff --git a/tensor2tensor/models/revnet.py b/tensor2tensor/models/revnet.py index 28b4cf681..3a6c7b32b 100644 --- a/tensor2tensor/models/revnet.py +++ b/tensor2tensor/models/revnet.py @@ -277,7 +277,7 @@ def final_block(x1, x2, dim='2d', training=True, scope='final_block'): # Global average pooling net = tf.reduce_mean(y, CONFIG[dim]['reduction_dimensions'], - name='final_pool', keepdims=True) + name='final_pool', keep_dims=True) return net diff --git a/tensor2tensor/models/shake_shake.py b/tensor2tensor/models/shake_shake.py index 5e1680edb..31a576338 100644 --- a/tensor2tensor/models/shake_shake.py +++ b/tensor2tensor/models/shake_shake.py @@ -191,3 +191,11 @@ def shakeshake_big(): hparams.layer_prepostprocess_dropout = 0.0 hparams.hidden_size = 96 return hparams + + +@registry.register_hparams +def shakeshake_tpu(): + hparams = shakeshake_big() + hparams.learning_rate_cosine_cycle_steps = 180000 + hparams.learning_rate = 0.6 + return hparams diff --git a/tensor2tensor/models/slicenet.py b/tensor2tensor/models/slicenet.py index e77412513..43cfa571e 100644 --- a/tensor2tensor/models/slicenet.py +++ b/tensor2tensor/models/slicenet.py @@ -381,10 +381,6 @@ def slicenet_params1_tiny(): def slicenet_range1(ranged_hparams): """Small range of hyperparameters.""" rhp = ranged_hparams - - hparams = slicenet_params1() - common_hparams.fill_ranged_hparams_from_hparams(hparams, rhp) - rhp.set_float("clip_grad_norm", 1.0, 10.0, scale=rhp.LOG_SCALE) rhp.set_float("learning_rate", 0.02, 1.0, scale=rhp.LOG_SCALE) rhp.set_float("optimizer_adam_beta2", 0.995, 0.998) diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index b241cc24a..061b68ab7 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -146,12 +146,13 @@ def body(self, features): """ hparams = self._hparams - inputs = features.get("inputs") - encoder_output, encoder_decoder_attention_bias = (None, None) - if inputs is not None: + if self.has_input: + inputs = features["inputs"] target_space = features["target_space_id"] encoder_output, encoder_decoder_attention_bias = self.encode( inputs, target_space, hparams, features=features) + else: + encoder_output, encoder_decoder_attention_bias = (None, None) targets = features["targets"] targets = common_layers.flatten4d3d(targets) @@ -159,10 +160,21 @@ def body(self, features): decoder_input, decoder_self_attention_bias = transformer_prepare_decoder( targets, hparams, features=features) - return self.decode(decoder_input, encoder_output, - encoder_decoder_attention_bias, - decoder_self_attention_bias, hparams, - nonpadding=features_to_nonpadding(features, "targets")) + decoder_output = self.decode( + decoder_input, + encoder_output, + encoder_decoder_attention_bias, + decoder_self_attention_bias, + hparams, + nonpadding=features_to_nonpadding(features, "targets")) + + expected_attention_weights = features.get("expected_attention_weights") + if expected_attention_weights is not None: + attention_loss = common_attention.encoder_decoder_attention_loss( + expected_attention_weights, self.attention_weights) + return decoder_output, {"attention_loss": attention_loss} + + return decoder_output def _greedy_infer(self, features, decode_length): """Fast version of greedy decoding. @@ -245,31 +257,44 @@ def _fast_decode(self, raise NotImplementedError("Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams - - inputs = features["inputs"] target_modality = self._problem_hparams.target_modality - if target_modality.is_class_modality: - decode_length = 1 + + if self.has_input: + inputs = features["inputs"] + if target_modality.is_class_modality: + decode_length = 1 + else: + decode_length = common_layers.shape_list(inputs)[1] + decode_length + + # TODO(llion): Clean up this reshaping logic. + inputs = tf.expand_dims(inputs, axis=1) + if len(inputs.shape) < 5: + inputs = tf.expand_dims(inputs, axis=4) + s = common_layers.shape_list(inputs) + batch_size = s[0] + inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) + # _shard_features called to ensure that the variable names match + inputs = self._shard_features({"inputs": inputs})["inputs"] + input_modality = self._problem_hparams.input_modality["inputs"] + with tf.variable_scope(input_modality.name): + inputs = input_modality.bottom_sharded(inputs, dp) + with tf.variable_scope("body"): + encoder_output, encoder_decoder_attention_bias = dp( + self.encode, inputs, features["target_space_id"], hparams, + features=features) + encoder_output = encoder_output[0] + encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] + partial_targets = None else: - decode_length = common_layers.shape_list(inputs)[1] + decode_length - - # TODO(llion): Clean up this reshaping logic. - inputs = tf.expand_dims(inputs, axis=1) - if len(inputs.shape) < 5: - inputs = tf.expand_dims(inputs, axis=4) - s = common_layers.shape_list(inputs) - inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) - # _shard_features called to ensure that the variable names match - inputs = self._shard_features({"inputs": inputs})["inputs"] - input_modality = self._problem_hparams.input_modality["inputs"] - with tf.variable_scope(input_modality.name): - inputs = input_modality.bottom_sharded(inputs, dp) - with tf.variable_scope("body"): - encoder_output, encoder_decoder_attention_bias = dp( - self.encode, inputs, features["target_space_id"], hparams, - features=features) - encoder_output = encoder_output[0] - encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] + # The problem has no inputs. + # In this case, features["inputs"] contains partial targets. + # We force the outputs to begin with these sequences. + encoder_output = None + encoder_decoder_attention_bias = None + partial_targets = tf.squeeze(tf.to_int64(features["inputs"]), [2, 3]) + partial_targets_length = common_layers.shape_list(partial_targets)[1] + decode_length += partial_targets_length + batch_size = tf.shape(partial_targets)[0] if hparams.pos == "timing": timing_signal = common_attention.get_timing_signal_1d( @@ -320,16 +345,30 @@ def symbols_to_logits_fn(ids, i, cache): with tf.variable_scope("body"): body_outputs = dp( - self.decode, targets, cache["encoder_output"], - cache["encoder_decoder_attention_bias"], bias, hparams, cache, + self.decode, targets, cache.get("encoder_output"), + cache.get("encoder_decoder_attention_bias"), + bias, hparams, cache, nonpadding=features_to_nonpadding(features, "targets")) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] - return tf.squeeze(logits, axis=[1, 2, 3]), cache - - return fast_decode( + ret = tf.squeeze(logits, axis=[1, 2, 3]) + if partial_targets is not None: + # If the position is within the given partial targets, we alter the + # logits to always return those values. + # A faster approach would be to process the partial targets in one + # iteration in order to fill the corresponding parts of the cache. + # This would require broader changes, though. + vocab_size = tf.shape(ret)[1] + def forced_logits(): + return tf.one_hot(tf.tile(partial_targets[:, i], [beam_size]), + vocab_size, 0.0, -1e9) + ret = tf.cond( + tf.less(i, partial_targets_length), forced_logits, lambda: ret) + return ret, cache + + ret = fast_decode( encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, symbols_to_logits_fn=symbols_to_logits_fn, @@ -338,7 +377,11 @@ def symbols_to_logits_fn(ids, i, cache): vocab_size=target_modality.top_dimensionality, beam_size=beam_size, top_beams=top_beams, - alpha=alpha) + alpha=alpha, + batch_size=batch_size) + if partial_targets is not None: + ret["outputs"] = ret["outputs"][:, partial_targets_length:] + return ret def fast_decode(encoder_output, @@ -350,7 +393,8 @@ def fast_decode(encoder_output, beam_size=1, top_beams=1, alpha=1.0, - eos_id=beam_search.EOS_ID): + eos_id=beam_search.EOS_ID, + batch_size=None): """Given encoder output and a symbols to logits function, does fast decoding. Implements both greedy and beam search decoding, uses beam search iff @@ -370,6 +414,7 @@ def fast_decode(encoder_output, alpha: Float that controls the length penalty. larger the alpha, stronger the preference for slonger translations. eos_id: End-of-sequence symbol in beam search. + batch_size: an integer scalar - must be passed if there is no input Returns: A dict of decoding results { @@ -379,8 +424,12 @@ def fast_decode(encoder_output, "scores": decoding log probs from the beam search, None if using greedy decoding (beam_size=1) } + + Raises: + NotImplementedError: If beam size > 1 with partial targets. """ - batch_size = common_layers.shape_list(encoder_output)[0] + if encoder_output is not None: + batch_size = common_layers.shape_list(encoder_output)[0] key_channels = hparams.attention_key_channels or hparams.hidden_size value_channels = hparams.attention_value_channels or hparams.hidden_size @@ -394,8 +443,9 @@ def fast_decode(encoder_output, for layer in range(num_layers) } - cache["encoder_output"] = encoder_output - cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias + if encoder_output is not None: + cache["encoder_output"] = encoder_output + cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias if beam_size > 1: # Beam Search initial_ids = tf.zeros([batch_size], dtype=tf.int32) @@ -417,6 +467,7 @@ def fast_decode(encoder_output, else: # Greedy def inner_loop(i, finished, next_id, decoded_ids, cache): + """One step of greedy decoding.""" logits, cache = symbols_to_logits_fn(next_id, i, cache) temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) @@ -430,7 +481,7 @@ def is_not_finished(i, finished, *_): return (i < decode_length) & tf.logical_not(tf.reduce_all(finished)) decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64) - finished = tf.constant(False, shape=[batch_size]) + finished = tf.fill([batch_size], False) next_id = tf.zeros([batch_size, 1], dtype=tf.int64) _, _, _, decoded_ids, _ = tf.while_loop( is_not_finished, @@ -826,7 +877,7 @@ def transformer_base_v1(): hparams.max_length = 256 hparams.clip_grad_norm = 0. # i.e. no gradient clipping hparams.optimizer_adam_epsilon = 1e-9 - hparams.learning_rate_decay_scheme = "noam" + hparams.learning_rate_schedule = "linear_warmup_rsqrt_decay" hparams.learning_rate = 0.1 hparams.learning_rate_warmup_steps = 4000 hparams.initializer_gain = 1.0 @@ -1177,11 +1228,9 @@ def transformer_prepend(): return transformer_prepend_v2() -@registry.register_ranged_hparams("transformer_base") +@registry.register_ranged_hparams def transformer_base_range(rhp): """Small range of hyperparameters.""" - hparams = transformer_base() - common_hparams.fill_ranged_hparams_from_hparams(hparams, rhp) # After starting from base, set intervals for some parameters. rhp.set_float("learning_rate", 0.3, 3.0, scale=rhp.LOG_SCALE) rhp.set_discrete("learning_rate_warmup_steps", @@ -1286,8 +1335,6 @@ def transformer_tiny_tpu(): @registry.register_ranged_hparams def transformer_tiny_tpu_range(rhp): """Small range of hyperparameters.""" - hparams = transformer_tiny_tpu() - common_hparams.fill_ranged_hparams_from_hparams(hparams, rhp) rhp.set_float("learning_rate", 0.3, 3.0, scale=rhp.LOG_SCALE) rhp.set_float("weight_decay", 0.0, 2.0) @@ -1295,8 +1342,6 @@ def transformer_tiny_tpu_range(rhp): @registry.register_ranged_hparams def transformer_tpu_range(rhp): """Small range of hyperparameters.""" - hparams = transformer_tpu() - common_hparams.fill_ranged_hparams_from_hparams(hparams, rhp) # After starting from base, set intervals for some parameters. rhp.set_float("learning_rate", 0.3, 3.0, scale=rhp.LOG_SCALE) rhp.set_discrete("learning_rate_warmup_steps", @@ -1307,13 +1352,6 @@ def transformer_tpu_range(rhp): rhp.set_float("weight_decay", 0.0, 2.0) -@registry.register_ranged_hparams -def transformer_tpu_batch_range(rhp): - hparams = transformer_tpu() - common_hparams.fill_ranged_hparams_from_hparams(hparams, rhp) - rhp.set_discrete("batch_size", [256, 512, 768, 1024]) - - @registry.register_hparams def transformer_small_tpu(): """TPU-friendly version of transformer_small. diff --git a/tensor2tensor/models/transformer_moe.py b/tensor2tensor/models/transformer_moe.py index efa67bf27..2bf807b19 100644 --- a/tensor2tensor/models/transformer_moe.py +++ b/tensor2tensor/models/transformer_moe.py @@ -329,7 +329,7 @@ def transformer_moe_8k_lm(): # * Memory efficient multihead attention (slow): # hparams.layer_types = "#mem/mem/mem-moe/mem/mem" # * Alternate between local/compressed attention layers (faster): - # hparams.layer_types = "#locm/red/locm-moe/red/locm" + # hparams.layer_types = "#locm/redm/locm-moe/redm/locm" return hparams @@ -386,6 +386,6 @@ def transformer_moe_prepend_8k(): hparams.eval_drop_long_sequences = False hparams.max_input_seq_length = 7500, hparams.default_ff = "sepm" - hparams.layer_types = "locm/red/locm-moe/red/locm" + hparams.layer_types = "locm/redm/locm-moe/redm/locm" hparams.moe_num_experts = 256 return hparams diff --git a/tensor2tensor/models/transformer_sketch.py b/tensor2tensor/models/transformer_sketch.py index 913243f00..e8fe796a8 100644 --- a/tensor2tensor/models/transformer_sketch.py +++ b/tensor2tensor/models/transformer_sketch.py @@ -22,7 +22,6 @@ # Dependency imports -from tensor2tensor.layers import common_hparams from tensor2tensor.layers import common_layers from tensor2tensor.models import transformer from tensor2tensor.models import transformer_vae @@ -125,10 +124,6 @@ def transformer_sketch_6layer(): @registry.register_ranged_hparams("transformer_sketch_ranged") def transformer_sketch_ranged(rhp): """Range of hparams for vizier.""" - - hparams = transformer_sketch() - common_hparams.fill_ranged_hparams_from_hparams(hparams, rhp) - rhp.set_categorical("ffn_layer", ["conv_hidden_relu_with_sepconv", "conv_hidden_relu"]) rhp.set_discrete("batch_size", [1024, 2048, 4096]) diff --git a/tensor2tensor/models/transformer_test.py b/tensor2tensor/models/transformer_test.py index f67476006..1a6134b51 100644 --- a/tensor2tensor/models/transformer_test.py +++ b/tensor2tensor/models/transformer_test.py @@ -37,13 +37,15 @@ class TransformerTest(tf.test.TestCase): - def getModel(self, hparams, mode=tf.estimator.ModeKeys.TRAIN): + def getModel(self, hparams, mode=tf.estimator.ModeKeys.TRAIN, has_input=True): hparams.hidden_size = 8 hparams.filter_size = 32 hparams.num_heads = 1 hparams.layer_prepostprocess_dropout = 0.0 p_hparams = problem_hparams.test_problem_hparams(VOCAB_SIZE, VOCAB_SIZE) + if not has_input: + p_hparams.input_modality = {} hparams.problems = [p_hparams] inputs = -1 + np.random.random_integers( @@ -108,6 +110,41 @@ def testGreedyVsFast(self): self.assertEqual(fast_res.shape, (BATCH_SIZE, INPUT_LENGTH + decode_length)) self.assertAllClose(greedy_res, fast_res) + def testSlowVsFastNoInput(self): + model, features = self.getModel( + transformer.transformer_small(), has_input=False) + + decode_length = 2 + + out_logits, _ = model(features) + out_logits = tf.squeeze(out_logits, axis=[2, 3]) + loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]), + labels=tf.reshape(features["targets"], [-1])) + loss = tf.reduce_mean(loss) + apply_grad = tf.train.AdamOptimizer(0.001).minimize(loss) + + with self.test_session(): + tf.global_variables_initializer().run() + for _ in range(100): + apply_grad.run() + + model.set_mode(tf.estimator.ModeKeys.PREDICT) + + with tf.variable_scope(tf.get_variable_scope(), reuse=True): + slow_result = model._slow_greedy_infer( + features, decode_length)["outputs"] + slow_result = tf.squeeze(slow_result, axis=[2, 3]) + + fast_result = model._greedy_infer(features, decode_length)["outputs"] + + with self.test_session(): + slow_res = slow_result.eval() + fast_res = fast_result.eval() + + self.assertEqual(fast_res.shape, (BATCH_SIZE, decode_length)) + self.assertAllClose(slow_res, fast_res) + def testBeamVsFast(self): model, features = self.getModel(transformer.transformer_small()) @@ -170,6 +207,18 @@ def testTransformerWithoutProblem(self): body_out.get_shape().as_list(), [BATCH_SIZE, TARGET_LENGTH, 1, hparams.hidden_size]) + def testTransformerWithEncoderDecoderAttentionLoss(self): + model, features = self.getModel(transformer.transformer_small()) + expected_attention_weights = np.random.random_sample( + size=(BATCH_SIZE, TARGET_LENGTH, INPUT_LENGTH)) + features["expected_attention_weights"] = tf.constant( + expected_attention_weights, dtype=tf.float32) + _, extra_loss = model(features) + with self.test_session() as session: + session.run(tf.global_variables_initializer()) + res = session.run(extra_loss["attention_loss"]) + self.assertEqual(res.shape, ()) + if __name__ == "__main__": tf.test.main() diff --git a/tensor2tensor/models/transformer_vae.py b/tensor2tensor/models/transformer_vae.py index ac9a66b77..59f2d08ba 100644 --- a/tensor2tensor/models/transformer_vae.py +++ b/tensor2tensor/models/transformer_vae.py @@ -24,14 +24,16 @@ # Dependency imports from six.moves import xrange # pylint: disable=redefined-builtin - from tensor2tensor.layers import common_attention +from tensor2tensor.layers import common_image_attention as cia from tensor2tensor.layers import common_layers from tensor2tensor.models import transformer +from tensor2tensor.utils import beam_search from tensor2tensor.utils import expert_utils from tensor2tensor.utils import registry from tensor2tensor.utils import t2t_model + import tensorflow as tf from tensorflow.python.training import moving_averages @@ -90,9 +92,9 @@ def top_k_softmax(x, k): """Calculate softmax(x), select top-k and rescale to sum to 1.""" x = tf.nn.softmax(x) top_x, _ = tf.nn.top_k(x, k=k+1) - min_top = tf.reduce_min(top_x, axis=-1, keepdims=True) + min_top = tf.reduce_min(top_x, axis=-1, keep_dims=True) x = tf.nn.relu((x - min_top) + 1e-12) - x /= tf.reduce_sum(x, axis=-1, keepdims=True) + x /= tf.reduce_sum(x, axis=-1, keep_dims=True) return x, tf.reduce_max(top_x, axis=-1) @@ -140,7 +142,7 @@ def dae(x, hparams, name): maxvhot = tf.stop_gradient(tf.one_hot(maxvec, hparams.v_size)) # Add losses that prevent too few being used. distrib = tf.reshape(logsm, [-1, hparams.v_size]) * maxvhot - d_mean = tf.reduce_mean(distrib, axis=[0], keepdims=True) + d_mean = tf.reduce_mean(distrib, axis=[0], keep_dims=True) d_variance = tf.reduce_mean(tf.square(distrib - d_mean), axis=[0]) d_dev = - tf.reduce_mean(d_variance) ret = s @@ -200,8 +202,8 @@ def slice_hidden(x, hparams): def nearest(x, means, hparams): """Find the nearest means to elements in x.""" x_reshaped = hparams.reshape_fn(x, hparams) - x_norm_sq = tf.reduce_sum(tf.square(x_reshaped), axis=-1, keepdims=True) - means_norm_sq = tf.reduce_sum(tf.square(means), axis=-1, keepdims=True) + x_norm_sq = tf.reduce_sum(tf.square(x_reshaped), axis=-1, keep_dims=True) + means_norm_sq = tf.reduce_sum(tf.square(means), axis=-1, keep_dims=True) scalar_prod = tf.matmul( tf.transpose(x_reshaped, perm=[1, 0, 2]), tf.transpose(means, perm=[0, 2, 1])) @@ -287,22 +289,23 @@ def embed(x): hot = tf.one_hot(x, hparams.v_size) h1 = tf.layers.dense(hot, hparams.hidden_size, name="dae_dense") elif hparams.bottleneck_kind == "vq-vae": - means_embed = means shape_x = common_layers.shape_list(x) x_flat = tf.reshape(x, [-1, 1]) - c = int_to_bit(x_flat, nbits=int(math.log(hparams.v_size, 2)), base=2) + c = int_to_bit(x_flat, nbits=hparams.z_size, base=2) shape = common_layers.shape_list(c) new_shape = shape new_shape[-1] = hparams.num_blocks - new_shape.append(int(math.log(hparams.v_size, 2) // hparams.num_blocks)) + new_shape.append(int(hparams.z_size / hparams.num_blocks)) c = tf.to_int32(tf.reshape(c, shape=new_shape)) c = bit_to_int( c, - nbits=int(math.log(hparams.v_size, 2) // hparams.num_blocks), + nbits=int(hparams.z_size / hparams.num_blocks), base=2) - h1 = tf.gather(tf.transpose(means_embed, [1, 0, 2]), c) - h1 = tf.stack( - [h1[:, :, i, i, :] for i in range(hparams.num_blocks)], axis=-2) + c_hot = tf.one_hot(c, depth=hparams.block_v_size, axis=-1) + c_hot_flat = tf.reshape( + c_hot, shape=[-1, hparams.num_blocks, hparams.block_v_size]) + h1 = tf.matmul(tf.transpose(c_hot_flat, perm=[1, 0, 2]), means) + h1 = tf.transpose(h1, perm=[1, 0, 2]) new_shape = shape_x new_shape.append(hparams.hidden_size) h1 = tf.reshape(h1, new_shape) @@ -354,18 +357,19 @@ def embed(x): # Get the discrete latent represenation x_means_idx = tf.argmax(x_means_hot, axis=-1) + # Get the binary representation x_means_bits = int_to_bit( x_means_idx, - nbits=int(math.log(hparams.v_size, 2) // hparams.num_blocks), + nbits=int(hparams.z_size / hparams.num_blocks), base=2) shape = common_layers.shape_list(x_means_bits) new_shape = shape[:-1] - new_shape[-1] = int(math.log(hparams.v_size, 2)) + new_shape[-1] = hparams.z_size x_means_bits = tf.reshape(x_means_bits, shape=new_shape) c = bit_to_int( tf.to_int32(x_means_bits), - nbits=int(math.log(hparams.v_size, 2)), + nbits=hparams.z_size, base=2) # Update the ema variables @@ -389,7 +393,7 @@ def embed(x): tf.transpose(x_reshaped, perm=[1, 0, 2])) updated_ema_means = moving_averages.assign_moving_average( ema_means, dw, hparams.decay, zero_debias=False) - n = tf.reduce_sum(updated_ema_count, axis=-1, keepdims=True) + n = tf.reduce_sum(updated_ema_count, axis=-1, keep_dims=True) updated_ema_count = ((updated_ema_count + hparams.epsilon) / (n + hparams.v_size * hparams.epsilon) * n) updated_ema_means /= tf.expand_dims(updated_ema_count, axis=-1) @@ -459,26 +463,57 @@ def decode_transformer(encoder_output, encoder_decoder_attention_bias, targets, hparams, - name): + name, + task=None): """Original Transformer decoder.""" with tf.variable_scope(name): - targets = common_layers.flatten4d3d(targets) - - decoder_input, decoder_self_bias = transformer.transformer_prepare_decoder( - targets, hparams) - - decoder_input = tf.nn.dropout(decoder_input, - 1.0 - hparams.layer_prepostprocess_dropout) - - decoder_output = transformer.transformer_decoder( - decoder_input, - encoder_output, - decoder_self_bias, - encoder_decoder_attention_bias, - hparams) - + if task is None: + task = hparams.task + if task == "translate": + targets = common_layers.flatten4d3d(targets) + + decoder_input, decoder_self_bias = ( + transformer.transformer_prepare_decoder(targets, hparams)) + + decoder_input = tf.nn.dropout(decoder_input, + 1.0 - hparams.layer_prepostprocess_dropout) + + decoder_output = transformer.transformer_decoder( + decoder_input, + encoder_output, + decoder_self_bias, + encoder_decoder_attention_bias, + hparams) + decoder_output = tf.expand_dims(decoder_output, axis=2) + else: + assert task == "image" + inputs = None + # have to reshape targets as b, 32, 32, 3 * hidden size] beacuse otherwise + # prepare_image will choke + targets = tf.reshape(targets, [tf.shape(targets)[0], hparams.img_len, + hparams.img_len, + hparams.num_channels*hparams.hidden_size]) + + # Prepare decoder inputs and bias. + decoder_input, _, _, bias = cia.prepare_decoder(targets, hparams) + # Add class label to decoder input. + if not hparams.drop_inputs: + decoder_input += tf.reshape( + inputs, + [common_layers.shape_list(targets)[0], 1, 1, hparams.hidden_size]) + decoder_output = cia.transformer_decoder_layers( + decoder_input, + None, + bias, + hparams.num_decoder_layers or hparams.num_hidden_layers, + hparams, + attention_type=hparams.dec_attention_type, + name="decoder") + decoder_output_shape = common_layers.shape_list(decoder_output) + decoder_output = tf.reshape(decoder_output, [decoder_output_shape[0], -1, 1, + hparams.hidden_size]) # Expand since t2t expects 4d tensors. - return tf.expand_dims(decoder_output, axis=2) + return decoder_output def multinomial_sample(x, vocab_size, temperature): @@ -496,7 +531,7 @@ def ae_latent_softmax(latents_pred, latents_discrete, hparams): vocab_size = hparams.v_size if hparams.bottleneck_kind == "semhash": vocab_size = 2**hparams.z_size - if hparams.num_blocks < 2: + if hparams.num_decode_blocks < 2: latents_logits = tf.layers.dense(latents_pred, vocab_size, name="extra_logits") loss = None @@ -510,15 +545,17 @@ def ae_latent_softmax(latents_pred, latents_discrete, hparams): # Multi-block case. vocab_bits = int(math.log(vocab_size, 2)) assert vocab_size == 2**vocab_bits - assert vocab_bits % hparams.num_blocks == 0 - block_vocab_size = 2**(vocab_bits // hparams.num_blocks) - latents_logits = [tf.layers.dense(latents_pred, block_vocab_size, - name="extra_logits_%d" % i) - for i in xrange(hparams.num_blocks)] + assert vocab_bits % hparams.num_decode_blocks == 0 + block_vocab_size = 2**(vocab_bits // hparams.num_decode_blocks) + latents_logits = [ + tf.layers.dense( + latents_pred, block_vocab_size, name="extra_logits_%d" % i) + for i in xrange(hparams.num_decode_blocks) + ] loss = None if latents_discrete is not None: losses = [] - for i in xrange(hparams.num_blocks): + for i in xrange(hparams.num_decode_blocks): d = tf.floormod(tf.floordiv(latents_discrete, block_vocab_size**i), block_vocab_size) losses.append(tf.nn.sparse_softmax_cross_entropy_with_logits( @@ -530,8 +567,42 @@ def ae_latent_softmax(latents_pred, latents_discrete, hparams): return sample, loss +def ae_latent_sample_beam(latents_dense_in, inputs, ed, embed, hparams): + """Sample from the latent space in the autoencoder.""" + vocab_size = 2**hparams.z_size + beam_size = 1 # TODO(lukaszkaiser): larger beam sizes seem to work bad. + inputs = tf.tile(inputs, [beam_size, 1, 1]) + ed = tf.tile(ed, [beam_size, 1, 1, 1]) + + def symbols_to_logits_fn(ids): + """Go from ids to logits.""" + ids = tf.expand_dims(ids, axis=2) # Ids start with added all-zeros. + latents_discrete = tf.pad(ids[:, 1:], [[0, 0], [0, 1], [0, 0]]) + + with tf.variable_scope(tf.get_variable_scope(), reuse=False): + latents_dense = embed(latents_discrete) + latents_pred = decode_transformer( + inputs, ed, latents_dense, hparams, "extra") + logits = tf.layers.dense(latents_pred, vocab_size, name="extra_logits") + current_output_position = common_layers.shape_list(ids)[1] - 1 + logits = logits[:, current_output_position, :, :] + return tf.squeeze(logits, axis=[1]) + + initial_ids = tf.zeros([tf.shape(latents_dense_in)[0]], dtype=tf.int32) + length = tf.shape(latents_dense_in)[1] + ids, _ = beam_search.beam_search( + symbols_to_logits_fn, initial_ids, beam_size, length, + vocab_size, alpha=0.0, eos_id=-1, stop_early=False) + + res = tf.expand_dims(ids[:, 0, :], axis=2) # Pick first beam. + return res[:, 1:] # Remove the added all-zeros from ids. + + def ae_latent_sample(latents_dense, inputs, ed, embed, iters, hparams): """Sample from the latent space in the autoencoder.""" + if hparams.num_decode_blocks < 2: + # TODO(lukaszkaiser): beam-search only works in non-blocked mode for now. + return ae_latent_sample_beam(latents_dense, inputs, ed, embed, hparams) latents_pred = decode_transformer(inputs, ed, latents_dense, hparams, "extra") latents_discrete, _ = ae_latent_softmax(latents_pred, None, hparams) @@ -566,7 +637,10 @@ def ae_transformer_internal(inputs, _DO_SUMMARIES = False # Prepare. - batch_size = common_layers.shape_list(inputs)[0] + if inputs is not None: + batch_size = common_layers.shape_list(inputs)[0] + else: + batch_size = common_layers.shape_list(targets)[0] targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size]) # Encoder. @@ -579,14 +653,22 @@ def ae_transformer_internal(inputs, # Autoencoding. losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)} if hparams.do_ae: - max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1) + # flatten here + original_targets_shape = tf.shape(targets) + if hparams.task == "image": + cia.maybe_reshape_4d_to_3d(targets, hparams) + if hparams.task == "translate": + max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1) + else: + assert hparams.task == "image" + max_targets_len_from_inputs = targets targets, _ = common_layers.pad_to_same_length( targets, max_targets_len_from_inputs, final_length_divisible_by=2**hparams.num_compress_steps) targets_c = compress(targets, inputs, False, hparams, "compress") if hparams.mode != tf.estimator.ModeKeys.PREDICT: # Compress and bottleneck. - latents_dense, latents_discrete, extra_loss, _ = bottleneck( + latents_dense, latents_discrete, extra_loss, embed = bottleneck( targets_c, hparams, 2 * 2048, "vc", means, ema_count, ema_means) if _DO_SUMMARIES: tf.summary.histogram("b0", tf.reshape(latents_discrete[:, 0, :], [-1])) @@ -599,8 +681,10 @@ def ae_transformer_internal(inputs, # Extra loss predicting latent code from input. Discrete only. if hparams.bottleneck_kind not in ["dense", "vae"]: latents_pred = decode_transformer( - tf.stop_gradient(inputs), tf.stop_gradient(ed), - tf.stop_gradient(latents_dense), hparams, "extra") + tf.stop_gradient(inputs) if inputs is not None else None, + tf.stop_gradient(ed) if inputs is not None else None, + tf.stop_gradient(latents_dense), hparams, "extra", + task="translate") _, latent_pred_loss = ae_latent_softmax( latents_pred, latents_discrete, hparams) losses["latent_pred"] = tf.reduce_mean( @@ -631,7 +715,8 @@ def bn_inputs(): ema_count, ema_means) latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :]) if cache is None: - cache = ae_latent_sample(latents_dense, inputs, ed, embed, 8, hparams) + cache = ae_latent_sample( + latents_dense, inputs, ed, embed, 16, hparams) latents_dense = embed(cache) # Postprocess. d = latents_dense @@ -659,12 +744,18 @@ def bn_inputs(): if hparams.do_attend_decompress: d = attend(d, inputs, hparams, "decompress_attend_%d" % j) d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j) + # targets is always [batch, length, 1, depth] targets = mask * targets + (1.0 - mask) * d - targets = tf.concat([tf.reverse(latents_dense, [1]), targets], axis=1) + # reshape back to 4d here + if hparams.task == "image": + targets = tf.reshape(targets, original_targets_shape) + if hparams.task == "translate": + targets = tf.concat([tf.reverse(latents_dense, [1]), targets], axis=1) res = decode_transformer(inputs, ed, targets, hparams, "decoder") if hparams.do_ae: - res = res[:, common_layers.shape_list(latents_dense)[1]:, :, :] + if hparams.task == "translate": + res = res[:, common_layers.shape_list(latents_dense)[1]:, :, :] if hparams.do_mask and hparams.do_refine: def refine_res(): # return residual_conv(res, 1, (5, 1), hparams, "refine") @@ -701,11 +792,11 @@ def __init__(self, *args, **kwargs): self._hparams.block_dim = int( self._hparams.hidden_size // self._hparams.num_blocks) self._hparams.block_v_size = 2**( - math.log(self._hparams.v_size, 2) / self._hparams.num_blocks) + self._hparams.z_size / self._hparams.num_blocks) self._hparams.block_v_size = int(self._hparams.block_v_size) if self._hparams.reshape_method == "project": - tf.logging.info("Using random projections for hierarchical vq-vae") + tf.logging.info("Using projections for decomposed vq-vae") tf.logging.info("Trainable projections = {}".format( self._hparams.trainable_projections)) self._hparams.projection_tensors = tf.get_variable( @@ -718,7 +809,7 @@ def __init__(self, *args, **kwargs): trainable=self._hparams.trainable_projections) self._hparams.reshape_fn = project_hidden elif self._hparams.reshape_method == "slice": - tf.logging.info("Using slices for hierarchical vq-vae") + tf.logging.info("Using slices for decomposed vq-vae") self._hparams.reshape_fn = slice_hidden else: raise ValueError("Unknown reshape method") @@ -763,7 +854,7 @@ def body(self, features): return res, loss def prepare_features_for_infer(self, features): - if not self._hparams.do_ae: + if self._hparams.do_mask or not self._hparams.do_ae: return features beam_batch_size = self._decode_hparams.beam_size beam_batch_size *= self._decode_hparams.batch_size @@ -830,17 +921,21 @@ def transformer_ae_small(): hparams.hidden_size = 384 hparams.filter_size = 2048 hparams.label_smoothing = 0.0 - hparams.optimizer = "Adafactor" - hparams.add_hparam("z_size", 16) - hparams.add_hparam("noise_dev", 0.0) + hparams.optimizer = "Adam" # Can be unstable, maybe try Adam. + hparams.optimizer_adam_epsilon = 1e-9 + hparams.optimizer_adam_beta1 = 0.9 + hparams.optimizer_adam_beta2 = 0.997 # Needs tuning, try 0.98 to 0.999. + hparams.add_hparam("z_size", 14) + hparams.add_hparam("noise_dev", 0.5) hparams.add_hparam("d_mix", 0.5) # Bottleneck kinds supported: dense, vae, semhash, gumbel-softmax, vq-vae. hparams.add_hparam("bottleneck_kind", "semhash") hparams.add_hparam("num_blocks", 1) - # Reshape method for hierarchical vq-vae: slice, project + hparams.add_hparam("num_decode_blocks", 1) + # Reshape method for decomposed vq-vae: slice, project hparams.add_hparam("reshape_method", "slice") hparams.add_hparam("trainable_projections", False) - hparams.add_hparam("unmasked_percentage", 0.3) + hparams.add_hparam("unmasked_percentage", 0.1) hparams.add_hparam("do_ae", True) hparams.add_hparam("do_mask", True) hparams.add_hparam("do_refine", False) @@ -869,25 +964,81 @@ def transformer_ae_small(): hparams.add_hparam("random_top_k", 1) hparams.kl_warmup_steps = 150000 hparams.force_full_predict = True + + # task params + hparams.add_hparam("task", "translate") # translate or image tasks supported return hparams @registry.register_hparams -def transformer_ae_cifar(): +def imagetransformer_ae_cifar(): """Hyperparameters for CIFAR-10 experiments.""" hparams = transformer_ae_small() - hparams.hidden_size = 256 hparams.filter_size = 512 - hparams.batch_size = 1024 * 4 - hparams.num_compress_steps = 2 + hparams.num_compress_steps = 3 hparams.v_size = 1024 * 64 - hparams.kl_warmup_steps = 150000 hparams.startup_steps = 10000 hparams.kmeans_lr_factor = 0.0 - hparams.is_2d = 1 + hparams.is_2d = 0 hparams.learning_rate_warmup_steps = 8000 hparams.learning_rate = 0.2 - hparams.ffn_layer = "conv_hidden_relu_with_sepconv" + hparams.hidden_size = 512 + hparams.batch_size = 1 + hparams.max_length = 256 + hparams.dropout = 0.0 + hparams.clip_grad_norm = 0. # i.e. no gradient clipping + hparams.optimizer_adam_epsilon = 1e-9 + hparams.learning_rate_decay_scheme = "noam" + hparams.learning_rate = 0.1 + hparams.initializer_gain = 0.2 + hparams.num_hidden_layers = 6 + hparams.initializer = "uniform_unit_scaling" + hparams.weight_decay = 0.0 + hparams.optimizer_adam_beta1 = 0.9 + hparams.optimizer_adam_beta2 = 0.98 + hparams.label_smoothing = 0.0 + hparams.norm_type = "layer" + hparams.layer_prepostprocess_dropout = 0.0 + hparams.num_heads = 8 + hparams.task = "image" + hparams.ffn_layer = "conv_hidden_relu" + # All hyperparameters ending in "dropout" are automatically set to 0.0 + # when not in training mode. + hparams.attention_dropout = 0.0 + hparams.relu_dropout = 0. + hparams.pos = "timing" # timing, none + hparams.nbr_decoder_problems = 1 + hparams.num_output_layers = 3 + hparams.add_hparam("block_size", 1) + + # dilated attention based flags + hparams.add_hparam("gap_sizes", [2, 4, 8, 16, 32, 64, 2, 4, 8, 16, 32, 64]) + hparams.add_hparam("dilated_attention", False) + + # image size related flags + # assuming that the image has same height and width + hparams.add_hparam("img_len", 32) + hparams.add_hparam("num_channels", 3) + # Local attention params + hparams.add_hparam("local_and_global_att", False) + hparams.add_hparam("block_length", 256) + hparams.add_hparam("block_width", 128) + hparams.num_encoder_layers = 4 + hparams.num_decoder_layers = 12 + hparams.sep_rgb_embed = False + hparams.add_hparam("dec_attention_type", cia.AttentionType.LOCAL_1D) + hparams.add_hparam("block_rastor_scan", False) + + # multipos attention params + hparams.add_hparam("q_filter_width", 1) + hparams.add_hparam("kv_filter_width", 1) + + hparams.add_hparam("unconditional", False) # unconditional generation + + hparams.target_modality = "image:channel_embeddings_bottom" + hparams.drop_inputs = True + hparams.do_attend_compress = False + hparams.do_attend_decompress = False return hparams @@ -895,7 +1046,7 @@ def transformer_ae_cifar(): def transformer_ae_base(): """Set of hyperparameters.""" hparams = transformer_ae_small() - hparams.batch_size = 1024 + hparams.batch_size = 2048 hparams.hidden_size = 512 hparams.filter_size = 4096 hparams.num_hidden_layers = 6 diff --git a/tensor2tensor/notebooks/hello_t2t.ipynb b/tensor2tensor/notebooks/hello_t2t.ipynb index cc9f66a02..3b3573098 100644 --- a/tensor2tensor/notebooks/hello_t2t.ipynb +++ b/tensor2tensor/notebooks/hello_t2t.ipynb @@ -60,7 +60,7 @@ }, "source": [ "# Install deps\n", - "!pip install -q tensor2tensor" + "!pip install -q -U tensor2tensor tensorflow" ], "cell_type": "code", "execution_count": 0, @@ -87,7 +87,7 @@ "from tensor2tensor import models\n", "from tensor2tensor import problems\n", "from tensor2tensor.layers import common_layers\n", - "from tensor2tensor.tpu import tpu_trainer_lib\n", + "from tensor2tensor.utils import trainer_lib\n", "from tensor2tensor.utils import t2t_model\n", "from tensor2tensor.utils import registry\n", "from tensor2tensor.utils import metrics\n", @@ -597,7 +597,7 @@ "model_name = \"transformer\"\n", "hparams_set = \"transformer_base\"\n", "\n", - "hparams = tpu_trainer_lib.create_hparams(hparams_set, data_dir=data_dir, problem_name=\"translate_ende_wmt32k\")\n", + "hparams = trainer_lib.create_hparams(hparams_set, data_dir=data_dir, problem_name=\"translate_ende_wmt32k\")\n", "\n", "# NOTE: Only create the model once when restoring from a checkpoint; it's a\n", "# Layer and so subsequent instantiations will have different variable scopes\n", @@ -1407,7 +1407,7 @@ " return tf.layers.conv2d(tf.nn.relu(h2), filters,\n", " kernel_size=(3, 3))\n", "\n", - "hparams = tpu_trainer_lib.create_hparams(\"basic_1\", data_dir=data_dir, problem_name=\"image_mnist\")\n", + "hparams = trainer_lib.create_hparams(\"basic_1\", data_dir=data_dir, problem_name=\"image_mnist\")\n", "hparams.hidden_size = 64\n", "model = MySimpleModel(hparams, Modes.TRAIN)" ], @@ -1584,7 +1584,7 @@ " break\n", "\n", " # Make the inputs and targets 4D\n", - " example[\"inputs\"] = tf.reshape(example[\"inputs\"], [1, 28, 28, 3])\n", + " example[\"inputs\"] = tf.reshape(example[\"inputs\"], [1, 28, 28, 1])\n", " example[\"targets\"] = tf.reshape(example[\"targets\"], [1, 1, 1, 1])\n", "\n", " # Call the model\n", diff --git a/tensor2tensor/rl/README.md b/tensor2tensor/rl/README.md index bf21ab1ad..053119f05 100644 --- a/tensor2tensor/rl/README.md +++ b/tensor2tensor/rl/README.md @@ -1,10 +1,11 @@ -# Tensor2Tensor Reinforcement Learning starter. +# Tensor2Tensor experimental Reinforcement Learning. The rl package intention is to provide possiblity to run reinforcement -algorithms within Tensorflow's computation graph. +algorithms within Tensorflow's computation graph. It's very experimental +for now and under heavy development. Currently the only supported algorithm is Proximy Policy Optimization - PPO. ## Sample usage - training in Pendulum-v0 environment. -```t2t-rl-trainer``` +```python rl/t2t_rl_trainer.py``` diff --git a/tensor2tensor/rl/__init__.py b/tensor2tensor/rl/__init__.py index e69de29bb..3f714ce1f 100644 --- a/tensor2tensor/rl/__init__.py +++ b/tensor2tensor/rl/__init__.py @@ -0,0 +1,15 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tensor2tensor/rl/collect.py b/tensor2tensor/rl/collect.py index dadab4d92..2ea262143 100644 --- a/tensor2tensor/rl/collect.py +++ b/tensor2tensor/rl/collect.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2018 The Tensor2Tensor Authors. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ def define_collect(policy_factory, batch_env, config): - + """Collect trajectories.""" memory_shape = [config.epoch_length] + [batch_env.observ.shape.as_list()[0]] memories_shapes_and_types = [ # observation @@ -45,6 +45,7 @@ def define_collect(policy_factory, batch_env, config): with tf.control_dependencies([reset_once_op]): def step(index, scores_sum, scores_num): + """Single step.""" # Note - the only way to ensure making a copy of tensor is to run simple # operation. We are waiting for tf.copy: # https://github.com/tensorflow/tensorflow/issues/11186 @@ -63,17 +64,17 @@ def step(index, scores_sum, scores_num): save_ops = [tf.scatter_update(memory_slot, index, value) for memory_slot, value in zip(memory, to_save)] cumulate_rewards_op = cumulative_rewards.assign_add(reward) - agent_indicies_to_reset = tf.where(done)[:, 0] + agent_indices_to_reset = tf.where(done)[:, 0] with tf.control_dependencies([cumulate_rewards_op]): scores_sum_delta = tf.reduce_sum( - tf.gather(cumulative_rewards, agent_indicies_to_reset)) + tf.gather(cumulative_rewards, agent_indices_to_reset)) scores_num_delta = tf.count_nonzero(done, dtype=tf.int32) with tf.control_dependencies(save_ops + [scores_sum_delta, scores_num_delta]): - reset_env_op = batch_env.reset(agent_indicies_to_reset) + reset_env_op = batch_env.reset(agent_indices_to_reset) reset_cumulative_rewards_op = tf.scatter_update( - cumulative_rewards, agent_indicies_to_reset, - tf.zeros(tf.shape(agent_indicies_to_reset))) + cumulative_rewards, agent_indices_to_reset, + tf.zeros(tf.shape(agent_indices_to_reset))) with tf.control_dependencies([reset_env_op, reset_cumulative_rewards_op]): return [index + 1, scores_sum + scores_sum_delta, diff --git a/tensor2tensor/rl/envs/__init__.py b/tensor2tensor/rl/envs/__init__.py index e69de29bb..3f714ce1f 100644 --- a/tensor2tensor/rl/envs/__init__.py +++ b/tensor2tensor/rl/envs/__init__.py @@ -0,0 +1,15 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tensor2tensor/rl/envs/batch_env.py b/tensor2tensor/rl/envs/batch_env.py index 30bfdce55..453348976 100644 --- a/tensor2tensor/rl/envs/batch_env.py +++ b/tensor2tensor/rl/envs/batch_env.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2018 The Tensor2Tensor Authors. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Combine multiple environments to step them in batch.""" + # The code was based on Danijar Hafner's code from tf.agents: # https://github.com/tensorflow/agents/blob/master/agents/tools/batch_env.py -"""Combine multiple environments to step them in batch.""" - from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/tensor2tensor/rl/envs/in_graph_batch_env.py b/tensor2tensor/rl/envs/in_graph_batch_env.py index d0e1e4c26..eae0826a3 100644 --- a/tensor2tensor/rl/envs/in_graph_batch_env.py +++ b/tensor2tensor/rl/envs/in_graph_batch_env.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2018 The Tensor2Tensor Authors. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,16 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Batch of environments inside the TensorFlow graph.""" + # The code was based on Danijar Hafner's code from tf.agents: # https://github.com/tensorflow/agents/blob/master/agents/tools/in_graph_batch_env.py -"""Batch of environments inside the TensorFlow graph.""" - from __future__ import absolute_import from __future__ import division from __future__ import print_function +# Dependency imports + import gym + import tensorflow as tf @@ -92,7 +95,6 @@ def simulate(self, action): with tf.control_dependencies([self._observ.assign(observ)]): return tf.identity(reward), tf.identity(done) - def reset(self, indices=None): """Reset the batch of environments. diff --git a/tensor2tensor/rl/envs/utils.py b/tensor2tensor/rl/envs/utils.py index 2b81af270..8171fbe17 100644 --- a/tensor2tensor/rl/envs/utils.py +++ b/tensor2tensor/rl/envs/utils.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2018 The Tensor2Tensor Authors. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,20 +13,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Utilities for using batched environments.""" + # The code was based on Danijar Hafner's code from tf.agents: # https://github.com/tensorflow/agents/blob/master/agents/tools/wrappers.py # https://github.com/tensorflow/agents/blob/master/agents/scripts/utility.py -"""Utilities for using batched environments.""" - import atexit import multiprocessing import sys import traceback -import tensorflow as tf + +# Dependency imports from tensor2tensor.rl.envs import batch_env from tensor2tensor.rl.envs import in_graph_batch_env +import tensorflow as tf + class ExternalProcessEnv(object): """Step environment in a separate process for lock free paralellism.""" @@ -202,6 +205,7 @@ def _worker(self, constructor, conn): conn.send((self._EXCEPTION, stacktrace)) conn.close() + def define_batch_env(constructor, num_agents, env_processes=True): """Create environments and apply all desired wrappers. diff --git a/tensor2tensor/rl/networks.py b/tensor2tensor/rl/networks.py index af8709191..4ad7c5020 100644 --- a/tensor2tensor/rl/networks.py +++ b/tensor2tensor/rl/networks.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2018 The Tensor2Tensor Authors. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,11 +15,14 @@ """Neural networks for actor-critic algorithms.""" -import operator -import functools import collections -import tensorflow as tf +import functools +import operator + +# Dependency imports + import gym +import tensorflow as tf NetworkOutput = collections.namedtuple( @@ -28,6 +31,7 @@ def feed_forward_gaussian_fun(observation_space, action_space, config, observations): + """Feed-forward gaussian.""" assert isinstance(observation_space, gym.spaces.box.Box) mean_weights_initializer = tf.contrib.layers.variance_scaling_initializer( diff --git a/tensor2tensor/rl/ppo.py b/tensor2tensor/rl/ppo.py index 1c9654608..a2b34c797 100644 --- a/tensor2tensor/rl/ppo.py +++ b/tensor2tensor/rl/ppo.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2018 The Tensor2Tensor Authors. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,11 +18,13 @@ Based on: https://arxiv.org/abs/1707.06347 """ + import tensorflow as tf + def define_ppo_step(observation, action, reward, done, value, old_pdf, policy_factory, config): - + """A step of PPO.""" new_policy_dist, new_value, _ = policy_factory(observation) new_pdf = new_policy_dist.prob(action) @@ -58,6 +60,7 @@ def define_ppo_step(observation, action, reward, done, value, old_pdf, def define_ppo_epoch(memory, policy_factory, config): + """An epoch of PPO.""" observation, reward, done, action, old_pdf, value = memory # This is to avoid propagating gradients though simulation of simulation @@ -69,8 +72,9 @@ def define_ppo_epoch(memory, policy_factory, config): old_pdf = tf.stop_gradient(old_pdf) policy_loss, value_loss, entropy_loss = tf.scan( - lambda _1, _2: define_ppo_step(observation, action, reward, done, value, - old_pdf, policy_factory, config), + lambda _1, _2: define_ppo_step( # pylint: disable=g-long-lambda + observation, action, reward, done, value, + old_pdf, policy_factory, config), tf.range(config.optimization_epochs), [0., 0., 0.], parallel_iterations=1) diff --git a/tensor2tensor/bin/t2t_rl_trainer.py b/tensor2tensor/rl/rl_trainer_lib.py similarity index 56% rename from tensor2tensor/bin/t2t_rl_trainer.py rename to tensor2tensor/rl/rl_trainer_lib.py index b53692ccc..ced6da342 100644 --- a/tensor2tensor/bin/t2t_rl_trainer.py +++ b/tensor2tensor/rl/rl_trainer_lib.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2018 The Tensor2Tensor Authors. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,38 +18,40 @@ from __future__ import absolute_import import functools -from munch import Munch -import tensorflow as tf -from tensor2tensor.rl.collect import define_collect -from tensor2tensor.rl.envs.utils import define_batch_env -from tensor2tensor.rl.ppo import define_ppo_epoch +# Dependency imports + +import gym +from tensor2tensor.rl import collect +from tensor2tensor.rl import networks +from tensor2tensor.rl import ppo +from tensor2tensor.rl.envs import utils + +import tensorflow as tf def define_train(policy_lambda, env_lambda, config): + """Define the training setup.""" env = env_lambda() action_space = env.action_space observation_space = env.observation_space - batch_env = define_batch_env(env_lambda, config["num_agents"]) + batch_env = utils.define_batch_env(env_lambda, config.num_agents) policy_factory = tf.make_template( - 'network', + "network", functools.partial(policy_lambda, observation_space, action_space, config)) - (collect_op, memory) = define_collect(policy_factory, batch_env, config) + (collect_op, memory) = collect.define_collect( + policy_factory, batch_env, config) with tf.control_dependencies([collect_op]): - ppo_op = define_ppo_epoch(memory, policy_factory, config) + ppo_op = ppo.define_ppo_epoch(memory, policy_factory, config) return ppo_op -def main(): - train(example_params()) - - def train(params): policy_lambda, env_lambda, config = params ppo_op = define_train(policy_lambda, env_lambda, config) @@ -61,32 +63,25 @@ def train(params): def example_params(): - from tensor2tensor.rl import networks - config = {} - config['init_mean_factor'] = 0.1 - config['init_logstd'] = 0.1 - config['policy_layers'] = 100, 100 - config['value_layers'] = 100, 100 - config['num_agents'] = 30 - config['clipping_coef'] = 0.2 - config['gae_gamma'] = 0.99 - config['gae_lambda'] = 0.95 - config['entropy_loss_coef'] = 0.01 - config['value_loss_coef'] = 1 - config['optimizer'] = tf.train.AdamOptimizer - config['learning_rate'] = 1e-4 - config['optimization_epochs'] = 15 - config['epoch_length'] = 200 - config['epochs_num'] = 2000 - - config = Munch(config) + """Example hyperparameters.""" + config = tf.contrib.training.HParams( + init_mean_factor=0.1, + init_logstd=0.1, + policy_layers=(100, 100), + value_layers=(100, 100), + num_agents=30, + clipping_coef=0.2, + gae_gamma=0.99, + gae_lambda=0.95, + entropy_loss_coef=0.01, + value_loss_coef=1, + optimizer=tf.train.AdamOptimizer, + learning_rate=1e-4, + optimization_epochs=15, + epoch_length=200, + epochs_num=2000) return networks.feed_forward_gaussian_fun, pendulum_lambda, config def pendulum_lambda(): - import gym return gym.make("Pendulum-v0") - - -if __name__ == '__main__': - main() diff --git a/tensor2tensor/rl/train_test.py b/tensor2tensor/rl/rl_trainer_lib_test.py similarity index 78% rename from tensor2tensor/rl/train_test.py rename to tensor2tensor/rl/rl_trainer_lib_test.py index ac14c2083..1a276ef39 100644 --- a/tensor2tensor/rl/train_test.py +++ b/tensor2tensor/rl/rl_trainer_lib_test.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2018 The Tensor2Tensor Authors. +# Copyright 2017 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,22 +15,20 @@ """Tests of basic flow of collecting trajectories and training PPO.""" -import tensorflow as tf - -from tensor2tensor.bin import t2t_rl_trainer +# Dependency imports +from tensor2tensor.rl import rl_trainer_lib -FLAGS = tf.app.flags.FLAGS +import tensorflow as tf class TrainTest(tf.test.TestCase): def test_no_crash_pendulum(self): - params = t2t_rl_trainer.example_params() + params = rl_trainer_lib.example_params() params[2].epochs_num = 10 - t2t_rl_trainer.train(params) + rl_trainer_lib.train(params) if __name__ == '__main__': - FLAGS.config = 'unused' tf.test.main() diff --git a/tensor2tensor/rl/t2t_rl_trainer.py b/tensor2tensor/rl/t2t_rl_trainer.py new file mode 100644 index 000000000..875c28567 --- /dev/null +++ b/tensor2tensor/rl/t2t_rl_trainer.py @@ -0,0 +1,30 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Training of RL agent with PPO algorithm.""" + +# Dependency imports + +from tensor2tensor.rl import rl_trainer_lib + +import tensorflow as tf + + +def main(_): + rl_trainer_lib.train(rl_trainer_lib.example_params()) + + +if __name__ == "__main__": + tf.app.run() diff --git a/tensor2tensor/utils/beam_search.py b/tensor2tensor/utils/beam_search.py index 3841b5953..3c7b8c203 100644 --- a/tensor2tensor/utils/beam_search.py +++ b/tensor2tensor/utils/beam_search.py @@ -90,7 +90,7 @@ def get_state_shape_invariants(tensor): def log_prob_from_logits(logits): - return logits - tf.reduce_logsumexp(logits, axis=2, keepdims=True) + return logits - tf.reduce_logsumexp(logits, axis=2, keep_dims=True) def compute_batch_indices(batch_size, beam_size): diff --git a/tensor2tensor/utils/cloud_mlengine.py b/tensor2tensor/utils/cloud_mlengine.py new file mode 100644 index 000000000..82ac23a39 --- /dev/null +++ b/tensor2tensor/utils/cloud_mlengine.py @@ -0,0 +1,242 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Launch on GCP's ML Engine.""" + +import os +import shutil +import sys +import tempfile + +from googleapiclient import discovery +from oauth2client.client import GoogleCredentials +from tensor2tensor.layers import common_hparams +from tensor2tensor.utils import cloud_tpu as cloud +from tensor2tensor.utils import registry +from tensor2tensor.utils import usr_dir as usr_dir_lib +import tensorflow as tf + +FLAGS = tf.flags.FLAGS + +CONSOLE_URL = 'https://console.cloud.google.com/mlengine/jobs/' + +# TODO(rsepassi): +# * Enable multi-machine sync/async training + +SETUP_PY = """ +from setuptools import find_packages +from setuptools import setup +setup( + name='DummyUsrDirPackage', + version='0.1', + packages=find_packages(), +) +""" + + +def flags_as_args(): + """Convert FLAGS to list of args suitable for passing on cmd line.""" + args_dict = dict(FLAGS.__dict__['__flags']) + del args_dict['cloud_mlengine'] + # Configured later + del args_dict['t2t_usr_dir'] + args = [] + for name, val in args_dict.items(): + if val is None: + continue + if name.startswith('autotune'): + continue + args.extend(['--%s' % name, str(val)]) + return args + + +def machine_config(num_gpus=1, use_tpu=False, master_type=None): + """Return dict specifying machine config for trainingInput.""" + scale_tier = 'BASIC_GPU' + if use_tpu: + scale_tier = 'BASIC_TPU' + elif num_gpus <= 0: + scale_tier = 'BASIC' + elif num_gpus > 1: + scale_tier = 'CUSTOM' + + config = {'scaleTier': scale_tier} + + if scale_tier == 'CUSTOM': + assert num_gpus > 1 + if num_gpus not in [4, 8]: + raise ValueError('Must use exactly 1, 4, or 8 GPUs.') + config['masterType'] = ('complex_model_m_gpu' + if num_gpus == 4 else 'complex_model_l_gpu') + + if master_type: + config['masterType'] = master_type + + return config + + +def configure_job(): + """Construct jobSpec for ML Engine job.""" + train_dir = FLAGS.output_dir + assert train_dir.startswith('gs://') + job_name = os.path.basename(train_dir) + + # See documentation: + # https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#traininginput + training_input = { + 'pythonModule': 'tensor2tensor.bin.t2t_trainer', + 'args': flags_as_args(), + 'region': cloud.default_region(), + 'runtimeVersion': '1.4', + 'pythonVersion': '3.5' if sys.version_info.major == 3 else '2.7', + 'jobDir': train_dir, + } + training_input.update( + machine_config( + num_gpus=FLAGS.worker_gpu, + use_tpu=FLAGS.use_tpu, + master_type=FLAGS.cloud_mlengine_master_type)) + if FLAGS.hparams_range: + assert FLAGS.autotune_objective + tf.logging.info('Configuring hyperparameter tuning.') + training_input['hyperparameters'] = configure_autotune( + FLAGS.hparams_range, + FLAGS.autotune_objective, + FLAGS.autotune_maximize, + FLAGS.autotune_max_trials, + FLAGS.autotune_parallel_trials, + ) + + if training_input['scaleTier'] == 'CUSTOM': + assert 'masterType' in training_input + + job_spec = {'jobId': job_name, 'trainingInput': training_input} + return job_spec + + +def launch_job(job_spec): + """Launch job on ML Engine.""" + project_id = 'projects/{}'.format(cloud.default_project()) + credentials = GoogleCredentials.get_application_default() + cloudml = discovery.build('ml', 'v1', credentials=credentials) + request = cloudml.projects().jobs().create(body=job_spec, parent=project_id) + request.execute() + + +def _tar_and_copy(src_dir, target_dir): + """Tar and gzip src_dir and copy to GCS target_dir.""" + src_dir = src_dir.rstrip('/') + target_dir = target_dir.rstrip('/') + tmp_dir = tempfile.gettempdir().rstrip('/') + src_base = os.path.basename(src_dir) + cloud.shell_run( + 'tar -zcf {tmp_dir}/{src_base}.tar.gz -C {src_dir} .', + src_dir=src_dir, + src_base=src_base, + tmp_dir=tmp_dir) + final_destination = '%s/%s.tar.gz' % (target_dir, src_base) + cloud.shell_run( + ('gsutil cp {tmp_dir}/{src_base}.tar.gz ' + '{final_destination}'), + tmp_dir=tmp_dir, + src_base=src_base, + final_destination=final_destination) + return final_destination + + +def tar_and_copy_t2t(train_dir): + """Tar Tensor2Tensor and cp to train_dir.""" + tf.logging.info('Tarring and pushing local Tensor2Tensor package.') + t2t_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + t2t_tar = _tar_and_copy(t2t_dir, train_dir) + return t2t_tar + + +def tar_and_copy_usr_dir(usr_dir, train_dir): + """Package, tar, and copy usr_dir to GCS train_dir.""" + tf.logging.info('Tarring and pushing t2t_usr_dir.') + usr_dir = os.path.abspath(os.path.expanduser(usr_dir)) + # Copy usr dir to a temp location + top_dir = os.path.join(tempfile.gettempdir(), 't2t_usr_container') + tmp_usr_dir = os.path.join(top_dir, usr_dir_lib.INTERNAL_USR_DIR_PACKAGE) + shutil.rmtree(top_dir, ignore_errors=True) + shutil.copytree(usr_dir, tmp_usr_dir) + # Insert setup.py if one does not exist + top_setup_fname = os.path.join(top_dir, 'setup.py') + usr_setup_fname = os.path.join(tmp_usr_dir, 'setup.py') + if tf.gfile.Exists(usr_setup_fname): + tf.gfile.Move(usr_setup_fname, top_setup_fname) + else: + with tf.gfile.Open(top_setup_fname, 'w') as f: + f.write(SETUP_PY) + usr_tar = _tar_and_copy(top_dir, train_dir) + return usr_tar + + +def autotune_paramspecs(hparams_range): + rhp = common_hparams.RangedHParams() + registry.ranged_hparams(hparams_range)(rhp) + return rhp.to_parameter_specs(name_prefix='hp_') + + +def configure_autotune(hparams_range, + objective, + maximize=True, + max_trials=10, + parallel_trials=1): + return { + 'goal': 'MAXIMIZE' if maximize else 'MINIMIZE', + 'params': autotune_paramspecs(hparams_range), + 'maxTrials': max_trials, + 'maxParallelTrials': parallel_trials, + 'hyperparameterMetricTag': objective, + } + + +def configure_trainer_package(job_spec, t2t_tar): + assert t2t_tar.startswith('gs://') + job_spec['trainingInput']['packageUris'] = [t2t_tar] + + +def configure_usr_dir(job_spec, usr_tar): + assert usr_tar.startswith('gs://') + job_spec['trainingInput']['packageUris'].append(usr_tar) + usr_args = ['--t2t_usr_dir', usr_dir_lib.INTERNAL_USR_DIR_PACKAGE] + job_spec['trainingInput']['args'].extend(usr_args) + + +def launch(): + """Launch t2t_trainer on Cloud ML Engine.""" + assert not FLAGS.cloud_tpu + assert not FLAGS.job_dir + assert FLAGS.output_dir.startswith('gs://') + assert FLAGS.data_dir.startswith('gs://') + assert FLAGS.worker_replicas <= 1 + assert FLAGS.ps_replicas <= 0 + + job_spec = configure_job() + job_name = job_spec['jobId'] + tf.logging.info('Launching job %s with ML Engine spec:\n%s', job_name, + job_spec) + assert cloud.confirm() + train_dir = FLAGS.output_dir + t2t_tar = tar_and_copy_t2t(train_dir) + configure_trainer_package(job_spec, t2t_tar) + if FLAGS.t2t_usr_dir: + usr_tar = tar_and_copy_usr_dir(FLAGS.t2t_usr_dir, train_dir) + configure_usr_dir(job_spec, usr_tar) + launch_job(job_spec) + tf.logging.info('Launched %s. See console to track: %s.', job_name, + CONSOLE_URL) diff --git a/tensor2tensor/utils/cloud.py b/tensor2tensor/utils/cloud_tpu.py similarity index 97% rename from tensor2tensor/utils/cloud.py rename to tensor2tensor/utils/cloud_tpu.py index 937c6ee46..53dd36bd0 100644 --- a/tensor2tensor/utils/cloud.py +++ b/tensor2tensor/utils/cloud_tpu.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Launch on GCP.""" +"""Launch on TPU on GCP.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -191,6 +191,9 @@ def create_tpu(cls): gcloud compute ssh {name} -- -N """ + DEFAULT_PROJECT = "gcloud config get-value project" + DEFAULT_REGION = "gcloud config get-value compute/region" + @contextlib.contextmanager def shell_background(cmd_, **kwargs): @@ -224,6 +227,14 @@ def format_cmd(cmd_, **kwargs): return cmd_.format(**kwargs).strip().split() +def default_region(): + return shell_output(Gcloud.DEFAULT_REGION).strip() + + +def default_project(): + return shell_output(Gcloud.DEFAULT_PROJECT).strip() + + def create_vm(vm_name): out = shell_output(Gcloud.create_vm(), name=vm_name) return out.split("\n")[1:-1][0].split()[4] diff --git a/tensor2tensor/utils/diet.py b/tensor2tensor/utils/diet.py index 19702338b..7ecfba693 100644 --- a/tensor2tensor/utils/diet.py +++ b/tensor2tensor/utils/diet.py @@ -193,10 +193,10 @@ def update_variable(self, var, grad_var): beta2_pow = tf.pow(params.beta2, global_step) if params.factored_second_moment_accumulator and len(var.shape) == 2: vr_update = tf.assign(slots["adam_vr"], slots["adam_vr"] * params.beta2 + - tf.reduce_mean(grad_squared, 1, keepdims=True) * + tf.reduce_mean(grad_squared, 1, keep_dims=True) * (1.0 - params.beta2)) vc_update = tf.assign(slots["adam_vc"], slots["adam_vc"] * params.beta2 + - tf.reduce_mean(grad_squared, 0, keepdims=True) * + tf.reduce_mean(grad_squared, 0, keep_dims=True) * (1.0 - params.beta2)) with tf.control_dependencies([vr_update, vc_update]): vr = tf.sqrt(slots["adam_vr"] / (1.0 - beta2_pow)) + params.epsilon diff --git a/tensor2tensor/utils/optimize.py b/tensor2tensor/utils/optimize.py index a497d56bd..bc09c009d 100644 --- a/tensor2tensor/utils/optimize.py +++ b/tensor2tensor/utils/optimize.py @@ -168,10 +168,6 @@ def learning_rate_decay(hparams, warmup_steps=0): hparams.learning_rate_boundaries, hparams.learning_rate_multiples) - if scheme == "noam": - return 5000.0 * hparams.hidden_size**-0.5 * tf.minimum( - (global_step + 1) * warmup_steps**-1.5, (global_step + 1)**-0.5) - if scheme == "cosine": cycle_steps = hparams.learning_rate_cosine_cycle_steps cycle_position = global_step % (2 * cycle_steps) @@ -224,6 +220,23 @@ def learning_rate_decay_with_warmup(hparams, num_worker_replicas=1): return tf.where(global_step < warmup_steps, warmup, decay) +def learning_rate_schedule(hparams, num_worker_replicas=1): + """Learning rate schedule based on hparams.""" + schedule = hparams.learning_rate_schedule + warmup_steps = tf.to_float(hparams.learning_rate_warmup_steps) + global_step = tf.to_float(tf.train.get_or_create_global_step()) + if hparams.learning_rate_decay_scheme == "noam": + # backwards compatiblity with previous behavior + schedule = "linear_warmup_rsqrt_decay" + if schedule == "warmup_and_decay": + return learning_rate_decay_with_warmup(hparams, num_worker_replicas) + elif schedule == "linear_warmup_rsqrt_decay": + return 5000.0 * hparams.hidden_size**-0.5 * tf.minimum( + (global_step + 1) * warmup_steps**-1.5, (global_step + 1)**-0.5) + else: + raise ValueError("Unrecognized learning rate schedule: %s" % schedule) + + def weight_decay_and_noise(loss, hparams, learning_rate, var_list=None): """Apply weight decay and weight noise.""" if var_list is None: diff --git a/tensor2tensor/utils/t2t_model.py b/tensor2tensor/utils/t2t_model.py index 0623a975e..aadf5e358 100644 --- a/tensor2tensor/utils/t2t_model.py +++ b/tensor2tensor/utils/t2t_model.py @@ -124,7 +124,7 @@ def hparams(self): @property def has_input(self): if self._problem_hparams: - return self._problem_hparams.input_modality + return "inputs" in self._problem_hparams.input_modality else: return True @@ -296,7 +296,7 @@ def optimize(self, loss, num_async_replicas=1): """Return a training op minimizing loss.""" tf.logging.info("Base learning rate: %f", self.hparams.learning_rate) lr = self.hparams.learning_rate - decay_rate = optimize.learning_rate_decay_with_warmup(self.hparams) + decay_rate = optimize.learning_rate_schedule(self.hparams) lr *= decay_rate if self.hparams.learning_rate_minimum: lr_min = float(self.hparams.learning_rate_minimum) diff --git a/tensor2tensor/utils/trainer_lib.py b/tensor2tensor/utils/trainer_lib.py index 039b06e68..0e64f2475 100644 --- a/tensor2tensor/utils/trainer_lib.py +++ b/tensor2tensor/utils/trainer_lib.py @@ -68,8 +68,12 @@ def create_hparams(hparams_set, hparams_overrides_str="", data_dir=None, problem_name=None): + """Create HParams with data_dir and problem hparams, if kwargs provided.""" hparams = registry.hparams(hparams_set)() if hparams_overrides_str: + tf.logging.info("Overriding hparams in %s with %s", + hparams_set, + hparams_overrides_str) hparams = hparams.parse(hparams_overrides_str) if data_dir: hparams.add_hparam("data_dir", data_dir) diff --git a/tensor2tensor/utils/usr_dir.py b/tensor2tensor/utils/usr_dir.py index d89745b98..5edd6f1a2 100644 --- a/tensor2tensor/utils/usr_dir.py +++ b/tensor2tensor/utils/usr_dir.py @@ -27,13 +27,20 @@ import tensorflow as tf +INTERNAL_USR_DIR_PACKAGE = "t2t_usr_dir_internal" + + def import_usr_dir(usr_dir): """Import module at usr_dir, if provided.""" if not usr_dir: return - dir_path = os.path.expanduser(usr_dir) - if dir_path[-1] == "/": - dir_path = dir_path[:-1] + if usr_dir == INTERNAL_USR_DIR_PACKAGE: + # The package has been installed with pip under this name for Cloud ML + # Engine so just import it. + importlib.import_module(INTERNAL_USR_DIR_PACKAGE) + return + + dir_path = os.path.abspath(os.path.expanduser(usr_dir).rstrip("/")) containing_dir, module_name = os.path.split(dir_path) tf.logging.info("Importing user module %s from path %s", module_name, containing_dir)