Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PoC: Python bindings #340

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,11 @@ jobs:
- name: Install wasm-pack
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
- name: Build target wasm32-unknown-unknown
run: cargo build --workspace --exclude argmin-observer-spectator --exclude spectator --exclude argmin-observer-paramwriter --exclude "example-*" --target wasm32-unknown-unknown --features wasm-bindgen
run: cargo build --workspace --exclude argmin-observer-spectator --exclude spectator --exclude argmin-observer-paramwriter --exclude "example-*" --exclude argmin-py --target wasm32-unknown-unknown --features wasm-bindgen
- name: Build target wasm32-wasi with feature wasm-bindgen
run: cargo build --workspace --exclude argmin-observer-spectator --exclude spectator --exclude argmin-observer-paramwriter --exclude "example-*" --target wasm32-wasi --features wasm-bindgen
run: cargo build --workspace --exclude argmin-observer-spectator --exclude spectator --exclude argmin-observer-paramwriter --exclude "example-*" --exclude argmin-py --target wasm32-wasi --features wasm-bindgen
- name: Build target wasm32-unknown-emscripten
run: cargo build --workspace --exclude argmin-observer-spectator --exclude spectator --exclude argmin-observer-paramwriter --exclude "example-*" --target wasm32-unknown-emscripten --no-default-features --features wasm-bindgen
run: cargo build --workspace --exclude argmin-observer-spectator --exclude spectator --exclude argmin-observer-paramwriter --exclude "example-*" --exclude argmin-py --target wasm32-unknown-emscripten --no-default-features --features wasm-bindgen

