Skip to content

Commit

Permalink
Jacobian stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan-k committed Mar 8, 2024
1 parent da8f173 commit fb810f6
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 8 deletions.
106 changes: 105 additions & 1 deletion python/finitediff-py/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use finitediff_rust::ndarr;
use numpy::ndarray::Array1;
use numpy::{IntoPyArray, PyArray1};
use numpy::{IntoPyArray, PyArray1, PyArray2};
use pyo3::{
exceptions::PyTypeError,
prelude::*,
Expand Down Expand Up @@ -38,9 +38,113 @@ fn forward_diff<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunction
}
}

/// Central diff
#[pyfunction]
fn central_diff<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunction> {
if f.as_ref(py).is_callable() {
PyCFunction::new_closure(
py,
None,
None,
move |args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult<Py<PyArray1<f64>>> {
Python::with_gil(|py| {
let out = (ndarr::central_diff(|x: &Array1<f64>| -> f64 {
let x = PyArray1::from_array(py, x);
f.call(py, (x,), None).unwrap().extract(py).unwrap()
}))(
&args
.get_item(0)?
.downcast::<PyArray1<f64>>()?
.to_owned_array(),
);
Ok(out.into_pyarray(py).into())
})
},
)
} else {
Err(PyErr::new::<PyTypeError, _>(format!(
"object {} not callable",
f.as_ref(py).get_type()
)))
}
}

/// Forward Jacobian
#[pyfunction]
fn forward_jacobian<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunction> {
if f.as_ref(py).is_callable() {
PyCFunction::new_closure(
py,
None,
None,
move |args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult<Py<PyArray2<f64>>> {
Python::with_gil(|py| {
let out = (ndarr::forward_jacobian(|x: &Array1<f64>| -> Array1<f64> {
let x = PyArray1::from_array(py, x);
f.call(py, (x,), None)
.unwrap()
.extract::<&PyArray1<f64>>(py)
.unwrap()
.to_owned_array()
}))(
&args
.get_item(0)?
.downcast::<PyArray1<f64>>()?
.to_owned_array(),
);
Ok(out.into_pyarray(py).into())
})
},
)
} else {
Err(PyErr::new::<PyTypeError, _>(format!(
"object {} not callable",
f.as_ref(py).get_type()
)))
}
}

/// Central Jacobian
#[pyfunction]
fn central_jacobian<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunction> {
if f.as_ref(py).is_callable() {
PyCFunction::new_closure(
py,
None,
None,
move |args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult<Py<PyArray2<f64>>> {
Python::with_gil(|py| {
let out = (ndarr::central_jacobian(|x: &Array1<f64>| -> Array1<f64> {
let x = PyArray1::from_array(py, x);
f.call(py, (x,), None)
.unwrap()
.extract::<&PyArray1<f64>>(py)
.unwrap()
.to_owned_array()
}))(
&args
.get_item(0)?
.downcast::<PyArray1<f64>>()?
.to_owned_array(),
);
Ok(out.into_pyarray(py).into())
})
},
)
} else {
Err(PyErr::new::<PyTypeError, _>(format!(
"object {} not callable",
f.as_ref(py).get_type()
)))
}
}

/// A Python module implemented in Rust.
#[pymodule]
fn finitediff(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(forward_diff, m)?)?;
m.add_function(wrap_pyfunction!(central_diff, m)?)?;
m.add_function(wrap_pyfunction!(forward_jacobian, m)?)?;
m.add_function(wrap_pyfunction!(central_jacobian, m)?)?;
Ok(())
}
39 changes: 32 additions & 7 deletions python/finitediff-py/test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from finitediff import forward_diff
from finitediff import forward_diff, central_diff, forward_jacobian, central_jacobian
import numpy as np


Expand Down Expand Up @@ -30,13 +30,38 @@ def blaah(self, x):
x = np.array([1.0, 2.0])
print(g(x))

g = central_diff(f)
x = np.array([1.0, 2.0])
print(g(x))

class NotCallable:
pass

def op(x):
return np.array(
[
2.0 * (x[1] ** 3 - x[0] ** 2),
3.0 * (x[1] ** 3 - x[0] ** 2) + 2.0 * (x[2] ** 3 - x[1] ** 2),
3.0 * (x[2] ** 3 - x[1] ** 2) + 2.0 * (x[3] ** 3 - x[2] ** 2),
3.0 * (x[3] ** 3 - x[2] ** 2) + 2.0 * (x[4] ** 3 - x[3] ** 2),
3.0 * (x[4] ** 3 - x[3] ** 2) + 2.0 * (x[5] ** 3 - x[4] ** 2),
3.0 * (x[5] ** 3 - x[4] ** 2),
]
)

notcallable = NotCallable()

g = forward_diff(notcallable)
x = np.array([1.0, 2.0])
print(g(x))
j = forward_jacobian(op)
x = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
print(j(x))

j = central_jacobian(op)
x = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
print(j(x))

# class NotCallable:
# pass


# notcallable = NotCallable()

# g = forward_diff(notcallable)
# x = np.array([1.0, 2.0])
# print(g(x))

0 comments on commit fb810f6

Please sign in to comment.