-
Notifications
You must be signed in to change notification settings - Fork 139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Apply universal gemm to bwd_weight_cshuffle operator #1658
base: develop
Are you sure you want to change the base?
Conversation
118fd3f
to
932e6a0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will continue review later
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
Outdated
Show resolved
Hide resolved
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
Outdated
Show resolved
Hide resolved
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @mozga-amd please remove merged groups support for this kernel. Then please update instances. After that we can review it again
932e6a0
to
860433e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nothing to review for docs. If anything needs to be documented, let me know.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a few comment. Please focus firstly on:
Regarding API changes, we have two solutions but please make some local measurements:
-
If gridwise gemm v3 is better for each case can we move BlkGemmPipeSched and BlgGemmPipVer at the end and keep api consistent?
-
If gridwise gemm v3 is better but not for each, then can we restore previous implementation and copy your new DeviceGroupedConvBwdWeight_Xdl_CShuffle as DeviceGroupedConvBwdWeight_Xdl_CShuffleV3?
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
Outdated
Show resolved
Hide resolved
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer | ||
ComputeTypeA, // ComputeTypeA | ||
ComputeTypeB>; // ComputeTypeB | ||
// clang-format on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding API changes, we have two solutions but please make some local measurements:
- If gridwise gemm v3 is better for each case can we move BlkGemmPipeSched and BlgGemmPipVer at the end and keep api consistent?
- If gridwise gemm v3 is better but not for each, then can we restore previous implementation and copy your new DeviceGroupedConvBwdWeight_Xdl_CShuffle as DeviceGroupedConvBwdWeight_Xdl_CShuffleV3?
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
Outdated
Show resolved
Hide resolved
const index_t K, | ||
const std::array<index_t, NDimSpatial + 3>& output_strides) | ||
{ | ||
const index_t BatchStride = output_strides[0]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For, out, in and wei. Can we add some condintion if (NumGroupsToMerge == 1) {/* Create descriptor in typical way (like in v1) ?} Or if possible use transform_v1 in device_op? I think it could inpact on performance
No description provided.