Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/412 add faer rs backend #549

Open
wants to merge 37 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
2f62998
start with impl: add
geo-ant Dec 30, 2024
80eccfa
more and more things implemented
geo-ant Dec 30, 2024
0d95253
coming to grips with component wise operations
geo-ant Dec 30, 2024
2a56af4
I think I understand zipped now... kinda
geo-ant Dec 30, 2024
0a7181e
iterate to better solutions, expand scalar stuff
geo-ant Dec 30, 2024
b580fd2
fix division code
geo-ant Dec 30, 2024
f70510a
fix dot product
geo-ant Dec 30, 2024
38fc00c
more fixes, I think
geo-ant Dec 30, 2024
e8fb561
add argmin random impl
geo-ant Dec 30, 2024
56bf173
more modules
geo-ant Dec 31, 2024
692e722
I think I've implemented everything that nalgebra has implemented. No…
geo-ant Dec 31, 2024
8893c5b
add tests for zero
geo-ant Dec 31, 2024
0f19649
translate tests for transposition from nalgebra
geo-ant Dec 31, 2024
52e06a6
fix some clippy lints
geo-ant Dec 31, 2024
3896fbc
fix more clippy lints
geo-ant Dec 31, 2024
5384c09
add tests for subtraction
geo-ant Dec 31, 2024
d14faf7
more tests
geo-ant Dec 31, 2024
40b7fa1
add tests for random
geo-ant Dec 31, 2024
7f33575
add forgotten multiplication impl
geo-ant Dec 31, 2024
c086960
fix errors in matrix multiplication
geo-ant Dec 31, 2024
44a2eca
add tests for minmax
geo-ant Dec 31, 2024
f095d5c
add tests for l2 norm
geo-ant Dec 31, 2024
42250e7
add tests for l1 norm
geo-ant Dec 31, 2024
a3b33ea
start with tests for inverse (partially failing)
geo-ant Jan 1, 2025
e06be9f
different matrix inversion logic and tests
geo-ant Jan 1, 2025
43baf01
tests for identity matrix
geo-ant Jan 1, 2025
b032a77
start with dot prod
geo-ant Jan 1, 2025
e92b0c5
work on scalar product
geo-ant Jan 2, 2025
6dcedee
further work on dot product, failing tests
geo-ant Jan 2, 2025
2d6d276
finish argmin-dot
geo-ant Jan 3, 2025
0607688
add tests for pointwise div
geo-ant Jan 3, 2025
271ef94
fix conj and add tests
geo-ant Jan 3, 2025
8f45102
docs, all tests passing, almost finished
geo-ant Jan 3, 2025
d8ffa1f
docs
geo-ant Jan 3, 2025
c7c27ef
add forgotten implementation
geo-ant Jan 3, 2025
47e182a
conditionally enable faer
geo-ant Jan 3, 2025
146db74
add faer tests to ci
geo-ant Jan 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ jobs:
run: cargo test -p argmin-math --no-default-features --features "nalgebra_v0_30"
- name: argmin-math (nalgebra_v0_29)
run: cargo test -p argmin-math --no-default-features --features "nalgebra_v0_29"
# faer
- name: argmin-math (faer_latest)
run: cargo test -p argmin-math --no-default-features --features "faer_latest"
- name: argmin-math (faer_v0_20)
run: cargo test -p argmin-math --no-default-features --features "faer_v0_20"

clippy:
runs-on: ubuntu-latest
Expand Down
8 changes: 8 additions & 0 deletions crates/argmin-math/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ ndarray_0_14 = { package = "ndarray", version = "0.14", optional = true }
## v0.13
ndarray_0_13 = { package = "ndarray", version = "0.13", optional = true }

#faer
faer_0_20 = { package = "faer", version = "0.20", optional = true}

# general
num-complex_0_4 = { package = "num-complex", version = "0.4", optional = true, default-features = false, features = ["std"] }
num-complex_0_3 = { package = "num-complex", version = "0.3", optional = true, default-features = false, features = ["std"] }
Expand Down Expand Up @@ -69,6 +72,11 @@ nalgebra_v0_29 = ["nalgebra_0_29", "num-complex_0_4", "nalgebra_all"]
ndarray_all = ["primitives"]
ndarray_latest = ["ndarray_v0_15"]

