diff --git a/encodings/alp/src/alp/compute.rs b/encodings/alp/src/alp/compute.rs index 897326eb37..f242bcdc67 100644 --- a/encodings/alp/src/alp/compute.rs +++ b/encodings/alp/src/alp/compute.rs @@ -1,7 +1,8 @@ use vortex_array::array::ConstantArray; use vortex_array::compute::unary::{scalar_at, scalar_at_unchecked, ScalarAtFn}; use vortex_array::compute::{ - compare, filter, slice, take, ArrayCompute, FilterFn, MaybeCompareFn, Operator, SliceFn, TakeFn, + compare, filter, slice, take, ArrayCompute, FilterFn, MaybeCompareFn, Operator, SliceFn, + TakeFn, TakeOptions, }; use vortex_array::stats::{ArrayStatistics, Stat}; use vortex_array::variants::PrimitiveArrayTrait; @@ -60,12 +61,14 @@ impl ScalarAtFn for ALPArray { } impl TakeFn for ALPArray { - fn take(&self, indices: &ArrayData) -> VortexResult { + fn take(&self, indices: &ArrayData, options: TakeOptions) -> VortexResult { // TODO(ngates): wrap up indices in an array that caches decompression? Ok(Self::try_new( - take(self.encoded(), indices)?, + take(self.encoded(), indices, options)?, self.exponents(), - self.patches().map(|p| take(&p, indices)).transpose()?, + self.patches() + .map(|p| take(&p, indices, options)) + .transpose()?, )? .into_array()) } diff --git a/encodings/alp/src/alp_rd/compute/take.rs b/encodings/alp/src/alp_rd/compute/take.rs index 017b1dbdfe..5011a65c3f 100644 --- a/encodings/alp/src/alp_rd/compute/take.rs +++ b/encodings/alp/src/alp_rd/compute/take.rs @@ -1,21 +1,21 @@ -use vortex_array::compute::{take, TakeFn}; +use vortex_array::compute::{take, TakeFn, TakeOptions}; use vortex_array::{ArrayDType, ArrayData, IntoArrayData}; use vortex_error::VortexResult; use crate::ALPRDArray; impl TakeFn for ALPRDArray { - fn take(&self, indices: &ArrayData) -> VortexResult { + fn take(&self, indices: &ArrayData, options: TakeOptions) -> VortexResult { let left_parts_exceptions = self .left_parts_exceptions() - .map(|array| take(&array, indices)) + .map(|array| take(&array, indices, options)) .transpose()?; Ok(ALPRDArray::try_new( self.dtype().clone(), - take(self.left_parts(), indices)?, + take(self.left_parts(), indices, options)?, self.left_parts_dict(), - take(self.right_parts(), indices)?, + take(self.right_parts(), indices, options)?, self.right_bit_width(), left_parts_exceptions, )? @@ -27,7 +27,7 @@ impl TakeFn for ALPRDArray { mod test { use rstest::rstest; use vortex_array::array::PrimitiveArray; - use vortex_array::compute::take; + use vortex_array::compute::{take, TakeOptions}; use vortex_array::IntoArrayVariant; use crate::{ALPRDFloat, RDEncoder}; @@ -41,10 +41,14 @@ mod test { assert!(encoded.left_parts_exceptions().is_some()); - let taken = take(encoded.as_ref(), PrimitiveArray::from(vec![0, 2]).as_ref()) - .unwrap() - .into_primitive() - .unwrap(); + let taken = take( + encoded.as_ref(), + PrimitiveArray::from(vec![0, 2]).as_ref(), + TakeOptions::default(), + ) + .unwrap() + .into_primitive() + .unwrap(); assert_eq!(taken.maybe_null_slice::(), &[a, outlier]); } diff --git a/encodings/bytebool/src/compute.rs b/encodings/bytebool/src/compute.rs index 0c37bc9a9c..1c9ae3377d 100644 --- a/encodings/bytebool/src/compute.rs +++ b/encodings/bytebool/src/compute.rs @@ -1,6 +1,6 @@ use num_traits::AsPrimitive; use vortex_array::compute::unary::{FillForwardFn, ScalarAtFn}; -use vortex_array::compute::{ArrayCompute, SliceFn, TakeFn}; +use vortex_array::compute::{ArrayCompute, SliceFn, TakeFn, TakeOptions}; use vortex_array::validity::{ArrayValidity, Validity}; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::{ArrayDType, ArrayData, IntoArrayData}; @@ -49,7 +49,7 @@ impl SliceFn for ByteBoolArray { } impl TakeFn for ByteBoolArray { - fn take(&self, indices: &ArrayData) -> VortexResult { + fn take(&self, indices: &ArrayData, _options: TakeOptions) -> VortexResult { let validity = self.validity(); let indices = indices.clone().as_primitive(); let bools = self.maybe_null_slice(); diff --git a/encodings/datetime-parts/src/compute.rs b/encodings/datetime-parts/src/compute.rs index f31e2022a8..0eb235c4a7 100644 --- a/encodings/datetime-parts/src/compute.rs +++ b/encodings/datetime-parts/src/compute.rs @@ -1,7 +1,7 @@ use itertools::Itertools as _; use vortex_array::array::{PrimitiveArray, TemporalArray}; use vortex_array::compute::unary::{scalar_at, ScalarAtFn}; -use vortex_array::compute::{slice, take, ArrayCompute, SliceFn, TakeFn}; +use vortex_array::compute::{slice, take, ArrayCompute, SliceFn, TakeFn, TakeOptions}; use vortex_array::validity::ArrayValidity; use vortex_array::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant}; use vortex_datetime_dtype::{TemporalMetadata, TimeUnit}; @@ -26,12 +26,12 @@ impl ArrayCompute for DateTimePartsArray { } impl TakeFn for DateTimePartsArray { - fn take(&self, indices: &ArrayData) -> VortexResult { + fn take(&self, indices: &ArrayData, options: TakeOptions) -> VortexResult { Ok(Self::try_new( self.dtype().clone(), - take(self.days(), indices)?, - take(self.seconds(), indices)?, - take(self.subsecond(), indices)?, + take(self.days(), indices, options)?, + take(self.seconds(), indices, options)?, + take(self.subsecond(), indices, options)?, )? .into_array()) } diff --git a/encodings/dict/src/array.rs b/encodings/dict/src/array.rs index 23affd31a7..310e7f502e 100644 --- a/encodings/dict/src/array.rs +++ b/encodings/dict/src/array.rs @@ -4,8 +4,8 @@ use arrow_buffer::BooleanBuffer; use serde::{Deserialize, Serialize}; use vortex_array::array::visitor::{AcceptArrayVisitor, ArrayVisitor}; use vortex_array::array::BoolArray; -use vortex_array::compute::take; use vortex_array::compute::unary::scalar_at; +use vortex_array::compute::{take, TakeOptions}; use vortex_array::encoding::ids; use vortex_array::stats::StatsSet; use vortex_array::validity::{ArrayValidity, LogicalValidity}; @@ -75,10 +75,10 @@ impl IntoCanonical for DictArray { // copies of the view pointers. DType::Utf8(_) | DType::Binary(_) => { let canonical_values: ArrayData = self.values().into_canonical()?.into(); - take(canonical_values, self.codes())?.into_canonical() + take(canonical_values, self.codes(), TakeOptions::default())?.into_canonical() } // Non-string case: take and then canonicalize - _ => take(self.values(), self.codes())?.into_canonical(), + _ => take(self.values(), self.codes(), TakeOptions::default())?.into_canonical(), } } } diff --git a/encodings/dict/src/compute.rs b/encodings/dict/src/compute.rs index b2debc36fc..75a20245b3 100644 --- a/encodings/dict/src/compute.rs +++ b/encodings/dict/src/compute.rs @@ -1,6 +1,7 @@ use vortex_array::compute::unary::{scalar_at, scalar_at_unchecked, ScalarAtFn}; use vortex_array::compute::{ - compare, filter, slice, take, ArrayCompute, FilterFn, MaybeCompareFn, Operator, SliceFn, TakeFn, + compare, filter, slice, take, ArrayCompute, FilterFn, MaybeCompareFn, Operator, SliceFn, + TakeFn, TakeOptions, }; use vortex_array::stats::{ArrayStatistics, Stat}; use vortex_array::{ArrayData, IntoArrayData}; @@ -75,11 +76,11 @@ impl ScalarAtFn for DictArray { } impl TakeFn for DictArray { - fn take(&self, indices: &ArrayData) -> VortexResult { + fn take(&self, indices: &ArrayData, options: TakeOptions) -> VortexResult { // Dict // codes: 0 0 1 // dict: a b c d e f g h - let codes = take(self.codes(), indices)?; + let codes = take(self.codes(), indices, options)?; Self::try_new(codes, self.values()).map(|a| a.into_array()) } } diff --git a/encodings/dict/src/variants.rs b/encodings/dict/src/variants.rs index 7a2199aa36..e729f13c31 100644 --- a/encodings/dict/src/variants.rs +++ b/encodings/dict/src/variants.rs @@ -1,12 +1,17 @@ use vortex_array::variants::{ - ArrayVariants, BinaryArrayTrait, PrimitiveArrayTrait, Utf8ArrayTrait, + ArrayVariants, BinaryArrayTrait, BoolArrayTrait, PrimitiveArrayTrait, Utf8ArrayTrait, }; -use vortex_array::ArrayDType; +use vortex_array::{ArrayDType, ArrayData}; use vortex_dtype::DType; +use vortex_error::VortexResult; use crate::DictArray; impl ArrayVariants for DictArray { + fn as_bool_array(&self) -> Option<&dyn BoolArrayTrait> { + matches!(self.dtype(), DType::Bool(..)).then_some(self) + } + fn as_primitive_array(&self) -> Option<&dyn PrimitiveArrayTrait> { matches!(self.dtype(), DType::Primitive(..)).then_some(self) } @@ -20,6 +25,20 @@ impl ArrayVariants for DictArray { } } +impl BoolArrayTrait for DictArray { + fn invert(&self) -> VortexResult { + todo!() + } + + fn maybe_null_indices_iter<'a>(&'a self) -> Box + 'a> { + todo!() + } + + fn maybe_null_slices_iter<'a>(&'a self) -> Box + 'a> { + todo!() + } +} + impl PrimitiveArrayTrait for DictArray {} impl Utf8ArrayTrait for DictArray {} diff --git a/encodings/fastlanes/benches/bitpacking_take.rs b/encodings/fastlanes/benches/bitpacking_take.rs index 627d3c2e0f..06d004a700 100644 --- a/encodings/fastlanes/benches/bitpacking_take.rs +++ b/encodings/fastlanes/benches/bitpacking_take.rs @@ -5,7 +5,7 @@ use itertools::Itertools; use rand::distributions::Uniform; use rand::{thread_rng, Rng}; use vortex_array::array::{PrimitiveArray, SparseArray}; -use vortex_array::compute::take; +use vortex_array::compute::{take, TakeOptions}; use vortex_fastlanes::{find_best_bit_width, BitPackedArray}; fn values(len: usize, bits: usize) -> Vec { @@ -26,12 +26,30 @@ fn bench_take(c: &mut Criterion) { let stratified_indices: PrimitiveArray = (0..10).map(|i| i * 10_000).collect::>().into(); c.bench_function("take_10_stratified", |b| { - b.iter(|| black_box(take(packed.as_ref(), stratified_indices.as_ref()).unwrap())); + b.iter(|| { + black_box( + take( + packed.as_ref(), + stratified_indices.as_ref(), + TakeOptions::default(), + ) + .unwrap(), + ) + }); }); let contiguous_indices: PrimitiveArray = (0..10).collect::>().into(); c.bench_function("take_10_contiguous", |b| { - b.iter(|| black_box(take(packed.as_ref(), contiguous_indices.as_ref()).unwrap())); + b.iter(|| { + black_box( + take( + packed.as_ref(), + contiguous_indices.as_ref(), + TakeOptions::default(), + ) + .unwrap(), + ) + }); }); let rng = thread_rng(); @@ -43,12 +61,30 @@ fn bench_take(c: &mut Criterion) { .collect_vec() .into(); c.bench_function("take_10K_random", |b| { - b.iter(|| black_box(take(packed.as_ref(), random_indices.as_ref()).unwrap())); + b.iter(|| { + black_box( + take( + packed.as_ref(), + random_indices.as_ref(), + TakeOptions::default(), + ) + .unwrap(), + ) + }); }); let contiguous_indices: PrimitiveArray = (0..10_000).collect::>().into(); c.bench_function("take_10K_contiguous", |b| { - b.iter(|| black_box(take(packed.as_ref(), contiguous_indices.as_ref()).unwrap())); + b.iter(|| { + black_box( + take( + packed.as_ref(), + contiguous_indices.as_ref(), + TakeOptions::default(), + ) + .unwrap(), + ) + }); }); let lots_of_indices: PrimitiveArray = (0..200_000) @@ -56,7 +92,16 @@ fn bench_take(c: &mut Criterion) { .collect::>() .into(); c.bench_function("take_200K_dispersed", |b| { - b.iter(|| black_box(take(packed.as_ref(), lots_of_indices.as_ref()).unwrap())); + b.iter(|| { + black_box( + take( + packed.as_ref(), + lots_of_indices.as_ref(), + TakeOptions::default(), + ) + .unwrap(), + ) + }); }); let lots_of_indices: PrimitiveArray = (0..200_000) @@ -64,7 +109,16 @@ fn bench_take(c: &mut Criterion) { .collect::>() .into(); c.bench_function("take_200K_first_chunk_only", |b| { - b.iter(|| black_box(take(packed.as_ref(), lots_of_indices.as_ref()).unwrap())); + b.iter(|| { + black_box( + take( + packed.as_ref(), + lots_of_indices.as_ref(), + TakeOptions::default(), + ) + .unwrap(), + ) + }); }); } @@ -90,12 +144,30 @@ fn bench_patched_take(c: &mut Criterion) { let stratified_indices: PrimitiveArray = (0..10).map(|i| i * 10_000).collect::>().into(); c.bench_function("patched_take_10_stratified", |b| { - b.iter(|| black_box(take(packed.as_ref(), stratified_indices.as_ref()).unwrap())); + b.iter(|| { + black_box( + take( + packed.as_ref(), + stratified_indices.as_ref(), + TakeOptions::default(), + ) + .unwrap(), + ) + }); }); let contiguous_indices: PrimitiveArray = (0..10).collect::>().into(); c.bench_function("patched_take_10_contiguous", |b| { - b.iter(|| black_box(take(packed.as_ref(), contiguous_indices.as_ref()).unwrap())); + b.iter(|| { + black_box( + take( + packed.as_ref(), + contiguous_indices.as_ref(), + TakeOptions::default(), + ) + .unwrap(), + ) + }); }); let rng = thread_rng(); @@ -107,7 +179,16 @@ fn bench_patched_take(c: &mut Criterion) { .collect_vec() .into(); c.bench_function("patched_take_10K_random", |b| { - b.iter(|| black_box(take(packed.as_ref(), random_indices.as_ref()).unwrap())); + b.iter(|| { + black_box( + take( + packed.as_ref(), + random_indices.as_ref(), + TakeOptions::default(), + ) + .unwrap(), + ) + }); }); let not_patch_indices: PrimitiveArray = (0u32..num_exceptions) @@ -116,7 +197,16 @@ fn bench_patched_take(c: &mut Criterion) { .collect_vec() .into(); c.bench_function("patched_take_10K_contiguous_not_patches", |b| { - b.iter(|| black_box(take(packed.as_ref(), not_patch_indices.as_ref()).unwrap())); + b.iter(|| { + black_box( + take( + packed.as_ref(), + not_patch_indices.as_ref(), + TakeOptions::default(), + ) + .unwrap(), + ) + }); }); let patch_indices: PrimitiveArray = (big_base2..big_base2 + num_exceptions) @@ -125,7 +215,16 @@ fn bench_patched_take(c: &mut Criterion) { .collect_vec() .into(); c.bench_function("patched_take_10K_contiguous_patches", |b| { - b.iter(|| black_box(take(packed.as_ref(), patch_indices.as_ref()).unwrap())); + b.iter(|| { + black_box( + take( + packed.as_ref(), + patch_indices.as_ref(), + TakeOptions::default(), + ) + .unwrap(), + ) + }); }); let lots_of_indices: PrimitiveArray = (0..200_000) @@ -133,7 +232,16 @@ fn bench_patched_take(c: &mut Criterion) { .collect::>() .into(); c.bench_function("patched_take_200K_dispersed", |b| { - b.iter(|| black_box(take(packed.as_ref(), lots_of_indices.as_ref()).unwrap())); + b.iter(|| { + black_box( + take( + packed.as_ref(), + lots_of_indices.as_ref(), + TakeOptions::default(), + ) + .unwrap(), + ) + }); }); let lots_of_indices: PrimitiveArray = (0..200_000) @@ -141,7 +249,16 @@ fn bench_patched_take(c: &mut Criterion) { .collect::>() .into(); c.bench_function("patched_take_200K_first_chunk_only", |b| { - b.iter(|| black_box(take(packed.as_ref(), lots_of_indices.as_ref()).unwrap())); + b.iter(|| { + black_box( + take( + packed.as_ref(), + lots_of_indices.as_ref(), + TakeOptions::default(), + ) + .unwrap(), + ) + }); }); // There are currently 2 magic parameters of note: @@ -165,7 +282,16 @@ fn bench_patched_take(c: &mut Criterion) { .collect_vec() .into(); c.bench_function("patched_take_10K_adversarial", |b| { - b.iter(|| black_box(take(packed.as_ref(), adversarial_indices.as_ref()).unwrap())); + b.iter(|| { + black_box( + take( + packed.as_ref(), + adversarial_indices.as_ref(), + TakeOptions::default(), + ) + .unwrap(), + ) + }); }); } diff --git a/encodings/fastlanes/src/bitpacking/compute/mod.rs b/encodings/fastlanes/src/bitpacking/compute/mod.rs index 98e4eb7c02..ebf35ce791 100644 --- a/encodings/fastlanes/src/bitpacking/compute/mod.rs +++ b/encodings/fastlanes/src/bitpacking/compute/mod.rs @@ -1,5 +1,8 @@ use vortex_array::compute::unary::ScalarAtFn; -use vortex_array::compute::{ArrayCompute, SearchSortedFn, SliceFn, TakeFn}; +use vortex_array::compute::{filter, ArrayCompute, FilterFn, SearchSortedFn, SliceFn, TakeFn}; +use vortex_array::stats::ArrayStatistics; +use vortex_array::{ArrayData, IntoCanonical}; +use vortex_error::{vortex_err, VortexResult}; use crate::BitPackedArray; @@ -9,6 +12,10 @@ mod slice; mod take; impl ArrayCompute for BitPackedArray { + fn filter(&self) -> Option<&dyn FilterFn> { + Some(self) + } + fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { Some(self) } @@ -25,3 +32,14 @@ impl ArrayCompute for BitPackedArray { Some(self) } } + +impl FilterFn for BitPackedArray { + fn filter(&self, predicate: &ArrayData) -> VortexResult { + let _predicate_true_count = predicate + .statistics() + .compute_true_count() + .ok_or_else(|| vortex_err!("Cannot compute true count of predicate"))?; + + filter(self.clone().into_canonical()?.as_ref(), predicate) + } +} diff --git a/encodings/fastlanes/src/bitpacking/compute/slice.rs b/encodings/fastlanes/src/bitpacking/compute/slice.rs index 1aa29f7624..436fdbab0e 100644 --- a/encodings/fastlanes/src/bitpacking/compute/slice.rs +++ b/encodings/fastlanes/src/bitpacking/compute/slice.rs @@ -47,7 +47,7 @@ mod test { use itertools::Itertools; use vortex_array::array::{PrimitiveArray, SparseArray}; use vortex_array::compute::unary::scalar_at; - use vortex_array::compute::{slice, take}; + use vortex_array::compute::{slice, take, TakeOptions}; use vortex_array::IntoArrayData; use crate::BitPackedArray; @@ -201,6 +201,7 @@ mod test { let taken = take( &sliced, PrimitiveArray::from(vec![101i64, 1125i64, 1138i64]).as_ref(), + TakeOptions::default(), ) .unwrap(); assert_eq!(taken.len(), 3); diff --git a/encodings/fastlanes/src/bitpacking/compute/take.rs b/encodings/fastlanes/src/bitpacking/compute/take.rs index 22e7cb59a7..c587703962 100644 --- a/encodings/fastlanes/src/bitpacking/compute/take.rs +++ b/encodings/fastlanes/src/bitpacking/compute/take.rs @@ -3,7 +3,7 @@ use std::cmp::min; use fastlanes::BitPacking; use itertools::Itertools; use vortex_array::array::{PrimitiveArray, SparseArray}; -use vortex_array::compute::{slice, take, TakeFn}; +use vortex_array::compute::{slice, take, TakeFn, TakeOptions}; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant, IntoCanonical}; use vortex_dtype::{ @@ -20,24 +20,24 @@ const UNPACK_CHUNK_THRESHOLD: usize = 8; const BULK_PATCH_THRESHOLD: usize = 64; impl TakeFn for BitPackedArray { - fn take(&self, indices: &ArrayData) -> VortexResult { + fn take(&self, indices: &ArrayData, options: TakeOptions) -> VortexResult { // If the indices are large enough, it's faster to flatten and take the primitive array. if indices.len() * UNPACK_CHUNK_THRESHOLD > self.len() { return self .clone() .into_canonical()? .into_primitive()? - .take(indices); + .take(indices, options); } let ptype: PType = self.dtype().try_into()?; let validity = self.validity(); - let taken_validity = validity.take(indices)?; + let taken_validity = validity.take(indices, options)?; let indices = indices.clone().into_primitive()?; let taken = match_each_unsigned_integer_ptype!(ptype, |$T| { match_each_integer_ptype!(indices.ptype(), |$I| { - PrimitiveArray::from_vec(take_primitive::<$T, $I>(self, &indices)?, taken_validity) + PrimitiveArray::from_vec(take_primitive::<$T, $I>(self, &indices, options)?, taken_validity) }) }); Ok(taken.reinterpret_cast(ptype).into_array()) @@ -47,6 +47,7 @@ impl TakeFn for BitPackedArray { fn take_primitive( array: &BitPackedArray, indices: &PrimitiveArray, + options: TakeOptions, ) -> VortexResult> { if indices.is_empty() { return Ok(vec![]); @@ -119,7 +120,14 @@ fn take_primitive( } if let Some(ref patches) = patches { - patch_for_take_primitive::(patches, indices, offset, batch_count, &mut output)?; + patch_for_take_primitive::( + patches, + indices, + offset, + batch_count, + &mut output, + options, + )?; } Ok(output) @@ -131,14 +139,16 @@ fn patch_for_take_primitive( offset: usize, batch_count: usize, output: &mut [T], + options: TakeOptions, ) -> VortexResult<()> { #[inline] fn inner_patch( patches: &SparseArray, indices: &PrimitiveArray, output: &mut [T], + options: TakeOptions, ) -> VortexResult<()> { - let taken_patches = take(patches.as_ref(), indices.as_ref())?; + let taken_patches = take(patches.as_ref(), indices.as_ref(), options)?; let taken_patches = SparseArray::try_from(taken_patches)?; let base_index = output.len() - indices.len(); @@ -163,7 +173,7 @@ fn patch_for_take_primitive( // roughly, if we have an average of less than 64 elements per batch, we prefer bulk patching let prefer_bulk_patch = batch_count * BULK_PATCH_THRESHOLD > indices.len(); if prefer_bulk_patch { - return inner_patch(patches, indices, output); + return inner_patch(patches, indices, output, options); } let min_index = patches.min_index().unwrap_or_default(); @@ -207,7 +217,12 @@ fn patch_for_take_primitive( continue; } - inner_patch(&patches_slice, &PrimitiveArray::from(offsets), output)?; + inner_patch( + &patches_slice, + &PrimitiveArray::from(offsets), + output, + options, + )?; } Ok(()) @@ -220,7 +235,7 @@ mod test { use rand::{thread_rng, Rng}; use vortex_array::array::{PrimitiveArray, SparseArray}; use vortex_array::compute::unary::scalar_at; - use vortex_array::compute::{slice, take}; + use vortex_array::compute::{slice, take, TakeOptions}; use vortex_array::{IntoArrayData, IntoArrayVariant}; use crate::BitPackedArray; @@ -233,7 +248,7 @@ mod test { let unpacked = PrimitiveArray::from((0..4096).map(|i| (i % 63) as u8).collect::>()); let bitpacked = BitPackedArray::encode(unpacked.as_ref(), 6).unwrap(); - let primitive_result = take(bitpacked.as_ref(), &indices) + let primitive_result = take(bitpacked.as_ref(), &indices, TakeOptions::default()) .unwrap() .into_primitive() .unwrap(); @@ -250,7 +265,10 @@ mod test { let bitpacked = BitPackedArray::encode(unpacked.as_ref(), 6).unwrap(); let sliced = slice(bitpacked.as_ref(), 128, 2050).unwrap(); - let primitive_result = take(&sliced, &indices).unwrap().into_primitive().unwrap(); + let primitive_result = take(&sliced, &indices, TakeOptions::default()) + .unwrap() + .into_primitive() + .unwrap(); let res_bytes = primitive_result.maybe_null_slice::(); assert_eq!(res_bytes, &[31, 33]); } @@ -278,7 +296,12 @@ mod test { .map(|i| i as u32) .collect_vec() .into(); - let taken = take(packed.as_ref(), random_indices.as_ref()).unwrap(); + let taken = take( + packed.as_ref(), + random_indices.as_ref(), + TakeOptions::default(), + ) + .unwrap(); // sanity check random_indices diff --git a/encodings/fastlanes/src/for/compute.rs b/encodings/fastlanes/src/for/compute.rs index d8056652cc..c751cd00e1 100644 --- a/encodings/fastlanes/src/for/compute.rs +++ b/encodings/fastlanes/src/for/compute.rs @@ -4,7 +4,7 @@ use num_traits::{WrappingAdd, WrappingSub}; use vortex_array::compute::unary::{scalar_at_unchecked, ScalarAtFn}; use vortex_array::compute::{ filter, search_sorted, slice, take, ArrayCompute, FilterFn, SearchResult, SearchSortedFn, - SearchSortedSide, SliceFn, TakeFn, + SearchSortedSide, SliceFn, TakeFn, TakeOptions, }; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::{ArrayDType, ArrayData, IntoArrayData}; @@ -37,9 +37,9 @@ impl ArrayCompute for FoRArray { } impl TakeFn for FoRArray { - fn take(&self, indices: &ArrayData) -> VortexResult { + fn take(&self, indices: &ArrayData, options: TakeOptions) -> VortexResult { Self::try_new( - take(self.encoded(), indices)?, + take(self.encoded(), indices, options)?, self.owned_reference_scalar(), self.shift(), ) diff --git a/encodings/fsst/src/compute.rs b/encodings/fsst/src/compute.rs index 3fd8ed76db..f8c91e9dc4 100644 --- a/encodings/fsst/src/compute.rs +++ b/encodings/fsst/src/compute.rs @@ -2,7 +2,8 @@ use fsst::Symbol; use vortex_array::array::{varbin_scalar, ConstantArray}; use vortex_array::compute::unary::{scalar_at_unchecked, ScalarAtFn}; use vortex_array::compute::{ - compare, filter, slice, take, ArrayCompute, FilterFn, MaybeCompareFn, Operator, SliceFn, TakeFn, + compare, filter, slice, take, ArrayCompute, FilterFn, MaybeCompareFn, Operator, SliceFn, + TakeFn, TakeOptions, }; use vortex_array::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant}; use vortex_buffer::Buffer; @@ -117,13 +118,13 @@ impl SliceFn for FSSTArray { impl TakeFn for FSSTArray { // Take on an FSSTArray is a simple take on the codes array. - fn take(&self, indices: &ArrayData) -> VortexResult { + fn take(&self, indices: &ArrayData, options: TakeOptions) -> VortexResult { Ok(Self::try_new( self.dtype().clone(), self.symbols(), self.symbol_lengths(), - take(self.codes(), indices)?, - take(self.uncompressed_lengths(), indices)?, + take(self.codes(), indices, options)?, + take(self.uncompressed_lengths(), indices, options)?, )? .into_array()) } diff --git a/encodings/fsst/tests/fsst_tests.rs b/encodings/fsst/tests/fsst_tests.rs index 30ada95b8e..bff0ab99e4 100644 --- a/encodings/fsst/tests/fsst_tests.rs +++ b/encodings/fsst/tests/fsst_tests.rs @@ -3,7 +3,7 @@ use vortex_array::array::builder::VarBinBuilder; use vortex_array::array::{BoolArray, PrimitiveArray}; use vortex_array::compute::unary::scalar_at; -use vortex_array::compute::{filter, slice, take}; +use vortex_array::compute::{filter, slice, take, TakeOptions}; use vortex_array::validity::Validity; use vortex_array::{ArrayData, ArrayDef, IntoArrayData, IntoCanonical}; use vortex_dtype::{DType, Nullability}; @@ -71,7 +71,7 @@ fn test_fsst_array_ops() { // test take let indices = PrimitiveArray::from_vec(vec![0, 2], Validity::NonNullable).into_array(); - let fsst_taken = take(&fsst_array, &indices).unwrap(); + let fsst_taken = take(&fsst_array, &indices, TakeOptions::default()).unwrap(); assert_eq!(fsst_taken.len(), 2); assert_nth_scalar!( fsst_taken, diff --git a/encodings/runend-bool/src/array.rs b/encodings/runend-bool/src/array.rs index 7f4a506591..4bcb47c232 100644 --- a/encodings/runend-bool/src/array.rs +++ b/encodings/runend-bool/src/array.rs @@ -238,7 +238,7 @@ mod test { use rstest::rstest; use vortex_array::array::BoolArray; use vortex_array::compute::unary::scalar_at; - use vortex_array::compute::{slice, take}; + use vortex_array::compute::{slice, take, TakeOptions}; use vortex_array::stats::{ArrayStatistics as _, ArrayStatisticsCompute}; use vortex_array::validity::Validity; use vortex_array::{ArrayDType, ArrayData, IntoArrayData, IntoCanonical, ToArrayData}; @@ -326,6 +326,7 @@ mod test { ) .unwrap(), vec![0, 0, 6, 4].into_array(), + TakeOptions::default(), ) .unwrap(); diff --git a/encodings/runend-bool/src/compute.rs b/encodings/runend-bool/src/compute.rs index e9750b7896..c48d132fbf 100644 --- a/encodings/runend-bool/src/compute.rs +++ b/encodings/runend-bool/src/compute.rs @@ -1,6 +1,6 @@ use vortex_array::array::BoolArray; use vortex_array::compute::unary::ScalarAtFn; -use vortex_array::compute::{slice, ArrayCompute, SliceFn, TakeFn}; +use vortex_array::compute::{slice, ArrayCompute, SliceFn, TakeFn, TakeOptions}; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::{ArrayData, IntoArrayData, IntoArrayVariant, ToArrayData}; use vortex_dtype::match_each_integer_ptype; @@ -44,7 +44,7 @@ impl ScalarAtFn for RunEndBoolArray { } impl TakeFn for RunEndBoolArray { - fn take(&self, indices: &ArrayData) -> VortexResult { + fn take(&self, indices: &ArrayData, _options: TakeOptions) -> VortexResult { let primitive_indices = indices.clone().into_primitive()?; let physical_indices = match_each_integer_ptype!(primitive_indices.ptype(), |$P| { primitive_indices diff --git a/encodings/runend/src/compute.rs b/encodings/runend/src/compute.rs index b7863294eb..078d71ad6e 100644 --- a/encodings/runend/src/compute.rs +++ b/encodings/runend/src/compute.rs @@ -4,7 +4,8 @@ use num_traits::AsPrimitive; use vortex_array::array::{BoolArray, BooleanBuffer, ConstantArray, PrimitiveArray, SparseArray}; use vortex_array::compute::unary::{scalar_at, scalar_at_unchecked, ScalarAtFn}; use vortex_array::compute::{ - compare, filter, slice, take, ArrayCompute, FilterFn, MaybeCompareFn, Operator, SliceFn, TakeFn, + compare, filter, slice, take, ArrayCompute, FilterFn, MaybeCompareFn, Operator, SliceFn, + TakeFn, TakeOptions, }; use vortex_array::stats::{ArrayStatistics, Stat}; use vortex_array::validity::Validity; @@ -76,7 +77,7 @@ impl ScalarAtFn for RunEndArray { } impl TakeFn for RunEndArray { - fn take(&self, indices: &ArrayData) -> VortexResult { + fn take(&self, indices: &ArrayData, options: TakeOptions) -> VortexResult { let primitive_indices = indices.clone().into_primitive()?; let u64_indices = match_each_integer_ptype!(primitive_indices.ptype(), |$P| { primitive_indices @@ -99,7 +100,7 @@ impl TakeFn for RunEndArray { .map(|idx| *idx as u64) .collect(); let physical_indices_array = PrimitiveArray::from(physical_indices).into_array(); - let dense_values = take(self.values(), &physical_indices_array)?; + let dense_values = take(self.values(), &physical_indices_array, options)?; Ok(match self.validity() { Validity::NonNullable => dense_values, @@ -108,7 +109,7 @@ impl TakeFn for RunEndArray { ConstantArray::new(Scalar::null(self.dtype().clone()), indices.len()).into_array() } Validity::Array(original_validity) => { - let dense_validity = take(&original_validity, indices)?; + let dense_validity = take(&original_validity, indices, options)?; let filtered_values = filter(&dense_values, &dense_validity)?; let length = dense_validity.len(); let dense_nonnull_indices = PrimitiveArray::from( @@ -200,7 +201,7 @@ fn filter_run_ends + AsPrimitive>( mod test { use vortex_array::array::{BoolArray, PrimitiveArray}; use vortex_array::compute::unary::{scalar_at, try_cast}; - use vortex_array::compute::{filter, slice, take}; + use vortex_array::compute::{filter, slice, take, TakeOptions}; use vortex_array::validity::{ArrayValidity, Validity}; use vortex_array::{ArrayDType, IntoArrayData, IntoArrayVariant, ToArrayData}; use vortex_dtype::{DType, Nullability, PType}; @@ -220,6 +221,7 @@ mod test { let taken = take( ree_array().as_ref(), PrimitiveArray::from(vec![9, 8, 1, 3]).as_ref(), + TakeOptions::default(), ) .unwrap(); assert_eq!( @@ -233,6 +235,7 @@ mod test { let taken = take( ree_array().as_ref(), PrimitiveArray::from(vec![11]).as_ref(), + TakeOptions::default(), ) .unwrap(); assert_eq!( @@ -247,6 +250,7 @@ mod test { take( ree_array().as_ref(), PrimitiveArray::from(vec![12]).as_ref(), + TakeOptions::default(), ) .unwrap(); } @@ -407,7 +411,7 @@ mod test { .unwrap(); let test_indices = PrimitiveArray::from_vec(vec![0, 2, 4, 6], Validity::NonNullable); - let taken = take(arr.as_ref(), test_indices.as_ref()).unwrap(); + let taken = take(arr.as_ref(), test_indices.as_ref(), TakeOptions::default()).unwrap(); assert_eq!(taken.len(), test_indices.len()); @@ -426,6 +430,7 @@ mod test { let taken = take( sliced.as_ref(), PrimitiveArray::from(vec![1, 3, 4]).as_ref(), + TakeOptions::default(), ) .unwrap(); diff --git a/fuzz/fuzz_targets/array_ops.rs b/fuzz/fuzz_targets/array_ops.rs index 6f0be25aae..9ff7f367a1 100644 --- a/fuzz/fuzz_targets/array_ops.rs +++ b/fuzz/fuzz_targets/array_ops.rs @@ -6,7 +6,9 @@ use vortex_array::array::{ BoolEncoding, PrimitiveEncoding, StructEncoding, VarBinEncoding, VarBinViewEncoding, }; use vortex_array::compute::unary::scalar_at; -use vortex_array::compute::{filter, search_sorted, slice, take, SearchResult, SearchSortedSide}; +use vortex_array::compute::{ + filter, search_sorted, slice, take, SearchResult, SearchSortedSide, TakeOptions, +}; use vortex_array::encoding::EncodingRef; use vortex_array::{ArrayData, IntoCanonical}; use vortex_fuzz::{sort_canonical_array, Action, FuzzArrayAction}; @@ -35,7 +37,7 @@ fuzz_target!(|fuzz_action: FuzzArrayAction| -> Corpus { if indices.is_empty() { return Corpus::Reject; } - current_array = take(¤t_array, &indices).unwrap(); + current_array = take(¤t_array, &indices, TakeOptions::default()).unwrap(); assert_array_eq(&expected.array(), ¤t_array, i); } Action::SearchSorted(s, side) => { diff --git a/pyvortex/src/array.rs b/pyvortex/src/array.rs index 9f24336f45..fb8d3566e2 100644 --- a/pyvortex/src/array.rs +++ b/pyvortex/src/array.rs @@ -5,7 +5,7 @@ use pyo3::prelude::*; use pyo3::types::{IntoPyDict, PyInt, PyList}; use vortex::array::ChunkedArray; use vortex::compute::unary::{fill_forward, scalar_at}; -use vortex::compute::{compare, slice, take, Operator}; +use vortex::compute::{compare, slice, take, Operator, TakeOptions}; use vortex::{ArrayDType, ArrayData, IntoCanonical}; use crate::dtype::PyDType; @@ -441,7 +441,7 @@ impl PyArray { ))); } - let inner = take(&self.inner, indices)?; + let inner = take(&self.inner, indices, TakeOptions::default())?; Ok(PyArray { inner }) } diff --git a/vortex-array/benches/take_strings.rs b/vortex-array/benches/take_strings.rs index 69f651970e..23accc3cec 100644 --- a/vortex-array/benches/take_strings.rs +++ b/vortex-array/benches/take_strings.rs @@ -2,7 +2,7 @@ use criterion::{criterion_group, criterion_main, Criterion}; use vortex_array::array::{PrimitiveArray, VarBinArray}; -use vortex_array::compute::take; +use vortex_array::compute::{take, TakeOptions}; use vortex_array::validity::Validity; use vortex_array::{ArrayData, IntoArrayData, IntoArrayVariant}; use vortex_dtype::{DType, Nullability}; @@ -33,7 +33,9 @@ fn bench_varbin(c: &mut Criterion) { let array = fixture(65_535); let indices = indices(1024); - c.bench_function("varbin", |b| b.iter(|| take(&array, &indices).unwrap())); + c.bench_function("varbin", |b| { + b.iter(|| take(&array, &indices, TakeOptions::default()).unwrap()) + }); } fn bench_varbinview(c: &mut Criterion) { @@ -41,7 +43,7 @@ fn bench_varbinview(c: &mut Criterion) { let indices = indices(1024); c.bench_function("varbinview", |b| { - b.iter(|| take(array.as_ref(), &indices).unwrap()) + b.iter(|| take(array.as_ref(), &indices, TakeOptions::default()).unwrap()) }); } diff --git a/vortex-array/src/array/bool/compute/take.rs b/vortex-array/src/array/bool/compute/take.rs index 6e49b53e6d..6a01bc3381 100644 --- a/vortex-array/src/array/bool/compute/take.rs +++ b/vortex-array/src/array/bool/compute/take.rs @@ -1,35 +1,81 @@ use arrow_buffer::BooleanBuffer; +use itertools::Itertools; use num_traits::AsPrimitive; use vortex_dtype::match_each_integer_ptype; use vortex_error::VortexResult; use crate::array::BoolArray; -use crate::compute::TakeFn; +use crate::compute::{TakeFn, TakeOptions}; use crate::variants::PrimitiveArrayTrait; use crate::{ArrayData, IntoArrayData, IntoArrayVariant}; impl TakeFn for BoolArray { - fn take(&self, indices: &ArrayData) -> VortexResult { + fn take(&self, indices: &ArrayData, options: TakeOptions) -> VortexResult { let validity = self.validity(); let indices = indices.clone().into_primitive()?; - match_each_integer_ptype!(indices.ptype(), |$I| { - Ok(BoolArray::try_new( - take_bool(&self.boolean_buffer(), indices.maybe_null_slice::<$I>()), - validity.take(indices.as_ref())?, - )?.into_array()) - }) + + // For boolean arrays that roughly fit into a single page (at least, on Linux), it's worth + // the overhead to convert to a Vec. + let buffer = if self.len() <= 4096 { + let bools = self.boolean_buffer().into_iter().collect_vec(); + match_each_integer_ptype!(indices.ptype(), |$I| { + if options.skip_bounds_check { + take_byte_bool_unchecked(bools, indices.maybe_null_slice::<$I>()) + } else { + take_byte_bool(bools, indices.maybe_null_slice::<$I>()) + } + }) + } else { + match_each_integer_ptype!(indices.ptype(), |$I| { + if options.skip_bounds_check { + take_bool_unchecked(&self.boolean_buffer(), indices.maybe_null_slice::<$I>()) + } else { + take_bool(&self.boolean_buffer(), indices.maybe_null_slice::<$I>()) + } + }) + }; + + Ok(BoolArray::try_new(buffer, validity.take(indices.as_ref(), options)?)?.into_array()) } } +fn take_byte_bool>(bools: Vec, indices: &[I]) -> BooleanBuffer { + BooleanBuffer::collect_bool(indices.len(), |idx| { + bools[unsafe { (*indices.get_unchecked(idx)).as_() }] + }) +} + +fn take_byte_bool_unchecked>( + bools: Vec, + indices: &[I], +) -> BooleanBuffer { + BooleanBuffer::collect_bool(indices.len(), |idx| unsafe { + *bools.get_unchecked((*indices.get_unchecked(idx)).as_()) + }) +} + fn take_bool>(bools: &BooleanBuffer, indices: &[I]) -> BooleanBuffer { - BooleanBuffer::collect_bool(indices.len(), |idx| bools.value(indices[idx].as_())) + BooleanBuffer::collect_bool(indices.len(), |idx| { + // We can always take from the indices unchecked since collect_bool just iterates len. + bools.value(unsafe { (*indices.get_unchecked(idx)).as_() }) + }) +} + +fn take_bool_unchecked>( + bools: &BooleanBuffer, + indices: &[I], +) -> BooleanBuffer { + BooleanBuffer::collect_bool(indices.len(), |idx| unsafe { + // We can always take from the indices unchecked since collect_bool just iterates len. + bools.value_unchecked((*indices.get_unchecked(idx)).as_()) + }) } #[cfg(test)] mod test { use crate::array::primitive::PrimitiveArray; use crate::array::BoolArray; - use crate::compute::take; + use crate::compute::{take, TakeOptions}; #[test] fn take_nullable() { @@ -41,8 +87,15 @@ mod test { Some(false), ]); - let b = BoolArray::try_from(take(&reference, PrimitiveArray::from(vec![0, 3, 4])).unwrap()) - .unwrap(); + let b = BoolArray::try_from( + take( + &reference, + PrimitiveArray::from(vec![0, 3, 4]), + TakeOptions::default(), + ) + .unwrap(), + ) + .unwrap(); assert_eq!( b.boolean_buffer(), BoolArray::from_iter(vec![Some(false), None, Some(false)]).boolean_buffer() diff --git a/vortex-array/src/array/chunked/compute/filter.rs b/vortex-array/src/array/chunked/compute/filter.rs index b65c2a33bb..64c0e58016 100644 --- a/vortex-array/src/array/chunked/compute/filter.rs +++ b/vortex-array/src/array/chunked/compute/filter.rs @@ -2,7 +2,7 @@ use arrow_buffer::BooleanBufferBuilder; use vortex_error::{VortexExpect, VortexResult}; use crate::array::{BoolArray, ChunkedArray, PrimitiveArray}; -use crate::compute::{filter, take, FilterFn, SearchSorted, SearchSortedSide}; +use crate::compute::{filter, take, FilterFn, SearchSorted, SearchSortedSide, TakeOptions}; use crate::{ArrayDType, ArrayData, IntoArrayData, IntoCanonical}; // This is modeled after the constant with the equivalent name in arrow-rs. @@ -174,6 +174,7 @@ fn filter_indices<'a>( let filtered_chunk = take( chunk, PrimitiveArray::from(chunk_indices.clone()).into_array(), + TakeOptions::default(), )?; result.push(filtered_chunk); } @@ -193,6 +194,7 @@ fn filter_indices<'a>( let filtered_chunk = take( &chunk, PrimitiveArray::from(chunk_indices.clone()).into_array(), + TakeOptions::default(), )?; result.push(filtered_chunk); } diff --git a/vortex-array/src/array/chunked/compute/take.rs b/vortex-array/src/array/chunked/compute/take.rs index 3b65dd917b..c4222ab115 100644 --- a/vortex-array/src/array/chunked/compute/take.rs +++ b/vortex-array/src/array/chunked/compute/take.rs @@ -5,12 +5,12 @@ use vortex_scalar::Scalar; use crate::array::chunked::ChunkedArray; use crate::compute::unary::{scalar_at, subtract_scalar, try_cast}; -use crate::compute::{search_sorted, slice, take, SearchSortedSide, TakeFn}; +use crate::compute::{search_sorted, slice, take, SearchSortedSide, TakeFn, TakeOptions}; use crate::stats::ArrayStatistics; use crate::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant, ToArrayData}; impl TakeFn for ChunkedArray { - fn take(&self, indices: &ArrayData) -> VortexResult { + fn take(&self, indices: &ArrayData, options: TakeOptions) -> VortexResult { // Fast path for strict sorted indices. if indices .statistics() @@ -21,7 +21,7 @@ impl TakeFn for ChunkedArray { return Ok(self.to_array()); } - return take_strict_sorted(self, indices); + return take_strict_sorted(self, indices, options); } let indices = try_cast(indices, PType::U64.into())?.into_primitive()?; @@ -38,7 +38,11 @@ impl TakeFn for ChunkedArray { if chunk_idx != prev_chunk_idx { // Start a new chunk let indices_in_chunk_array = indices_in_chunk.clone().into_array(); - chunks.push(take(&self.chunk(prev_chunk_idx)?, &indices_in_chunk_array)?); + chunks.push(take( + &self.chunk(prev_chunk_idx)?, + &indices_in_chunk_array, + options, + )?); indices_in_chunk = Vec::new(); } @@ -48,7 +52,11 @@ impl TakeFn for ChunkedArray { if !indices_in_chunk.is_empty() { let indices_in_chunk_array = indices_in_chunk.into_array(); - chunks.push(take(&self.chunk(prev_chunk_idx)?, &indices_in_chunk_array)?); + chunks.push(take( + &self.chunk(prev_chunk_idx)?, + &indices_in_chunk_array, + options, + )?); } Ok(Self::try_new(chunks, self.dtype().clone())?.into_array()) @@ -56,7 +64,11 @@ impl TakeFn for ChunkedArray { } /// When the indices are non-null and strict-sorted, we can do better -fn take_strict_sorted(chunked: &ChunkedArray, indices: &ArrayData) -> VortexResult { +fn take_strict_sorted( + chunked: &ChunkedArray, + indices: &ArrayData, + options: TakeOptions, +) -> VortexResult { let mut indices_by_chunk = vec![None; chunked.nchunks()]; // Track our position in the indices array @@ -100,7 +112,7 @@ fn take_strict_sorted(chunked: &ChunkedArray, indices: &ArrayData) -> VortexResu .into_iter() .enumerate() .filter_map(|(chunk_idx, indices)| indices.map(|i| (chunk_idx, i))) - .map(|(chunk_idx, chunk_indices)| take(&chunked.chunk(chunk_idx)?, &chunk_indices)) + .map(|(chunk_idx, chunk_indices)| take(&chunked.chunk(chunk_idx)?, &chunk_indices, options)) .try_collect()?; Ok(ChunkedArray::try_new(chunks, chunked.dtype().clone())?.into_array()) @@ -109,7 +121,7 @@ fn take_strict_sorted(chunked: &ChunkedArray, indices: &ArrayData) -> VortexResu #[cfg(test)] mod test { use crate::array::chunked::ChunkedArray; - use crate::compute::take; + use crate::compute::{take, TakeOptions}; use crate::{ArrayDType, IntoArrayData, IntoArrayVariant}; #[test] @@ -121,11 +133,12 @@ mod test { assert_eq!(arr.len(), 9); let indices = vec![0u64, 0, 6, 4].into_array(); - let result = &ChunkedArray::try_from(take(arr.as_ref(), &indices).unwrap()) - .unwrap() - .into_array() - .into_primitive() - .unwrap(); + let result = + &ChunkedArray::try_from(take(arr.as_ref(), &indices, TakeOptions::default()).unwrap()) + .unwrap() + .into_array() + .into_primitive() + .unwrap(); assert_eq!(result.maybe_null_slice::(), &[1, 1, 1, 2]); } } diff --git a/vortex-array/src/array/constant/compute.rs b/vortex-array/src/array/constant/compute.rs index c18aaa3d0e..e888016c63 100644 --- a/vortex-array/src/array/constant/compute.rs +++ b/vortex-array/src/array/constant/compute.rs @@ -8,7 +8,7 @@ use crate::array::constant::ConstantArray; use crate::compute::unary::{scalar_at, ScalarAtFn}; use crate::compute::{ scalar_cmp, AndFn, ArrayCompute, FilterFn, MaybeCompareFn, Operator, OrFn, SearchResult, - SearchSortedFn, SearchSortedSide, SliceFn, TakeFn, + SearchSortedFn, SearchSortedSide, SliceFn, TakeFn, TakeOptions, }; use crate::stats::{ArrayStatistics, Stat}; use crate::{ArrayDType, ArrayData, IntoArrayData}; @@ -58,7 +58,7 @@ impl ScalarAtFn for ConstantArray { } impl TakeFn for ConstantArray { - fn take(&self, indices: &ArrayData) -> VortexResult { + fn take(&self, indices: &ArrayData, _options: TakeOptions) -> VortexResult { Ok(Self::new(self.owned_scalar(), indices.len()).into_array()) } } @@ -242,8 +242,10 @@ mod test { } #[rstest] - #[case(ConstantArray::new(true, 4).into_array(), BoolArray::from_iter([Some(true), Some(false), Some(true), Some(false)].into_iter()).into_array())] - #[case(BoolArray::from_iter([Some(true), Some(false), Some(true), Some(false)].into_iter()).into_array(), ConstantArray::new(true, 4).into_array())] + #[case(ConstantArray::new(true, 4).into_array(), BoolArray::from_iter([Some(true), Some(false), Some(true), Some(false)].into_iter()).into_array() + )] + #[case(BoolArray::from_iter([Some(true), Some(false), Some(true), Some(false)].into_iter()).into_array(), ConstantArray::new(true, 4).into_array() + )] fn test_or(#[case] lhs: ArrayData, #[case] rhs: ArrayData) { let r = or(&lhs, &rhs).unwrap().into_bool().unwrap().into_array(); @@ -259,7 +261,8 @@ mod test { } #[rstest] - #[case(ConstantArray::new(true, 4).into_array(), BoolArray::from_iter([Some(true), Some(false), Some(true), Some(false)].into_iter()).into_array())] + #[case(ConstantArray::new(true, 4).into_array(), BoolArray::from_iter([Some(true), Some(false), Some(true), Some(false)].into_iter()).into_array() + )] #[case(BoolArray::from_iter([Some(true), Some(false), Some(true), Some(false)].into_iter()).into_array(), ConstantArray::new(true, 4).into_array())] fn test_and(#[case] lhs: ArrayData, #[case] rhs: ArrayData) { diff --git a/vortex-array/src/array/extension/compute.rs b/vortex-array/src/array/extension/compute.rs index b4fb574c60..e514f8b782 100644 --- a/vortex-array/src/array/extension/compute.rs +++ b/vortex-array/src/array/extension/compute.rs @@ -5,7 +5,7 @@ use crate::array::extension::ExtensionArray; use crate::array::ConstantArray; use crate::compute::unary::{scalar_at, scalar_at_unchecked, CastFn, ScalarAtFn}; use crate::compute::{ - compare, slice, take, ArrayCompute, MaybeCompareFn, Operator, SliceFn, TakeFn, + compare, slice, take, ArrayCompute, MaybeCompareFn, Operator, SliceFn, TakeFn, TakeOptions, }; use crate::variants::ExtensionArrayTrait; use crate::{ArrayDType, ArrayData, IntoArrayData}; @@ -87,7 +87,11 @@ impl SliceFn for ExtensionArray { } impl TakeFn for ExtensionArray { - fn take(&self, indices: &ArrayData) -> VortexResult { - Ok(Self::new(self.ext_dtype().clone(), take(self.storage(), indices)?).into_array()) + fn take(&self, indices: &ArrayData, options: TakeOptions) -> VortexResult { + Ok(Self::new( + self.ext_dtype().clone(), + take(self.storage(), indices, options)?, + ) + .into_array()) } } diff --git a/vortex-array/src/array/null/compute.rs b/vortex-array/src/array/null/compute.rs index 1017a68c8f..f9cb3ed7f7 100644 --- a/vortex-array/src/array/null/compute.rs +++ b/vortex-array/src/array/null/compute.rs @@ -4,7 +4,7 @@ use vortex_scalar::Scalar; use crate::array::null::NullArray; use crate::compute::unary::ScalarAtFn; -use crate::compute::{ArrayCompute, SliceFn, TakeFn}; +use crate::compute::{ArrayCompute, SliceFn, TakeFn, TakeOptions}; use crate::variants::PrimitiveArrayTrait; use crate::{ArrayData, IntoArrayData, IntoArrayVariant}; @@ -39,17 +39,19 @@ impl ScalarAtFn for NullArray { } impl TakeFn for NullArray { - fn take(&self, indices: &ArrayData) -> VortexResult { + fn take(&self, indices: &ArrayData, options: TakeOptions) -> VortexResult { let indices = indices.clone().into_primitive()?; // Enforce all indices are valid - match_each_integer_ptype!(indices.ptype(), |$T| { - for index in indices.maybe_null_slice::<$T>() { - if !((*index as usize) < self.len()) { - vortex_bail!(OutOfBounds: *index as usize, 0, self.len()); + if !options.skip_bounds_check { + match_each_integer_ptype!(indices.ptype(), |$T| { + for index in indices.maybe_null_slice::<$T>() { + if !((*index as usize) < self.len()) { + vortex_bail!(OutOfBounds: *index as usize, 0, self.len()); + } } - } - }); + }); + } Ok(NullArray::new(indices.len()).into_array()) } @@ -61,7 +63,7 @@ mod test { use crate::array::null::NullArray; use crate::compute::unary::scalar_at; - use crate::compute::{SliceFn, TakeFn}; + use crate::compute::{SliceFn, TakeFn, TakeOptions}; use crate::validity::{ArrayValidity, LogicalValidity}; use crate::IntoArrayData; @@ -80,8 +82,12 @@ mod test { #[test] fn test_take_nulls() { let nulls = NullArray::new(10); - let taken = - NullArray::try_from(nulls.take(&vec![0u64, 2, 4, 6, 8].into_array()).unwrap()).unwrap(); + let taken = NullArray::try_from( + nulls + .take(&vec![0u64, 2, 4, 6, 8].into_array(), TakeOptions::default()) + .unwrap(), + ) + .unwrap(); assert_eq!(taken.len(), 5); assert!(matches!( diff --git a/vortex-array/src/array/primitive/compute/take.rs b/vortex-array/src/array/primitive/compute/take.rs index 38abedf680..5d791e9de5 100644 --- a/vortex-array/src/array/primitive/compute/take.rs +++ b/vortex-array/src/array/primitive/compute/take.rs @@ -1,35 +1,49 @@ -use num_traits::PrimInt; +use num_traits::AsPrimitive; use vortex_dtype::{match_each_integer_ptype, match_each_native_ptype, NativePType}; -use vortex_error::{vortex_panic, VortexResult}; +use vortex_error::VortexResult; use crate::array::primitive::PrimitiveArray; -use crate::compute::TakeFn; +use crate::compute::{TakeFn, TakeOptions}; use crate::variants::PrimitiveArrayTrait; use crate::{ArrayData, IntoArrayData, IntoArrayVariant}; impl TakeFn for PrimitiveArray { - fn take(&self, indices: &ArrayData) -> VortexResult { - let validity = self.validity(); + #[allow(clippy::cognitive_complexity)] + fn take(&self, indices: &ArrayData, options: TakeOptions) -> VortexResult { let indices = indices.clone().into_primitive()?; + let validity = self.validity().take(indices.as_ref(), options)?; + match_each_native_ptype!(self.ptype(), |$T| { match_each_integer_ptype!(indices.ptype(), |$I| { - Ok(PrimitiveArray::from_vec( - take_primitive(self.maybe_null_slice::<$T>(), indices.maybe_null_slice::<$I>()), - validity.take(indices.as_ref())?, - ).into_array()) + let values = if options.skip_bounds_check { + take_primitive_unchecked(self.maybe_null_slice::<$T>(), indices.into_maybe_null_slice::<$I>()) + } else { + take_primitive(self.maybe_null_slice::<$T>(), indices.into_maybe_null_slice::<$I>()) + }; + Ok(PrimitiveArray::from_vec(values,validity).into_array()) }) }) } } -fn take_primitive(array: &[T], indices: &[I]) -> Vec { +// We pass a Vec in case we're T == u64. +// In which case, Rust should reuse the same Vec the result. +fn take_primitive>( + array: &[T], + indices: Vec, +) -> Vec { + indices.into_iter().map(|idx| array[idx.as_()]).collect() +} + +// We pass a Vec in case we're T == u64. +// In which case, Rust should reuse the same Vec the result. +fn take_primitive_unchecked>( + array: &[T], + indices: Vec, +) -> Vec { indices - .iter() - .map(|&idx| { - array[idx.to_usize().unwrap_or_else(|| { - vortex_panic!("Failed to convert index to usize: {}", idx); - })] - }) + .into_iter() + .map(|idx| unsafe { *array.get_unchecked(idx.as_()) }) .collect() } @@ -40,7 +54,7 @@ mod test { #[test] fn test_take() { let a = vec![1i32, 2, 3, 4, 5]; - let result = take_primitive(&a, &[0, 0, 4, 2]); + let result = take_primitive(&a, vec![0, 0, 4, 2]); assert_eq!(result, vec![1i32, 1, 5, 3]); } } diff --git a/vortex-array/src/array/sparse/compute/mod.rs b/vortex-array/src/array/sparse/compute/mod.rs index 2fa24e3087..e4816cafc8 100644 --- a/vortex-array/src/array/sparse/compute/mod.rs +++ b/vortex-array/src/array/sparse/compute/mod.rs @@ -7,7 +7,7 @@ use crate::array::PrimitiveArray; use crate::compute::unary::{scalar_at, scalar_at_unchecked, ScalarAtFn}; use crate::compute::{ search_sorted, take, ArrayCompute, FilterFn, SearchResult, SearchSortedFn, SearchSortedSide, - SliceFn, TakeFn, + SliceFn, TakeFn, TakeOptions, }; use crate::variants::PrimitiveArrayTrait; use crate::{ArrayData, IntoArrayData, IntoArrayVariant}; @@ -103,7 +103,11 @@ impl FilterFn for SparseArray { Ok(SparseArray::try_new( PrimitiveArray::from(coordinate_indices).into_array(), - take(self.values(), PrimitiveArray::from(value_indices))?, + take( + self.values(), + PrimitiveArray::from(value_indices), + TakeOptions::default(), + )?, buffer.count_set_bits(), self.fill_value().clone(), )? diff --git a/vortex-array/src/array/sparse/compute/take.rs b/vortex-array/src/array/sparse/compute/take.rs index 888607ee20..a04ad97c38 100644 --- a/vortex-array/src/array/sparse/compute/take.rs +++ b/vortex-array/src/array/sparse/compute/take.rs @@ -7,12 +7,12 @@ use vortex_error::VortexResult; use crate::aliases::hash_map::HashMap; use crate::array::primitive::PrimitiveArray; use crate::array::sparse::SparseArray; -use crate::compute::{take, TakeFn}; +use crate::compute::{take, TakeFn, TakeOptions}; use crate::variants::PrimitiveArrayTrait; use crate::{ArrayData, IntoArrayData, IntoArrayVariant}; impl TakeFn for SparseArray { - fn take(&self, indices: &ArrayData) -> VortexResult { + fn take(&self, indices: &ArrayData, options: TakeOptions) -> VortexResult { let flat_indices = indices.clone().into_primitive()?; // if we are taking a lot of values we should build a hashmap let (positions, physical_take_indices) = if indices.len() > 128 { @@ -21,7 +21,7 @@ impl TakeFn for SparseArray { take_search_sorted(self, &flat_indices)? }; - let taken_values = take(self.values(), physical_take_indices)?; + let taken_values = take(self.values(), physical_take_indices, options)?; Ok(Self::try_new( positions.into_array(), @@ -96,7 +96,7 @@ mod test { use crate::array::primitive::PrimitiveArray; use crate::array::sparse::compute::take::take_map; use crate::array::sparse::SparseArray; - use crate::compute::take; + use crate::compute::{take, TakeOptions}; use crate::validity::Validity; use crate::{ArrayData, IntoArrayData, IntoArrayVariant}; @@ -115,9 +115,15 @@ mod test { #[test] fn sparse_take() { let sparse = sparse_array(); - let taken = - SparseArray::try_from(take(sparse, vec![0, 47, 47, 0, 99].into_array()).unwrap()) - .unwrap(); + let taken = SparseArray::try_from( + take( + sparse, + vec![0, 47, 47, 0, 99].into_array(), + TakeOptions::default(), + ) + .unwrap(), + ) + .unwrap(); assert_eq!( taken .indices() @@ -139,7 +145,10 @@ mod test { #[test] fn nonexistent_take() { let sparse = sparse_array(); - let taken = SparseArray::try_from(take(sparse, vec![69].into_array()).unwrap()).unwrap(); + let taken = SparseArray::try_from( + take(sparse, vec![69].into_array(), TakeOptions::default()).unwrap(), + ) + .unwrap(); assert!(taken .indices() .into_primitive() @@ -157,8 +166,10 @@ mod test { #[test] fn ordered_take() { let sparse = sparse_array(); - let taken = - SparseArray::try_from(take(&sparse, vec![69, 37].into_array()).unwrap()).unwrap(); + let taken = SparseArray::try_from( + take(&sparse, vec![69, 37].into_array(), TakeOptions::default()).unwrap(), + ) + .unwrap(); assert_eq!( taken .indices() diff --git a/vortex-array/src/array/struct_/compute.rs b/vortex-array/src/array/struct_/compute.rs index ac2bdf444a..58d2719abf 100644 --- a/vortex-array/src/array/struct_/compute.rs +++ b/vortex-array/src/array/struct_/compute.rs @@ -4,7 +4,7 @@ use vortex_scalar::Scalar; use crate::array::struct_::StructArray; use crate::compute::unary::{scalar_at, scalar_at_unchecked, ScalarAtFn}; -use crate::compute::{filter, slice, take, ArrayCompute, FilterFn, SliceFn, TakeFn}; +use crate::compute::{filter, slice, take, ArrayCompute, FilterFn, SliceFn, TakeFn, TakeOptions}; use crate::stats::ArrayStatistics; use crate::variants::StructArrayTrait; use crate::{ArrayDType, ArrayData, IntoArrayData}; @@ -48,14 +48,14 @@ impl ScalarAtFn for StructArray { } impl TakeFn for StructArray { - fn take(&self, indices: &ArrayData) -> VortexResult { + fn take(&self, indices: &ArrayData, options: TakeOptions) -> VortexResult { Self::try_new( self.names().clone(), self.children() - .map(|field| take(&field, indices)) + .map(|field| take(&field, indices, options)) .try_collect()?, indices.len(), - self.validity().take(indices)?, + self.validity().take(indices, options)?, ) .map(|a| a.into_array()) } diff --git a/vortex-array/src/array/varbin/compute/take.rs b/vortex-array/src/array/varbin/compute/take.rs index c0ce9630c7..693419477e 100644 --- a/vortex-array/src/array/varbin/compute/take.rs +++ b/vortex-array/src/array/varbin/compute/take.rs @@ -4,13 +4,13 @@ use vortex_error::{vortex_err, vortex_panic, VortexResult}; use crate::array::varbin::builder::VarBinBuilder; use crate::array::varbin::VarBinArray; -use crate::compute::TakeFn; +use crate::compute::{TakeFn, TakeOptions}; use crate::validity::Validity; use crate::variants::PrimitiveArrayTrait; use crate::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant}; impl TakeFn for VarBinArray { - fn take(&self, indices: &ArrayData) -> VortexResult { + fn take(&self, indices: &ArrayData, _options: TakeOptions) -> VortexResult { let offsets = self.offsets().into_primitive()?; let data = self.bytes().into_primitive()?; let indices = indices.clone().into_primitive()?; diff --git a/vortex-array/src/array/varbinview/compute.rs b/vortex-array/src/array/varbinview/compute.rs index 23e755729c..7848881971 100644 --- a/vortex-array/src/array/varbinview/compute.rs +++ b/vortex-array/src/array/varbinview/compute.rs @@ -14,7 +14,7 @@ use crate::array::varbinview::{VarBinViewArray, VIEW_SIZE_BYTES}; use crate::array::{varbinview_as_arrow, ConstantArray}; use crate::arrow::FromArrowArray; use crate::compute::unary::ScalarAtFn; -use crate::compute::{slice, ArrayCompute, MaybeCompareFn, Operator, SliceFn, TakeFn}; +use crate::compute::{slice, ArrayCompute, MaybeCompareFn, Operator, SliceFn, TakeFn, TakeOptions}; use crate::{ArrayDType, ArrayData, IntoArrayData, IntoCanonical}; impl ArrayCompute for VarBinViewArray { @@ -66,11 +66,17 @@ impl SliceFn for VarBinViewArray { /// Take involves creating a new array that references the old array, just with the given set of views. impl TakeFn for VarBinViewArray { - fn take(&self, indices: &ArrayData) -> VortexResult { + fn take(&self, indices: &ArrayData, options: TakeOptions) -> VortexResult { let array_ref = varbinview_as_arrow(self); let indices_arrow = indices.clone().into_canonical()?.into_arrow()?; - let take_arrow = arrow_select::take::take(&array_ref, &indices_arrow, None)?; + let take_arrow = arrow_select::take::take( + &array_ref, + &indices_arrow, + Some(arrow_select::take::TakeOptions { + check_bounds: !options.skip_bounds_check, + }), + )?; Ok(ArrayData::from_arrow( take_arrow, self.dtype().is_nullable(), @@ -138,7 +144,7 @@ mod tests { use crate::accessor::ArrayAccessor; use crate::array::varbinview::compute::compare_constant; use crate::array::{ConstantArray, PrimitiveArray, VarBinViewArray}; - use crate::compute::{take, Operator}; + use crate::compute::{take, Operator, TakeOptions}; use crate::{ArrayDType, IntoArrayData, IntoArrayVariant}; #[test] @@ -175,7 +181,12 @@ mod tests { Some("six"), ]); - let taken = take(arr, PrimitiveArray::from(vec![0, 3]).into_array()).unwrap(); + let taken = take( + arr, + PrimitiveArray::from(vec![0, 3]).into_array(), + TakeOptions::default(), + ) + .unwrap(); assert!(taken.dtype().is_nullable()); assert_eq!( diff --git a/vortex-array/src/compute/mod.rs b/vortex-array/src/compute/mod.rs index 9ced04dccd..fed55fa357 100644 --- a/vortex-array/src/compute/mod.rs +++ b/vortex-array/src/compute/mod.rs @@ -13,7 +13,7 @@ pub use compare::{compare, scalar_cmp, CompareFn, MaybeCompareFn, Operator}; pub use filter::{filter, FilterFn}; pub use search_sorted::*; pub use slice::{slice, SliceFn}; -pub use take::{take, TakeFn}; +pub use take::*; use unary::{CastFn, FillForwardFn, ScalarAtFn, SubtractScalarFn}; use vortex_error::VortexResult; diff --git a/vortex-array/src/compute/take.rs b/vortex-array/src/compute/take.rs index 735f233cfe..9be95dfd96 100644 --- a/vortex-array/src/compute/take.rs +++ b/vortex-array/src/compute/take.rs @@ -1,15 +1,22 @@ use log::info; use vortex_error::{vortex_bail, vortex_err, VortexResult}; +use crate::stats::{ArrayStatistics, Stat}; use crate::{ArrayDType as _, ArrayData, IntoCanonical as _}; +#[derive(Default, Debug, Clone, Copy)] +pub struct TakeOptions { + pub skip_bounds_check: bool, +} + pub trait TakeFn { - fn take(&self, indices: &ArrayData) -> VortexResult; + fn take(&self, indices: &ArrayData, options: TakeOptions) -> VortexResult; } pub fn take( array: impl AsRef, indices: impl AsRef, + mut options: TakeOptions, ) -> VortexResult { let array = array.as_ref(); let indices = indices.as_ref(); @@ -21,16 +28,28 @@ pub fn take( ); } + // If the indices are all within bounds, we can skip bounds checking. + if indices + .statistics() + .get_as::(Stat::Max) + .is_some_and(|max| max < array.len()) + { + options.skip_bounds_check = true; + } + + // TODO(ngates): if indices min is quite high, we could slice self and offset the indices + // such that canonicalize does less work. + array.with_dyn(|a| { if let Some(take) = a.take() { - return take.take(indices); + return take.take(indices, options); } // Otherwise, flatten and try again. info!("TakeFn not implemented for {}, flattening", array); ArrayData::from(array.clone().into_canonical()?).with_dyn(|a| { a.take() - .map(|t| t.take(indices)) + .map(|t| t.take(indices, options)) .unwrap_or_else(|| Err(vortex_err!(NotImplemented: "take", array.encoding().id()))) }) }) diff --git a/vortex-array/src/stream/take_rows.rs b/vortex-array/src/stream/take_rows.rs index b3e06d875a..2ccc2b5b6c 100644 --- a/vortex-array/src/stream/take_rows.rs +++ b/vortex-array/src/stream/take_rows.rs @@ -8,7 +8,7 @@ use vortex_error::{vortex_bail, VortexResult}; use vortex_scalar::Scalar; use crate::compute::unary::subtract_scalar; -use crate::compute::{search_sorted, slice, take, SearchSortedSide}; +use crate::compute::{search_sorted, slice, take, SearchSortedSide, TakeOptions}; use crate::stats::{ArrayStatistics, Stat}; use crate::stream::ArrayStream; use crate::variants::PrimitiveArrayTrait; @@ -93,7 +93,11 @@ impl Stream for TakeRows { let shifted_arr = match_each_integer_ptype!(indices_for_batch.ptype(), |$T| { subtract_scalar(&indices_for_batch.into_array(), &Scalar::from(curr_offset as $T))? }); - return Poll::Ready(take(&batch, &shifted_arr).map(Some).transpose()); + return Poll::Ready( + take(&batch, &shifted_arr, TakeOptions::default()) + .map(Some) + .transpose(), + ); } Poll::Ready(None) diff --git a/vortex-array/src/validity.rs b/vortex-array/src/validity.rs index 4e4bbdf709..68f06e7823 100644 --- a/vortex-array/src/validity.rs +++ b/vortex-array/src/validity.rs @@ -13,7 +13,7 @@ use vortex_error::{ use crate::array::{BoolArray, ConstantArray}; use crate::compute::unary::scalar_at_unchecked; -use crate::compute::{filter, slice, take}; +use crate::compute::{filter, slice, take, TakeOptions}; use crate::stats::ArrayStatistics; use crate::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant}; @@ -161,12 +161,12 @@ impl Validity { } } - pub fn take(&self, indices: &ArrayData) -> VortexResult { + pub fn take(&self, indices: &ArrayData, options: TakeOptions) -> VortexResult { match self { Self::NonNullable => Ok(Self::NonNullable), Self::AllValid => Ok(Self::AllValid), Self::AllInvalid => Ok(Self::AllInvalid), - Self::Array(a) => Ok(Self::Array(take(a, indices)?)), + Self::Array(a) => Ok(Self::Array(take(a, indices, options)?)), } } @@ -467,20 +467,31 @@ mod tests { #[rstest] #[case(Validity::NonNullable, 5, &[2, 4], Validity::NonNullable, Validity::NonNullable)] #[case(Validity::NonNullable, 5, &[2, 4], Validity::AllValid, Validity::NonNullable)] - #[case(Validity::NonNullable, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array()))] - #[case(Validity::NonNullable, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array()))] + #[case(Validity::NonNullable, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array()) + )] + #[case(Validity::NonNullable, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array()) + )] #[case(Validity::AllValid, 5, &[2, 4], Validity::NonNullable, Validity::AllValid)] #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)] - #[case(Validity::AllValid, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array()))] - #[case(Validity::AllValid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array()))] - #[case(Validity::AllInvalid, 5, &[2, 4], Validity::NonNullable, Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array()))] - #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array()))] + #[case(Validity::AllValid, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array()) + )] + #[case(Validity::AllValid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array()) + )] + #[case(Validity::AllInvalid, 5, &[2, 4], Validity::NonNullable, Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array()) + )] + #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array()) + )] #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)] - #[case(Validity::AllInvalid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array()))] - #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::NonNullable, Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array()))] - #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array()))] - #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()))] - #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array()))] + #[case(Validity::AllInvalid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array()) + )] + #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::NonNullable, Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array()) + )] + #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array()) + )] + #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()) + )] + #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array()) + )] fn patch_validity( #[case] validity: Validity, #[case] len: usize, diff --git a/vortex-datafusion/src/memory/plans.rs b/vortex-datafusion/src/memory/plans.rs index 5cd5ffb67f..c71479aa0b 100644 --- a/vortex-datafusion/src/memory/plans.rs +++ b/vortex-datafusion/src/memory/plans.rs @@ -20,7 +20,7 @@ use futures::{ready, Stream}; use pin_project::pin_project; use vortex_array::array::ChunkedArray; use vortex_array::arrow::FromArrowArray; -use vortex_array::compute::take; +use vortex_array::compute::{take, TakeOptions}; use vortex_array::{ArrayData, IntoArrayVariant, IntoCanonical}; use vortex_dtype::field::Field; use vortex_error::{vortex_err, vortex_panic, VortexError}; @@ -350,7 +350,7 @@ where // We should find a way to avoid decoding the filter columns and only decode the other // columns, then stitch the StructArray back together from those. let projected_for_output = chunk.project(this.output_projection)?; - let decoded = take(projected_for_output, &row_indices)? + let decoded = take(projected_for_output, &row_indices, TakeOptions::default())? .into_canonical()? .into_arrow()?; diff --git a/vortex-file/src/chunked_reader/take_rows.rs b/vortex-file/src/chunked_reader/take_rows.rs index 9ddc513c80..d236d98ae4 100644 --- a/vortex-file/src/chunked_reader/take_rows.rs +++ b/vortex-file/src/chunked_reader/take_rows.rs @@ -6,7 +6,7 @@ use itertools::Itertools; use vortex_array::aliases::hash_map::HashMap; use vortex_array::array::{ChunkedArray, PrimitiveArray}; use vortex_array::compute::unary::{subtract_scalar, try_cast}; -use vortex_array::compute::{search_sorted, slice, take, SearchSortedSide}; +use vortex_array::compute::{search_sorted, slice, take, SearchSortedSide, TakeOptions}; use vortex_array::stats::ArrayStatistics; use vortex_array::stream::{ArrayStream, ArrayStreamExt}; use vortex_array::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant}; @@ -74,12 +74,16 @@ impl ChunkedArrayReader { // Grab the row and byte offsets for each chunk range. let start_chunks = PrimitiveArray::from(start_chunks).into_array(); - let start_rows = take(&self.row_offsets, &start_chunks)?.into_primitive()?; - let start_bytes = take(&self.byte_offsets, &start_chunks)?.into_primitive()?; + let start_rows = + take(&self.row_offsets, &start_chunks, TakeOptions::default())?.into_primitive()?; + let start_bytes = + take(&self.byte_offsets, &start_chunks, TakeOptions::default())?.into_primitive()?; let stop_chunks = PrimitiveArray::from(stop_chunks).into_array(); - let stop_rows = take(&self.row_offsets, &stop_chunks)?.into_primitive()?; - let stop_bytes = take(&self.byte_offsets, &stop_chunks)?.into_primitive()?; + let stop_rows = + take(&self.row_offsets, &stop_chunks, TakeOptions::default())?.into_primitive()?; + let stop_bytes = + take(&self.byte_offsets, &stop_chunks, TakeOptions::default())?.into_primitive()?; // For each chunk-range, read the data as an ArrayStream and call take on it. let chunks = stream::iter(0..coalesced_chunks.len()) diff --git a/vortex-file/src/read/mask.rs b/vortex-file/src/read/mask.rs index deea46ca9e..f3eeb1b8b2 100644 --- a/vortex-file/src/read/mask.rs +++ b/vortex-file/src/read/mask.rs @@ -4,7 +4,7 @@ use std::fmt::{Display, Formatter}; use arrow_buffer::{BooleanBuffer, MutableBuffer}; use croaring::Bitmap; use vortex_array::array::{BoolArray, PrimitiveArray, SparseArray}; -use vortex_array::compute::{filter, slice, take}; +use vortex_array::compute::{filter, slice, take, TakeOptions}; use vortex_array::validity::{LogicalValidity, Validity}; use vortex_array::{iterate_integer_array, ArrayData, IntoArrayData, IntoArrayVariant}; use vortex_dtype::PType; @@ -213,7 +213,7 @@ impl RowMask { if (true_count as f64 / sliced.len() as f64) < PREFER_TAKE_TO_FILTER_DENSITY { let indices = self.to_indices_array()?; - take(sliced, indices).map(Some) + take(sliced, indices, TakeOptions::default()).map(Some) } else { let mask = self.to_mask_array()?; filter(sliced, mask).map(Some) diff --git a/vortex-ipc/benches/ipc_take.rs b/vortex-ipc/benches/ipc_take.rs index 095fc42845..8103c28820 100644 --- a/vortex-ipc/benches/ipc_take.rs +++ b/vortex-ipc/benches/ipc_take.rs @@ -15,7 +15,7 @@ use futures_util::{pin_mut, TryStreamExt}; use itertools::Itertools; use vortex_array::array::PrimitiveArray; use vortex_array::compress::CompressionStrategy; -use vortex_array::compute::take; +use vortex_array::compute::{take, TakeOptions}; use vortex_array::{Context, IntoArrayData}; use vortex_io::FuturesAdapter; use vortex_ipc::stream_reader::StreamArrayReader; @@ -81,7 +81,7 @@ fn ipc_take(c: &mut Criterion) { let reader = stream_reader.into_array_stream(); pin_mut!(reader); let array_view = reader.try_next().await?.unwrap(); - black_box(take(&array_view, indices_ref)) + black_box(take(&array_view, indices_ref, TakeOptions::default())) }); }); }