Skip to content

Commit

Permalink
[FEA] Support for half-float mixed precise in brute-force (#2382)
Browse files Browse the repository at this point in the history
- distance supports half-float
- SDDMM support half-float
- gemm supports multi-type compose
- transpose & copy support half
- random supports half

Authors:
  - rhdong (https://github.com/rhdong)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2382
  • Loading branch information
rhdong authored Aug 22, 2024
1 parent 32f3703 commit db07998
Show file tree
Hide file tree
Showing 29 changed files with 1,050 additions and 441 deletions.
76 changes: 45 additions & 31 deletions cpp/include/raft/core/detail/copy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <raft/core/resource/cublas_handle.hpp>
#include <raft/linalg/detail/cublas_wrappers.hpp>
#ifdef __CUDACC__
#include <raft/linalg/transpose.cuh>
#include <raft/util/cuda_dev_essentials.cuh>
#endif
#endif
Expand Down Expand Up @@ -449,38 +450,51 @@ mdspan_copyable_t<DstType, SrcType> copy(resources const& res, DstType&& dst, Sr
#endif
} else if constexpr (config::can_use_cublas) {
#ifndef RAFT_DISABLE_CUDA
auto constexpr const alpha = typename std::remove_reference_t<DstType>::value_type{1};
auto constexpr const beta = typename std::remove_reference_t<DstType>::value_type{0};
if constexpr (std::is_same_v<typename config::dst_layout_type, layout_c_contiguous>) {
CUBLAS_TRY(linalg::detail::cublasgeam(resource::get_cublas_handle(res),
CUBLAS_OP_T,
CUBLAS_OP_N,
dst.extent(1),
dst.extent(0),
&alpha,
src.data_handle(),
src.extent(0),
&beta,
dst.data_handle(),
dst.extent(1),
dst.data_handle(),
dst.extent(1),
resource::get_cuda_stream(res)));
if constexpr (!((std::is_same_v<typename std::remove_reference_t<DstType>::value_type, half>)&&(
std::is_same_v<typename std::remove_reference_t<SrcType>::value_type, half>))) {
auto constexpr const alpha = typename std::remove_reference_t<DstType>::value_type{1};
auto constexpr const beta = typename std::remove_reference_t<DstType>::value_type{0};
if constexpr (std::is_same_v<typename config::dst_layout_type, layout_c_contiguous>) {
CUBLAS_TRY(linalg::detail::cublasgeam(resource::get_cublas_handle(res),
CUBLAS_OP_T,
CUBLAS_OP_N,
dst.extent(1),
dst.extent(0),
&alpha,
src.data_handle(),
src.extent(0),
&beta,
dst.data_handle(),
dst.extent(1),
dst.data_handle(),
dst.extent(1),
resource::get_cuda_stream(res)));
} else {
CUBLAS_TRY(linalg::detail::cublasgeam(resource::get_cublas_handle(res),
CUBLAS_OP_T,
CUBLAS_OP_N,
dst.extent(0),
dst.extent(1),
&alpha,
src.data_handle(),
src.extent(1),
&beta,
dst.data_handle(),
dst.extent(0),
dst.data_handle(),
dst.extent(0),
resource::get_cuda_stream(res)));
}
} else {
CUBLAS_TRY(linalg::detail::cublasgeam(resource::get_cublas_handle(res),
CUBLAS_OP_T,
CUBLAS_OP_N,
dst.extent(0),
dst.extent(1),
&alpha,
src.data_handle(),
src.extent(1),
&beta,
dst.data_handle(),
dst.extent(0),
dst.data_handle(),
dst.extent(0),
resource::get_cuda_stream(res)));
#ifdef __CUDACC__
raft::linalg::transpose(res, dst, src);
#else
// Should never actually reach this because of enable_ifs. Included for
// safety.
RAFT_FAIL(
"raft::copy called in a way that requires custom kernel. Please use "
"raft/core/copy.cuh and include the header in a .cu file");
#endif
}
#else
// Not possible to reach this due to enable_ifs. Included for safety.
Expand Down
14 changes: 13 additions & 1 deletion cpp/include/raft/core/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,13 @@ template <typename T>
RAFT_INLINE_FUNCTION auto asin(T x)
{
#ifdef __CUDA_ARCH__
return ::asin(x);
if constexpr (std::is_same<T, __half>::value) {
float x_float = __half2float(x);
float result_float = ::asin(x_float);
return __float2half(result_float);
} else {
return ::asin(x);
}
#else
return std::asin(x);
#endif
Expand Down Expand Up @@ -337,6 +343,12 @@ RAFT_INLINE_FUNCTION auto max(const T1& x, const T2& y)
((std::is_same_v<T1, float> || std::is_same_v<T1, double>)&&(
std::is_same_v<T2, float> || std::is_same_v<T2, double>))) {
return ::max(x, y);
} else if constexpr (std::is_same_v<T1, float> && std::is_same_v<T2, __half>) {
const float f_y = __half2float(y);
return (x < f_y) ? f_y : x;
} else if constexpr (std::is_same_v<T1, __half> && std::is_same_v<T2, float>) {
const float f_x = __half2float(x);
return (f_x < y) ? y : f_x;
}
// Else, check that the types are the same and provide a generic implementation
else {
Expand Down
18 changes: 17 additions & 1 deletion cpp/include/raft/core/operators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <raft/core/detail/macros.hpp>
#include <raft/core/math.hpp>

#include <cuda_fp16.h>

#include <algorithm>
#include <cmath>
#include <tuple>
Expand Down Expand Up @@ -104,13 +106,27 @@ struct sq_op {
{
return in * in;
}

template <typename... UnusedArgs>
constexpr RAFT_INLINE_FUNCTION auto operator()(const half& in, UnusedArgs...) const
{
return __half2float(in) * __half2float(in);
}
};

struct add_op {
template <typename T1, typename T2>
constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const
{
return a + b;
if constexpr (std::is_same_v<T1, half> && std::is_same_v<T2, half>) {
return __half2float(a) + __half2float(b);
} else if constexpr (std::is_same_v<T1, half>) {
return __half2float(a) + b;
} else if constexpr (std::is_same_v<T2, half>) {
return a + __half2float(b);
} else {
return a + b;
}
}
};

Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/distance/detail/masked_distance_base.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -266,7 +266,7 @@ struct MaskedDistances : public BaseClass {
for (int i = 0; i < P::AccRowsPerTh; ++i) {
#pragma unroll
for (int j = 0; j < P::AccColsPerTh; ++j) {
acc[i][j] = BaseClass::Zero;
acc[i][j] = BaseClass::Zero();
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/distance/detail/pairwise_distance_base.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -200,7 +200,7 @@ struct PairwiseDistances : public BaseClass {
for (int i = 0; i < P::AccRowsPerTh; ++i) {
#pragma unroll
for (int j = 0; j < P::AccColsPerTh; ++j) {
acc[i][j] = BaseClass::Zero;
acc[i][j] = BaseClass::Zero();
}
}
}
Expand Down
14 changes: 13 additions & 1 deletion cpp/include/raft/linalg/contractions.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -164,6 +164,12 @@ struct Policy4x4<double, _veclen> {
typedef KernelPolicy<double, _veclen, 16, 4, 4, 16, 16> Policy;
typedef ColKernelPolicy<double, _veclen, 16, 4, 4, 16, 16> ColPolicy;
};

template <int _veclen>
struct Policy4x4<half, _veclen> {
typedef KernelPolicy<half, _veclen, 64, 4, 4, 16, 16> Policy;
typedef ColKernelPolicy<half, _veclen, 64, 4, 4, 16, 16> ColPolicy;
};
/** @} */

/**
Expand Down Expand Up @@ -204,6 +210,12 @@ struct Policy2x8<double, _veclen> {
// this is not used just for keeping compiler happy.
typedef KernelPolicy<double, _veclen, 32, 1, 2, 8, 32> Policy;
};

template <int _veclen>
struct Policy2x8<half, _veclen> {
typedef KernelPolicy<half, _veclen, 16, 2, 8, 8, 32> Policy;
};

/** @} */

/**
Expand Down
14 changes: 8 additions & 6 deletions cpp/include/raft/linalg/detail/contractions.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -72,7 +72,9 @@ struct Contractions_NT {
/** block of Y data loaded from global mem after `ldgXY()` */
DataT ldgDataY[P::LdgPerThY][P::Veclen];

static constexpr DataT Zero = (DataT)0;
// static constexpr DataT Zero = DataT{0};

static constexpr DataT Zero() { return DataT{0}; }

public:
/**
Expand Down Expand Up @@ -197,7 +199,7 @@ struct Contractions_NT {
} else {
#pragma unroll
for (int j = 0; j < P::Veclen; ++j) {
ldgDataX[i][j] = Zero;
ldgDataX[i][j] = Zero();
}
}
}
Expand All @@ -211,7 +213,7 @@ struct Contractions_NT {
} else {
#pragma unroll
for (int j = 0; j < P::Veclen; ++j) {
ldgDataX[i][j] = Zero;
ldgDataX[i][j] = Zero();
}
}
}
Expand All @@ -235,7 +237,7 @@ struct Contractions_NT {
} else {
#pragma unroll
for (int j = 0; j < P::Veclen; ++j) {
ldgDataY[i][j] = Zero;
ldgDataY[i][j] = Zero();
}
}
}
Expand All @@ -249,7 +251,7 @@ struct Contractions_NT {
} else {
#pragma unroll
for (int j = 0; j < P::Veclen; ++j) {
ldgDataY[i][j] = Zero;
ldgDataY[i][j] = Zero();
}
}
}
Expand Down
Loading

0 comments on commit db07998

Please sign in to comment.