Skip to content

Commit

Permalink
Promote latest CK develop (#1767)
Browse files Browse the repository at this point in the history
* updated fp16 instances to be on parity with universal gemm instances (#1754)

* updated fp16 instances to be on parity with universal gemm instances

* corrected instance name to streamk instance

* [CK_TILE] Move hipmalloc/memcpy calls out of gpu reference gemm (#1743)

* [CK_TILE] Move hipmalloc/memcpy calls out of gpu reference gemm

* [CK_TILE] Move hipmalloc/memcpy calls out of gpu reference gemm - review changes

* [CK_TILE] Move hipmalloc/memcpy calls out of gpu reference gemm - review fix

* Disambiguate bit_cast (#1749)

Adding namespace to disambiguate with std::bit_cast

Co-authored-by: Po Yen Chen <[email protected]>

* [CK TILE] Refactor GemmKernel to be reused by other GEMM related operators (#1730)

* Gemm Kernel Refactor part1

* Gemm Kernel Refactor common gemm pipeline part2

* [CK TILE] Refactor batched gemm to reuse GemmKernel

* [CK TILE] Refactor GemmKernel - review changes part1

* [CK TILE] Refactor GemmKernel - references fix

* [CK TILE] Refactor GemmKernel - naming changes, add problem

* [CK_TILE] Refactor GemmKernel - update tests

* [CK_TILE] Refactor GemmKernel - review changes

* [CK_TILE] Refactor GemmKernel - update test

* [CK_TILE] Refactor GemmKernel - constness fixes

* [CK_TILE] Refactor GemmKernel - update tests

* Apply Ck-tile argument parser for vectors [I/O] (#1758)

* Parser for a vector was added. Additionaly we valid correctnes of numbers

* Remove unnecessary comments

* Review part 1

* Review part 2

* Add const to variadic lambda

* Rename C->K

* fix profiler_grouped_gemm (#1766)

---------

Co-authored-by: Harisankar Sadasivan <[email protected]>
Co-authored-by: aledudek <[email protected]>
Co-authored-by: Xiaodong Wang <[email protected]>
Co-authored-by: Po Yen Chen <[email protected]>
Co-authored-by: Mateusz Ozga <[email protected]>
  • Loading branch information
6 people authored Dec 20, 2024
1 parent 05e8a4c commit da0295b
Show file tree
Hide file tree
Showing 18 changed files with 489 additions and 561 deletions.
16 changes: 4 additions & 12 deletions example/ck_tile/03_gemm/gemm_basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "gemm_basic.hpp"

template <typename ALayout, typename BLayout, typename CLayout>
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false;
Expand Down Expand Up @@ -79,17 +79,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;

auto kargs = Kernel::MakeKargs(args.p_a,
args.p_b,
args.p_c,
args.M,
args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C);

const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
auto kargs = Kernel::MakeKernelArgs(args);

const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();

if(!Kernel::IsSupportedArgument(kargs))
Expand Down
16 changes: 1 addition & 15 deletions example/ck_tile/03_gemm/gemm_basic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,6 @@ using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType;

struct gemm_basic_args
{
const void* p_a;
const void* p_b;
void* p_c;
ck_tile::index_t kbatch;
ck_tile::index_t M;
ck_tile::index_t N;
ck_tile::index_t K;
ck_tile::index_t stride_A;
ck_tile::index_t stride_B;
ck_tile::index_t stride_C;
};

auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
Expand All @@ -89,4 +75,4 @@ auto create_args(int argc, char* argv[])
}

// host API
float gemm_calc(gemm_basic_args args, const ck_tile::stream_config& s);
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);
39 changes: 32 additions & 7 deletions example/ck_tile/03_gemm/run_gemm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_warmup,
int n_repeat)
{
gemm_basic_args args;
args.p_a = a_m_k_dev_buf.GetDeviceBuffer();
args.p_b = b_k_n_dev_buf.GetDeviceBuffer();
args.p_c = c_m_n_dev_buf.GetDeviceBuffer();
args.kbatch = kbatch;
ck_tile::GemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = kbatch;
args.M = M;
args.N = N;
args.K = K;
Expand Down Expand Up @@ -161,14 +161,39 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_gpu_ref.SetZero();
c_m_n_gpu_buf_ref.SetZero();

ADataType* d_A;
BDataType* d_B;
CDataType* d_C;

ck_tile::hip_check_error(hipMalloc(&d_A, M * K * sizeof(ADataType)));
ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType)));
ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType)));

