From 40464a55d97bc4222f285489c74df584bdec67ad Mon Sep 17 00:00:00 2001 From: Michael Baikov Date: Wed, 13 Mar 2024 15:36:06 -0400 Subject: [PATCH] Disable counting by default --- crates/argmin/src/core/state/iterstate.rs | 14 ++++++++++ .../src/core/state/linearprogramstate.rs | 27 ++++++++++++++++--- .../argmin/src/core/state/populationstate.rs | 27 ++++++++++++++++--- crates/argmin/src/solver/brent/brentopt.rs | 2 +- .../src/solver/linesearch/backtracking.rs | 14 ++++++++-- 5 files changed, 73 insertions(+), 11 deletions(-) diff --git a/crates/argmin/src/core/state/iterstate.rs b/crates/argmin/src/core/state/iterstate.rs index cc7899625..26e99affe 100644 --- a/crates/argmin/src/core/state/iterstate.rs +++ b/crates/argmin/src/core/state/iterstate.rs @@ -972,6 +972,20 @@ where pub fn take_prev_residuals(&mut self) -> Option { self.prev_residuals.take() } + + /// Overrides state of counting function executions (default: false) + /// ``` + /// # use argmin::core::{IterState, State}; + /// # let mut state: IterState<(), (), (), (), Vec, f64> = IterState::new(); + /// # assert!(!state.counting_enabled); + /// let state = state.counting(true); + /// # assert!(state.counting_enabled); + /// ``` + #[must_use] + pub fn counting(mut self, mode: bool) -> Self { + self.counting_enabled = mode; + self + } } impl State for IterState diff --git a/crates/argmin/src/core/state/linearprogramstate.rs b/crates/argmin/src/core/state/linearprogramstate.rs index 414f56f23..fb25e8e27 100644 --- a/crates/argmin/src/core/state/linearprogramstate.rs +++ b/crates/argmin/src/core/state/linearprogramstate.rs @@ -57,6 +57,8 @@ pub struct LinearProgramState { pub max_iters: u64, /// Evaluation counts pub counts: HashMap, + /// Update evaluation counts? + pub counting_enabled: bool, /// Time required so far pub time: Option, /// Status of optimization execution @@ -151,6 +153,20 @@ impl LinearProgramState { self.cost = cost; self } + + /// Overrides state of counting function executions (default: false) + /// ``` + /// # use argmin::core::{State, LinearProgramState}; + /// # let mut state: LinearProgramState, f64> = LinearProgramState::new(); + /// # assert!(!state.counting_enabled); + /// let state = state.counting(true); + /// # assert!(state.counting_enabled); + /// ``` + #[must_use] + pub fn counting(mut self, mode: bool) -> Self { + self.counting_enabled = mode; + self + } } impl State for LinearProgramState @@ -206,6 +222,7 @@ where last_best_iter: 0, max_iters: std::u64::MAX, counts: HashMap::new(), + counting_enabled: false, time: Some(instant::Duration::new(0, 0)), termination_status: TerminationStatus::NotTerminated, } @@ -504,7 +521,7 @@ where /// ``` /// # use std::collections::HashMap; /// # use argmin::core::{Problem, LinearProgramState, State, ArgminFloat}; - /// # let mut state: LinearProgramState, f64> = LinearProgramState::new(); + /// # let mut state: LinearProgramState, f64> = LinearProgramState::new().counting(true); /// # assert_eq!(state.counts, HashMap::new()); /// # state.counts.insert("test2".to_string(), 10u64); /// # @@ -521,9 +538,11 @@ where /// # assert_eq!(state.counts, hm); /// ``` fn func_counts(&mut self, problem: &Problem) { - for (k, &v) in problem.counts.iter() { - let count = self.counts.entry(k.to_string()).or_insert(0); - *count = v + if self.counting_enabled { + for (k, &v) in problem.counts.iter() { + let count = self.counts.entry(k.to_string()).or_insert(0); + *count = v + } } } diff --git a/crates/argmin/src/core/state/populationstate.rs b/crates/argmin/src/core/state/populationstate.rs index 7fefecf8e..53bf98cea 100644 --- a/crates/argmin/src/core/state/populationstate.rs +++ b/crates/argmin/src/core/state/populationstate.rs @@ -59,6 +59,8 @@ pub struct PopulationState { pub max_iters: u64, /// Evaluation counts pub counts: HashMap, + /// Update evaluation counts? + pub counting_enabled: bool, /// Time required so far pub time: Option, /// Status of optimization execution @@ -430,6 +432,20 @@ where pub fn take_population(&mut self) -> Option> { self.population.take() } + + /// Overrides state of counting function executions (default: false) + /// ``` + /// # use argmin::core::{State, PopulationState}; + /// # let mut state: PopulationState, f64> = PopulationState::new(); + /// # assert!(!state.counting_enabled); + /// let state = state.counting(true); + /// # assert!(state.counting_enabled); + /// ``` + #[must_use] + pub fn counting(mut self, mode: bool) -> Self { + self.counting_enabled = mode; + self + } } impl State for PopulationState @@ -484,6 +500,7 @@ where last_best_iter: 0, max_iters: std::u64::MAX, counts: HashMap::new(), + counting_enabled: false, time: Some(instant::Duration::new(0, 0)), termination_status: TerminationStatus::NotTerminated, } @@ -783,7 +800,7 @@ where /// ``` /// # use std::collections::HashMap; /// # use argmin::core::{Problem, PopulationState, State, ArgminFloat}; - /// # let mut state: PopulationState, f64> = PopulationState::new(); + /// # let mut state: PopulationState, f64> = PopulationState::new().counting(true); /// # assert_eq!(state.counts, HashMap::new()); /// # state.counts.insert("test2".to_string(), 10u64); /// # @@ -800,9 +817,11 @@ where /// # assert_eq!(state.counts, hm); /// ``` fn func_counts(&mut self, problem: &Problem) { - for (k, &v) in problem.counts.iter() { - let count = self.counts.entry(k.to_string()).or_insert(0); - *count = v + if self.counting_enabled { + for (k, &v) in problem.counts.iter() { + let count = self.counts.entry(k.to_string()).or_insert(0); + *count = v + } } } diff --git a/crates/argmin/src/solver/brent/brentopt.rs b/crates/argmin/src/solver/brent/brentopt.rs index 4b685ab50..703bc6fa5 100644 --- a/crates/argmin/src/solver/brent/brentopt.rs +++ b/crates/argmin/src/solver/brent/brentopt.rs @@ -230,7 +230,7 @@ mod tests { let cost = TestFunc {}; let solver = BrentOpt::new(-10., 10.); let res = Executor::new(cost, solver) - .configure(|state| state.max_iters(13)) + .configure(|state| state.counting(true).max_iters(13)) .run() .unwrap(); assert_eq!( diff --git a/crates/argmin/src/solver/linesearch/backtracking.rs b/crates/argmin/src/solver/linesearch/backtracking.rs index 313341870..ed21926e1 100644 --- a/crates/argmin/src/solver/linesearch/backtracking.rs +++ b/crates/argmin/src/solver/linesearch/backtracking.rs @@ -640,7 +640,12 @@ mod tests { ls.search_direction(vec![2.0f64, 0.0]); let data = Executor::new(prob, ls.clone()) - .configure(|config| config.param(ls.init_param.clone().unwrap()).max_iters(10)) + .configure(|config| { + config + .counting(true) + .param(ls.init_param.clone().unwrap()) + .max_iters(10) + }) .run(); assert!(data.is_ok()); @@ -689,7 +694,12 @@ mod tests { ls.search_direction(vec![2.0f64, 0.0]); let data = Executor::new(prob, ls.clone()) - .configure(|config| config.param(ls.init_param.clone().unwrap()).max_iters(10)) + .configure(|config| { + config + .counting(true) + .param(ls.init_param.clone().unwrap()) + .max_iters(10) + }) .run(); assert!(data.is_ok());