Skip to content

Commit

Permalink
Added const generics versions of Rosenbrock derivative and Hessian
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan-k committed Jan 24, 2024
1 parent dd6a9e4 commit 1438820
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tools/testfunctions/benches/testfunctions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ mod tests {

make_bench!(rosenbrock(&[-43.0, 53.0, 3.4], 1_f64, 100_f64));
make_bench!(rosenbrock_derivative(&[-43.0, 53.0], 1_f64, 100_f64));
make_bench!(rosenbrock_derivative_const(&[-43.0, 53.0], 1_f64, 100_f64));
make_bench!(rosenbrock_hessian(&[-43.0, 53.0], 1_f64, 100_f64));
make_bench!(rosenbrock_hessian_const(&[-43.0, 53.0], 1_f64, 100_f64));

make_bench!(sphere(&vec![-43.0, 53.0]));
make_bench!(sphere_derivative(&[-43.0, 53.0]));
Expand Down
51 changes: 51 additions & 0 deletions tools/testfunctions/src/rosenbrock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,57 @@ where
hessian
}

/// Derivative of the multidimensional Rosenbrock test function
///
/// This is the const generics version, which requires the number of parameters to be known
/// at compile time.
pub fn rosenbrock_derivative_const<const N: usize, T>(param: &[T; N], a: T, b: T) -> [T; N]
where
T: Float + FromPrimitive + AddAssign,
{
let n0 = T::from_f64(0.0).unwrap();
let n2 = T::from_f64(2.0).unwrap();
let n4 = T::from_f64(4.0).unwrap();

let mut result = [n0; N];

for i in 0..(N - 1) {
let xi = param[i];
let xi1 = param[i + 1];

let t1 = -n4 * b * xi * (xi1 - xi.powi(2));
let t2 = n2 * b * (xi1 - xi.powi(2));

result[i] += t1 + n2 * (xi - a);
result[i + 1] += t2;
}
result
}

/// Hessian of the multidimensional Rosenbrock test function
pub fn rosenbrock_hessian_const<const N: usize, T>(x: &[T; N], a: T, b: T) -> [[T; N]; N]
where
T: Float + FromPrimitive + AddAssign,
{
let n0 = T::from_f64(0.0).unwrap();
let n2 = T::from_f64(2.0).unwrap();
let n4 = T::from_f64(4.0).unwrap();
let n12 = T::from_f64(12.0).unwrap();

let mut hessian = [[n0; N]; N];

for i in 0..(N - 1) {
let xi = x[i];
let xi1 = x[i + 1];

hessian[i][i] += n12 * b * xi.powi(2) - n4 * b * xi1 + n2 * a;
hessian[i + 1][i + 1] = n2 * b;
hessian[i][i + 1] = -n4 * b * xi;
hessian[i + 1][i] = -n4 * b * xi;
}
hessian
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down

0 comments on commit 1438820

Please sign in to comment.