Skip to content

Commit

Permalink
Added derivative and Hessian for eggholder test function
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan-k committed Feb 7, 2024
1 parent 3562ec7 commit 224bb46
Showing 1 changed file with 191 additions and 0 deletions.
191 changes: 191 additions & 0 deletions crates/argmin-testfunctions/src/eggholder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,168 @@ where
- x1 * (x1 - (x2 + n47)).abs().sqrt().sin()
}

/// Derivative of Eggholder test function
pub fn eggholder_derivative<T>(param: &[T; 2]) -> [T; 2]
where
T: Float + FromPrimitive,
{
let [x1, x2] = *param;

let eps = T::epsilon();
let n0 = T::from_f64(0.0).unwrap();
let n2 = T::from_f64(2.0).unwrap();
let n4 = T::from_f64(4.0).unwrap();
let n47 = T::from_f64(47.0).unwrap();

let x1mx2m47 = x1 - x2 - n47;
let x1mx2m47abs = x1mx2m47.abs();
let x1mx2m47abssqrt = x1mx2m47abs.sqrt();
let x1mx2m47abssqrtsin = x1mx2m47abssqrt.sin();
let x1mx2m47abssqrtcos = x1mx2m47abssqrt.cos();
let x1hpx2p47 = x1 / n2 + x2 + n47;
let x1hpx2p47abs = x1hpx2p47.abs();
let x1hpx2p47abssqrt = x1hpx2p47abs.sqrt();
let x1hpx2p47abssqrtsin = x1hpx2p47abssqrt.sin();
let x1hpx2p47abssqrtcos = x1hpx2p47abssqrt.cos();
let x2mx1p47 = x2 - x1 + n47;
let x2mx1p47abs = x2mx1p47.abs();
let x2mx1p47abssqrt = x2mx1p47abs.sqrt();
let x2mx1p47abssqrtcos = x2mx1p47abssqrt.cos();

[
-x1mx2m47abssqrtsin
- if x1mx2m47abs <= eps {
n0
} else {
(x1 * x1mx2m47 * x1mx2m47abssqrtcos) / (n2 * x1mx2m47abssqrt.powi(3))
}
- if x1hpx2p47abs <= eps {
n0
} else {
((x2 + n47) * x1hpx2p47abssqrtcos * x1hpx2p47) / (n4 * x1hpx2p47abssqrt.powi(3))
},
-x1hpx2p47abssqrtsin
- if x1hpx2p47abs <= eps {
n0
} else {
((x2 + n47) * x1hpx2p47 * x1hpx2p47abssqrtcos) / (n2 * x1hpx2p47abssqrt.powi(3))
}
- if x2mx1p47abs <= eps {
n0
} else {
(x1 * x2mx1p47 * x2mx1p47abssqrtcos) / (n2 * x2mx1p47abssqrt.powi(3))
},
]
}

