Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan-k authored Nov 4, 2023
2 parents 00822bb + 13a39ea commit bdda07f
Show file tree
Hide file tree
Showing 14 changed files with 317 additions and 50 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
|
<a href="https://argmin-rs.github.io/argmin/argmin/">Docs (main branch)</a>
|
<a href="https://github.com/argmin-rs/argmin/tree/v0.5.0/examples">Examples (latest release)</a>
<a href="https://github.com/argmin-rs/argmin/tree/argmin-v0.8.1/argmin/examples">Examples (latest release)</a>
|
<a href="https://github.com/argmin-rs/argmin/tree/main/argmin/examples">Examples (main branch)</a>
</p>
Expand Down
18 changes: 18 additions & 0 deletions argmin-math/src/ndarray_m/inv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ macro_rules! make_inv {
Ok(<Self as Inverse>::inv(&self)?)
}
}

// inverse for scalars (1d solvers)
impl ArgminInv<$t> for $t {
#[inline]
fn inv(&self) -> Result<$t, Error> {
Ok(1.0 / self)
}
}
};
}

Expand Down Expand Up @@ -60,6 +68,16 @@ mod tests {
}
}
}

item! {
#[test]
fn [<test_inv_scalar_ $t>]() {
let a = 2.0;
let target = 0.5;
let res = <$t as ArgminInv<$t>>::inv(&a).unwrap();
assert!(((res - target) as f64).abs() < 0.000001);
}
}
};
}

