From 687318637b7cee552b26036fddce7d0edcc5daf5 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 14 Oct 2022 19:26:49 -0400 Subject: [PATCH 01/35] MOving gram matrix over to raft --- .../distance/detail/kernels/gram_matrix.cuh | 218 ++++++++++ .../detail/kernels/kernel_factory.cuh | 48 +++ .../detail/kernels/kernel_matrices.cuh | 376 ++++++++++++++++ cpp/include/raft/distance/distance_types.hpp | 21 + cpp/include/raft/distance/kernels.cuh | 32 ++ cpp/include/raft/linalg/init.cuh | 13 + cpp/include/raft/util/cache.cuh | 406 ++++++++++++++++++ 7 files changed, 1114 insertions(+) create mode 100644 cpp/include/raft/distance/detail/kernels/gram_matrix.cuh create mode 100644 cpp/include/raft/distance/detail/kernels/kernel_factory.cuh create mode 100644 cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh create mode 100644 cpp/include/raft/distance/kernels.cuh create mode 100644 cpp/include/raft/util/cache.cuh diff --git a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh new file mode 100644 index 0000000000..54ac490ca4 --- /dev/null +++ b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh @@ -0,0 +1,218 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include + +namespace raft::distance::kernels::detail { + +/** + * Base class for general Gram matrices + * A Gram matrix is the Hermitian matrix of inner probucts G_ik = + * Here, the inner product is evaluated for all elements from vectors sets X1, + * and X2. + * + * To be more precise, on exit the output buffer will store: + * - if is_row_major == true: out[j+k*n1] = , + * - if is_row_major == false: out[j*n2 + k] = , + * where x1_j is the j-th vector from the x1 set and x2_k is the k-th vector + * from the x2 set. + */ +template +class GramMatrixBase { + cublasHandle_t cublas_handle; + + public: + GramMatrixBase(cublasHandle_t cublas_handle) : cublas_handle(cublas_handle){}; + + virtual ~GramMatrixBase(){}; + + /** Convenience function to evaluate the Gram matrix for two vector sets. + * + * @param [in] x1 device array of vectors, size [n1*n_cols] + * @param [in] n1 number vectors in x1 + * @param [in] n_cols number of columns (features) in x1 and x2 + * @param [in] x2 device array of vectors, size [n2*n_cols] + * @param [in] n2 number vectors in x2 + * @param [out] out device buffer to store the Gram matrix, size [n1*n2] + * @param [in] is_row_major whether the input and output matrices are in row + * major format + * @param [in] stream cuda stream + * @param ld1 leading dimension of x1 + * @param ld2 leading dimension of x2 + * @param ld_out leading dimension of out + */ + virtual void operator()(const math_t* x1, + int n1, + int n_cols, + const math_t* x2, + int n2, + math_t* out, + bool is_row_major, + cudaStream_t stream, + int ld1 = 0, + int ld2 = 0, + int ld_out = 0) + { + if (ld1 <= 0) { ld1 = is_row_major ? n_cols : n1; } + if (ld2 <= 0) { ld2 = is_row_major ? n_cols : n2; } + if (ld_out <= 0) { ld_out = is_row_major ? n2 : n1; } + evaluate(x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); + } + + /** Evaluate the Gram matrix for two vector sets using simple dot product. + * + * @param [in] x1 device array of vectors, size [n1*n_cols] + * @param [in] n1 number vectors in x1 + * @param [in] n_cols number of columns (features) in x1 and x2 + * @param [in] x2 device array of vectors, size [n2*n_cols] + * @param [in] n2 number vectors in x2 + * @param [out] out device buffer to store the Gram matrix, size [n1*n2] + * @param [in] is_row_major whether the input and output matrices are in row + * major format + * @param [in] stream cuda stream + * @param ld1 leading dimension of x1 (usually it is n1) + * @param ld2 leading dimension of x2 (usually it is n2) + * @param ld_out leading dimension of out (usually it is n1) + */ + virtual void evaluate(const math_t* x1, + int n1, + int n_cols, + const math_t* x2, + int n2, + math_t* out, + bool is_row_major, + cudaStream_t stream, + int ld1, + int ld2, + int ld_out) + { + linear(x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); + } + + // private: + // The following methods should be private, they are kept public to avoid: + // "error: The enclosing parent function ("distance") for an extended + // __device__ lambda cannot have private or protected access within its class" + + /** Calculates the Gram matrix using simple dot product between vector sets. + * + * out = x1 * x2 + * + * Can be used as a building block for more complex kernel functions. + * + * @param [in] x1 device array of vectors, size [n1*n_cols] + * @param [in] n1 number vectors in x1 + * @param [in] n_cols number of colums (features) in x1 and x2 + * @param [in] x2 device array of vectors, size [n2*n_cols] + * @param [in] n2 number vectors in x2 + * @param [out] out device buffer to store the Gram matrix, size [n1*n2] + * @param [in] is_row_major whether the input and output matrices are in row + * major format + * @param [in] stream cuda stream + * @param ld1 leading dimension of x1 + * @param ld2 leading dimension of x2 + * @param ld_out leading dimension of out + */ + void linear(const math_t* x1, + int n1, + int n_cols, + const math_t* x2, + int n2, + math_t* out, + bool is_row_major, + cudaStream_t stream, + int ld1, + int ld2, + int ld_out) + { + math_t alpha = 1.0; + math_t beta = 0.0; + if (is_row_major) { + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemm(cublas_handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + n2, + n1, + n_cols, + &alpha, + x2, + ld2, + x1, + ld1, + &beta, + out, + ld_out, + stream)); + } else { + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemm(cublas_handle, + CUBLAS_OP_N, + CUBLAS_OP_T, + n1, + n2, + n_cols, + &alpha, + x1, + ld1, + x2, + ld2, + &beta, + out, + ld_out, + stream)); + } + } + + /** Calculates the Gram matrix using Euclidean distance. + * + * Can be used as a building block for more complex kernel functions. + * + * @param [in] x1 device array of vectors, size [n1*n_cols] + * @param [in] n1 number vectors in x1 + * @param [in] n_cols number of columns (features) in x1 and x2 + * @param [in] x2 device array of vectors, size [n2*n_cols] + * @param [in] n2 number vectors in x2 + * @param [out] out device buffer to store the Gram matrix, size [n1*n2] + * @param [in] is_row_major whether the input and output matrices are in row + * major format + * @param [in] stream cuda stream + * @param ld1 leading dimension of x1 + * @param ld2 leading dimension of x2 + * @param ld_out leading dimension of out + */ + virtual void distance(const math_t* x1, + int n1, + int n_cols, + const math_t* x2, + int n2, + math_t* out, + bool is_row_major, + cudaStream_t stream, + int ld1, + int ld2, + int ld_out) + { + raft::distance::distance( + x1, x2, out, n1, n2, n_cols, stream, is_row_major); + } +}; +}; // end namespace raft::distance::kernels::detail \ No newline at end of file diff --git a/cpp/include/raft/distance/detail/kernels/kernel_factory.cuh b/cpp/include/raft/distance/detail/kernels/kernel_factory.cuh new file mode 100644 index 0000000000..0103ecb003 --- /dev/null +++ b/cpp/include/raft/distance/detail/kernels/kernel_factory.cuh @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "gram_matrix.cuh" +#include "kernel_matrices.cuh" +#include +#include + +namespace raft::distance::kernels::detail { + +template +class KernelFactory { + public: + static GramMatrixBase* create(KernelParams params, cublasHandle_t cublas_handle) + { + GramMatrixBase* res; + // KernelParams is not templated, we convert the parameters to math_t here: + math_t coef0 = params.coef0; + math_t gamma = params.gamma; + switch (params.kernel) { + case LINEAR: res = new GramMatrixBase(cublas_handle); break; + case POLYNOMIAL: + res = new PolynomialKernel(params.degree, gamma, coef0, cublas_handle); + break; + case TANH: res = new TanhKernel(gamma, coef0, cublas_handle); break; + case RBF: res = new RBFKernel(gamma); break; + default: throw raft::exception("Kernel not implemented"); + } + return res; + } +}; + +}; // end namespace raft::distance::kernels::detail diff --git a/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh b/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh new file mode 100644 index 0000000000..6d59e1c7c5 --- /dev/null +++ b/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh @@ -0,0 +1,376 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "gram_matrix.cuh" +#include + +#include +#include + +namespace raft::distance::kernels::detail { + +/** Epiloge function for polynomial kernel without padding. + * Calculates output = (gain*in + offset)^exponent + * @param inout device vector in column major format, size [len] + * @param len array length + * @param exponent + * @param gain + * @param offset + */ +template +__global__ void polynomial_kernel_nopad( + math_t* inout, size_t len, exp_t exponent, math_t gain, math_t offset) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < len; + tid += blockDim.x * gridDim.x) { + inout[tid] = pow(gain * inout[tid] + offset, exponent); + } +} + +/** Epiloge function for polynomial kernel with padding. + * Calculates output = (gain*input + offset)^exponent + * @param inout device vector in column major format, size [ld * cols] + * @param ld leading dimension of the inout buffer + * @param rows number of rows (rows <= ld) + * @param cols number of colums + * @param exponent + * @param gain + * @param offset + */ +template +__global__ void polynomial_kernel( + math_t* inout, int ld, int rows, int cols, exp_t exponent, math_t gain, math_t offset) +{ + for (size_t tidy = threadIdx.y + blockIdx.y * blockDim.y; tidy < cols; + tidy += blockDim.y * gridDim.y) + for (size_t tidx = threadIdx.x + blockIdx.x * blockDim.x; tidx < rows; + tidx += blockDim.x * gridDim.x) { + inout[tidx + tidy * ld] = pow(gain * inout[tidx + tidy * ld] + offset, exponent); + } +} + +/** Epiloge function for tanh kernel without padding. + * Calculates output = tanh(gain*input + offset) + * @param inout device vector, size [len] + * @param len length of the input vector + * @param gain + * @param offset + */ +template +__global__ void tanh_kernel_nopad(math_t* inout, size_t len, math_t gain, math_t offset) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < len; + tid += blockDim.x * gridDim.x) { + inout[tid] = tanh(gain * inout[tid] + offset); + } +} + +/** Epiloge function for tanh kernel without padding. + * Calculates output = tanh(gain*input + offset) + * @param inout device vector in column major format, size [ld * cols] + * @param ld leading dimension of the inout buffer + * @param rows number of rows (rows <= ld) + * @param cols number of colums + * @param gain + * @param offset + */ +template +__global__ void tanh_kernel(math_t* inout, int ld, int rows, int cols, math_t gain, math_t offset) +{ + for (size_t tidy = threadIdx.y + blockIdx.y * blockDim.y; tidy < cols; + tidy += blockDim.y * gridDim.y) + for (size_t tidx = threadIdx.x + blockIdx.x * blockDim.x; tidx < rows; + tidx += blockDim.x * gridDim.x) { + inout[tidx + tidy * ld] = tanh(gain * inout[tidx + tidy * ld] + offset); + } +} + +/** + * Create a kernel matrix using polynomial kernel function. + */ +template +class PolynomialKernel : public GramMatrixBase { + exp_t exponent; + math_t gain; + math_t offset; + + void applyKernel( + math_t* inout, int ld, int rows, int cols, bool is_row_major, cudaStream_t stream) + { + const int n_minor = is_row_major ? cols : rows; + if (ld == n_minor) { + polynomial_kernel_nopad<<((size_t)rows * cols, 128), 128, 0, stream>>>( + inout, rows * cols, exponent, gain, offset); + } else { + int n1 = is_row_major ? cols : rows; + int n2 = is_row_major ? rows : cols; + polynomial_kernel<<>>(inout, ld, n1, n2, exponent, gain, offset); + } + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + + public: + /** + * Constructs a polynomial kernel object. + * It evaluates the kernel matrix using the following formula: + * K_ij = (gain* + offset)^exponent + * + * @tparam math_t floating point type + * @tparam exp_t type of exponent + * @param exponent + * @param gain + * @param offset + * @param cublas_handle + */ + PolynomialKernel(exp_t exponent, math_t gain, math_t offset, cublasHandle_t cublas_handle) + : GramMatrixBase(cublas_handle), exponent(exponent), gain(gain), offset(offset) + { + } + + /** Evaluate kernel matrix using polynomial kernel. + * + * output[i,k] = (gain* + offset)^exponent, + * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector + * in the x2 set, and < , > denotes dot product. + * + * @param [in] x1 device array of vectors, size [n1*n_cols] + * @param [in] n1 number vectors in x1 + * @param [in] n_cols number of features in x1 and x2 + * @param [in] x2 device array of vectors, size [n2*cols] + * @param [in] n2 number vectors in x2 + * @param [out] out device buffer to store the Gram matrix, size [n1*n2] + * @param [in] is_row_major whether the input and output matrices are in row + * major format + * @param [in] stream cuda stream + * @param ld1 leading dimension of x1 + * @param ld2 leading dimension of x2 + * @param ld_out leading dimension of out + */ + void evaluate(const math_t* x1, + int n1, + int n_cols, + const math_t* x2, + int n2, + math_t* out, + bool is_row_major, + cudaStream_t stream, + int ld1, + int ld2, + int ld_out) + { + GramMatrixBase::linear( + x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); + applyKernel(out, ld_out, n1, n2, is_row_major, stream); + } +}; + +/** + * Create a kernel matrix using tanh kernel function. + */ +template +class TanhKernel : public GramMatrixBase { + math_t gain, offset; + + void applyKernel( + math_t* inout, int ld, int rows, int cols, bool is_row_major, cudaStream_t stream) + { + const int n_minor = is_row_major ? cols : rows; + if (ld == n_minor) { + tanh_kernel_nopad<<((size_t)rows * cols, 128), 128, 0, stream>>>( + inout, rows * cols, gain, offset); + } else { + int n1 = is_row_major ? cols : rows; + int n2 = is_row_major ? rows : cols; + tanh_kernel<<>>(inout, ld, n1, n2, gain, offset); + } + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + + public: + /** + * Constructs a tanh kernel object. + * It evaluates the kernel matrix using the following formula: + * K_ij = tanh(gain* + offset) + * + * @tparam math_t floating point type + * @param gain + * @param offset + * @param cublas_handle + */ + TanhKernel(math_t gain, math_t offset, cublasHandle_t cublas_handle) + : GramMatrixBase(cublas_handle), gain(gain), offset(offset) + { + } + + /** Evaluate kernel matrix using tanh kernel. + * + * output_[i + k*n1] = (gain* + offset)^exponent, + * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector + * in the x2 set, and < , > denotes dot product. + * + * @param [in] x1 device array of vectors, + * size [n1*n_cols] + * @param [in] n1 number vectors in x1 + * @param [in] n_cols number of features in x1 and x2 + * @param [in] x2 device array of vectors, + * size [n2*n_cols] + * @param [in] n2 number vectors in x2 + * @param [out] out device buffer to store the Gram matrix, size [n1*n2] + * @param [in] is_row_major whether the input and output matrices are in row + * major format + * @param [in] stream cuda stream + * @param ld1 leading dimension of x1 (usually it is n1) + * @param ld2 leading dimension of x2 (usually it is n2) + * @param ld_out leading dimension of out (usually it is n1) + */ + void evaluate(const math_t* x1, + int n1, + int n_cols, + const math_t* x2, + int n2, + math_t* out, + bool is_row_major, + cudaStream_t stream, + int ld1, + int ld2, + int ld_out) + { + GramMatrixBase::linear( + x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); + applyKernel(out, ld_out, n1, n2, is_row_major, stream); + } +}; + +/** + * Create a kernel matrix using RBF kernel function. + */ +template +class RBFKernel : public GramMatrixBase { + math_t gain; + + void applyKernel( + math_t* inout, int ld, int rows, int cols, bool is_row_major, cudaStream_t stream) + { + const int n_minor = is_row_major ? cols : rows; + if (ld == n_minor) { + rbf_kernel_nopad<<((size_t)rows * cols, 128), 128, 0, stream>>>( + inout, rows * cols, gain); + } else { + int n1 = is_row_major ? cols : rows; + int n2 = is_row_major ? rows : cols; + rbf_kernel<<>>(inout, ld, n1, n2, gain); + } + } + + public: + /** + * Constructs a RBF kernel object. + * It evaluates the kernel matrix using the following formula: + * K_ij = exp(-gain*|x1_i- x2_k|^2) + * + * @tparam math_t floating point type + * @param gain + */ + RBFKernel(math_t gain) : GramMatrixBase(NULL), gain(gain) {} + + /** Evaluate kernel matrix using RBF kernel. + * + * output_[i + k*n1] = exp(-gain*|x1_i - x2_k|^2), + * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector + * in the x2 set, and | | euclidean distance. + * + * @param [in] x1 device array of vectors, size [n1*n_cols] + * @param [in] n1 number vectors in x1 + * @param [in] n_cols number of features in x1 and x2 + * @param [in] x2 device array of vectors, size [n2*n_cols] + * @param [in] n2 number vectors in x2 + * @param [out] out device buffer to store the Gram matrix, size [n1*n2] + * @param [in] is_row_major whether the input and output matrices are in row + * major format + * @param [in] stream cuda stream + * @param ld1 leading dimension of x1, currently only ld1 == n1 is supported + * @param ld2 leading dimension of x2, currently only ld2 == n2 is supported + * @param ld_out leading dimension of out, only ld_out == n1 is supported + */ + void evaluate(const math_t* x1, + int n1, + int n_cols, + const math_t* x2, + int n2, + math_t* out, + bool is_row_major, + cudaStream_t stream, + int ld1, + int ld2, + int ld_out) + { + int minor1 = is_row_major ? n_cols : n1; + int minor2 = is_row_major ? n_cols : n2; + int minor_out = is_row_major ? n2 : n1; + ASSERT(ld1 == minor1, "RBF Kernel distance does not support ld1 parameter"); + ASSERT(ld2 == minor2, "RBF Kernel distance does not support ld2 parameter"); + ASSERT(ld_out == minor_out, "RBF Kernel distance does not support ld_out parameter"); + distance(x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); + } + + /** Customize distance function withe RBF epilogue */ + void distance(const math_t* x1, + int n1, + int n_cols, + const math_t* x2, + int n2, + math_t* out, + bool is_row_major, + cudaStream_t stream, + int ld1, + int ld2, + int ld_out) + { + math_t gain = this->gain; + using index_t = int64_t; + + auto fin_op = [gain] __device__(math_t d_val, index_t idx) { return exp(-gain * d_val); }; + raft::distance::distance(const_cast(x1), + const_cast(x2), + out, + n1, + n2, + n_cols, + NULL, + 0, + fin_op, + stream, + is_row_major); + } +}; + +}; // end namespace raft::distance::kernels::detail diff --git a/cpp/include/raft/distance/distance_types.hpp b/cpp/include/raft/distance/distance_types.hpp index f75263b00d..f5ed68af4a 100644 --- a/cpp/include/raft/distance/distance_types.hpp +++ b/cpp/include/raft/distance/distance_types.hpp @@ -65,5 +65,26 @@ enum DistanceType : unsigned short { /** Precomputed (special value) **/ Precomputed = 100 }; + +namespace kernels { +enum KernelType { LINEAR, POLYNOMIAL, RBF, TANH }; + +/** + * Parameters for kernel matrices. + * The following kernels are implemented: + * - LINEAR \f[ K(x_1,x_2) = , \f] where \f$< , >\f$ is the dot product + * - POLYNOMIAL \f[ K(x_1, x_2) = (\gamma + \mathrm{coef0})^\mathrm{degree} \f] + * - RBF \f[ K(x_1, x_2) = \exp(- \gamma |x_1-x_2|^2) \f] + * - TANH \f[ K(x_1, x_2) = \tanh(\gamma + \mathrm{coef0}) \f] + */ +struct KernelParams { + // Kernel function parameters + KernelType kernel; //!< Type of the kernel function + int degree; //!< Degree of polynomial kernel (ignored by others) + double gamma; //!< multiplier in the + double coef0; //!< additive constant in poly and tanh kernels +}; +} // end namespace kernels + }; // namespace distance }; // end namespace raft diff --git a/cpp/include/raft/distance/kernels.cuh b/cpp/include/raft/distance/kernels.cuh new file mode 100644 index 0000000000..86f9f82406 --- /dev/null +++ b/cpp/include/raft/distance/kernels.cuh @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include +#include + +namespace raft::distance::kernels { + +// TODO: Need to expose formal APIs for this that are more consistent w/ other APIs in RAFT +using raft::distance::kernels::detail::GramMatrixBase; +using raft::distance::kernels::detail::KernelFactory; + +}; // end namespace raft::distance::kernels diff --git a/cpp/include/raft/linalg/init.cuh b/cpp/include/raft/linalg/init.cuh index 2fdf9dceb9..7a09cee289 100644 --- a/cpp/include/raft/linalg/init.cuh +++ b/cpp/include/raft/linalg/init.cuh @@ -54,6 +54,19 @@ void range(T* out, int n, cudaStream_t stream) detail::range(out, n, stream); } +/** + * @brief Zeros the output. + * + * \param [out] out device array, size [n] + * \param [in] n length of the array + * \param [in] stream cuda stream + */ +template +void zero(T* out, int n, cudaStream_t stream) +{ + RAFT_CUDA_TRY(cudaMemsetAsync(static_cast(out), 0, n * sizeof(T), stream)); +} + } // namespace linalg } // namespace raft diff --git a/cpp/include/raft/util/cache.cuh b/cpp/include/raft/util/cache.cuh new file mode 100644 index 0000000000..ef210fad82 --- /dev/null +++ b/cpp/include/raft/util/cache.cuh @@ -0,0 +1,406 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +#include + +namespace raft::util::cache + + /** + * @brief Associative cache with least recently used replacement policy. + * + * SW managed cache in device memory, for ML algos where we can trade memory + * access for computation. The two main functions of this class are the + * management of cache indices, and methods to retrieve/store data using the + * cache indices. + * + * The index management can be considered as a hash map, where the int + * keys are the original vector indices that we want to store, and the values are + * the cache location of these vectors. The keys are hashed into a bucket + * whose size equals the associativity. These are the cache sets. If a cache + * set is full, then new indices are stored by replacing the oldest entries. + * + * Using this index mapping we implement methods to store and retrive data from + * the cache buffer, where a unit of data that we are storing is math_t[n_vec]. + * For example in SVM we store full columns of the kernel matrix at each cache + * entry. + * + * Note: we should have a look if the index management could be simplified using + * concurrent_unordered_map.cuh from cudf. See Issue #914. + * + * Example usage: + * @code{.cpp} + * + * // An expensive calculation that we want to accelerate with caching: + * // we have n keys, and for each key we generate a vector with m elements. + * // The keys and the output values are stored in GPU memory. + * void calc(int *key, int n, int m, float *out, cudaStream_t stream) { + * for (k=0; k cache(h.get_device_allocator(), stream, m); + * + * // A buffer that we will reuse to store the cache indices. + * rmm::device_uvector cache_idx(h.get_device_allocator(), stream, n); + * + * void cached_calc(int *key, int n, int m, float *out, stream) { + * int n_cached = 0; + * + * cache.GetCacheIdxPartitioned(key, n, cache_idx.data(), &n_cached, + * cudaStream_t stream); + * + * // Note: GetCacheIdxPartitioned has reordered the keys so that + * // key[0..n_cached-1] are the keys already in the cache. + * // We collect the corresponding values + * cache.GetVecs(cache_idx.data(), n_cached, out, stream); + * + * // Calculate the elements not in the cache + * int non_cached = n - n_cached; + * if (non_cached > 0) { + * int *key_new = key + n_cached; + * int *cache_idx_new = cache_idx.data() + n_cached; + * float *out_new = out + n_cached * m; + * // AssignCacheIdx can permute the keys, therefore it has to come before + * // we call calc. + * // Note: a call to AssignCacheIdx should always be preceded with + * // GetCacheIdxPartitioned, because that initializes the cache_idx_new array + * // with the cache set (hash bucket) that correspond to the keys. + * // The cache idx will be assigned from that cache set. + * cache.AssignCacheIdx(key_new, non_cached, cache_idx_new, stream); + * + * calc(key_new, non_cached, m, out_new, stream); + * + * // Store the calculated vectors into the cache. + * cache.StoreVecs(out_new, non_cached, non_cached, cache_idx_new, stream); + * } + * } + * @endcode + */ + template + class Cache { + public: + /** + * @brief Construct a Cache object + * + * @tparam math_t type of elements to be cached + * @tparam associativity number of vectors in a cache set + * + * @param stream cuda stream + * @param n_vec number of elements in a single vector that is stored in a + * cache entry + * @param cache_size in MiB + */ + Cache(cudaStream_t stream, int n_vec, float cache_size = 200) + : n_vec(n_vec), + cache_size(cache_size), + cache(0, stream), + cached_keys(0, stream), + cache_time(0, stream), + is_cached(0, stream), + ws_tmp(0, stream), + idx_tmp(0, stream), + d_num_selected_out(stream), + d_temp_storage(0, stream) + { + ASSERT(n_vec > 0, "Parameter n_vec: shall be larger than zero"); + ASSERT(associativity > 0, "Associativity shall be larger than zero"); + ASSERT(cache_size >= 0, "Cache size should not be negative"); + + // Calculate how many vectors would fit the cache + int n_cache_vecs = (cache_size * 1024 * 1024) / (sizeof(math_t) * n_vec); + + // The available memory shall be enough for at least one cache set + if (n_cache_vecs >= associativity) { + n_cache_sets = n_cache_vecs / associativity; + n_cache_vecs = n_cache_sets * associativity; + cache.resize(n_cache_vecs * n_vec, stream); + cached_keys.resize(n_cache_vecs, stream); + cache_time.resize(n_cache_vecs, stream); + RAFT_CUDA_TRY( + cudaMemsetAsync(cached_keys.data(), 0, cached_keys.size() * sizeof(int), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(cache_time.data(), 0, cache_time.size() * sizeof(int), stream)); + } else { + if (cache_size > 0) { + CUML_LOG_WARN( + "Warning: not enough memory to cache a single set of " + "rows, not using cache"); + } + n_cache_sets = 0; + cache_size = 0; + } + CUML_LOG_DEBUG( + "Creating cache with size=%f MiB, to store %d vectors, in " + "%d sets with associativity=%d", + cache_size, + n_cache_vecs, + n_cache_sets, + associativity); + } + + Cache(const Cache& other) = delete; + + Cache& operator=(const Cache& other) = delete; + + /** @brief Collect cached data into contiguous memory space. + * + * On exit, the tile array is filled the following way: + * out[i + n_vec*k] = cache[i + n_vec * idx[k]]), where i=0..n_vec-1, + * k = 0..n-1 + * + * Idx values less than 0 are ignored. + * + * @param [in] idx cache indices, size [n] + * @param [in] n the number of vectors that need to be collected + * @param [out] out vectors collected from cache, size [n_vec*n] + * @param [in] stream cuda stream + */ + void GetVecs(const int* idx, int n, math_t* out, cudaStream_t stream) + { + if (n > 0) { + get_vecs<<>>(cache.data(), n_vec, idx, n, out); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + } + + /** @brief Store vectors of data into the cache. + * + * Roughly the opposite of GetVecs, but the input vectors can be scattered + * in memory. The cache is updated using the following formula: + * + * cache[i + cache_idx[k]*n_vec] = tile[i + tile_idx[k]*n_vec], + * for i=0..n_vec-1, k=0..n-1 + * + * If tile_idx==nullptr, then we assume tile_idx[k] = k. + * + * Elements within a vector should be contiguous in memory (i.e. column vectors + * for column major data storage, or row vectors of row major data). + * + * @param [in] tile stores the data to be cashed cached, size [n_vec x n_tile] + * @param [in] n_tile number of vectors in tile (at least n) + * @param [in] n number of vectors that need to be stored in the cache (a subset + * of all the vectors in the tile) + * @param [in] cache_idx cache indices for storing the vectors (negative values + * are ignored), size [n] + * @param [in] stream cuda stream + * @param [in] tile_idx indices of vectors that need to be stored + */ + void StoreVecs(const math_t* tile, + int n_tile, + int n, + int* cache_idx, + cudaStream_t stream, + const int* tile_idx = nullptr) + { + if (n > 0) { + store_vecs<<>>( + tile, n_tile, n_vec, tile_idx, n, cache_idx, cache.data(), cache.size() / n_vec); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + } + + /** @brief Map a set of keys to cache indices. + * + * For each k in 0..n-1, if keys[k] is found in the cache, then cache_idx[k] + * will tell the corresponding cache idx, and is_cached[k] is set to true. + * + * If keys[k] is not found in the cache, then is_cached[k] is set to false. + * In this case we assign the cache set for keys[k], and cache_idx[k] will + * store the cache set. + * + * @note in order to retrieve the cached vector j=cache_idx[k] from the cache, + * we have to access cache[i + j*n_vec], where i=0..n_vec-1. + * + * @note: do not use simultaneous GetCacheIdx and AssignCacheIdx + * + * @param [in] keys device array of keys, size [n] + * @param [in] n number of keys + * @param [out] cache_idx device array of cache indices corresponding to the + * input keys, size [n] + * @param [out] is_cached whether the element is already available in the + * cache, size [n] + * @param [in] stream + */ + void GetCacheIdx(int* keys, int n, int* cache_idx, bool* is_cached, cudaStream_t stream) + { + n_iter++; // we increase the iteration counter, that is used to time stamp + // accessing entries from the cache + get_cache_idx<<>>(keys, + n, + cached_keys.data(), + n_cache_sets, + associativity, + cache_time.data(), + cache_idx, + is_cached, + n_iter); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + + /** @brief Map a set of keys to cache indices. + * + * Same as GetCacheIdx, but partitions the keys, and cache_idx arrays in a way + * that keys[0..n_cached-1] and cache_idx[0..n_cached-1] store the indices of + * vectors that are found in the cache, while keys[n_cached..n-1] are the + * indices of vectors that are not found in the cache. For the vectors not + * found in the cache, cache_idx[n_cached..n-1] stores the cache set, and this + * can be used to call AssignCacheIdx. + * + * @param [inout] keys device array of keys, size [n] + * @param [in] n number of indices + * @param [out] cache_idx device array of cache indices corresponding to + * the input keys, size [n] + * @param [out] n_cached number of elements that are cached + * @param [in] stream cuda stream + */ + void GetCacheIdxPartitioned(int* keys, int n, int* cache_idx, int* n_cached, cudaStream_t stream) + { + ResizeTmpBuffers(n, stream); + + GetCacheIdx(keys, n, ws_tmp.data(), is_cached.data(), stream); + + // Group cache indices as [already cached, non_cached] + cub::DevicePartition::Flagged(d_temp_storage.data(), + d_temp_storage_size, + ws_tmp.data(), + is_cached.data(), + cache_idx, + d_num_selected_out.data(), + n, + stream); + + raft::update_host(n_cached, d_num_selected_out.data(), 1, stream); + + // Similarily re-group the input indices + raft::copy(ws_tmp.data(), keys, n, stream); + cub::DevicePartition::Flagged(d_temp_storage.data(), + d_temp_storage_size, + ws_tmp.data(), + is_cached.data(), + keys, + d_num_selected_out.data(), + n, + stream); + + raft::interruptible::synchronize(stream); + } + + /** + * @brief Assign cache location to a set of keys. + * + * Note: call GetCacheIdx first, to get the cache_set assigned to the keys. + * Keys that cannot be cached are assigned to -1. + * + * @param [inout] keys device array of keys, size [n] + * @param [in] n number of elements that we want to cache + * @param [inout] cidx on entry: cache_set, on exit: assigned cache_idx or -1, + * size[n] + * @param [in] stream cuda stream + */ + void AssignCacheIdx(int* keys, int n, int* cidx, cudaStream_t stream) + { + if (n <= 0) return; + cub::DeviceRadixSort::SortPairs(d_temp_storage.data(), + d_temp_storage_size, + cidx, + ws_tmp.data(), + keys, + idx_tmp.data(), + n, + 0, + sizeof(int) * 8, + stream); + + raft::copy(keys, idx_tmp.data(), n, stream); + + // set it to -1 + RAFT_CUDA_TRY(cudaMemsetAsync(cidx, 255, n * sizeof(int), stream)); + const int nthreads = associativity <= 32 ? associativity : 32; + + assign_cache_idx<<>>( + keys, n, ws_tmp.data(), cached_keys.data(), n_cache_sets, cache_time.data(), n_iter, cidx); + + RAFT_CUDA_TRY(cudaPeekAtLastError()); + if (debug_mode) RAFT_CUDA_TRY(cudaDeviceSynchronize()); + } + + /** Return approximate cache size in MiB. */ + float GetSizeInMiB() const { return cache_size; } + + /** + * Returns the number of vectors that can be cached. + */ + int GetSize() const { return cached_keys.size(); } + + private: + int n_vec; //!< Number of elements in a cached vector + float cache_size; //!< in MiB + int n_cache_sets; //!< number of cache sets + + const int TPB = 256; //!< threads per block for kernel launch + int n_iter = 0; //!< Counter for time stamping cache operation + + bool debug_mode = false; + + rmm::device_uvector cache; //!< The value of cached vectors + rmm::device_uvector cached_keys; //!< Keys stored at each cache loc + rmm::device_uvector cache_time; //!< Time stamp for LRU cache + + // Helper arrays for GetCacheIdx + rmm::device_uvector is_cached; + rmm::device_uvector ws_tmp; + rmm::device_uvector idx_tmp; + + // Helper arrays for cub + rmm::device_scalar d_num_selected_out; + rmm::device_uvector d_temp_storage; + size_t d_temp_storage_size = 0; + + void ResizeTmpBuffers(int n, cudaStream_t stream) + { + if (ws_tmp.size() < static_cast(n)) { + ws_tmp.resize(n, stream); + is_cached.resize(n, stream); + idx_tmp.resize(n, stream); + cub::DevicePartition::Flagged(NULL, + d_temp_storage_size, + cached_keys.data(), + is_cached.data(), + cached_keys.data(), + d_num_selected_out.data(), + n, + stream); + d_temp_storage.resize(d_temp_storage_size, stream); + } + } +}; +} +; // namespace raft::util::cache From 4df4df1ad12776f3ce021afcbe20d8771c64e0d0 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 14 Oct 2022 19:51:14 -0400 Subject: [PATCH 02/35] Adding specializations for gram matrix kernels --- cpp/CMakeLists.txt | 9 ++ cpp/bench/distance/distance_common.cuh | 2 +- cpp/bench/distance/kernels.cu | 132 ++++++++++++++++++ .../specializations/detail/kernels.cuh | 30 ++++ .../distance/specializations/distance.cuh | 1 + .../detail/kernels/gram_matrix_base_double.cu | 19 +++ .../detail/kernels/gram_matrix_base_float.cu | 19 +++ .../kernels/polynomial_kernel_double_int.cu | 19 +++ .../kernels/polynomial_kernel_float_int.cu | 19 +++ .../detail/kernels/rbf_kernel_double.cu | 19 +++ .../detail/kernels/rbf_kernel_float.cu | 19 +++ .../detail/kernels/tanh_kernel_double.cu | 19 +++ .../detail/kernels/tanh_kernel_float.cu | 19 +++ 13 files changed, 325 insertions(+), 1 deletion(-) create mode 100644 cpp/bench/distance/kernels.cu create mode 100644 cpp/include/raft/distance/specializations/detail/kernels.cuh create mode 100644 cpp/src/distance/specializations/detail/kernels/gram_matrix_base_double.cu create mode 100644 cpp/src/distance/specializations/detail/kernels/gram_matrix_base_float.cu create mode 100644 cpp/src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu create mode 100644 cpp/src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu create mode 100644 cpp/src/distance/specializations/detail/kernels/rbf_kernel_double.cu create mode 100644 cpp/src/distance/specializations/detail/kernels/rbf_kernel_float.cu create mode 100644 cpp/src/distance/specializations/detail/kernels/tanh_kernel_double.cu create mode 100644 cpp/src/distance/specializations/detail/kernels/tanh_kernel_float.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 12bebfa2a5..ce7af8535a 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -248,11 +248,20 @@ if(RAFT_COMPILE_DIST_LIBRARY) src/distance/specializations/detail/chebyshev.cu src/distance/specializations/detail/correlation.cu src/distance/specializations/detail/cosine.cu + src/distance/specializations/detail/cosine.cu src/distance/specializations/detail/hamming_unexpanded.cu src/distance/specializations/detail/hellinger_expanded.cu src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu src/distance/specializations/detail/jensen_shannon_float_float_float_uint32.cu src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu + src/distance/specializations/detail/kernels/gram_matrix_base_double.cu + src/distance/specializations/detail/kernels/gram_matrix_base_float.cu + src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu + src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu + src/distance/specializations/detail/kernels/rbf_kernel_double.cu + src/distance/specializations/detail/kernels/rbf_kernel_float.cu + src/distance/specializations/detail/kernels/tanh_kernel_double.cu + src/distance/specializations/detail/kernels/tanh_kernel_float.cu src/distance/specializations/detail/kl_divergence_float_float_float_int.cu src/distance/specializations/detail/kl_divergence_float_float_float_uint32.cu src/distance/specializations/detail/kl_divergence_double_double_double_int.cu diff --git a/cpp/bench/distance/distance_common.cuh b/cpp/bench/distance/distance_common.cuh index 4f1a8ccab1..14ebb55ebe 100644 --- a/cpp/bench/distance/distance_common.cuh +++ b/cpp/bench/distance/distance_common.cuh @@ -15,8 +15,8 @@ */ #include -#include #include +#include #if defined RAFT_DISTANCE_COMPILED #include #endif diff --git a/cpp/bench/distance/kernels.cu b/cpp/bench/distance/kernels.cu new file mode 100644 index 0000000000..c2486506e5 --- /dev/null +++ b/cpp/bench/distance/kernels.cu @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#if defined RAFT_DISTANCE_COMPILED +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::distance::kernels::bench { + +struct GramTestParams { + int m; // m parameter of the GEMM + int k; // k parameter of the GEMM + int n; // n parameter of the GEMM + KernelParams kernel_params; + bool is_row_major; +}; // struct GramTestParams + +template +struct GramMatrix : public Fixture { + GramMatrix(const std::string& name, const GramTestParams& p) + : Fixture(name), params(p), A(0, stream), B(0, stream), C(0, stream) + { + std::vector kernel_names{"linear", "poly", "rbf", "tanh"}; + std::ostringstream oss; + oss << name << "/" << kernel_names[p.kernel_params.kernel] << "/" << p.m << "x" << p.k << "x" + << p.n << "/" << (p.is_row_major ? "row_major" : "col_major"); + this->SetName(oss.str().c_str()); + + kernel = std::unique_ptr>( + KernelFactory::create(p.kernel_params, handle.get_cublas_handle())); + } + + ~GramMatrix() {} + + protected: + void allocateBuffers(const ::benchmark::State& state) override + { + A.resize(params.m * params.k, stream); + B.resize(params.k * params.n, stream); + C.resize(params.m * params.n, stream); + raft::random::Rng r(123456ULL); + r.uniform(A.data(), params.m * params.k, T(-1.0), T(1.0), stream); + r.uniform(B.data(), params.k * params.n, T(-1.0), T(1.0), stream); + } + void deallocateBuffers(const ::benchmark::State& state) override + { + A.release(); + B.release(); + C.release(); + } + void runBenchmark(::benchmark::State& state) override + { + if (!this->kernel) { state.SkipWithError("Kernel matrix is not initialized"); } + loopOnState(state, [this]() { + (*this->kernel)(A.data(), + this->params.m, + this->params.k, + B.data(), + this->params.n, + C.data(), + this->params.is_row_major, + this->stream); + }); + } + + private: + raft::handle_t& handle; + std::unique_ptr> kernel; + GramTestParams params; + + rmm::device_uvector A; // input matrix A, size [m * k] + rmm::device_uvector B; // input matrix B, size [n * k] + rmm::device_uvector C; // output matrix C, size [m*n] +}; + +static std::vector getInputs() +{ + std::vector param_vec; + std::vector kernel_params{KernelParams{LINEAR, 3, 1, 0}, + KernelParams{POLYNOMIAL, 2, 1.3, 1}, + KernelParams{TANH, 2, 0.5, 2.4}, + KernelParams{RBF, 2, 0.5, 0}}; + struct TestSize { + int m; + int k; + int n; + }; + std::vector data_size{{4096, 10, 1024}, + {4096, 100, 1024}, + {4096, 1000, 1024}, + {4096, 10000, 1024}, + {100000, 10, 1024}, + {100000, 100, 1024}, + {100000, 1000, 1024}}; + + param_vec.reserve(kernel_params.size() * data_size.size()); + for (TestSize s : data_size) { + for (auto kernel : kernel_params) { + for (bool row_major : {false, true}) { + param_vec.push_back(GramTestParams{s.m, s.k, s.n, kernel, row_major}); + } + } + } + return param_vec; +} + +ML_BENCH_REGISTER(GramTestParams, GramMatrix, "", getInputs()); +ML_BENCH_REGISTER(GramTestParams, GramMatrix, "", getInputs()); + +} // namespace raft::distance::kernels::bench diff --git a/cpp/include/raft/distance/specializations/detail/kernels.cuh b/cpp/include/raft/distance/specializations/detail/kernels.cuh new file mode 100644 index 0000000000..a8159e7610 --- /dev/null +++ b/cpp/include/raft/distance/specializations/detail/kernels.cuh @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +extern template class raft::distance::kernels::detail::GramMatrixBase; +extern template class raft::distance::kernels::detail::GramMatrixBase; + +extern template class raft::distance::kernels::detail::PolynomialKernel; +extern template class raft::distance::kernels::detail::PolynomialKernel; + +extern template class raft::distance::kernels::detail::TanhKernel; +extern template class raft::distance::kernels::detail::TanhKernel; + +extern template class raft::distance::kernels::detail::RBFKernel; +extern template class raft::distance::kernels::detail::RBFKernel; \ No newline at end of file diff --git a/cpp/include/raft/distance/specializations/distance.cuh b/cpp/include/raft/distance/specializations/distance.cuh index 73d075f260..053441d68a 100644 --- a/cpp/include/raft/distance/specializations/distance.cuh +++ b/cpp/include/raft/distance/specializations/distance.cuh @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/src/distance/specializations/detail/kernels/gram_matrix_base_double.cu b/cpp/src/distance/specializations/detail/kernels/gram_matrix_base_double.cu new file mode 100644 index 0000000000..c893e9a358 --- /dev/null +++ b/cpp/src/distance/specializations/detail/kernels/gram_matrix_base_double.cu @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +template class raft::distance::kernels::detail::GramMatrixBase; \ No newline at end of file diff --git a/cpp/src/distance/specializations/detail/kernels/gram_matrix_base_float.cu b/cpp/src/distance/specializations/detail/kernels/gram_matrix_base_float.cu new file mode 100644 index 0000000000..3265f828e6 --- /dev/null +++ b/cpp/src/distance/specializations/detail/kernels/gram_matrix_base_float.cu @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +template class raft::distance::kernels::detail::GramMatrixBase; \ No newline at end of file diff --git a/cpp/src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu b/cpp/src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu new file mode 100644 index 0000000000..36c7945d27 --- /dev/null +++ b/cpp/src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +template class raft::distance::kernels::detail::PolynomalKernel; \ No newline at end of file diff --git a/cpp/src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu b/cpp/src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu new file mode 100644 index 0000000000..37e173fbc3 --- /dev/null +++ b/cpp/src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +template class raft::distance::kernels::detail::PolynomalKernel; \ No newline at end of file diff --git a/cpp/src/distance/specializations/detail/kernels/rbf_kernel_double.cu b/cpp/src/distance/specializations/detail/kernels/rbf_kernel_double.cu new file mode 100644 index 0000000000..6577e1b6c7 --- /dev/null +++ b/cpp/src/distance/specializations/detail/kernels/rbf_kernel_double.cu @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +template class raft::distance::kernels::detail::RBFKernel; \ No newline at end of file diff --git a/cpp/src/distance/specializations/detail/kernels/rbf_kernel_float.cu b/cpp/src/distance/specializations/detail/kernels/rbf_kernel_float.cu new file mode 100644 index 0000000000..1d2582cf81 --- /dev/null +++ b/cpp/src/distance/specializations/detail/kernels/rbf_kernel_float.cu @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +template class raft::distance::kernels::detail::RBFKernel; \ No newline at end of file diff --git a/cpp/src/distance/specializations/detail/kernels/tanh_kernel_double.cu b/cpp/src/distance/specializations/detail/kernels/tanh_kernel_double.cu new file mode 100644 index 0000000000..13d5159504 --- /dev/null +++ b/cpp/src/distance/specializations/detail/kernels/tanh_kernel_double.cu @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +template class raft::distance::kernels::detail::TanhKernel; \ No newline at end of file diff --git a/cpp/src/distance/specializations/detail/kernels/tanh_kernel_float.cu b/cpp/src/distance/specializations/detail/kernels/tanh_kernel_float.cu new file mode 100644 index 0000000000..ee62de7d34 --- /dev/null +++ b/cpp/src/distance/specializations/detail/kernels/tanh_kernel_float.cu @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +template class raft::distance::kernels::detail::TanhKernel; \ No newline at end of file From c9e82cd89b30e0de6c102fa111a3c1aa20ca5663 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 14 Oct 2022 20:14:11 -0400 Subject: [PATCH 03/35] Commenting out rbf kernel instantiations for now. --- cpp/CMakeLists.txt | 5 +++-- cpp/bench/CMakeLists.txt | 1 + cpp/include/raft/distance/specializations/detail/kernels.cuh | 5 +++-- .../detail/kernels/polynomial_kernel_double_int.cu | 2 +- .../detail/kernels/polynomial_kernel_float_int.cu | 2 +- 5 files changed, 9 insertions(+), 6 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index ce7af8535a..21e2103c4b 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -258,8 +258,9 @@ if(RAFT_COMPILE_DIST_LIBRARY) src/distance/specializations/detail/kernels/gram_matrix_base_float.cu src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu - src/distance/specializations/detail/kernels/rbf_kernel_double.cu - src/distance/specializations/detail/kernels/rbf_kernel_float.cu +# These are somehow missing a kernel definition which is causing a compile error. +# src/distance/specializations/detail/kernels/rbf_kernel_double.cu +# src/distance/specializations/detail/kernels/rbf_kernel_float.cu src/distance/specializations/detail/kernels/tanh_kernel_double.cu src/distance/specializations/detail/kernels/tanh_kernel_float.cu src/distance/specializations/detail/kl_divergence_float_float_float_int.cu diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 4af3df9a1a..9c6b60d83b 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -85,6 +85,7 @@ if(BUILD_BENCH) bench/distance/distance_exp_l2.cu bench/distance/distance_l1.cu bench/distance/distance_unexp_l2.cu + bench/distance/kernels.cu bench/main.cpp OPTIONAL DIST ) diff --git a/cpp/include/raft/distance/specializations/detail/kernels.cuh b/cpp/include/raft/distance/specializations/detail/kernels.cuh index a8159e7610..d34a30edd9 100644 --- a/cpp/include/raft/distance/specializations/detail/kernels.cuh +++ b/cpp/include/raft/distance/specializations/detail/kernels.cuh @@ -26,5 +26,6 @@ extern template class raft::distance::kernels::detail::PolynomialKernel; extern template class raft::distance::kernels::detail::TanhKernel; -extern template class raft::distance::kernels::detail::RBFKernel; -extern template class raft::distance::kernels::detail::RBFKernel; \ No newline at end of file +// These are somehow missing a kernel definition which is causing a compile error +//extern template class raft::distance::kernels::detail::RBFKernel; +//extern template class raft::distance::kernels::detail::RBFKernel; \ No newline at end of file diff --git a/cpp/src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu b/cpp/src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu index 36c7945d27..0edf45a6f1 100644 --- a/cpp/src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu +++ b/cpp/src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu @@ -16,4 +16,4 @@ #include -template class raft::distance::kernels::detail::PolynomalKernel; \ No newline at end of file +template class raft::distance::kernels::detail::PolynomialKernel; \ No newline at end of file diff --git a/cpp/src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu b/cpp/src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu index 37e173fbc3..a719175e6b 100644 --- a/cpp/src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu +++ b/cpp/src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu @@ -16,4 +16,4 @@ #include -template class raft::distance::kernels::detail::PolynomalKernel; \ No newline at end of file +template class raft::distance::kernels::detail::PolynomialKernel; \ No newline at end of file From 443130bea3dba32e6f5adc36751efdc1805fe4a4 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 14 Oct 2022 20:27:23 -0400 Subject: [PATCH 04/35] Fixing style after commenting RBC out --- cpp/include/raft/distance/specializations/detail/kernels.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/distance/specializations/detail/kernels.cuh b/cpp/include/raft/distance/specializations/detail/kernels.cuh index d34a30edd9..75c9c023e8 100644 --- a/cpp/include/raft/distance/specializations/detail/kernels.cuh +++ b/cpp/include/raft/distance/specializations/detail/kernels.cuh @@ -27,5 +27,5 @@ extern template class raft::distance::kernels::detail::TanhKernel; extern template class raft::distance::kernels::detail::TanhKernel; // These are somehow missing a kernel definition which is causing a compile error -//extern template class raft::distance::kernels::detail::RBFKernel; -//extern template class raft::distance::kernels::detail::RBFKernel; \ No newline at end of file +// extern template class raft::distance::kernels::detail::RBFKernel; +// extern template class raft::distance::kernels::detail::RBFKernel; \ No newline at end of file From 231c3d93cac785fbaa88d4996e59a4f5a7d10780 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 14 Oct 2022 23:51:39 -0400 Subject: [PATCH 05/35] Adding cudart_utils.hpp to init.cuh --- cpp/include/raft/linalg/init.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/include/raft/linalg/init.cuh b/cpp/include/raft/linalg/init.cuh index 7a09cee289..f23b454636 100644 --- a/cpp/include/raft/linalg/init.cuh +++ b/cpp/include/raft/linalg/init.cuh @@ -18,6 +18,7 @@ #pragma once +#include #include "detail/init.hpp" namespace raft { From 1bbd5f0f555d87507fca1d5feac06b8080d81e42 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Sat, 15 Oct 2022 00:15:59 -0400 Subject: [PATCH 06/35] Style --- cpp/include/raft/linalg/init.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/linalg/init.cuh b/cpp/include/raft/linalg/init.cuh index f23b454636..5a810bf2ba 100644 --- a/cpp/include/raft/linalg/init.cuh +++ b/cpp/include/raft/linalg/init.cuh @@ -18,8 +18,8 @@ #pragma once -#include #include "detail/init.hpp" +#include namespace raft { namespace linalg { From b663d3e37e43b2a78e43829f3f84632d07353998 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Sat, 15 Oct 2022 01:56:40 -0400 Subject: [PATCH 07/35] Fixing typo --- cpp/bench/distance/distance_common.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/bench/distance/distance_common.cuh b/cpp/bench/distance/distance_common.cuh index 14ebb55ebe..73faacce37 100644 --- a/cpp/bench/distance/distance_common.cuh +++ b/cpp/bench/distance/distance_common.cuh @@ -16,7 +16,7 @@ #include #include -#include +#include #if defined RAFT_DISTANCE_COMPILED #include #endif From 94bc17c4d98e922ce5b1ecbb8432a4e066e1aab6 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Sat, 15 Oct 2022 09:34:14 -0400 Subject: [PATCH 08/35] Fixing benchmark gramm --- cpp/bench/distance/kernels.cu | 42 +++++++++++++++++------------------ 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/cpp/bench/distance/kernels.cu b/cpp/bench/distance/kernels.cu index c2486506e5..6139d00d0b 100644 --- a/cpp/bench/distance/kernels.cu +++ b/cpp/bench/distance/kernels.cu @@ -27,8 +27,9 @@ #include #include -namespace raft::distance::kernels::bench { +namespace raft::bench::distance::kernels { +using namespace raft::distance::kernels; struct GramTestParams { int m; // m parameter of the GEMM int k; // k parameter of the GEMM @@ -38,25 +39,20 @@ struct GramTestParams { }; // struct GramTestParams template -struct GramMatrix : public Fixture { - GramMatrix(const std::string& name, const GramTestParams& p) - : Fixture(name), params(p), A(0, stream), B(0, stream), C(0, stream) +struct GramMatrix : public fixture { + GramMatrix(const GramTestParams& p) + : params(p), handle(stream), A(0, stream), B(0, stream), C(0, stream) { - std::vector kernel_names{"linear", "poly", "rbf", "tanh"}; - std::ostringstream oss; - oss << name << "/" << kernel_names[p.kernel_params.kernel] << "/" << p.m << "x" << p.k << "x" - << p.n << "/" << (p.is_row_major ? "row_major" : "col_major"); - this->SetName(oss.str().c_str()); + // std::vector kernel_names{"linear", "poly", "rbf", "tanh"}; + // std::ostringstream oss; + // oss << name << "/" << kernel_names[p.kernel_params.kernel] << "/" << p.m << "x" << p.k << + // "x" + // << p.n << "/" << (p.is_row_major ? "row_major" : "col_major"); + // this->SetName(oss.str().c_str()); kernel = std::unique_ptr>( KernelFactory::create(p.kernel_params, handle.get_cublas_handle())); - } - - ~GramMatrix() {} - protected: - void allocateBuffers(const ::benchmark::State& state) override - { A.resize(params.m * params.k, stream); B.resize(params.k * params.n, stream); C.resize(params.m * params.n, stream); @@ -64,16 +60,18 @@ struct GramMatrix : public Fixture { r.uniform(A.data(), params.m * params.k, T(-1.0), T(1.0), stream); r.uniform(B.data(), params.k * params.n, T(-1.0), T(1.0), stream); } - void deallocateBuffers(const ::benchmark::State& state) override + + ~GramMatrix() { A.release(); B.release(); C.release(); } - void runBenchmark(::benchmark::State& state) override + + void run_benchmark(::benchmark::State& state) override { if (!this->kernel) { state.SkipWithError("Kernel matrix is not initialized"); } - loopOnState(state, [this]() { + loop_on_state(state, [this]() { (*this->kernel)(A.data(), this->params.m, this->params.k, @@ -86,7 +84,7 @@ struct GramMatrix : public Fixture { } private: - raft::handle_t& handle; + const raft::handle_t handle; std::unique_ptr> kernel; GramTestParams params; @@ -126,7 +124,7 @@ static std::vector getInputs() return param_vec; } -ML_BENCH_REGISTER(GramTestParams, GramMatrix, "", getInputs()); -ML_BENCH_REGISTER(GramTestParams, GramMatrix, "", getInputs()); +RAFT_BENCH_REGISTER(GramMatrix, "", getInputs()); +RAFT_BENCH_REGISTER(GramMatrix, "", getInputs()); -} // namespace raft::distance::kernels::bench +} // namespace raft::bench::distance::kernels From f94cc7f282ed9018382f3253ee1d4367b6f01c9c Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Sat, 15 Oct 2022 11:16:50 -0400 Subject: [PATCH 09/35] Fixing include --- cpp/include/raft/distance/detail/kernels/kernel_factory.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/distance/detail/kernels/kernel_factory.cuh b/cpp/include/raft/distance/detail/kernels/kernel_factory.cuh index 0103ecb003..1aa6809bcd 100644 --- a/cpp/include/raft/distance/detail/kernels/kernel_factory.cuh +++ b/cpp/include/raft/distance/detail/kernels/kernel_factory.cuh @@ -18,7 +18,7 @@ #include "gram_matrix.cuh" #include "kernel_matrices.cuh" -#include +#include #include namespace raft::distance::kernels::detail { From ada42fda724eddf78831a780d588fba68c3573b1 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Sat, 15 Oct 2022 11:50:59 -0400 Subject: [PATCH 10/35] Adding gram test to distances --- cpp/test/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index b10de9d1cc..07ec85bf1e 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -116,6 +116,7 @@ if(BUILD_TESTS) test/distance/dist_minkowski.cu test/distance/dist_russell_rao.cu test/distance/fused_l2_nn.cu + test/distance/gram.cu OPTIONAL DIST ) From 510cb6d105ca16c234a9620f62c05ed97cb59f98 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Sat, 15 Oct 2022 12:28:33 -0400 Subject: [PATCH 11/35] Adding gram.cu --- cpp/test/distance/gram.cu | 189 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 189 insertions(+) create mode 100644 cpp/test/distance/gram.cu diff --git a/cpp/test/distance/gram.cu b/cpp/test/distance/gram.cu new file mode 100644 index 0000000000..d1cdbcacee --- /dev/null +++ b/cpp/test/distance/gram.cu @@ -0,0 +1,189 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if defined RAFT_DISTANCE_COMPILED +#include +#endif + +#include "test_utils.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::distance::kernels { + +// Get the offset of element [i,k]. +HDI int get_offset(int i, int k, int ld, bool is_row_major) +{ + return is_row_major ? i * ld + k : i + k * ld; +} + +struct GramMatrixInputs { + int n1; // feature vectors in matrix 1 + int n2; // featuer vectors in matrix 2 + int n_cols; // number of elements in a feature vector + bool is_row_major; + KernelParams kernel; + int ld1; + int ld2; + int ld_out; + // We will generate random input using the dimensions given here. + // The reference output is calculated by a custom kernel. +}; + +std::ostream& operator<<(std::ostream& os, const GramMatrixInputs& p) +{ + std::vector kernel_names{"linear", "poly", "rbf", "tanh"}; + os << "/" << p.n1 << "x" << p.n2 << "x" << p.n_cols << "/" + << (p.is_row_major ? "RowMajor/" : "ColMajor/") << kernel_names[p.kernel.kernel] << "/ld_" + << p.ld1 << "x" << p.ld2 << "x" << p.ld_out; + return os; +} + +const std::vector inputs = { + {42, 137, 2, false, {KernelType::LINEAR}}, + {42, 137, 2, true, {KernelType::LINEAR}}, + {42, 137, 2, false, {KernelType::LINEAR}, 64, 179, 181}, + {42, 137, 2, true, {KernelType::LINEAR}, 64, 179, 181}, + {137, 42, 2, false, {KernelType::POLYNOMIAL, 2, 0.5, 2.4}}, + {137, 42, 2, true, {KernelType::POLYNOMIAL, 2, 0.5, 2.4}}, + {137, 42, 2, false, {KernelType::POLYNOMIAL, 2, 0.5, 2.4}, 159, 73, 144}, + {137, 42, 2, true, {KernelType::POLYNOMIAL, 2, 0.5, 2.4}, 159, 73, 144}, + {42, 137, 2, false, {KernelType::TANH, 0, 0.5, 2.4}}, + {42, 137, 2, true, {KernelType::TANH, 0, 0.5, 2.4}}, + {42, 137, 2, false, {KernelType::TANH, 0, 0.5, 2.4}, 64, 155, 49}, + {42, 137, 2, true, {KernelType::TANH, 0, 0.5, 2.4}, 64, 155, 143}, + {3, 4, 2, false, {KernelType::RBF, 0, 0.5}}, + {42, 137, 2, false, {KernelType::RBF, 0, 0.5}}, + {42, 137, 2, true, {KernelType::RBF, 0, 0.5}}, + // Distance kernel does not support LD parameter yet. + //{42, 137, 2, false, {KernelType::RBF, 0, 0.5}, 64, 155, 49}, + // {42, 137, 2, true, {KernelType::RBF, 0, 0.5}, 64, 155, 143}, +}; + +template +class GramMatrixTest : public ::testing::TestWithParam { + protected: + GramMatrixTest() + : params(GetParam()), stream(0), x1(0, stream), x2(0, stream), gram(0, stream), gram_host(0) + { + RAFT_CUDA_TRY(cudaStreamCreate(&stream)); + + if (params.ld1 == 0) { params.ld1 = params.is_row_major ? params.n_cols : params.n1; } + if (params.ld2 == 0) { params.ld2 = params.is_row_major ? params.n_cols : params.n2; } + if (params.ld_out == 0) { params.ld_out = params.is_row_major ? params.n2 : params.n1; } + // Derive the size of the ouptut from the offset of the last element. + size_t size = get_offset(params.n1 - 1, params.n_cols - 1, params.ld1, params.is_row_major) + 1; + x1.resize(size, stream); + size = get_offset(params.n2 - 1, params.n_cols - 1, params.ld2, params.is_row_major) + 1; + x2.resize(size, stream); + size = get_offset(params.n1 - 1, params.n2 - 1, params.ld_out, params.is_row_major) + 1; + + gram.resize(size, stream); + RAFT_CUDA_TRY(cudaMemsetAsync(gram.data(), 0, gram.size() * sizeof(math_t), stream)); + gram_host.resize(gram.size()); + std::fill(gram_host.begin(), gram_host.end(), 0); + + raft::random::Rng r(42137ULL); + r.uniform(x1.data(), x1.size(), math_t(0), math_t(1), stream); + r.uniform(x2.data(), x2.size(), math_t(0), math_t(1), stream); + } + + ~GramMatrixTest() override { RAFT_CUDA_TRY_NO_THROW(cudaStreamDestroy(stream)); } + + // Calculate the Gram matrix on the host. + void naiveKernel() + { + std::vector x1_host(x1.size()); + raft::update_host(x1_host.data(), x1.data(), x1.size(), stream); + std::vector x2_host(x2.size()); + raft::update_host(x2_host.data(), x2.data(), x2.size(), stream); + handle.sync_stream(stream); + + for (int i = 0; i < params.n1; i++) { + for (int j = 0; j < params.n2; j++) { + float d = 0; + for (int k = 0; k < params.n_cols; k++) { + if (params.kernel.kernel == KernelType::RBF) { + math_t diff = x1_host[get_offset(i, k, params.ld1, params.is_row_major)] - + x2_host[get_offset(j, k, params.ld2, params.is_row_major)]; + d += diff * diff; + } else { + d += x1_host[get_offset(i, k, params.ld1, params.is_row_major)] * + x2_host[get_offset(j, k, params.ld2, params.is_row_major)]; + } + } + int idx = get_offset(i, j, params.ld_out, params.is_row_major); + math_t v = 0; + switch (params.kernel.kernel) { + case (KernelType::LINEAR): gram_host[idx] = d; break; + case (KernelType::POLYNOMIAL): + v = params.kernel.gamma * d + params.kernel.coef0; + gram_host[idx] = std::pow(v, params.kernel.degree); + break; + case (KernelType::TANH): + gram_host[idx] = std::tanh(params.kernel.gamma * d + params.kernel.coef0); + break; + case (KernelType::RBF): gram_host[idx] = exp(-params.kernel.gamma * d); break; + } + } + } + } + + void runTest() + { + std::unique_ptr> kernel = std::unique_ptr>( + KernelFactory::create(params.kernel, handle.get_cublas_handle())); + + kernel->evaluate(x1.data(), + params.n1, + params.n_cols, + x2.data(), + params.n2, + gram.data(), + params.is_row_major, + stream, + params.ld1, + params.ld2, + params.ld_out); + naiveKernel(); + ASSERT_TRUE(raft::devArrMatchHost( + gram_host.data(), gram.data(), gram.size(), raft::CompareApprox(1e-6f))); + } + + raft::handle_t handle; + cudaStream_t stream = 0; + GramMatrixInputs params; + + rmm::device_uvector x1; + rmm::device_uvector x2; + rmm::device_uvector gram; + std::vector gram_host; +}; + +typedef GramMatrixTest GramMatrixTestFloat; +typedef GramMatrixTest GramMatrixTestDouble; + +TEST_P(GramMatrixTestFloat, Gram) { runTest(); } + +INSTANTIATE_TEST_SUITE_P(GramMatrixTests, GramMatrixTestFloat, ::testing::ValuesIn(inputs)); +}; // end namespace raft::distance::kernels \ No newline at end of file From d1ab18e49a85f46e3cf5fe0c5bbb783d2b3da451 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Sat, 15 Oct 2022 12:33:37 -0400 Subject: [PATCH 12/35] Fixing import --- cpp/test/distance/gram.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/test/distance/gram.cu b/cpp/test/distance/gram.cu index d1cdbcacee..cf7215bddb 100644 --- a/cpp/test/distance/gram.cu +++ b/cpp/test/distance/gram.cu @@ -18,7 +18,7 @@ #include #endif -#include "test_utils.h" +#include "../test_utils.h" #include #include #include From 915d659413af93aa13a351b16f65a46bd574be05 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Sat, 15 Oct 2022 21:24:42 -0400 Subject: [PATCH 13/35] Adding missing curly brace --- cpp/include/raft/util/cache.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/util/cache.cuh b/cpp/include/raft/util/cache.cuh index ef210fad82..b5569467e6 100644 --- a/cpp/include/raft/util/cache.cuh +++ b/cpp/include/raft/util/cache.cuh @@ -27,7 +27,7 @@ #include -namespace raft::util::cache +namespace raft::util::cache { /** * @brief Associative cache with least recently used replacement policy. From 86b72e6283ae7dbb61d90865a9a1d6b90a7c114b Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Sun, 16 Oct 2022 08:05:43 -0400 Subject: [PATCH 14/35] CHanging namespace --- cpp/include/raft/util/cache.cuh | 4 +- cpp/include/raft/util/fast_int_div.cuh | 123 +++++++++++++++++++++++++ 2 files changed, 125 insertions(+), 2 deletions(-) create mode 100644 cpp/include/raft/util/fast_int_div.cuh diff --git a/cpp/include/raft/util/cache.cuh b/cpp/include/raft/util/cache.cuh index b5569467e6..699ab97476 100644 --- a/cpp/include/raft/util/cache.cuh +++ b/cpp/include/raft/util/cache.cuh @@ -27,7 +27,7 @@ #include -namespace raft::util::cache { +namespace raft::cache { /** * @brief Associative cache with least recently used replacement policy. @@ -403,4 +403,4 @@ namespace raft::util::cache { } }; } -; // namespace raft::util::cache +; // namespace raft::cache diff --git a/cpp/include/raft/util/fast_int_div.cuh b/cpp/include/raft/util/fast_int_div.cuh new file mode 100644 index 0000000000..a8fb3cc457 --- /dev/null +++ b/cpp/include/raft/util/fast_int_div.cuh @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::util { + +/** + * @brief Perform fast integer division and modulo using a known divisor + * + * @note This currently only supports 32b signed integers + * @todo Extend support for signed divisors + * @ref Hacker's Delight, Second Edition, Chapter 10 + */ + struct FastIntDiv { + /** + * @defgroup HostMethods Ctor's that are accessible only from host + * @{ + * @brief Host-only ctor's + * @param _d the divisor + */ + FastIntDiv(int _d) : d(_d) { computeScalars(); } + FastIntDiv& operator=(int _d) + { + d = _d; + computeScalars(); + return *this; + } + /** @} */ + + /** + * @defgroup DeviceMethods Ctor's which even the device-side can access + * @{ + * @brief host and device ctor's + * @param other source object to be copied from + */ + HDI FastIntDiv(const FastIntDiv& other) : d(other.d), m(other.m), p(other.p) {} + HDI FastIntDiv& operator=(const FastIntDiv& other) + { + d = other.d; + m = other.m; + p = other.p; + return *this; + } + /** @} */ + + /** divisor */ + int d; + /** the term 'm' as found in the reference chapter */ + unsigned m; + /** the term 'p' as found in the reference chapter */ + int p; + + private: + void computeScalars() + { + if (d == 1) { + m = 0; + p = 1; + return; + } else if (d < 0) { + ASSERT(false, "FastIntDiv: division by negative numbers not supported!"); + } else if (d == 0) { + ASSERT(false, "FastIntDiv: got division by zero!"); + } + int64_t nc = ((1LL << 31) / d) * d - 1; + p = 31; + int64_t twoP, rhs; + do { + ++p; + twoP = 1LL << p; + rhs = nc * (d - twoP % d); + } while (twoP <= rhs); + m = (twoP + d - twoP % d) / d; + } + }; // struct FastIntDiv + +/** + * @brief Division overload, so that FastIntDiv can be transparently switched + * to even on device + * @param n numerator + * @param divisor the denominator + * @return the quotient + */ + HDI int operator/(int n, const FastIntDiv& divisor) + { + if (divisor.d == 1) return n; + int ret = (int64_t(divisor.m) * int64_t(n)) >> divisor.p; + if (n < 0) ++ret; + return ret; + } + +/** + * @brief Modulo overload, so that FastIntDiv can be transparently switched + * to even on device + * @param n numerator + * @param divisor the denominator + * @return the remainder + */ + HDI int operator%(int n, const FastIntDiv& divisor) + { + int quotient = n / divisor; + int remainder = n - quotient * divisor.d; + return remainder; + } + +}; // namespace raft::util From 271abfd2d714812c0300d2c17fdaba9aaa0df3d3 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Sun, 16 Oct 2022 12:00:16 -0400 Subject: [PATCH 15/35] Adding logger --- cpp/include/raft/util/cache.cuh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/util/cache.cuh b/cpp/include/raft/util/cache.cuh index 699ab97476..0c5f0a19e0 100644 --- a/cpp/include/raft/util/cache.cuh +++ b/cpp/include/raft/util/cache.cuh @@ -18,6 +18,7 @@ #include +#include #include #include #include @@ -150,14 +151,14 @@ namespace raft::cache { RAFT_CUDA_TRY(cudaMemsetAsync(cache_time.data(), 0, cache_time.size() * sizeof(int), stream)); } else { if (cache_size > 0) { - CUML_LOG_WARN( + RAFT_LOG_WARN( "Warning: not enough memory to cache a single set of " "rows, not using cache"); } n_cache_sets = 0; cache_size = 0; } - CUML_LOG_DEBUG( + RAFT_LOG_DEBUG( "Creating cache with size=%f MiB, to store %d vectors, in " "%d sets with associativity=%d", cache_size, From 8a13f8da4e1203c21db8258d153da9b6aa158efb Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Sun, 16 Oct 2022 14:22:46 -0400 Subject: [PATCH 16/35] Fixing style --- cpp/include/raft/util/cache.cuh | 161 ++++++++++++------------- cpp/include/raft/util/fast_int_div.cuh | 144 +++++++++++----------- 2 files changed, 152 insertions(+), 153 deletions(-) diff --git a/cpp/include/raft/util/cache.cuh b/cpp/include/raft/util/cache.cuh index 0c5f0a19e0..8394ce83b8 100644 --- a/cpp/include/raft/util/cache.cuh +++ b/cpp/include/raft/util/cache.cuh @@ -18,8 +18,8 @@ #include -#include #include +#include #include #include #include @@ -30,84 +30,84 @@ namespace raft::cache { - /** - * @brief Associative cache with least recently used replacement policy. - * - * SW managed cache in device memory, for ML algos where we can trade memory - * access for computation. The two main functions of this class are the - * management of cache indices, and methods to retrieve/store data using the - * cache indices. - * - * The index management can be considered as a hash map, where the int - * keys are the original vector indices that we want to store, and the values are - * the cache location of these vectors. The keys are hashed into a bucket - * whose size equals the associativity. These are the cache sets. If a cache - * set is full, then new indices are stored by replacing the oldest entries. - * - * Using this index mapping we implement methods to store and retrive data from - * the cache buffer, where a unit of data that we are storing is math_t[n_vec]. - * For example in SVM we store full columns of the kernel matrix at each cache - * entry. - * - * Note: we should have a look if the index management could be simplified using - * concurrent_unordered_map.cuh from cudf. See Issue #914. - * - * Example usage: - * @code{.cpp} - * - * // An expensive calculation that we want to accelerate with caching: - * // we have n keys, and for each key we generate a vector with m elements. - * // The keys and the output values are stored in GPU memory. - * void calc(int *key, int n, int m, float *out, cudaStream_t stream) { - * for (k=0; k cache(h.get_device_allocator(), stream, m); - * - * // A buffer that we will reuse to store the cache indices. - * rmm::device_uvector cache_idx(h.get_device_allocator(), stream, n); - * - * void cached_calc(int *key, int n, int m, float *out, stream) { - * int n_cached = 0; - * - * cache.GetCacheIdxPartitioned(key, n, cache_idx.data(), &n_cached, - * cudaStream_t stream); - * - * // Note: GetCacheIdxPartitioned has reordered the keys so that - * // key[0..n_cached-1] are the keys already in the cache. - * // We collect the corresponding values - * cache.GetVecs(cache_idx.data(), n_cached, out, stream); - * - * // Calculate the elements not in the cache - * int non_cached = n - n_cached; - * if (non_cached > 0) { - * int *key_new = key + n_cached; - * int *cache_idx_new = cache_idx.data() + n_cached; - * float *out_new = out + n_cached * m; - * // AssignCacheIdx can permute the keys, therefore it has to come before - * // we call calc. - * // Note: a call to AssignCacheIdx should always be preceded with - * // GetCacheIdxPartitioned, because that initializes the cache_idx_new array - * // with the cache set (hash bucket) that correspond to the keys. - * // The cache idx will be assigned from that cache set. - * cache.AssignCacheIdx(key_new, non_cached, cache_idx_new, stream); - * - * calc(key_new, non_cached, m, out_new, stream); - * - * // Store the calculated vectors into the cache. - * cache.StoreVecs(out_new, non_cached, non_cached, cache_idx_new, stream); - * } - * } - * @endcode - */ - template - class Cache { +/** + * @brief Associative cache with least recently used replacement policy. + * + * SW managed cache in device memory, for ML algos where we can trade memory + * access for computation. The two main functions of this class are the + * management of cache indices, and methods to retrieve/store data using the + * cache indices. + * + * The index management can be considered as a hash map, where the int + * keys are the original vector indices that we want to store, and the values are + * the cache location of these vectors. The keys are hashed into a bucket + * whose size equals the associativity. These are the cache sets. If a cache + * set is full, then new indices are stored by replacing the oldest entries. + * + * Using this index mapping we implement methods to store and retrive data from + * the cache buffer, where a unit of data that we are storing is math_t[n_vec]. + * For example in SVM we store full columns of the kernel matrix at each cache + * entry. + * + * Note: we should have a look if the index management could be simplified using + * concurrent_unordered_map.cuh from cudf. See Issue #914. + * + * Example usage: + * @code{.cpp} + * + * // An expensive calculation that we want to accelerate with caching: + * // we have n keys, and for each key we generate a vector with m elements. + * // The keys and the output values are stored in GPU memory. + * void calc(int *key, int n, int m, float *out, cudaStream_t stream) { + * for (k=0; k cache(h.get_device_allocator(), stream, m); + * + * // A buffer that we will reuse to store the cache indices. + * rmm::device_uvector cache_idx(h.get_device_allocator(), stream, n); + * + * void cached_calc(int *key, int n, int m, float *out, stream) { + * int n_cached = 0; + * + * cache.GetCacheIdxPartitioned(key, n, cache_idx.data(), &n_cached, + * cudaStream_t stream); + * + * // Note: GetCacheIdxPartitioned has reordered the keys so that + * // key[0..n_cached-1] are the keys already in the cache. + * // We collect the corresponding values + * cache.GetVecs(cache_idx.data(), n_cached, out, stream); + * + * // Calculate the elements not in the cache + * int non_cached = n - n_cached; + * if (non_cached > 0) { + * int *key_new = key + n_cached; + * int *cache_idx_new = cache_idx.data() + n_cached; + * float *out_new = out + n_cached * m; + * // AssignCacheIdx can permute the keys, therefore it has to come before + * // we call calc. + * // Note: a call to AssignCacheIdx should always be preceded with + * // GetCacheIdxPartitioned, because that initializes the cache_idx_new array + * // with the cache set (hash bucket) that correspond to the keys. + * // The cache idx will be assigned from that cache set. + * cache.AssignCacheIdx(key_new, non_cached, cache_idx_new, stream); + * + * calc(key_new, non_cached, m, out_new, stream); + * + * // Store the calculated vectors into the cache. + * cache.StoreVecs(out_new, non_cached, non_cached, cache_idx_new, stream); + * } + * } + * @endcode + */ +template +class Cache { public: /** * @brief Construct a Cache object @@ -403,5 +403,4 @@ namespace raft::cache { } } }; -} -; // namespace raft::cache +}; // namespace raft::cache diff --git a/cpp/include/raft/util/fast_int_div.cuh b/cpp/include/raft/util/fast_int_div.cuh index a8fb3cc457..46e2159ed7 100644 --- a/cpp/include/raft/util/fast_int_div.cuh +++ b/cpp/include/raft/util/fast_int_div.cuh @@ -28,68 +28,68 @@ namespace raft::util { * @todo Extend support for signed divisors * @ref Hacker's Delight, Second Edition, Chapter 10 */ - struct FastIntDiv { - /** - * @defgroup HostMethods Ctor's that are accessible only from host - * @{ - * @brief Host-only ctor's - * @param _d the divisor - */ - FastIntDiv(int _d) : d(_d) { computeScalars(); } - FastIntDiv& operator=(int _d) - { - d = _d; - computeScalars(); - return *this; - } - /** @} */ +struct FastIntDiv { + /** + * @defgroup HostMethods Ctor's that are accessible only from host + * @{ + * @brief Host-only ctor's + * @param _d the divisor + */ + FastIntDiv(int _d) : d(_d) { computeScalars(); } + FastIntDiv& operator=(int _d) + { + d = _d; + computeScalars(); + return *this; + } + /** @} */ - /** - * @defgroup DeviceMethods Ctor's which even the device-side can access - * @{ - * @brief host and device ctor's - * @param other source object to be copied from - */ - HDI FastIntDiv(const FastIntDiv& other) : d(other.d), m(other.m), p(other.p) {} - HDI FastIntDiv& operator=(const FastIntDiv& other) - { - d = other.d; - m = other.m; - p = other.p; - return *this; - } - /** @} */ + /** + * @defgroup DeviceMethods Ctor's which even the device-side can access + * @{ + * @brief host and device ctor's + * @param other source object to be copied from + */ + HDI FastIntDiv(const FastIntDiv& other) : d(other.d), m(other.m), p(other.p) {} + HDI FastIntDiv& operator=(const FastIntDiv& other) + { + d = other.d; + m = other.m; + p = other.p; + return *this; + } + /** @} */ - /** divisor */ - int d; - /** the term 'm' as found in the reference chapter */ - unsigned m; - /** the term 'p' as found in the reference chapter */ - int p; + /** divisor */ + int d; + /** the term 'm' as found in the reference chapter */ + unsigned m; + /** the term 'p' as found in the reference chapter */ + int p; - private: - void computeScalars() - { - if (d == 1) { - m = 0; - p = 1; - return; - } else if (d < 0) { - ASSERT(false, "FastIntDiv: division by negative numbers not supported!"); - } else if (d == 0) { - ASSERT(false, "FastIntDiv: got division by zero!"); - } - int64_t nc = ((1LL << 31) / d) * d - 1; - p = 31; - int64_t twoP, rhs; - do { - ++p; - twoP = 1LL << p; - rhs = nc * (d - twoP % d); - } while (twoP <= rhs); - m = (twoP + d - twoP % d) / d; - } - }; // struct FastIntDiv + private: + void computeScalars() + { + if (d == 1) { + m = 0; + p = 1; + return; + } else if (d < 0) { + ASSERT(false, "FastIntDiv: division by negative numbers not supported!"); + } else if (d == 0) { + ASSERT(false, "FastIntDiv: got division by zero!"); + } + int64_t nc = ((1LL << 31) / d) * d - 1; + p = 31; + int64_t twoP, rhs; + do { + ++p; + twoP = 1LL << p; + rhs = nc * (d - twoP % d); + } while (twoP <= rhs); + m = (twoP + d - twoP % d) / d; + } +}; // struct FastIntDiv /** * @brief Division overload, so that FastIntDiv can be transparently switched @@ -98,13 +98,13 @@ namespace raft::util { * @param divisor the denominator * @return the quotient */ - HDI int operator/(int n, const FastIntDiv& divisor) - { - if (divisor.d == 1) return n; - int ret = (int64_t(divisor.m) * int64_t(n)) >> divisor.p; - if (n < 0) ++ret; - return ret; - } +HDI int operator/(int n, const FastIntDiv& divisor) +{ + if (divisor.d == 1) return n; + int ret = (int64_t(divisor.m) * int64_t(n)) >> divisor.p; + if (n < 0) ++ret; + return ret; +} /** * @brief Modulo overload, so that FastIntDiv can be transparently switched @@ -113,11 +113,11 @@ namespace raft::util { * @param divisor the denominator * @return the remainder */ - HDI int operator%(int n, const FastIntDiv& divisor) - { - int quotient = n / divisor; - int remainder = n - quotient * divisor.d; - return remainder; - } +HDI int operator%(int n, const FastIntDiv& divisor) +{ + int quotient = n / divisor; + int remainder = n - quotient * divisor.d; + return remainder; +} }; // namespace raft::util From 8d7ab1c00eba26889a000971f7f8856bd1868569 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Sun, 16 Oct 2022 16:17:36 -0400 Subject: [PATCH 17/35] Fixing doc --- cpp/include/raft/util/fast_int_div.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/util/fast_int_div.cuh b/cpp/include/raft/util/fast_int_div.cuh index 46e2159ed7..a0cb8f0f53 100644 --- a/cpp/include/raft/util/fast_int_div.cuh +++ b/cpp/include/raft/util/fast_int_div.cuh @@ -23,10 +23,10 @@ namespace raft::util { /** * @brief Perform fast integer division and modulo using a known divisor + * From Hacker's Delight, Second Edition, Chapter 10 * * @note This currently only supports 32b signed integers * @todo Extend support for signed divisors - * @ref Hacker's Delight, Second Edition, Chapter 10 */ struct FastIntDiv { /** From 405f38793ac4d724a2260f1a16381fed46b9533e Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 17 Oct 2022 13:52:53 -0400 Subject: [PATCH 18/35] Pulling solver files over, starting to update qn solver public API --- .../raft/solver/coordinate_descent.cuh | 23 + cpp/include/raft/solver/detail/cd.cuh | 355 +++++ cpp/include/raft/solver/detail/lars.cuh | 1142 +++++++++++++++++ .../raft/solver/detail/learning_rate.h | 71 + .../raft/solver/detail/objectives/hinge.cuh | 191 +++ .../solver/detail/objectives/linearReg.cuh | 124 ++ .../raft/solver/detail/objectives/log.cuh | 30 + .../solver/detail/objectives/logisticReg.cuh | 153 +++ .../raft/solver/detail/objectives/penalty.cuh | 99 ++ .../raft/solver/detail/objectives/sigmoid.cuh | 32 + .../raft/solver/detail/objectives/sign.cuh | 49 + .../solver/detail/objectives/softThres.cuh | 42 + .../raft/solver/detail/qn/objectives/base.cuh | 239 ++++ .../solver/detail/qn/objectives/hinge.cuh | 156 +++ .../solver/detail/qn/objectives/linear.cuh | 79 ++ .../solver/detail/qn/objectives/logistic.cuh | 65 + .../detail/qn/objectives/regularizer.cuh | 94 ++ .../solver/detail/qn/objectives/softmax.cuh | 197 +++ .../raft/solver/detail/qn/qn_decision.cuh | 51 + .../raft/solver/detail/qn/qn_linesearch.cuh | 210 +++ .../raft/solver/detail/qn/qn_solvers.cuh | 469 +++++++ cpp/include/raft/solver/detail/qn/qn_util.cuh | 170 +++ .../raft/solver/detail/qn/simple_mat.cuh | 20 + .../raft/solver/detail/qn/simple_mat/base.hpp | 54 + .../solver/detail/qn/simple_mat/dense.hpp | 413 ++++++ .../solver/detail/qn/simple_mat/sparse.hpp | 216 ++++ cpp/include/raft/solver/detail/sgd.cuh | 422 ++++++ cpp/include/raft/solver/detail/shuffle.h | 39 + cpp/include/raft/solver/gradient_descent.cuh | 23 + .../raft/solver/least_angle_regression.cuh | 23 + cpp/include/raft/solver/quasi_newton.cuh | 79 ++ cpp/include/raft/solver/solver_types.hpp | 238 ++++ 32 files changed, 5568 insertions(+) create mode 100644 cpp/include/raft/solver/coordinate_descent.cuh create mode 100644 cpp/include/raft/solver/detail/cd.cuh create mode 100644 cpp/include/raft/solver/detail/lars.cuh create mode 100644 cpp/include/raft/solver/detail/learning_rate.h create mode 100644 cpp/include/raft/solver/detail/objectives/hinge.cuh create mode 100644 cpp/include/raft/solver/detail/objectives/linearReg.cuh create mode 100644 cpp/include/raft/solver/detail/objectives/log.cuh create mode 100644 cpp/include/raft/solver/detail/objectives/logisticReg.cuh create mode 100644 cpp/include/raft/solver/detail/objectives/penalty.cuh create mode 100644 cpp/include/raft/solver/detail/objectives/sigmoid.cuh create mode 100644 cpp/include/raft/solver/detail/objectives/sign.cuh create mode 100644 cpp/include/raft/solver/detail/objectives/softThres.cuh create mode 100644 cpp/include/raft/solver/detail/qn/objectives/base.cuh create mode 100644 cpp/include/raft/solver/detail/qn/objectives/hinge.cuh create mode 100644 cpp/include/raft/solver/detail/qn/objectives/linear.cuh create mode 100644 cpp/include/raft/solver/detail/qn/objectives/logistic.cuh create mode 100644 cpp/include/raft/solver/detail/qn/objectives/regularizer.cuh create mode 100644 cpp/include/raft/solver/detail/qn/objectives/softmax.cuh create mode 100644 cpp/include/raft/solver/detail/qn/qn_decision.cuh create mode 100644 cpp/include/raft/solver/detail/qn/qn_linesearch.cuh create mode 100644 cpp/include/raft/solver/detail/qn/qn_solvers.cuh create mode 100644 cpp/include/raft/solver/detail/qn/qn_util.cuh create mode 100644 cpp/include/raft/solver/detail/qn/simple_mat.cuh create mode 100644 cpp/include/raft/solver/detail/qn/simple_mat/base.hpp create mode 100644 cpp/include/raft/solver/detail/qn/simple_mat/dense.hpp create mode 100644 cpp/include/raft/solver/detail/qn/simple_mat/sparse.hpp create mode 100644 cpp/include/raft/solver/detail/sgd.cuh create mode 100644 cpp/include/raft/solver/detail/shuffle.h create mode 100644 cpp/include/raft/solver/gradient_descent.cuh create mode 100644 cpp/include/raft/solver/least_angle_regression.cuh create mode 100644 cpp/include/raft/solver/quasi_newton.cuh create mode 100644 cpp/include/raft/solver/solver_types.hpp diff --git a/cpp/include/raft/solver/coordinate_descent.cuh b/cpp/include/raft/solver/coordinate_descent.cuh new file mode 100644 index 0000000000..39b255f524 --- /dev/null +++ b/cpp/include/raft/solver/coordinate_descent.cuh @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::solver::coordinate_descent { + +} \ No newline at end of file diff --git a/cpp/include/raft/solver/detail/cd.cuh b/cpp/include/raft/solver/detail/cd.cuh new file mode 100644 index 0000000000..bd23f39850 --- /dev/null +++ b/cpp/include/raft/solver/detail/cd.cuh @@ -0,0 +1,355 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "shuffle.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::solver::detail { + +namespace { + +/** Epoch and iteration -related state. */ +template +struct ConvState { + math_t coef; + math_t coefMax; + math_t diffMax; +}; + +/** + * Update a single CD coefficient and the corresponding convergence criteria. + * + * @param[inout] coefLoc pointer to the coefficient (arr ptr + column index offset) + * @param[in] squaredLoc pointer to the precomputed data - L2 norm of input for across rows + * @param[inout] convStateLoc pointer to the structure holding the convergence state + * @param[in] l1_alpha L1 regularization coef + */ +template +__global__ void __launch_bounds__(1, 1) cdUpdateCoefKernel(math_t* coefLoc, + const math_t* squaredLoc, + ConvState* convStateLoc, + const math_t l1_alpha) +{ + auto coef = *coefLoc; + auto r = coef > l1_alpha ? coef - l1_alpha : (coef < -l1_alpha ? coef + l1_alpha : 0); + auto squared = *squaredLoc; + r = squared > math_t(1e-5) ? r / squared : math_t(0); + auto diff = raft::myAbs(convStateLoc->coef - r); + if (convStateLoc->diffMax < diff) convStateLoc->diffMax = diff; + auto absv = raft::myAbs(r); + if (convStateLoc->coefMax < absv) convStateLoc->coefMax = absv; + convStateLoc->coef = -r; + *coefLoc = r; +} + +} // namespace + +/** + * Fits a linear, lasso, and elastic-net regression model using Coordinate Descent solver. + * + * i.e. finds coefficients that minimize the following loss function: + * + * f(coef) = 1/2 * || labels - input * coef ||^2 + * + 1/2 * alpha * (1 - l1_ratio) * ||coef||^2 + * + alpha * l1_ratio * ||coef||_1 + * + * + * @param handle + * Reference of raft::handle_t + * @param input + * pointer to an array in column-major format (size of n_rows, n_cols) + * @param n_rows + * n_samples or rows in input + * @param n_cols + * n_features or columns in X + * @param labels + * pointer to an array for labels (size of n_rows) + * @param coef + * pointer to an array for coefficients (size of n_cols). This will be filled with + * coefficients once the function is executed. + * @param intercept + * pointer to a scalar for intercept. This will be filled + * once the function is executed + * @param fit_intercept + * boolean parameter to control if the intercept will be fitted or not + * @param normalize + * boolean parameter to control if the data will be normalized or not; + * NB: the input is scaled by the column-wise biased sample standard deviation estimator. + * @param epochs + * Maximum number of iterations that solver will run + * @param loss + * enum to use different loss functions. Only linear regression loss functions is supported + * right now + * @param alpha + * L1 parameter + * @param l1_ratio + * ratio of alpha will be used for L1. (1 - l1_ratio) * alpha will be used for L2 + * @param shuffle + * boolean parameter to control whether coordinates will be picked randomly or not + * @param tol + * tolerance to stop the solver + * @param sample_weight + * device pointer to sample weight vector of length n_rows (nullptr or uniform weights) + * This vector is modified during the computation + */ +template +void cdFit(const raft::handle_t& handle, + math_t* input, + int n_rows, + int n_cols, + math_t* labels, + math_t* coef, + math_t* intercept, + bool fit_intercept, + bool normalize, + int epochs, + ML::loss_funct loss, + math_t alpha, + math_t l1_ratio, + bool shuffle, + math_t tol, + math_t* sample_weight = nullptr) +{ + raft::common::nvtx::range fun_scope("ML::Solver::cdFit-%d-%d", n_rows, n_cols); + ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one"); + ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two"); + ASSERT(loss == ML::loss_funct::SQRD_LOSS, + "Parameter loss: Only SQRT_LOSS function is supported for now"); + + cudaStream_t stream = handle.get_stream(); + rmm::device_uvector residual(n_rows, stream); + rmm::device_uvector squared(n_cols, stream); + rmm::device_uvector mu_input(0, stream); + rmm::device_uvector mu_labels(0, stream); + rmm::device_uvector norm2_input(0, stream); + math_t h_sum_sw = 0; + + if (sample_weight != nullptr) { + rmm::device_scalar sum_sw(stream); + raft::stats::sum(sum_sw.data(), sample_weight, 1, n_rows, true, stream); + raft::update_host(&h_sum_sw, sum_sw.data(), 1, stream); + + raft::linalg::multiplyScalar( + sample_weight, sample_weight, (math_t)n_rows / h_sum_sw, n_rows, stream); + } + + if (fit_intercept) { + mu_input.resize(n_cols, stream); + mu_labels.resize(1, stream); + if (normalize) { norm2_input.resize(n_cols, stream); } + + GLM::preProcessData(handle, + input, + n_rows, + n_cols, + labels, + intercept, + mu_input.data(), + mu_labels.data(), + norm2_input.data(), + fit_intercept, + normalize, + sample_weight); + } + if (sample_weight != nullptr) { + raft::linalg::sqrt(sample_weight, sample_weight, n_rows, stream); + raft::matrix::matrixVectorBinaryMult( + input, sample_weight, n_rows, n_cols, false, false, stream); + raft::linalg::map_k( + labels, + n_rows, + [] __device__(math_t a, math_t b) { return a * b; }, + stream, + labels, + sample_weight); + } + + std::vector ri(n_cols); + std::mt19937 g(rand()); + initShuffle(ri, g); + + math_t l2_alpha = (1 - l1_ratio) * alpha * n_rows; + math_t l1_alpha = l1_ratio * alpha * n_rows; + + // Precompute the residual + if (normalize) { + // if we normalized the data, we know sample variance for each column is 1, + // thus no need to compute the norm again. + math_t scalar = math_t(n_rows) + l2_alpha; + raft::matrix::setValue(squared.data(), squared.data(), scalar, n_cols, stream); + } else { + raft::linalg::colNorm( + squared.data(), input, n_cols, n_rows, raft::linalg::L2Norm, false, stream); + raft::linalg::addScalar(squared.data(), squared.data(), l2_alpha, n_cols, stream); + } + + raft::copy(residual.data(), labels, n_rows, stream); + + ConvState h_convState; + rmm::device_uvector> convStateBuf(1, stream); + auto convStateLoc = convStateBuf.data(); + + rmm::device_scalar cublas_alpha(1.0, stream); + rmm::device_scalar cublas_beta(0.0, stream); + + for (int i = 0; i < epochs; i++) { + raft::common::nvtx::range epoch_scope("ML::Solver::cdFit::epoch-%d", i); + if (i > 0 && shuffle) { Solver::shuffle(ri, g); } + + RAFT_CUDA_TRY(cudaMemsetAsync(convStateLoc, 0, sizeof(ConvState), stream)); + + for (int j = 0; j < n_cols; j++) { + raft::common::nvtx::range iter_scope("ML::Solver::cdFit::col-%d", j); + int ci = ri[j]; + math_t* coef_loc = coef + ci; + math_t* squared_loc = squared.data() + ci; + math_t* input_col_loc = input + (ci * n_rows); + + // remember current coef + raft::copy(&(convStateLoc->coef), coef_loc, 1, stream); + // calculate the residual without the contribution from column ci + // residual[:] += coef[ci] * X[:, ci] + raft::linalg::axpy( + handle, n_rows, coef_loc, input_col_loc, 1, residual.data(), 1, stream); + + // coef[ci] = dot(X[:, ci], residual[:]) + raft::linalg::gemv(handle, + false, + 1, + n_rows, + cublas_alpha.data(), + input_col_loc, + 1, + residual.data(), + 1, + cublas_beta.data(), + coef_loc, + 1, + stream); + + // Calculate the new coefficient that minimizes f along coordinate line ci + // coef[ci] = SoftTreshold(dot(X[:, ci], residual[:]), l1_alpha) / dot(X[:, ci], X[:, ci])) + // Also, update the convergence criteria. + cdUpdateCoefKernel<<>>( + coef_loc, squared_loc, convStateLoc, l1_alpha); + RAFT_CUDA_TRY(cudaGetLastError()); + + // Restore the residual using the updated coeffecient + raft::linalg::axpy( + handle, n_rows, &(convStateLoc->coef), input_col_loc, 1, residual.data(), 1, stream); + } + raft::update_host(&h_convState, convStateLoc, 1, stream); + handle.sync_stream(stream); + + if (h_convState.coefMax < tol || (h_convState.diffMax / h_convState.coefMax) < tol) break; + } + + if (sample_weight != nullptr) { + raft::matrix::matrixVectorBinaryDivSkipZero( + input, sample_weight, n_rows, n_cols, false, false, stream); + raft::linalg::map_k( + labels, + n_rows, + [] __device__(math_t a, math_t b) { return a / b; }, + stream, + labels, + sample_weight); + raft::linalg::powerScalar(sample_weight, sample_weight, (math_t)2, n_rows, stream); + raft::linalg::multiplyScalar(sample_weight, sample_weight, h_sum_sw / n_rows, n_rows, stream); + } + + if (fit_intercept) { + GLM::postProcessData(handle, + input, + n_rows, + n_cols, + labels, + coef, + intercept, + mu_input.data(), + mu_labels.data(), + norm2_input.data(), + fit_intercept, + normalize); + + } else { + *intercept = math_t(0); + } +} + +/** + * Fits a linear, lasso, and elastic-net regression model using Coordinate Descent solver + * @param handle + * cuml handle + * @param input + * pointer to an array in column-major format (size of n_rows, n_cols) + * @param n_rows + * n_samples or rows in input + * @param n_cols + * n_features or columns in X + * @param coef + * pointer to an array for coefficients (size of n_cols). Calculated in cdFit function. + * @param intercept + * intercept value calculated in cdFit function + * @param preds + * pointer to an array for predictions (size of n_rows). This will be fitted once functions + * is executed. + * @param loss + * enum to use different loss functions. Only linear regression loss functions is supported + * right now. + */ +template +void cdPredict(const raft::handle_t& handle, + const math_t* input, + int n_rows, + int n_cols, + const math_t* coef, + math_t intercept, + math_t* preds, + ML::loss_funct loss) +{ + ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one"); + ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two"); + ASSERT(loss == ML::loss_funct::SQRD_LOSS, + "Parameter loss: Only SQRT_LOSS function is supported for now"); + + Functions::linearRegH(handle, input, n_rows, n_cols, coef, preds, intercept, handle.get_stream()); +} + +}; // namespace raft::solver::detail diff --git a/cpp/include/raft/solver/detail/lars.cuh b/cpp/include/raft/solver/detail/lars.cuh new file mode 100644 index 0000000000..1c2bd04285 --- /dev/null +++ b/cpp/include/raft/solver/detail/lars.cuh @@ -0,0 +1,1142 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::solver::detail { + + +/** + * @brief Select the largest element from the inactive working set. + * + * The inactive set consist of cor[n_active..n-1]. This function returns the + * index of the most correlated element. The value of the largest element is + * returned in cj. + * + * The correlation value is checked for numeric error and convergence, and the + * return status indicates whether training should continue. + * + * @param n_active number of active elements (n_active <= n ) + * @param n number of elements in vector cor + * @param correlation device array of correlations, size [n] + * @param cj host pointer to return the value of the largest element + * @param wokspace buffer, size >= n_cols + * @param max_idx host pointer the index of the max correlation is returned here + * @param indices host pointer of feature column indices, size [n_cols] + * @param n_iter iteration counter + * @param stream CUDA stream + * + * @return fit status + */ +template +LarsFitStatus selectMostCorrelated(idx_t n_active, + idx_t n, + math_t* correlation, + math_t* cj, + rmm::device_uvector& workspace, + idx_t* max_idx, + idx_t n_rows, + idx_t* indices, + idx_t n_iter, + cudaStream_t stream) +{ + const idx_t align_bytes = 16 * sizeof(math_t); + // We might need to start a few elements earlier to ensure that the unary + // op has aligned access for vectorized load. + int start = raft::alignDown(n_active, align_bytes) / sizeof(math_t); + raft::linalg::unaryOp( + workspace.data(), correlation + start, n, [] __device__(math_t a) { return abs(a); }, stream); + thrust::device_ptr ptr(workspace.data() + n_active - start); + auto max_ptr = thrust::max_element(thrust::cuda::par.on(stream), ptr, ptr + n - n_active); + raft::update_host(cj, max_ptr.get(), 1, stream); + raft::interruptible::synchronize(stream); + + *max_idx = n_active + (max_ptr - ptr); // the index of the maximum element + + RAFT_LOG_DEBUG( + "Iteration %d, selected feature %d with correlation %f", n_iter, indices[*max_idx], *cj); + + if (!std::isfinite(*cj)) { + RAFT_LOG_ERROR("Correlation is not finite, aborting."); + return LarsFitStatus::kError; + } + + // Tolerance for early stopping. Note we intentionally use here fp32 epsilon, + // otherwise the tolerance is too small (which could result in numeric error + // in Cholesky rank one update if eps < 0, or exploding regression parameters + // if eps > 0). + const math_t tolerance = std::numeric_limits::epsilon(); + if (abs(*cj) / n_rows < tolerance) { + RAFT_LOG_WARN("Reached tolarence limit with %e", abs(*cj)); + return LarsFitStatus::kStop; + } + return LarsFitStatus::kOk; +} + +/** + * @brief Swap two feature vectors. + * + * The function swaps feature column j and k or the corresponding rows and + * and columns of the Gram matrix. The elements of the cor and indices arrays + * are also swapped. + * + * @param handle cuBLAS handle + * @param j column index + * @param k column index + * @param X device array of feature vectors in column major format, size + * [n_cols * ld_X] + * @param n_rows number of training vectors + * @param n_cols number of features + * @param ld_X leading dimension of X + * @param cor device array of correlations, size [n_cols] + * @param indices host array of indices, size [n_cols] + * @param G device pointer of Gram matrix (or nullptr), size [n_cols * ld_G] + * @param ld_G leading dimension of G + * @param stream CUDA stream + */ +template +void swapFeatures(cublasHandle_t handle, + idx_t j, + idx_t k, + math_t* X, + idx_t n_rows, + idx_t n_cols, + idx_t ld_X, + math_t* cor, + idx_t* indices, + math_t* G, + idx_t ld_G, + cudaStream_t stream) +{ + std::swap(indices[j], indices[k]); + if (G) { + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY( + raft::linalg::detail::cublasSwap(handle, n_cols, G + ld_G * j, 1, G + ld_G * k, 1, stream)); + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY( + raft::linalg::detail::cublasSwap(handle, n_cols, G + j, ld_G, G + k, ld_G, stream)); + } else { + // Only swap X if G is nullptr. Only in that case will we use the feature + // columns, otherwise all the necessary information is already there in G. + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY( + raft::linalg::detail::cublasSwap(handle, n_rows, X + ld_X * j, 1, X + ld_X * k, 1, stream)); + } + // swap (c[j], c[k]) + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasSwap(handle, 1, cor + j, 1, cor + k, 1, stream)); +} + +/** + * @brief Move feature at idx=j into the active set. + * + * We have an active set with n_active elements, and an inactive set with + * n_valid_cols - n_active elements. The matrix X [n_samples, n_features] is + * partitioned in a way that the first n_active columns store the active set. + * Similarily the vectors correlation and indices are partitioned in a way + * that the first n_active elements belong to the active set: + * - active set: X[:,:n_active], correlation[:n_active], indices[:n_active] + * - inactive set: X[:,n_active:], correlation[n_active:], indices[n_active:]. + * + * This function moves the feature column X[:,idx] into the active set by + * replacing the first inactive element with idx. The indices and correlation + * vectors are modified accordinly. The sign array is updated with the sign + * of correlation[n_active]. + * + * @param handle cuBLAS handle + * @param n_active number of active elements, will be increased by one after + * we move the new element j into the active set + * @param j index of the new element (n_active <= j < n_cols) + * @param X device array of feature vectors in column major format, size + * [n_cols * ld_X] + * @param n_rows number of training vectors + * @param n_cols number of valid features colums (ignoring those features which + * are detected to be collinear with the active set) + * @param ld_X leading dimension of X + * @param cor device array of correlations, size [n_cols] + * @param indices host array of indices, size [n_cols] + * @param G device pointer of Gram matrix (or nullptr), size [n_cols * ld_G] + * @param ld_G leading dimension of G + * @param sign device pointer to sign array, size[n] + * @param stream CUDA stream + */ +template +void moveToActive(cublasHandle_t handle, + idx_t* n_active, + idx_t j, + math_t* X, + idx_t n_rows, + idx_t n_cols, + idx_t ld_X, + math_t* cor, + idx_t* indices, + math_t* G, + idx_t ld_G, + math_t* sign, + cudaStream_t stream) +{ + idx_t idx_free = *n_active; + swapFeatures(handle, idx_free, j, X, n_rows, n_cols, ld_X, cor, indices, G, ld_G, stream); + + // sign[n_active] = sign(c[n_active]) + raft::linalg::unaryOp( + sign + idx_free, + cor + idx_free, + 1, + [] __device__(math_t c) -> math_t { + // return the sign of c + return (math_t(0) < c) - (c < math_t(0)); + }, + stream); + + (*n_active)++; +} + +/** + * @brief Update the Cholesky decomposition of the Gram matrix of the active set + * + * G0 = X.T * X, Gram matrix without signs. We use the part that corresponds to + * the active set, [n_A x n_A] + * + * At each step on the LARS path we add one column to the active set, therefore + * the Gram matrix grows incrementally. We update the Cholesky decomposition + * G0 = U.T * U. + * + * The Cholesky decomposition can use the same storage as G0, if the input + * pointers are same. + * + * @param handle RAFT handle + * @param n_active number of active elements + * @param X device array of feature vectors in column major format, size + * [n_rows * n_cols] + * @param n_rows number of training vectors + * @param n_cols number of features + * @param ld_X leading dimension of X (stride of columns) + * @param U device pointer to the Cholesky decomposition of G0, + * size [n_cols * ld_U] + * @param ld_U leading dimension of U + * @param G0 device pointer to Gram matrix G0 = X.T*X (can be nullptr), + * size [n_cols * ld_G]. + * @param ld_G leading dimension of G + * @param workspace workspace for the Cholesky update + * @param eps parameter for cheleskyRankOneUpdate + * @param stream CUDA stream + */ +template +void updateCholesky(const raft::handle_t& handle, + idx_t n_active, + const math_t* X, + idx_t n_rows, + idx_t n_cols, + idx_t ld_X, + math_t* U, + idx_t ld_U, + const math_t* G0, + idx_t ld_G, + rmm::device_uvector& workspace, + math_t eps, + cudaStream_t stream) +{ + const cublasFillMode_t fillmode = CUBLAS_FILL_MODE_UPPER; + if (G0 == nullptr) { + // Calculate the new column of G0. It is stored in U. + math_t* G_row = U + (n_active - 1) * ld_U; + const math_t* X_row = X + (n_active - 1) * ld_X; + math_t one = 1; + math_t zero = 0; + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemv(handle.get_cublas_handle(), + CUBLAS_OP_T, + n_rows, + n_cols, + &one, + X, + n_rows, + X_row, + 1, + &zero, + G_row, + 1, + stream)); + } else if (G0 != U) { + // Copy the new column of G0 into U, because the factorization works in + // place. + raft::copy(U + (n_active - 1) * ld_U, G0 + (n_active - 1) * ld_G, n_active, stream); + } // Otherwise the new data is already in place in U. + + // Update the Cholesky decomposition + int n_work = workspace.size(); + if (n_work == 0) { + // Query workspace size and allocate it + raft::linalg::choleskyRank1Update( + handle, U, n_active, ld_U, nullptr, &n_work, fillmode, stream); + workspace.resize(n_work, stream); + } + raft::linalg::choleskyRank1Update( + handle, U, n_active, ld_U, workspace.data(), &n_work, fillmode, stream, eps); +} + +/** + * @brief Solve for ws = S * GA^(-1) * 1_A using a Cholesky decomposition. + * + * See calcEquiangularVec for more details on the formulas. In this function we + * calculate ws = S * (S * G0 * S)^{-1} 1_A = G0^{-1} (S 1_A) = G0^{-1} sign_A. + * + * @param handle RAFT handle + * @param n_active number of active elements + * @param n_cols number of features + * @param sign array with sign of the active set, size [n_cols] + * @param U device pointer to the Cholesky decomposition of G0, + * size [n_cols * n_cols] + * @param ld_U leading dimension of U (column stride) + * @param ws device pointer, size [n_active] + * @param stream CUDA stream + */ +template +void calcW0(const raft::handle_t& handle, + idx_t n_active, + idx_t n_cols, + const math_t* sign, + const math_t* U, + idx_t ld_U, + math_t* ws, + cudaStream_t stream) +{ + const cublasFillMode_t fillmode = CUBLAS_FILL_MODE_UPPER; + + // First we calculate x by solving equation U.T x = sign_A. + raft::copy(ws, sign, n_active, stream); + math_t alpha = 1; + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublastrsm(handle.get_cublas_handle(), + CUBLAS_SIDE_LEFT, + fillmode, + CUBLAS_OP_T, + CUBLAS_DIAG_NON_UNIT, + n_active, + 1, + &alpha, + U, + ld_U, + ws, + ld_U, + stream)); + + // ws stores x, the solution of U.T x = sign_A. Now we solve U * ws = x + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublastrsm(handle.get_cublas_handle(), + CUBLAS_SIDE_LEFT, + fillmode, + CUBLAS_OP_N, + CUBLAS_DIAG_NON_UNIT, + n_active, + 1, + &alpha, + U, + ld_U, + ws, + ld_U, + stream)); + // Now ws = G0^(-1) sign_A = S GA^{-1} 1_A. +} + +/** + * @brief Calculate A = (1_A * GA^{-1} * 1_A)^{-1/2}. + * + * See calcEquiangularVec for more details on the formulas. + * + * @param handle RAFT handle + * @param A device pointer to store the result + * @param n_active number of active elements + * @param sign array with sign of the active set, size [n_cols] + * @param ws device pointer, size [n_active] + * @param stream CUDA stream + */ +template +void calcA(const raft::handle_t& handle, + math_t* A, + idx_t n_active, + const math_t* sign, + const math_t* ws, + cudaStream_t stream) +{ + // Calculate sum (w) = sum(ws * sign) + auto multiply = [] __device__(math_t w, math_t s) { return w * s; }; + raft::linalg::mapThenSumReduce(A, n_active, multiply, stream, ws, sign); + // Calc Aa = 1 / sqrt(sum(w)) + raft::linalg::unaryOp( + A, A, 1, [] __device__(math_t a) { return 1 / sqrt(a); }, stream); +} + +/** + * @brief Calculate the equiangular vector u, w and A according to [1]. + * + * We introduce the following variables (Python like indexing): + * - n_A number of elements in the active set + * - S = diag(sign_A): diagonal matrix with the signs, size [n_A x n_A] + * - X_A = X[:,:n_A] * S, column vectors of the active set size [n_A x n_A] + * - G0 = X.T * X, Gram matrix without signs. We just use the part that + * corresponds to the active set, [n_A x n_A] + * - GA = X_A.T * X_A is the Gram matrix of the active set, size [n_A x n_A] + * GA = S * G0[:n_A, :n_A] * S + * - 1_A = np.ones(n_A) + * - A = (1_A * GA^{-1} * 1_A)^{-1/2}, scalar, see eq (2.5) in [1] + * - w = A GA^{-1} * 1_A, vector of size [n_A] see eq (2.6) in [1] + * - ws = S * w, vector of size [n_A] + * + * The equiangular vector can be expressed the following way (equation 2.6): + * u = X_A * w = X[:,:n_A] S * w = X[:,:n_A] * ws. + * + * The equiangular vector later appears only in an expression like X.T u, which + * can be reformulated as X.T u = X.T X[:,:n_A] S * w = G[:n_A,:n_A] * ws. + * If the gram matrix is given, then we do not need to calculate u, it will be + * sufficient to calculate ws and A. + * + * We use Cholesky decomposition G0 = U.T * U to solve to calculate A and w + * which depend on GA^{-1}. + * + * References: + * [1] B. Efron, T. Hastie, I. Johnstone, R Tibshirani, Least Angle Regression + * The Annals of Statistics (2004) Vol 32, No 2, 407-499 + * http://statweb.stanford.edu/~tibs/ftp/lars.pdf + * + * @param handle RAFT handle + * @param n_active number of active elements + * @param X device array of feature vectors in column major format, size + * [ld_X * n_cols] + * @param n_rows number of training vectors + * @param n_cols number of features + * @param ld_X leading dimension of array X (column stride, ld_X >= n_rows) + * @param sign array with sign of the active set, size [n_cols] + * @param U device pointer to the Cholesky decomposition of G0, + * size [ld_U * n_cols] + * @param ld_U leading dimension of array U (ld_U >= n_cols) + * @param G0 device pointer to Gram matrix G0 = X.T*X (can be nullptr), + * size [ld_G * n_cols]. Note the difference between G0 and + * GA = X_A.T * X_A + * @param ld_G leading dimension of array G0 (ld_G >= n_cols) + * @param workspace workspace for the Cholesky update + * @param ws device pointer, size [n_active] + * @param A device pointer to a scalar + * @param u_eq device pointer to the equiangular vector, only used if + * Gram==nullptr, size [n_rows]. + * @param eps numerical regularizaton parameter for the Cholesky decomposition + * @param stream CUDA stream + * + * @return fit status + */ +template +LarsFitStatus calcEquiangularVec(const raft::handle_t& handle, + idx_t n_active, + math_t* X, + idx_t n_rows, + idx_t n_cols, + idx_t ld_X, + math_t* sign, + math_t* U, + idx_t ld_U, + math_t* G0, + idx_t ld_G, + rmm::device_uvector& workspace, + math_t* ws, + math_t* A, + math_t* u_eq, + math_t eps, + cudaStream_t stream) +{ + // Since we added a new vector to the active set, we update the Cholesky + // decomposition (U) + updateCholesky( + handle, n_active, X, n_rows, n_cols, ld_X, U, ld_U, G0, ld_G, workspace, eps, stream); + + // Calculate ws = S GA^{-1} 1_A using U + calcW0(handle, n_active, n_cols, sign, U, ld_U, ws, stream); + + calcA(handle, A, n_active, sign, ws, stream); + + // ws *= Aa + raft::linalg::unaryOp( + ws, ws, n_active, [A] __device__(math_t w) { return (*A) * w; }, stream); + + // Check for numeric error + math_t ws_host; + raft::update_host(&ws_host, ws, 1, stream); + math_t diag_host; // U[n_active-1, n_active-1] + raft::update_host(&diag_host, U + ld_U * (n_active - 1) + n_active - 1, 1, stream); + handle.sync_stream(stream); + if (diag_host < 1e-7) { + RAFT_LOG_WARN( + "Vanising diagonal in Cholesky factorization (%e). This indicates " + "collinear features. Dropping current regressor.", + diag_host); + return LarsFitStatus::kCollinear; + } + if (!std::isfinite(ws_host)) { + RAFT_LOG_WARN("ws=%f is not finite at iteration %d", ws_host, n_active); + return LarsFitStatus::kError; + } + + if (G0 == nullptr) { + // Calculate u_eq only in the case if the Gram matrix is not stored. + math_t one = 1; + math_t zero = 0; + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemv(handle.get_cublas_handle(), + CUBLAS_OP_N, + n_rows, + n_active, + &one, + X, + ld_X, + ws, + 1, + &zero, + u_eq, + 1, + stream)); + } + return LarsFitStatus::kOk; +} + +/** + * @brief Calculate the maximum step size (gamma) in the equiangular direction. + * + * Let mu = X beta.T be the current prediction vector. The modified solution + * after taking step gamma is defined as mu' = mu + gamma u. With this + * solution the correlation of the covariates in the active set will decrease + * equally, to a new value |c_j(gamma)| = Cmax - gamma A. At the same time + * the correlation of the values in the inactive set changes according to the + * following formula: c_j(gamma) = c_j - gamma a_j. We increase gamma until + * one of correlations from the inactive set becomes equal with the + * correlation from the active set. + * + * References: + * [1] B. Efron, T. Hastie, I. Johnstone, R Tibshirani, Least Angle Regression + * The Annals of Statistics (2004) Vol 32, No 2, 407-499 + * http://statweb.stanford.edu/~tibs/ftp/lars.pdf + * + * @param handle RAFT handle + * @param max_iter maximum number of iterations + * @param n_rows number of samples + * @param n_cols number of valid feature columns + * @param n_active size of the active set (n_active <= max_iter <= n_cols) + * @param cj value of the maximum correlation + * @param A device pointer to a scalar, as defined by eq 2.5 in [1] + * @param cor device pointer to correlation vector, size [n_active] + * @param G device pointer to Gram matrix of the active set (without signs) + * size [n_active * ld_G] + * @param ld_G leading dimension of G (ld_G >= n_cols) + * @param X device array of training vectors in column major format, + * size [n_rows * n_cols]. Only used if the gram matrix is not avaiable. + * @param ld_X leading dimension of X (ld_X >= n_rows) + * @param u device pointer to equiangular vector size [n_rows]. Only used if the + * Gram matrix G is not available. + * @param ws device pointer to the ws vector defined in calcEquiangularVec, + * size [n_active] + * @param gamma device pointer to a scalar. The max step size is returned here. + * @param a_vec device pointer, size [n_cols] + * @param stream CUDA stream + */ +template +void calcMaxStep(const raft::handle_t& handle, + idx_t max_iter, + idx_t n_rows, + idx_t n_cols, + idx_t n_active, + math_t cj, + const math_t* A, + math_t* cor, + const math_t* G, + idx_t ld_G, + const math_t* X, + idx_t ld_X, + const math_t* u, + const math_t* ws, + math_t* gamma, + math_t* a_vec, + cudaStream_t stream) +{ + // In the active set each element has the same correlation, whose absolute + // value is given by Cmax. + math_t Cmax = std::abs(cj); + if (n_active == n_cols) { + // Last iteration, the inactive set is empty we use equation (2.21) + raft::linalg::unaryOp( + gamma, A, 1, [Cmax] __device__(math_t A) { return Cmax / A; }, stream); + } else { + const int n_inactive = n_cols - n_active; + if (G == nullptr) { + // Calculate a = X.T[:,n_active:] * u (2.11) + math_t one = 1; + math_t zero = 0; + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemv(handle.get_cublas_handle(), + CUBLAS_OP_T, + n_rows, + n_inactive, + &one, + X + n_active * ld_X, + ld_X, + u, + 1, + &zero, + a_vec, + 1, + stream)); + } else { + // Calculate a = X.T[:,n_A:] * u = X.T[:, n_A:] * X[:,:n_A] * ws + // = G[n_A:,:n_A] * ws (2.11) + math_t one = 1; + math_t zero = 0; + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemv(handle.get_cublas_handle(), + CUBLAS_OP_N, + n_inactive, + n_active, + &one, + G + n_active, + ld_G, + ws, + 1, + &zero, + a_vec, + 1, + stream)); + } + const math_t tiny = std::numeric_limits::min(); + const math_t huge = std::numeric_limits::max(); + // + // gamma = min^+_{j \in inactive} {(Cmax - cor_j) / (A-a_j), + // (Cmax + cor_j) / (A+a_j)} (2.13) + auto map = [Cmax, A, tiny, huge] __device__(math_t c, math_t a) -> math_t { + math_t tmp1 = (Cmax - c) / (*A - a + tiny); + math_t tmp2 = (Cmax + c) / (*A + a + tiny); + // We consider only positive elements while we search for the minimum + math_t val = (tmp1 > 0) ? tmp1 : huge; + if (tmp2 > 0 && tmp2 < val) val = tmp2; + return val; + }; + raft::linalg::mapThenReduce( + gamma, n_inactive, huge, map, cub::Min(), stream, cor + n_active, a_vec); + } +} + +/** + * @brief Initialize for Lars training. + * + * We calculate the initial correlation, initialize the indices array, and set + * up pointers to store the Cholesky factorization. + * + * @param handle RAFT handle + * @param X device array of training vectors in column major format, + * size [ld_X * n_cols]. + * @param n_rows number of samples + * @param n_cols number of valid feature columns + * @param ld_X leading dimension of X (ld_X >= n_rows) + * @param y device pointer to regression targets, size [n_rows] + * @param Gram device pointer to Gram matrix (X.T * X), size [n_cols * ld_G], + * can be nullptr + * @param ld_G leading dimension of G (ld_G >= n_cols) + * @param U_buffer device buffer that will be initialized to store the Cholesky + * factorization. Only used if Gram is nullptr. + * @param U device pointer to U + * @param ld_U leading dimension of U + * @param indices host buffer to store feature column indices + * @param cor device pointer to correlation vector, size [n_cols] + * @param max_iter host pointer to the maximum number of iterations + * @param coef_path device pointer to store coefficients along the + * regularization path size [(max_iter + 1) * max_iter], can be nullptr + * @param stream CUDA stream + */ +template +void larsInit(const raft::handle_t& handle, + const math_t* X, + idx_t n_rows, + idx_t n_cols, + idx_t ld_X, + const math_t* y, + math_t* Gram, + idx_t ld_G, + rmm::device_uvector& U_buffer, + math_t** U, + idx_t* ld_U, + std::vector& indices, + rmm::device_uvector& cor, + int* max_iter, + math_t* coef_path, + cudaStream_t stream) +{ + if (n_cols < *max_iter) { *max_iter = n_cols; } + if (Gram == nullptr) { + const idx_t align_bytes = 256; + *ld_U = raft::alignTo(*max_iter, align_bytes); + try { + U_buffer.resize((*ld_U) * (*max_iter), stream); + } catch (std::bad_alloc const&) { + THROW( + "Not enough GPU memory! The memory usage depends quadraticaly on the " + "n_nonzero_coefs parameter, try to decrease it."); + } + *U = U_buffer.data(); + } else { + // Set U as G. During the solution in larsFit, the Cholesky factorization + // U will overwrite G. + *U = Gram; + *ld_U = ld_G; + } + std::iota(indices.data(), indices.data() + n_cols, 0); + + math_t one = 1; + math_t zero = 0; + // Set initial correlation to X.T * y + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemv(handle.get_cublas_handle(), + CUBLAS_OP_T, + n_rows, + n_cols, + &one, + X, + ld_X, + y, + 1, + &zero, + cor.data(), + 1, + stream)); + if (coef_path) { + RAFT_CUDA_TRY( + cudaMemsetAsync(coef_path, 0, sizeof(math_t) * (*max_iter + 1) * (*max_iter), stream)); + } +} + +/** + * @brief Update regression coefficient and correlations + * + * After we calculated the equiangular vector and the step size (gamma) we + * adjust the regression coefficients here. + * + * See calcEquiangularVec for definition of ws. + * + * @param handle RAFT handle + * @param max_iter maximum number of iterations + * @param n_cols number of valid feature columns + * @param n_active number of elements in the active set (n_active <= n_cols) + * @param gamma device pointer to the maximum step size (scalar) + * @param ws device pointer to the ws vector, size [n_cols] + * @param cor device pointer to the correlations, size [n_cols] + * @param a_vec device pointer to a = X.T[:,n_A:] * u, size [n_cols] + * @param beta pointer to regression coefficents, size [max_iter] + * @param coef_path device pointer to all the coefficients along the + * regularization path, size [(max_iter + 1) * max_iter] + * @param stream CUDA stream + */ +template +void updateCoef(const raft::handle_t& handle, + idx_t max_iter, + idx_t n_cols, + idx_t n_active, + math_t* gamma, + const math_t* ws, + math_t* cor, + math_t* a_vec, + math_t* beta, + math_t* coef_path, + cudaStream_t stream) +{ + // It is sufficient to update correlations only for the inactive set. + // cor[n_active:] -= gamma * a_vec + int n_inactive = n_cols - n_active; + if (n_inactive > 0) { + raft::linalg::binaryOp( + cor + n_active, + cor + n_active, + a_vec, + n_inactive, + [gamma] __device__(math_t c, math_t a) { return c - *gamma * a; }, + stream); + } + // beta[:n_active] += gamma * ws + raft::linalg::binaryOp( + beta, + beta, + ws, + n_active, + [gamma] __device__(math_t b, math_t w) { return b + *gamma * w; }, + stream); + if (coef_path) { raft::copy(coef_path + n_active * max_iter, beta, n_active, stream); } +} + +/** + * @brief Train a regressor using Least Angre Regression. + * + * Least Angle Regression (LAR or LARS) is a model selection algorithm. It + * builds up the model using the following algorithm: + * + * 1. We start with all the coefficients equal to zero. + * 2. At each step we select the predictor that has the largest absolute + * correlation with the residual. + * 3. We take the largest step possible in the direction which is equiangular + * with all the predictors selected so far. The largest step is determined + * such that using this step a new predictor will have as much correlation + * with the residual as any of the currently active predictors. + * 4. Stop if max_iter reached or all the predictors are used, or if the + * correlation between any unused predictor and the residual is lower than + * a tolerance. + * + * The solver is based on [1]. The equations referred in the comments correspond + * to the equations in the paper. + * + * Note: this algorithm assumes that the offset is removed from X and y, and + * each feature is normalized: + * - sum_i y_i = 0, + * - sum_i x_{i,j} = 0, sum_i x_{i,j}^2=1 for j=0..n_col-1 + * + * References: + * [1] B. Efron, T. Hastie, I. Johnstone, R Tibshirani, Least Angle Regression + * The Annals of Statistics (2004) Vol 32, No 2, 407-499 + * http://statweb.stanford.edu/~tibs/ftp/lars.pdf + * + * @param handle RAFT handle + * @param X device array of training vectors in column major format, + * size [n_rows * n_cols]. Note that the columns of X will be permuted if + * the Gram matrix is not specified. It is expected that X is normalized so + * that each column has zero mean and unit variance. + * @param n_rows number of training samples + * @param n_cols number of feature columns + * @param y device array of the regression targets, size [n_rows]. y should + * be normalized to have zero mean. + * @param beta device array of regression coefficients, has to be allocated on + * entry, size [max_iter] + * @param active_idx device array containing the indices of active variables. + * Must be allocated on entry. Size [max_iter] + * @param alphas device array to return the maximum correlation along the + * regularization path. Must be allocated on entry, size [max_iter+1]. + * @param n_active host pointer to return the number of active elements (scalar) + * @param Gram device array containing Gram matrix containing X.T * X. Can be + * nullptr. + * @param max_iter maximum number of iterations, this equals with the maximum + * number of coefficients returned. max_iter <= n_cols. + * @param coef_path coefficients along the regularization path are returned + * here. Must be nullptr, or a device array already allocated on entry. + * Size [max_iter * (max_iter+1)]. + * @param verbosity verbosity level + * @param ld_X leading dimension of X (stride of columns) + * @param ld_G leading dimesion of G + * @param eps numeric parameter for Cholesky rank one update + */ +template +void larsFit(const raft::handle_t& handle, + math_t* X, + idx_t n_rows, + idx_t n_cols, + const math_t* y, + math_t* beta, + idx_t* active_idx, + math_t* alphas, + idx_t* n_active, + math_t* Gram = nullptr, + int max_iter = 500, + math_t* coef_path = nullptr, + int verbosity = 0, + idx_t ld_X = 0, + idx_t ld_G = 0, + math_t eps = -1) +{ + ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one"); + ASSERT(n_rows > 0, "Parameter n_rows: number of rows cannot be less than one"); + ML::Logger::get().setLevel(verbosity); + + // Set default ld parameters if needed. + if (ld_X == 0) ld_X = n_rows; + if (Gram && ld_G == 0) ld_G = n_cols; + + cudaStream_t stream = handle.get_stream(); + + // We will use either U_buffer.data() to store the Cholesky factorization, or + // store it in place at Gram. Pointer U will point to the actual storage. + rmm::device_uvector U_buffer(0, stream); + idx_t ld_U = 0; + math_t* U = nullptr; + + // Indices of elements in the active set. + std::vector indices(n_cols); + // Sign of the correlation at the time when the element was added to the + // active set. + rmm::device_uvector sign(n_cols, stream); + + // Correlation between the residual mu = y - X.T*beta and columns of X + rmm::device_uvector cor(n_cols, stream); + + // Temporary arrays used by the solver + rmm::device_scalar A(stream); + rmm::device_uvector a_vec(n_cols, stream); + rmm::device_scalar gamma(stream); + rmm::device_uvector u_eq(n_rows, stream); + rmm::device_uvector ws(max_iter, stream); + rmm::device_uvector workspace(n_cols, stream); + + larsInit(handle, + X, + n_rows, + n_cols, + ld_X, + y, + Gram, + ld_G, + U_buffer, + &U, + &ld_U, + indices, + cor, + &max_iter, + coef_path, + stream); + + // If we detect collinear features, then we will move them to the end of the + // correlation array and mark them as invalid (simply by decreasing + // n_valid_cols). At every iteration the solver is only working with the valid + // columns stored at X[:,:n_valid_cols], and G[:n_valid_cols, :n_valid_cols] + // cor[:n_valid_cols]. + int n_valid_cols = n_cols; + + *n_active = 0; + for (int i = 0; i < max_iter; i++) { + math_t cj; + idx_t j; + LarsFitStatus status = selectMostCorrelated( + *n_active, n_valid_cols, cor.data(), &cj, workspace, &j, n_rows, indices.data(), i, stream); + if (status != LarsFitStatus::kOk) { break; } + + moveToActive(handle.get_cublas_handle(), + n_active, + j, + X, + n_rows, + n_valid_cols, + ld_X, + cor.data(), + indices.data(), + Gram, + ld_G, + sign.data(), + stream); + + status = calcEquiangularVec(handle, + *n_active, + X, + n_rows, + n_valid_cols, + ld_X, + sign.data(), + U, + ld_U, + Gram, + ld_G, + workspace, + ws.data(), + A.data(), + u_eq.data(), + eps, + stream); + + if (status == LarsFitStatus::kError) { + if (*n_active > 1) { RAFT_LOG_WARN("Returning with last valid model."); } + *n_active -= 1; + break; + } else if (status == LarsFitStatus::kCollinear) { + // We move the current feature to the invalid set + swapFeatures(handle.get_cublas_handle(), + n_valid_cols - 1, + *n_active - 1, + X, + n_rows, + n_cols, + ld_X, + cor.data(), + indices.data(), + Gram, + ld_G, + stream); + *n_active -= 1; + n_valid_cols--; + continue; + } + + calcMaxStep(handle, + max_iter, + n_rows, + n_valid_cols, + *n_active, + cj, + A.data(), + cor.data(), + Gram, + ld_G, + X, + ld_X, + u_eq.data(), + ws.data(), + gamma.data(), + a_vec.data(), + stream); + + updateCoef(handle, + max_iter, + n_valid_cols, + *n_active, + gamma.data(), + ws.data(), + cor.data(), + a_vec.data(), + beta, + coef_path, + stream); + } + + if (*n_active > 0) { + // Apply sklearn definition of alphas = cor / n_rows + raft::linalg::unaryOp( + alphas, + cor.data(), + *n_active, + [n_rows] __device__(math_t c) { return abs(c) / n_rows; }, + stream); + + // Calculate the final correlation. We use the correlation from the last + // iteration and apply the changed during the last LARS iteration: + // alpha[n_active] = cor[n_active-1] - gamma * A + math_t* gamma_ptr = gamma.data(); + math_t* A_ptr = A.data(); + raft::linalg::unaryOp( + alphas + *n_active, + cor.data() + *n_active - 1, + 1, + [gamma_ptr, A_ptr, n_rows] __device__(math_t c) { + return abs(c - (*gamma_ptr) * (*A_ptr)) / n_rows; + }, + stream); + + raft::update_device(active_idx, indices.data(), *n_active, stream); + } else { + THROW("Model is not fitted."); + } +} + +/** + * @brief Predict with least angle regressor. + * + * @param handle RAFT handle + * @param X device array of training vectors in column major format, + * size [n_rows * n_cols]. + * @param n_rows number of training samples + * @param n_cols number of feature columns + * @param ld_X leading dimension of X (stride of columns) + * @param beta device array of regression coefficients, size [n_active] + * @param n_active the number of regression coefficients + * @param active_idx device array containing the indices of active variables. + * Only these columns of X will be used for prediction, size [n_active]. + * @param intercept + * @param preds device array to store the predictions, size [n_rows]. Must be + * allocated on entry. + */ +template +void larsPredict(const raft::handle_t& handle, + const math_t* X, + idx_t n_rows, + idx_t n_cols, + idx_t ld_X, + const math_t* beta, + idx_t n_active, + idx_t* active_idx, + math_t intercept, + math_t* preds) +{ + cudaStream_t stream = handle.get_stream(); + rmm::device_uvector beta_sorted(0, stream); + rmm::device_uvector X_active_cols(0, stream); + auto execution_policy = handle.get_thrust_policy(); + + if (n_active == 0 || n_rows == 0) return; + + if (n_active == n_cols) { + // We make a copy of the beta coefs and sort them + beta_sorted.resize(n_active, stream); + rmm::device_uvector idx_sorted(n_active, stream); + raft::copy(beta_sorted.data(), beta, n_active, stream); + raft::copy(idx_sorted.data(), active_idx, n_active, stream); + thrust::device_ptr beta_ptr(beta_sorted.data()); + thrust::device_ptr idx_ptr(idx_sorted.data()); + thrust::sort_by_key(execution_policy, idx_ptr, idx_ptr + n_active, beta_ptr); + beta = beta_sorted.data(); + } else { + // We collect active columns of X to contiguous space + X_active_cols.resize(n_active * ld_X, stream); + const int TPB = 64; + raft::cache::get_vecs<<>>( + X, ld_X, active_idx, n_active, X_active_cols.data()); + RAFT_CUDA_TRY(cudaGetLastError()); + X = X_active_cols.data(); + } + // Initialize preds = intercept + thrust::device_ptr pred_ptr(preds); + thrust::fill(execution_policy, pred_ptr, pred_ptr + n_rows, intercept); + math_t one = 1; + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemv(handle.get_cublas_handle(), + CUBLAS_OP_N, + n_rows, + n_active, + &one, + X, + ld_X, + beta, + 1, + &one, + preds, + 1, + stream)); +} +}; // namespace raft::solver::detail diff --git a/cpp/include/raft/solver/detail/learning_rate.h b/cpp/include/raft/solver/detail/learning_rate.h new file mode 100644 index 0000000000..c83a65d472 --- /dev/null +++ b/cpp/include/raft/solver/detail/learning_rate.h @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::solver::detail { + +template +math_t max(math_t a, math_t b) +{ + return (a < b) ? b : a; + ; +} + +template +math_t invScaling(math_t eta, math_t power_t, int t) +{ + return (eta / pow(t, power_t)); +} + +template +math_t regDLoss(math_t a, math_t b) +{ + return a - b; +} + +template +math_t calOptimalInit(math_t alpha) +{ + math_t typw = sqrt(math_t(1.0) / sqrt(alpha)); + math_t initial_eta0 = typw / max(math_t(1.0), regDLoss(-typw, math_t(1.0))); + return (math_t(1.0) / (initial_eta0 * alpha)); +} + +template +math_t optimal(math_t alpha, math_t optimal_init, int t) +{ + return math_t(1.0) / (alpha * (optimal_init + t - 1)); +} + +template +math_t calLearningRate(ML::lr_type lr_type, math_t eta, math_t power_t, math_t alpha, math_t t) +{ + if (lr_type == ML::lr_type::CONSTANT) { + return eta; + } else if (lr_type == ML::lr_type::INVSCALING) { + return invScaling(eta, power_t, t); + } else if (lr_type == ML::lr_type::OPTIMAL) { + return optimal(alpha, eta, t); + } else { + return math_t(0); + } +} + +}; // namespace raft::solver::detail diff --git a/cpp/include/raft/solver/detail/objectives/hinge.cuh b/cpp/include/raft/solver/detail/objectives/hinge.cuh new file mode 100644 index 0000000000..c6152a8fbe --- /dev/null +++ b/cpp/include/raft/solver/detail/objectives/hinge.cuh @@ -0,0 +1,191 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "penalty.cuh" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::solver::detail::objectives { + +template +void hingeLossGradMult(math_t* data, + const math_t* vec1, + const math_t* vec2, + idx_type n_row, + idx_type n_col, + cudaStream_t stream) +{ + raft::linalg::matrixVectorOp( + data, + data, + vec1, + vec2, + n_col, + n_row, + false, + false, + [] __device__(math_t a, math_t b, math_t c) { + if (c < math_t(1)) + return -a * b; + else + return math_t(0); + }, + stream); +} + +template +void hingeLossSubtract( + math_t* out, const math_t* in, math_t scalar, idx_type len, cudaStream_t stream) +{ + raft::linalg::unaryOp( + out, + in, + len, + [scalar] __device__(math_t in) { + if (in < scalar) + return math_t(1) - in; + else + return math_t(0); + }, + stream); +} + +template +void hingeH(const raft::handle_t& handle, + const math_t* input, + idx_type n_rows, + idx_type n_cols, + const math_t* coef, + math_t* pred, + math_t intercept, + cudaStream_t stream) +{ + raft::linalg::gemm( + handle, input, n_rows, n_cols, coef, pred, n_rows, 1, CUBLAS_OP_N, CUBLAS_OP_N, stream); + + if (intercept != math_t(0)) raft::linalg::addScalar(pred, pred, intercept, n_rows, stream); + + sign(pred, pred, math_t(1.0), n_rows, stream); +} + +template +void hingeLossGrads(const raft::handle_t& handle, + math_t* input, + int n_rows, + int n_cols, + const math_t* labels, + const math_t* coef, + math_t* grads, + penalty pen, + math_t alpha, + math_t l1_ratio, + cudaStream_t stream) +{ + rmm::device_uvector labels_pred(n_rows, stream); + + raft::linalg::gemm(handle, + input, + n_rows, + n_cols, + coef, + labels_pred.data(), + n_rows, + 1, + CUBLAS_OP_N, + CUBLAS_OP_N, + stream); + + raft::linalg::eltwiseMultiply(labels_pred.data(), labels_pred.data(), labels, n_rows, stream); + hingeLossGradMult(input, labels, labels_pred.data(), n_rows, n_cols, stream); + raft::stats::mean(grads, input, n_cols, n_rows, false, false, stream); + + rmm::device_uvector pen_grads(0, stream); + + if (pen != penalty::NONE) pen_grads.resize(n_cols, stream); + + if (pen == penalty::L1) { + lassoGrad(pen_grads.data(), coef, n_cols, alpha, stream); + } else if (pen == penalty::L2) { + ridgeGrad(pen_grads.data(), coef, n_cols, alpha, stream); + } else if (pen == penalty::ELASTICNET) { + elasticnetGrad(pen_grads.data(), coef, n_cols, alpha, l1_ratio, stream); + } + + if (pen != penalty::NONE) { raft::linalg::add(grads, grads, pen_grads.data(), n_cols, stream); } +} + +template +void hingeLoss(const raft::handle_t& handle, + math_t* input, + int n_rows, + int n_cols, + const math_t* labels, + const math_t* coef, + math_t* loss, + penalty pen, + math_t alpha, + math_t l1_ratio, + cudaStream_t stream) +{ + rmm::device_uvector labels_pred(n_rows, stream); + + raft::linalg::gemm(handle, + input, + n_rows, + n_cols, + coef, + labels_pred.data(), + n_rows, + 1, + CUBLAS_OP_N, + CUBLAS_OP_N, + stream); + + raft::linalg::eltwiseMultiply(labels_pred.data(), labels_pred.data(), labels, n_rows, stream); + + hingeLossSubtract(labels_pred.data(), labels_pred.data(), math_t(1), n_rows, stream); + + raft::stats::sum(loss, labels_pred.data(), 1, n_rows, false, stream); + + rmm::device_uvector pen_val(0, stream); + + if (pen != penalty::NONE) pen_val.resize(1, stream); + + if (pen == penalty::L1) { + lasso(pen_val.data(), coef, n_cols, alpha, stream); + } else if (pen == penalty::L2) { + ridge(pen_val.data(), coef, n_cols, alpha, stream); + } else if (pen == penalty::ELASTICNET) { + elasticnet(pen_val.data(), coef, n_cols, alpha, l1_ratio, stream); + } + + if (pen != penalty::NONE) { raft::linalg::add(loss, loss, pen_val.data(), 1, stream); } +} + +}; // namespace raft::solver::detail::objectives diff --git a/cpp/include/raft/solver/detail/objectives/linearReg.cuh b/cpp/include/raft/solver/detail/objectives/linearReg.cuh new file mode 100644 index 0000000000..78a22b10e7 --- /dev/null +++ b/cpp/include/raft/solver/detail/objectives/linearReg.cuh @@ -0,0 +1,124 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "penalty.cuh" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::solver::detail::objectives { + +template +void linearRegH(const raft::handle_t& handle, + const math_t* input, + int n_rows, + int n_cols, + const math_t* coef, + math_t* pred, + math_t intercept, + cudaStream_t stream) +{ + raft::linalg::gemm( + handle, input, n_rows, n_cols, coef, pred, n_rows, 1, CUBLAS_OP_N, CUBLAS_OP_N, stream); + + if (intercept != math_t(0)) raft::linalg::addScalar(pred, pred, intercept, n_rows, stream); +} + +template +void linearRegLossGrads(const raft::handle_t& handle, + math_t* input, + int n_rows, + int n_cols, + const math_t* labels, + const math_t* coef, + math_t* grads, + penalty pen, + math_t alpha, + math_t l1_ratio, + cudaStream_t stream) +{ + rmm::device_uvector labels_pred(n_rows, stream); + + linearRegH(handle, input, n_rows, n_cols, coef, labels_pred.data(), math_t(0), stream); + raft::linalg::subtract(labels_pred.data(), labels_pred.data(), labels, n_rows, stream); + raft::matrix::matrixVectorBinaryMult( + input, labels_pred.data(), n_rows, n_cols, false, false, stream); + + raft::stats::mean(grads, input, n_cols, n_rows, false, false, stream); + raft::linalg::scalarMultiply(grads, grads, math_t(2), n_cols, stream); + + rmm::device_uvector pen_grads(0, stream); + + if (pen != penalty::NONE) pen_grads.resize(n_cols, stream); + + if (pen == penalty::L1) { + lassoGrad(pen_grads.data(), coef, n_cols, alpha, stream); + } else if (pen == penalty::L2) { + ridgeGrad(pen_grads.data(), coef, n_cols, alpha, stream); + } else if (pen == penalty::ELASTICNET) { + elasticnetGrad(pen_grads.data(), coef, n_cols, alpha, l1_ratio, stream); + } + + if (pen != penalty::NONE) { raft::linalg::add(grads, grads, pen_grads.data(), n_cols, stream); } +} + +template +void linearRegLoss(const raft::handle_t& handle, + math_t* input, + int n_rows, + int n_cols, + const math_t* labels, + const math_t* coef, + math_t* loss, + penalty pen, + math_t alpha, + math_t l1_ratio, + cudaStream_t stream) +{ + rmm::device_uvector labels_pred(n_rows, stream); + + linearRegH(handle, input, n_rows, n_cols, coef, labels_pred.data(), math_t(0), stream); + + raft::linalg::subtract(labels_pred.data(), labels, labels_pred.data(), n_rows, stream); + raft::matrix::power(labels_pred.data(), n_rows, stream); + raft::stats::mean(loss, labels_pred.data(), 1, n_rows, false, false, stream); + + rmm::device_uvector pen_val(0, stream); + + if (pen != penalty::NONE) pen_val.resize(1, stream); + + if (pen == penalty::L1) { + lasso(pen_val.data(), coef, n_cols, alpha, stream); + } else if (pen == penalty::L2) { + ridge(pen_val.data(), coef, n_cols, alpha, stream); + } else if (pen == penalty::ELASTICNET) { + elasticnet(pen_val.data(), coef, n_cols, alpha, l1_ratio, stream); + } + + if (pen != penalty::NONE) { raft::linalg::add(loss, loss, pen_val.data(), 1, stream); } +} + +}; // namespace raft::solver::detail::objectives \ No newline at end of file diff --git a/cpp/include/raft/solver/detail/objectives/log.cuh b/cpp/include/raft/solver/detail/objectives/log.cuh new file mode 100644 index 0000000000..c62e7e580c --- /dev/null +++ b/cpp/include/raft/solver/detail/objectives/log.cuh @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::solver::detail::objectives { + +template +void f_log(T* out, T* in, T scalar, IdxType len, cudaStream_t stream) +{ + raft::linalg::unaryOp( + out, in, len, [scalar] __device__(T in) { return raft::myLog(in) * scalar; }, stream); +} + +}; // end namespace raft::solver::detail::objectives diff --git a/cpp/include/raft/solver/detail/objectives/logisticReg.cuh b/cpp/include/raft/solver/detail/objectives/logisticReg.cuh new file mode 100644 index 0000000000..40a2d4b2c4 --- /dev/null +++ b/cpp/include/raft/solver/detail/objectives/logisticReg.cuh @@ -0,0 +1,153 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "penalty.cuh" +#include "sigmoid.cuh" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::solver::detail::objectives { + +template +void logisticRegH(const raft::handle_t& handle, + const math_t* input, + int n_rows, + int n_cols, + const math_t* coef, + math_t* pred, + math_t intercept, + cudaStream_t stream) +{ + raft::linalg::gemm( + handle, input, n_rows, n_cols, coef, pred, n_rows, 1, CUBLAS_OP_N, CUBLAS_OP_N, stream); + + if (intercept != math_t(0)) raft::linalg::addScalar(pred, pred, intercept, n_rows, stream); + + sigmoid(pred, pred, n_rows, stream); +} + +template +void logisticRegLossGrads(const raft::handle_t& handle, + math_t* input, + int n_rows, + int n_cols, + const math_t* labels, + const math_t* coef, + math_t* grads, + penalty pen, + math_t alpha, + math_t l1_ratio, + cudaStream_t stream) +{ + rmm::device_uvector labels_pred(n_rows, stream); + + logisticRegH(handle, input, n_rows, n_cols, coef, labels_pred.data(), math_t(0), stream); + raft::linalg::subtract(labels_pred.data(), labels_pred.data(), labels, n_rows, stream); + raft::matrix::matrixVectorBinaryMult( + input, labels_pred.data(), n_rows, n_cols, false, false, stream); + + raft::stats::mean(grads, input, n_cols, n_rows, false, false, stream); + + rmm::device_uvector pen_grads(0, stream); + + if (pen != penalty::NONE) pen_grads.resize(n_cols, stream); + + if (pen == penalty::L1) { + lassoGrad(pen_grads.data(), coef, n_cols, alpha, stream); + } else if (pen == penalty::L2) { + ridgeGrad(pen_grads.data(), coef, n_cols, alpha, stream); + } else if (pen == penalty::ELASTICNET) { + elasticnetGrad(pen_grads.data(), coef, n_cols, alpha, l1_ratio, stream); + } + + if (pen != penalty::NONE) { raft::linalg::add(grads, grads, pen_grads.data(), n_cols, stream); } +} + +template +void logLoss(T* out, T* label, T* label_pred, int len, cudaStream_t stream); + +template <> +inline void logLoss(float* out, float* label, float* label_pred, int len, cudaStream_t stream) +{ + raft::linalg::binaryOp( + out, + label, + label_pred, + len, + [] __device__(float y, float y_pred) { return -y * logf(y_pred) - (1 - y) * logf(1 - y_pred); }, + stream); +} + +template <> +inline void logLoss(double* out, double* label, double* label_pred, int len, cudaStream_t stream) +{ + raft::linalg::binaryOp( + out, + label, + label_pred, + len, + [] __device__(double y, double y_pred) { + return -y * log(y_pred) - (1 - y) * logf(1 - y_pred); + }, + stream); +} + +template +void logisticRegLoss(const raft::handle_t& handle, + math_t* input, + int n_rows, + int n_cols, + math_t* labels, + const math_t* coef, + math_t* loss, + penalty pen, + math_t alpha, + math_t l1_ratio, + cudaStream_t stream) +{ + rmm::device_uvector labels_pred(n_rows, stream); + logisticRegH(handle, input, n_rows, n_cols, coef, labels_pred.data(), math_t(0), stream); + logLoss(labels_pred.data(), labels, labels_pred.data(), n_rows, stream); + + raft::stats::mean(loss, labels_pred.data(), 1, n_rows, false, false, stream); + + rmm::device_uvector pen_val(0, stream); + + if (pen != penalty::NONE) pen_val.resize(1, stream); + + if (pen == penalty::L1) { + lasso(pen_val.data(), coef, n_cols, alpha, stream); + } else if (pen == penalty::L2) { + ridge(pen_val.data(), coef, n_cols, alpha, stream); + } else if (pen == penalty::ELASTICNET) { + elasticnet(pen_val.data(), coef, n_cols, alpha, l1_ratio, stream); + } + + if (pen != penalty::NONE) { raft::linalg::add(loss, loss, pen_val.data(), 1, stream); } +} + +}; // namespace raft::solver::detail::objectives diff --git a/cpp/include/raft/solver/detail/objectives/penalty.cuh b/cpp/include/raft/solver/detail/objectives/penalty.cuh new file mode 100644 index 0000000000..db60f029a9 --- /dev/null +++ b/cpp/include/raft/solver/detail/objectives/penalty.cuh @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "sign.cuh" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::solver::detail::objectives { + +enum penalty { + NONE, + L1, + L2, + ELASTICNET, +}; + +template +void lasso(math_t* out, const math_t* coef, const int len, const math_t alpha, cudaStream_t stream) +{ + raft::linalg::rowNorm(out, coef, len, 1, raft::linalg::NormType::L1Norm, true, stream); + raft::linalg::scalarMultiply(out, out, alpha, 1, stream); +} + +template +void lassoGrad( + math_t* grad, const math_t* coef, const int len, const math_t alpha, cudaStream_t stream) +{ + sign(grad, coef, alpha, len, stream); +} + +template +void ridge(math_t* out, const math_t* coef, const int len, const math_t alpha, cudaStream_t stream) +{ + raft::linalg::rowNorm(out, coef, len, 1, raft::linalg::NormType::L2Norm, true, stream); + raft::linalg::scalarMultiply(out, out, alpha, 1, stream); +} + +template +void ridgeGrad( + math_t* grad, const math_t* coef, const int len, const math_t alpha, cudaStream_t stream) +{ + raft::linalg::scalarMultiply(grad, coef, math_t(2) * alpha, len, stream); +} + +template +void elasticnet(math_t* out, + const math_t* coef, + const int len, + const math_t alpha, + const math_t l1_ratio, + cudaStream_t stream) +{ + rmm::device_scalar out_lasso(stream); + + ridge(out, coef, len, alpha * (math_t(1) - l1_ratio), stream); + lasso(out_lasso.data(), coef, len, alpha * l1_ratio, stream); + + raft::linalg::add(out, out, out_lasso.data(), 1, stream); +} + +template +void elasticnetGrad(math_t* grad, + const math_t* coef, + const int len, + const math_t alpha, + const math_t l1_ratio, + cudaStream_t stream) +{ + rmm::device_uvector grad_lasso(len, stream); + + ridgeGrad(grad, coef, len, alpha * (math_t(1) - l1_ratio), stream); + lassoGrad(grad_lasso.data(), coef, len, alpha * l1_ratio, stream); + + raft::linalg::add(grad, grad, grad_lasso.data(), len, stream); +} + +}; // namespace raft::solver::detail::objectives diff --git a/cpp/include/raft/solver/detail/objectives/sigmoid.cuh b/cpp/include/raft/solver/detail/objectives/sigmoid.cuh new file mode 100644 index 0000000000..a06a305e44 --- /dev/null +++ b/cpp/include/raft/solver/detail/objectives/sigmoid.cuh @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::solver::detail::objectives { + +template +void sigmoid(T* out, T* in, IdxType len, cudaStream_t stream) +{ + T one = T(1); + raft::linalg::unaryOp( + out, in, len, [one] __device__(T in) { return one / (one + raft::myExp(-in)); }, stream); +} + +}; // end namespace raft::solver::detail::objectives diff --git a/cpp/include/raft/solver/detail/objectives/sign.cuh b/cpp/include/raft/solver/detail/objectives/sign.cuh new file mode 100644 index 0000000000..ca37727355 --- /dev/null +++ b/cpp/include/raft/solver/detail/objectives/sign.cuh @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::solver::detail::objectives { + +template +void sign( + math_t* out, const math_t* in, const math_t scalar, const idx_type len, cudaStream_t stream) +{ + raft::linalg::unaryOp( + out, + in, + len, + [scalar] __device__(math_t in) { + if (in < math_t(0)) + return (math_t(-1) * scalar); + else if (in > math_t(0)) + return (math_t(1) * scalar); + else + return math_t(0); + }, + stream); +} + +template +void sign(math_t* out, const math_t* in, const idx_type n_len, cudaStream_t stream) +{ + math_t scalar = math_t(1); + sign(out, in, scalar, n_len, stream); +} + +}; // namespace raft::solver::detail::objectives diff --git a/cpp/include/raft/solver/detail/objectives/softThres.cuh b/cpp/include/raft/solver/detail/objectives/softThres.cuh new file mode 100644 index 0000000000..485fc4f688 --- /dev/null +++ b/cpp/include/raft/solver/detail/objectives/softThres.cuh @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::solver::detail::objectives { + +template +void softThres( + math_t* out, const math_t* in, const math_t thres, const int len, cudaStream_t stream) +{ + raft::linalg::unaryOp( + out, + in, + len, + [thres] __device__(math_t in) { + if (in > math_t(0) && thres < raft::myAbs(in)) + return in - thres; + else if (in < math_t(0) && thres < raft::myAbs(in)) + return in + thres; + else + return math_t(0); + }, + stream); +} + +}; // namespace raft::solver::detail::objectives diff --git a/cpp/include/raft/solver/detail/qn/objectives/base.cuh b/cpp/include/raft/solver/detail/qn/objectives/base.cuh new file mode 100644 index 0000000000..1edc1904a5 --- /dev/null +++ b/cpp/include/raft/solver/detail/qn/objectives/base.cuh @@ -0,0 +1,239 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../simple_mat.cuh" +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +namespace raft::solver::quasi_newton::detail::objectives { + +template +inline void linearFwd(const raft::handle_t& handle, + SimpleDenseMat& Z, + const SimpleMat& X, + const SimpleDenseMat& W, + cudaStream_t stream) +{ + // Forward pass: compute Z <- W * X.T + bias + const bool has_bias = X.n != W.n; + const int D = X.n; + if (has_bias) { + SimpleVec bias; + SimpleDenseMat weights; + col_ref(W, bias, D); + col_slice(W, weights, 0, D); + // We implement Z <- W * X^T + b by + // - Z <- b (broadcast): TODO reads Z unnecessarily atm + // - Z <- W * X^T + Z : TODO can be fused in CUTLASS? + auto set_bias = [] __device__(const T z, const T b) { return b; }; + raft::linalg::matrixVectorOp( + Z.data, Z.data, bias.data, Z.n, Z.m, false, false, set_bias, stream); + + Z.assign_gemm(handle, 1, weights, false, X, true, 1, stream); + } else { + Z.assign_gemm(handle, 1, W, false, X, true, 0, stream); + } +} + +template +inline void linearBwd(const raft::handle_t& handle, + SimpleDenseMat& G, + const SimpleMat& X, + const SimpleDenseMat& dZ, + bool setZero, + cudaStream_t stream) +{ + // Backward pass: + // - compute G <- dZ * X.T + // - for bias: Gb = mean(dZ, 1) + + const bool has_bias = X.n != G.n; + const int D = X.n; + const T beta = setZero ? T(0) : T(1); + if (has_bias) { + SimpleVec Gbias; + SimpleDenseMat Gweights; + col_ref(G, Gbias, D); + col_slice(G, Gweights, 0, D); + + // TODO can this be fused somehow? + Gweights.assign_gemm(handle, 1.0 / X.m, dZ, false, X, false, beta, stream); + raft::stats::mean(Gbias.data, dZ.data, dZ.m, dZ.n, false, true, stream); + } else { + G.assign_gemm(handle, 1.0 / X.m, dZ, false, X, false, beta, stream); + } +} + + + +template +struct QNLinearBase : LinearDims { + typedef SimpleDenseMat Mat; + typedef SimpleVec Vec; + + const raft::handle_t& handle; + T* sample_weights; + T weights_sum; + + QNLinearBase(const raft::handle_t& handle, int D, int C, bool fit_intercept) + : LinearDims(C, D, fit_intercept), handle(handle), sample_weights(nullptr), weights_sum(0) + { + } + + void add_sample_weights(T* sample_weights, int n_samples, cudaStream_t stream) + { + this->sample_weights = sample_weights; + this->weights_sum = thrust::reduce(thrust::cuda::par.on(stream), + sample_weights, + sample_weights + n_samples, + (T)0, + thrust::plus()); + } + + /* + * Computes the following: + * 1. Z <- dL/DZ + * 2. loss_val <- sum loss(Z) + * + * Default: elementwise application of loss and its derivative + * + * NB: for this method to work, loss implementations must have two functor fields `lz` and `dlz`. + * These two compute loss value and its derivative w.r.t. `z`. + */ + inline void getLossAndDZ(T* loss_val, + SimpleDenseMat& Z, + const SimpleVec& y, + cudaStream_t stream) + { + // Base impl assumes simple case C = 1 + // TODO would be nice to have a kernel that fuses these two steps + // This would be easy, if mapThenSumReduce allowed outputing the result of + // map (supporting inplace) + auto lz_copy = static_cast(this)->lz; + auto dlz_copy = static_cast(this)->dlz; + if (this->sample_weights) { // Sample weights are in use + T normalization = 1.0 / this->weights_sum; + raft::linalg::mapThenSumReduce( + loss_val, + y.len, + [lz_copy, normalization] __device__(const T y, const T z, const T weight) { + return lz_copy(y, z) * (weight * normalization); + }, + stream, + y.data, + Z.data, + sample_weights); + raft::linalg::map_k( + Z.data, + y.len, + [dlz_copy] __device__(const T y, const T z, const T weight) { + return weight * dlz_copy(y, z); + }, + stream, + y.data, + Z.data, + sample_weights); + } else { // Sample weights are not used + T normalization = 1.0 / y.len; + raft::linalg::mapThenSumReduce( + loss_val, + y.len, + [lz_copy, normalization] __device__(const T y, const T z) { + return lz_copy(y, z) * normalization; + }, + stream, + y.data, + Z.data); + raft::linalg::binaryOp(Z.data, y.data, Z.data, y.len, dlz_copy, stream); + } + } + + inline void loss_grad(T* loss_val, + Mat& G, + const Mat& W, + const SimpleMat& Xb, + const Vec& yb, + Mat& Zb, + cudaStream_t stream, + bool initGradZero = true) + { + Loss* loss = static_cast(this); // static polymorphism + + linearFwd(handle, Zb, Xb, W, stream); // linear part: forward pass + loss->getLossAndDZ(loss_val, Zb, yb, stream); // loss specific part + linearBwd(handle, G, Xb, Zb, initGradZero, + stream); // linear part: backward pass + } +}; + +template +struct QNWithData : LinearDims { + const SimpleMat* X; + const SimpleVec* y; + SimpleDenseMat* Z; + QuasiNewtonObjective* objective; + + QNWithData(QuasiNewtonObjective* obj, const SimpleMat& X, const SimpleVec& y, SimpleDenseMat& Z) + : objective(obj), X(&X), y(&y), Z(&Z), LinearDims(obj->C, obj->D, obj->fit_intercept) + { + } + + // interface exposed to typical non-linear optimizers + inline T operator()(const SimpleVec& wFlat, + SimpleVec& gradFlat, + T* dev_scalar, + cudaStream_t stream) + { + SimpleDenseMat W(wFlat.data, C, dims); + SimpleDenseMat G(gradFlat.data, C, dims); + objective->loss_grad(dev_scalar, G, W, *X, *y, *Z, stream); + T loss_host; + raft::update_host(&loss_host, dev_scalar, 1, stream); + raft::interruptible::synchronize(stream); + return loss_host; + } + + /** + * @brief Calculate a norm of the gradient computed using the given Loss instance. + * + * This function is intended to be used in `check_convergence`; it's output is supposed + * to be proportional to the loss value w.r.t. the number of features (D). + * + * Different loss functions may scale differently with the number of features (D). + * This has an effect on the convergence criteria. To account for that, we let a + * loss function define its preferred metric. Normally, we differentiate between the + * L2 norm (e.g. for Squared loss) and LInf norm (e.g. for Softmax loss). + */ + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return objective->gradNorm(grad, dev_scalar, stream); + } +}; + +}; // namespace raft::solver::quasi_newton::detail::objectives diff --git a/cpp/include/raft/solver/detail/qn/objectives/hinge.cuh b/cpp/include/raft/solver/detail/qn/objectives/hinge.cuh new file mode 100644 index 0000000000..6a43babd6c --- /dev/null +++ b/cpp/include/raft/solver/detail/qn/objectives/hinge.cuh @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "base.cuh" +#include "../simple_mat.cuh" +#include +#include + +namespace raft::solver::quasi_newton::detail::objectives { + +template +struct HingeLoss : QNLinearBase> { + typedef QNLinearBase> Super; + + const struct Lz { + inline __device__ T operator()(const T y, const T z) const + { + T s = 2 * y - 1; + return raft::myMax(0, 1 - s * z); + } + } lz; + + const struct Dlz { + inline __device__ T operator()(const T y, const T z) const + { + T s = 2 * y - 1; + return s * z <= 1 ? -s : 0; + } + } dlz; + + HingeLoss(const raft::handle_t& handle, int D, bool has_bias) + : Super(handle, D, 1, has_bias), lz{}, dlz{} + { + } + + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return nrm1(grad, dev_scalar, stream); + } +}; + +template +struct SqHingeLoss : QNLinearBase> { + typedef QNLinearBase> Super; + + const struct Lz { + inline __device__ T operator()(const T y, const T z) const + { + T s = 2 * y - 1; + T t = raft::myMax(0, 1 - s * z); + return t * t; + } + } lz; + + const struct Dlz { + inline __device__ T operator()(const T y, const T z) const + { + T s = 2 * y - 1; + return s * z <= 1 ? z - s : 0; + } + } dlz; + + SqHingeLoss(const raft::handle_t& handle, int D, bool has_bias) + : Super(handle, D, 1, has_bias), lz{}, dlz{} + { + } + + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return squaredNorm(grad, dev_scalar, stream) * 0.5; + } +}; + +template +struct EpsInsHingeLoss : QNLinearBase> { + typedef QNLinearBase> Super; + + const struct Lz { + T sensitivity; + inline __device__ T operator()(const T y, const T z) const + { + T t = y - z; + return t > sensitivity ? t - sensitivity : t < -sensitivity ? -t - sensitivity : 0; + } + } lz; + + const struct Dlz { + T sensitivity; + inline __device__ T operator()(const T y, const T z) const + { + T t = y - z; + return t > sensitivity ? -1 : (t < -sensitivity ? 1 : 0); + } + } dlz; + + EpsInsHingeLoss(const raft::handle_t& handle, int D, bool has_bias, T sensitivity) + : Super(handle, D, 1, has_bias), lz{sensitivity}, dlz{sensitivity} + { + } + + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return nrm1(grad, dev_scalar, stream); + } +}; + +template +struct SqEpsInsHingeLoss : QNLinearBase> { + typedef QNLinearBase> Super; + + const struct Lz { + T sensitivity; + inline __device__ T operator()(const T y, const T z) const + { + T t = y - z; + T s = t > sensitivity ? t - sensitivity : t < -sensitivity ? -t - sensitivity : 0; + return s * s; + } + } lz; + + const struct Dlz { + T sensitivity; + inline __device__ T operator()(const T y, const T z) const + { + T t = y - z; + return -2 * (t > sensitivity ? t - sensitivity : t < -sensitivity ? (t + sensitivity) : 0); + } + } dlz; + + SqEpsInsHingeLoss(const raft::handle_t& handle, int D, bool has_bias, T sensitivity) + : Super(handle, D, 1, has_bias), lz{sensitivity}, dlz{sensitivity} + { + } + + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return squaredNorm(grad, dev_scalar, stream) * 0.5; + } +}; + +}; // namespace raft::solver::quasi_newton::detail::objectives diff --git a/cpp/include/raft/solver/detail/qn/objectives/linear.cuh b/cpp/include/raft/solver/detail/qn/objectives/linear.cuh new file mode 100644 index 0000000000..93c4363f78 --- /dev/null +++ b/cpp/include/raft/solver/detail/qn/objectives/linear.cuh @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "base.cuh" +#include "../simple_mat.cuh" +#include +#include + +namespace raft::solver::quasi_newton::detail::objectives { + +template +struct SquaredLoss : QNLinearBase> { + typedef QNLinearBase> Super; + + const struct Lz { + inline __device__ T operator()(const T y, const T z) const + { + T diff = z - y; + return diff * diff * 0.5; + } + } lz; + + const struct Dlz { + inline __device__ T operator()(const T y, const T z) const { return z - y; } + } dlz; + + SquaredLoss(const raft::handle_t& handle, int D, bool has_bias) + : Super(handle, D, 1, has_bias), lz{}, dlz{} + { + } + + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return squaredNorm(grad, dev_scalar, stream) * 0.5; + } +}; + +template +struct AbsLoss : QNLinearBase> { + typedef QNLinearBase> Super; + + const struct Lz { + inline __device__ T operator()(const T y, const T z) const { return raft::myAbs(z - y); } + } lz; + + const struct Dlz { + inline __device__ T operator()(const T y, const T z) const + { + return z > y ? 1 : (z < y ? -1 : 0); + } + } dlz; + + AbsLoss(const raft::handle_t& handle, int D, bool has_bias) + : Super(handle, D, 1, has_bias), lz{}, dlz{} + { + } + + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return nrm1(grad, dev_scalar, stream); + } +}; + +}; // namespace raft::solver::quasi_newton::detail::objectives diff --git a/cpp/include/raft/solver/detail/qn/objectives/logistic.cuh b/cpp/include/raft/solver/detail/qn/objectives/logistic.cuh new file mode 100644 index 0000000000..064cf4b793 --- /dev/null +++ b/cpp/include/raft/solver/detail/qn/objectives/logistic.cuh @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "base.cuh" +#include "../simple_mat.cuh" +#include +#include + +namespace raft::solver::quasi_newton::detail::objectives { + +template +struct LogisticLoss : QNLinearBase> { + typedef QNLinearBase> Super; + + const struct Lz { + inline __device__ T log_sigmoid(const T x) const + { + // To avoid floating point overflow in the exp function + T temp = raft::myLog(1 + raft::myExp(x < 0 ? x : -x)); + return x < 0 ? x - temp : -temp; + } + + inline __device__ T operator()(const T y, const T z) const + { + T ytil = 2 * y - 1; + return -log_sigmoid(ytil * z); + } + } lz; + + const struct Dlz { + inline __device__ T operator()(const T y, const T z) const + { + // To avoid fp overflow with exp(z) when abs(z) is large + T ez = raft::myExp(z < 0 ? z : -z); + T numerator = z < 0 ? ez : T(1.0); + return numerator / (T(1.0) + ez) - y; + } + } dlz; + + LogisticLoss(const raft::handle_t& handle, int D, bool has_bias) + : Super(handle, D, 1, has_bias), lz{}, dlz{} + { + } + + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return nrmMax(grad, dev_scalar, stream); + } +}; +}; // namespace raft::solver::quasi_newton::detail::objectives \ No newline at end of file diff --git a/cpp/include/raft/solver/detail/qn/objectives/regularizer.cuh b/cpp/include/raft/solver/detail/qn/objectives/regularizer.cuh new file mode 100644 index 0000000000..45395cbd0e --- /dev/null +++ b/cpp/include/raft/solver/detail/qn/objectives/regularizer.cuh @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "base.cuh" +#include "../simple_mat.cuh" +#include +#include +#include +#include +#include + +namespace raft::solver::quasi_newton::detail::objectives { + +template +struct Tikhonov { + T l2_penalty; + Tikhonov(T l2) : l2_penalty(l2) {} + Tikhonov(const Tikhonov& other) : l2_penalty(other.l2_penalty) {} + + HDI T operator()(const T w) const { return 0.5 * l2_penalty * w * w; } + + inline void reg_grad(T* reg_val, + SimpleDenseMat& G, + const SimpleDenseMat& W, + const bool has_bias, + cudaStream_t stream) const + { + // NOTE: scikit generally does not penalize biases + SimpleDenseMat Gweights; + SimpleDenseMat Wweights; + col_slice(G, Gweights, 0, G.n - has_bias); + col_slice(W, Wweights, 0, G.n - has_bias); + Gweights.ax(l2_penalty, Wweights, stream); + + raft::linalg::mapThenSumReduce(reg_val, Wweights.len, *this, stream, Wweights.data); + } +}; + +template +struct RegularizedQN : LinearDims { + Reg* reg; + Loss* loss; + + RegularizedQN(Loss* loss, Reg* reg) + : reg(reg), loss(loss), LinearDims(loss->C, loss->D, loss->fit_intercept) + { + } + + inline void loss_grad(T* loss_val, + SimpleDenseMat& G, + const SimpleDenseMat& W, + const SimpleMat& Xb, + const SimpleVec& yb, + SimpleDenseMat& Zb, + cudaStream_t stream, + bool initGradZero = true) + { + T reg_host, loss_host; + SimpleVec lossVal(loss_val, 1); + + G.fill(0, stream); + + reg->reg_grad(lossVal.data, G, W, loss->fit_intercept, stream); + raft::update_host(®_host, lossVal.data, 1, stream); + + loss->loss_grad(lossVal.data, G, W, Xb, yb, Zb, stream, false); + raft::update_host(&loss_host, lossVal.data, 1, stream); + + raft::interruptible::synchronize(stream); + + lossVal.fill(loss_host + reg_host, stream); + } + + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return loss->gradNorm(grad, dev_scalar, stream); + } +}; +}; // namespace raft::solver::quasi_newton::detail::objectives diff --git a/cpp/include/raft/solver/detail/qn/objectives/softmax.cuh b/cpp/include/raft/solver/detail/qn/objectives/softmax.cuh new file mode 100644 index 0000000000..642800a7ff --- /dev/null +++ b/cpp/include/raft/solver/detail/qn/objectives/softmax.cuh @@ -0,0 +1,197 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "base.cuh" +#include "../simple_mat.cuh" +#include +#include + +namespace raft::solver::quasi_newton::detail::objectives { +using raft::ceildiv; +using raft::myExp; +using raft::myLog; +using raft::myMax; + +// Input: matrix Z (dims: CxN) +// Computes softmax cross entropy loss across columns, i.e. normalization +// column-wise. +// +// This kernel performs best for small number of classes C. +// It's much faster than implementation based on ml-prims (up to ~2x - ~10x for +// small C <= BX). More importantly, it does not require another CxN scratch +// space. In that case the block covers the whole column and warp reduce is fast +// TODO for very large C, there should be maybe rather something along the lines +// of +// coalesced reduce, i.e. blocks should take care of columns +// TODO split into two kernels for small and large case? +template +__global__ void logSoftmaxKernel( + T* out, T* dZ, const T* in, const T* labels, int C, int N, bool getDerivative = true) +{ + typedef cub::WarpReduce WarpRed; + typedef cub::BlockReduce BlockRed; + + __shared__ union { + typename WarpRed::TempStorage warpStore[BY]; + typename BlockRed::TempStorage blockStore; + T sh_val[BY]; + } shm; + + int y = threadIdx.y + blockIdx.x * BY; + int len = C * N; + + bool delta = false; + // TODO is there a better way to read this? + if (getDerivative && threadIdx.x == 0) { + if (y < N) { + shm.sh_val[threadIdx.y] = labels[y]; + } else { + shm.sh_val[threadIdx.y] = std::numeric_limits::lowest(); + } + } + __syncthreads(); + T label = shm.sh_val[threadIdx.y]; + __syncthreads(); + T eta_y = 0; + T myEta = 0; + T etaMax = -1e9; + T lse = 0; + /* + * Phase 1: Find Maximum m over column + */ + for (int x = threadIdx.x; x < C; x += BX) { + int idx = x + y * C; + if (x < C && idx < len) { + myEta = in[idx]; + if (x == label) { + delta = true; + eta_y = myEta; + } + etaMax = myMax(myEta, etaMax); + } + } + T tmpMax = WarpRed(shm.warpStore[threadIdx.y]).Reduce(etaMax, cub::Max()); + if (threadIdx.x == 0) { shm.sh_val[threadIdx.y] = tmpMax; } + __syncthreads(); + etaMax = shm.sh_val[threadIdx.y]; + __syncthreads(); + + /* + * Phase 2: Compute stabilized log-sum-exp over column + * lse = m + log(sum(exp(eta - m))) + */ + // TODO there must be a better way to do this... + if (C <= BX) { // this means one block covers a column and myEta is valid + int idx = threadIdx.x + y * C; + if (threadIdx.x < C && idx < len) { lse = myExp(myEta - etaMax); } + } else { + for (int x = threadIdx.x; x < C; x += BX) { + int idx = x + y * C; + if (x < C && idx < len) { lse += myExp(in[idx] - etaMax); } + } + } + T tmpLse = WarpRed(shm.warpStore[threadIdx.y]).Sum(lse); + if (threadIdx.x == 0) { shm.sh_val[threadIdx.y] = etaMax + myLog(tmpLse); } + __syncthreads(); + lse = shm.sh_val[threadIdx.y]; + __syncthreads(); + + /* + * Phase 3: Compute derivatives dL/dZ = P - delta_y + * P is the softmax distribution, delta_y the kronecker delta for the class of + * label y If we getDerivative=false, dZ will just contain P, which might be + * useful + */ + + if (C <= BX) { // this means one block covers a column and myEta is valid + int idx = threadIdx.x + y * C; + if (threadIdx.x < C && idx < len) { + dZ[idx] = (myExp(myEta - lse) - (getDerivative ? (threadIdx.x == label) : T(0))); + } + } else { + for (int x = threadIdx.x; x < C; x += BX) { + int idx = x + y * C; + if (x < C && idx < len) { + T logP = in[idx] - lse; + dZ[idx] = (myExp(logP) - (getDerivative ? (x == label) : T(0))); + } + } + } + + if (!getDerivative) // no need to continue, lossval will be undefined + return; + + T lossVal = 0; + if (delta) { lossVal = (lse - eta_y) / N; } + + /* + * Phase 4: accumulate loss value + */ + T blockSum = BlockRed(shm.blockStore).Sum(lossVal); + if (threadIdx.x == 0 && threadIdx.y == 0) { raft::myAtomicAdd(out, blockSum); } +} + +template +void launchLogsoftmax( + T* loss_val, T* dldZ, const T* Z, const T* labels, int C, int N, cudaStream_t stream) +{ + RAFT_CUDA_TRY(cudaMemsetAsync(loss_val, 0, sizeof(T), stream)); + raft::interruptible::synchronize(stream); + if (C <= 4) { + dim3 bs(4, 64); + dim3 gs(ceildiv(N, 64)); + logSoftmaxKernel<<>>(loss_val, dldZ, Z, labels, C, N); + } else if (C <= 8) { + dim3 bs(8, 32); + dim3 gs(ceildiv(N, 32)); + logSoftmaxKernel<<>>(loss_val, dldZ, Z, labels, C, N); + } else if (C <= 16) { + dim3 bs(16, 16); + dim3 gs(ceildiv(N, 16)); + logSoftmaxKernel<<>>(loss_val, dldZ, Z, labels, C, N); + } else { + dim3 bs(32, 8); + dim3 gs(ceildiv(N, 8)); + logSoftmaxKernel<<>>(loss_val, dldZ, Z, labels, C, N); + } + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +template +struct Softmax : QNLinearBase> { + typedef QNLinearBase> Super; + + Softmax(const raft::handle_t& handle, int D, int C, bool has_bias) : Super(handle, D, C, has_bias) + { + } + + inline void getLossAndDZ(T* loss_val, + SimpleDenseMat& Z, + const SimpleVec& y, + cudaStream_t stream) + { + launchLogsoftmax(loss_val, Z.data, Z.data, y.data, Z.m, Z.n, stream); + } + + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return nrmMax(grad, dev_scalar, stream); + } +}; + +}; // namespace raft::solver::quasi_newton::detail::objectives diff --git a/cpp/include/raft/solver/detail/qn/qn_decision.cuh b/cpp/include/raft/solver/detail/qn/qn_decision.cuh new file mode 100644 index 0000000000..c7bf123e55 --- /dev/null +++ b/cpp/include/raft/solver/detail/qn/qn_decision.cuh @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "objectives/base.cuh" +#include "objectives/linear.cuh" +#include "objectives/logistic.cuh" +#include "objectives/regularizer.cuh" +#include "objectives/softmax.cuh" +#include "objectives/hinge.cuh" +#include "qn_solvers.cuh" +#include "qn_util.cuh" + +#include +#include +#include + +namespace raft::solver::quasi_newton::detail { + +template +void linear_decision_function(const raft::handle_t& handle, + const qn_params& pams, + SimpleMat& X, + int C, + T* params, + T* scores, + cudaStream_t stream) { + // NOTE: While gtests pass X as row-major, and python API passes X as + // col-major, no extensive testing has been done to ensure that + // this function works correctly for both input types + int n_targets = qn_is_classification(pams.loss) && C == 2 ? 1 : C; + LinearDims dims(n_targets, X.n, pams.fit_intercept); + SimpleDenseMat W(params, n_targets, dims.dims); + SimpleDenseMat Z(scores, n_targets, X.m); + linearFwd(handle, Z, X, W, stream); +} +}; // namespace raft::solver::quasi_newton::detail diff --git a/cpp/include/raft/solver/detail/qn/qn_linesearch.cuh b/cpp/include/raft/solver/detail/qn/qn_linesearch.cuh new file mode 100644 index 0000000000..70ed2a471f --- /dev/null +++ b/cpp/include/raft/solver/detail/qn/qn_linesearch.cuh @@ -0,0 +1,210 @@ +/* + * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include "qn_util.cuh" + +/* + * Linesearch functions + */ + +namespace raft::solver::quasi_newton::detail { + +template +struct LSProjectedStep { + typedef SimpleVec Vector; + struct op_pstep { + T step; + op_pstep(const T s) : step(s) {} + + HDI T operator()(const T xp, const T drt, const T pg) const + { + T xi = xp == 0 ? -pg : xp; + return project_orth(xp + step * drt, xi); + } + }; + + void operator()(const T step, + Vector& x, + const Vector& drt, + const Vector& xp, + const Vector& pgrad, + cudaStream_t stream) const + { + op_pstep pstep(step); + x.assign_ternary(xp, drt, pgrad, pstep, stream); + } +}; + +template +inline bool ls_success(const LBFGSParam& param, + const T fx_init, + const T dg_init, + const T fx, + const T dg_test, + const T step, + const SimpleVec& grad, + const SimpleVec& drt, + T* width, + T* dev_scalar, + cudaStream_t stream) +{ + if (fx > fx_init + step * dg_test) { + *width = param.ls_dec; + } else { + // Armijo condition is met + if (param.linesearch == LBFGS_LS_BT_ARMIJO) return true; + + const T dg = dot(grad, drt, dev_scalar, stream); + if (dg < param.wolfe * dg_init) { + *width = param.ls_inc; + } else { + // Regular Wolfe condition is met + if (param.linesearch == LBFGS_LS_BT_WOLFE) return true; + + if (dg > -param.wolfe * dg_init) { + *width = param.ls_dec; + } else { + // Strong Wolfe condition is met + return true; + } + } + } + + return false; +} + +/** + * Backtracking linesearch + * + * \param param LBFGS parameters + * \param f A function object such that `f(x, grad)` returns the + * objective function value at `x`, and overwrites `grad` + * with the gradient. + * \param fx In: The objective function value at the current point. + * Out: The function value at the new point. + * \param x Out: The new point moved to. + * \param grad In: The current gradient vector. + * Out: The gradient at the new point. + * \param step In: The initial step length. + * Out: The calculated step length. + * \param drt The current moving direction. + * \param xp The current point. + * \param dev_scalar Device pointer to workspace of at least 1 + * \param stream Device pointer to workspace of at least 1 + */ +template +LINE_SEARCH_RETCODE ls_backtrack(const LBFGSParam& param, + Function& f, + T& fx, + SimpleVec& x, + SimpleVec& grad, + T& step, + const SimpleVec& drt, + const SimpleVec& xp, + T* dev_scalar, + cudaStream_t stream) +{ + // Check the value of step + if (step <= T(0)) return LS_INVALID_STEP; + + // Save the function value at the current x + const T fx_init = fx; + // Projection of gradient on the search direction + const T dg_init = dot(grad, drt, dev_scalar, stream); + // Make sure d points to a descent direction + if (dg_init > 0) return LS_INVALID_DIR; + + const T dg_test = param.ftol * dg_init; + T width; + + RAFT_LOG_TRACE("Starting line search fx_init=%f, dg_init=%f", fx_init, dg_init); + + int iter; + for (iter = 0; iter < param.max_linesearch; iter++) { + // x_{k+1} = x_k + step * d_k + x.axpy(step, drt, xp, stream); + // Evaluate this candidate + fx = f(x, grad, dev_scalar, stream); + RAFT_LOG_TRACE("Line search iter %d, fx=%f", iter, fx); + // if (is_success(fx_init, dg_init, fx, dg_test, step, grad, drt, &width)) + if (ls_success( + param, fx_init, dg_init, fx, dg_test, step, grad, drt, &width, dev_scalar, stream)) + return LS_SUCCESS; + + if (step < param.min_step) return LS_INVALID_STEP_MIN; + + if (step > param.max_step) return LS_INVALID_STEP_MAX; + + step *= width; + } + return LS_MAX_ITERS_REACHED; +} + +template +LINE_SEARCH_RETCODE ls_backtrack_projected(const LBFGSParam& param, + Function& f, + T& fx, + SimpleVec& x, + SimpleVec& grad, + const SimpleVec& pseudo_grad, + T& step, + const SimpleVec& drt, + const SimpleVec& xp, + T l1_penalty, + T* dev_scalar, + cudaStream_t stream) +{ + LSProjectedStep lsstep; + + // Check the value of step + if (step <= T(0)) return LS_INVALID_STEP; + + // Save the function value at the current x + const T fx_init = fx; + // Projection of gradient on the search direction + const T dg_init = dot(pseudo_grad, drt, dev_scalar, stream); + // Make sure d points to a descent direction + if (dg_init > 0) return LS_INVALID_DIR; + + const T dg_test = param.ftol * dg_init; + T width; + + int iter; + for (iter = 0; iter < param.max_linesearch; iter++) { + // x_{k+1} = proj_orth(x_k + step * d_k) + lsstep(step, x, drt, xp, pseudo_grad, stream); + // evaluates fx with l1 term, but only grad of the loss term + fx = f(x, grad, dev_scalar, stream); + + // if (is_success(fx_init, dg_init, fx, dg_test, step, pseudo_grad, drt, + // &width)) + if (ls_success( + param, fx_init, dg_init, fx, dg_test, step, pseudo_grad, drt, &width, dev_scalar, stream)) + return LS_SUCCESS; + + if (step < param.min_step) return LS_INVALID_STEP_MIN; + + if (step > param.max_step) return LS_INVALID_STEP_MAX; + + step *= width; + } + return LS_MAX_ITERS_REACHED; +} + +}; // namespace raft::solver::quasi_newton::detail diff --git a/cpp/include/raft/solver/detail/qn/qn_solvers.cuh b/cpp/include/raft/solver/detail/qn/qn_solvers.cuh new file mode 100644 index 0000000000..48c698ee4e --- /dev/null +++ b/cpp/include/raft/solver/detail/qn/qn_solvers.cuh @@ -0,0 +1,469 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +/* + * This file contains implementations of two popular Quasi-Newton methods: + * - Limited-memory Broyden Fletcher Goldfarb Shanno (L-BFGS) [Nocedal, Wright - + * Numerical Optimization (1999)] + * - Orthant-wise limited-memory quasi-newton (OWL-QN) [Andrew, Gao - ICML 2007] + * https://www.microsoft.com/en-us/research/publication/scalable-training-of-l1-regularized-log-linear-models/ + * + * L-BFGS is a classical method to solve unconstrained optimization problems of + * differentiable multi-variate functions f: R^D \mapsto R, i.e. it solves + * + * \min_{x \in R^D} f(x) + * + * iteratively by building up a m-dimensional (inverse) Hessian approximation. + * + * OWL-QN is an extension of L-BFGS that is specifically designed to optimize + * functions of the form + * + * f(x) + \lambda * \sum_i |x_i|, + * + * i.e. functions with an l1 penalty, by leveraging that |z| is differentiable + * when restricted to an orthant. + * + */ + +#include "qn_linesearch.cuh" +#include "qn_util.cuh" +#include "simple_mat.cuh" +#include +#include +#include + +namespace raft::solver::quasi_newton::detail { + +// TODO better way to deal with alignment? Smaller aligne possible? +constexpr size_t qn_align = 256; + +template +inline size_t lbfgs_workspace_size(const LBFGSParam& param, const int n) +{ + size_t mat_size = raft::alignTo(sizeof(T) * param.m * n, qn_align); + size_t vec_size = raft::alignTo(sizeof(T) * n, qn_align); + return 2 * mat_size + 4 * vec_size + qn_align; +} + +template +inline size_t owlqn_workspace_size(const LBFGSParam& param, const int n) +{ + size_t vec_size = raft::alignTo(sizeof(T) * n, qn_align); + return lbfgs_workspace_size(param, n) + vec_size; +} + +template +inline bool update_and_check(const char* solver, + const LBFGSParam& param, + int iter, + LINE_SEARCH_RETCODE lsret, + T& fx, + T& fxp, + const T& gnorm, + ML::SimpleVec& x, + ML::SimpleVec& xp, + ML::SimpleVec& grad, + ML::SimpleVec& gradp, + std::vector& fx_hist, + T* dev_scalar, + OPT_RETCODE& outcode, + cudaStream_t stream) +{ + bool stop = false; + bool converged = false; + bool isLsValid = !isnan(fx) && !isinf(fx); + // Linesearch may fail to converge, but still come closer to the solution; + // if that is not the case, let `check_convergence` ("insufficient change") + // below terminate the loop. + bool isLsNonCritical = lsret == LS_INVALID_STEP_MIN || lsret == LS_MAX_ITERS_REACHED; + // If the error is not critical, check that the target function does not grow. + // This shouldn't really happen, but weird things can happen if the convergence + // thresholds are too small. + bool isLsInDoubt = isLsValid && fx <= fxp + param.ftol && isLsNonCritical; + bool isLsSuccess = lsret == LS_SUCCESS || isLsInDoubt; + + RAFT_LOG_TRACE("%s iteration %d, fx=%f", solver, iter, fx); + + // if the target is at least finite, we can check the convergence + if (isLsValid) converged = check_convergence(param, iter, fx, gnorm, fx_hist); + + if (!isLsSuccess && !converged) { + RAFT_LOG_WARN( + "%s line search failed (code %d); stopping at the last valid step", solver, lsret); + outcode = OPT_LS_FAILED; + stop = true; + } else if (!isLsValid) { + RAFT_LOG_ERROR( + "%s error fx=%f at iteration %d; stopping at the last valid step", solver, fx, iter); + outcode = OPT_NUMERIC_ERROR; + stop = true; + } else if (converged) { + RAFT_LOG_DEBUG("%s converged", solver); + outcode = OPT_SUCCESS; + stop = true; + } else if (isLsInDoubt && fx + param.ftol >= fxp) { + // If a non-critical error has happened during the line search, check if the target + // is improved at least a bit. Otherwise, stop to avoid spinning till the iteration limit. + RAFT_LOG_WARN( + "%s stopped, because the line search failed to advance (step delta = %f)", solver, fx - fxp); + outcode = OPT_LS_FAILED; + stop = true; + } + + // if lineseach wasn't successful, undo the update. + if (!isLsSuccess || !isLsValid) { + fx = fxp; + x.copy_async(xp, stream); + grad.copy_async(gradp, stream); + } + + return stop; +} + +template +inline OPT_RETCODE min_lbfgs(const LBFGSParam& param, + Function& f, // function to minimize + SimpleVec& x, // initial point, holds result + T& fx, // output function value + int* k, // output iterations + SimpleVec& workspace, // scratch space + cudaStream_t stream, + int verbosity = 0) +{ + int n = x.len; + const int workspace_size = lbfgs_workspace_size(param, n); + ASSERT(workspace.len >= workspace_size, "LBFGS: workspace insufficient"); + + // SETUP WORKSPACE + size_t mat_size = raft::alignTo(sizeof(T) * param.m * n, qn_align); + size_t vec_size = raft::alignTo(sizeof(T) * n, qn_align); + T* p_ws = workspace.data; + SimpleDenseMat S(p_ws, n, param.m); + p_ws += mat_size; + SimpleDenseMat Y(p_ws, n, param.m); + p_ws += mat_size; + SimpleVec xp(p_ws, n); + p_ws += vec_size; + SimpleVec grad(p_ws, n); + p_ws += vec_size; + SimpleVec gradp(p_ws, n); + p_ws += vec_size; + SimpleVec drt(p_ws, n); + p_ws += vec_size; + T* dev_scalar = p_ws; + + SimpleVec svec, yvec; // mask vectors + + std::vector ys(param.m); + std::vector alpha(param.m); + std::vector fx_hist(param.past > 0 ? param.past : 0); + + *k = 0; + ML::Logger::get().setLevel(verbosity); + RAFT_LOG_DEBUG("Running L-BFGS"); + + // Evaluate function and compute gradient + fx = f(x, grad, dev_scalar, stream); + T gnorm = f.gradNorm(grad, dev_scalar, stream); + + if (param.past > 0) fx_hist[0] = fx; + + // Early exit if the initial x is already a minimizer + if (check_convergence(param, *k, fx, gnorm, fx_hist)) { + RAFT_LOG_DEBUG("Initial solution fulfills optimality condition."); + return OPT_SUCCESS; + } + + // Initial direction + drt.ax(-1.0, grad, stream); + + // Initial step + T step = T(1.0) / nrm2(drt, dev_scalar, stream); + T fxp = fx; + + *k = 1; + int end = 0; + int n_vec = 0; // number of vector updates made in lbfgs_search_dir + OPT_RETCODE retcode; + LINE_SEARCH_RETCODE lsret; + for (; *k <= param.max_iterations; (*k)++) { + // Save the curent x and gradient + xp.copy_async(x, stream); + gradp.copy_async(grad, stream); + fxp = fx; + + // Line search to update x, fx and gradient + lsret = ls_backtrack(param, f, fx, x, grad, step, drt, xp, dev_scalar, stream); + gnorm = f.gradNorm(grad, dev_scalar, stream); + + if (update_and_check("L-BFGS", + param, + *k, + lsret, + fx, + fxp, + gnorm, + x, + xp, + grad, + gradp, + fx_hist, + dev_scalar, + retcode, + stream)) + return retcode; + + // Update s and y + // s_{k+1} = x_{k+1} - x_k + // y_{k+1} = g_{k+1} - g_k + col_ref(S, svec, end); + col_ref(Y, yvec, end); + svec.axpy(-1.0, xp, x, stream); + yvec.axpy(-1.0, gradp, grad, stream); + // drt <- -H * g + end = lbfgs_search_dir( + param, &n_vec, end, S, Y, grad, svec, yvec, drt, ys, alpha, dev_scalar, stream); + + // step = 1.0 as initial guess + step = T(1.0); + } + RAFT_LOG_WARN("L-BFGS: max iterations reached"); + return OPT_MAX_ITERS_REACHED; +} + +template +inline void update_pseudo(const SimpleVec& x, + const SimpleVec& grad, + const op_pseudo_grad& pseudo_grad, + const int pg_limit, + SimpleVec& pseudo, + cudaStream_t stream) +{ + if (grad.len > pg_limit) { + pseudo.copy_async(grad, stream); + SimpleVec mask(pseudo.data, pg_limit); + mask.assign_binary(x, grad, pseudo_grad, stream); + } else { + pseudo.assign_binary(x, grad, pseudo_grad, stream); + } +} + +template +inline OPT_RETCODE min_owlqn(const LBFGSParam& param, + Function& f, + const T l1_penalty, + const int pg_limit, + SimpleVec& x, + T& fx, + int* k, + SimpleVec& workspace, // scratch space + cudaStream_t stream, + const int verbosity = 0) +{ + int n = x.len; + const int workspace_size = owlqn_workspace_size(param, n); + ASSERT(workspace.len >= workspace_size, "LBFGS: workspace insufficient"); + ASSERT(pg_limit <= n && pg_limit > 0, "OWL-QN: Invalid pseudo grad limit parameter"); + + // SETUP WORKSPACE + size_t mat_size = raft::alignTo(sizeof(T) * param.m * n, qn_align); + size_t vec_size = raft::alignTo(sizeof(T) * n, qn_align); + T* p_ws = workspace.data; + SimpleDenseMat S(p_ws, n, param.m); + p_ws += mat_size; + SimpleDenseMat Y(p_ws, n, param.m); + p_ws += mat_size; + SimpleVec xp(p_ws, n); + p_ws += vec_size; + SimpleVec grad(p_ws, n); + p_ws += vec_size; + SimpleVec gradp(p_ws, n); + p_ws += vec_size; + SimpleVec drt(p_ws, n); + p_ws += vec_size; + SimpleVec pseudo(p_ws, n); + p_ws += vec_size; + T* dev_scalar = p_ws; + + ML::Logger::get().setLevel(verbosity); + + SimpleVec svec, yvec; // mask vectors + + std::vector ys(param.m); + std::vector alpha(param.m); + std::vector fx_hist(param.past > 0 ? param.past : 0); + + op_project project_neg(T(-1.0)); + + auto f_wrap = [&f, &l1_penalty, &pg_limit]( + SimpleVec& x, SimpleVec& grad, T* dev_scalar, cudaStream_t stream) { + T tmp = f(x, grad, dev_scalar, stream); + SimpleVec mask(x.data, pg_limit); + return tmp + l1_penalty * nrm1(mask, dev_scalar, stream); + }; + + *k = 0; + RAFT_LOG_DEBUG("Running OWL-QN with lambda=%f", l1_penalty); + + // op to compute the pseudo gradients + op_pseudo_grad pseudo_grad(l1_penalty); + + fx = f_wrap(x, grad, dev_scalar, + stream); // fx is loss+regularizer, grad is grad of loss only + T gnorm = f.gradNorm(grad, dev_scalar, stream); + + // compute pseudo grad, but don't overwrite grad: used to build H + // pseudo.assign_binary(x, grad, pseudo_grad); + update_pseudo(x, grad, pseudo_grad, pg_limit, pseudo, stream); + + if (param.past > 0) fx_hist[0] = fx; + + // Early exit if the initial x is already a minimizer + if (check_convergence(param, *k, fx, gnorm, fx_hist)) { + RAFT_LOG_DEBUG("Initial solution fulfills optimality condition."); + return OPT_SUCCESS; + } + + // Initial direction + drt.ax(-1.0, pseudo, stream); // using Pseudo gradient here + // below should be done for consistency but seems unnecessary + // drt.assign_k_ary(project, pseudo, x); + + // Initial step + T step = T(1.0) / std::max(T(1), nrm2(drt, dev_scalar, stream)); + T fxp = fx; + + int end = 0; + int n_vec = 0; // number of vector updates made in lbfgs_search_dir + OPT_RETCODE retcode; + LINE_SEARCH_RETCODE lsret; + for ((*k) = 1; (*k) <= param.max_iterations; (*k)++) { + // Save the curent x and gradient + xp.copy_async(x, stream); + gradp.copy_async(grad, stream); + fxp = fx; + + // Projected line search to update x, fx and gradient + lsret = ls_backtrack_projected( + param, f_wrap, fx, x, grad, pseudo, step, drt, xp, l1_penalty, dev_scalar, stream); + gnorm = f.gradNorm(grad, dev_scalar, stream); + + if (update_and_check("QWL-QN", + param, + *k, + lsret, + fx, + fxp, + gnorm, + x, + xp, + grad, + gradp, + fx_hist, + dev_scalar, + retcode, + stream)) + return retcode; + + // recompute pseudo + // pseudo.assign_binary(x, grad, pseudo_grad); + update_pseudo(x, grad, pseudo_grad, pg_limit, pseudo, stream); + + // Update s and y - We should only do this if there is no skipping condition + + col_ref(S, svec, end); + col_ref(Y, yvec, end); + svec.axpy(-1.0, xp, x, stream); + yvec.axpy(-1.0, gradp, grad, stream); + // drt <- -H * -> pseudo grad <- + end = lbfgs_search_dir( + param, &n_vec, end, S, Y, pseudo, svec, yvec, drt, ys, alpha, dev_scalar, stream); + + // Project drt onto orthant of -pseudog + drt.assign_binary(drt, pseudo, project_neg, stream); + + // step = 1.0 as initial guess + step = T(1.0); + } + RAFT_LOG_WARN("QWL-QN: max iterations reached"); + return OPT_MAX_ITERS_REACHED; +} +/* + * Chooses the right algorithm, depending on presence of l1 term + */ +template +inline int qn_minimize(const raft::handle_t& handle, + SimpleVec& x, + T* fx, + int* num_iters, + LossFunction& loss, + const T l1, + const LBFGSParam& opt_param, + cudaStream_t stream, + const int verbosity = 0) +{ + // TODO should the worksapce allocation happen outside? + OPT_RETCODE ret; + if (l1 == 0.0) { + rmm::device_uvector tmp(lbfgs_workspace_size(opt_param, x.len), stream); + SimpleVec workspace(tmp.data(), tmp.size()); + + ret = min_lbfgs(opt_param, + loss, // function to minimize + x, // initial point, holds result + *fx, // output function value + num_iters, // output iterations + workspace, // scratch space + stream, + verbosity); + + RAFT_LOG_DEBUG("L-BFGS Done"); + } else { + // There might not be a better way to deal with dispatching + // for the l1 case: + // The algorithm explicitely expects a differentiable + // function f(x). It takes care of adding and + // handling the term l1norm(x) * l1_pen explicitely, i.e. + // it needs to evaluate f(x) and its gradient separately + + rmm::device_uvector tmp(owlqn_workspace_size(opt_param, x.len), stream); + SimpleVec workspace(tmp.data(), tmp.size()); + + ret = min_owlqn(opt_param, + loss, // function to minimize + l1, + loss.D * loss.C, + x, // initial point, holds result + *fx, // output function value + num_iters, // output iterations + workspace, // scratch space + stream, + verbosity); + + RAFT_LOG_DEBUG("OWL-QN Done"); + } + if (ret == OPT_MAX_ITERS_REACHED) { + RAFT_LOG_WARN( + "Maximum iterations reached before solver is converged. To increase " + "model accuracy you can increase the number of iterations (max_iter) or " + "improve the scaling of the input data."); + } + return ret; +} + +}; // namespace raft::solver::quasi_newton::detail \ No newline at end of file diff --git a/cpp/include/raft/solver/detail/qn/qn_util.cuh b/cpp/include/raft/solver/detail/qn/qn_util.cuh new file mode 100644 index 0000000000..0124d248c9 --- /dev/null +++ b/cpp/include/raft/solver/detail/qn/qn_util.cuh @@ -0,0 +1,170 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::solver::quasi_newton::detail { + + + +inline bool qn_is_classification(qn_loss_type t) +{ + switch (t) { + case QN_LOSS_LOGISTIC: + case QN_LOSS_SOFTMAX: + case QN_LOSS_SVC_L1: + case QN_LOSS_SVC_L2: return true; + default: return false; + } +} + +template +HDI T project_orth(T x, T y) { + return x * y <= T(0) ? T(0) : x; +} + +template +inline bool check_convergence( + const LBFGSParam& param, const int k, const T fx, const T gnorm, std::vector& fx_hist) +{ + // Positive scale factor for the stop condition + T fmag = std::max(fx, param.epsilon); + + RAFT_LOG_DEBUG( + "%04d: f(x)=%.8f conv.crit=%.8f (gnorm=%.8f, fmag=%.8f)", k, fx, gnorm / fmag, gnorm, fmag); + // Convergence test -- gradient + if (gnorm <= param.epsilon * fmag) { + RAFT_LOG_DEBUG("Converged after %d iterations: f(x)=%.6f", k, fx); + return true; + } + // Convergence test -- objective function value + if (param.past > 0) { + if (k >= param.past && std::abs(fx_hist[k % param.past] - fx) <= param.delta * fmag) { + RAFT_LOG_DEBUG("Insufficient change in objective value"); + return true; + } + + fx_hist[k % param.past] = fx; + } + return false; +} + +/* + * Multiplies a vector g with the inverse hessian approximation, i.e. + * drt = - H * g, + * e.g. to compute the new search direction for g = \nabla f(x) + */ +template +inline int lbfgs_search_dir(const LBFGSParam& param, + int* n_vec, + const int end_prev, + const SimpleDenseMat& S, + const SimpleDenseMat& Y, + const SimpleVec& g, + const SimpleVec& svec, + const SimpleVec& yvec, + SimpleVec& drt, + std::vector& yhist, + std::vector& alpha, + T* dev_scalar, + cudaStream_t stream) +{ + SimpleVec sj, yj; // mask vectors + int end = end_prev; + // note: update_state assigned svec, yvec to m_s[:,end], m_y[:,end] + T ys = dot(svec, yvec, dev_scalar, stream); + T yy = dot(yvec, yvec, dev_scalar, stream); + RAFT_LOG_TRACE("ys=%e, yy=%e", ys, yy); + // Skipping test: + if (ys <= std::numeric_limits::epsilon() * yy) { + // We can land here for example if yvec == 0 (no change in the gradient, + // g_k == g_k+1). That means the Hessian is approximately zero. We cannot + // use the QN model to update the search dir, we just continue along the + // previous direction. + // + // See eq (3.9) and Section 6 in "A limited memory algorithm for bound + // constrained optimization" Richard H. Byrd, Peihuang Lu, Jorge Nocedal and + // Ciyou Zhu Technical Report NAM-08 (1994) NORTHWESTERN UNIVERSITY. + // + // Alternative condition to skip update is: ys / (-gs) <= epsmch, + // (where epsmch = std::numeric_limits::epsilon) given in Section 5 of + // "L-BFGS-B Fortran subroutines for large-scale bound constrained + // optimization" Ciyou Zhu, Richard H. Byrd, Peihuang Lu and Jorge Nocedal + // (1994). + RAFT_LOG_DEBUG("L-BFGS WARNING: skipping update step ys=%f, yy=%f", ys, yy); + return end; + } + (*n_vec)++; + yhist[end] = ys; + + // Recursive formula to compute d = -H * g + drt.ax(-1.0, g, stream); + int bound = std::min(param.m, *n_vec); + end = (end + 1) % param.m; + int j = end; + for (int i = 0; i < bound; i++) { + j = (j + param.m - 1) % param.m; + col_ref(S, sj, j); + col_ref(Y, yj, j); + alpha[j] = dot(sj, drt, dev_scalar, stream) / yhist[j]; + drt.axpy(-alpha[j], yj, drt, stream); + } + + drt.ax(ys / yy, drt, stream); + + for (int i = 0; i < bound; i++) { + col_ref(S, sj, j); + col_ref(Y, yj, j); + T beta = dot(yj, drt, dev_scalar, stream) / yhist[j]; + drt.axpy((alpha[j] - beta), sj, drt, stream); + j = (j + 1) % param.m; + } + + return end; +} + +template +HDI T get_pseudo_grad(T x, T dlossx, T C) +{ + if (x != 0) { return dlossx + raft::sgn(x) * C; } + T dplus = dlossx + C; + T dmins = dlossx - C; + if (dmins > T(0)) return dmins; + if (dplus < T(0)) return dplus; + return T(0); +} + +template +struct op_project { + T scal; + op_project(T s) : scal(s) {} + + HDI T operator()(const T x, const T y) const { return project_orth(x, scal * y); } +}; + +template +struct op_pseudo_grad { + T l1; + op_pseudo_grad(const T lam) : l1(lam) {} + + HDI T operator()(const T x, const T dlossx) const { return get_pseudo_grad(x, dlossx, l1); } +}; + +}; // namespace raft::solver::quasi_newton::detail diff --git a/cpp/include/raft/solver/detail/qn/simple_mat.cuh b/cpp/include/raft/solver/detail/qn/simple_mat.cuh new file mode 100644 index 0000000000..f455f6a1e1 --- /dev/null +++ b/cpp/include/raft/solver/detail/qn/simple_mat.cuh @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "simple_mat/base.hpp" +#include "simple_mat/dense.hpp" +#include "simple_mat/sparse.hpp" diff --git a/cpp/include/raft/solver/detail/qn/simple_mat/base.hpp b/cpp/include/raft/solver/detail/qn/simple_mat/base.hpp new file mode 100644 index 0000000000..2e9ae5dfcd --- /dev/null +++ b/cpp/include/raft/solver/detail/qn/simple_mat/base.hpp @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +namespace raft::solver::detail { + +template +struct SimpleDenseMat; + +template +struct SimpleMat { + int m, n; + + SimpleMat(int m, int n) : m(m), n(n) {} + + void operator=(const SimpleMat& other) = delete; + + virtual void print(std::ostream& oss) const = 0; + + /** + * GEMM assigning to C where `this` refers to B. + * + * ``` + * C <- alpha * A^transA * (*this)^transB + beta * C + * ``` + */ + virtual void gemmb(const raft::handle_t& handle, + const T alpha, + const SimpleDenseMat& A, + const bool transA, + const bool transB, + const T beta, + SimpleDenseMat& C, + cudaStream_t stream) const = 0; +}; + +}; // namespace raft::solver::detail diff --git a/cpp/include/raft/solver/detail/qn/simple_mat/dense.hpp b/cpp/include/raft/solver/detail/qn/simple_mat/dense.hpp new file mode 100644 index 0000000000..c915586cc2 --- /dev/null +++ b/cpp/include/raft/solver/detail/qn/simple_mat/dense.hpp @@ -0,0 +1,413 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +#include "base.hpp" +#include +#include +#include +#include +#include +// #TODO: Replace with public header when ready +#include +#include +#include +#include +#include + +namespace raft::solver::detail { + +enum STORAGE_ORDER { COL_MAJOR = 0, ROW_MAJOR = 1 }; + +template +struct SimpleDenseMat : SimpleMat { + typedef SimpleMat Super; + int len; + T* data; + + STORAGE_ORDER ord; // storage order: runtime param for compile time sake + + SimpleDenseMat(STORAGE_ORDER order = COL_MAJOR) : Super(0, 0), data(nullptr), len(0), ord(order) + { + } + + SimpleDenseMat(T* data, int m, int n, STORAGE_ORDER order = COL_MAJOR) + : Super(m, n), data(data), len(m * n), ord(order) + { + } + + void reset(T* data_, int m_, int n_) + { + this->m = m_; + this->n = n_; + data = data_; + len = m_ * n_; + } + + // Implemented GEMM as a static method here to improve readability + inline static void gemm(const raft::handle_t& handle, + const T alpha, + const SimpleDenseMat& A, + const bool transA, + const SimpleDenseMat& B, + const bool transB, + const T beta, + SimpleDenseMat& C, + cudaStream_t stream) + { + int kA = A.n; + int kB = B.m; + + if (transA) { + ASSERT(A.n == C.m, "GEMM invalid dims: m"); + kA = A.m; + } else { + ASSERT(A.m == C.m, "GEMM invalid dims: m"); + } + + if (transB) { + ASSERT(B.m == C.n, "GEMM invalid dims: n"); + kB = B.n; + } else { + ASSERT(B.n == C.n, "GEMM invalid dims: n"); + } + ASSERT(kA == kB, "GEMM invalid dims: k"); + + if (A.ord == COL_MAJOR && B.ord == COL_MAJOR && C.ord == COL_MAJOR) { + // #TODO: Call from public API when ready + raft::linalg::detail::cublasgemm(handle.get_cublas_handle(), // handle + transA ? CUBLAS_OP_T : CUBLAS_OP_N, // transA + transB ? CUBLAS_OP_T : CUBLAS_OP_N, // transB + C.m, + C.n, + kA, // dimensions m,n,k + &alpha, + A.data, + A.m, // lda + B.data, + B.m, // ldb + &beta, + C.data, + C.m, // ldc, + stream); + return; + } + if (A.ord == ROW_MAJOR) { + const SimpleDenseMat Acm(A.data, A.n, A.m, COL_MAJOR); + gemm(handle, alpha, Acm, !transA, B, transB, beta, C, stream); + return; + } + if (B.ord == ROW_MAJOR) { + const SimpleDenseMat Bcm(B.data, B.n, B.m, COL_MAJOR); + gemm(handle, alpha, A, transA, Bcm, !transB, beta, C, stream); + return; + } + if (C.ord == ROW_MAJOR) { + SimpleDenseMat Ccm(C.data, C.n, C.m, COL_MAJOR); + gemm(handle, alpha, B, !transB, A, !transA, beta, Ccm, stream); + return; + } + } + + inline void gemmb(const raft::handle_t& handle, + const T alpha, + const SimpleDenseMat& A, + const bool transA, + const bool transB, + const T beta, + SimpleDenseMat& C, + cudaStream_t stream) const override + { + SimpleDenseMat::gemm(handle, alpha, A, transA, *this, transB, beta, C, stream); + } + + /** + * GEMM assigning to C where `this` refers to C. + * + * ``` + * *this <- alpha * A^transA * B^transB + beta * (*this) + * ``` + */ + inline void assign_gemm(const raft::handle_t& handle, + const T alpha, + const SimpleDenseMat& A, + const bool transA, + const SimpleMat& B, + const bool transB, + const T beta, + cudaStream_t stream) + { + B.gemmb(handle, alpha, A, transA, transB, beta, *this, stream); + } + + // this = a*x + inline void ax(const T a, const SimpleDenseMat& x, cudaStream_t stream) + { + ASSERT(ord == x.ord, "SimpleDenseMat::ax: Storage orders must match"); + + auto scale = [a] __device__(const T x) { return a * x; }; + raft::linalg::unaryOp(data, x.data, len, scale, stream); + } + + // this = a*x + y + inline void axpy(const T a, + const SimpleDenseMat& x, + const SimpleDenseMat& y, + cudaStream_t stream) + { + ASSERT(ord == x.ord, "SimpleDenseMat::axpy: Storage orders must match"); + ASSERT(ord == y.ord, "SimpleDenseMat::axpy: Storage orders must match"); + + auto axpy = [a] __device__(const T x, const T y) { return a * x + y; }; + raft::linalg::binaryOp(data, x.data, y.data, len, axpy, stream); + } + + template + inline void assign_unary(const SimpleDenseMat& other, Lambda f, cudaStream_t stream) + { + ASSERT(ord == other.ord, "SimpleDenseMat::assign_unary: Storage orders must match"); + + raft::linalg::unaryOp(data, other.data, len, f, stream); + } + + template + inline void assign_binary(const SimpleDenseMat& other1, + const SimpleDenseMat& other2, + Lambda& f, + cudaStream_t stream) + { + ASSERT(ord == other1.ord, "SimpleDenseMat::assign_binary: Storage orders must match"); + ASSERT(ord == other2.ord, "SimpleDenseMat::assign_binary: Storage orders must match"); + + raft::linalg::binaryOp(data, other1.data, other2.data, len, f, stream); + } + + template + inline void assign_ternary(const SimpleDenseMat& other1, + const SimpleDenseMat& other2, + const SimpleDenseMat& other3, + Lambda& f, + cudaStream_t stream) + { + ASSERT(ord == other1.ord, "SimpleDenseMat::assign_ternary: Storage orders must match"); + ASSERT(ord == other2.ord, "SimpleDenseMat::assign_ternary: Storage orders must match"); + ASSERT(ord == other3.ord, "SimpleDenseMat::assign_ternary: Storage orders must match"); + + raft::linalg::ternaryOp(data, other1.data, other2.data, other3.data, len, f, stream); + } + + inline void fill(const T val, cudaStream_t stream) + { + // TODO this reads data unnecessary, though it's mostly used for testing + auto f = [val] __device__(const T x) { return val; }; + raft::linalg::unaryOp(data, data, len, f, stream); + } + + inline void copy_async(const SimpleDenseMat& other, cudaStream_t stream) + { + ASSERT((ord == other.ord) && (this->m == other.m) && (this->n == other.n), + "SimpleDenseMat::copy: matrices not compatible"); + + RAFT_CUDA_TRY( + cudaMemcpyAsync(data, other.data, len * sizeof(T), cudaMemcpyDeviceToDevice, stream)); + } + + void print(std::ostream& oss) const override { oss << (*this) << std::endl; } + + void operator=(const SimpleDenseMat& other) = delete; +}; + +template +struct SimpleVec : SimpleDenseMat { + typedef SimpleDenseMat Super; + + SimpleVec(T* data, const int n) : Super(data, n, 1, COL_MAJOR) {} + // this = alpha * A * x + beta * this + void assign_gemv(const raft::handle_t& handle, + const T alpha, + const SimpleDenseMat& A, + bool transA, + const SimpleVec& x, + const T beta, + cudaStream_t stream) + { + Super::assign_gemm(handle, alpha, A, transA, x, false, beta, stream); + } + + SimpleVec() : Super(COL_MAJOR) {} + + inline void reset(T* new_data, int n) { Super::reset(new_data, n, 1); } +}; + +template +inline void col_ref(const SimpleDenseMat& mat, SimpleVec& mask_vec, int c) +{ + ASSERT(mat.ord == COL_MAJOR, "col_ref only available for column major mats"); + T* tmp = &mat.data[mat.m * c]; + mask_vec.reset(tmp, mat.m); +} + +template +inline void col_slice(const SimpleDenseMat& mat, + SimpleDenseMat& mask_mat, + int c_from, + int c_to) +{ + ASSERT(c_from >= 0 && c_from < mat.n, "col_slice: invalid from"); + ASSERT(c_to >= 0 && c_to <= mat.n, "col_slice: invalid to"); + + ASSERT(mat.ord == COL_MAJOR, "col_ref only available for column major mats"); + ASSERT(mask_mat.ord == COL_MAJOR, "col_ref only available for column major mask"); + T* tmp = &mat.data[mat.m * c_from]; + mask_mat.reset(tmp, mat.m, c_to - c_from); +} + +// Reductions such as dot or norm require an additional location in dev mem +// to hold the result. We don't want to deal with this in the SimpleVec class +// as it impedes thread safety and constness + +template +inline T dot(const SimpleVec& u, const SimpleVec& v, T* tmp_dev, cudaStream_t stream) +{ + auto f = [] __device__(const T x, const T y) { return x * y; }; + raft::linalg::mapThenSumReduce(tmp_dev, u.len, f, stream, u.data, v.data); + T tmp_host; + raft::update_host(&tmp_host, tmp_dev, 1, stream); + + raft::interruptible::synchronize(stream); + return tmp_host; +} + +template +inline T squaredNorm(const SimpleVec& u, T* tmp_dev, cudaStream_t stream) +{ + return dot(u, u, tmp_dev, stream); +} + +template +inline T nrmMax(const SimpleVec& u, T* tmp_dev, cudaStream_t stream) +{ + auto f = [] __device__(const T x) { return raft::myAbs(x); }; + auto r = [] __device__(const T x, const T y) { return raft::myMax(x, y); }; + raft::linalg::mapThenReduce(tmp_dev, u.len, T(0), f, r, stream, u.data); + T tmp_host; + raft::update_host(&tmp_host, tmp_dev, 1, stream); + raft::interruptible::synchronize(stream); + return tmp_host; +} + +template +inline T nrm2(const SimpleVec& u, T* tmp_dev, cudaStream_t stream) +{ + return raft::mySqrt(squaredNorm(u, tmp_dev, stream)); +} + +template +inline T nrm1(const SimpleVec& u, T* tmp_dev, cudaStream_t stream) +{ + raft::linalg::rowNorm( + tmp_dev, u.data, u.len, 1, raft::linalg::L1Norm, true, stream, raft::Nop()); + T tmp_host; + raft::update_host(&tmp_host, tmp_dev, 1, stream); + raft::interruptible::synchronize(stream); + return tmp_host; +} + +template +std::ostream& operator<<(std::ostream& os, const SimpleVec& v) +{ + std::vector out(v.len); + raft::update_host(&out[0], v.data, v.len, 0); + raft::interruptible::synchronize(rmm::cuda_stream_view()); + int it = 0; + for (; it < v.len - 1;) { + os << out[it] << " "; + it++; + } + os << out[it]; + return os; +} + +template +std::ostream& operator<<(std::ostream& os, const SimpleDenseMat& mat) +{ + os << "ord=" << (mat.ord == COL_MAJOR ? "CM" : "RM") << "\n"; + std::vector out(mat.len); + raft::update_host(&out[0], mat.data, mat.len, rmm::cuda_stream_default); + raft::interruptible::synchronize(rmm::cuda_stream_view()); + if (mat.ord == COL_MAJOR) { + for (int r = 0; r < mat.m; r++) { + int idx = r; + for (int c = 0; c < mat.n - 1; c++) { + os << out[idx] << ","; + idx += mat.m; + } + os << out[idx] << std::endl; + } + } else { + for (int c = 0; c < mat.m; c++) { + int idx = c * mat.n; + for (int r = 0; r < mat.n - 1; r++) { + os << out[idx] << ","; + idx += 1; + } + os << out[idx] << std::endl; + } + } + + return os; +} + +template +struct SimpleVecOwning : SimpleVec { + typedef SimpleVec Super; + typedef rmm::device_uvector Buffer; + Buffer buf; + + SimpleVecOwning() = delete; + + SimpleVecOwning(int n, cudaStream_t stream) : Super(), buf(n, stream) + { + Super::reset(buf.data(), n); + } + + void operator=(const SimpleVec& other) = delete; +}; + +template +struct SimpleMatOwning : SimpleDenseMat { + typedef SimpleDenseMat Super; + typedef rmm::device_uvector Buffer; + Buffer buf; + using Super::m; + using Super::n; + using Super::ord; + + SimpleMatOwning() = delete; + + SimpleMatOwning(int m, int n, cudaStream_t stream, STORAGE_ORDER order = COL_MAJOR) + : Super(order), buf(m * n, stream) + { + Super::reset(buf.data(), m, n); + } + + void operator=(const SimpleVec& other) = delete; +}; + +}; // namespace raft::solver::detail diff --git a/cpp/include/raft/solver/detail/qn/simple_mat/sparse.hpp b/cpp/include/raft/solver/detail/qn/simple_mat/sparse.hpp new file mode 100644 index 0000000000..1d2a025ccd --- /dev/null +++ b/cpp/include/raft/solver/detail/qn/simple_mat/sparse.hpp @@ -0,0 +1,216 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +#include "base.hpp" +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +namespace raft::solver::detail { + +/** + * Sparse matrix in CSR format. + * + * Note, we use cuSPARSE to manimulate matrices, and it guarantees: + * + * 1. row_ids[m] == nnz + * 2. cols are sorted within rows. + * + * However, when the data comes from the outside, we cannot guarantee that. + */ +template +struct SimpleSparseMat : SimpleMat { + typedef SimpleMat Super; + T* values; + int* cols; + int* row_ids; + int nnz; + + SimpleSparseMat() : Super(0, 0), values(nullptr), cols(nullptr), row_ids(nullptr), nnz(0) {} + + SimpleSparseMat(T* values, int* cols, int* row_ids, int nnz, int m, int n) + : Super(m, n), values(values), cols(cols), row_ids(row_ids), nnz(nnz) + { + check_csr(*this, 0); + } + + void print(std::ostream& oss) const override { oss << (*this) << std::endl; } + + void operator=(const SimpleSparseMat& other) = delete; + + inline void gemmb(const raft::handle_t& handle, + const T alpha, + const SimpleDenseMat& A, + const bool transA, + const bool transB, + const T beta, + SimpleDenseMat& C, + cudaStream_t stream) const override + { + const SimpleSparseMat& B = *this; + int kA = A.n; + int kB = B.m; + + if (transA) { + ASSERT(A.n == C.m, "GEMM invalid dims: m"); + kA = A.m; + } else { + ASSERT(A.m == C.m, "GEMM invalid dims: m"); + } + + if (transB) { + ASSERT(B.m == C.n, "GEMM invalid dims: n"); + kB = B.n; + } else { + ASSERT(B.n == C.n, "GEMM invalid dims: n"); + } + ASSERT(kA == kB, "GEMM invalid dims: k"); + + // matrix C must change the order and be transposed, because we need + // to swap arguments A and B in cusparseSpMM. + cusparseDnMatDescr_t descrC; + auto order = C.ord == COL_MAJOR ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL; + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat( + &descrC, C.n, C.m, order == CUSPARSE_ORDER_COL ? C.n : C.m, C.data, order)); + + /* + The matrix A must have the same order as the matrix C in the input + of function cusparseSpMM (i.e. swapped order w.r.t. original C). + To account this requirement, I may need to flip transA (whether to transpose A). + + C C' rowsC' colsC' ldC' A A' rowsA' colsA' ldA' flipTransA + c r n m m c r n m m x + c r n m m r r m n n o + r c n m n c c m n m o + r c n m n r c n m n x + + where: + c/r - column/row major order + A,C - input to gemmb + A', C' - input to cusparseSpMM + ldX' - leading dimension - m or n, depending on order and transX + */ + cusparseDnMatDescr_t descrA; + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat(&descrA, + C.ord == A.ord ? A.n : A.m, + C.ord == A.ord ? A.m : A.n, + A.ord == COL_MAJOR ? A.m : A.n, + A.data, + order)); + auto opA = + transA ^ (C.ord == A.ord) ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE; + + cusparseSpMatDescr_t descrB; + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr( + &descrB, B.m, B.n, B.nnz, B.row_ids, B.cols, B.values)); + auto opB = transB ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE; + + auto alg = order == CUSPARSE_ORDER_COL ? CUSPARSE_SPMM_CSR_ALG1 : CUSPARSE_SPMM_CSR_ALG2; + + size_t bufferSize; + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsespmm_bufferSize(handle.get_cusparse_handle(), + opB, + opA, + &alpha, + descrB, + descrA, + &beta, + descrC, + alg, + &bufferSize, + stream)); + + raft::interruptible::synchronize(stream); + rmm::device_uvector tmp(bufferSize, stream); + + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsespmm(handle.get_cusparse_handle(), + opB, + opA, + &alpha, + descrB, + descrA, + &beta, + descrC, + alg, + tmp.data(), + stream)); + + RAFT_CUSPARSE_TRY(cusparseDestroyDnMat(descrA)); + RAFT_CUSPARSE_TRY(cusparseDestroySpMat(descrB)); + RAFT_CUSPARSE_TRY(cusparseDestroyDnMat(descrC)); + } +}; + +template +inline void check_csr(const SimpleSparseMat& mat, cudaStream_t stream) +{ + int row_ids_nnz; + raft::update_host(&row_ids_nnz, &mat.row_ids[mat.m], 1, stream); + raft::interruptible::synchronize(stream); + ASSERT(row_ids_nnz == mat.nnz, + "SimpleSparseMat: the size of CSR row_ids array must be `m + 1`, and " + "the last element must be equal nnz."); +} + +template +std::ostream& operator<<(std::ostream& os, const SimpleSparseMat& mat) +{ + check_csr(mat, 0); + os << "SimpleSparseMat (CSR)" + << "\n"; + std::vector values(mat.nnz); + std::vector cols(mat.nnz); + std::vector row_ids(mat.m + 1); + raft::update_host(&values[0], mat.values, mat.nnz, rmm::cuda_stream_default); + raft::update_host(&cols[0], mat.cols, mat.nnz, rmm::cuda_stream_default); + raft::update_host(&row_ids[0], mat.row_ids, mat.m + 1, rmm::cuda_stream_default); + raft::interruptible::synchronize(rmm::cuda_stream_view()); + + int i, row_end = 0; + for (int row = 0; row < mat.m; row++) { + i = row_end; + row_end = row_ids[row + 1]; + for (int col = 0; col < mat.n; col++) { + if (i >= row_end || col < cols[i]) { + os << "0"; + } else { + os << values[i]; + i++; + } + if (col < mat.n - 1) os << ","; + } + + os << std::endl; + } + + return os; +} + +}; // namespace raft::solver::detail diff --git a/cpp/include/raft/solver/detail/sgd.cuh b/cpp/include/raft/solver/detail/sgd.cuh new file mode 100644 index 0000000000..c03c64d47f --- /dev/null +++ b/cpp/include/raft/solver/detail/sgd.cuh @@ -0,0 +1,422 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "learning_rate.h" +#include "shuffle.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::solver::detail { + +/** + * Fits a linear, lasso, and elastic-net regression model using Coordinate Descent solver + * @param handle + * Reference of raft::handle_t + * @param input + * pointer to an array in column-major format (size of n_rows, n_cols) + * @param n_rows + * n_samples or rows in input + * @param n_cols + * n_features or columns in X + * @param labels + * pointer to an array for labels (size of n_rows) + * @param coef + * pointer to an array for coefficients (size of n_cols). This will be filled with + * coefficients once the function is executed. + * @param intercept + * pointer to a scalar for intercept. This will be filled + * once the function is executed + * @param fit_intercept + * boolean parameter to control if the intercept will be fitted or not + * @param batch_size + * number of rows in the minibatch + * @param epochs + * number of iterations that the solver will run + * @param lr_type + * type of the learning rate function (i.e. OPTIMAL, CONSTANT, INVSCALING, ADAPTIVE) + * @param eta0 + * learning rate for contant lr_type. It's used to calculate learning rate function for other + * types of lr_type + * @param power_t + * power value in the INVSCALING lr_type + * @param loss + * enum to use different loss functions. + * @param penalty + * None, L1, L2, or Elastic-net penalty + * @param alpha + * alpha value in L1 + * @param l1_ratio + * ratio of alpha will be used for L1. (1 - l1_ratio) * alpha will be used for L2. + * @param shuffle + * boolean parameter to control whether coordinates will be picked randomly or not. + * @param tol + * tolerance to stop the solver + * @param n_iter_no_change + * solver stops if there is no update greater than tol after n_iter_no_change iterations + * @param stream + * cuda stream + */ +template +void sgdFit(const raft::handle_t& handle, + math_t* input, + int n_rows, + int n_cols, + math_t* labels, + math_t* coef, + math_t* intercept, + bool fit_intercept, + int batch_size, + int epochs, + ML::lr_type lr_type, + math_t eta0, + math_t power_t, + ML::loss_funct loss, + Functions::penalty penalty, + math_t alpha, + math_t l1_ratio, + bool shuffle, + math_t tol, + int n_iter_no_change, + cudaStream_t stream) +{ + ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one"); + ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two"); + + cublasHandle_t cublas_handle = handle.get_cublas_handle(); + + rmm::device_uvector mu_input(0, stream); + rmm::device_uvector mu_labels(0, stream); + rmm::device_uvector norm2_input(0, stream); + + if (fit_intercept) { + mu_input.resize(n_cols, stream); + mu_labels.resize(1, stream); + + GLM::preProcessData(handle, + input, + n_rows, + n_cols, + labels, + intercept, + mu_input.data(), + mu_labels.data(), + norm2_input.data(), + fit_intercept, + false); + } + + rmm::device_uvector grads(n_cols, stream); + rmm::device_uvector indices(batch_size, stream); + rmm::device_uvector input_batch(batch_size * n_cols, stream); + rmm::device_uvector labels_batch(batch_size, stream); + rmm::device_scalar loss_value(stream); + + math_t prev_loss_value = math_t(0); + math_t curr_loss_value = math_t(0); + + std::vector rand_indices(n_rows); + std::mt19937 g(rand()); + initShuffle(rand_indices, g); + + math_t t = math_t(1); + math_t learning_rate = math_t(0); + if (lr_type == ML::lr_type::ADAPTIVE) { + learning_rate = eta0; + } else if (lr_type == ML::lr_type::OPTIMAL) { + eta0 = calOptimalInit(alpha); + } + + int n_iter_no_change_curr = 0; + + for (int i = 0; i < epochs; i++) { + int cbs = 0; + int j = 0; + + if (i > 0 && shuffle) { Solver::shuffle(rand_indices, g); } + + while (j < n_rows) { + if ((j + batch_size) > n_rows) { + cbs = n_rows - j; + } else { + cbs = batch_size; + } + + if (cbs == 0) break; + + raft::update_device(indices.data(), &rand_indices[j], cbs, stream); + raft::matrix::copyRows( + input, n_rows, n_cols, input_batch.data(), indices.data(), cbs, stream); + raft::matrix::copyRows(labels, n_rows, 1, labels_batch.data(), indices.data(), cbs, stream); + + if (loss == ML::loss_funct::SQRD_LOSS) { + Functions::linearRegLossGrads(handle, + input_batch.data(), + cbs, + n_cols, + labels_batch.data(), + coef, + grads.data(), + penalty, + alpha, + l1_ratio, + stream); + } else if (loss == ML::loss_funct::LOG) { + Functions::logisticRegLossGrads(handle, + input_batch.data(), + cbs, + n_cols, + labels_batch.data(), + coef, + grads.data(), + penalty, + alpha, + l1_ratio, + stream); + } else if (loss == ML::loss_funct::HINGE) { + Functions::hingeLossGrads(handle, + input_batch.data(), + cbs, + n_cols, + labels_batch.data(), + coef, + grads.data(), + penalty, + alpha, + l1_ratio, + stream); + } else { + ASSERT(false, "sgd.cuh: Other loss functions have not been implemented yet!"); + } + + if (lr_type != ML::lr_type::ADAPTIVE) + learning_rate = calLearningRate(lr_type, eta0, power_t, alpha, t); + + raft::linalg::scalarMultiply(grads.data(), grads.data(), learning_rate, n_cols, stream); + raft::linalg::subtract(coef, coef, grads.data(), n_cols, stream); + + j = j + cbs; + t = t + 1; + } + + if (tol > math_t(0)) { + if (loss == ML::loss_funct::SQRD_LOSS) { + Functions::linearRegLoss(handle, + input, + n_rows, + n_cols, + labels, + coef, + loss_value.data(), + penalty, + alpha, + l1_ratio, + stream); + } else if (loss == ML::loss_funct::LOG) { + Functions::logisticRegLoss(handle, + input, + n_rows, + n_cols, + labels, + coef, + loss_value.data(), + penalty, + alpha, + l1_ratio, + stream); + } else if (loss == ML::loss_funct::HINGE) { + Functions::hingeLoss(handle, + input, + n_rows, + n_cols, + labels, + coef, + loss_value.data(), + penalty, + alpha, + l1_ratio, + stream); + } + + raft::update_host(&curr_loss_value, loss_value.data(), 1, stream); + handle.sync_stream(stream); + + if (i > 0) { + if (curr_loss_value > (prev_loss_value - tol)) { + n_iter_no_change_curr = n_iter_no_change_curr + 1; + if (n_iter_no_change_curr > n_iter_no_change) { + if (lr_type == ML::lr_type::ADAPTIVE && learning_rate > math_t(1e-6)) { + learning_rate = learning_rate / math_t(5); + n_iter_no_change_curr = 0; + } else { + break; + } + } + } else { + n_iter_no_change_curr = 0; + } + } + + prev_loss_value = curr_loss_value; + } + } + + if (fit_intercept) { + GLM::postProcessData(handle, + input, + n_rows, + n_cols, + labels, + coef, + intercept, + mu_input.data(), + mu_labels.data(), + norm2_input.data(), + fit_intercept, + false); + } else { + *intercept = math_t(0); + } +} + +/** + * Make predictions + * @param handle + * Reference of raft::handle_t + * @param input + * pointer to an array in column-major format (size of n_rows, n_cols) + * @param n_rows + * n_samples or rows in input + * @param n_cols + * n_features or columns in X + * @param coef + * pointer to an array for coefficients (size of n_cols). Calculated in cdFit function. + * @param intercept + * intercept value calculated in cdFit function + * @param preds + * pointer to an array for predictions (size of n_rows). This will be fitted once functions + * is executed. + * @param loss + * enum to use different loss functions. Only linear regression loss functions is supported + * right now. + * @param stream + * cuda stream + */ +template +void sgdPredict(const raft::handle_t& handle, + const math_t* input, + int n_rows, + int n_cols, + const math_t* coef, + math_t intercept, + math_t* preds, + ML::loss_funct loss, + cudaStream_t stream) +{ + ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one"); + ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two"); + + if (loss == ML::loss_funct::SQRD_LOSS) { + Functions::linearRegH(handle, input, n_rows, n_cols, coef, preds, intercept, stream); + } else if (loss == ML::loss_funct::LOG) { + Functions::logisticRegH(handle, input, n_rows, n_cols, coef, preds, intercept, stream); + } else if (loss == ML::loss_funct::HINGE) { + Functions::hingeH(handle, input, n_rows, n_cols, coef, preds, intercept, stream); + } +} + +/** + * Make binary classifications + * @param handle + * Reference of raft::handle_t + * @param input + * pointer to an array in column-major format (size of n_rows, n_cols) + * @param n_rows + * n_samples or rows in input + * @param n_cols + * n_features or columns in X + * @param coef + * pointer to an array for coefficients (size of n_cols). Calculated in cdFit function. + * @param intercept + * intercept value calculated in cdFit function + * @param preds + * pointer to an array for predictions (size of n_rows). This will be fitted once functions + * is executed. + * @param loss + * enum to use different loss functions. Only linear regression loss functions is supported + * right now. + * @param stream + * cuda stream + */ +template +void sgdPredictBinaryClass(const raft::handle_t& handle, + const math_t* input, + int n_rows, + int n_cols, + const math_t* coef, + math_t intercept, + math_t* preds, + ML::loss_funct loss, + cudaStream_t stream) +{ + sgdPredict(handle, input, n_rows, n_cols, coef, intercept, preds, loss, stream); + + math_t scalar = math_t(1); + if (loss == ML::loss_funct::SQRD_LOSS || loss == ML::loss_funct::LOG) { + raft::linalg::unaryOp( + preds, + preds, + n_rows, + [scalar] __device__(math_t in) { + if (in >= math_t(0.5)) + return math_t(1); + else + return math_t(0); + }, + stream); + } else if (loss == ML::loss_funct::HINGE) { + raft::linalg::unaryOp( + preds, + preds, + n_rows, + [scalar] __device__(math_t in) { + if (in >= math_t(0.0)) + return math_t(1); + else + return math_t(0); + }, + stream); + } +} + +}; // namespace raft::solver::detail diff --git a/cpp/include/raft/solver/detail/shuffle.h b/cpp/include/raft/solver/detail/shuffle.h new file mode 100644 index 0000000000..4c131163c3 --- /dev/null +++ b/cpp/include/raft/solver/detail/shuffle.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::solver::detail { + +template +void initShuffle(std::vector& rand_indices, std::mt19937& g, math_t random_state = 0) +{ + g.seed((int)random_state); + for (std::size_t i = 0; i < rand_indices.size(); ++i) + rand_indices[i] = i; +} + +template +void shuffle(std::vector& rand_indices, std::mt19937& g) +{ + std::shuffle(rand_indices.begin(), rand_indices.end(), g); +} + +}; // namespace raft::solver::detail diff --git a/cpp/include/raft/solver/gradient_descent.cuh b/cpp/include/raft/solver/gradient_descent.cuh new file mode 100644 index 0000000000..07e49b3cd1 --- /dev/null +++ b/cpp/include/raft/solver/gradient_descent.cuh @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::solver::gradient_descent { + +} \ No newline at end of file diff --git a/cpp/include/raft/solver/least_angle_regression.cuh b/cpp/include/raft/solver/least_angle_regression.cuh new file mode 100644 index 0000000000..9f58a3c1ba --- /dev/null +++ b/cpp/include/raft/solver/least_angle_regression.cuh @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::solver::least_angle_regression { + +} \ No newline at end of file diff --git a/cpp/include/raft/solver/quasi_newton.cuh b/cpp/include/raft/solver/quasi_newton.cuh new file mode 100644 index 0000000000..1de75b7c5a --- /dev/null +++ b/cpp/include/raft/solver/quasi_newton.cuh @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace raft::solver::quasi_newton { + + using raft::solver::quasi_newton::detail::objectives::AbsLoss; + using raft::solver::quasi_newton::detail::objectives::HingeLoss; + using raft::solver::quasi_newton::detail::objectives::LogisticLoss; + using raft::solver::quasi_newton::detail::objectives::LinearDims; + using raft::solver::quasi_newton::detail::objectives::SqHingeLoss; + using raft::solver::quasi_newton::detail::objectives::SqEpsInsHingeLoss; + using raft::solver::quasi_newton::detail::objectives::EpsInsHingeLoss; + using raft::solver::quasi_newton::detail::LBFGSParam + + /** + * + * @tparam T + * @tparam Loss + * @tparam Reg + */ + template + class RegularizedQN : public detail::objectives::RegularizedQN { + RegularizedQN(Loss* loss, Reg* reg): detail::objectives::RegularizedQN(loss, reg) {} + }; + + /** + * + * @tparam T + * @tparam Loss + */ + template + struct QNLinearBase : detail::objectives::QNLinearBase { + QNLinearBase(const raft::handle_t &handle, int D, int C, bool fit_intercept) + : detail::objectives::QNLinearBase(C, D, fit_intercept) {} + } + + + using raft::solver::quasi_newton::detail::objectives::Softmax; + + using raft::solver::quasi_newton::detail::objectives::QNWithData; + using raft::solver::quasi_newton::detail::objectives::QuasiNewtonBase; + + template + inline int qn_minimize(const raft::handle_t& handle, + T *x, + T* fx, + int* num_iters, + LossFunction& loss, + const T l1, + const detail::LBFGSParam& opt_param) { + + } + +} \ No newline at end of file diff --git a/cpp/include/raft/solver/solver_types.hpp b/cpp/include/raft/solver/solver_types.hpp new file mode 100644 index 0000000000..e44b05cb58 --- /dev/null +++ b/cpp/include/raft/solver/solver_types.hpp @@ -0,0 +1,238 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + + + +namespace raft::solver { + + enum lr_type { + OPTIMAL, + CONSTANT, + INVSCALING, + ADAPTIVE, + }; + + enum loss_funct { + SQRD_LOSS, + HINGE, + LOG, + }; + + enum penalty { NONE, L1, L2, ELASTICNET }; + +/** Loss function types supported by the Quasi-Newton solvers. */ + enum qn_loss_type { + /** Logistic classification. + * Expected target: {0, 1}. + */ + QN_LOSS_LOGISTIC = 0, + /** L2 regression. + * Expected target: R. + */ + QN_LOSS_SQUARED = 1, + /** Softmax classification.. + * Expected target: {0, 1, ...}. + */ + QN_LOSS_SOFTMAX = 2, + /** Hinge. + * Expected target: {0, 1}. + */ + QN_LOSS_HINGE = 3, + /** Squared-hinge. + * Expected target: {0, 1}. + */ + QN_LOSS_SQUARED_HINGE = 4, + /** Epsilon-insensitive. + * Expected target: R. + */ + QN_LOSS_EPS_INS_HINGE = 5, + /** Epsilon-insensitive-squared. + * Expected target: R. + */ + QN_LOSS_SQ_EPS_INS_HINGE = 6, + /** L1 regression. + * Expected target: R. + */ + QN_LOSS_ABS = 7, + /** Someone forgot to set the loss type! */ + QN_LOSS_UNKNOWN = 99 + }; + + struct qn_params { + /** Loss type. */ + qn_loss_type loss; + /** Regularization: L1 component. */ + double penalty_l1; + /** Regularization: L2 component. */ + double penalty_l2; + /** Convergence criteria: the threshold on the gradient. */ + double grad_tol; + /** Convergence criteria: the threshold on the function change. */ + double change_tol; + /** Maximum number of iterations. */ + int max_iter; + /** Maximum number of linesearch (inner loop) iterations. */ + int linesearch_max_iter; + /** Number of vectors approximating the hessian (l-bfgs). */ + int lbfgs_memory; + /** Triggers extra output when greater than zero. */ + int verbose; + /** Whether to fit the bias term. */ + bool fit_intercept; + /** + * Whether to divide the L1 and L2 regularization parameters by the sample size. + * + * Note, the defined QN loss functions normally are scaled for the sample size, + * e.g. the average across the data rows is calculated. + * Enabling `penalty_normalized` makes this solver's behavior compatible to those solvers, + * which do not scale the loss functions (like sklearn.LogisticRegression()). + */ + bool penalty_normalized; + +qn_params() + : loss(QN_LOSS_UNKNOWN), + penalty_l1(0), + penalty_l2(0), + grad_tol(1e-4), + change_tol(1e-5), + max_iter(1000), + linesearch_max_iter(50), + lbfgs_memory(5), + verbose(0), + fit_intercept(true), + penalty_normalized(true) {} +}; + + + enum class LarsFitStatus { kOk, kCollinear, kError, kStop }; + + + namespace quasi_newton { + + enum LINE_SEARCH_ALGORITHM { + LBFGS_LS_BT_ARMIJO = 1, + LBFGS_LS_BT = 2, // Default. Alias for Wolfe + LBFGS_LS_BT_WOLFE = 2, + LBFGS_LS_BT_STRONG_WOLFE = 3 + }; + + enum LINE_SEARCH_RETCODE { + LS_SUCCESS = 0, + LS_INVALID_STEP_MIN = 1, + LS_INVALID_STEP_MAX = 2, + LS_MAX_ITERS_REACHED = 3, + LS_INVALID_DIR = 4, + LS_INVALID_STEP = 5 + }; + + enum OPT_RETCODE { + OPT_SUCCESS = 0, + OPT_NUMERIC_ERROR = 1, + OPT_LS_FAILED = 2, + OPT_MAX_ITERS_REACHED = 3, + OPT_INVALID_ARGS = 4 + }; + + template + class LBFGSParam { + public: + int m; // lbfgs memory limit + T epsilon; // controls convergence + int past; // lookback for function value based convergence test + T delta; // controls fun val based conv test + int max_iterations; + int linesearch; // see enum above + int max_linesearch; + T min_step; // min. allowed step length + T max_step; // max. allowed step length + T ftol; // line search tolerance + T wolfe; // wolfe parameter + T ls_dec; // line search decrease factor + T ls_inc; // line search increase factor + + public: + LBFGSParam() + { + m = 6; + epsilon = T(1e-5); + past = 0; + delta = T(0); + max_iterations = 0; + linesearch = LBFGS_LS_BT_ARMIJO; + max_linesearch = 20; + min_step = T(1e-20); + max_step = T(1e+20); + ftol = T(1e-4); + wolfe = T(0.9); + ls_dec = T(0.5); + ls_inc = T(2.1); + } + + explicit LBFGSParam(const qn_params& pams) : LBFGSParam() + { + m = pams.lbfgs_memory; + epsilon = T(pams.grad_tol); + // sometimes even number works better - to detect zig-zags; + past = pams.change_tol > 0 ? 10 : 0; + delta = T(pams.change_tol); + max_iterations = pams.max_iter; + max_linesearch = pams.linesearch_max_iter; + ftol = pams.change_tol > 0 ? T(pams.change_tol * 0.1) : T(1e-4); + } + + inline int check_param() const + { // TODO exceptions + int ret = 1; + if (m <= 0) return ret; + ret++; + if (epsilon <= 0) return ret; + ret++; + if (past < 0) return ret; + ret++; + if (delta < 0) return ret; + ret++; + if (max_iterations < 0) return ret; + ret++; + if (linesearch < LBFGS_LS_BT_ARMIJO || linesearch > LBFGS_LS_BT_STRONG_WOLFE) return ret; + ret++; + if (max_linesearch <= 0) return ret; + ret++; + if (min_step < 0) return ret; + ret++; + if (max_step < min_step) return ret; + ret++; + if (ftol <= 0 || ftol >= 0.5) return ret; + ret++; + if (wolfe <= ftol || wolfe >= 1) return ret; + ret++; + return 0; + } + }; + + struct LinearDims { + bool fit_intercept; + int C, D, dims, n_param; + LinearDims(int C, int D, bool fit_intercept) : C(C), D(D), fit_intercept(fit_intercept) + { + dims = D + fit_intercept; + n_param = dims * C; + } + }; + } + +} From 8d2bd0b6957cece0df61b1d282380bf0fa59b2c5 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 17 Oct 2022 20:30:36 -0400 Subject: [PATCH 19/35] Updates --- .../raft/solver/detail/qn/qn_decision.cuh | 51 ------- .../raft/solver/detail/qn/qn_solvers.cuh | 12 +- .../solver/detail/qn/simple_mat/sparse.hpp | 2 +- cpp/include/raft/solver/quasi_newton.cuh | 90 ++++++++++-- cpp/include/raft/solver/solver_types.hpp | 128 ++++++------------ 5 files changed, 131 insertions(+), 152 deletions(-) delete mode 100644 cpp/include/raft/solver/detail/qn/qn_decision.cuh diff --git a/cpp/include/raft/solver/detail/qn/qn_decision.cuh b/cpp/include/raft/solver/detail/qn/qn_decision.cuh deleted file mode 100644 index c7bf123e55..0000000000 --- a/cpp/include/raft/solver/detail/qn/qn_decision.cuh +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "objectives/base.cuh" -#include "objectives/linear.cuh" -#include "objectives/logistic.cuh" -#include "objectives/regularizer.cuh" -#include "objectives/softmax.cuh" -#include "objectives/hinge.cuh" -#include "qn_solvers.cuh" -#include "qn_util.cuh" - -#include -#include -#include - -namespace raft::solver::quasi_newton::detail { - -template -void linear_decision_function(const raft::handle_t& handle, - const qn_params& pams, - SimpleMat& X, - int C, - T* params, - T* scores, - cudaStream_t stream) { - // NOTE: While gtests pass X as row-major, and python API passes X as - // col-major, no extensive testing has been done to ensure that - // this function works correctly for both input types - int n_targets = qn_is_classification(pams.loss) && C == 2 ? 1 : C; - LinearDims dims(n_targets, X.n, pams.fit_intercept); - SimpleDenseMat W(params, n_targets, dims.dims); - SimpleDenseMat Z(scores, n_targets, X.m); - linearFwd(handle, Z, X, W, stream); -} -}; // namespace raft::solver::quasi_newton::detail diff --git a/cpp/include/raft/solver/detail/qn/qn_solvers.cuh b/cpp/include/raft/solver/detail/qn/qn_solvers.cuh index 48c698ee4e..89d067f4e2 100644 --- a/cpp/include/raft/solver/detail/qn/qn_solvers.cuh +++ b/cpp/include/raft/solver/detail/qn/qn_solvers.cuh @@ -75,10 +75,10 @@ inline bool update_and_check(const char* solver, T& fx, T& fxp, const T& gnorm, - ML::SimpleVec& x, - ML::SimpleVec& xp, - ML::SimpleVec& grad, - ML::SimpleVec& gradp, + raft::solver::quasi_newton::SimpleVec& x, + raft::solver::quasi_newton::SimpleVec& xp, + raft::solver::quasi_newton::SimpleVec& grad, + raft::solver::quasi_newton::SimpleVec& gradp, std::vector& fx_hist, T* dev_scalar, OPT_RETCODE& outcode, @@ -174,7 +174,7 @@ inline OPT_RETCODE min_lbfgs(const LBFGSParam& param, std::vector fx_hist(param.past > 0 ? param.past : 0); *k = 0; - ML::Logger::get().setLevel(verbosity); + raft::solver::quasi_newton::Logger::get().setLevel(verbosity); RAFT_LOG_DEBUG("Running L-BFGS"); // Evaluate function and compute gradient @@ -300,7 +300,7 @@ inline OPT_RETCODE min_owlqn(const LBFGSParam& param, p_ws += vec_size; T* dev_scalar = p_ws; - ML::Logger::get().setLevel(verbosity); + raft::solver::quasi_newton::Logger::get().setLevel(verbosity); SimpleVec svec, yvec; // mask vectors diff --git a/cpp/include/raft/solver/detail/qn/simple_mat/sparse.hpp b/cpp/include/raft/solver/detail/qn/simple_mat/sparse.hpp index 1d2a025ccd..cc79922267 100644 --- a/cpp/include/raft/solver/detail/qn/simple_mat/sparse.hpp +++ b/cpp/include/raft/solver/detail/qn/simple_mat/sparse.hpp @@ -19,7 +19,7 @@ #include #include "base.hpp" -#include +#include #include #include #include diff --git a/cpp/include/raft/solver/quasi_newton.cuh b/cpp/include/raft/solver/quasi_newton.cuh index 1de75b7c5a..06aa4b0520 100644 --- a/cpp/include/raft/solver/quasi_newton.cuh +++ b/cpp/include/raft/solver/quasi_newton.cuh @@ -28,14 +28,63 @@ namespace raft::solver::quasi_newton { - using raft::solver::quasi_newton::detail::objectives::AbsLoss; - using raft::solver::quasi_newton::detail::objectives::HingeLoss; - using raft::solver::quasi_newton::detail::objectives::LogisticLoss; - using raft::solver::quasi_newton::detail::objectives::LinearDims; - using raft::solver::quasi_newton::detail::objectives::SqHingeLoss; + /** + * The follow loss functions are wrapped so they will be included in the docs + * @tparam T + */ + + /** + * + * @tparam T + */ + template + struct AbsLoss : detail::objectives::AbsLoss { + AbsLoss(const raft::handle_t &handle, int D, bool has_bias) + : detail::objectives::AbsLoss(handle, D, has_bias) {} + } + + /** + * + * @tparam T + */ + template + struct HingeLoss : detail::objectives::HingeLoss { + HingeLoss(const raft::handle_t &handle, int D, bool has_bias) + : detail::objectives::HingeLoss(handle, D, has_bias) {} + } + + /** + * + * @tparam T + */ + template + struct LogisticLoss : detail::objectives::LogisticLoss { + LogisticLoss(const raft::handle_t &handle, int D, bool has_bias) + : detail::objectives::LogisticLoss(handle, D, has_bias) {} + } + + /** + * + * @tparam T + */ + template + struct SqHingeLoss : detail::objectives::SqHingeLoss { + SqHingeLoss(const raft::handle_t &handle, int D, bool has_bias) + : detail::objectives::SqHingeLoss(handle, D, has_bias) {} + } + + + template + struct EpsInsHingeLoss : QNLinearBase> { + typedef QNLinearBase> Super; + EpsInsHingeLoss(const raft::handle_t& handle, int D, bool has_bias, T sensitivity) + : Super(handle, D, 1, has_bias), lz{sensitivity}, dlz{sensitivity} + { + } + using raft::solver::quasi_newton::detail::objectives::SqEpsInsHingeLoss; using raft::solver::quasi_newton::detail::objectives::EpsInsHingeLoss; - using raft::solver::quasi_newton::detail::LBFGSParam + using raft::solver::quasi_newton::detail::LBFGSParam; /** * @@ -56,17 +105,36 @@ namespace raft::solver::quasi_newton { template struct QNLinearBase : detail::objectives::QNLinearBase { QNLinearBase(const raft::handle_t &handle, int D, int C, bool fit_intercept) - : detail::objectives::QNLinearBase(C, D, fit_intercept) {} + : detail::objectives::QNLinearBase(C, D, fit_intercept) {} } + using raft::solver::quasi_newton::detail::objectives::Softmax; + using raft::solver::quasi_newton::detail::objectives::QNWithData; + using raft::solver::quasi_newton::detail::objectives::QNLinearBase; - using raft::solver::quasi_newton::detail::objectives::Softmax; + template + OPT_RETCODE lbfgs_minimize(const LBFGSParam& param, + Function& f, // function to minimize + SimpleVec& x, // initial point, holds result + T& fx, // output function value + int* k, // output iterations + SimpleVec& workspace, // scratch space + cudaStream_t stream, + int verbosity = 0) - using raft::solver::quasi_newton::detail::objectives::QNWithData; - using raft::solver::quasi_newton::detail::objectives::QuasiNewtonBase; + template + OPT_RETCODE owl_minimize(const LBFGSParam& param, + Function& f, + const T l1_penalty, + const int pg_limit, + SimpleVec& x, + T& fx, + int* k) { + + } template - inline int qn_minimize(const raft::handle_t& handle, + int qn_minimize(const raft::handle_t& handle, T *x, T* fx, int* num_iters, diff --git a/cpp/include/raft/solver/solver_types.hpp b/cpp/include/raft/solver/solver_types.hpp index e44b05cb58..2fa60ce4b3 100644 --- a/cpp/include/raft/solver/solver_types.hpp +++ b/cpp/include/raft/solver/solver_types.hpp @@ -35,95 +35,57 @@ namespace raft::solver { enum penalty { NONE, L1, L2, ELASTICNET }; -/** Loss function types supported by the Quasi-Newton solvers. */ - enum qn_loss_type { - /** Logistic classification. - * Expected target: {0, 1}. - */ - QN_LOSS_LOGISTIC = 0, - /** L2 regression. - * Expected target: R. - */ - QN_LOSS_SQUARED = 1, - /** Softmax classification.. - * Expected target: {0, 1, ...}. - */ - QN_LOSS_SOFTMAX = 2, - /** Hinge. - * Expected target: {0, 1}. - */ - QN_LOSS_HINGE = 3, - /** Squared-hinge. - * Expected target: {0, 1}. - */ - QN_LOSS_SQUARED_HINGE = 4, - /** Epsilon-insensitive. - * Expected target: R. - */ - QN_LOSS_EPS_INS_HINGE = 5, - /** Epsilon-insensitive-squared. - * Expected target: R. - */ - QN_LOSS_SQ_EPS_INS_HINGE = 6, - /** L1 regression. - * Expected target: R. - */ - QN_LOSS_ABS = 7, - /** Someone forgot to set the loss type! */ - QN_LOSS_UNKNOWN = 99 - }; - - struct qn_params { - /** Loss type. */ - qn_loss_type loss; - /** Regularization: L1 component. */ - double penalty_l1; - /** Regularization: L2 component. */ - double penalty_l2; - /** Convergence criteria: the threshold on the gradient. */ - double grad_tol; - /** Convergence criteria: the threshold on the function change. */ - double change_tol; - /** Maximum number of iterations. */ - int max_iter; - /** Maximum number of linesearch (inner loop) iterations. */ - int linesearch_max_iter; - /** Number of vectors approximating the hessian (l-bfgs). */ - int lbfgs_memory; - /** Triggers extra output when greater than zero. */ - int verbose; - /** Whether to fit the bias term. */ - bool fit_intercept; - /** - * Whether to divide the L1 and L2 regularization parameters by the sample size. - * - * Note, the defined QN loss functions normally are scaled for the sample size, - * e.g. the average across the data rows is calculated. - * Enabling `penalty_normalized` makes this solver's behavior compatible to those solvers, - * which do not scale the loss functions (like sklearn.LogisticRegression()). - */ - bool penalty_normalized; - -qn_params() - : loss(QN_LOSS_UNKNOWN), - penalty_l1(0), - penalty_l2(0), - grad_tol(1e-4), - change_tol(1e-5), - max_iter(1000), - linesearch_max_iter(50), - lbfgs_memory(5), - verbose(0), - fit_intercept(true), - penalty_normalized(true) {} -}; - enum class LarsFitStatus { kOk, kCollinear, kError, kStop }; namespace quasi_newton { + struct qn_params { + /** Loss type. */ + qn_loss_type loss; + /** Regularization: L1 component. */ + double penalty_l1; + /** Regularization: L2 component. */ + double penalty_l2; + /** Convergence criteria: the threshold on the gradient. */ + double grad_tol; + /** Convergence criteria: the threshold on the function change. */ + double change_tol; + /** Maximum number of iterations. */ + int max_iter; + /** Maximum number of linesearch (inner loop) iterations. */ + int linesearch_max_iter; + /** Number of vectors approximating the hessian (l-bfgs). */ + int lbfgs_memory; + /** Triggers extra output when greater than zero. */ + int verbose; + /** Whether to fit the bias term. */ + bool fit_intercept; + /** + * Whether to divide the L1 and L2 regularization parameters by the sample size. + * + * Note, the defined QN loss functions normally are scaled for the sample size, + * e.g. the average across the data rows is calculated. + * Enabling `penalty_normalized` makes this solver's behavior compatible to those solvers, + * which do not scale the loss functions (like sklearn.LogisticRegression()). + */ + bool penalty_normalized; + + qn_params() + : loss(QN_LOSS_UNKNOWN), + penalty_l1(0), + penalty_l2(0), + grad_tol(1e-4), + change_tol(1e-5), + max_iter(1000), + linesearch_max_iter(50), + lbfgs_memory(5), + verbose(0), + fit_intercept(true), + penalty_normalized(true) {} + }; + enum LINE_SEARCH_ALGORITHM { LBFGS_LS_BT_ARMIJO = 1, LBFGS_LS_BT = 2, // Default. Alias for Wolfe From 879e85f76e761795dcea6382d5211ffe0dc037c7 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 20 Oct 2022 11:20:50 -0400 Subject: [PATCH 20/35] Fixing style --- cpp/include/raft/solver/detail/lars.cuh | 7 +- .../raft/solver/detail/qn/objectives/base.cuh | 11 +- .../solver/detail/qn/objectives/hinge.cuh | 4 +- .../solver/detail/qn/objectives/linear.cuh | 6 +- .../solver/detail/qn/objectives/logistic.cuh | 6 +- .../detail/qn/objectives/regularizer.cuh | 8 +- .../solver/detail/qn/objectives/softmax.cuh | 6 +- .../raft/solver/detail/qn/qn_linesearch.cuh | 4 +- .../raft/solver/detail/qn/qn_solvers.cuh | 2 +- cpp/include/raft/solver/detail/qn/qn_util.cuh | 9 +- .../solver/detail/qn/simple_mat/dense.hpp | 4 +- .../solver/detail/qn/simple_mat/sparse.hpp | 4 +- cpp/include/raft/solver/quasi_newton.cuh | 232 ++++++------ cpp/include/raft/solver/solver_types.hpp | 358 +++++++++--------- 14 files changed, 333 insertions(+), 328 deletions(-) diff --git a/cpp/include/raft/solver/detail/lars.cuh b/cpp/include/raft/solver/detail/lars.cuh index 1c2bd04285..d753dd8253 100644 --- a/cpp/include/raft/solver/detail/lars.cuh +++ b/cpp/include/raft/solver/detail/lars.cuh @@ -27,13 +27,13 @@ #include #include #include -#include -#include -#include #include #include #include #include +#include +#include +#include #include #include #include @@ -47,7 +47,6 @@ namespace raft::solver::detail { - /** * @brief Select the largest element from the inactive working set. * diff --git a/cpp/include/raft/solver/detail/qn/objectives/base.cuh b/cpp/include/raft/solver/detail/qn/objectives/base.cuh index 1edc1904a5..0dbc79807e 100644 --- a/cpp/include/raft/solver/detail/qn/objectives/base.cuh +++ b/cpp/include/raft/solver/detail/qn/objectives/base.cuh @@ -17,13 +17,13 @@ #pragma once #include "../simple_mat.cuh" -#include -#include #include #include #include #include #include +#include +#include #include #include @@ -90,8 +90,6 @@ inline void linearBwd(const raft::handle_t& handle, } } - - template struct QNLinearBase : LinearDims { typedef SimpleDenseMat Mat; @@ -199,7 +197,10 @@ struct QNWithData : LinearDims { SimpleDenseMat* Z; QuasiNewtonObjective* objective; - QNWithData(QuasiNewtonObjective* obj, const SimpleMat& X, const SimpleVec& y, SimpleDenseMat& Z) + QNWithData(QuasiNewtonObjective* obj, + const SimpleMat& X, + const SimpleVec& y, + SimpleDenseMat& Z) : objective(obj), X(&X), y(&y), Z(&Z), LinearDims(obj->C, obj->D, obj->fit_intercept) { } diff --git a/cpp/include/raft/solver/detail/qn/objectives/hinge.cuh b/cpp/include/raft/solver/detail/qn/objectives/hinge.cuh index 6a43babd6c..c8effc6b7a 100644 --- a/cpp/include/raft/solver/detail/qn/objectives/hinge.cuh +++ b/cpp/include/raft/solver/detail/qn/objectives/hinge.cuh @@ -16,10 +16,10 @@ #pragma once -#include "base.cuh" #include "../simple_mat.cuh" -#include +#include "base.cuh" #include +#include namespace raft::solver::quasi_newton::detail::objectives { diff --git a/cpp/include/raft/solver/detail/qn/objectives/linear.cuh b/cpp/include/raft/solver/detail/qn/objectives/linear.cuh index 93c4363f78..731a47b886 100644 --- a/cpp/include/raft/solver/detail/qn/objectives/linear.cuh +++ b/cpp/include/raft/solver/detail/qn/objectives/linear.cuh @@ -16,12 +16,12 @@ #pragma once -#include "base.cuh" #include "../simple_mat.cuh" -#include +#include "base.cuh" #include +#include -namespace raft::solver::quasi_newton::detail::objectives { +namespace raft::solver::quasi_newton::detail::objectives { template struct SquaredLoss : QNLinearBase> { diff --git a/cpp/include/raft/solver/detail/qn/objectives/logistic.cuh b/cpp/include/raft/solver/detail/qn/objectives/logistic.cuh index 064cf4b793..ea2c25bf6f 100644 --- a/cpp/include/raft/solver/detail/qn/objectives/logistic.cuh +++ b/cpp/include/raft/solver/detail/qn/objectives/logistic.cuh @@ -16,12 +16,12 @@ #pragma once -#include "base.cuh" #include "../simple_mat.cuh" -#include +#include "base.cuh" #include +#include -namespace raft::solver::quasi_newton::detail::objectives { +namespace raft::solver::quasi_newton::detail::objectives { template struct LogisticLoss : QNLinearBase> { diff --git a/cpp/include/raft/solver/detail/qn/objectives/regularizer.cuh b/cpp/include/raft/solver/detail/qn/objectives/regularizer.cuh index 45395cbd0e..e4acf76672 100644 --- a/cpp/include/raft/solver/detail/qn/objectives/regularizer.cuh +++ b/cpp/include/raft/solver/detail/qn/objectives/regularizer.cuh @@ -16,15 +16,15 @@ #pragma once -#include "base.cuh" #include "../simple_mat.cuh" -#include -#include +#include "base.cuh" #include #include #include +#include +#include -namespace raft::solver::quasi_newton::detail::objectives { +namespace raft::solver::quasi_newton::detail::objectives { template struct Tikhonov { diff --git a/cpp/include/raft/solver/detail/qn/objectives/softmax.cuh b/cpp/include/raft/solver/detail/qn/objectives/softmax.cuh index 642800a7ff..70eeb9d6e6 100644 --- a/cpp/include/raft/solver/detail/qn/objectives/softmax.cuh +++ b/cpp/include/raft/solver/detail/qn/objectives/softmax.cuh @@ -16,12 +16,12 @@ #pragma once -#include "base.cuh" #include "../simple_mat.cuh" -#include +#include "base.cuh" #include +#include -namespace raft::solver::quasi_newton::detail::objectives { +namespace raft::solver::quasi_newton::detail::objectives { using raft::ceildiv; using raft::myExp; using raft::myLog; diff --git a/cpp/include/raft/solver/detail/qn/qn_linesearch.cuh b/cpp/include/raft/solver/detail/qn/qn_linesearch.cuh index 70ed2a471f..26445fbed9 100644 --- a/cpp/include/raft/solver/detail/qn/qn_linesearch.cuh +++ b/cpp/include/raft/solver/detail/qn/qn_linesearch.cuh @@ -16,14 +16,14 @@ #pragma once -#include #include "qn_util.cuh" +#include /* * Linesearch functions */ -namespace raft::solver::quasi_newton::detail { +namespace raft::solver::quasi_newton::detail { template struct LSProjectedStep { diff --git a/cpp/include/raft/solver/detail/qn/qn_solvers.cuh b/cpp/include/raft/solver/detail/qn/qn_solvers.cuh index 89d067f4e2..ce64590127 100644 --- a/cpp/include/raft/solver/detail/qn/qn_solvers.cuh +++ b/cpp/include/raft/solver/detail/qn/qn_solvers.cuh @@ -47,7 +47,7 @@ #include #include -namespace raft::solver::quasi_newton::detail { +namespace raft::solver::quasi_newton::detail { // TODO better way to deal with alignment? Smaller aligne possible? constexpr size_t qn_align = 256; diff --git a/cpp/include/raft/solver/detail/qn/qn_util.cuh b/cpp/include/raft/solver/detail/qn/qn_util.cuh index 0124d248c9..1081fef123 100644 --- a/cpp/include/raft/solver/detail/qn/qn_util.cuh +++ b/cpp/include/raft/solver/detail/qn/qn_util.cuh @@ -16,13 +16,11 @@ #pragma once -#include #include +#include #include -namespace raft::solver::quasi_newton::detail { - - +namespace raft::solver::quasi_newton::detail { inline bool qn_is_classification(qn_loss_type t) { @@ -36,7 +34,8 @@ inline bool qn_is_classification(qn_loss_type t) } template -HDI T project_orth(T x, T y) { +HDI T project_orth(T x, T y) +{ return x * y <= T(0) ? T(0) : x; } diff --git a/cpp/include/raft/solver/detail/qn/simple_mat/dense.hpp b/cpp/include/raft/solver/detail/qn/simple_mat/dense.hpp index c915586cc2..971737a259 100644 --- a/cpp/include/raft/solver/detail/qn/simple_mat/dense.hpp +++ b/cpp/include/raft/solver/detail/qn/simple_mat/dense.hpp @@ -19,11 +19,11 @@ #include #include "base.hpp" -#include #include -#include #include #include +#include +#include // #TODO: Replace with public header when ready #include #include diff --git a/cpp/include/raft/solver/detail/qn/simple_mat/sparse.hpp b/cpp/include/raft/solver/detail/qn/simple_mat/sparse.hpp index cc79922267..83734b5b7f 100644 --- a/cpp/include/raft/solver/detail/qn/simple_mat/sparse.hpp +++ b/cpp/include/raft/solver/detail/qn/simple_mat/sparse.hpp @@ -19,10 +19,10 @@ #include #include "base.hpp" -#include #include -#include #include +#include +#include #include #include diff --git a/cpp/include/raft/solver/quasi_newton.cuh b/cpp/include/raft/solver/quasi_newton.cuh index 06aa4b0520..090a1a13db 100644 --- a/cpp/include/raft/solver/quasi_newton.cuh +++ b/cpp/include/raft/solver/quasi_newton.cuh @@ -23,125 +23,133 @@ #include #include -#include #include +#include namespace raft::solver::quasi_newton { - /** - * The follow loss functions are wrapped so they will be included in the docs - * @tparam T - */ - - /** - * - * @tparam T - */ - template - struct AbsLoss : detail::objectives::AbsLoss { - AbsLoss(const raft::handle_t &handle, int D, bool has_bias) - : detail::objectives::AbsLoss(handle, D, has_bias) {} - } - - /** - * - * @tparam T - */ - template - struct HingeLoss : detail::objectives::HingeLoss { - HingeLoss(const raft::handle_t &handle, int D, bool has_bias) - : detail::objectives::HingeLoss(handle, D, has_bias) {} - } - - /** - * - * @tparam T - */ - template - struct LogisticLoss : detail::objectives::LogisticLoss { - LogisticLoss(const raft::handle_t &handle, int D, bool has_bias) - : detail::objectives::LogisticLoss(handle, D, has_bias) {} - } - - /** - * - * @tparam T - */ - template - struct SqHingeLoss : detail::objectives::SqHingeLoss { - SqHingeLoss(const raft::handle_t &handle, int D, bool has_bias) - : detail::objectives::SqHingeLoss(handle, D, has_bias) {} - } - +/** + * The follow loss functions are wrapped so they will be included in the docs + * @tparam T + */ - template - struct EpsInsHingeLoss : QNLinearBase> { - typedef QNLinearBase> Super; - EpsInsHingeLoss(const raft::handle_t& handle, int D, bool has_bias, T sensitivity) - : Super(handle, D, 1, has_bias), lz{sensitivity}, dlz{sensitivity} - { - } - - using raft::solver::quasi_newton::detail::objectives::SqEpsInsHingeLoss; - using raft::solver::quasi_newton::detail::objectives::EpsInsHingeLoss; - using raft::solver::quasi_newton::detail::LBFGSParam; - - /** - * - * @tparam T - * @tparam Loss - * @tparam Reg - */ - template - class RegularizedQN : public detail::objectives::RegularizedQN { - RegularizedQN(Loss* loss, Reg* reg): detail::objectives::RegularizedQN(loss, reg) {} - }; - - /** - * - * @tparam T - * @tparam Loss - */ - template - struct QNLinearBase : detail::objectives::QNLinearBase { - QNLinearBase(const raft::handle_t &handle, int D, int C, bool fit_intercept) - : detail::objectives::QNLinearBase(C, D, fit_intercept) {} +/** + * + * @tparam T + */ +template +struct AbsLoss : detail::objectives::AbsLoss { + AbsLoss(const raft::handle_t& handle, int D, bool has_bias) + : detail::objectives::AbsLoss(handle, D, has_bias) + { + } +} + +/** + * + * @tparam T + */ +template +struct HingeLoss : detail::objectives::HingeLoss { + HingeLoss(const raft::handle_t& handle, int D, bool has_bias) + : detail::objectives::HingeLoss(handle, D, has_bias) + { + } +} + +/** + * + * @tparam T + */ +template +struct LogisticLoss : detail::objectives::LogisticLoss { + LogisticLoss(const raft::handle_t& handle, int D, bool has_bias) + : detail::objectives::LogisticLoss(handle, D, has_bias) + { + } +} + +/** + * + * @tparam T + */ +template +struct SqHingeLoss : detail::objectives::SqHingeLoss { + SqHingeLoss(const raft::handle_t& handle, int D, bool has_bias) + : detail::objectives::SqHingeLoss(handle, D, has_bias) + { + } +} + +template +struct EpsInsHingeLoss : QNLinearBase> { + typedef QNLinearBase> Super; + EpsInsHingeLoss(const raft::handle_t& handle, int D, bool has_bias, T sensitivity) + : Super(handle, D, 1, has_bias), lz{sensitivity}, dlz{sensitivity} + { + } + + using raft::solver::quasi_newton::detail::LBFGSParam; + using raft::solver::quasi_newton::detail::objectives::EpsInsHingeLoss; + using raft::solver::quasi_newton::detail::objectives::SqEpsInsHingeLoss; + + /** + * + * @tparam T + * @tparam Loss + * @tparam Reg + */ + template + class RegularizedQN : public detail::objectives::RegularizedQN { + RegularizedQN(Loss* loss, Reg* reg) : detail::objectives::RegularizedQN(loss, reg) {} + }; + + /** + * + * @tparam T + * @tparam Loss + */ + template + struct QNLinearBase : detail::objectives::QNLinearBase { + QNLinearBase(const raft::handle_t& handle, int D, int C, bool fit_intercept) + : detail::objectives::QNLinearBase(C, D, fit_intercept) + { } - - using raft::solver::quasi_newton::detail::objectives::Softmax; - using raft::solver::quasi_newton::detail::objectives::QNWithData; - using raft::solver::quasi_newton::detail::objectives::QNLinearBase; - - template - OPT_RETCODE lbfgs_minimize(const LBFGSParam& param, - Function& f, // function to minimize - SimpleVec& x, // initial point, holds result - T& fx, // output function value - int* k, // output iterations - SimpleVec& workspace, // scratch space - cudaStream_t stream, - int verbosity = 0) + } + + using raft::solver::quasi_newton::detail::objectives::Softmax; + using raft::solver::quasi_newton::detail::objectives::QNLinearBase; + using raft::solver::quasi_newton::detail::objectives::QNWithData; + + template + OPT_RETCODE lbfgs_minimize(const LBFGSParam& param, + Function& f, // function to minimize + SimpleVec& x, // initial point, holds result + T& fx, // output function value + int* k, // output iterations + SimpleVec& workspace, // scratch space + cudaStream_t stream, + int verbosity = 0) template OPT_RETCODE owl_minimize(const LBFGSParam& param, - Function& f, - const T l1_penalty, - const int pg_limit, - SimpleVec& x, - T& fx, - int* k) { - - } - - template - int qn_minimize(const raft::handle_t& handle, - T *x, - T* fx, - int* num_iters, - LossFunction& loss, - const T l1, - const detail::LBFGSParam& opt_param) { - - } - + Function& f, + const T l1_penalty, + const int pg_limit, + SimpleVec& x, + T& fx, + int* k) + { + } + + template + int qn_minimize(const raft::handle_t& handle, + T* x, + T* fx, + int* num_iters, + LossFunction& loss, + const T l1, + const detail::LBFGSParam& opt_param) + { + } } \ No newline at end of file diff --git a/cpp/include/raft/solver/solver_types.hpp b/cpp/include/raft/solver/solver_types.hpp index 2fa60ce4b3..f3e7d66a27 100644 --- a/cpp/include/raft/solver/solver_types.hpp +++ b/cpp/include/raft/solver/solver_types.hpp @@ -16,185 +16,183 @@ #pragma once - - namespace raft::solver { - enum lr_type { - OPTIMAL, - CONSTANT, - INVSCALING, - ADAPTIVE, - }; - - enum loss_funct { - SQRD_LOSS, - HINGE, - LOG, - }; - - enum penalty { NONE, L1, L2, ELASTICNET }; - - - enum class LarsFitStatus { kOk, kCollinear, kError, kStop }; - - - namespace quasi_newton { - - struct qn_params { - /** Loss type. */ - qn_loss_type loss; - /** Regularization: L1 component. */ - double penalty_l1; - /** Regularization: L2 component. */ - double penalty_l2; - /** Convergence criteria: the threshold on the gradient. */ - double grad_tol; - /** Convergence criteria: the threshold on the function change. */ - double change_tol; - /** Maximum number of iterations. */ - int max_iter; - /** Maximum number of linesearch (inner loop) iterations. */ - int linesearch_max_iter; - /** Number of vectors approximating the hessian (l-bfgs). */ - int lbfgs_memory; - /** Triggers extra output when greater than zero. */ - int verbose; - /** Whether to fit the bias term. */ - bool fit_intercept; - /** - * Whether to divide the L1 and L2 regularization parameters by the sample size. - * - * Note, the defined QN loss functions normally are scaled for the sample size, - * e.g. the average across the data rows is calculated. - * Enabling `penalty_normalized` makes this solver's behavior compatible to those solvers, - * which do not scale the loss functions (like sklearn.LogisticRegression()). - */ - bool penalty_normalized; - - qn_params() - : loss(QN_LOSS_UNKNOWN), - penalty_l1(0), - penalty_l2(0), - grad_tol(1e-4), - change_tol(1e-5), - max_iter(1000), - linesearch_max_iter(50), - lbfgs_memory(5), - verbose(0), - fit_intercept(true), - penalty_normalized(true) {} - }; - - enum LINE_SEARCH_ALGORITHM { - LBFGS_LS_BT_ARMIJO = 1, - LBFGS_LS_BT = 2, // Default. Alias for Wolfe - LBFGS_LS_BT_WOLFE = 2, - LBFGS_LS_BT_STRONG_WOLFE = 3 - }; - - enum LINE_SEARCH_RETCODE { - LS_SUCCESS = 0, - LS_INVALID_STEP_MIN = 1, - LS_INVALID_STEP_MAX = 2, - LS_MAX_ITERS_REACHED = 3, - LS_INVALID_DIR = 4, - LS_INVALID_STEP = 5 - }; - - enum OPT_RETCODE { - OPT_SUCCESS = 0, - OPT_NUMERIC_ERROR = 1, - OPT_LS_FAILED = 2, - OPT_MAX_ITERS_REACHED = 3, - OPT_INVALID_ARGS = 4 - }; - - template - class LBFGSParam { - public: - int m; // lbfgs memory limit - T epsilon; // controls convergence - int past; // lookback for function value based convergence test - T delta; // controls fun val based conv test - int max_iterations; - int linesearch; // see enum above - int max_linesearch; - T min_step; // min. allowed step length - T max_step; // max. allowed step length - T ftol; // line search tolerance - T wolfe; // wolfe parameter - T ls_dec; // line search decrease factor - T ls_inc; // line search increase factor - - public: - LBFGSParam() - { - m = 6; - epsilon = T(1e-5); - past = 0; - delta = T(0); - max_iterations = 0; - linesearch = LBFGS_LS_BT_ARMIJO; - max_linesearch = 20; - min_step = T(1e-20); - max_step = T(1e+20); - ftol = T(1e-4); - wolfe = T(0.9); - ls_dec = T(0.5); - ls_inc = T(2.1); - } - - explicit LBFGSParam(const qn_params& pams) : LBFGSParam() - { - m = pams.lbfgs_memory; - epsilon = T(pams.grad_tol); - // sometimes even number works better - to detect zig-zags; - past = pams.change_tol > 0 ? 10 : 0; - delta = T(pams.change_tol); - max_iterations = pams.max_iter; - max_linesearch = pams.linesearch_max_iter; - ftol = pams.change_tol > 0 ? T(pams.change_tol * 0.1) : T(1e-4); - } - - inline int check_param() const - { // TODO exceptions - int ret = 1; - if (m <= 0) return ret; - ret++; - if (epsilon <= 0) return ret; - ret++; - if (past < 0) return ret; - ret++; - if (delta < 0) return ret; - ret++; - if (max_iterations < 0) return ret; - ret++; - if (linesearch < LBFGS_LS_BT_ARMIJO || linesearch > LBFGS_LS_BT_STRONG_WOLFE) return ret; - ret++; - if (max_linesearch <= 0) return ret; - ret++; - if (min_step < 0) return ret; - ret++; - if (max_step < min_step) return ret; - ret++; - if (ftol <= 0 || ftol >= 0.5) return ret; - ret++; - if (wolfe <= ftol || wolfe >= 1) return ret; - ret++; - return 0; - } - }; - - struct LinearDims { - bool fit_intercept; - int C, D, dims, n_param; - LinearDims(int C, int D, bool fit_intercept) : C(C), D(D), fit_intercept(fit_intercept) - { - dims = D + fit_intercept; - n_param = dims * C; - } - }; - } - -} +enum lr_type { + OPTIMAL, + CONSTANT, + INVSCALING, + ADAPTIVE, +}; + +enum loss_funct { + SQRD_LOSS, + HINGE, + LOG, +}; + +enum penalty { NONE, L1, L2, ELASTICNET }; + +enum class LarsFitStatus { kOk, kCollinear, kError, kStop }; + +namespace quasi_newton { + +struct qn_params { + /** Loss type. */ + qn_loss_type loss; + /** Regularization: L1 component. */ + double penalty_l1; + /** Regularization: L2 component. */ + double penalty_l2; + /** Convergence criteria: the threshold on the gradient. */ + double grad_tol; + /** Convergence criteria: the threshold on the function change. */ + double change_tol; + /** Maximum number of iterations. */ + int max_iter; + /** Maximum number of linesearch (inner loop) iterations. */ + int linesearch_max_iter; + /** Number of vectors approximating the hessian (l-bfgs). */ + int lbfgs_memory; + /** Triggers extra output when greater than zero. */ + int verbose; + /** Whether to fit the bias term. */ + bool fit_intercept; + /** + * Whether to divide the L1 and L2 regularization parameters by the sample size. + * + * Note, the defined QN loss functions normally are scaled for the sample size, + * e.g. the average across the data rows is calculated. + * Enabling `penalty_normalized` makes this solver's behavior compatible to those solvers, + * which do not scale the loss functions (like sklearn.LogisticRegression()). + */ + bool penalty_normalized; + + qn_params() + : loss(QN_LOSS_UNKNOWN), + penalty_l1(0), + penalty_l2(0), + grad_tol(1e-4), + change_tol(1e-5), + max_iter(1000), + linesearch_max_iter(50), + lbfgs_memory(5), + verbose(0), + fit_intercept(true), + penalty_normalized(true) + { + } +}; + +enum LINE_SEARCH_ALGORITHM { + LBFGS_LS_BT_ARMIJO = 1, + LBFGS_LS_BT = 2, // Default. Alias for Wolfe + LBFGS_LS_BT_WOLFE = 2, + LBFGS_LS_BT_STRONG_WOLFE = 3 +}; + +enum LINE_SEARCH_RETCODE { + LS_SUCCESS = 0, + LS_INVALID_STEP_MIN = 1, + LS_INVALID_STEP_MAX = 2, + LS_MAX_ITERS_REACHED = 3, + LS_INVALID_DIR = 4, + LS_INVALID_STEP = 5 +}; + +enum OPT_RETCODE { + OPT_SUCCESS = 0, + OPT_NUMERIC_ERROR = 1, + OPT_LS_FAILED = 2, + OPT_MAX_ITERS_REACHED = 3, + OPT_INVALID_ARGS = 4 +}; + +template +class LBFGSParam { + public: + int m; // lbfgs memory limit + T epsilon; // controls convergence + int past; // lookback for function value based convergence test + T delta; // controls fun val based conv test + int max_iterations; + int linesearch; // see enum above + int max_linesearch; + T min_step; // min. allowed step length + T max_step; // max. allowed step length + T ftol; // line search tolerance + T wolfe; // wolfe parameter + T ls_dec; // line search decrease factor + T ls_inc; // line search increase factor + + public: + LBFGSParam() + { + m = 6; + epsilon = T(1e-5); + past = 0; + delta = T(0); + max_iterations = 0; + linesearch = LBFGS_LS_BT_ARMIJO; + max_linesearch = 20; + min_step = T(1e-20); + max_step = T(1e+20); + ftol = T(1e-4); + wolfe = T(0.9); + ls_dec = T(0.5); + ls_inc = T(2.1); + } + + explicit LBFGSParam(const qn_params& pams) : LBFGSParam() + { + m = pams.lbfgs_memory; + epsilon = T(pams.grad_tol); + // sometimes even number works better - to detect zig-zags; + past = pams.change_tol > 0 ? 10 : 0; + delta = T(pams.change_tol); + max_iterations = pams.max_iter; + max_linesearch = pams.linesearch_max_iter; + ftol = pams.change_tol > 0 ? T(pams.change_tol * 0.1) : T(1e-4); + } + + inline int check_param() const + { // TODO exceptions + int ret = 1; + if (m <= 0) return ret; + ret++; + if (epsilon <= 0) return ret; + ret++; + if (past < 0) return ret; + ret++; + if (delta < 0) return ret; + ret++; + if (max_iterations < 0) return ret; + ret++; + if (linesearch < LBFGS_LS_BT_ARMIJO || linesearch > LBFGS_LS_BT_STRONG_WOLFE) return ret; + ret++; + if (max_linesearch <= 0) return ret; + ret++; + if (min_step < 0) return ret; + ret++; + if (max_step < min_step) return ret; + ret++; + if (ftol <= 0 || ftol >= 0.5) return ret; + ret++; + if (wolfe <= ftol || wolfe >= 1) return ret; + ret++; + return 0; + } +}; + +struct LinearDims { + bool fit_intercept; + int C, D, dims, n_param; + LinearDims(int C, int D, bool fit_intercept) : C(C), D(D), fit_intercept(fit_intercept) + { + dims = D + fit_intercept; + n_param = dims * C; + } +}; +} // namespace quasi_newton + +} // namespace raft::solver From 66e128168a6d7d7626e60ddf5e9deee2ec63688d Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 20 Oct 2022 13:45:40 -0400 Subject: [PATCH 21/35] Checking in --- .../raft/solver/detail/qn/objectives/base.cuh | 2 +- .../solver/detail/qn/objectives/hinge.cuh | 2 +- .../solver/detail/qn/objectives/linear.cuh | 2 +- .../solver/detail/qn/objectives/logistic.cuh | 2 +- .../detail/qn/objectives/regularizer.cuh | 2 +- .../solver/detail/qn/objectives/softmax.cuh | 2 +- .../raft/solver/detail/qn/qn_solvers.cuh | 2 +- .../raft/solver/detail/qn/simple_mat.cuh | 20 - .../raft/solver/detail/qn/simple_mat/base.hpp | 54 -- .../solver/detail/qn/simple_mat/dense.hpp | 413 --------- .../solver/detail/qn/simple_mat/sparse.hpp | 216 ----- cpp/include/raft/solver/quasi_newton.cuh | 198 ++++- cpp/include/raft/solver/simple_mat.cuh | 622 +++++++++++++ cpp/test/CMakeLists.txt | 3 +- cpp/test/{lap => solver}/lap.cu | 0 cpp/test/solver/quasi_newton.cu | 833 ++++++++++++++++++ 16 files changed, 1626 insertions(+), 747 deletions(-) delete mode 100644 cpp/include/raft/solver/detail/qn/simple_mat.cuh delete mode 100644 cpp/include/raft/solver/detail/qn/simple_mat/base.hpp delete mode 100644 cpp/include/raft/solver/detail/qn/simple_mat/dense.hpp delete mode 100644 cpp/include/raft/solver/detail/qn/simple_mat/sparse.hpp create mode 100644 cpp/include/raft/solver/simple_mat.cuh rename cpp/test/{lap => solver}/lap.cu (100%) create mode 100644 cpp/test/solver/quasi_newton.cu diff --git a/cpp/include/raft/solver/detail/qn/objectives/base.cuh b/cpp/include/raft/solver/detail/qn/objectives/base.cuh index 0dbc79807e..b3d60637a1 100644 --- a/cpp/include/raft/solver/detail/qn/objectives/base.cuh +++ b/cpp/include/raft/solver/detail/qn/objectives/base.cuh @@ -16,7 +16,7 @@ #pragma once -#include "../simple_mat.cuh" +#include #include #include #include diff --git a/cpp/include/raft/solver/detail/qn/objectives/hinge.cuh b/cpp/include/raft/solver/detail/qn/objectives/hinge.cuh index c8effc6b7a..dcab7543fc 100644 --- a/cpp/include/raft/solver/detail/qn/objectives/hinge.cuh +++ b/cpp/include/raft/solver/detail/qn/objectives/hinge.cuh @@ -16,7 +16,7 @@ #pragma once -#include "../simple_mat.cuh" +#include #include "base.cuh" #include #include diff --git a/cpp/include/raft/solver/detail/qn/objectives/linear.cuh b/cpp/include/raft/solver/detail/qn/objectives/linear.cuh index 731a47b886..af6094e00f 100644 --- a/cpp/include/raft/solver/detail/qn/objectives/linear.cuh +++ b/cpp/include/raft/solver/detail/qn/objectives/linear.cuh @@ -16,7 +16,7 @@ #pragma once -#include "../simple_mat.cuh" +#include #include "base.cuh" #include #include diff --git a/cpp/include/raft/solver/detail/qn/objectives/logistic.cuh b/cpp/include/raft/solver/detail/qn/objectives/logistic.cuh index ea2c25bf6f..acb7c1ac55 100644 --- a/cpp/include/raft/solver/detail/qn/objectives/logistic.cuh +++ b/cpp/include/raft/solver/detail/qn/objectives/logistic.cuh @@ -16,7 +16,7 @@ #pragma once -#include "../simple_mat.cuh" +#include #include "base.cuh" #include #include diff --git a/cpp/include/raft/solver/detail/qn/objectives/regularizer.cuh b/cpp/include/raft/solver/detail/qn/objectives/regularizer.cuh index e4acf76672..7bb509a934 100644 --- a/cpp/include/raft/solver/detail/qn/objectives/regularizer.cuh +++ b/cpp/include/raft/solver/detail/qn/objectives/regularizer.cuh @@ -16,7 +16,7 @@ #pragma once -#include "../simple_mat.cuh" +#include #include "base.cuh" #include #include diff --git a/cpp/include/raft/solver/detail/qn/objectives/softmax.cuh b/cpp/include/raft/solver/detail/qn/objectives/softmax.cuh index 70eeb9d6e6..2e53881c2a 100644 --- a/cpp/include/raft/solver/detail/qn/objectives/softmax.cuh +++ b/cpp/include/raft/solver/detail/qn/objectives/softmax.cuh @@ -16,7 +16,7 @@ #pragma once -#include "../simple_mat.cuh" +#include #include "base.cuh" #include #include diff --git a/cpp/include/raft/solver/detail/qn/qn_solvers.cuh b/cpp/include/raft/solver/detail/qn/qn_solvers.cuh index ce64590127..833c1170ef 100644 --- a/cpp/include/raft/solver/detail/qn/qn_solvers.cuh +++ b/cpp/include/raft/solver/detail/qn/qn_solvers.cuh @@ -42,7 +42,7 @@ #include "qn_linesearch.cuh" #include "qn_util.cuh" -#include "simple_mat.cuh" +#include #include #include #include diff --git a/cpp/include/raft/solver/detail/qn/simple_mat.cuh b/cpp/include/raft/solver/detail/qn/simple_mat.cuh deleted file mode 100644 index f455f6a1e1..0000000000 --- a/cpp/include/raft/solver/detail/qn/simple_mat.cuh +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright (c) 2018-2021, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "simple_mat/base.hpp" -#include "simple_mat/dense.hpp" -#include "simple_mat/sparse.hpp" diff --git a/cpp/include/raft/solver/detail/qn/simple_mat/base.hpp b/cpp/include/raft/solver/detail/qn/simple_mat/base.hpp deleted file mode 100644 index 2e9ae5dfcd..0000000000 --- a/cpp/include/raft/solver/detail/qn/simple_mat/base.hpp +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include - -namespace raft::solver::detail { - -template -struct SimpleDenseMat; - -template -struct SimpleMat { - int m, n; - - SimpleMat(int m, int n) : m(m), n(n) {} - - void operator=(const SimpleMat& other) = delete; - - virtual void print(std::ostream& oss) const = 0; - - /** - * GEMM assigning to C where `this` refers to B. - * - * ``` - * C <- alpha * A^transA * (*this)^transB + beta * C - * ``` - */ - virtual void gemmb(const raft::handle_t& handle, - const T alpha, - const SimpleDenseMat& A, - const bool transA, - const bool transB, - const T beta, - SimpleDenseMat& C, - cudaStream_t stream) const = 0; -}; - -}; // namespace raft::solver::detail diff --git a/cpp/include/raft/solver/detail/qn/simple_mat/dense.hpp b/cpp/include/raft/solver/detail/qn/simple_mat/dense.hpp deleted file mode 100644 index 971737a259..0000000000 --- a/cpp/include/raft/solver/detail/qn/simple_mat/dense.hpp +++ /dev/null @@ -1,413 +0,0 @@ -/* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include - -#include "base.hpp" -#include -#include -#include -#include -#include -// #TODO: Replace with public header when ready -#include -#include -#include -#include -#include - -namespace raft::solver::detail { - -enum STORAGE_ORDER { COL_MAJOR = 0, ROW_MAJOR = 1 }; - -template -struct SimpleDenseMat : SimpleMat { - typedef SimpleMat Super; - int len; - T* data; - - STORAGE_ORDER ord; // storage order: runtime param for compile time sake - - SimpleDenseMat(STORAGE_ORDER order = COL_MAJOR) : Super(0, 0), data(nullptr), len(0), ord(order) - { - } - - SimpleDenseMat(T* data, int m, int n, STORAGE_ORDER order = COL_MAJOR) - : Super(m, n), data(data), len(m * n), ord(order) - { - } - - void reset(T* data_, int m_, int n_) - { - this->m = m_; - this->n = n_; - data = data_; - len = m_ * n_; - } - - // Implemented GEMM as a static method here to improve readability - inline static void gemm(const raft::handle_t& handle, - const T alpha, - const SimpleDenseMat& A, - const bool transA, - const SimpleDenseMat& B, - const bool transB, - const T beta, - SimpleDenseMat& C, - cudaStream_t stream) - { - int kA = A.n; - int kB = B.m; - - if (transA) { - ASSERT(A.n == C.m, "GEMM invalid dims: m"); - kA = A.m; - } else { - ASSERT(A.m == C.m, "GEMM invalid dims: m"); - } - - if (transB) { - ASSERT(B.m == C.n, "GEMM invalid dims: n"); - kB = B.n; - } else { - ASSERT(B.n == C.n, "GEMM invalid dims: n"); - } - ASSERT(kA == kB, "GEMM invalid dims: k"); - - if (A.ord == COL_MAJOR && B.ord == COL_MAJOR && C.ord == COL_MAJOR) { - // #TODO: Call from public API when ready - raft::linalg::detail::cublasgemm(handle.get_cublas_handle(), // handle - transA ? CUBLAS_OP_T : CUBLAS_OP_N, // transA - transB ? CUBLAS_OP_T : CUBLAS_OP_N, // transB - C.m, - C.n, - kA, // dimensions m,n,k - &alpha, - A.data, - A.m, // lda - B.data, - B.m, // ldb - &beta, - C.data, - C.m, // ldc, - stream); - return; - } - if (A.ord == ROW_MAJOR) { - const SimpleDenseMat Acm(A.data, A.n, A.m, COL_MAJOR); - gemm(handle, alpha, Acm, !transA, B, transB, beta, C, stream); - return; - } - if (B.ord == ROW_MAJOR) { - const SimpleDenseMat Bcm(B.data, B.n, B.m, COL_MAJOR); - gemm(handle, alpha, A, transA, Bcm, !transB, beta, C, stream); - return; - } - if (C.ord == ROW_MAJOR) { - SimpleDenseMat Ccm(C.data, C.n, C.m, COL_MAJOR); - gemm(handle, alpha, B, !transB, A, !transA, beta, Ccm, stream); - return; - } - } - - inline void gemmb(const raft::handle_t& handle, - const T alpha, - const SimpleDenseMat& A, - const bool transA, - const bool transB, - const T beta, - SimpleDenseMat& C, - cudaStream_t stream) const override - { - SimpleDenseMat::gemm(handle, alpha, A, transA, *this, transB, beta, C, stream); - } - - /** - * GEMM assigning to C where `this` refers to C. - * - * ``` - * *this <- alpha * A^transA * B^transB + beta * (*this) - * ``` - */ - inline void assign_gemm(const raft::handle_t& handle, - const T alpha, - const SimpleDenseMat& A, - const bool transA, - const SimpleMat& B, - const bool transB, - const T beta, - cudaStream_t stream) - { - B.gemmb(handle, alpha, A, transA, transB, beta, *this, stream); - } - - // this = a*x - inline void ax(const T a, const SimpleDenseMat& x, cudaStream_t stream) - { - ASSERT(ord == x.ord, "SimpleDenseMat::ax: Storage orders must match"); - - auto scale = [a] __device__(const T x) { return a * x; }; - raft::linalg::unaryOp(data, x.data, len, scale, stream); - } - - // this = a*x + y - inline void axpy(const T a, - const SimpleDenseMat& x, - const SimpleDenseMat& y, - cudaStream_t stream) - { - ASSERT(ord == x.ord, "SimpleDenseMat::axpy: Storage orders must match"); - ASSERT(ord == y.ord, "SimpleDenseMat::axpy: Storage orders must match"); - - auto axpy = [a] __device__(const T x, const T y) { return a * x + y; }; - raft::linalg::binaryOp(data, x.data, y.data, len, axpy, stream); - } - - template - inline void assign_unary(const SimpleDenseMat& other, Lambda f, cudaStream_t stream) - { - ASSERT(ord == other.ord, "SimpleDenseMat::assign_unary: Storage orders must match"); - - raft::linalg::unaryOp(data, other.data, len, f, stream); - } - - template - inline void assign_binary(const SimpleDenseMat& other1, - const SimpleDenseMat& other2, - Lambda& f, - cudaStream_t stream) - { - ASSERT(ord == other1.ord, "SimpleDenseMat::assign_binary: Storage orders must match"); - ASSERT(ord == other2.ord, "SimpleDenseMat::assign_binary: Storage orders must match"); - - raft::linalg::binaryOp(data, other1.data, other2.data, len, f, stream); - } - - template - inline void assign_ternary(const SimpleDenseMat& other1, - const SimpleDenseMat& other2, - const SimpleDenseMat& other3, - Lambda& f, - cudaStream_t stream) - { - ASSERT(ord == other1.ord, "SimpleDenseMat::assign_ternary: Storage orders must match"); - ASSERT(ord == other2.ord, "SimpleDenseMat::assign_ternary: Storage orders must match"); - ASSERT(ord == other3.ord, "SimpleDenseMat::assign_ternary: Storage orders must match"); - - raft::linalg::ternaryOp(data, other1.data, other2.data, other3.data, len, f, stream); - } - - inline void fill(const T val, cudaStream_t stream) - { - // TODO this reads data unnecessary, though it's mostly used for testing - auto f = [val] __device__(const T x) { return val; }; - raft::linalg::unaryOp(data, data, len, f, stream); - } - - inline void copy_async(const SimpleDenseMat& other, cudaStream_t stream) - { - ASSERT((ord == other.ord) && (this->m == other.m) && (this->n == other.n), - "SimpleDenseMat::copy: matrices not compatible"); - - RAFT_CUDA_TRY( - cudaMemcpyAsync(data, other.data, len * sizeof(T), cudaMemcpyDeviceToDevice, stream)); - } - - void print(std::ostream& oss) const override { oss << (*this) << std::endl; } - - void operator=(const SimpleDenseMat& other) = delete; -}; - -template -struct SimpleVec : SimpleDenseMat { - typedef SimpleDenseMat Super; - - SimpleVec(T* data, const int n) : Super(data, n, 1, COL_MAJOR) {} - // this = alpha * A * x + beta * this - void assign_gemv(const raft::handle_t& handle, - const T alpha, - const SimpleDenseMat& A, - bool transA, - const SimpleVec& x, - const T beta, - cudaStream_t stream) - { - Super::assign_gemm(handle, alpha, A, transA, x, false, beta, stream); - } - - SimpleVec() : Super(COL_MAJOR) {} - - inline void reset(T* new_data, int n) { Super::reset(new_data, n, 1); } -}; - -template -inline void col_ref(const SimpleDenseMat& mat, SimpleVec& mask_vec, int c) -{ - ASSERT(mat.ord == COL_MAJOR, "col_ref only available for column major mats"); - T* tmp = &mat.data[mat.m * c]; - mask_vec.reset(tmp, mat.m); -} - -template -inline void col_slice(const SimpleDenseMat& mat, - SimpleDenseMat& mask_mat, - int c_from, - int c_to) -{ - ASSERT(c_from >= 0 && c_from < mat.n, "col_slice: invalid from"); - ASSERT(c_to >= 0 && c_to <= mat.n, "col_slice: invalid to"); - - ASSERT(mat.ord == COL_MAJOR, "col_ref only available for column major mats"); - ASSERT(mask_mat.ord == COL_MAJOR, "col_ref only available for column major mask"); - T* tmp = &mat.data[mat.m * c_from]; - mask_mat.reset(tmp, mat.m, c_to - c_from); -} - -// Reductions such as dot or norm require an additional location in dev mem -// to hold the result. We don't want to deal with this in the SimpleVec class -// as it impedes thread safety and constness - -template -inline T dot(const SimpleVec& u, const SimpleVec& v, T* tmp_dev, cudaStream_t stream) -{ - auto f = [] __device__(const T x, const T y) { return x * y; }; - raft::linalg::mapThenSumReduce(tmp_dev, u.len, f, stream, u.data, v.data); - T tmp_host; - raft::update_host(&tmp_host, tmp_dev, 1, stream); - - raft::interruptible::synchronize(stream); - return tmp_host; -} - -template -inline T squaredNorm(const SimpleVec& u, T* tmp_dev, cudaStream_t stream) -{ - return dot(u, u, tmp_dev, stream); -} - -template -inline T nrmMax(const SimpleVec& u, T* tmp_dev, cudaStream_t stream) -{ - auto f = [] __device__(const T x) { return raft::myAbs(x); }; - auto r = [] __device__(const T x, const T y) { return raft::myMax(x, y); }; - raft::linalg::mapThenReduce(tmp_dev, u.len, T(0), f, r, stream, u.data); - T tmp_host; - raft::update_host(&tmp_host, tmp_dev, 1, stream); - raft::interruptible::synchronize(stream); - return tmp_host; -} - -template -inline T nrm2(const SimpleVec& u, T* tmp_dev, cudaStream_t stream) -{ - return raft::mySqrt(squaredNorm(u, tmp_dev, stream)); -} - -template -inline T nrm1(const SimpleVec& u, T* tmp_dev, cudaStream_t stream) -{ - raft::linalg::rowNorm( - tmp_dev, u.data, u.len, 1, raft::linalg::L1Norm, true, stream, raft::Nop()); - T tmp_host; - raft::update_host(&tmp_host, tmp_dev, 1, stream); - raft::interruptible::synchronize(stream); - return tmp_host; -} - -template -std::ostream& operator<<(std::ostream& os, const SimpleVec& v) -{ - std::vector out(v.len); - raft::update_host(&out[0], v.data, v.len, 0); - raft::interruptible::synchronize(rmm::cuda_stream_view()); - int it = 0; - for (; it < v.len - 1;) { - os << out[it] << " "; - it++; - } - os << out[it]; - return os; -} - -template -std::ostream& operator<<(std::ostream& os, const SimpleDenseMat& mat) -{ - os << "ord=" << (mat.ord == COL_MAJOR ? "CM" : "RM") << "\n"; - std::vector out(mat.len); - raft::update_host(&out[0], mat.data, mat.len, rmm::cuda_stream_default); - raft::interruptible::synchronize(rmm::cuda_stream_view()); - if (mat.ord == COL_MAJOR) { - for (int r = 0; r < mat.m; r++) { - int idx = r; - for (int c = 0; c < mat.n - 1; c++) { - os << out[idx] << ","; - idx += mat.m; - } - os << out[idx] << std::endl; - } - } else { - for (int c = 0; c < mat.m; c++) { - int idx = c * mat.n; - for (int r = 0; r < mat.n - 1; r++) { - os << out[idx] << ","; - idx += 1; - } - os << out[idx] << std::endl; - } - } - - return os; -} - -template -struct SimpleVecOwning : SimpleVec { - typedef SimpleVec Super; - typedef rmm::device_uvector Buffer; - Buffer buf; - - SimpleVecOwning() = delete; - - SimpleVecOwning(int n, cudaStream_t stream) : Super(), buf(n, stream) - { - Super::reset(buf.data(), n); - } - - void operator=(const SimpleVec& other) = delete; -}; - -template -struct SimpleMatOwning : SimpleDenseMat { - typedef SimpleDenseMat Super; - typedef rmm::device_uvector Buffer; - Buffer buf; - using Super::m; - using Super::n; - using Super::ord; - - SimpleMatOwning() = delete; - - SimpleMatOwning(int m, int n, cudaStream_t stream, STORAGE_ORDER order = COL_MAJOR) - : Super(order), buf(m * n, stream) - { - Super::reset(buf.data(), m, n); - } - - void operator=(const SimpleVec& other) = delete; -}; - -}; // namespace raft::solver::detail diff --git a/cpp/include/raft/solver/detail/qn/simple_mat/sparse.hpp b/cpp/include/raft/solver/detail/qn/simple_mat/sparse.hpp deleted file mode 100644 index 83734b5b7f..0000000000 --- a/cpp/include/raft/solver/detail/qn/simple_mat/sparse.hpp +++ /dev/null @@ -1,216 +0,0 @@ -/* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include - -#include "base.hpp" -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include - -namespace raft::solver::detail { - -/** - * Sparse matrix in CSR format. - * - * Note, we use cuSPARSE to manimulate matrices, and it guarantees: - * - * 1. row_ids[m] == nnz - * 2. cols are sorted within rows. - * - * However, when the data comes from the outside, we cannot guarantee that. - */ -template -struct SimpleSparseMat : SimpleMat { - typedef SimpleMat Super; - T* values; - int* cols; - int* row_ids; - int nnz; - - SimpleSparseMat() : Super(0, 0), values(nullptr), cols(nullptr), row_ids(nullptr), nnz(0) {} - - SimpleSparseMat(T* values, int* cols, int* row_ids, int nnz, int m, int n) - : Super(m, n), values(values), cols(cols), row_ids(row_ids), nnz(nnz) - { - check_csr(*this, 0); - } - - void print(std::ostream& oss) const override { oss << (*this) << std::endl; } - - void operator=(const SimpleSparseMat& other) = delete; - - inline void gemmb(const raft::handle_t& handle, - const T alpha, - const SimpleDenseMat& A, - const bool transA, - const bool transB, - const T beta, - SimpleDenseMat& C, - cudaStream_t stream) const override - { - const SimpleSparseMat& B = *this; - int kA = A.n; - int kB = B.m; - - if (transA) { - ASSERT(A.n == C.m, "GEMM invalid dims: m"); - kA = A.m; - } else { - ASSERT(A.m == C.m, "GEMM invalid dims: m"); - } - - if (transB) { - ASSERT(B.m == C.n, "GEMM invalid dims: n"); - kB = B.n; - } else { - ASSERT(B.n == C.n, "GEMM invalid dims: n"); - } - ASSERT(kA == kB, "GEMM invalid dims: k"); - - // matrix C must change the order and be transposed, because we need - // to swap arguments A and B in cusparseSpMM. - cusparseDnMatDescr_t descrC; - auto order = C.ord == COL_MAJOR ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL; - RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat( - &descrC, C.n, C.m, order == CUSPARSE_ORDER_COL ? C.n : C.m, C.data, order)); - - /* - The matrix A must have the same order as the matrix C in the input - of function cusparseSpMM (i.e. swapped order w.r.t. original C). - To account this requirement, I may need to flip transA (whether to transpose A). - - C C' rowsC' colsC' ldC' A A' rowsA' colsA' ldA' flipTransA - c r n m m c r n m m x - c r n m m r r m n n o - r c n m n c c m n m o - r c n m n r c n m n x - - where: - c/r - column/row major order - A,C - input to gemmb - A', C' - input to cusparseSpMM - ldX' - leading dimension - m or n, depending on order and transX - */ - cusparseDnMatDescr_t descrA; - RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat(&descrA, - C.ord == A.ord ? A.n : A.m, - C.ord == A.ord ? A.m : A.n, - A.ord == COL_MAJOR ? A.m : A.n, - A.data, - order)); - auto opA = - transA ^ (C.ord == A.ord) ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE; - - cusparseSpMatDescr_t descrB; - RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr( - &descrB, B.m, B.n, B.nnz, B.row_ids, B.cols, B.values)); - auto opB = transB ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE; - - auto alg = order == CUSPARSE_ORDER_COL ? CUSPARSE_SPMM_CSR_ALG1 : CUSPARSE_SPMM_CSR_ALG2; - - size_t bufferSize; - RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsespmm_bufferSize(handle.get_cusparse_handle(), - opB, - opA, - &alpha, - descrB, - descrA, - &beta, - descrC, - alg, - &bufferSize, - stream)); - - raft::interruptible::synchronize(stream); - rmm::device_uvector tmp(bufferSize, stream); - - RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsespmm(handle.get_cusparse_handle(), - opB, - opA, - &alpha, - descrB, - descrA, - &beta, - descrC, - alg, - tmp.data(), - stream)); - - RAFT_CUSPARSE_TRY(cusparseDestroyDnMat(descrA)); - RAFT_CUSPARSE_TRY(cusparseDestroySpMat(descrB)); - RAFT_CUSPARSE_TRY(cusparseDestroyDnMat(descrC)); - } -}; - -template -inline void check_csr(const SimpleSparseMat& mat, cudaStream_t stream) -{ - int row_ids_nnz; - raft::update_host(&row_ids_nnz, &mat.row_ids[mat.m], 1, stream); - raft::interruptible::synchronize(stream); - ASSERT(row_ids_nnz == mat.nnz, - "SimpleSparseMat: the size of CSR row_ids array must be `m + 1`, and " - "the last element must be equal nnz."); -} - -template -std::ostream& operator<<(std::ostream& os, const SimpleSparseMat& mat) -{ - check_csr(mat, 0); - os << "SimpleSparseMat (CSR)" - << "\n"; - std::vector values(mat.nnz); - std::vector cols(mat.nnz); - std::vector row_ids(mat.m + 1); - raft::update_host(&values[0], mat.values, mat.nnz, rmm::cuda_stream_default); - raft::update_host(&cols[0], mat.cols, mat.nnz, rmm::cuda_stream_default); - raft::update_host(&row_ids[0], mat.row_ids, mat.m + 1, rmm::cuda_stream_default); - raft::interruptible::synchronize(rmm::cuda_stream_view()); - - int i, row_end = 0; - for (int row = 0; row < mat.m; row++) { - i = row_end; - row_end = row_ids[row + 1]; - for (int col = 0; col < mat.n; col++) { - if (i >= row_end || col < cols[i]) { - os << "0"; - } else { - os << values[i]; - i++; - } - if (col < mat.n - 1) os << ","; - } - - os << std::endl; - } - - return os; -} - -}; // namespace raft::solver::detail diff --git a/cpp/include/raft/solver/quasi_newton.cuh b/cpp/include/raft/solver/quasi_newton.cuh index 090a1a13db..3c3f7b9aa7 100644 --- a/cpp/include/raft/solver/quasi_newton.cuh +++ b/cpp/include/raft/solver/quasi_newton.cuh @@ -23,18 +23,18 @@ #include #include -#include +#include +#include #include namespace raft::solver::quasi_newton { /** - * The follow loss functions are wrapped so they will be included in the docs - * @tparam T + * The following loss functions are wrapped only so they will be included in the docs */ /** - * + * Absolute difference loss function specification * @tparam T */ template @@ -45,8 +45,19 @@ struct AbsLoss : detail::objectives::AbsLoss { } } + /** + * Squared loss function specification + * @tparam T + */ + template + struct SquaredLoss : detail::objectives::SquaredLoss { + SquaredLoss(const raft::handle_t &handle, int D, bool has_bias) + : detail::objectives::SquaredLoss(handle, D, 1, has_bias), lz{}, dlz{} {} + } + + /** - * + * Standard hinge loss function specification * @tparam T */ template @@ -70,7 +81,7 @@ struct LogisticLoss : detail::objectives::LogisticLoss { } /** - * + * Squared hinge loss function specification * @tparam T */ template @@ -81,19 +92,50 @@ struct SqHingeLoss : detail::objectives::SqHingeLoss { } } +/** + * Epsilon insensitive (regression) hinge loss function specification + * @tparam T + */ template -struct EpsInsHingeLoss : QNLinearBase> { - typedef QNLinearBase> Super; - EpsInsHingeLoss(const raft::handle_t& handle, int D, bool has_bias, T sensitivity) - : Super(handle, D, 1, has_bias), lz{sensitivity}, dlz{sensitivity} - { - } +struct EpsInsHingeLoss : detail::objectives::EpsInsHingeLoss> { + EpsInsHingeLoss(const raft::handle_t &handle, int D, bool has_bias, T sensitivity) + : detail::objectives::EpsInsHingeLoss(handle, D, 1, has_bias), lz{sensitivity}, dlz{sensitivity} { + } +} - using raft::solver::quasi_newton::detail::LBFGSParam; - using raft::solver::quasi_newton::detail::objectives::EpsInsHingeLoss; - using raft::solver::quasi_newton::detail::objectives::SqEpsInsHingeLoss; +/** + * Squared Epsilon insensitive (regression) hinge loss function specification + * @tparam T + */ + template + struct SqEpsInsHingeLoss : detail::objectives::SqEpsInsHingeLoss { + SqEpsInsHingeLoss(const raft::handle_t &handle, int D, bool has_bias, T sensitivity) + : detail::objectives::SqEpsInsHingeLoss(handle, D, 1, has_bias), lz{sensitivity}, dlz{sensitivity} { + } + }; /** + * Tikhonov (l2) penalty function + * @tparam T + */ +template +struct Tikhonov : detail::objectives::Tikhonov { + Tikhonov(T l2) : detail::objectives::Tikhonov(l2) {} + + Tikhonov(const Tikhonov &other) : detail::objectives::Tikhonov(other.l2_penalty) {} +}; + + + + /** + * Loss function wrapper that add a penalty to another loss function + * + * Example: + * + * raft::handle_t handle; + * AbsLoss abs_loss(handle, 5, true); + * Tikhonov l2_reg(0.3); + * RegularizedQN(&abs_loss, ®); * * @tparam T * @tparam Loss @@ -105,7 +147,7 @@ struct EpsInsHingeLoss : QNLinearBase> { }; /** - * + * Base loss function that constrains the solution to a linear system * @tparam T * @tparam Loss */ @@ -117,39 +159,123 @@ struct EpsInsHingeLoss : QNLinearBase> { } } - using raft::solver::quasi_newton::detail::objectives::Softmax; - using raft::solver::quasi_newton::detail::objectives::QNLinearBase; - using raft::solver::quasi_newton::detail::objectives::QNWithData; + /** + * Softmax loss function specification + * @tparam T + */ + template + struct Softmax : detail::objectives::Softmax { + Softmax(const raft::handle_t &handle, int D, int C, bool has_bias) + : detail::objectives::Softmax(handle, D, C, has_bias) { + } + } + + /** + * Constructs a end-to-end quasi-newton objective function to solve the system + * AX = b (where each row in X contains the coefficients for each target) + * + * Example: + * + * @tparam T + * @tparam QuasiNewtonObjective + */ + template + struct ObjectiveWithData : detail::objectives::QNWithData { + + ObjectiveWithData(QuasiNewtonObjective *obj, + const SimpleMat &A, + const SimpleVec &b, + SimpleDenseMat &X) + : detail::objectives::QNWithData(obj->C, obj->D, obj->fit_intercept) { + } + } + /** + * @brief Minimize the given `raft::solver::quasi_newton::ObjectiveWithData` using + * the Limited-Memory Broyden-Fletcher-Goldfarb-Shanno algorithm. This algorithm + * estimates the inverse of the Hessian matrix, minimizing the memory footprint from + * the original BFGS algorithm by maintaining only a subset of the update history. + * + * @tparam T + * @tparam Function + * @param param + * @param f + * @param x + * @param fx + * @param k + * @param workspace + * @param stream + * @param verbosity + * @return + */ template - OPT_RETCODE lbfgs_minimize(const LBFGSParam& param, + OPT_RETCODE lbfgs_minimize(raft::handle_t &handle, + const LBFGSParam& param, Function& f, // function to minimize SimpleVec& x, // initial point, holds result T& fx, // output function value - int* k, // output iterations - SimpleVec& workspace, // scratch space - cudaStream_t stream, - int verbosity = 0) + int* k) { // output iterations + rmm::device_uvector tmp(detail::lbfgs_workspace_size(param, x.len), handle.get_stream()); + SimpleVec workspace(tmp.data(), tmp.size()); + return detail::min_lbfgs(param, f, x, fx, k, workspace, handle.get_stream(), 0); + } + /** + * @brief Minimize the given `ObjectiveWithData` using the Orthant-wise + * Limited-Memory Quasi-Newton algorithm, an L-BFGS variant for fitting + * models with lasso (l1) penalties, enabling it to exploit the sparsity + * of the models. + * + * @tparam T + * @tparam Function + * @param param + * @param f + * @param l1_penalty + * @param pg_limit + * @param x + * @param fx + * @param k + * @return + */ template - OPT_RETCODE owl_minimize(const LBFGSParam& param, + OPT_RETCODE owl_minimize(raft::handle_t &handle, + const LBFGSParam& param, Function& f, const T l1_penalty, const int pg_limit, SimpleVec& x, T& fx, - int* k) - { - } + int* k) { + rmm::device_uvector tmp(detail::owlqn_workspace_size(opt_param, x.len), stream); + SimpleVec workspace(tmp.data(), tmp.size()); + return detail::min_owlqn(param, f, l1_penalty, pg_limit, x, fx, k, workspace, handle.get_stream(), 0); + } + + /** + * @brief Simple wrapper function that chooses the quasi-newton solver to use + * based on the presence of the L1 penalty term. + * @tparam T + * @tparam LossFunction + * @param handle + * @param x + * @param fx + * @param num_iters + * @param loss + * @param l1 + * @param opt_param + * @return + */ template - int qn_minimize(const raft::handle_t& handle, - T* x, - T* fx, - int* num_iters, - LossFunction& loss, - const T l1, - const detail::LBFGSParam& opt_param) - { + inline int qn_minimize(const raft::handle_t& handle, + SimpleVec& x, + T* fx, + int* num_iters, + LossFunction& loss, + const T l1, + const LBFGSParam& opt_param, + cudaStream_t stream, + const int verbosity = 0) { + return detail::qn_minimize(handle, x, fx, num_iters, loss, l1, opt_param, handle.get_stream(), 0); } } \ No newline at end of file diff --git a/cpp/include/raft/solver/simple_mat.cuh b/cpp/include/raft/solver/simple_mat.cuh new file mode 100644 index 0000000000..1a80dbc78e --- /dev/null +++ b/cpp/include/raft/solver/simple_mat.cuh @@ -0,0 +1,622 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +// #TODO: Replace with public header when ready +#include +#include +#include +#include +#include + +/** + * NOTE: This will eventually get replaced with mdspan/mdarray + */ + +namespace raft::solver { + + template + struct SimpleMat { + int m, n; + + SimpleMat(int m, int n) : m(m), n(n) {} + + void operator=(const SimpleMat& other) = delete; + + virtual void print(std::ostream& oss) const = 0; + + /** + * GEMM assigning to C where `this` refers to B. + * + * ``` + * C <- alpha * A^transA * (*this)^transB + beta * C + * ``` + */ + virtual void gemmb(const raft::handle_t& handle, + const T alpha, + const SimpleDenseMat& A, + const bool transA, + const bool transB, + const T beta, + SimpleDenseMat& C, + cudaStream_t stream) const = 0; + }; + + enum STORAGE_ORDER { COL_MAJOR = 0, ROW_MAJOR = 1 }; + + template + struct SimpleDenseMat : SimpleMat { + typedef SimpleMat Super; + int len; + T* data; + + STORAGE_ORDER ord; // storage order: runtime param for compile time sake + + SimpleDenseMat(STORAGE_ORDER order = COL_MAJOR) : Super(0, 0), data(nullptr), len(0), ord(order) + { + } + + SimpleDenseMat(T* data, int m, int n, STORAGE_ORDER order = COL_MAJOR) + : Super(m, n), data(data), len(m * n), ord(order) + { + } + + void reset(T* data_, int m_, int n_) + { + this->m = m_; + this->n = n_; + data = data_; + len = m_ * n_; + } + + // Implemented GEMM as a static method here to improve readability + inline static void gemm(const raft::handle_t& handle, + const T alpha, + const SimpleDenseMat& A, + const bool transA, + const SimpleDenseMat& B, + const bool transB, + const T beta, + SimpleDenseMat& C, + cudaStream_t stream) + { + int kA = A.n; + int kB = B.m; + + if (transA) { + ASSERT(A.n == C.m, "GEMM invalid dims: m"); + kA = A.m; + } else { + ASSERT(A.m == C.m, "GEMM invalid dims: m"); + } + + if (transB) { + ASSERT(B.m == C.n, "GEMM invalid dims: n"); + kB = B.n; + } else { + ASSERT(B.n == C.n, "GEMM invalid dims: n"); + } + ASSERT(kA == kB, "GEMM invalid dims: k"); + + if (A.ord == COL_MAJOR && B.ord == COL_MAJOR && C.ord == COL_MAJOR) { + // #TODO: Call from public API when ready + raft::linalg::detail::cublasgemm(handle.get_cublas_handle(), // handle + transA ? CUBLAS_OP_T : CUBLAS_OP_N, // transA + transB ? CUBLAS_OP_T : CUBLAS_OP_N, // transB + C.m, + C.n, + kA, // dimensions m,n,k + &alpha, + A.data, + A.m, // lda + B.data, + B.m, // ldb + &beta, + C.data, + C.m, // ldc, + stream); + return; + } + if (A.ord == ROW_MAJOR) { + const SimpleDenseMat Acm(A.data, A.n, A.m, COL_MAJOR); + gemm(handle, alpha, Acm, !transA, B, transB, beta, C, stream); + return; + } + if (B.ord == ROW_MAJOR) { + const SimpleDenseMat Bcm(B.data, B.n, B.m, COL_MAJOR); + gemm(handle, alpha, A, transA, Bcm, !transB, beta, C, stream); + return; + } + if (C.ord == ROW_MAJOR) { + SimpleDenseMat Ccm(C.data, C.n, C.m, COL_MAJOR); + gemm(handle, alpha, B, !transB, A, !transA, beta, Ccm, stream); + return; + } + } + + inline void gemmb(const raft::handle_t& handle, + const T alpha, + const SimpleDenseMat& A, + const bool transA, + const bool transB, + const T beta, + SimpleDenseMat& C, + cudaStream_t stream) const override + { + SimpleDenseMat::gemm(handle, alpha, A, transA, *this, transB, beta, C, stream); + } + + /** + * GEMM assigning to C where `this` refers to C. + * + * ``` + * *this <- alpha * A^transA * B^transB + beta * (*this) + * ``` + */ + inline void assign_gemm(const raft::handle_t& handle, + const T alpha, + const SimpleDenseMat& A, + const bool transA, + const SimpleMat& B, + const bool transB, + const T beta, + cudaStream_t stream) + { + B.gemmb(handle, alpha, A, transA, transB, beta, *this, stream); + } + + // this = a*x + inline void ax(const T a, const SimpleDenseMat& x, cudaStream_t stream) + { + ASSERT(ord == x.ord, "SimpleDenseMat::ax: Storage orders must match"); + + auto scale = [a] __device__(const T x) { return a * x; }; + raft::linalg::unaryOp(data, x.data, len, scale, stream); + } + + // this = a*x + y + inline void axpy(const T a, + const SimpleDenseMat& x, + const SimpleDenseMat& y, + cudaStream_t stream) + { + ASSERT(ord == x.ord, "SimpleDenseMat::axpy: Storage orders must match"); + ASSERT(ord == y.ord, "SimpleDenseMat::axpy: Storage orders must match"); + + auto axpy = [a] __device__(const T x, const T y) { return a * x + y; }; + raft::linalg::binaryOp(data, x.data, y.data, len, axpy, stream); + } + + template + inline void assign_unary(const SimpleDenseMat& other, Lambda f, cudaStream_t stream) + { + ASSERT(ord == other.ord, "SimpleDenseMat::assign_unary: Storage orders must match"); + + raft::linalg::unaryOp(data, other.data, len, f, stream); + } + + template + inline void assign_binary(const SimpleDenseMat& other1, + const SimpleDenseMat& other2, + Lambda& f, + cudaStream_t stream) + { + ASSERT(ord == other1.ord, "SimpleDenseMat::assign_binary: Storage orders must match"); + ASSERT(ord == other2.ord, "SimpleDenseMat::assign_binary: Storage orders must match"); + + raft::linalg::binaryOp(data, other1.data, other2.data, len, f, stream); + } + + template + inline void assign_ternary(const SimpleDenseMat& other1, + const SimpleDenseMat& other2, + const SimpleDenseMat& other3, + Lambda& f, + cudaStream_t stream) + { + ASSERT(ord == other1.ord, "SimpleDenseMat::assign_ternary: Storage orders must match"); + ASSERT(ord == other2.ord, "SimpleDenseMat::assign_ternary: Storage orders must match"); + ASSERT(ord == other3.ord, "SimpleDenseMat::assign_ternary: Storage orders must match"); + + raft::linalg::ternaryOp(data, other1.data, other2.data, other3.data, len, f, stream); + } + + inline void fill(const T val, cudaStream_t stream) + { + // TODO this reads data unnecessary, though it's mostly used for testing + auto f = [val] __device__(const T x) { return val; }; + raft::linalg::unaryOp(data, data, len, f, stream); + } + + inline void copy_async(const SimpleDenseMat& other, cudaStream_t stream) + { + ASSERT((ord == other.ord) && (this->m == other.m) && (this->n == other.n), + "SimpleDenseMat::copy: matrices not compatible"); + + RAFT_CUDA_TRY( + cudaMemcpyAsync(data, other.data, len * sizeof(T), cudaMemcpyDeviceToDevice, stream)); + } + + void print(std::ostream& oss) const override { oss << (*this) << std::endl; } + + void operator=(const SimpleDenseMat& other) = delete; + }; + + template + struct SimpleVec : SimpleDenseMat { + typedef SimpleDenseMat Super; + + SimpleVec(T* data, const int n) : Super(data, n, 1, COL_MAJOR) {} + // this = alpha * A * x + beta * this + void assign_gemv(const raft::handle_t& handle, + const T alpha, + const SimpleDenseMat& A, + bool transA, + const SimpleVec& x, + const T beta, + cudaStream_t stream) + { + Super::assign_gemm(handle, alpha, A, transA, x, false, beta, stream); + } + + SimpleVec() : Super(COL_MAJOR) {} + + inline void reset(T* new_data, int n) { Super::reset(new_data, n, 1); } + }; + + template + inline void col_ref(const SimpleDenseMat& mat, SimpleVec& mask_vec, int c) + { + ASSERT(mat.ord == COL_MAJOR, "col_ref only available for column major mats"); + T* tmp = &mat.data[mat.m * c]; + mask_vec.reset(tmp, mat.m); + } + + template + inline void col_slice(const SimpleDenseMat& mat, + SimpleDenseMat& mask_mat, + int c_from, + int c_to) + { + ASSERT(c_from >= 0 && c_from < mat.n, "col_slice: invalid from"); + ASSERT(c_to >= 0 && c_to <= mat.n, "col_slice: invalid to"); + + ASSERT(mat.ord == COL_MAJOR, "col_ref only available for column major mats"); + ASSERT(mask_mat.ord == COL_MAJOR, "col_ref only available for column major mask"); + T* tmp = &mat.data[mat.m * c_from]; + mask_mat.reset(tmp, mat.m, c_to - c_from); + } + +// Reductions such as dot or norm require an additional location in dev mem +// to hold the result. We don't want to deal with this in the SimpleVec class +// as it impedes thread safety and constness + + template + inline T dot(const SimpleVec& u, const SimpleVec& v, T* tmp_dev, cudaStream_t stream) + { + auto f = [] __device__(const T x, const T y) { return x * y; }; + raft::linalg::mapThenSumReduce(tmp_dev, u.len, f, stream, u.data, v.data); + T tmp_host; + raft::update_host(&tmp_host, tmp_dev, 1, stream); + + raft::interruptible::synchronize(stream); + return tmp_host; + } + + template + inline T squaredNorm(const SimpleVec& u, T* tmp_dev, cudaStream_t stream) + { + return dot(u, u, tmp_dev, stream); + } + + template + inline T nrmMax(const SimpleVec& u, T* tmp_dev, cudaStream_t stream) + { + auto f = [] __device__(const T x) { return raft::myAbs(x); }; + auto r = [] __device__(const T x, const T y) { return raft::myMax(x, y); }; + raft::linalg::mapThenReduce(tmp_dev, u.len, T(0), f, r, stream, u.data); + T tmp_host; + raft::update_host(&tmp_host, tmp_dev, 1, stream); + raft::interruptible::synchronize(stream); + return tmp_host; + } + + template + inline T nrm2(const SimpleVec& u, T* tmp_dev, cudaStream_t stream) + { + return raft::mySqrt(squaredNorm(u, tmp_dev, stream)); + } + + template + inline T nrm1(const SimpleVec& u, T* tmp_dev, cudaStream_t stream) + { + raft::linalg::rowNorm( + tmp_dev, u.data, u.len, 1, raft::linalg::L1Norm, true, stream, raft::Nop()); + T tmp_host; + raft::update_host(&tmp_host, tmp_dev, 1, stream); + raft::interruptible::synchronize(stream); + return tmp_host; + } + + template + std::ostream& operator<<(std::ostream& os, const SimpleVec& v) + { + std::vector out(v.len); + raft::update_host(&out[0], v.data, v.len, 0); + raft::interruptible::synchronize(rmm::cuda_stream_view()); + int it = 0; + for (; it < v.len - 1;) { + os << out[it] << " "; + it++; + } + os << out[it]; + return os; + } + + template + std::ostream& operator<<(std::ostream& os, const SimpleDenseMat& mat) + { + os << "ord=" << (mat.ord == COL_MAJOR ? "CM" : "RM") << "\n"; + std::vector out(mat.len); + raft::update_host(&out[0], mat.data, mat.len, rmm::cuda_stream_default); + raft::interruptible::synchronize(rmm::cuda_stream_view()); + if (mat.ord == COL_MAJOR) { + for (int r = 0; r < mat.m; r++) { + int idx = r; + for (int c = 0; c < mat.n - 1; c++) { + os << out[idx] << ","; + idx += mat.m; + } + os << out[idx] << std::endl; + } + } else { + for (int c = 0; c < mat.m; c++) { + int idx = c * mat.n; + for (int r = 0; r < mat.n - 1; r++) { + os << out[idx] << ","; + idx += 1; + } + os << out[idx] << std::endl; + } + } + + return os; + } + + template + struct SimpleVecOwning : SimpleVec { + typedef SimpleVec Super; + typedef rmm::device_uvector Buffer; + Buffer buf; + + SimpleVecOwning() = delete; + + SimpleVecOwning(int n, cudaStream_t stream) : Super(), buf(n, stream) + { + Super::reset(buf.data(), n); + } + + void operator=(const SimpleVec& other) = delete; + }; + + template + struct SimpleMatOwning : SimpleDenseMat { + typedef SimpleDenseMat Super; + typedef rmm::device_uvector Buffer; + Buffer buf; + using Super::m; + using Super::n; + using Super::ord; + + SimpleMatOwning() = delete; + + SimpleMatOwning(int m, int n, cudaStream_t stream, STORAGE_ORDER order = COL_MAJOR) + : Super(order), buf(m * n, stream) + { + Super::reset(buf.data(), m, n); + } + + void operator=(const SimpleVec& other) = delete; + }; + + /** + * Sparse matrix in CSR format. + * + * Note, we use cuSPARSE to manimulate matrices, and it guarantees: + * + * 1. row_ids[m] == nnz + * 2. cols are sorted within rows. + * + * However, when the data comes from the outside, we cannot guarantee that. + */ + template + struct SimpleSparseMat : SimpleMat { + typedef SimpleMat Super; + T* values; + int* cols; + int* row_ids; + int nnz; + + SimpleSparseMat() : Super(0, 0), values(nullptr), cols(nullptr), row_ids(nullptr), nnz(0) {} + + SimpleSparseMat(T* values, int* cols, int* row_ids, int nnz, int m, int n) + : Super(m, n), values(values), cols(cols), row_ids(row_ids), nnz(nnz) + { + check_csr(*this, 0); + } + + void print(std::ostream& oss) const override { oss << (*this) << std::endl; } + + void operator=(const SimpleSparseMat& other) = delete; + + inline void gemmb(const raft::handle_t& handle, + const T alpha, + const SimpleDenseMat& A, + const bool transA, + const bool transB, + const T beta, + SimpleDenseMat& C, + cudaStream_t stream) const override + { + const SimpleSparseMat& B = *this; + int kA = A.n; + int kB = B.m; + + if (transA) { + ASSERT(A.n == C.m, "GEMM invalid dims: m"); + kA = A.m; + } else { + ASSERT(A.m == C.m, "GEMM invalid dims: m"); + } + + if (transB) { + ASSERT(B.m == C.n, "GEMM invalid dims: n"); + kB = B.n; + } else { + ASSERT(B.n == C.n, "GEMM invalid dims: n"); + } + ASSERT(kA == kB, "GEMM invalid dims: k"); + + // matrix C must change the order and be transposed, because we need + // to swap arguments A and B in cusparseSpMM. + cusparseDnMatDescr_t descrC; + auto order = C.ord == COL_MAJOR ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL; + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat( + &descrC, C.n, C.m, order == CUSPARSE_ORDER_COL ? C.n : C.m, C.data, order)); + + /* + The matrix A must have the same order as the matrix C in the input + of function cusparseSpMM (i.e. swapped order w.r.t. original C). + To account this requirement, I may need to flip transA (whether to transpose A). + + C C' rowsC' colsC' ldC' A A' rowsA' colsA' ldA' flipTransA + c r n m m c r n m m x + c r n m m r r m n n o + r c n m n c c m n m o + r c n m n r c n m n x + + where: + c/r - column/row major order + A,C - input to gemmb + A', C' - input to cusparseSpMM + ldX' - leading dimension - m or n, depending on order and transX + */ + cusparseDnMatDescr_t descrA; + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat(&descrA, + C.ord == A.ord ? A.n : A.m, + C.ord == A.ord ? A.m : A.n, + A.ord == COL_MAJOR ? A.m : A.n, + A.data, + order)); + auto opA = + transA ^ (C.ord == A.ord) ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE; + + cusparseSpMatDescr_t descrB; + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr( + &descrB, B.m, B.n, B.nnz, B.row_ids, B.cols, B.values)); + auto opB = transB ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE; + + auto alg = order == CUSPARSE_ORDER_COL ? CUSPARSE_SPMM_CSR_ALG1 : CUSPARSE_SPMM_CSR_ALG2; + + size_t bufferSize; + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsespmm_bufferSize(handle.get_cusparse_handle(), + opB, + opA, + &alpha, + descrB, + descrA, + &beta, + descrC, + alg, + &bufferSize, + stream)); + + raft::interruptible::synchronize(stream); + rmm::device_uvector tmp(bufferSize, stream); + + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsespmm(handle.get_cusparse_handle(), + opB, + opA, + &alpha, + descrB, + descrA, + &beta, + descrC, + alg, + tmp.data(), + stream)); + + RAFT_CUSPARSE_TRY(cusparseDestroyDnMat(descrA)); + RAFT_CUSPARSE_TRY(cusparseDestroySpMat(descrB)); + RAFT_CUSPARSE_TRY(cusparseDestroyDnMat(descrC)); + } + }; + + template + inline void check_csr(const SimpleSparseMat& mat, cudaStream_t stream) + { + int row_ids_nnz; + raft::update_host(&row_ids_nnz, &mat.row_ids[mat.m], 1, stream); + raft::interruptible::synchronize(stream); + ASSERT(row_ids_nnz == mat.nnz, + "SimpleSparseMat: the size of CSR row_ids array must be `m + 1`, and " + "the last element must be equal nnz."); + } + + template + std::ostream& operator<<(std::ostream& os, const SimpleSparseMat& mat) + { + check_csr(mat, 0); + os << "SimpleSparseMat (CSR)" + << "\n"; + std::vector values(mat.nnz); + std::vector cols(mat.nnz); + std::vector row_ids(mat.m + 1); + raft::update_host(&values[0], mat.values, mat.nnz, rmm::cuda_stream_default); + raft::update_host(&cols[0], mat.cols, mat.nnz, rmm::cuda_stream_default); + raft::update_host(&row_ids[0], mat.row_ids, mat.m + 1, rmm::cuda_stream_default); + raft::interruptible::synchronize(rmm::cuda_stream_view()); + + int i, row_end = 0; + for (int row = 0; row < mat.m; row++) { + i = row_end; + row_end = row_ids[row + 1]; + for (int col = 0; col < mat.n; col++) { + if (i >= row_end || col < cols[i]) { + os << "0"; + } else { + os << values[i]; + i++; + } + if (col < mat.n - 1) os << ","; + } + + os << std::endl; + } + + return os; + } + + +}; // namespace raft::solver diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 07ec85bf1e..3098643cb7 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -182,7 +182,8 @@ if(BUILD_TESTS) PATH test/cluster_solvers_deprecated.cu test/eigen_solvers.cu - test/lap/lap.cu + test/solver/lap.cu + test/solver/quasi_newton.cu test/mst.cu ) diff --git a/cpp/test/lap/lap.cu b/cpp/test/solver/lap.cu similarity index 100% rename from cpp/test/lap/lap.cu rename to cpp/test/solver/lap.cu diff --git a/cpp/test/solver/quasi_newton.cu b/cpp/test/solver/quasi_newton.cu new file mode 100644 index 0000000000..72f9ad9c6d --- /dev/null +++ b/cpp/test/solver/quasi_newton.cu @@ -0,0 +1,833 @@ +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::solver::quasi_newton { + + template + int qn_fit(const raft::handle_t& handle, + const qn_params& pams, + LossFunction& loss, + const SimpleMat& X, + const SimpleVec& y, + SimpleDenseMat& Z, + T* w0_data, // initial value and result + T* fx, + int* num_iters) + { + LBFGSParam opt_param(pams); + SimpleVec w0(w0_data, loss.n_param); + + // Scale the regularization strenght with the number of samples. + T l1 = pams.penalty_l1; + T l2 = pams.penalty_l2; + if (pams.penalty_normalized) { + l1 /= X.m; + l2 /= X.m; + } + + if (l2 == 0) { + ObjectiveWithData lossWith(&loss, X, y, Z); + + return qn_minimize(handle, w0, fx, num_iters, lossWith, l1, opt_param); + + } else { + Tikhonov reg(l2); + RegularizedQN obj(&loss, ®); + ObjectiveWithData lossWith(&obj, X, y, Z); + + return qn_minimize(handle, w0, fx, num_iters, lossWith, l1, opt_param); + } + } + + template + inline void qn_fit_x(const raft::handle_t& handle, + const qn_params& pams, + SimpleMat& X, + T* y_data, + int C, + T* w0_data, + T* f, + int* num_iters, + cudaStream_t stream, + T* sample_weight = nullptr, + T svr_eps = 0) + { + /* + NB: + N - number of data rows + D - number of data columns (features) + C - number of output classes + + X in R^[N, D] + w in R^[D, C] + y in {0, 1}^[N, C] or {cat}^N + + Dimensionality of w0 depends on loss, so we initialize it later. + */ + int N = X.m; + int D = X.n; + int n_targets = qn_is_classification(pams.loss) && C == 2 ? 1 : C; + rmm::device_uvector tmp(n_targets * N, stream); + SimpleDenseMat Z(tmp.data(), n_targets, N); + SimpleVec y(y_data, N); + + switch (pams.loss) { + case QN_LOSS_LOGISTIC: { + ASSERT(C == 2, "qn.h: logistic loss invalid C"); + LogisticLoss loss(handle, D, pams.fit_intercept); + if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + } break; + case QN_LOSS_SQUARED: { + ASSERT(C == 1, "qn.h: squared loss invalid C"); + SquaredLoss loss(handle, D, pams.fit_intercept); + if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + } break; + case QN_LOSS_SOFTMAX: { + ASSERT(C > 2, "qn.h: softmax invalid C"); + Softmax loss(handle, D, C, pams.fit_intercept); + if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + } break; + case QN_LOSS_SVC_L1: { + ASSERT(C == 2, "qn.h: SVC-L1 loss invalid C"); + HingeLoss loss(handle, D, pams.fit_intercept); + if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + } break; + case QN_LOSS_SVC_L2: { + ASSERT(C == 2, "qn.h: SVC-L2 loss invalid C"); + SqHingeLoss loss(handle, D, pams.fit_intercept); + if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + } break; + case QN_LOSS_SVR_L1: { + ASSERT(C == 1, "qn.h: SVR-L1 loss invalid C"); + EpsInsHingeLoss loss(handle, D, pams.fit_intercept, svr_eps); + if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + } break; + case QN_LOSS_SVR_L2: { + ASSERT(C == 1, "qn.h: SVR-L2 loss invalid C"); + SqEpsInsHingeLoss loss(handle, D, pams.fit_intercept, svr_eps); + if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + } break; + case QN_LOSS_ABS: { + ASSERT(C == 1, "qn.h: abs loss (L1) invalid C"); + AbsLoss loss(handle, D, pams.fit_intercept); + if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + } break; + default: { + ASSERT(false, "qn.h: unknown loss function type (id = %d).", pams.loss); + } + } + } + + + + struct QuasiNewtonTest : ::testing::Test { + static constexpr int N = 10; + static constexpr int D = 2; + + const static double* nobptr; + const static double tol; + const static double X[N][D]; + const raft::handle_t& handle; + cudaStream_t stream = 0; + std::shared_ptr> Xdev; + std::shared_ptr> ydev; + + QuasiNewtonTest() {} + void SetUp() + { + stream = handle.get_stream(); + Xdev.reset(new SimpleMatOwning(N, D, stream, ROW_MAJOR)); + raft::update_device(Xdev->data, &X[0][0], Xdev->len, stream); + + ydev.reset(new SimpleVecOwning(N, stream)); + handle.sync_stream(stream); + } + void TearDown() {} + }; + + const double* QuasiNewtonTest::nobptr = 0; + const double QuasiNewtonTest::tol = 5e-6; + const double QuasiNewtonTest::X[QuasiNewtonTest::N][QuasiNewtonTest::D] = { + {-0.2047076594847130, 0.4789433380575482}, + {-0.5194387150567381, -0.5557303043474900}, + {1.9657805725027142, 1.3934058329729904}, + {0.0929078767437177, 0.2817461528302025}, + {0.7690225676118387, 1.2464347363862822}, + {1.0071893575830049, -1.2962211091122635}, + {0.2749916334321240, 0.2289128789353159}, + {1.3529168351654497, 0.8864293405915888}, + {-2.0016373096603974, -0.3718425371402544}, + {1.6690253095248706, -0.4385697358355719}}; + + template + ::testing::AssertionResult checkParamsEqual(const raft::handle_t& handle, + const T* host_weights, + const T* host_bias, + const T* w, + const GLMDims& dims, + Comp& comp, + cudaStream_t stream) + { + int C = dims.C; + int D = dims.D; + bool fit_intercept = dims.fit_intercept; + std::vector w_ref_cm(C * D); + int idx = 0; + for (int d = 0; d < D; d++) + for (int c = 0; c < C; c++) { + w_ref_cm[idx++] = host_weights[c * D + d]; + } + + SimpleVecOwning w_ref(dims.n_param, stream); + raft::update_device(w_ref.data, &w_ref_cm[0], C * D, stream); + if (fit_intercept) { raft::update_device(&w_ref.data[C * D], host_bias, C, stream); } + handle.sync_stream(stream); + return raft::devArrMatch(w_ref.data, w, w_ref.len, comp); + } + + template + T run(const raft::handle_t& handle, + LossFunction& loss, + const SimpleMat& X, + const SimpleVec& y, + T l1, + T l2, + T* w, + SimpleDenseMat& z) + { + qn_params pams; + pams.max_iter = 100; + pams.grad_tol = 1e-16; + pams.change_tol = 1e-16; + pams.linesearch_max_iter = 50; + pams.lbfgs_memory = 5; + pams.penalty_l1 = l1; + pams.penalty_l2 = l2; + pams.verbose = verbosity; + + int num_iters = 0; + + T fx; + + qn_fit(handle, pams, loss, X, y, z, w, &fx, &num_iters); + + return fx; + } + + template + T run_api(const raft::handle_t& cuml_handle, + qn_loss_type loss_type, + int C, + bool fit_intercept, + const SimpleMat& X, + const SimpleVec& y, + T l1, + T l2, + T* w, + SimpleDenseMat& z, + int verbosity, + cudaStream_t stream) + { + qn_params pams; + + pams.max_iter = 100; + pams.grad_tol = 1e-8; + pams.change_tol = 1e-8; + pams.linesearch_max_iter = 50; + pams.lbfgs_memory = 5; + pams.penalty_l1 = l1; + pams.penalty_l2 = l2; + pams.verbose = verbosity; + pams.fit_intercept = fit_intercept; + pams.loss = loss_type; + + int num_iters = 0; + + SimpleVec w0(w, X.n + fit_intercept); + w0.fill(T(0), stream); + T fx; + + qn_fit_on_x(cuml_handle, + pams, + X_dense->data, + X_dense->ord == COL_MAJOR, + y.data, + X_dense->m, + X_dense->n, + C, + w, + &fx, + &num_iters); + } else { + ADD_FAILURE(); + } + + return fx; + } + + TEST_F(QuasiNewtonTest, binary_logistic_vs_sklearn) + { +#if CUDART_VERSION >= 11020 + GTEST_SKIP(); +#endif + raft::CompareApprox compApprox(tol); + // Test case generated in python and solved with sklearn + double y[N] = {1, 1, 1, 0, 1, 0, 1, 0, 1, 0}; + raft::update_device(ydev->data, &y[0], ydev->len, stream); + handle.sync_stream(stream); + + double alpha = 0.01 * N; + + LogisticLoss loss_b(handle, D, true); + LogisticLoss loss_no_b(handle, D, false); + + SimpleVecOwning w0(D + 1, stream); + SimpleMatOwning z(1, N, stream); + + double l1, l2, fx; + + double w_l1_b[2] = {-1.6899370396155091, 1.9021577534928300}; + double b_l1_b = 0.8057670813749118; + double obj_l1_b = 0.44295941481024703; + + l1 = alpha; + l2 = 0.0; + fx = run(handle, loss_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l1_b, fx)); + ASSERT_TRUE(checkParamsEqual(handle, &w_l1_b[0], &b_l1_b, w0.data, loss_b, compApprox, stream)); + + fx = run_api(handle, + QN_LOSS_LOGISTIC, + 2, + loss_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l1_b, fx)); + + double w_l2_b[2] = {-1.5339880402781370, 1.6788639581350926}; + double b_l2_b = 0.806087868102401; + double obj_l2_b = 0.4378085369889721; + + l1 = 0; + l2 = alpha; + fx = run(handle, loss_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + + ASSERT_TRUE(compApprox(obj_l2_b, fx)); + ASSERT_TRUE(checkParamsEqual(handle, &w_l2_b[0], &b_l2_b, w0.data, loss_b, compApprox, stream)); + + fx = run_api(cuml_handle, + QN_LOSS_LOGISTIC, + 2, + loss_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l2_b, fx)); + + double w_l1_no_b[2] = {-1.6215035298864591, 2.3650868394981086}; + double obj_l1_no_b = 0.4769896009200278; + + l1 = alpha; + l2 = 0.0; + fx = run(handle, loss_no_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l1_no_b, fx)); + ASSERT_TRUE( + checkParamsEqual(handle, &w_l1_no_b[0], nobptr, w0.data, loss_no_b, compApprox, stream)); + + fx = run_api(cuml_handle, + QN_LOSS_LOGISTIC, + 2, + loss_no_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l1_no_b, fx)); + + double w_l2_no_b[2] = {-1.3931049893764620, 2.0140103094119621}; + double obj_l2_no_b = 0.47502098062114273; + + l1 = 0; + l2 = alpha; + fx = run(handle, loss_no_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l2_no_b, fx)); + ASSERT_TRUE( + checkParamsEqual(handle, &w_l2_no_b[0], nobptr, w0.data, loss_no_b, compApprox, stream)); + + fx = run_api(cuml_handle, + QN_LOSS_LOGISTIC, + 2, + loss_no_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l2_no_b, fx)); + } + + TEST_F(QuasiNewtonTest, multiclass_logistic_vs_sklearn) +{ +#if CUDART_VERSION >= 11020 + GTEST_SKIP(); +#endif + // The data seems to small for the objective to be strongly convex + // leaving out exact param checks + + raft::CompareApprox compApprox(tol); + double y[N] = {2, 2, 0, 3, 3, 0, 0, 0, 1, 0}; + raft::update_device(ydev->data, &y[0], ydev->len, stream); + handle.sync_stream(stream); + + double fx, l1, l2; + int C = 4; + + double alpha = 0.016 * N; + + SimpleMatOwning z(C, N, stream); + SimpleVecOwning w0(C * (D + 1), stream); + + Softmax loss_b(handle, D, C, true); + Softmax loss_no_b(handle, D, C, false); + + l1 = alpha; + l2 = 0.0; + double obj_l1_b = 0.5407911382311313; + + fx = run(handle, loss_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l1_b, fx)); + + fx = run_api(cuml_handle, + QN_LOSS_SOFTMAX, + C, + loss_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l1_b, fx)); + + l1 = 0.0; + l2 = alpha; + double obj_l2_b = 0.5721784062720949; + + fx = run(handle, loss_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l2_b, fx)); + + fx = run_api(cuml_handle, + QN_LOSS_SOFTMAX, + C, + loss_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l2_b, fx)); + + l1 = alpha; + l2 = 0.0; + double obj_l1_no_b = 0.6606929813245878; + + fx = run(handle, loss_no_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l1_no_b, fx)); + + fx = run_api(cuml_handle, + QN_LOSS_SOFTMAX, + C, + loss_no_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l1_no_b, fx)); + + l1 = 0.0; + l2 = alpha; + + double obj_l2_no_b = 0.6597171282106854; + + fx = run(handle, loss_no_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l2_no_b, fx)); + + fx = run_api(cuml_handle, + QN_LOSS_SOFTMAX, + C, + loss_no_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l2_no_b, fx)); +} + +TEST_F(QuasiNewtonTest, linear_regression_vs_sklearn) +{ +raft::CompareApprox compApprox(tol); +double y[N] = {0.2675836026202781, + -0.0678277759663704, + -0.6334027174275105, + -0.1018336189077367, + 0.0933815935886932, + -1.1058853496996381, + -0.1658298189619160, + -0.2954290675648911, + 0.7966520536712608, + -1.0767450516284769}; +raft::update_device(ydev->data, &y[0], ydev->len, stream); +handle.sync_stream(stream); + +double fx, l1, l2; +double alpha = 0.01 * N; + +SimpleVecOwning w0(D + 1, stream); +SimpleMatOwning z(1, N, stream); +SquaredLoss loss_b(handle, D, true); +SquaredLoss loss_no_b(handle, D, false); + +l1 = alpha; +l2 = 0.0; +double w_l1_b[2] = {-0.4952397281519840, 0.3813315300180231}; +double b_l1_b = -0.08140861819001188; +double obj_l1_b = 0.011136986298775138; +fx = run(handle, loss_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); +ASSERT_TRUE(compApprox(obj_l1_b, fx)); +ASSERT_TRUE(checkParamsEqual(handle, &w_l1_b[0], &b_l1_b, w0.data, loss_b, compApprox, stream)); + +fx = run_api(cuml_handle, + QN_LOSS_SQUARED, + 1, + loss_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); +ASSERT_TRUE(compApprox(obj_l1_b, fx)); + +l1 = 0.0; +l2 = alpha; +double w_l2_b[2] = {-0.5022384743587150, 0.3937352417485087}; +double b_l2_b = -0.08062397391797513; +double obj_l2_b = 0.004268621967866347; + +fx = run(handle, loss_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); +ASSERT_TRUE(compApprox(obj_l2_b, fx)); +ASSERT_TRUE(checkParamsEqual(handle, &w_l2_b[0], &b_l2_b, w0.data, loss_b, compApprox, stream)); + +fx = run_api(cuml_handle, + QN_LOSS_SQUARED, + 1, + loss_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); +ASSERT_TRUE(compApprox(obj_l2_b, fx)); + +l1 = alpha; +l2 = 0.0; +double w_l1_no_b[2] = {-0.5175178128147135, 0.3720844589831813}; +double obj_l1_no_b = 0.013981355746112447; + +fx = run(handle, loss_no_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); +ASSERT_TRUE(compApprox(obj_l1_no_b, fx)); +ASSERT_TRUE( + checkParamsEqual(handle, &w_l1_no_b[0], nobptr, w0.data, loss_no_b, compApprox, stream)); + +fx = run_api(cuml_handle, + QN_LOSS_SQUARED, + 1, + loss_no_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); +ASSERT_TRUE(compApprox(obj_l1_no_b, fx)); + +l1 = 0.0; +l2 = alpha; +double w_l2_no_b[2] = {-0.5241651041233270, 0.3846317886627560}; +double obj_l2_no_b = 0.007061261366969662; + +fx = run(handle, loss_no_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); +ASSERT_TRUE(compApprox(obj_l2_no_b, fx)); +ASSERT_TRUE( + checkParamsEqual(handle, &w_l2_no_b[0], nobptr, w0.data, loss_no_b, compApprox, stream)); + +fx = run_api(cuml_handle, + QN_LOSS_SQUARED, + 1, + loss_no_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); +ASSERT_TRUE(compApprox(obj_l2_no_b, fx)); +} + +TEST_F(QuasiNewtonTest, predict) +{ +raft::CompareApprox compApprox(1e-8); +std::vector w_host(D); +w_host[0] = 1; +std::vector preds_host(N); +SimpleVecOwning w(D, stream); +SimpleVecOwning preds(N, stream); + +raft::update_device(w.data, &w_host[0], w.len, stream); +qn_params pams; +pams.loss = QN_LOSS_LOGISTIC; +pams.fit_intercept = false; + +qnPredict(handle, pams, Xdev->data, false, N, D, 2, w.data, preds.data, stream); +raft::update_host(&preds_host[0], preds.data, preds.len, stream); +handle.sync_stream(stream); + +for (int it = 0; it < N; it++) { +ASSERT_TRUE(X[it][0] > 0 ? compApprox(preds_host[it], 1) : compApprox(preds_host[it], 0)); +} + +pams.loss = QN_LOSS_SQUARED; +pams.fit_intercept = false; +qnPredict(handle, pams, Xdev->data, false, N, D, 1, w.data, preds.data, stream); +raft::update_host(&preds_host[0], preds.data, preds.len, stream); +handle.sync_stream(stream); + +for (int it = 0; it < N; it++) { +ASSERT_TRUE(compApprox(X[it][0], preds_host[it])); +} +} + +TEST_F(QuasiNewtonTest, predict_softmax) +{ +raft::CompareApprox compApprox(1e-8); +int C = 4; +std::vector w_host(C * D); +w_host[0] = 1; +w_host[D * C - 1] = 1; + +std::vector preds_host(N); +SimpleVecOwning w(w_host.size(), stream); +SimpleVecOwning preds(N, stream); + +raft::update_device(w.data, &w_host[0], w.len, stream); + +qn_params pams; +pams.loss = QN_LOSS_SOFTMAX; +pams.fit_intercept = false; +qnPredict(handle, pams, Xdev->data, false, N, D, C, w.data, preds.data, stream); +raft::update_host(&preds_host[0], preds.data, preds.len, stream); +handle.sync_stream(stream); + +for (int it = 0; it < N; it++) { +if (X[it][0] < 0 && X[it][1] < 0) { +ASSERT_TRUE(compApprox(1, preds_host[it])); +} else if (X[it][0] > X[it][1]) { +ASSERT_TRUE(compApprox(0, preds_host[it])); +} else { +ASSERT_TRUE(compApprox(C - 1, preds_host[it])); +} +} +} + +TEST_F(QuasiNewtonTest, dense_vs_sparse_logistic) +{ +#if CUDART_VERSION >= 11020 +GTEST_SKIP(); +#endif +// Prepare a sparse input matrix from the dense matrix X. +// Yes, it's not sparse at all, yet the test does check whether the behaviour +// of dense and sparse variants is the same. +rmm::device_uvector mem_X_cols(N * D, stream); +rmm::device_uvector mem_X_row_ids(N + 1, stream); +int host_X_cols[N][D]; +int host_X_row_ids[N + 1]; +for (int i = 0; i < N; i++) { +for (int j = 0; j < D; j++) { +host_X_cols[i][j] = j; +} +} +for (int i = 0; i < N + 1; i++) { +host_X_row_ids[i] = i * D; +} +raft::update_device(mem_X_cols.data(), &host_X_cols[0][0], mem_X_cols.size(), stream); +raft::update_device(mem_X_row_ids.data(), &host_X_row_ids[0], mem_X_row_ids.size(), stream); +SimpleSparseMat X_sparse( + Xdev->data, mem_X_cols.data(), mem_X_row_ids.data(), N * D, N, D); + +raft::CompareApprox compApprox(tol); +double y[N] = {2, 2, 0, 3, 3, 0, 0, 0, 1, 0}; +raft::update_device(ydev->data, &y[0], ydev->len, stream); +handle.sync_stream(stream); + +int C = 4; +qn_loss_type loss_type = QN_LOSS_SOFTMAX; // Softmax (loss_b, loss_no_b) +double alpha = 0.016 * N; +Softmax loss_b(handle, D, C, true); +Softmax loss_no_b(handle, D, C, false); + +SimpleMatOwning z_dense(C, N, stream); +SimpleMatOwning z_sparse(C, N, stream); +SimpleVecOwning w0_dense(C * (D + 1), stream); +SimpleVecOwning w0_sparse(C * (D + 1), stream); + +std::vector preds_dense_host(N); +std::vector preds_sparse_host(N); +SimpleVecOwning preds_dense(N, stream); +SimpleVecOwning preds_sparse(N, stream); + +auto test_run = [&](double l1, double l2, Softmax loss) { + qn_params pams; + pams.penalty_l1 = l1; + pams.penalty_l2 = l2; + pams.loss = loss_type; + pams.fit_intercept = loss.fit_intercept; + + double f_dense, f_sparse; + f_dense = run(handle, loss, *Xdev, *ydev, l1, l2, w0_dense.data, z_dense, 0, stream); + f_sparse = run(handle, loss, X_sparse, *ydev, l1, l2, w0_sparse.data, z_sparse, 0, stream); + ASSERT_TRUE(compApprox(f_dense, f_sparse)); + + qnPredict(handle, + pams, + Xdev->data, + Xdev->ord == COL_MAJOR, + N, + D, + C, + w0_dense.data, + preds_dense.data, + stream); + qnPredictSparse(handle, + pams, + X_sparse.values, + X_sparse.cols, + X_sparse.row_ids, + X_sparse.nnz, + N, + D, + C, + w0_sparse.data, + preds_sparse.data, + stream); + + raft::update_host(&preds_dense_host[0], preds_dense.data, preds_dense.len, stream); + raft::update_host(&preds_sparse_host[0], preds_sparse.data, preds_sparse.len, stream); + handle.sync_stream(stream); + for (int i = 0; i < N; i++) { + ASSERT_TRUE(compApprox(preds_dense_host[i], preds_sparse_host[i])); + } + + f_dense = run_api(cuml_handle, + QN_LOSS_SOFTMAX, + C, + loss.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0_dense.data, + z_dense, + 0, + stream); + f_sparse = run_api(cuml_handle, + QN_LOSS_SOFTMAX, + C, + loss.fit_intercept, + X_sparse, + *ydev, + l1, + l2, + w0_sparse.data, + z_sparse, + 0, + stream); + ASSERT_TRUE(compApprox(f_dense, f_sparse)); +}; + +test_run(alpha, 0.0, loss_b); +test_run(0.0, alpha, loss_b); +test_run(alpha, 0.0, loss_no_b); +test_run(0.0, alpha, loss_no_b); +} + +} // namespace raft::solver::quasi_newton From 4131c0ece6e8e044a19bd24d328994bd8d36ca31 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 20 Oct 2022 17:19:43 -0400 Subject: [PATCH 22/35] Exposing solver APIs. It needs some work but it's getting there. --- .../raft/solver/coordinate_descent.cuh | 44 +++++ cpp/include/raft/solver/detail/cd.cuh | 22 +-- cpp/include/raft/solver/detail/lars.cuh | 2 +- cpp/include/raft/solver/detail/preprocess.cuh | 171 ++++++++++++++++++ cpp/include/raft/solver/detail/sgd.cuh | 22 +-- cpp/include/raft/solver/gradient_descent.cuh | 34 ++++ .../raft/solver/least_angle_regression.cuh | 73 ++++++++ cpp/include/raft/solver/quasi_newton.cuh | 2 +- cpp/include/raft/solver/solver_types.hpp | 54 +++++- cpp/test/solver/quasi_newton.cu | 4 +- 10 files changed, 401 insertions(+), 27 deletions(-) create mode 100644 cpp/include/raft/solver/detail/preprocess.cuh diff --git a/cpp/include/raft/solver/coordinate_descent.cuh b/cpp/include/raft/solver/coordinate_descent.cuh index 39b255f524..b7c0e1bbc0 100644 --- a/cpp/include/raft/solver/coordinate_descent.cuh +++ b/cpp/include/raft/solver/coordinate_descent.cuh @@ -16,8 +16,52 @@ #pragma once +#include +#include #include namespace raft::solver::coordinate_descent { + /** + * @brief Minimizes an objective function using the Coordinate Descent solver. + * + * Note: Currently only least squares loss is supported w/ optional lasso and elastic-net penalties: + * f(coef) = 1/2 * || b - Ax ||^2 + * + 1/2 * alpha * (1 - l1_ratio) * ||coef||^2 + * + alpha * l1_ratio * ||coef||_1 + * + * @param[in] handle: Reference of raft::handle_t + * @param[in] A: Input matrix in column-major format (size of n_rows, n_cols) + * @param[in] b: Input vector of labels (size of n_rows) + * @param[in] sample_weights: Optional input vector for sample weights (size n_rows) + * @param[out] x: Output vector of learned coefficients (size of n_cols) + * @param[out] intercept: Optional scalar to hold intercept if desired + */ + template + void minimize(const raft::handle_t& handle, + raft::device_matrix_view A, + raft::device_vector_view b, + std::optional sample_weights, + raft::device_vector_view x, + std::optional> intercept, + cd_params ¶ms) { + + RAFT_EXPECTS(A.extent(0) == b.extent(0), "Number of labels must match the number of rows in input matrix"); + + if(sample_weights.has_value()) { + RAFT_EXPECTS(A.extent(0) == sample_weights.value().extent(0), "Number of sample weights must match number of rows in input matrix"); + } + + RAFT_EXPECTS(x.extent(0) == A.extent(1), "Objective is linear. The number of coefficients must match the number features in the input matrix"); + RAFT_EXPECTS(lossFunct == loss_funct::SQRD_LOSS, "Only squared loss is supported in the current implementation."); + + math_t *intercept_ptr = intercept.has_value() ? intercept.value().data_handle() : nullptr; + math_t *sample_weight_ptr = sample_weights.has_value() ? sample_weights.value().data_handle() : nullptr; + + detail::cdFit(handle, A.data_handle(), A.extent(0), A.extent(1), + b.data_handle(), x.data_handle(), intercept_ptr, + intercept.has_value(), params.normalize, params.epochs, + params.loss, params.alpha, params.l1_ratio, params.shuffle, + params.tol, sample_weight_ptr); + } } \ No newline at end of file diff --git a/cpp/include/raft/solver/detail/cd.cuh b/cpp/include/raft/solver/detail/cd.cuh index bd23f39850..db5b20e90b 100644 --- a/cpp/include/raft/solver/detail/cd.cuh +++ b/cpp/include/raft/solver/detail/cd.cuh @@ -16,15 +16,15 @@ #pragma once -#include "shuffle.h" -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include #include -#include +#include #include #include #include @@ -85,7 +85,7 @@ __global__ void __launch_bounds__(1, 1) cdUpdateCoefKernel(math_t* coefLoc, * * i.e. finds coefficients that minimize the following loss function: * - * f(coef) = 1/2 * || labels - input * coef ||^2 + * f(coef) = 1/2 * || b - A * x ||^2 * + 1/2 * alpha * (1 - l1_ratio) * ||coef||^2 * + alpha * l1_ratio * ||coef||_1 * @@ -174,7 +174,7 @@ void cdFit(const raft::handle_t& handle, mu_labels.resize(1, stream); if (normalize) { norm2_input.resize(n_cols, stream); } - GLM::preProcessData(handle, + preProcessData(handle, input, n_rows, n_cols, @@ -295,7 +295,7 @@ void cdFit(const raft::handle_t& handle, } if (fit_intercept) { - GLM::postProcessData(handle, + postProcessData(handle, input, n_rows, n_cols, diff --git a/cpp/include/raft/solver/detail/lars.cuh b/cpp/include/raft/solver/detail/lars.cuh index d753dd8253..6ee77ec6f8 100644 --- a/cpp/include/raft/solver/detail/lars.cuh +++ b/cpp/include/raft/solver/detail/lars.cuh @@ -809,7 +809,7 @@ void updateCoef(const raft::handle_t& handle, } /** - * @brief Train a regressor using Least Angre Regression. + * @brief Train a regressor using Least Angle Regression. * * Least Angle Regression (LAR or LARS) is a model selection algorithm. It * builds up the model using the following algorithm: diff --git a/cpp/include/raft/solver/detail/preprocess.cuh b/cpp/include/raft/solver/detail/preprocess.cuh new file mode 100644 index 0000000000..5832f6d1d9 --- /dev/null +++ b/cpp/include/raft/solver/detail/preprocess.cuh @@ -0,0 +1,171 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::solver::detail { + +/** + * @brief Center and scale the data, depending on the flags fit_intercept and normalize + * + * @tparam math_t the element type + * @param [inout] input the column-major data of size [n_rows, n_cols] + * @param [in] n_rows + * @param [in] n_cols + * @param [inout] labels vector of size [n_rows] + * @param [out] intercept + * @param [out] mu_input the column-wise means of the input of size [n_cols] + * @param [out] mu_labels the scalar mean of the target (labels vector) + * @param [out] norm2_input the column-wise standard deviations of the input of size [n_cols]; + * note, the biased estimator is used to match sklearn's StandardScaler + * (dividing by n_rows, not by (n_rows - 1)). + * @param [in] fit_intercept whether to center the data / to fit the intercept + * @param [in] normalize whether to normalize the data + * @param [in] stream + */ + template + void preProcessData(const raft::handle_t& handle, + math_t* input, + int n_rows, + int n_cols, + math_t* labels, + math_t* intercept, + math_t* mu_input, + math_t* mu_labels, + math_t* norm2_input, + bool fit_intercept, + bool normalize, + math_t* sample_weight = nullptr) + { + cudaStream_t stream = handle.get_stream(); + raft::common::nvtx::range fun_scope("ML::GLM::preProcessData-%d-%d", n_rows, n_cols); + ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one"); + ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two"); + + if (fit_intercept) { + if (normalize && sample_weight == nullptr) { + raft::stats::meanvar(mu_input, norm2_input, input, n_cols, n_rows, false, false, stream); + raft::linalg::unaryOp( + norm2_input, + norm2_input, + n_cols, + [] __device__(math_t v) { return raft::mySqrt(v); }, + stream); + raft::matrix::linewiseOp( + input, + input, + n_rows, + n_cols, + false, + [] __device__(math_t x, math_t m, math_t s) { return s > 1e-10 ? (x - m) / s : 0; }, + stream, + mu_input, + norm2_input); + } else { + if (sample_weight != nullptr) { + raft::stats::weightedMean( + mu_input, input, sample_weight, n_cols, n_rows, false, false, stream); + } else { + raft::stats::mean(mu_input, input, n_cols, n_rows, false, false, stream); + } + raft::stats::meanCenter(input, input, mu_input, n_cols, n_rows, false, true, stream); + if (normalize) { + raft::linalg::colNorm(norm2_input, + input, + n_cols, + n_rows, + raft::linalg::L2Norm, + false, + stream, + [] __device__(math_t v) { return raft::mySqrt(v); }); + raft::matrix::matrixVectorBinaryDivSkipZero( + input, norm2_input, n_rows, n_cols, false, true, stream, true); + } + } + + if (sample_weight != nullptr) { + raft::stats::weightedMean(mu_labels, labels, sample_weight, 1, n_rows, true, false, stream); + } else { + raft::stats::mean(mu_labels, labels, 1, n_rows, false, false, stream); + } + raft::stats::meanCenter(labels, labels, mu_labels, 1, n_rows, false, true, stream); + } + } + + template + void postProcessData(const raft::handle_t& handle, + math_t* input, + int n_rows, + int n_cols, + math_t* labels, + math_t* coef, + math_t* intercept, + math_t* mu_input, + math_t* mu_labels, + math_t* norm2_input, + bool fit_intercept, + bool normalize) + { + cudaStream_t stream = handle.get_stream(); + raft::common::nvtx::range fun_scope("ML::GLM::postProcessData-%d-%d", n_rows, n_cols); + ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one"); + ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two"); + + cublasHandle_t cublas_handle = handle.get_cublas_handle(); + rmm::device_scalar d_intercept(stream); + + if (normalize) { + raft::matrix::matrixVectorBinaryDivSkipZero( + coef, norm2_input, 1, n_cols, false, true, stream, true); + } + + raft::linalg::gemm( + handle, mu_input, 1, n_cols, coef, d_intercept.data(), 1, 1, CUBLAS_OP_N, CUBLAS_OP_N, stream); + + raft::linalg::subtract(d_intercept.data(), mu_labels, d_intercept.data(), 1, stream); + *intercept = d_intercept.value(stream); + + if (normalize) { + raft::matrix::linewiseOp( + input, + input, + n_rows, + n_cols, + false, + [] __device__(math_t x, math_t m, math_t s) { return s * x + m; }, + stream, + mu_input, + norm2_input); + } else { + raft::stats::meanAdd(input, input, mu_input, n_cols, n_rows, false, true, stream); + } + raft::stats::meanAdd(labels, labels, mu_labels, 1, n_rows, false, true, stream); + } + + }; // end namespace raft::solver::detail \ No newline at end of file diff --git a/cpp/include/raft/solver/detail/sgd.cuh b/cpp/include/raft/solver/detail/sgd.cuh index c03c64d47f..e7a41f7b27 100644 --- a/cpp/include/raft/solver/detail/sgd.cuh +++ b/cpp/include/raft/solver/detail/sgd.cuh @@ -16,15 +16,15 @@ #pragma once -#include "learning_rate.h" -#include "shuffle.h" -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include #include #include @@ -40,7 +40,7 @@ namespace raft::solver::detail { /** - * Fits a linear, lasso, and elastic-net regression model using Coordinate Descent solver + * Fits a linear, lasso, and elastic-net regression model using Gradient Descent solver * @param handle * Reference of raft::handle_t * @param input @@ -123,7 +123,7 @@ void sgdFit(const raft::handle_t& handle, mu_input.resize(n_cols, stream); mu_labels.resize(1, stream); - GLM::preProcessData(handle, + preProcessData(handle, input, n_rows, n_cols, diff --git a/cpp/include/raft/solver/gradient_descent.cuh b/cpp/include/raft/solver/gradient_descent.cuh index 07e49b3cd1..9ba8b16d3a 100644 --- a/cpp/include/raft/solver/gradient_descent.cuh +++ b/cpp/include/raft/solver/gradient_descent.cuh @@ -16,8 +16,42 @@ #pragma once + +#include +#include +#include #include namespace raft::solver::gradient_descent { + /** + * @brief Minimizes an objective function using the Gradient Descent solver and optional + * lasso or elastic-net penalties. + * + * @param[in] handle: Reference of raft::handle_t + * @param[in] A: Input matrix in column-major format (size of n_rows, n_cols) + * @param[in] b: Input vector of labels (size of n_rows) + * @param[out] x: Output vector of coefficients (size of n_cols) + * @param[out] intercept: Optional scalar if fitting the intercept + * @param[in] params: solver hyper-parameters + */ + template + void minimize(const raft::handle_t& handle, + raft::device_matrix_view A, + raft::device_vector_view b, + raft::device_vector_view x, + std::optional intercept, + sgd_params ¶ms) { + + RAFT_EXPECTS(A.extent(0) == b.extent(0), "Number of labels must match the number of rows in input matrix"); + RAFT_EXPECTS(x.extent(0) == A.extent(1), "Objective is linear. The number of coefficients must match the number features in the input matrix"); + + auto intercept_ptr = intercept.has_value() ? intercept.data_handle() ? nullptr; + detail::sgdFit(handle, A.data_handle(), A.extent(0), A.extent(1), b.data_handle(), x.data_handle(), + intercept_ptr, intercept.has_value(), params.batch_size, params.epochs, params.lr_type, + params.eta0, params.power_t, params.loss, params.penalty, params.alpha, params.l1_ratio, + params.shuffle, params.tol, params.n_iter_no_change, handle.get_stream()); + + } + } \ No newline at end of file diff --git a/cpp/include/raft/solver/least_angle_regression.cuh b/cpp/include/raft/solver/least_angle_regression.cuh index 9f58a3c1ba..c449d51e20 100644 --- a/cpp/include/raft/solver/least_angle_regression.cuh +++ b/cpp/include/raft/solver/least_angle_regression.cuh @@ -16,8 +16,81 @@ #pragma once +#include +#include +#include +#include #include namespace raft::solver::least_angle_regression { +/** + * @brief Train a regression model using Least Angle Regression (LARS). + * + * Least Angle Regression (LAR or LARS) is a model selection algorithm. It + * builds up the model using the following algorithm: + * + * 1. We start with all the coefficients equal to zero. + * 2. At each step we select the predictor that has the largest absolute + * correlation with the residual. + * 3. We take the largest step possible in the direction which is equiangular + * with all the predictors selected so far. The largest step is determined + * such that using this step a new predictor will have as much correlation + * with the residual as any of the currently active predictors. + * 4. Stop if max_iter reached or all the predictors are used, or if the + * correlation between any unused predictor and the residual is lower than + * a tolerance. + * + * The solver is based on [1]. The equations referred in the comments correspond + * to the equations in the paper. + * + * Note: this algorithm assumes that the offset is removed from X and y, and + * each feature is normalized: + * - sum_i y_i = 0, + * - sum_i x_{i,j} = 0, sum_i x_{i,j}^2=1 for j=0..n_col-1 + * + * References: + * [1] B. Efron, T. Hastie, I. Johnstone, R Tibshirani, Least Angle Regression + * The Annals of Statistics (2004) Vol 32, No 2, 407-499 + * http://statweb.stanford.edu/~tibs/ftp/lars.pdf + * + * @param handle RAFT handle + * @param[in] A device array of training vectors in column major format, + * size [n_rows * n_cols]. Note that the columns of X will be permuted if + * the Gram matrix is not specified. It is expected that X is normalized so + * that each column has zero mean and unit variance. + * @param[in] b device array of the regression targets, size [n_rows]. y should + * be normalized to have zero mean. + * @param[in] Gram device array containing Gram matrix containing X.T * X. Can be + * nullptr. + * @param[out] x: device array of regression coefficients, has to be allocated on + * entry, size [max_iter] + * @param[in] active_idx device vector containing the indices of active variables. + * Must be allocated on entry. Size [max_iter] + * @param[out] alphas device array to return the maximum correlation along the + * regularization path. Must be allocated on entry, size [max_iter+1]. + * @param[out] n_active host pointer to return the number of active elements (scalar) + * @param[out] coef_path coefficients along the regularization path are returned + * here. Must be nullptr, or a device array already allocated on entry. + * Size [max_iter * (max_iter+1)]. + * @param[in] params: lars hyper-parameters + * @param[in] ld_X leading dimension of A (stride of columns) + * @param[in] ld_G leading dimesion of G + */ +template +void minimize(const raft::handle_t& handle, + raft::device_matrix_view A, + raft::device_vector_view b, + std::optional> Gram, + raft::device_vector_view x, + raft::device_vector_view active_idx, + raft::device_vector_view alphas, + raft::host_scalar_view n_active, + std::optional> coef_path, + lars_params ¶ms, + idx_t ld_X = 0, + idx_t ld_G = 0) { + + + } } \ No newline at end of file diff --git a/cpp/include/raft/solver/quasi_newton.cuh b/cpp/include/raft/solver/quasi_newton.cuh index 3c3f7b9aa7..d16839a506 100644 --- a/cpp/include/raft/solver/quasi_newton.cuh +++ b/cpp/include/raft/solver/quasi_newton.cuh @@ -267,7 +267,7 @@ struct Tikhonov : detail::objectives::Tikhonov { * @return */ template - inline int qn_minimize(const raft::handle_t& handle, + inline int minimize(const raft::handle_t& handle, SimpleVec& x, T* fx, int* num_iters, diff --git a/cpp/include/raft/solver/solver_types.hpp b/cpp/include/raft/solver/solver_types.hpp index f3e7d66a27..cf7f3bcd8b 100644 --- a/cpp/include/raft/solver/solver_types.hpp +++ b/cpp/include/raft/solver/solver_types.hpp @@ -26,13 +26,65 @@ enum lr_type { }; enum loss_funct { - SQRD_LOSS, + SQUARED, HINGE, LOG, }; enum penalty { NONE, L1, L2, ELASTICNET }; +namespace gradient_descent { + template + struct sgd_params { + int batch_size; + int epochs; + lr_type lr_type; + math_t eta0; + math_t power_t; + loss_funct loss; + penalty penalty; + math_t alpha; + math_t l1_ratio; + bool shuffle; + math_t tol; + int n_iter_no_change; + + sgd_params() : batch_size(100), epochs(100), lr_type(lr_type::OPTIMAL), eta0(0.5), power_t(0.5), + loss(loss_funct::SQUARED), penalty(penalty::L1), alpha(0.5), l1_ratio(0.2), shuffle(true), tol(1e-8), n_iter_no_change(5){} + }; +} +namespace coordinate_descent { + template + struct cd_params { + bool normalize; // whether to normalize the data to zero-mean and unit std + int epochs; // number of iterations + loss_funct loss; // loss function to minimize + math_t alpha; // l1 penalty parameter + math_t l1_ratio; // ratio of alpha that will be used for l1 penalty. (1 - l1_ratio) * alpha will be used for l2 penalty + bool shuffle; // randomly pick coordinates + math_t tol; // early-stopping convergence tolerance + + cd_params() : + normalize(true), + epochs(100), + alpha(0.3), + l1_ratio(0.5), + shuffle(true), + tol(1e-8), + loss(loss_funct::SQRD_LOSS) {} + }; +} + +namespace least_angle_regression { + template + struct lars_params { + int max_iter; + math_t eps; + + lars_params(): max_iter(500), eps(-1) {} + }; +} + enum class LarsFitStatus { kOk, kCollinear, kError, kStop }; namespace quasi_newton { diff --git a/cpp/test/solver/quasi_newton.cu b/cpp/test/solver/quasi_newton.cu index 72f9ad9c6d..5a72eb4772 100644 --- a/cpp/test/solver/quasi_newton.cu +++ b/cpp/test/solver/quasi_newton.cu @@ -50,14 +50,14 @@ namespace raft::solver::quasi_newton { if (l2 == 0) { ObjectiveWithData lossWith(&loss, X, y, Z); - return qn_minimize(handle, w0, fx, num_iters, lossWith, l1, opt_param); + return minimize(handle, w0, fx, num_iters, lossWith, l1, opt_param); } else { Tikhonov reg(l2); RegularizedQN obj(&loss, ®); ObjectiveWithData lossWith(&obj, X, y, Z); - return qn_minimize(handle, w0, fx, num_iters, lossWith, l1, opt_param); + return minimize(handle, w0, fx, num_iters, lossWith, l1, opt_param); } } From 87d620ff15de25086d83be8dbfa19756e906cb61 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 20 Oct 2022 19:06:27 -0400 Subject: [PATCH 23/35] Correcting spatial::knn docs to raft::neighbors --- docs/source/cpp_api.rst | 6 +-- .../cpp_api/{clustering.rst => cluster.rst} | 6 +-- docs/source/cpp_api/distance.rst | 10 +++++ docs/source/cpp_api/neighbors.rst | 43 +++++++++++++++++++ .../cpp_api/{optimization.rst => solver.rst} | 0 docs/source/cpp_api/spatial.rst | 31 ------------- 6 files changed, 59 insertions(+), 37 deletions(-) rename docs/source/cpp_api/{clustering.rst => cluster.rst} (77%) create mode 100644 docs/source/cpp_api/distance.rst create mode 100644 docs/source/cpp_api/neighbors.rst rename docs/source/cpp_api/{optimization.rst => solver.rst} (100%) delete mode 100644 docs/source/cpp_api/spatial.rst diff --git a/docs/source/cpp_api.rst b/docs/source/cpp_api.rst index db139031a2..05d3686dc3 100644 --- a/docs/source/cpp_api.rst +++ b/docs/source/cpp_api.rst @@ -9,11 +9,11 @@ RAFT C++ API Reference :maxdepth: 4 cpp_api/core.rst - cpp_api/clustering.rst + cpp_api/cluster.rst cpp_api/linalg.rst cpp_api/matrix.rst - cpp_api/optimization.rst + cpp_api/solver.rst cpp_api/random.rst - cpp_api/spatial.rst + cpp_api/distance.rst cpp_api/sparse.rst cpp_api/stats.rst \ No newline at end of file diff --git a/docs/source/cpp_api/clustering.rst b/docs/source/cpp_api/cluster.rst similarity index 77% rename from docs/source/cpp_api/clustering.rst rename to docs/source/cpp_api/cluster.rst index 90ca786cc1..781180a72a 100644 --- a/docs/source/cpp_api/clustering.rst +++ b/docs/source/cpp_api/cluster.rst @@ -1,7 +1,7 @@ -Clustering -========== +Cluster +======= -This page provides C++ class references for the publicly-exposed elements of the clustering package. +This page provides C++ class references for the publicly-exposed elements of the cluster package. .. doxygennamespace:: raft::cluster :project: RAFT diff --git a/docs/source/cpp_api/distance.rst b/docs/source/cpp_api/distance.rst new file mode 100644 index 0000000000..c2bce860d5 --- /dev/null +++ b/docs/source/cpp_api/distance.rst @@ -0,0 +1,10 @@ +Distance +======== + +This page provides C++ class references for the publicly-exposed elements of the distance package. + +Distance +######## + +.. doxygennamespace:: raft::distance + :project: RAFT diff --git a/docs/source/cpp_api/neighbors.rst b/docs/source/cpp_api/neighbors.rst new file mode 100644 index 0000000000..962bbd1efe --- /dev/null +++ b/docs/source/cpp_api/neighbors.rst @@ -0,0 +1,43 @@ +Neighbors +========= + +This page provides C++ class references for the publicly-exposed elements of the neighbors package. + + +Brute-force +----------- + +.. doxygennamespace:: raft::neighbors::brute_force + :project: RAFT + + +IVF-Flat +-------- + +.. doxygennamespace:: raft::neighbors::ivf_flat + :project: RAFT + :members: + + +IVF-PQ +-------- + +.. doxygennamespace:: raft::neighbors::ivf_pq + :project: RAFT + :members: + + +Epsilon Neighborhood +-------------------- + +.. doxygennamespace:: raft::neighbors::epsilon_neighborhood + :project: RAFT + :members: + + +Random Ball Cover +----------------- + +.. doxygennamespace:: raft::neighbors::ball_cover + :project: RAFT + :members: diff --git a/docs/source/cpp_api/optimization.rst b/docs/source/cpp_api/solver.rst similarity index 100% rename from docs/source/cpp_api/optimization.rst rename to docs/source/cpp_api/solver.rst diff --git a/docs/source/cpp_api/spatial.rst b/docs/source/cpp_api/spatial.rst deleted file mode 100644 index 9bda00dab7..0000000000 --- a/docs/source/cpp_api/spatial.rst +++ /dev/null @@ -1,31 +0,0 @@ -Spatial -======= - -This page provides C++ class references for the publicly-exposed elements of the spatial package. - -Distance -######## - -.. doxygennamespace:: raft::distance - :project: RAFT - - -Nearest Neighbors -################# - -.. doxygenfunction:: raft::spatial::knn::brute_force_knn - :project: RAFT - -.. doxygenfunction:: raft::spatial::knn::select_k - :project: RAFT - -.. doxygenfunction:: raft::spatial::knn::knn_merge_parts - :project: RAFT - - -IVF-Flat --------- - -.. doxygennamespace:: raft::spatial::knn::ivf_flat - :project: RAFT - :members: From febd1d40eb2bc511731b9335912e8bbf9cbeb3f8 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 20 Oct 2022 19:08:48 -0400 Subject: [PATCH 24/35] Adding neighbors to index --- docs/source/cpp_api.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/cpp_api.rst b/docs/source/cpp_api.rst index 05d3686dc3..d10d9773a5 100644 --- a/docs/source/cpp_api.rst +++ b/docs/source/cpp_api.rst @@ -10,10 +10,11 @@ RAFT C++ API Reference cpp_api/core.rst cpp_api/cluster.rst + cpp_api/distance.rst cpp_api/linalg.rst cpp_api/matrix.rst + cpp_api/neighbors.rst cpp_api/solver.rst cpp_api/random.rst - cpp_api/distance.rst cpp_api/sparse.rst cpp_api/stats.rst \ No newline at end of file From 63e2e8bdaa1bc71f41bc59c46f86bf61e9511b6a Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 20 Oct 2022 20:01:34 -0400 Subject: [PATCH 25/35] Updating docs --- cpp/include/raft/cluster/kmeans.cuh | 470 ++++++++++++++++++ cpp/include/raft/cluster/kmeans_types.hpp | 13 +- cpp/include/raft/cluster/single_linkage.cuh | 81 +++ .../raft/cluster/single_linkage_types.hpp | 8 + docs/source/cpp_api/cluster.rst | 17 +- docs/source/cpp_api/core.rst | 63 ++- docs/source/cpp_api/solver.rst | 5 +- 7 files changed, 647 insertions(+), 10 deletions(-) diff --git a/cpp/include/raft/cluster/kmeans.cuh b/cpp/include/raft/cluster/kmeans.cuh index 0ce35da4a5..d737b1b736 100644 --- a/cpp/include/raft/cluster/kmeans.cuh +++ b/cpp/include/raft/cluster/kmeans.cuh @@ -21,6 +21,476 @@ #include #include +namespace raft::cluster::kmeans { + +/** + * @brief Find clusters with k-means algorithm. + * Initial centroids are chosen with k-means++ algorithm. Empty + * clusters are reinitialized by choosing new centroids with + * k-means++ algorithm. + * @tparam DataT the type of data used for weights, distances. + * @tparam IndexT the type of data used for indexing. + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances to cluster. The data must + * be in row-major format. + * [dim = n_samples x n_features] + * @param[in] sample_weight Optional weights for each observation in X. + * [len = n_samples] + * @param[inout] centroids [in] When init is InitMethod::Array, use + * centroids as the initial cluster centers. + * [out] The generated centroids from the + * kmeans algorithm are stored at the address + * pointed by 'centroids'. + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. + * @param[out] n_iter Number of iterations run. + */ +template +void fit(handle_t const& handle, + const KMeansParams& params, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + detail::kmeans_fit(handle, params, X, sample_weight, centroids, inertia, n_iter); +} + +template +void fit(handle_t const& handle, + const KMeansParams& params, + const DataT* X, + const DataT* sample_weight, + DataT* centroids, + IndexT n_samples, + IndexT n_features, + DataT& inertia, + IndexT& n_iter) +{ + detail::kmeans_fit( + handle, params, X, sample_weight, centroids, n_samples, n_features, inertia, n_iter); +} + +/** + * @brief Predict the closest cluster each sample in X belongs to. + * @tparam DataT the type of data used for weights, distances. + * @tparam IndexT the type of data used for indexing. + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X New data to predict. + * [dim = n_samples x n_features] + * @param[in] sample_weight Optional weights for each observation in X. + * [len = n_samples] + * @param[in] centroids Cluster centroids. The data must be in + * row-major format. + * [dim = n_clusters x n_features] + * @param[in] normalize_weight True if the weights should be normalized + * @param[out] labels Index of the cluster each sample in X + * belongs to. + * [len = n_samples] + * @param[out] inertia Sum of squared distances of samples to + * their closest cluster center. + */ +template +void predict(handle_t const& handle, + const KMeansParams& params, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::device_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia) +{ + detail::kmeans_predict( + handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia); +} + +template +void predict(handle_t const& handle, + const KMeansParams& params, + const DataT* X, + const DataT* sample_weight, + const DataT* centroids, + IndexT n_samples, + IndexT n_features, + IndexT* labels, + bool normalize_weight, + DataT& inertia) +{ + detail::kmeans_predict(handle, + params, + X, + sample_weight, + centroids, + n_samples, + n_features, + labels, + normalize_weight, + inertia); +} + +/** + * @brief Compute k-means clustering and predicts cluster index for each sample + * in the input. + * + * @tparam DataT the type of data used for weights, distances. + * @tparam IndexT the type of data used for indexing. + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances to cluster. The data must be + * in row-major format. + * [dim = n_samples x n_features] + * @param[in] sample_weight Optional weights for each observation in X. + * [len = n_samples] + * @param[inout] centroids Optional + * [in] When init is InitMethod::Array, use + * centroids as the initial cluster centers + * [out] The generated centroids from the + * kmeans algorithm are stored at the address + * pointed by 'centroids'. + * [dim = n_clusters x n_features] + * @param[out] labels Index of the cluster each sample in X belongs + * to. + * [len = n_samples] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. + * @param[out] n_iter Number of iterations run. + */ +template +void fit_predict(handle_t const& handle, + const KMeansParams& params, + raft::device_matrix_view X, + std::optional> sample_weight, + std::optional> centroids, + raft::device_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + detail::kmeans_fit_predict( + handle, params, X, sample_weight, centroids, labels, inertia, n_iter); +} + +template +void fit_predict(handle_t const& handle, + const KMeansParams& params, + const DataT* X, + const DataT* sample_weight, + DataT* centroids, + IndexT n_samples, + IndexT n_features, + IndexT* labels, + DataT& inertia, + IndexT& n_iter) +{ + detail::kmeans_fit_predict( + handle, params, X, sample_weight, centroids, n_samples, n_features, labels, inertia, n_iter); +} + +/** + * @brief Transform X to a cluster-distance space. + * + * @tparam DataT the type of data used for weights, distances. + * @tparam IndexT the type of data used for indexing. + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances to cluster. The data must + * be in row-major format + * [dim = n_samples x n_features] + * @param[in] centroids Cluster centroids. The data must be in row-major format. + * [dim = n_clusters x n_features] + * @param[out] X_new X transformed in the new space. + * [dim = n_samples x n_features] + */ +template +void transform(const raft::handle_t& handle, + const KMeansParams& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_matrix_view X_new) +{ + detail::kmeans_transform(handle, params, X, centroids, X_new); +} + +template +void transform(const raft::handle_t& handle, + const KMeansParams& params, + const DataT* X, + const DataT* centroids, + IndexT n_samples, + IndexT n_features, + DataT* X_new) +{ + detail::kmeans_transform( + handle, params, X, centroids, n_samples, n_features, X_new); +} + +template +using SamplingOp = detail::SamplingOp; + +template +using KeyValueIndexOp = detail::KeyValueIndexOp; + +/** + * @brief Select centroids according to a sampling operation + * + * @tparam DataT the type of data used for weights, distances. + * @tparam IndexT the type of data used for indexing. + * + * @param[in] handle The raft handle + * @param[in] X The data in row-major format + * [dim = n_samples x n_features] + * @param[in] minClusterDistance Distance for every sample to it's nearest centroid + * [dim = n_samples] + * @param[in] isSampleCentroid Flag the sample choosen as initial centroid + * [dim = n_samples] + * @param[in] select_op The sampling operation used to select the centroids + * @param[out] inRankCp The sampled centroids + * [dim = n_selected_centroids x n_features] + * @param[in] workspace Temporary workspace buffer which can get resized + * + */ +template +void sample_centroids(const raft::handle_t& handle, + const raft::device_matrix_view& X, + const raft::device_vector_view& minClusterDistance, + const raft::device_vector_view& isSampleCentroid, + SamplingOp& select_op, + rmm::device_uvector& inRankCp, + rmm::device_uvector& workspace) +{ + detail::sampleCentroids( + handle, X, minClusterDistance, isSampleCentroid, select_op, inRankCp, workspace); +} + +/** + * @brief Compute cluster cost + * + * @tparam DataT the type of data used for weights, distances. + * @tparam ReductionOpT the type of data used for the reduction operation. + * + * @param[in] handle The raft handle + * @param[in] minClusterDistance Distance for every sample to it's nearest centroid + * [dim = n_samples] + * @param[in] workspace Temporary workspace buffer which can get resized + * @param[out] clusterCost Resulting cluster cost + * @param[in] reduction_op The reduction operation used for the cost + * + */ +template +void cluster_cost(const raft::handle_t& handle, + const raft::device_vector_view& minClusterDistance, + rmm::device_uvector& workspace, + const raft::device_scalar_view& clusterCost, + ReductionOpT reduction_op) +{ + detail::computeClusterCost( + handle, minClusterDistance, workspace, clusterCost, reduction_op); +} + +/** + * @brief Compute distance for every sample to it's nearest centroid + * + * @tparam DataT the type of data used for weights, distances. + * @tparam IndexT the type of data used for indexing. + * + * @param[in] handle The raft handle + * @param[in] params The parameters for KMeans + * @param[in] X The data in row-major format + * [dim = n_samples x n_features] + * @param[in] centroids Centroids data + * [dim = n_cluster x n_features] + * @param[out] minClusterDistance Distance for every sample to it's nearest centroid + * [dim = n_samples] + * @param[in] L2NormX L2 norm of X : ||x||^2 + * [dim = n_samples] + * @param[out] L2NormBuf_OR_DistBuf Resizable buffer to store L2 norm of centroids or distance + * matrix + * @param[in] workspace Temporary workspace buffer which can get resized + * + */ +template +void min_cluster_distance_compute(const raft::handle_t& handle, + const KMeansParams& params, + const raft::device_matrix_view& X, + const raft::device_matrix_view& centroids, + const raft::device_vector_view& minClusterDistance, + const raft::device_vector_view& L2NormX, + rmm::device_uvector& L2NormBuf_OR_DistBuf, + rmm::device_uvector& workspace) +{ + detail::minClusterDistanceCompute( + handle, params, X, centroids, minClusterDistance, L2NormX, L2NormBuf_OR_DistBuf, workspace); +} + +/** + * @brief Calculates a pair for every sample in input 'X' where key is an + * index of one of the 'centroids' (index of the nearest centroid) and 'value' + * is the distance between the sample and the 'centroid[key]' + * + * @tparam DataT the type of data used for weights, distances. + * @tparam IndexT the type of data used for indexing. + * + * @param[in] handle The raft handle + * @param[in] params The parameters for KMeans + * @param[in] X The data in row-major format + * [dim = n_samples x n_features] + * @param[in] centroids Centroids data + * [dim = n_cluster x n_features] + * @param[out] minClusterAndDistance Distance vector that contains for every sample, the nearest + * centroid and it's distance + * [dim = n_samples] + * @param[in] L2NormX L2 norm of X : ||x||^2 + * [dim = n_samples] + * @param[out] L2NormBuf_OR_DistBuf Resizable buffer to store L2 norm of centroids or distance + * matrix + * @param[in] workspace Temporary workspace buffer which can get resized + * + */ +template +void min_cluster_and_distance( + const raft::handle_t& handle, + const KMeansParams& params, + const raft::device_matrix_view X, + const raft::device_matrix_view centroids, + const raft::device_vector_view, IndexT>& minClusterAndDistance, + const raft::device_vector_view& L2NormX, + rmm::device_uvector& L2NormBuf_OR_DistBuf, + rmm::device_uvector& workspace) +{ + detail::minClusterAndDistanceCompute( + handle, params, X, centroids, minClusterAndDistance, L2NormX, L2NormBuf_OR_DistBuf, workspace); +} + +/** + * @brief Shuffle and randomly select 'n_samples_to_gather' from input 'in' and stores + * in 'out' does not modify the input + * + * @tparam DataT the type of data used for weights, distances. + * @tparam IndexT the type of data used for indexing. + * + * @param[in] handle The raft handle + * @param[in] in The data to shuffle and gather + * [dim = n_samples x n_features] + * @param[out] out The sampled data + * [dim = n_samples_to_gather x n_features] + * @param[in] n_samples_to_gather Number of sample to gather + * @param[in] seed Seed for the shuffle + * @param[in] workspace Temporary workspace buffer which can get resized + * + */ +template +void shuffle_and_gather(const raft::handle_t& handle, + const raft::device_matrix_view& in, + const raft::device_matrix_view& out, + uint32_t n_samples_to_gather, + uint64_t seed, + rmm::device_uvector* workspace = nullptr) +{ + detail::shuffleAndGather(handle, in, out, n_samples_to_gather, seed, workspace); +} + +/** + * @brief Count the number of samples in each cluster + * + * @tparam DataT the type of data used for weights, distances. + * @tparam IndexT the type of data used for indexing. + * + * @param[in] handle The raft handle + * @param[in] params The parameters for KMeans + * @param[in] X The data in row-major format + * [dim = n_samples x n_features] + * @param[in] L2NormX L2 norm of X : ||x||^2 + * [dim = n_samples] + * @param[in] centroids Centroids data + * [dim = n_cluster x n_features] + * @param[in] workspace Temporary workspace buffer which can get resized + * @param[out] sampleCountInCluster The count for each centroid + * [dim = n_cluster] + * + */ +template +void count_samples_in_cluster(const raft::handle_t& handle, + const KMeansParams& params, + const raft::device_matrix_view& X, + const raft::device_vector_view& L2NormX, + const raft::device_matrix_view& centroids, + rmm::device_uvector& workspace, + const raft::device_vector_view& sampleCountInCluster) +{ + detail::countSamplesInCluster( + handle, params, X, L2NormX, centroids, workspace, sampleCountInCluster); +} + +/* + * @brief Selects 'n_clusters' samples from the input X using kmeans++ algorithm. + + * @note This is the algorithm described in + * "k-means++: the advantages of careful seeding". 2007, Arthur, D. and Vassilvitskii, S. + * ACM-SIAM symposium on Discrete algorithms. + * + * @tparam DataT the type of data used for weights, distances. + * @tparam IndexT the type of data used for indexing. + * + * @param[in] handle The raft handle + * @param[in] params The parameters for KMeans + * @param[in] X The data in row-major format + * [dim = n_samples x n_features] + * @param[out] centroids Centroids data + * [dim = n_cluster x n_features] + * @param[in] workspace Temporary workspace buffer which can get resized + */ +template +void init_plus_plus(const raft::handle_t& handle, + const KMeansParams& params, + const raft::device_matrix_view& X, + const raft::device_matrix_view& centroidsRawData, + rmm::device_uvector& workspace) +{ + detail::kmeansPlusPlus(handle, params, X, centroidsRawData, workspace); +} + +/* + * @brief Main function used to fit KMeans (after cluster initialization) + * + * @tparam DataT the type of data used for weights, distances. + * @tparam IndexT the type of data used for indexing. + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances to cluster. The data must + * be in row-major format. + * [dim = n_samples x n_features] + * @param[in] sample_weight Weights for each observation in X. + * [len = n_samples] + * @param[inout] centroids [in] Initial cluster centers. + * [out] The generated centroids from the + * kmeans algorithm are stored at the address + * pointed by 'centroids'. + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. + * @param[out] n_iter Number of iterations run. + * @param[in] workspace Temporary workspace buffer which can get resized + */ +template +void fit_main(const raft::handle_t& handle, + const KMeansParams& params, + const raft::device_matrix_view& X, + const raft::device_vector_view& weight, + const raft::device_matrix_view& centroidsRawData, + const raft::host_scalar_view& inertia, + const raft::host_scalar_view& n_iter, + rmm::device_uvector& workspace) +{ + detail::kmeans_fit_main( + handle, params, X, weight, centroidsRawData, inertia, n_iter, workspace); +} + +} // end namespace raft::cluster::kmeans + namespace raft::cluster { /** * @brief Find clusters with k-means algorithm. diff --git a/cpp/include/raft/cluster/kmeans_types.hpp b/cpp/include/raft/cluster/kmeans_types.hpp index 87fc7c1880..bb8e1a2b73 100644 --- a/cpp/include/raft/cluster/kmeans_types.hpp +++ b/cpp/include/raft/cluster/kmeans_types.hpp @@ -18,8 +18,9 @@ #include #include -namespace raft { -namespace cluster { +namespace raft::cluster { + +namespace kmeans { struct KMeansParams { enum InitMethod { KMeansPlusPlus, Random, Array }; @@ -69,5 +70,9 @@ struct KMeansParams { bool inertia_check = false; }; -} // namespace cluster -} // namespace raft + +} // namespace kmeans + +using kmeans::KMeansParams; + +} // namespace raft::cluster diff --git a/cpp/include/raft/cluster/single_linkage.cuh b/cpp/include/raft/cluster/single_linkage.cuh index 8e33b8389d..ca2234f01f 100644 --- a/cpp/include/raft/cluster/single_linkage.cuh +++ b/cpp/include/raft/cluster/single_linkage.cuh @@ -21,6 +21,87 @@ namespace raft::cluster { +namespace hierarchy { +constexpr int DEFAULT_CONST_C = 15; + +/** + * Single-linkage clustering, capable of constructing a KNN graph to + * scale the algorithm beyond the n^2 memory consumption of implementations + * that use the fully-connected graph of pairwise distances by connecting + * a knn graph when k is not large enough to connect it. + + * @tparam value_idx + * @tparam value_t + * @tparam dist_type method to use for constructing connectivities graph + * @param[in] handle raft handle + * @param[in] X dense input matrix in row-major layout + * @param[in] m number of rows in X + * @param[in] n number of columns in X + * @param[in] metric distance metrix to use when constructing connectivities graph + * @param[out] out struct containing output dendrogram and cluster assignments + * @param[in] c a constant used when constructing connectivities from knn graph. Allows the indirect + control + * of k. The algorithm will set `k = log(n) + c` + * @param[in] n_clusters number of clusters to assign data samples + */ +template +void single_linkage(const raft::handle_t& handle, + const value_t* X, + size_t m, + size_t n, + raft::distance::DistanceType metric, + linkage_output* out, + int c, + size_t n_clusters) +{ + detail::single_linkage( + handle, X, m, n, metric, out, c, n_clusters); +} + +/** + * Single-linkage clustering, capable of constructing a KNN graph to + * scale the algorithm beyond the n^2 memory consumption of implementations + * that use the fully-connected graph of pairwise distances by connecting + * a knn graph when k is not large enough to connect it. + + * @tparam value_idx + * @tparam value_t + * @tparam dist_type method to use for constructing connectivities graph + * @param[in] handle raft handle + * @param[in] X dense input matrix in row-major layout + * @param[out] dendrogram output dendrogram (size [n_rows - 1] * 2) + * @param[out] labels output labels vector (size n_rows) + * @param[in] metric distance metrix to use when constructing connectivities graph + * @param[in] n_clusters number of clusters to assign data samples + * @param[in] c a constant used when constructing connectivities from knn graph. Allows the indirect + control of k. The algorithm will set `k = log(n) + c` + */ +template +void single_linkage(const raft::handle_t& handle, + raft::device_matrix_view X, + raft::device_matrix_view dendrogram, + raft::device_vector_view labels, + raft::distance::DistanceType metric, + size_t n_clusters, + std::optional c = std::make_optional(DEFAULT_CONST_C)) +{ + linkage_output out_arrs; + out_arrs.children = dendrogram.data_handle(); + out_arrs.labels = labels.data_handle(); + + single_linkage(handle, + X.data_handle(), + static_cast(X.extent(0)), + static_cast(X.extent(1)), + metric, + &out_arrs, + c.has_value() ? c.value() : DEFAULT_CONST_C, + n_clusters); +} +} // namespace hierarchy + constexpr int DEFAULT_CONST_C = 15; /** diff --git a/cpp/include/raft/cluster/single_linkage_types.hpp b/cpp/include/raft/cluster/single_linkage_types.hpp index 79f2ede482..28b245a2cf 100644 --- a/cpp/include/raft/cluster/single_linkage_types.hpp +++ b/cpp/include/raft/cluster/single_linkage_types.hpp @@ -20,6 +20,7 @@ namespace raft::cluster { +namespace hierarchy { enum LinkageDistance { PAIRWISE = 0, KNN_GRAPH = 1 }; /** @@ -58,4 +59,11 @@ class linkage_output_int : public linkage_output { class linkage_output_int64 : public linkage_output { }; +} // end namespace hierarchy + +using hierarchy::linkage_output; +using hierarchy::linkage_output_int; +using hierarchy::linkage_output_int64; +using hierarchy::LinkageDistance; + }; // namespace raft::cluster \ No newline at end of file diff --git a/docs/source/cpp_api/cluster.rst b/docs/source/cpp_api/cluster.rst index 781180a72a..41816482cc 100644 --- a/docs/source/cpp_api/cluster.rst +++ b/docs/source/cpp_api/cluster.rst @@ -3,10 +3,25 @@ Cluster This page provides C++ class references for the publicly-exposed elements of the cluster package. -.. doxygennamespace:: raft::cluster +K-Means +------- + +.. doxygennamespace:: raft::cluster::kmeans + :project: RAFT + :members: + + +Hierarchical Clustering +----------------------- + +.. doxygennamespace:: raft::cluster::hierarchy :project: RAFT :members: + +Spectral Clustering +------------------- + .. doxygennamespace:: raft::spectral :project: RAFT :members: \ No newline at end of file diff --git a/docs/source/cpp_api/core.rst b/docs/source/cpp_api/core.rst index ef6270556e..d4891bf0b3 100644 --- a/docs/source/cpp_api/core.rst +++ b/docs/source/cpp_api/core.rst @@ -4,7 +4,6 @@ Core This page provides C++ class references for the publicly-exposed elements of the core package. - handle_t ######## @@ -20,6 +19,13 @@ interruptible :project: RAFT :members: +NVTX +#### + +.. doxygennamespace:: raft::common::nvtx + :project: RAFT + :members: + mdarray ####### @@ -28,11 +34,64 @@ mdarray :project: RAFT :members: +.. doxygenclass:: raft::make_device_matrix + :project: RAFT + +.. doxygenclass:: raft::make_device_vector + :project: RAFT + +.. doxygenclass:: raft::make_device_scalar + :project: RAFT + +.. doxygenclass:: raft::make_host_matrix + :project: RAFT + +.. doxygenclass:: raft::make_host_vector + :project: RAFT + +.. doxygenclass:: raft::make_device_scalar + :project: RAFT + + +mdspan +####### + +.. doxygenfunction:: raft::make_device_mdspan + :project: RAFT + +.. doxygenfunction:: raft::make_device_matrix_view + :project: RAFT + +.. doxygenfunction:: raft::make_device_vector_view + :project: RAFT + +.. doxygenfunction:: raft::make_device_scalar_view + :project: RAFT + +.. doxygenfunction:: raft::make_host_matrix_view + :project: RAFT + +.. doxygenfunction:: raft::make_host_vector_view + :project: RAFT + +.. doxygenfunction:: raft::make_device_scalar_view + :project: RAFT span #### -.. doxygenclass:: raft::span +.. doxygenclass:: raft::device_span + :project: RAFT + :members: + +.. doxygenclass:: raft::host_span + :project: RAFT + :members: + +Key-Value Pair +############## + +.. doxygenclass:: raft::KeyValuePair :project: RAFT :members: diff --git a/docs/source/cpp_api/solver.rst b/docs/source/cpp_api/solver.rst index 75cec2494e..a8b93ca046 100644 --- a/docs/source/cpp_api/solver.rst +++ b/docs/source/cpp_api/solver.rst @@ -7,13 +7,12 @@ This page provides C++ class references for the publicly-exposed elements of the Linear Assignment Problem ######################### -.. doxygenclass:: raft::lap::LinearAssignmentProblem +.. doxygenclass:: raft::solver::LinearAssignmentProblem :project: RAFT :members: Minimum Spanning Tree ##################### -.. doxygennamespace:: raft::mst +.. doxygenfunction:: raft::sparse::solver::mst :project: RAFT - :members: From 831c3d244dcb7c428a5f2e63ad964e265c87575d Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 20 Oct 2022 20:29:15 -0400 Subject: [PATCH 26/35] Making sure we call new cluster namespaced code from deprecated code to spot any bugs/syntax issues --- cpp/include/raft/cluster/kmeans.cuh | 71 ++++++++++----------- cpp/include/raft/cluster/single_linkage.cuh | 20 ++---- 2 files changed, 39 insertions(+), 52 deletions(-) diff --git a/cpp/include/raft/cluster/kmeans.cuh b/cpp/include/raft/cluster/kmeans.cuh index d737b1b736..6384cfdeaa 100644 --- a/cpp/include/raft/cluster/kmeans.cuh +++ b/cpp/include/raft/cluster/kmeans.cuh @@ -312,14 +312,14 @@ void cluster_cost(const raft::handle_t& handle, * */ template -void min_cluster_distance_compute(const raft::handle_t& handle, - const KMeansParams& params, - const raft::device_matrix_view& X, - const raft::device_matrix_view& centroids, - const raft::device_vector_view& minClusterDistance, - const raft::device_vector_view& L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - rmm::device_uvector& workspace) +void min_cluster_distance(const raft::handle_t& handle, + const KMeansParams& params, + const raft::device_matrix_view& X, + const raft::device_matrix_view& centroids, + const raft::device_vector_view& minClusterDistance, + const raft::device_vector_view& L2NormX, + rmm::device_uvector& L2NormBuf_OR_DistBuf, + rmm::device_uvector& workspace) { detail::minClusterDistanceCompute( handle, params, X, centroids, minClusterDistance, L2NormX, L2NormBuf_OR_DistBuf, workspace); @@ -525,7 +525,7 @@ void kmeans_fit(handle_t const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - detail::kmeans_fit(handle, params, X, sample_weight, centroids, inertia, n_iter); + kmeans::fit(handle, params, X, sample_weight, centroids, inertia, n_iter); } template @@ -539,7 +539,7 @@ void kmeans_fit(handle_t const& handle, DataT& inertia, IndexT& n_iter) { - detail::kmeans_fit( + kmeans::fit( handle, params, X, sample_weight, centroids, n_samples, n_features, inertia, n_iter); } @@ -573,7 +573,7 @@ void kmeans_predict(handle_t const& handle, bool normalize_weight, raft::host_scalar_view inertia) { - detail::kmeans_predict( + kmeans::predict( handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia); } @@ -589,16 +589,16 @@ void kmeans_predict(handle_t const& handle, bool normalize_weight, DataT& inertia) { - detail::kmeans_predict(handle, - params, - X, - sample_weight, - centroids, - n_samples, - n_features, - labels, - normalize_weight, - inertia); + kmeans::predict(handle, + params, + X, + sample_weight, + centroids, + n_samples, + n_features, + labels, + normalize_weight, + inertia); } /** @@ -638,7 +638,7 @@ void kmeans_fit_predict(handle_t const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - detail::kmeans_fit_predict( + kmeans::fit_predict( handle, params, X, sample_weight, centroids, labels, inertia, n_iter); } @@ -654,7 +654,7 @@ void kmeans_fit_predict(handle_t const& handle, DataT& inertia, IndexT& n_iter) { - detail::kmeans_fit_predict( + kmeans::fit_predict( handle, params, X, sample_weight, centroids, n_samples, n_features, labels, inertia, n_iter); } @@ -680,7 +680,7 @@ void kmeans_transform(const raft::handle_t& handle, raft::device_matrix_view centroids, raft::device_matrix_view X_new) { - detail::kmeans_transform(handle, params, X, centroids, X_new); + kmeans::transform(handle, params, X, centroids, X_new); } template @@ -692,15 +692,14 @@ void kmeans_transform(const raft::handle_t& handle, IndexT n_features, DataT* X_new) { - detail::kmeans_transform( - handle, params, X, centroids, n_samples, n_features, X_new); + kmeans::transform(handle, params, X, centroids, n_samples, n_features, X_new); } template -using SamplingOp = detail::SamplingOp; +using SamplingOp = kmeans::SamplingOp; template -using KeyValueIndexOp = detail::KeyValueIndexOp; +using KeyValueIndexOp = kmeans::KeyValueIndexOp; /** * @brief Select centroids according to a sampling operation @@ -730,7 +729,7 @@ void sampleCentroids(const raft::handle_t& handle, rmm::device_uvector& inRankCp, rmm::device_uvector& workspace) { - detail::sampleCentroids( + kmeans::sample_entroids( handle, X, minClusterDistance, isSampleCentroid, select_op, inRankCp, workspace); } @@ -755,7 +754,7 @@ void computeClusterCost(const raft::handle_t& handle, const raft::device_scalar_view& clusterCost, ReductionOpT reduction_op) { - detail::computeClusterCost( + kmeans::cluster_cost( handle, minClusterDistance, workspace, clusterCost, reduction_op); } @@ -790,7 +789,7 @@ void minClusterDistanceCompute(const raft::handle_t& handle, rmm::device_uvector& L2NormBuf_OR_DistBuf, rmm::device_uvector& workspace) { - detail::minClusterDistanceCompute( + kmeans::min_cluster_distance( handle, params, X, centroids, minClusterDistance, L2NormX, L2NormBuf_OR_DistBuf, workspace); } @@ -829,7 +828,7 @@ void minClusterAndDistanceCompute( rmm::device_uvector& L2NormBuf_OR_DistBuf, rmm::device_uvector& workspace) { - detail::minClusterAndDistanceCompute( + kmeans::min_cluster_and_distance( handle, params, X, centroids, minClusterAndDistance, L2NormX, L2NormBuf_OR_DistBuf, workspace); } @@ -858,7 +857,7 @@ void shuffleAndGather(const raft::handle_t& handle, uint64_t seed, rmm::device_uvector* workspace = nullptr) { - detail::shuffleAndGather(handle, in, out, n_samples_to_gather, seed, workspace); + kmeans::shuffle_and_gather(handle, in, out, n_samples_to_gather, seed, workspace); } /** @@ -889,7 +888,7 @@ void countSamplesInCluster(const raft::handle_t& handle, rmm::device_uvector& workspace, const raft::device_vector_view& sampleCountInCluster) { - detail::countSamplesInCluster( + kmeans::count_samples_in_cluster( handle, params, X, L2NormX, centroids, workspace, sampleCountInCluster); } @@ -918,7 +917,7 @@ void kmeansPlusPlus(const raft::handle_t& handle, const raft::device_matrix_view& centroidsRawData, rmm::device_uvector& workspace) { - detail::kmeansPlusPlus(handle, params, X, centroidsRawData, workspace); + kmeans::init_plus_plus(handle, params, X, centroidsRawData, workspace); } /* @@ -954,7 +953,7 @@ void kmeans_fit_main(const raft::handle_t& handle, const raft::host_scalar_view& n_iter, rmm::device_uvector& workspace) { - detail::kmeans_fit_main( + kmeans::fit_main( handle, params, X, weight, centroidsRawData, inertia, n_iter, workspace); } } // namespace raft::cluster diff --git a/cpp/include/raft/cluster/single_linkage.cuh b/cpp/include/raft/cluster/single_linkage.cuh index ca2234f01f..36fc812445 100644 --- a/cpp/include/raft/cluster/single_linkage.cuh +++ b/cpp/include/raft/cluster/single_linkage.cuh @@ -102,8 +102,6 @@ void single_linkage(const raft::handle_t& handle, } } // namespace hierarchy -constexpr int DEFAULT_CONST_C = 15; - /** * Single-linkage clustering, capable of constructing a KNN graph to * scale the algorithm beyond the n^2 memory consumption of implementations @@ -136,7 +134,7 @@ void single_linkage(const raft::handle_t& handle, int c, size_t n_clusters) { - detail::single_linkage( + hierarchy::single_linkage( handle, X, m, n, metric, out, c, n_clusters); } @@ -165,20 +163,10 @@ void single_linkage(const raft::handle_t& handle, raft::device_vector_view labels, raft::distance::DistanceType metric, size_t n_clusters, - std::optional c = std::make_optional(DEFAULT_CONST_C)) + std::optional c = std::make_optional(hierarchy::DEFAULT_CONST_C)) { - linkage_output out_arrs; - out_arrs.children = dendrogram.data_handle(); - out_arrs.labels = labels.data_handle(); - - single_linkage(handle, - X.data_handle(), - static_cast(X.extent(0)), - static_cast(X.extent(1)), - metric, - &out_arrs, - c.has_value() ? c.value() : DEFAULT_CONST_C, - n_clusters); + hierarchy::single_linkage( + handle, X, dendrogram, labels, metric, n_clusters, c); } }; // namespace raft::cluster From 3c8bad164bc336b4c2f4fe4adde459dbda7ae0fa Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 20 Oct 2022 20:32:26 -0400 Subject: [PATCH 27/35] Fixing style --- .../raft/solver/coordinate_descent.cuh | 97 +- cpp/include/raft/solver/detail/cd.cuh | 60 +- cpp/include/raft/solver/detail/preprocess.cuh | 222 +-- .../raft/solver/detail/qn/objectives/base.cuh | 2 +- .../solver/detail/qn/objectives/hinge.cuh | 2 +- .../solver/detail/qn/objectives/linear.cuh | 2 +- .../solver/detail/qn/objectives/logistic.cuh | 2 +- .../detail/qn/objectives/regularizer.cuh | 2 +- .../solver/detail/qn/objectives/softmax.cuh | 2 +- .../raft/solver/detail/qn/qn_linesearch.cuh | 2 +- .../raft/solver/detail/qn/qn_solvers.cuh | 2 +- cpp/include/raft/solver/detail/sgd.cuh | 38 +- cpp/include/raft/solver/detail/shuffle.h | 2 +- cpp/include/raft/solver/gradient_descent.cuh | 82 +- .../raft/solver/least_angle_regression.cuh | 33 +- cpp/include/raft/solver/quasi_newton.cuh | 359 +++-- cpp/include/raft/solver/simple_mat.cuh | 1085 +++++++------ cpp/include/raft/solver/solver_types.hpp | 110 +- cpp/test/solver/quasi_newton.cu | 1425 ++++++++--------- 19 files changed, 1786 insertions(+), 1743 deletions(-) diff --git a/cpp/include/raft/solver/coordinate_descent.cuh b/cpp/include/raft/solver/coordinate_descent.cuh index b7c0e1bbc0..aa2086d3b3 100644 --- a/cpp/include/raft/solver/coordinate_descent.cuh +++ b/cpp/include/raft/solver/coordinate_descent.cuh @@ -16,52 +16,69 @@ #pragma once -#include #include #include +#include namespace raft::solver::coordinate_descent { - /** - * @brief Minimizes an objective function using the Coordinate Descent solver. - * - * Note: Currently only least squares loss is supported w/ optional lasso and elastic-net penalties: - * f(coef) = 1/2 * || b - Ax ||^2 - * + 1/2 * alpha * (1 - l1_ratio) * ||coef||^2 - * + alpha * l1_ratio * ||coef||_1 - * - * @param[in] handle: Reference of raft::handle_t - * @param[in] A: Input matrix in column-major format (size of n_rows, n_cols) - * @param[in] b: Input vector of labels (size of n_rows) - * @param[in] sample_weights: Optional input vector for sample weights (size n_rows) - * @param[out] x: Output vector of learned coefficients (size of n_cols) - * @param[out] intercept: Optional scalar to hold intercept if desired - */ - template - void minimize(const raft::handle_t& handle, - raft::device_matrix_view A, - raft::device_vector_view b, - std::optional sample_weights, - raft::device_vector_view x, - std::optional> intercept, - cd_params ¶ms) { - - RAFT_EXPECTS(A.extent(0) == b.extent(0), "Number of labels must match the number of rows in input matrix"); +/** + * @brief Minimizes an objective function using the Coordinate Descent solver. + * + * Note: Currently only least squares loss is supported w/ optional lasso and elastic-net penalties: + * f(coef) = 1/2 * || b - Ax ||^2 + * + 1/2 * alpha * (1 - l1_ratio) * ||coef||^2 + * + alpha * l1_ratio * ||coef||_1 + * + * @param[in] handle: Reference of raft::handle_t + * @param[in] A: Input matrix in column-major format (size of n_rows, n_cols) + * @param[in] b: Input vector of labels (size of n_rows) + * @param[in] sample_weights: Optional input vector for sample weights (size n_rows) + * @param[out] x: Output vector of learned coefficients (size of n_cols) + * @param[out] intercept: Optional scalar to hold intercept if desired + */ +template +void minimize(const raft::handle_t& handle, + raft::device_matrix_view A, + raft::device_vector_view b, + std::optional < raft::device_vector_view sample_weights, + raft::device_vector_view x, + std::optional> intercept, + cd_params& params) +{ + RAFT_EXPECTS(A.extent(0) == b.extent(0), + "Number of labels must match the number of rows in input matrix"); - if(sample_weights.has_value()) { - RAFT_EXPECTS(A.extent(0) == sample_weights.value().extent(0), "Number of sample weights must match number of rows in input matrix"); - } + if (sample_weights.has_value()) { + RAFT_EXPECTS(A.extent(0) == sample_weights.value().extent(0), + "Number of sample weights must match number of rows in input matrix"); + } - RAFT_EXPECTS(x.extent(0) == A.extent(1), "Objective is linear. The number of coefficients must match the number features in the input matrix"); - RAFT_EXPECTS(lossFunct == loss_funct::SQRD_LOSS, "Only squared loss is supported in the current implementation."); + RAFT_EXPECTS(x.extent(0) == A.extent(1), + "Objective is linear. The number of coefficients must match the number features in " + "the input matrix"); + RAFT_EXPECTS(lossFunct == loss_funct::SQRD_LOSS, + "Only squared loss is supported in the current implementation."); - math_t *intercept_ptr = intercept.has_value() ? intercept.value().data_handle() : nullptr; - math_t *sample_weight_ptr = sample_weights.has_value() ? sample_weights.value().data_handle() : nullptr; + math_t* intercept_ptr = intercept.has_value() ? intercept.value().data_handle() : nullptr; + math_t* sample_weight_ptr = + sample_weights.has_value() ? sample_weights.value().data_handle() : nullptr; - detail::cdFit(handle, A.data_handle(), A.extent(0), A.extent(1), - b.data_handle(), x.data_handle(), intercept_ptr, - intercept.has_value(), params.normalize, params.epochs, - params.loss, params.alpha, params.l1_ratio, params.shuffle, - params.tol, sample_weight_ptr); - } -} \ No newline at end of file + detail::cdFit(handle, + A.data_handle(), + A.extent(0), + A.extent(1), + b.data_handle(), + x.data_handle(), + intercept_ptr, + intercept.has_value(), + params.normalize, + params.epochs, + params.loss, + params.alpha, + params.l1_ratio, + params.shuffle, + params.tol, + sample_weight_ptr); +} +} // namespace raft::solver::coordinate_descent \ No newline at end of file diff --git a/cpp/include/raft/solver/detail/cd.cuh b/cpp/include/raft/solver/detail/cd.cuh index db5b20e90b..ad6092c929 100644 --- a/cpp/include/raft/solver/detail/cd.cuh +++ b/cpp/include/raft/solver/detail/cd.cuh @@ -16,15 +16,7 @@ #pragma once -#include -#include -#include -#include -#include -#include -#include #include -#include #include #include #include @@ -38,7 +30,15 @@ #include #include #include +#include +#include +#include +#include +#include +#include #include +#include +#include namespace raft::solver::detail { @@ -175,17 +175,17 @@ void cdFit(const raft::handle_t& handle, if (normalize) { norm2_input.resize(n_cols, stream); } preProcessData(handle, - input, - n_rows, - n_cols, - labels, - intercept, - mu_input.data(), - mu_labels.data(), - norm2_input.data(), - fit_intercept, - normalize, - sample_weight); + input, + n_rows, + n_cols, + labels, + intercept, + mu_input.data(), + mu_labels.data(), + norm2_input.data(), + fit_intercept, + normalize, + sample_weight); } if (sample_weight != nullptr) { raft::linalg::sqrt(sample_weight, sample_weight, n_rows, stream); @@ -296,17 +296,17 @@ void cdFit(const raft::handle_t& handle, if (fit_intercept) { postProcessData(handle, - input, - n_rows, - n_cols, - labels, - coef, - intercept, - mu_input.data(), - mu_labels.data(), - norm2_input.data(), - fit_intercept, - normalize); + input, + n_rows, + n_cols, + labels, + coef, + intercept, + mu_input.data(), + mu_labels.data(), + norm2_input.data(), + fit_intercept, + normalize); } else { *intercept = math_t(0); diff --git a/cpp/include/raft/solver/detail/preprocess.cuh b/cpp/include/raft/solver/detail/preprocess.cuh index 5832f6d1d9..3fd863df42 100644 --- a/cpp/include/raft/solver/detail/preprocess.cuh +++ b/cpp/include/raft/solver/detail/preprocess.cuh @@ -49,123 +49,123 @@ namespace raft::solver::detail { * @param [in] normalize whether to normalize the data * @param [in] stream */ - template - void preProcessData(const raft::handle_t& handle, - math_t* input, - int n_rows, - int n_cols, - math_t* labels, - math_t* intercept, - math_t* mu_input, - math_t* mu_labels, - math_t* norm2_input, - bool fit_intercept, - bool normalize, - math_t* sample_weight = nullptr) - { - cudaStream_t stream = handle.get_stream(); - raft::common::nvtx::range fun_scope("ML::GLM::preProcessData-%d-%d", n_rows, n_cols); - ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one"); - ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two"); +template +void preProcessData(const raft::handle_t& handle, + math_t* input, + int n_rows, + int n_cols, + math_t* labels, + math_t* intercept, + math_t* mu_input, + math_t* mu_labels, + math_t* norm2_input, + bool fit_intercept, + bool normalize, + math_t* sample_weight = nullptr) +{ + cudaStream_t stream = handle.get_stream(); + raft::common::nvtx::range fun_scope("ML::GLM::preProcessData-%d-%d", n_rows, n_cols); + ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one"); + ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two"); - if (fit_intercept) { - if (normalize && sample_weight == nullptr) { - raft::stats::meanvar(mu_input, norm2_input, input, n_cols, n_rows, false, false, stream); - raft::linalg::unaryOp( - norm2_input, - norm2_input, - n_cols, - [] __device__(math_t v) { return raft::mySqrt(v); }, - stream); - raft::matrix::linewiseOp( - input, - input, - n_rows, - n_cols, - false, - [] __device__(math_t x, math_t m, math_t s) { return s > 1e-10 ? (x - m) / s : 0; }, - stream, - mu_input, - norm2_input); - } else { - if (sample_weight != nullptr) { - raft::stats::weightedMean( - mu_input, input, sample_weight, n_cols, n_rows, false, false, stream); - } else { - raft::stats::mean(mu_input, input, n_cols, n_rows, false, false, stream); - } - raft::stats::meanCenter(input, input, mu_input, n_cols, n_rows, false, true, stream); - if (normalize) { - raft::linalg::colNorm(norm2_input, - input, - n_cols, - n_rows, - raft::linalg::L2Norm, - false, - stream, - [] __device__(math_t v) { return raft::mySqrt(v); }); - raft::matrix::matrixVectorBinaryDivSkipZero( - input, norm2_input, n_rows, n_cols, false, true, stream, true); - } - } + if (fit_intercept) { + if (normalize && sample_weight == nullptr) { + raft::stats::meanvar(mu_input, norm2_input, input, n_cols, n_rows, false, false, stream); + raft::linalg::unaryOp( + norm2_input, + norm2_input, + n_cols, + [] __device__(math_t v) { return raft::mySqrt(v); }, + stream); + raft::matrix::linewiseOp( + input, + input, + n_rows, + n_cols, + false, + [] __device__(math_t x, math_t m, math_t s) { return s > 1e-10 ? (x - m) / s : 0; }, + stream, + mu_input, + norm2_input); + } else { + if (sample_weight != nullptr) { + raft::stats::weightedMean( + mu_input, input, sample_weight, n_cols, n_rows, false, false, stream); + } else { + raft::stats::mean(mu_input, input, n_cols, n_rows, false, false, stream); + } + raft::stats::meanCenter(input, input, mu_input, n_cols, n_rows, false, true, stream); + if (normalize) { + raft::linalg::colNorm(norm2_input, + input, + n_cols, + n_rows, + raft::linalg::L2Norm, + false, + stream, + [] __device__(math_t v) { return raft::mySqrt(v); }); + raft::matrix::matrixVectorBinaryDivSkipZero( + input, norm2_input, n_rows, n_cols, false, true, stream, true); + } + } - if (sample_weight != nullptr) { - raft::stats::weightedMean(mu_labels, labels, sample_weight, 1, n_rows, true, false, stream); - } else { - raft::stats::mean(mu_labels, labels, 1, n_rows, false, false, stream); - } - raft::stats::meanCenter(labels, labels, mu_labels, 1, n_rows, false, true, stream); - } - } + if (sample_weight != nullptr) { + raft::stats::weightedMean(mu_labels, labels, sample_weight, 1, n_rows, true, false, stream); + } else { + raft::stats::mean(mu_labels, labels, 1, n_rows, false, false, stream); + } + raft::stats::meanCenter(labels, labels, mu_labels, 1, n_rows, false, true, stream); + } +} - template - void postProcessData(const raft::handle_t& handle, - math_t* input, - int n_rows, - int n_cols, - math_t* labels, - math_t* coef, - math_t* intercept, - math_t* mu_input, - math_t* mu_labels, - math_t* norm2_input, - bool fit_intercept, - bool normalize) - { - cudaStream_t stream = handle.get_stream(); - raft::common::nvtx::range fun_scope("ML::GLM::postProcessData-%d-%d", n_rows, n_cols); - ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one"); - ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two"); +template +void postProcessData(const raft::handle_t& handle, + math_t* input, + int n_rows, + int n_cols, + math_t* labels, + math_t* coef, + math_t* intercept, + math_t* mu_input, + math_t* mu_labels, + math_t* norm2_input, + bool fit_intercept, + bool normalize) +{ + cudaStream_t stream = handle.get_stream(); + raft::common::nvtx::range fun_scope("ML::GLM::postProcessData-%d-%d", n_rows, n_cols); + ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one"); + ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two"); - cublasHandle_t cublas_handle = handle.get_cublas_handle(); - rmm::device_scalar d_intercept(stream); + cublasHandle_t cublas_handle = handle.get_cublas_handle(); + rmm::device_scalar d_intercept(stream); - if (normalize) { - raft::matrix::matrixVectorBinaryDivSkipZero( - coef, norm2_input, 1, n_cols, false, true, stream, true); - } + if (normalize) { + raft::matrix::matrixVectorBinaryDivSkipZero( + coef, norm2_input, 1, n_cols, false, true, stream, true); + } - raft::linalg::gemm( - handle, mu_input, 1, n_cols, coef, d_intercept.data(), 1, 1, CUBLAS_OP_N, CUBLAS_OP_N, stream); + raft::linalg::gemm( + handle, mu_input, 1, n_cols, coef, d_intercept.data(), 1, 1, CUBLAS_OP_N, CUBLAS_OP_N, stream); - raft::linalg::subtract(d_intercept.data(), mu_labels, d_intercept.data(), 1, stream); - *intercept = d_intercept.value(stream); + raft::linalg::subtract(d_intercept.data(), mu_labels, d_intercept.data(), 1, stream); + *intercept = d_intercept.value(stream); - if (normalize) { - raft::matrix::linewiseOp( - input, - input, - n_rows, - n_cols, - false, - [] __device__(math_t x, math_t m, math_t s) { return s * x + m; }, - stream, - mu_input, - norm2_input); - } else { - raft::stats::meanAdd(input, input, mu_input, n_cols, n_rows, false, true, stream); - } - raft::stats::meanAdd(labels, labels, mu_labels, 1, n_rows, false, true, stream); - } + if (normalize) { + raft::matrix::linewiseOp( + input, + input, + n_rows, + n_cols, + false, + [] __device__(math_t x, math_t m, math_t s) { return s * x + m; }, + stream, + mu_input, + norm2_input); + } else { + raft::stats::meanAdd(input, input, mu_input, n_cols, n_rows, false, true, stream); + } + raft::stats::meanAdd(labels, labels, mu_labels, 1, n_rows, false, true, stream); +} - }; // end namespace raft::solver::detail \ No newline at end of file +}; // end namespace raft::solver::detail \ No newline at end of file diff --git a/cpp/include/raft/solver/detail/qn/objectives/base.cuh b/cpp/include/raft/solver/detail/qn/objectives/base.cuh index b3d60637a1..288bfac1c8 100644 --- a/cpp/include/raft/solver/detail/qn/objectives/base.cuh +++ b/cpp/include/raft/solver/detail/qn/objectives/base.cuh @@ -16,11 +16,11 @@ #pragma once -#include #include #include #include #include +#include #include #include #include diff --git a/cpp/include/raft/solver/detail/qn/objectives/hinge.cuh b/cpp/include/raft/solver/detail/qn/objectives/hinge.cuh index dcab7543fc..d90c30dc1c 100644 --- a/cpp/include/raft/solver/detail/qn/objectives/hinge.cuh +++ b/cpp/include/raft/solver/detail/qn/objectives/hinge.cuh @@ -16,9 +16,9 @@ #pragma once -#include #include "base.cuh" #include +#include #include namespace raft::solver::quasi_newton::detail::objectives { diff --git a/cpp/include/raft/solver/detail/qn/objectives/linear.cuh b/cpp/include/raft/solver/detail/qn/objectives/linear.cuh index af6094e00f..dfaf83abf0 100644 --- a/cpp/include/raft/solver/detail/qn/objectives/linear.cuh +++ b/cpp/include/raft/solver/detail/qn/objectives/linear.cuh @@ -16,9 +16,9 @@ #pragma once -#include #include "base.cuh" #include +#include #include namespace raft::solver::quasi_newton::detail::objectives { diff --git a/cpp/include/raft/solver/detail/qn/objectives/logistic.cuh b/cpp/include/raft/solver/detail/qn/objectives/logistic.cuh index acb7c1ac55..ed52069bc6 100644 --- a/cpp/include/raft/solver/detail/qn/objectives/logistic.cuh +++ b/cpp/include/raft/solver/detail/qn/objectives/logistic.cuh @@ -16,9 +16,9 @@ #pragma once -#include #include "base.cuh" #include +#include #include namespace raft::solver::quasi_newton::detail::objectives { diff --git a/cpp/include/raft/solver/detail/qn/objectives/regularizer.cuh b/cpp/include/raft/solver/detail/qn/objectives/regularizer.cuh index 7bb509a934..68c79ab15d 100644 --- a/cpp/include/raft/solver/detail/qn/objectives/regularizer.cuh +++ b/cpp/include/raft/solver/detail/qn/objectives/regularizer.cuh @@ -16,10 +16,10 @@ #pragma once -#include #include "base.cuh" #include #include +#include #include #include #include diff --git a/cpp/include/raft/solver/detail/qn/objectives/softmax.cuh b/cpp/include/raft/solver/detail/qn/objectives/softmax.cuh index 2e53881c2a..74b78e1158 100644 --- a/cpp/include/raft/solver/detail/qn/objectives/softmax.cuh +++ b/cpp/include/raft/solver/detail/qn/objectives/softmax.cuh @@ -16,9 +16,9 @@ #pragma once -#include #include "base.cuh" #include +#include #include namespace raft::solver::quasi_newton::detail::objectives { diff --git a/cpp/include/raft/solver/detail/qn/qn_linesearch.cuh b/cpp/include/raft/solver/detail/qn/qn_linesearch.cuh index 26445fbed9..28e37da2fb 100644 --- a/cpp/include/raft/solver/detail/qn/qn_linesearch.cuh +++ b/cpp/include/raft/solver/detail/qn/qn_linesearch.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * Copyright (c) 2018-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/solver/detail/qn/qn_solvers.cuh b/cpp/include/raft/solver/detail/qn/qn_solvers.cuh index 833c1170ef..a9f26096cd 100644 --- a/cpp/include/raft/solver/detail/qn/qn_solvers.cuh +++ b/cpp/include/raft/solver/detail/qn/qn_solvers.cuh @@ -42,8 +42,8 @@ #include "qn_linesearch.cuh" #include "qn_util.cuh" -#include #include +#include #include #include diff --git a/cpp/include/raft/solver/detail/sgd.cuh b/cpp/include/raft/solver/detail/sgd.cuh index e7a41f7b27..8a5372dc33 100644 --- a/cpp/include/raft/solver/detail/sgd.cuh +++ b/cpp/include/raft/solver/detail/sgd.cuh @@ -16,15 +16,6 @@ #pragma once -#include -#include -#include -#include -#include -#include -#include -#include -#include #include #include #include @@ -33,8 +24,17 @@ #include #include #include +#include +#include +#include +#include +#include +#include +#include #include #include +#include +#include #include namespace raft::solver::detail { @@ -124,16 +124,16 @@ void sgdFit(const raft::handle_t& handle, mu_labels.resize(1, stream); preProcessData(handle, - input, - n_rows, - n_cols, - labels, - intercept, - mu_input.data(), - mu_labels.data(), - norm2_input.data(), - fit_intercept, - false); + input, + n_rows, + n_cols, + labels, + intercept, + mu_input.data(), + mu_labels.data(), + norm2_input.data(), + fit_intercept, + false); } rmm::device_uvector grads(n_cols, stream); diff --git a/cpp/include/raft/solver/detail/shuffle.h b/cpp/include/raft/solver/detail/shuffle.h index 4c131163c3..1a815822b4 100644 --- a/cpp/include/raft/solver/detail/shuffle.h +++ b/cpp/include/raft/solver/detail/shuffle.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * Copyright (c) 2018-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/solver/gradient_descent.cuh b/cpp/include/raft/solver/gradient_descent.cuh index 9ba8b16d3a..5188d81ae2 100644 --- a/cpp/include/raft/solver/gradient_descent.cuh +++ b/cpp/include/raft/solver/gradient_descent.cuh @@ -16,42 +16,60 @@ #pragma once - -#include #include -#include +#include #include +#include namespace raft::solver::gradient_descent { - /** - * @brief Minimizes an objective function using the Gradient Descent solver and optional - * lasso or elastic-net penalties. - * - * @param[in] handle: Reference of raft::handle_t - * @param[in] A: Input matrix in column-major format (size of n_rows, n_cols) - * @param[in] b: Input vector of labels (size of n_rows) - * @param[out] x: Output vector of coefficients (size of n_cols) - * @param[out] intercept: Optional scalar if fitting the intercept - * @param[in] params: solver hyper-parameters - */ - template - void minimize(const raft::handle_t& handle, - raft::device_matrix_view A, - raft::device_vector_view b, - raft::device_vector_view x, - std::optional intercept, - sgd_params ¶ms) { - - RAFT_EXPECTS(A.extent(0) == b.extent(0), "Number of labels must match the number of rows in input matrix"); - RAFT_EXPECTS(x.extent(0) == A.extent(1), "Objective is linear. The number of coefficients must match the number features in the input matrix"); - - auto intercept_ptr = intercept.has_value() ? intercept.data_handle() ? nullptr; - detail::sgdFit(handle, A.data_handle(), A.extent(0), A.extent(1), b.data_handle(), x.data_handle(), - intercept_ptr, intercept.has_value(), params.batch_size, params.epochs, params.lr_type, - params.eta0, params.power_t, params.loss, params.penalty, params.alpha, params.l1_ratio, - params.shuffle, params.tol, params.n_iter_no_change, handle.get_stream()); +/** + * @brief Minimizes an objective function using the Gradient Descent solver and optional + * lasso or elastic-net penalties. + * + * @param[in] handle: Reference of raft::handle_t + * @param[in] A: Input matrix in column-major format (size of n_rows, n_cols) + * @param[in] b: Input vector of labels (size of n_rows) + * @param[out] x: Output vector of coefficients (size of n_cols) + * @param[out] intercept: Optional scalar if fitting the intercept + * @param[in] params: solver hyper-parameters + */ +template +void minimize(const raft::handle_t& handle, + raft::device_matrix_view A, + raft::device_vector_view b, + raft::device_vector_view x, + std::optional < raft::device_scalar_view intercept, + sgd_params& params) +{ + RAFT_EXPECTS(A.extent(0) == b.extent(0), + "Number of labels must match the number of rows in input matrix"); + RAFT_EXPECTS(x.extent(0) == A.extent(1), + "Objective is linear. The number of coefficients must match the number features in " + "the input matrix"); - } + auto intercept_ptr = intercept.has_value() ? intercept.data_handle() ? nullptr; + detail::sgdFit(handle, + A.data_handle(), + A.extent(0), + A.extent(1), + b.data_handle(), + x.data_handle(), + intercept_ptr, + intercept.has_value(), + params.batch_size, + params.epochs, + params.lr_type, + params.eta0, + params.power_t, + params.loss, + params.penalty, + params.alpha, + params.l1_ratio, + params.shuffle, + params.tol, + params.n_iter_no_change, + handle.get_stream()); +} -} \ No newline at end of file +} // namespace raft::solver::gradient_descent \ No newline at end of file diff --git a/cpp/include/raft/solver/least_angle_regression.cuh b/cpp/include/raft/solver/least_angle_regression.cuh index c449d51e20..6484e40d7d 100644 --- a/cpp/include/raft/solver/least_angle_regression.cuh +++ b/cpp/include/raft/solver/least_angle_regression.cuh @@ -16,11 +16,11 @@ #pragma once -#include #include +#include #include -#include #include +#include namespace raft::solver::least_angle_regression { @@ -79,18 +79,17 @@ namespace raft::solver::least_angle_regression { */ template void minimize(const raft::handle_t& handle, - raft::device_matrix_view A, - raft::device_vector_view b, - std::optional> Gram, - raft::device_vector_view x, - raft::device_vector_view active_idx, - raft::device_vector_view alphas, - raft::host_scalar_view n_active, - std::optional> coef_path, - lars_params ¶ms, - idx_t ld_X = 0, - idx_t ld_G = 0) { - - - } -} \ No newline at end of file + raft::device_matrix_view A, + raft::device_vector_view b, + std::optional> Gram, + raft::device_vector_view x, + raft::device_vector_view active_idx, + raft::device_vector_view alphas, + raft::host_scalar_view n_active, + std::optional> coef_path, + lars_params& params, + idx_t ld_X = 0, + idx_t ld_G = 0) +{ +} +} // namespace raft::solver::least_angle_regression \ No newline at end of file diff --git a/cpp/include/raft/solver/quasi_newton.cuh b/cpp/include/raft/solver/quasi_newton.cuh index d16839a506..dee73a010b 100644 --- a/cpp/include/raft/solver/quasi_newton.cuh +++ b/cpp/include/raft/solver/quasi_newton.cuh @@ -23,8 +23,8 @@ #include #include -#include #include +#include #include namespace raft::solver::quasi_newton { @@ -45,16 +45,17 @@ struct AbsLoss : detail::objectives::AbsLoss { } } - /** - * Squared loss function specification - * @tparam T - */ - template - struct SquaredLoss : detail::objectives::SquaredLoss { - SquaredLoss(const raft::handle_t &handle, int D, bool has_bias) - : detail::objectives::SquaredLoss(handle, D, 1, has_bias), lz{}, dlz{} {} - } - +/** + * Squared loss function specification + * @tparam T + */ +template +struct SquaredLoss : detail::objectives::SquaredLoss { + SquaredLoss(const raft::handle_t& handle, int D, bool has_bias) + : detail::objectives::SquaredLoss(handle, D, 1, has_bias), lz{}, dlz{} + { + } +} /** * Standard hinge loss function specification @@ -92,190 +93,196 @@ struct SqHingeLoss : detail::objectives::SqHingeLoss { } } + /** + * Epsilon insensitive (regression) hinge loss function specification + * @tparam T + */ + template + struct EpsInsHingeLoss : detail::objectives::EpsInsHingeLoss > { + EpsInsHingeLoss(const raft::handle_t& handle, int D, bool has_bias, T sensitivity) + : detail::objectives::EpsInsHingeLoss(handle, D, 1, has_bias), lz{sensitivity}, dlz{sensitivity} + { + } +} + /** - * Epsilon insensitive (regression) hinge loss function specification + * Squared Epsilon insensitive (regression) hinge loss function specification * @tparam T */ template -struct EpsInsHingeLoss : detail::objectives::EpsInsHingeLoss> { - EpsInsHingeLoss(const raft::handle_t &handle, int D, bool has_bias, T sensitivity) - : detail::objectives::EpsInsHingeLoss(handle, D, 1, has_bias), lz{sensitivity}, dlz{sensitivity} { - } -} +struct SqEpsInsHingeLoss : detail::objectives::SqEpsInsHingeLoss { + SqEpsInsHingeLoss(const raft::handle_t& handle, int D, bool has_bias, T sensitivity) + : detail::objectives::SqEpsInsHingeLoss(handle, D, 1, has_bias), + lz{sensitivity}, + dlz{sensitivity} + { + } +}; /** - * Squared Epsilon insensitive (regression) hinge loss function specification + * Tikhonov (l2) penalty function * @tparam T */ - template - struct SqEpsInsHingeLoss : detail::objectives::SqEpsInsHingeLoss { - SqEpsInsHingeLoss(const raft::handle_t &handle, int D, bool has_bias, T sensitivity) - : detail::objectives::SqEpsInsHingeLoss(handle, D, 1, has_bias), lz{sensitivity}, dlz{sensitivity} { - } - }; - - /** - * Tikhonov (l2) penalty function - * @tparam T - */ template struct Tikhonov : detail::objectives::Tikhonov { - Tikhonov(T l2) : detail::objectives::Tikhonov(l2) {} + Tikhonov(T l2) : detail::objectives::Tikhonov(l2) {} - Tikhonov(const Tikhonov &other) : detail::objectives::Tikhonov(other.l2_penalty) {} + Tikhonov(const Tikhonov& other) : detail::objectives::Tikhonov(other.l2_penalty) {} }; +/** + * Loss function wrapper that add a penalty to another loss function + * + * Example: + * + * raft::handle_t handle; + * AbsLoss abs_loss(handle, 5, true); + * Tikhonov l2_reg(0.3); + * RegularizedQN(&abs_loss, ®); + * + * @tparam T + * @tparam Loss + * @tparam Reg + */ +template +class RegularizedQN : public detail::objectives::RegularizedQN { + RegularizedQN(Loss* loss, Reg* reg) : detail::objectives::RegularizedQN(loss, reg) {} +}; - - /** - * Loss function wrapper that add a penalty to another loss function - * - * Example: - * - * raft::handle_t handle; - * AbsLoss abs_loss(handle, 5, true); - * Tikhonov l2_reg(0.3); - * RegularizedQN(&abs_loss, ®); - * - * @tparam T - * @tparam Loss - * @tparam Reg - */ - template - class RegularizedQN : public detail::objectives::RegularizedQN { - RegularizedQN(Loss* loss, Reg* reg) : detail::objectives::RegularizedQN(loss, reg) {} - }; - - /** - * Base loss function that constrains the solution to a linear system - * @tparam T - * @tparam Loss - */ - template - struct QNLinearBase : detail::objectives::QNLinearBase { - QNLinearBase(const raft::handle_t& handle, int D, int C, bool fit_intercept) - : detail::objectives::QNLinearBase(C, D, fit_intercept) - { - } - } - - /** - * Softmax loss function specification - * @tparam T - */ - template - struct Softmax : detail::objectives::Softmax { - Softmax(const raft::handle_t &handle, int D, int C, bool has_bias) - : detail::objectives::Softmax(handle, D, C, has_bias) { - } +/** + * Base loss function that constrains the solution to a linear system + * @tparam T + * @tparam Loss + */ +template +struct QNLinearBase : detail::objectives::QNLinearBase { + QNLinearBase(const raft::handle_t& handle, int D, int C, bool fit_intercept) + : detail::objectives::QNLinearBase(C, D, fit_intercept) + { } +} - /** - * Constructs a end-to-end quasi-newton objective function to solve the system - * AX = b (where each row in X contains the coefficients for each target) - * - * Example: - * - * @tparam T - * @tparam QuasiNewtonObjective - */ - template - struct ObjectiveWithData : detail::objectives::QNWithData { - - ObjectiveWithData(QuasiNewtonObjective *obj, - const SimpleMat &A, - const SimpleVec &b, - SimpleDenseMat &X) - : detail::objectives::QNWithData(obj->C, obj->D, obj->fit_intercept) { - } +/** + * Softmax loss function specification + * @tparam T + */ +template +struct Softmax : detail::objectives::Softmax { + Softmax(const raft::handle_t& handle, int D, int C, bool has_bias) + : detail::objectives::Softmax(handle, D, C, has_bias) + { } +} - /** - * @brief Minimize the given `raft::solver::quasi_newton::ObjectiveWithData` using - * the Limited-Memory Broyden-Fletcher-Goldfarb-Shanno algorithm. This algorithm - * estimates the inverse of the Hessian matrix, minimizing the memory footprint from - * the original BFGS algorithm by maintaining only a subset of the update history. - * - * @tparam T - * @tparam Function - * @param param - * @param f - * @param x - * @param fx - * @param k - * @param workspace - * @param stream - * @param verbosity - * @return - */ - template - OPT_RETCODE lbfgs_minimize(raft::handle_t &handle, - const LBFGSParam& param, - Function& f, // function to minimize - SimpleVec& x, // initial point, holds result - T& fx, // output function value - int* k) { // output iterations - rmm::device_uvector tmp(detail::lbfgs_workspace_size(param, x.len), handle.get_stream()); - SimpleVec workspace(tmp.data(), tmp.size()); - return detail::min_lbfgs(param, f, x, fx, k, workspace, handle.get_stream(), 0); +/** + * Constructs a end-to-end quasi-newton objective function to solve the system + * AX = b (where each row in X contains the coefficients for each target) + * + * Example: + * + * @tparam T + * @tparam QuasiNewtonObjective + */ +template +struct ObjectiveWithData : detail::objectives::QNWithData { + ObjectiveWithData(QuasiNewtonObjective* obj, + const SimpleMat& A, + const SimpleVec& b, + SimpleDenseMat& X) + : detail::objectives::QNWithData(obj->C, obj->D, obj->fit_intercept) + { } +} - /** - * @brief Minimize the given `ObjectiveWithData` using the Orthant-wise - * Limited-Memory Quasi-Newton algorithm, an L-BFGS variant for fitting - * models with lasso (l1) penalties, enabling it to exploit the sparsity - * of the models. - * - * @tparam T - * @tparam Function - * @param param - * @param f - * @param l1_penalty - * @param pg_limit - * @param x - * @param fx - * @param k - * @return - */ - template - OPT_RETCODE owl_minimize(raft::handle_t &handle, - const LBFGSParam& param, - Function& f, - const T l1_penalty, - const int pg_limit, - SimpleVec& x, - T& fx, - int* k) { - rmm::device_uvector tmp(detail::owlqn_workspace_size(opt_param, x.len), stream); - SimpleVec workspace(tmp.data(), tmp.size()); - return detail::min_owlqn(param, f, l1_penalty, pg_limit, x, fx, k, workspace, handle.get_stream(), 0); - } - +/** + * @brief Minimize the given `raft::solver::quasi_newton::ObjectiveWithData` using + * the Limited-Memory Broyden-Fletcher-Goldfarb-Shanno algorithm. This algorithm + * estimates the inverse of the Hessian matrix, minimizing the memory footprint from + * the original BFGS algorithm by maintaining only a subset of the update history. + * + * @tparam T + * @tparam Function + * @param param + * @param f + * @param x + * @param fx + * @param k + * @param workspace + * @param stream + * @param verbosity + * @return + */ +template +OPT_RETCODE lbfgs_minimize(raft::handle_t& handle, + const LBFGSParam& param, + Function& f, // function to minimize + SimpleVec& x, // initial point, holds result + T& fx, // output function value + int* k) +{ // output iterations + rmm::device_uvector tmp(detail::lbfgs_workspace_size(param, x.len), handle.get_stream()); + SimpleVec workspace(tmp.data(), tmp.size()); + return detail::min_lbfgs(param, f, x, fx, k, workspace, handle.get_stream(), 0); +} - /** - * @brief Simple wrapper function that chooses the quasi-newton solver to use - * based on the presence of the L1 penalty term. - * @tparam T - * @tparam LossFunction - * @param handle - * @param x - * @param fx - * @param num_iters - * @param loss - * @param l1 - * @param opt_param - * @return - */ - template - inline int minimize(const raft::handle_t& handle, +/** + * @brief Minimize the given `ObjectiveWithData` using the Orthant-wise + * Limited-Memory Quasi-Newton algorithm, an L-BFGS variant for fitting + * models with lasso (l1) penalties, enabling it to exploit the sparsity + * of the models. + * + * @tparam T + * @tparam Function + * @param param + * @param f + * @param l1_penalty + * @param pg_limit + * @param x + * @param fx + * @param k + * @return + */ +template +OPT_RETCODE owl_minimize(raft::handle_t& handle, + const LBFGSParam& param, + Function& f, + const T l1_penalty, + const int pg_limit, SimpleVec& x, - T* fx, - int* num_iters, - LossFunction& loss, - const T l1, - const LBFGSParam& opt_param, - cudaStream_t stream, - const int verbosity = 0) { - return detail::qn_minimize(handle, x, fx, num_iters, loss, l1, opt_param, handle.get_stream(), 0); - } -} \ No newline at end of file + T& fx, + int* k) +{ + rmm::device_uvector tmp(detail::owlqn_workspace_size(opt_param, x.len), stream); + SimpleVec workspace(tmp.data(), tmp.size()); + return detail::min_owlqn( + param, f, l1_penalty, pg_limit, x, fx, k, workspace, handle.get_stream(), 0); +} + +/** + * @brief Simple wrapper function that chooses the quasi-newton solver to use + * based on the presence of the L1 penalty term. + * @tparam T + * @tparam LossFunction + * @param handle + * @param x + * @param fx + * @param num_iters + * @param loss + * @param l1 + * @param opt_param + * @return + */ +template +inline int minimize(const raft::handle_t& handle, + SimpleVec& x, + T* fx, + int* num_iters, + LossFunction& loss, + const T l1, + const LBFGSParam& opt_param, + cudaStream_t stream, + const int verbosity = 0) +{ + return detail::qn_minimize(handle, x, fx, num_iters, loss, l1, opt_param, handle.get_stream(), 0); +} +} // namespace raft::solver::quasi_newton \ No newline at end of file diff --git a/cpp/include/raft/solver/simple_mat.cuh b/cpp/include/raft/solver/simple_mat.cuh index 1a80dbc78e..5d20e171dd 100644 --- a/cpp/include/raft/solver/simple_mat.cuh +++ b/cpp/include/raft/solver/simple_mat.cuh @@ -36,411 +36,411 @@ namespace raft::solver { - template - struct SimpleMat { - int m, n; - - SimpleMat(int m, int n) : m(m), n(n) {} - - void operator=(const SimpleMat& other) = delete; - - virtual void print(std::ostream& oss) const = 0; - - /** - * GEMM assigning to C where `this` refers to B. - * - * ``` - * C <- alpha * A^transA * (*this)^transB + beta * C - * ``` - */ - virtual void gemmb(const raft::handle_t& handle, - const T alpha, - const SimpleDenseMat& A, - const bool transA, - const bool transB, - const T beta, - SimpleDenseMat& C, - cudaStream_t stream) const = 0; - }; - - enum STORAGE_ORDER { COL_MAJOR = 0, ROW_MAJOR = 1 }; - - template - struct SimpleDenseMat : SimpleMat { - typedef SimpleMat Super; - int len; - T* data; - - STORAGE_ORDER ord; // storage order: runtime param for compile time sake - - SimpleDenseMat(STORAGE_ORDER order = COL_MAJOR) : Super(0, 0), data(nullptr), len(0), ord(order) - { - } - - SimpleDenseMat(T* data, int m, int n, STORAGE_ORDER order = COL_MAJOR) - : Super(m, n), data(data), len(m * n), ord(order) - { - } - - void reset(T* data_, int m_, int n_) - { - this->m = m_; - this->n = n_; - data = data_; - len = m_ * n_; - } - - // Implemented GEMM as a static method here to improve readability - inline static void gemm(const raft::handle_t& handle, - const T alpha, - const SimpleDenseMat& A, - const bool transA, - const SimpleDenseMat& B, - const bool transB, - const T beta, - SimpleDenseMat& C, - cudaStream_t stream) - { - int kA = A.n; - int kB = B.m; - - if (transA) { - ASSERT(A.n == C.m, "GEMM invalid dims: m"); - kA = A.m; - } else { - ASSERT(A.m == C.m, "GEMM invalid dims: m"); - } - - if (transB) { - ASSERT(B.m == C.n, "GEMM invalid dims: n"); - kB = B.n; - } else { - ASSERT(B.n == C.n, "GEMM invalid dims: n"); - } - ASSERT(kA == kB, "GEMM invalid dims: k"); - - if (A.ord == COL_MAJOR && B.ord == COL_MAJOR && C.ord == COL_MAJOR) { - // #TODO: Call from public API when ready - raft::linalg::detail::cublasgemm(handle.get_cublas_handle(), // handle - transA ? CUBLAS_OP_T : CUBLAS_OP_N, // transA - transB ? CUBLAS_OP_T : CUBLAS_OP_N, // transB - C.m, - C.n, - kA, // dimensions m,n,k - &alpha, - A.data, - A.m, // lda - B.data, - B.m, // ldb - &beta, - C.data, - C.m, // ldc, - stream); - return; - } - if (A.ord == ROW_MAJOR) { - const SimpleDenseMat Acm(A.data, A.n, A.m, COL_MAJOR); - gemm(handle, alpha, Acm, !transA, B, transB, beta, C, stream); - return; - } - if (B.ord == ROW_MAJOR) { - const SimpleDenseMat Bcm(B.data, B.n, B.m, COL_MAJOR); - gemm(handle, alpha, A, transA, Bcm, !transB, beta, C, stream); - return; - } - if (C.ord == ROW_MAJOR) { - SimpleDenseMat Ccm(C.data, C.n, C.m, COL_MAJOR); - gemm(handle, alpha, B, !transB, A, !transA, beta, Ccm, stream); - return; - } - } - - inline void gemmb(const raft::handle_t& handle, +template +struct SimpleMat { + int m, n; + + SimpleMat(int m, int n) : m(m), n(n) {} + + void operator=(const SimpleMat& other) = delete; + + virtual void print(std::ostream& oss) const = 0; + + /** + * GEMM assigning to C where `this` refers to B. + * + * ``` + * C <- alpha * A^transA * (*this)^transB + beta * C + * ``` + */ + virtual void gemmb(const raft::handle_t& handle, + const T alpha, + const SimpleDenseMat& A, + const bool transA, + const bool transB, + const T beta, + SimpleDenseMat& C, + cudaStream_t stream) const = 0; +}; + +enum STORAGE_ORDER { COL_MAJOR = 0, ROW_MAJOR = 1 }; + +template +struct SimpleDenseMat : SimpleMat { + typedef SimpleMat Super; + int len; + T* data; + + STORAGE_ORDER ord; // storage order: runtime param for compile time sake + + SimpleDenseMat(STORAGE_ORDER order = COL_MAJOR) : Super(0, 0), data(nullptr), len(0), ord(order) + { + } + + SimpleDenseMat(T* data, int m, int n, STORAGE_ORDER order = COL_MAJOR) + : Super(m, n), data(data), len(m * n), ord(order) + { + } + + void reset(T* data_, int m_, int n_) + { + this->m = m_; + this->n = n_; + data = data_; + len = m_ * n_; + } + + // Implemented GEMM as a static method here to improve readability + inline static void gemm(const raft::handle_t& handle, const T alpha, const SimpleDenseMat& A, const bool transA, + const SimpleDenseMat& B, const bool transB, const T beta, SimpleDenseMat& C, - cudaStream_t stream) const override - { - SimpleDenseMat::gemm(handle, alpha, A, transA, *this, transB, beta, C, stream); - } - - /** - * GEMM assigning to C where `this` refers to C. - * - * ``` - * *this <- alpha * A^transA * B^transB + beta * (*this) - * ``` - */ - inline void assign_gemm(const raft::handle_t& handle, - const T alpha, - const SimpleDenseMat& A, - const bool transA, - const SimpleMat& B, - const bool transB, - const T beta, - cudaStream_t stream) - { - B.gemmb(handle, alpha, A, transA, transB, beta, *this, stream); - } - - // this = a*x - inline void ax(const T a, const SimpleDenseMat& x, cudaStream_t stream) - { - ASSERT(ord == x.ord, "SimpleDenseMat::ax: Storage orders must match"); - - auto scale = [a] __device__(const T x) { return a * x; }; - raft::linalg::unaryOp(data, x.data, len, scale, stream); - } - - // this = a*x + y - inline void axpy(const T a, - const SimpleDenseMat& x, - const SimpleDenseMat& y, - cudaStream_t stream) - { - ASSERT(ord == x.ord, "SimpleDenseMat::axpy: Storage orders must match"); - ASSERT(ord == y.ord, "SimpleDenseMat::axpy: Storage orders must match"); - - auto axpy = [a] __device__(const T x, const T y) { return a * x + y; }; - raft::linalg::binaryOp(data, x.data, y.data, len, axpy, stream); - } - - template - inline void assign_unary(const SimpleDenseMat& other, Lambda f, cudaStream_t stream) - { - ASSERT(ord == other.ord, "SimpleDenseMat::assign_unary: Storage orders must match"); - - raft::linalg::unaryOp(data, other.data, len, f, stream); - } - - template - inline void assign_binary(const SimpleDenseMat& other1, - const SimpleDenseMat& other2, - Lambda& f, - cudaStream_t stream) - { - ASSERT(ord == other1.ord, "SimpleDenseMat::assign_binary: Storage orders must match"); - ASSERT(ord == other2.ord, "SimpleDenseMat::assign_binary: Storage orders must match"); - - raft::linalg::binaryOp(data, other1.data, other2.data, len, f, stream); - } - - template - inline void assign_ternary(const SimpleDenseMat& other1, - const SimpleDenseMat& other2, - const SimpleDenseMat& other3, - Lambda& f, - cudaStream_t stream) - { - ASSERT(ord == other1.ord, "SimpleDenseMat::assign_ternary: Storage orders must match"); - ASSERT(ord == other2.ord, "SimpleDenseMat::assign_ternary: Storage orders must match"); - ASSERT(ord == other3.ord, "SimpleDenseMat::assign_ternary: Storage orders must match"); - - raft::linalg::ternaryOp(data, other1.data, other2.data, other3.data, len, f, stream); - } - - inline void fill(const T val, cudaStream_t stream) - { - // TODO this reads data unnecessary, though it's mostly used for testing - auto f = [val] __device__(const T x) { return val; }; - raft::linalg::unaryOp(data, data, len, f, stream); - } - - inline void copy_async(const SimpleDenseMat& other, cudaStream_t stream) - { - ASSERT((ord == other.ord) && (this->m == other.m) && (this->n == other.n), - "SimpleDenseMat::copy: matrices not compatible"); - - RAFT_CUDA_TRY( - cudaMemcpyAsync(data, other.data, len * sizeof(T), cudaMemcpyDeviceToDevice, stream)); - } - - void print(std::ostream& oss) const override { oss << (*this) << std::endl; } - - void operator=(const SimpleDenseMat& other) = delete; - }; - - template - struct SimpleVec : SimpleDenseMat { - typedef SimpleDenseMat Super; - - SimpleVec(T* data, const int n) : Super(data, n, 1, COL_MAJOR) {} - // this = alpha * A * x + beta * this - void assign_gemv(const raft::handle_t& handle, - const T alpha, - const SimpleDenseMat& A, - bool transA, - const SimpleVec& x, - const T beta, - cudaStream_t stream) - { - Super::assign_gemm(handle, alpha, A, transA, x, false, beta, stream); - } - - SimpleVec() : Super(COL_MAJOR) {} - - inline void reset(T* new_data, int n) { Super::reset(new_data, n, 1); } - }; - - template - inline void col_ref(const SimpleDenseMat& mat, SimpleVec& mask_vec, int c) - { - ASSERT(mat.ord == COL_MAJOR, "col_ref only available for column major mats"); - T* tmp = &mat.data[mat.m * c]; - mask_vec.reset(tmp, mat.m); + cudaStream_t stream) + { + int kA = A.n; + int kB = B.m; + + if (transA) { + ASSERT(A.n == C.m, "GEMM invalid dims: m"); + kA = A.m; + } else { + ASSERT(A.m == C.m, "GEMM invalid dims: m"); } - template - inline void col_slice(const SimpleDenseMat& mat, - SimpleDenseMat& mask_mat, - int c_from, - int c_to) - { - ASSERT(c_from >= 0 && c_from < mat.n, "col_slice: invalid from"); - ASSERT(c_to >= 0 && c_to <= mat.n, "col_slice: invalid to"); - - ASSERT(mat.ord == COL_MAJOR, "col_ref only available for column major mats"); - ASSERT(mask_mat.ord == COL_MAJOR, "col_ref only available for column major mask"); - T* tmp = &mat.data[mat.m * c_from]; - mask_mat.reset(tmp, mat.m, c_to - c_from); + if (transB) { + ASSERT(B.m == C.n, "GEMM invalid dims: n"); + kB = B.n; + } else { + ASSERT(B.n == C.n, "GEMM invalid dims: n"); } - -// Reductions such as dot or norm require an additional location in dev mem -// to hold the result. We don't want to deal with this in the SimpleVec class -// as it impedes thread safety and constness - - template - inline T dot(const SimpleVec& u, const SimpleVec& v, T* tmp_dev, cudaStream_t stream) - { - auto f = [] __device__(const T x, const T y) { return x * y; }; - raft::linalg::mapThenSumReduce(tmp_dev, u.len, f, stream, u.data, v.data); - T tmp_host; - raft::update_host(&tmp_host, tmp_dev, 1, stream); - - raft::interruptible::synchronize(stream); - return tmp_host; + ASSERT(kA == kB, "GEMM invalid dims: k"); + + if (A.ord == COL_MAJOR && B.ord == COL_MAJOR && C.ord == COL_MAJOR) { + // #TODO: Call from public API when ready + raft::linalg::detail::cublasgemm(handle.get_cublas_handle(), // handle + transA ? CUBLAS_OP_T : CUBLAS_OP_N, // transA + transB ? CUBLAS_OP_T : CUBLAS_OP_N, // transB + C.m, + C.n, + kA, // dimensions m,n,k + &alpha, + A.data, + A.m, // lda + B.data, + B.m, // ldb + &beta, + C.data, + C.m, // ldc, + stream); + return; } - - template - inline T squaredNorm(const SimpleVec& u, T* tmp_dev, cudaStream_t stream) - { - return dot(u, u, tmp_dev, stream); + if (A.ord == ROW_MAJOR) { + const SimpleDenseMat Acm(A.data, A.n, A.m, COL_MAJOR); + gemm(handle, alpha, Acm, !transA, B, transB, beta, C, stream); + return; } - - template - inline T nrmMax(const SimpleVec& u, T* tmp_dev, cudaStream_t stream) - { - auto f = [] __device__(const T x) { return raft::myAbs(x); }; - auto r = [] __device__(const T x, const T y) { return raft::myMax(x, y); }; - raft::linalg::mapThenReduce(tmp_dev, u.len, T(0), f, r, stream, u.data); - T tmp_host; - raft::update_host(&tmp_host, tmp_dev, 1, stream); - raft::interruptible::synchronize(stream); - return tmp_host; + if (B.ord == ROW_MAJOR) { + const SimpleDenseMat Bcm(B.data, B.n, B.m, COL_MAJOR); + gemm(handle, alpha, A, transA, Bcm, !transB, beta, C, stream); + return; } - - template - inline T nrm2(const SimpleVec& u, T* tmp_dev, cudaStream_t stream) - { - return raft::mySqrt(squaredNorm(u, tmp_dev, stream)); + if (C.ord == ROW_MAJOR) { + SimpleDenseMat Ccm(C.data, C.n, C.m, COL_MAJOR); + gemm(handle, alpha, B, !transB, A, !transA, beta, Ccm, stream); + return; } + } + + inline void gemmb(const raft::handle_t& handle, + const T alpha, + const SimpleDenseMat& A, + const bool transA, + const bool transB, + const T beta, + SimpleDenseMat& C, + cudaStream_t stream) const override + { + SimpleDenseMat::gemm(handle, alpha, A, transA, *this, transB, beta, C, stream); + } + + /** + * GEMM assigning to C where `this` refers to C. + * + * ``` + * *this <- alpha * A^transA * B^transB + beta * (*this) + * ``` + */ + inline void assign_gemm(const raft::handle_t& handle, + const T alpha, + const SimpleDenseMat& A, + const bool transA, + const SimpleMat& B, + const bool transB, + const T beta, + cudaStream_t stream) + { + B.gemmb(handle, alpha, A, transA, transB, beta, *this, stream); + } + + // this = a*x + inline void ax(const T a, const SimpleDenseMat& x, cudaStream_t stream) + { + ASSERT(ord == x.ord, "SimpleDenseMat::ax: Storage orders must match"); + + auto scale = [a] __device__(const T x) { return a * x; }; + raft::linalg::unaryOp(data, x.data, len, scale, stream); + } + + // this = a*x + y + inline void axpy(const T a, + const SimpleDenseMat& x, + const SimpleDenseMat& y, + cudaStream_t stream) + { + ASSERT(ord == x.ord, "SimpleDenseMat::axpy: Storage orders must match"); + ASSERT(ord == y.ord, "SimpleDenseMat::axpy: Storage orders must match"); + + auto axpy = [a] __device__(const T x, const T y) { return a * x + y; }; + raft::linalg::binaryOp(data, x.data, y.data, len, axpy, stream); + } + + template + inline void assign_unary(const SimpleDenseMat& other, Lambda f, cudaStream_t stream) + { + ASSERT(ord == other.ord, "SimpleDenseMat::assign_unary: Storage orders must match"); + + raft::linalg::unaryOp(data, other.data, len, f, stream); + } + + template + inline void assign_binary(const SimpleDenseMat& other1, + const SimpleDenseMat& other2, + Lambda& f, + cudaStream_t stream) + { + ASSERT(ord == other1.ord, "SimpleDenseMat::assign_binary: Storage orders must match"); + ASSERT(ord == other2.ord, "SimpleDenseMat::assign_binary: Storage orders must match"); + + raft::linalg::binaryOp(data, other1.data, other2.data, len, f, stream); + } + + template + inline void assign_ternary(const SimpleDenseMat& other1, + const SimpleDenseMat& other2, + const SimpleDenseMat& other3, + Lambda& f, + cudaStream_t stream) + { + ASSERT(ord == other1.ord, "SimpleDenseMat::assign_ternary: Storage orders must match"); + ASSERT(ord == other2.ord, "SimpleDenseMat::assign_ternary: Storage orders must match"); + ASSERT(ord == other3.ord, "SimpleDenseMat::assign_ternary: Storage orders must match"); + + raft::linalg::ternaryOp(data, other1.data, other2.data, other3.data, len, f, stream); + } + + inline void fill(const T val, cudaStream_t stream) + { + // TODO this reads data unnecessary, though it's mostly used for testing + auto f = [val] __device__(const T x) { return val; }; + raft::linalg::unaryOp(data, data, len, f, stream); + } + + inline void copy_async(const SimpleDenseMat& other, cudaStream_t stream) + { + ASSERT((ord == other.ord) && (this->m == other.m) && (this->n == other.n), + "SimpleDenseMat::copy: matrices not compatible"); + + RAFT_CUDA_TRY( + cudaMemcpyAsync(data, other.data, len * sizeof(T), cudaMemcpyDeviceToDevice, stream)); + } + + void print(std::ostream& oss) const override { oss << (*this) << std::endl; } + + void operator=(const SimpleDenseMat& other) = delete; +}; + +template +struct SimpleVec : SimpleDenseMat { + typedef SimpleDenseMat Super; + + SimpleVec(T* data, const int n) : Super(data, n, 1, COL_MAJOR) {} + // this = alpha * A * x + beta * this + void assign_gemv(const raft::handle_t& handle, + const T alpha, + const SimpleDenseMat& A, + bool transA, + const SimpleVec& x, + const T beta, + cudaStream_t stream) + { + Super::assign_gemm(handle, alpha, A, transA, x, false, beta, stream); + } + + SimpleVec() : Super(COL_MAJOR) {} + + inline void reset(T* new_data, int n) { Super::reset(new_data, n, 1); } +}; + +template +inline void col_ref(const SimpleDenseMat& mat, SimpleVec& mask_vec, int c) +{ + ASSERT(mat.ord == COL_MAJOR, "col_ref only available for column major mats"); + T* tmp = &mat.data[mat.m * c]; + mask_vec.reset(tmp, mat.m); +} + +template +inline void col_slice(const SimpleDenseMat& mat, + SimpleDenseMat& mask_mat, + int c_from, + int c_to) +{ + ASSERT(c_from >= 0 && c_from < mat.n, "col_slice: invalid from"); + ASSERT(c_to >= 0 && c_to <= mat.n, "col_slice: invalid to"); + + ASSERT(mat.ord == COL_MAJOR, "col_ref only available for column major mats"); + ASSERT(mask_mat.ord == COL_MAJOR, "col_ref only available for column major mask"); + T* tmp = &mat.data[mat.m * c_from]; + mask_mat.reset(tmp, mat.m, c_to - c_from); +} - template - inline T nrm1(const SimpleVec& u, T* tmp_dev, cudaStream_t stream) - { - raft::linalg::rowNorm( - tmp_dev, u.data, u.len, 1, raft::linalg::L1Norm, true, stream, raft::Nop()); - T tmp_host; - raft::update_host(&tmp_host, tmp_dev, 1, stream); - raft::interruptible::synchronize(stream); - return tmp_host; - } +// Reductions such as dot or norm require an additional location in dev mem +// to hold the result. We don't want to deal with this in the SimpleVec class +// as it impedes thread safety and constness - template - std::ostream& operator<<(std::ostream& os, const SimpleVec& v) - { - std::vector out(v.len); - raft::update_host(&out[0], v.data, v.len, 0); - raft::interruptible::synchronize(rmm::cuda_stream_view()); - int it = 0; - for (; it < v.len - 1;) { - os << out[it] << " "; - it++; - } - os << out[it]; - return os; +template +inline T dot(const SimpleVec& u, const SimpleVec& v, T* tmp_dev, cudaStream_t stream) +{ + auto f = [] __device__(const T x, const T y) { return x * y; }; + raft::linalg::mapThenSumReduce(tmp_dev, u.len, f, stream, u.data, v.data); + T tmp_host; + raft::update_host(&tmp_host, tmp_dev, 1, stream); + + raft::interruptible::synchronize(stream); + return tmp_host; +} + +template +inline T squaredNorm(const SimpleVec& u, T* tmp_dev, cudaStream_t stream) +{ + return dot(u, u, tmp_dev, stream); +} + +template +inline T nrmMax(const SimpleVec& u, T* tmp_dev, cudaStream_t stream) +{ + auto f = [] __device__(const T x) { return raft::myAbs(x); }; + auto r = [] __device__(const T x, const T y) { return raft::myMax(x, y); }; + raft::linalg::mapThenReduce(tmp_dev, u.len, T(0), f, r, stream, u.data); + T tmp_host; + raft::update_host(&tmp_host, tmp_dev, 1, stream); + raft::interruptible::synchronize(stream); + return tmp_host; +} + +template +inline T nrm2(const SimpleVec& u, T* tmp_dev, cudaStream_t stream) +{ + return raft::mySqrt(squaredNorm(u, tmp_dev, stream)); +} + +template +inline T nrm1(const SimpleVec& u, T* tmp_dev, cudaStream_t stream) +{ + raft::linalg::rowNorm( + tmp_dev, u.data, u.len, 1, raft::linalg::L1Norm, true, stream, raft::Nop()); + T tmp_host; + raft::update_host(&tmp_host, tmp_dev, 1, stream); + raft::interruptible::synchronize(stream); + return tmp_host; +} + +template +std::ostream& operator<<(std::ostream& os, const SimpleVec& v) +{ + std::vector out(v.len); + raft::update_host(&out[0], v.data, v.len, 0); + raft::interruptible::synchronize(rmm::cuda_stream_view()); + int it = 0; + for (; it < v.len - 1;) { + os << out[it] << " "; + it++; + } + os << out[it]; + return os; +} + +template +std::ostream& operator<<(std::ostream& os, const SimpleDenseMat& mat) +{ + os << "ord=" << (mat.ord == COL_MAJOR ? "CM" : "RM") << "\n"; + std::vector out(mat.len); + raft::update_host(&out[0], mat.data, mat.len, rmm::cuda_stream_default); + raft::interruptible::synchronize(rmm::cuda_stream_view()); + if (mat.ord == COL_MAJOR) { + for (int r = 0; r < mat.m; r++) { + int idx = r; + for (int c = 0; c < mat.n - 1; c++) { + os << out[idx] << ","; + idx += mat.m; + } + os << out[idx] << std::endl; } - - template - std::ostream& operator<<(std::ostream& os, const SimpleDenseMat& mat) - { - os << "ord=" << (mat.ord == COL_MAJOR ? "CM" : "RM") << "\n"; - std::vector out(mat.len); - raft::update_host(&out[0], mat.data, mat.len, rmm::cuda_stream_default); - raft::interruptible::synchronize(rmm::cuda_stream_view()); - if (mat.ord == COL_MAJOR) { - for (int r = 0; r < mat.m; r++) { - int idx = r; - for (int c = 0; c < mat.n - 1; c++) { - os << out[idx] << ","; - idx += mat.m; - } - os << out[idx] << std::endl; - } - } else { - for (int c = 0; c < mat.m; c++) { - int idx = c * mat.n; - for (int r = 0; r < mat.n - 1; r++) { - os << out[idx] << ","; - idx += 1; - } - os << out[idx] << std::endl; - } - } - - return os; + } else { + for (int c = 0; c < mat.m; c++) { + int idx = c * mat.n; + for (int r = 0; r < mat.n - 1; r++) { + os << out[idx] << ","; + idx += 1; + } + os << out[idx] << std::endl; } + } + + return os; +} - template - struct SimpleVecOwning : SimpleVec { - typedef SimpleVec Super; - typedef rmm::device_uvector Buffer; - Buffer buf; +template +struct SimpleVecOwning : SimpleVec { + typedef SimpleVec Super; + typedef rmm::device_uvector Buffer; + Buffer buf; - SimpleVecOwning() = delete; + SimpleVecOwning() = delete; - SimpleVecOwning(int n, cudaStream_t stream) : Super(), buf(n, stream) - { - Super::reset(buf.data(), n); - } + SimpleVecOwning(int n, cudaStream_t stream) : Super(), buf(n, stream) + { + Super::reset(buf.data(), n); + } - void operator=(const SimpleVec& other) = delete; - }; + void operator=(const SimpleVec& other) = delete; +}; - template - struct SimpleMatOwning : SimpleDenseMat { - typedef SimpleDenseMat Super; - typedef rmm::device_uvector Buffer; - Buffer buf; - using Super::m; - using Super::n; - using Super::ord; +template +struct SimpleMatOwning : SimpleDenseMat { + typedef SimpleDenseMat Super; + typedef rmm::device_uvector Buffer; + Buffer buf; + using Super::m; + using Super::n; + using Super::ord; - SimpleMatOwning() = delete; + SimpleMatOwning() = delete; - SimpleMatOwning(int m, int n, cudaStream_t stream, STORAGE_ORDER order = COL_MAJOR) - : Super(order), buf(m * n, stream) - { - Super::reset(buf.data(), m, n); - } + SimpleMatOwning(int m, int n, cudaStream_t stream, STORAGE_ORDER order = COL_MAJOR) + : Super(order), buf(m * n, stream) + { + Super::reset(buf.data(), m, n); + } - void operator=(const SimpleVec& other) = delete; - }; + void operator=(const SimpleVec& other) = delete; +}; - /** +/** * Sparse matrix in CSR format. * * Note, we use cuSPARSE to manimulate matrices, and it guarantees: @@ -450,173 +450,172 @@ namespace raft::solver { * * However, when the data comes from the outside, we cannot guarantee that. */ - template - struct SimpleSparseMat : SimpleMat { - typedef SimpleMat Super; - T* values; - int* cols; - int* row_ids; - int nnz; - - SimpleSparseMat() : Super(0, 0), values(nullptr), cols(nullptr), row_ids(nullptr), nnz(0) {} - - SimpleSparseMat(T* values, int* cols, int* row_ids, int nnz, int m, int n) - : Super(m, n), values(values), cols(cols), row_ids(row_ids), nnz(nnz) - { - check_csr(*this, 0); - } - - void print(std::ostream& oss) const override { oss << (*this) << std::endl; } - - void operator=(const SimpleSparseMat& other) = delete; - - inline void gemmb(const raft::handle_t& handle, - const T alpha, - const SimpleDenseMat& A, - const bool transA, - const bool transB, - const T beta, - SimpleDenseMat& C, - cudaStream_t stream) const override - { - const SimpleSparseMat& B = *this; - int kA = A.n; - int kB = B.m; - - if (transA) { - ASSERT(A.n == C.m, "GEMM invalid dims: m"); - kA = A.m; - } else { - ASSERT(A.m == C.m, "GEMM invalid dims: m"); - } - - if (transB) { - ASSERT(B.m == C.n, "GEMM invalid dims: n"); - kB = B.n; - } else { - ASSERT(B.n == C.n, "GEMM invalid dims: n"); - } - ASSERT(kA == kB, "GEMM invalid dims: k"); - - // matrix C must change the order and be transposed, because we need - // to swap arguments A and B in cusparseSpMM. - cusparseDnMatDescr_t descrC; - auto order = C.ord == COL_MAJOR ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL; - RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat( - &descrC, C.n, C.m, order == CUSPARSE_ORDER_COL ? C.n : C.m, C.data, order)); - - /* - The matrix A must have the same order as the matrix C in the input - of function cusparseSpMM (i.e. swapped order w.r.t. original C). - To account this requirement, I may need to flip transA (whether to transpose A). - - C C' rowsC' colsC' ldC' A A' rowsA' colsA' ldA' flipTransA - c r n m m c r n m m x - c r n m m r r m n n o - r c n m n c c m n m o - r c n m n r c n m n x - - where: - c/r - column/row major order - A,C - input to gemmb - A', C' - input to cusparseSpMM - ldX' - leading dimension - m or n, depending on order and transX - */ - cusparseDnMatDescr_t descrA; - RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat(&descrA, - C.ord == A.ord ? A.n : A.m, - C.ord == A.ord ? A.m : A.n, - A.ord == COL_MAJOR ? A.m : A.n, - A.data, - order)); - auto opA = - transA ^ (C.ord == A.ord) ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE; - - cusparseSpMatDescr_t descrB; - RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr( - &descrB, B.m, B.n, B.nnz, B.row_ids, B.cols, B.values)); - auto opB = transB ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE; - - auto alg = order == CUSPARSE_ORDER_COL ? CUSPARSE_SPMM_CSR_ALG1 : CUSPARSE_SPMM_CSR_ALG2; - - size_t bufferSize; - RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsespmm_bufferSize(handle.get_cusparse_handle(), - opB, - opA, - &alpha, - descrB, - descrA, - &beta, - descrC, - alg, - &bufferSize, - stream)); - - raft::interruptible::synchronize(stream); - rmm::device_uvector tmp(bufferSize, stream); - - RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsespmm(handle.get_cusparse_handle(), - opB, - opA, - &alpha, - descrB, - descrA, - &beta, - descrC, - alg, - tmp.data(), - stream)); - - RAFT_CUSPARSE_TRY(cusparseDestroyDnMat(descrA)); - RAFT_CUSPARSE_TRY(cusparseDestroySpMat(descrB)); - RAFT_CUSPARSE_TRY(cusparseDestroyDnMat(descrC)); - } - }; - - template - inline void check_csr(const SimpleSparseMat& mat, cudaStream_t stream) - { - int row_ids_nnz; - raft::update_host(&row_ids_nnz, &mat.row_ids[mat.m], 1, stream); - raft::interruptible::synchronize(stream); - ASSERT(row_ids_nnz == mat.nnz, - "SimpleSparseMat: the size of CSR row_ids array must be `m + 1`, and " - "the last element must be equal nnz."); +template +struct SimpleSparseMat : SimpleMat { + typedef SimpleMat Super; + T* values; + int* cols; + int* row_ids; + int nnz; + + SimpleSparseMat() : Super(0, 0), values(nullptr), cols(nullptr), row_ids(nullptr), nnz(0) {} + + SimpleSparseMat(T* values, int* cols, int* row_ids, int nnz, int m, int n) + : Super(m, n), values(values), cols(cols), row_ids(row_ids), nnz(nnz) + { + check_csr(*this, 0); + } + + void print(std::ostream& oss) const override { oss << (*this) << std::endl; } + + void operator=(const SimpleSparseMat& other) = delete; + + inline void gemmb(const raft::handle_t& handle, + const T alpha, + const SimpleDenseMat& A, + const bool transA, + const bool transB, + const T beta, + SimpleDenseMat& C, + cudaStream_t stream) const override + { + const SimpleSparseMat& B = *this; + int kA = A.n; + int kB = B.m; + + if (transA) { + ASSERT(A.n == C.m, "GEMM invalid dims: m"); + kA = A.m; + } else { + ASSERT(A.m == C.m, "GEMM invalid dims: m"); } - template - std::ostream& operator<<(std::ostream& os, const SimpleSparseMat& mat) - { - check_csr(mat, 0); - os << "SimpleSparseMat (CSR)" - << "\n"; - std::vector values(mat.nnz); - std::vector cols(mat.nnz); - std::vector row_ids(mat.m + 1); - raft::update_host(&values[0], mat.values, mat.nnz, rmm::cuda_stream_default); - raft::update_host(&cols[0], mat.cols, mat.nnz, rmm::cuda_stream_default); - raft::update_host(&row_ids[0], mat.row_ids, mat.m + 1, rmm::cuda_stream_default); - raft::interruptible::synchronize(rmm::cuda_stream_view()); - - int i, row_end = 0; - for (int row = 0; row < mat.m; row++) { - i = row_end; - row_end = row_ids[row + 1]; - for (int col = 0; col < mat.n; col++) { - if (i >= row_end || col < cols[i]) { - os << "0"; - } else { - os << values[i]; - i++; - } - if (col < mat.n - 1) os << ","; - } - - os << std::endl; - } - - return os; + if (transB) { + ASSERT(B.m == C.n, "GEMM invalid dims: n"); + kB = B.n; + } else { + ASSERT(B.n == C.n, "GEMM invalid dims: n"); } + ASSERT(kA == kB, "GEMM invalid dims: k"); + + // matrix C must change the order and be transposed, because we need + // to swap arguments A and B in cusparseSpMM. + cusparseDnMatDescr_t descrC; + auto order = C.ord == COL_MAJOR ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL; + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat( + &descrC, C.n, C.m, order == CUSPARSE_ORDER_COL ? C.n : C.m, C.data, order)); + + /* + The matrix A must have the same order as the matrix C in the input + of function cusparseSpMM (i.e. swapped order w.r.t. original C). + To account this requirement, I may need to flip transA (whether to transpose A). + + C C' rowsC' colsC' ldC' A A' rowsA' colsA' ldA' flipTransA + c r n m m c r n m m x + c r n m m r r m n n o + r c n m n c c m n m o + r c n m n r c n m n x + + where: + c/r - column/row major order + A,C - input to gemmb + A', C' - input to cusparseSpMM + ldX' - leading dimension - m or n, depending on order and transX + */ + cusparseDnMatDescr_t descrA; + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat(&descrA, + C.ord == A.ord ? A.n : A.m, + C.ord == A.ord ? A.m : A.n, + A.ord == COL_MAJOR ? A.m : A.n, + A.data, + order)); + auto opA = + transA ^ (C.ord == A.ord) ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE; + + cusparseSpMatDescr_t descrB; + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr( + &descrB, B.m, B.n, B.nnz, B.row_ids, B.cols, B.values)); + auto opB = transB ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE; + + auto alg = order == CUSPARSE_ORDER_COL ? CUSPARSE_SPMM_CSR_ALG1 : CUSPARSE_SPMM_CSR_ALG2; + + size_t bufferSize; + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsespmm_bufferSize(handle.get_cusparse_handle(), + opB, + opA, + &alpha, + descrB, + descrA, + &beta, + descrC, + alg, + &bufferSize, + stream)); + + raft::interruptible::synchronize(stream); + rmm::device_uvector tmp(bufferSize, stream); + + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsespmm(handle.get_cusparse_handle(), + opB, + opA, + &alpha, + descrB, + descrA, + &beta, + descrC, + alg, + tmp.data(), + stream)); + + RAFT_CUSPARSE_TRY(cusparseDestroyDnMat(descrA)); + RAFT_CUSPARSE_TRY(cusparseDestroySpMat(descrB)); + RAFT_CUSPARSE_TRY(cusparseDestroyDnMat(descrC)); + } +}; + +template +inline void check_csr(const SimpleSparseMat& mat, cudaStream_t stream) +{ + int row_ids_nnz; + raft::update_host(&row_ids_nnz, &mat.row_ids[mat.m], 1, stream); + raft::interruptible::synchronize(stream); + ASSERT(row_ids_nnz == mat.nnz, + "SimpleSparseMat: the size of CSR row_ids array must be `m + 1`, and " + "the last element must be equal nnz."); +} + +template +std::ostream& operator<<(std::ostream& os, const SimpleSparseMat& mat) +{ + check_csr(mat, 0); + os << "SimpleSparseMat (CSR)" + << "\n"; + std::vector values(mat.nnz); + std::vector cols(mat.nnz); + std::vector row_ids(mat.m + 1); + raft::update_host(&values[0], mat.values, mat.nnz, rmm::cuda_stream_default); + raft::update_host(&cols[0], mat.cols, mat.nnz, rmm::cuda_stream_default); + raft::update_host(&row_ids[0], mat.row_ids, mat.m + 1, rmm::cuda_stream_default); + raft::interruptible::synchronize(rmm::cuda_stream_view()); + + int i, row_end = 0; + for (int row = 0; row < mat.m; row++) { + i = row_end; + row_end = row_ids[row + 1]; + for (int col = 0; col < mat.n; col++) { + if (i >= row_end || col < cols[i]) { + os << "0"; + } else { + os << values[i]; + i++; + } + if (col < mat.n - 1) os << ","; + } + + os << std::endl; + } + return os; +} }; // namespace raft::solver diff --git a/cpp/include/raft/solver/solver_types.hpp b/cpp/include/raft/solver/solver_types.hpp index cf7f3bcd8b..fb454e2745 100644 --- a/cpp/include/raft/solver/solver_types.hpp +++ b/cpp/include/raft/solver/solver_types.hpp @@ -34,56 +34,72 @@ enum loss_funct { enum penalty { NONE, L1, L2, ELASTICNET }; namespace gradient_descent { - template - struct sgd_params { - int batch_size; - int epochs; - lr_type lr_type; - math_t eta0; - math_t power_t; - loss_funct loss; - penalty penalty; - math_t alpha; - math_t l1_ratio; - bool shuffle; - math_t tol; - int n_iter_no_change; - - sgd_params() : batch_size(100), epochs(100), lr_type(lr_type::OPTIMAL), eta0(0.5), power_t(0.5), - loss(loss_funct::SQUARED), penalty(penalty::L1), alpha(0.5), l1_ratio(0.2), shuffle(true), tol(1e-8), n_iter_no_change(5){} - }; -} +template +struct sgd_params { + int batch_size; + int epochs; + lr_type lr_type; + math_t eta0; + math_t power_t; + loss_funct loss; + penalty penalty; + math_t alpha; + math_t l1_ratio; + bool shuffle; + math_t tol; + int n_iter_no_change; + + sgd_params() + : batch_size(100), + epochs(100), + lr_type(lr_type::OPTIMAL), + eta0(0.5), + power_t(0.5), + loss(loss_funct::SQUARED), + penalty(penalty::L1), + alpha(0.5), + l1_ratio(0.2), + shuffle(true), + tol(1e-8), + n_iter_no_change(5) + { + } +}; +} // namespace gradient_descent namespace coordinate_descent { - template - struct cd_params { - bool normalize; // whether to normalize the data to zero-mean and unit std - int epochs; // number of iterations - loss_funct loss; // loss function to minimize - math_t alpha; // l1 penalty parameter - math_t l1_ratio; // ratio of alpha that will be used for l1 penalty. (1 - l1_ratio) * alpha will be used for l2 penalty - bool shuffle; // randomly pick coordinates - math_t tol; // early-stopping convergence tolerance - - cd_params() : - normalize(true), - epochs(100), - alpha(0.3), - l1_ratio(0.5), - shuffle(true), - tol(1e-8), - loss(loss_funct::SQRD_LOSS) {} - }; -} +template +struct cd_params { + bool normalize; // whether to normalize the data to zero-mean and unit std + int epochs; // number of iterations + loss_funct loss; // loss function to minimize + math_t alpha; // l1 penalty parameter + math_t l1_ratio; // ratio of alpha that will be used for l1 penalty. (1 - l1_ratio) * alpha will + // be used for l2 penalty + bool shuffle; // randomly pick coordinates + math_t tol; // early-stopping convergence tolerance + + cd_params() + : normalize(true), + epochs(100), + alpha(0.3), + l1_ratio(0.5), + shuffle(true), + tol(1e-8), + loss(loss_funct::SQRD_LOSS) + { + } +}; +} // namespace coordinate_descent namespace least_angle_regression { - template - struct lars_params { - int max_iter; - math_t eps; - - lars_params(): max_iter(500), eps(-1) {} - }; -} +template +struct lars_params { + int max_iter; + math_t eps; + + lars_params() : max_iter(500), eps(-1) {} +}; +} // namespace least_angle_regression enum class LarsFitStatus { kOk, kCollinear, kError, kStop }; diff --git a/cpp/test/solver/quasi_newton.cu b/cpp/test/solver/quasi_newton.cu index 5a72eb4772..f0a67d9eb9 100644 --- a/cpp/test/solver/quasi_newton.cu +++ b/cpp/test/solver/quasi_newton.cu @@ -15,747 +15,734 @@ */ #include -#include #include #include #include +#include #include #include #include namespace raft::solver::quasi_newton { - template - int qn_fit(const raft::handle_t& handle, - const qn_params& pams, - LossFunction& loss, - const SimpleMat& X, - const SimpleVec& y, - SimpleDenseMat& Z, - T* w0_data, // initial value and result - T* fx, - int* num_iters) - { - LBFGSParam opt_param(pams); - SimpleVec w0(w0_data, loss.n_param); - - // Scale the regularization strenght with the number of samples. - T l1 = pams.penalty_l1; - T l2 = pams.penalty_l2; - if (pams.penalty_normalized) { - l1 /= X.m; - l2 /= X.m; - } - - if (l2 == 0) { - ObjectiveWithData lossWith(&loss, X, y, Z); - - return minimize(handle, w0, fx, num_iters, lossWith, l1, opt_param); - - } else { - Tikhonov reg(l2); - RegularizedQN obj(&loss, ®); - ObjectiveWithData lossWith(&obj, X, y, Z); - - return minimize(handle, w0, fx, num_iters, lossWith, l1, opt_param); - } - } +template +int qn_fit(const raft::handle_t& handle, + const qn_params& pams, + LossFunction& loss, + const SimpleMat& X, + const SimpleVec& y, + SimpleDenseMat& Z, + T* w0_data, // initial value and result + T* fx, + int* num_iters) +{ + LBFGSParam opt_param(pams); + SimpleVec w0(w0_data, loss.n_param); - template - inline void qn_fit_x(const raft::handle_t& handle, - const qn_params& pams, - SimpleMat& X, - T* y_data, - int C, - T* w0_data, - T* f, - int* num_iters, - cudaStream_t stream, - T* sample_weight = nullptr, - T svr_eps = 0) - { - /* - NB: - N - number of data rows - D - number of data columns (features) - C - number of output classes - - X in R^[N, D] - w in R^[D, C] - y in {0, 1}^[N, C] or {cat}^N - - Dimensionality of w0 depends on loss, so we initialize it later. - */ - int N = X.m; - int D = X.n; - int n_targets = qn_is_classification(pams.loss) && C == 2 ? 1 : C; - rmm::device_uvector tmp(n_targets * N, stream); - SimpleDenseMat Z(tmp.data(), n_targets, N); - SimpleVec y(y_data, N); - - switch (pams.loss) { - case QN_LOSS_LOGISTIC: { - ASSERT(C == 2, "qn.h: logistic loss invalid C"); - LogisticLoss loss(handle, D, pams.fit_intercept); - if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); - qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); - } break; - case QN_LOSS_SQUARED: { - ASSERT(C == 1, "qn.h: squared loss invalid C"); - SquaredLoss loss(handle, D, pams.fit_intercept); - if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); - qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); - } break; - case QN_LOSS_SOFTMAX: { - ASSERT(C > 2, "qn.h: softmax invalid C"); - Softmax loss(handle, D, C, pams.fit_intercept); - if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); - qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); - } break; - case QN_LOSS_SVC_L1: { - ASSERT(C == 2, "qn.h: SVC-L1 loss invalid C"); - HingeLoss loss(handle, D, pams.fit_intercept); - if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); - qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); - } break; - case QN_LOSS_SVC_L2: { - ASSERT(C == 2, "qn.h: SVC-L2 loss invalid C"); - SqHingeLoss loss(handle, D, pams.fit_intercept); - if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); - qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); - } break; - case QN_LOSS_SVR_L1: { - ASSERT(C == 1, "qn.h: SVR-L1 loss invalid C"); - EpsInsHingeLoss loss(handle, D, pams.fit_intercept, svr_eps); - if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); - qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); - } break; - case QN_LOSS_SVR_L2: { - ASSERT(C == 1, "qn.h: SVR-L2 loss invalid C"); - SqEpsInsHingeLoss loss(handle, D, pams.fit_intercept, svr_eps); - if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); - qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); - } break; - case QN_LOSS_ABS: { - ASSERT(C == 1, "qn.h: abs loss (L1) invalid C"); - AbsLoss loss(handle, D, pams.fit_intercept); - if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); - qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); - } break; - default: { - ASSERT(false, "qn.h: unknown loss function type (id = %d).", pams.loss); - } - } - } + // Scale the regularization strenght with the number of samples. + T l1 = pams.penalty_l1; + T l2 = pams.penalty_l2; + if (pams.penalty_normalized) { + l1 /= X.m; + l2 /= X.m; + } + if (l2 == 0) { + ObjectiveWithData lossWith(&loss, X, y, Z); + return minimize(handle, w0, fx, num_iters, lossWith, l1, opt_param); - struct QuasiNewtonTest : ::testing::Test { - static constexpr int N = 10; - static constexpr int D = 2; - - const static double* nobptr; - const static double tol; - const static double X[N][D]; - const raft::handle_t& handle; - cudaStream_t stream = 0; - std::shared_ptr> Xdev; - std::shared_ptr> ydev; - - QuasiNewtonTest() {} - void SetUp() - { - stream = handle.get_stream(); - Xdev.reset(new SimpleMatOwning(N, D, stream, ROW_MAJOR)); - raft::update_device(Xdev->data, &X[0][0], Xdev->len, stream); - - ydev.reset(new SimpleVecOwning(N, stream)); - handle.sync_stream(stream); - } - void TearDown() {} - }; - - const double* QuasiNewtonTest::nobptr = 0; - const double QuasiNewtonTest::tol = 5e-6; - const double QuasiNewtonTest::X[QuasiNewtonTest::N][QuasiNewtonTest::D] = { - {-0.2047076594847130, 0.4789433380575482}, - {-0.5194387150567381, -0.5557303043474900}, - {1.9657805725027142, 1.3934058329729904}, - {0.0929078767437177, 0.2817461528302025}, - {0.7690225676118387, 1.2464347363862822}, - {1.0071893575830049, -1.2962211091122635}, - {0.2749916334321240, 0.2289128789353159}, - {1.3529168351654497, 0.8864293405915888}, - {-2.0016373096603974, -0.3718425371402544}, - {1.6690253095248706, -0.4385697358355719}}; - - template - ::testing::AssertionResult checkParamsEqual(const raft::handle_t& handle, - const T* host_weights, - const T* host_bias, - const T* w, - const GLMDims& dims, - Comp& comp, - cudaStream_t stream) - { - int C = dims.C; - int D = dims.D; - bool fit_intercept = dims.fit_intercept; - std::vector w_ref_cm(C * D); - int idx = 0; - for (int d = 0; d < D; d++) - for (int c = 0; c < C; c++) { - w_ref_cm[idx++] = host_weights[c * D + d]; - } - - SimpleVecOwning w_ref(dims.n_param, stream); - raft::update_device(w_ref.data, &w_ref_cm[0], C * D, stream); - if (fit_intercept) { raft::update_device(&w_ref.data[C * D], host_bias, C, stream); } - handle.sync_stream(stream); - return raft::devArrMatch(w_ref.data, w, w_ref.len, comp); - } - - template - T run(const raft::handle_t& handle, - LossFunction& loss, - const SimpleMat& X, - const SimpleVec& y, - T l1, - T l2, - T* w, - SimpleDenseMat& z) - { - qn_params pams; - pams.max_iter = 100; - pams.grad_tol = 1e-16; - pams.change_tol = 1e-16; - pams.linesearch_max_iter = 50; - pams.lbfgs_memory = 5; - pams.penalty_l1 = l1; - pams.penalty_l2 = l2; - pams.verbose = verbosity; - - int num_iters = 0; - - T fx; - - qn_fit(handle, pams, loss, X, y, z, w, &fx, &num_iters); - - return fx; - } - - template - T run_api(const raft::handle_t& cuml_handle, - qn_loss_type loss_type, - int C, - bool fit_intercept, - const SimpleMat& X, - const SimpleVec& y, - T l1, - T l2, - T* w, - SimpleDenseMat& z, - int verbosity, - cudaStream_t stream) - { - qn_params pams; - - pams.max_iter = 100; - pams.grad_tol = 1e-8; - pams.change_tol = 1e-8; - pams.linesearch_max_iter = 50; - pams.lbfgs_memory = 5; - pams.penalty_l1 = l1; - pams.penalty_l2 = l2; - pams.verbose = verbosity; - pams.fit_intercept = fit_intercept; - pams.loss = loss_type; - - int num_iters = 0; - - SimpleVec w0(w, X.n + fit_intercept); - w0.fill(T(0), stream); - T fx; - - qn_fit_on_x(cuml_handle, - pams, - X_dense->data, - X_dense->ord == COL_MAJOR, - y.data, - X_dense->m, - X_dense->n, - C, - w, - &fx, - &num_iters); - } else { - ADD_FAILURE(); - } - - return fx; - } - - TEST_F(QuasiNewtonTest, binary_logistic_vs_sklearn) - { -#if CUDART_VERSION >= 11020 - GTEST_SKIP(); -#endif - raft::CompareApprox compApprox(tol); - // Test case generated in python and solved with sklearn - double y[N] = {1, 1, 1, 0, 1, 0, 1, 0, 1, 0}; - raft::update_device(ydev->data, &y[0], ydev->len, stream); - handle.sync_stream(stream); - - double alpha = 0.01 * N; - - LogisticLoss loss_b(handle, D, true); - LogisticLoss loss_no_b(handle, D, false); - - SimpleVecOwning w0(D + 1, stream); - SimpleMatOwning z(1, N, stream); - - double l1, l2, fx; - - double w_l1_b[2] = {-1.6899370396155091, 1.9021577534928300}; - double b_l1_b = 0.8057670813749118; - double obj_l1_b = 0.44295941481024703; - - l1 = alpha; - l2 = 0.0; - fx = run(handle, loss_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); - ASSERT_TRUE(compApprox(obj_l1_b, fx)); - ASSERT_TRUE(checkParamsEqual(handle, &w_l1_b[0], &b_l1_b, w0.data, loss_b, compApprox, stream)); - - fx = run_api(handle, - QN_LOSS_LOGISTIC, - 2, - loss_b.fit_intercept, - *Xdev, - *ydev, - l1, - l2, - w0.data, - z, - 0, - stream); - ASSERT_TRUE(compApprox(obj_l1_b, fx)); - - double w_l2_b[2] = {-1.5339880402781370, 1.6788639581350926}; - double b_l2_b = 0.806087868102401; - double obj_l2_b = 0.4378085369889721; - - l1 = 0; - l2 = alpha; - fx = run(handle, loss_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); - - ASSERT_TRUE(compApprox(obj_l2_b, fx)); - ASSERT_TRUE(checkParamsEqual(handle, &w_l2_b[0], &b_l2_b, w0.data, loss_b, compApprox, stream)); - - fx = run_api(cuml_handle, - QN_LOSS_LOGISTIC, - 2, - loss_b.fit_intercept, - *Xdev, - *ydev, - l1, - l2, - w0.data, - z, - 0, - stream); - ASSERT_TRUE(compApprox(obj_l2_b, fx)); - - double w_l1_no_b[2] = {-1.6215035298864591, 2.3650868394981086}; - double obj_l1_no_b = 0.4769896009200278; - - l1 = alpha; - l2 = 0.0; - fx = run(handle, loss_no_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); - ASSERT_TRUE(compApprox(obj_l1_no_b, fx)); - ASSERT_TRUE( - checkParamsEqual(handle, &w_l1_no_b[0], nobptr, w0.data, loss_no_b, compApprox, stream)); - - fx = run_api(cuml_handle, - QN_LOSS_LOGISTIC, - 2, - loss_no_b.fit_intercept, - *Xdev, - *ydev, - l1, - l2, - w0.data, - z, - 0, - stream); - ASSERT_TRUE(compApprox(obj_l1_no_b, fx)); - - double w_l2_no_b[2] = {-1.3931049893764620, 2.0140103094119621}; - double obj_l2_no_b = 0.47502098062114273; - - l1 = 0; - l2 = alpha; - fx = run(handle, loss_no_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); - ASSERT_TRUE(compApprox(obj_l2_no_b, fx)); - ASSERT_TRUE( - checkParamsEqual(handle, &w_l2_no_b[0], nobptr, w0.data, loss_no_b, compApprox, stream)); - - fx = run_api(cuml_handle, - QN_LOSS_LOGISTIC, - 2, - loss_no_b.fit_intercept, - *Xdev, - *ydev, - l1, - l2, - w0.data, - z, - 0, - stream); - ASSERT_TRUE(compApprox(obj_l2_no_b, fx)); - } + } else { + Tikhonov reg(l2); + RegularizedQN obj(&loss, ®); + ObjectiveWithData lossWith(&obj, X, y, Z); + + return minimize(handle, w0, fx, num_iters, lossWith, l1, opt_param); + } +} - TEST_F(QuasiNewtonTest, multiclass_logistic_vs_sklearn) +template +inline void qn_fit_x(const raft::handle_t& handle, + const qn_params& pams, + SimpleMat& X, + T* y_data, + int C, + T* w0_data, + T* f, + int* num_iters, + cudaStream_t stream, + T* sample_weight = nullptr, + T svr_eps = 0) { -#if CUDART_VERSION >= 11020 - GTEST_SKIP(); -#endif - // The data seems to small for the objective to be strongly convex - // leaving out exact param checks + /* + NB: + N - number of data rows + D - number of data columns (features) + C - number of output classes + + X in R^[N, D] + w in R^[D, C] + y in {0, 1}^[N, C] or {cat}^N + + Dimensionality of w0 depends on loss, so we initialize it later. + */ + int N = X.m; + int D = X.n; + int n_targets = qn_is_classification(pams.loss) && C == 2 ? 1 : C; + rmm::device_uvector tmp(n_targets * N, stream); + SimpleDenseMat Z(tmp.data(), n_targets, N); + SimpleVec y(y_data, N); + + switch (pams.loss) { + case QN_LOSS_LOGISTIC: { + ASSERT(C == 2, "qn.h: logistic loss invalid C"); + LogisticLoss loss(handle, D, pams.fit_intercept); + if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + } break; + case QN_LOSS_SQUARED: { + ASSERT(C == 1, "qn.h: squared loss invalid C"); + SquaredLoss loss(handle, D, pams.fit_intercept); + if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + } break; + case QN_LOSS_SOFTMAX: { + ASSERT(C > 2, "qn.h: softmax invalid C"); + Softmax loss(handle, D, C, pams.fit_intercept); + if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + } break; + case QN_LOSS_SVC_L1: { + ASSERT(C == 2, "qn.h: SVC-L1 loss invalid C"); + HingeLoss loss(handle, D, pams.fit_intercept); + if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + } break; + case QN_LOSS_SVC_L2: { + ASSERT(C == 2, "qn.h: SVC-L2 loss invalid C"); + SqHingeLoss loss(handle, D, pams.fit_intercept); + if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + } break; + case QN_LOSS_SVR_L1: { + ASSERT(C == 1, "qn.h: SVR-L1 loss invalid C"); + EpsInsHingeLoss loss(handle, D, pams.fit_intercept, svr_eps); + if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + } break; + case QN_LOSS_SVR_L2: { + ASSERT(C == 1, "qn.h: SVR-L2 loss invalid C"); + SqEpsInsHingeLoss loss(handle, D, pams.fit_intercept, svr_eps); + if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + } break; + case QN_LOSS_ABS: { + ASSERT(C == 1, "qn.h: abs loss (L1) invalid C"); + AbsLoss loss(handle, D, pams.fit_intercept); + if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); + qn_fit(handle, pams, loss, X, y, Z, w0_data, f, num_iters, stream); + } break; + default: { + ASSERT(false, "qn.h: unknown loss function type (id = %d).", pams.loss); + } + } +} - raft::CompareApprox compApprox(tol); - double y[N] = {2, 2, 0, 3, 3, 0, 0, 0, 1, 0}; - raft::update_device(ydev->data, &y[0], ydev->len, stream); +struct QuasiNewtonTest : ::testing::Test { + static constexpr int N = 10; + static constexpr int D = 2; + + const static double* nobptr; + const static double tol; + const static double X[N][D]; + const raft::handle_t& handle; + cudaStream_t stream = 0; + std::shared_ptr> Xdev; + std::shared_ptr> ydev; + + QuasiNewtonTest() {} + void SetUp() + { + stream = handle.get_stream(); + Xdev.reset(new SimpleMatOwning(N, D, stream, ROW_MAJOR)); + raft::update_device(Xdev->data, &X[0][0], Xdev->len, stream); + + ydev.reset(new SimpleVecOwning(N, stream)); handle.sync_stream(stream); + } + void TearDown() {} +}; + +const double* QuasiNewtonTest::nobptr = 0; +const double QuasiNewtonTest::tol = 5e-6; +const double QuasiNewtonTest::X[QuasiNewtonTest::N][QuasiNewtonTest::D] = { + {-0.2047076594847130, 0.4789433380575482}, + {-0.5194387150567381, -0.5557303043474900}, + {1.9657805725027142, 1.3934058329729904}, + {0.0929078767437177, 0.2817461528302025}, + {0.7690225676118387, 1.2464347363862822}, + {1.0071893575830049, -1.2962211091122635}, + {0.2749916334321240, 0.2289128789353159}, + {1.3529168351654497, 0.8864293405915888}, + {-2.0016373096603974, -0.3718425371402544}, + {1.6690253095248706, -0.4385697358355719}}; + +template +::testing::AssertionResult checkParamsEqual(const raft::handle_t& handle, + const T* host_weights, + const T* host_bias, + const T* w, + const GLMDims& dims, + Comp& comp, + cudaStream_t stream) +{ + int C = dims.C; + int D = dims.D; + bool fit_intercept = dims.fit_intercept; + std::vector w_ref_cm(C * D); + int idx = 0; + for (int d = 0; d < D; d++) + for (int c = 0; c < C; c++) { + w_ref_cm[idx++] = host_weights[c * D + d]; + } - double fx, l1, l2; - int C = 4; - - double alpha = 0.016 * N; - - SimpleMatOwning z(C, N, stream); - SimpleVecOwning w0(C * (D + 1), stream); - - Softmax loss_b(handle, D, C, true); - Softmax loss_no_b(handle, D, C, false); - - l1 = alpha; - l2 = 0.0; - double obj_l1_b = 0.5407911382311313; - - fx = run(handle, loss_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); - ASSERT_TRUE(compApprox(obj_l1_b, fx)); - - fx = run_api(cuml_handle, - QN_LOSS_SOFTMAX, - C, - loss_b.fit_intercept, - *Xdev, - *ydev, - l1, - l2, - w0.data, - z, - 0, - stream); - ASSERT_TRUE(compApprox(obj_l1_b, fx)); - - l1 = 0.0; - l2 = alpha; - double obj_l2_b = 0.5721784062720949; - - fx = run(handle, loss_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); - ASSERT_TRUE(compApprox(obj_l2_b, fx)); - - fx = run_api(cuml_handle, - QN_LOSS_SOFTMAX, - C, - loss_b.fit_intercept, - *Xdev, - *ydev, - l1, - l2, - w0.data, - z, - 0, - stream); - ASSERT_TRUE(compApprox(obj_l2_b, fx)); - - l1 = alpha; - l2 = 0.0; - double obj_l1_no_b = 0.6606929813245878; - - fx = run(handle, loss_no_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); - ASSERT_TRUE(compApprox(obj_l1_no_b, fx)); - - fx = run_api(cuml_handle, - QN_LOSS_SOFTMAX, - C, - loss_no_b.fit_intercept, - *Xdev, - *ydev, - l1, - l2, - w0.data, - z, - 0, - stream); - ASSERT_TRUE(compApprox(obj_l1_no_b, fx)); - - l1 = 0.0; - l2 = alpha; - - double obj_l2_no_b = 0.6597171282106854; - - fx = run(handle, loss_no_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); - ASSERT_TRUE(compApprox(obj_l2_no_b, fx)); - - fx = run_api(cuml_handle, - QN_LOSS_SOFTMAX, - C, - loss_no_b.fit_intercept, - *Xdev, - *ydev, - l1, - l2, - w0.data, - z, - 0, - stream); - ASSERT_TRUE(compApprox(obj_l2_no_b, fx)); + SimpleVecOwning w_ref(dims.n_param, stream); + raft::update_device(w_ref.data, &w_ref_cm[0], C * D, stream); + if (fit_intercept) { raft::update_device(&w_ref.data[C * D], host_bias, C, stream); } + handle.sync_stream(stream); + return raft::devArrMatch(w_ref.data, w, w_ref.len, comp); } -TEST_F(QuasiNewtonTest, linear_regression_vs_sklearn) +template +T run(const raft::handle_t& handle, + LossFunction& loss, + const SimpleMat& X, + const SimpleVec& y, + T l1, + T l2, + T* w, + SimpleDenseMat& z) { -raft::CompareApprox compApprox(tol); -double y[N] = {0.2675836026202781, - -0.0678277759663704, - -0.6334027174275105, - -0.1018336189077367, - 0.0933815935886932, - -1.1058853496996381, - -0.1658298189619160, - -0.2954290675648911, - 0.7966520536712608, - -1.0767450516284769}; -raft::update_device(ydev->data, &y[0], ydev->len, stream); -handle.sync_stream(stream); - -double fx, l1, l2; -double alpha = 0.01 * N; - -SimpleVecOwning w0(D + 1, stream); -SimpleMatOwning z(1, N, stream); -SquaredLoss loss_b(handle, D, true); -SquaredLoss loss_no_b(handle, D, false); - -l1 = alpha; -l2 = 0.0; -double w_l1_b[2] = {-0.4952397281519840, 0.3813315300180231}; -double b_l1_b = -0.08140861819001188; -double obj_l1_b = 0.011136986298775138; -fx = run(handle, loss_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); -ASSERT_TRUE(compApprox(obj_l1_b, fx)); -ASSERT_TRUE(checkParamsEqual(handle, &w_l1_b[0], &b_l1_b, w0.data, loss_b, compApprox, stream)); - -fx = run_api(cuml_handle, - QN_LOSS_SQUARED, - 1, - loss_b.fit_intercept, - *Xdev, - *ydev, - l1, - l2, - w0.data, - z, - 0, - stream); -ASSERT_TRUE(compApprox(obj_l1_b, fx)); - -l1 = 0.0; -l2 = alpha; -double w_l2_b[2] = {-0.5022384743587150, 0.3937352417485087}; -double b_l2_b = -0.08062397391797513; -double obj_l2_b = 0.004268621967866347; - -fx = run(handle, loss_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); -ASSERT_TRUE(compApprox(obj_l2_b, fx)); -ASSERT_TRUE(checkParamsEqual(handle, &w_l2_b[0], &b_l2_b, w0.data, loss_b, compApprox, stream)); - -fx = run_api(cuml_handle, - QN_LOSS_SQUARED, - 1, - loss_b.fit_intercept, - *Xdev, - *ydev, - l1, - l2, - w0.data, - z, - 0, - stream); -ASSERT_TRUE(compApprox(obj_l2_b, fx)); - -l1 = alpha; -l2 = 0.0; -double w_l1_no_b[2] = {-0.5175178128147135, 0.3720844589831813}; -double obj_l1_no_b = 0.013981355746112447; - -fx = run(handle, loss_no_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); -ASSERT_TRUE(compApprox(obj_l1_no_b, fx)); -ASSERT_TRUE( - checkParamsEqual(handle, &w_l1_no_b[0], nobptr, w0.data, loss_no_b, compApprox, stream)); - -fx = run_api(cuml_handle, - QN_LOSS_SQUARED, - 1, - loss_no_b.fit_intercept, - *Xdev, - *ydev, - l1, - l2, - w0.data, - z, - 0, - stream); -ASSERT_TRUE(compApprox(obj_l1_no_b, fx)); - -l1 = 0.0; -l2 = alpha; -double w_l2_no_b[2] = {-0.5241651041233270, 0.3846317886627560}; -double obj_l2_no_b = 0.007061261366969662; - -fx = run(handle, loss_no_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); -ASSERT_TRUE(compApprox(obj_l2_no_b, fx)); -ASSERT_TRUE( - checkParamsEqual(handle, &w_l2_no_b[0], nobptr, w0.data, loss_no_b, compApprox, stream)); - -fx = run_api(cuml_handle, - QN_LOSS_SQUARED, - 1, - loss_no_b.fit_intercept, - *Xdev, - *ydev, - l1, - l2, - w0.data, - z, - 0, - stream); -ASSERT_TRUE(compApprox(obj_l2_no_b, fx)); + qn_params pams; + pams.max_iter = 100; + pams.grad_tol = 1e-16; + pams.change_tol = 1e-16; + pams.linesearch_max_iter = 50; + pams.lbfgs_memory = 5; + pams.penalty_l1 = l1; + pams.penalty_l2 = l2; + pams.verbose = verbosity; + + int num_iters = 0; + + T fx; + + qn_fit(handle, pams, loss, X, y, z, w, &fx, &num_iters); + + return fx; } -TEST_F(QuasiNewtonTest, predict) +template +T run_api(const raft::handle_t& cuml_handle, + qn_loss_type loss_type, + int C, + bool fit_intercept, + const SimpleMat& X, + const SimpleVec& y, + T l1, + T l2, + T* w, + SimpleDenseMat& z, + int verbosity, + cudaStream_t stream) { -raft::CompareApprox compApprox(1e-8); -std::vector w_host(D); -w_host[0] = 1; -std::vector preds_host(N); -SimpleVecOwning w(D, stream); -SimpleVecOwning preds(N, stream); - -raft::update_device(w.data, &w_host[0], w.len, stream); -qn_params pams; -pams.loss = QN_LOSS_LOGISTIC; -pams.fit_intercept = false; - -qnPredict(handle, pams, Xdev->data, false, N, D, 2, w.data, preds.data, stream); -raft::update_host(&preds_host[0], preds.data, preds.len, stream); -handle.sync_stream(stream); - -for (int it = 0; it < N; it++) { -ASSERT_TRUE(X[it][0] > 0 ? compApprox(preds_host[it], 1) : compApprox(preds_host[it], 0)); + qn_params pams; + + pams.max_iter = 100; + pams.grad_tol = 1e-8; + pams.change_tol = 1e-8; + pams.linesearch_max_iter = 50; + pams.lbfgs_memory = 5; + pams.penalty_l1 = l1; + pams.penalty_l2 = l2; + pams.verbose = verbosity; + pams.fit_intercept = fit_intercept; + pams.loss = loss_type; + + int num_iters = 0; + + SimpleVec w0(w, X.n + fit_intercept); + w0.fill(T(0), stream); + T fx; + + qn_fit_on_x(cuml_handle, + pams, + X_dense->data, + X_dense->ord == COL_MAJOR, + y.data, + X_dense->m, + X_dense->n, + C, + w, + &fx, + &num_iters); } +else { ADD_FAILURE(); } -pams.loss = QN_LOSS_SQUARED; -pams.fit_intercept = false; -qnPredict(handle, pams, Xdev->data, false, N, D, 1, w.data, preds.data, stream); -raft::update_host(&preds_host[0], preds.data, preds.len, stream); -handle.sync_stream(stream); +return fx; +} -for (int it = 0; it < N; it++) { -ASSERT_TRUE(compApprox(X[it][0], preds_host[it])); +TEST_F(QuasiNewtonTest, binary_logistic_vs_sklearn) +{ +#if CUDART_VERSION >= 11020 + GTEST_SKIP(); +#endif + raft::CompareApprox compApprox(tol); + // Test case generated in python and solved with sklearn + double y[N] = {1, 1, 1, 0, 1, 0, 1, 0, 1, 0}; + raft::update_device(ydev->data, &y[0], ydev->len, stream); + handle.sync_stream(stream); + + double alpha = 0.01 * N; + + LogisticLoss loss_b(handle, D, true); + LogisticLoss loss_no_b(handle, D, false); + + SimpleVecOwning w0(D + 1, stream); + SimpleMatOwning z(1, N, stream); + + double l1, l2, fx; + + double w_l1_b[2] = {-1.6899370396155091, 1.9021577534928300}; + double b_l1_b = 0.8057670813749118; + double obj_l1_b = 0.44295941481024703; + + l1 = alpha; + l2 = 0.0; + fx = run(handle, loss_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l1_b, fx)); + ASSERT_TRUE(checkParamsEqual(handle, &w_l1_b[0], &b_l1_b, w0.data, loss_b, compApprox, stream)); + + fx = run_api( + handle, QN_LOSS_LOGISTIC, 2, loss_b.fit_intercept, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l1_b, fx)); + + double w_l2_b[2] = {-1.5339880402781370, 1.6788639581350926}; + double b_l2_b = 0.806087868102401; + double obj_l2_b = 0.4378085369889721; + + l1 = 0; + l2 = alpha; + fx = run(handle, loss_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + + ASSERT_TRUE(compApprox(obj_l2_b, fx)); + ASSERT_TRUE(checkParamsEqual(handle, &w_l2_b[0], &b_l2_b, w0.data, loss_b, compApprox, stream)); + + fx = run_api(cuml_handle, + QN_LOSS_LOGISTIC, + 2, + loss_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l2_b, fx)); + + double w_l1_no_b[2] = {-1.6215035298864591, 2.3650868394981086}; + double obj_l1_no_b = 0.4769896009200278; + + l1 = alpha; + l2 = 0.0; + fx = run(handle, loss_no_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l1_no_b, fx)); + ASSERT_TRUE( + checkParamsEqual(handle, &w_l1_no_b[0], nobptr, w0.data, loss_no_b, compApprox, stream)); + + fx = run_api(cuml_handle, + QN_LOSS_LOGISTIC, + 2, + loss_no_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l1_no_b, fx)); + + double w_l2_no_b[2] = {-1.3931049893764620, 2.0140103094119621}; + double obj_l2_no_b = 0.47502098062114273; + + l1 = 0; + l2 = alpha; + fx = run(handle, loss_no_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l2_no_b, fx)); + ASSERT_TRUE( + checkParamsEqual(handle, &w_l2_no_b[0], nobptr, w0.data, loss_no_b, compApprox, stream)); + + fx = run_api(cuml_handle, + QN_LOSS_LOGISTIC, + 2, + loss_no_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l2_no_b, fx)); } + +TEST_F(QuasiNewtonTest, multiclass_logistic_vs_sklearn) +{ +#if CUDART_VERSION >= 11020 + GTEST_SKIP(); +#endif + // The data seems to small for the objective to be strongly convex + // leaving out exact param checks + + raft::CompareApprox compApprox(tol); + double y[N] = {2, 2, 0, 3, 3, 0, 0, 0, 1, 0}; + raft::update_device(ydev->data, &y[0], ydev->len, stream); + handle.sync_stream(stream); + + double fx, l1, l2; + int C = 4; + + double alpha = 0.016 * N; + + SimpleMatOwning z(C, N, stream); + SimpleVecOwning w0(C * (D + 1), stream); + + Softmax loss_b(handle, D, C, true); + Softmax loss_no_b(handle, D, C, false); + + l1 = alpha; + l2 = 0.0; + double obj_l1_b = 0.5407911382311313; + + fx = run(handle, loss_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l1_b, fx)); + + fx = run_api(cuml_handle, + QN_LOSS_SOFTMAX, + C, + loss_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l1_b, fx)); + + l1 = 0.0; + l2 = alpha; + double obj_l2_b = 0.5721784062720949; + + fx = run(handle, loss_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l2_b, fx)); + + fx = run_api(cuml_handle, + QN_LOSS_SOFTMAX, + C, + loss_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l2_b, fx)); + + l1 = alpha; + l2 = 0.0; + double obj_l1_no_b = 0.6606929813245878; + + fx = run(handle, loss_no_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l1_no_b, fx)); + + fx = run_api(cuml_handle, + QN_LOSS_SOFTMAX, + C, + loss_no_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l1_no_b, fx)); + + l1 = 0.0; + l2 = alpha; + + double obj_l2_no_b = 0.6597171282106854; + + fx = run(handle, loss_no_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l2_no_b, fx)); + + fx = run_api(cuml_handle, + QN_LOSS_SOFTMAX, + C, + loss_no_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l2_no_b, fx)); } -TEST_F(QuasiNewtonTest, predict_softmax) +TEST_F(QuasiNewtonTest, linear_regression_vs_sklearn) { -raft::CompareApprox compApprox(1e-8); -int C = 4; -std::vector w_host(C * D); -w_host[0] = 1; -w_host[D * C - 1] = 1; - -std::vector preds_host(N); -SimpleVecOwning w(w_host.size(), stream); -SimpleVecOwning preds(N, stream); - -raft::update_device(w.data, &w_host[0], w.len, stream); - -qn_params pams; -pams.loss = QN_LOSS_SOFTMAX; -pams.fit_intercept = false; -qnPredict(handle, pams, Xdev->data, false, N, D, C, w.data, preds.data, stream); -raft::update_host(&preds_host[0], preds.data, preds.len, stream); -handle.sync_stream(stream); - -for (int it = 0; it < N; it++) { -if (X[it][0] < 0 && X[it][1] < 0) { -ASSERT_TRUE(compApprox(1, preds_host[it])); -} else if (X[it][0] > X[it][1]) { -ASSERT_TRUE(compApprox(0, preds_host[it])); -} else { -ASSERT_TRUE(compApprox(C - 1, preds_host[it])); + raft::CompareApprox compApprox(tol); + double y[N] = {0.2675836026202781, + -0.0678277759663704, + -0.6334027174275105, + -0.1018336189077367, + 0.0933815935886932, + -1.1058853496996381, + -0.1658298189619160, + -0.2954290675648911, + 0.7966520536712608, + -1.0767450516284769}; + raft::update_device(ydev->data, &y[0], ydev->len, stream); + handle.sync_stream(stream); + + double fx, l1, l2; + double alpha = 0.01 * N; + + SimpleVecOwning w0(D + 1, stream); + SimpleMatOwning z(1, N, stream); + SquaredLoss loss_b(handle, D, true); + SquaredLoss loss_no_b(handle, D, false); + + l1 = alpha; + l2 = 0.0; + double w_l1_b[2] = {-0.4952397281519840, 0.3813315300180231}; + double b_l1_b = -0.08140861819001188; + double obj_l1_b = 0.011136986298775138; + fx = run(handle, loss_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l1_b, fx)); + ASSERT_TRUE(checkParamsEqual(handle, &w_l1_b[0], &b_l1_b, w0.data, loss_b, compApprox, stream)); + + fx = run_api(cuml_handle, + QN_LOSS_SQUARED, + 1, + loss_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l1_b, fx)); + + l1 = 0.0; + l2 = alpha; + double w_l2_b[2] = {-0.5022384743587150, 0.3937352417485087}; + double b_l2_b = -0.08062397391797513; + double obj_l2_b = 0.004268621967866347; + + fx = run(handle, loss_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l2_b, fx)); + ASSERT_TRUE(checkParamsEqual(handle, &w_l2_b[0], &b_l2_b, w0.data, loss_b, compApprox, stream)); + + fx = run_api(cuml_handle, + QN_LOSS_SQUARED, + 1, + loss_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l2_b, fx)); + + l1 = alpha; + l2 = 0.0; + double w_l1_no_b[2] = {-0.5175178128147135, 0.3720844589831813}; + double obj_l1_no_b = 0.013981355746112447; + + fx = run(handle, loss_no_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l1_no_b, fx)); + ASSERT_TRUE( + checkParamsEqual(handle, &w_l1_no_b[0], nobptr, w0.data, loss_no_b, compApprox, stream)); + + fx = run_api(cuml_handle, + QN_LOSS_SQUARED, + 1, + loss_no_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l1_no_b, fx)); + + l1 = 0.0; + l2 = alpha; + double w_l2_no_b[2] = {-0.5241651041233270, 0.3846317886627560}; + double obj_l2_no_b = 0.007061261366969662; + + fx = run(handle, loss_no_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); + ASSERT_TRUE(compApprox(obj_l2_no_b, fx)); + ASSERT_TRUE( + checkParamsEqual(handle, &w_l2_no_b[0], nobptr, w0.data, loss_no_b, compApprox, stream)); + + fx = run_api(cuml_handle, + QN_LOSS_SQUARED, + 1, + loss_no_b.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0.data, + z, + 0, + stream); + ASSERT_TRUE(compApprox(obj_l2_no_b, fx)); } + +TEST_F(QuasiNewtonTest, predict) +{ + raft::CompareApprox compApprox(1e-8); + std::vector w_host(D); + w_host[0] = 1; + std::vector preds_host(N); + SimpleVecOwning w(D, stream); + SimpleVecOwning preds(N, stream); + + raft::update_device(w.data, &w_host[0], w.len, stream); + qn_params pams; + pams.loss = QN_LOSS_LOGISTIC; + pams.fit_intercept = false; + + qnPredict(handle, pams, Xdev->data, false, N, D, 2, w.data, preds.data, stream); + raft::update_host(&preds_host[0], preds.data, preds.len, stream); + handle.sync_stream(stream); + + for (int it = 0; it < N; it++) { + ASSERT_TRUE(X[it][0] > 0 ? compApprox(preds_host[it], 1) : compApprox(preds_host[it], 0)); + } + + pams.loss = QN_LOSS_SQUARED; + pams.fit_intercept = false; + qnPredict(handle, pams, Xdev->data, false, N, D, 1, w.data, preds.data, stream); + raft::update_host(&preds_host[0], preds.data, preds.len, stream); + handle.sync_stream(stream); + + for (int it = 0; it < N; it++) { + ASSERT_TRUE(compApprox(X[it][0], preds_host[it])); + } } + +TEST_F(QuasiNewtonTest, predict_softmax) +{ + raft::CompareApprox compApprox(1e-8); + int C = 4; + std::vector w_host(C * D); + w_host[0] = 1; + w_host[D * C - 1] = 1; + + std::vector preds_host(N); + SimpleVecOwning w(w_host.size(), stream); + SimpleVecOwning preds(N, stream); + + raft::update_device(w.data, &w_host[0], w.len, stream); + + qn_params pams; + pams.loss = QN_LOSS_SOFTMAX; + pams.fit_intercept = false; + qnPredict(handle, pams, Xdev->data, false, N, D, C, w.data, preds.data, stream); + raft::update_host(&preds_host[0], preds.data, preds.len, stream); + handle.sync_stream(stream); + + for (int it = 0; it < N; it++) { + if (X[it][0] < 0 && X[it][1] < 0) { + ASSERT_TRUE(compApprox(1, preds_host[it])); + } else if (X[it][0] > X[it][1]) { + ASSERT_TRUE(compApprox(0, preds_host[it])); + } else { + ASSERT_TRUE(compApprox(C - 1, preds_host[it])); + } + } } TEST_F(QuasiNewtonTest, dense_vs_sparse_logistic) { #if CUDART_VERSION >= 11020 -GTEST_SKIP(); + GTEST_SKIP(); #endif -// Prepare a sparse input matrix from the dense matrix X. -// Yes, it's not sparse at all, yet the test does check whether the behaviour -// of dense and sparse variants is the same. -rmm::device_uvector mem_X_cols(N * D, stream); -rmm::device_uvector mem_X_row_ids(N + 1, stream); -int host_X_cols[N][D]; -int host_X_row_ids[N + 1]; -for (int i = 0; i < N; i++) { -for (int j = 0; j < D; j++) { -host_X_cols[i][j] = j; -} -} -for (int i = 0; i < N + 1; i++) { -host_X_row_ids[i] = i * D; -} -raft::update_device(mem_X_cols.data(), &host_X_cols[0][0], mem_X_cols.size(), stream); -raft::update_device(mem_X_row_ids.data(), &host_X_row_ids[0], mem_X_row_ids.size(), stream); -SimpleSparseMat X_sparse( - Xdev->data, mem_X_cols.data(), mem_X_row_ids.data(), N * D, N, D); - -raft::CompareApprox compApprox(tol); -double y[N] = {2, 2, 0, 3, 3, 0, 0, 0, 1, 0}; -raft::update_device(ydev->data, &y[0], ydev->len, stream); -handle.sync_stream(stream); - -int C = 4; -qn_loss_type loss_type = QN_LOSS_SOFTMAX; // Softmax (loss_b, loss_no_b) -double alpha = 0.016 * N; -Softmax loss_b(handle, D, C, true); -Softmax loss_no_b(handle, D, C, false); - -SimpleMatOwning z_dense(C, N, stream); -SimpleMatOwning z_sparse(C, N, stream); -SimpleVecOwning w0_dense(C * (D + 1), stream); -SimpleVecOwning w0_sparse(C * (D + 1), stream); - -std::vector preds_dense_host(N); -std::vector preds_sparse_host(N); -SimpleVecOwning preds_dense(N, stream); -SimpleVecOwning preds_sparse(N, stream); - -auto test_run = [&](double l1, double l2, Softmax loss) { + // Prepare a sparse input matrix from the dense matrix X. + // Yes, it's not sparse at all, yet the test does check whether the behaviour + // of dense and sparse variants is the same. + rmm::device_uvector mem_X_cols(N * D, stream); + rmm::device_uvector mem_X_row_ids(N + 1, stream); + int host_X_cols[N][D]; + int host_X_row_ids[N + 1]; + for (int i = 0; i < N; i++) { + for (int j = 0; j < D; j++) { + host_X_cols[i][j] = j; + } + } + for (int i = 0; i < N + 1; i++) { + host_X_row_ids[i] = i * D; + } + raft::update_device(mem_X_cols.data(), &host_X_cols[0][0], mem_X_cols.size(), stream); + raft::update_device(mem_X_row_ids.data(), &host_X_row_ids[0], mem_X_row_ids.size(), stream); + SimpleSparseMat X_sparse( + Xdev->data, mem_X_cols.data(), mem_X_row_ids.data(), N * D, N, D); + + raft::CompareApprox compApprox(tol); + double y[N] = {2, 2, 0, 3, 3, 0, 0, 0, 1, 0}; + raft::update_device(ydev->data, &y[0], ydev->len, stream); + handle.sync_stream(stream); + + int C = 4; + qn_loss_type loss_type = QN_LOSS_SOFTMAX; // Softmax (loss_b, loss_no_b) + double alpha = 0.016 * N; + Softmax loss_b(handle, D, C, true); + Softmax loss_no_b(handle, D, C, false); + + SimpleMatOwning z_dense(C, N, stream); + SimpleMatOwning z_sparse(C, N, stream); + SimpleVecOwning w0_dense(C * (D + 1), stream); + SimpleVecOwning w0_sparse(C * (D + 1), stream); + + std::vector preds_dense_host(N); + std::vector preds_sparse_host(N); + SimpleVecOwning preds_dense(N, stream); + SimpleVecOwning preds_sparse(N, stream); + + auto test_run = [&](double l1, double l2, Softmax loss) { qn_params pams; pams.penalty_l1 = l1; pams.penalty_l2 = l2; @@ -794,21 +781,21 @@ auto test_run = [&](double l1, double l2, Softmax loss) { raft::update_host(&preds_sparse_host[0], preds_sparse.data, preds_sparse.len, stream); handle.sync_stream(stream); for (int i = 0; i < N; i++) { - ASSERT_TRUE(compApprox(preds_dense_host[i], preds_sparse_host[i])); + ASSERT_TRUE(compApprox(preds_dense_host[i], preds_sparse_host[i])); } f_dense = run_api(cuml_handle, - QN_LOSS_SOFTMAX, - C, - loss.fit_intercept, - *Xdev, - *ydev, - l1, - l2, - w0_dense.data, - z_dense, - 0, - stream); + QN_LOSS_SOFTMAX, + C, + loss.fit_intercept, + *Xdev, + *ydev, + l1, + l2, + w0_dense.data, + z_dense, + 0, + stream); f_sparse = run_api(cuml_handle, QN_LOSS_SOFTMAX, C, @@ -822,12 +809,12 @@ auto test_run = [&](double l1, double l2, Softmax loss) { 0, stream); ASSERT_TRUE(compApprox(f_dense, f_sparse)); -}; + }; -test_run(alpha, 0.0, loss_b); -test_run(0.0, alpha, loss_b); -test_run(alpha, 0.0, loss_no_b); -test_run(0.0, alpha, loss_no_b); + test_run(alpha, 0.0, loss_b); + test_run(0.0, alpha, loss_b); + test_run(alpha, 0.0, loss_no_b); + test_run(0.0, alpha, loss_no_b); } } // namespace raft::solver::quasi_newton From 31047e638b94316af12d8fcd1b7057bf1f9caf1d Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 20 Oct 2022 20:41:11 -0400 Subject: [PATCH 28/35] Deprecation warnings --- cpp/include/raft/cluster/single_linkage.cuh | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cpp/include/raft/cluster/single_linkage.cuh b/cpp/include/raft/cluster/single_linkage.cuh index 36fc812445..0a79601647 100644 --- a/cpp/include/raft/cluster/single_linkage.cuh +++ b/cpp/include/raft/cluster/single_linkage.cuh @@ -102,6 +102,12 @@ void single_linkage(const raft::handle_t& handle, } } // namespace hierarchy +/** + * Note: All of the functions below in the raft::cluster namespace are deprecated + * and will be removed in a future release. Please use raft::cluster::hierarchy + * instead. + */ + /** * Single-linkage clustering, capable of constructing a KNN graph to * scale the algorithm beyond the n^2 memory consumption of implementations From fd8899ce7a2f5afa0066cacf83c919db20a9cee0 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 20 Oct 2022 21:10:38 -0400 Subject: [PATCH 29/35] Fixing typo --- cpp/include/raft/cluster/kmeans.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/cluster/kmeans.cuh b/cpp/include/raft/cluster/kmeans.cuh index 10004b69bf..cfd47d4058 100644 --- a/cpp/include/raft/cluster/kmeans.cuh +++ b/cpp/include/raft/cluster/kmeans.cuh @@ -735,7 +735,7 @@ void sampleCentroids(const raft::handle_t& handle, rmm::device_uvector& inRankCp, rmm::device_uvector& workspace) { - kmeans::sample_entroids( + kmeans::sample_centroids( handle, X, minClusterDistance, isSampleCentroid, select_op, inRankCp, workspace); } From 4968ce88263150065d0d13dead09a0b5b73f85e3 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 20 Oct 2022 21:32:52 -0400 Subject: [PATCH 30/35] Fixing hierarchical compile error --- cpp/include/raft/cluster/single_linkage.cuh | 106 ++++-------------- .../raft/cluster/single_linkage_types.hpp | 16 ++- cpp/test/CMakeLists.txt | 2 +- cpp/test/{sparse => cluster}/linkage.cu | 21 ++-- 4 files changed, 39 insertions(+), 106 deletions(-) rename cpp/test/{sparse => cluster}/linkage.cu (98%) diff --git a/cpp/include/raft/cluster/single_linkage.cuh b/cpp/include/raft/cluster/single_linkage.cuh index 0a79601647..2d74c364b2 100644 --- a/cpp/include/raft/cluster/single_linkage.cuh +++ b/cpp/include/raft/cluster/single_linkage.cuh @@ -21,8 +21,11 @@ namespace raft::cluster { -namespace hierarchy { -constexpr int DEFAULT_CONST_C = 15; +/** + * Note: All of the functions below in the raft::cluster namespace are deprecated + * and will be removed in a future release. Please use raft::cluster::hierarchy + * instead. + */ /** * Single-linkage clustering, capable of constructing a KNN graph to @@ -59,6 +62,11 @@ void single_linkage(const raft::handle_t& handle, detail::single_linkage( handle, X, m, n, metric, out, c, n_clusters); } +}; // namespace raft::cluster + +namespace raft::cluster::hierarchy { + +constexpr int DEFAULT_CONST_C = 15; /** * Single-linkage clustering, capable of constructing a KNN graph to @@ -91,88 +99,14 @@ void single_linkage(const raft::handle_t& handle, out_arrs.children = dendrogram.data_handle(); out_arrs.labels = labels.data_handle(); - single_linkage(handle, - X.data_handle(), - static_cast(X.extent(0)), - static_cast(X.extent(1)), - metric, - &out_arrs, - c.has_value() ? c.value() : DEFAULT_CONST_C, - n_clusters); + raft::cluster::single_linkage( + handle, + X.data_handle(), + static_cast(X.extent(0)), + static_cast(X.extent(1)), + metric, + &out_arrs, + c.has_value() ? c.value() : DEFAULT_CONST_C, + n_clusters); } -} // namespace hierarchy - -/** - * Note: All of the functions below in the raft::cluster namespace are deprecated - * and will be removed in a future release. Please use raft::cluster::hierarchy - * instead. - */ - -/** - * Single-linkage clustering, capable of constructing a KNN graph to - * scale the algorithm beyond the n^2 memory consumption of implementations - * that use the fully-connected graph of pairwise distances by connecting - * a knn graph when k is not large enough to connect it. - - * @tparam value_idx - * @tparam value_t - * @tparam dist_type method to use for constructing connectivities graph - * @param[in] handle raft handle - * @param[in] X dense input matrix in row-major layout - * @param[in] m number of rows in X - * @param[in] n number of columns in X - * @param[in] metric distance metrix to use when constructing connectivities graph - * @param[out] out struct containing output dendrogram and cluster assignments - * @param[in] c a constant used when constructing connectivities from knn graph. Allows the indirect - control - * of k. The algorithm will set `k = log(n) + c` - * @param[in] n_clusters number of clusters to assign data samples - */ -template -void single_linkage(const raft::handle_t& handle, - const value_t* X, - size_t m, - size_t n, - raft::distance::DistanceType metric, - linkage_output* out, - int c, - size_t n_clusters) -{ - hierarchy::single_linkage( - handle, X, m, n, metric, out, c, n_clusters); -} - -/** - * Single-linkage clustering, capable of constructing a KNN graph to - * scale the algorithm beyond the n^2 memory consumption of implementations - * that use the fully-connected graph of pairwise distances by connecting - * a knn graph when k is not large enough to connect it. - - * @tparam value_idx - * @tparam value_t - * @tparam dist_type method to use for constructing connectivities graph - * @param[in] handle raft handle - * @param[in] X dense input matrix in row-major layout - * @param[out] dendrogram output dendrogram (size [n_rows - 1] * 2) - * @param[out] labels output labels vector (size n_rows) - * @param[in] metric distance metrix to use when constructing connectivities graph - * @param[in] n_clusters number of clusters to assign data samples - * @param[in] c a constant used when constructing connectivities from knn graph. Allows the indirect - control of k. The algorithm will set `k = log(n) + c` - */ -template -void single_linkage(const raft::handle_t& handle, - raft::device_matrix_view X, - raft::device_matrix_view dendrogram, - raft::device_vector_view labels, - raft::distance::DistanceType metric, - size_t n_clusters, - std::optional c = std::make_optional(hierarchy::DEFAULT_CONST_C)) -{ - hierarchy::single_linkage( - handle, X, dendrogram, labels, metric, n_clusters, c); -} - -}; // namespace raft::cluster +}; // namespace raft::cluster::hierarchy diff --git a/cpp/include/raft/cluster/single_linkage_types.hpp b/cpp/include/raft/cluster/single_linkage_types.hpp index 28b245a2cf..55239ff6d6 100644 --- a/cpp/include/raft/cluster/single_linkage_types.hpp +++ b/cpp/include/raft/cluster/single_linkage_types.hpp @@ -18,10 +18,15 @@ #include +namespace raft::cluster::hierarchy { +enum LinkageDistance { PAIRWISE = 0, KNN_GRAPH = 1 }; + +}; // end namespace raft::cluster::hierarchy + +// The code below is legacy namespace raft::cluster { -namespace hierarchy { -enum LinkageDistance { PAIRWISE = 0, KNN_GRAPH = 1 }; +using hierarchy::LinkageDistance; /** * Simple POCO for consolidating linkage results. This closely @@ -59,11 +64,4 @@ class linkage_output_int : public linkage_output { class linkage_output_int64 : public linkage_output { }; -} // end namespace hierarchy - -using hierarchy::linkage_output; -using hierarchy::linkage_output_int; -using hierarchy::linkage_output_int64; -using hierarchy::LinkageDistance; - }; // namespace raft::cluster \ No newline at end of file diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 07ec85bf1e..0d5af9be5c 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -82,7 +82,7 @@ if(BUILD_TESTS) PATH test/cluster/kmeans.cu test/cluster_solvers.cu - test/sparse/linkage.cu + test/cluster/linkage.cu OPTIONAL DIST NN ) diff --git a/cpp/test/sparse/linkage.cu b/cpp/test/cluster/linkage.cu similarity index 98% rename from cpp/test/sparse/linkage.cu rename to cpp/test/cluster/linkage.cu index ce5741d06b..5533f552bd 100644 --- a/cpp/test/sparse/linkage.cu +++ b/cpp/test/cluster/linkage.cu @@ -180,20 +180,21 @@ class LinkageTest : public ::testing::TestWithParam> { raft::handle_t handle; - auto data_view = - raft::make_device_matrix_view(data.data(), params.n_row, params.n_col); + auto data_view = raft::make_device_matrix_view( + data.data(), params.n_row, params.n_col); auto dendrogram_view = raft::make_device_matrix_view(out_children.data(), params.n_row, 2); auto labels_view = raft::make_device_vector_view(labels.data(), params.n_row); - raft::cluster::single_linkage( - handle, - data_view, - dendrogram_view, - labels_view, - raft::distance::DistanceType::L2SqrtExpanded, - params.n_clusters, - std::make_optional(params.c)); + raft::cluster::hierarchy:: + single_linkage( + handle, + data_view, + dendrogram_view, + labels_view, + raft::distance::DistanceType::L2SqrtExpanded, + params.n_clusters, + std::make_optional(params.c)); handle.sync_stream(stream); From 6ccf61fbac6ae3d6eb6a098b9eb1e7ea4c77d27e Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 20 Oct 2022 22:03:46 -0400 Subject: [PATCH 31/35] Removing namespace conflict --- cpp/include/raft/cluster/kmeans.cuh | 4 +- cpp/include/raft/cluster/kmeans_types.hpp | 8 ++-- cpp/test/cluster_solvers_deprecated.cu | 48 ----------------------- 3 files changed, 6 insertions(+), 54 deletions(-) diff --git a/cpp/include/raft/cluster/kmeans.cuh b/cpp/include/raft/cluster/kmeans.cuh index cfd47d4058..2025a15ecf 100644 --- a/cpp/include/raft/cluster/kmeans.cuh +++ b/cpp/include/raft/cluster/kmeans.cuh @@ -489,7 +489,7 @@ void fit_main(const raft::handle_t& handle, handle, params, X, weight, centroidsRawData, inertia, n_iter, workspace); } -} // end namespace raft::cluster::kmeans +}; // end namespace raft::cluster::kmeans namespace raft::cluster { @@ -962,4 +962,4 @@ void kmeans_fit_main(const raft::handle_t& handle, kmeans::fit_main( handle, params, X, weight, centroidsRawData, inertia, n_iter, workspace); } -} // namespace raft::cluster +}; // namespace raft::cluster diff --git a/cpp/include/raft/cluster/kmeans_types.hpp b/cpp/include/raft/cluster/kmeans_types.hpp index bb8e1a2b73..d6eadd1ba6 100644 --- a/cpp/include/raft/cluster/kmeans_types.hpp +++ b/cpp/include/raft/cluster/kmeans_types.hpp @@ -18,9 +18,7 @@ #include #include -namespace raft::cluster { - -namespace kmeans { +namespace raft::cluster::kmeans { struct KMeansParams { enum InitMethod { KMeansPlusPlus, Random, Array }; @@ -71,7 +69,9 @@ struct KMeansParams { bool inertia_check = false; }; -} // namespace kmeans +} // namespace raft::cluster::kmeans + +namespace raft::cluster { using kmeans::KMeansParams; diff --git a/cpp/test/cluster_solvers_deprecated.cu b/cpp/test/cluster_solvers_deprecated.cu index 1e9ec0c15b..167a710b34 100644 --- a/cpp/test/cluster_solvers_deprecated.cu +++ b/cpp/test/cluster_solvers_deprecated.cu @@ -20,7 +20,6 @@ #include #include -#include namespace raft { namespace spectral { @@ -54,52 +53,5 @@ TEST(Raft, ClusterSolvers) EXPECT_ANY_THROW(cluster_solver.solve(h, n, d, eigvecs, codes)); } -TEST(Raft, ModularitySolvers) -{ - using namespace matrix; - using index_type = int; - using value_type = double; - - handle_t h; - ASSERT_EQ(0, - h. - - get_device() - - ); - - index_type neigvs{10}; - index_type maxiter{100}; - index_type restart_iter{10}; - value_type tol{1.0e-10}; - bool reorthog{true}; - - // nullptr expected to trigger exceptions: - // - index_type* clusters{nullptr}; - value_type* eigvals{nullptr}; - value_type* eigvecs{nullptr}; - - unsigned long long seed{100110021003}; - - eigen_solver_config_t eig_cfg{ - neigvs, maxiter, restart_iter, tol, reorthog, seed}; - lanczos_solver_t eig_solver{eig_cfg}; - - index_type k{5}; - - cluster_solver_config_deprecated_t clust_cfg{k, maxiter, tol, seed}; - kmeans_solver_deprecated_t cluster_solver{clust_cfg}; - - auto stream = h.get_stream(); - sparse_matrix_t sm{h, nullptr, nullptr, nullptr, 0, 0}; - - EXPECT_ANY_THROW(spectral::modularity_maximization( - h, sm, eig_solver, cluster_solver, clusters, eigvals, eigvecs)); - - value_type modularity{0}; - EXPECT_ANY_THROW(spectral::analyzeModularity(h, sm, k, clusters, modularity)); -} - } // namespace spectral } // namespace raft From a36f26e8e378856946578258e7b7a40e56628e97 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 21 Oct 2022 18:41:31 -0400 Subject: [PATCH 32/35] Many many updates to the docs, including a quick-start --- build.sh | 44 ++++- cpp/doxygen/Doxyfile.in | 4 +- cpp/include/raft/cluster/kmeans.cuh | 161 ++++++++++-------- cpp/include/raft/cluster/kmeans_params.hpp | 11 +- cpp/include/raft/cluster/kmeans_types.hpp | 66 +++++-- .../raft/cluster/single_linkage_types.hpp | 24 ++- cpp/include/raft/neighbors/ball_cover.cuh | 60 +++++++ cpp/include/raft/neighbors/brute_force.cuh | 91 +++++++--- .../raft/neighbors/epsilon_neighborhood.cuh | 17 +- cpp/include/raft/neighbors/ivf_flat.cuh | 102 ++++++----- cpp/include/raft/neighbors/ivf_pq.cuh | 2 +- cpp/include/raft/solver/linear_assignment.cuh | 49 +++++- cpp/include/raft/sparse/solver/mst.cuh | 23 +++ cpp/include/raft/spatial/knn/ivf_flat.cuh | 1 - .../raft/stats/adjusted_rand_index.cuh | 8 +- cpp/include/raft/stats/common.hpp | 59 +------ cpp/include/raft/stats/detail/histogram.cuh | 6 + cpp/include/raft/stats/histogram.cuh | 8 + cpp/include/raft/stats/stats_types.hpp | 62 +++++++ cpp/test/neighbors/ann_ivf_flat.cu | 2 +- cpp/test/neighbors/knn.cu | 3 +- docs/source/cpp_api.rst | 7 +- docs/source/cpp_api/cluster.rst | 3 +- docs/source/cpp_api/core.rst | 75 ++++++-- docs/source/cpp_api/distance.rst | 3 +- docs/source/cpp_api/linalg.rst | 5 +- docs/source/cpp_api/matrix.rst | 3 +- docs/source/cpp_api/solver.rst | 6 +- docs/source/cpp_api/sparse.rst | 15 +- docs/source/index.rst | 37 +++- docs/source/quick_start.md | 129 ++++++++++++++ 31 files changed, 811 insertions(+), 275 deletions(-) create mode 100644 cpp/include/raft/stats/stats_types.hpp create mode 100644 docs/source/quick_start.md diff --git a/build.sh b/build.sh index 9548fbec44..61e6d1a007 100755 --- a/build.sh +++ b/build.sh @@ -227,18 +227,50 @@ fi if hasArg tests || (( ${NUMARGS} == 0 )); then BUILD_TESTS=ON - COMPILE_DIST_LIBRARY=ON - ENABLE_NN_DEPENDENCIES=ON - COMPILE_NN_LIBRARY=ON CMAKE_TARGET="${CMAKE_TARGET};${TEST_TARGETS}" + + # Force compile nn library when needed test targets are specified + if [[ $CMAKE_TARGET == *"CLUSTER_TEST"* || \ + $CMAKE_TARGET == *"SPARSE_DIST_TEST"* || \ + $CMAKE_TARGET == *"SPARSE_NEIGHBORS_TEST"* || \ + $CMAKE_TARGET == *"NEIGHBORS_TEST"* || \ + $CMAKE_TARGET == *"STATS_TEST"* ]]; then + echo "-- Enabling nearest neighbors lib for gtests" + ENABLE_NN_DEPENDENCIES=ON + COMPILE_NN_LIBRARY=ON + fi + + # Force compile distance library when needed test targets are specified + if [[ $CMAKE_TARGET == *"CLUSTER_TEST"* || \ + $CMAKE_TARGET == *"DISTANCE_TEST"* || \ + $CMAKE_TARGET == *"SPARSE_DIST_TEST" || \ + $CMAKE_TARGET == *"SPARSE_NEIGHBORS_TEST"* || \ + $CMAKE_TARGET == *"NEIGHBORS_TEST" || \ + $CMAKE_TARGET == *"STATS_TEST"* ]]; then + echo "-- Enabling distance lib for gtests" + COMPILE_DIST_LIBRARY=ON + fi fi if hasArg bench || (( ${NUMARGS} == 0 )); then BUILD_BENCH=ON - COMPILE_DIST_LIBRARY=ON - ENABLE_NN_DEPENDENCIES=ON - COMPILE_NN_LIBRARY=ON CMAKE_TARGET="${CMAKE_TARGET};${BENCH_TARGETS}" + + # Force compile nn library when needed benchmark targets are specified + if [[ $CMAKE_TARGET == *"CLUSTER_BENCH"* || \ + $CMAKE_TARGET == *"NEIGHBORS_BENCH"* ]]; then + echo "-- Enabling nearest neighbors lib for benchmarks" + ENABLE_NN_DEPENDENCIES=ON + COMPILE_NN_LIBRARY=ON + fi + + # Force compile distance library when needed benchmark targets are specified + if [[ $CMAKE_TARGET == *"CLUSTER_BENCH"* || \ + $CMAKE_TARGET == *"NEIGHBORS_BENCH"* ]]; then + echo "-- Enabling distance lib for benchmarks" + COMPILE_DIST_LIBRARY=ON + fi + fi if hasArg --buildfaiss; then diff --git a/cpp/doxygen/Doxyfile.in b/cpp/doxygen/Doxyfile.in index 5517562a9f..07056e503d 100644 --- a/cpp/doxygen/Doxyfile.in +++ b/cpp/doxygen/Doxyfile.in @@ -900,7 +900,9 @@ EXCLUDE = @CMAKE_CURRENT_SOURCE_DIR@/include/raft/sparse/linalg/s @CMAKE_CURRENT_SOURCE_DIR@/include/raft/span.hpp \ @CMAKE_CURRENT_SOURCE_DIR@/include/raft/vectorized.cuh \ @CMAKE_CURRENT_SOURCE_DIR@/include/raft/raft.hpp \ - @CMAKE_CURRENT_SOURCE_DIR@/include/raft/core/cudart_utils.hpp + @CMAKE_CURRENT_SOURCE_DIR@/include/raft/core/cudart_utils.hpp \ + @CMAKE_CURRENT_SOURCE_DIR@/include/raft/matrix/math.cuh \ + @CMAKE_CURRENT_SOURCE_DIR@/include/raft/matrix/matrix.cuh # The EXCLUDE_SYMLINKS tag can be used to select whether or not files or # directories that are symbolic links (a Unix file system feature) are excluded diff --git a/cpp/include/raft/cluster/kmeans.cuh b/cpp/include/raft/cluster/kmeans.cuh index 2025a15ecf..a85fd1b38b 100644 --- a/cpp/include/raft/cluster/kmeans.cuh +++ b/cpp/include/raft/cluster/kmeans.cuh @@ -28,6 +28,27 @@ namespace raft::cluster::kmeans { * Initial centroids are chosen with k-means++ algorithm. Empty * clusters are reinitialized by choosing new centroids with * k-means++ algorithm. + * + * @code{.cpp} + * #include + * #include + * #include + * using namespace raft::cluster; + * ... + * raft::handle_t handle; + * raft::cluster::KMeansParams params; + * int n_features = 15, inertia, n_iter; + * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); + * + * kmeans::fit(handle, + * params, + * X, + * std::nullopt, + * centroids, + * raft::make_scalar_view(&inertia), + * raft::make_scalar_view(&n_iter)); + * @endcode + * * @tparam DataT the type of data used for weights, distances. * @tparam IndexT the type of data used for indexing. * @param[in] handle The raft handle. @@ -47,7 +68,7 @@ namespace raft::cluster::kmeans { * closest cluster center. * @param[out] n_iter Number of iterations run. */ -template +template void fit(handle_t const& handle, const KMeansParams& params, raft::device_matrix_view X, @@ -59,23 +80,40 @@ void fit(handle_t const& handle, detail::kmeans_fit(handle, params, X, sample_weight, centroids, inertia, n_iter); } -template -void fit(handle_t const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* sample_weight, - DataT* centroids, - IndexT n_samples, - IndexT n_features, - DataT& inertia, - IndexT& n_iter) -{ - detail::kmeans_fit( - handle, params, X, sample_weight, centroids, n_samples, n_features, inertia, n_iter); -} - /** * @brief Predict the closest cluster each sample in X belongs to. + * + * @code{.cpp} + * #include + * #include + * #include + * using namespace raft::cluster; + * ... + * raft::handle_t handle; + * raft::cluster::KMeansParams params; + * int n_features = 15, inertia, n_iter; + * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); + * + * kmeans::fit(handle, + * params, + * X, + * std::nullopt, + * centroids.view(), + * raft::make_scalar_view(&inertia), + * raft::make_scalar_view(&n_iter)); + * ... + * auto labels = raft::make_device_vector(handle, X.extent(0)); + * + * kmeans::predict(handle, + * params, + * X, + * std::nullopt, + * centroids.view(), + * false, + * labels.view(), + * raft::make_scalar_view(&ineratia)); + * @endcode + * * @tparam DataT the type of data used for weights, distances. * @tparam IndexT the type of data used for indexing. * @param[in] handle The raft handle. @@ -94,7 +132,7 @@ void fit(handle_t const& handle, * @param[out] inertia Sum of squared distances of samples to * their closest cluster center. */ -template +template void predict(handle_t const& handle, const KMeansParams& params, raft::device_matrix_view X, @@ -108,34 +146,32 @@ void predict(handle_t const& handle, handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia); } -template -void predict(handle_t const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* sample_weight, - const DataT* centroids, - IndexT n_samples, - IndexT n_features, - IndexT* labels, - bool normalize_weight, - DataT& inertia) -{ - detail::kmeans_predict(handle, - params, - X, - sample_weight, - centroids, - n_samples, - n_features, - labels, - normalize_weight, - inertia); -} - /** * @brief Compute k-means clustering and predicts cluster index for each sample * in the input. * + * @code{.cpp} + * #include + * #include + * #include + * using namespace raft::cluster; + * ... + * raft::handle_t handle; + * raft::cluster::KMeansParams params; + * int n_features = 15, inertia, n_iter; + * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); + * auto labels = raft::make_device_vector(handle, X.extent(0)); + * + * kmeans::fit_predict(handle, + * params, + * X, + * std::nullopt, + * centroids.view(), + * labels.view(), + * raft::make_scalar_view(&inertia), + * raft::make_scalar_view(&n_iter)); + * @endcode + * * @tparam DataT the type of data used for weights, distances. * @tparam IndexT the type of data used for indexing. * @param[in] handle The raft handle. @@ -159,7 +195,7 @@ void predict(handle_t const& handle, * closest cluster center. * @param[out] n_iter Number of iterations run. */ -template +template void fit_predict(handle_t const& handle, const KMeansParams& params, raft::device_matrix_view X, @@ -173,22 +209,6 @@ void fit_predict(handle_t const& handle, handle, params, X, sample_weight, centroids, labels, inertia, n_iter); } -template -void fit_predict(handle_t const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* sample_weight, - DataT* centroids, - IndexT n_samples, - IndexT n_features, - IndexT* labels, - DataT& inertia, - IndexT& n_iter) -{ - detail::kmeans_fit_predict( - handle, params, X, sample_weight, centroids, n_samples, n_features, labels, inertia, n_iter); -} - /** * @brief Transform X to a cluster-distance space. * @@ -204,7 +224,7 @@ void fit_predict(handle_t const& handle, * @param[out] X_new X transformed in the new space. * [dim = n_samples x n_features] */ -template +template void transform(const raft::handle_t& handle, const KMeansParams& params, raft::device_matrix_view X, @@ -214,7 +234,7 @@ void transform(const raft::handle_t& handle, detail::kmeans_transform(handle, params, X, centroids, X_new); } -template +template void transform(const raft::handle_t& handle, const KMeansParams& params, const DataT* X, @@ -227,7 +247,7 @@ void transform(const raft::handle_t& handle, handle, params, X, centroids, n_samples, n_features, X_new); } -template +template using SamplingOp = detail::SamplingOp; template @@ -252,7 +272,7 @@ using KeyValueIndexOp = detail::KeyValueIndexOp; * @param[in] workspace Temporary workspace buffer which can get resized * */ -template +template void sample_centroids(const raft::handle_t& handle, raft::device_matrix_view X, raft::device_vector_view minClusterDistance, @@ -279,7 +299,7 @@ void sample_centroids(const raft::handle_t& handle, * @param[in] reduction_op The reduction operation used for the cost * */ -template +template void cluster_cost(const raft::handle_t& handle, raft::device_vector_view minClusterDistance, rmm::device_uvector workspace, @@ -424,11 +444,10 @@ void count_samples_in_cluster(const raft::handle_t& handle, handle, params, X, L2NormX, centroids, workspace, sampleCountInCluster); } -/* +/** * @brief Selects 'n_clusters' samples from the input X using kmeans++ algorithm. - - * @note This is the algorithm described in - * "k-means++: the advantages of careful seeding". 2007, Arthur, D. and Vassilvitskii, S. + * + * @see "k-means++: the advantages of careful seeding". 2007, Arthur, D. and Vassilvitskii, S. * ACM-SIAM symposium on Discrete algorithms. * * @tparam DataT the type of data used for weights, distances. @@ -446,10 +465,10 @@ template void init_plus_plus(const raft::handle_t& handle, const KMeansParams& params, raft::device_matrix_view X, - raft::device_matrix_view centroidsRawData, + raft::device_matrix_view centroids, rmm::device_uvector& workspace) { - detail::kmeansPlusPlus(handle, params, X, centroidsRawData, workspace); + detail::kmeansPlusPlus(handle, params, X, centroids, workspace); } /* @@ -480,13 +499,13 @@ void fit_main(const raft::handle_t& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_vector_view weight, - raft::device_matrix_view centroidsRawData, + raft::device_matrix_view centroids, raft::host_scalar_view inertia, raft::host_scalar_view n_iter, rmm::device_uvector& workspace) { detail::kmeans_fit_main( - handle, params, X, weight, centroidsRawData, inertia, n_iter, workspace); + handle, params, X, weight, centroids, inertia, n_iter, workspace); } }; // end namespace raft::cluster::kmeans diff --git a/cpp/include/raft/cluster/kmeans_params.hpp b/cpp/include/raft/cluster/kmeans_params.hpp index 433e32f5ff..a1532d9dd4 100644 --- a/cpp/include/raft/cluster/kmeans_params.hpp +++ b/cpp/include/raft/cluster/kmeans_params.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,15 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -/** - * This file is deprecated and will be removed in release 22.06. - * Please use the cuh version instead. - */ - -/** - * DISCLAIMER: this file is deprecated: use lap.cuh instead - */ - #pragma once #pragma message(__FILE__ \ diff --git a/cpp/include/raft/cluster/kmeans_types.hpp b/cpp/include/raft/cluster/kmeans_types.hpp index d6eadd1ba6..f411b12b5c 100644 --- a/cpp/include/raft/cluster/kmeans_types.hpp +++ b/cpp/include/raft/cluster/kmeans_types.hpp @@ -20,14 +20,34 @@ namespace raft::cluster::kmeans { +/** + * Simple object to specify hyper-parameters to the kmeans algorithm. + */ struct KMeansParams { - enum InitMethod { KMeansPlusPlus, Random, Array }; - - // The number of clusters to form as well as the number of centroids to - // generate (default:8). + enum InitMethod { + + /** + * Sample the centroids using the kmeans++ strategy + */ + KMeansPlusPlus, + + /** + * Sample the centroids uniformly at random + */ + Random, + + /** + * User provides the array of initial centroids + */ + Array + }; + + /** + * The number of clusters to form as well as the number of centroids to generate (default:8). + */ int n_clusters = 8; - /* + /** * Method for initialization, defaults to k-means++: * - InitMethod::KMeansPlusPlus (k-means++): Use scalable k-means++ algorithm * to select the initial cluster centers. @@ -37,34 +57,52 @@ struct KMeansParams { */ InitMethod init = KMeansPlusPlus; - // Maximum number of iterations of the k-means algorithm for a single run. + /** + * Maximum number of iterations of the k-means algorithm for a single run. + */ int max_iter = 300; - // Relative tolerance with regards to inertia to declare convergence. + /** + * Relative tolerance with regards to inertia to declare convergence. + */ double tol = 1e-4; - // verbosity level. + /** + * verbosity level. + */ int verbosity = RAFT_LEVEL_INFO; - // Seed to the random number generator. + /** + * Seed to the random number generator. + */ raft::random::RngState rng_state = raft::random::RngState(0, raft::random::GeneratorType::GenPhilox); - // Metric to use for distance computation. + /** + * Metric to use for distance computation. + */ raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded; - // Number of instance k-means algorithm will be run with different seeds. + /** + * Number of instance k-means algorithm will be run with different seeds. + */ int n_init = 1; - // Oversampling factor for use in the k-means|| algorithm. + /** + * Oversampling factor for use in the k-means|| algorithm + */ double oversampling_factor = 2.0; // batch_samples and batch_centroids are used to tile 1NN computation which is // useful to optimize/control the memory footprint // Default tile is [batch_samples x n_clusters] i.e. when batch_centroids is 0 // then don't tile the centroids - int batch_samples = 1 << 15; - int batch_centroids = 0; // if 0 then batch_centroids = n_clusters + int batch_samples = 1 << 15; + + /** + * if 0 then batch_centroids = n_clusters + */ + int batch_centroids = 0; // bool inertia_check = false; }; diff --git a/cpp/include/raft/cluster/single_linkage_types.hpp b/cpp/include/raft/cluster/single_linkage_types.hpp index 55239ff6d6..c8d6e4c8a6 100644 --- a/cpp/include/raft/cluster/single_linkage_types.hpp +++ b/cpp/include/raft/cluster/single_linkage_types.hpp @@ -19,17 +19,35 @@ #include namespace raft::cluster::hierarchy { -enum LinkageDistance { PAIRWISE = 0, KNN_GRAPH = 1 }; + +/** + * Determines the method for computing the minimum spanning tree (MST) + */ +enum LinkageDistance { + + /** + * Use a pairwise distance matrix as input to the mst. This + * is very fast and the best option for fairly small datasets (~50k data points) + */ + PAIRWISE = 0, + + /** + * Construct a KNN graph as input to the mst and provide additional + * edges if the mst does not converge. This is slower but scales + * to very large datasets. + */ + KNN_GRAPH = 1 +}; }; // end namespace raft::cluster::hierarchy -// The code below is legacy +// The code below is now considered legacy namespace raft::cluster { using hierarchy::LinkageDistance; /** - * Simple POCO for consolidating linkage results. This closely + * Simple container object for consolidating linkage results. This closely * mirrors the trained instance variables populated in * Scikit-learn's AgglomerativeClustering estimator. * @tparam value_idx diff --git a/cpp/include/raft/neighbors/ball_cover.cuh b/cpp/include/raft/neighbors/ball_cover.cuh index 780a9cfce2..28ff8491b6 100644 --- a/cpp/include/raft/neighbors/ball_cover.cuh +++ b/cpp/include/raft/neighbors/ball_cover.cuh @@ -30,6 +30,23 @@ namespace raft::neighbors::ball_cover { /** * Builds and populates a previously unbuilt BallCoverIndex + * + * Usage example: + * @code{.cpp} + * + * #include + * #include + * #include + * using namespace raft::neighbors; + * + * raft::handle_t handle; + * ... + * auto metric = raft::distance::DistanceType::L2Expanded; + * BallCoverIndex index(handle, X, metric); + * + * ball_cover::build_index(handle, index); + * @endcode + * * @tparam idx_t knn index type * @tparam value_t knn value type * @tparam int_t integral type for knn params @@ -130,10 +147,31 @@ void all_knn_query(const raft::handle_t& handle, * the index and query are the same array. This function will * build the index and assumes rbc_build_index() has not already * been called. + * + * Usage example: + * @code{.cpp} + * + * #include + * #include + * #include + * using namespace raft::neighbors; + * + * raft::handle_t handle; + * ... + * auto metric = raft::distance::DistanceType::L2Expanded; + * + * // Construct a ball cover index + * BallCoverIndex index(handle, X, metric); + * + * // Perform all neighbors knn query + * ball_cover::all_knn_query(handle, index, inds, dists, k); + * @endcode + * * @tparam idx_t knn index type * @tparam value_t knn distance type * @tparam int_t type for integers, such as number of rows/cols * @tparam matrix_idx_t matrix indexing type + * * @param[in] handle raft handle for resource management * @param[in] index ball cover index which has not yet been built * @param[out] inds output knn indices @@ -250,6 +288,28 @@ void knn_query(const raft::handle_t& handle, * function does not build the index and assumes rbc_build_index() has * already been called. Use this function when the index and * query arrays are different, otherwise use rbc_all_knn_query(). + * + * Usage example: + * @code{.cpp} + * + * #include + * #include + * #include + * using namespace raft::neighbors; + * + * raft::handle_t handle; + * ... + * auto metric = raft::distance::DistanceType::L2Expanded; + * + * // Build a ball cover index + * BallCoverIndex index(handle, X, metric); + * ball_cover::build_index(handle, index); + * + * // Perform all neighbors knn query + * ball_cover::knn_query(handle, index, inds, dists, k); + * @endcode + + * * @tparam idx_t index type * @tparam value_t distances type * @tparam int_t integer type for size info diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh index 3641a38991..772ccb67d2 100644 --- a/cpp/include/raft/neighbors/brute_force.cuh +++ b/cpp/include/raft/neighbors/brute_force.cuh @@ -23,26 +23,52 @@ namespace raft::neighbors::brute_force { /** - * @brief Performs a k-select across row partitioned index/distance + * @brief Performs a k-select across several (contiguous) row-partitioned index/distance * matrices formatted like the following: - * row1: k0, k1, k2 - * row2: k0, k1, k2 - * row3: k0, k1, k2 - * row1: k0, k1, k2 - * row2: k0, k1, k2 - * row3: k0, k1, k2 * + * part1row1: k0, k1, k2, k3 + * part1row2: k0, k1, k2, k3 + * part1row3: k0, k1, k2, k3 + * part2row1: k0, k1, k2, k3 + * part2row2: k0, k1, k2, k3 + * part2row3: k0, k1, k2, k3 * etc... * + * The example above shows what an aggregated index/distance matrix + * would look like with two partitions when n_samples=3 and k=4. + * + * When working with extremely large data sets that have been broken + * over multiple indexes, such as when computing over multiple GPUs, + * the ids will often start at 0 for each local knn index but the + * global ids need to be used when merging them together. An optional + * translations vector can be supplied to map the starting id of + * each partition to its global id so that the final merged knn + * is based on the global ids. + * + * Usage example: + * @code{.cpp} + * #include + * #include + * using namespace raft::neighbors; + * + * raft::handle_t handle; + * ... + * compute multiple knn graphs and aggregate row-wise + * (see detailed description above) + * ... + * brute_force::knn_merge_parts(handle, in_keys, in_values, out_keys, out_values, n_samples); + * @endcode + * * @tparam idx_t * @tparam value_t + * * @param[in] handle * @param[in] in_keys matrix of input keys (size n_samples * n_parts * k) * @param[in] in_values matrix of input values (size n_samples * n_parts * k) * @param[out] out_keys matrix of output keys (size n_samples * k) * @param[out] out_values matrix of output values (size n_samples * k) - * @param[in] n_samples number of rows in each part - * @param[in] translations optional vector of starting index mappings for each partition + * @param[in] n_samples number of rows in each partition + * @param[in] translations optional vector of starting global id mappings for each local partition */ template inline void knn_merge_parts( @@ -81,17 +107,31 @@ inline void knn_merge_parts( * row- or column-major but the output matrices will always be in * row-major format. * - * @param[in] handle the cuml handle to use - * @param[in] index vector of device matrices (each size m_i*d) to be used as the knn index - * @param[in] search matrix (size n*d) to be used for searching the index - * @param[out] indices matrix (size n*k) to store output knn indices - * @param[out] distances matrix (size n*k) to store the output knn distance - * @param[in] k the number of nearest neighbors to return - * @param[in] metric distance metric to use. Euclidean (L2) is used by default - * @param[in] metric_arg the value of `p` for Minkowski (l-p) distances. This + * Usage example: + * @code{.cpp} + * #include + * #include + * #include + * using namespace raft::neighbors; + * + * raft::handle_t handle; + * ... + * int k = 10; + * auto metric = raft::distance::DistanceType::L2SqrtExpanded; + * brute_force::knn(handle, index, search, indices, distances, k, metric); + * @endcode + * + * @param[in] handle: the cuml handle to use + * @param[in] index: vector of device matrices (each size m_i*d) to be used as the knn index + * @param[in] search: matrix (size n*d) to be used for searching the index + * @param[out] indices: matrix (size n*k) to store output knn indices + * @param[out] distances: matrix (size n*k) to store the output knn distance + * @param[in] k: the number of nearest neighbors to return + * @param[in] metric: distance metric to use. Euclidean (L2) is used by default + * @param[in] metric_arg: the value of `p` for Minkowski (l-p) distances. This * is ignored if the metric_type is not Minkowski. - * @param[in] translations starting offsets for partitions. should be the same size - * as input vector. + * @param[in] global_id_offset: optional starting global id mapping for the local partition + * (assumes the index contains contiguous ids in the global id space) */ template indices, raft::device_matrix_view distances, value_int k, - distance::DistanceType metric = distance::DistanceType::L2Unexpanded, - std::optional metric_arg = std::make_optional(2.0f), - std::optional> translations = std::nullopt) + distance::DistanceType metric = distance::DistanceType::L2Unexpanded, + std::optional metric_arg = std::make_optional(2.0f), + std::optional global_id_offset = std::nullopt) { RAFT_EXPECTS(index[0].extent(1) == search.extent(1), "Number of dimensions for both index and search matrices must be equal"); @@ -129,7 +169,10 @@ void knn(raft::handle_t const& handle, sizes.push_back(index[i].extent(0)); } - std::vector* trans = translations.has_value() ? &(*translations) : nullptr; + std::vector trans; + if (global_id_offset.has_value()) { trans.push_back(global_id_offset.value()); } + + std::vector* trans_arg = global_id_offset.has_value() ? &trans : nullptr; raft::spatial::knn::detail::brute_force_knn_impl(handle, inputs, @@ -143,7 +186,7 @@ void knn(raft::handle_t const& handle, k, rowMajorIndex, rowMajorQuery, - trans, + trans_arg, metric, metric_arg.value_or(2.0f)); } diff --git a/cpp/include/raft/neighbors/epsilon_neighborhood.cuh b/cpp/include/raft/neighbors/epsilon_neighborhood.cuh index b0e9b842ec..114216fc50 100644 --- a/cpp/include/raft/neighbors/epsilon_neighborhood.cuh +++ b/cpp/include/raft/neighbors/epsilon_neighborhood.cuh @@ -60,7 +60,22 @@ void epsUnexpL2SqNeighborhood(bool* adj, } /** - * @brief Computes epsilon neighborhood for the L2-Squared distance metric + * @brief Computes epsilon neighborhood for the L2-Squared distance metric and given ball size. + * The epsilon neighbors is represented by a dense boolean adjacency matrix of size m * n and + * an array of degrees for each vertex, which can be used as a compressed sparse row (CSR) + * indptr array. + * + * @code{.cpp} + * #include + * #include + * #include + * using namespace raft::neighbors; + * raft::handle_t handle; + * ... + * auto adj = raft::make_device_matrix(handle, m * n); + * auto vd = raft::make_device_vector(handle, m+1); + * epsilon_neighborhood::eps_neighbors_l2sq(handle, x, y, adj.view(), vd.view(), eps); + * @endcode * * @tparam value_t IO and math type * @tparam idx_t Index type diff --git a/cpp/include/raft/neighbors/ivf_flat.cuh b/cpp/include/raft/neighbors/ivf_flat.cuh index 23ae6c42bf..87400a9b93 100644 --- a/cpp/include/raft/neighbors/ivf_flat.cuh +++ b/cpp/include/raft/neighbors/ivf_flat.cuh @@ -38,7 +38,7 @@ namespace raft::neighbors::ivf_flat { * * Usage example: * @code{.cpp} - * using namespace raft::spatial::knn; + * using namespace raft::neighbors; * // use default index parameters * ivf_flat::index_params index_params; * // create and fill the index from a [N, D] dataset @@ -61,7 +61,7 @@ namespace raft::neighbors::ivf_flat { * @return the constructed ivf-flat index */ template -inline auto build( +auto build( const handle_t& handle, const index_params& params, const T* dataset, IdxT n_rows, uint32_t dim) -> index { @@ -78,15 +78,15 @@ inline auto build( * * Usage example: * @code{.cpp} - * using namespace raft::spatial::knn; + * using namespace raft::neighbors; * // use default index parameters * ivf_flat::index_params index_params; * // create and fill the index from a [N, D] dataset - * auto index = ivf_flat::build(handle, index_params, dataset, N, D); + * auto index = ivf_flat::build(handle, dataset, index_params); * // use default search parameters * ivf_flat::search_params search_params; * // search K nearest neighbours for each of the N queries - * ivf_flat::search(handle, search_params, index, queries, N, K, out_inds, out_dists); + * ivf_flat::search(handle, index, queries, out_inds, out_dists, search_params, k); * @endcode * * @tparam value_t data element type @@ -101,9 +101,9 @@ inline auto build( * @return the constructed ivf-flat index */ template -auto build_index(const handle_t& handle, - raft::device_matrix_view dataset, - const index_params& params) -> index +auto build(const handle_t& handle, + raft::device_matrix_view dataset, + const index_params& params) -> index { return raft::spatial::knn::ivf_flat::detail::build(handle, params, @@ -145,11 +145,11 @@ auto build_index(const handle_t& handle, * @return the constructed extended ivf-flat index */ template -inline auto extend(const handle_t& handle, - const index& orig_index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -> index +auto extend(const handle_t& handle, + const index& orig_index, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) -> index { return raft::spatial::knn::ivf_flat::detail::extend( handle, orig_index, new_vectors, new_indices, n_rows); @@ -169,9 +169,9 @@ inline auto extend(const handle_t& handle, * index_params.add_data_on_build = false; // don't populate index on build * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training * // train the index from a [N, D] dataset - * auto index_empty = ivf_flat::build(handle, index_params, dataset, N, D); + * auto index_empty = ivf_flat::build(handle, dataset, index_params, dataset); * // fill the index with the data - * auto index = ivf_flat::extend(handle, index_empty, dataset, nullptr, N); + * auto index = ivf_flat::extend(handle, index_empty, dataset); * @endcode * * @tparam value_t data element type @@ -204,8 +204,20 @@ auto extend(const handle_t& handle, } /** - * @brief Extend the index with the new data. - * * + * @brief Extend the index in-place with the new data. + * + * Usage example: + * @code{.cpp} + * using namespace raft::spatial::knn; + * ivf_flat::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_flat::build(handle, index_params, dataset, N, D); + * // fill the index with the data + * ivf_flat::extend(handle, index_empty, dataset, nullptr, N); + * @endcode + * * @tparam T data element type * @tparam IdxT type of the indices in the source dataset * @@ -218,18 +230,30 @@ auto extend(const handle_t& handle, * @param[in] n_rows the number of samples */ template -inline void extend(const handle_t& handle, - index* index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) +void extend(const handle_t& handle, + index* index, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) { *index = extend(handle, *index, new_vectors, new_indices, n_rows); } /** - * @brief Extend the index with the new data. - * * + * @brief Extend the index in-place with the new data. + * + * Usage example: + * @code{.cpp} + * using namespace raft::spatial::knn; + * ivf_flat::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_flat::build(handle, dataset, index_params, dataset); + * // fill the index with the data + * ivf_flat::extend(handle, index_empty, dataset); + * @endcode + * * @tparam value_t data element type * @tparam idx_t type of the indices in the source dataset * @tparam int_t precision / type of integral arguments @@ -298,15 +322,15 @@ void extend(const handle_t& handle, * enough memory pool here to avoid memory allocations within search). */ template -inline void search(const handle_t& handle, - const search_params& params, - const index& index, - const T* queries, - uint32_t n_queries, - uint32_t k, - IdxT* neighbors, - float* distances, - rmm::mr::device_memory_resource* mr = nullptr) +void search(const handle_t& handle, + const search_params& params, + const index& index, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + rmm::mr::device_memory_resource* mr = nullptr) { return raft::spatial::knn::ivf_flat::detail::search( handle, params, index, queries, n_queries, k, neighbors, distances, mr); @@ -323,21 +347,15 @@ inline void search(const handle_t& handle, * eliminate entirely allocations happening within `search`: * @code{.cpp} * ... - * // Create a pooling memory resource with a pre-defined initial size. - * rmm::mr::pool_memory_resource mr( - * rmm::mr::get_current_device_resource(), 1024 * 1024); * // use default search parameters * ivf_flat::search_params search_params; * // Use the same allocator across multiple searches to reduce the number of * // cuda memory allocations - * ivf_flat::search(handle, search_params, index, queries1, N1, K, out_inds1, out_dists1, &mr); - * ivf_flat::search(handle, search_params, index, queries2, N2, K, out_inds2, out_dists2, &mr); - * ivf_flat::search(handle, search_params, index, queries3, N3, K, out_inds3, out_dists3, &mr); + * ivf_flat::search(handle, index, queries1, out_inds1, out_dists1, search_params, K); + * ivf_flat::search(handle, index, queries2, out_inds2, out_dists2, search_params, K); + * ivf_flat::search(handle, index, queries3, out_inds3, out_dists3, search_params, K); * ... * @endcode - * The exact size of the temporary buffer depends on multiple factors and is an implementation - * detail. However, you can safely specify a small initial size for the memory pool, so that only a - * few allocations happen to grow it during the first invocations of the `search`. * * @tparam value_t data element type * @tparam idx_t type of the indices diff --git a/cpp/include/raft/neighbors/ivf_pq.cuh b/cpp/include/raft/neighbors/ivf_pq.cuh index 1e32d5d7ba..5d619c5bec 100644 --- a/cpp/include/raft/neighbors/ivf_pq.cuh +++ b/cpp/include/raft/neighbors/ivf_pq.cuh @@ -37,7 +37,7 @@ namespace raft::neighbors::ivf_pq { * * Usage example: * @code{.cpp} - * using namespace raft::spatial::knn; + * using namespace raft::neighbors; * // use default index parameters * ivf_pq::index_params index_params; * // create and fill the index from a [N, D] dataset diff --git a/cpp/include/raft/solver/linear_assignment.cuh b/cpp/include/raft/solver/linear_assignment.cuh index 4c24dcbc29..3e17b557f2 100644 --- a/cpp/include/raft/solver/linear_assignment.cuh +++ b/cpp/include/raft/solver/linear_assignment.cuh @@ -39,8 +39,19 @@ namespace raft::solver { +/** + * @brief CUDA Implementation of O(n^3) alternating tree Hungarian Algorithm + * @note This is a port to RAFT from original authors Ketan Date and Rakesh Nagi + * + * @see Date, Ketan, and Rakesh Nagi. "GPU-accelerated Hungarian algorithms + * for the Linear Assignment Problem." Parallel Computing 57 (2016): 52-72. + * + * @tparam vertex_t + * @tparam weight_t + */ template class LinearAssignmentProblem { + private: vertex_t size_; vertex_t batchsize_; weight_t epsilon_; @@ -66,6 +77,13 @@ class LinearAssignmentProblem { rmm::device_uvector obj_val_dual_v; public: + /** + * @brief Constructor + * @param handle raft handle for managing resources + * @param size size of square matrix + * @param batchsize + * @param epsilon + */ LinearAssignmentProblem(raft::handle_t const& handle, vertex_t size, vertex_t batchsize, @@ -91,7 +109,12 @@ class LinearAssignmentProblem { { } - // Executes Hungarian algorithm on the input cost matrix. + /** + * Executes Hungarian algorithm on the input cost matrix. + * @param d_cost_matrix + * @param d_row_assignment + * @param d_col_assignment + */ void solve(weight_t const* d_cost_matrix, vertex_t* d_row_assignment, vertex_t* d_col_assignment) { initializeDevice(); @@ -118,19 +141,31 @@ class LinearAssignmentProblem { d_costs_ = nullptr; } - // Function for getting optimal row dual vector for subproblem spId. + /** + * Function for getting optimal row dual vector for subproblem spId. + * @param spId + * @return + */ std::pair getRowDualVector(int spId) const { return std::make_pair(row_duals_v.data() + spId * size_, size_); } - // Function for getting optimal col dual vector for subproblem spId. + /** + * Function for getting optimal col dual vector for subproblem spId. + * @param spId + * @return + */ std::pair getColDualVector(int spId) { return std::make_pair(col_duals_v.data() + spId * size_, size_); } - // Function for getting optimal primal objective value for subproblem spId. + /** + * Function for getting optimal primal objective value for subproblem spId. + * @param spId + * @return + */ weight_t getPrimalObjectiveValue(int spId) { weight_t result; @@ -139,7 +174,11 @@ class LinearAssignmentProblem { return result; } - // Function for getting optimal dual objective value for subproblem spId. + /** + * Function for getting optimal dual objective value for subproblem spId. + * @param spId + * @return + */ weight_t getDualObjectiveValue(int spId) { weight_t result; diff --git a/cpp/include/raft/sparse/solver/mst.cuh b/cpp/include/raft/sparse/solver/mst.cuh index 33beeb1915..5f55a567ca 100644 --- a/cpp/include/raft/sparse/solver/mst.cuh +++ b/cpp/include/raft/sparse/solver/mst.cuh @@ -20,6 +20,29 @@ namespace raft::sparse::solver { +/** + * Compute the minimium spanning tree (MST) or minimum spanning forest (MSF) depending on + * the connected components of the given graph. + * + * @tparam vertex_t integral type for precision of vertex indexing + * @tparam edge_t integral type for precision of edge indexing + * @tparam weight_t type of weights array + * @tparam alteration_t type to use for random alteration + * + * @param handle + * @param offsets csr inptr array of row offsets (size v+1) + * @param indices csr array of column indices (size e) + * @param weights csr array of weights (size e) + * @param v number of vertices in graph + * @param e number of edges in graph + * @param color array to store resulting colors for MSF + * @param stream cuda stream for ordering operations + * @param symmetrize_output should the resulting output edge list should be symmetrized? + * @param initialize_colors should the colors array be initialized inside the MST? + * @param iterations maximum number of iterations to perform + * @return a list of edges containing the mst (or a subset of the edges guaranteed to be in the mst + * when an msf is encountered) + */ template Graph_COO mst(const raft::handle_t& handle, edge_t const* offsets, diff --git a/cpp/include/raft/spatial/knn/ivf_flat.cuh b/cpp/include/raft/spatial/knn/ivf_flat.cuh index d7c3d80fb5..65b6f5ed4b 100644 --- a/cpp/include/raft/spatial/knn/ivf_flat.cuh +++ b/cpp/include/raft/spatial/knn/ivf_flat.cuh @@ -33,7 +33,6 @@ namespace raft::spatial::knn::ivf_flat { using raft::neighbors::ivf_flat::build; -using raft::neighbors::ivf_flat::build_index; using raft::neighbors::ivf_flat::extend; using raft::neighbors::ivf_flat::search; diff --git a/cpp/include/raft/stats/adjusted_rand_index.cuh b/cpp/include/raft/stats/adjusted_rand_index.cuh index e1b6a241c4..93fd07eb0b 100644 --- a/cpp/include/raft/stats/adjusted_rand_index.cuh +++ b/cpp/include/raft/stats/adjusted_rand_index.cuh @@ -31,8 +31,8 @@ namespace raft { namespace stats { /** - * @brief Function to calculate Adjusted RandIndex as described - * here + * @brief Function to calculate Adjusted RandIndex + * @see https://en.wikipedia.org/wiki/Rand_index * @tparam T data-type for input label arrays * @tparam MathT integral data-type used for computing n-choose-r * @param firstClusterArray: the array of classes @@ -50,8 +50,8 @@ double adjusted_rand_index(const T* firstClusterArray, } /** - * @brief Function to calculate Adjusted RandIndex as described - * here + * @brief Function to calculate Adjusted RandIndex + * @see https://en.wikipedia.org/wiki/Rand_index * @tparam value_t data-type for input label arrays * @tparam math_t integral data-type used for computing n-choose-r * @tparam idx_t Index type of matrix extent. diff --git a/cpp/include/raft/stats/common.hpp b/cpp/include/raft/stats/common.hpp index 8392bd50fe..724ca224c6 100644 --- a/cpp/include/raft/stats/common.hpp +++ b/cpp/include/raft/stats/common.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,59 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include - -// This file is a shameless amalgamation of independent works done by -// Lars Nyland and Andy Adinets - -///@todo: add cub's histogram as another option - -namespace raft { -namespace stats { - -/** Default mapper which just returns the value of the data itself */ -template -struct IdentityBinner { - DI int operator()(DataT val, IdxT row, IdxT col) { return int(val); } -}; - -/** Types of support histogram implementations */ -enum HistType { - /** shared mem atomics but with bins to be 1b int's */ - HistTypeSmemBits1 = 1, - /** shared mem atomics but with bins to be 2b int's */ - HistTypeSmemBits2 = 2, - /** shared mem atomics but with bins to be 4b int's */ - HistTypeSmemBits4 = 4, - /** shared mem atomics but with bins to ba 1B int's */ - HistTypeSmemBits8 = 8, - /** shared mem atomics but with bins to be 2B int's */ - HistTypeSmemBits16 = 16, - /** use only global atomics */ - HistTypeGmem, - /** uses shared mem atomics to reduce global traffic */ - HistTypeSmem, - /** - * uses shared mem atomics with match_any intrinsic to further reduce shared - * memory traffic. This can only be enabled on Volta and later architectures. - * If one tries to enable this for older arch's, it will fall back to - * `HistTypeSmem`. - * @note This is to be used only when the input dataset leads to a lot of - * repetitions in a given warp, else, this algo can be much slower than - * `HistTypeSmem`! - */ - HistTypeSmemMatchAny, - /** builds a hashmap of active bins in shared mem */ - HistTypeSmemHash, - /** decide at runtime the best algo for the given inputs */ - HistTypeAuto -}; - -/// Supported types of information criteria -enum IC_Type { AIC, AICc, BIC }; +#pragma message(__FILE__ \ + " is deprecated and will be removed in a future release." \ + " Please use the raft/stats/stats_types.hpp version instead.") -}; // end namespace stats -}; // end namespace raft +#include diff --git a/cpp/include/raft/stats/detail/histogram.cuh b/cpp/include/raft/stats/detail/histogram.cuh index 777e0b7816..69bd721ded 100644 --- a/cpp/include/raft/stats/detail/histogram.cuh +++ b/cpp/include/raft/stats/detail/histogram.cuh @@ -32,6 +32,12 @@ namespace raft { namespace stats { namespace detail { +/** Default mapper which just returns the value of the data itself */ +template +struct IdentityBinner { + DI int operator()(DataT val, IdxT row, IdxT col) { return int(val); } +}; + static const int ThreadsPerBlock = 256; template diff --git a/cpp/include/raft/stats/histogram.cuh b/cpp/include/raft/stats/histogram.cuh index df1c2772f1..8efb2e8df8 100644 --- a/cpp/include/raft/stats/histogram.cuh +++ b/cpp/include/raft/stats/histogram.cuh @@ -31,6 +31,14 @@ namespace raft { namespace stats { +/** + * Default mapper which just returns the value of the data itself + */ +template +struct IdentityBinner : public detail::IdentityBinner { + IdentityBinner() : detail::IdentityBinner() {} +}; + /** * @brief Perform histogram on the input data. It chooses the right load size * based on the input data vector length. It also supports large-bin cases diff --git a/cpp/include/raft/stats/stats_types.hpp b/cpp/include/raft/stats/stats_types.hpp new file mode 100644 index 0000000000..5db5ef1c57 --- /dev/null +++ b/cpp/include/raft/stats/stats_types.hpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::stats { + +/** + * @brief Types of support histogram implementations + */ +enum HistType { + /** shared mem atomics but with bins to be 1b int's */ + HistTypeSmemBits1 = 1, + /** shared mem atomics but with bins to be 2b int's */ + HistTypeSmemBits2 = 2, + /** shared mem atomics but with bins to be 4b int's */ + HistTypeSmemBits4 = 4, + /** shared mem atomics but with bins to ba 1B int's */ + HistTypeSmemBits8 = 8, + /** shared mem atomics but with bins to be 2B int's */ + HistTypeSmemBits16 = 16, + /** use only global atomics */ + HistTypeGmem, + /** uses shared mem atomics to reduce global traffic */ + HistTypeSmem, + /** + * uses shared mem atomics with match_any intrinsic to further reduce shared + * memory traffic. This can only be enabled on Volta and later architectures. + * If one tries to enable this for older arch's, it will fall back to + * `HistTypeSmem`. + * @note This is to be used only when the input dataset leads to a lot of + * repetitions in a given warp, else, this algo can be much slower than + * `HistTypeSmem`! + */ + HistTypeSmemMatchAny, + /** builds a hashmap of active bins in shared mem */ + HistTypeSmemHash, + /** decide at runtime the best algo for the given inputs */ + HistTypeAuto +}; + +/** + * @brief Supported types of information criteria + */ +enum IC_Type { AIC, AICc, BIC }; + +}; // end namespace raft::stats diff --git a/cpp/test/neighbors/ann_ivf_flat.cu b/cpp/test/neighbors/ann_ivf_flat.cu index 01af7ea0bd..3a5daff4bb 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cu +++ b/cpp/test/neighbors/ann_ivf_flat.cu @@ -154,7 +154,7 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { auto database_view = raft::make_device_matrix_view( (const DataT*)database.data(), ps.num_db_vecs, ps.dim); - auto index = ivf_flat::build_index(handle_, database_view, index_params); + auto index = ivf_flat::build(handle_, database_view, index_params); rmm::device_uvector vector_indices(ps.num_db_vecs, stream_); thrust::sequence(handle_.get_thrust_policy(), diff --git a/cpp/test/neighbors/knn.cu b/cpp/test/neighbors/knn.cu index 710950e312..eb5ecf663f 100644 --- a/cpp/test/neighbors/knn.cu +++ b/cpp/test/neighbors/knn.cu @@ -94,7 +94,8 @@ class KNNTest : public ::testing::TestWithParam { auto distances = raft::make_device_matrix_view(distances_.data(), rows_, k_); - knn(handle, index, search, indices, distances, k_); + auto metric = raft::distance::DistanceType::L2Unexpanded; + knn(handle, index, search, indices, distances, k_, metric, std::make_optional(0)); build_actual_output<<>>( actual_labels_.data(), rows_, k_, search_labels_.data(), indices_.data()); diff --git a/docs/source/cpp_api.rst b/docs/source/cpp_api.rst index d10d9773a5..e3f650563d 100644 --- a/docs/source/cpp_api.rst +++ b/docs/source/cpp_api.rst @@ -1,6 +1,7 @@ -~~~~~~~~~~~~~~~~~~~~~~ -RAFT C++ API Reference -~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~ +C++ API Reference +~~~~~~~~~~~~~~~~~ + .. _api: diff --git a/docs/source/cpp_api/cluster.rst b/docs/source/cpp_api/cluster.rst index 41816482cc..0ecfe81bc3 100644 --- a/docs/source/cpp_api/cluster.rst +++ b/docs/source/cpp_api/cluster.rst @@ -1,7 +1,8 @@ Cluster ======= -This page provides C++ class references for the publicly-exposed elements of the cluster package. +This page provides C++ class references for the publicly-exposed elements of the `raft/cluster` headers. RAFT provides +fundamental clustering algorithms which are, themselves, considered reusable building blocks for other algorithms. K-Means ------- diff --git a/docs/source/cpp_api/core.rst b/docs/source/cpp_api/core.rst index d4891bf0b3..7228213d9b 100644 --- a/docs/source/cpp_api/core.rst +++ b/docs/source/cpp_api/core.rst @@ -1,7 +1,10 @@ Core ==== -This page provides C++ class references for the publicly-exposed elements of the core package. +This page provides C++ class references for the publicly-exposed elements of the `raft/core` package. The `raft/core` headers +require minimal dependencies, can be compiled without `nvcc`, and thus are safe to expose on your own public APIs. Aside from +the headers in the `raft/core` include directory, any headers in the codebase with the suffix `_types.hpp` are also safe to +expose in public APIs. handle_t @@ -34,30 +37,67 @@ mdarray :project: RAFT :members: -.. doxygenclass:: raft::make_device_matrix + +Device Factories +---------------- + +.. doxygenfunction:: raft::make_device_matrix :project: RAFT -.. doxygenclass:: raft::make_device_vector +.. doxygenfunction:: raft::make_device_vector :project: RAFT -.. doxygenclass:: raft::make_device_scalar +.. doxygenfunction:: raft::make_device_scalar :project: RAFT -.. doxygenclass:: raft::make_host_matrix +Host Factories +---------------- + +.. doxygenfunction:: raft::make_host_matrix :project: RAFT -.. doxygenclass:: raft::make_host_vector +.. doxygenfunction:: raft::make_host_vector :project: RAFT -.. doxygenclass:: raft::make_device_scalar +.. doxygenfunction:: raft::make_device_scalar :project: RAFT mdspan ####### -.. doxygenfunction:: raft::make_device_mdspan - :project: RAFT +Device Vocabulary +----------------- + +.. doxygentypedef:: raft::device_mdspan + :project: RAFT + +.. doxygentypedef:: raft::device_matrix_view + :project: RAFT + +.. doxygentypedef:: raft::device_vector_view + :project: RAFT + +.. doxygentypedef:: raft::device_scalar_view + :project: RAFT + +Host Vocabulary +--------------- + +.. doxygentypedef:: raft::host_mdspan + :project: RAFT + +.. doxygentypedef:: raft::host_matrix_view + :project: RAFT + +.. doxygentypedef:: raft::host_vector_view + :project: RAFT + +.. doxygentypedef:: raft::host_scalar_view + :project: RAFT + +Device Factories +---------------- .. doxygenfunction:: raft::make_device_matrix_view :project: RAFT @@ -68,6 +108,9 @@ mdspan .. doxygenfunction:: raft::make_device_scalar_view :project: RAFT +Host Factories +-------------- + .. doxygenfunction:: raft::make_host_matrix_view :project: RAFT @@ -80,18 +123,22 @@ mdspan span #### -.. doxygenclass:: raft::device_span - :project: RAFT - :members: +.. doxygentypedef:: raft::device_span + :project: RAFT + +.. doxygentypedef:: raft::host_span + :project: RAFT -.. doxygenclass:: raft::host_span +.. doxygenclass:: raft::span :project: RAFT :members: + + Key-Value Pair ############## -.. doxygenclass:: raft::KeyValuePair +.. doxygenstruct:: raft::KeyValuePair :project: RAFT :members: diff --git a/docs/source/cpp_api/distance.rst b/docs/source/cpp_api/distance.rst index c2bce860d5..2596361f6a 100644 --- a/docs/source/cpp_api/distance.rst +++ b/docs/source/cpp_api/distance.rst @@ -1,7 +1,8 @@ Distance ======== -This page provides C++ class references for the publicly-exposed elements of the distance package. +This page provides C++ class references for the publicly-exposed elements of the `raft/distance` package. RAFT's +distances have been highly optimized and support a wide assortment of different distance measures. Distance ######## diff --git a/docs/source/cpp_api/linalg.rst b/docs/source/cpp_api/linalg.rst index f9986fd2ce..5664e5b3dc 100644 --- a/docs/source/cpp_api/linalg.rst +++ b/docs/source/cpp_api/linalg.rst @@ -1,7 +1,10 @@ Linear Algebra ============== -This page provides C++ class references for the publicly-exposed elements of the (dense) linear algebra package. +This page provides C++ class references for the publicly-exposed elements of the `raft/linalg` (dense) linear algebra headers. +In addition to providing highly optimized arithmetic and matrix/vector operations, RAFT provides a consistent user experience +by providing common BLAS routines, standard linear system solvers, factorization and eigenvalue solvers. Some of these routines +hide the complexities of lower-level C-based libraries provided in the CUDA toolkit .. doxygennamespace:: raft::linalg :project: RAFT diff --git a/docs/source/cpp_api/matrix.rst b/docs/source/cpp_api/matrix.rst index 65534aa6ee..945658eb7b 100644 --- a/docs/source/cpp_api/matrix.rst +++ b/docs/source/cpp_api/matrix.rst @@ -1,7 +1,8 @@ Matrix ====== -This page provides C++ class references for the publicly-exposed elements of the matrix package. +This page provides C++ class references for the publicly-exposed elements of the `raft/matrix` headers. The `raft/matrix` +headers cover many operations on matrices that are otherwise not covered by `raft/linalg`. .. doxygennamespace:: raft::matrix :project: RAFT diff --git a/docs/source/cpp_api/solver.rst b/docs/source/cpp_api/solver.rst index a8b93ca046..f7ca244dc8 100644 --- a/docs/source/cpp_api/solver.rst +++ b/docs/source/cpp_api/solver.rst @@ -1,7 +1,7 @@ -Optimization -============ +Solvers +======= -This page provides C++ class references for the publicly-exposed elements of the optimization package. +This page provides C++ class references for the publicly-exposed elements of the iterative and combinatorial solvers package. Linear Assignment Problem diff --git a/docs/source/cpp_api/sparse.rst b/docs/source/cpp_api/sparse.rst index c0ea61c6f7..a7c32cc65d 100644 --- a/docs/source/cpp_api/sparse.rst +++ b/docs/source/cpp_api/sparse.rst @@ -4,7 +4,6 @@ Sparse This page provides C++ class references for the publicly-exposed elements of the sparse package. - Conversion ########## @@ -26,20 +25,16 @@ Linear Algebra :project: RAFT :members: -Misc Operations -############### +Matrix Operations +################# .. doxygennamespace:: raft::sparse::op :project: RAFT :members: -Selection -######### - -.. doxygennamespace:: raft::sparse::selection - :project: RAFT - :members: +Nearest Neighbors +################# -.. doxygennamespace:: raft::linkage +.. doxygennamespace:: raft::sparse::neighbors :project: RAFT :members: diff --git a/docs/source/index.rst b/docs/source/index.rst index 0d7ab295f4..fb7ce310c8 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,15 +1,48 @@ Welcome to RAFT's documentation! ================================= -RAFT contains fundamental widely-used algorithms and primitives for data science and machine learning. +RAFT contains fundamental widely-used algorithms and primitives for scientific computing, data science and machine learning. The algorithms are CUDA-accelerated and form building-blocks for rapidly composing analytics. + +By taking a primitives-based approach to algorithm development, RAFT + +- accelerates algorithm construction time +- reduces the maintenance burden by maximizing reuse across projects, and +- centralizes core reusable computations, allowing future optimizations to benefit all algorithms that use them. + + +While not exhaustive, the following general categories help summarize the accelerated building blocks that RAFT contains: + +.. list-table:: + :widths: 25 50 + :header-rows: 1 + + * - Category + - Examples + * - Data Formats + - sparse & dense, conversions, data generation + * - Dense Operations + - linear algebra, matrix and vector operations, slicing, norms, factorization, least squares, svd & eigenvalue problems + * - Sparse Operations + - linear algebra, arithmetic, eigenvalue problems, slicing, symmetrization, components & labeling + * - Spatial + - pairwise distances, nearest neighbors, neighborhood graph construction + * - Basic Clustering + - spectral clustering, hierarchical clustering, k-means + * - Solvers + - combinatorial optimization, iterative solvers + * - Statistics + - sampling, moments and summary statistics, metrics + * - Tools & Utilities + - common utilities for developing CUDA applications, multi-node multi-gpu infrastructure .. toctree:: :maxdepth: 2 :caption: Contents: + quick_start.md cpp_api.rst - raft_dask_api.rst pylibraft_api.rst + raft_dask_api.rst Indices and tables diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md new file mode 100644 index 0000000000..a09f7beceb --- /dev/null +++ b/docs/source/quick_start.md @@ -0,0 +1,129 @@ +# Quick Start + + +This guide is meant to provide a quick-start tutorial for interacting with RAFT's C++ APIs. + +## RAPIDS Memory Manager (RMM) + +RAFT relies heavily on RMM which eases the burden of configuring different allocation strategies globally across the libraries that use it. + +## Multi-dimensional Spans and Arrays + +The APIs in RAFT currently accept raw pointers to device memory and we are in the process of simplifying the APIs with the [mdspan](https://arxiv.org/abs/2010.06474) multi-dimensional array view for representing data in higher dimensions similar to the `ndarray` in the Numpy Python library. RAFT also contains the corresponding owning `mdarray` structure, which simplifies the allocation and management of multi-dimensional data in both host and device (GPU) memory. + +The `mdarray` forms a convenience layer over RMM and can be constructed in RAFT using a number of different helper functions: + +```c++ +#include + +int n_rows = 10; +int n_cols = 10; + +auto scalar = raft::make_device_scalar(handle, 1.0); +auto vector = raft::make_device_vector(handle, n_cols); +auto matrix = raft::make_device_matrix(handle, n_rows, n_cols); +``` + +The `mdspan` is a lightweight non-owning view that can wrap around any pointer, maintaining shape, layout, and indexing information for accessing elements. + + +We can construct `mdspan` instances directly from the above `mdarray` instances: + +```c++ +// Scalar mdspan on device +auto scalar_view = scalar.view(); + +// Vector mdspan on device +auto vector_view = vector.view(); + +// Matrix mdspan on device +auto matrix_view = matrix.view(); +``` +Since the `mdspan` is just a lightweight wrapper, we can also construct it from the underlying data handles in the `mdarray` instances above. We use the extent to get information about the `mdarray` or `mdspan`'s shape. + +```c++ +#include + +auto scalar_view = raft::make_device_scalar_view(scalar.data_handle()); +auto vector_view = raft::make_device_vector_view(vector.data_handle(), vector.extent(0)); +auto matrix_view = raft::make_device_matrix_view(matrix.data_handle(), matrix.extent(0), matrix.extent(1)); +``` + +Of course, RAFT's `mdspan`/`mdarray` APIs aren't just limited to the `device`. You can also create `host` variants: + +```c++ +#include +#include + +int n_rows = 10; +int n_cols = 10; + +auto scalar = raft::make_host_scalar(handle, 1.0); +auto vector = raft::make_host_vector(handle, n_cols); +auto matrix = raft::make_host_matrix(handle, n_rows, n_cols); + +auto scalar_view = raft::make_host_scalar_view(scalar.data_handle()); +auto vector_view = raft::make_host_vector_view(vector.data_handle(), vector.extent(0)); +auto matrix_view = raft::make_host_matrix_view(matrix.data_handle(), matrix.extent(0), matrix.extent(1)); +``` + +And `managed` variants: + +```c++ +#include + +int n_rows = 10; +int n_cols = 10; + +auto matrix = raft::make_managed_mdspan(managed_ptr, raft::make_matrix_extents(n_rows, n_cols)); +``` + + +## C++ Example + +Most of the primitives in RAFT accept a `raft::handle_t` object for the management of resources which are expensive to create, such CUDA streams, stream pools, and handles to other CUDA libraries like `cublas` and `cusolver`. + +The example below demonstrates creating a RAFT handle and using it with `device_matrix` and `device_vector` to allocate memory, generating random clusters, and computing +pairwise Euclidean distances: + +```c++ +#include +#include +#include +#include + +raft::handle_t handle; + +int n_samples = 5000; +int n_features = 50; + +auto input = raft::make_device_matrix(handle, n_samples, n_features); +auto labels = raft::make_device_vector(handle, n_samples); +auto output = raft::make_device_matrix(handle, n_samples, n_samples); + +raft::random::make_blobs(handle, input.view(), labels.view()); + +auto metric = raft::distance::DistanceType::L2SqrtExpanded; +raft::distance::pairwise_distance(handle, input.view(), input.view(), output.view(), metric); +``` + +## Python Example + +The `pylibraft` package contains a Python API for RAFT algorithms and primitives. `pylibraft` integrates nicely into other libraries by being very lightweight with minimal dependencies and accepting any object that supports the `__cuda_array_interface__`, such as [CuPy's ndarray](https://docs.cupy.dev/en/stable/user_guide/interoperability.html#rmm). The package is currently limited to pairwise distances and RMAT graph generation, but we will continue adding more in future releases. + +The example below demonstrates computing the pairwise Euclidean distances between CuPy arrays. `pylibraft` is a low-level API that prioritizes efficiency and simplicity over being pythonic, which is shown here by pre-allocating the output memory before invoking the `pairwise_distance` function. Note that CuPy is not a required dependency for `pylibraft`. + +```python +import cupy as cp + +from pylibraft.distance import pairwise_distance + +n_samples = 5000 +n_features = 50 + +in1 = cp.random.random_sample((n_samples, n_features), dtype=cp.float32) +in2 = cp.random.random_sample((n_samples, n_features), dtype=cp.float32) +output = cp.empty((n_samples, n_samples), dtype=cp.float32) + +pairwise_distance(in1, in2, output, metric="euclidean") +``` From 004af044465a1be73c5c0e56b680fee771845605 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 21 Oct 2022 18:53:50 -0400 Subject: [PATCH 33/35] Fixing style --- cpp/include/raft/cluster/kmeans.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/cluster/kmeans.cuh b/cpp/include/raft/cluster/kmeans.cuh index 9d19e24806..c109cba713 100644 --- a/cpp/include/raft/cluster/kmeans.cuh +++ b/cpp/include/raft/cluster/kmeans.cuh @@ -982,7 +982,7 @@ void fit_main(const raft::handle_t& handle, handle, params, X, weight, centroidsRawData, inertia, n_iter, workspace); } -}; // end namespace raft::cluster::kmeans +}; // namespace raft::cluster namespace raft::cluster { From 337dd023a6e228852376e8e5178586e6c1d0f939 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 27 Oct 2022 16:31:50 -0400 Subject: [PATCH 34/35] Fixing style --- cpp/include/raft/cluster/kmeans.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/cluster/kmeans.cuh b/cpp/include/raft/cluster/kmeans.cuh index f64de1c8de..4eddc784e4 100644 --- a/cpp/include/raft/cluster/kmeans.cuh +++ b/cpp/include/raft/cluster/kmeans.cuh @@ -514,7 +514,7 @@ void fit_main(const raft::handle_t& handle, handle, params, X, weight, centroids, inertia, n_iter, workspace); } -}; // namespace raft::cluster +}; // namespace raft::cluster::kmeans namespace raft::cluster { From 77a7c0b88bdcbc85377ea6f2f770c755286d3f25 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 27 Oct 2022 18:15:53 -0400 Subject: [PATCH 35/35] A little more cleanup --- cpp/include/raft/solver/detail/qn/qn_util.cuh | 4 +- cpp/include/raft/solver/quasi_newton.cuh | 20 +++--- cpp/include/raft/solver/simple_mat.cuh | 15 +++-- cpp/include/raft/solver/solver_types.hpp | 45 ++++++++++++- cpp/test/solver/quasi_newton.cu | 66 ++++++++++--------- 5 files changed, 99 insertions(+), 51 deletions(-) diff --git a/cpp/include/raft/solver/detail/qn/qn_util.cuh b/cpp/include/raft/solver/detail/qn/qn_util.cuh index 1081fef123..a8df31df8f 100644 --- a/cpp/include/raft/solver/detail/qn/qn_util.cuh +++ b/cpp/include/raft/solver/detail/qn/qn_util.cuh @@ -27,8 +27,8 @@ inline bool qn_is_classification(qn_loss_type t) switch (t) { case QN_LOSS_LOGISTIC: case QN_LOSS_SOFTMAX: - case QN_LOSS_SVC_L1: - case QN_LOSS_SVC_L2: return true; + case QN_LOSS_HINGE: + case QN_LOSS_SQ_HINGE: return true; default: return false; } } diff --git a/cpp/include/raft/solver/quasi_newton.cuh b/cpp/include/raft/solver/quasi_newton.cuh index dee73a010b..0c731cebd6 100644 --- a/cpp/include/raft/solver/quasi_newton.cuh +++ b/cpp/include/raft/solver/quasi_newton.cuh @@ -43,7 +43,7 @@ struct AbsLoss : detail::objectives::AbsLoss { : detail::objectives::AbsLoss(handle, D, has_bias) { } -} +}; /** * Squared loss function specification @@ -55,7 +55,7 @@ struct SquaredLoss : detail::objectives::SquaredLoss { : detail::objectives::SquaredLoss(handle, D, 1, has_bias), lz{}, dlz{} { } -} +}; /** * Standard hinge loss function specification @@ -67,7 +67,7 @@ struct HingeLoss : detail::objectives::HingeLoss { : detail::objectives::HingeLoss(handle, D, has_bias) { } -} +}; /** * @@ -79,7 +79,7 @@ struct LogisticLoss : detail::objectives::LogisticLoss { : detail::objectives::LogisticLoss(handle, D, has_bias) { } -} +}; /** * Squared hinge loss function specification @@ -91,19 +91,19 @@ struct SqHingeLoss : detail::objectives::SqHingeLoss { : detail::objectives::SqHingeLoss(handle, D, has_bias) { } -} +}; /** * Epsilon insensitive (regression) hinge loss function specification * @tparam T */ template - struct EpsInsHingeLoss : detail::objectives::EpsInsHingeLoss > { + struct EpsInsHingeLoss : detail::objectives::EpsInsHingeLoss { EpsInsHingeLoss(const raft::handle_t& handle, int D, bool has_bias, T sensitivity) : detail::objectives::EpsInsHingeLoss(handle, D, 1, has_bias), lz{sensitivity}, dlz{sensitivity} { } -} +}; /** * Squared Epsilon insensitive (regression) hinge loss function specification @@ -160,7 +160,7 @@ struct QNLinearBase : detail::objectives::QNLinearBase { : detail::objectives::QNLinearBase(C, D, fit_intercept) { } -} +}; /** * Softmax loss function specification @@ -172,7 +172,7 @@ struct Softmax : detail::objectives::Softmax { : detail::objectives::Softmax(handle, D, C, has_bias) { } -} +}; /** * Constructs a end-to-end quasi-newton objective function to solve the system @@ -192,7 +192,7 @@ struct ObjectiveWithData : detail::objectives::QNWithDataC, obj->D, obj->fit_intercept) { } -} +}; /** * @brief Minimize the given `raft::solver::quasi_newton::ObjectiveWithData` using diff --git a/cpp/include/raft/solver/simple_mat.cuh b/cpp/include/raft/solver/simple_mat.cuh index 5d20e171dd..69bd0acdd8 100644 --- a/cpp/include/raft/solver/simple_mat.cuh +++ b/cpp/include/raft/solver/simple_mat.cuh @@ -18,23 +18,27 @@ #include #include #include +#include -#include -#include -#include #include // #TODO: Replace with public header when ready #include +#include + +#include #include #include +#include #include -#include /** * NOTE: This will eventually get replaced with mdspan/mdarray */ -namespace raft::solver { +namespace raft::solver::quasi_newton { + +template +struct SimpleDenseMat; template struct SimpleMat { @@ -63,7 +67,6 @@ struct SimpleMat { cudaStream_t stream) const = 0; }; -enum STORAGE_ORDER { COL_MAJOR = 0, ROW_MAJOR = 1 }; template struct SimpleDenseMat : SimpleMat { diff --git a/cpp/include/raft/solver/solver_types.hpp b/cpp/include/raft/solver/solver_types.hpp index fb454e2745..db2f63ee65 100644 --- a/cpp/include/raft/solver/solver_types.hpp +++ b/cpp/include/raft/solver/solver_types.hpp @@ -18,7 +18,9 @@ namespace raft::solver { -enum lr_type { +enum STORAGE_ORDER { COL_MAJOR = 0, ROW_MAJOR = 1 }; + + enum lr_type { OPTIMAL, CONSTANT, INVSCALING, @@ -105,7 +107,46 @@ enum class LarsFitStatus { kOk, kCollinear, kError, kStop }; namespace quasi_newton { -struct qn_params { +/** Loss function types supported by the Quasi-Newton solvers. */ +enum qn_loss_type { + /** Logistic classification. + * Expected target: {0, 1}. + */ + QN_LOSS_LOGISTIC = 0, + /** L2 regression. + * Expected target: R. + */ + QN_LOSS_SQUARED = 1, + /** Softmax classification.. + * Expected target: {0, 1, ...}. + */ + QN_LOSS_SOFTMAX = 2, + /** Hinge. + * Expected target: {0, 1}. + */ + QN_LOSS_HINGE = 3, + /** Squared-hinge. + * Expected target: {0, 1}. + */ + QN_LOSS_SQ_HINGE = 4, + /** Epsilon-insensitive. + * Expected target: R. + */ + QN_LOSS_HINGE_EPS_INS = 5, + /** Epsilon-insensitive-squared. + * Expected target: R. + */ + QN_LOSS_HINGE_SQ_EPS_INS = 6, + /** L1 regression. + * Expected target: R. + */ + QN_LOSS_ABS = 7, + /** Someone forgot to set the loss type! */ + QN_LOSS_UNKNOWN = 99 +}; + + + struct qn_params { /** Loss type. */ qn_loss_type loss; /** Regularization: L1 component. */ diff --git a/cpp/test/solver/quasi_newton.cu b/cpp/test/solver/quasi_newton.cu index f0a67d9eb9..ff2d2316aa 100644 --- a/cpp/test/solver/quasi_newton.cu +++ b/cpp/test/solver/quasi_newton.cu @@ -14,12 +14,16 @@ * limitations under the License. */ -#include -#include #include + #include + +#include #include + #include + +#include #include #include @@ -149,16 +153,6 @@ inline void qn_fit_x(const raft::handle_t& handle, } struct QuasiNewtonTest : ::testing::Test { - static constexpr int N = 10; - static constexpr int D = 2; - - const static double* nobptr; - const static double tol; - const static double X[N][D]; - const raft::handle_t& handle; - cudaStream_t stream = 0; - std::shared_ptr> Xdev; - std::shared_ptr> ydev; QuasiNewtonTest() {} void SetUp() @@ -171,6 +165,18 @@ struct QuasiNewtonTest : ::testing::Test { handle.sync_stream(stream); } void TearDown() {} + + static constexpr int N = 10; + static constexpr int D = 2; + + const static double* nobptr; + const static double tol; + const static double X[N][D]; + const raft::handle_t handle; + cudaStream_t stream = 0; + std::shared_ptr> Xdev; + std::shared_ptr> ydev; + }; const double* QuasiNewtonTest::nobptr = 0; @@ -192,7 +198,7 @@ template const T* host_weights, const T* host_bias, const T* w, - const GLMDims& dims, + const LinearDims& dims, Comp& comp, cudaStream_t stream) { @@ -243,7 +249,7 @@ T run(const raft::handle_t& handle, } template -T run_api(const raft::handle_t& cuml_handle, +T run_api(const raft::handle_t& handle, qn_loss_type loss_type, int C, bool fit_intercept, @@ -275,7 +281,7 @@ T run_api(const raft::handle_t& cuml_handle, w0.fill(T(0), stream); T fx; - qn_fit_on_x(cuml_handle, + qn_fit_on_x(handle, pams, X_dense->data, X_dense->ord == COL_MAJOR, @@ -286,10 +292,8 @@ T run_api(const raft::handle_t& cuml_handle, w, &fx, &num_iters); -} -else { ADD_FAILURE(); } -return fx; + return fx; } TEST_F(QuasiNewtonTest, binary_logistic_vs_sklearn) @@ -338,7 +342,7 @@ TEST_F(QuasiNewtonTest, binary_logistic_vs_sklearn) ASSERT_TRUE(compApprox(obj_l2_b, fx)); ASSERT_TRUE(checkParamsEqual(handle, &w_l2_b[0], &b_l2_b, w0.data, loss_b, compApprox, stream)); - fx = run_api(cuml_handle, + fx = run_api(handle, QN_LOSS_LOGISTIC, 2, loss_b.fit_intercept, @@ -362,7 +366,7 @@ TEST_F(QuasiNewtonTest, binary_logistic_vs_sklearn) ASSERT_TRUE( checkParamsEqual(handle, &w_l1_no_b[0], nobptr, w0.data, loss_no_b, compApprox, stream)); - fx = run_api(cuml_handle, + fx = run_api(handle, QN_LOSS_LOGISTIC, 2, loss_no_b.fit_intercept, @@ -386,7 +390,7 @@ TEST_F(QuasiNewtonTest, binary_logistic_vs_sklearn) ASSERT_TRUE( checkParamsEqual(handle, &w_l2_no_b[0], nobptr, w0.data, loss_no_b, compApprox, stream)); - fx = run_api(cuml_handle, + fx = run_api(handle, QN_LOSS_LOGISTIC, 2, loss_no_b.fit_intercept, @@ -432,7 +436,7 @@ TEST_F(QuasiNewtonTest, multiclass_logistic_vs_sklearn) fx = run(handle, loss_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); ASSERT_TRUE(compApprox(obj_l1_b, fx)); - fx = run_api(cuml_handle, + fx = run_api(handle, QN_LOSS_SOFTMAX, C, loss_b.fit_intercept, @@ -453,7 +457,7 @@ TEST_F(QuasiNewtonTest, multiclass_logistic_vs_sklearn) fx = run(handle, loss_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); ASSERT_TRUE(compApprox(obj_l2_b, fx)); - fx = run_api(cuml_handle, + fx = run_api(handle, QN_LOSS_SOFTMAX, C, loss_b.fit_intercept, @@ -474,7 +478,7 @@ TEST_F(QuasiNewtonTest, multiclass_logistic_vs_sklearn) fx = run(handle, loss_no_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); ASSERT_TRUE(compApprox(obj_l1_no_b, fx)); - fx = run_api(cuml_handle, + fx = run_api(handle, QN_LOSS_SOFTMAX, C, loss_no_b.fit_intercept, @@ -496,7 +500,7 @@ TEST_F(QuasiNewtonTest, multiclass_logistic_vs_sklearn) fx = run(handle, loss_no_b, *Xdev, *ydev, l1, l2, w0.data, z, 0, stream); ASSERT_TRUE(compApprox(obj_l2_no_b, fx)); - fx = run_api(cuml_handle, + fx = run_api(handle, QN_LOSS_SOFTMAX, C, loss_no_b.fit_intercept, @@ -544,7 +548,7 @@ TEST_F(QuasiNewtonTest, linear_regression_vs_sklearn) ASSERT_TRUE(compApprox(obj_l1_b, fx)); ASSERT_TRUE(checkParamsEqual(handle, &w_l1_b[0], &b_l1_b, w0.data, loss_b, compApprox, stream)); - fx = run_api(cuml_handle, + fx = run_api(handle, QN_LOSS_SQUARED, 1, loss_b.fit_intercept, @@ -568,7 +572,7 @@ TEST_F(QuasiNewtonTest, linear_regression_vs_sklearn) ASSERT_TRUE(compApprox(obj_l2_b, fx)); ASSERT_TRUE(checkParamsEqual(handle, &w_l2_b[0], &b_l2_b, w0.data, loss_b, compApprox, stream)); - fx = run_api(cuml_handle, + fx = run_api(handle, QN_LOSS_SQUARED, 1, loss_b.fit_intercept, @@ -592,7 +596,7 @@ TEST_F(QuasiNewtonTest, linear_regression_vs_sklearn) ASSERT_TRUE( checkParamsEqual(handle, &w_l1_no_b[0], nobptr, w0.data, loss_no_b, compApprox, stream)); - fx = run_api(cuml_handle, + fx = run_api(handle, QN_LOSS_SQUARED, 1, loss_no_b.fit_intercept, @@ -616,7 +620,7 @@ TEST_F(QuasiNewtonTest, linear_regression_vs_sklearn) ASSERT_TRUE( checkParamsEqual(handle, &w_l2_no_b[0], nobptr, w0.data, loss_no_b, compApprox, stream)); - fx = run_api(cuml_handle, + fx = run_api(handle, QN_LOSS_SQUARED, 1, loss_no_b.fit_intercept, @@ -784,7 +788,7 @@ TEST_F(QuasiNewtonTest, dense_vs_sparse_logistic) ASSERT_TRUE(compApprox(preds_dense_host[i], preds_sparse_host[i])); } - f_dense = run_api(cuml_handle, + f_dense = run_api(handle, QN_LOSS_SOFTMAX, C, loss.fit_intercept, @@ -796,7 +800,7 @@ TEST_F(QuasiNewtonTest, dense_vs_sparse_logistic) z_dense, 0, stream); - f_sparse = run_api(cuml_handle, + f_sparse = run_api(handle, QN_LOSS_SOFTMAX, C, loss.fit_intercept,