Skip to content

Commit

Permalink
Remove select with typed_simplify pass (#1929)
Browse files Browse the repository at this point in the history
  • Loading branch information
joseph-isaacs authored Jan 14, 2025
1 parent e4a9061 commit 6cd94fc
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 14 deletions.
35 changes: 34 additions & 1 deletion vortex-expr/src/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::sync::Arc;
use itertools::Itertools;
use vortex_array::ArrayData;
use vortex_dtype::FieldNames;
use vortex_error::{vortex_err, VortexResult};
use vortex_error::{vortex_bail, vortex_err, VortexResult};

use crate::field::DisplayFieldNames;
use crate::{ExprRef, VortexExpr};
Expand Down Expand Up @@ -51,6 +51,13 @@ impl Select {
pub fn child(&self) -> &ExprRef {
&self.child
}

pub fn as_include(&self, field_names: &FieldNames) -> VortexResult<ExprRef> {
Ok(Self::new_expr(
SelectField::Include(self.fields.as_include_names(field_names)?),
self.child.clone(),
))
}
}

impl SelectField {
Expand All @@ -64,12 +71,38 @@ impl SelectField {
Self::Exclude(columns)
}

pub fn is_include(&self) -> bool {
matches!(self, Self::Include(_))
}

pub fn is_exclude(&self) -> bool {
matches!(self, Self::Exclude(_))
}

pub fn fields(&self) -> &FieldNames {
match self {
SelectField::Include(fields) => fields,
SelectField::Exclude(fields) => fields,
}
}

pub fn as_include_names(&self, field_names: &FieldNames) -> VortexResult<FieldNames> {
if self.fields().iter().any(|f| !field_names.contains(f)) {
vortex_bail!(
"Field {:?} in select not in field names {:?}",
self,
field_names
);
}
match self {
SelectField::Include(fields) => Ok(fields.clone()),
SelectField::Exclude(exc_fields) => Ok(field_names
.iter()
.filter(|f| exc_fields.contains(f))
.cloned()
.collect()),
}
}
}

impl Display for SelectField {
Expand Down
2 changes: 2 additions & 0 deletions vortex-expr/src/transform/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! A collection of transformations that can be applied to a [`crate::ExprRef`].
pub mod partition;
pub(crate) mod remove_select;
pub mod simplify;
pub mod simplify_typed;
21 changes: 10 additions & 11 deletions vortex-expr/src/transform/partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ mod tests {

use super::*;
use crate::transform::simplify::simplify;
use crate::transform::simplify_typed::simplify_typed;
use crate::{and, get_item, ident, lit, pack, select, Pack};

fn struct_dtype() -> StructDType {
Expand Down Expand Up @@ -407,35 +408,33 @@ mod tests {
assert_eq!(partitioned.partitions.len(), 2);
}

// Test that typed_simplify removes select and partition precise
#[test]
fn test_expr_partition_many_occurances_of_field() {
fn test_expr_partition_many_occurrences_of_field() {
let dtype = struct_dtype();

let expr = and(
get_item("b", get_item("a", ident())),
select(vec!["a".into(), "b".into()], ident()),
);
let expr = simplify_typed(expr, DType::Struct(dtype.clone(), NonNullable)).unwrap();
let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();

// One for id.a and id.b
assert_eq!(partitioned.partitions.len(), 3);
assert_eq!(partitioned.partitions.len(), 2);

// This fetches [].$c which is unused, however a previous optimisation should replace select
// with get_item and pack removing this field.
assert_eq!(
&partitioned.root,
&and(
get_item("0", get_item("a", ident())),
select(
pack(
vec!["a".into(), "b".into()],
pack(
vec!["a".into(), "b".into(), "c".into()],
vec![
get_item("1", get_item("a", ident())),
get_item("0", get_item("b", ident())),
get_item("0", get_item("c", ident())),
]
)
vec![
get_item("1", get_item("a", ident())),
get_item("0", get_item("b", ident())),
]
)
)
)
Expand Down
84 changes: 84 additions & 0 deletions vortex-expr/src/transform/remove_select.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
use itertools::Itertools;
use vortex_dtype::DType;
use vortex_error::{vortex_err, VortexResult};

use crate::traversal::{MutNodeVisitor, Node, TransformResult};
use crate::{get_item, pack, ExprRef, Select};

/// Select is a useful expression, however it can be defined in terms of get_item & pack,
/// once the expression type is known, this simplifications pass removes the select expression.
pub fn remove_select(e: ExprRef, scope_dt: DType) -> VortexResult<ExprRef> {
let mut transform = RemoveSelectTransform::new(scope_dt);
e.transform(&mut transform).map(|e| e.result)
}

struct RemoveSelectTransform {
ident_dtype: DType,
}

impl RemoveSelectTransform {
fn new(ident_dtype: DType) -> Self {
Self { ident_dtype }
}
}

impl MutNodeVisitor for RemoveSelectTransform {
type NodeTy = ExprRef;

fn visit_up(&mut self, node: ExprRef) -> VortexResult<TransformResult<Self::NodeTy>> {
if let Some(select) = node.as_any().downcast_ref::<Select>() {
let child = select.child();
let child_dtype = child.return_dtype(&self.ident_dtype)?;
let child_dtype = child_dtype.as_struct().ok_or_else(|| {
vortex_err!(
"Select child must return a struct dtype, however it was a {}",
child_dtype
)
})?;

let names = select
.fields()
.as_include_names(child_dtype.names())
.map_err(|e| {
vortex_err!(
"Select fields must be a subset of child fields, however {}",
e
)
})?;

let pack_children = names
.iter()
.map(|name| get_item(name.clone(), child.clone()))
.collect_vec();

Ok(TransformResult::yes(pack(names, pack_children)))
} else {
Ok(TransformResult::no(node))
}
}
}

#[cfg(test)]
mod tests {
use vortex_dtype::Nullability::NonNullable;
use vortex_dtype::PType::I32;
use vortex_dtype::{DType, StructDType};

use crate::transform::remove_select::remove_select;
use crate::{ident, select, Pack};

#[test]
fn test_remove_select() {
let dtype = DType::Struct(
StructDType::new(
["a".into(), "b".into()].into(),
vec![I32.into(), I32.into()],
),
NonNullable,
);
let e = select(["a".into(), "b".into()], ident());
let e = remove_select(e, dtype).unwrap();

assert!(e.as_any().downcast_ref::<Pack>().is_some());
}
}
14 changes: 14 additions & 0 deletions vortex-expr/src/transform/simplify_typed.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
use vortex_dtype::DType;
use vortex_error::VortexResult;

use crate::transform::remove_select::remove_select;
use crate::transform::simplify::simplify;
use crate::ExprRef;

/// This pass simplifies an expression under the assumption that ident()/scope as a fixed DType.
/// There is another pass `simplify` that simplifies an expression without any assumptions.
/// This pass also applies simplify.
pub fn simplify_typed(e: ExprRef, scope_dt: DType) -> VortexResult<ExprRef> {
let e = simplify(e)?;
remove_select(e, scope_dt)
}
4 changes: 2 additions & 2 deletions vortex-expr/src/traversal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,8 @@ impl Node for ExprRef {
};

if ord == TraversalOrder::Continue {
let up = visitor.visit_up(self)?;
Ok(TransformResult::yes(up.result.replacing_children(children)))
let up = visitor.visit_up(self.replacing_children(children))?;
Ok(TransformResult::yes(up.result))
} else {
Ok(TransformResult {
result: self.replacing_children(children),
Expand Down

0 comments on commit 6cd94fc

Please sign in to comment.