Skip to content

Commit

Permalink
switch order of generic parameters on IterState to make F last
Browse files Browse the repository at this point in the history
  • Loading branch information
gmilleramilar committed Apr 6, 2023
1 parent 5afc596 commit 00822bb
Show file tree
Hide file tree
Showing 27 changed files with 147 additions and 156 deletions.
6 changes: 3 additions & 3 deletions argmin/src/core/state/iterstate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use std::collections::HashMap;
/// * termination status
#[derive(Clone, Default, Debug, Eq, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct IterState<P, G, J, H, F, R> {
pub struct IterState<P, G, J, H, R, F> {
/// Current parameter vector
pub param: Option<P>,
/// Previous parameter vector
Expand Down Expand Up @@ -86,7 +86,7 @@ pub struct IterState<P, G, J, H, F, R> {
pub termination_status: TerminationStatus,
}

impl<P, G, J, H, F, R> IterState<P, G, J, H, F, R>
impl<P, G, J, H, R, F> IterState<P, G, J, H, R, F>
where
Self: State<Float = F>,
F: ArgminFloat,
Expand Down Expand Up @@ -904,7 +904,7 @@ where
}
}

impl<P, G, J, H, F, R> State for IterState<P, G, J, H, F, R>
impl<P, G, J, H, R, F> State for IterState<P, G, J, H, R, F>
where
P: Clone,
F: ArgminFloat,
Expand Down
6 changes: 3 additions & 3 deletions argmin/src/core/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,14 +326,14 @@ impl TestSolver {
}
}

impl<O> Solver<O, IterState<Vec<f64>, (), (), (), f64, ()>> for TestSolver {
impl<O> Solver<O, IterState<Vec<f64>, (), (), (), (), f64>> for TestSolver {
const NAME: &'static str = "TestSolver";

fn next_iter(
&mut self,
_problem: &mut Problem<O>,
state: IterState<Vec<f64>, (), (), (), f64, ()>,
) -> Result<(IterState<Vec<f64>, (), (), (), f64, ()>, Option<KV>), Error> {
state: IterState<Vec<f64>, (), (), (), (), f64>,
) -> Result<(IterState<Vec<f64>, (), (), (), (), f64>, Option<KV>), Error> {
Ok((state, None))
}
}
10 changes: 5 additions & 5 deletions argmin/src/solver/brent/brentopt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ impl<F: ArgminFloat> BrentOpt<F> {
}
}

impl<O, F, R> Solver<O, IterState<F, (), (), (), F, R>> for BrentOpt<F>
impl<O, F> Solver<O, IterState<F, (), (), (), (), F>> for BrentOpt<F>
where
O: CostFunction<Param = F, Output = F>,
F: ArgminFloat,
Expand All @@ -108,8 +108,8 @@ where
&mut self,
problem: &mut Problem<O>,
// BrentOpt maintains its own state
state: IterState<F, (), (), (), F, R>,
) -> Result<(IterState<F, (), (), (), F, R>, Option<KV>), Error> {
state: IterState<F, (), (), (), (), F>,
) -> Result<(IterState<F, (), (), (), (), F>, Option<KV>), Error> {
let u = self.a + self.c * (self.b - self.a);
self.v = u;
self.w = u;
Expand All @@ -125,8 +125,8 @@ where
&mut self,
problem: &mut Problem<O>,
// BrentOpt maintains its own state
state: IterState<F, (), (), (), F, R>,
) -> Result<(IterState<F, (), (), (), F, R>, Option<KV>), Error> {
state: IterState<F, (), (), (), (), F>,
) -> Result<(IterState<F, (), (), (), (), F>, Option<KV>), Error> {
let two = float!(2f64);
let tol = self.eps * self.x.abs() + self.t;
let m = (self.a + self.b) / two;
Expand Down
10 changes: 5 additions & 5 deletions argmin/src/solver/brent/brentroot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl<F: ArgminFloat> BrentRoot<F> {
}
}

