Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #13 from rsepassi/push
Browse files Browse the repository at this point in the history
v1.0.5
  • Loading branch information
lukaszkaiser authored Jun 22, 2017
2 parents 7ec178b + 8195f34 commit 8ec4233
Show file tree
Hide file tree
Showing 9 changed files with 238 additions and 35 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Compiled python modules.
*.pyc

# Python egg metadata, regenerated from source files by setuptools.
/*.egg-info
48 changes: 48 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version](https://badge.fury.io/py/tensor2tensor.svg)](https://badge.fury.io/py/t
Issues](https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg)](https://github.com/tensorflow/tensor2tensor/issues)
[![Contributions
welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)
[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby)
[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0)

[T2T](https://github.com/tensorflow/tensor2tensor) is a modular and extensible
Expand All @@ -22,6 +23,8 @@ send along a pull request to add your data-set or model.
See [our contribution
doc](CONTRIBUTING.md) for details and our [open
issues](https://github.com/tensorflow/tensor2tensor/issues).
And chat with us and other users on
[Gitter](https://gitter.im/tensor2tensor/Lobby).

---

Expand Down Expand Up @@ -95,7 +98,14 @@ cat $DECODE_FILE.$MODEL.$HPARAMS.beam$BEAM_SIZE.alpha$ALPHA.decodes
## Installation

```
# Assumes tensorflow or tensorflow-gpu installed
pip install tensor2tensor
# Installs with tensorflow-gpu requirement
pip install tensor2tensor[tensorflow_gpu]
# Installs with tensorflow (cpu) requirement
pip install tensor2tensor[tensorflow]
```

Binaries:
Expand Down Expand Up @@ -191,6 +201,44 @@ related flags control local and distributed training/evaluation

---

## Adding your own components

T2T's components are registered using a central registration mechanism that
enables easily adding new ones and easily swapping amongst them by command-line
flag. You can add your own components without editing the T2T codebase by
specifying the `--t2t_usr_dir` flag in `t2t-trainer`.

You can currently do so for models, hyperparameter sets, and modalities. Please
do submit a pull request if your component might be useful to others.

Here's an example with a new hyperparameter set:

```python
# In ~/usr/t2t_usr/my_registrations.py

from tensor2tensor.models import transformer
from tensor2tensor.utils import registry

@registry.register_hparams
def transformer_my_very_own_hparams_set():
hparams = transformer.transformer_base()
hparams.hidden_size = 1024
...
```

```python
# In ~/usr/t2t_usr/__init__.py
import my_registrations
```

```
t2t-trainer --t2t_usr_dir=~/usr/t2t_usr --registry_help
```

You'll see under the registered HParams your
`transformer_my_very_own_hparams_set`, which you can directly use on the command
line with the `--hparams_set` flag.

## Adding a dataset

See the [data generators
Expand Down
7 changes: 5 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='tensor2tensor',
version='1.0.4',
version='1.0.5',
description='Tensor2Tensor',
author='Google Inc.',
author_email='[email protected]',
Expand All @@ -17,8 +17,11 @@
'numpy',
'sympy',
'six',
'tensorflow-gpu>=1.2.0rc1',
],
extras_require={
'tensorflow': ['tensorflow>=1.2.0rc1'],
'tensorflow_gpu': ['tensorflow-gpu>=1.2.0rc1'],
},
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
Expand Down
30 changes: 29 additions & 1 deletion tensor2tensor/bin/t2t-trainer
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,45 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import importlib
import os
import sys

# Dependency imports

from tensor2tensor.utils import trainer_utils as utils

import tensorflow as tf

FLAGS = tf.flags.FLAGS
flags = tf.flags
FLAGS = flags.FLAGS

flags.DEFINE_string("t2t_usr_dir", "",
"Path to a Python module that will be imported. The "
"__init__.py file should include the necessary imports. "
"The imported files should contain registrations, "
"e.g. @registry.register_model calls, that will then be "
"available to the t2t-trainer.")


def import_usr_dir():
"""Import module at FLAGS.t2t_usr_dir, if provided."""
if not FLAGS.t2t_usr_dir:
return
dir_path = os.path.expanduser(FLAGS.t2t_usr_dir)
if dir_path[-1] == "/":
dir_path = dir_path[:-1]
containing_dir, module_name = os.path.split(dir_path)
tf.logging.info("Importing user module %s from path %s", module_name,
containing_dir)
sys.path.insert(0, containing_dir)
importlib.import_module(module_name)
sys.path.pop(0)


def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
import_usr_dir()
utils.log_registry()
utils.validate_flags()
utils.run(
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/data_generators/lm_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
tar xvf 1-billion-word-language-modeling-benchmark-r13output.tar.gz
# replace oov words with UNK
./blaze-bin/third_party/py/tensor2tensor/data_generators/replace_oov \
$BINARYDIR/replace_oov \
--vocab_file=$DATADIR/vocab-2016-09-10.txt \
--in_filepattern=\
$DATADIR/1-billion-word-language-modeling-benchmark-r13output/\
Expand Down
107 changes: 93 additions & 14 deletions tensor2tensor/models/common_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,14 @@ def attention_image_summary(attn, image_shapes=None):
Args:
attn: a Tensor with shape [batch, num_heads, query_length, memory_length]
image_shapes: optional quadruple of integer scalars.
image_shapes: optional tuple of integer scalars.
If the query positions and memory positions represent the
pixels of a flattened image, then pass in their dimensions:
pixels of flattened images, then pass in their dimensions:
(query_rows, query_cols, memory_rows, memory_cols).
If the query positions and memory positions represent the
pixels x channels of flattened images, then pass in their dimensions:
(query_rows, query_cols, query_channels,
memory_rows, memory_cols, memory_channels).
"""
num_heads = attn.get_shape().as_list()[1]
# [batch, query_length, memory_length, num_heads]
Expand All @@ -286,10 +290,20 @@ def attention_image_summary(attn, image_shapes=None):
image = split_last_dimension(image, 3)
image = tf.reduce_max(image, 4)
if image_shapes is not None:
q_rows, q_cols, m_rows, m_cols = list(image_shapes)
image = tf.reshape(image, [-1, q_rows, q_cols, m_rows, m_cols, 3])
image = tf.transpose(image, [0, 1, 3, 2, 4, 5])
image = tf.reshape(image, [-1, q_rows * m_rows, q_cols * m_cols, 3])
if len(image_shapes) == 4:
q_rows, q_cols, m_rows, m_cols = list(image_shapes)
image = tf.reshape(image, [-1, q_rows, q_cols, m_rows, m_cols, 3])
image = tf.transpose(image, [0, 1, 3, 2, 4, 5])
image = tf.reshape(image, [-1, q_rows * m_rows, q_cols * m_cols, 3])
else:
assert len(image_shapes) == 6
q_rows, q_cols, q_channnels, m_rows, m_cols, m_channels = list(
image_shapes)
image = tf.reshape(image, [-1, q_rows, q_cols, q_channnels,
m_rows, m_cols, m_channels, 3])
image = tf.transpose(image, [0, 1, 4, 3, 2, 5, 6, 7])
image = tf.reshape(image, [-1, q_rows * m_rows * q_channnels,
q_cols * m_cols * m_channels, 3])
tf.summary.image("attention", image, max_outputs=1)


Expand All @@ -310,10 +324,8 @@ def dot_product_attention(q,
bias: bias Tensor (see attention_bias())
dropout_rate: a floating point number
summaries: a boolean
image_shapes: optional quadruple of integer scalars for image summary.
If the query positions and memory positions represent the
pixels of a flattened image, then pass in their dimensions:
(query_rows, query_cols, memory_rows, memory_cols).
image_shapes: optional tuple of integer scalars.
see comments for attention_image_summary()
name: an optional string
Returns:
Expand Down Expand Up @@ -356,10 +368,8 @@ def multihead_attention(query_antecedent,
num_heads: an integer dividing total_key_depth and total_value_depth
dropout_rate: a floating point number
summaries: a boolean
image_shapes: optional quadruple of integer scalars for image summary.
If the query positions and memory positions represent the
pixels of a flattened image, then pass in their dimensions:
(query_rows, query_cols, memory_rows, memory_cols).
image_shapes: optional tuple of integer scalars.
see comments for attention_image_summary()
name: an optional string
Returns:
Expand Down Expand Up @@ -398,3 +408,72 @@ def multihead_attention(query_antecedent,
x = combine_heads(x)
x = common_layers.conv1d(x, output_depth, 1, name="output_transform")
return x


def parameter_attention(x,
total_key_depth,
total_value_depth,
output_depth,
memory_rows,
num_heads,
dropout_rate,
name=None):
"""Attention over parameters.
We use the same multi-headed attention as in the other layers, but the memory
keys and values are model parameters. There are no linear transformation
on the keys or values.
We are also a bit more careful about memory usage, since the number of
memory positions may be very large.
Args:
x: a Tensor with shape [batch, length_q, channels]
total_key_depth: an integer
total_value_depth: an integer
output_depth: an integer
memory_rows: an integer
num_heads: an integer dividing total_key_depth and total_value_depth
dropout_rate: a floating point number
name: an optional string
Returns:
A Tensor.
"""
with tf.variable_scope(name, default_name="parameter_attention",
values=[x]):
head_size_k = total_key_depth // num_heads
head_size_v = total_value_depth // num_heads
var_shape_k = [num_heads, memory_rows, head_size_k]
var_shape_v = [num_heads, memory_rows, head_size_v]
k = tf.get_variable(
"k", var_shape_k,
initializer=tf.random_normal_initializer(
0, output_depth ** -0.5)) * (num_heads ** 0.5)
v = tf.get_variable(
"v", var_shape_v,
initializer=tf.random_normal_initializer(
0, output_depth ** -0.5)) * (output_depth ** 0.5)
batch_size = tf.shape(x)[0]
length = tf.shape(x)[1]
q = common_layers.conv1d(x, total_key_depth, 1, name="q_transform")
if dropout_rate:
# This is a cheaper form of attention dropout where we use to use
# the same dropout decisions across batch elemets and query positions,
# but different decisions across heads and memory positions.
v = tf.nn.dropout(v, 1.0 - dropout_rate,
noise_shape=[num_heads, memory_rows, 1])
# query is [batch, length, hidden_size]
# reshape and transpose it to [heads, batch * length, head_size]
q = tf.reshape(q, [batch_size, length, num_heads, head_size_k])
q = tf.transpose(q, [2, 0, 1, 3])
q = tf.reshape(q, [num_heads, batch_size * length, head_size_k])
weights = tf.matmul(q, k, transpose_b=True)
weights = tf.nn.softmax(weights)
y = tf.matmul(weights, v)
y = tf.reshape(y, [num_heads, batch_size, length, head_size_v])
y = tf.transpose(y, [1, 2, 0, 3])
y = tf.reshape(y, [batch_size, length, total_value_depth])
y.set_shape([None, None, total_value_depth])
y = common_layers.conv1d(y, output_depth, 1, name="output_transform")
return y
6 changes: 3 additions & 3 deletions tensor2tensor/models/modalities.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,8 @@ class IdentityModality(modality.Modality):
def targets_dimensionality(self):
return self._vocab_size

def inputs_bottom_simple(self, inputs):
return tf.to_float(inputs)
def bottom(self, x):
return tf.to_float(x)

def targets_top_simple(self, body_output, _):
def top(self, body_output, _):
return body_output
Loading

0 comments on commit 8ec4233

Please sign in to comment.