-
Notifications
You must be signed in to change notification settings - Fork 139
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CK_TILE] Add fmha fwd N-Warp S-Shuffle pipeline (fmha fwd splitkv pi…
…peline variant) (#1705) * Add check for zero values * Add static assertions * Remove invalid option '-e' in smoke_test.sh * Use correct path of smoke_test.sh * Avoid zero-sized shared memory array * Add warning comment * Replace expr by integer_divide_ceil() call * Use more readable constant names * Write down assumption as static assertion * Add more diagnostic error messages * Fix wrong BlockWarps when using default pipeline policy * Add more static assertions for A LDS desc * Allow using vector size < 8 for data type fp16/bf16 * Align vector size between DRAM dist & LDS desc * Remove no-longer used func decl * Fix wrong displayed piepline name * Undo policy template changes for tile_example_gemm_basic * Add missing space and make error message stands out * Unify print precision * Add missing include directive <iomanip> * Replace constant 64 by get_warp_size() call * Replace constant 128 by named variable: BankLength * Add kAMBlock/kBNBlock attributes * Allow usig different A/B warp dist for multiple blocks * Add helper function to get warp dist encodings * Add 4x64x4 fp16 warp gemm attribute impl * Complete the A/B warp dist encoding logic * Fix wrong thread mapping for C matrix * Use smaller vector size for small tile * Add static assert to block unsupported warp gemm impl * Extract common code out as helper method * Add 4x64x16 fp16 warp gemm type alias * Add comment to warning developers * Undo WarpGemmAtrributeMfma<> changes * Use more clear static assertion error message * Add trivial wrapper to get warp dstr encodings * Only transpose warp gemm result if it's square * Fix compilation error * Support multi-block warp gemm (on N direction) * Remove duplicated code * Fix output encoding of warp gemm * Fix wrong shape of WarpGemmAtrributeMfmaIterateK<> * Remove unused code * Fix wrong shape of WarpGemmAttributeMfmaImplF16F16F32M4N64K4 * Add type config for bf16_t * Add 4x64x16 bf16 warp gemm * Update WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution * Add 64x4x4 fp16/bf16 warp gemm impl * Add 64x4x16 fp16/bf16 warp gemm * Add static assertion for better error diagnostic * Get Q dram dstr directly form block gemm * Add missing header: fused_moe.hpp * Allow specifying different warp-gemm for gemm0 & gemm1 * Store P matrix into LDS before gemm1 * Fix inconsistant kernel name * Remove constraint on gemm0 & gemm1 block warps * Remove unsupported vector size from checking list * Allow using 4x64x16 warp gemm for gemm0 * Finish policy customization * Finish pipeline modification F# * Use block warps in codegen * Fix wrong rank of m_lds_window origin * Use better distributed tensor * Make P-store earlier * Remove duplicated experssions * Remove unnecessary tile window * Create new files for new splitkv pipeline * Separate old/new pipeline codegen logic * Sync changes form develop * Undo gemm kernel/pipeline changes * Undo gemm example changes * Remove blank lines * Fix typo * Use new warp gemm interface * Fix link error * Fix wrong pipeline tag * Fix more link error * Avoid unnecessary padding * Always use vector load for K * Padding on fastest dimension when necessary * Force padding Q on hdim_q * Set high dimension padding flag to false * Re-format headers * Use warps=<1, 4, 1> for both gemm0 & gemm1 * Fix complilation errors * Remove m/l shuffle logics * Ignore duplicate data when write lse_acc * Use gemm0 block warps as lds tile width * Remove hard-coded numbers * Fix wrong distribution width * Remove unnecessary code * Add s_barrier before writing to LDS * Store Q into LDS before gemm0 * Fix wrong Q tile size * Use simple Q lds descriptor for debuging * Use more realistic Q lds descriptor * Add comment & use better variable name * Make Q lds space not overlapped with others * Remove unnecessary block_tile_reduce_sync() call * Move Q load statements * Move block_sync_lds() right before use * Re-order instructions * Remove necessary lambda expression * Use 8 threads on kMaxSplits direction while doing reduction * Tiny correction for using 8 threads on kMaxSplits direction for combine kernel * Padding num_split direction of o_acc tile window to 4x * Update splitkv combine pipeline design * Add kN1 back to splitkv combine pipeline problem * Fix compilation errors * Add missing template parameter * Fix wrong splitkv combine kernel name * Fix wrong origin * Fix wrong LDS descriptor shape * Fix sync & reduction logics * Remove unnecessary static assertions * Extract tile size computation logics * Make sure we can reuse padding flags in combine kernels * Rename variables * Use OaccDataType in BlockFmhaSplitKVCombinePipelineTileSizes<> * Remove unnecessary static assertion * Fix function name typo * Add constraint on kN1 template parameter * Hide K tile loading latency in earlier iteration * Fix wrong splitkv kernel name * Use s_shuffling to replace p_shuffling which removes the needs of cross-warp reduction * Rename pipeline * Fix wrong pipeline name attribute * Add GetAlignmentQ() for NWarpSShuffle pipeline * Separate Q tile into dram tile & register tile concepts * Remove non-squre warp gemm transpose c type alias * Fallback tile size changes for fmha fwd splitkv * Remove redundant change * Refine naming for the S tile * Use better naming of the S tile dstr (read from lds) * Share Q lds with K lds * Tiny change * Fix with using static_for for passing CI checking --------- Co-authored-by: Qianfeng Zhang <[email protected]>
- Loading branch information
Showing
23 changed files
with
1,987 additions
and
272 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.