#faer
faer_all = ["primitives"]
faer_latest = ["faer_v0_20"]
faer_v0_20 = ["faer_0_20", "num-complex_0_4", "faer_all"]

## With `ndarray-linalg`
ndarray_v0_15 = ["ndarray_0_15", "ndarray-linalg_0_16", "num-complex_0_4", "ndarray_all"]

Expand Down
2 changes: 1 addition & 1 deletion crates/argmin-math/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@


This create provides a abstractions for mathematical operations needed in [argmin](https://argmin-rs.org).
The supported math backends so far are basic `Vec`s, `ndarray` and `nalgebra`.
The supported math backends so far are basic `Vec`s, `ndarray`, `nalgebra`, and `faer`.
Please consult the documentation for details.


Expand Down
278 changes: 278 additions & 0 deletions crates/argmin-math/src/faer_m/add.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
use crate::ArgminAdd;
use faer::{
mat::{AsMatMut, AsMatRef},
reborrow::{IntoConst, Reborrow, ReborrowMut},
unzipped, zipped, zipped_rw, ComplexField, Conjugate, Entity, Mat, MatMut, MatRef,
SimpleEntity,
};
use std::ops::{Add, AddAssign};

/// MatRef + Scalar -> Mat
impl<E, R, C> ArgminAdd<E, Mat<E, R, C>> for MatRef<'_, E, R, C>
where
E: Entity + Add<E, Output = E>,
R: faer::Shape,
C: faer::Shape,
{
#[inline]
fn add(&self, other: &E) -> Mat<E, R, C> {
zipped_rw!(self).map(|unzipped!(this)| this.read() + *other)
}
}

/// Scaler + MatRef-> Mat
impl<'a, E, R, C> ArgminAdd<MatRef<'a, E, R, C>, Mat<E, R, C>> for E
where
E: Entity + Add<E, Output = E>,
R: faer::Shape,
C: faer::Shape,
{
#[inline]
fn add(&self, other: &MatRef<'a, E, R, C>) -> Mat<E, R, C> {
// commutative with MatRef + Scalar so we can fall back on that case
<_ as ArgminAdd<_, _>>::add(other, self)
}
}

/// Mat + Scalar -> Mat
impl<E, R, C> ArgminAdd<E, Mat<E, R, C>> for Mat<E, R, C>
where
E: Entity + Add<E, Output = E>,
R: faer::Shape,
C: faer::Shape,
{
#[inline]
fn add(&self, other: &E) -> Mat<E, R, C> {
//@note(geo-ant) because we are taking self by reference we
// cannot mutate the matrix in place, so we can just as well
// reuse the reference code
<_ as ArgminAdd<_, _>>::add(&self.as_mat_ref(), other)
}
}

/// Scalar + Mat -> Mat
impl<E, R, C> ArgminAdd<Mat<E, R, C>, Mat<E, R, C>> for E
where
E: Entity + Add<E, Output = E>,
R: faer::Shape,
C: faer::Shape,
{
#[inline]
fn add(&self, other: &Mat<E, R, C>) -> Mat<E, R, C> {
// commutative with Mat + Scalar so we can fall back on that case
<_ as ArgminAdd<_, _>>::add(other, self)
}
}

/// MatRef + MatRef -> Mat
impl<'a, E> ArgminAdd<MatRef<'a, E>, Mat<E>> for MatRef<'_, E>
where
E: Entity + ComplexField,
{
#[inline]
fn add(&self, other: &MatRef<'a, E>) -> Mat<E> {
self + other
}
}

/// MatRef + Mat -> Mat
impl<E: Entity + ComplexField> ArgminAdd<Mat<E>, Mat<E>> for MatRef<'_, E> {
#[inline]
fn add(&self, other: &Mat<E>) -> Mat<E> {
self + other
}
}

/// Mat + MatRef -> Mat
impl<E: Entity + ComplexField> ArgminAdd<MatRef<'_, E>, Mat<E>> for Mat<E> {
#[inline]
fn add(&self, other: &MatRef<'_, E>) -> Mat<E> {
self + other
}
}

/// Mat + Mat -> Mat
impl<E: Entity + ComplexField> ArgminAdd<Mat<E>, Mat<E>> for Mat<E> {
#[inline]
fn add(&self, other: &Mat<E>) -> Mat<E> {
self + other
}
}

