Skip to content

Commit

Permalink
[CK_TILE] Add fmha fwd N-Warp S-Shuffle pipeline (fmha fwd splitkv pi…
Browse files Browse the repository at this point in the history
…peline variant) (#1705)

* Add check for zero values

* Add static assertions

* Remove invalid option '-e' in smoke_test.sh

* Use correct path of smoke_test.sh

* Avoid zero-sized shared memory array

* Add warning comment

* Replace expr by integer_divide_ceil() call

* Use more readable constant names

* Write down assumption as static assertion

* Add more diagnostic error messages

* Fix wrong BlockWarps when using default pipeline policy

* Add more static assertions for A LDS desc

* Allow using vector size < 8 for data type fp16/bf16

* Align vector size between DRAM dist & LDS desc

* Remove no-longer used func decl

* Fix wrong displayed piepline name

* Undo policy template changes for tile_example_gemm_basic

* Add missing space and make error message stands out

* Unify print precision

* Add missing include directive <iomanip>

* Replace constant 64 by get_warp_size() call

* Replace constant 128 by named variable: BankLength

* Add kAMBlock/kBNBlock attributes

* Allow usig different A/B warp dist for multiple blocks

* Add helper function to get warp dist encodings

* Add 4x64x4 fp16 warp gemm attribute impl

* Complete the A/B warp dist encoding logic

* Fix wrong thread mapping for C matrix

* Use smaller vector size for small tile

* Add static assert to block unsupported warp gemm impl

* Extract common code out as helper method

* Add 4x64x16 fp16 warp gemm type alias

* Add comment to warning developers

* Undo WarpGemmAtrributeMfma<> changes

* Use more clear static assertion error message

* Add trivial wrapper to get warp dstr encodings

* Only transpose warp gemm result if it's square

* Fix compilation error

* Support multi-block warp gemm (on N direction)

* Remove duplicated code

* Fix output encoding of warp gemm

* Fix wrong shape of WarpGemmAtrributeMfmaIterateK<>

* Remove unused code

* Fix wrong shape of WarpGemmAttributeMfmaImplF16F16F32M4N64K4

* Add type config for bf16_t

* Add 4x64x16 bf16 warp gemm

* Update WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution

* Add 64x4x4 fp16/bf16 warp gemm impl

* Add 64x4x16 fp16/bf16 warp gemm

* Add static assertion for better error diagnostic

* Get Q dram dstr directly form block gemm

* Add missing header: fused_moe.hpp

* Allow specifying different warp-gemm for gemm0 & gemm1

* Store P matrix into LDS before gemm1

* Fix inconsistant kernel name

* Remove constraint on gemm0 & gemm1 block warps

* Remove unsupported vector size from checking list

* Allow using 4x64x16 warp gemm for gemm0

* Finish policy customization

* Finish pipeline modification
F#

* Use block warps in codegen

* Fix wrong rank of m_lds_window origin

* Use better distributed tensor

* Make P-store earlier

* Remove duplicated experssions

* Remove unnecessary tile window

* Create new files for new splitkv pipeline

* Separate old/new pipeline codegen logic

* Sync changes form develop

* Undo gemm kernel/pipeline changes

* Undo gemm example changes

* Remove blank lines

* Fix typo

* Use new warp gemm interface

* Fix link error

* Fix wrong pipeline tag

* Fix more link error

* Avoid unnecessary padding

* Always use vector load for K

* Padding on fastest dimension when necessary

* Force padding Q on hdim_q

* Set high dimension padding flag to false

* Re-format headers

* Use warps=<1, 4, 1> for both gemm0 & gemm1

* Fix complilation errors

* Remove m/l shuffle logics

* Ignore duplicate data when write lse_acc

* Use gemm0 block warps as lds tile width

* Remove hard-coded numbers

* Fix wrong distribution width

* Remove unnecessary code

* Add s_barrier before writing to LDS

* Store Q into LDS before gemm0

* Fix wrong Q tile size

* Use simple Q lds descriptor for debuging

* Use more realistic Q lds descriptor

* Add comment & use better variable name

* Make Q lds space not overlapped with others

* Remove unnecessary block_tile_reduce_sync() call

* Move Q load statements

* Move block_sync_lds() right before use

* Re-order instructions

* Remove necessary lambda expression

* Use 8 threads on kMaxSplits direction while doing reduction

* Tiny correction for using 8 threads on kMaxSplits direction for combine kernel

* Padding num_split direction of o_acc tile window to 4x

* Update splitkv combine pipeline design

* Add kN1 back to splitkv combine pipeline problem

* Fix compilation errors

* Add missing template parameter

* Fix wrong splitkv combine kernel name

* Fix wrong origin

* Fix wrong LDS descriptor shape

* Fix sync & reduction logics

* Remove unnecessary static assertions

* Extract tile size computation logics

* Make sure we can reuse padding flags in combine kernels

* Rename variables

* Use OaccDataType in BlockFmhaSplitKVCombinePipelineTileSizes<>

* Remove unnecessary static assertion

* Fix function name typo

* Add constraint on kN1 template parameter

* Hide K tile loading latency in earlier iteration

* Fix wrong splitkv kernel name

* Use s_shuffling to replace p_shuffling which removes the needs of cross-warp reduction

* Rename pipeline

* Fix wrong pipeline name attribute

* Add GetAlignmentQ() for NWarpSShuffle pipeline

* Separate Q tile into dram tile & register tile concepts

* Remove non-squre warp gemm transpose c type alias

* Fallback tile size changes for fmha fwd splitkv

* Remove redundant change

* Refine naming for the S tile

* Use better naming of the S tile dstr (read from lds)

* Share Q lds with K lds

* Tiny change

* Fix with using static_for for passing CI checking

---------

Co-authored-by: Qianfeng Zhang <[email protected]>
  • Loading branch information
poyenc and qianfengz authored Dec 20, 2024
1 parent 2944c50 commit 37cdbf4
Show file tree
Hide file tree
Showing 23 changed files with 1,987 additions and 272 deletions.
1 change: 1 addition & 0 deletions example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def get_mask_check_map(mask : str):
PIPELINE_ENUM_MAP = {
"qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
"qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
}

BOOL_MAP = {
Expand Down
42 changes: 24 additions & 18 deletions example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,12 @@
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
fmha_warp_tile_{F_idx},
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
fmha_warp_tile_{F_idx},
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
{F_vlayout}>;
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
Expand Down Expand Up @@ -306,15 +305,19 @@ class FmhaFwdTileSize:
F_rm1 : int # number of warps for gemm1 along q seqlen
F_rn1 : int # number of warps for gemm1 along head dim v
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
F_wm : int # warp size along m (warp size)
F_wn : int # warp size along n
F_wk : int # warp size along k
F_wm0 : int # gemm0 warp size along m
F_wn0 : int # gemm0 warp size along n
F_wk0 : int # gemm0 warp size along k
F_wm1 : int # gemm1 warp size along m
F_wn1 : int # gemm1 warp size along n
F_wk1 : int # gemm1 warp size along k
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@property
def name(self) -> str:
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\
f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}" + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")

