From 98557bc3372a356a95fa33e907efd53795fa9958 Mon Sep 17 00:00:00 2001 From: Stefan Kroboth Date: Fri, 8 Mar 2024 09:11:54 +0100 Subject: [PATCH] Error handling, other improvements --- python/finitediff-py/Cargo.toml | 3 +- python/finitediff-py/src/lib.rs | 105 ++++++++++++++------------------ python/finitediff-py/test.py | 1 + 3 files changed, 48 insertions(+), 61 deletions(-) diff --git a/python/finitediff-py/Cargo.toml b/python/finitediff-py/Cargo.toml index 367d0da45..72a62f080 100644 --- a/python/finitediff-py/Cargo.toml +++ b/python/finitediff-py/Cargo.toml @@ -21,4 +21,5 @@ crate-type = ["cdylib"] [dependencies] finitediff_rust = { package = "finitediff", version = "0.1.4", path = "../../crates/finitediff", features = ["ndarray"] } numpy = "0.20.0" -pyo3 = { version = "0.20.0", features = ["extension-module"] } +pyo3 = { version = "0.20.0", features = ["extension-module", "anyhow"] } +anyhow = "1.0" diff --git a/python/finitediff-py/src/lib.rs b/python/finitediff-py/src/lib.rs index d72aab2cc..a9324b284 100644 --- a/python/finitediff-py/src/lib.rs +++ b/python/finitediff-py/src/lib.rs @@ -1,12 +1,29 @@ +use anyhow::Error; use finitediff_rust::ndarr; -use numpy::ndarray::Array1; -use numpy::{IntoPyArray, PyArray1, PyArray2}; +use numpy::{ndarray::Array1, IntoPyArray, PyArray1, PyArray2}; use pyo3::{ exceptions::PyTypeError, prelude::*, types::{PyCFunction, PyDict, PyTuple}, }; +fn process_args(args: &PyTuple) -> PyResult> { + Ok(args + .get_item(0) + .map_err(|_| PyErr::new::("Insufficient number of arguments"))? + .downcast::>()? + .to_owned_array()) +} + +macro_rules! not_callable { + ($py:ident, $f:ident) => { + Err(PyErr::new::(format!( + "object {} not callable", + $f.as_ref($py).get_type() + ))) + }; +} + /// Forward diff #[pyfunction] fn forward_diff<'py>(py: Python<'py>, f: Py) -> PyResult<&'py PyCFunction> { @@ -17,24 +34,16 @@ fn forward_diff<'py>(py: Python<'py>, f: Py) -> PyResult<&'py PyCFunction None, move |args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult>> { Python::with_gil(|py| { - let out = (ndarr::forward_diff(|x: &Array1| -> f64 { + let out = (ndarr::forward_diff(&|x: &Array1| -> Result { let x = PyArray1::from_array(py, x); - f.call(py, (x,), None).unwrap().extract(py).unwrap() - }))( - &args - .get_item(0)? - .downcast::>()? - .to_owned_array(), - ); + Ok(f.call(py, (x,), None)?.extract(py)?) + }))(&process_args(args)?)?; Ok(out.into_pyarray(py).into()) }) }, ) } else { - Err(PyErr::new::(format!( - "object {} not callable", - f.as_ref(py).get_type() - ))) + not_callable!(py, f) } } @@ -48,24 +57,16 @@ fn central_diff<'py>(py: Python<'py>, f: Py) -> PyResult<&'py PyCFunction None, move |args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult>> { Python::with_gil(|py| { - let out = (ndarr::central_diff(|x: &Array1| -> f64 { + let out = (ndarr::central_diff(&|x: &Array1| -> Result { let x = PyArray1::from_array(py, x); - f.call(py, (x,), None).unwrap().extract(py).unwrap() - }))( - &args - .get_item(0)? - .downcast::>()? - .to_owned_array(), - ); + Ok(f.call(py, (x,), None)?.extract(py)?) + }))(&process_args(args)?)?; Ok(out.into_pyarray(py).into()) }) }, ) } else { - Err(PyErr::new::(format!( - "object {} not callable", - f.as_ref(py).get_type() - ))) + not_callable!(py, f) } } @@ -79,28 +80,20 @@ fn forward_jacobian<'py>(py: Python<'py>, f: Py) -> PyResult<&'py PyCFunc None, move |args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult>> { Python::with_gil(|py| { - let out = (ndarr::forward_jacobian(|x: &Array1| -> Array1 { - let x = PyArray1::from_array(py, x); - f.call(py, (x,), None) - .unwrap() - .extract::<&PyArray1>(py) - .unwrap() - .to_owned_array() - }))( - &args - .get_item(0)? - .downcast::>()? - .to_owned_array(), - ); + let out = (ndarr::forward_jacobian( + &|x: &Array1| -> Result, Error> { + let x = PyArray1::from_array(py, x); + Ok(f.call(py, (x,), None)? + .extract::<&PyArray1>(py)? + .to_owned_array()) + }, + ))(&process_args(args)?)?; Ok(out.into_pyarray(py).into()) }) }, ) } else { - Err(PyErr::new::(format!( - "object {} not callable", - f.as_ref(py).get_type() - ))) + not_callable!(py, f) } } @@ -114,28 +107,20 @@ fn central_jacobian<'py>(py: Python<'py>, f: Py) -> PyResult<&'py PyCFunc None, move |args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult>> { Python::with_gil(|py| { - let out = (ndarr::central_jacobian(|x: &Array1| -> Array1 { - let x = PyArray1::from_array(py, x); - f.call(py, (x,), None) - .unwrap() - .extract::<&PyArray1>(py) - .unwrap() - .to_owned_array() - }))( - &args - .get_item(0)? - .downcast::>()? - .to_owned_array(), - ); + let out = (ndarr::central_jacobian( + &|x: &Array1| -> Result, Error> { + let x = PyArray1::from_array(py, x); + Ok(f.call(py, (x,), None)? + .extract::<&PyArray1>(py)? + .to_owned_array()) + }, + ))(&process_args(args)?)?; Ok(out.into_pyarray(py).into()) }) }, ) } else { - Err(PyErr::new::(format!( - "object {} not callable", - f.as_ref(py).get_type() - ))) + not_callable!(py, f) } } diff --git a/python/finitediff-py/test.py b/python/finitediff-py/test.py index e9d0f2898..2197efc18 100644 --- a/python/finitediff-py/test.py +++ b/python/finitediff-py/test.py @@ -56,6 +56,7 @@ def op(x): x = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) print(j(x)) + # class NotCallable: # pass