Skip to content

Commit

Permalink
Only launch splitkv combine kernel when necessary
Browse files Browse the repository at this point in the history
  • Loading branch information
poyenc committed Jan 1, 2025
1 parent ab1b16a commit b102083
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 162 deletions.
164 changes: 84 additions & 80 deletions example/ck_tile/01_fmha/fmha_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <array>
#include <cstring>
#include <functional>
#include <map>
#include <numeric>
#include <ostream>
#include <string>
Expand Down Expand Up @@ -176,61 +177,14 @@ auto get_elimit<FmhaFwdFp8>(std::string init_method)
}
}

int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits)
{
// If we have enough to almost fill the SMs, then just use 1 split
if(batch_nhead_mblocks >= 0.8f * num_SMs)
{
return 1;
}
max_splits = std::min({max_splits, num_SMs, num_n_blocks});
float max_efficiency = 0.f;
std::vector<float> efficiency;
efficiency.reserve(max_splits);
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
// (i.e. it's 11 splits anyway).
// So we check if the number of blocks per split is the same as the previous num_splits.
auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
return num_splits == 1 ||
ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
};
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
{
if(!is_split_eligible(num_splits))
{
efficiency.push_back(0.f);
}
else
{
float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs;
float eff = n_waves / ceil(n_waves);
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
if(eff > max_efficiency)
{
max_efficiency = eff;
}
efficiency.push_back(eff);
}
}
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
{
if(!is_split_eligible(num_splits))
{
continue;
}
if(efficiency[num_splits - 1] >= 0.85 * max_efficiency)
{
// printf("num_splits chosen = %d\n", num_splits);
return num_splits;
}
}
return 1;
}

int override_num_splits_if_necessary(
int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
int override_num_splits_if_necessary(int batch,
int nhead,
int max_seqlen_q,
int hdim_q,
int hdim_v,
float p_drop,
bool is_prefill,
int num_splits)
{
int device;
auto status = hipGetDevice(&device);
Expand All @@ -246,17 +200,41 @@ int override_num_splits_if_necessary(
return num_splits;
}

// tile size should match the generate.py
const int kM0 = 64;
const int kN1 = hdim_v;
const int kM0 = [&] {
// get kM0 for prefill phase
if(is_prefill)
{
return 128;
}

// get kM0 for decode phase
/// TODO: take dtype=fp8/bf8 into consideration
const std::map<int, int> hdim_to_m0 = {
{32, 32},
{64, 64},
// {96, 64},
{128, 64},
{256, 64},
};

for(auto [hdim, m0] : hdim_to_m0)
{
if(hdim_q <= hdim && hdim_v <= hdim)
{
return m0;
}
}

return 64; // meet unsupported hdim_q/hdim_v
}();
// const int kN1 = hdim_v;

const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0);
const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1);
// const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1); // always 1

if(num_splits < 1 && p_drop == 0.0f)
{
return num_splits_heuristic(
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128);
return num_splits_heuristic(batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 8);
}

return num_splits;
Expand Down Expand Up @@ -556,8 +534,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
// legalize num_splits according to other options
if(num_splits < 1)
{
num_splits = override_num_splits_if_necessary(
batch, nhead, max_seqlen_q, hdim_v, p_drop, num_splits);
num_splits = override_num_splits_if_necessary(batch,
nhead,
max_seqlen_q,
hdim_q,
hdim_v,
p_drop,
/*is_prefill=*/mode == mode_enum::group &&
0 < page_block_size,
num_splits);
}
if(128 < num_splits)
{
Expand Down Expand Up @@ -632,17 +617,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto [rotary_cos_host, rotary_sin_host] = generate_rotary_cos_sin<KDataType>(
std::max(shape_seqlen_q, shape_seqlen_k), rotary_dim, seed);

// lse_acc_host & o_acc_host are only used when 1 < num_spilts
ck_tile::HostTensor<LSEDataType> lse_acc_host(
1 < num_splits || use_kvcache
1 < num_splits
? std::array<ck_tile::index_t, 4>{shape_batch, nhead, num_splits, shape_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<OaccDataType> o_acc_host(
1 < num_splits || use_kvcache ? std::array<ck_tile::index_t, 5>{shape_batch,
nhead,
num_splits,
shape_seqlen_q,
hdim_v}
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
1 < num_splits ? std::array<ck_tile::index_t, 5>{shape_batch,
nhead,
num_splits,
shape_seqlen_q,
hdim_v}
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});

// batch mode of lse data layout is [batch, nhead, seqlen_q]
// group mode of lse data layout is [nhead, total_seqlen_q]
Expand Down Expand Up @@ -1043,9 +1029,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>)
{
args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer();
args.o_acc_ptr = o_acc_buf.GetDeviceBuffer();

// lse_acc_buf & o_acc_buf are only used when 1 < num_spilts
args.block_table_ptr =
(0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr);
args.batch_stride_block_table = batch_stride_block_table;
Expand All @@ -1057,13 +1041,33 @@ bool run(const ck_tile::ArgParser& arg_parser)

args.num_splits = num_splits;

args.stride_o_acc = stride_o_acc;
args.nhead_stride_lse_acc = nhead_stride_lse_acc;
args.nhead_stride_o_acc = nhead_stride_o_acc;
args.batch_stride_lse_acc = batch_stride_lse_acc;
args.batch_stride_o_acc = batch_stride_o_acc;
args.split_stride_lse_acc = split_stride_lse_acc;
args.split_stride_o_acc = split_stride_o_acc;
if(1 < num_splits)
{
args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer();
args.o_acc_ptr = o_acc_buf.GetDeviceBuffer();

args.stride_o_acc = stride_o_acc;
args.nhead_stride_lse_acc = nhead_stride_lse_acc;
args.nhead_stride_o_acc = nhead_stride_o_acc;
args.batch_stride_lse_acc = batch_stride_lse_acc;
args.batch_stride_o_acc = batch_stride_o_acc;
args.split_stride_lse_acc = split_stride_lse_acc;
args.split_stride_o_acc = split_stride_o_acc;
}
else
{
// following attribues are ignored by fmha_fwd_splitkv()
args.lse_acc_ptr = nullptr;
args.o_acc_ptr = nullptr;

args.stride_o_acc = 0;
args.nhead_stride_lse_acc = 0;
args.nhead_stride_o_acc = 0;
args.batch_stride_lse_acc = 0;
args.batch_stride_o_acc = 0;
args.split_stride_lse_acc = 0;
args.split_stride_o_acc = 0;
}
}
}
};
Expand Down
Loading

0 comments on commit b102083

Please sign in to comment.