diff --git a/tensor2tensor/layers/common_attention.py b/tensor2tensor/layers/common_attention.py index a43afec47..2b1bd124f 100644 --- a/tensor2tensor/layers/common_attention.py +++ b/tensor2tensor/layers/common_attention.py @@ -166,17 +166,17 @@ def add_positional_embedding_nd(x, max_length, name): def embedding_to_padding(emb): - """Input embeddings -> is_padding. + """Calculates the padding mask based on which embeddings are all zero. We have hacked symbol_modality to return all-zero embeddings for padding. Args: emb: a Tensor with shape [..., depth]. Returns: - a boolean Tensor with shape [...]. + a float Tensor with shape [...]. """ emb_sum = tf.reduce_sum(tf.abs(emb), axis=-1) - return tf.equal(emb_sum, 0.0) + return tf.to_float(tf.equal(emb_sum, 0.0)) def attention_bias_lower_triangle(length): @@ -197,13 +197,13 @@ def attention_bias_ignore_padding(memory_padding): """Create an bias tensor to be added to attention logits. Args: - memory_padding: a boolean `Tensor` with shape [batch, memory_length]. + memory_padding: a float `Tensor` with shape [batch, memory_length]. Returns: a `Tensor` with shape [batch, 1, 1, memory_length]. """ - ret = tf.to_float(memory_padding) * -1e9 - return tf.expand_dims(tf.expand_dims(ret, 1), 1) + ret = memory_padding * -1e9 + return tf.expand_dims(tf.expand_dims(ret, axis=1), axis=1) def attention_bias_proximal(length): @@ -523,8 +523,7 @@ def pad_l_and_r(x, pad_length): # [batch, heads, blocks, block_length, dim] k_new = tf.transpose(k_new, [2, 3, 0, 1, 4]) - attention_bias = tf.expand_dims( - tf.to_float(embedding_to_padding(k_new)) * -1e9, axis=-2) + attention_bias = tf.expand_dims(embedding_to_padding(k_new) * -1e9, axis=-2) v_t = tf.transpose(v, [2, 0, 1, 3]) v_new = tf.gather(v_t, gather_indices)