Skip to content

Commit

Permalink
Be even more generic
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan-k committed Mar 2, 2024
1 parent 234c131 commit 5bdde3e
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 93 deletions.
28 changes: 14 additions & 14 deletions crates/finitediff/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -651,51 +651,51 @@ where
type OperatorOutput = ndarray::Array1<f64>;

fn forward_diff(&self, f: &dyn Fn(&Self) -> f64) -> Self {
forward_diff_ndarray_f64(self, f)
forward_diff_ndarray(self, f)
}

fn central_diff(&self, f: &dyn Fn(&ndarray::Array1<f64>) -> f64) -> Self {
central_diff_ndarray_f64(self, f)
central_diff_ndarray(self, f)
}

fn forward_jacobian(&self, fs: &dyn Fn(&Self) -> Self::OperatorOutput) -> Self::Jacobian {
forward_jacobian_ndarray_f64(self, fs)
forward_jacobian_ndarray(self, fs)
}

fn central_jacobian(&self, fs: &dyn Fn(&Self) -> Self::OperatorOutput) -> Self::Jacobian {
central_jacobian_ndarray_f64(self, fs)
central_jacobian_ndarray(self, fs)
}

fn forward_jacobian_vec_prod(
&self,
fs: &dyn Fn(&Self) -> Self::OperatorOutput,
p: &Self,
) -> Self {
forward_jacobian_vec_prod_ndarray_f64(self, fs, p)
forward_jacobian_vec_prod_ndarray(self, fs, p)
}

fn central_jacobian_vec_prod(
&self,
fs: &dyn Fn(&Self) -> Self::OperatorOutput,
p: &Self,
) -> Self {
central_jacobian_vec_prod_ndarray_f64(self, fs, p)
central_jacobian_vec_prod_ndarray(self, fs, p)
}

fn forward_jacobian_pert(
&self,
fs: &dyn Fn(&Self) -> Self::OperatorOutput,
pert: &PerturbationVectors,
) -> Self::Jacobian {
forward_jacobian_pert_ndarray_f64(self, fs, pert)
forward_jacobian_pert_ndarray(self, fs, pert)
}

fn central_jacobian_pert(
&self,
fs: &dyn Fn(&Self) -> Self::OperatorOutput,
pert: &PerturbationVectors,
) -> Self::Jacobian {
central_jacobian_pert_ndarray_f64(self, fs, pert)
central_jacobian_pert_ndarray(self, fs, pert)
}