impl<O, F, R> Solver<O, IterState<F, (), (), (), F, R>> for BrentRoot<F>
impl<O, F> Solver<O, IterState<F, (), (), (), (), F>> for BrentRoot<F>
where
O: CostFunction<Param = F, Output = F>,
F: ArgminFloat,
Expand All @@ -86,8 +86,8 @@ where
&mut self,
problem: &mut Problem<O>,
// BrentRoot maintains its own state
state: IterState<F, (), (), (), F, R>,
) -> Result<(IterState<F, (), (), (), F, R>, Option<KV>), Error> {
state: IterState<F, (), (), (), (), F>,
) -> Result<(IterState<F, (), (), (), (), F>, Option<KV>), Error> {
self.fa = problem.cost(&self.a)?;
self.fb = problem.cost(&self.b)?;
if self.fa * self.fb > float!(0.0) {
Expand All @@ -101,8 +101,8 @@ where
&mut self,
problem: &mut Problem<O>,
// BrentRoot maintains its own state
state: IterState<F, (), (), (), F, R>,
) -> Result<(IterState<F, (), (), (), F, R>, Option<KV>), Error> {
state: IterState<F, (), (), (), (), F>,
) -> Result<(IterState<F, (), (), (), (), F>, Option<KV>), Error> {
if (self.fb > float!(0.0) && self.fc > float!(0.0))
|| self.fb < float!(0.0) && self.fc < float!(0.0)
{
Expand Down
10 changes: 5 additions & 5 deletions argmin/src/solver/conjugategradient/cg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ where
}
}

impl<P, O, F, R> Solver<O, IterState<P, (), (), (), F, R>> for ConjugateGradient<P, F>
impl<P, O, F> Solver<O, IterState<P, (), (), (), (), F>> for ConjugateGradient<P, F>
where
O: Operator<Param = P, Output = P>,
P: Clone
Expand All @@ -106,8 +106,8 @@ where
fn init(
&mut self,
problem: &mut Problem<O>,
state: IterState<P, (), (), (), F, R>,
) -> Result<(IterState<P, (), (), (), F, R>, Option<KV>), Error> {
state: IterState<P, (), (), (), (), F>,
) -> Result<(IterState<P, (), (), (), (), F>, Option<KV>), Error> {
let init_param = state.get_param().ok_or_else(argmin_error_closure!(
NotInitialized,
concat!(
Expand All @@ -127,8 +127,8 @@ where
fn next_iter(
&mut self,
problem: &mut Problem<O>,
state: IterState<P, (), (), (), F, R>,
) -> Result<(IterState<P, (), (), (), F, R>, Option<KV>), Error> {
state: IterState<P, (), (), (), (), F>,
) -> Result<(IterState<P, (), (), (), (), F>, Option<KV>), Error> {
let p = self.p.take().ok_or_else(argmin_error_closure!(
PotentialBug,
"`ConjugateGradient`: Field `p` not set"
Expand Down
13 changes: 6 additions & 7 deletions argmin/src/solver/conjugategradient/nonlinear_cg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ where
}
}

impl<O, P, G, L, B, F, R> Solver<O, IterState<P, G, (), (), F, R>>
impl<O, P, G, L, B, F> Solver<O, IterState<P, G, (), (), (), F>>
for NonlinearConjugateGradient<P, L, B, F>
where
O: CostFunction<Param = P, Output = F> + Gradient<Param = P, Gradient = G>,
Expand All @@ -129,18 +129,17 @@ where
+ ArgminMul<F, P>
+ ArgminDot<G, F>
+ ArgminL2Norm<F>,
L: Clone + LineSearch<P, F> + Solver<O, IterState<P, G, (), (), F, R>>,
L: Clone + LineSearch<P, F> + Solver<O, IterState<P, G, (), (), (), F>>,
B: NLCGBetaUpdate<G, P, F>,
F: ArgminFloat,
R: Clone + SerializeAlias + DeserializeOwnedAlias,
{
const NAME: &'static str = "Nonlinear Conjugate Gradient";

fn init(
&mut self,
problem: &mut Problem<O>,
state: IterState<P, G, (), (), F, R>,
) -> Result<(IterState<P, G, (), (), F, R>, Option<KV>), Error> {
state: IterState<P, G, (), (), (), F>,
) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
let param = state.get_param().ok_or_else(argmin_error_closure!(
NotInitialized,
concat!(
Expand All @@ -157,8 +156,8 @@ where
fn next_iter(
&mut self,
problem: &mut Problem<O>,
mut state: IterState<P, G, (), (), F, R>,
) -> Result<(IterState<P, G, (), (), F, R>, Option<KV>), Error> {
mut state: IterState<P, G, (), (), (), F>,
) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
let p = self.p.as_ref().ok_or_else(argmin_error_closure!(
PotentialBug,
"`NonlinearConjugateGradient`: Field `p` not set"
Expand Down
10 changes: 5 additions & 5 deletions argmin/src/solver/gaussnewton/gaussnewton_linesearch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ impl<L, F: ArgminFloat> GaussNewtonLS<L, F> {
}
}

impl<O, L, F, P, G, J, U, R> Solver<O, IterState<P, G, J, (), F, R>> for GaussNewtonLS<L, F>
impl<O, L, F, P, G, J, U, R> Solver<O, IterState<P, G, J, (), R, F>> for GaussNewtonLS<L, F>
where
O: Operator<Param = P, Output = U> + Jacobian<Param = P, Jacobian = J>,
P: Clone + SerializeAlias + DeserializeOwnedAlias + ArgminMul<F, P>,
Expand All @@ -95,7 +95,7 @@ where
+ ArgminDot<J, J>
+ ArgminDot<G, P>
+ ArgminDot<U, G>,
L: Clone + LineSearch<P, F> + Solver<LineSearchProblem<O, F>, IterState<P, G, (), (), F, R>>,
L: Clone + LineSearch<P, F> + Solver<LineSearchProblem<O, F>, IterState<P, G, (), (), R, F>>,
F: ArgminFloat,
R: Clone + SerializeAlias + DeserializeOwnedAlias,
{
Expand All @@ -104,8 +104,8 @@ where
fn next_iter(
&mut self,
problem: &mut Problem<O>,
mut state: IterState<P, G, J, (), F, R>,
) -> Result<(IterState<P, G, J, (), F, R>, Option<KV>), Error> {
mut state: IterState<P, G, J, (), R, F>,
) -> Result<(IterState<P, G, J, (), R, F>, Option<KV>), Error> {
let param = state.take_param().ok_or_else(argmin_error_closure!(
NotInitialized,
concat!(
Expand Down Expand Up @@ -167,7 +167,7 @@ where
))
}

fn terminate(&mut self, state: &IterState<P, G, J, (), F, R>) -> TerminationStatus {
fn terminate(&mut self, state: &IterState<P, G, J, (), R, F>) -> TerminationStatus {
if (state.get_prev_cost() - state.get_cost()).abs() < self.tol {
return TerminationStatus::Terminated(TerminationReason::SolverConverged);
}
Expand Down
8 changes: 4 additions & 4 deletions argmin/src/solver/gaussnewton/gaussnewton_method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ impl<F: ArgminFloat> Default for GaussNewton<F> {
}
}

impl<O, F, P, J, U, R> Solver<O, IterState<P, (), J, (), F, R>> for GaussNewton<F>
impl<O, F, P, J, U, R> Solver<O, IterState<P, (), J, (), R, F>> for GaussNewton<F>
where
O: Operator<Param = P, Output = U> + Jacobian<Param = P, Jacobian = J>,
P: Clone + ArgminSub<P, P> + ArgminMul<F, P>,
Expand All @@ -127,8 +127,8 @@ where
fn next_iter(
&mut self,
problem: &mut Problem<O>,
state: IterState<P, (), J, (), F, R>,
) -> Result<(IterState<P, (), J, (), F, R>, Option<KV>), Error> {
state: IterState<P, (), J, (), R, F>,
) -> Result<(IterState<P, (), J, (), R, F>, Option<KV>), Error> {
let param = state.get_param().ok_or_else(argmin_error_closure!(
NotInitialized,
concat!(
Expand All @@ -151,7 +151,7 @@ where
Ok((state.param(new_param).cost(residuals.l2_norm()), None))
}

fn terminate(&mut self, state: &IterState<P, (), J, (), F, R>) -> TerminationStatus {
fn terminate(&mut self, state: &IterState<P, (), J, (), R, F>) -> TerminationStatus {
if (state.get_prev_cost() - state.get_cost()).abs() < self.tol {
return TerminationStatus::Terminated(TerminationReason::SolverConverged);
}
Expand Down
12 changes: 6 additions & 6 deletions argmin/src/solver/goldensectionsearch/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ where
}
}

impl<O, F, R> Solver<O, IterState<F, (), (), (), F, R>> for GoldenSectionSearch<F>
impl<O, F> Solver<O, IterState<F, (), (), (), (), F>> for GoldenSectionSearch<F>
where
O: CostFunction<Param = F, Output = F>,
F: ArgminFloat,
Expand All @@ -144,8 +144,8 @@ where
fn init(
&mut self,
problem: &mut Problem<O>,
mut state: IterState<F, (), (), (), F, R>,
) -> Result<(IterState<F, (), (), (), F, R>, Option<KV>), Error> {
mut state: IterState<F, (), (), (), (), F>,
) -> Result<(IterState<F, (), (), (), (), F>, Option<KV>), Error> {
let init_estimate = state.take_param().ok_or_else(argmin_error_closure!(
NotInitialized,
concat!(
Expand Down Expand Up @@ -181,8 +181,8 @@ where
fn next_iter(
&mut self,
problem: &mut Problem<O>,
state: IterState<F, (), (), (), F, R>,
) -> Result<(IterState<F, (), (), (), F, R>, Option<KV>), Error> {
state: IterState<F, (), (), (), (), F>,
) -> Result<(IterState<F, (), (), (), (), F>, Option<KV>), Error> {
if self.f2 < self.f1 {
self.x0 = self.x1;
self.x1 = self.x2;
Expand All @@ -203,7 +203,7 @@ where
}
}

fn terminate(&mut self, _state: &IterState<F, (), (), (), F, R>) -> TerminationStatus {
fn terminate(&mut self, _state: &IterState<F, (), (), (), (), F>) -> TerminationStatus {
if self.tolerance * (self.x1.abs() + self.x2.abs()) >= (self.x3 - self.x0).abs() {
return TerminationStatus::Terminated(TerminationReason::SolverConverged);
}
Expand Down
9 changes: 4 additions & 5 deletions argmin/src/solver/gradientdescent/steepestdescent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,21 @@ impl<L> SteepestDescent<L> {
}
}

impl<O, L, P, G, F, R> Solver<O, IterState<P, G, (), (), F, R>> for SteepestDescent<L>
impl<O, L, P, G, F> Solver<O, IterState<P, G, (), (), (), F>> for SteepestDescent<L>
where
O: CostFunction<Param = P, Output = F> + Gradient<Param = P, Gradient = G>,
P: Clone + SerializeAlias + DeserializeOwnedAlias,
G: Clone + SerializeAlias + DeserializeOwnedAlias + ArgminMul<F, P>,
L: Clone + LineSearch<P, F> + Solver<O, IterState<P, G, (), (), F, R>>,
L: Clone + LineSearch<P, F> + Solver<O, IterState<P, G, (), (), (), F>>,
F: ArgminFloat,
R: Clone + SerializeAlias + DeserializeOwnedAlias,
{
const NAME: &'static str = "Steepest Descent";

fn next_iter(
&mut self,
problem: &mut Problem<O>,
mut state: IterState<P, G, (), (), F, R>,
) -> Result<(IterState<P, G, (), (), F, R>, Option<KV>), Error> {
mut state: IterState<P, G, (), (), (), F>,
) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
let param_new = state.take_param().ok_or_else(argmin_error_closure!(
NotInitialized,
concat!(
Expand Down
6 changes: 3 additions & 3 deletions argmin/src/solver/landweber/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ impl<F> Landweber<F> {
}
}

impl<O, F, P, G, R> Solver<O, IterState<P, G, (), (), F, R>> for Landweber<F>
impl<O, F, P, G> Solver<O, IterState<P, G, (), (), (), F>> for Landweber<F>
where
O: Gradient<Param = P, Gradient = G>,
P: Clone + ArgminScaledSub<G, F, P>,
Expand All @@ -74,8 +74,8 @@ where
fn next_iter(
&mut self,
problem: &mut Problem<O>,
mut state: IterState<P, G, (), (), F, R>,
) -> Result<(IterState<P, G, (), (), F, R>, Option<KV>), Error> {
mut state: IterState<P, G, (), (), (), F>,
) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
let param = state.take_param().ok_or_else(argmin_error_closure!(
NotInitialized,
concat!(
Expand Down
21 changes: 10 additions & 11 deletions argmin/src/solver/linesearch/backtracking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,14 @@ where
F: ArgminFloat,
{
/// Perform a single backtracking step
fn backtracking_step<O, R>(
fn backtracking_step<O>(
&self,
problem: &mut Problem<O>,
state: IterState<P, G, (), (), F, R>,
) -> Result<IterState<P, G, (), (), F, R>, Error>
state: IterState<P, G, (), (), (), F>,
) -> Result<IterState<P, G, (), (), (), F>, Error>
where
O: CostFunction<Param = P, Output = F> + Gradient<Param = P, Gradient = G>,
IterState<P, G, (), (), F, R>: State<Float = F>,
IterState<P, G, (), (), (), F>: State<Float = F>,
{
let new_param = self
.init_param
Expand Down Expand Up @@ -174,8 +174,7 @@ where
}
}

impl<O, P, G, L, F, R> Solver<O, IterState<P, G, (), (), F, R>>
for BacktrackingLineSearch<P, G, L, F>
impl<O, P, G, L, F> Solver<O, IterState<P, G, (), (), (), F>> for BacktrackingLineSearch<P, G, L, F>
where
P: Clone + SerializeAlias + ArgminScaledAdd<P, F, P>,
G: SerializeAlias + ArgminScaledAdd<P, F, P>,
Expand All @@ -188,8 +187,8 @@ where
fn init(
&mut self,
problem: &mut Problem<O>,
mut state: IterState<P, G, (), (), F, R>,
) -> Result<(IterState<P, G, (), (), F, R>, Option<KV>), Error> {
mut state: IterState<P, G, (), (), (), F>,
) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
if self.search_direction.is_none() {
return Err(argmin_error!(
NotInitialized,
Expand Down Expand Up @@ -227,14 +226,14 @@ where
fn next_iter(
&mut self,
problem: &mut Problem<O>,
state: IterState<P, G, (), (), F, R>,
) -> Result<(IterState<P, G, (), (), F, R>, Option<KV>), Error> {
state: IterState<P, G, (), (), (), F>,
) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
self.alpha = self.alpha * self.rho;
let state = self.backtracking_step(problem, state)?;
Ok((state, None))
}

fn terminate(&mut self, state: &IterState<P, G, (), (), F, R>) -> TerminationStatus {
fn terminate(&mut self, state: &IterState<P, G, (), (), (), F>) -> TerminationStatus {
if self.condition.evaluate_condition(
state.cost,
state.get_gradient(),
Expand Down
Loading

0 comments on commit 00822bb

Please sign in to comment.