Skip to content

Commit

Permalink
Add optional timeout to Executor
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan-k committed Jan 21, 2024
1 parent b74620c commit 346f1b7
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 2 deletions.
48 changes: 46 additions & 2 deletions argmin/src/core/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ pub struct Executor<O, S, I> {
observers: Observers<I>,
/// Checkpoint
checkpoint: Option<Box<dyn Checkpoint<S, I>>>,
/// Timeout
timeout: Option<std::time::Duration>,
/// Indicates whether Ctrl-C functionality should be active or not
ctrlc: bool,
/// Indicates whether to time execution or not
Expand Down Expand Up @@ -66,6 +68,7 @@ where
state,
observers: Observers::new(),
checkpoint: None,
timeout: None,
ctrlc: true,
timer: true,
}
Expand Down Expand Up @@ -250,10 +253,18 @@ where
}

if self.timer {
// Increase accumulated total_time
total_time.map(|total_time| state.time(Some(total_time.elapsed())));

// If a timeout is set, check if timeout is reached
if let (Some(timeout), Some(total_time)) = (self.timeout, total_time) {
if total_time.elapsed() > timeout {
state = state.terminate_with(TerminationReason::Timeout);
}
}
}

// Check if termination occurred inside next_iter()
// Check if termination occurred in the meantime
if state.terminated() {
break;
}
Expand Down Expand Up @@ -374,6 +385,8 @@ where

/// Enables or disables timing of individual iterations (default: enabled).
///
/// Setting this to false will silently be ignored in case a timeout is set.
///
/// # Example
///
/// ```
Expand All @@ -391,7 +404,38 @@ where
/// ```
#[must_use]
pub fn timer(mut self, timer: bool) -> Self {
self.timer = timer;
if self.timeout.is_none() {
self.timer = timer;
}
self
}

/// Sets a timeout for the run.
///
/// The optimization run is stopped once the timeout is exceeded. Note that the check is
/// performed after each iteration, therefore the acutal runtime can exceed the the set