ck_tile::hip_check_error(hipMemcpy(d_A,
a_m_k_dev_buf.GetDeviceBuffer(),
M * K * sizeof(ADataType),
hipMemcpyHostToDevice));
ck_tile::hip_check_error(hipMemcpy(d_B,
b_k_n_dev_buf.GetDeviceBuffer(),
N * K * sizeof(BDataType),
hipMemcpyHostToDevice));

ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_gpu_buf_ref, M, N, K, stride_A, stride_B, stride_C);
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);

ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(),
d_C,
M * N * sizeof(CDataType),
hipMemcpyDeviceToHost));

ck_tile::hip_check_error(hipFree(d_A));
ck_tile::hip_check_error(hipFree(d_B));
ck_tile::hip_check_error(hipFree(d_C));

c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref);
Expand Down
6 changes: 3 additions & 3 deletions example/ck_tile/16_batched_gemm/batched_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include "batched_gemm.hpp"

template <typename ALayout, typename BLayout, typename CLayout>
float batched_gemm(const batched_gemm_kargs& args, const ck_tile::stream_config& s)
float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false;
Expand Down Expand Up @@ -79,9 +79,9 @@ float batched_gemm(const batched_gemm_kargs& args, const ck_tile::stream_config&
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;

auto kargs = Kernel::MakeKargs(args);
auto kargs = Kernel::MakeKernelArgs(args);

const dim3 grids = Kernel::GridSize(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count);
constexpr dim3 blocks = Kernel::BlockSize();

if(s.log_level_ > 0)
Expand Down
6 changes: 1 addition & 5 deletions example/ck_tile/16_batched_gemm/batched_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType;

struct batched_gemm_kargs : public ck_tile::BatchedGemmHostArgs
{
};

auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
Expand Down Expand Up @@ -60,4 +56,4 @@ auto create_args(int argc, char* argv[])
}

// host API
float batched_gemm(batched_gemm_kargs args, const ck_tile::stream_config& s);
float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s);
35 changes: 31 additions & 4 deletions example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_warmup,
int n_repeat)
{
batched_gemm_kargs args;
ck_tile::BatchedGemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
Expand Down Expand Up @@ -188,15 +188,33 @@ int run_batched_gemm_example_with_layouts(int argc,
c_m_n_gpu_ref.SetZero();
c_m_n_gpu_buf_ref.SetZero();

ADataType* d_A;
BDataType* d_B;
CDataType* d_C;

ck_tile::hip_check_error(hipMalloc(&d_A, batch_count * M * K * sizeof(ADataType)));
ck_tile::hip_check_error(hipMalloc(&d_B, batch_count * N * K * sizeof(BDataType)));
ck_tile::hip_check_error(hipMalloc(&d_C, batch_count * M * N * sizeof(CDataType)));

ck_tile::hip_check_error(hipMemcpy(d_A,
a_m_k_dev_buf.GetDeviceBuffer(),
batch_count * M * K * sizeof(ADataType),
hipMemcpyHostToDevice));

ck_tile::hip_check_error(hipMemcpy(d_B,
b_k_n_dev_buf.GetDeviceBuffer(),
batch_count * N * K * sizeof(BDataType),
hipMemcpyHostToDevice));

ck_tile::reference_batched_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_gpu_buf_ref,
CLayout>(d_A,
d_B,
d_C,
M,
N,
K,
Expand All @@ -208,6 +226,15 @@ int run_batched_gemm_example_with_layouts(int argc,
batch_stride_C,
batch_count);

ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(),
d_C,
batch_count * M * N * sizeof(CDataType),
hipMemcpyDeviceToHost));

ck_tile::hip_check_error(hipFree(d_A));
ck_tile::hip_check_error(hipFree(d_B));
ck_tile::hip_check_error(hipFree(d_C));

c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref);

