From da0295be1eca97ce53d073835442d4e2f68e24f0 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 19 Dec 2024 17:25:38 -0800 Subject: [PATCH] Promote latest CK develop (#1767) * updated fp16 instances to be on parity with universal gemm instances (#1754) * updated fp16 instances to be on parity with universal gemm instances * corrected instance name to streamk instance * [CK_TILE] Move hipmalloc/memcpy calls out of gpu reference gemm (#1743) * [CK_TILE] Move hipmalloc/memcpy calls out of gpu reference gemm * [CK_TILE] Move hipmalloc/memcpy calls out of gpu reference gemm - review changes * [CK_TILE] Move hipmalloc/memcpy calls out of gpu reference gemm - review fix * Disambiguate bit_cast (#1749) Adding namespace to disambiguate with std::bit_cast Co-authored-by: Po Yen Chen * [CK TILE] Refactor GemmKernel to be reused by other GEMM related operators (#1730) * Gemm Kernel Refactor part1 * Gemm Kernel Refactor common gemm pipeline part2 * [CK TILE] Refactor batched gemm to reuse GemmKernel * [CK TILE] Refactor GemmKernel - review changes part1 * [CK TILE] Refactor GemmKernel - references fix * [CK TILE] Refactor GemmKernel - naming changes, add problem * [CK_TILE] Refactor GemmKernel - update tests * [CK_TILE] Refactor GemmKernel - review changes * [CK_TILE] Refactor GemmKernel - update test * [CK_TILE] Refactor GemmKernel - constness fixes * [CK_TILE] Refactor GemmKernel - update tests * Apply Ck-tile argument parser for vectors [I/O] (#1758) * Parser for a vector was added. Additionaly we valid correctnes of numbers * Remove unnecessary comments * Review part 1 * Review part 2 * Add const to variadic lambda * Rename C->K * fix profiler_grouped_gemm (#1766) --------- Co-authored-by: Harisankar Sadasivan <135730918+hsadasiv@users.noreply.github.com> Co-authored-by: aledudek Co-authored-by: Xiaodong Wang Co-authored-by: Po Yen Chen Co-authored-by: Mateusz Ozga <110818320+mozga-amd@users.noreply.github.com> --- example/ck_tile/03_gemm/gemm_basic.cpp | 16 +- example/ck_tile/03_gemm/gemm_basic.hpp | 16 +- example/ck_tile/03_gemm/run_gemm_example.inc | 39 ++- .../ck_tile/16_batched_gemm/batched_gemm.cpp | 6 +- .../ck_tile/16_batched_gemm/batched_gemm.hpp | 6 +- .../run_batched_gemm_example.inc | 35 ++- .../ck_tile/17_grouped_gemm/grouped_gemm.hpp | 20 +- .../run_grouped_gemm_example.inc | 34 ++- .../core/container/meta_data_buffer.hpp | 6 +- include/ck_tile/host/arg_parser.hpp | 46 ++- .../ck_tile/host/reference/reference_gemm.hpp | 162 +---------- .../ops/gemm/kernel/batched_gemm_kernel.hpp | 274 +++++------------- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 259 ++++++++++++----- ...universal_streamk_f16_f16_f16_mk_kn_mn.hpp | 18 +- ...universal_streamk_f16_f16_f16_mk_nk_mn.hpp | 29 +- .../profiler/profile_grouped_gemm_impl.hpp | 2 +- .../batched_gemm/test_batched_gemm_util.hpp | 42 ++- test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 40 +-- 18 files changed, 489 insertions(+), 561 deletions(-) mode change 100644 => 100755 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp mode change 100644 => 100755 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index f5260c306e..4c630375f4 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -15,7 +15,7 @@ #include "gemm_basic.hpp" template -float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) +float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. constexpr bool kPadM = false; @@ -79,17 +79,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKargs(args.p_a, - args.p_b, - args.p_c, - args.M, - args.N, - args.K, - args.stride_A, - args.stride_B, - args.stride_C); - - const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); constexpr dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) diff --git a/example/ck_tile/03_gemm/gemm_basic.hpp b/example/ck_tile/03_gemm/gemm_basic.hpp index 23e99bc2a8..58cdaea7d8 100644 --- a/example/ck_tile/03_gemm/gemm_basic.hpp +++ b/example/ck_tile/03_gemm/gemm_basic.hpp @@ -51,20 +51,6 @@ using BDataType = Types::BDataType; using AccDataType = Types::AccDataType; using CDataType = Types::CDataType; -struct gemm_basic_args -{ - const void* p_a; - const void* p_b; - void* p_c; - ck_tile::index_t kbatch; - ck_tile::index_t M; - ck_tile::index_t N; - ck_tile::index_t K; - ck_tile::index_t stride_A; - ck_tile::index_t stride_B; - ck_tile::index_t stride_C; -}; - auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; @@ -89,4 +75,4 @@ auto create_args(int argc, char* argv[]) } // host API -float gemm_calc(gemm_basic_args args, const ck_tile::stream_config& s); +float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index a1fc155775..68df389bfc 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -16,11 +16,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, int n_warmup, int n_repeat) { - gemm_basic_args args; - args.p_a = a_m_k_dev_buf.GetDeviceBuffer(); - args.p_b = b_k_n_dev_buf.GetDeviceBuffer(); - args.p_c = c_m_n_dev_buf.GetDeviceBuffer(); - args.kbatch = kbatch; + ck_tile::GemmHostArgs args; + args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); + args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); + args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); + args.k_batch = kbatch; args.M = M; args.N = N; args.K = K; @@ -161,14 +161,39 @@ int run_gemm_example_with_layouts(int argc, c_m_n_gpu_ref.SetZero(); c_m_n_gpu_buf_ref.SetZero(); + ADataType* d_A; + BDataType* d_B; + CDataType* d_C; + + ck_tile::hip_check_error(hipMalloc(&d_A, M * K * sizeof(ADataType))); + ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType))); + ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType))); + + ck_tile::hip_check_error(hipMemcpy(d_A, + a_m_k_dev_buf.GetDeviceBuffer(), + M * K * sizeof(ADataType), + hipMemcpyHostToDevice)); + ck_tile::hip_check_error(hipMemcpy(d_B, + b_k_n_dev_buf.GetDeviceBuffer(), + N * K * sizeof(BDataType), + hipMemcpyHostToDevice)); + ck_tile::reference_gemm_gpu( - a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_gpu_buf_ref, M, N, K, stride_A, stride_B, stride_C); + CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); + + ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(), + d_C, + M * N * sizeof(CDataType), + hipMemcpyDeviceToHost)); + + ck_tile::hip_check_error(hipFree(d_A)); + ck_tile::hip_check_error(hipFree(d_B)); + ck_tile::hip_check_error(hipFree(d_C)); c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref); diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.cpp b/example/ck_tile/16_batched_gemm/batched_gemm.cpp index bfdd74126e..9b4ed9a9e7 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.cpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.cpp @@ -16,7 +16,7 @@ #include "batched_gemm.hpp" template -float batched_gemm(const batched_gemm_kargs& args, const ck_tile::stream_config& s) +float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s) { // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. constexpr bool kPadM = false; @@ -79,9 +79,9 @@ float batched_gemm(const batched_gemm_kargs& args, const ck_tile::stream_config& // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. using Kernel = ck_tile::BatchedGemmKernel; - auto kargs = Kernel::MakeKargs(args); + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count); constexpr dim3 blocks = Kernel::BlockSize(); if(s.log_level_ > 0) diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.hpp b/example/ck_tile/16_batched_gemm/batched_gemm.hpp index e252c0f673..f0c0c9efba 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.hpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.hpp @@ -29,10 +29,6 @@ using BDataType = Types::BDataType; using AccDataType = Types::AccDataType; using CDataType = Types::CDataType; -struct batched_gemm_kargs : public ck_tile::BatchedGemmHostArgs -{ -}; - auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; @@ -60,4 +56,4 @@ auto create_args(int argc, char* argv[]) } // host API -float batched_gemm(batched_gemm_kargs args, const ck_tile::stream_config& s); +float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc index dacca2042e..4e7218b5b1 100644 --- a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc +++ b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc @@ -20,7 +20,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, int n_warmup, int n_repeat) { - batched_gemm_kargs args; + ck_tile::BatchedGemmHostArgs args; args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); @@ -188,15 +188,33 @@ int run_batched_gemm_example_with_layouts(int argc, c_m_n_gpu_ref.SetZero(); c_m_n_gpu_buf_ref.SetZero(); + ADataType* d_A; + BDataType* d_B; + CDataType* d_C; + + ck_tile::hip_check_error(hipMalloc(&d_A, batch_count * M * K * sizeof(ADataType))); + ck_tile::hip_check_error(hipMalloc(&d_B, batch_count * N * K * sizeof(BDataType))); + ck_tile::hip_check_error(hipMalloc(&d_C, batch_count * M * N * sizeof(CDataType))); + + ck_tile::hip_check_error(hipMemcpy(d_A, + a_m_k_dev_buf.GetDeviceBuffer(), + batch_count * M * K * sizeof(ADataType), + hipMemcpyHostToDevice)); + + ck_tile::hip_check_error(hipMemcpy(d_B, + b_k_n_dev_buf.GetDeviceBuffer(), + batch_count * N * K * sizeof(BDataType), + hipMemcpyHostToDevice)); + ck_tile::reference_batched_gemm_gpu(a_m_k_dev_buf, - b_k_n_dev_buf, - c_m_n_gpu_buf_ref, + CLayout>(d_A, + d_B, + d_C, M, N, K, @@ -208,6 +226,15 @@ int run_batched_gemm_example_with_layouts(int argc, batch_stride_C, batch_count); + ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(), + d_C, + batch_count * M * N * sizeof(CDataType), + hipMemcpyDeviceToHost)); + + ck_tile::hip_check_error(hipFree(d_A)); + ck_tile::hip_check_error(hipFree(d_B)); + ck_tile::hip_check_error(hipFree(d_C)); + c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref); diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 94af4711d1..20ba740884 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -34,13 +34,19 @@ using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; - arg_parser.insert("a_layout", "R", "A tensor data layout - Row by default") - .insert("b_layout", "R", "B tensor data layout - Row by default") - .insert("c_layout", "R", "C tensor data layout - Row by default") - .insert("validate", "1", "0. No validation, 1. Validation on CPU") - .insert("warmup", "10", "number of iterations before benchmark the kernel") - .insert("repeat", "100", "number of iterations to benchmark the kernel") - .insert("group_count", "16", "group count"); + arg_parser.insert("Ms", "", "M dimensions - empty by default.") + .insert("Ns", "", "N dimensions - empty by default.") + .insert("Ks", "", "K dimensions - empty by default.") + .insert("stride_As", "", "Tensor A strides - it is empty by default.") + .insert("stride_Bs", "", "Tensor B strides - it is empty by default.") + .insert("stride_Cs", "", "Tensor C strides - it is empty by default.") + .insert("a_layout", "R", "A tensor data layout - Row by default.") + .insert("b_layout", "R", "B tensor data layout - Row by default.") + .insert("c_layout", "R", "C tensor data layout - Row by default.") + .insert("validate", "1", "0. No validation, 1. Validation on CPU.") + .insert("warmup", "10", "number of iterations before benchmark the kernel.") + .insert("repeat", "100", "number of iterations to benchmark the kernel.") + .insert("group_count", "16", "group count."); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index cd5b1c2864..11faa6642c 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -53,26 +53,34 @@ int run_grouped_gemm_example_with_layouts(int argc, return -1; }; + auto valid_input_data = [&](int group_count, const auto&... args) { + return !(args.empty() || ...) && group_count == (args.size() == ...); + }; + const int group_count = arg_parser.get_int("group_count"); const int repeat = arg_parser.get_int("repeat"); const int warmup = arg_parser.get_int("warmup"); - std::vector Ms; - std::vector Ns; - std::vector Ks; - std::vector stride_As; - std::vector stride_Bs; - std::vector stride_Cs; + std::vector Ms = arg_parser.get_int_vec("Ms"); + std::vector Ns = arg_parser.get_int_vec("Ns"); + std::vector Ks = arg_parser.get_int_vec("Ks"); + std::vector stride_As = arg_parser.get_int_vec("stride_As"); + std::vector stride_Bs = arg_parser.get_int_vec("stride_Bs"); + std::vector stride_Cs = arg_parser.get_int_vec("stride_Cs"); - for(int i = 0; i < group_count; i++) + if(!valid_input_data(group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs)) { - Ms.push_back(256 + 256 * i); - Ns.push_back(128 + 128 * i); - Ks.push_back(128 + 64 * i); + std::cout << "Please check the input data. Default values will be used." << std::endl; + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(128 + 128 * i); + Ks.push_back(128 + 64 * i); - stride_As.push_back(Ks[i]); - stride_Bs.push_back(Ks[i]); - stride_Cs.push_back(Ns[i]); + stride_As.push_back(Ks[i]); + stride_Bs.push_back(Ks[i]); + stride_Cs.push_back(Ns[i]); + } } std::vector> a_m_k_tensors; diff --git a/include/ck_tile/core/container/meta_data_buffer.hpp b/include/ck_tile/core/container/meta_data_buffer.hpp index 7493b93d80..eba60fac75 100644 --- a/include/ck_tile/core/container/meta_data_buffer.hpp +++ b/include/ck_tile/core/container/meta_data_buffer.hpp @@ -30,7 +30,7 @@ struct meta_data_buffer { constexpr index_t size = sizeof(T); - auto tmp = bit_cast>(data); + auto tmp = ck_tile::bit_cast>(data); for(int i = 0; i < size; i++) { @@ -66,7 +66,7 @@ struct meta_data_buffer pos++; } - data = bit_cast(tmp); + data = ck_tile::bit_cast(tmp); } return data; @@ -86,7 +86,7 @@ struct meta_data_buffer pos++; } - auto data = bit_cast(tmp); + auto data = ck_tile::bit_cast(tmp); return data; } diff --git a/include/ck_tile/host/arg_parser.hpp b/include/ck_tile/host/arg_parser.hpp index 3765156df0..df309f312a 100644 --- a/include/ck_tile/host/arg_parser.hpp +++ b/include/ck_tile/host/arg_parser.hpp @@ -15,11 +15,14 @@ namespace ck_tile { /* - * a host side utility, arg parser for - * -[key0]=[value0] -[key1]=[value1] ... + * a host side utility, arg parser for, either + * -[key0] = [value0, value1, value2] + * or + * -[key0]=[value0] -[key1]=[value1] ... */ class ArgParser { + public: class Arg { @@ -187,6 +190,45 @@ class ArgParser return value; } + std::vector get_string_vec(const std::string& name, + const std::string& delimiter = ",") const + { + if(get_str(name).empty()) + { + return {}; + } + std::string s = get_str(name); + std::vector tokens; + size_t pos = 0; + std::string token; + while((pos = s.find(delimiter)) != std::string::npos) + { + token = s.substr(0, pos); + tokens.push_back(token); + s.erase(0, pos + delimiter.length()); + } + tokens.push_back(s); + + return tokens; + } + + std::vector get_int_vec(const std::string& name, const std::string& delimiter = ",") const + { + if(get_str(name).empty()) + { + return {}; + } + const std::vector args = get_string_vec(name, delimiter); + std::vector tokens; + tokens.reserve(static_cast(args.size())); + for(const std::string& token : args) + { + int value = atoi(token.c_str()); + tokens.push_back(value); + } + return tokens; + } + private: std::unordered_map input_map; std::vector keys; diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 8bd1f5b048..fc412e8831 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -97,9 +97,9 @@ template -void reference_gemm_gpu(DeviceMem& a_device, - DeviceMem& b_device, - DeviceMem& c_device, +void reference_gemm_gpu(ADataType* a_ptr, + BDataType* b_ptr, + CDataType* c_ptr, index_t M, index_t N, index_t K, @@ -107,79 +107,13 @@ void reference_gemm_gpu(DeviceMem& a_device, index_t stride_b, index_t stride_c) { - - ADataType* d_A; - BDataType* d_B; - CDataType* d_C; - - hipError_t errA = hipMalloc(&d_A, M * K * sizeof(ADataType)); - hipError_t errB = hipMalloc(&d_B, N * K * sizeof(BDataType)); - hipError_t errC = hipMalloc(&d_C, M * N * sizeof(CDataType)); - if(errA != hipSuccess) - { - std::cerr << "Error allocating device memory for A: " << hipGetErrorString(errA) - << std::endl; - return; // Early exit on error - } - - if(errB != hipSuccess) - { - std::cerr << "Error allocating device memory for B: " << hipGetErrorString(errB) - << std::endl; - return; // Early exit on error - } - - if(errC != hipSuccess) - { - std::cerr << "Error allocating device memory for C: " << hipGetErrorString(errC) - << std::endl; - return; // Early exit on error - } - - errA = hipMemcpy( - d_A, a_device.GetDeviceBuffer(), M * K * sizeof(ADataType), hipMemcpyHostToDevice); - if(errA != hipSuccess) - { - std::cerr << "Error copying A to device: " << hipGetErrorString(errA) << std::endl; - } - - errB = hipMemcpy( - d_B, b_device.GetDeviceBuffer(), N * K * sizeof(BDataType), hipMemcpyHostToDevice); - if(errB != hipSuccess) - { - std::cerr << "Error copying B to device: " << hipGetErrorString(errB) << std::endl; - } - int totalElements = M * N; int numThreadsPerBlock = 256; // Common choice for threads per block int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock; naive_gemm_kernel - <<>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c); - errC = hipMemcpy( - c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost); - if(errC != hipSuccess) - { - std::cerr << "Error copying C to device: " << hipGetErrorString(errC) << std::endl; - } - - errA = hipFree(d_A); - if(errA != hipSuccess) - { - std::cerr << "Error free the A memory: " << hipGetErrorString(errA) << std::endl; - } - - errB = hipFree(d_B); - if(errB != hipSuccess) - { - std::cerr << "Error free the B memory: " << hipGetErrorString(errB) << std::endl; - } - - errC = hipFree(d_C); - if(errC != hipSuccess) - { - std::cerr << "Error free the C memory: " << hipGetErrorString(errC) << std::endl; - } + <<>>( + a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c); return; } @@ -191,9 +125,9 @@ template -void reference_batched_gemm_gpu(DeviceMem& a_device, - DeviceMem& b_device, - DeviceMem& c_device, +void reference_batched_gemm_gpu(ADataType* a_ptr, + BDataType* b_ptr, + CDataType* c_ptr, index_t M, index_t N, index_t K, @@ -205,94 +139,20 @@ void reference_batched_gemm_gpu(DeviceMem& a_device, index_t batch_stride_C, index_t batch_count) { - - ADataType* d_A; - BDataType* d_B; - CDataType* d_C; - - hipError_t errA = hipMalloc(&d_A, batch_count * M * K * sizeof(ADataType)); - hipError_t errB = hipMalloc(&d_B, batch_count * N * K * sizeof(BDataType)); - hipError_t errC = hipMalloc(&d_C, batch_count * M * N * sizeof(CDataType)); - if(errA != hipSuccess) - { - std::cerr << "Error allocating device memory for A: " << hipGetErrorString(errA) - << std::endl; - return; // Early exit on error - } - - if(errB != hipSuccess) - { - std::cerr << "Error allocating device memory for B: " << hipGetErrorString(errB) - << std::endl; - return; // Early exit on error - } - - if(errC != hipSuccess) - { - std::cerr << "Error allocating device memory for C: " << hipGetErrorString(errC) - << std::endl; - return; // Early exit on error - } - - errA = hipMemcpy(d_A, - a_device.GetDeviceBuffer(), - batch_count * M * K * sizeof(ADataType), - hipMemcpyHostToDevice); - if(errA != hipSuccess) - { - std::cerr << "Error copying A to device: " << hipGetErrorString(errA) << std::endl; - } - - errB = hipMemcpy(d_B, - b_device.GetDeviceBuffer(), - batch_count * N * K * sizeof(BDataType), - hipMemcpyHostToDevice); - if(errB != hipSuccess) - { - std::cerr << "Error copying B to device: " << hipGetErrorString(errB) << std::endl; - } - int totalElements = M * N; int numThreadsPerBlock = 256; // Common choice for threads per block int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock; for(index_t batch_id = 0; batch_id < batch_count; ++batch_id) { - ADataType* d_ATemp = d_A + batch_id * batch_stride_A; - BDataType* d_BTemp = d_B + batch_id * batch_stride_B; - CDataType* d_CTemp = d_C + batch_id * batch_stride_C; + ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A; + BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B; + CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C; naive_gemm_kernel <<>>( d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c); } - errC = hipMemcpy(c_device.GetDeviceBuffer(), - d_C, - batch_count * M * N * sizeof(CDataType), - hipMemcpyDeviceToHost); - if(errC != hipSuccess) - { - std::cerr << "Error copying C to device: " << hipGetErrorString(errC) << std::endl; - } - - errA = hipFree(d_A); - if(errA != hipSuccess) - { - std::cerr << "Error free the A memory: " << hipGetErrorString(errA) << std::endl; - } - - errB = hipFree(d_B); - if(errB != hipSuccess) - { - std::cerr << "Error free the B memory: " << hipGetErrorString(errB) << std::endl; - } - - errC = hipFree(d_C); - if(errC != hipSuccess) - { - std::cerr << "Error free the C memory: " << hipGetErrorString(errC) << std::endl; - } - return; } } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index 07b4af5730..07a4cf8fbe 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -3,90 +3,93 @@ #pragma once -#include -#include - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" namespace ck_tile { -struct BatchedGemmHostArgs +struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs { - const void* a_ptr; - const void* b_ptr; - void* c_ptr; - index_t M; - index_t N; - index_t K; - index_t stride_A; - index_t stride_B; - index_t stride_C; - index_t batch_stride_A; - index_t batch_stride_B; - index_t batch_stride_C; - index_t batch_count; + CK_TILE_HOST BatchedGemmHostArgs() = default; + CK_TILE_HOST BatchedGemmHostArgs(const void* a_ptr_, + const void* b_ptr_, + void* c_ptr_, + ck_tile::index_t k_batch_, + ck_tile::index_t M_, + ck_tile::index_t N_, + ck_tile::index_t K_, + ck_tile::index_t stride_A_, + ck_tile::index_t stride_B_, + ck_tile::index_t stride_C_, + ck_tile::index_t batch_stride_A_, + ck_tile::index_t batch_stride_B_, + ck_tile::index_t batch_stride_C_, + ck_tile::index_t batch_count_) + : GemmHostArgs( + a_ptr_, b_ptr_, c_ptr_, k_batch_, M_, N_, K_, stride_A_, stride_B_, stride_C_), + batch_stride_A(batch_stride_A_), + batch_stride_B(batch_stride_B_), + batch_stride_C(batch_stride_C_), + batch_count(batch_count_) + { + } + + ck_tile::index_t batch_stride_A; + ck_tile::index_t batch_stride_B; + ck_tile::index_t batch_stride_C; + ck_tile::index_t batch_count; }; template -struct BatchedGemmKernel +struct BatchedGemmKernel : public GemmKernel { - using TilePartitioner = remove_cvref_t; - using GemmPipeline = remove_cvref_t; - using EpiloguePipeline = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + using Base = GemmKernel; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using GemmKernelArgs = typename Base::GemmKernelArgs; - struct BatchedGemmKargs + using ADataType = typename Base::ADataType; + using BDataType = typename Base::BDataType; + using CDataType = typename Base::CDataType; + + using TilePartitioner = typename Base::TilePartitioner; + using GemmPipeline = typename Base::GemmPipeline; + using EpiloguePipeline = typename Base::EpiloguePipeline; + using ALayout = typename Base::ALayout; + using BLayout = typename Base::BLayout; + using CLayout = typename Base::CLayout; + + struct BatchedGemmKernelArgs : GemmKernelArgs { - const void* a_ptr; - const void* b_ptr; - void* c_ptr; - index_t M; - index_t N; - index_t K; - index_t stride_A; - index_t stride_B; - index_t stride_C; index_t batch_stride_A; index_t batch_stride_B; index_t batch_stride_C; index_t batch_count; }; - using Kargs = BatchedGemmKargs; - using Hargs = BatchedGemmHostArgs; + using KernelArgs = BatchedGemmKernelArgs; - __host__ static constexpr auto GridSize(const Hargs& h) + __host__ static constexpr auto GridSize(index_t M, index_t N, index_t batch_count) { - return TilePartitioner::GridSize(h.M, h.N, h.batch_count); + return TilePartitioner::GridSize(M, N, batch_count); } - __host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + __host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); } - CK_TILE_HOST static constexpr BatchedGemmKargs MakeKargs(const Hargs& h) + CK_TILE_HOST static constexpr BatchedGemmKernelArgs + MakeKernelArgs(const BatchedGemmHostArgs& hostArgs) { - Kargs k; - k.a_ptr = h.a_ptr; - k.b_ptr = h.b_ptr; - k.c_ptr = h.c_ptr; - k.M = h.M; - k.N = h.N; - k.K = h.K; - k.stride_A = h.stride_A; - k.stride_B = h.stride_B; - k.stride_C = h.stride_C; - k.batch_stride_A = h.batch_stride_A; - k.batch_stride_B = h.batch_stride_B; - k.batch_stride_C = h.batch_stride_C; - k.batch_count = h.batch_count; - return k; + return BatchedGemmKernelArgs{{hostArgs.a_ptr, + hostArgs.b_ptr, + hostArgs.c_ptr, + hostArgs.M, + hostArgs.N, + hostArgs.K, + hostArgs.stride_A, + hostArgs.stride_B, + hostArgs.stride_C}, + hostArgs.batch_stride_A, + hostArgs.batch_stride_B, + hostArgs.batch_stride_C, + hostArgs.batch_count}; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() @@ -94,7 +97,7 @@ struct BatchedGemmKernel return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); } - CK_TILE_DEVICE void operator()(Kargs kargs) const + CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const { const auto [i_m, i_n] = TilePartitioner{}(); const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z); @@ -102,156 +105,17 @@ struct BatchedGemmKernel // options const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A); const auto batch_offset_A = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_A); - const ADataType* a_start = static_cast(kargs.a_ptr); + const ADataType* a_ptr = static_cast(kargs.a_ptr) + batch_offset_A; const auto batch_stride_B = __builtin_amdgcn_readfirstlane(kargs.batch_stride_B); const auto batch_offset_B = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_B); - const BDataType* b_start = static_cast(kargs.b_ptr); - - // Convert pointers to tensor views - auto a_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - a_start + batch_offset_A, - make_tuple(kargs.M, kargs.K), - make_tuple(kargs.stride_A, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - a_start + batch_offset_A, - make_tuple(kargs.M, kargs.K), - make_tuple(1, kargs.stride_A), - number<1>{}, - number<1>{}); - } - }(); - - auto b_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - b_start + batch_offset_B, - make_tuple(kargs.N, kargs.K), - make_tuple(1, kargs.stride_B), - number<1>{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - b_start + batch_offset_B, - make_tuple(kargs.N, kargs.K), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); - } - }(); - - auto a_pad_view = [&]() { - if constexpr(std::is_same_v) - { - return pad_tensor_view( - a_tensor_view, - make_tuple(number{}, number{}), - sequence{}); - } - else - { - return pad_tensor_view( - a_tensor_view, - make_tuple(number{}, number{}), - sequence{}); - } - }(); - // clang-format on - - auto a_block_window = make_tile_window( - a_pad_view, - make_tuple(number{}, number{}), - {i_m, 0}); - - auto b_pad_view = [&]() { - if constexpr(std::is_same_v) - { - return pad_tensor_view( - b_tensor_view, - make_tuple(number{}, number{}), - sequence{}); - } - else - { - return pad_tensor_view( - b_tensor_view, - make_tuple(number{}, number{}), - sequence{}); - } - }(); - // clang-format on - - auto b_block_window = make_tile_window( - b_pad_view, - make_tuple(number{}, number{}), - {i_n, 0}); - - // allocate LDS - __shared__ char smem_ptr[GetSmemSize()]; - - const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K); - - // Run GEMM cooperatively by whole wokrgroup. - auto c_block_tile = - GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr); + const BDataType* b_ptr = static_cast(kargs.b_ptr) + batch_offset_B; const auto batch_stride_C = __builtin_amdgcn_readfirstlane(kargs.batch_stride_C); const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_C); - CDataType* c_start = static_cast(kargs.c_ptr); - auto c_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - c_start + batch_offset_C, - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_C, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - c_start + batch_offset_C, - make_tuple(kargs.M, kargs.N), - make_tuple(1, kargs.stride_C), - number<1>{}, - number<1>{}); - } - }(); - - auto c_pad_view = [&]() { - if constexpr(std::is_same_v) - { - return pad_tensor_view( - c_tensor_view, - make_tuple(number{}, number{}), - sequence{}); - } - else - { - return pad_tensor_view( - c_tensor_view, - make_tuple(number{}, number{}), - sequence{}); - } - }(); - auto c_block_window = make_tile_window( - c_pad_view, - make_tuple(number{}, number{}), - {i_m, i_n}); + CDataType* c_ptr = static_cast(kargs.c_ptr) + batch_offset_C; - EpiloguePipeline{}(c_block_window, c_block_tile); + this->RunGemm(a_ptr, b_ptr, c_ptr, kargs, i_m, i_n); } }; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 763d8cad9c..925648a886 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -12,6 +12,50 @@ namespace ck_tile { +struct GemmProblem +{ + CK_TILE_HOST GemmProblem() = default; + CK_TILE_HOST GemmProblem( + index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_) + : M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_) + { + } + + index_t M; + index_t N; + index_t K; + index_t stride_A; + index_t stride_B; + index_t stride_C; +}; + +struct GemmHostArgs : public GemmProblem +{ + CK_TILE_HOST GemmHostArgs() = default; + CK_TILE_HOST GemmHostArgs(const void* a_ptr_, + const void* b_ptr_, + void* c_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + index_t stride_A_, + index_t stride_B_, + index_t stride_C_) + : GemmProblem(M_, N_, K_, stride_A_, stride_B_, stride_C_), + a_ptr(a_ptr_), + b_ptr(b_ptr_), + c_ptr(c_ptr_), + k_batch(k_batch_) + { + } + + const void* a_ptr; + const void* b_ptr; + void* c_ptr; + index_t k_batch; +}; + template struct GemmKernel { @@ -25,9 +69,12 @@ struct GemmKernel using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; - // using CAccDataType = remove_cvref_t; using CDataType = remove_cvref_t; + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + __host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) { return TilePartitioner::GridSize(M, N, KBatch); @@ -35,7 +82,7 @@ struct GemmKernel __host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); } - struct GemmCommonKargs + struct GemmKernelArgs { const void* a_ptr; const void* b_ptr; @@ -48,25 +95,37 @@ struct GemmKernel index_t stride_C; }; - CK_TILE_HOST static constexpr GemmCommonKargs MakeKargs(const void* a_ptr, - const void* b_ptr, - void* c_ptr, - index_t M, - index_t N, - index_t K, - index_t stride_A, - index_t stride_B, - index_t stride_C) + CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs) { - return GemmCommonKargs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C}; + return GemmKernelArgs{hostArgs.a_ptr, + hostArgs.b_ptr, + hostArgs.c_ptr, + hostArgs.M, + hostArgs.N, + hostArgs.K, + hostArgs.stride_A, + hostArgs.stride_B, + hostArgs.stride_C}; } + // CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const void* a_ptr, + // const void* b_ptr, + // void* c_ptr, + // index_t M, + // index_t N, + // index_t K, + // index_t stride_A, + // index_t stride_B, + // index_t stride_C) + // { + // return GemmKernelArgs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C}; + // } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); } - CK_TILE_HOST static bool IsSupportedArgument(const GemmCommonKargs& kargs) + CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs) { if constexpr(std::is_same_v) { @@ -139,18 +198,16 @@ struct GemmKernel return true; } - CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const + CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType* a_ptr, + const BDataType* b_ptr, + CDataType* c_ptr, + const GemmKernelArgs& kargs) const { - const auto [i_m, i_n] = TilePartitioner{}(); - // options - const ADataType* a_start = static_cast(kargs.a_ptr); - const BDataType* b_start = static_cast(kargs.b_ptr); - // Convert pointers to tensor views - auto a_tensor_view = [&]() { + const auto& a_tensor_view = [&]() { if constexpr(std::is_same_v) { return make_naive_tensor_view( - a_start, + a_ptr, make_tuple(kargs.M, kargs.K), make_tuple(kargs.stride_A, 1), number{}, @@ -159,7 +216,7 @@ struct GemmKernel else { return make_naive_tensor_view( - a_start, + a_ptr, make_tuple(kargs.M, kargs.K), make_tuple(1, kargs.stride_A), number<1>{}, @@ -167,11 +224,11 @@ struct GemmKernel } }(); - auto b_tensor_view = [&]() { + const auto& b_tensor_view = [&]() { if constexpr(std::is_same_v) { return make_naive_tensor_view( - b_start, + b_ptr, make_tuple(kargs.N, kargs.K), make_tuple(1, kargs.stride_B), number<1>{}, @@ -180,7 +237,7 @@ struct GemmKernel else { return make_naive_tensor_view( - b_start, + b_ptr, make_tuple(kargs.N, kargs.K), make_tuple(kargs.stride_B, 1), number{}, @@ -188,7 +245,35 @@ struct GemmKernel } }(); - auto a_pad_view = [&]() { + const auto& c_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + c_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.stride_C, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + c_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(1, kargs.stride_C), + number<1>{}, + number<1>{}); + } + }(); + + return make_tuple(a_tensor_view, b_tensor_view, c_tensor_view); + } + + template + CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView& views) const + { + const auto& a_pad_view = [&]() { + const auto& a_tensor_view = views.at(I0); if constexpr(std::is_same_v) { return pad_tensor_view( @@ -204,14 +289,9 @@ struct GemmKernel sequence{}); } }(); - // clang-format on - - auto a_block_window = make_tile_window( - a_pad_view, - make_tuple(number{}, number{}), - {i_m, 0}); - auto b_pad_view = [&]() { + const auto& b_pad_view = [&]() { + const auto& b_tensor_view = views.at(I1); if constexpr(std::is_same_v) { return pad_tensor_view( @@ -228,43 +308,8 @@ struct GemmKernel } }(); - auto b_block_window = make_tile_window( - b_pad_view, - make_tuple(number{}, number{}), - {i_n, 0}); - - // allocate LDS - __shared__ char smem_ptr[GetSmemSize()]; - - const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K); - - // Run GEMM cooperatively by whole wokrgroup. - auto c_block_tile = - GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr); - - CDataType* c_start = static_cast(kargs.c_ptr); - auto c_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - c_start, - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_C, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - c_start, - make_tuple(kargs.M, kargs.N), - make_tuple(1, kargs.stride_C), - number<1>{}, - number<1>{}); - } - }(); - - auto c_pad_view = [&]() { + const auto& c_pad_view = [&]() { + const auto& c_tensor_view = views.at(I2); if constexpr(std::is_same_v) { return pad_tensor_view( @@ -280,12 +325,82 @@ struct GemmKernel sequence{}); } }(); - auto CBlockWindow_pad = make_tile_window( + + return make_tuple(a_pad_view, b_pad_view, c_pad_view); + } + + template + CK_TILE_DEVICE auto + MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) const + { + const auto& a_pad_view = views.at(I0); + const auto& a_block_window = make_tile_window( + a_pad_view, + make_tuple(number{}, number{}), + {i_m, 0}); + + const auto& b_pad_view = views.at(I1); + const auto& b_block_window = make_tile_window( + b_pad_view, + make_tuple(number{}, number{}), + {i_n, 0}); + + const auto& c_pad_view = views.at(I2); + auto c_block_window = make_tile_window( c_pad_view, make_tuple(number{}, number{}), {i_m, i_n}); - EpiloguePipeline{}(CBlockWindow_pad, c_block_tile); + return make_tuple(a_block_window, b_block_window, c_block_window); + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @param a_ptr input A pointer + * @param b_ptr input B pointer + * @param c_ptr output C pointer + * @param kargs GEMM kernel arguments + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + */ + CK_TILE_DEVICE void RunGemm(const ADataType* a_ptr, + const BDataType* b_ptr, + CDataType* c_ptr, + const GemmKernelArgs& kargs, + const index_t block_idx_m, + const index_t block_idx_n) const + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs); + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K); + + // Run GEMM cooperatively by whole workgroup. + const auto& a_block_window = gemm_tile_windows.at(I0); + const auto& b_block_window = gemm_tile_windows.at(I1); + const auto& c_block_tile = + GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I2); + EpiloguePipeline{}(c_block_window, c_block_tile); + } + + CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const + { + const auto [i_m, i_n] = TilePartitioner{}(); + // options + const ADataType* a_ptr = static_cast(kargs.a_ptr); + const BDataType* b_ptr = static_cast(kargs.b_ptr); + CDataType* c_ptr = static_cast(kargs.c_ptr); + + RunGemm(a_ptr, b_ptr, c_ptr, kargs, i_m, i_n); } }; diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp old mode 100644 new mode 100755 index 6e8d5c798b..5460f7f857 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp @@ -41,6 +41,8 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances = st //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, @@ -49,7 +51,9 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances = st DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 2, 2, 32, 32, 2, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; @@ -61,14 +65,21 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances = std //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // Latency friendly + // Latency friendly DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 4, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<32, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 2, 2, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, // Memory friendly DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 8, 2, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 2, 2, 32, 32, 2, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 2, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 2, 2, 16, 16, 4, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 64, 8, 4, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 4, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 64, 8, 4, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, @@ -82,6 +93,7 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances = std DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 4, 4, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp old mode 100644 new mode 100755 index e00c1733e0..e716b3e85c --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp @@ -42,14 +42,21 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances = st // Compute friendly DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 4, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - // AGPR Spill - // DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, - // AGPR Spill when use permuted lds layout. so, use padding for these two. + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 16, 16, 8, 8, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, @@ -68,15 +75,23 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances = std //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // Latency friendly + // Latency friendly DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 2, 2, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, // Memory friendly DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 4, 4, 32, 32, 2, 1, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 2, 2, 32, 32, 2, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 64, 4, 4, 32, 32, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 64, 2, 2, 32, 32, 2, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, @@ -84,12 +99,16 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances = std DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 4, 4, 32, 32, 1, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 2, 2, 32, 32, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> // clang-format on >; } // namespace instance diff --git a/profiler/include/profiler/profile_grouped_gemm_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_impl.hpp index c10cd0ea9f..367e94de11 100644 --- a/profiler/include/profiler/profile_grouped_gemm_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_impl.hpp @@ -77,7 +77,7 @@ bool profile_grouped_gemm_impl(int do_verification, std::vector> c_m_n_host_results; std::vector> c_m_n_device_results; - ComputeDataType max_abs_in_val = 0.f; + double max_abs_in_val = 0.f; for(std::size_t i = 0; i < group_count; i++) { a_m_k.push_back( diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp index 88145b987b..d3f3077870 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp +++ b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp @@ -24,12 +24,9 @@ class TestCkTileBatchedGemm : public ::testing::Test using AccDataType = std::tuple_element_t<5, Tuple>; using CDataType = std::tuple_element_t<6, Tuple>; - struct batched_gemm_kargs : public ck_tile::BatchedGemmHostArgs - { - }; - template - void invoke_batched_gemm(const batched_gemm_kargs& args, const ck_tile::stream_config& s) + void invoke_batched_gemm(const ck_tile::BatchedGemmHostArgs& args, + const ck_tile::stream_config& s) { // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. constexpr bool kPadM = false; @@ -94,9 +91,9 @@ class TestCkTileBatchedGemm : public ::testing::Test using Kernel = ck_tile::BatchedGemmKernel; - auto kargs = Kernel::MakeKargs(args); + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count); constexpr dim3 blocks = Kernel::BlockSize(); if(s.log_level_ > 0) @@ -185,21 +182,22 @@ class TestCkTileBatchedGemm : public ::testing::Test c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - batched_gemm_kargs kargs{a_m_k_dev_buf.GetDeviceBuffer(), - b_k_n_dev_buf.GetDeviceBuffer(), - c_m_n_dev_buf.GetDeviceBuffer(), - M, - N, - K, - StrideA, - StrideB, - StrideC, - BatchStrideA, - BatchStrideB, - BatchStrideC, - BatchCount}; - - invoke_batched_gemm(kargs, + ck_tile::BatchedGemmHostArgs args; + args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); + args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); + args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); + args.M = M; + args.N = N; + args.K = K; + args.stride_A = StrideA; + args.stride_B = StrideB; + args.stride_C = StrideC; + args.batch_stride_A = BatchStrideA; + args.batch_stride_B = BatchStrideB; + args.batch_stride_C = BatchStrideC; + args.batch_count = BatchCount; + + invoke_batched_gemm(args, ck_tile::stream_config{nullptr, false}); std::cout << "Run kernel with M =" << M << " N =" << N << " K =" << K diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index a514986024..53ead4d8d6 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -31,22 +31,8 @@ class TestCkTileGemmPipeline : public ::testing::Test static constexpr auto PipelineType = std::tuple_element_t<8, Tuple>::value; // TODO: expose tile size through test t-param ? - struct gemm_args - { - const void* p_a; - const void* p_b; - void* p_c; - ck_tile::index_t kbatch; - ck_tile::index_t M; - ck_tile::index_t N; - ck_tile::index_t K; - ck_tile::index_t stride_A; - ck_tile::index_t stride_B; - ck_tile::index_t stride_C; - }; - template - void invoke_gemm(const gemm_args& args, const ck_tile::stream_config& s) + void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { // TODO: This should be parameterized in tests constexpr ck_tile::index_t M_Tile = 128; @@ -117,17 +103,9 @@ class TestCkTileGemmPipeline : public ::testing::Test has_hot_loop_v, tail_number_v>>>; using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKargs(args.p_a, - args.p_b, - args.p_c, - args.M, - args.N, - args.K, - args.stride_A, - args.stride_B, - args.stride_C); - - const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); constexpr dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) @@ -319,11 +297,11 @@ class TestCkTileGemmPipeline : public ::testing::Test c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - gemm_args args; - args.p_a = a_m_k_dev_buf.GetDeviceBuffer(); - args.p_b = b_k_n_dev_buf.GetDeviceBuffer(); - args.p_c = c_m_n_dev_buf.GetDeviceBuffer(); - args.kbatch = kbatch; + ck_tile::GemmHostArgs args; + args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); + args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); + args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); + args.k_batch = kbatch; args.M = M; args.N = N; args.K = K;