Skip to content

Commit

Permalink
Be even more generic
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan-k committed Mar 2, 2024
1 parent 5bdde3e commit c947cc6
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 70 deletions.
33 changes: 21 additions & 12 deletions crates/finitediff/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -400,11 +400,21 @@ mod utils;
mod vec;

#[cfg(feature = "ndarray")]
use ndarray_m::*;
pub use pert::*;
pub use vec::*;

const EPS_F64: f64 = std::f64::EPSILON;
use ndarray_m::{
central_diff_ndarray, central_hessian_ndarray, central_hessian_vec_prod_ndarray,
central_jacobian_ndarray, central_jacobian_pert_ndarray, central_jacobian_vec_prod_ndarray,
forward_diff_ndarray, forward_hessian_ndarray, forward_hessian_nograd_ndarray,
forward_hessian_nograd_sparse_ndarray, forward_hessian_vec_prod_ndarray,
forward_jacobian_ndarray, forward_jacobian_pert_ndarray, forward_jacobian_vec_prod_ndarray,
};
pub use pert::{PerturbationVector, PerturbationVectors};
pub use vec::{
central_diff_vec, central_hessian_vec, central_hessian_vec_prod_vec, central_jacobian_pert_vec,
central_jacobian_vec, central_jacobian_vec_prod_vec, forward_diff_vec,
forward_hessian_nograd_sparse_vec, forward_hessian_nograd_vec, forward_hessian_vec,
forward_hessian_vec_prod_vec, forward_jacobian_pert_vec, forward_jacobian_vec,
forward_jacobian_vec_prod_vec,
};

pub trait FiniteDiff
where
Expand Down Expand Up @@ -699,39 +709,39 @@ where
}

fn forward_hessian(&self, g: &dyn Fn(&Self) -> Self::OperatorOutput) -> Self::Jacobian {
forward_hessian_ndarray_f64(self, g)
forward_hessian_ndarray(self, g)
}

fn central_hessian(&self, g: &dyn Fn(&Self) -> Self::OperatorOutput) -> Self::Jacobian {
central_hessian_ndarray_f64(self, g)
central_hessian_ndarray(self, g)
}

fn forward_hessian_vec_prod(
&self,
g: &dyn Fn(&Self) -> Self::OperatorOutput,
p: &Self,
) -> Self {
forward_hessian_vec_prod_ndarray_f64(self, g, p)
forward_hessian_vec_prod_ndarray(self, g, p)
}

fn central_hessian_vec_prod(
&self,
g: &dyn Fn(&Self) -> Self::OperatorOutput,
p: &Self,
) -> Self {
central_hessian_vec_prod_ndarray_f64(self, g, p)
central_hessian_vec_prod_ndarray(self, g, p)
}

fn forward_hessian_nograd(&self, f: &dyn Fn(&Self) -> f64) -> Self::Hessian {
forward_hessian_nograd_ndarray_f64(self, f)
forward_hessian_nograd_ndarray(self, f)
}

fn forward_hessian_nograd_sparse(
&self,
f: &dyn Fn(&Self) -> f64,
indices: Vec<[usize; 2]>,
) -> Self::Hessian {
forward_hessian_nograd_sparse_ndarray_f64(self, f, indices)
forward_hessian_nograd_sparse_ndarray(self, f, indices)
}
}

