From f44b41935cffde17e9beef8a7ca6b92d38c6e0db Mon Sep 17 00:00:00 2001 From: Stefan Kroboth Date: Mon, 22 Jan 2024 08:18:12 +0100 Subject: [PATCH] Use the new residuals handling in CG --- argmin/src/solver/conjugategradient/cg.rs | 79 +++++++++-------------- 1 file changed, 30 insertions(+), 49 deletions(-) diff --git a/argmin/src/solver/conjugategradient/cg.rs b/argmin/src/solver/conjugategradient/cg.rs index c59d03cd7..ac517930f 100644 --- a/argmin/src/solver/conjugategradient/cg.rs +++ b/argmin/src/solver/conjugategradient/cg.rs @@ -32,8 +32,6 @@ use serde::{Deserialize, Serialize}; pub struct ConjugateGradient { /// b (right hand side of `A * x = b`) b: P, - /// Residual - r: Option

, /// p p: Option

, /// previous p @@ -60,7 +58,6 @@ where pub fn new(b: P) -> Self { ConjugateGradient { b, - r: None, p: None, p_prev: None, rtr: F::nan(), @@ -87,15 +84,11 @@ where } } -impl Solver> for ConjugateGradient +impl Solver> for ConjugateGradient where O: Operator, - P: Clone - + ArgminDot - + ArgminSub - + ArgminScaledAdd - + ArgminConj - + ArgminMul, + P: Clone + ArgminDot + ArgminSub + ArgminScaledAdd + ArgminConj, + R: ArgminMul + ArgminMul + ArgminConj + ArgminDot + ArgminScaledAdd, F: ArgminFloat + ArgminL2Norm, { const NAME: &'static str = "Conjugate Gradient"; @@ -103,8 +96,8 @@ where fn init( &mut self, problem: &mut Problem, - state: IterState, - ) -> Result<(IterState, Option), Error> { + state: IterState, + ) -> Result<(IterState, Option), Error> { let init_param = state.get_param().ok_or_else(argmin_error_closure!( NotInitialized, concat!( @@ -113,26 +106,25 @@ where ) ))?; let ap = problem.apply(init_param)?; - let r0 = self.b.sub(&ap).mul(&(float!(-1.0))); + let r0: R = self.b.sub(&ap).mul(&(float!(-1.0))); self.p = Some(r0.mul(&(float!(-1.0)))); self.rtr = r0.dot(&r0.conj()); - self.r = Some(r0); - Ok((state, None)) + Ok((state.residuals(r0), None)) } /// Perform one iteration of CG algorithm fn next_iter( &mut self, problem: &mut Problem, - state: IterState, - ) -> Result<(IterState, Option), Error> { + mut state: IterState, + ) -> Result<(IterState, Option), Error> { let p = self.p.take().ok_or_else(argmin_error_closure!( PotentialBug, "`ConjugateGradient`: Field `p` not set" ))?; - let r = self.r.as_ref().ok_or_else(argmin_error_closure!( + let r = state.take_residuals().ok_or_else(argmin_error_closure!( PotentialBug, - "`ConjugateGradient`: Field `r` not set" + "`ConjugateGradient`: Residuals in `state` not set" ))?; let apk = problem.apply(&p)?; @@ -146,15 +138,14 @@ where let rtr_n = r.dot(&r.conj()); let beta = rtr_n.div(self.rtr); self.rtr = rtr_n; - let p_n = r.mul(&(float!(-1.0))).scaled_add(&beta, &p); + let p_n = >::mul(&r, &(float!(-1.0))).scaled_add(&beta, &p); let norm = r.dot(&r.conj()).l2_norm(); self.p = Some(p_n); self.p_prev = Some(p); - self.r = Some(r); Ok(( - state.param(new_param).cost(norm), + state.param(new_param).residuals(r).cost(norm), Some(kv!("alpha" => alpha; "beta" => beta;)), )) } @@ -172,16 +163,9 @@ mod tests { #[test] fn test_new() { let cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]); - let ConjugateGradient { - b, - r, - p, - p_prev, - rtr, - } = cg; + let ConjugateGradient { b, p, p_prev, rtr } = cg; assert_eq!(b[0].to_ne_bytes(), 1.0f64.to_ne_bytes()); assert_eq!(b[1].to_ne_bytes(), 2.0f64.to_ne_bytes()); - assert!(r.is_none()); assert!(p.is_none()); assert!(p_prev.is_none()); assert!(rtr.is_nan()); @@ -226,28 +210,28 @@ mod tests { #[test] fn test_init() { let mut cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]); - let state: IterState, (), (), (), (), f64> = + let state: IterState, (), (), (), Vec, f64> = IterState::new().param(vec![3.0, 4.0]); let (state_out, kv) = cg .init(&mut Problem::new(TestProblem::new()), state.clone()) .unwrap(); assert!(kv.is_none()); - // State remains unchanged in `init`. - assert_eq!(state_out, state); - let ConjugateGradient { - b, - r, - p, - p_prev, - rtr, - } = cg; + let ConjugateGradient { b, p, p_prev, rtr } = cg; assert_relative_eq!(b[0], 1.0, epsilon = f64::EPSILON); assert_relative_eq!(b[1], 2.0, epsilon = f64::EPSILON); let r0 = [2.0f64, 2.0]; - assert_relative_eq!(r0[0], r.as_ref().unwrap()[0], epsilon = f64::EPSILON); - assert_relative_eq!(r0[1], r.as_ref().unwrap()[1], epsilon = f64::EPSILON); + assert_relative_eq!( + r0[0], + state_out.get_residuals().as_ref().unwrap()[0], + epsilon = f64::EPSILON + ); + assert_relative_eq!( + r0[1], + state_out.get_residuals().as_ref().unwrap()[1], + epsilon = f64::EPSILON + ); let pp = [-2.0f64, -2.0]; assert_relative_eq!(pp[0], p.as_ref().unwrap()[0], epsilon = f64::EPSILON); assert_relative_eq!(pp[1], p.as_ref().unwrap()[1], epsilon = f64::EPSILON); @@ -259,7 +243,6 @@ mod tests { fn test_next_iter_p_not_set() { let mut cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]); let state = IterState::new().param(vec![1.0f64]); - cg.r = Some(vec![]); assert!(cg.p.is_none()); let res = cg.next_iter(&mut Problem::new(TestProblem::new()), state); assert_error!( @@ -278,14 +261,13 @@ mod tests { let mut cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]); let state = IterState::new().param(vec![1.0f64]); cg.p = Some(vec![]); - assert!(cg.r.is_none()); let res = cg.next_iter(&mut Problem::new(TestProblem::new()), state); assert_error!( res, ArgminError, concat!( "Potential bug: \"`ConjugateGradient`: ", - "Field `r` not set\". This is potentially a bug. ", + "Residuals in `state` not set\". This is potentially a bug. ", "Please file a report on https://github.com/argmin-rs/argmin/issues" ) ); @@ -294,9 +276,8 @@ mod tests { #[test] fn test_next_iter_state_param_not_set() { let mut cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]); - let state = IterState::new(); + let state = IterState::new().residuals(vec![]); cg.p = Some(vec![]); - cg.r = Some(vec![]); assert!(state.param.is_none()); let res = cg.next_iter(&mut Problem::new(TestProblem::new()), state); assert_error!( @@ -318,7 +299,7 @@ mod tests { let (state, _) = cg.init(&mut problem, state).unwrap(); let rtr = cg.rtr; let p = cg.p.clone().unwrap()[0]; - let r = cg.r.clone().unwrap()[0]; + let r = state.get_residuals().unwrap()[0]; let apk = p; let alpha = rtr / (p * apk); @@ -332,7 +313,7 @@ mod tests { let (state, kv) = cg.next_iter(&mut problem, state).unwrap(); assert!(kv.is_some()); - assert_relative_eq!(r, cg.r.as_ref().unwrap()[0]); + assert_relative_eq!(r, state.get_residuals().unwrap()[0]); assert_relative_eq!(p_n, cg.p.as_ref().unwrap()[0]); assert_relative_eq!(p, cg.p_prev.as_ref().unwrap()[0]); assert_relative_eq!(rtr_n, cg.rtr);