Skip to content

Commit

Permalink
fix case of f(dict_array, dict_array) invocation (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt authored Jan 15, 2025
1 parent 38caf97 commit cb0e669
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 47 deletions.
60 changes: 52 additions & 8 deletions src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@ use std::sync::Arc;

use datafusion::arrow::array::{
Array, ArrayAccessor, ArrayRef, AsArray, DictionaryArray, Int64Array, LargeStringArray, PrimitiveArray,
StringArray, StringViewArray, UInt64Array, UnionArray,
StringArray, StringViewArray, UInt64Array,
};
use datafusion::arrow::compute::take;
use datafusion::arrow::datatypes::{
ArrowDictionaryKeyType, ArrowNativeType, ArrowPrimitiveType, DataType, Int64Type, UInt64Type,
ArrowDictionaryKeyType, ArrowNativeType, ArrowNativeTypeOp, DataType, Int64Type, UInt64Type,
};
use datafusion::arrow::downcast_dictionary_array;
use datafusion::common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue};
use datafusion::logical_expr::ColumnarValue;
use jiter::{Jiter, JiterError, Peek};

use crate::common_union::{is_json_union, json_from_union_scalar, nested_json_array, TYPE_ID_NULL};
use crate::common_union::{
is_json_union, json_from_union_scalar, nested_json_array, nested_json_array_ref, TYPE_ID_NULL,
};

/// General implementation of `ScalarUDFImpl::return_type`.
///
Expand Down Expand Up @@ -95,6 +97,7 @@ impl From<i64> for JsonPath<'_> {
}
}

#[derive(Debug)]
enum JsonPathArgs<'a> {
Array(&'a ArrayRef),
Scalars(Vec<JsonPath<'a>>),
Expand Down Expand Up @@ -175,9 +178,48 @@ fn invoke_array_array<C: FromIterator<Option<I>> + 'static, I>(
) -> DataFusionResult<ArrayRef> {
downcast_dictionary_array!(
json_array => {
let values = invoke_array_array(json_array.values(), path_array, to_array, jiter_find, return_dict)?;
post_process_dict(json_array, values, return_dict)
}
fn wrap_as_dictionary<K: ArrowDictionaryKeyType>(original: &DictionaryArray<K>, new_values: ArrayRef) -> DictionaryArray<K> {
assert_eq!(original.keys().len(), new_values.len());
let mut key = K::Native::ZERO;
let key_range = std::iter::from_fn(move || {
let next = key;
key = key.add_checked(K::Native::ONE).expect("keys exhausted");
Some(next)
}).take(new_values.len());
let mut keys = PrimitiveArray::<K>::from_iter_values(key_range);
if is_json_union(new_values.data_type()) {
// JSON union: post-process the array to set keys to null where the union member is null
let type_ids = new_values.as_union().type_ids();
keys = mask_dictionary_keys(&keys, type_ids);
}
DictionaryArray::<K>::new(keys, new_values)
}

// TODO: in theory if path_array is _also_ a dictionary we could work out the unique key
// combinations and do less work, but this can be left as a future optimization
let output = match json_array.values().data_type() {
DataType::Utf8 => zip_apply(json_array.downcast_dict::<StringArray>().unwrap(), path_array, to_array, jiter_find),
DataType::LargeUtf8 => zip_apply(json_array.downcast_dict::<LargeStringArray>().unwrap(), path_array, to_array, jiter_find),
DataType::Utf8View => zip_apply(json_array.downcast_dict::<StringViewArray>().unwrap(), path_array, to_array, jiter_find),
other => if let Some(child_array) = nested_json_array_ref(json_array.values(), is_object_lookup_array(path_array.data_type())) {
// Horrible case: dict containing union as input with array for paths, figure
// out from the path type which union members we should access, repack the
// dictionary and then recurse.
//
// Use direct return because if return_dict applies, the recursion will handle it.
return invoke_array_array(&(Arc::new(json_array.with_values(child_array.clone())) as _), path_array, to_array, jiter_find, return_dict)
} else {
exec_err!("unexpected json array type {:?}", other)
}
}?;

if return_dict {
// ensure return is a dictionary to satisfy the declaration above in return_type_check
Ok(Arc::new(wrap_as_dictionary(json_array, output)))
} else {
Ok(output)
}
},
DataType::Utf8 => zip_apply(json_array.as_string::<i32>().iter(), path_array, to_array, jiter_find),
DataType::LargeUtf8 => zip_apply(json_array.as_string::<i64>().iter(), path_array, to_array, jiter_find),
DataType::Utf8View => zip_apply(json_array.as_string_view().iter(), path_array, to_array, jiter_find),
Expand Down Expand Up @@ -239,6 +281,7 @@ fn invoke_scalar_array<C: FromIterator<Option<I>> + 'static, I>(
to_array,
jiter_find,
)
// FIXME edge cases where scalar is wrapped in a dictionary, should return a dictionary?
.map(ColumnarValue::Array)
}

Expand All @@ -250,6 +293,7 @@ fn invoke_scalar_scalars<I>(
) -> DataFusionResult<ColumnarValue> {
let s = extract_json_scalar(scalar)?;
let v = jiter_find(s, path).ok();
// FIXME edge cases where scalar is wrapped in a dictionary, should return a dictionary?
Ok(ColumnarValue::Scalar(to_scalar(v)))
}

Expand Down Expand Up @@ -321,7 +365,7 @@ fn post_process_dict<T: ArrowDictionaryKeyType>(
if return_dict {
if is_json_union(result_values.data_type()) {
// JSON union: post-process the array to set keys to null where the union member is null
let type_ids = result_values.as_any().downcast_ref::<UnionArray>().unwrap().type_ids();
let type_ids = result_values.as_union().type_ids();
Ok(Arc::new(DictionaryArray::new(
mask_dictionary_keys(dict_array.keys(), type_ids),
result_values,
Expand Down Expand Up @@ -413,7 +457,7 @@ impl From<Utf8Error> for GetError {
///
/// That said, doing this might also be an optimization for cases like null-checking without needing
/// to check the value union array.
fn mask_dictionary_keys<K: ArrowPrimitiveType>(keys: &PrimitiveArray<K>, type_ids: &[i8]) -> PrimitiveArray<K> {
fn mask_dictionary_keys<K: ArrowDictionaryKeyType>(keys: &PrimitiveArray<K>, type_ids: &[i8]) -> PrimitiveArray<K> {
let mut null_mask = vec![true; keys.len()];
for (i, k) in keys.iter().enumerate() {
match k {
Expand Down
6 changes: 5 additions & 1 deletion src/common_union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@ pub fn is_json_union(data_type: &DataType) -> bool {
/// * `object_lookup` - If `true`, extract from the "object" member of the union,
/// otherwise extract from the "array" member
pub(crate) fn nested_json_array(array: &ArrayRef, object_lookup: bool) -> Option<&StringArray> {
nested_json_array_ref(array, object_lookup).map(AsArray::as_string)
}

pub(crate) fn nested_json_array_ref(array: &ArrayRef, object_lookup: bool) -> Option<&ArrayRef> {
let union_array: &UnionArray = array.as_any().downcast_ref::<UnionArray>()?;
let type_id = if object_lookup { TYPE_ID_OBJECT } else { TYPE_ID_ARRAY };
union_array.child(type_id).as_any().downcast_ref()
Some(union_array.child(type_id))
}

/// Extract a JSON string from a `JsonUnion` scalar
Expand Down
Loading

0 comments on commit cb0e669

Please sign in to comment.