From 37cdbf4f0ec88ba5064f46c3370633b5950bc7ae Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Fri, 20 Dec 2024 14:41:01 +0800 Subject: [PATCH] [CK_TILE] Add fmha fwd N-Warp S-Shuffle pipeline (fmha fwd splitkv pipeline variant) (#1705) * Add check for zero values * Add static assertions * Remove invalid option '-e' in smoke_test.sh * Use correct path of smoke_test.sh * Avoid zero-sized shared memory array * Add warning comment * Replace expr by integer_divide_ceil() call * Use more readable constant names * Write down assumption as static assertion * Add more diagnostic error messages * Fix wrong BlockWarps when using default pipeline policy * Add more static assertions for A LDS desc * Allow using vector size < 8 for data type fp16/bf16 * Align vector size between DRAM dist & LDS desc * Remove no-longer used func decl * Fix wrong displayed piepline name * Undo policy template changes for tile_example_gemm_basic * Add missing space and make error message stands out * Unify print precision * Add missing include directive * Replace constant 64 by get_warp_size() call * Replace constant 128 by named variable: BankLength * Add kAMBlock/kBNBlock attributes * Allow usig different A/B warp dist for multiple blocks * Add helper function to get warp dist encodings * Add 4x64x4 fp16 warp gemm attribute impl * Complete the A/B warp dist encoding logic * Fix wrong thread mapping for C matrix * Use smaller vector size for small tile * Add static assert to block unsupported warp gemm impl * Extract common code out as helper method * Add 4x64x16 fp16 warp gemm type alias * Add comment to warning developers * Undo WarpGemmAtrributeMfma<> changes * Use more clear static assertion error message * Add trivial wrapper to get warp dstr encodings * Only transpose warp gemm result if it's square * Fix compilation error * Support multi-block warp gemm (on N direction) * Remove duplicated code * Fix output encoding of warp gemm * Fix wrong shape of WarpGemmAtrributeMfmaIterateK<> * Remove unused code * Fix wrong shape of WarpGemmAttributeMfmaImplF16F16F32M4N64K4 * Add type config for bf16_t * Add 4x64x16 bf16 warp gemm * Update WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution * Add 64x4x4 fp16/bf16 warp gemm impl * Add 64x4x16 fp16/bf16 warp gemm * Add static assertion for better error diagnostic * Get Q dram dstr directly form block gemm * Add missing header: fused_moe.hpp * Allow specifying different warp-gemm for gemm0 & gemm1 * Store P matrix into LDS before gemm1 * Fix inconsistant kernel name * Remove constraint on gemm0 & gemm1 block warps * Remove unsupported vector size from checking list * Allow using 4x64x16 warp gemm for gemm0 * Finish policy customization * Finish pipeline modification F# * Use block warps in codegen * Fix wrong rank of m_lds_window origin * Use better distributed tensor * Make P-store earlier * Remove duplicated experssions * Remove unnecessary tile window * Create new files for new splitkv pipeline * Separate old/new pipeline codegen logic * Sync changes form develop * Undo gemm kernel/pipeline changes * Undo gemm example changes * Remove blank lines * Fix typo * Use new warp gemm interface * Fix link error * Fix wrong pipeline tag * Fix more link error * Avoid unnecessary padding * Always use vector load for K * Padding on fastest dimension when necessary * Force padding Q on hdim_q * Set high dimension padding flag to false * Re-format headers * Use warps=<1, 4, 1> for both gemm0 & gemm1 * Fix complilation errors * Remove m/l shuffle logics * Ignore duplicate data when write lse_acc * Use gemm0 block warps as lds tile width * Remove hard-coded numbers * Fix wrong distribution width * Remove unnecessary code * Add s_barrier before writing to LDS * Store Q into LDS before gemm0 * Fix wrong Q tile size * Use simple Q lds descriptor for debuging * Use more realistic Q lds descriptor * Add comment & use better variable name * Make Q lds space not overlapped with others * Remove unnecessary block_tile_reduce_sync() call * Move Q load statements * Move block_sync_lds() right before use * Re-order instructions * Remove necessary lambda expression * Use 8 threads on kMaxSplits direction while doing reduction * Tiny correction for using 8 threads on kMaxSplits direction for combine kernel * Padding num_split direction of o_acc tile window to 4x * Update splitkv combine pipeline design * Add kN1 back to splitkv combine pipeline problem * Fix compilation errors * Add missing template parameter * Fix wrong splitkv combine kernel name * Fix wrong origin * Fix wrong LDS descriptor shape * Fix sync & reduction logics * Remove unnecessary static assertions * Extract tile size computation logics * Make sure we can reuse padding flags in combine kernels * Rename variables * Use OaccDataType in BlockFmhaSplitKVCombinePipelineTileSizes<> * Remove unnecessary static assertion * Fix function name typo * Add constraint on kN1 template parameter * Hide K tile loading latency in earlier iteration * Fix wrong splitkv kernel name * Use s_shuffling to replace p_shuffling which removes the needs of cross-warp reduction * Rename pipeline * Fix wrong pipeline name attribute * Add GetAlignmentQ() for NWarpSShuffle pipeline * Separate Q tile into dram tile & register tile concepts * Remove non-squre warp gemm transpose c type alias * Fallback tile size changes for fmha fwd splitkv * Remove redundant change * Refine naming for the S tile * Use better naming of the S tile dstr (read from lds) * Share Q lds with K lds * Tiny change * Fix with using static_for for passing CI checking --------- Co-authored-by: Qianfeng Zhang --- .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 1 + .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 42 +- .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 85 +- example/ck_tile/01_fmha/fmha_fwd.hpp | 2 - .../core/arch/amd_buffer_addressing.hpp | 4 +- .../core/tensor/static_distributed_tensor.hpp | 1 + include/ck_tile/ops/fmha.hpp | 2 + .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 6 +- .../fmha_fwd_splitkv_combine_kernel.hpp | 56 +- .../fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 9 +- ...lock_fmha_fwd_splitkv_combine_pipeline.hpp | 83 +- ...plitkv_combine_pipeline_default_policy.hpp | 173 ++-- ...litkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp | 794 ++++++++++++++++++ ...nwarp_sshuffle_qr_ks_vs_default_policy.hpp | 226 +++++ .../pipeline/block_fmha_pipeline_problem.hpp | 36 +- ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 55 +- .../ops/fmha/pipeline/tile_fmha_shape.hpp | 2 - ...block_gemm_areg_bsmem_creg_one_warp_v1.hpp | 44 +- .../block/block_gemm_areg_bsmem_creg_v2.hpp | 44 +- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 16 + .../gemm/warp/warp_gemm_attribute_mfma.hpp | 303 ++++++- .../warp/warp_gemm_attribute_mfma_impl.hpp | 271 ++++++ .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 4 + 23 files changed, 1987 insertions(+), 272 deletions(-) create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index f6df44a318..332707eafd 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -119,6 +119,7 @@ def get_mask_check_map(mask : str): PIPELINE_ENUM_MAP = { "qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", + "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", } BOOL_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index eca638784d..66814f5a16 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -44,13 +44,12 @@ using fmha_dtype_{F_idx} = {F_dtype}; using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; -using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, - fmha_warp_tile_{F_idx}, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, - fmha_warp_tile_{F_idx}, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, {F_vlayout}>; using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, @@ -306,15 +305,19 @@ class FmhaFwdTileSize: F_rm1 : int # number of warps for gemm1 along q seqlen F_rn1 : int # number of warps for gemm1 along head dim v F_rk1 : int # number of warps for gemm1 along k seqlen (not used) - F_wm : int # warp size along m (warp size) - F_wn : int # warp size along n - F_wk : int # warp size along k + F_wm0 : int # gemm0 warp size along m + F_wn0 : int # gemm0 warp size along n + F_wk0 : int # gemm0 warp size along k + F_wm1 : int # gemm1 warp size along m + F_wn1 : int # gemm1 warp size along n + F_wk1 : int # gemm1 warp size along k F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy @property def name(self) -> str: return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ - f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}" + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") @dataclass class FmhaFwdKernel: @@ -352,9 +355,12 @@ def template(self) -> str: F_rm1 = self.F_tile.F_rm1, F_rn1 = self.F_tile.F_rn1, F_rk1 = self.F_tile.F_rk1, - F_wm = self.F_tile.F_wm, - F_wn = self.F_tile.F_wn, - F_wk = self.F_tile.F_wk, + F_wm0 = self.F_tile.F_wm0, + F_wn0 = self.F_tile.F_wn0, + F_wk0 = self.F_tile.F_wk0, + F_wm1 = self.F_tile.F_wm1, + F_wn1 = self.F_tile.F_wn1, + F_wk1 = self.F_tile.F_wk1, F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], F_spad = BOOL_MAP[self.F_pipeline.F_spad], F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], @@ -409,17 +415,17 @@ def api_trait(self) -> FmhaFwdApiTrait: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, -1), - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), - ## '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), + '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), } elif dtype == 'fp8' or dtype == 'bf8': return { - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1) + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), } else: return None diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index e448902cf8..df5b9cecc6 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -39,6 +39,7 @@ FMHA_FWD_SPLITKV_PIPELINE_MAP = { "qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS", + "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS", "qr_async" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync", } @@ -50,13 +51,12 @@ template struct kernel_runner {{ using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; -using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; using fmha_shape = ck_tile::TileFmhaShape, - fmha_warp_tile, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, - fmha_warp_tile, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, {F_vlayout}>; using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad}, @@ -161,9 +161,8 @@ typename FmhaFwdTypeConfig::OaccDataType, typename FmhaFwdTypeConfig::ODataType, {F_hdim}, - {F_bm0}, - {F_bn1}, {F_mode}, + {F_bn1}, fmha_trait>; using fmha_pipeline = ck_tile::BlockFmhaFwdSplitKVCombinePipeline< @@ -177,9 +176,11 @@ false, false>>; using fmha_kernel = - ck_tile::FmhaFwdSplitKVCombineKernel, - fmha_pipeline, - fmha_epilogue>; + ck_tile::FmhaFwdSplitKVCombineKernel< + ck_tile::FmhaFwdSplitKVCombineTilePartitioner< + fmha_pipeline_problem::kM0, fmha_pipeline_problem::kN1>, + fmha_pipeline, + fmha_epilogue>; static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ @@ -192,7 +193,7 @@ }}; }} -using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn1}, +using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; #include @@ -250,16 +251,25 @@ FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + + // get combine kernel tile sizes + using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType; + constexpr ck_tile::index_t kM0 = ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes::kM0; + + // make sure we can reuse the padding flags in combine kernels + static_assert({F_bm0} % kM0 == 0); + static_assert({F_bn1} % 32 == 0); + if (t.has_lse) {{ if constexpr (std::is_same_v<{F_dtype}, ck_tile::fp8_t>) {{ return -1; }} else {{ - using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, true, {F_squant}, {F_spad}, {F_dvpad}>; + using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, true, {F_squant}, {F_spad}, {F_dvpad}>; return fmha_fwd_splitkv_(s, a); }} }} else {{ - using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, false, {F_squant}, {F_spad}, {F_dvpad}>; + using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, false, {F_squant}, {F_spad}, {F_dvpad}>; return fmha_fwd_splitkv_(s, a); }} @@ -302,7 +312,7 @@ def scheck(self) -> str: if self.pipeline_tag == 'qr_async': if self.spad == 't' : return 'true' # always support else : return 'true' - elif self.pipeline_tag in ['qr']: + elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']: if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_q % {self.bm0} == 0' else: assert False @@ -313,7 +323,7 @@ def skcheck(self) -> str: if self.pipeline_tag == 'qr_async': if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' - elif self.pipeline_tag in ['qr', 'qr_fp8']: + elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']: if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_k % {self.bn0} == 0' else: assert False @@ -324,7 +334,7 @@ def dcheck(self) -> str: vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dpad == 't': return f'a.hdim_q % {vec} == 0' else : assert False - elif self.pipeline_tag in ['qr']: + elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_q % {bk0submax} == 0' @@ -336,7 +346,7 @@ def dvcheck(self) -> str: vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' else : assert False - elif self.pipeline_tag in ['qr']: + elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_v % {bk0submax} == 0' @@ -447,12 +457,11 @@ def api(self) -> str: @dataclass class FmhaFwdSplitKVCombineTileSize: - F_bm0 : int # tile size along q seqlen F_bn1 : int # tile size along v head_dim F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy @property def name(self) -> str: - return f"b{self.F_bm0}x{self.F_bn1}" +\ + return f"b{self.F_bn1}" +\ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") @dataclass @@ -485,9 +494,12 @@ def template(self) -> str: F_rm1 = self.F_tile.F_rm1, F_rn1 = self.F_tile.F_rn1, F_rk1 = self.F_tile.F_rk1, - F_wm = self.F_tile.F_wm, - F_wn = self.F_tile.F_wn, - F_wk = self.F_tile.F_wk, + F_wm0 = self.F_tile.F_wm0, + F_wn0 = self.F_tile.F_wn0, + F_wk0 = self.F_tile.F_wk0, + F_wm1 = self.F_tile.F_wm1, + F_wn1 = self.F_tile.F_wn1, + F_wk1 = self.F_tile.F_wk1, F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], F_spad = BOOL_MAP[self.F_pipeline.F_spad], F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], @@ -553,7 +565,6 @@ def template(self) -> str: F_idx = self.F_idx, F_hdim = self.F_hdim, F_dtype = FWD_DTYPE_MAP[self.F_dtype], - F_bm0 = self.F_tile.F_bm0, F_bn1 = self.F_tile.F_bn1, F_spad = BOOL_MAP[self.F_pipeline.F_spad], F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], @@ -577,17 +588,17 @@ def filename(self) -> str: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, -1), - '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), - ## '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), - '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), - '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), + '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + ### '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), } elif dtype == 'fp8' or dtype == 'bf8': return { - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1) + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), } else: return None @@ -595,17 +606,17 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - '32' : FmhaFwdSplitKVCombineTileSize(16, 16, -1), - '64' : FmhaFwdSplitKVCombineTileSize(32, 32, -1), - ## '96' : FmhaFwdSplitKVCombineTileSize(32, 64, -1), - '128' : FmhaFwdSplitKVCombineTileSize(32, 64, -1), - '256' : FmhaFwdSplitKVCombineTileSize(32, 128, -1), + '32' : FmhaFwdSplitKVCombineTileSize(32, -1), + '64' : FmhaFwdSplitKVCombineTileSize(32, -1), + ### '96' : FmhaFwdSplitKVCombineTileSize(32, -1), + '128' : FmhaFwdSplitKVCombineTileSize(32, -1), + '256' : FmhaFwdSplitKVCombineTileSize(32, -1), } elif dtype == 'fp8' or dtype == 'bf8': return { - '64' : FmhaFwdSplitKVCombineTileSize(64, 32, -1), - '128' : FmhaFwdSplitKVCombineTileSize(64, 64, -1), - '256' : FmhaFwdSplitKVCombineTileSize(64, 128, -1), + '64' : FmhaFwdSplitKVCombineTileSize(32, -1), + '128' : FmhaFwdSplitKVCombineTileSize(32, -1), + '256' : FmhaFwdSplitKVCombineTileSize(32, -1), } else: return None diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index aee54b4758..0e821ed5d9 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -709,7 +709,6 @@ std::string fmha_fwd_splitkv_get_name_(); template ; static constexpr bool kIsGroupMode = kIsGroupMode_; - static constexpr ck_tile::index_t kM0 = kM0_; static constexpr ck_tile::index_t kN1 = kN1_; static constexpr bool kStoreLse = kStoreLse_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index bebf035e9c..107aae5516 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1303,8 +1303,8 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe static_assert( (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || diff --git a/include/ck_tile/core/tensor/static_distributed_tensor.hpp b/include/ck_tile/core/tensor/static_distributed_tensor.hpp index 568d618ec2..8d2f88af39 100644 --- a/include/ck_tile/core/tensor/static_distributed_tensor.hpp +++ b/include/ck_tile/core/tensor/static_distributed_tensor.hpp @@ -29,6 +29,7 @@ struct static_distributed_tensor remove_cvref_t; static constexpr index_t kThreadElementSpaceSize = ThreadTensorDesc{}.get_element_space_size(); + static_assert(0 < kThreadElementSpaceSize, "Make sure tile distribution is valid"); CK_TILE_HOST_DEVICE static constexpr auto get_num_of_dimension() { diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index e106264cef..7a09e4622d 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -29,6 +29,8 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 3de433d6a7..90102a6c6f 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -71,7 +71,8 @@ struct FmhaFwdKernel using bfs = typename FmhaPipeline::BlockFmhaShape; using g0br = typename bfs::Gemm0BlockWarps; using g1br = typename bfs::Gemm1BlockWarps; - using gwt = typename bfs::Gemm0WarpTile; + using g0wt = typename bfs::Gemm0WarpTile; + using g1wt = typename bfs::Gemm1WarpTile; #define _SS_ std::string #define _TS_ std::to_string auto pn = [&] () { @@ -88,7 +89,8 @@ struct FmhaFwdKernel _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" + "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + - "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + + "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" + + "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp index 0bccabdd2f..a0adfdc127 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp @@ -8,9 +8,11 @@ namespace ck_tile { template struct FmhaFwdSplitKVCombineKernel { - using TilePartitioner = remove_cvref_t; - using FmhaPipeline = remove_cvref_t; - using EpiloguePipeline = remove_cvref_t; + using TilePartitioner = remove_cvref_t; + using FmhaPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + + static constexpr index_t kNumWarps = FmhaPipeline::kNumWarps; static constexpr index_t kBlockSize = FmhaPipeline::kBlockSize; static constexpr index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; static_assert(kBlockPerCu > 0); @@ -50,8 +52,7 @@ struct FmhaFwdSplitKVCombineKernel return _SS_("fmha_fwd_splitkv_combine_d") + _TS_(FmhaPipeline::kHeadDimV) + "_" + _SS_(t2s::name) + "_" + (kIsGroupMode ? "group" : "batch") + "_" - "b" + _TS_(FmhaPipeline::kM0) + "x" + - _TS_(FmhaPipeline::kN1) + "_" + + "b" + _TS_(FmhaPipeline::kN1) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + (pn.empty() ? "" : "_" + pn) + @@ -339,37 +340,56 @@ struct FmhaFwdSplitKVCombineKernel number{}, number<1>{}); + // read 4 * (kM0, kN1) o_acc tiles simultaneously by 4 warps const auto o_acc_dram_view = pad_tensor_view( o_acc_dram_naive, - make_tuple(number<1>{}, number{}, number{}), - sequence{}); + make_tuple( + number{}, number{}, number{}), + sequence{}); + const index_t padded_num_splits = + o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<0>{}]; const index_t padded_seqlen_q = o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}]; const index_t padded_hdim_v = o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}]; - return transform_tensor_view( + const index_t num_m_tiles = integer_divide_floor(padded_seqlen_q, FmhaPipeline::kM0); + + // transform tensor view by following steps, given shape: (padded_num_splits, + // padded_seqlen_q, padded_hdim_v) + // 1. unmerge to (padded_num_splits, num_m_tiles, kM0, padded_hdim_v) + // 2. transpose to (num_m_tiles, padded_num_splits, kM0, padded_hdim_v) + // 3. merge to (num_m_tiles * padded_num_splits * kM0, padded_hdim_v) + auto transposed = transform_tensor_view( o_acc_dram_view, - make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_seqlen_q)), + make_tuple(make_pass_through_transform(padded_num_splits), + make_unmerge_transform(make_tuple(num_m_tiles, FmhaPipeline::kM0)), make_pass_through_transform(padded_hdim_v)), - make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<1>{}, sequence<0, 2>{}, sequence<3>{})); + + return transform_tensor_view( + transposed, + make_tuple(make_merge_transform( + make_tuple(num_m_tiles, padded_num_splits, FmhaPipeline::kM0)), + make_pass_through_transform(padded_hdim_v)), + make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), make_tuple(sequence<0>{}, sequence<1>{})); }(); auto lse_acc_dram_window = make_tile_window( lse_acc_dram, - [&]() { - return make_tuple(number{}, number{}); - }(), + make_tuple(number{}, number{}), {0, i_m0}); + const index_t padded_num_splits = + integer_divide_ceil(kargs.num_splits, kNumWarps) * kNumWarps; + auto o_acc_dram_window = make_tile_window( o_acc_dram, - [&]() { - return make_tuple(number{}, number{}); - }(), - {i_m0, i_n1}); + make_tuple(number{}, number{}), + {i_tile_m * padded_num_splits * FmhaPipeline::kM0, i_n1}); // LSE DRAM window auto lse_dram_window = [&, i_nhead_ = i_nhead]() { @@ -410,7 +430,6 @@ struct FmhaFwdSplitKVCombineKernel identity{}, // lse_element_func composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func kargs.num_splits, - kargs.seqlen_q, smem_ptr); } else @@ -419,7 +438,6 @@ struct FmhaFwdSplitKVCombineKernel o_acc_dram_window, lse_dram_window, kargs.num_splits, - kargs.seqlen_q, smem_ptr); } }(); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index f37e676da0..dc17487262 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -45,6 +45,7 @@ struct FmhaFwdSplitKVKernel static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; + static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; @@ -67,7 +68,8 @@ struct FmhaFwdSplitKVKernel using bfs = typename FmhaPipeline::BlockFmhaShape; using g0br = typename bfs::Gemm0BlockWarps; using g1br = typename bfs::Gemm1BlockWarps; - using gwt = typename bfs::Gemm0WarpTile; + using g0wt = typename bfs::Gemm0WarpTile; + using g1wt = typename bfs::Gemm1WarpTile; #define _SS_ std::string #define _TS_ std::to_string auto pn = [&] () { @@ -84,11 +86,12 @@ struct FmhaFwdSplitKVKernel _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" + "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + - "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + + "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" + + "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kDoFp8StaticQuant ? "_squant" : "") + (kIsPagedKV ? "_pagedkv" : "" ); + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kDoFp8StaticQuant ? "_squant" : "") + (kIsPagedKV ? "_pagedkv" : "" ); #undef _SS_ #undef _TS_ // clang-format on diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp index 7c49fce99a..7ac86e6d12 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -53,6 +53,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; + static constexpr index_t kNumWarps = Problem::kNumWarps; static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kHeadDimV = Problem::kHeadDimV; @@ -117,7 +118,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline const LSEElementFunction& lse_element_func, const OaccElementFunction& o_acc_element_func, index_t num_splits, - index_t seqlen_q, void* smem_ptr) const { // lse_acc tile in LDS @@ -143,11 +143,12 @@ struct BlockFmhaFwdSplitKVCombinePipeline // copy lse_acc tile (shape=[kMaxSplits, kM0]) to LDS (shape=[kMaxSplits, kM0]). auto lse_acc_tile = load_tile(lse_acc_dram_window); store_tile(lse_acc_lds_write_window, lse_acc_tile); - block_sync_lds(); auto lse_accum = make_static_distributed_tensor( Policy::template MakeLSEaccRegTileDistribution()); + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); // copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, kMaxSplits]) // and fill up -INF values outside the [kM0, num_splits] region. { @@ -264,46 +265,94 @@ struct BlockFmhaFwdSplitKVCombinePipeline } }); } - block_sync_lds(); if constexpr(kStoreLSE) { store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum)); } - auto o_acc_dist = Policy::template MakeOaccDramTileDistribution(); - auto o_acc_dram_window = + auto o_acc_4_dist = Policy::template MakeOacc4DramTileDistribution(); + auto o_acc_4_dram_window = make_tile_window(o_acc_dram_block_window_tmp.get_bottom_tensor_view(), o_acc_dram_block_window_tmp.get_window_lengths(), o_acc_dram_block_window_tmp.get_window_origin(), - o_acc_dist); - auto o_acc = make_static_distributed_tensor(o_acc_dist); - clear_tile(o_acc); + o_acc_4_dist); - const index_t padded_seqlen_q = integer_divide_ceil(seqlen_q, kM0) * kM0; + // shape=[4 * KM0, kN1] + auto o_acc_4 = make_static_distributed_tensor(o_acc_4_dist); + clear_tile(o_acc_4); - for(index_t i_split = 0; i_split < num_splits; ++i_split) + const index_t padded_num_splits = integer_divide_ceil(num_splits, kNumWarps) * kNumWarps; + + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + // each warp handles a [KM0, kN1] tile + for(index_t split_start = 0; split_start < padded_num_splits; split_start += kNumWarps) { - auto o_tile = load_tile(o_acc_dram_window); + auto o_tile = load_tile(o_acc_4_dram_window); + const index_t i_split = split_start + get_warp_id(); + const index_t row_start = kM0 * get_warp_id(); { - constexpr auto spans = decltype(o_acc)::get_distributed_spans(); + constexpr auto spans = decltype(o_acc_4)::get_distributed_spans(); sweep_tile_span(spans[number<0>{}], [&](auto idx0) { sweep_tile_span(spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); const auto x_indices = get_x_indices_from_distributed_indices( - o_acc.get_tile_distribution(), i_j_idx); + o_acc_4.get_tile_distribution(), i_j_idx); const auto row = x_indices.at(number<0>{}); - const LSEDataType lse_scale = lse_acc_lds(row, i_split); - o_acc(i_j_idx) += lse_scale * o_tile(i_j_idx); + const LSEDataType lse_scale = lse_acc_lds(row - row_start, i_split); + o_acc_4(i_j_idx) += lse_scale * o_tile(i_j_idx); }); }); } - move_tile_window(o_acc_dram_window, {padded_seqlen_q, 0}); + move_tile_window(o_acc_4_dram_window, {kNumWarps * kM0, 0}); + } + + // 4 o_acc tiles in LDS. shape=[4 * kM0, kN1] + OaccDataType* o_acc_4_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeLSEacc())); + + { + auto o_acc_4_lds_window = [&]() { + auto desc = Policy::template MakeOacc4LdsBlockDescriptor(); + auto view = make_tensor_view(o_acc_4_lds_ptr, desc); + return make_tile_window(view, desc.get_lengths(), {0, 0}); + }(); + store_tile(o_acc_4_lds_window, o_acc_4); } + auto o_acc_dist = Policy::template MakeOaccDramTileDistribution(); + + auto o_acc_4_lds_window = [&]() { + auto desc = Policy::template MakeOacc4LdsBlockDescriptor(); + auto view = make_tensor_view(o_acc_4_lds_ptr, desc); + return make_tile_window(view, desc.get_lengths(), {0, 0}, o_acc_dist); + }(); + + auto o_acc = make_static_distributed_tensor(o_acc_dist); + clear_tile(o_acc); + + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + static_for<0, kNumWarps, 1>{}([&](auto) { + auto o_acc_in = load_tile(o_acc_4_lds_window); + + { + constexpr auto spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) += o_acc_in(i_j_idx); + }); + }); + } + + move_tile_window(o_acc_4_lds_window, {kM0, 0}); + }); + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); return o_acc; @@ -316,7 +365,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline const OaccDramBlockWindow& o_acc_dram_block_window, LSEDramBlockWindow& lse_dram_block_window, index_t num_splits, - index_t seqlen_q, void* smem_ptr) const { return operator()(lse_acc_dram_block_window, @@ -325,7 +373,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline identity{}, identity{}, num_splits, - seqlen_q, smem_ptr); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp index ebd69c0cf8..2d4abb3888 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp @@ -10,23 +10,38 @@ namespace ck_tile { struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy { - template + template + CK_TILE_HOST_DEVICE static constexpr auto GetMaxNumWarpsForTile() + { + static_assert(NumWarps == 1 || NumWarps == 2 || NumWarps == 4); + + constexpr index_t ElemPerThread = (M * N) / (NumWarps * get_warp_size()); + if constexpr(0 < ElemPerThread) + { + return NumWarps; + } + else + { // try dividing tile by smaller # of warps + return GetMaxNumWarpsForTile(); + } + } + + template CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeForTile() { - constexpr index_t PixelsPerThread = (M * N) / BlockSize; - static_assert(0 < PixelsPerThread); + constexpr index_t MaxNumWarps = GetMaxNumWarpsForTile(); - constexpr index_t MaxNPerThread = 16 / sizeof(DataType); - constexpr index_t NPerThread = min(MaxNPerThread, PixelsPerThread); + constexpr index_t ElemPerThread = (M * N) / (MaxNumWarps * get_warp_size()); - return NPerThread; + constexpr index_t MaxNPerThread = 16 / sizeof(DataType); + return min(MaxNPerThread, ElemPerThread); } // alignment for dram lse tile (shape=[kMaxSplits, kM0]) template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentLSE() { - return GetVectorSizeForTile(); @@ -56,40 +71,54 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy } template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeLSEacc() { return sizeof(typename Problem::LSEDataType) * MakeLSEaccLdsBlockDescriptor().get_element_space_size(); } + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeOacc4() + { + return sizeof(typename Problem::OaccDataType) * + MakeOacc4LdsBlockDescriptor().get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return GetSmemSizeLSEacc() + GetSmemSizeOacc4(); + } + // shape=[kMaxSplits, kM0] template CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccDramTileDistribution() { using LSEDataType = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNumWarps = Problem::kNumWarps; - - constexpr index_t kNPerBlock = Problem::kM0; constexpr index_t kMPerBlock = Problem::kMaxSplits; + constexpr index_t kNPerBlock = Problem::kM0; + + constexpr index_t MaxNumWarps = + GetMaxNumWarpsForTile(); + constexpr index_t Replicate = Problem::kNumWarps / MaxNumWarps; constexpr index_t NPerThread = - GetVectorSizeForTile(); + GetVectorSizeForTile(); constexpr index_t NThreads = kNPerBlock / NPerThread; constexpr index_t MThreadsPerWarp = get_warp_size() / NThreads; - constexpr index_t MPerThread = kMPerBlock / (kNumWarps * MThreadsPerWarp); + constexpr index_t MPerThread = kMPerBlock / (MaxNumWarps * MThreadsPerWarp); + static_assert(MPerThread * MaxNumWarps * MThreadsPerWarp == kMPerBlock); static_assert(NThreads * NPerThread == kNPerBlock); - static_assert(MPerThread * kNumWarps * MThreadsPerWarp == kMPerBlock); return make_static_tile_distribution( - tile_distribution_encoding, - tuple, + tile_distribution_encoding, + tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, sequence<1, 2>, sequence<0, 1>>{}); } @@ -100,17 +129,15 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy { using LSEDataType = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kMPerBlock = Problem::kMaxSplits; - constexpr index_t kNPerBlock = Problem::kM0; + constexpr index_t kMPerBlock = Problem::kM0; + constexpr index_t kNPerBlock = Problem::kMaxSplits; constexpr index_t NPack = - GetVectorSizeForTile(); + GetVectorSizeForTile(); constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), make_tuple(number<(kMPerBlock + 1) * NPack>{}, number{}, number<1>{}), - number<8>{}, + number{}, number<1>{}); constexpr auto lse_acc_lds_block_desc = transform_tensor_descriptor( @@ -129,17 +156,15 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy { using LSEDataType = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kMPerBlock = Problem::kMaxSplits; - constexpr index_t kNPerBlock = Problem::kM0; + constexpr index_t kMPerBlock = Problem::kM0; + constexpr index_t kNPerBlock = Problem::kMaxSplits; constexpr index_t NPack = - GetVectorSizeForTile(); + GetVectorSizeForTile(); constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), make_tuple(number<(kMPerBlock + 1) * NPack>{}, number{}, number<1>{}), - number<8>{}, + number{}, number<1>{}); constexpr auto lse_acc_t_lds_block_desc = transform_tensor_descriptor( @@ -152,33 +177,86 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy return lse_acc_t_lds_block_desc; } + // 3d + padding, shape=[4 * kM0, kN1] template - CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccRegTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakeOacc4LdsBlockDescriptor() { - constexpr index_t kBlockSize = Problem::kBlockSize; + using LSEDataType = remove_cvref_t; - constexpr index_t kNPerBlock = Problem::kMaxSplits; + constexpr index_t kMPerBlock = 4 * Problem::kM0; + constexpr index_t kNPerBlock = Problem::kN1; + constexpr index_t NPack = + GetVectorSizeForTile(); + + constexpr auto o_acc_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kMPerBlock + 1) * NPack>{}, number{}, number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto o_acc_t_lds_block_desc = transform_tensor_descriptor( + o_acc_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return o_acc_t_lds_block_desc; + } + + // shape=[kM0, kMaxSplits] + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccRegTileDistribution() + { constexpr index_t kMPerBlock = Problem::kM0; + constexpr index_t kNPerBlock = Problem::kMaxSplits; - constexpr index_t NThreads = 4; - constexpr index_t NPerThread = kNPerBlock / NThreads; + constexpr index_t MaxNThreads = 8; + constexpr index_t NThreads = min(kNPerBlock, MaxNThreads); + constexpr index_t NPerThread = kNPerBlock / NThreads; - constexpr index_t MThreads = kBlockSize / NThreads; - constexpr index_t MPerThread = kMPerBlock / MThreads; - constexpr index_t MWarps = kBlockSize / get_warp_size(); + constexpr index_t MPerThread = 1; + constexpr index_t MThreads = kMPerBlock / MPerThread; constexpr index_t MThreadPerWarp = get_warp_size() / NThreads; + constexpr index_t MaxNumWarps = (MThreads * NThreads) / get_warp_size(); + constexpr index_t Replicate = Problem::kNumWarps / MaxNumWarps; + + static_assert(MaxNumWarps * MThreadPerWarp * MPerThread == kMPerBlock); static_assert(NThreads * NPerThread == kNPerBlock); - static_assert(MWarps * MThreadPerWarp * MPerThread == kMPerBlock); return make_static_tile_distribution( - tile_distribution_encoding< - sequence<1>, - tuple, sequence>, - tuple, sequence<2, 1>>, - tuple, sequence<0, 1>>, - sequence<1, 2>, - sequence<2, 1>>{}); + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<0, 1>>, + sequence<1, 2>, + sequence<2, 1>>{}); + } + + // similar to MakeOaccDramTileDistribution(), but duplicate same 1-warp encoding 4 times on M + // direction + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOacc4DramTileDistribution() + { + constexpr index_t kMPerBlock = Problem::kM0; // real kMPerBlock we want is (4 * kM0) + constexpr index_t kNPerBlock = Problem::kN1; + static_assert(get_warp_size() <= kMPerBlock * kNPerBlock); + + constexpr index_t M1 = 1; // compose encoding base on 1 warp + constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size()); + constexpr index_t N0 = get_warp_size() / M2; + constexpr index_t N1 = kNPerBlock / N0; + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<3, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); } template @@ -187,6 +265,7 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::kM0; constexpr index_t kNPerBlock = Problem::kN1; + static_assert(kBlockSize <= kMPerBlock * kNPerBlock); constexpr index_t M1 = kBlockSize / get_warp_size(); constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size()); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp new file mode 100644 index 0000000000..3726cd433c --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp @@ -0,0 +1,794 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +// This pipeline is qkv all located in LDS +template +struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static_assert(kQLoadOnce == Policy::QLoadOnce); + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kIsPagedKV = Problem::kIsPagedKV; + static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(std::is_same_v) + return kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + + static constexpr index_t kAlignmentOacc = + kPadHeadDimV ? 1 : Policy::template GetAlignmentOacc(); + + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + if constexpr(kQKHeaddim <= 32) + { + return 2; + } + else if constexpr(kQKHeaddim <= 64) + { + return 3; + } + else if constexpr(kQKHeaddim <= 128) + { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim <= 256) + { + return 1; + } + } + }(); + + static constexpr const char* name = "qr_nwarp_sshuffle"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile + const KPageBlockNavigator& k_page_block_navigator, + const KElementFunction& k_element_func, + const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile + const VPageBlockNavigator& v_page_block_navigator, + const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile + const LSEaccElementFunction& lse_acc_element_func, + const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + index_t num_splits, + index_t i_split, + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate + void* smem_ptr) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kSubQKHeaddim == + QDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN0 == KDramBlockWindowLengths{}[number<0>{}] && + kK0 == KDramBlockWindowLengths{}[number<1>{}] && + kN1 == VDramBlockWindowLengths{}[number<0>{}] && + kK1 == VDramBlockWindowLengths{}[number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + // Q tile in LDS + QDataType* q_lds_ptr = + static_cast(static_cast(static_cast(smem_ptr))); + auto q_lds = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); + + // K tile in LDS + KDataType* k_lds_ptr = + static_cast(static_cast(static_cast(smem_ptr))); + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(static_cast(smem_ptr) + + max(Policy::template GetSmemSizeQ(), + Policy::template GetSmemSizeK())), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + // S tile in LDS + auto s_lds = make_tensor_view( + reinterpret_cast(reinterpret_cast(smem_ptr) + + max(Policy::template GetSmemSizeQ(), + Policy::template GetSmemSizeK())), + Policy::template MakeSLdsBlockDescriptor()); + auto s_write_lds_window = make_tile_window( + s_lds, Policy::template MakeSLdsBlockDescriptor().get_lengths(), {0, 0}); + auto s_read_lds_window = + make_tile_window(s_lds, + Policy::template MakeSLdsBlockDescriptor().get_lengths(), + {0, 0}, + Policy::template MakeSRegTileDistribution()); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + auto q_dram_window = + make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQDramTileDistribution()); + + // load Q here, will store Q into LDS to maximize throughput + auto origin_q = load_tile(q_dram_window); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + auto o_acc = OaccBlockTileType{}; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(o_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + // init M, L + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, -numeric::infinity()); + clear_tile(l); + + const auto q_origin = q_dram_window.get_window_origin(); + const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); + + // check early exit if no work to do + if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits) + { + const index_t logical_num_total_loop = + integer_divide_ceil(logical_seqlen_k_end - logical_seqlen_k_start, kN0); + if(logical_num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse_acc = + make_static_distributed_tensor(m.get_tile_distribution()); + + set_tile(lse_acc, -numeric::infinity()); + + if(get_thread_local_1d_id() < kM0) + { + store_tile(lse_acc_dram_window_tmp, + tile_elementwise_in(lse_acc_element_func, lse_acc)); + } + } + + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset; + const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset; + // make sure the first tile is completely located in page-block (page-block size should be + // divisible by kN0) + // relationship between each *_start variables: aligned_physical_seqlen_k_start <= + // physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start + const index_t aligned_physical_seqlen_k_start = + [&, physical_seqlen_k_start_ = physical_seqlen_k_start] { + if constexpr(kIsPagedKV) + { + return kN0 * integer_divide_floor(physical_seqlen_k_start_, kN0); + } + else + { + return physical_seqlen_k_start_; + } + }(); + const index_t num_total_loop = + integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0); + + auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window( + k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0}); + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {bias_origin.at(number<0>{}), + logical_seqlen_k_start - (physical_seqlen_k_start - + aligned_physical_seqlen_k_start)}, // M/N + Policy::template MakeBiasDramTileDistribution()); + + auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window( + v_dram_block_window_lengths, + {0, aligned_physical_seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + + // store Q into LDS + __builtin_amdgcn_sched_barrier(0); + auto q_lds_window_for_store = make_tile_window( + q_lds, Policy::template MakeQLdsBlockDescriptor().get_lengths(), {0, 0}); + + store_tile(q_lds_window_for_store, origin_q); + __builtin_amdgcn_sched_barrier(0); + + // load Q from LDS + __builtin_amdgcn_sched_barrier(0); + auto q_lds_window_for_load = make_tile_window( + q_lds, + Policy::template MakeQLdsBlockDescriptor().get_lengths(), + {0, 0}, + Policy::template MakeQRegTileDistribution()); + block_sync_lds(); + auto q = load_tile(q_lds_window_for_load); + __builtin_amdgcn_sched_barrier(0); + auto q_tile = tile_elementwise_in(q_element_func, q); + + // prefetch K tile + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(2 <= k0_loops); + static_assert(1 <= k1_loops); + + auto k_dram_window = make_tile_window( + k_dram_block_window, + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + + // load the first tile of the first iteration and store to LDS + auto k_block_tile = load_tile(k_dram_window); + // moving k_dram_window is an in-page-block operation, so there is + // no need to invoke k_page_block_navigator.move_tile_window() here. + move_tile_window(k_dram_window, {0, kK0}); + store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); + + do + { + // STAGE 1, QK gemm + clear_tile(s_acc); // initialize C + + // load the second tile of the first iteration + k_block_tile = load_tile(k_dram_window); + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + + if constexpr(k0_loops > 2) + { + static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, i_k0 * kK0>{}, + sequence{}), + k_lds_window); + block_sync_lds(); + move_tile_window(k_dram_window, {0, kK0}); + + store_tile( + k_lds_window, + tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1 + k_block_tile = load_tile(k_dram_window); // global read i + 2 + }); + } + + const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile + { // tail + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 2) * kK0>{}, + sequence{}), + k_lds_window); + block_sync_lds(); + + store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); + block_sync_lds(); + + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + k_lds_window); + } + + // STAGE 2, scale_s, add bias, mask, softmax + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x += type_convert(bias_element_func(y)); +#else + x += log2e_v * + type_convert(bias_element_func(y)); +#endif + }, + s_acc, + bias_tile); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto k_origin = k_page_block_navigator.to_global_window_origin( + i_page_block_k, k_dram_block_window.get_window_origin()); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale_s; + // position_encoding accept only logical coordinates, do conversion here + position_encoding.update(s_acc(i_j_idx), row, col - kv_l2p_offset); + }); + }); + } + else + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + } + move_tile_window(bias_dram_window, {0, kN0}); + + /// TODO: only check in first/last iteration without increasing code size + if constexpr(kHasUnevenSplits) + { + const auto k_origin = k_page_block_navigator.to_global_window_origin( + i_page_block_k, k_dram_block_window.get_window_origin()); + set_tile_if( + s_acc, + -numeric::infinity(), + [&, + physical_seqlen_k_start_ = physical_seqlen_k_start, + physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) { + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + if constexpr(kIsPagedKV) + { + return col < physical_seqlen_k_start_ || physical_seqlen_k_end_ <= col; + } + else + { + return physical_seqlen_k_end_ <= col; + } + }); + } + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = k_page_block_navigator.to_global_window_origin( + i_page_block_k, k_dram_block_window.get_window_origin()); + // mask accept only logical coordinates, do conversion here + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}) - kv_l2p_offset, + number{}, + number{}); + if(need_perpixel_check) + { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col - kv_l2p_offset); + }); + } + } + + __builtin_amdgcn_sched_barrier(0); + + // load the first tile for next iteration + if(i_total_loops < num_total_loop - 1) + { + // move K tile windows + i_page_block_k = k_page_block_navigator.move_tile_window( + i_page_block_k, k_dram_block_window, {kN0, 0}); + + k_dram_window = make_tile_window( + k_dram_block_window, + Policy::template MakeKDramTileDistribution()); // K DRAM tile window + + // laod the first tile of the first iteration and store to LDS + k_block_tile = load_tile(k_dram_window); + } + + __builtin_amdgcn_sched_barrier(0); + + const auto s = cast_tile(s_acc); // S{j} + + // shuffle through LDS so that the tile layout is consistent with required by Gemm1 + store_tile(s_write_lds_window, s); + block_sync_lds(); + auto s_new = load_tile(s_read_lds_window); + + auto m_local = block_tile_reduce( + s_new, + sequence<1>{}, + f_max, + -numeric::infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s_new.get_tile_distribution()); // Pcompute{j} + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_m == -numeric::infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_max = scale_s * get_validated_m(m[i_idx]); +#endif + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max); + } +#else + p_compute(i_j_idx) = exp(s_new[i_j_idx] - get_validated_m(m[i_idx])); +#endif + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + + const auto p = + cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); + + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + const auto tmp = [&]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } + }(); +#else + const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + block_sync_lds(); + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_prefetch); + store_tile( + v_lds_window, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch + } + i_page_block_v = + v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {0, kK1}); + + // STAGE 3, KV gemm + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&, + &i_page_block_v_ = i_page_block_v, + &v_dram_window_ = v_dram_window](auto i_k1) { + const auto v = load_tile(v_dram_window_); // load next v + block_sync_lds(); + + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, i_k1 * kK1>{}, sequence{}), + v_lds_window); + block_sync_lds(); + + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v); + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, + v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v)); // store next v + } + i_page_block_v_ = v_page_block_navigator.move_tile_window( + i_page_block_v_, v_dram_window_, {0, kK1}); + }); + } + + // tail + { + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + v_lds_window); + block_sync_lds(); + } + + __builtin_amdgcn_sched_barrier(0); + + // load the first tile for next iteration + if(i_total_loops < num_total_loop - 1) + { + // store the first tile for next iteration to LDS + // moving k_dram_window is an in-page-block operation, so there is + // no need to invoke k_page_block_navigator.move_tile_window() here. + move_tile_window(k_dram_window, {0, kK0}); + store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); + } + } while(++i_total_loops < num_total_loop); + + if constexpr(kStoreLSE) + { + // store lse acc + auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans(); + sweep_tile_span(lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); + } + else + { + lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + } +#else + lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]); +#endif + }); + + if(get_thread_local_1d_id() < kM0) + { + store_tile(lse_acc_dram_window_tmp, + tile_elementwise_in(lse_acc_element_func, lse_acc)); + } + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile + const KPageBlockNavigator& k_page_block_navigator, + const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile + const VPageBlockNavigator& v_page_block_navigator, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile + index_t num_splits, + index_t i_split, + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate + void* smem_ptr) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_lengths, + k_page_block_navigator, + identity{}, + v_dram_block_window_lengths, + v_page_block_navigator, + identity{}, + bias_dram_block_window_tmp, + identity{}, + lse_acc_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + num_splits, + i_split, + mask, + position_encoding, + scale_s, + kv_l2p_offset, + smem_ptr); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp new file mode 100644 index 0000000000..74d755ef39 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp @@ -0,0 +1,226 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp" + +namespace ck_tile { + +// This pipeline is qkv all located in LDS +struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy + : BlockFmhaPipelineQXKSVSCustomPolicy +{ + using BasePolicy = BlockFmhaPipelineQXKSVSCustomPolicy; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; + + constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType); + + // this should align with MakeQDramTileDistribution() + constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; + static_assert(0 < ElemPerThread); + return min(ElemPerThread, MaxVectorSize); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc() + { + using OaccDataType = remove_cvref_t; + + return static_cast(16 / sizeof(OaccDataType)); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; + + constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType); + + constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; + static_assert(0 < ElemPerThread); + constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); + + constexpr index_t KPerThread = kMaxVecLoad; + constexpr index_t KThreads = kKPerBlock / KPerThread; + constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution() + { + return BasePolicy::template MakeQDramTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ() + { + // TODO: this is for 3d layout + using QDataType = remove_cvref_t; + return static_cast(16 / sizeof(QDataType)); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; + + constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; + static_assert(0 < ElemPerThread); + constexpr index_t kKPack = min(ElemPerThread, GetSmemKPackQ()); + + constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto q_lds_block_desc = transform_tensor_descriptor( + q_lds_block_desc_0, + make_tuple( + make_pass_through_transform(number{}), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return q_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemNPackS() + { + using SDataType = remove_cvref_t; + return static_cast(16 / sizeof(SDataType)); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeSLdsBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kNPack = GetSmemNPackS(); + + constexpr auto s_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kMPerBlock + 1) * kNPack>{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto s_lds_block_desc = transform_tensor_descriptor( + s_lds_block_desc_0, + make_tuple( + make_pass_through_transform(number{}), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return s_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeSRegTileDistribution() + { + using BlockGemm = remove_cvref_t())>; + + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + static_assert(MWarp == 1, "Check failed!"); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kTileK = Problem::BlockFmhaShape::kN0; + + // K2 is equal to Impl::kABKPerLane * kKIterPerWarpGemm + constexpr index_t K3 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t K2 = WG::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t K1 = kKPerBlock / (K2 * K3); + constexpr index_t K0 = kTileK / kKPerBlock; + constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t M1 = MWarp; + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + constexpr auto s2_block_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<2, 2>>, + sequence<1, 2, 2, 2>, + sequence<0, 0, 1, 3>>{}; + + constexpr auto s2_block_dstr = make_static_tile_distribution(s2_block_dstr_encoding); + + return s2_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ() + { + return MakeQLdsBlockDescriptor().get_element_space_size() * + sizeof(typename Problem::QDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK() + { + return MakeKLdsBlockDescriptor().get_element_space_size() * + sizeof(typename Problem::KDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV() + { + return MakeVLdsBlockDescriptor().get_element_space_size() * + sizeof(typename Problem::VDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeS() + { + return MakeSLdsBlockDescriptor().get_element_space_size() * + sizeof(typename Problem::SaccDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return max(GetSmemSizeQ(), GetSmemSizeK()) + + max(GetSmemSizeV(), GetSmemSizeS()); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index d9da2f088c..1fe19faaf9 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -106,28 +106,43 @@ struct BlockFmhaFwdSplitKVPipelineProblem static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; +// extract tile size attributes to remove dependency on traits +template +struct BlockFmhaSplitKVCombinePipelineTileSizes +{ + static constexpr index_t MaxVectorSize = 16 / sizeof(OaccDataType_); + + static constexpr index_t kN1 = kN1_; + static constexpr index_t NThreads = kN1 / MaxVectorSize; + static constexpr index_t kM0 = get_warp_size() / NThreads; // MThreadPerWarp +}; + template struct BlockFmhaSplitKVCombinePipelineProblem + : BlockFmhaSplitKVCombinePipelineTileSizes { + using BaseType = BlockFmhaSplitKVCombinePipelineTileSizes; + using LSEDataType = remove_cvref_t; using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; using Traits = remove_cvref_t; - static constexpr index_t kNumWarps = kM0_ / (get_warp_size() / 4); - static constexpr index_t kBlockSize = kNumWarps * get_warp_size(); - static constexpr bool kIsGroupMode = kIsGroupMode_; + static_assert(std::is_same_v); static constexpr index_t kHeadDimV = HeadDimV_; - static constexpr index_t kM0 = kM0_; - static constexpr index_t kN1 = kN1_; + static constexpr bool kIsGroupMode = kIsGroupMode_; + + using BaseType::kM0; + using BaseType::kN1; + + static_assert(kN1 <= kHeadDimV && kHeadDimV % kN1 == 0); // attributes from traits static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; @@ -136,6 +151,13 @@ struct BlockFmhaSplitKVCombinePipelineProblem static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr index_t kMaxSplits = Traits::kMaxSplits; + static_assert(8 <= kMaxSplits); + + static constexpr index_t kNumWarps = 4; // always use 4 warps for each workgroup + static constexpr index_t kBlockSize = kNumWarps * get_warp_size(); + + static_assert(get_warp_size() <= (kM0 * kMaxSplits) && + (kM0 * kMaxSplits) % get_warp_size() == 0); }; template template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() { + constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType); + using BlockGemm = remove_cvref_t())>; constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; - return WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; + + return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane); } template CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() { - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr index_t MWarp = config.template at<1>(); - - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; - - constexpr index_t K2 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; - constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane; - constexpr index_t K0 = kKPerBlock / (K1 * K2); - - constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane; - constexpr index_t M1 = MWarp; - constexpr index_t M0 = kMPerBlock / (M2 * M1); - - if constexpr(1 < Problem::kNumGemm0Warps) - { - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1>>, - tuple, sequence<1, 2>>, - sequence<1, 2, 2>, - sequence<0, 0, 2>>{}); - } - else - { - static_assert(MWarp == 1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2, 2>, - sequence<0, 0, 2>>{}); - } + return BlockGemm::template MakeABlockTileDistribution< + Problem::BlockFmhaShape::kM0, + Problem::BlockFmhaShape::kSubQKHeaddim>(); } template @@ -105,7 +74,7 @@ struct BlockFmhaPipelineQXCustomPolicy constexpr auto warp_gemm = []() { constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); - static_assert(WarpGemmM == 16 || WarpGemmM == 32); + static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); if constexpr(std::is_same_v && std::is_same_v && @@ -113,8 +82,10 @@ struct BlockFmhaPipelineQXCustomPolicy { if constexpr(WarpGemmM == 32) return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; - else // WarpGemmM == 16 + else if constexpr(WarpGemmM == 16) return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}; + else // WarpGemmM == 4 + return WarpGemmMfmaF16F16F32M4N64K16{}; } else if constexpr(std::is_same_v && std::is_same_v && @@ -122,8 +93,10 @@ struct BlockFmhaPipelineQXCustomPolicy { if constexpr(WarpGemmM == 32) return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; - else // WarpGemmM == 16 + else if constexpr(WarpGemmM == 16) return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}; + else // WarpGemmM == 4 + return WarpGemmMfmaBf16Bf16F32M4N64K16{}; } else if constexpr(std::is_same_v && std::is_same_v && diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp index bb33b5f021..5ce80c2d1f 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp @@ -43,8 +43,6 @@ struct TileFmhaShape static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps); - static_assert(std::is_same_v); - static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp index ff23f63556..b99466b1ea 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp @@ -65,14 +65,6 @@ struct BlockGemmARegBSmemCRegOneWarpV1 const index_t iNWarp = 0; - constexpr auto a_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding, tuple, sequence>, @@ -81,19 +73,14 @@ struct BlockGemmARegBSmemCRegOneWarpV1 sequence<1, 2>, sequence<0, 0>>{}; - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); - constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); - // constrcut from A-block-tensor from A-Block-tensor-tmp // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent // distribution - auto a_block_tensor = - make_static_distributed_tensor(a_block_dstr); + auto a_block_tensor = make_static_distributed_tensor( + MakeABlockTileDistribution()); a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); @@ -187,6 +174,33 @@ struct BlockGemmARegBSmemCRegOneWarpV1 }); } + template + CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution() + { + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + return make_static_tile_distribution(a_block_dstr_encode); + } + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { constexpr index_t MPerBlock = BlockGemmShape::kM; diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp index 173ef0a02e..0181c0eec8 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp @@ -59,14 +59,6 @@ struct BlockGemmARegBSmemCRegV2 const index_t iNWarp = get_warp_id() % NWarp; - constexpr auto a_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -75,19 +67,14 @@ struct BlockGemmARegBSmemCRegV2 sequence<1, 2>, sequence<0, 0>>{}; - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); - constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); - // constrcut from A-block-tensor from A-Block-tensor-tmp // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent // distribution - auto a_block_tensor = - make_static_distributed_tensor(a_block_dstr); + auto a_block_tensor = make_static_distributed_tensor( + MakeABlockTileDistribution()); a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); @@ -182,6 +169,33 @@ struct BlockGemmARegBSmemCRegV2 }); } + template + CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution() + { + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + return make_static_tile_distribution(a_block_dstr_encode); + } + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { constexpr index_t MPerBlock = BlockGemmShape::kM; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 89ea82c5bd..1fd12973f6 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -56,6 +56,14 @@ using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>; +using WarpGemmMfmaF16F16F32M4N64K16 = WarpGemmImpl, + 4>>; + +using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl, + 4>>; + // bf16 using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl< @@ -104,6 +112,14 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 2>>; +using WarpGemmMfmaBf16Bf16F32M4N64K16 = WarpGemmImpl, + 4>>; + +using WarpGemmMfmaBf16Bf16F32M64N4K16 = WarpGemmImpl, + 4>>; + // fp8 using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl< diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index a9e466a796..e7d4c37966 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -28,6 +28,9 @@ struct WarpGemmAtrributeMfma CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } + static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, + "Multi-block WarpGemmAttributeMfmaImpl is not supported"); + using AWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -94,30 +97,130 @@ struct WarpGemmAtrributeMfmaIterateK CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } - using AWarpDstrEncoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>; + static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1, + "Multi-block on both M & N directions is not supported"); - using BWarpDstrEncoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>; + CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding() + { + if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) + { + // each M blocks share the same data + return tile_distribution_encoding< + sequence, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1) + { + // single block to multi-block thread mapping + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + } - using CWarpDstrEncoding = tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<1, 1>, - sequence<0, 2>>; + CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding() + { + if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) + { + // single block to multi-block thread mapping + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1) + { + // each N blocks share the same data + return tile_distribution_encoding< + sequence, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + } + + CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding() + { + if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 1>, + sequence<0, 2>>{}; + } + else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 1>, + sequence<0, 2>>{}; + } + else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple< + sequence, + sequence>, + tuple>, + tuple>, + sequence<1, 1>, + sequence<0, 2>>{}; + } + } + + using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding()); + + using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding()); + + using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding()); // c_vec += a_vec * b_vec template @@ -206,6 +309,9 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } + static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, + "Multi-block WarpGemmAttributeMfmaImpl is not supported"); + using AWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -270,6 +376,9 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } + static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, + "Multi-block WarpGemmAttributeMfmaImpl is not supported"); + using AWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -341,30 +450,130 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } - using AWarpDstrEncoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>; + static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1, + "Multi-block on both M & N directions is not supported"); - using BWarpDstrEncoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>; + CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding() + { + if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) + { + // single block to multi-block thread mapping + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1) + { + // each N blocks share the same data + return tile_distribution_encoding< + sequence, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + } - using CWarpDstrEncoding = tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2, 2>, - sequence<0, 2>>; + CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding() + { + if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) + { + // each M blocks share the same data + return tile_distribution_encoding< + sequence, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1) + { + // single block to multi-block thread mapping + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + } + + CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding() + { + if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } + else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } + else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple< + sequence, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } + } + + using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding()); + + using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding()); + + using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding()); template // c_vec += a_vec * b_vec @@ -457,6 +666,9 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } + static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, + "Multi-block WarpGemmAttributeMfmaImpl is not supported"); + using AWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -597,6 +809,9 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } + static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, + "Multi-block WarpGemmAttributeMfmaImpl is not supported"); + using AWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple +struct WarpGemmAttributeMfmaImplF16F16F32M4N64K4 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = fp16_t; + using BDataType = fp16_t; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 4; + static constexpr index_t kN = 64; + static constexpr index_t kK = 4; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 16; + + // we only write down single block (4 threads) thread mapping here + static constexpr index_t kAMLane = 4; + static constexpr index_t kBNLane = 4; + static constexpr index_t kABKLane = 1; + static constexpr index_t kABKPerLane = 4; + + static constexpr index_t kCMLane = 1; + static constexpr index_t kCNLane = 4; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4f16", Ctrl) + else + { +#if defined(__gfx9__) + c_vec = __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, c_vec, 0, 0, 0); +#else + ignore = c_vec; + ignore = a_vec; + ignore = b_vec; +#endif + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx9__) + return bit_cast( + __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); +#else + ignore = a_vec; + ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; + +template +struct WarpGemmAttributeMfmaImplF16F16F32M64N4K4 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = fp16_t; + using BDataType = fp16_t; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 64; + static constexpr index_t kN = 4; + static constexpr index_t kK = 4; + + static constexpr index_t kAMBlock = 16; + static constexpr index_t kBNBlock = 1; + + // we only write down single block (4 threads) thread mapping here + static constexpr index_t kAMLane = 4; + static constexpr index_t kBNLane = 4; + static constexpr index_t kABKLane = 1; + static constexpr index_t kABKPerLane = 4; + + static constexpr index_t kCMLane = 1; + static constexpr index_t kCNLane = 4; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4f16", Ctrl) + else + { +#if defined(__gfx9__) + c_vec = __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, c_vec, 0, 0, 0); +#else + ignore = c_vec; + ignore = a_vec; + ignore = b_vec; +#endif + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx9__) + return bit_cast( + __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); +#else + ignore = a_vec; + ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; + // Bf16 template struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 @@ -199,6 +333,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 static constexpr index_t kN = 32; static constexpr index_t kK = 8; + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + static constexpr index_t kAMLane = 32; static constexpr index_t kBNLane = 32; static constexpr index_t kABKLane = 2; @@ -285,6 +422,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 static constexpr index_t kN = 16; static constexpr index_t kK = 16; + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + static constexpr index_t kAMLane = 16; static constexpr index_t kBNLane = 16; static constexpr index_t kABKLane = 4; @@ -354,6 +494,134 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 } }; +template +struct WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = bf16_t; + using BDataType = bf16_t; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 4; + static constexpr index_t kN = 64; + static constexpr index_t kK = 4; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 16; + + // we only write down single block (4 threads) thread mapping here + static constexpr index_t kAMLane = 4; + static constexpr index_t kBNLane = 4; + static constexpr index_t kABKLane = 1; + static constexpr index_t kABKPerLane = 4; + + static constexpr index_t kCMLane = 1; + static constexpr index_t kCNLane = 4; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl) + else + { +#if defined(__gfx9__) + c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); +#else + ignore = c_vec; + ignore = a_vec; + ignore = b_vec; +#endif + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx9__) + return bit_cast( + __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); +#else + ignore = a_vec; + ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; + +template +struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = bf16_t; + using BDataType = bf16_t; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 64; + static constexpr index_t kN = 4; + static constexpr index_t kK = 4; + + static constexpr index_t kAMBlock = 16; + static constexpr index_t kBNBlock = 1; + + // we only write down single block (4 threads) thread mapping here + static constexpr index_t kAMLane = 4; + static constexpr index_t kBNLane = 4; + static constexpr index_t kABKLane = 1; + static constexpr index_t kABKPerLane = 4; + + static constexpr index_t kCMLane = 1; + static constexpr index_t kCNLane = 4; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl) + else + { +#if defined(__gfx9__) + c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); +#else + ignore = c_vec; + ignore = a_vec; + ignore = b_vec; +#endif + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx9__) + return bit_cast( + __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); +#else + ignore = a_vec; + ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; + // FP8 template struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base @@ -371,6 +639,9 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base static constexpr index_t kN = 32; static constexpr index_t kK = 16; + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + static constexpr index_t kAMLane = 32; static constexpr index_t kBNLane = 32; static constexpr index_t kABKLane = 2; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 99cd5d787e..9c319b5e5f 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -29,6 +29,8 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M4N64K16; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M64N4K16; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; @@ -42,6 +44,8 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M4N64K16; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M64N4K16; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };