Skip to content

Commit

Permalink
Error handling, other improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan-k committed Mar 8, 2024
1 parent fb810f6 commit 98557bc
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 61 deletions.
3 changes: 2 additions & 1 deletion python/finitediff-py/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ crate-type = ["cdylib"]
[dependencies]
finitediff_rust = { package = "finitediff", version = "0.1.4", path = "../../crates/finitediff", features = ["ndarray"] }
numpy = "0.20.0"
pyo3 = { version = "0.20.0", features = ["extension-module"] }
pyo3 = { version = "0.20.0", features = ["extension-module", "anyhow"] }
anyhow = "1.0"
105 changes: 45 additions & 60 deletions python/finitediff-py/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,29 @@
use anyhow::Error;
use finitediff_rust::ndarr;
use numpy::ndarray::Array1;
use numpy::{IntoPyArray, PyArray1, PyArray2};
use numpy::{ndarray::Array1, IntoPyArray, PyArray1, PyArray2};
use pyo3::{
exceptions::PyTypeError,
prelude::*,
types::{PyCFunction, PyDict, PyTuple},
};

fn process_args(args: &PyTuple) -> PyResult<Array1<f64>> {
Ok(args
.get_item(0)
.map_err(|_| PyErr::new::<PyTypeError, _>("Insufficient number of arguments"))?
.downcast::<PyArray1<f64>>()?
.to_owned_array())
}

macro_rules! not_callable {
($py:ident, $f:ident) => {
Err(PyErr::new::<PyTypeError, _>(format!(
"object {} not callable",
$f.as_ref($py).get_type()
)))
};
}

/// Forward diff
#[pyfunction]
fn forward_diff<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunction> {
Expand All @@ -17,24 +34,16 @@ fn forward_diff<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunction
None,
move |args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult<Py<PyArray1<f64>>> {
Python::with_gil(|py| {
let out = (ndarr::forward_diff(|x: &Array1<f64>| -> f64 {
let out = (ndarr::forward_diff(&|x: &Array1<f64>| -> Result<f64, Error> {
let x = PyArray1::from_array(py, x);
f.call(py, (x,), None).unwrap().extract(py).unwrap()
}))(
&args
.get_item(0)?
.downcast::<PyArray1<f64>>()?
.to_owned_array(),
);
Ok(f.call(py, (x,), None)?.extract(py)?)
}))(&process_args(args)?)?;
Ok(out.into_pyarray(py).into())
})
},
)
} else {
Err(PyErr::new::<PyTypeError, _>(format!(
"object {} not callable",
f.as_ref(py).get_type()
)))
not_callable!(py, f)
}
}

Expand All @@ -48,24 +57,16 @@ fn central_diff<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunction
None,
move |args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult<Py<PyArray1<f64>>> {
Python::with_gil(|py| {
let out = (ndarr::central_diff(|x: &Array1<f64>| -> f64 {
let out = (ndarr::central_diff(&|x: &Array1<f64>| -> Result<f64, Error> {
let x = PyArray1::from_array(py, x);
f.call(py, (x,), None).unwrap().extract(py).unwrap()
}))(
&args
.get_item(0)?
.downcast::<PyArray1<f64>>()?
.to_owned_array(),
);
Ok(f.call(py, (x,), None)?.extract(py)?)
}))(&process_args(args)?)?;
Ok(out.into_pyarray(py).into())
})
},
)
} else {
Err(PyErr::new::<PyTypeError, _>(format!(
"object {} not callable",
f.as_ref(py).get_type()
)))
not_callable!(py, f)
}
}

Expand All @@ -79,28 +80,20 @@ fn forward_jacobian<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunc
None,
move |args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult<Py<PyArray2<f64>>> {
Python::with_gil(|py| {
let out = (ndarr::forward_jacobian(|x: &Array1<f64>| -> Array1<f64> {
let x = PyArray1::from_array(py, x);
f.call(py, (x,), None)
.unwrap()
.extract::<&PyArray1<f64>>(py)
.unwrap()
.to_owned_array()
}))(
&args
.get_item(0)?
.downcast::<PyArray1<f64>>()?
.to_owned_array(),
);
let out = (ndarr::forward_jacobian(
&|x: &Array1<f64>| -> Result<Array1<f64>, Error> {
let x = PyArray1::from_array(py, x);
Ok(f.call(py, (x,), None)?
.extract::<&PyArray1<f64>>(py)?
.to_owned_array())
},
))(&process_args(args)?)?;
Ok(out.into_pyarray(py).into())
})
},
)
} else {
Err(PyErr::new::<PyTypeError, _>(format!(
"object {} not callable",
f.as_ref(py).get_type()
)))
not_callable!(py, f)
}
}

Expand All @@ -114,28 +107,20 @@ fn central_jacobian<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunc
None,
move |args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult<Py<PyArray2<f64>>> {
Python::with_gil(|py| {
let out = (ndarr::central_jacobian(|x: &Array1<f64>| -> Array1<f64> {
let x = PyArray1::from_array(py, x);
f.call(py, (x,), None)
.unwrap()
.extract::<&PyArray1<f64>>(py)
.unwrap()
.to_owned_array()
}))(
&args
.get_item(0)?
.downcast::<PyArray1<f64>>()?
.to_owned_array(),
);
let out = (ndarr::central_jacobian(
&|x: &Array1<f64>| -> Result<Array1<f64>, Error> {
let x = PyArray1::from_array(py, x);
Ok(f.call(py, (x,), None)?
.extract::<&PyArray1<f64>>(py)?
.to_owned_array())
},
))(&process_args(args)?)?;
Ok(out.into_pyarray(py).into())
})
},
)
} else {
Err(PyErr::new::<PyTypeError, _>(format!(
"object {} not callable",
f.as_ref(py).get_type()
)))
not_callable!(py, f)
}
}

Expand Down
1 change: 1 addition & 0 deletions python/finitediff-py/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def op(x):
x = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
print(j(x))


# class NotCallable:
# pass

Expand Down

0 comments on commit 98557bc

Please sign in to comment.