diff --git a/keras/src/layers/attention/attention.py b/keras/src/layers/attention/attention.py index 592468fe802..15ff906e592 100644 --- a/keras/src/layers/attention/attention.py +++ b/keras/src/layers/attention/attention.py @@ -1,6 +1,7 @@ from keras.src import backend from keras.src import ops from keras.src.api_export import keras_export +from keras.src.backend import KerasTensor from keras.src.layers.layer import Layer @@ -84,6 +85,8 @@ def __init__( f"Received: score_mode={score_mode}" ) + self._return_attention_scores = False + def build(self, input_shape): self._validate_inputs(input_shape) self.scale = None @@ -217,6 +220,7 @@ def call( use_causal_mask=False, ): self._validate_inputs(inputs=inputs, mask=mask) + self._return_attention_scores = return_attention_scores q = inputs[0] v = inputs[1] k = inputs[2] if len(inputs) > 2 else v @@ -226,16 +230,17 @@ def call( scores_mask = self._calculate_score_mask( scores, v_mask, use_causal_mask ) - result, attention_scores = self._apply_scores( + attention_output, attention_scores = self._apply_scores( scores=scores, value=v, scores_mask=scores_mask, training=training ) if q_mask is not None: # Mask of shape [batch_size, Tq, 1]. q_mask = ops.expand_dims(q_mask, axis=-1) - result *= ops.cast(q_mask, dtype=result.dtype) + attention_output *= ops.cast(q_mask, dtype=attention_output.dtype) if return_attention_scores: - return result, attention_scores - return result + return (attention_output, attention_scores) + else: + return attention_output def compute_mask(self, inputs, mask=None): self._validate_inputs(inputs=inputs, mask=mask) @@ -244,8 +249,49 @@ def compute_mask(self, inputs, mask=None): return ops.convert_to_tensor(mask[0]) def compute_output_shape(self, input_shape): - """Returns shape of value tensor dim, but for query tensor length""" - return (*input_shape[0][:-1], input_shape[1][-1]) + query_shape, value_shape, key_shape = input_shape + if key_shape is None: + key_shape = value_shape + + output_shape = (*query_shape[:-1], value_shape[-1]) + if self._return_attention_scores: + scores_shape = (query_shape[0], query_shape[1], key_shape[1]) + return output_shape, scores_shape + return output_shape + + def compute_output_spec( + self, + inputs, + mask=None, + return_attention_scores=False, + training=None, + use_causal_mask=False, + ): + # Validate and unpack inputs + self._validate_inputs(inputs, mask) + query = inputs[0] + value = inputs[1] + key = inputs[2] if len(inputs) > 2 else value + + # Compute primary output shape + output_shape = self.compute_output_shape( + [query.shape, value.shape, key.shape] + ) + output_spec = KerasTensor(output_shape, dtype=self.compute_dtype) + + # Handle attention scores if requested + if self._return_attention_scores: + scores_shape = ( + query.shape[0], + query.shape[1], + key.shape[1], + ) # (batch_size, Tq, Tv) + attention_scores_spec = KerasTensor( + scores_shape, dtype=self.compute_dtype + ) + return (output_spec, attention_scores_spec) + + return output_spec def _validate_inputs(self, inputs, mask=None): """Validates arguments of the call method.""" diff --git a/keras/src/layers/attention/attention_test.py b/keras/src/layers/attention/attention_test.py index de8dba64340..eab40b2a038 100644 --- a/keras/src/layers/attention/attention_test.py +++ b/keras/src/layers/attention/attention_test.py @@ -358,3 +358,62 @@ def test_attention_compute_output_shape(self): ), output.shape, ) + + def test_return_attention_scores_true(self): + """Test that the layer returns attention scores along with outputs.""" + # Generate dummy input data + query = np.random.random((2, 8, 16)).astype(np.float32) + value = np.random.random((2, 4, 16)).astype(np.float32) + + # Initialize the Attention layer + layer = layers.Attention() + + # Call the layer with return_attention_scores=True + output, attention_scores = layer( + [query, value], return_attention_scores=True + ) + + # Check the shape of the outputs + self.assertEqual(output.shape, (2, 8, 16)) # Output shape + self.assertEqual( + attention_scores.shape, (2, 8, 4) + ) # Attention scores shape + + def test_return_attention_scores_true_and_tuple(self): + """Test that the layer outputs are a tuple when + return_attention_scores=True.""" + # Generate dummy input data + query = np.random.random((2, 8, 16)).astype(np.float32) + value = np.random.random((2, 4, 16)).astype(np.float32) + + # Initialize the Attention layer + layer = layers.Attention() + + # Call the layer with return_attention_scores=True + outputs = layer([query, value], return_attention_scores=True) + + # Check that outputs is a tuple + self.assertIsInstance( + outputs, tuple, "Expected the outputs to be a tuple" + ) + + def test_return_attention_scores_true_tuple_then_unpack(self): + """Test that outputs can be unpacked correctly.""" + # Generate dummy input data + query = np.random.random((2, 8, 16)).astype(np.float32) + value = np.random.random((2, 4, 16)).astype(np.float32) + + # Initialize the Attention layer + layer = layers.Attention() + + # Call the layer with return_attention_scores=True + outputs = layer([query, value], return_attention_scores=True) + + # Unpack the outputs + output, attention_scores = outputs + + # Check the shape of the unpacked outputs + self.assertEqual(output.shape, (2, 8, 16)) # Output shape + self.assertEqual( + attention_scores.shape, (2, 8, 4) + ) # Attention scores shape