Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 5, 2024
1 parent b1b51c3 commit 39f0375
Showing 1 changed file with 9 additions and 15 deletions.
24 changes: 9 additions & 15 deletions transformer_engine/jax/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,9 @@ def _gemm_bwd_rule(
if dgrad_overlap_config["method"] == "bulk":
# Set DGRAD output buffer to the comm buffer of WGRAD GEMM in order to do the
# bulk RS overlap without an extra memcpy.
assert wgrad_overlap_config is not None, (
f"Missing comm+GEMM overlap config for {wgrad_overlap_name}!"
)
assert (
wgrad_overlap_config is not None
), f"Missing comm+GEMM overlap config for {wgrad_overlap_name}!"
dgrad_pre_rs = tex.get_overlap_buffer(wgrad_overlap_name, False)

# Copy transposed input into the DGRAD overlap buffer for bulk AG.
Expand Down Expand Up @@ -275,7 +275,6 @@ def _gemm_bwd_rule(
# Otherwise, if DGRAD overlap is RS overlap, DGRAD output is the extra output tensor
dgrad = dgrad_extra_out


# WGRAD w/o Overlap:
# AG+GEMM: ([B], M/P, K)^T --(AG)--> ([B], M, K)^T x ([B], M, N/P) --> (K, N/P)
#
Expand Down Expand Up @@ -663,13 +662,11 @@ def _fp8_gemm_bwd_rule(
dgrad_scale = None
if dgrad_overlap_config is not None:
if dgrad_overlap_config["method"] == "bulk":
assert wgrad_overlap_config is not None, (
f"Missing comm+GEMM overlap config for {wgrad_overlap_name}!"
)
assert (
wgrad_overlap_config is not None
), f"Missing comm+GEMM overlap config for {wgrad_overlap_name}!"
# Set WGRAD buffer as output of DGRAD in order to avoid a memcpy for bulk RS overlap
dgrad_pre_rs = jax.dlpack.from_dlpack(
tex.get_overlap_buffer(wgrad_overlap_name, False)
)
dgrad_pre_rs = jax.dlpack.from_dlpack(tex.get_overlap_buffer(wgrad_overlap_name, False))
# Copy input into overlap buffer for all-gather
copy_into_overlap_buffer(casted_x_t, dgrad_overlap_name, True)

Expand Down Expand Up @@ -710,15 +707,12 @@ def _fp8_gemm_bwd_rule(
if wgrad_overlap_config is not None:
if wgrad_overlap_config["method"] == "bulk":
# Get all-gathered input from DGRAD bulk overlap
casted_x_t = jax.dlpack.from_dlpack(
tex.get_overlap_buffer(dgrad_overlap_name, False)
)
casted_x_t = jax.dlpack.from_dlpack(tex.get_overlap_buffer(dgrad_overlap_name, False))

elif tex.overlap_buffer_is_fp8(wgrad_overlap_name):
# Set FP8 scale inverse for non-bulk AG overlap
tex.set_overlap_buffer_scale_inverse(
wgrad_overlap_name,
jax.dlpack.to_dlpack(x_scale_inv)
wgrad_overlap_name, jax.dlpack.to_dlpack(x_scale_inv)
)

# WGRAD: ([B], N, M) x ([B], K, M)^T = (N, K)
Expand Down

0 comments on commit 39f0375

Please sign in to comment.