@dataclass
class FmhaFwdKernel:
Expand Down Expand Up @@ -352,9 +355,12 @@ def template(self) -> str:
F_rm1 = self.F_tile.F_rm1,
F_rn1 = self.F_tile.F_rn1,
F_rk1 = self.F_tile.F_rk1,
F_wm = self.F_tile.F_wm,
F_wn = self.F_tile.F_wn,
F_wk = self.F_tile.F_wk,
F_wm0 = self.F_tile.F_wm0,
F_wn0 = self.F_tile.F_wn0,
F_wk0 = self.F_tile.F_wk0,
F_wm1 = self.F_tile.F_wm1,
F_wn1 = self.F_tile.F_wn1,
F_wk1 = self.F_tile.F_wk1,
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
Expand Down Expand Up @@ -409,17 +415,17 @@ def api_trait(self) -> FmhaFwdApiTrait:
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, -1),
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
## '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
}
elif dtype == 'fp8' or dtype == 'bf8':
return {
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1)
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
}
else:
return None
Expand Down
85 changes: 48 additions & 37 deletions example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

FMHA_FWD_SPLITKV_PIPELINE_MAP = {
"qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS",
"qr_nwarp_sshuffle" : "ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS",
"qr_async" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync",
}

Expand All @@ -50,13 +51,12 @@
template <bool kHasUnevenSplits>
struct kernel_runner {{
using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_shape = ck_tile::TileFmhaShape<fmha_block_tile,
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
fmha_warp_tile,
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
fmha_warp_tile,
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
{F_vlayout}>;
using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
Expand Down Expand Up @@ -161,9 +161,8 @@
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
{F_hdim},
{F_bm0},
{F_bn1},
{F_mode},
{F_bn1},
fmha_trait>;
using fmha_pipeline = ck_tile::BlockFmhaFwdSplitKVCombinePipeline<
Expand All @@ -177,9 +176,11 @@
false, false>>;
using fmha_kernel =
ck_tile::FmhaFwdSplitKVCombineKernel<ck_tile::FmhaFwdSplitKVCombineTilePartitioner<{F_bm0}, {F_bn1}>,
fmha_pipeline,
fmha_epilogue>;
ck_tile::FmhaFwdSplitKVCombineKernel<
ck_tile::FmhaFwdSplitKVCombineTilePartitioner<
fmha_pipeline_problem::kM0, fmha_pipeline_problem::kN1>,
fmha_pipeline,
fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
Expand All @@ -192,7 +193,7 @@
}};
}}
using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn1},
using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bn1},
{F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
#include <iostream>
Expand Down Expand Up @@ -250,16 +251,25 @@
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
// get combine kernel tile sizes
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
constexpr ck_tile::index_t kM0 = ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType, /*F_bn1=*/32>::kM0;
// make sure we can reuse the padding flags in combine kernels
static_assert({F_bm0} % kM0 == 0);
static_assert({F_bn1} % 32 == 0);
if (t.has_lse) {{
if constexpr (std::is_same_v<{F_dtype}, ck_tile::fp8_t>) {{
return -1;
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, true, {F_squant}, {F_spad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, true, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}}
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, false, {F_squant}, {F_spad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, false, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}}
Expand Down Expand Up @@ -302,7 +312,7 @@ def scheck(self) -> str:
if self.pipeline_tag == 'qr_async':
if self.spad == 't' : return 'true' # always support
else : return 'true'
elif self.pipeline_tag in ['qr']:
elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']:
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_q % {self.bm0} == 0'
else: assert False
Expand All @@ -313,7 +323,7 @@ def skcheck(self) -> str:
if self.pipeline_tag == 'qr_async':
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
elif self.pipeline_tag in ['qr', 'qr_fp8']:
elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']:
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_k % {self.bn0} == 0'
else: assert False
Expand All @@ -324,7 +334,7 @@ def dcheck(self) -> str:
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr']:
elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {bk0submax} == 0'
Expand All @@ -336,7 +346,7 @@ def dvcheck(self) -> str:
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr']:
elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {bk0submax} == 0'
Expand Down Expand Up @@ -447,12 +457,11 @@ def api(self) -> str:

@dataclass
class FmhaFwdSplitKVCombineTileSize:
F_bm0 : int # tile size along q seqlen
F_bn1 : int # tile size along v head_dim
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@property
def name(self) -> str:
return f"b{self.F_bm0}x{self.F_bn1}" +\
return f"b{self.F_bn1}" +\
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")

@dataclass
Expand Down Expand Up @@ -485,9 +494,12 @@ def template(self) -> str:
F_rm1 = self.F_tile.F_rm1,
F_rn1 = self.F_tile.F_rn1,
F_rk1 = self.F_tile.F_rk1,
F_wm = self.F_tile.F_wm,
F_wn = self.F_tile.F_wn,
F_wk = self.F_tile.F_wk,
F_wm0 = self.F_tile.F_wm0,
F_wn0 = self.F_tile.F_wn0,
F_wk0 = self.F_tile.F_wk0,
F_wm1 = self.F_tile.F_wm1,
F_wn1 = self.F_tile.F_wn1,
F_wk1 = self.F_tile.F_wk1,
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
Expand Down Expand Up @@ -553,7 +565,6 @@ def template(self) -> str:
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn1 = self.F_tile.F_bn1,
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
Expand All @@ -577,35 +588,35 @@ def filename(self) -> str:
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
'32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, -1),
'64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
## '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
'128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
'256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
'32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
'64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
### '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
'128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
'256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
}
elif dtype == 'fp8' or dtype == 'bf8':
return {
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1)
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
}
else:
return None

def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
'32' : FmhaFwdSplitKVCombineTileSize(16, 16, -1),
'64' : FmhaFwdSplitKVCombineTileSize(32, 32, -1),
## '96' : FmhaFwdSplitKVCombineTileSize(32, 64, -1),
'128' : FmhaFwdSplitKVCombineTileSize(32, 64, -1),
'256' : FmhaFwdSplitKVCombineTileSize(32, 128, -1),
'32' : FmhaFwdSplitKVCombineTileSize(32, -1),
'64' : FmhaFwdSplitKVCombineTileSize(32, -1),
### '96' : FmhaFwdSplitKVCombineTileSize(32, -1),
'128' : FmhaFwdSplitKVCombineTileSize(32, -1),
'256' : FmhaFwdSplitKVCombineTileSize(32, -1),
}
elif dtype == 'fp8' or dtype == 'bf8':
return {
'64' : FmhaFwdSplitKVCombineTileSize(64, 32, -1),
'128' : FmhaFwdSplitKVCombineTileSize(64, 64, -1),
'256' : FmhaFwdSplitKVCombineTileSize(64, 128, -1),
'64' : FmhaFwdSplitKVCombineTileSize(32, -1),
'128' : FmhaFwdSplitKVCombineTileSize(32, -1),
'256' : FmhaFwdSplitKVCombineTileSize(32, -1),
}
else:
return None
Expand Down
2 changes: 0 additions & 2 deletions example/ck_tile/01_fmha/fmha_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,6 @@ std::string fmha_fwd_splitkv_get_name_();
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::index_t kM0_,
ck_tile::index_t kN1_,
bool kStoreLse_,
bool kDoFp8StaticQuant_,
Expand All @@ -720,7 +719,6 @@ struct fmha_fwd_splitkv_combine_traits_
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN1 = kN1_;
static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
Expand Down
4 changes: 2 additions & 2 deletions include/ck_tile/core/arch/amd_buffer_addressing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1303,8 +1303,8 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
static_assert(
(std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, bf16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, int32_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
Expand Down
1 change: 1 addition & 0 deletions include/ck_tile/core/tensor/static_distributed_tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ struct static_distributed_tensor
remove_cvref_t<decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())>;

static constexpr index_t kThreadElementSpaceSize = ThreadTensorDesc{}.get_element_space_size();
static_assert(0 < kThreadElementSpaceSize, "Make sure tile distribution is valid");

CK_TILE_HOST_DEVICE static constexpr auto get_num_of_dimension()
{
Expand Down
Loading

0 comments on commit 37cdbf4

Please sign in to comment.