Skip to content

Commit

Permalink
remove specialized canonicalize, just use take
Browse files Browse the repository at this point in the history
  • Loading branch information
a10y committed Oct 23, 2024
1 parent 53b9f4d commit 8d3ed4d
Showing 6 changed files with 66 additions and 52 deletions.
4 changes: 2 additions & 2 deletions encodings/dict/benches/dict_canonical.rs
Original file line number Diff line number Diff line change
@@ -23,10 +23,10 @@ fn fixture(len: usize) -> DictArray {
}

fn bench_canonical(c: &mut Criterion) {
let dict_array = fixture(1024).into_array();
let dict_array = fixture(1024 * 1024).into_array();

c.bench_function("canonical", |b| {
b.iter(|| dict_array.clone().into_canonical().unwrap())
b.iter(|| dict_array.clone().into_canonical())
});
}

55 changes: 7 additions & 48 deletions encodings/dict/src/array.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use std::fmt::{Debug, Display};

use arrow_buffer::{BooleanBuffer, ScalarBuffer};
use arrow_buffer::BooleanBuffer;
use serde::{Deserialize, Serialize};
use vortex::array::visitor::{AcceptArrayVisitor, ArrayVisitor};
use vortex::array::{BoolArray, ConstantArray, VarBinViewArray};
use vortex::compute::unary::{scalar_at, try_cast};
use vortex::compute::{compare, take, Operator};
use vortex::array::BoolArray;
use vortex::compute::take;
use vortex::compute::unary::scalar_at;
use vortex::encoding::ids;
use vortex::stats::StatsSet;
use vortex::validity::{ArrayValidity, LogicalValidity, Validity};
use vortex::validity::{ArrayValidity, LogicalValidity};
use vortex::{
impl_encoding, Array, ArrayDType, ArrayTrait, Canonical, IntoArray, IntoArrayVariant,
IntoCanonical,
@@ -68,52 +68,11 @@ impl ArrayTrait for DictArray {}

impl IntoCanonical for DictArray {
fn into_canonical(self) -> VortexResult<Canonical> {
match self.dtype() {
DType::Utf8(_) | DType::Binary(_) => canonicalize_string(self),
_ => canonicalize_primitive(self),
}
let canonical_values: Array = self.values().into_canonical()?.into();
take(canonical_values, self.codes())?.into_canonical()
}
}

/// Canonicalize a set of codes and values.
fn canonicalize_string(array: DictArray) -> VortexResult<Canonical> {
let values = array.values().into_varbinview()?;
let codes = try_cast(array.codes(), PType::U64.into())?.into_primitive()?;

let value_views = ScalarBuffer::<u128>::from(values.views().clone().into_arrow());

// Gather the views from value_views into full_views using the dictionary codes.
let full_views: Vec<u128> = codes
.maybe_null_slice::<u64>()
.iter()
.map(|code| value_views[*code as usize])
.collect();

let validity = if array.dtype().is_nullable() {
// For nullable arrays, a code of 0 indicates null value.
Validity::Array(compare(
codes.as_ref(),
ConstantArray::new(0u64, codes.len()).as_ref(),
Operator::Eq,
)?)
} else {
Validity::NonNullable
};

VarBinViewArray::try_new(
full_views.into(),
values.buffers().collect(),
array.dtype().clone(),
validity,
)
.map(Canonical::VarBinView)
}

fn canonicalize_primitive(array: DictArray) -> VortexResult<Canonical> {
let canonical_values: Array = array.values().into_canonical()?.into();
take(canonical_values, array.codes())?.into_canonical()
}

impl ArrayValidity for DictArray {
fn is_valid(&self, index: usize) -> bool {
let values_index = scalar_at(self.codes(), index)
35 changes: 35 additions & 0 deletions vortex-array/src/array/bool/compute/cast.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use vortex_dtype::{DType, Nullability};
use vortex_error::{vortex_bail, VortexResult};

use crate::array::BoolArray;
use crate::compute::unary::CastFn;
use crate::validity::Validity;
use crate::{Array, ArrayDType, IntoArray};

impl CastFn for BoolArray {
fn cast(&self, dtype: &DType) -> VortexResult<Array> {
if !dtype.is_boolean() {
vortex_bail!("Cannot cast BoolArray to non-Bool type");
}

match (self.dtype().nullability(), dtype.nullability()) {
// convert to same nullability => no-op
(Nullability::NonNullable, Nullability::NonNullable)
| (Nullability::Nullable, Nullability::Nullable) => Ok(self.clone().into_array()),

// convert non-nullable to nullable
(Nullability::NonNullable, Nullability::Nullable) => {
Ok(BoolArray::try_new(self.boolean_buffer(), Validity::AllValid)?.into_array())
}

// convert nullable to non-nullable, only safe if there are no nulls present.
(Nullability::Nullable, Nullability::NonNullable) => {
if self.validity() != Validity::AllValid {
vortex_bail!("cannot cast bool array with nulls as non-nullable");
}

Ok(BoolArray::try_new(self.boolean_buffer(), Validity::NonNullable)?.into_array())
}
}
}
}
7 changes: 6 additions & 1 deletion vortex-array/src/array/bool/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use crate::array::BoolArray;
use crate::compute::unary::{FillForwardFn, ScalarAtFn};
use crate::compute::unary::{CastFn, FillForwardFn, ScalarAtFn};
use crate::compute::{AndFn, ArrayCompute, OrFn, SliceFn, TakeFn};

mod boolean;

mod cast;
mod fill;
mod filter;
mod flatten;
@@ -12,6 +13,10 @@ mod slice;
mod take;

impl ArrayCompute for BoolArray {
fn cast(&self) -> Option<&dyn CastFn> {
Some(self)
}

fn fill_forward(&self) -> Option<&dyn FillForwardFn> {
Some(self)
}
2 changes: 1 addition & 1 deletion vortex-array/src/array/varbinview/mod.rs
Original file line number Diff line number Diff line change
@@ -2,12 +2,12 @@ use std::fmt::{Debug, Display, Formatter};
use std::slice;
use std::sync::Arc;

use ::serde::{Deserialize, Serialize};
use arrow_array::builder::{BinaryViewBuilder, GenericByteViewBuilder, StringViewBuilder};
use arrow_array::types::{BinaryViewType, ByteViewType, StringViewType};
use arrow_array::{ArrayRef, BinaryViewArray, GenericByteViewArray, StringViewArray};
use arrow_buffer::ScalarBuffer;
use itertools::Itertools;
use ::serde::{Deserialize, Serialize};
use static_assertions::{assert_eq_align, assert_eq_size};
use vortex_buffer::Buffer;
use vortex_dtype::{DType, PType};
15 changes: 15 additions & 0 deletions vortex-array/src/compute/compare.rs
Original file line number Diff line number Diff line change
@@ -78,6 +78,21 @@ pub trait MaybeCompareFn {
fn maybe_compare(&self, other: &Array, operator: Operator) -> Option<VortexResult<Array>>;
}

/// Binary comparison operation between two arrays.
///
/// The result of comparison is a `Bool` typed Array holding `true` where both of the operands
/// satisfy the operation, `false` otherwise.
///
/// Nullability of the result is the union of the nullabilities of the operands.
///
/// ## Null semantics
///
/// All binary comparison operations where one of the operands is `NULL` will result in a `NULL`
/// value being placed in the output.
///
/// This semantic is derived from [Apache Arrow's handling of nulls].
///
/// [Apache Arrow's handling of nulls]: https://arrow.apache.org/rust/arrow/compute/kernels/cmp/fn.eq.html
pub fn compare(
left: impl AsRef<Array>,
right: impl AsRef<Array>,

0 comments on commit 8d3ed4d

Please sign in to comment.