fn forward_hessian(&self, g: &dyn Fn(&Self) -> Self::OperatorOutput) -> Self::Jacobian {
Expand Down Expand Up @@ -825,15 +825,15 @@ mod tests_vec {
#[test]
fn test_forward_diff_vec_f64_trait() {
let grad = x1().forward_diff(&f1);
let res = vec![1.0f64, 2.0];
let res = [1.0f64, 2.0];

for i in 0..2 {
assert!((res[i] - grad[i]).abs() < COMP_ACC)
}

let p = vec![1.0f64, 2.0f64];
let grad = p.forward_diff(&f1);
let res = vec![1.0f64, 4.0];
let res = [1.0f64, 4.0];

for i in 0..2 {
assert!((res[i] - grad[i]).abs() < COMP_ACC)
Expand All @@ -843,15 +843,15 @@ mod tests_vec {
#[test]
fn test_central_diff_vec_f64_trait() {
let grad = x1().central_diff(&f1);
let res = vec![1.0f64, 2.0];
let res = [1.0f64, 2.0];

for i in 0..2 {
assert!((res[i] - grad[i]).abs() < COMP_ACC)
}

let p = vec![1.0f64, 2.0f64];
let grad = p.central_diff(&f1);
let res = vec![1.0f64, 4.0];
let res = [1.0f64, 4.0];

for i in 0..2 {
assert!((res[i] - grad[i]).abs() < COMP_ACC)
Expand Down Expand Up @@ -958,7 +958,7 @@ mod tests_vec {
#[test]
fn test_forward_hessian_vec_prod_vec_f64_trait() {
let hessian = x3().forward_hessian_vec_prod(&g, &p2());
let res = vec![0.0, 6.0, 10.0, 18.0];
let res = [0.0, 6.0, 10.0, 18.0];
// println!("hessian:\n{:#?}", hessian);
// println!("diff:\n{:#?}", diff);
for i in 0..4 {
Expand All @@ -969,7 +969,7 @@ mod tests_vec {
#[test]
fn test_central_hessian_vec_prod_vec_f64_trait() {
let hessian = x3().central_hessian_vec_prod(&g, &p2());
let res = vec![0.0, 6.0, 10.0, 18.0];
let res = [0.0, 6.0, 10.0, 18.0];
// println!("hessian:\n{:#?}", hessian);
// println!("diff:\n{:#?}", diff);
for i in 0..4 {
Expand Down
47 changes: 29 additions & 18 deletions crates/finitediff/src/ndarray_m/diff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,44 @@
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

use num::{Float, FromPrimitive};

use crate::utils::*;
use crate::EPS_F64;

pub fn forward_diff_ndarray_f64(
x: &ndarray::Array1<f64>,
f: &dyn Fn(&ndarray::Array1<f64>) -> f64,
) -> ndarray::Array1<f64> {
pub fn forward_diff_ndarray<F>(
x: &ndarray::Array1<F>,
f: &dyn Fn(&ndarray::Array1<F>) -> F,
) -> ndarray::Array1<F>
where
F: Float,
{
let eps_sqrt = F::epsilon().sqrt();

let fx = (f)(x);
let mut xt = x.clone();
(0..x.len())
.map(|i| {
let fx1 = mod_and_calc(&mut xt, f, i, EPS_F64.sqrt());
(fx1 - fx) / (EPS_F64.sqrt())
let fx1 = mod_and_calc(&mut xt, f, i, eps_sqrt);
(fx1 - fx) / eps_sqrt
})
.collect()
}

pub fn central_diff_ndarray_f64(
x: &ndarray::Array1<f64>,
f: &dyn Fn(&ndarray::Array1<f64>) -> f64,
) -> ndarray::Array1<f64> {
pub fn central_diff_ndarray<F>(
x: &ndarray::Array1<F>,
f: &dyn Fn(&ndarray::Array1<F>) -> F,
) -> ndarray::Array1<F>
where
F: Float + FromPrimitive,
{
let eps_sqrt = F::epsilon().sqrt();

let mut xt = x.clone();
(0..x.len())
.map(|i| {
let fx1 = mod_and_calc(&mut xt, f, i, EPS_F64.sqrt());
let fx2 = mod_and_calc(&mut xt, f, i, -EPS_F64.sqrt());
(fx1 - fx2) / (2.0 * EPS_F64.sqrt())
let fx1 = mod_and_calc(&mut xt, f, i, eps_sqrt);
let fx2 = mod_and_calc(&mut xt, f, i, -eps_sqrt);
(fx1 - fx2) / (F::from_f64(2.0).unwrap() * eps_sqrt)
})
.collect()
}
Expand All @@ -51,15 +62,15 @@ mod tests {
fn test_forward_diff_ndarray_f64() {
let p = ndarray::Array1::from(vec![1.0f64, 1.0f64]);

let grad = forward_diff_ndarray_f64(&p, &f);
let grad = forward_diff_ndarray(&p, &f);
let res = vec![1.0f64, 2.0];

(0..2)
.map(|i| assert!((res[i] - grad[i]).abs() < COMP_ACC))
.count();

let p = ndarray::Array1::from(vec![1.0f64, 2.0f64]);
let grad = forward_diff_ndarray_f64(&p, &f);
let grad = forward_diff_ndarray(&p, &f);
let res = vec![1.0f64, 4.0];

(0..2)
Expand All @@ -70,15 +81,15 @@ mod tests {
fn test_central_diff_ndarray_f64() {
let p = ndarray::Array1::from(vec![1.0f64, 1.0f64]);

let grad = central_diff_ndarray_f64(&p, &f);
let grad = central_diff_ndarray(&p, &f);
let res = vec![1.0f64, 2.0];

(0..2)
.map(|i| assert!((res[i] - grad[i]).abs() < COMP_ACC))
.count();

let p = ndarray::Array1::from(vec![1.0f64, 2.0f64]);
let grad = central_diff_ndarray_f64(&p, &f);
let grad = central_diff_ndarray(&p, &f);
let res = vec![1.0f64, 4.0];

(0..2)
Expand Down
Loading

0 comments on commit 5bdde3e

Please sign in to comment.