Skip to content

Commit

Permalink
Filter mask (#1327)
Browse files Browse the repository at this point in the history
Add a memoized filter mask.

FLUPs:
- [ ] Swap to a single `iter()` function that returns an Enum of slices
or indices based on selectivity. This forces all consumers to support
both.
  • Loading branch information
gatesn authored Nov 18, 2024
1 parent f952450 commit dbbfd56
Show file tree
Hide file tree
Showing 34 changed files with 403 additions and 484 deletions.
10 changes: 5 additions & 5 deletions encodings/alp/src/alp/compute.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use vortex_array::array::ConstantArray;
use vortex_array::compute::unary::{scalar_at, scalar_at_unchecked, ScalarAtFn};
use vortex_array::compute::{
compare, filter, slice, take, ArrayCompute, FilterFn, MaybeCompareFn, Operator, SliceFn,
TakeFn, TakeOptions,
compare, filter, slice, take, ArrayCompute, FilterFn, FilterMask, MaybeCompareFn, Operator,
SliceFn, TakeFn, TakeOptions,
};
use vortex_array::stats::{ArrayStatistics, Stat};
use vortex_array::variants::PrimitiveArrayTrait;
Expand Down Expand Up @@ -86,11 +86,11 @@ impl SliceFn for ALPArray {
}

impl FilterFn for ALPArray {
fn filter(&self, predicate: &ArrayData) -> VortexResult<ArrayData> {
fn filter(&self, mask: &FilterMask) -> VortexResult<ArrayData> {
Ok(Self::try_new(
filter(self.encoded(), predicate)?,
filter(&self.encoded(), mask)?,
self.exponents(),
self.patches().map(|p| filter(&p, predicate)).transpose()?,
self.patches().map(|p| filter(&p, mask)).transpose()?,
)?
.into_array())
}
Expand Down
25 changes: 14 additions & 11 deletions encodings/alp/src/alp_rd/compute/filter.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
use vortex_array::compute::{filter, FilterFn};
use vortex_array::compute::{filter, FilterFn, FilterMask};
use vortex_array::{ArrayDType, ArrayData, IntoArrayData};
use vortex_error::VortexResult;

use crate::ALPRDArray;

impl FilterFn for ALPRDArray {
fn filter(&self, predicate: &ArrayData) -> VortexResult<ArrayData> {
fn filter(&self, mask: &FilterMask) -> VortexResult<ArrayData> {
let left_parts_exceptions = self
.left_parts_exceptions()
.map(|array| filter(&array, predicate))
.map(|array| filter(&array, mask))
.transpose()?;

Ok(ALPRDArray::try_new(
self.dtype().clone(),
filter(self.left_parts(), predicate)?,
filter(&self.left_parts(), mask)?,
self.left_parts_dict(),
filter(self.right_parts(), predicate)?,
filter(&self.right_parts(), mask)?,
self.right_bit_width(),
left_parts_exceptions,
)?
Expand All @@ -26,8 +26,8 @@ impl FilterFn for ALPRDArray {
#[cfg(test)]
mod test {
use rstest::rstest;
use vortex_array::array::{BoolArray, PrimitiveArray};
use vortex_array::compute::filter;
use vortex_array::array::PrimitiveArray;
use vortex_array::compute::{filter, FilterMask};
use vortex_array::IntoArrayVariant;

use crate::{ALPRDFloat, RDEncoder};
Expand All @@ -43,10 +43,13 @@ mod test {
assert!(encoded.left_parts_exceptions().is_some());

// The first two values need no patching
let filtered = filter(encoded.as_ref(), BoolArray::from_iter([true, false, true]))
.unwrap()
.into_primitive()
.unwrap();
let filtered = filter(
encoded.as_ref(),
&FilterMask::from_iter([true, false, true]),
)
.unwrap()
.into_primitive()
.unwrap();
assert_eq!(filtered.maybe_null_slice::<T>(), &[a, outlier]);
}
}
8 changes: 0 additions & 8 deletions encodings/bytebool/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,6 @@ impl BoolArrayTrait for ByteBoolArray {
)
.map(|a| a.into_array())
}

fn maybe_null_indices_iter<'a>(&'a self) -> Box<dyn Iterator<Item = usize> + 'a> {
todo!()
}

fn maybe_null_slices_iter<'a>(&'a self) -> Box<dyn Iterator<Item = (usize, usize)> + 'a> {
todo!()
}
}

impl From<Vec<bool>> for ByteBoolArray {
Expand Down
8 changes: 4 additions & 4 deletions encodings/dict/src/compute.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use vortex_array::compute::unary::{scalar_at, scalar_at_unchecked, ScalarAtFn};
use vortex_array::compute::{
compare, filter, slice, take, ArrayCompute, FilterFn, MaybeCompareFn, Operator, SliceFn,
TakeFn, TakeOptions,
compare, filter, slice, take, ArrayCompute, FilterFn, FilterMask, MaybeCompareFn, Operator,
SliceFn, TakeFn, TakeOptions,
};
use vortex_array::stats::{ArrayStatistics, Stat};
use vortex_array::{ArrayData, IntoArrayData};
Expand Down Expand Up @@ -86,8 +86,8 @@ impl TakeFn for DictArray {
}

impl FilterFn for DictArray {
fn filter(&self, predicate: &ArrayData) -> VortexResult<ArrayData> {
let codes = filter(self.codes(), predicate)?;
fn filter(&self, mask: &FilterMask) -> VortexResult<ArrayData> {
let codes = filter(&self.codes(), mask)?;
Self::try_new(codes, self.values()).map(|a| a.into_array())
}
}
Expand Down
8 changes: 0 additions & 8 deletions encodings/dict/src/variants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,6 @@ impl BoolArrayTrait for DictArray {
fn invert(&self) -> VortexResult<ArrayData> {
todo!()
}

fn maybe_null_indices_iter<'a>(&'a self) -> Box<dyn Iterator<Item = usize> + 'a> {
todo!()
}

fn maybe_null_slices_iter<'a>(&'a self) -> Box<dyn Iterator<Item = (usize, usize)> + 'a> {
todo!()
}
}

