Skip to content

Commit

Permalink
Fix: Return Attention Scores when return_attention_scores=True (ker…
Browse files Browse the repository at this point in the history
…as-team#20684)

* Fix: Ensure Attention Layer Returns Attention Scores when `return_attention_scores=True`

This pull request addresses an issue in the Attention layer where the return_attention_scores parameter wasn't correctly handled in the compute_output_shape method. This fix ensures that attention scores are returned when return_attention_scores=True.

## Changes Made
Modified compute_output_shape method to return the shape of both the attention output and the attention scores when return_attention_scores=True.

* Formatting

* Fixed score return and added unit tests for return_attention_scores=True

* Removed debug print statement
  • Loading branch information
Furkan-rgb authored Dec 24, 2024
1 parent c1316e5 commit df002a9
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 6 deletions.
58 changes: 52 additions & 6 deletions keras/src/layers/attention/attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from keras.src import backend
from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.backend import KerasTensor
from keras.src.layers.layer import Layer


Expand Down Expand Up @@ -84,6 +85,8 @@ def __init__(
f"Received: score_mode={score_mode}"
)

self._return_attention_scores = False

def build(self, input_shape):
self._validate_inputs(input_shape)
self.scale = None
Expand Down Expand Up @@ -217,6 +220,7 @@ def call(
use_causal_mask=False,
):
self._validate_inputs(inputs=inputs, mask=mask)
self._return_attention_scores = return_attention_scores
q = inputs[0]
v = inputs[1]
k = inputs[2] if len(inputs) > 2 else v
Expand All @@ -226,16 +230,17 @@ def call(
scores_mask = self._calculate_score_mask(
scores, v_mask, use_causal_mask
)
result, attention_scores = self._apply_scores(
attention_output, attention_scores = self._apply_scores(
scores=scores, value=v, scores_mask=scores_mask, training=training
)
if q_mask is not None:
# Mask of shape [batch_size, Tq, 1].
q_mask = ops.expand_dims(q_mask, axis=-1)
result *= ops.cast(q_mask, dtype=result.dtype)
attention_output *= ops.cast(q_mask, dtype=attention_output.dtype)
if return_attention_scores:
return result, attention_scores
return result
return (attention_output, attention_scores)
else:
return attention_output

def compute_mask(self, inputs, mask=None):
self._validate_inputs(inputs=inputs, mask=mask)
Expand All @@ -244,8 +249,49 @@ def compute_mask(self, inputs, mask=None):
return ops.convert_to_tensor(mask[0])

def compute_output_shape(self, input_shape):
"""Returns shape of value tensor dim, but for query tensor length"""
return (*input_shape[0][:-1], input_shape[1][-1])
query_shape, value_shape, key_shape = input_shape
if key_shape is None:
key_shape = value_shape

output_shape = (*query_shape[:-1], value_shape[-1])
if self._return_attention_scores:
scores_shape = (query_shape[0], query_shape[1], key_shape[1])
return output_shape, scores_shape
return output_shape

def compute_output_spec(
self,
inputs,
mask=None,
return_attention_scores=False,
training=None,
use_causal_mask=False,
):
# Validate and unpack inputs
self._validate_inputs(inputs, mask)
query = inputs[0]
value = inputs[1]
key = inputs[2] if len(inputs) > 2 else value

# Compute primary output shape
output_shape = self.compute_output_shape(
[query.shape, value.shape, key.shape]
)
output_spec = KerasTensor(output_shape, dtype=self.compute_dtype)

# Handle attention scores if requested
if self._return_attention_scores:
scores_shape = (
query.shape[0],
query.shape[1],
key.shape[1],
) # (batch_size, Tq, Tv)
attention_scores_spec = KerasTensor(
scores_shape, dtype=self.compute_dtype
)
return (output_spec, attention_scores_spec)

return output_spec

def _validate_inputs(self, inputs, mask=None):
"""Validates arguments of the call method."""
Expand Down
59 changes: 59 additions & 0 deletions keras/src/layers/attention/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,62 @@ def test_attention_compute_output_shape(self):
),
output.shape,
)

def test_return_attention_scores_true(self):
"""Test that the layer returns attention scores along with outputs."""
# Generate dummy input data
query = np.random.random((2, 8, 16)).astype(np.float32)
value = np.random.random((2, 4, 16)).astype(np.float32)

# Initialize the Attention layer
layer = layers.Attention()

# Call the layer with return_attention_scores=True
output, attention_scores = layer(
[query, value], return_attention_scores=True
)

# Check the shape of the outputs
self.assertEqual(output.shape, (2, 8, 16)) # Output shape
self.assertEqual(
attention_scores.shape, (2, 8, 4)
) # Attention scores shape

def test_return_attention_scores_true_and_tuple(self):
"""Test that the layer outputs are a tuple when
return_attention_scores=True."""
# Generate dummy input data
query = np.random.random((2, 8, 16)).astype(np.float32)
value = np.random.random((2, 4, 16)).astype(np.float32)

# Initialize the Attention layer
layer = layers.Attention()

# Call the layer with return_attention_scores=True
outputs = layer([query, value], return_attention_scores=True)

# Check that outputs is a tuple
self.assertIsInstance(
outputs, tuple, "Expected the outputs to be a tuple"
)

def test_return_attention_scores_true_tuple_then_unpack(self):
"""Test that outputs can be unpacked correctly."""
# Generate dummy input data
query = np.random.random((2, 8, 16)).astype(np.float32)
value = np.random.random((2, 4, 16)).astype(np.float32)

# Initialize the Attention layer
layer = layers.Attention()

# Call the layer with return_attention_scores=True
outputs = layer([query, value], return_attention_scores=True)

# Unpack the outputs
output, attention_scores = outputs

# Check the shape of the unpacked outputs
self.assertEqual(output.shape, (2, 8, 16)) # Output shape
self.assertEqual(
attention_scores.shape, (2, 8, 4)
) # Attention scores shape

0 comments on commit df002a9

Please sign in to comment.