diff --git a/argmin/src/core/executor.rs b/argmin/src/core/executor.rs index 4b529ddca..6e1ba14b7 100644 --- a/argmin/src/core/executor.rs +++ b/argmin/src/core/executor.rs @@ -26,6 +26,8 @@ pub struct Executor { observers: Observers, /// Checkpoint checkpoint: Option>>, + /// Timeout + timeout: Option, /// Indicates whether Ctrl-C functionality should be active or not ctrlc: bool, /// Indicates whether to time execution or not @@ -66,6 +68,7 @@ where state, observers: Observers::new(), checkpoint: None, + timeout: None, ctrlc: true, timer: true, } @@ -250,10 +253,18 @@ where } if self.timer { + // Increase accumulated total_time total_time.map(|total_time| state.time(Some(total_time.elapsed()))); + + // If a timeout is set, check if timeout is reached + if let (Some(timeout), Some(total_time)) = (self.timeout, total_time) { + if total_time.elapsed() > timeout { + state = state.terminate_with(TerminationReason::Timeout); + } + } } - // Check if termination occurred inside next_iter() + // Check if termination occurred in the meantime if state.terminated() { break; } @@ -374,6 +385,8 @@ where /// Enables or disables timing of individual iterations (default: enabled). /// + /// Setting this to false will silently be ignored in case a timeout is set. + /// /// # Example /// /// ``` @@ -391,7 +404,38 @@ where /// ``` #[must_use] pub fn timer(mut self, timer: bool) -> Self { - self.timer = timer; + if self.timeout.is_none() { + self.timer = timer; + } + self + } + + /// Sets a timeout for the run. + /// + /// The optimization run is stopped once the timeout is exceeded. Note that the check is + /// performed after each iteration, therefore the acutal runtime can exceed the the set + /// duration. + /// This also enables time measurements. + /// + /// # Example + /// + /// ``` + /// # use argmin::core::{Error, Executor}; + /// # use argmin::core::test_utils::{TestSolver, TestProblem}; + /// # + /// # fn main() -> Result<(), Error> { + /// # let solver = TestSolver::new(); + /// # let problem = TestProblem::new(); + /// # + /// // Create instance of `Executor` with `problem` and `solver` + /// let executor = Executor::new(problem, solver).timeout(std::time::Duration::from_secs(30)); + /// # Ok(()) + /// # } + /// ``` + #[must_use] + pub fn timeout(mut self, timeout: std::time::Duration) -> Self { + self.timer = true; + self.timeout = Some(timeout); self } } diff --git a/argmin/src/core/termination.rs b/argmin/src/core/termination.rs index ced9c9714..8abe04dad 100644 --- a/argmin/src/core/termination.rs +++ b/argmin/src/core/termination.rs @@ -31,6 +31,7 @@ impl TerminationStatus { /// assert!(TerminationStatus::Terminated(TerminationReason::TargetCostReached).terminated()); /// assert!(TerminationStatus::Terminated(TerminationReason::SolverConverged).terminated()); /// assert!(TerminationStatus::Terminated(TerminationReason::KeyboardInterrupt).terminated()); + /// assert!(TerminationStatus::Terminated(TerminationReason::Timeout).terminated()); /// assert!(TerminationStatus::Terminated(TerminationReason::SolverExit("Exit reason".to_string())).terminated()); /// ``` pub fn terminated(&self) -> bool { @@ -59,6 +60,8 @@ pub enum TerminationReason { KeyboardInterrupt, /// Converged SolverConverged, + /// Timeout reached + Timeout, /// Solver exit with given reason SolverExit(String), } @@ -88,6 +91,10 @@ impl TerminationReason { /// "Solver converged" /// ); /// assert_eq!( + /// TerminationReason::Timeout.text(), + /// "Timeout reached" + /// ); + /// assert_eq!( /// TerminationReason::SolverExit("Aborted".to_string()).text(), /// "Aborted" /// ); @@ -98,6 +105,7 @@ impl TerminationReason { TerminationReason::TargetCostReached => "Target cost value reached", TerminationReason::KeyboardInterrupt => "Keyboard interrupt", TerminationReason::SolverConverged => "Solver converged", + TerminationReason::Timeout => "Timeout reached", TerminationReason::SolverExit(reason) => reason.as_ref(), } } diff --git a/examples/timeout/Cargo.toml b/examples/timeout/Cargo.toml new file mode 100644 index 000000000..1cb73e66a --- /dev/null +++ b/examples/timeout/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "example-timeout" +version = "0.1.0" +edition = "2021" +license = "MIT OR Apache-2.0" +publish = false + +[dependencies] +argmin = { version = "*", path = "../../argmin" } +argmin-math = { version = "*", features = ["vec"], path = "../../argmin-math" } +argmin-observer-slog = { version = "*", path = "../../observers/slog/" } +argmin_testfunctions = "*" +rand = "0.8.5" +rand_xoshiro = "0.6.0" diff --git a/examples/timeout/src/main.rs b/examples/timeout/src/main.rs new file mode 100644 index 000000000..cbfa9d07b --- /dev/null +++ b/examples/timeout/src/main.rs @@ -0,0 +1,98 @@ +// Copyright 2018-2024 argmin developers +// +// Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be +// copied, modified, or distributed except according to those terms. + +// This example shows how to add a timeout ot an optimization run. The optimization will be +// terminated once the 3 seconds timeout is reached. + +use argmin::{ + core::{observers::ObserverMode, CostFunction, Error, Executor}, + solver::simulatedannealing::{Anneal, SATempFunc, SimulatedAnnealing}, +}; +use argmin_observer_slog::SlogLogger; +use argmin_testfunctions::rosenbrock; +use rand::{distributions::Uniform, prelude::*}; +use rand_xoshiro::Xoshiro256PlusPlus; +use std::sync::{Arc, Mutex}; + +struct Rosenbrock { + a: f64, + b: f64, + lower_bound: Vec, + upper_bound: Vec, + rng: Arc>, +} + +impl Rosenbrock { + pub fn new(a: f64, b: f64, lower_bound: Vec, upper_bound: Vec) -> Self { + Rosenbrock { + a, + b, + lower_bound, + upper_bound, + rng: Arc::new(Mutex::new(Xoshiro256PlusPlus::from_entropy())), + } + } +} + +impl CostFunction for Rosenbrock { + type Param = Vec; + type Output = f64; + + fn cost(&self, param: &Self::Param) -> Result { + Ok(rosenbrock(param, self.a, self.b)) + } +} + +impl Anneal for Rosenbrock { + type Param = Vec; + type Output = Vec; + type Float = f64; + + fn anneal(&self, param: &Vec, temp: f64) -> Result, Error> { + let mut param_n = param.clone(); + let mut rng = self.rng.lock().unwrap(); + let distr = Uniform::from(0..param.len()); + for _ in 0..(temp.floor() as u64 + 1) { + let idx = rng.sample(distr); + let val = rng.sample(Uniform::new_inclusive(-0.1, 0.1)); + param_n[idx] += val; + param_n[idx] = param_n[idx].clamp(self.lower_bound[idx], self.upper_bound[idx]); + } + Ok(param_n) + } +} + +fn run() -> Result<(), Error> { + let lower_bound: Vec = vec![-5.0, -5.0]; + let upper_bound: Vec = vec![5.0, 5.0]; + let operator = Rosenbrock::new(1.0, 100.0, lower_bound, upper_bound); + let init_param: Vec = vec![1.0, 1.2]; + let temp = 15.0; + let solver = SimulatedAnnealing::new(temp)?.with_temp_func(SATempFunc::Boltzmann); + + let res = Executor::new(operator, solver) + .configure(|state| state.param(init_param).max_iters(10_000_000)) + .add_observer(SlogLogger::term(), ObserverMode::Always) + ///////////////////////////////////////////////////////////////////////////////////////// + // // + // Add a timeout of 3 seconds // + // // + ///////////////////////////////////////////////////////////////////////////////////////// + .timeout(std::time::Duration::from_secs(3)) + .run()?; + + // Print result + println!("{res}"); + Ok(()) +} + +fn main() { + if let Err(ref e) = run() { + println!("{e}"); + std::process::exit(1); + } +}