From b414ae9cc894735c3c39402c611249a828d55eaa Mon Sep 17 00:00:00 2001 From: Charlie Fu Date: Thu, 5 Dec 2024 10:29:36 -0600 Subject: [PATCH] Always use 64 as the block size of moe_align kernel to avoid lds out of limit (#303) * always use 64 as the block size to avoid lds out of limit * lint --- csrc/moe/moe_align_sum_kernels.cu | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index fff7ce34c838a..dd90c38d9a721 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -55,11 +55,11 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, __syncthreads(); // For each expert we accumulate the token counts from the different threads. - if (threadIdx.x < num_experts) { - tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; + for (int eid = threadIdx.x; eid < num_experts; eid += blockDim.x) { + tokens_cnts[index(num_experts, 0, eid)] = 0; for (int i = 1; i <= blockDim.x; ++i) { - tokens_cnts[index(num_experts, i, threadIdx.x)] += - tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; + tokens_cnts[index(num_experts, i, eid)] += + tokens_cnts[index(num_experts, i - 1, eid)]; } } @@ -83,10 +83,9 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, * For each expert, each thread processes the tokens of the corresponding * blocks and stores the corresponding expert_id for each block. */ - if (threadIdx.x < num_experts) { - for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; - i += block_size) { - expert_ids[i / block_size] = threadIdx.x; + for (int eid = threadIdx.x; eid < num_experts; eid += blockDim.x) { + for (int i = cumsum[eid]; i < cumsum[eid + 1]; i += block_size) { + expert_ids[i / block_size] = eid; } } @@ -141,7 +140,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { // calc needed amount of shared mem for `tokens_cnts` and `cumsum` // tensors - const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); + const int32_t num_thread = WARP_SIZE; const int32_t shared_mem = ((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);