Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MoE][PyTorch] Add mask-based MoE permutation #1373

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

hxbai
Copy link

@hxbai hxbai commented Dec 13, 2024

Description

Add mask-based token permutation and local chunk permutation fused kernels. These kernels are implemented with OpenAI Triton.

Related commit in Megatron-LM NVIDIA/Megatron-LM@ac0474d

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

Please list the changes introduced in this PR:

  • Non-breaking API changes in te.pytorch.permutation.moe_permute and te.pytorch.permutation.moe_unpermute
  • Add new APIs of te.pytorch.permutation.moe_sort_chunks_by_indices

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@hxbai hxbai changed the title [MoE][Common/PyTorch] Add mask-based MoE permutation [MoE][PyTorch] Add mask-based MoE permutation Dec 13, 2024
@phu0ngng phu0ngng self-requested a review January 8, 2025 15:20
]


class _moe_permute(torch.autograd.Function):
"""functional Permute"""
class _moe_permute_indice_map(torch.autograd.Function):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class _moe_permute_indice_map(torch.autograd.Function):
class _moe_permute_index_map(torch.autograd.Function):

We should make sure to use "index" in user-facing APIs like moe_permute/moe_unpermute.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, modified.

import warnings
from typing import Tuple
import torch

import transformer_engine_torch as tex
from .constants import TE_DType
from .float8_tensor import Float8Tensor
import transformer_engine.pytorch.triton.permutation as triton_permuataion
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
import transformer_engine.pytorch.triton.permutation as triton_permuataion
import transformer_engine.pytorch.triton.permutation as triton_permutation

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

Comment on lines 292 to 295
if ctx.fp8:
assert isinstance(
permuted_act_grad, Float8Tensor
), "Grad of the output must be in Float8Tensor type for FP8 moe_permute."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't we decouple FP8 in the forward and backward?

Suggested change
if ctx.fp8:
assert isinstance(
permuted_act_grad, Float8Tensor
), "Grad of the output must be in Float8Tensor type for FP8 moe_permute."
fp8 = isinstance(permuted_act_grad, Float8Tensor)
if fp8:

If there are no obstacles, we could also do the same thing for _moe_unpermute_mask_map and _moe_chunk_sort.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified. Now for bwd, it would follow the dtype of the grad tensor.

# Results Check
#
###################################################################################################################################
tols = dtype_tols(te_dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we expect bit-wise exact results?

Suggested change
tols = dtype_tols(te_dtype)
tols = { "atol": 0, "rtol": 0 }

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, Tim. I made some modifications here; it now uses two types of tols.

We cannot use bit-wise matching for all cases. Firstly, for fp8 case of the fusion, the function in PyTorch version uses fp32. Besides, there are reductions in the unpermutation kernels, and we cannot get bit-wise matching results for permute bwd, unpermute fwd, and unpermute bwd with probs.

For other cases, I modified to bit-wise matching. Is this OK for you?

# Results Check
#
###################################################################################################################################
tols = dtype_tols(te_dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should expect bit-wise exact results.

Suggested change
tols = dtype_tols(te_dtype)
tols = { "atol": 0, "rtol": 0 }

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like the one above, I changed to bit-wise matching except for fp8.

@timmoon10 timmoon10 self-requested a review January 8, 2025 21:57
mask=(offset < num_tokens),
other=0,
).to(tl.int64)
expert_token_cumsum = tl.cumsum(expert_token_mask) * expert_token_mask
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An interesting way to exclude the zero token_mask. Happy to learn!

Comment on lines 61 to 67
chunk_cumsum = tl.load(
row_id_map_ptr + pid_m * num_tokens + offset, mask=(offset < num_tokens), other=0
)

workspace_off = tl.arange(0, WORKSPACE_LOAD_WIDTH)
chunk_sums = tl.load(workspace_ptr + workspace_off, mask=workspace_off < chunk_idx)
chunk_cumsum = tl.where(chunk_cumsum == 0, -1, chunk_cumsum + tl.sum(chunk_sums) - 1)
Copy link
Collaborator

@phu0ngng phu0ngng Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These three names chuck_cumsum, chuck_sums, and chunk_cumsum are quite confusing.
If I understand it correctly, I suggest to rename them to:

  • chuck_cumsum -> row_id_within_token_block
  • chuck_sums -> n_tokens_per_expert
  • chuck_cumsum -> row_id

In addition, I think we should move the -1 to the pass1 as it is the correction for the calculation of expert_token_cumsum, as:

expert_token_cumsum = (tl.cumsum(expert_token_mask) - 1) * expert_token_mask

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. You are right. I modified these names (renamed chunk_sums to n_tokens_per_block rather than n_tokens_per_expert).

For the -1, if we move it to pass1, then we cannot easily distinguish the row_id: 0 and the mask: 0 and we need extra ways to handle whether it is masked out. So, I still left the -1 in the pass2. Do you think it is OK?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants