diff --git a/allegro/models/transformers/block.py b/allegro/models/transformers/block.py index 190f910..38a8265 100644 --- a/allegro/models/transformers/block.py +++ b/allegro/models/transformers/block.py @@ -818,18 +818,23 @@ def __call__( # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 - if self.attention_mode == 'flash': - # assert attention_mask is None, 'flash-attn do not support attention_mask' - with sdpa_kernel(SDPBackend.FLASH_ATTENTION): - hidden_states = F.scaled_dot_product_attention( - query, key, value, dropout_p=0.0, is_causal=False - ) - elif self.attention_mode == 'xformers': - with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION): - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + # if self.attention_mode == 'flash': + # # assert attention_mask is None, 'flash-attn do not support attention_mask' + # with sdpa_kernel(SDPBackend.FLASH_ATTENTION): + # hidden_states = F.scaled_dot_product_attention( + # query, key, value, dropout_p=0.0, is_causal=False + # ) + # elif self.attention_mode == 'xformers': + # with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION): + # hidden_states = F.scaled_dot_product_attention( + # query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + # ) + # Use basic attention implementation + with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=True): + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim) hidden_states = hidden_states.to(query.dtype)