#[cfg(test)]
mod tests {
use super::super::test_helper::*;
use super::*;
use approx::assert_relative_eq;
use faer::mat::AsMatRef;
use paste::item;

macro_rules! make_test {
($t:ty) => {
item! {
#[test]
fn [<test_add_vec_scalar_ $t>]() {
let a = vector3_new(1 as $t, 4 as $t, 8 as $t);
let b = 34 as $t;
let target = vector3_new(35 as $t, 38 as $t, 42 as $t);
let res1 = <_ as ArgminAdd<$t, _>>::add(&a, &b);
let res2 = <_ as ArgminAdd<$t, _>>::add(&a.as_mat_ref(), &b);
assert_eq!(res1, res2);
assert_eq!(res1.nrows(), 3);
assert_eq!(res1.ncols(), 1);
for i in 0..3 {
assert_relative_eq!(target[(i,0)] as f64, res1[(i,0)] as f64, epsilon = f64::EPSILON);
}
}
}

item! {
#[test]
fn [<test_add_scalar_vec_ $t>]() {
let a = vector3_new(1 as $t, 4 as $t, 8 as $t);
let b = 34 as $t;
let target = vector3_new(35 as $t, 38 as $t, 42 as $t);
let res1 = <_ as ArgminAdd<_, _>>::add(&b, &a);
let res2 = <_ as ArgminAdd<_, _>>::add(&b, &a.as_mat_ref());
assert_eq!(res1, res2);
assert_eq!(res1.nrows(), 3);
assert_eq!(res1.ncols(), 1);
for i in 0..3 {
assert_relative_eq!(target[(i,0)] as f64, res1[(i,0)] as f64, epsilon = f64::EPSILON);
}
}
}

item! {
#[test]
fn [<test_add_vec_vec_ $t>]() {
let a = vector3_new(1 as $t, 4 as $t, 8 as $t);
let b = vector3_new(41 as $t, 38 as $t, 34 as $t);
let target = vector3_new(42 as $t, 42 as $t, 42 as $t);
let res = <_ as ArgminAdd<_, _>>::add(&a, &b);
for i in 0..3 {
assert_relative_eq!(target[(i,0)] as f64, res[(i,0)] as f64, epsilon = f64::EPSILON);
}
}
}

item! {
#[test]
#[should_panic]
fn [<test_add_vec_vec_panic_ $t>]() {
let a = column_vector_from_vec(vec![1 as $t, 4 as $t]);
let b = column_vector_from_vec(vec![41 as $t, 38 as $t, 34 as $t]);
<_ as ArgminAdd<_,_>>::add(&a, &b);
}
}

item! {
#[test]
#[should_panic]
fn [<test_add_vec_vec_panic_2_ $t>]() {
let a = column_vector_from_vec(vec![]);
let b = column_vector_from_vec(vec![41 as $t, 38 as $t, 34 as $t]);
<_ as ArgminAdd<_, _>>::add(&a, &b);
}
}

item! {
#[test]
#[should_panic]
fn [<test_add_vec_vec_panic_3_ $t>]() {
let a = column_vector_from_vec(vec![41 as $t, 38 as $t, 34 as $t]);
let b = column_vector_from_vec(vec![]);
<_ as ArgminAdd<_, _>>::add(&a, &b);
}
}

item! {
#[test]
fn [<test_add_mat_mat_ $t>]() {
let a = matrix2x3_new(
1 as $t, 4 as $t, 8 as $t,
2 as $t, 5 as $t, 9 as $t
);
let b = matrix2x3_new(
41 as $t, 38 as $t, 34 as $t,
40 as $t, 37 as $t, 33 as $t
);
let target = matrix2x3_new(
42 as $t, 42 as $t, 42 as $t,
42 as $t, 42 as $t, 42 as $t
);
let res1 = <_ as ArgminAdd<_, _>>::add(&a, &b);
let res2 = <_ as ArgminAdd<_, _>>::add(&a.as_mat_ref(), &b);
let res3 = <_ as ArgminAdd<_, _>>::add(&a, &b.as_mat_ref());
let res4 = <_ as ArgminAdd<_, _>>::add(&a.as_mat_ref(), &b.as_mat_ref());
assert_eq!(res1, res2);
assert_eq!(res1, res3);
assert_eq!(res1, res4);
assert_eq!(res1.nrows(), 2);
assert_eq!(res1.ncols(), 3);
for i in 0..3 {
for j in 0..2 {
assert_relative_eq!(target[(j, i)] as f64, res1[(j, i)] as f64, epsilon = f64::EPSILON);
}
}
}
}

item! {
#[test]
fn [<test_add_mat_scalar_ $t>]() {
let a = matrix2x3_new(
1 as $t, 4 as $t, 8 as $t,
2 as $t, 5 as $t, 9 as $t
);
let b = 2 as $t;
let target = matrix2x3_new(
3 as $t, 6 as $t, 10 as $t,
4 as $t, 7 as $t, 11 as $t
);
let res1 = <_ as ArgminAdd<$t, _>>::add(&a, &b);
let res2 = <_ as ArgminAdd<$t, _>>::add(&a.as_mat_ref(), &b);
assert_eq!(res1, res2);
assert_eq!(res1.nrows(), 2);
assert_eq!(res1.ncols(), 3);
for i in 0..3 {
for j in 0..2 {
assert_relative_eq!(target[(j, i)] as f64, res1[(j, i)] as f64, epsilon = f64::EPSILON);
}
}
}
}

item! {
#[test]
#[should_panic]
fn [<test_add_mat_mat_panic_2_ $t>]() {
let a = faer::mat![
[1 as $t, 4 as $t, 8 as $t],
[2 as $t, 5 as $t, 9 as $t]
];
let b = faer::mat![
[41 as $t, 38 as $t]
];
<_ as ArgminAdd<_, _>>::add(&a, &b);
}
}

item! {
#[test]
#[should_panic]
fn [<test_add_mat_mat_panic_3_ $t>]() {
let a = faer::mat![
[1 as $t, 4 as $t, 8 as $t],
[2 as $t, 5 as $t, 9 as $t]
];
let b = faer::Mat::new();
<_ as ArgminAdd<_, _>>::add(&a, &b);
}
}
};
}

make_test!(f32);
make_test!(f64);
}
68 changes: 68 additions & 0 deletions crates/argmin-math/src/faer_m/conj.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
use crate::ArgminConj;
use faer::{
mat::AsMatRef, reborrow::ReborrowMut, unzipped, Conjugate, Entity, Mat, MatMut, MatRef,
SimpleEntity,
};
use num_complex::ComplexFloat;