Expand Down
10 changes: 9 additions & 1 deletion argmin/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ serde1 = ["serde", "serde_json", "rand/serde1", "bincode", "slog-json", "rand_xo
_ndarrayl = ["argmin-math/ndarray_latest-serde", "argmin-math/_dev_linalg_latest"]
_nalgebral = ["argmin-math/nalgebra_latest-serde"]
# When adding new features, please consider adding them to either `full` (for users)
# or `_full_dev` (only for local development, tesing and computing test coverage).
# or `_full_dev` (only for local development, testing and computing test coverage).
full = ["default", "slog-logger", "serde1", "ctrlc"]
_full_dev = ["full", "_ndarrayl", "_nalgebral"]

Expand Down Expand Up @@ -129,6 +129,10 @@ required-features = ["argmin-math/nalgebra_latest-serde", "slog-logger"]
name = "morethuente"
required-features = ["slog-logger"]

[[example]]
name = "neldermead-cubic"
required-features = ["slog-logger"]

[[example]]
name = "neldermead"
required-features = ["argmin-math/ndarray_latest-serde", "slog-logger"]
Expand Down Expand Up @@ -177,6 +181,10 @@ required-features = ["argmin-math/ndarray_latest-serde", "slog-logger"]
name = "steepestdescent"
required-features = ["slog-logger"]

[[example]]
name = "steepestdescent_manifold"
required-features = ["slog-logger"]

[[example]]
name = "trustregion_nd"
required-features = ["argmin-math/ndarray_latest-serde", "slog-logger"]
Expand Down
127 changes: 127 additions & 0 deletions argmin/examples/neldermead-cubic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// Copyright 2018-2022 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

//! A (hopefully) simple example of using Nelder-Mead to find the roots of a
//! cubic polynomial.
//!
//! You can run this example with:
//! `cargo run --example neldermead-cubic --features slog-logger`
use argmin::core::observers::{ObserverMode, SlogLogger};
use argmin::core::{CostFunction, Error, Executor, State};
use argmin::solver::neldermead::NelderMead;

/// Coefficients describing a cubic `f(x) = ax^3 + bx^2 + cx + d`
#[derive(Clone, Copy)]
struct Cubic {
/// Coefficient of the `x^3` term
a: f64,
/// Coefficient of the `x^2` term
b: f64,
/// Coefficient of the `x` term
c: f64,
/// Coefficient of the `x^0` term
d: f64,
}

impl Cubic {
/// Evaluate the cubic at `x`.
fn eval(self, x: f64) -> f64 {
self.a * x.powi(3) + self.b * x.powi(2) + self.c * x + self.d
}
}

impl CostFunction for Cubic {
type Param = f64;
type Output = f64;

fn cost(&self, p: &Self::Param) -> Result<Self::Output, Error> {
// The cost function is the evaluation of the polynomial with our
// parameter, squared. The parameter is a guess of `x`, and the
// objective is to minimize `x` (i.e. find a polynomial root). The
// square value can be considered an error. We want the error to (1)
// always be positive and (2) bigger the further it is from a polynomial
// root.
Ok(self.eval(*p).powi(2))
}
}

fn run() -> Result<(), Error> {
// Define the cost function. This needs to be something with an
// implementation of `CostFunction`; in this case, the impl is right
// above. Here, our cubic is `(x-2)(x+2)(x-5)`; see
// <https://www.wolframalpha.com/input?i=%28x-2%29%28x%2B2%29%28x-5%29> for
// more info.
let cost = Cubic {
a: 1.0,
b: -5.0,
c: -4.0,
d: 20.0,
};

// Let's find a root of the cubic (+5).
{
// Set up solver -- note that the proper choice of the vertices is very
// important! This example should find 5, because our vertices are 6 and 7.
let solver = NelderMead::new(vec![6.0, 7.0]).with_sd_tolerance(0.0001)?;

// Run solver
let res = Executor::new(cost, solver)
.configure(|state| state.max_iters(100))
.add_observer(SlogLogger::term(), ObserverMode::Always)
.run()?;

// Wait a second (lets the logger flush everything before printing again)
std::thread::sleep(std::time::Duration::from_secs(1));

// Print result
println!(
"Polynomial root: {}",
res.state.get_best_param().expect("Found a root")
);
}

// Now find -2.
{
let solver = NelderMead::new(vec![-3.0, -4.0]).with_sd_tolerance(0.0001)?;
let res = Executor::new(cost, solver)
.configure(|state| state.max_iters(100))
.add_observer(SlogLogger::term(), ObserverMode::Always)
.run()?;
std::thread::sleep(std::time::Duration::from_secs(1));
println!("{res}");
println!(
"Polynomial root: {}",
res.state.get_best_param().expect("Found a root")
);
}

// This example will find +2, even though it might look like we're trying to
// find +5.
{
let solver = NelderMead::new(vec![4.0, 6.0]).with_sd_tolerance(0.0001)?;
let res = Executor::new(cost, solver)
.configure(|state| state.max_iters(100))
.add_observer(SlogLogger::term(), ObserverMode::Always)
.run()?;
std::thread::sleep(std::time::Duration::from_secs(1));
println!("{res}");
println!(
"Polynomial root: {}",
res.state.get_best_param().expect("Found a root")
);
}

Ok(())
}

fn main() {
if let Err(ref e) = run() {
println!("{e}");
std::process::exit(1);
}
}
2 changes: 1 addition & 1 deletion argmin/examples/steepestdescent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl Gradient for Rosenbrock {
}

fn run() -> Result<(), Error> {
// Define cost function (must implement `ArgminOperator`)
// Define cost function (must implement `CostFunction` and `Gradient`)
let cost = Rosenbrock { a: 1.0, b: 100.0 };

// Define initial parameter vector
Expand Down
94 changes: 94 additions & 0 deletions argmin/examples/steepestdescent_manifold.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright 2018-2022 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

#![allow(unused_imports)]

use argmin::core::observers::{ObserverMode, SlogLogger};
use argmin::core::{CostFunction, Error, Executor, Gradient};
use argmin::solver::gradientdescent::SteepestDescent;
use argmin::solver::linesearch::condition::{ArmijoCondition, LineSearchCondition};
use argmin::solver::linesearch::BacktrackingLineSearch;
use argmin_math::ArgminScaledAdd;

use serde::{Deserialize, Serialize};

#[derive(Clone, Copy, Debug)]
struct ClosestPointOnCircle {
x: f64,
y: f64,
}

#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
struct CirclePoint {
angle: f64,
}

impl CostFunction for ClosestPointOnCircle {
type Param = CirclePoint;
type Output = f64;

fn cost(&self, p: &Self::Param) -> Result<Self::Output, Error> {
let x_circ = p.angle.cos();
let y_circ = p.angle.sin();
let x_diff = x_circ - self.x;
let y_diff = y_circ - self.y;
Ok(x_diff.powi(2) + y_diff.powi(2))
}
}

impl Gradient for ClosestPointOnCircle {
type Param = CirclePoint;
type Gradient = f64;

fn gradient(&self, p: &Self::Param) -> Result<Self::Gradient, Error> {
Ok(2.0 * (p.angle.cos() - self.x) * (-p.angle.sin())
+ 2.0 * (p.angle.sin() - self.y) * p.angle.cos())
}
}

impl ArgminScaledAdd<f64, f64, CirclePoint> for CirclePoint {
fn scaled_add(&self, alpha: &f64, delta: &f64) -> Self {
CirclePoint {
angle: self.angle + alpha * delta,
}
}
}

fn run() -> Result<(), Error> {
// Define cost function (must implement `CostFunction` and `Gradient`)
let cost = ClosestPointOnCircle { x: 1.0, y: 1.0 };

// Define initial parameter vector
let init_param = CirclePoint { angle: 0.0 };

// Pick a line search.
let cond = ArmijoCondition::new(0.5)?;
let linesearch = BacktrackingLineSearch::new(cond);

// Set up solver
let solver = SteepestDescent::new(linesearch);

// Run solver
let res = Executor::new(cost, solver)
.configure(|state| state.param(init_param).max_iters(10))
.add_observer(SlogLogger::term(), ObserverMode::Always)
.run()?;

// Wait a second (lets the logger flush everything first)
std::thread::sleep(std::time::Duration::from_secs(1));

// print result
println!("{res}");
Ok(())
}

fn main() {
if let Err(ref e) = run() {
println!("{e}");
std::process::exit(1);
}
}
4 changes: 2 additions & 2 deletions argmin/src/solver/conjugategradient/cg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,10 @@ mod tests {

assert_relative_eq!(b[0], 1.0, epsilon = f64::EPSILON);
assert_relative_eq!(b[1], 2.0, epsilon = f64::EPSILON);
let r0 = vec![2.0f64, 2.0];
let r0 = [2.0f64, 2.0];
assert_relative_eq!(r0[0], r.as_ref().unwrap()[0], epsilon = f64::EPSILON);
assert_relative_eq!(r0[1], r.as_ref().unwrap()[1], epsilon = f64::EPSILON);
let pp = vec![-2.0f64, -2.0];
let pp = [-2.0f64, -2.0];
assert_relative_eq!(pp[0], p.as_ref().unwrap()[0], epsilon = f64::EPSILON);
assert_relative_eq!(pp[1], p.as_ref().unwrap()[1], epsilon = f64::EPSILON);
assert_relative_eq!(rtr, 8.0, epsilon = f64::EPSILON);
Expand Down
39 changes: 28 additions & 11 deletions argmin/src/solver/gradientdescent/steepestdescent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

use crate::core::{
ArgminFloat, CostFunction, DeserializeOwnedAlias, Error, Executor, Gradient, IterState,
LineSearch, OptimizationResult, Problem, SerializeAlias, Solver, KV,
LineSearch, OptimizationResult, Problem, SerializeAlias, Solver, State, KV,
};
use argmin_math::ArgminMul;
#[cfg(feature = "serde1")]
Expand Down Expand Up @@ -54,24 +54,27 @@ impl<O, L, P, G, F> Solver<O, IterState<P, G, (), (), (), F>> for SteepestDescen
where
O: CostFunction<Param = P, Output = F> + Gradient<Param = P, Gradient = G>,
P: Clone + SerializeAlias + DeserializeOwnedAlias,
G: Clone + SerializeAlias + DeserializeOwnedAlias + ArgminMul<F, P>,
L: Clone + LineSearch<P, F> + Solver<O, IterState<P, G, (), (), (), F>>,
G: Clone + SerializeAlias + DeserializeOwnedAlias + ArgminMul<F, G>,
L: Clone + LineSearch<G, F> + Solver<O, IterState<P, G, (), (), (), F>>,
F: ArgminFloat,
{
const NAME: &'static str = "Steepest Descent";

fn next_iter(
&mut self,
problem: &mut Problem<O>,
mut state: IterState<P, G, (), (), (), F>,
state: IterState<P, G, (), (), (), F>,
) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
let param_new = state.take_param().ok_or_else(argmin_error_closure!(
NotInitialized,
concat!(
"`SteepestDescent` requires an initial parameter vector. ",
"Please provide an initial guess via `Executor`s `configure` method."
)
))?;
let param_new = state
.get_param()
.ok_or_else(argmin_error_closure!(
NotInitialized,
concat!(
"`SteepestDescent` requires an initial parameter vector. ",
"Please provide an initial guess via `Executor`s `configure` method."
)
))?
.clone();
let new_cost = problem.cost(&param_new)?;
let new_grad = problem.gradient(&param_new)?;

Expand Down Expand Up @@ -153,6 +156,20 @@ mod tests {
);
}

#[test]
fn test_next_iter_prev_param_not_erased() {
let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
let mut sd = SteepestDescent::new(linesearch);
let (state, _kv) = sd
.next_iter(
&mut Problem::new(TestProblem::new()),
IterState::new().param(vec![1.0, 2.0]),
)
.unwrap();
state.prev_param.unwrap();
}

#[test]
fn test_next_iter_regression() {
struct SDProblem {}
Expand Down
Loading

0 comments on commit bdda07f

Please sign in to comment.