Check warning on line 416 in argmin/src/core/executor.rs

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"acutal" should be "actual".
/// duration.
/// This also enables time measurements.
///
/// # Example
///
/// ```
/// # use argmin::core::{Error, Executor};
/// # use argmin::core::test_utils::{TestSolver, TestProblem};
/// #
/// # fn main() -> Result<(), Error> {
/// # let solver = TestSolver::new();
/// # let problem = TestProblem::new();
/// #
/// // Create instance of `Executor` with `problem` and `solver`
/// let executor = Executor::new(problem, solver).timeout(std::time::Duration::from_secs(30));
/// # Ok(())
/// # }
/// ```
#[must_use]
pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
self.timer = true;
self.timeout = Some(timeout);
self
}
}
Expand Down
8 changes: 8 additions & 0 deletions argmin/src/core/termination.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ impl TerminationStatus {
/// assert!(TerminationStatus::Terminated(TerminationReason::TargetCostReached).terminated());
/// assert!(TerminationStatus::Terminated(TerminationReason::SolverConverged).terminated());
/// assert!(TerminationStatus::Terminated(TerminationReason::KeyboardInterrupt).terminated());
/// assert!(TerminationStatus::Terminated(TerminationReason::Timeout).terminated());
/// assert!(TerminationStatus::Terminated(TerminationReason::SolverExit("Exit reason".to_string())).terminated());
/// ```
pub fn terminated(&self) -> bool {
Expand Down Expand Up @@ -59,6 +60,8 @@ pub enum TerminationReason {
KeyboardInterrupt,
/// Converged
SolverConverged,
/// Timeout reached
Timeout,
/// Solver exit with given reason
SolverExit(String),
}
Expand Down Expand Up @@ -88,6 +91,10 @@ impl TerminationReason {
/// "Solver converged"
/// );
/// assert_eq!(
/// TerminationReason::Timeout.text(),
/// "Timeout reached"
/// );
/// assert_eq!(
/// TerminationReason::SolverExit("Aborted".to_string()).text(),
/// "Aborted"
/// );
Expand All @@ -98,6 +105,7 @@ impl TerminationReason {
TerminationReason::TargetCostReached => "Target cost value reached",
TerminationReason::KeyboardInterrupt => "Keyboard interrupt",
TerminationReason::SolverConverged => "Solver converged",
TerminationReason::Timeout => "Timeout reached",
TerminationReason::SolverExit(reason) => reason.as_ref(),
}
}
Expand Down
14 changes: 14 additions & 0 deletions examples/timeout/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "example-timeout"
version = "0.1.0"
edition = "2021"
license = "MIT OR Apache-2.0"
publish = false

[dependencies]
argmin = { version = "*", path = "../../argmin" }
argmin-math = { version = "*", features = ["vec"], path = "../../argmin-math" }
argmin-observer-slog = { version = "*", path = "../../observers/slog/" }
argmin_testfunctions = "*"
rand = "0.8.5"
rand_xoshiro = "0.6.0"
98 changes: 98 additions & 0 deletions examples/timeout/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright 2018-2024 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

// This example shows how to add a timeout ot an optimization run. The optimization will be

Check warning on line 8 in examples/timeout/src/main.rs

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"ot" should be "to" or "of" or "or" or "not".
// terminated once the 3 seconds timeout is reached.

use argmin::{
core::{observers::ObserverMode, CostFunction, Error, Executor},
solver::simulatedannealing::{Anneal, SATempFunc, SimulatedAnnealing},
};
use argmin_observer_slog::SlogLogger;
use argmin_testfunctions::rosenbrock;
use rand::{distributions::Uniform, prelude::*};
use rand_xoshiro::Xoshiro256PlusPlus;
use std::sync::{Arc, Mutex};

struct Rosenbrock {
a: f64,
b: f64,
lower_bound: Vec<f64>,
upper_bound: Vec<f64>,
rng: Arc<Mutex<Xoshiro256PlusPlus>>,
}

impl Rosenbrock {
pub fn new(a: f64, b: f64, lower_bound: Vec<f64>, upper_bound: Vec<f64>) -> Self {
Rosenbrock {
a,
b,
lower_bound,
upper_bound,
rng: Arc::new(Mutex::new(Xoshiro256PlusPlus::from_entropy())),
}
}
}

impl CostFunction for Rosenbrock {
type Param = Vec<f64>;
type Output = f64;

fn cost(&self, param: &Self::Param) -> Result<Self::Output, Error> {
Ok(rosenbrock(param, self.a, self.b))
}
}

impl Anneal for Rosenbrock {
type Param = Vec<f64>;
type Output = Vec<f64>;
type Float = f64;

fn anneal(&self, param: &Vec<f64>, temp: f64) -> Result<Vec<f64>, Error> {
let mut param_n = param.clone();
let mut rng = self.rng.lock().unwrap();
let distr = Uniform::from(0..param.len());
for _ in 0..(temp.floor() as u64 + 1) {
let idx = rng.sample(distr);
let val = rng.sample(Uniform::new_inclusive(-0.1, 0.1));
param_n[idx] += val;
param_n[idx] = param_n[idx].clamp(self.lower_bound[idx], self.upper_bound[idx]);
}
Ok(param_n)
}
}

fn run() -> Result<(), Error> {
let lower_bound: Vec<f64> = vec![-5.0, -5.0];
let upper_bound: Vec<f64> = vec![5.0, 5.0];
let operator = Rosenbrock::new(1.0, 100.0, lower_bound, upper_bound);
let init_param: Vec<f64> = vec![1.0, 1.2];
let temp = 15.0;
let solver = SimulatedAnnealing::new(temp)?.with_temp_func(SATempFunc::Boltzmann);

let res = Executor::new(operator, solver)
.configure(|state| state.param(init_param).max_iters(10_000_000))
.add_observer(SlogLogger::term(), ObserverMode::Always)
/////////////////////////////////////////////////////////////////////////////////////////
// //
// Add a timeout of 3 seconds //
// //
/////////////////////////////////////////////////////////////////////////////////////////
.timeout(std::time::Duration::from_secs(3))
.run()?;

// Print result
println!("{res}");
Ok(())
}

fn main() {
if let Err(ref e) = run() {
println!("{e}");
std::process::exit(1);
}
}

0 comments on commit 346f1b7

Please sign in to comment.