From 39f0375a4fd4f9ec4abcbc11a5b51f2301739d22 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Dec 2024 19:54:48 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/gemm.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index 1a275ceed7..7024dcb9fe 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -225,9 +225,9 @@ def _gemm_bwd_rule( if dgrad_overlap_config["method"] == "bulk": # Set DGRAD output buffer to the comm buffer of WGRAD GEMM in order to do the # bulk RS overlap without an extra memcpy. - assert wgrad_overlap_config is not None, ( - f"Missing comm+GEMM overlap config for {wgrad_overlap_name}!" - ) + assert ( + wgrad_overlap_config is not None + ), f"Missing comm+GEMM overlap config for {wgrad_overlap_name}!" dgrad_pre_rs = tex.get_overlap_buffer(wgrad_overlap_name, False) # Copy transposed input into the DGRAD overlap buffer for bulk AG. @@ -275,7 +275,6 @@ def _gemm_bwd_rule( # Otherwise, if DGRAD overlap is RS overlap, DGRAD output is the extra output tensor dgrad = dgrad_extra_out - # WGRAD w/o Overlap: # AG+GEMM: ([B], M/P, K)^T --(AG)--> ([B], M, K)^T x ([B], M, N/P) --> (K, N/P) # @@ -663,13 +662,11 @@ def _fp8_gemm_bwd_rule( dgrad_scale = None if dgrad_overlap_config is not None: if dgrad_overlap_config["method"] == "bulk": - assert wgrad_overlap_config is not None, ( - f"Missing comm+GEMM overlap config for {wgrad_overlap_name}!" - ) + assert ( + wgrad_overlap_config is not None + ), f"Missing comm+GEMM overlap config for {wgrad_overlap_name}!" # Set WGRAD buffer as output of DGRAD in order to avoid a memcpy for bulk RS overlap - dgrad_pre_rs = jax.dlpack.from_dlpack( - tex.get_overlap_buffer(wgrad_overlap_name, False) - ) + dgrad_pre_rs = jax.dlpack.from_dlpack(tex.get_overlap_buffer(wgrad_overlap_name, False)) # Copy input into overlap buffer for all-gather copy_into_overlap_buffer(casted_x_t, dgrad_overlap_name, True) @@ -710,15 +707,12 @@ def _fp8_gemm_bwd_rule( if wgrad_overlap_config is not None: if wgrad_overlap_config["method"] == "bulk": # Get all-gathered input from DGRAD bulk overlap - casted_x_t = jax.dlpack.from_dlpack( - tex.get_overlap_buffer(dgrad_overlap_name, False) - ) + casted_x_t = jax.dlpack.from_dlpack(tex.get_overlap_buffer(dgrad_overlap_name, False)) elif tex.overlap_buffer_is_fp8(wgrad_overlap_name): # Set FP8 scale inverse for non-bulk AG overlap tex.set_overlap_buffer_scale_inverse( - wgrad_overlap_name, - jax.dlpack.to_dlpack(x_scale_inv) + wgrad_overlap_name, jax.dlpack.to_dlpack(x_scale_inv) ) # WGRAD: ([B], N, M) x ([B], K, M)^T = (N, K)