/// Hessian of Eggholder test function
pub fn eggholder_hessian<T>(param: &[T; 2]) -> [[T; 2]; 2]
where
T: Float + FromPrimitive,
{
let [x1, x2] = *param;

let eps = T::epsilon();
let n0 = T::from_f64(0.0).unwrap();
let n2 = T::from_f64(2.0).unwrap();
let n3 = T::from_f64(3.0).unwrap();
let n4 = T::from_f64(4.0).unwrap();
let n8 = T::from_f64(8.0).unwrap();
let n16 = T::from_f64(16.0).unwrap();
let n47 = T::from_f64(47.0).unwrap();

let x1mx2m47 = x1 - x2 - n47;
let x1mx2m47abs = x1mx2m47.abs();
let x1mx2m47abssqrt = x1mx2m47abs.sqrt();
let x1mx2m47abssqrtsin = x1mx2m47abssqrt.sin();
let x1mx2m47abssqrtcos = x1mx2m47abssqrt.cos();
let x1hpx2p47 = x1 / n2 + x2 + n47;
let x1hpx2p47abs = x1hpx2p47.abs();
let x1hpx2p47abssqrt = x1hpx2p47abs.sqrt();
let x1hpx2p47abssqrtsin = x1hpx2p47abssqrt.sin();
let x1hpx2p47abssqrtcos = x1hpx2p47abssqrt.cos();
let x2mx1p47 = x2 - x1 + n47;
let x2mx1p47abs = x2mx1p47.abs();
let x2mx1p47abssqrt = x2mx1p47abs.sqrt();
let x2mx1p47abssqrtcos = x2mx1p47abssqrt.cos();
let x2mx1p47abssqrtsin = x2mx1p47abssqrt.sin();

let a = if x1mx2m47abs <= eps {
n0
} else {
(x1 * x1mx2m47abssqrtsin) / (n4 * x1mx2m47abs)
- (x1mx2m47 * x1mx2m47abssqrtcos) / (x1mx2m47abssqrt.powi(3))
+ (n3 * x1 * x1mx2m47.powi(2) * x1mx2m47abssqrtcos) / (n4 * x1mx2m47abssqrt.powi(7))
} + if x1mx2m47abs <= eps && x1.abs() <= eps {
n0
} else {
-(x1 * x1mx2m47abssqrtcos) / (n2 * x1mx2m47abssqrt.powi(3))
} + if x1hpx2p47abs <= eps {
n0
} else {
(n3 * (x2 + n47) * x1hpx2p47abssqrtcos * x1hpx2p47.powi(2))
/ (n16 * x1hpx2p47abssqrt.powi(7))
+ ((x2 + n47) * x1hpx2p47abssqrtsin) / (n16 * x1hpx2p47abs)
} - if x1hpx2p47abs <= eps && (x1 + n47).abs() <= eps {
n0
} else {
((x2 + n47) * x1hpx2p47abssqrtcos) / (n8 * x1hpx2p47abssqrt.powi(3))
};

let b = if x1hpx2p47abs <= eps {
n0
} else {
((x2 + n47) * x1hpx2p47abssqrtsin) / (n4 * x1hpx2p47abs)
- (x1hpx2p47 * x1hpx2p47abssqrtcos) / (x1hpx2p47abssqrt.powi(3))
+ (n3 * (x2 + n47) * x1hpx2p47.powi(2) * x1hpx2p47abssqrtcos)
/ (n4 * x1hpx2p47abssqrt.powi(7))
} - if x1hpx2p47abs <= eps && (x2 + n47).abs() <= eps {
n0
} else {
((x2 + n47) * x1hpx2p47abssqrtcos) / (n2 * x1hpx2p47abssqrt.powi(3))
} + if x2mx1p47abs <= eps {
n0
} else {
(x1 * x2mx1p47abssqrtsin) / (n4 * x2mx1p47abs)
+ (n3 * x1 * x2mx1p47.powi(2) * x2mx1p47abssqrtcos) / (n4 * x2mx1p47abssqrt.powi(7))
} - if x2mx1p47abs <= eps && x1.abs() <= eps {
n0
} else {
(x1 * x2mx1p47abssqrtcos) / (n2 * x2mx1p47abssqrt.powi(3))
};

let offdiag = if x1hpx2p47abs <= eps {
n0
} else {
((x2 + n47) * x1hpx2p47abssqrtsin) / (n8 * x1hpx2p47abs)
- (x1hpx2p47 * x1hpx2p47abssqrtcos) / (n4 * x1hpx2p47abssqrt.powi(3))
+ (n3 * (x2 + n47) * x1hpx2p47.powi(2) * x1hpx2p47abssqrtcos)
/ (n8 * x1hpx2p47abssqrt.powi(7))
} - if x1hpx2p47abs <= eps && (x2 + n47).abs() <= eps {
n0
} else {
((x2 + n47) * x1hpx2p47abssqrtcos) / (n4 * x1hpx2p47abssqrt.powi(3))
} + if x2mx1p47abs <= eps {
n0
} else {
-(x1 * x2mx1p47 * x2mx1p47abssqrtsin) / (n4 * x2mx1p47 * x2mx1p47abs)
- (x2mx1p47 * x2mx1p47abssqrtcos) / (n2 * x2mx1p47abssqrt.powi(3))
- (n3 * x1 * x2mx1p47.powi(2) * x2mx1p47abssqrtcos) / (n4 * x2mx1p47abssqrt.powi(7))
} + if x2mx1p47abs <= eps && x1.abs() <= eps {
n0
} else {
(x1 * x2mx1p47abssqrtcos) / (n2 * x2mx1p47abssqrt.powi(3))
};

[[a, offdiag], [offdiag, b]]
}

#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use finitediff::FiniteDiff;
use proptest::prelude::*;
use std::{f32, f64};

#[test]
Expand All @@ -61,4 +219,37 @@ mod tests {
epsilon = f64::EPSILON
);
}

proptest! {
#[test]
fn test_eggholder_derivative_finitediff(a in -512.0..512.0, b in -512.0..512.0) {
let param = [a, b];
let derivative = eggholder_derivative(&param);
let derivative_fd = Vec::from(param).central_diff(&|x| eggholder(&[x[0], x[1]]));
for i in 0..derivative.len() {
assert_relative_eq!(derivative[i], derivative_fd[i], epsilon = 1e-4);
}
}
}

proptest! {
#[test]
fn test_eggholder_hessian_finitediff(a in -512.0..512.0, b in -512.0..512.0) {
let param = [a, b];
let hessian = eggholder_hessian(&param);
let hessian_fd =
Vec::from(param).central_hessian(&|x| eggholder_derivative(&[x[0], x[1]]).to_vec());
let n = hessian.len();
println!("1: {hessian:?} at {a}/{b}");
println!("2: {hessian_fd:?} at {a}/{b}");
for i in 0..n {
assert_eq!(hessian[i].len(), n);
for j in 0..n {
if hessian_fd[i][j].is_finite() {
assert_relative_eq!(hessian[i][j], hessian_fd[i][j], epsilon = 1e-5);
}
}
}
}
}
}

0 comments on commit 224bb46

Please sign in to comment.