-
Notifications
You must be signed in to change notification settings - Fork 348
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MoE][PyTorch] Add mask-based MoE permutation #1373
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Hongxiao Bai <[email protected]>
for more information, see https://pre-commit.ci Signed-off-by: Hongxiao Bai <[email protected]>
Signed-off-by: Hongxiao Bai <[email protected]>
Signed-off-by: Hongxiao Bai <[email protected]>
6160104
to
ca94d72
Compare
] | ||
|
||
|
||
class _moe_permute(torch.autograd.Function): | ||
"""functional Permute""" | ||
class _moe_permute_indice_map(torch.autograd.Function): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class _moe_permute_indice_map(torch.autograd.Function): | |
class _moe_permute_index_map(torch.autograd.Function): |
We should make sure to use "index" in user-facing APIs like moe_permute
/moe_unpermute
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, modified.
import warnings | ||
from typing import Tuple | ||
import torch | ||
|
||
import transformer_engine_torch as tex | ||
from .constants import TE_DType | ||
from .float8_tensor import Float8Tensor | ||
import transformer_engine.pytorch.triton.permutation as triton_permuataion |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
import transformer_engine.pytorch.triton.permutation as triton_permuataion | |
import transformer_engine.pytorch.triton.permutation as triton_permutation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
if ctx.fp8: | ||
assert isinstance( | ||
permuted_act_grad, Float8Tensor | ||
), "Grad of the output must be in Float8Tensor type for FP8 moe_permute." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couldn't we decouple FP8 in the forward and backward?
if ctx.fp8: | |
assert isinstance( | |
permuted_act_grad, Float8Tensor | |
), "Grad of the output must be in Float8Tensor type for FP8 moe_permute." | |
fp8 = isinstance(permuted_act_grad, Float8Tensor) | |
if fp8: |
If there are no obstacles, we could also do the same thing for _moe_unpermute_mask_map
and _moe_chunk_sort
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified. Now for bwd, it would follow the dtype of the grad tensor.
# Results Check | ||
# | ||
################################################################################################################################### | ||
tols = dtype_tols(te_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we expect bit-wise exact results?
tols = dtype_tols(te_dtype) | |
tols = { "atol": 0, "rtol": 0 } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, Tim. I made some modifications here; it now uses two types of tols
.
We cannot use bit-wise matching for all cases. Firstly, for fp8 case of the fusion, the function in PyTorch version uses fp32. Besides, there are reductions in the unpermutation kernels, and we cannot get bit-wise matching results for permute bwd, unpermute fwd, and unpermute bwd with probs.
For other cases, I modified to bit-wise matching. Is this OK for you?
tests/pytorch/test_permutation.py
Outdated
# Results Check | ||
# | ||
################################################################################################################################### | ||
tols = dtype_tols(te_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should expect bit-wise exact results.
tols = dtype_tols(te_dtype) | |
tols = { "atol": 0, "rtol": 0 } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like the one above, I changed to bit-wise matching except for fp8.
mask=(offset < num_tokens), | ||
other=0, | ||
).to(tl.int64) | ||
expert_token_cumsum = tl.cumsum(expert_token_mask) * expert_token_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An interesting way to exclude the zero token_mask. Happy to learn!
chunk_cumsum = tl.load( | ||
row_id_map_ptr + pid_m * num_tokens + offset, mask=(offset < num_tokens), other=0 | ||
) | ||
|
||
workspace_off = tl.arange(0, WORKSPACE_LOAD_WIDTH) | ||
chunk_sums = tl.load(workspace_ptr + workspace_off, mask=workspace_off < chunk_idx) | ||
chunk_cumsum = tl.where(chunk_cumsum == 0, -1, chunk_cumsum + tl.sum(chunk_sums) - 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These three names chuck_cumsum
, chuck_sums
, and chunk_cumsum
are quite confusing.
If I understand it correctly, I suggest to rename them to:
chuck_cumsum
->row_id_within_token_block
chuck_sums
-> n_tokens_per_expertchuck_cumsum
->row_id
In addition, I think we should move the -1
to the pass1
as it is the correction for the calculation of expert_token_cumsum
, as:
expert_token_cumsum = (tl.cumsum(expert_token_mask) - 1) * expert_token_mask
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. You are right. I modified these names (renamed chunk_sums
to n_tokens_per_block
rather than n_tokens_per_expert
).
For the -1
, if we move it to pass1
, then we cannot easily distinguish the row_id: 0
and the mask: 0
and we need extra ways to handle whether it is masked out. So, I still left the -1
in the pass2
. Do you think it is OK?
Signed-off-by: Hongxiao Bai <[email protected]>
Signed-off-by: Hongxiao Bai <[email protected]>
Signed-off-by: Hongxiao Bai <[email protected]>
Signed-off-by: Hongxiao Bai <[email protected]>
Description
Add mask-based token permutation and local chunk permutation fused kernels. These kernels are implemented with OpenAI Triton.
Related commit in Megatron-LM NVIDIA/Megatron-LM@ac0474d
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
te.pytorch.permutation.moe_permute
andte.pytorch.permutation.moe_unpermute
te.pytorch.permutation.moe_sort_chunks_by_indices
Checklist: