diff --git a/src/turbomind/kernels/gemm/moe_utils_v2.cu b/src/turbomind/kernels/gemm/moe_utils_v2.cu index 44fec6774..3309933db 100644 --- a/src/turbomind/kernels/gemm/moe_utils_v2.cu +++ b/src/turbomind/kernels/gemm/moe_utils_v2.cu @@ -250,7 +250,6 @@ __global__ void MoeScanKernel_v2(int* f2n, // [e*n] template= 1; m /= 2) { + sum_prob += __shfl_xor_sync((uint32_t)-1, sum_prob, m); } + sum_prob = fdividef(1.f, sum_prob); } - - PRAGMA_UNROLL - for (int m = threads_per_token / 2; m >= 1; m /= 2) { - sum_prob += __shfl_xor_sync((uint32_t)-1, sum_prob, m); + else { + sum_prob = 1.f; } - sum_prob = fdividef(1.f, sum_prob); - using WarpScan = cub::WarpScan; __shared__ typename WarpScan::TempStorage temp_storage[tokens_per_cta]; @@ -569,6 +574,7 @@ void invokeMoeGate_V2(int* f2n, // [e*n] -> n int tokens_padded, // round_up(n, 4) int experts, // E int experts_per_token, + bool softmax, bool norm_topk, float routed_scale, cudaStream_t st) @@ -602,29 +608,24 @@ void invokeMoeGate_V2(int* f2n, // [e*n] -> n tokens_padded, experts, experts_per_token, + softmax, norm_topk, routed_scale); }; auto fail = [&] { std::cerr << __FILE__ << "(" << __LINE__ << "): unsupported moe config: expert_num=" << experts - << ", top_k=" << experts_per_token << "\n"; + << ", top_k=" << experts_per_token << ", softmax=" << softmax << ", norm_topk=" << norm_topk << "\n"; std::abort(); }; + if (!softmax && norm_topk) { + // norm top-k is part of softmax impl + fail(); + } + if (experts <= 8) { if (experts_per_token <= 2) { - // MoeGateKernel_V2<2, 128><<>>(scales, - // (int8_t*)masks, - // accum, - // logits, - // log_tile, - // tiles, - // tokens, - // tokens_padded, - // experts); - - // std::cout << tokens << " " << experts << " " << experts_per_token << " " << tokens_padded << "\n"; invoke(_Int<8>, _Int<2>, _Int<8>, _Int<4>); } else { @@ -879,7 +880,7 @@ std::vector SampleBalanced(int token_num, int expert_num, int exp_per_tok, } template -__global__ void MoeMaskTopKGroups(float* logits, int token_num, int expert_num, int top_k) +__global__ void MoeSoftmaxMaskTopKGroups(float* logits, int token_num, int expert_num, int top_k) { constexpr int threads_per_token = max_expert_num / items_per_thread; @@ -896,11 +897,12 @@ __global__ void MoeMaskTopKGroups(float* logits, int token_num, int expert_num, for (int i = 0; i < items_per_thread; ++i) { data[i] = -std::numeric_limits::infinity(); } + // max logit in the group float max_val = -std::numeric_limits::infinity(); if (ti < token_num) { PRAGMA_UNROLL for (int i = 0; i < items_per_thread; i += access_size) { - const int e = ei * items_per_thread + i; + const int e = ei * items_per_thread + i; // blocked partition if (e < expert_num) { Ldg((Array&)data[i], &logits[ti * expert_num + e]); PRAGMA_UNROLL @@ -914,7 +916,8 @@ __global__ void MoeMaskTopKGroups(float* logits, int token_num, int expert_num, const int warp_ti = threadIdx.x % WARP_SIZE / threads_per_token; const int warp_ti_offset = warp_ti * threads_per_token; - bool alive = false; + bool alive = false; + float max_logit = 0; for (int k = 0; k < top_k; ++k) { int g_max_ei = ei; @@ -926,34 +929,58 @@ __global__ void MoeMaskTopKGroups(float* logits, int token_num, int expert_num, // tie breaking const auto active = __ballot_sync((uint32_t)-1, max_val == g_max_val); g_max_ei = __ffs(active >> (unsigned)warp_ti_offset) - 1; + if (k == 0) { + max_logit = g_max_val; + } if (ei == g_max_ei) { alive = true; max_val = -std::numeric_limits::infinity(); } } - if (!alive && ti < token_num) { - Array vec; - fill(vec, -std::numeric_limits::infinity()); + float sum_prob{}; + + PRAGMA_NO_UNROLL + for (int i = 0; i < items_per_thread; ++i) { + data[i] = expf(data[i] - max_logit); + sum_prob += data[i]; + } + + PRAGMA_UNROLL + for (int m = threads_per_token / 2; m >= 1; m /= 2) { + sum_prob += __shfl_xor_sync((uint32_t)-1, sum_prob, m); + } + + // mask dead logits + sum_prob = alive ? fdividef(1.f, sum_prob) : 0; + + PRAGMA_UNROLL + for (int i = 0; i < items_per_thread; ++i) { + data[i] *= sum_prob; + } + + if (ti < token_num) { PRAGMA_UNROLL for (int i = 0; i < items_per_thread; i += access_size) { const int e = ei * items_per_thread + i; if (e < expert_num) { - Store(&logits[ti * expert_num + e], vec); + Store(&logits[ti * expert_num + e], (Array&)data[i]); } } } } -void invokeMaskMoeTopKGroups(float* logits, int token_num, int expert_num, int group_size, int top_k, cudaStream_t st) +void invokeMoeSoftmaxMaskTopKGroups( + float* logits, int token_num, int expert_num, int group_size, int top_k, cudaStream_t st) { auto invoke = [&](auto max_expert_num, auto items_per_thread, auto vec_size) { constexpr int thrs_per_tok = max_expert_num.value / items_per_thread.value; constexpr int threads = 256; const int blocks = ceil_div(token_num, threads / thrs_per_tok); - MoeMaskTopKGroups + MoeSoftmaxMaskTopKGroups <<>>(logits, token_num, expert_num, top_k); }; + if (expert_num == 160 && group_size == 20) { return invoke(_Int<160>, _Int<20>, _Int<4>); } diff --git a/src/turbomind/kernels/gemm/moe_utils_v2.h b/src/turbomind/kernels/gemm/moe_utils_v2.h index d53de1354..4a603a07b 100644 --- a/src/turbomind/kernels/gemm/moe_utils_v2.h +++ b/src/turbomind/kernels/gemm/moe_utils_v2.h @@ -21,6 +21,7 @@ void invokeMoeGate_V2(int* f2n, int tokens_padded, int experts, int exp_per_tok, + bool softmax, bool norm_topk, float routed_scale, cudaStream_t st); @@ -58,7 +59,8 @@ void invokeMoeReduce(T* dst, float dst_scale, cudaStream_t st); -void invokeMaskMoeTopKGroups(float* logits, int token_num, int expert_num, int group_size, int top_k, cudaStream_t st); +void invokeMoeSoftmaxMaskTopKGroups( + float* logits, int token_num, int expert_num, int group_size, int top_k, cudaStream_t st); // Sample `e` from `E` experts uniformly for every token std::vector SampleUniform(int token_num, int expert_num, int exp_per_tok, std::mt19937& g); diff --git a/src/turbomind/kernels/gemm/test/test_moe_utils.cu b/src/turbomind/kernels/gemm/test/test_moe_utils.cu index 4b2ea6a83..1fb6fe0c6 100644 --- a/src/turbomind/kernels/gemm/test/test_moe_utils.cu +++ b/src/turbomind/kernels/gemm/test/test_moe_utils.cu @@ -205,7 +205,12 @@ bool test_moe_gate(int tokens, // cudaMemPrefetchAsync(scales.data().get(), sizeof(float) * scales.size(), 0); cudaMemPrefetchAsync(logits.data().get(), sizeof(float) * logits.size(), 0); - // invokeMaskMoeTopKGroups(logits.data().get(), tokens, expert_num, expert_num / 8, 3, nullptr); + bool softmax = true; + + if (1) { + invokeMoeSoftmaxMaskTopKGroups(logits.data().get(), tokens, expert_num, expert_num / 8, 8, nullptr); + softmax = false; + } for (int i = 0; i < 1; ++i) { gemm::CacheFlushing::flush(); @@ -222,6 +227,7 @@ bool test_moe_gate(int tokens, // tokens_padded, expert_num, experts_per_token, + softmax, false, 1.f, nullptr); @@ -307,8 +313,8 @@ bool test_moe_gate(int tokens, // // thrust::host_vector tile_offsets(tape.max_ctas); // std::cout << tape.max_ctas << std::endl; - // cudaMemcpy(tile_offsets.data(), tape.tile_offsets, sizeof(int4) * tile_offsets.size(), cudaMemcpyDefault); - // cudaDeviceSynchronize(); + // cudaMemcpy(tile_offsets.data(), tape.tile_offsets, sizeof(int4) * tile_offsets.size(), + // cudaMemcpyDefault); cudaDeviceSynchronize(); // std::cout << "coords:\n"; // int last = -1; @@ -342,6 +348,7 @@ int main() // test_moe_gate(8, 60, 4, tape, tiling); test_moe_gate(16, 160, 6, tape, tiling); + return 0; for (int i = 1; i < 16384; ++i) { diff --git a/src/turbomind/models/llama/moe_ffn_layer.cc b/src/turbomind/models/llama/moe_ffn_layer.cc index 390d14754..038d54f97 100644 --- a/src/turbomind/models/llama/moe_ffn_layer.cc +++ b/src/turbomind/models/llama/moe_ffn_layer.cc @@ -110,9 +110,12 @@ void MoeFfnLayer::forward(T* output, const T* input, int tokens, int layer_id // dump_logits(tokens, layer_id); + bool softmax = true; if (param_.topk_method == "group_limited_greedy") { - invokeMaskMoeTopKGroups(logits_, tokens, expert_num, expert_num / param_.n_group, param_.topk_group, stream_); + invokeMoeSoftmaxMaskTopKGroups( + logits_, tokens, expert_num, expert_num / param_.n_group, param_.topk_group, stream_); sync_check_cuda_error(); + softmax = false; } /// TODO: fix illegal memory access even if NaN are present in logits @@ -127,6 +130,7 @@ void MoeFfnLayer::forward(T* output, const T* input, int tokens, int layer_id padded, expert_num, param_.experts_per_token, + softmax, param_.norm_topk_prob, param_.routed_scale, stream_);