Skip to content

Commit

Permalink
fix: use basic attention kernel for consumer GPU compatibility
Browse files Browse the repository at this point in the history
reference: rhymes-ai#17
  • Loading branch information
ai-anchorite committed Oct 25, 2024
1 parent 90ee2a9 commit 56e4236
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions allegro/models/transformers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,18 +818,23 @@ def __call__(

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
if self.attention_mode == 'flash':
# assert attention_mask is None, 'flash-attn do not support attention_mask'
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False
)
elif self.attention_mode == 'xformers':
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
# if self.attention_mode == 'flash':
# # assert attention_mask is None, 'flash-attn do not support attention_mask'
# with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
# hidden_states = F.scaled_dot_product_attention(
# query, key, value, dropout_p=0.0, is_causal=False
# )
# elif self.attention_mode == 'xformers':
# with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
# hidden_states = F.scaled_dot_product_attention(
# query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
# )

# Use basic attention implementation
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=True):
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
Expand Down

0 comments on commit 56e4236

Please sign in to comment.