impl PrimitiveArrayTrait for DictArray {}
Expand Down
20 changes: 1 addition & 19 deletions encodings/fastlanes/src/bitpacking/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
use vortex_array::compute::unary::ScalarAtFn;
use vortex_array::compute::{filter, ArrayCompute, FilterFn, SearchSortedFn, SliceFn, TakeFn};
use vortex_array::stats::ArrayStatistics;
use vortex_array::{ArrayData, IntoCanonical};
use vortex_error::{vortex_err, VortexResult};
use vortex_array::compute::{ArrayCompute, SearchSortedFn, SliceFn, TakeFn};

use crate::BitPackedArray;

Expand All @@ -12,10 +9,6 @@ mod slice;
mod take;

impl ArrayCompute for BitPackedArray {
fn filter(&self) -> Option<&dyn FilterFn> {
Some(self)
}

fn scalar_at(&self) -> Option<&dyn ScalarAtFn> {
Some(self)
}
Expand All @@ -32,14 +25,3 @@ impl ArrayCompute for BitPackedArray {
Some(self)
}
}

impl FilterFn for BitPackedArray {
fn filter(&self, predicate: &ArrayData) -> VortexResult<ArrayData> {
let _predicate_true_count = predicate
.statistics()
.compute_true_count()
.ok_or_else(|| vortex_err!("Cannot compute true count of predicate"))?;

filter(self.clone().into_canonical()?.as_ref(), predicate)
}
}
6 changes: 5 additions & 1 deletion encodings/fastlanes/src/bitpacking/compute/search_sorted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,11 @@ impl<'a, T: BitPacking + NativePType> BitPackedSearch<'a, T> {
Validity::AllInvalid => 0,
Validity::Array(varray) => {
// In sorted order, nulls come after all the non-null values.
varray.with_dyn(|a| a.as_bool_array_unchecked().true_count())
varray.with_dyn(|a| {
a.statistics()
.compute_true_count()
.vortex_expect("Failed to compute true count")
})
}
};

Expand Down
8 changes: 4 additions & 4 deletions encodings/fastlanes/src/for/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::ops::{AddAssign, Shl, Shr};
use num_traits::{WrappingAdd, WrappingSub};
use vortex_array::compute::unary::{scalar_at_unchecked, ScalarAtFn};
use vortex_array::compute::{
filter, search_sorted, slice, take, ArrayCompute, FilterFn, SearchResult, SearchSortedFn,
SearchSortedSide, SliceFn, TakeFn, TakeOptions,
filter, search_sorted, slice, take, ArrayCompute, FilterFn, FilterMask, SearchResult,
SearchSortedFn, SearchSortedSide, SliceFn, TakeFn, TakeOptions,
};
use vortex_array::variants::PrimitiveArrayTrait;
use vortex_array::{ArrayDType, ArrayData, IntoArrayData};
Expand Down Expand Up @@ -48,9 +48,9 @@ impl TakeFn for FoRArray {
}

