From 3df3fe179bea137025413c35fcde2469e37d36fd Mon Sep 17 00:00:00 2001 From: Stefan Kroboth Date: Wed, 7 Feb 2024 19:31:35 +0100 Subject: [PATCH] Added derivative and Hessian for Himmelblau test function --- crates/argmin-testfunctions/src/himmelblau.rs | 94 +++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/crates/argmin-testfunctions/src/himmelblau.rs b/crates/argmin-testfunctions/src/himmelblau.rs index 468f49681..fa6b13c78 100644 --- a/crates/argmin-testfunctions/src/himmelblau.rs +++ b/crates/argmin-testfunctions/src/himmelblau.rs @@ -44,10 +44,51 @@ where (x1.powi(2) + x2 - n11).powi(2) + (x1 + x2.powi(2) - n7).powi(2) } +/// Derivative of Himmelblau test function +pub fn himmelblau_derivative(param: &[T; 2]) -> [T; 2] +where + T: Float + FromPrimitive, +{ + let [x1, x2] = *param; + + let n2 = T::from_f64(2.0).unwrap(); + let n4 = T::from_f64(4.0).unwrap(); + let n7 = T::from_f64(7.0).unwrap(); + let n11 = T::from_f64(11.0).unwrap(); + + [ + n4 * x1 * (x1.powi(2) + x2 - n11) + n2 * (x1 + x2.powi(2) - n7), + n4 * x2 * (x2.powi(2) + x1 - n7) + n2 * (x2 + x1.powi(2) - n11), + ] +} + +/// Hessian of Himmelblau test function +pub fn himmelblau_hessian(param: &[T; 2]) -> [[T; 2]; 2] +where + T: Float + FromPrimitive, +{ + let [x1, x2] = *param; + + let n2 = T::from_f64(2.0).unwrap(); + let n4 = T::from_f64(4.0).unwrap(); + let n7 = T::from_f64(7.0).unwrap(); + let n8 = T::from_f64(8.0).unwrap(); + let n11 = T::from_f64(11.0).unwrap(); + + let offdiag = n4 * (x1 + x2); + + [ + [n4 * (x1.powi(2) + x2 - n11) + n8 * x1.powi(2) + n2, offdiag], + [offdiag, n4 * (x2.powi(2) + x1 - n7) + n8 * x2.powi(2) + n2], + ] +} + #[cfg(test)] mod tests { use super::*; use approx::assert_relative_eq; + use finitediff::FiniteDiff; + use proptest::prelude::*; use std::{f32, f64}; #[test] @@ -87,5 +128,58 @@ mod tests { 0.0, epsilon = f32::EPSILON.into() ); + + let deriv = himmelblau_derivative(&[3.0_f32, 2.0_f32]); + for i in 0..2 { + assert_relative_eq!(deriv[i], 0.0, epsilon = f32::EPSILON); + } + + let deriv = himmelblau_derivative(&[-2.805118_f32, 3.131312_f32]); + for i in 0..2 { + assert_relative_eq!(deriv[i], 0.0, epsilon = 1e-4); + } + + let deriv = himmelblau_derivative(&[-3.779310_f64, -3.283186_f64]); + for i in 0..2 { + assert_relative_eq!(deriv[i], 0.0, epsilon = 1e-4); + } + + let deriv = himmelblau_derivative(&[3.584428_f64, -1.848126_f64]); + for i in 0..2 { + assert_relative_eq!(deriv[i], 0.0, epsilon = 1e-4); + } + } + + proptest! { + #[test] + fn test_himmelblau_derivative_finitediff(a in -5.0..5.0, b in -5.0..5.0) { + let param = [a, b]; + let derivative = himmelblau_derivative(¶m); + let derivative_fd = Vec::from(param).central_diff(&|x| himmelblau(&[x[0], x[1]])); + for i in 0..derivative.len() { + assert_relative_eq!(derivative[i], derivative_fd[i], epsilon = 1e-4); + } + } + } + + proptest! { + #[test] + fn test_himmelblau_hessian_finitediff(a in -5.0..5.0, b in -5.0..5.0) { + let param = [a, b]; + let hessian = himmelblau_hessian(¶m); + let hessian_fd = + Vec::from(param).central_hessian(&|x| himmelblau_derivative(&[x[0], x[1]]).to_vec()); + let n = hessian.len(); + println!("1: {hessian:?} at {a}/{b}"); + println!("2: {hessian_fd:?} at {a}/{b}"); + for i in 0..n { + assert_eq!(hessian[i].len(), n); + for j in 0..n { + if hessian_fd[i][j].is_finite() { + assert_relative_eq!(hessian[i][j], hessian_fd[i][j], epsilon = 1e-5); + } + } + } + } } }