From 8907bcbff80cfd93b6b4c148eb7a660a95ac69ab Mon Sep 17 00:00:00 2001 From: Surya Date: Thu, 26 Dec 2024 23:12:15 +0530 Subject: [PATCH] fix attention output with symbolic tensors and attention scores (#20689) --- keras/src/layers/attention/attention.py | 2 +- keras/src/layers/attention/attention_test.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/keras/src/layers/attention/attention.py b/keras/src/layers/attention/attention.py index 15ff906e592..d336781c8b3 100644 --- a/keras/src/layers/attention/attention.py +++ b/keras/src/layers/attention/attention.py @@ -280,7 +280,7 @@ def compute_output_spec( output_spec = KerasTensor(output_shape, dtype=self.compute_dtype) # Handle attention scores if requested - if self._return_attention_scores: + if self._return_attention_scores or return_attention_scores: scores_shape = ( query.shape[0], query.shape[1], diff --git a/keras/src/layers/attention/attention_test.py b/keras/src/layers/attention/attention_test.py index eab40b2a038..88598d72112 100644 --- a/keras/src/layers/attention/attention_test.py +++ b/keras/src/layers/attention/attention_test.py @@ -417,3 +417,15 @@ def test_return_attention_scores_true_tuple_then_unpack(self): self.assertEqual( attention_scores.shape, (2, 8, 4) ) # Attention scores shape + + def test_return_attention_scores_with_symbolic_tensors(self): + """Test to check outputs with symbolic tensors with + return_attention_scores = True""" + attention = layers.Attention() + x = layers.Input(shape=(3, 5)) + y = layers.Input(shape=(4, 5)) + output, attention_scores = attention( + [x, y], return_attention_scores=True + ) + self.assertEqual(output.shape, (None, 3, 5)) # Output shape + self.assertEqual(attention_scores.shape, (None, 3, 4))