Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
ruomingp committed Dec 17, 2024
1 parent 80e87b2 commit a0edff6
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
8 changes: 6 additions & 2 deletions axlearn/common/flash_attention/gpu_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ def test_triton_against_xla_ref(
block_q=block_size,
block_k=block_size,
)
jax_ref_out = mha_reference(q, k, v, bias, segment_ids, causal=causal, softmax_scale=softmax_scale)
jax_ref_out = mha_reference(
q, k, v, bias, segment_ids, causal=causal, softmax_scale=softmax_scale
)
if input_dtype == jnp.float16:
chex.assert_trees_all_close(jax_out, jax_ref_out, atol=0.005)
elif input_dtype == jnp.float32:
Expand Down Expand Up @@ -227,7 +229,9 @@ def test_cudnn_against_triton_ref(
softmax_scale = q.shape[-1] ** -0.5

# Compare outputs.
jax_out = cudnn_dot_product_attention(q, k, v, bias=None, causal=causal, softmax_scale=softmax_scale)
jax_out = cudnn_dot_product_attention(
q, k, v, bias=None, causal=causal, softmax_scale=softmax_scale
)
jax_ref_out = flash_attention(q, k, v, bias=None, causal=causal, softmax_scale=softmax_scale)
if dtype == jnp.bfloat16:
# We relax the atol to support bf16 in the unit test.
Expand Down
8 changes: 6 additions & 2 deletions axlearn/common/flash_attention/tpu_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,9 @@ def pallas_tpu_flash_attention(
block_sizes = LegacyBlockSizes.get_default(
batch_size, num_heads, q_seq_len, kv_seq_len, d_model
)
return _flash_attention(q, k, v, ab, segment_ids, False, causal, softmax_scale, block_sizes, debug)
return _flash_attention(
q, k, v, ab, segment_ids, False, causal, softmax_scale, block_sizes, debug
)


@functools.partial(jax.custom_vjp, nondiff_argnums=range(5, 10))
Expand Down Expand Up @@ -443,7 +445,9 @@ def _flash_attention_fwd(
):
if save_residuals:
raise NotImplementedError("Higher-order AD not supported")
o, l, m = _flash_attention(q, k, v, ab, segment_ids, True, causal, softmax_scale, block_sizes, debug)
o, l, m = _flash_attention(
q, k, v, ab, segment_ids, True, causal, softmax_scale, block_sizes, debug
)
return o, (q, k, v, ab, segment_ids, o, l, m)


Expand Down

0 comments on commit a0edff6

Please sign in to comment.