Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PoC: Python bindings #340

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ target/
target/*
*.log
justfile
.vscode
.venv
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
members = [
"argmin",
"argmin-math",
"argmin-py",
]

exclude = [
Expand Down
18 changes: 18 additions & 0 deletions argmin-py/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[package]
name = "argmin-py"
version = "0.1.0"
edition = "2021"

[lib]
name = "argmin"
crate-type = ["cdylib"]

[dependencies]
anyhow = "1.0.70"
argmin_testfunctions = "0.1.1"
argmin = {path="../argmin", default-features=false, features=[]}
argmin-math = {path="../argmin-math", features=["ndarray_latest-serde"]}
ndarray-linalg = { version = "0.16", features = ["netlib"] }
ndarray = { version = "0.15", features = ["serde-1"] }
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section requires cleanup, I am not sure what's the best configuration of features.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks mostly fine. I think in the long run at least the serde1 feature of argmin can be enabled because that would allow checkpointing. But I guess checkpointing will need more work anyways.
At some point we will have to decide which BLAS backend to use. This is probably mostly a platform issue (only Intel-MKL works on Linux, Windows and Mac) and a licensing issue since the compiled code will be packaged into a python module.

numpy = "0.18.0"
pyo3 = {version="0.18.1", features=["extension-module", "anyhow"]}
26 changes: 26 additions & 0 deletions argmin-py/examples/newton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright 2018-2023 argmin developers
#
# Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
# http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
# http://opensource.org/licenses/MIT>, at your option. This file may not be
# copied, modified, or distributed except according to those terms.
from argmin import Problem, Solver, Executor
import numpy as np
from scipy.optimize import rosen_der, rosen_hess


def main():
problem = Problem(
gradient=rosen_der,
hessian=rosen_hess,
)
solver = Solver.Newton
executor = Executor(problem, solver)
executor.configure(param=np.array([-1.2, 1.0]), max_iters=8)

result = executor.run()
print(result)


if __name__ == "__main__":
main()
64 changes: 64 additions & 0 deletions argmin-py/src/executor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright 2018-2023 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

// TODO: docs

use pyo3::{prelude::*, types::PyDict};

use argmin::core;

use crate::problem::Problem;
use crate::solver::{DynamicSolver, Solver};
use crate::types::{IterState, PyArray1};

#[pyclass]
pub struct Executor(Option<core::Executor<Problem, DynamicSolver, IterState>>);

impl Executor {
/// Consumes the inner executor.
///
/// PyObjects do not allow methods that consume the object itself, so this is a workaround
/// for using methods like `configure` and `run`.
fn take(&mut self) -> anyhow::Result<core::Executor<Problem, DynamicSolver, IterState>> {
stefan-k marked this conversation as resolved.
Show resolved Hide resolved
let Some(inner) = self.0.take() else {
return Err(anyhow::anyhow!("Executor was already run."));
};
Ok(inner)
}
}

#[pymethods]
impl Executor {
#[new]
fn new(problem: Problem, solver: Solver) -> Self {
Self(Some(core::Executor::new(problem, solver.into())))
}

#[pyo3(signature = (**kwargs))]
fn configure(&mut self, kwargs: Option<&PyDict>) -> PyResult<()> {
if let Some(kwargs) = kwargs {
let new_self = self.take()?.configure(|mut state| {
if let Some(param) = kwargs.get_item("param") {
let param: &PyArray1 = param.extract().unwrap();
state = state.param(param.to_owned_array());
}
if let Some(max_iters) = kwargs.get_item("max_iters") {
state = state.max_iters(max_iters.extract().unwrap());
}
state
});
self.0 = Some(new_self);
}
Ok(())
}

fn run(&mut self) -> PyResult<String> {
// TODO: return usable OptimizationResult
let res = self.take()?.run();
Ok(res?.to_string())
}
}
24 changes: 24 additions & 0 deletions argmin-py/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright 2018-2023 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

// TODO: docs
mod executor;
mod problem;
mod solver;
mod types;

use pyo3::prelude::*;

#[pymodule]
#[pyo3(name = "argmin")]
fn argmin_py(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<executor::Executor>()?;
m.add_class::<problem::Problem>()?;
m.add_class::<solver::Solver>()?;

Ok(())
}
68 changes: 68 additions & 0 deletions argmin-py/src/problem.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright 2018-2023 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

// TODO: docs

use numpy::ToPyArray;
use pyo3::{prelude::*, types::PyTuple};

use argmin::core;

use crate::types::{Array1, Array2, Scalar};

#[pyclass]
#[derive(Clone)]
pub struct Problem {
gradient: PyObject,
hessian: PyObject,
// TODO: jacobian
}

#[pymethods]
impl Problem {
#[new]
fn new(gradient: PyObject, hessian: PyObject) -> Self {
Self { gradient, hessian }
}
}

impl core::Gradient for Problem {
type Param = Array1;
type Gradient = Array1;

fn gradient(&self, param: &Self::Param) -> Result<Self::Gradient, argmin::core::Error> {
call(&self.gradient, param)
}
}

impl argmin::core::Hessian for Problem {
type Param = Array1;

type Hessian = Array2;

fn hessian(&self, param: &Self::Param) -> Result<Self::Hessian, core::Error> {
call(&self.hessian, param)
}
}

fn call<InputDimension, OutputDimension>(
callable: &PyObject,
param: &ndarray::Array<Scalar, InputDimension>,
) -> Result<ndarray::Array<Scalar, OutputDimension>, argmin::core::Error>
where
InputDimension: ndarray::Dimension,
OutputDimension: ndarray::Dimension,
{
// TODO: prevent dynamic dispatch for every call
Python::with_gil(|py| {
let args = PyTuple::new(py, [param.to_pyarray(py)]);
let pyresult = callable.call(py, args, Default::default())?;
let pyarray = pyresult.extract::<&numpy::PyArray<Scalar, OutputDimension>>(py)?;
// TODO: try to get ownership instead of cloning
Ok(pyarray.to_owned_array())
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I am unsure what's the overhead of calling to_pyarray and extract for every evaluation of the gradient, hessian etc. Probably needs benchmarks.
  2. to_owned_array makes a copy of the data, this should not be necessary.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just had a new idea: You're currently using the ndarray backend which requires transitioning between numpy arrays and ndarray arrays. Instead we could add a new math backend based on PyArray, which would mean that numpy would do all the heavy lifting. I'm not sure whether numpy or ndarray is faster and I haven't really thought this through either.

I somehow thought that it would be possible to use the underlying memory in both Rust and Python without copying but that doesn't seem to be the case.

Regarding point 2: I agree. I assumed that there is also into_owned_array but that does not seem to be the case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I somehow thought that it would be possible to use the underlying memory in both Rust and Python without copying but that doesn't seem to be the case.

I'll give it a try!

})
}
51 changes: 51 additions & 0 deletions argmin-py/src/solver.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright 2018-2023 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

// TODO: docs

use pyo3::prelude::*;

use argmin::{core, solver};

use crate::{
problem::Problem,
types::{IterState, Scalar},
};

#[pyclass]
#[derive(Clone)]
pub enum Solver {
Newton,
}

pub enum DynamicSolver {
// NOTE: I tried using a Box<dyn Solver<> here, but Solver is not object safe.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too bad!! That's what I was hoping for, but I tend to forget about object safety. I don't think it'll be possible to make it object safe :(

Newton(solver::newton::Newton<Scalar>),
}

impl From<Solver> for DynamicSolver {
fn from(solver: Solver) -> Self {
match solver {
Solver::Newton => Self::Newton(solver::newton::Newton::new()),
}
}
}

impl core::Solver<Problem, IterState> for DynamicSolver {
// TODO: make this a trait method so we can return a dynamic
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me! We could have both for backwards compatibility, right? The default impl of the name method would then just return self.NAME.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems I was able to solve two problems at once: When I remove the associated constant, Solver becomes object-safe, so we can create trait objects for it. Let me know if you objects.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I thought the generics would also be a problem for object safety, but it's great if this isn't the case. Sounds good to me!

const NAME: &'static str = "Dynamic Solver";

fn next_iter(
&mut self,
problem: &mut core::Problem<Problem>,
state: IterState,
) -> Result<(IterState, Option<core::KV>), core::Error> {
match self {
DynamicSolver::Newton(inner) => inner.next_iter(problem, state),
}
}
}
15 changes: 15 additions & 0 deletions argmin-py/src/types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright 2018-2023 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

//! Base types for the Python extension.

pub type Scalar = f64; // TODO: allow complex numbers
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Complex numbers would be great but I wouldn't give this a high priority.

pub type Array1 = ndarray::Array1<Scalar>;
pub type Array2 = ndarray::Array2<Scalar>;
pub type PyArray1 = numpy::PyArray1<Scalar>;

pub type IterState = argmin::core::IterState<Array1, Array1, (), ndarray::Array2<Scalar>, Scalar>;
6 changes: 3 additions & 3 deletions argmin/src/core/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub struct Executor<O, S, I> {
/// Storage for observers
observers: Observers<I>,
/// Checkpoint
checkpoint: Option<Box<dyn Checkpoint<S, I>>>,
checkpoint: Option<Box<dyn Checkpoint<S, I> + Send>>,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added these bounds because PyClass must be Send.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds also good to me. Would it make sense to add the Send bound to the Checkpoint trait?

/// Indicates whether Ctrl-C functionality should be active or not
ctrlc: bool,
/// Indicates whether to time execution or not
Expand Down Expand Up @@ -298,7 +298,7 @@ where
/// # }
/// ```
#[must_use]
pub fn add_observer<OBS: Observe<I> + 'static>(
pub fn add_observer<OBS: Observe<I> + 'static + Send>(
mut self,
observer: OBS,
mode: ObserverMode,
Expand Down Expand Up @@ -340,7 +340,7 @@ where
/// # }
/// ```
#[must_use]
pub fn checkpointing<C: 'static + Checkpoint<S, I>>(mut self, checkpoint: C) -> Self {
pub fn checkpointing<C: 'static + Checkpoint<S, I> + Send>(mut self, checkpoint: C) -> Self {
self.checkpoint = Some(Box::new(checkpoint));
self
}
Expand Down
4 changes: 2 additions & 2 deletions argmin/src/core/observers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ pub trait Observe<I> {
}
}

type ObserversVec<I> = Vec<(Arc<Mutex<dyn Observe<I>>>, ObserverMode)>;
type ObserversVec<I> = Vec<(Arc<Mutex<dyn Observe<I> + Send>>, ObserverMode)>;

/// Container for observers.
///
Expand Down Expand Up @@ -236,7 +236,7 @@ impl<I> Observers<I> {
/// # #[cfg(feature = "slog-logger")]
/// # assert!(!observers.is_empty());
/// ```
pub fn push<OBS: Observe<I> + 'static>(
pub fn push<OBS: Observe<I> + 'static + Send>(
&mut self,
observer: OBS,
mode: ObserverMode,
Expand Down