-
-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PoC: Python bindings #340
base: main
Are you sure you want to change the base?
PoC: Python bindings #340
Changes from 1 commit
4c94c51
bc2e840
4c0c79c
007a13c
f24ea5a
b669e48
f096301
839ef5f
72677e6
1be6014
d430569
4cbf7a0
61f157e
52030e0
ee48f38
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,5 @@ target/ | |
target/* | ||
*.log | ||
justfile | ||
.vscode | ||
.venv |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
members = [ | ||
"argmin", | ||
"argmin-math", | ||
"argmin-py", | ||
] | ||
|
||
exclude = [ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
[package] | ||
name = "argmin-py" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
[lib] | ||
name = "argmin" | ||
crate-type = ["cdylib"] | ||
|
||
[dependencies] | ||
anyhow = "1.0.70" | ||
argmin_testfunctions = "0.1.1" | ||
argmin = {path="../argmin", default-features=false, features=[]} | ||
argmin-math = {path="../argmin-math", features=["ndarray_latest-serde"]} | ||
ndarray-linalg = { version = "0.16", features = ["netlib"] } | ||
ndarray = { version = "0.15", features = ["serde-1"] } | ||
numpy = "0.18.0" | ||
pyo3 = {version="0.18.1", features=["extension-module", "anyhow"]} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Copyright 2018-2023 argmin developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | ||
# http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or | ||
# http://opensource.org/licenses/MIT>, at your option. This file may not be | ||
# copied, modified, or distributed except according to those terms. | ||
from argmin import Problem, Solver, Executor | ||
import numpy as np | ||
from scipy.optimize import rosen_der, rosen_hess | ||
|
||
|
||
def main(): | ||
problem = Problem( | ||
gradient=rosen_der, | ||
hessian=rosen_hess, | ||
) | ||
solver = Solver.Newton | ||
executor = Executor(problem, solver) | ||
executor.configure(param=np.array([-1.2, 1.0]), max_iters=8) | ||
|
||
result = executor.run() | ||
print(result) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
// Copyright 2018-2023 argmin developers | ||
// | ||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | ||
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or | ||
// http://opensource.org/licenses/MIT>, at your option. This file may not be | ||
// copied, modified, or distributed except according to those terms. | ||
|
||
// TODO: docs | ||
|
||
use pyo3::{prelude::*, types::PyDict}; | ||
|
||
use argmin::core; | ||
|
||
use crate::problem::Problem; | ||
use crate::solver::{DynamicSolver, Solver}; | ||
use crate::types::{IterState, PyArray1}; | ||
|
||
#[pyclass] | ||
pub struct Executor(Option<core::Executor<Problem, DynamicSolver, IterState>>); | ||
|
||
impl Executor { | ||
/// Consumes the inner executor. | ||
/// | ||
/// PyObjects do not allow methods that consume the object itself, so this is a workaround | ||
/// for using methods like `configure` and `run`. | ||
fn take(&mut self) -> anyhow::Result<core::Executor<Problem, DynamicSolver, IterState>> { | ||
stefan-k marked this conversation as resolved.
Show resolved
Hide resolved
|
||
let Some(inner) = self.0.take() else { | ||
return Err(anyhow::anyhow!("Executor was already run.")); | ||
}; | ||
Ok(inner) | ||
} | ||
} | ||
|
||
#[pymethods] | ||
impl Executor { | ||
#[new] | ||
fn new(problem: Problem, solver: Solver) -> Self { | ||
Self(Some(core::Executor::new(problem, solver.into()))) | ||
} | ||
|
||
#[pyo3(signature = (**kwargs))] | ||
fn configure(&mut self, kwargs: Option<&PyDict>) -> PyResult<()> { | ||
if let Some(kwargs) = kwargs { | ||
let new_self = self.take()?.configure(|mut state| { | ||
if let Some(param) = kwargs.get_item("param") { | ||
let param: &PyArray1 = param.extract().unwrap(); | ||
state = state.param(param.to_owned_array()); | ||
} | ||
if let Some(max_iters) = kwargs.get_item("max_iters") { | ||
state = state.max_iters(max_iters.extract().unwrap()); | ||
} | ||
state | ||
}); | ||
self.0 = Some(new_self); | ||
} | ||
Ok(()) | ||
} | ||
|
||
fn run(&mut self) -> PyResult<String> { | ||
// TODO: return usable OptimizationResult | ||
let res = self.take()?.run(); | ||
Ok(res?.to_string()) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
// Copyright 2018-2023 argmin developers | ||
// | ||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | ||
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or | ||
// http://opensource.org/licenses/MIT>, at your option. This file may not be | ||
// copied, modified, or distributed except according to those terms. | ||
|
||
// TODO: docs | ||
mod executor; | ||
mod problem; | ||
mod solver; | ||
mod types; | ||
|
||
use pyo3::prelude::*; | ||
|
||
#[pymodule] | ||
#[pyo3(name = "argmin")] | ||
fn argmin_py(_py: Python, m: &PyModule) -> PyResult<()> { | ||
m.add_class::<executor::Executor>()?; | ||
m.add_class::<problem::Problem>()?; | ||
m.add_class::<solver::Solver>()?; | ||
|
||
Ok(()) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
// Copyright 2018-2023 argmin developers | ||
// | ||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | ||
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or | ||
// http://opensource.org/licenses/MIT>, at your option. This file may not be | ||
// copied, modified, or distributed except according to those terms. | ||
|
||
// TODO: docs | ||
|
||
use numpy::ToPyArray; | ||
use pyo3::{prelude::*, types::PyTuple}; | ||
|
||
use argmin::core; | ||
|
||
use crate::types::{Array1, Array2, Scalar}; | ||
|
||
#[pyclass] | ||
#[derive(Clone)] | ||
pub struct Problem { | ||
gradient: PyObject, | ||
hessian: PyObject, | ||
// TODO: jacobian | ||
} | ||
|
||
#[pymethods] | ||
impl Problem { | ||
#[new] | ||
fn new(gradient: PyObject, hessian: PyObject) -> Self { | ||
Self { gradient, hessian } | ||
} | ||
} | ||
|
||
impl core::Gradient for Problem { | ||
type Param = Array1; | ||
type Gradient = Array1; | ||
|
||
fn gradient(&self, param: &Self::Param) -> Result<Self::Gradient, argmin::core::Error> { | ||
call(&self.gradient, param) | ||
} | ||
} | ||
|
||
impl argmin::core::Hessian for Problem { | ||
type Param = Array1; | ||
|
||
type Hessian = Array2; | ||
|
||
fn hessian(&self, param: &Self::Param) -> Result<Self::Hessian, core::Error> { | ||
call(&self.hessian, param) | ||
} | ||
} | ||
|
||
fn call<InputDimension, OutputDimension>( | ||
callable: &PyObject, | ||
param: &ndarray::Array<Scalar, InputDimension>, | ||
) -> Result<ndarray::Array<Scalar, OutputDimension>, argmin::core::Error> | ||
where | ||
InputDimension: ndarray::Dimension, | ||
OutputDimension: ndarray::Dimension, | ||
{ | ||
// TODO: prevent dynamic dispatch for every call | ||
Python::with_gil(|py| { | ||
let args = PyTuple::new(py, [param.to_pyarray(py)]); | ||
let pyresult = callable.call(py, args, Default::default())?; | ||
let pyarray = pyresult.extract::<&numpy::PyArray<Scalar, OutputDimension>>(py)?; | ||
// TODO: try to get ownership instead of cloning | ||
Ok(pyarray.to_owned_array()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just had a new idea: You're currently using the I somehow thought that it would be possible to use the underlying memory in both Rust and Python without copying but that doesn't seem to be the case. Regarding point 2: I agree. I assumed that there is also There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I'll give it a try! |
||
}) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
// Copyright 2018-2023 argmin developers | ||
// | ||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | ||
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or | ||
// http://opensource.org/licenses/MIT>, at your option. This file may not be | ||
// copied, modified, or distributed except according to those terms. | ||
|
||
// TODO: docs | ||
|
||
use pyo3::prelude::*; | ||
|
||
use argmin::{core, solver}; | ||
|
||
use crate::{ | ||
problem::Problem, | ||
types::{IterState, Scalar}, | ||
}; | ||
|
||
#[pyclass] | ||
#[derive(Clone)] | ||
pub enum Solver { | ||
Newton, | ||
} | ||
|
||
pub enum DynamicSolver { | ||
// NOTE: I tried using a Box<dyn Solver<> here, but Solver is not object safe. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Too bad!! That's what I was hoping for, but I tend to forget about object safety. I don't think it'll be possible to make it object safe :( |
||
Newton(solver::newton::Newton<Scalar>), | ||
} | ||
|
||
impl From<Solver> for DynamicSolver { | ||
fn from(solver: Solver) -> Self { | ||
match solver { | ||
Solver::Newton => Self::Newton(solver::newton::Newton::new()), | ||
} | ||
} | ||
} | ||
|
||
impl core::Solver<Problem, IterState> for DynamicSolver { | ||
// TODO: make this a trait method so we can return a dynamic | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good to me! We could have both for backwards compatibility, right? The default impl of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems I was able to solve two problems at once: When I remove the associated constant, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! I thought the generics would also be a problem for object safety, but it's great if this isn't the case. Sounds good to me! |
||
const NAME: &'static str = "Dynamic Solver"; | ||
|
||
fn next_iter( | ||
&mut self, | ||
problem: &mut core::Problem<Problem>, | ||
state: IterState, | ||
) -> Result<(IterState, Option<core::KV>), core::Error> { | ||
match self { | ||
DynamicSolver::Newton(inner) => inner.next_iter(problem, state), | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
// Copyright 2018-2023 argmin developers | ||
// | ||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | ||
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or | ||
// http://opensource.org/licenses/MIT>, at your option. This file may not be | ||
// copied, modified, or distributed except according to those terms. | ||
|
||
//! Base types for the Python extension. | ||
|
||
pub type Scalar = f64; // TODO: allow complex numbers | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Complex numbers would be great but I wouldn't give this a high priority. |
||
pub type Array1 = ndarray::Array1<Scalar>; | ||
pub type Array2 = ndarray::Array2<Scalar>; | ||
pub type PyArray1 = numpy::PyArray1<Scalar>; | ||
|
||
pub type IterState = argmin::core::IterState<Array1, Array1, (), ndarray::Array2<Scalar>, Scalar>; |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,7 @@ pub struct Executor<O, S, I> { | |
/// Storage for observers | ||
observers: Observers<I>, | ||
/// Checkpoint | ||
checkpoint: Option<Box<dyn Checkpoint<S, I>>>, | ||
checkpoint: Option<Box<dyn Checkpoint<S, I> + Send>>, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added these bounds because There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds also good to me. Would it make sense to add the |
||
/// Indicates whether Ctrl-C functionality should be active or not | ||
ctrlc: bool, | ||
/// Indicates whether to time execution or not | ||
|
@@ -298,7 +298,7 @@ where | |
/// # } | ||
/// ``` | ||
#[must_use] | ||
pub fn add_observer<OBS: Observe<I> + 'static>( | ||
pub fn add_observer<OBS: Observe<I> + 'static + Send>( | ||
mut self, | ||
observer: OBS, | ||
mode: ObserverMode, | ||
|
@@ -340,7 +340,7 @@ where | |
/// # } | ||
/// ``` | ||
#[must_use] | ||
pub fn checkpointing<C: 'static + Checkpoint<S, I>>(mut self, checkpoint: C) -> Self { | ||
pub fn checkpointing<C: 'static + Checkpoint<S, I> + Send>(mut self, checkpoint: C) -> Self { | ||
self.checkpoint = Some(Box::new(checkpoint)); | ||
self | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This section requires cleanup, I am not sure what's the best configuration of features.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks mostly fine. I think in the long run at least the
serde1
feature of argmin can be enabled because that would allow checkpointing. But I guess checkpointing will need more work anyways.At some point we will have to decide which BLAS backend to use. This is probably mostly a platform issue (only Intel-MKL works on Linux, Windows and Mac) and a licensing issue since the compiled code will be packaged into a python module.