Expand Down
20 changes: 13 additions & 7 deletions example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,19 @@ using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "R", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("validate", "1", "0. No validation, 1. Validation on CPU")
.insert("warmup", "10", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("group_count", "16", "group count");
arg_parser.insert("Ms", "", "M dimensions - empty by default.")
.insert("Ns", "", "N dimensions - empty by default.")
.insert("Ks", "", "K dimensions - empty by default.")
.insert("stride_As", "", "Tensor A strides - it is empty by default.")
.insert("stride_Bs", "", "Tensor B strides - it is empty by default.")
.insert("stride_Cs", "", "Tensor C strides - it is empty by default.")
.insert("a_layout", "R", "A tensor data layout - Row by default.")
.insert("b_layout", "R", "B tensor data layout - Row by default.")
.insert("c_layout", "R", "C tensor data layout - Row by default.")
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
.insert("group_count", "16", "group count.");

bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
Expand Down
34 changes: 21 additions & 13 deletions example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -53,26 +53,34 @@ int run_grouped_gemm_example_with_layouts(int argc,
return -1;
};

auto valid_input_data = [&](int group_count, const auto&... args) {
return !(args.empty() || ...) && group_count == (args.size() == ...);
};

const int group_count = arg_parser.get_int("group_count");
const int repeat = arg_parser.get_int("repeat");
const int warmup = arg_parser.get_int("warmup");

std::vector<ck_tile::index_t> Ms;
std::vector<ck_tile::index_t> Ns;
std::vector<ck_tile::index_t> Ks;
std::vector<ck_tile::index_t> stride_As;
std::vector<ck_tile::index_t> stride_Bs;
std::vector<ck_tile::index_t> stride_Cs;
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
std::vector<ck_tile::index_t> stride_As = arg_parser.get_int_vec("stride_As");
std::vector<ck_tile::index_t> stride_Bs = arg_parser.get_int_vec("stride_Bs");
std::vector<ck_tile::index_t> stride_Cs = arg_parser.get_int_vec("stride_Cs");

for(int i = 0; i < group_count; i++)
if(!valid_input_data(group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs))
{
Ms.push_back(256 + 256 * i);
Ns.push_back(128 + 128 * i);
Ks.push_back(128 + 64 * i);
std::cout << "Please check the input data. Default values will be used." << std::endl;
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256 + 256 * i);
Ns.push_back(128 + 128 * i);
Ks.push_back(128 + 64 * i);

stride_As.push_back(Ks[i]);
stride_Bs.push_back(Ks[i]);
stride_Cs.push_back(Ns[i]);
stride_As.push_back(Ks[i]);
stride_Bs.push_back(Ks[i]);
stride_Cs.push_back(Ns[i]);
}
}

std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
Expand Down
6 changes: 3 additions & 3 deletions include/ck_tile/core/container/meta_data_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ struct meta_data_buffer
{
constexpr index_t size = sizeof(T);

auto tmp = bit_cast<array<std::byte, size>>(data);
auto tmp = ck_tile::bit_cast<array<std::byte, size>>(data);

for(int i = 0; i < size; i++)
{
Expand Down Expand Up @@ -66,7 +66,7 @@ struct meta_data_buffer
pos++;
}

data = bit_cast<T>(tmp);
data = ck_tile::bit_cast<T>(tmp);
}

return data;
Expand All @@ -86,7 +86,7 @@ struct meta_data_buffer
pos++;
}

auto data = bit_cast<T>(tmp);
auto data = ck_tile::bit_cast<T>(tmp);

return data;
}
Expand Down
46 changes: 44 additions & 2 deletions include/ck_tile/host/arg_parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@

namespace ck_tile {
/*
* a host side utility, arg parser for
* -[key0]=[value0] -[key1]=[value1] ...
* a host side utility, arg parser for, either
* -[key0] = [value0, value1, value2]
* or
* -[key0]=[value0] -[key1]=[value1] ...
*/
class ArgParser
{

public:
class Arg
{
Expand Down Expand Up @@ -187,6 +190,45 @@ class ArgParser
return value;
}

std::vector<std::string> get_string_vec(const std::string& name,
const std::string& delimiter = ",") const
{
if(get_str(name).empty())
{
return {};
}
std::string s = get_str(name);
std::vector<std::string> tokens;
size_t pos = 0;
std::string token;
while((pos = s.find(delimiter)) != std::string::npos)
{
token = s.substr(0, pos);
tokens.push_back(token);
s.erase(0, pos + delimiter.length());
}
tokens.push_back(s);

return tokens;
}

std::vector<int> get_int_vec(const std::string& name, const std::string& delimiter = ",") const
{
if(get_str(name).empty())
{
return {};
}
const std::vector<std::string> args = get_string_vec(name, delimiter);
std::vector<int> tokens;
tokens.reserve(static_cast<int>(args.size()));
for(const std::string& token : args)
{
int value = atoi(token.c_str());
tokens.push_back(value);
}
return tokens;
}

private:
std::unordered_map<std::string, Arg> input_map;
std::vector<std::string> keys;
Expand Down
Loading

0 comments on commit da0295b

Please sign in to comment.