diff --git a/encodings/runend/src/compute.rs b/encodings/runend/src/compute.rs index 89251e8011..b7863294eb 100644 --- a/encodings/runend/src/compute.rs +++ b/encodings/runend/src/compute.rs @@ -1,13 +1,16 @@ -use vortex_array::array::{ConstantArray, PrimitiveArray, SparseArray}; +use std::ops::AddAssign; + +use num_traits::AsPrimitive; +use vortex_array::array::{BoolArray, BooleanBuffer, ConstantArray, PrimitiveArray, SparseArray}; use vortex_array::compute::unary::{scalar_at, scalar_at_unchecked, ScalarAtFn}; use vortex_array::compute::{ - compare, filter, slice, take, ArrayCompute, MaybeCompareFn, Operator, SliceFn, TakeFn, + compare, filter, slice, take, ArrayCompute, FilterFn, MaybeCompareFn, Operator, SliceFn, TakeFn, }; use vortex_array::stats::{ArrayStatistics, Stat}; use vortex_array::validity::Validity; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant}; -use vortex_dtype::match_each_integer_ptype; +use vortex_dtype::{match_each_integer_ptype, match_each_unsigned_integer_ptype, NativePType}; use vortex_error::{VortexExpect as _, VortexResult}; use vortex_scalar::{Scalar, ScalarValue}; @@ -18,6 +21,10 @@ impl ArrayCompute for RunEndArray { MaybeCompareFn::maybe_compare(self, other, operator) } + fn filter(&self) -> Option<&dyn FilterFn> { + Some(self) + } + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { Some(self) } @@ -142,11 +149,58 @@ impl SliceFn for RunEndArray { } } +impl FilterFn for RunEndArray { + fn filter(&self, predicate: &ArrayData) -> VortexResult { + let primitive_run_ends = self.ends().into_primitive()?; + let (run_ends, pred) = match_each_unsigned_integer_ptype!(primitive_run_ends.ptype(), |$P| { + filter_run_ends(primitive_run_ends.maybe_null_slice::<$P>(), predicate)? + }); + let values = filter(self.values(), &pred)?; + let validity = self.validity().filter(predicate)?; + + RunEndArray::try_new(run_ends.into_array(), values, validity).map(|a| a.into_array()) + } +} + +// Code adapted from apache arrow-rs https://github.com/apache/arrow-rs/blob/b1f5c250ebb6c1252b4e7c51d15b8e77f4c361fa/arrow-select/src/filter.rs#L425 +fn filter_run_ends + AsPrimitive>( + run_ends: &[R], + predicate: &ArrayData, +) -> VortexResult<(PrimitiveArray, BoolArray)> { + let mut new_run_ends = vec![R::zero(); run_ends.len()]; + + let mut start = 0u64; + let mut j = 0; + let mut count = R::zero(); + let filter_values = predicate.clone().into_bool()?.boolean_buffer(); + + let pred: BoolArray = BooleanBuffer::collect_bool(run_ends.len(), |i| { + let mut keep = false; + let end = run_ends[i].as_(); + + // Safety: predicate must be the same length as the array the ends have been taken from + for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) { + count += >::from(pred); + keep |= pred + } + // this is to avoid branching + new_run_ends[j] = count; + j += keep as usize; + + start = end; + keep + }) + .into(); + + new_run_ends.truncate(j); + Ok((PrimitiveArray::from(new_run_ends), pred)) +} + #[cfg(test)] mod test { use vortex_array::array::{BoolArray, PrimitiveArray}; use vortex_array::compute::unary::{scalar_at, try_cast}; - use vortex_array::compute::{slice, take}; + use vortex_array::compute::{filter, slice, take}; use vortex_array::validity::{ArrayValidity, Validity}; use vortex_array::{ArrayDType, IntoArrayData, IntoArrayVariant, ToArrayData}; use vortex_dtype::{DType, Nullability, PType}; @@ -380,4 +434,35 @@ mod test { assert_eq!(scalar_at(taken.as_ref(), 1).unwrap(), 2.into()); assert_eq!(scalar_at(taken.as_ref(), 2).unwrap(), 5.into()); } + + #[test] + fn filter_run_end() { + let arr = ree_array(); + let filtered = filter( + arr, + BoolArray::from_iter([ + true, true, false, false, false, false, false, false, false, false, true, true, + ]) + .into_array(), + ) + .unwrap(); + let filtered_run_end = RunEndArray::try_from(filtered).unwrap(); + + assert_eq!( + filtered_run_end + .ends() + .into_primitive() + .unwrap() + .maybe_null_slice::(), + [2, 4] + ); + assert_eq!( + filtered_run_end + .values() + .into_primitive() + .unwrap() + .maybe_null_slice::(), + [1, 5] + ); + } }