From ddbdf7ceff12e95c71fd3c5e8d4860cda0f6ade6 Mon Sep 17 00:00:00 2001 From: Stefan Kroboth Date: Wed, 7 Feb 2024 16:32:15 +0100 Subject: [PATCH] Added derivative and Hessian for eggholder test function --- crates/argmin-testfunctions/src/eggholder.rs | 191 +++++++++++++++++++ 1 file changed, 191 insertions(+) diff --git a/crates/argmin-testfunctions/src/eggholder.rs b/crates/argmin-testfunctions/src/eggholder.rs index 86d6057b0..ad8cc0525 100644 --- a/crates/argmin-testfunctions/src/eggholder.rs +++ b/crates/argmin-testfunctions/src/eggholder.rs @@ -42,10 +42,168 @@ where - x1 * (x1 - (x2 + n47)).abs().sqrt().sin() } +/// Derivative of Eggholder test function +pub fn eggholder_derivative(param: &[T; 2]) -> [T; 2] +where + T: Float + FromPrimitive, +{ + let [x1, x2] = *param; + + let eps = T::epsilon(); + let n0 = T::from_f64(0.0).unwrap(); + let n2 = T::from_f64(2.0).unwrap(); + let n4 = T::from_f64(4.0).unwrap(); + let n47 = T::from_f64(47.0).unwrap(); + + let x1mx2m47 = x1 - x2 - n47; + let x1mx2m47abs = x1mx2m47.abs(); + let x1mx2m47abssqrt = x1mx2m47abs.sqrt(); + let x1mx2m47abssqrtsin = x1mx2m47abssqrt.sin(); + let x1mx2m47abssqrtcos = x1mx2m47abssqrt.cos(); + let x1hpx2p47 = x1 / n2 + x2 + n47; + let x1hpx2p47abs = x1hpx2p47.abs(); + let x1hpx2p47abssqrt = x1hpx2p47abs.sqrt(); + let x1hpx2p47abssqrtsin = x1hpx2p47abssqrt.sin(); + let x1hpx2p47abssqrtcos = x1hpx2p47abssqrt.cos(); + let x2mx1p47 = x2 - x1 + n47; + let x2mx1p47abs = x2mx1p47.abs(); + let x2mx1p47abssqrt = x2mx1p47abs.sqrt(); + let x2mx1p47abssqrtcos = x2mx1p47abssqrt.cos(); + + [ + -x1mx2m47abssqrtsin + - if x1mx2m47abs <= eps { + n0 + } else { + (x1 * x1mx2m47 * x1mx2m47abssqrtcos) / (n2 * x1mx2m47abssqrt.powi(3)) + } + - if x1hpx2p47abs <= eps { + n0 + } else { + ((x2 + n47) * x1hpx2p47abssqrtcos * x1hpx2p47) / (n4 * x1hpx2p47abssqrt.powi(3)) + }, + -x1hpx2p47abssqrtsin + - if x1hpx2p47abs <= eps { + n0 + } else { + ((x2 + n47) * x1hpx2p47 * x1hpx2p47abssqrtcos) / (n2 * x1hpx2p47abssqrt.powi(3)) + } + - if x2mx1p47abs <= eps { + n0 + } else { + (x1 * x2mx1p47 * x2mx1p47abssqrtcos) / (n2 * x2mx1p47abssqrt.powi(3)) + }, + ] +} + +/// Hessian of Eggholder test function +pub fn eggholder_hessian(param: &[T; 2]) -> [[T; 2]; 2] +where + T: Float + FromPrimitive, +{ + let [x1, x2] = *param; + + let eps = T::epsilon(); + let n0 = T::from_f64(0.0).unwrap(); + let n2 = T::from_f64(2.0).unwrap(); + let n3 = T::from_f64(3.0).unwrap(); + let n4 = T::from_f64(4.0).unwrap(); + let n8 = T::from_f64(8.0).unwrap(); + let n16 = T::from_f64(16.0).unwrap(); + let n47 = T::from_f64(47.0).unwrap(); + + let x1mx2m47 = x1 - x2 - n47; + let x1mx2m47abs = x1mx2m47.abs(); + let x1mx2m47abssqrt = x1mx2m47abs.sqrt(); + let x1mx2m47abssqrtsin = x1mx2m47abssqrt.sin(); + let x1mx2m47abssqrtcos = x1mx2m47abssqrt.cos(); + let x1hpx2p47 = x1 / n2 + x2 + n47; + let x1hpx2p47abs = x1hpx2p47.abs(); + let x1hpx2p47abssqrt = x1hpx2p47abs.sqrt(); + let x1hpx2p47abssqrtsin = x1hpx2p47abssqrt.sin(); + let x1hpx2p47abssqrtcos = x1hpx2p47abssqrt.cos(); + let x2mx1p47 = x2 - x1 + n47; + let x2mx1p47abs = x2mx1p47.abs(); + let x2mx1p47abssqrt = x2mx1p47abs.sqrt(); + let x2mx1p47abssqrtcos = x2mx1p47abssqrt.cos(); + let x2mx1p47abssqrtsin = x2mx1p47abssqrt.sin(); + + let a = if x1mx2m47abs <= eps { + n0 + } else { + (x1 * x1mx2m47abssqrtsin) / (n4 * x1mx2m47abs) + - (x1mx2m47 * x1mx2m47abssqrtcos) / (x1mx2m47abssqrt.powi(3)) + + (n3 * x1 * x1mx2m47.powi(2) * x1mx2m47abssqrtcos) / (n4 * x1mx2m47abssqrt.powi(7)) + } + if x1mx2m47abs <= eps && x1.abs() <= eps { + n0 + } else { + -(x1 * x1mx2m47abssqrtcos) / (n2 * x1mx2m47abssqrt.powi(3)) + } + if x1hpx2p47abs <= eps { + n0 + } else { + (n3 * (x2 + n47) * x1hpx2p47abssqrtcos * x1hpx2p47.powi(2)) + / (n16 * x1hpx2p47abssqrt.powi(7)) + + ((x2 + n47) * x1hpx2p47abssqrtsin) / (n16 * x1hpx2p47abs) + } - if x1hpx2p47abs <= eps && (x1 + n47).abs() <= eps { + n0 + } else { + ((x2 + n47) * x1hpx2p47abssqrtcos) / (n8 * x1hpx2p47abssqrt.powi(3)) + }; + + let b = if x1hpx2p47abs <= eps { + n0 + } else { + ((x2 + n47) * x1hpx2p47abssqrtsin) / (n4 * x1hpx2p47abs) + - (x1hpx2p47 * x1hpx2p47abssqrtcos) / (x1hpx2p47abssqrt.powi(3)) + + (n3 * (x2 + n47) * x1hpx2p47.powi(2) * x1hpx2p47abssqrtcos) + / (n4 * x1hpx2p47abssqrt.powi(7)) + } - if x1hpx2p47abs <= eps && (x2 + n47).abs() <= eps { + n0 + } else { + ((x2 + n47) * x1hpx2p47abssqrtcos) / (n2 * x1hpx2p47abssqrt.powi(3)) + } + if x2mx1p47abs <= eps { + n0 + } else { + (x1 * x2mx1p47abssqrtsin) / (n4 * x2mx1p47abs) + + (n3 * x1 * x2mx1p47.powi(2) * x2mx1p47abssqrtcos) / (n4 * x2mx1p47abssqrt.powi(7)) + } - if x2mx1p47abs <= eps && x1.abs() <= eps { + n0 + } else { + (x1 * x2mx1p47abssqrtcos) / (n2 * x2mx1p47abssqrt.powi(3)) + }; + + let offdiag = if x1hpx2p47abs <= eps { + n0 + } else { + ((x2 + n47) * x1hpx2p47abssqrtsin) / (n8 * x1hpx2p47abs) + - (x1hpx2p47 * x1hpx2p47abssqrtcos) / (n4 * x1hpx2p47abssqrt.powi(3)) + + (n3 * (x2 + n47) * x1hpx2p47.powi(2) * x1hpx2p47abssqrtcos) + / (n8 * x1hpx2p47abssqrt.powi(7)) + } - if x1hpx2p47abs <= eps && (x2 + n47).abs() <= eps { + n0 + } else { + ((x2 + n47) * x1hpx2p47abssqrtcos) / (n4 * x1hpx2p47abssqrt.powi(3)) + } + if x2mx1p47abs <= eps { + n0 + } else { + -(x1 * x2mx1p47 * x2mx1p47abssqrtsin) / (n4 * x2mx1p47 * x2mx1p47abs) + - (x2mx1p47 * x2mx1p47abssqrtcos) / (n2 * x2mx1p47abssqrt.powi(3)) + - (n3 * x1 * x2mx1p47.powi(2) * x2mx1p47abssqrtcos) / (n4 * x2mx1p47abssqrt.powi(7)) + } + if x2mx1p47abs <= eps && x1.abs() <= eps { + n0 + } else { + (x1 * x2mx1p47abssqrtcos) / (n2 * x2mx1p47abssqrt.powi(3)) + }; + + [[a, offdiag], [offdiag, b]] +} + #[cfg(test)] mod tests { use super::*; use approx::assert_relative_eq; + use finitediff::FiniteDiff; + use proptest::prelude::*; use std::{f32, f64}; #[test] @@ -61,4 +219,37 @@ mod tests { epsilon = f64::EPSILON ); } + + proptest! { + #[test] + fn test_eggholder_derivative_finitediff(a in -512.0..512.0, b in -512.0..512.0) { + let param = [a, b]; + let derivative = eggholder_derivative(¶m); + let derivative_fd = Vec::from(param).central_diff(&|x| eggholder(&[x[0], x[1]])); + for i in 0..derivative.len() { + assert_relative_eq!(derivative[i], derivative_fd[i], epsilon = 1e-4); + } + } + } + + proptest! { + #[test] + fn test_eggholder_hessian_finitediff(a in -512.0..512.0, b in -512.0..512.0) { + let param = [a, b]; + let hessian = eggholder_hessian(¶m); + let hessian_fd = + Vec::from(param).central_hessian(&|x| eggholder_derivative(&[x[0], x[1]]).to_vec()); + let n = hessian.len(); + println!("1: {hessian:?} at {a}/{b}"); + println!("2: {hessian_fd:?} at {a}/{b}"); + for i in 0..n { + assert_eq!(hessian[i].len(), n); + for j in 0..n { + if hessian_fd[i][j].is_finite() { + assert_relative_eq!(hessian[i][j], hessian_fd[i][j], epsilon = 1e-5); + } + } + } + } + } }