From 216e382fa11f4dfb858ae283a6fa69b8e7b4f507 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Wed, 18 Dec 2024 16:36:21 +0000 Subject: [PATCH 1/2] Using tensors in the explicit cache function calls from mllama implementation --- vllm/model_executor/models/mllama.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 7f9b5c3cdb5ee..32f7c115cfe35 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -829,6 +829,7 @@ def _attention_with_mask( ) -> torch.Tensor: # Skip writing kv-cache for the initial profiling run. if len(kv_cache.shape) > 1: + i = torch.ones(dtype=torch.float32) if current_platform.is_rocm(): key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_local_key_value_heads, self.head_dim) @@ -836,7 +837,7 @@ def _attention_with_mask( cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) PagedAttention.write_to_paged_cache( cached_k, cached_v, key_cache, value_cache, - attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0) + attn_metadata.cross_slot_mapping, "auto", i, i) else: if self.attn.backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1): @@ -852,8 +853,8 @@ def _attention_with_mask( attn_metadata. cross_slot_mapping, # type: ignore[union-attr] "auto", - 1.0, - 1.0, + i, + i, ) elif self.attn.backend in (_Backend.XFORMERS, _Backend.TORCH_SDPA): @@ -866,7 +867,7 @@ def _attention_with_mask( [v[s:e] for s, e in kv_range_for_decode]) PagedAttention.write_to_paged_cache( cached_k, cached_v, key_cache, value_cache, - attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0) + attn_metadata.cross_slot_mapping, "auto", i, i) else: raise ValueError( f"Unsupported Attention backend {self.attn.backend} " From 04e64246f9c6aed798334084ad5e07379fe20256 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Wed, 18 Dec 2024 16:40:56 +0000 Subject: [PATCH 2/2] Properly creating the tensor --- vllm/model_executor/models/mllama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 32f7c115cfe35..583f4be1b28b1 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -829,7 +829,7 @@ def _attention_with_mask( ) -> torch.Tensor: # Skip writing kv-cache for the initial profiling run. if len(kv_cache.shape) > 1: - i = torch.ones(dtype=torch.float32) + i = torch.ones(1, dtype=torch.float32) if current_platform.is_rocm(): key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_local_key_value_heads, self.head_dim)