-
Notifications
You must be signed in to change notification settings - Fork 139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CK TILE] GEMM and Batched GEMM SplitK support #1724
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good work! However I have few things for reconsideration.
memory_operation_enum out_memory_data_op = memory_operation_enum::set> | ||
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if you could add third function parameter memory_operation_enum o_mem_data_op = out_memory_data_op
? Then you wouldn't have to pass all template parameters, but just pass memory op if you need different one than default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is not possible to pass it because it is enum. Then you cannot compare object from argument in if constexpr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bartekxk You're right. What about having a store_tile
API which get's as as last paramter memory operation enum ? And for set
it will do store while for atomic_add
and others it will do update ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can do it at now there are store_tile and update_tile, so it is just concept
@@ -158,12 +167,26 @@ struct CShuffleEpilogue | |||
// Store the tile data to the permuted location | |||
if constexpr(kPadM || kPadN) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bartekxk By the way do we really need here this check? The *_raw
version of tile API just does things using assembly... I'm not sure if we really need it here. The plain tile API should work as well regardless of padding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure we need it here? Looks like it could improve performance like for example here: #1752
@@ -54,8 +54,7 @@ using CDataType = Types::CDataType; | |||
auto create_args(int argc, char* argv[]) | |||
{ | |||
ck_tile::ArgParser arg_parser; | |||
arg_parser.insert("b", "1", "batch size") | |||
.insert("m", "3840", "m dimension") | |||
arg_parser.insert("m", "3840", "m dimension") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we not supporting batch (b
) in this example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is simple gemm, not batched.
@@ -78,7 +78,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) | |||
#endif | |||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>; | |||
|
|||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K); | |||
const ck_tile::index_t k_grain = args.k_batch * K_Tile; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it true that if we set the split_k=1 from cmd arg, the kernel will run only K_Tile for each kernel's unroll? what about if we want to disable split-k from cmd args, is it through split_k=0? or not considered?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is analogus to just round up K dimension in the case of split_k=1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
No description provided.