Skip to content

Commit

Permalink
Vision model ROCm fix
Browse files Browse the repository at this point in the history
  • Loading branch information
gshtras committed Nov 18, 2024
1 parent 9540837 commit d889855
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 22 deletions.
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def __post_init__(self):
assert self.num_added_elements <= self.num_added_elements_padded


@torch.compile(dynamic=True)
#@torch.compile(dynamic=True)
def get_masked_input_and_mask(
input_: torch.Tensor, org_vocab_start_index: int,
org_vocab_end_index: int, num_org_vocab_padding: int,
Expand Down
52 changes: 31 additions & 21 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@

import vllm.distributed.parallel_state as ps
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.backends.xformers import XFormersMetadata
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
Expand All @@ -53,6 +51,7 @@
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform
from vllm.sequence import SequenceData
from vllm.utils import is_list_of

Expand Down Expand Up @@ -829,21 +828,7 @@ def _attention_with_mask(
) -> torch.Tensor:
# Skip writing kv-cache for the initial profiling run.
if len(kv_cache.shape) > 1:
if isinstance(attn_metadata, FlashAttentionMetadata):
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
torch.ops._C_cache_ops.reshape_and_cache_flash(
cached_k,
cached_v,
kv_cache[0],
kv_cache[1],
attn_metadata.
cross_slot_mapping, # type: ignore[union-attr]
"auto",
1.0,
1.0,
)
elif isinstance(attn_metadata, XFormersMetadata):
if current_platform.is_rocm():
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_local_key_value_heads, self.head_dim)
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
Expand All @@ -852,10 +837,35 @@ def _attention_with_mask(
cached_k, cached_v, key_cache, value_cache,
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
else:
raise ValueError(
f"Unsupported AttentionMetadata {type(attn_metadata)} "
f"class found. Expected the AttentionMetadata to "
f"be either XFormersMetadata or FlashAttentionMetadata.")
from vllm.attention.backends.flash_attn import FlashAttentionMetadata

Check failure on line 840 in vllm/model_executor/models/mllama.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/model_executor/models/mllama.py:840:81: E501 Line too long (85 > 80)
from vllm.attention.backends.xformers import XFormersMetadata
if isinstance(attn_metadata, FlashAttentionMetadata):
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])

Check failure on line 843 in vllm/model_executor/models/mllama.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/model_executor/models/mllama.py:843:81: E501 Line too long (82 > 80)
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])

Check failure on line 844 in vllm/model_executor/models/mllama.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/model_executor/models/mllama.py:844:81: E501 Line too long (82 > 80)
torch.ops._C_cache_ops.reshape_and_cache_flash(
cached_k,
cached_v,
kv_cache[0],
kv_cache[1],
attn_metadata.
cross_slot_mapping, # type: ignore[union-attr]
"auto",
1.0,
1.0,
)
elif isinstance(attn_metadata, XFormersMetadata):
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_local_key_value_heads, self.head_dim)
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])

Check failure on line 859 in vllm/model_executor/models/mllama.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/model_executor/models/mllama.py:859:81: E501 Line too long (82 > 80)
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])

Check failure on line 860 in vllm/model_executor/models/mllama.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/model_executor/models/mllama.py:860:81: E501 Line too long (82 > 80)
PagedAttention.write_to_paged_cache(
cached_k, cached_v, key_cache, value_cache,
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
else:
raise ValueError(
f"Unsupported AttentionMetadata {type(attn_metadata)} "
f"class found. Expected the AttentionMetadata to "
f"be either XFormersMetadata or FlashAttentionMetadata.")

Check failure on line 868 in vllm/model_executor/models/mllama.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/model_executor/models/mllama.py:868:81: E501 Line too long (81 > 80)

# We have to call torch.sdpa for prefill when using a
# custom cross-attention mask. Because the mask is not a
Expand Down

0 comments on commit d889855

Please sign in to comment.