Skip to content

Commit

Permalink
Disable building DPP kernels by default (#1804)
Browse files Browse the repository at this point in the history
* Disable building DPP kernels by default

* Disable building dpp instances, examples, or tests if DPP_KERNELS is not set

* Add new DPP_KERNELS flag to readme
  • Loading branch information
darren-amd authored Jan 8, 2025
1 parent ad697c7 commit 26b3829
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 40 deletions.
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ if(DL_KERNELS)
add_definitions(-DDL_KERNELS)
set(CK_ENABLE_DL_KERNELS "ON")
endif()
if(DPP_KERNELS)
add_definitions(-DDPP_KERNELS)
set(CK_ENABLE_DPP_KERNELS "ON")
endif()
option(CK_USE_CODEGEN "Enable codegen library" OFF)
if(CK_USE_CODEGEN)
add_definitions(-DCK_USE_CODEGEN)
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ Additional cmake flags can be used to significantly speed-up the build:
`batched_gemm_multi_d_dl`. These instances are useful on architectures like the NAVI2x, as most
other platforms have faster instances, such as `xdl` or `wmma`, available.

* `DPP_KERNELS` (default is OFF) must be set to ON in order to build instances, such as `gemm_dpp`.
These instances are useful on architectures like the NAVI2x, as most other platforms have faster instances, such as `xdl` or `wmma`, available.

* `CK_USE_FP8_ON_UNSUPPORTED_ARCH` (default is OFF) must be set to ON in order to build instances,
such as `gemm_universal`, `gemm_universal_streamk` and `gemm_multiply_multiply` for fp8 data type for GPU targets which do not have native support for fp8 data type, such as gfx908 or gfx90a. These instances are useful on
architectures like the MI100/MI200 for the functional support only.
Expand Down
4 changes: 2 additions & 2 deletions example/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any DPP examples if DL_KERNELS not set
#Do not build any DPP examples if DPP_KERNELS not set
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dpp")
if(NOT DEFINED DPP_KERNELS AND source MATCHES "_dpp")
message("removing dpp example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
Expand Down
4 changes: 4 additions & 0 deletions include/ck/config.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
#endif

#ifndef CK_ENABLE_DPP_KERNELS
#cmakedefine CK_ENABLE_DPP_KERNELS @CK_ENABLE_DPP_KERNELS@
#endif

//
// CK kernels which support XDL (MI series)
//
Expand Down
44 changes: 36 additions & 8 deletions library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
#ifdef DL_KERNELS
#include "gemm_dl.inc"
#endif
#ifdef DPP_KERNELS
#include "gemm_dpp.inc"
#endif
#ifdef CK_USE_WMMA
#include "gemm_wmma.inc"
#endif
Expand Down Expand Up @@ -92,32 +95,24 @@ struct DeviceOperationInstanceFactory<
{
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs);
}
}
#endif
Expand Down Expand Up @@ -153,6 +148,39 @@ struct DeviceOperationInstanceFactory<
#endif
#endif // DL_KERNELS

#ifdef DPP_KERNELS
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(op_ptrs);
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs);
}
}
#endif
#endif // DPP_KERNELS

#ifdef CK_USE_WMMA
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,6 @@ void add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);

void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);

void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);

void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
Expand All @@ -48,16 +38,6 @@ void add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);

void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);

void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);

void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
Expand All @@ -68,16 +48,6 @@ void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);

void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);

void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);

void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include <memory>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"

namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {

#if defined(CK_ENABLE_FP16)
void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);

void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);

void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);

void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);

void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);

void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);

void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);

void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif

} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
7 changes: 7 additions & 0 deletions library/src/tensor_operation_instance/gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ function(add_instance_library INSTANCE_NAME)

set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})

# Do not build DPP instances if DPP_KERNELS macro is not set
foreach(source IN LISTS ARGN)
if(NOT DEFINED DPP_KERNELS AND source MATCHES "_dpp")
message("removing dpp instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
# Do not build DL instances if DL_KERNELS macro is not set
foreach(source IN LISTS ARGN)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
Expand Down
6 changes: 6 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ function(add_test_executable TEST_NAME)

set(TEST_TARGETS ${SUPPORTED_GPU_TARGETS})

foreach(source IN LISTS ARGN)
if(NOT DEFINED DPP_KERNELS AND source MATCHES "_dpp")
message("removing dpp test ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
foreach(source IN LISTS ARGN)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl test ${source} ")
Expand Down

0 comments on commit 26b3829

Please sign in to comment.