Expand Down Expand Up @@ -1009,7 +1019,6 @@ mod tests_vec {
#[cfg(test)]
mod tests_ndarray {
use super::*;
use ndarray;
use ndarray::{array, Array1};

const COMP_ACC: f64 = 1e-6;
Expand Down
1 change: 0 additions & 1 deletion crates/finitediff/src/ndarray_m/diff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ where
#[cfg(test)]
mod tests {
use super::*;
use ndarray;

const COMP_ACC: f64 = 1e-6;

Expand Down
145 changes: 89 additions & 56 deletions crates/finitediff/src/ndarray_m/hessian.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,88 +5,115 @@
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

use std::ops::AddAssign;

use ndarray::Array2;
use ndarray::ScalarOperand;
use num::{Float, FromPrimitive};

use crate::utils::*;
use crate::EPS_F64;
use crate::utils::{mod_and_calc, restore_symmetry_ndarray, KV};

/// I wish this wasn't necessary!
const EPS_F64_NOGRAD: f64 = EPS_F64 * 2.0;
pub fn forward_hessian_ndarray<F>(
x: &ndarray::Array1<F>,
grad: &dyn Fn(&ndarray::Array1<F>) -> ndarray::Array1<F>,
) -> ndarray::Array2<F>
where
F: Float + FromPrimitive,
{
let eps_sqrt = F::epsilon().sqrt();

pub fn forward_hessian_ndarray_f64(
x: &ndarray::Array1<f64>,
grad: &dyn Fn(&ndarray::Array1<f64>) -> ndarray::Array1<f64>,
) -> ndarray::Array2<f64> {
let mut xt = x.clone();
let fx = (grad)(x);
let rn = fx.len();
let n = x.len();
let mut out = Array2::zeros((n, rn));
for i in 0..n {
let fx1 = mod_and_calc(&mut xt, grad, i, EPS_F64.sqrt());
let fx1 = mod_and_calc(&mut xt, grad, i, eps_sqrt);
for j in 0..rn {
out[(i, j)] = (fx1[j] - fx[j]) / EPS_F64.sqrt();
out[(i, j)] = (fx1[j] - fx[j]) / eps_sqrt;
}
}
// restore symmetry
restore_symmetry_ndarray(out)
}

pub fn central_hessian_ndarray_f64(
x: &ndarray::Array1<f64>,
grad: &dyn Fn(&ndarray::Array1<f64>) -> ndarray::Array1<f64>,
) -> ndarray::Array2<f64> {
pub fn central_hessian_ndarray<F>(
x: &ndarray::Array1<F>,
grad: &dyn Fn(&ndarray::Array1<F>) -> ndarray::Array1<F>,
) -> ndarray::Array2<F>
where
F: Float + FromPrimitive,
{
let eps_sqrt = F::epsilon().sqrt();

let mut xt = x.clone();
// TODO: get rid of this!
let fx = (grad)(x);
let rn = fx.len();
let n = x.len();
let mut out = ndarray::Array2::zeros((n, rn));
for i in 0..n {
let fx1 = mod_and_calc(&mut xt, grad, i, EPS_F64.sqrt());
let fx2 = mod_and_calc(&mut xt, grad, i, -EPS_F64.sqrt());
let fx1 = mod_and_calc(&mut xt, grad, i, eps_sqrt);
let fx2 = mod_and_calc(&mut xt, grad, i, -eps_sqrt);
for j in 0..rn {
out[(i, j)] = (fx1[j] - fx2[j]) / (2.0 * EPS_F64.sqrt());
out[(i, j)] = (fx1[j] - fx2[j]) / (F::from_f64(2.0).unwrap() * eps_sqrt);
}
}
// restore symmetry
restore_symmetry_ndarray(out)
}

pub fn forward_hessian_vec_prod_ndarray_f64(
x: &ndarray::Array1<f64>,
grad: &dyn Fn(&ndarray::Array1<f64>) -> ndarray::Array1<f64>,
p: &ndarray::Array1<f64>,
) -> ndarray::Array1<f64> {
pub fn forward_hessian_vec_prod_ndarray<F>(
x: &ndarray::Array1<F>,
grad: &dyn Fn(&ndarray::Array1<F>) -> ndarray::Array1<F>,
p: &ndarray::Array1<F>,
) -> ndarray::Array1<F>
where
F: Float + ScalarOperand,
{
let eps_sqrt = F::epsilon().sqrt();

let fx = (grad)(x);
let x1 = x + &(p.mapv(|pi| pi * EPS_F64.sqrt()));
let x1 = x + &(p.mapv(|pi| pi * eps_sqrt));
let fx1 = (grad)(&x1);
(fx1 - fx) / EPS_F64.sqrt()
(fx1 - fx) / eps_sqrt
}

pub fn central_hessian_vec_prod_ndarray_f64(
x: &ndarray::Array1<f64>,
grad: &dyn Fn(&ndarray::Array1<f64>) -> ndarray::Array1<f64>,
p: &ndarray::Array1<f64>,
) -> ndarray::Array1<f64> {
let x1 = x + &(p.mapv(|pi| pi * EPS_F64.sqrt()));
let x2 = x - &(p.mapv(|pi| pi * EPS_F64.sqrt()));
pub fn central_hessian_vec_prod_ndarray<F>(
x: &ndarray::Array1<F>,
grad: &dyn Fn(&ndarray::Array1<F>) -> ndarray::Array1<F>,
p: &ndarray::Array1<F>,
) -> ndarray::Array1<F>
where
F: Float + FromPrimitive + ScalarOperand,
{
let eps_sqrt = F::epsilon().sqrt();

let x1 = x + &(p.mapv(|pi| pi * eps_sqrt));
let x2 = x - &(p.mapv(|pi| pi * eps_sqrt));
let fx1 = (grad)(&x1);
let fx2 = (grad)(&x2);
(fx1 - fx2) / (2.0 * EPS_F64.sqrt())
(fx1 - fx2) / (F::from_f64(2.0).unwrap() * eps_sqrt)
}

pub fn forward_hessian_nograd_ndarray_f64(
x: &ndarray::Array1<f64>,
f: &dyn Fn(&ndarray::Array1<f64>) -> f64,
) -> ndarray::Array2<f64> {
pub fn forward_hessian_nograd_ndarray<F>(
x: &ndarray::Array1<F>,
f: &dyn Fn(&ndarray::Array1<F>) -> F,
) -> ndarray::Array2<F>
where
F: Float + FromPrimitive + AddAssign,
{
// TODO: Check why this is necessary
let eps_nograd = F::from_f64(2.0).unwrap() * F::epsilon();
let eps_sqrt_nograd = eps_nograd.sqrt();

let fx = (f)(x);
let n = x.len();
let mut xt = x.clone();

// Precompute f(x + sqrt(EPS) * e_i) for all i
let fxei: Vec<f64> = (0..n)
.map(|i| mod_and_calc(&mut xt, f, i, EPS_F64_NOGRAD.sqrt()))
let fxei: Vec<F> = (0..n)
.map(|i| mod_and_calc(&mut xt, f, i, eps_sqrt_nograd))
.collect();

let mut out = ndarray::Array2::zeros((n, n));
Expand All @@ -95,12 +122,12 @@ pub fn forward_hessian_nograd_ndarray_f64(
let t = {
let xti = xt[i];
let xtj = xt[j];
xt[i] += EPS_F64_NOGRAD.sqrt();
xt[j] += EPS_F64_NOGRAD.sqrt();
xt[i] += eps_sqrt_nograd;
xt[j] += eps_sqrt_nograd;
let fxij = (f)(&xt);
xt[i] = xti;
xt[j] = xtj;
(fxij - fxei[i] - fxei[j] + fx) / EPS_F64_NOGRAD
(fxij - fxei[i] - fxei[j] + fx) / eps_nograd
};
out[(i, j)] = t;
out[(j, i)] = t;
Expand All @@ -109,11 +136,18 @@ pub fn forward_hessian_nograd_ndarray_f64(
out
}

pub fn forward_hessian_nograd_sparse_ndarray_f64(
x: &ndarray::Array1<f64>,
f: &dyn Fn(&ndarray::Array1<f64>) -> f64,
pub fn forward_hessian_nograd_sparse_ndarray<F>(
x: &ndarray::Array1<F>,
f: &dyn Fn(&ndarray::Array1<F>) -> F,
indices: Vec<[usize; 2]>,
) -> ndarray::Array2<f64> {
) -> ndarray::Array2<F>
where
F: Float + FromPrimitive + AddAssign,
{
// TODO: Check why this is necessary
let eps_nograd = F::from_f64(2.0).unwrap() * F::epsilon();
let eps_sqrt_nograd = eps_nograd.sqrt();

let fx = (f)(x);
let n = x.len();
let mut xt = x.clone();
Expand All @@ -129,24 +163,24 @@ pub fn forward_hessian_nograd_sparse_ndarray_f64(
let mut fxei = KV::new(idxs.len());

for idx in idxs.iter() {
fxei.set(*idx, mod_and_calc(&mut xt, f, *idx, EPS_F64_NOGRAD.sqrt()));
fxei.set(*idx, mod_and_calc(&mut xt, f, *idx, eps_sqrt_nograd));
}

let mut out = ndarray::Array2::zeros((n, n));
for [i, j] in indices {
let t = {
let xti = xt[i];
let xtj = xt[j];
xt[i] += EPS_F64_NOGRAD.sqrt();
xt[j] += EPS_F64_NOGRAD.sqrt();
xt[i] += eps_sqrt_nograd;
xt[j] += eps_sqrt_nograd;
let fxij = (f)(&xt);
xt[i] = xti;
xt[j] = xtj;

let fxi = fxei.get(i).unwrap();
let fxj = fxei.get(j).unwrap();

(fxij - fxi - fxj + fx) / EPS_F64_NOGRAD
(fxij - fxi - fxj + fx) / eps_nograd
};
out[(i, j)] = t;
out[(j, i)] = t;
Expand All @@ -157,7 +191,6 @@ pub fn forward_hessian_nograd_sparse_ndarray_f64(
#[cfg(test)]
mod tests {
use super::*;
use ndarray;
use ndarray::{array, Array1};

const COMP_ACC: f64 = 1e-6;
Expand Down Expand Up @@ -193,7 +226,7 @@ mod tests {

#[test]
fn test_forward_hessian_ndarray_f64() {
let hessian = forward_hessian_ndarray_f64(&x(), &g);
let hessian = forward_hessian_ndarray(&x(), &g);
let res = res1();
// println!("hessian:\n{:#?}", hessian);
// println!("diff:\n{:#?}", diff);
Expand All @@ -206,7 +239,7 @@ mod tests {

#[test]
fn test_central_hessian_ndarray_f64() {
let hessian = central_hessian_ndarray_f64(&x(), &g);
let hessian = central_hessian_ndarray(&x(), &g);
let res = res1();
// println!("hessian:\n{:#?}", hessian);
// println!("diff:\n{:#?}", diff);
Expand All @@ -219,7 +252,7 @@ mod tests {

#[test]
fn test_forward_hessian_vec_prod_ndarray_f64() {
let hessian = forward_hessian_vec_prod_ndarray_f64(&x(), &g, &p());
let hessian = forward_hessian_vec_prod_ndarray(&x(), &g, &p());
let res = res2();
// println!("hessian:\n{:#?}", hessian);
// println!("diff:\n{:#?}", diff);
Expand All @@ -230,7 +263,7 @@ mod tests {

#[test]
fn test_central_hessian_vec_prod_ndarray_f64() {
let hessian = central_hessian_vec_prod_ndarray_f64(&x(), &g, &p());
let hessian = central_hessian_vec_prod_ndarray(&x(), &g, &p());
let res = res2();
// println!("hessian:\n{:#?}", hessian);
// println!("diff:\n{:#?}", diff);
Expand All @@ -241,7 +274,7 @@ mod tests {

#[test]
fn test_forward_hessian_nograd_ndarray_f64() {
let hessian = forward_hessian_nograd_ndarray_f64(&x(), &f);
let hessian = forward_hessian_nograd_ndarray(&x(), &f);
let res = res1();
// println!("hessian:\n{:#?}", hessian);
for i in 0..4 {
Expand All @@ -254,7 +287,7 @@ mod tests {
#[test]
fn test_forward_hessian_nograd_sparse_ndarray_f64() {
let indices = vec![[1, 1], [2, 3], [3, 3]];
let hessian = forward_hessian_nograd_sparse_ndarray_f64(&x(), &f, indices);
let hessian = forward_hessian_nograd_sparse_ndarray(&x(), &f, indices);
let res = res1();
// println!("hessian:\n{:#?}", hessian);
// println!("diff:\n{:#?}", diff);
Expand Down
1 change: 0 additions & 1 deletion crates/finitediff/src/ndarray_m/jacobian.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ mod tests {
use crate::PerturbationVector;

use super::*;
use ndarray;
use ndarray::{array, Array1};

const COMP_ACC: f64 = 1e-6;
Expand Down

0 comments on commit c947cc6

Please sign in to comment.