Skip to content

Commit

Permalink
Always use 64 as the block size of moe_align kernel to avoid lds out …
Browse files Browse the repository at this point in the history
…of limit (#303)

* always use 64 as the block size to avoid lds out of limit

* lint
  • Loading branch information
charlifu authored Dec 5, 2024
1 parent ccdb5b8 commit b414ae9
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions csrc/moe/moe_align_sum_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
__syncthreads();

// For each expert we accumulate the token counts from the different threads.
if (threadIdx.x < num_experts) {
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int eid = threadIdx.x; eid < num_experts; eid += blockDim.x) {
tokens_cnts[index(num_experts, 0, eid)] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
tokens_cnts[index(num_experts, i, eid)] +=
tokens_cnts[index(num_experts, i - 1, eid)];
}
}

Expand All @@ -83,10 +83,9 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
for (int eid = threadIdx.x; eid < num_experts; eid += blockDim.x) {
for (int i = cumsum[eid]; i < cumsum[eid + 1]; i += block_size) {
expert_ids[i / block_size] = eid;
}
}

Expand Down Expand Up @@ -141,7 +140,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
const int32_t num_thread = WARP_SIZE;
const int32_t shared_mem =
((num_thread + 1) * num_experts + (num_experts + 1)) *
sizeof(int32_t);
Expand Down

0 comments on commit b414ae9

Please sign in to comment.