Skip to content

Commit

Permalink
Continued restructuring, removed FiniteDiff trait, added array impl
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan-k committed Mar 5, 2024
1 parent 1f17b97 commit a55fa5d
Show file tree
Hide file tree
Showing 13 changed files with 2,120 additions and 777 deletions.
100 changes: 100 additions & 0 deletions crates/finitediff/src/array/diff.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright 2018-2020 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

use num::Float;
use num::FromPrimitive;

use crate::utils::mod_and_calc_const;

pub fn forward_diff_const<const N: usize, F>(x: &[F; N], f: &dyn Fn(&[F; N]) -> F) -> [F; N]
where
F: Float + FromPrimitive,
{
let fx = (f)(x);
let mut xt = *x;
let eps_sqrt = F::epsilon().sqrt();
let mut out = [F::from_f64(0.0).unwrap(); N];
out.iter_mut()
.enumerate()
.map(|(i, o)| {
let fx1 = mod_and_calc_const(&mut xt, f, i, eps_sqrt);
*o = (fx1 - fx) / eps_sqrt;
})
.count();
out
}

pub fn central_diff_const<const N: usize, F>(x: &[F; N], f: &dyn Fn(&[F; N]) -> F) -> [F; N]
where
F: Float + FromPrimitive,
{
let mut xt = *x;
let eps_cbrt = F::epsilon().cbrt();
let mut out = [F::from_f64(0.0).unwrap(); N];
out.iter_mut()
.enumerate()
.map(|(i, o)| {
let fx1 = mod_and_calc_const(&mut xt, f, i, eps_cbrt);
let fx2 = mod_and_calc_const(&mut xt, f, i, -eps_cbrt);
*o = (fx1 - fx2) / (F::from_f64(2.0).unwrap() * eps_cbrt);
})
.count();
out
}

#[cfg(test)]
mod tests {
use super::*;

const COMP_ACC: f64 = 1e-6;

fn f(x: &[f64; 2]) -> f64 {
x[0] + x[1].powi(2)
}

fn f2(x: &[f64; 2]) -> f64 {
x[0] + x[1].powi(2)
}

#[test]
fn test_forward_diff_const_f64() {
let p = [1.0f64, 1.0f64];
let grad = forward_diff_const(&p, &f2);
let res = [1.0f64, 2.0];

(0..2)
.map(|i| assert!((res[i] - grad[i]).abs() < COMP_ACC))
.count();

let p = [1.0f64, 2.0f64];
let grad = forward_diff_const(&p, &f2);
let res = [1.0f64, 4.0];

(0..2)
.map(|i| assert!((res[i] - grad[i]).abs() < COMP_ACC))
.count();
}

#[test]
fn test_central_diff_vec_f64() {
let p = [1.0f64, 1.0f64];
let grad = central_diff_const(&p, &f);
let res = [1.0f64, 2.0];

(0..2)
.map(|i| assert!((res[i] - grad[i]).abs() < COMP_ACC))
.count();

let p = [1.0f64, 2.0f64];
let grad = central_diff_const(&p, &f);
let res = [1.0f64, 4.0];

(0..2)
.map(|i| assert!((res[i] - grad[i]).abs() < COMP_ACC))
.count();
}
}
Loading

0 comments on commit a55fa5d

Please sign in to comment.