impl FilterFn for FoRArray {
fn filter(&self, predicate: &ArrayData) -> VortexResult<ArrayData> {
fn filter(&self, mask: &FilterMask) -> VortexResult<ArrayData> {
Self::try_new(
filter(self.encoded(), predicate)?,
filter(&self.encoded(), mask)?,
self.owned_reference_scalar(),
self.shift(),
)
Expand Down
10 changes: 5 additions & 5 deletions encodings/fsst/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use fsst::Symbol;
use vortex_array::array::{varbin_scalar, ConstantArray};
use vortex_array::compute::unary::{scalar_at_unchecked, ScalarAtFn};
use vortex_array::compute::{
compare, filter, slice, take, ArrayCompute, FilterFn, MaybeCompareFn, Operator, SliceFn,
TakeFn, TakeOptions,
compare, filter, slice, take, ArrayCompute, FilterFn, FilterMask, MaybeCompareFn, Operator,
SliceFn, TakeFn, TakeOptions,
};
use vortex_array::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant};
use vortex_buffer::Buffer;
Expand Down Expand Up @@ -151,13 +151,13 @@ impl ScalarAtFn for FSSTArray {

impl FilterFn for FSSTArray {
// Filtering an FSSTArray filters the codes array, leaving the symbols array untouched
fn filter(&self, predicate: &ArrayData) -> VortexResult<ArrayData> {
fn filter(&self, mask: &FilterMask) -> VortexResult<ArrayData> {
Ok(Self::try_new(
self.dtype().clone(),
self.symbols(),
self.symbol_lengths(),
filter(self.codes(), predicate)?,
filter(self.uncompressed_lengths(), predicate)?,
filter(&self.codes(), mask)?,
filter(&self.uncompressed_lengths(), mask)?,
)?
.into_array())
}
Expand Down
8 changes: 4 additions & 4 deletions encodings/fsst/tests/fsst_tests.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#![cfg(test)]

use vortex_array::array::builder::VarBinBuilder;
use vortex_array::array::{BoolArray, PrimitiveArray};
use vortex_array::array::PrimitiveArray;
use vortex_array::compute::unary::scalar_at;
use vortex_array::compute::{filter, slice, take, TakeOptions};
use vortex_array::compute::{filter, slice, take, FilterMask, TakeOptions};
use vortex_array::validity::Validity;
use vortex_array::{ArrayData, ArrayDef, IntoArrayData, IntoCanonical};
use vortex_dtype::{DType, Nullability};
Expand Down Expand Up @@ -85,9 +85,9 @@ fn test_fsst_array_ops() {
);

// test filter
let predicate = BoolArray::from_iter([false, true, false]).into_array();
let mask = FilterMask::from_iter([false, true, false]);

let fsst_filtered = filter(&fsst_array, &predicate).unwrap();
let fsst_filtered = filter(&fsst_array, &mask).unwrap();
assert_eq!(fsst_filtered.encoding().id(), FSST::ENCODING.id());
assert_eq!(fsst_filtered.len(), 1);
assert_nth_scalar!(
Expand Down
8 changes: 0 additions & 8 deletions encodings/roaring/src/boolean/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,6 @@ impl BoolArrayTrait for RoaringBoolArray {
RoaringBoolArray::try_new(self.bitmap().flip(0..(self.len() as u32)), self.len())
.map(|a| a.into_array())
}

fn maybe_null_indices_iter<'a>(&'a self) -> Box<dyn Iterator<Item = usize> + 'a> {
todo!()
}

fn maybe_null_slices_iter<'a>(&'a self) -> Box<dyn Iterator<Item = (usize, usize)> + 'a> {
todo!()
}
}

impl AcceptArrayVisitor for RoaringBoolArray {
Expand Down
8 changes: 0 additions & 8 deletions encodings/runend-bool/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,6 @@ impl BoolArrayTrait for RunEndBoolArray {
RunEndBoolArray::try_new(self.ends(), !self.start(), self.validity())
.map(|a| a.into_array())
}

fn maybe_null_indices_iter<'a>(&'a self) -> Box<dyn Iterator<Item = usize> + 'a> {
todo!()
}

fn maybe_null_slices_iter<'a>(&'a self) -> Box<dyn Iterator<Item = (usize, usize)> + 'a> {
todo!()
}
}

