From a0edff63c7c4ac177596024b518d11deb0e70ef5 Mon Sep 17 00:00:00 2001 From: Ruoming Pang Date: Tue, 17 Dec 2024 14:25:53 +0000 Subject: [PATCH] black --- axlearn/common/flash_attention/gpu_attention_test.py | 8 ++++++-- axlearn/common/flash_attention/tpu_attention.py | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/axlearn/common/flash_attention/gpu_attention_test.py b/axlearn/common/flash_attention/gpu_attention_test.py index a031f795..901f9bf5 100644 --- a/axlearn/common/flash_attention/gpu_attention_test.py +++ b/axlearn/common/flash_attention/gpu_attention_test.py @@ -161,7 +161,9 @@ def test_triton_against_xla_ref( block_q=block_size, block_k=block_size, ) - jax_ref_out = mha_reference(q, k, v, bias, segment_ids, causal=causal, softmax_scale=softmax_scale) + jax_ref_out = mha_reference( + q, k, v, bias, segment_ids, causal=causal, softmax_scale=softmax_scale + ) if input_dtype == jnp.float16: chex.assert_trees_all_close(jax_out, jax_ref_out, atol=0.005) elif input_dtype == jnp.float32: @@ -227,7 +229,9 @@ def test_cudnn_against_triton_ref( softmax_scale = q.shape[-1] ** -0.5 # Compare outputs. - jax_out = cudnn_dot_product_attention(q, k, v, bias=None, causal=causal, softmax_scale=softmax_scale) + jax_out = cudnn_dot_product_attention( + q, k, v, bias=None, causal=causal, softmax_scale=softmax_scale + ) jax_ref_out = flash_attention(q, k, v, bias=None, causal=causal, softmax_scale=softmax_scale) if dtype == jnp.bfloat16: # We relax the atol to support bf16 in the unit test. diff --git a/axlearn/common/flash_attention/tpu_attention.py b/axlearn/common/flash_attention/tpu_attention.py index 52c7d200..7b44266b 100644 --- a/axlearn/common/flash_attention/tpu_attention.py +++ b/axlearn/common/flash_attention/tpu_attention.py @@ -396,7 +396,9 @@ def pallas_tpu_flash_attention( block_sizes = LegacyBlockSizes.get_default( batch_size, num_heads, q_seq_len, kv_seq_len, d_model ) - return _flash_attention(q, k, v, ab, segment_ids, False, causal, softmax_scale, block_sizes, debug) + return _flash_attention( + q, k, v, ab, segment_ids, False, causal, softmax_scale, block_sizes, debug + ) @functools.partial(jax.custom_vjp, nondiff_argnums=range(5, 10)) @@ -443,7 +445,9 @@ def _flash_attention_fwd( ): if save_residuals: raise NotImplementedError("Higher-order AD not supported") - o, l, m = _flash_attention(q, k, v, ab, segment_ids, True, causal, softmax_scale, block_sizes, debug) + o, l, m = _flash_attention( + q, k, v, ab, segment_ids, True, causal, softmax_scale, block_sizes, debug + ) return o, (q, k, v, ab, segment_ids, o, l, m)