Skip to content

Commit

Permalink
fix grouped gating for deepseek-v2
Browse files Browse the repository at this point in the history
  • Loading branch information
lzhangzz committed Jan 14, 2025
1 parent 46aa4e5 commit 6f603bc
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 40 deletions.
97 changes: 62 additions & 35 deletions src/turbomind/kernels/gemm/moe_utils_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ __global__ void MoeScanKernel_v2(int* f2n, // [e*n]

template<int max_expert_num,
int max_top_k,
// bool norm_top_k,
int items_per_thread,
int block_dim,
int access_size,
Expand All @@ -265,6 +264,7 @@ __global__ void MoeGateKernel_v8(float* scales, // [e,n]
int token_num_padded,
int expert_num,
int top_k,
bool softmax,
bool norm_topk,
float routed_scale)
{
Expand Down Expand Up @@ -426,8 +426,7 @@ __global__ void MoeGateKernel_v8(float* scales, // [e,n]
unsigned mask = (unsigned)-1;
float max_logit;

int count{};
float sum_prob{};
int count{};

const int warp_ti_offset = warp_ti * threads_per_token;

Expand All @@ -442,6 +441,7 @@ __global__ void MoeGateKernel_v8(float* scales, // [e,n]
max_bit = bit;
max_val = data[i];
}
// weird thing that nvcc tends to use funnel shift for `bit <<= 1`
asm("shl.b32 %0, %1, 1;\n" : "=r"(bit) : "r"(bit));
}

Expand Down Expand Up @@ -484,21 +484,26 @@ __global__ void MoeGateKernel_v8(float* scales, // [e,n]
}
}

PRAGMA_UNROLL
for (int i = 0; i < items_per_thread; ++i) {
if (!norm_topk || used[i]) {
data[i] = expf(data[i] - max_logit);
sum_prob += data[i];
float sum_prob{};

if (softmax) {
PRAGMA_UNROLL
for (int i = 0; i < items_per_thread; ++i) {
if (!norm_topk || used[i]) {
data[i] = expf(data[i] - max_logit);
sum_prob += data[i];
}
}
PRAGMA_UNROLL
for (int m = threads_per_token / 2; m >= 1; m /= 2) {
sum_prob += __shfl_xor_sync((uint32_t)-1, sum_prob, m);
}
sum_prob = fdividef(1.f, sum_prob);
}

PRAGMA_UNROLL
for (int m = threads_per_token / 2; m >= 1; m /= 2) {
sum_prob += __shfl_xor_sync((uint32_t)-1, sum_prob, m);
else {
sum_prob = 1.f;
}

sum_prob = fdividef(1.f, sum_prob);

using WarpScan = cub::WarpScan<int, threads_per_token>;
__shared__ typename WarpScan::TempStorage temp_storage[tokens_per_cta];

Expand Down Expand Up @@ -569,6 +574,7 @@ void invokeMoeGate_V2(int* f2n, // [e*n] -> n
int tokens_padded, // round_up(n, 4)
int experts, // E
int experts_per_token,
bool softmax,
bool norm_topk,
float routed_scale,
cudaStream_t st)
Expand Down Expand Up @@ -602,29 +608,24 @@ void invokeMoeGate_V2(int* f2n, // [e*n] -> n
tokens_padded,
experts,
experts_per_token,
softmax,
norm_topk,
routed_scale);
};

auto fail = [&] {
std::cerr << __FILE__ << "(" << __LINE__ << "): unsupported moe config: expert_num=" << experts
<< ", top_k=" << experts_per_token << "\n";
<< ", top_k=" << experts_per_token << ", softmax=" << softmax << ", norm_topk=" << norm_topk << "\n";
std::abort();
};

if (!softmax && norm_topk) {
// norm top-k is part of softmax impl
fail();
}

if (experts <= 8) {
if (experts_per_token <= 2) {
// MoeGateKernel_V2<2, 128><<<cdiv(tokens, 128), 128, 0, st>>>(scales,
// (int8_t*)masks,
// accum,
// logits,
// log_tile,
// tiles,
// tokens,
// tokens_padded,
// experts);

// std::cout << tokens << " " << experts << " " << experts_per_token << " " << tokens_padded << "\n";
invoke(_Int<8>, _Int<2>, _Int<8>, _Int<4>);
}
else {
Expand Down Expand Up @@ -879,7 +880,7 @@ std::vector<int> SampleBalanced(int token_num, int expert_num, int exp_per_tok,
}

template<int max_expert_num, int items_per_thread, int access_size>
__global__ void MoeMaskTopKGroups(float* logits, int token_num, int expert_num, int top_k)
__global__ void MoeSoftmaxMaskTopKGroups(float* logits, int token_num, int expert_num, int top_k)
{
constexpr int threads_per_token = max_expert_num / items_per_thread;

Expand All @@ -896,11 +897,12 @@ __global__ void MoeMaskTopKGroups(float* logits, int token_num, int expert_num,
for (int i = 0; i < items_per_thread; ++i) {
data[i] = -std::numeric_limits<float>::infinity();
}
// max logit in the group
float max_val = -std::numeric_limits<float>::infinity();
if (ti < token_num) {
PRAGMA_UNROLL
for (int i = 0; i < items_per_thread; i += access_size) {
const int e = ei * items_per_thread + i;
const int e = ei * items_per_thread + i; // blocked partition
if (e < expert_num) {
Ldg((Array<float, access_size>&)data[i], &logits[ti * expert_num + e]);
PRAGMA_UNROLL
Expand All @@ -914,7 +916,8 @@ __global__ void MoeMaskTopKGroups(float* logits, int token_num, int expert_num,
const int warp_ti = threadIdx.x % WARP_SIZE / threads_per_token;
const int warp_ti_offset = warp_ti * threads_per_token;

bool alive = false;
bool alive = false;
float max_logit = 0;

for (int k = 0; k < top_k; ++k) {
int g_max_ei = ei;
Expand All @@ -926,34 +929,58 @@ __global__ void MoeMaskTopKGroups(float* logits, int token_num, int expert_num,
// tie breaking
const auto active = __ballot_sync((uint32_t)-1, max_val == g_max_val);
g_max_ei = __ffs(active >> (unsigned)warp_ti_offset) - 1;
if (k == 0) {
max_logit = g_max_val;
}
if (ei == g_max_ei) {
alive = true;
max_val = -std::numeric_limits<float>::infinity();
}
}

if (!alive && ti < token_num) {
Array<float, access_size> vec;
fill(vec, -std::numeric_limits<float>::infinity());
float sum_prob{};

PRAGMA_NO_UNROLL
for (int i = 0; i < items_per_thread; ++i) {
data[i] = expf(data[i] - max_logit);
sum_prob += data[i];
}

PRAGMA_UNROLL
for (int m = threads_per_token / 2; m >= 1; m /= 2) {
sum_prob += __shfl_xor_sync((uint32_t)-1, sum_prob, m);
}

// mask dead logits
sum_prob = alive ? fdividef(1.f, sum_prob) : 0;

PRAGMA_UNROLL
for (int i = 0; i < items_per_thread; ++i) {
data[i] *= sum_prob;
}

if (ti < token_num) {
PRAGMA_UNROLL
for (int i = 0; i < items_per_thread; i += access_size) {
const int e = ei * items_per_thread + i;
if (e < expert_num) {
Store(&logits[ti * expert_num + e], vec);
Store(&logits[ti * expert_num + e], (Array<float, access_size>&)data[i]);
}
}
}
}

void invokeMaskMoeTopKGroups(float* logits, int token_num, int expert_num, int group_size, int top_k, cudaStream_t st)
void invokeMoeSoftmaxMaskTopKGroups(
float* logits, int token_num, int expert_num, int group_size, int top_k, cudaStream_t st)
{
auto invoke = [&](auto max_expert_num, auto items_per_thread, auto vec_size) {
constexpr int thrs_per_tok = max_expert_num.value / items_per_thread.value;
constexpr int threads = 256;
const int blocks = ceil_div(token_num, threads / thrs_per_tok);
MoeMaskTopKGroups<max_expert_num.value, items_per_thread.value, vec_size.value>
MoeSoftmaxMaskTopKGroups<max_expert_num.value, items_per_thread.value, vec_size.value>
<<<blocks, threads, 0, st>>>(logits, token_num, expert_num, top_k);
};

if (expert_num == 160 && group_size == 20) {
return invoke(_Int<160>, _Int<20>, _Int<4>);
}
Expand Down
4 changes: 3 additions & 1 deletion src/turbomind/kernels/gemm/moe_utils_v2.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ void invokeMoeGate_V2(int* f2n,
int tokens_padded,
int experts,
int exp_per_tok,
bool softmax,
bool norm_topk,
float routed_scale,
cudaStream_t st);
Expand Down Expand Up @@ -58,7 +59,8 @@ void invokeMoeReduce(T* dst,
float dst_scale,
cudaStream_t st);

void invokeMaskMoeTopKGroups(float* logits, int token_num, int expert_num, int group_size, int top_k, cudaStream_t st);
void invokeMoeSoftmaxMaskTopKGroups(
float* logits, int token_num, int expert_num, int group_size, int top_k, cudaStream_t st);

// Sample `e` from `E` experts uniformly for every token
std::vector<int> SampleUniform(int token_num, int expert_num, int exp_per_tok, std::mt19937& g);
Expand Down
13 changes: 10 additions & 3 deletions src/turbomind/kernels/gemm/test/test_moe_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,12 @@ bool test_moe_gate(int tokens, //
cudaMemPrefetchAsync(scales.data().get(), sizeof(float) * scales.size(), 0);
cudaMemPrefetchAsync(logits.data().get(), sizeof(float) * logits.size(), 0);

// invokeMaskMoeTopKGroups(logits.data().get(), tokens, expert_num, expert_num / 8, 3, nullptr);
bool softmax = true;

if (1) {
invokeMoeSoftmaxMaskTopKGroups(logits.data().get(), tokens, expert_num, expert_num / 8, 8, nullptr);
softmax = false;
}

for (int i = 0; i < 1; ++i) {
gemm::CacheFlushing::flush();
Expand All @@ -222,6 +227,7 @@ bool test_moe_gate(int tokens, //
tokens_padded,
expert_num,
experts_per_token,
softmax,
false,
1.f,
nullptr);
Expand Down Expand Up @@ -307,8 +313,8 @@ bool test_moe_gate(int tokens, //

// thrust::host_vector<int4> tile_offsets(tape.max_ctas);
// std::cout << tape.max_ctas << std::endl;
// cudaMemcpy(tile_offsets.data(), tape.tile_offsets, sizeof(int4) * tile_offsets.size(), cudaMemcpyDefault);
// cudaDeviceSynchronize();
// cudaMemcpy(tile_offsets.data(), tape.tile_offsets, sizeof(int4) * tile_offsets.size(),
// cudaMemcpyDefault); cudaDeviceSynchronize();

// std::cout << "coords:\n";
// int last = -1;
Expand Down Expand Up @@ -342,6 +348,7 @@ int main()
// test_moe_gate(8, 60, 4, tape, tiling);

test_moe_gate(16, 160, 6, tape, tiling);

return 0;

for (int i = 1; i < 16384; ++i) {
Expand Down
6 changes: 5 additions & 1 deletion src/turbomind/models/llama/moe_ffn_layer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,12 @@ void MoeFfnLayer<T>::forward(T* output, const T* input, int tokens, int layer_id

// dump_logits(tokens, layer_id);

bool softmax = true;
if (param_.topk_method == "group_limited_greedy") {
invokeMaskMoeTopKGroups(logits_, tokens, expert_num, expert_num / param_.n_group, param_.topk_group, stream_);
invokeMoeSoftmaxMaskTopKGroups(
logits_, tokens, expert_num, expert_num / param_.n_group, param_.topk_group, stream_);
sync_check_cuda_error();
softmax = false;
}

/// TODO: fix illegal memory access even if NaN are present in logits
Expand All @@ -127,6 +130,7 @@ void MoeFfnLayer<T>::forward(T* output, const T* input, int tokens, int layer_id
padded,
expert_num,
param_.experts_per_token,
softmax,
param_.norm_topk_prob,
param_.routed_scale,
stream_);
Expand Down

0 comments on commit 6f603bc

Please sign in to comment.