From fb810f6060d70413c162f009ea2867ab120a161b Mon Sep 17 00:00:00 2001 From: Stefan Kroboth Date: Thu, 7 Mar 2024 21:28:12 +0100 Subject: [PATCH] Jacobian stuff --- python/finitediff-py/src/lib.rs | 106 +++++++++++++++++++++++++++++++- python/finitediff-py/test.py | 39 +++++++++--- 2 files changed, 137 insertions(+), 8 deletions(-) diff --git a/python/finitediff-py/src/lib.rs b/python/finitediff-py/src/lib.rs index 948c17c29..d72aab2cc 100644 --- a/python/finitediff-py/src/lib.rs +++ b/python/finitediff-py/src/lib.rs @@ -1,6 +1,6 @@ use finitediff_rust::ndarr; use numpy::ndarray::Array1; -use numpy::{IntoPyArray, PyArray1}; +use numpy::{IntoPyArray, PyArray1, PyArray2}; use pyo3::{ exceptions::PyTypeError, prelude::*, @@ -38,9 +38,113 @@ fn forward_diff<'py>(py: Python<'py>, f: Py) -> PyResult<&'py PyCFunction } } +/// Central diff +#[pyfunction] +fn central_diff<'py>(py: Python<'py>, f: Py) -> PyResult<&'py PyCFunction> { + if f.as_ref(py).is_callable() { + PyCFunction::new_closure( + py, + None, + None, + move |args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult>> { + Python::with_gil(|py| { + let out = (ndarr::central_diff(|x: &Array1| -> f64 { + let x = PyArray1::from_array(py, x); + f.call(py, (x,), None).unwrap().extract(py).unwrap() + }))( + &args + .get_item(0)? + .downcast::>()? + .to_owned_array(), + ); + Ok(out.into_pyarray(py).into()) + }) + }, + ) + } else { + Err(PyErr::new::(format!( + "object {} not callable", + f.as_ref(py).get_type() + ))) + } +} + +/// Forward Jacobian +#[pyfunction] +fn forward_jacobian<'py>(py: Python<'py>, f: Py) -> PyResult<&'py PyCFunction> { + if f.as_ref(py).is_callable() { + PyCFunction::new_closure( + py, + None, + None, + move |args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult>> { + Python::with_gil(|py| { + let out = (ndarr::forward_jacobian(|x: &Array1| -> Array1 { + let x = PyArray1::from_array(py, x); + f.call(py, (x,), None) + .unwrap() + .extract::<&PyArray1>(py) + .unwrap() + .to_owned_array() + }))( + &args + .get_item(0)? + .downcast::>()? + .to_owned_array(), + ); + Ok(out.into_pyarray(py).into()) + }) + }, + ) + } else { + Err(PyErr::new::(format!( + "object {} not callable", + f.as_ref(py).get_type() + ))) + } +} + +/// Central Jacobian +#[pyfunction] +fn central_jacobian<'py>(py: Python<'py>, f: Py) -> PyResult<&'py PyCFunction> { + if f.as_ref(py).is_callable() { + PyCFunction::new_closure( + py, + None, + None, + move |args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult>> { + Python::with_gil(|py| { + let out = (ndarr::central_jacobian(|x: &Array1| -> Array1 { + let x = PyArray1::from_array(py, x); + f.call(py, (x,), None) + .unwrap() + .extract::<&PyArray1>(py) + .unwrap() + .to_owned_array() + }))( + &args + .get_item(0)? + .downcast::>()? + .to_owned_array(), + ); + Ok(out.into_pyarray(py).into()) + }) + }, + ) + } else { + Err(PyErr::new::(format!( + "object {} not callable", + f.as_ref(py).get_type() + ))) + } +} + /// A Python module implemented in Rust. #[pymodule] fn finitediff(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(forward_diff, m)?)?; + m.add_function(wrap_pyfunction!(central_diff, m)?)?; + m.add_function(wrap_pyfunction!(forward_jacobian, m)?)?; + m.add_function(wrap_pyfunction!(central_jacobian, m)?)?; Ok(()) } diff --git a/python/finitediff-py/test.py b/python/finitediff-py/test.py index a3d116c86..e9d0f2898 100644 --- a/python/finitediff-py/test.py +++ b/python/finitediff-py/test.py @@ -1,4 +1,4 @@ -from finitediff import forward_diff +from finitediff import forward_diff, central_diff, forward_jacobian, central_jacobian import numpy as np @@ -30,13 +30,38 @@ def blaah(self, x): x = np.array([1.0, 2.0]) print(g(x)) +g = central_diff(f) +x = np.array([1.0, 2.0]) +print(g(x)) -class NotCallable: - pass +def op(x): + return np.array( + [ + 2.0 * (x[1] ** 3 - x[0] ** 2), + 3.0 * (x[1] ** 3 - x[0] ** 2) + 2.0 * (x[2] ** 3 - x[1] ** 2), + 3.0 * (x[2] ** 3 - x[1] ** 2) + 2.0 * (x[3] ** 3 - x[2] ** 2), + 3.0 * (x[3] ** 3 - x[2] ** 2) + 2.0 * (x[4] ** 3 - x[3] ** 2), + 3.0 * (x[4] ** 3 - x[3] ** 2) + 2.0 * (x[5] ** 3 - x[4] ** 2), + 3.0 * (x[5] ** 3 - x[4] ** 2), + ] + ) -notcallable = NotCallable() -g = forward_diff(notcallable) -x = np.array([1.0, 2.0]) -print(g(x)) +j = forward_jacobian(op) +x = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) +print(j(x)) + +j = central_jacobian(op) +x = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) +print(j(x)) + +# class NotCallable: +# pass + + +# notcallable = NotCallable() + +# g = forward_diff(notcallable) +# x = np.array([1.0, 2.0]) +# print(g(x))