impl ArrayVariants for RunEndBoolArray {
Expand Down
42 changes: 20 additions & 22 deletions encodings/runend/src/compute.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use std::ops::AddAssign;

use num_traits::AsPrimitive;
use vortex_array::array::{BoolArray, BooleanBuffer, ConstantArray, PrimitiveArray, SparseArray};
use vortex_array::array::{BooleanBuffer, ConstantArray, PrimitiveArray, SparseArray};
use vortex_array::compute::unary::{scalar_at, scalar_at_unchecked, ScalarAtFn};
use vortex_array::compute::{
compare, filter, slice, take, ArrayCompute, FilterFn, MaybeCompareFn, Operator, SliceFn,
TakeFn, TakeOptions,
compare, filter, slice, take, ArrayCompute, FilterFn, FilterMask, MaybeCompareFn, Operator,
SliceFn, TakeFn, TakeOptions,
};
use vortex_array::stats::{ArrayStatistics, Stat};
use vortex_array::validity::Validity;
Expand Down Expand Up @@ -109,14 +109,13 @@ impl TakeFn for RunEndArray {
ConstantArray::new(Scalar::null(self.dtype().clone()), indices.len()).into_array()
}
Validity::Array(original_validity) => {
let dense_validity = take(&original_validity, indices, options)?;
let dense_validity =
FilterMask::try_from(take(&original_validity, indices, options)?)?;
let filtered_values = filter(&dense_values, &dense_validity)?;
let length = dense_validity.len();
let dense_nonnull_indices = PrimitiveArray::from(
dense_validity
.into_bool()?
.boolean_buffer()
.set_indices()
.iter_indices()?
.map(|idx| idx as u64)
.collect::<Vec<_>>(),
)
Expand Down Expand Up @@ -151,13 +150,13 @@ impl SliceFn for RunEndArray {
}

impl FilterFn for RunEndArray {
fn filter(&self, predicate: &ArrayData) -> VortexResult<ArrayData> {
fn filter(&self, mask: &FilterMask) -> VortexResult<ArrayData> {
let primitive_run_ends = self.ends().into_primitive()?;
let (run_ends, pred) = match_each_unsigned_integer_ptype!(primitive_run_ends.ptype(), |$P| {
filter_run_ends(primitive_run_ends.maybe_null_slice::<$P>(), predicate)?
let (run_ends, mask) = match_each_unsigned_integer_ptype!(primitive_run_ends.ptype(), |$P| {
filter_run_ends(primitive_run_ends.maybe_null_slice::<$P>(), mask)?
});
let values = filter(self.values(), &pred)?;
let validity = self.validity().filter(predicate)?;
let values = filter(&self.values(), &mask)?;
let validity = self.validity().filter(&mask)?;

RunEndArray::try_new(run_ends.into_array(), values, validity).map(|a| a.into_array())
}
Expand All @@ -166,16 +165,16 @@ impl FilterFn for RunEndArray {
// Code adapted from apache arrow-rs https://github.com/apache/arrow-rs/blob/b1f5c250ebb6c1252b4e7c51d15b8e77f4c361fa/arrow-select/src/filter.rs#L425
fn filter_run_ends<R: NativePType + AddAssign + From<bool> + AsPrimitive<u64>>(
run_ends: &[R],
predicate: &ArrayData,
) -> VortexResult<(PrimitiveArray, BoolArray)> {
mask: &FilterMask,
) -> VortexResult<(PrimitiveArray, FilterMask)> {
let mut new_run_ends = vec![R::zero(); run_ends.len()];

let mut start = 0u64;
let mut j = 0;
let mut count = R::zero();
let filter_values = predicate.clone().into_bool()?.boolean_buffer();
let filter_values = mask.to_boolean_buffer()?;

let pred: BoolArray = BooleanBuffer::collect_bool(run_ends.len(), |i| {
let new_mask: FilterMask = BooleanBuffer::collect_bool(run_ends.len(), |i| {
let mut keep = false;
let end = run_ends[i].as_();

Expand All @@ -194,14 +193,14 @@ fn filter_run_ends<R: NativePType + AddAssign + From<bool> + AsPrimitive<u64>>(
.into();

new_run_ends.truncate(j);
Ok((PrimitiveArray::from(new_run_ends), pred))
Ok((PrimitiveArray::from(new_run_ends), new_mask))
}

#[cfg(test)]
mod test {
use vortex_array::array::{BoolArray, PrimitiveArray};
use vortex_array::compute::unary::{scalar_at, try_cast};
use vortex_array::compute::{filter, slice, take, TakeOptions};
use vortex_array::compute::{filter, slice, take, FilterMask, TakeOptions};
use vortex_array::validity::{ArrayValidity, Validity};
use vortex_array::{ArrayDType, IntoArrayData, IntoArrayVariant, ToArrayData};
use vortex_dtype::{DType, Nullability, PType};
Expand Down Expand Up @@ -444,11 +443,10 @@ mod test {
fn filter_run_end() {
let arr = ree_array();
let filtered = filter(
arr,
BoolArray::from_iter([
arr.as_ref(),
&FilterMask::from_iter([
true, true, false, false, false, false, false, false, false, false, true, true,
])
.into_array(),
]),
)
.unwrap();
let filtered_run_end = RunEndArray::try_from(filtered).unwrap();
Expand Down
Loading

0 comments on commit dbbfd56

Please sign in to comment.