Skip to content

Commit

Permalink
Implement filter for RunEnd array (#1342)
Browse files Browse the repository at this point in the history
  • Loading branch information
robert3005 authored Nov 17, 2024
1 parent ff6c440 commit eee490e
Showing 1 changed file with 89 additions and 4 deletions.
93 changes: 89 additions & 4 deletions encodings/runend/src/compute.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
use vortex_array::array::{ConstantArray, PrimitiveArray, SparseArray};
use std::ops::AddAssign;

use num_traits::AsPrimitive;
use vortex_array::array::{BoolArray, BooleanBuffer, ConstantArray, PrimitiveArray, SparseArray};
use vortex_array::compute::unary::{scalar_at, scalar_at_unchecked, ScalarAtFn};
use vortex_array::compute::{
compare, filter, slice, take, ArrayCompute, MaybeCompareFn, Operator, SliceFn, TakeFn,
compare, filter, slice, take, ArrayCompute, FilterFn, MaybeCompareFn, Operator, SliceFn, TakeFn,
};
use vortex_array::stats::{ArrayStatistics, Stat};
use vortex_array::validity::Validity;
use vortex_array::variants::PrimitiveArrayTrait;
use vortex_array::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant};
use vortex_dtype::match_each_integer_ptype;
use vortex_dtype::{match_each_integer_ptype, match_each_unsigned_integer_ptype, NativePType};
use vortex_error::{VortexExpect as _, VortexResult};
use vortex_scalar::{Scalar, ScalarValue};

Expand All @@ -18,6 +21,10 @@ impl ArrayCompute for RunEndArray {
MaybeCompareFn::maybe_compare(self, other, operator)
}

fn filter(&self) -> Option<&dyn FilterFn> {
Some(self)
}

fn scalar_at(&self) -> Option<&dyn ScalarAtFn> {
Some(self)
}
Expand Down Expand Up @@ -142,11 +149,58 @@ impl SliceFn for RunEndArray {
}
}

impl FilterFn for RunEndArray {
fn filter(&self, predicate: &ArrayData) -> VortexResult<ArrayData> {
let primitive_run_ends = self.ends().into_primitive()?;
let (run_ends, pred) = match_each_unsigned_integer_ptype!(primitive_run_ends.ptype(), |$P| {
filter_run_ends(primitive_run_ends.maybe_null_slice::<$P>(), predicate)?
});
let values = filter(self.values(), &pred)?;
let validity = self.validity().filter(predicate)?;

RunEndArray::try_new(run_ends.into_array(), values, validity).map(|a| a.into_array())
}
}

// Code adapted from apache arrow-rs https://github.com/apache/arrow-rs/blob/b1f5c250ebb6c1252b4e7c51d15b8e77f4c361fa/arrow-select/src/filter.rs#L425
fn filter_run_ends<R: NativePType + AddAssign + From<bool> + AsPrimitive<u64>>(
run_ends: &[R],
predicate: &ArrayData,
) -> VortexResult<(PrimitiveArray, BoolArray)> {
let mut new_run_ends = vec![R::zero(); run_ends.len()];

let mut start = 0u64;
let mut j = 0;
let mut count = R::zero();
let filter_values = predicate.clone().into_bool()?.boolean_buffer();

let pred: BoolArray = BooleanBuffer::collect_bool(run_ends.len(), |i| {
let mut keep = false;
let end = run_ends[i].as_();

// Safety: predicate must be the same length as the array the ends have been taken from
for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) {
count += <R as From<bool>>::from(pred);
keep |= pred
}
// this is to avoid branching
new_run_ends[j] = count;
j += keep as usize;

start = end;
keep
})
.into();

new_run_ends.truncate(j);
Ok((PrimitiveArray::from(new_run_ends), pred))
}

#[cfg(test)]
mod test {
use vortex_array::array::{BoolArray, PrimitiveArray};
use vortex_array::compute::unary::{scalar_at, try_cast};
use vortex_array::compute::{slice, take};
use vortex_array::compute::{filter, slice, take};
use vortex_array::validity::{ArrayValidity, Validity};
use vortex_array::{ArrayDType, IntoArrayData, IntoArrayVariant, ToArrayData};
use vortex_dtype::{DType, Nullability, PType};
Expand Down Expand Up @@ -380,4 +434,35 @@ mod test {
assert_eq!(scalar_at(taken.as_ref(), 1).unwrap(), 2.into());
assert_eq!(scalar_at(taken.as_ref(), 2).unwrap(), 5.into());
}

#[test]
fn filter_run_end() {
let arr = ree_array();
let filtered = filter(
arr,
BoolArray::from_iter([
true, true, false, false, false, false, false, false, false, false, true, true,
])
.into_array(),
)
.unwrap();
let filtered_run_end = RunEndArray::try_from(filtered).unwrap();

assert_eq!(
filtered_run_end
.ends()
.into_primitive()
.unwrap()
.maybe_null_slice::<u64>(),
[2, 4]
);
assert_eq!(
filtered_run_end
.values()
.into_primitive()
.unwrap()
.maybe_null_slice::<i32>(),
[1, 5]
);
}
}

0 comments on commit eee490e

Please sign in to comment.