cargo-deny:
runs-on: ubuntu-latest
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ target/
target/*
*.log
justfile
.vscode
.venv
28 changes: 28 additions & 0 deletions crates/argmin-py/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
[package]
name = "argmin-py"
version = "0.1.0"
authors = ["Joris Bayer <[email protected]"]
edition = "2021"
license = "MIT OR Apache-2.0"
description = "argmin python bindings"
homepage = "https://argmin-rs.org"
repository = "https://github.com/argmin-rs/argmin"
readme = "README.md"
keywords = ["optimization", "math", "science"]
categories = ["science"]
exclude = []


[lib]
name = "argmin"
crate-type = ["cdylib"]

[dependencies]
anyhow = "1.0.70"
argmin_testfunctions = { version = "0.1.1", path = "../argmin-testfunctions" }
argmin = {path="../argmin", default-features=false, features=[]}
argmin-math = {path="../argmin-math", features=["ndarray_latest"]}
ndarray = { version = "0.15", features = ["serde-1"] }
ndarray-linalg = { version = "0.16", features = ["intel-mkl-static"] }
numpy = "0.20.0"
pyo3 = {version="0.20.2", features=["extension-module", "anyhow"]}
Empty file added crates/argmin-py/README.md
Empty file.
26 changes: 26 additions & 0 deletions crates/argmin-py/examples/newton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright 2018-2023 argmin developers
#
# Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
# http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
# http://opensource.org/licenses/MIT>, at your option. This file may not be
# copied, modified, or distributed except according to those terms.
from argmin import Problem, Solver, Executor
import numpy as np
from scipy.optimize import rosen_der, rosen_hess


def main():
problem = Problem(
gradient=rosen_der,
hessian=rosen_hess,
)
solver = Solver.Newton
executor = Executor(problem, solver)
executor.configure(param=np.array([-1.2, 1.0]), max_iters=8)

result = executor.run()
print(result)


if __name__ == "__main__":
main()
71 changes: 71 additions & 0 deletions crates/argmin-py/src/executor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright 2018-2023 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

// TODO: docs

use pyo3::{prelude::*, types::PyDict};

use argmin::core;

use crate::problem::Problem;
use crate::solver::{DynamicSolver, Solver};
use crate::types::{IterState, PyArray1};

#[pyclass]
pub struct Executor(Option<core::Executor<Problem, DynamicSolver, IterState>>);

impl Executor {
/// Consumes the inner executor.
///
/// PyObjects do not allow methods that consume the object itself, so this is a workaround
/// for using methods like `configure` and `run`.
fn take(&mut self) -> anyhow::Result<core::Executor<Problem, DynamicSolver, IterState>> {
let Some(inner) = self.0.take() else {
return Err(anyhow::anyhow!("Executor was already run."));
};
Ok(inner)
}
}

#[pymethods]
impl Executor {
#[new]
fn new(problem: Problem, solver: Solver) -> Self {
Self(Some(core::Executor::new(problem, solver.into())))
}

#[pyo3(signature = (**kwargs))]
fn configure(&mut self, kwargs: Option<&PyDict>) -> PyResult<()> {
if let Some(kwargs) = kwargs {
let param = kwargs
.get_item("param")?
.map(|x| x.extract::<&PyArray1>())
.map_or(Ok(None), |r| r.map(Some))?;
let max_iters = kwargs
.get_item("max_iters")?
.map(|x| x.extract())
.map_or(Ok(None), |r| r.map(Some))?;

self.0 = Some(self.take()?.configure(|mut state| {
if let Some(param) = param {
state = state.param(param.to_owned_array());
}
if let Some(max_iters) = max_iters {
state = state.max_iters(max_iters);
}
state
}));
}
Ok(())
}

fn run(&mut self) -> PyResult<String> {
// TODO: return usable OptimizationResult
let res = self.take()?.run();
Ok(res?.to_string())
}
}
24 changes: 24 additions & 0 deletions crates/argmin-py/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright 2018-2023 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

// TODO: docs
mod executor;
mod problem;
mod solver;
mod types;

use pyo3::prelude::*;

#[pymodule]
#[pyo3(name = "argmin")]
fn argmin_py(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<executor::Executor>()?;
m.add_class::<problem::Problem>()?;
m.add_class::<solver::Solver>()?;

Ok(())
}
68 changes: 68 additions & 0 deletions crates/argmin-py/src/problem.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright 2018-2023 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

// TODO: docs

use numpy::ToPyArray;
use pyo3::{prelude::*, types::PyTuple};

use argmin::core;

use crate::types::{Array1, Array2, Scalar};

#[pyclass]
#[derive(Clone)]
pub struct Problem {
gradient: PyObject,
hessian: PyObject,
// TODO: jacobian
}

#[pymethods]
impl Problem {
#[new]
fn new(gradient: PyObject, hessian: PyObject) -> Self {
Self { gradient, hessian }
}
}

impl core::Gradient for Problem {
type Param = Array1;
type Gradient = Array1;

fn gradient(&self, param: &Self::Param) -> Result<Self::Gradient, argmin::core::Error> {
call(&self.gradient, param)
}
}

impl argmin::core::Hessian for Problem {
type Param = Array1;

type Hessian = Array2;

fn hessian(&self, param: &Self::Param) -> Result<Self::Hessian, core::Error> {
call(&self.hessian, param)
}
}

fn call<InputDimension, OutputDimension>(
callable: &PyObject,
param: &ndarray::Array<Scalar, InputDimension>,
) -> Result<ndarray::Array<Scalar, OutputDimension>, argmin::core::Error>
where
InputDimension: ndarray::Dimension,
OutputDimension: ndarray::Dimension,
{
// TODO: prevent dynamic dispatch for every call
Python::with_gil(|py| {
let args = PyTuple::new(py, [param.to_pyarray(py)]);
let pyresult = callable.call(py, args, Default::default())?;
let pyarray = pyresult.extract::<&numpy::PyArray<Scalar, OutputDimension>>(py)?;
// TODO: try to get ownership instead of cloning
Ok(pyarray.to_owned_array())
})
}
46 changes: 46 additions & 0 deletions crates/argmin-py/src/solver.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright 2018-2023 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

// TODO: docs

use pyo3::prelude::*;

use argmin::{core, solver};

use crate::{problem::Problem, types::IterState};

#[pyclass]
#[derive(Clone)]
pub enum Solver {
Newton,
}

pub struct DynamicSolver(Box<dyn core::Solver<Problem, IterState> + Send>);

impl From<Solver> for DynamicSolver {
fn from(value: Solver) -> Self {
let inner = match value {
Solver::Newton => solver::newton::Newton::new(),
};
Self(Box::new(inner))
}
}

impl core::Solver<Problem, IterState> for DynamicSolver {
// TODO: make this a trait method so we can return a dynamic
fn name(&self) -> &str {
self.0.name()
}

fn next_iter(
&mut self,
problem: &mut core::Problem<Problem>,
state: IterState,
) -> Result<(IterState, Option<core::KV>), core::Error> {
self.0.next_iter(problem, state)
}
}
16 changes: 16 additions & 0 deletions crates/argmin-py/src/types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright 2018-2023 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

//! Base types for the Python extension.

pub type Scalar = f64; // TODO: allow complex numbers
pub type Array1 = ndarray::Array1<Scalar>;
pub type Array2 = ndarray::Array2<Scalar>;
pub type PyArray1 = numpy::PyArray1<Scalar>;

pub type IterState =
argmin::core::IterState<Array1, Array1, (), ndarray::Array2<Scalar>, (), Scalar>;
2 changes: 1 addition & 1 deletion crates/argmin/src/core/checkpointing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ use std::fmt::Display;
/// }
/// # fn main() {}
/// ```
pub trait Checkpoint<S, I> {
pub trait Checkpoint<S, I>: Send {
/// Save a checkpoint
///
/// Gets a reference to the current `solver` of type `S` and to the current `state` of type
Expand Down
7 changes: 5 additions & 2 deletions crates/argmin/src/core/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ where
let kv = kv.unwrap_or(kv![]);

// Observe after init
self.observers.observe_init(S::NAME, &state, &kv)?;
self.observers
.observe_init(self.solver.name(), &state, &kv)?;
}

state.func_counts(&self.problem);
Expand Down Expand Up @@ -681,7 +682,9 @@ mod tests {
P: Clone,
F: ArgminFloat,
{
const NAME: &'static str = "OptimizationAlgorithm";
fn name(&self) -> &str {
"OptimizationAlgorithm"
}

// Only resets internal_state to 1
fn init(
Expand Down
2 changes: 1 addition & 1 deletion crates/argmin/src/core/observers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ use std::sync::{Arc, Mutex};
/// }
/// }
/// ```
pub trait Observe<I> {
pub trait Observe<I>: Send {
/// Called once after initialization of the solver.
///
/// Has access to the name of the solver via `name`, the initial `state` and to a key-value
Expand Down
2 changes: 1 addition & 1 deletion crates/argmin/src/core/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ where
{
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
writeln!(f, "OptimizationResult:")?;
writeln!(f, " Solver: {}", S::NAME)?;
writeln!(f, " Solver: {}", self.solver().name())?;
writeln!(
f,
" param (best): {}",
Expand Down
7 changes: 5 additions & 2 deletions crates/argmin/src/core/solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use crate::core::{Error, Problem, State, TerminationReason, TerminationStatus, K
/// P: Clone,
/// F: ArgminFloat
/// {
/// const NAME: &'static str = "OptimizationAlgorithm";
/// fn name(&self) -> &str { "OptimizationAlgorithm" }
///
/// fn init(
/// &mut self,
Expand Down Expand Up @@ -67,7 +67,7 @@ use crate::core::{Error, Problem, State, TerminationReason, TerminationStatus, K
/// ```
pub trait Solver<O, I: State> {
/// Name of the solver. Mainly used in [Observers](`crate::core::observers::Observe`).
const NAME: &'static str;
// const NAME: &'static str;

/// Initializes the algorithm.
///
Expand Down Expand Up @@ -117,4 +117,7 @@ pub trait Solver<O, I: State> {
fn terminate(&mut self, _state: &I) -> TerminationStatus {
TerminationStatus::NotTerminated
}

/// Returns the name of the solver.
fn name(&self) -> &str;
}
4 changes: 3 additions & 1 deletion crates/argmin/src/core/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,9 @@ impl TestSolver {
}

impl<O> Solver<O, IterState<Vec<f64>, (), (), (), (), f64>> for TestSolver {
const NAME: &'static str = "TestSolver";
fn name(&self) -> &str {
"TestSolver"
}

fn next_iter(
&mut self,
Expand Down
Loading