Skip to content

Commit

Permalink
Mllama kv scale fix (#335)
Browse files Browse the repository at this point in the history
* Using tensors in the explicit cache function calls from mllama implementation

* Properly creating the tensor
  • Loading branch information
gshtras authored Dec 18, 2024
1 parent 27f53a2 commit fa1ff83
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,14 +829,15 @@ def _attention_with_mask(
) -> torch.Tensor:
# Skip writing kv-cache for the initial profiling run.
if len(kv_cache.shape) > 1:
i = torch.ones(1, dtype=torch.float32)
if current_platform.is_rocm():
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_local_key_value_heads, self.head_dim)
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
PagedAttention.write_to_paged_cache(
cached_k, cached_v, key_cache, value_cache,
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
attn_metadata.cross_slot_mapping, "auto", i, i)
else:
if self.attn.backend in (_Backend.FLASH_ATTN,
_Backend.FLASH_ATTN_VLLM_V1):
Expand All @@ -852,8 +853,8 @@ def _attention_with_mask(
attn_metadata.
cross_slot_mapping, # type: ignore[union-attr]
"auto",
1.0,
1.0,
i,
i,
)
elif self.attn.backend in (_Backend.XFORMERS,
_Backend.TORCH_SDPA):
Expand All @@ -866,7 +867,7 @@ def _attention_with_mask(
[v[s:e] for s, e in kv_range_for_decode])
PagedAttention.write_to_paged_cache(
cached_k, cached_v, key_cache, value_cache,
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
attn_metadata.cross_slot_mapping, "auto", i, i)
else:
raise ValueError(
f"Unsupported Attention backend {self.attn.backend} "
Expand Down

0 comments on commit fa1ff83

Please sign in to comment.