diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 7f9b5c3cdb5ee..583f4be1b28b1 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -829,6 +829,7 @@ def _attention_with_mask( ) -> torch.Tensor: # Skip writing kv-cache for the initial profiling run. if len(kv_cache.shape) > 1: + i = torch.ones(1, dtype=torch.float32) if current_platform.is_rocm(): key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_local_key_value_heads, self.head_dim) @@ -836,7 +837,7 @@ def _attention_with_mask( cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) PagedAttention.write_to_paged_cache( cached_k, cached_v, key_cache, value_cache, - attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0) + attn_metadata.cross_slot_mapping, "auto", i, i) else: if self.attn.backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1): @@ -852,8 +853,8 @@ def _attention_with_mask( attn_metadata. cross_slot_mapping, # type: ignore[union-attr] "auto", - 1.0, - 1.0, + i, + i, ) elif self.attn.backend in (_Backend.XFORMERS, _Backend.TORCH_SDPA): @@ -866,7 +867,7 @@ def _attention_with_mask( [v[s:e] for s, e in kv_range_for_decode]) PagedAttention.write_to_paged_cache( cached_k, cached_v, key_cache, value_cache, - attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0) + attn_metadata.cross_slot_mapping, "auto", i, i) else: raise ValueError( f"Unsupported Attention backend {self.attn.backend} "