impl<E: Entity + num_complex::ComplexFloat> ArgminConj for Mat<E> {
#[inline]
fn conj(&self) -> Self {
//@note(geo-ant): we can't directly use the `conjugate()' function
// on the MatRef struct since it's not guaranteed to return matrix same type.
// Thus, we implement the conjugation using the num-complex trait manually
faer::zipped_rw!(self).map(|unzipped!(this)| ComplexFloat::conj(this.read()))
}
}

#[cfg(test)]
mod tests {
use super::super::test_helper::*;
use super::*;
use approx::assert_relative_eq;
use faer::linalg::entity::complex_split::ComplexConj;
use num_complex::Complex;
use paste::item;

macro_rules! make_test {
($t:ty) => {
item! {
#[test]
fn [<test_conj_complex_faer_ $t>]() {
let a : Mat<Complex<$t>> = vector3_new(
Complex::new(1 as $t, 2 as $t),
Complex::new(4 as $t, -3 as $t),
Complex::new(8 as $t, 0 as $t)
);
let b = vector3_new(
Complex::new(1 as $t, -2 as $t),
Complex::new(4 as $t, 3 as $t),
Complex::new(8 as $t, 0 as $t)
);
let res: Mat<_> = <Mat<Complex<$t>> as ArgminConj>::conj(&a);
assert_eq!(res.nrows(),3);
assert_eq!(res.ncols(),1);
for i in 0..3 {
assert_relative_eq!(b.read(i,0).re(), res.read(i,0).re(), epsilon = $t::EPSILON);
assert_relative_eq!(b.read(i,0).im(), res.read(i,0).im(), epsilon = $t::EPSILON);
}
}
}

item! {
#[test]
fn [<test_conj_faer_ $t>]() {
let a = vector3_new(1 as $t, 4 as $t, 8 as $t);
let b = vector3_new(1 as $t, 4 as $t, 8 as $t);
let res = <_ as ArgminConj>::conj(&a);
for i in 0..3 {
assert_relative_eq!(b[(i,0)], res[(i,0)], epsilon = $t::EPSILON);
}
}
}
};
}

make_test!(f32);
make_test!(f64);
}
Loading
Loading