diff --git a/Cargo.toml b/Cargo.toml index fe3ea020..5351475a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ integer-encoding = "3.0.4" strum_macros = "0.24" ordered-float = "3.0" petgraph = "0.6.3" -futures-async-stream = "0.2.6" +futures-async-stream = "0.2.9" futures = "0.3.25" ahash = "0.8.3" lazy_static = "1.4.0" @@ -39,6 +39,7 @@ bytes = "1.5.0" kip_db = "0.1.2-alpha.17" rust_decimal = "1" csv = "1" +regex = "1.10.2" [dev-dependencies] tokio-test = "0.4.2" diff --git a/README.md b/README.md index 62638a13..501a2494 100755 --- a/README.md +++ b/README.md @@ -1,12 +1,27 @@ -# KipSQL +
+Built by @KipData
+
+██╗  ██╗██╗██████╗ ███████╗ ██████╗ ██╗
+██║ ██╔╝██║██╔══██╗██╔════╝██╔═══██╗██║
+█████╔╝ ██║██████╔╝███████╗██║   ██║██║
+██╔═██╗ ██║██╔═══╝ ╚════██║██║▄▄ ██║██║
+     ██║  ██╗██║██║     ███████║╚██████╔╝███████╗
+     ╚═╝  ╚═╝╚═╝╚═╝     ╚══════╝ ╚══▀▀═╝ ╚══════╝
+-----------------------------------
+Embedded SQL DBMS
+
+
+ +### Architecture +Welcome to our WebSite, Power By KipSQL: +**http://www.kipdata.site/** > Lightweight SQL calculation engine, as the SQL layer of KipDB, implemented with TalentPlan's TinySQL as the reference standard -### Architecture + ![architecture](./static/images/architecture.png) ### Get Started -#### 组件引入 ``` toml kip-sql = "0.0.1-alpha.0" ``` @@ -79,6 +94,12 @@ implement_from_tuple!(Post, ( - not null - null - unique + - primary key +- SQL where options + - is null + - is not null + - like + - not like - Supports index type - Unique Index - Supports multiple primary key types diff --git a/rust-toolchain b/rust-toolchain index c34ab853..07ade694 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -nightly-2023-10-13 \ No newline at end of file +nightly \ No newline at end of file diff --git a/src/binder/aggregate.rs b/src/binder/aggregate.rs index 9e65e0ea..b8523ac7 100644 --- a/src/binder/aggregate.rs +++ b/src/binder/aggregate.rs @@ -92,6 +92,8 @@ impl<'a, T: Transaction> Binder<'a, T> { expr: &mut ScalarExpression, is_select: bool, ) -> Result<(), BindError> { + let ref_columns = expr.referenced_columns(); + match expr { ScalarExpression::AggCall { ty: return_type, .. @@ -99,7 +101,11 @@ impl<'a, T: Transaction> Binder<'a, T> { let ty = return_type.clone(); if is_select { let index = self.context.input_ref_index(InputRefType::AggCall); - let input_ref = ScalarExpression::InputRef { index, ty }; + let input_ref = ScalarExpression::InputRef { + index, + ty, + ref_columns, + }; match std::mem::replace(expr, input_ref) { ScalarExpression::AggCall { kind, @@ -124,14 +130,21 @@ impl<'a, T: Transaction> Binder<'a, T> { .find_position(|agg_expr| agg_expr == &expr) .ok_or_else(|| BindError::AggMiss(format!("{:?}", expr)))?; - let _ = std::mem::replace(expr, ScalarExpression::InputRef { index, ty }); + let _ = std::mem::replace( + expr, + ScalarExpression::InputRef { + index, + ty, + ref_columns, + }, + ); } } ScalarExpression::TypeCast { expr, .. } => { self.visit_column_agg_expr(expr, is_select)? } - ScalarExpression::IsNull { expr } => self.visit_column_agg_expr(expr, is_select)?, + ScalarExpression::IsNull { expr, .. } => self.visit_column_agg_expr(expr, is_select)?, ScalarExpression::Unary { expr, .. } => self.visit_column_agg_expr(expr, is_select)?, ScalarExpression::Alias { expr, .. } => self.visit_column_agg_expr(expr, is_select)?, ScalarExpression::Binary { @@ -228,6 +241,7 @@ impl<'a, T: Transaction> Binder<'a, T> { }) { let index = self.context.input_ref_index(InputRefType::GroupBy); let mut select_item = &mut select_list[i]; + let ref_columns = select_item.referenced_columns(); let return_type = select_item.return_type(); self.context.group_by_exprs.push(std::mem::replace( @@ -235,6 +249,7 @@ impl<'a, T: Transaction> Binder<'a, T> { ScalarExpression::InputRef { index, ty: return_type, + ref_columns, }, )); return; @@ -243,6 +258,8 @@ impl<'a, T: Transaction> Binder<'a, T> { if let Some(i) = select_list.iter().position(|column| column == expr) { let expr = &mut select_list[i]; + let ref_columns = expr.referenced_columns(); + match expr { ScalarExpression::Constant(_) | ScalarExpression::ColumnRef { .. } => { self.context.group_by_exprs.push(expr.clone()) @@ -255,6 +272,7 @@ impl<'a, T: Transaction> Binder<'a, T> { ScalarExpression::InputRef { index, ty: expr.return_type(), + ref_columns, }, )) } @@ -300,7 +318,7 @@ impl<'a, T: Transaction> Binder<'a, T> { } ScalarExpression::TypeCast { expr, .. } => self.validate_having_orderby(expr), - ScalarExpression::IsNull { expr } => self.validate_having_orderby(expr), + ScalarExpression::IsNull { expr, .. } => self.validate_having_orderby(expr), ScalarExpression::Unary { expr, .. } => self.validate_having_orderby(expr), ScalarExpression::Binary { left_expr, diff --git a/src/binder/create_table.rs b/src/binder/create_table.rs index bba6497f..65aeb8f3 100644 --- a/src/binder/create_table.rs +++ b/src/binder/create_table.rs @@ -80,13 +80,13 @@ mod tests { match plan1.operator { Operator::CreateTable(op) => { assert_eq!(op.table_name, Arc::new("t1".to_string())); - assert_eq!(op.columns[0].name, "id".to_string()); + assert_eq!(op.columns[0].name(), "id"); assert_eq!(op.columns[0].nullable, false); assert_eq!( op.columns[0].desc, ColumnDesc::new(LogicalType::Integer, true, false) ); - assert_eq!(op.columns[1].name, "name".to_string()); + assert_eq!(op.columns[1].name(), "name"); assert_eq!(op.columns[1].nullable, true); assert_eq!( op.columns[1].desc, diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 628a1db6..01bd21c2 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -1,4 +1,5 @@ use crate::binder::BindError; +use crate::expression; use crate::expression::agg::AggKind; use itertools::Itertools; use sqlparser::ast::{ @@ -25,13 +26,41 @@ impl<'a, T: Transaction> Binder<'a, T> { Expr::Function(func) => self.bind_agg_call(func), Expr::Nested(expr) => self.bind_expr(expr), Expr::UnaryOp { expr, op } => self.bind_unary_op_internal(expr, op), - Expr::IsNull(expr) => self.bind_is_null(expr), + Expr::Like { + negated, + expr, + pattern, + .. + } => self.bind_like(*negated, expr, pattern), + Expr::IsNull(expr) => self.bind_is_null(expr, false), + Expr::IsNotNull(expr) => self.bind_is_null(expr, true), _ => { todo!() } } } + pub fn bind_like( + &mut self, + negated: bool, + expr: &Expr, + pattern: &Expr, + ) -> Result { + let left_expr = Box::new(self.bind_expr(expr)?); + let right_expr = Box::new(self.bind_expr(pattern)?); + let op = if negated { + expression::BinaryOperator::NotLike + } else { + expression::BinaryOperator::Like + }; + Ok(ScalarExpression::Binary { + op, + left_expr, + right_expr, + ty: LogicalType::Boolean, + }) + } + pub fn bind_column_ref_from_identifiers( &mut self, idents: &[Ident], @@ -199,8 +228,9 @@ impl<'a, T: Transaction> Binder<'a, T> { }) } - fn bind_is_null(&mut self, expr: &Expr) -> Result { + fn bind_is_null(&mut self, expr: &Expr, negated: bool) -> Result { Ok(ScalarExpression::IsNull { + negated, expr: Box::new(self.bind_expr(expr)?), }) } diff --git a/src/binder/select.rs b/src/binder/select.rs index 2001bd48..cb8868d8 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -343,9 +343,7 @@ impl<'a, T: Transaction> Binder<'a, T> { select_list: Vec, ) -> LogicalPlan { LogicalPlan { - operator: Operator::Project(ProjectOperator { - columns: select_list, - }), + operator: Operator::Project(ProjectOperator { exprs: select_list }), childrens: vec![children], } } @@ -431,7 +429,8 @@ impl<'a, T: Transaction> Binder<'a, T> { for column in select_items { if let ScalarExpression::ColumnRef(col) = column { - if let Some(nullable) = table_force_nullable.get(col.table_name.as_ref().unwrap()) { + if let Some(nullable) = table_force_nullable.get(col.table_name().as_ref().unwrap()) + { let mut new_col = ColumnCatalog::clone(col); new_col.nullable = *nullable; @@ -504,12 +503,12 @@ impl<'a, T: Transaction> Binder<'a, T> { // example: foo = bar (ScalarExpression::ColumnRef(l), ScalarExpression::ColumnRef(r)) => { // reorder left and right joins keys to pattern: (left, right) - if left_schema.contains_column(&l.name) - && right_schema.contains_column(&r.name) + if left_schema.contains_column(l.name()) + && right_schema.contains_column(r.name()) { accum.push((left, right)); - } else if left_schema.contains_column(&r.name) - && right_schema.contains_column(&l.name) + } else if left_schema.contains_column(r.name()) + && right_schema.contains_column(l.name()) { accum.push((right, left)); } else { diff --git a/src/catalog/column.rs b/src/catalog/column.rs index c3206160..4e73de6c 100644 --- a/src/catalog/column.rs +++ b/src/catalog/column.rs @@ -2,6 +2,7 @@ use crate::catalog::TableName; use crate::expression::ScalarExpression; use serde::{Deserialize, Serialize}; use sqlparser::ast::{ColumnDef, ColumnOption}; +use std::hash::Hash; use std::sync::Arc; use crate::types::{ColumnId, LogicalType}; @@ -10,14 +11,19 @@ pub type ColumnRef = Arc; #[derive(Debug, Clone, Serialize, Deserialize, Hash, Eq, PartialEq)] pub struct ColumnCatalog { - pub id: Option, - pub name: String, - pub table_name: Option, + pub summary: ColumnSummary, pub nullable: bool, pub desc: ColumnDesc, pub ref_expr: Option, } +#[derive(Debug, Clone, Serialize, Deserialize, Hash, Eq, PartialEq)] +pub struct ColumnSummary { + pub id: Option, + pub name: String, + pub table_name: Option, +} + impl ColumnCatalog { pub(crate) fn new( column_name: String, @@ -26,9 +32,11 @@ impl ColumnCatalog { ref_expr: Option, ) -> ColumnCatalog { ColumnCatalog { - id: None, - name: column_name, - table_name: None, + summary: ColumnSummary { + id: None, + name: column_name, + table_name: None, + }, nullable, desc: column_desc, ref_expr, @@ -37,20 +45,39 @@ impl ColumnCatalog { pub(crate) fn new_dummy(column_name: String) -> ColumnCatalog { ColumnCatalog { - id: Some(0), - name: column_name, - table_name: None, + summary: ColumnSummary { + id: Some(0), + name: column_name, + table_name: None, + }, nullable: false, desc: ColumnDesc::new(LogicalType::Varchar(None), false, false), ref_expr: None, } } + pub(crate) fn summary(&self) -> &ColumnSummary { + &self.summary + } + + pub(crate) fn id(&self) -> Option { + self.summary.id + } + + pub(crate) fn table_name(&self) -> Option { + self.summary.table_name.clone() + } + + pub(crate) fn name(&self) -> &str { + &self.summary.name + } + pub(crate) fn datatype(&self) -> &LogicalType { &self.desc.column_datatype } - pub fn desc(&self) -> &ColumnDesc { + #[allow(dead_code)] + pub(crate) fn desc(&self) -> &ColumnDesc { &self.desc } } diff --git a/src/catalog/table.rs b/src/catalog/table.rs index 215e2438..92268f61 100644 --- a/src/catalog/table.rs +++ b/src/catalog/table.rs @@ -38,7 +38,7 @@ impl TableCatalog { self.columns.get(id) } - pub(crate) fn contains_column(&self, name: &String) -> bool { + pub(crate) fn contains_column(&self, name: &str) -> bool { self.column_idxs.contains_key(name) } @@ -55,15 +55,15 @@ impl TableCatalog { /// Add a column to the table catalog. pub(crate) fn add_column(&mut self, mut col: ColumnCatalog) -> Result { - if self.column_idxs.contains_key(&col.name) { - return Err(CatalogError::Duplicated("column", col.name.clone())); + if self.column_idxs.contains_key(col.name()) { + return Err(CatalogError::Duplicated("column", col.name().to_string())); } let col_id = self.columns.len() as u32; - col.id = Some(col_id); - col.table_name = Some(self.name.clone()); - self.column_idxs.insert(col.name.clone(), col_id); + col.summary.id = Some(col_id); + col.summary.table_name = Some(self.name.clone()); + self.column_idxs.insert(col.name().to_string(), col_id); self.columns.insert(col_id, Arc::new(col)); Ok(col_id) @@ -148,11 +148,11 @@ mod tests { assert!(col_a_id < col_b_id); let column_catalog = table_catalog.get_column_by_id(&col_a_id).unwrap(); - assert_eq!(column_catalog.name, "a"); + assert_eq!(column_catalog.name(), "a"); assert_eq!(*column_catalog.datatype(), LogicalType::Integer,); let column_catalog = table_catalog.get_column_by_id(&col_b_id).unwrap(); - assert_eq!(column_catalog.name, "b"); + assert_eq!(column_catalog.name(), "b"); assert_eq!(*column_catalog.datatype(), LogicalType::Boolean,); } } diff --git a/src/db.rs b/src/db.rs index 813168bf..1b94a885 100644 --- a/src/db.rs +++ b/src/db.rs @@ -44,7 +44,6 @@ impl Database { if stmts.is_empty() { return Ok(vec![]); } - let binder = Binder::new(BinderContext::new(&transaction)); /// Build a logical plan. @@ -71,10 +70,15 @@ impl Database { fn default_optimizer(source_plan: LogicalPlan) -> HepOptimizer { HepOptimizer::new(source_plan) + .batch( + "Column Pruning".to_string(), + HepBatchStrategy::once_topdown(), + vec![RuleImpl::ColumnPruning], + ) .batch( "Simplify Filter".to_string(), HepBatchStrategy::fix_point_topdown(10), - vec![RuleImpl::SimplifyFilter], + vec![RuleImpl::SimplifyFilter, RuleImpl::ConstantCalculation], ) .batch( "Predicate Pushdown".to_string(), @@ -89,14 +93,6 @@ impl Database { HepBatchStrategy::fix_point_topdown(10), vec![RuleImpl::CollapseProject, RuleImpl::CombineFilter], ) - .batch( - "Column Pruning".to_string(), - HepBatchStrategy::fix_point_topdown(10), - vec![ - RuleImpl::PushProjectThroughChild, - RuleImpl::PushProjectIntoScan, - ], - ) .batch( "Limit Pushdown".to_string(), HepBatchStrategy::fix_point_topdown(10), @@ -199,7 +195,7 @@ mod test { let _ = kipsql .run("create table t2 (c int primary key, d int unsigned null, e datetime)") .await?; - let _ = kipsql.run("insert into t1 (a, b, k, z) values (-99, 1, 1, 'k'), (-1, 2, 2, 'i'), (5, 3, 2, 'p')").await?; + let _ = kipsql.run("insert into t1 (a, b, k, z) values (-99, 1, 1, 'k'), (-1, 2, 2, 'i'), (5, 3, 2, 'p'), (29, 4, 2, 'db')").await?; let _ = kipsql.run("insert into t2 (d, c, e) values (2, 1, '2021-05-20 21:00:00'), (3, 4, '2023-09-10 00:00:00')").await?; let _ = kipsql .run("create table t3 (a int primary key, b decimal(4,2))") @@ -231,6 +227,18 @@ mod test { let tuples_projection_and_sort = kipsql.run("select * from t1 order by a, b").await?; println!("{}", create_table(&tuples_projection_and_sort)); + println!("like t1 1:"); + let tuples_like_1_t1 = kipsql.run("select * from t1 where z like '%k'").await?; + println!("{}", create_table(&tuples_like_1_t1)); + + println!("like t1 2:"); + let tuples_like_2_t1 = kipsql.run("select * from t1 where z like '_b'").await?; + println!("{}", create_table(&tuples_like_2_t1)); + + println!("not like t1:"); + let tuples_not_like_t1 = kipsql.run("select * from t1 where z not like '_b'").await?; + println!("{}", create_table(&tuples_not_like_t1)); + println!("limit:"); let tuples_limit = kipsql.run("select * from t1 limit 1 offset 1").await?; println!("{}", create_table(&tuples_limit)); diff --git a/src/execution/executor/dml/copy_from_file.rs b/src/execution/executor/dml/copy_from_file.rs index 341fb974..10e68ac5 100644 --- a/src/execution/executor/dml/copy_from_file.rs +++ b/src/execution/executor/dml/copy_from_file.rs @@ -111,7 +111,7 @@ fn return_result(size: usize, tx: Sender) -> Result<(), ExecutorError> { #[cfg(test)] mod tests { - use crate::catalog::{ColumnCatalog, ColumnDesc}; + use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnSummary}; use crate::db::{Database, DatabaseError}; use futures::StreamExt; use std::io::Write; @@ -132,25 +132,31 @@ mod tests { let columns = vec![ Arc::new(ColumnCatalog { - id: Some(0), - name: "a".to_string(), - table_name: None, + summary: ColumnSummary { + id: Some(0), + name: "a".to_string(), + table_name: None, + }, nullable: false, desc: ColumnDesc::new(LogicalType::Integer, true, false), ref_expr: None, }), Arc::new(ColumnCatalog { - id: Some(1), - name: "b".to_string(), - table_name: None, + summary: ColumnSummary { + id: Some(1), + name: "b".to_string(), + table_name: None, + }, nullable: false, desc: ColumnDesc::new(LogicalType::Float, false, false), ref_expr: None, }), Arc::new(ColumnCatalog { - id: Some(1), - name: "c".to_string(), - table_name: None, + summary: ColumnSummary { + id: Some(1), + name: "c".to_string(), + table_name: None, + }, nullable: false, desc: ColumnDesc::new(LogicalType::Varchar(Some(10)), false, false), ref_expr: None, diff --git a/src/execution/executor/dml/delete.rs b/src/execution/executor/dml/delete.rs index d85c2549..a6b60540 100644 --- a/src/execution/executor/dml/delete.rs +++ b/src/execution/executor/dml/delete.rs @@ -39,7 +39,7 @@ impl Delete { col.desc .is_unique .then(|| { - col.id.and_then(|col_id| { + col.id().and_then(|col_id| { table_catalog .get_unique_index(&col_id) .map(|index_meta| (i, index_meta.clone())) diff --git a/src/execution/executor/dml/insert.rs b/src/execution/executor/dml/insert.rs index a4f5e1a0..34746825 100644 --- a/src/execution/executor/dml/insert.rs +++ b/src/execution/executor/dml/insert.rs @@ -62,7 +62,7 @@ impl Insert { for (i, value) in values.into_iter().enumerate() { let col = &columns[i]; - if let Some(col_id) = col.id { + if let Some(col_id) = col.id() { tuple_map.insert(col_id, value); } } @@ -70,7 +70,7 @@ impl Insert { columns .iter() .find(|col| col.desc.is_primary) - .map(|col| col.id.unwrap()) + .map(|col| col.id().unwrap()) .unwrap() }); let all_columns = table_catalog.all_columns_with_id(); @@ -87,7 +87,7 @@ impl Insert { if col.desc.is_unique && !value.is_null() { unique_values - .entry(col.id) + .entry(col.id()) .or_insert_with(|| vec![]) .push((tuple_id.clone(), value.clone())) } diff --git a/src/execution/executor/dml/update.rs b/src/execution/executor/dml/update.rs index ee48a87a..411753c6 100644 --- a/src/execution/executor/dml/update.rs +++ b/src/execution/executor/dml/update.rs @@ -56,7 +56,7 @@ impl Update { columns, values, .. } = tuple?; for i in 0..columns.len() { - value_map.insert(columns[i].id, values[i].clone()); + value_map.insert(columns[i].id(), values[i].clone()); } } #[for_await] @@ -65,7 +65,7 @@ impl Update { let mut is_overwrite = true; for (i, column) in tuple.columns.iter().enumerate() { - if let Some(value) = value_map.get(&column.id) { + if let Some(value) = value_map.get(&column.id()) { if column.desc.is_primary { let old_key = tuple.id.replace(value.clone()).unwrap(); @@ -74,7 +74,7 @@ impl Update { } if column.desc.is_unique && value != &tuple.values[i] { if let Some(index_meta) = - table_catalog.get_unique_index(&column.id.unwrap()) + table_catalog.get_unique_index(&column.id().unwrap()) { let mut index = Index { id: index_meta.id, diff --git a/src/execution/executor/dql/projection.rs b/src/execution/executor/dql/projection.rs index 8285896a..d6c21b92 100644 --- a/src/execution/executor/dql/projection.rs +++ b/src/execution/executor/dql/projection.rs @@ -13,11 +13,8 @@ pub struct Projection { } impl From<(ProjectOperator, BoxedExecutor)> for Projection { - fn from((ProjectOperator { columns }, input): (ProjectOperator, BoxedExecutor)) -> Self { - Projection { - exprs: columns, - input, - } + fn from((ProjectOperator { exprs }, input): (ProjectOperator, BoxedExecutor)) -> Self { + Projection { exprs, input } } } diff --git a/src/expression/evaluator.rs b/src/expression/evaluator.rs index c233ed23..9288ec43 100644 --- a/src/expression/evaluator.rs +++ b/src/expression/evaluator.rs @@ -16,7 +16,7 @@ impl ScalarExpression { match &self { ScalarExpression::Constant(val) => Ok(val.clone()), ScalarExpression::ColumnRef(col) => { - let value = Self::eval_with_name(&tuple, &col.name) + let value = Self::eval_with_name(&tuple, col.name()) .unwrap_or(&NULL_VALUE) .clone(); @@ -46,10 +46,13 @@ impl ScalarExpression { Ok(Arc::new(binary_op(&left, &right, op)?)) } - ScalarExpression::IsNull { expr } => { - let value = expr.eval_column(tuple)?; + ScalarExpression::IsNull { expr, negated } => { + let mut value = expr.eval_column(tuple)?.is_null(); - Ok(Arc::new(DataValue::Boolean(Some(value.is_null())))) + if *negated { + value = !value; + } + Ok(Arc::new(DataValue::Boolean(Some(value)))) } ScalarExpression::Unary { expr, op, .. } => { let value = expr.eval_column(tuple)?; @@ -60,11 +63,11 @@ impl ScalarExpression { } } - fn eval_with_name<'a>(tuple: &'a Tuple, name: &String) -> Option<&'a ValueRef> { + fn eval_with_name<'a>(tuple: &'a Tuple, name: &str) -> Option<&'a ValueRef> { tuple .columns .iter() - .find_position(|tul_col| &tul_col.name == name) + .find_position(|tul_col| tul_col.name() == name) .map(|(i, _)| &tuple.values[i]) } } diff --git a/src/expression/mod.rs b/src/expression/mod.rs index 6c617fd8..282649af 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -30,6 +30,7 @@ pub enum ScalarExpression { InputRef { index: usize, ty: LogicalType, + ref_columns: Vec, }, Alias { expr: Box, @@ -40,6 +41,7 @@ pub enum ScalarExpression { ty: LogicalType, }, IsNull { + negated: bool, expr: Box, }, Unary { @@ -70,6 +72,23 @@ impl ScalarExpression { } } + pub fn has_count_star(&self) -> bool { + match self { + ScalarExpression::InputRef { ref_columns, .. } => ref_columns.is_empty(), + ScalarExpression::Alias { expr, .. } => expr.has_count_star(), + ScalarExpression::TypeCast { expr, .. } => expr.has_count_star(), + ScalarExpression::IsNull { expr, .. } => expr.has_count_star(), + ScalarExpression::Unary { expr, .. } => expr.has_count_star(), + ScalarExpression::Binary { + left_expr, + right_expr, + .. + } => left_expr.has_count_star() || right_expr.has_count_star(), + ScalarExpression::AggCall { args, .. } => args.iter().any(Self::has_count_star), + _ => false, + } + } + pub fn nullable(&self) -> bool { match self { ScalarExpression::Constant(_) => false, @@ -77,7 +96,7 @@ impl ScalarExpression { ScalarExpression::InputRef { .. } => unreachable!(), ScalarExpression::Alias { expr, .. } => expr.nullable(), ScalarExpression::TypeCast { expr, .. } => expr.nullable(), - ScalarExpression::IsNull { expr } => expr.nullable(), + ScalarExpression::IsNull { expr, .. } => expr.nullable(), ScalarExpression::Unary { expr, .. } => expr.nullable(), ScalarExpression::Binary { left_expr, @@ -135,6 +154,9 @@ impl ScalarExpression { columns_collect(expr, vec) } } + ScalarExpression::InputRef { ref_columns, .. } => { + vec.extend_from_slice(ref_columns); + } _ => (), } } @@ -187,7 +209,7 @@ impl ScalarExpression { } => { let args_str = args .iter() - .map(|expr| expr.output_columns(tuple).name.clone()) + .map(|expr| expr.output_columns(tuple).name().to_string()) .join(", "); let op = |allow_distinct, distinct| { if allow_distinct && distinct { @@ -219,9 +241,9 @@ impl ScalarExpression { } => { let column_name = format!( "({} {} {})", - left_expr.output_columns(tuple).name, + left_expr.output_columns(tuple).name(), op, - right_expr.output_columns(tuple).name, + right_expr.output_columns(tuple).name(), ); Arc::new(ColumnCatalog::new( @@ -232,7 +254,7 @@ impl ScalarExpression { )) } ScalarExpression::Unary { expr, op, ty } => { - let column_name = format!("{} {}", op, expr.output_columns(tuple).name,); + let column_name = format!("{} {}", op, expr.output_columns(tuple).name()); Arc::new(ColumnCatalog::new( column_name, true, @@ -280,6 +302,8 @@ pub enum BinaryOperator { Spaceship, Eq, NotEq, + Like, + NotLike, And, Or, @@ -305,6 +329,8 @@ impl fmt::Display for BinaryOperator { BinaryOperator::And => write!(f, "&&"), BinaryOperator::Or => write!(f, "||"), BinaryOperator::Xor => write!(f, "^"), + BinaryOperator::Like => write!(f, "like"), + BinaryOperator::NotLike => write!(f, "not like"), } } } diff --git a/src/expression/simplify.rs b/src/expression/simplify.rs index 25a7be94..53537f01 100644 --- a/src/expression/simplify.rs +++ b/src/expression/simplify.rs @@ -266,10 +266,10 @@ struct ReplaceUnary { impl ScalarExpression { pub fn exist_column(&self, col_id: &ColumnId) -> bool { match self { - ScalarExpression::ColumnRef(col) => col.id == Some(*col_id), + ScalarExpression::ColumnRef(col) => col.id() == Some(*col_id), ScalarExpression::Alias { expr, .. } => expr.exist_column(col_id), ScalarExpression::TypeCast { expr, .. } => expr.exist_column(col_id), - ScalarExpression::IsNull { expr } => expr.exist_column(col_id), + ScalarExpression::IsNull { expr, .. } => expr.exist_column(col_id), ScalarExpression::Unary { expr, .. } => expr.exist_column(col_id), ScalarExpression::Binary { left_expr, @@ -287,7 +287,7 @@ impl ScalarExpression { ScalarExpression::TypeCast { expr, ty, .. } => expr .unpack_val() .and_then(|val| DataValue::clone(&val).cast(ty).ok().map(Arc::new)), - ScalarExpression::IsNull { expr } => { + ScalarExpression::IsNull { expr, .. } => { let is_null = expr.unpack_val().map(|val| val.is_null()); Some(Arc::new(DataValue::Boolean(is_null))) @@ -338,6 +338,48 @@ impl ScalarExpression { self._simplify(&mut Vec::new()) } + pub fn constant_calculation(&mut self) -> Result<(), TypeError> { + match self { + ScalarExpression::Unary { expr, op, .. } => { + expr.constant_calculation()?; + + if let ScalarExpression::Constant(unary_val) = expr.as_ref() { + let value = unary_op(unary_val, op)?; + let _ = mem::replace(self, ScalarExpression::Constant(Arc::new(value))); + } + } + ScalarExpression::Binary { + left_expr, + right_expr, + op, + .. + } => { + left_expr.constant_calculation()?; + right_expr.constant_calculation()?; + + if let ( + ScalarExpression::Constant(left_val), + ScalarExpression::Constant(right_val), + ) = (left_expr.as_ref(), right_expr.as_ref()) + { + let value = binary_op(left_val, right_val, op)?; + let _ = mem::replace(self, ScalarExpression::Constant(Arc::new(value))); + } + } + ScalarExpression::Alias { expr, .. } => expr.constant_calculation()?, + ScalarExpression::TypeCast { expr, .. } => expr.constant_calculation()?, + ScalarExpression::IsNull { expr, .. } => expr.constant_calculation()?, + ScalarExpression::AggCall { args, .. } => { + for expr in args { + expr.constant_calculation()?; + } + } + _ => (), + } + + Ok(()) + } + // Tips: Indirect expressions like `ScalarExpression::Alias` will be lost fn _simplify(&mut self, replaces: &mut Vec) -> Result<(), TypeError> { match self { @@ -458,7 +500,6 @@ impl ScalarExpression { if Self::is_arithmetic(op) { return Ok(()); } - while let Some(replace) = replaces.pop() { match replace { Replace::Binary(binary) => Self::fix_binary(binary, left_expr, right_expr, op), @@ -638,7 +679,7 @@ impl ScalarExpression { } ScalarExpression::Alias { expr, .. } => expr.convert_binary(col_id), ScalarExpression::TypeCast { expr, .. } => expr.convert_binary(col_id), - ScalarExpression::IsNull { expr } => expr.convert_binary(col_id), + ScalarExpression::IsNull { expr, .. } => expr.convert_binary(col_id), ScalarExpression::Unary { expr, .. } => expr.convert_binary(col_id), _ => Ok(None), } @@ -666,7 +707,7 @@ impl ScalarExpression { val: ValueRef, is_flip: bool, ) -> Option { - if col.id.unwrap() != *col_id { + if col.id() != Some(*col_id) { return None; } @@ -706,7 +747,7 @@ impl ScalarExpression { #[cfg(test)] mod test { - use crate::catalog::{ColumnCatalog, ColumnDesc}; + use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnSummary}; use crate::expression::simplify::ConstantBinary; use crate::expression::{BinaryOperator, ScalarExpression}; use crate::types::errors::TypeError; @@ -718,9 +759,11 @@ mod test { #[test] fn test_convert_binary_simple() -> Result<(), TypeError> { let col_1 = Arc::new(ColumnCatalog { - id: Some(0), - name: "c1".to_string(), - table_name: None, + summary: ColumnSummary { + id: Some(0), + name: "c1".to_string(), + table_name: None, + }, nullable: false, desc: ColumnDesc { column_datatype: LogicalType::Integer, diff --git a/src/expression/value_compute.rs b/src/expression/value_compute.rs index fe4efdc1..a47d62ec 100644 --- a/src/expression/value_compute.rs +++ b/src/expression/value_compute.rs @@ -2,6 +2,7 @@ use crate::expression::{BinaryOperator, UnaryOperator}; use crate::types::errors::TypeError; use crate::types::value::DataValue; use crate::types::LogicalType; +use regex::Regex; fn unpack_i32(value: DataValue) -> Option { match value { @@ -114,6 +115,22 @@ pub fn binary_op( right: &DataValue, op: &BinaryOperator, ) -> Result { + if matches!(op, BinaryOperator::Like | BinaryOperator::NotLike) { + let value_option = unpack_utf8(left.clone().cast(&LogicalType::Varchar(None))?); + let pattern_option = unpack_utf8(right.clone().cast(&LogicalType::Varchar(None))?); + + let mut is_match = if let (Some(value), Some(pattern)) = (value_option, pattern_option) { + let regex_pattern = pattern.replace("%", ".*").replace("_", "."); + + Regex::new(®ex_pattern).unwrap().is_match(&value) + } else { + unreachable!("The left and right values calculated by Like cannot be Null values.") + }; + if op == &BinaryOperator::NotLike { + is_match = !is_match; + } + return Ok(DataValue::Boolean(Some(is_match))); + } let unified_type = LogicalType::max_logical_type(&left.logical_type(), &right.logical_type())?; let value = match &unified_type { diff --git a/src/lib.rs b/src/lib.rs index e7d9c23b..a1478ad2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,11 +1,10 @@ #![feature(error_generic_member_access)] #![allow(unused_doc_comments)] #![feature(result_flattening)] -#![feature(generators)] +#![feature(coroutines)] #![feature(iterator_try_collect)] #![feature(slice_pattern)] #![feature(bound_map)] -#![feature(async_fn_in_trait)] extern crate core; pub mod binder; pub mod catalog; diff --git a/src/marco/mod.rs b/src/marco/mod.rs index f9ca70e9..1f6ab659 100644 --- a/src/marco/mod.rs +++ b/src/marco/mod.rs @@ -31,7 +31,7 @@ macro_rules! implement_from_tuple { let (idx, _) = tuple.columns .iter() .enumerate() - .find(|(_, col)| &col.name == field_name)?; + .find(|(_, col)| col.name() == field_name)?; DataValue::clone(&tuple.values[idx]) .cast(&ty) diff --git a/src/optimizer/core/opt_expr.rs b/src/optimizer/core/opt_expr.rs index 0b7136fc..c29e7b8e 100644 --- a/src/optimizer/core/opt_expr.rs +++ b/src/optimizer/core/opt_expr.rs @@ -3,36 +3,17 @@ use crate::planner::LogicalPlan; use std::fmt::Debug; pub type OptExprNodeId = usize; - -#[derive(Clone, PartialEq)] -pub enum OptExprNode { - /// Raw plan node with dummy children. - OperatorRef(Operator), - #[allow(dead_code)] - /// Existing OptExprNode in graph. - OptExpr(OptExprNodeId), -} - -impl Debug for OptExprNode { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::OperatorRef(op) => write!(f, "LogicOperator({:?})", op), - Self::OptExpr(id) => write!(f, "OptExpr({})", id), - } - } -} - #[derive(Clone, Debug)] pub struct OptExpr { /// The root of the tree. - pub root: OptExprNode, + pub root: Operator, /// The root's children expressions. pub childrens: Vec, } impl OptExpr { #[allow(dead_code)] - pub fn new(root: OptExprNode, childrens: Vec) -> Self { + pub fn new(root: Operator, childrens: Vec) -> Self { Self { root, childrens } } @@ -43,7 +24,7 @@ impl OptExpr { #[allow(dead_code)] fn build_opt_expr_internal(input: &LogicalPlan) -> OptExpr { - let root = OptExprNode::OperatorRef(input.operator.clone()); + let root = input.operator.clone(); let childrens = input .childrens .iter() @@ -54,22 +35,14 @@ impl OptExpr { #[allow(dead_code)] pub fn to_plan_ref(&self) -> LogicalPlan { - match &self.root { - OptExprNode::OperatorRef(op) => { - let childrens = self - .childrens - .iter() - .map(|c| c.to_plan_ref()) - .collect::>(); - LogicalPlan { - operator: op.clone(), - childrens, - } - } - OptExprNode::OptExpr(_) => LogicalPlan { - operator: Operator::Dummy, - childrens: vec![], - }, + let childrens = self + .childrens + .iter() + .map(|c| c.to_plan_ref()) + .collect::>(); + LogicalPlan { + operator: self.root.clone(), + childrens, } } } diff --git a/src/optimizer/heuristic/batch.rs b/src/optimizer/heuristic/batch.rs index 81f6ed78..2161bec7 100644 --- a/src/optimizer/heuristic/batch.rs +++ b/src/optimizer/heuristic/batch.rs @@ -30,7 +30,6 @@ pub struct HepBatchStrategy { } impl HepBatchStrategy { - #[allow(dead_code)] pub fn once_topdown() -> Self { HepBatchStrategy { max_iteration: 1, diff --git a/src/optimizer/heuristic/graph.rs b/src/optimizer/heuristic/graph.rs index 4f4b9e97..16ca4ef4 100644 --- a/src/optimizer/heuristic/graph.rs +++ b/src/optimizer/heuristic/graph.rs @@ -1,4 +1,4 @@ -use crate::optimizer::core::opt_expr::{OptExprNode, OptExprNodeId}; +use crate::optimizer::core::opt_expr::OptExprNodeId; use crate::optimizer::heuristic::batch::HepMatchOrder; use crate::planner::operator::Operator; use crate::planner::LogicalPlan; @@ -12,7 +12,7 @@ pub type HepNodeId = NodeIndex; #[derive(Debug)] pub struct HepGraph { - graph: StableDiGraph, + graph: StableDiGraph, root_index: HepNodeId, pub version: usize, } @@ -20,13 +20,13 @@ pub struct HepGraph { impl HepGraph { pub fn new(root: LogicalPlan) -> Self { fn graph_filling( - graph: &mut StableDiGraph, + graph: &mut StableDiGraph, LogicalPlan { operator, childrens, }: LogicalPlan, ) -> HepNodeId { - let index = graph.add_node(OptExprNode::OperatorRef(operator)); + let index = graph.add_node(operator); for (order, child) in childrens.into_iter().enumerate() { let child_index = graph_filling(graph, child); @@ -36,7 +36,7 @@ impl HepGraph { index } - let mut graph = StableDiGraph::::default(); + let mut graph = StableDiGraph::::default(); let root_index = graph_filling(&mut graph, root); @@ -54,7 +54,7 @@ impl HepGraph { } #[allow(dead_code)] - pub fn add_root(&mut self, new_node: OptExprNode) { + pub fn add_root(&mut self, new_node: Operator) { let old_root_id = mem::replace(&mut self.root_index, self.graph.add_node(new_node)); self.graph.add_edge(self.root_index, old_root_id, 0); @@ -65,7 +65,7 @@ impl HepGraph { &mut self, source_id: HepNodeId, children_option: Option, - new_node: OptExprNode, + new_node: Operator, ) { let new_index = self.graph.add_node(new_node); @@ -85,7 +85,7 @@ impl HepGraph { self.version += 1; } - pub fn replace_node(&mut self, source_id: HepNodeId, new_node: OptExprNode) { + pub fn replace_node(&mut self, source_id: HepNodeId, new_node: Operator) { self.graph[source_id] = new_node; self.version += 1; } @@ -97,11 +97,7 @@ impl HepGraph { self.version += 1; } - pub fn remove_node( - &mut self, - source_id: HepNodeId, - with_childrens: bool, - ) -> Option { + pub fn remove_node(&mut self, source_id: HepNodeId, with_childrens: bool) -> Option { if !with_childrens { let children_ids = self .graph @@ -152,15 +148,16 @@ impl HepGraph { } #[allow(dead_code)] - pub fn node(&self, node_id: HepNodeId) -> Option<&OptExprNode> { + pub fn node(&self, node_id: HepNodeId) -> Option<&Operator> { self.graph.node_weight(node_id) } pub fn operator(&self, node_id: HepNodeId) -> &Operator { - match &self.graph[node_id] { - OptExprNode::OperatorRef(op) => op, - OptExprNode::OptExpr(node_id) => self.operator(HepNodeId::new(*node_id)), - } + &self.graph[node_id] + } + + pub fn operator_mut(&mut self, node_id: HepNodeId) -> &mut Operator { + &mut self.graph[node_id] } pub fn to_plan(&self) -> LogicalPlan { @@ -204,7 +201,6 @@ impl HepGraph { mod tests { use crate::binder::test::select_sql_run; use crate::execution::ExecutorError; - use crate::optimizer::core::opt_expr::OptExprNode; use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; use crate::planner::operator::Operator; use petgraph::stable_graph::{EdgeIndex, NodeIndex}; @@ -236,23 +232,11 @@ mod tests { let plan = select_sql_run("select * from t1 left join t2 on c1 = c3").await?; let mut graph = HepGraph::new(plan); - graph.add_node( - HepNodeId::new(1), - None, - OptExprNode::OperatorRef(Operator::Dummy), - ); + graph.add_node(HepNodeId::new(1), None, Operator::Dummy); - graph.add_node( - HepNodeId::new(1), - Some(HepNodeId::new(4)), - OptExprNode::OperatorRef(Operator::Dummy), - ); + graph.add_node(HepNodeId::new(1), Some(HepNodeId::new(4)), Operator::Dummy); - graph.add_node( - HepNodeId::new(5), - None, - OptExprNode::OperatorRef(Operator::Dummy), - ); + graph.add_node(HepNodeId::new(5), None, Operator::Dummy); assert!(graph .graph @@ -276,7 +260,7 @@ mod tests { let plan = select_sql_run("select * from t1 left join t2 on c1 = c3").await?; let mut graph = HepGraph::new(plan); - graph.replace_node(HepNodeId::new(1), OptExprNode::OperatorRef(Operator::Dummy)); + graph.replace_node(HepNodeId::new(1), Operator::Dummy); assert!(matches!(graph.operator(HepNodeId::new(1)), Operator::Dummy)); @@ -338,7 +322,7 @@ mod tests { let plan = select_sql_run("select * from t1 left join t2 on c1 = c3").await?; let mut graph = HepGraph::new(plan); - graph.add_root(OptExprNode::OperatorRef(Operator::Dummy)); + graph.add_root(Operator::Dummy); assert_eq!(graph.graph.edge_count(), 4); assert!(graph diff --git a/src/optimizer/rule/column_pruning.rs b/src/optimizer/rule/column_pruning.rs index f9a657db..0ae347dc 100644 --- a/src/optimizer/rule/column_pruning.rs +++ b/src/optimizer/rule/column_pruning.rs @@ -1,191 +1,153 @@ -use crate::catalog::ColumnRef; +use crate::catalog::{ColumnRef, ColumnSummary}; +use crate::expression::agg::AggKind; use crate::expression::ScalarExpression; -use crate::optimizer::core::opt_expr::OptExprNode; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; use crate::optimizer::core::rule::Rule; use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; use crate::optimizer::OptimizerError; -use crate::planner::operator::aggregate::AggregateOperator; -use crate::planner::operator::project::ProjectOperator; use crate::planner::operator::Operator; -use itertools::Itertools; +use crate::types::value::DataValue; +use crate::types::LogicalType; use lazy_static::lazy_static; +use std::collections::HashSet; +use std::sync::Arc; lazy_static! { - static ref PUSH_PROJECT_INTO_SCAN_RULE: Pattern = { + static ref COLUMN_PRUNING_RULE: Pattern = { Pattern { - predicate: |op| matches!(op, Operator::Project(_)), - children: PatternChildrenPredicate::Predicate(vec![Pattern { - predicate: |op| matches!(op, Operator::Scan(_)), - children: PatternChildrenPredicate::None, - }]), - } - }; - static ref PUSH_PROJECT_THROUGH_CHILD_RULE: Pattern = { - Pattern { - predicate: |op| matches!(op, Operator::Project(_)), - children: PatternChildrenPredicate::Predicate(vec![Pattern { - predicate: |op| !matches!(op, Operator::Scan(_) | Operator::Project(_)), - children: PatternChildrenPredicate::Predicate(vec![Pattern { - predicate: |op| !matches!(op, Operator::Project(_)), - children: PatternChildrenPredicate::None, - }]), - }]), + predicate: |_| true, + children: PatternChildrenPredicate::None, } }; } -#[derive(Copy, Clone)] -pub struct PushProjectIntoScan; - -impl Rule for PushProjectIntoScan { - fn pattern(&self) -> &Pattern { - &PUSH_PROJECT_INTO_SCAN_RULE - } - - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> { - if let Operator::Project(project_op) = graph.operator(node_id) { - let child_index = graph.children_at(node_id)[0]; - if let Operator::Scan(scan_op) = graph.operator(child_index) { - let mut new_scan_op = scan_op.clone(); - - new_scan_op.columns = project_op - .columns - .iter() - .map(ScalarExpression::unpack_alias) - .cloned() - .collect_vec(); - - graph.remove_node(node_id, false); - graph.replace_node( - child_index, - OptExprNode::OperatorRef(Operator::Scan(new_scan_op)), - ); - } - } - - Ok(()) - } -} - #[derive(Clone)] -pub struct PushProjectThroughChild; +pub struct ColumnPruning; -impl Rule for PushProjectThroughChild { - fn pattern(&self) -> &Pattern { - &PUSH_PROJECT_THROUGH_CHILD_RULE +impl ColumnPruning { + fn clear_exprs( + column_references: &mut HashSet, + exprs: &mut Vec, + ) { + exprs.retain(|expr| { + expr.referenced_columns() + .iter() + .any(|column| column_references.contains(column.summary())) + }) } - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> { - let node_operator = graph.operator(node_id); - - if let Operator::Project(_) = node_operator { - let child_index = graph.children_at(node_id)[0]; - let node_referenced_columns = node_operator.referenced_columns(); - let child_operator = graph.operator(child_index); - let child_referenced_columns = child_operator.referenced_columns(); - let op = |col: &ColumnRef| format!("{:?}.{:?}", col.table_name, col.id); - - match child_operator { - // When the aggregate function is a child node, - // the pushdown will lose the corresponding ColumnRef due to `InputRef`. - // Therefore, it is necessary to map the InputRef to the corresponding ColumnRef - // and push it down. - Operator::Aggregate(AggregateOperator { agg_calls, .. }) => { - let grandson_id = graph.children_at(child_index)[0]; - let columns = node_operator - .project_input_refs() - .iter() - .filter_map(|expr| { - if agg_calls.is_empty() { - return None; - } - - if let ScalarExpression::InputRef { index, .. } = expr { - agg_calls.get(*index).cloned() - } else { - None - } + fn _apply( + column_references: &mut HashSet, + all_referenced: bool, + node_id: HepNodeId, + graph: &mut HepGraph, + ) { + let operator = graph.operator_mut(node_id); + + match operator { + Operator::Aggregate(op) => { + if !all_referenced { + Self::clear_exprs(column_references, &mut op.agg_calls); + + if op.agg_calls.is_empty() && op.groupby_exprs.is_empty() { + let value = Arc::new(DataValue::Utf8(Some("*".to_string()))); + // only single COUNT(*) is not depend on any column + // removed all expressions from the aggregate: push a COUNT(*) + op.agg_calls.push(ScalarExpression::AggCall { + distinct: false, + kind: AggKind::Count, + args: vec![ScalarExpression::Constant(value)], + ty: LogicalType::Integer, }) - .map(|expr| expr.referenced_columns()) - .flatten() - .chain(node_referenced_columns.into_iter()) - .chain(child_referenced_columns.into_iter()) - .unique_by(op) - .map(|col| ScalarExpression::ColumnRef(col)) - .collect_vec(); - - Self::add_project_node(graph, child_index, columns, grandson_id); - } - Operator::Join(_) => { - let parent_referenced_columns = node_referenced_columns - .into_iter() - .chain(child_referenced_columns.into_iter()) - .unique_by(op) - .collect_vec(); - - for grandson_id in graph.children_at(child_index) { - let grandson_referenced_column = - graph.operator(grandson_id).referenced_columns(); - - // for PushLimitThroughJoin - if grandson_referenced_column.is_empty() { - return Ok(()); - } - let grandson_table_name = grandson_referenced_column[0].table_name.clone(); - let columns = parent_referenced_columns - .iter() - .filter(|col| col.table_name == grandson_table_name) - .cloned() - .map(|col| ScalarExpression::ColumnRef(col)) - .collect_vec(); - - Self::add_project_node(graph, child_index, columns, grandson_id); } } - _ => { - let grandson_ids = graph.children_at(child_index); + let op_ref_columns = operator.referenced_columns(); - if grandson_ids.is_empty() { - return Ok(()); + Self::recollect_apply(op_ref_columns, false, node_id, graph); + } + Operator::Project(op) => { + let has_count_star = op.exprs.iter().any(ScalarExpression::has_count_star); + if !has_count_star { + if !all_referenced { + Self::clear_exprs(column_references, &mut op.exprs); } - let grandson_id = grandson_ids[0]; - let mut columns = node_operator.project_input_refs(); - let mut referenced_columns = node_referenced_columns - .into_iter() - .chain(child_referenced_columns.into_iter()) - .unique_by(op) - .map(|col| ScalarExpression::ColumnRef(col)) - .collect_vec(); - - columns.append(&mut referenced_columns); + let op_ref_columns = operator.referenced_columns(); - Self::add_project_node(graph, child_index, columns, grandson_id); + Self::recollect_apply(op_ref_columns, false, node_id, graph); } } - } + Operator::Sort(_op) => { + if !all_referenced { + // Todo: Order Project + // https://github.com/duckdb/duckdb/blob/main/src/optimizer/remove_unused_columns.cpp#L174 + } + for child_id in graph.children_at(node_id) { + Self::_apply(column_references, true, child_id, graph); + } + } + Operator::Scan(op) => { + if !all_referenced { + Self::clear_exprs(column_references, &mut op.columns); + } + } + Operator::Limit(_) | Operator::Join(_) | Operator::Filter(_) => { + for column in operator.referenced_columns() { + column_references.insert(column.summary().clone()); + } + for child_id in graph.children_at(node_id) { + Self::_apply(column_references, all_referenced, child_id, graph); + } + } + // Last Operator + Operator::Dummy | Operator::Values(_) => (), + // DDL Based on Other Plan + Operator::Insert(_) | Operator::Update(_) | Operator::Delete(_) => { + let op_ref_columns = operator.referenced_columns(); - Ok(()) + Self::recollect_apply(op_ref_columns, true, graph.children_at(node_id)[0], graph); + } + // DDL Single Plan + Operator::CreateTable(_) + | Operator::DropTable(_) + | Operator::Truncate(_) + | Operator::Show(_) + | Operator::CopyFromFile(_) + | Operator::CopyToFile(_) => (), + } } -} -impl PushProjectThroughChild { - fn add_project_node( + fn recollect_apply( + referenced_columns: Vec, + all_referenced: bool, + node_id: HepNodeId, graph: &mut HepGraph, - child_index: HepNodeId, - columns: Vec, - grandson_id: HepNodeId, ) { - if !columns.is_empty() { - graph.add_node( - child_index, - Some(grandson_id), - OptExprNode::OperatorRef(Operator::Project(ProjectOperator { columns })), - ); + for child_id in graph.children_at(node_id) { + let mut new_references: HashSet = referenced_columns + .iter() + .map(|column| column.summary()) + .cloned() + .collect(); + + Self::_apply(&mut new_references, all_referenced, child_id, graph); } } } +impl Rule for ColumnPruning { + fn pattern(&self) -> &Pattern { + &COLUMN_PRUNING_RULE + } + + fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> { + Self::_apply(&mut HashSet::new(), true, node_id, graph); + // mark changed to skip this rule batch + graph.version += 1; + + Ok(()) + } +} + #[cfg(test)] mod tests { use crate::binder::test::select_sql_run; @@ -197,47 +159,21 @@ mod tests { use crate::planner::operator::Operator; #[tokio::test] - async fn test_project_into_table_scan() -> Result<(), DatabaseError> { - let plan = select_sql_run("select * from t1").await?; - - let best_plan = HepOptimizer::new(plan.clone()) - .batch( - "test_project_into_table_scan".to_string(), - HepBatchStrategy::once_topdown(), - vec![RuleImpl::PushProjectIntoScan], - ) - .find_best()?; - - assert_eq!(best_plan.childrens.len(), 0); - match best_plan.operator { - Operator::Scan(op) => { - assert_eq!(op.columns.len(), 2); - } - _ => unreachable!("Should be a scan operator"), - } - - Ok(()) - } - - #[tokio::test] - async fn test_project_through_child_on_join() -> Result<(), DatabaseError> { + async fn test_column_pruning() -> Result<(), DatabaseError> { let plan = select_sql_run("select c1, c3 from t1 left join t2 on c1 = c3").await?; let best_plan = HepOptimizer::new(plan.clone()) .batch( - "test_project_through_child_on_join".to_string(), - HepBatchStrategy::fix_point_topdown(10), - vec![ - RuleImpl::PushProjectThroughChild, - RuleImpl::PushProjectIntoScan, - ], + "test_column_pruning".to_string(), + HepBatchStrategy::once_topdown(), + vec![RuleImpl::ColumnPruning], ) .find_best()?; assert_eq!(best_plan.childrens.len(), 1); match best_plan.operator { Operator::Project(op) => { - assert_eq!(op.columns.len(), 2); + assert_eq!(op.exprs.len(), 2); } _ => unreachable!("Should be a project operator"), } diff --git a/src/optimizer/rule/combine_operators.rs b/src/optimizer/rule/combine_operators.rs index 7105dada..a4daa1f1 100644 --- a/src/optimizer/rule/combine_operators.rs +++ b/src/optimizer/rule/combine_operators.rs @@ -1,5 +1,4 @@ use crate::expression::{BinaryOperator, ScalarExpression}; -use crate::optimizer::core::opt_expr::OptExprNode; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; use crate::optimizer::core::rule::Rule; use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; @@ -43,7 +42,7 @@ impl Rule for CollapseProject { if let Operator::Project(op) = graph.operator(node_id) { let child_id = graph.children_at(node_id)[0]; if let Operator::Project(child_op) = graph.operator(child_id) { - if is_subset_exprs(&op.columns, &child_op.columns) { + if is_subset_exprs(&op.exprs, &child_op.exprs) { graph.remove_node(child_id, false); } else { graph.remove_node(node_id, false); @@ -76,10 +75,7 @@ impl Rule for CombineFilter { }, having: op.having || child_op.having, }; - graph.replace_node( - node_id, - OptExprNode::OperatorRef(Operator::Filter(new_filter_op)), - ); + graph.replace_node(node_id, Operator::Filter(new_filter_op)); graph.remove_node(child_id, false); } } @@ -94,7 +90,6 @@ mod tests { use crate::db::DatabaseError; use crate::expression::ScalarExpression::Constant; use crate::expression::{BinaryOperator, ScalarExpression}; - use crate::optimizer::core::opt_expr::OptExprNode; use crate::optimizer::heuristic::batch::HepBatchStrategy; use crate::optimizer::heuristic::graph::HepNodeId; use crate::optimizer::heuristic::optimizer::HepOptimizer; @@ -117,19 +112,17 @@ mod tests { let mut new_project_op = optimizer.graph.operator(HepNodeId::new(0)).clone(); if let Operator::Project(op) = &mut new_project_op { - op.columns.remove(0); + op.exprs.remove(0); } else { unreachable!("Should be a project operator") } - optimizer - .graph - .add_root(OptExprNode::OperatorRef(new_project_op)); + optimizer.graph.add_root(new_project_op); let best_plan = optimizer.find_best()?; if let Operator::Project(op) = &best_plan.operator { - assert_eq!(op.columns.len(), 1); + assert_eq!(op.exprs.len(), 1); } else { unreachable!("Should be a project operator") } @@ -166,11 +159,9 @@ mod tests { unreachable!("Should be a filter operator") } - optimizer.graph.add_node( - HepNodeId::new(0), - Some(HepNodeId::new(1)), - OptExprNode::OperatorRef(new_filter_op), - ); + optimizer + .graph + .add_node(HepNodeId::new(0), Some(HepNodeId::new(1)), new_filter_op); let best_plan = optimizer.find_best()?; diff --git a/src/optimizer/rule/mod.rs b/src/optimizer/rule/mod.rs index 846aa27c..1c9bbbed 100644 --- a/src/optimizer/rule/mod.rs +++ b/src/optimizer/rule/mod.rs @@ -2,13 +2,14 @@ use crate::expression::ScalarExpression; use crate::optimizer::core::pattern::Pattern; use crate::optimizer::core::rule::Rule; use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; -use crate::optimizer::rule::column_pruning::{PushProjectIntoScan, PushProjectThroughChild}; +use crate::optimizer::rule::column_pruning::ColumnPruning; use crate::optimizer::rule::combine_operators::{CollapseProject, CombineFilter}; use crate::optimizer::rule::pushdown_limit::{ EliminateLimits, LimitProjectTranspose, PushLimitIntoScan, PushLimitThroughJoin, }; use crate::optimizer::rule::pushdown_predicates::PushPredicateIntoScan; use crate::optimizer::rule::pushdown_predicates::PushPredicateThroughJoin; +use crate::optimizer::rule::simplification::ConstantCalculation; use crate::optimizer::rule::simplification::SimplifyFilter; use crate::optimizer::OptimizerError; @@ -20,9 +21,7 @@ mod simplification; #[derive(Debug, Copy, Clone)] pub enum RuleImpl { - // Column pruning - PushProjectIntoScan, - PushProjectThroughChild, + ColumnPruning, // Combine operators CollapseProject, CombineFilter, @@ -37,38 +36,39 @@ pub enum RuleImpl { PushPredicateIntoScan, // Simplification SimplifyFilter, + ConstantCalculation, } impl Rule for RuleImpl { fn pattern(&self) -> &Pattern { match self { - RuleImpl::PushProjectIntoScan => PushProjectIntoScan {}.pattern(), - RuleImpl::PushProjectThroughChild => PushProjectThroughChild {}.pattern(), - RuleImpl::CollapseProject => CollapseProject {}.pattern(), - RuleImpl::CombineFilter => CombineFilter {}.pattern(), - RuleImpl::LimitProjectTranspose => LimitProjectTranspose {}.pattern(), - RuleImpl::EliminateLimits => EliminateLimits {}.pattern(), - RuleImpl::PushLimitThroughJoin => PushLimitThroughJoin {}.pattern(), - RuleImpl::PushLimitIntoTableScan => PushLimitIntoScan {}.pattern(), - RuleImpl::PushPredicateThroughJoin => PushPredicateThroughJoin {}.pattern(), - RuleImpl::PushPredicateIntoScan => PushPredicateIntoScan {}.pattern(), - RuleImpl::SimplifyFilter => SimplifyFilter {}.pattern(), + RuleImpl::ColumnPruning => ColumnPruning.pattern(), + RuleImpl::CollapseProject => CollapseProject.pattern(), + RuleImpl::CombineFilter => CombineFilter.pattern(), + RuleImpl::LimitProjectTranspose => LimitProjectTranspose.pattern(), + RuleImpl::EliminateLimits => EliminateLimits.pattern(), + RuleImpl::PushLimitThroughJoin => PushLimitThroughJoin.pattern(), + RuleImpl::PushLimitIntoTableScan => PushLimitIntoScan.pattern(), + RuleImpl::PushPredicateThroughJoin => PushPredicateThroughJoin.pattern(), + RuleImpl::PushPredicateIntoScan => PushPredicateIntoScan.pattern(), + RuleImpl::SimplifyFilter => SimplifyFilter.pattern(), + RuleImpl::ConstantCalculation => ConstantCalculation.pattern(), } } fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> { match self { - RuleImpl::PushProjectIntoScan => PushProjectIntoScan {}.apply(node_id, graph), - RuleImpl::PushProjectThroughChild => PushProjectThroughChild {}.apply(node_id, graph), - RuleImpl::CollapseProject => CollapseProject {}.apply(node_id, graph), - RuleImpl::CombineFilter => CombineFilter {}.apply(node_id, graph), - RuleImpl::LimitProjectTranspose => LimitProjectTranspose {}.apply(node_id, graph), - RuleImpl::EliminateLimits => EliminateLimits {}.apply(node_id, graph), - RuleImpl::PushLimitThroughJoin => PushLimitThroughJoin {}.apply(node_id, graph), - RuleImpl::PushLimitIntoTableScan => PushLimitIntoScan {}.apply(node_id, graph), - RuleImpl::PushPredicateThroughJoin => PushPredicateThroughJoin {}.apply(node_id, graph), - RuleImpl::SimplifyFilter => SimplifyFilter {}.apply(node_id, graph), - RuleImpl::PushPredicateIntoScan => PushPredicateIntoScan {}.apply(node_id, graph), + RuleImpl::ColumnPruning => ColumnPruning.apply(node_id, graph), + RuleImpl::CollapseProject => CollapseProject.apply(node_id, graph), + RuleImpl::CombineFilter => CombineFilter.apply(node_id, graph), + RuleImpl::LimitProjectTranspose => LimitProjectTranspose.apply(node_id, graph), + RuleImpl::EliminateLimits => EliminateLimits.apply(node_id, graph), + RuleImpl::PushLimitThroughJoin => PushLimitThroughJoin.apply(node_id, graph), + RuleImpl::PushLimitIntoTableScan => PushLimitIntoScan.apply(node_id, graph), + RuleImpl::PushPredicateThroughJoin => PushPredicateThroughJoin.apply(node_id, graph), + RuleImpl::SimplifyFilter => SimplifyFilter.apply(node_id, graph), + RuleImpl::PushPredicateIntoScan => PushPredicateIntoScan.apply(node_id, graph), + RuleImpl::ConstantCalculation => ConstantCalculation.apply(node_id, graph), } } } diff --git a/src/optimizer/rule/pushdown_limit.rs b/src/optimizer/rule/pushdown_limit.rs index d2504f3c..2884b84f 100644 --- a/src/optimizer/rule/pushdown_limit.rs +++ b/src/optimizer/rule/pushdown_limit.rs @@ -1,4 +1,3 @@ -use crate::optimizer::core::opt_expr::OptExprNode; use crate::optimizer::core::pattern::Pattern; use crate::optimizer::core::pattern::PatternChildrenPredicate; use crate::optimizer::core::rule::Rule; @@ -81,10 +80,7 @@ impl Rule for EliminateLimits { let new_limit_op = LimitOperator { offset, limit }; graph.remove_node(child_id, false); - graph.replace_node( - node_id, - OptExprNode::OperatorRef(Operator::Limit(new_limit_op)), - ); + graph.replace_node(node_id, Operator::Limit(new_limit_op)); } } @@ -135,11 +131,7 @@ impl Rule for PushLimitThroughJoin { JoinType::Right => Some(graph.children_at(child_id)[1]), _ => None, } { - graph.add_node( - child_id, - Some(grandson_id), - OptExprNode::OperatorRef(Operator::Limit(op.clone())), - ); + graph.add_node(child_id, Some(grandson_id), Operator::Limit(op.clone())); } } } @@ -165,10 +157,7 @@ impl Rule for PushLimitIntoScan { new_scan_op.limit = (limit_op.offset, limit_op.limit); graph.remove_node(node_id, false); - graph.replace_node( - child_index, - OptExprNode::OperatorRef(Operator::Scan(new_scan_op)), - ); + graph.replace_node(child_index, Operator::Scan(new_scan_op)); } } @@ -180,7 +169,6 @@ impl Rule for PushLimitIntoScan { mod tests { use crate::binder::test::select_sql_run; use crate::db::DatabaseError; - use crate::optimizer::core::opt_expr::OptExprNode; use crate::optimizer::heuristic::batch::HepBatchStrategy; use crate::optimizer::heuristic::optimizer::HepOptimizer; use crate::optimizer::rule::RuleImpl; @@ -227,9 +215,7 @@ mod tests { limit: Some(1), }; - optimizer - .graph - .add_root(OptExprNode::OperatorRef(Operator::Limit(new_limit_op))); + optimizer.graph.add_root(Operator::Limit(new_limit_op)); let best_plan = optimizer.find_best()?; diff --git a/src/optimizer/rule/pushdown_predicates.rs b/src/optimizer/rule/pushdown_predicates.rs index c3268eeb..d5132148 100644 --- a/src/optimizer/rule/pushdown_predicates.rs +++ b/src/optimizer/rule/pushdown_predicates.rs @@ -1,6 +1,5 @@ use crate::catalog::ColumnRef; use crate::expression::{BinaryOperator, ScalarExpression}; -use crate::optimizer::core::opt_expr::OptExprNode; use crate::optimizer::core::pattern::Pattern; use crate::optimizer::core::pattern::PatternChildrenPredicate; use crate::optimizer::core::rule::Rule; @@ -177,23 +176,15 @@ impl Rule for PushPredicateThroughJoin { } if let Some(left_op) = new_ops.0 { - graph.add_node( - child_id, - Some(join_childs[0]), - OptExprNode::OperatorRef(left_op), - ); + graph.add_node(child_id, Some(join_childs[0]), left_op); } if let Some(right_op) = new_ops.1 { - graph.add_node( - child_id, - Some(join_childs[1]), - OptExprNode::OperatorRef(right_op), - ); + graph.add_node(child_id, Some(join_childs[1]), right_op); } if let Some(common_op) = new_ops.2 { - graph.replace_node(node_id, OptExprNode::OperatorRef(common_op)); + graph.replace_node(node_id, common_op); } else { graph.remove_node(node_id, false); } @@ -203,7 +194,7 @@ impl Rule for PushPredicateThroughJoin { } } -pub struct PushPredicateIntoScan {} +pub struct PushPredicateIntoScan; impl Rule for PushPredicateIntoScan { fn pattern(&self) -> &Pattern { @@ -232,10 +223,7 @@ impl Rule for PushPredicateIntoScan { // The constant expression extracted in prewhere is used to // reduce the data scanning range and cannot replace the role of Filter. - graph.replace_node( - child_id, - OptExprNode::OperatorRef(Operator::Scan(scan_by_index)), - ); + graph.replace_node(child_id, Operator::Scan(scan_by_index)); return Ok(()); } diff --git a/src/optimizer/rule/simplification.rs b/src/optimizer/rule/simplification.rs index fc5e324d..3f004451 100644 --- a/src/optimizer/rule/simplification.rs +++ b/src/optimizer/rule/simplification.rs @@ -1,11 +1,17 @@ -use crate::optimizer::core::opt_expr::OptExprNode; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; use crate::optimizer::core::rule::Rule; use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; use crate::optimizer::OptimizerError; +use crate::planner::operator::join::JoinCondition; use crate::planner::operator::Operator; use lazy_static::lazy_static; lazy_static! { + static ref CONSTANT_CALCULATION_RULE: Pattern = { + Pattern { + predicate: |_| true, + children: PatternChildrenPredicate::None, + } + }; static ref SIMPLIFY_FILTER_RULE: Pattern = { Pattern { predicate: |op| matches!(op, Operator::Filter(_)), @@ -17,6 +23,72 @@ lazy_static! { }; } +#[derive(Copy, Clone)] +pub struct ConstantCalculation; + +impl ConstantCalculation { + fn _apply(node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> { + let operator = graph.operator_mut(node_id); + + match operator { + Operator::Aggregate(op) => { + for expr in op.agg_calls.iter_mut().chain(op.groupby_exprs.iter_mut()) { + expr.constant_calculation()?; + } + } + Operator::Filter(op) => { + op.predicate.constant_calculation()?; + } + Operator::Join(op) => { + if let JoinCondition::On { on, filter } = &mut op.on { + for (left_expr, right_expr) in on { + left_expr.constant_calculation()?; + right_expr.constant_calculation()?; + } + if let Some(expr) = filter { + expr.constant_calculation()?; + } + } + } + Operator::Project(op) => { + for expr in &mut op.exprs { + expr.constant_calculation()?; + } + } + Operator::Scan(op) => { + for expr in &mut op.columns { + expr.constant_calculation()?; + } + } + Operator::Sort(op) => { + for field in &mut op.sort_fields { + field.expr.constant_calculation()?; + } + } + _ => (), + } + for child_id in graph.children_at(node_id) { + Self::_apply(child_id, graph)?; + } + + Ok(()) + } +} + +impl Rule for ConstantCalculation { + fn pattern(&self) -> &Pattern { + &CONSTANT_CALCULATION_RULE + } + + fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> { + Self::_apply(node_id, graph)?; + // mark changed to skip this rule batch + graph.version += 1; + + Ok(()) + } +} + #[derive(Copy, Clone)] pub struct SimplifyFilter; @@ -28,11 +100,9 @@ impl Rule for SimplifyFilter { fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> { if let Operator::Filter(mut filter_op) = graph.operator(node_id).clone() { filter_op.predicate.simplify()?; + filter_op.predicate.constant_calculation()?; - graph.replace_node( - node_id, - OptExprNode::OperatorRef(Operator::Filter(filter_op)), - ) + graph.replace_node(node_id, Operator::Filter(filter_op)) } Ok(()) @@ -42,7 +112,7 @@ impl Rule for SimplifyFilter { #[cfg(test)] mod test { use crate::binder::test::select_sql_run; - use crate::catalog::{ColumnCatalog, ColumnDesc}; + use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnSummary}; use crate::db::DatabaseError; use crate::expression::simplify::ConstantBinary; use crate::expression::{BinaryOperator, ScalarExpression, UnaryOperator}; @@ -57,6 +127,45 @@ mod test { use std::collections::Bound; use std::sync::Arc; + #[tokio::test] + async fn test_constant_calculation_omitted() -> Result<(), DatabaseError> { + // (2 + (-1)) < -(c1 + 1) + let plan = + select_sql_run("select c1 + (2 + 1), 2 + 1 from t1 where (2 + (-1)) < -(c1 + 1)") + .await?; + + let best_plan = HepOptimizer::new(plan) + .batch( + "test_simplification".to_string(), + HepBatchStrategy::once_topdown(), + vec![RuleImpl::SimplifyFilter, RuleImpl::ConstantCalculation], + ) + .find_best()?; + if let Operator::Project(project_op) = best_plan.clone().operator { + let constant_expr = ScalarExpression::Constant(Arc::new(DataValue::Int32(Some(3)))); + if let ScalarExpression::Binary { right_expr, .. } = &project_op.exprs[0] { + assert_eq!(right_expr.as_ref(), &constant_expr); + } else { + unreachable!(); + } + assert_eq!(&project_op.exprs[1], &constant_expr); + } else { + unreachable!(); + } + if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator { + let column_binary = filter_op.predicate.convert_binary(&0).unwrap(); + let final_binary = ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-2)))), + }; + assert_eq!(column_binary, Some(final_binary)); + } else { + unreachable!(); + } + + Ok(()) + } + #[tokio::test] async fn test_simplify_filter_single_column() -> Result<(), DatabaseError> { // c1 + 1 < -1 => c1 < -2 @@ -137,9 +246,11 @@ mod test { .find_best()?; if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator { let c1_col = ColumnCatalog { - id: Some(0), - name: "c1".to_string(), - table_name: Some(Arc::new("t1".to_string())), + summary: ColumnSummary { + id: Some(0), + name: "c1".to_string(), + table_name: Some(Arc::new("t1".to_string())), + }, nullable: false, desc: ColumnDesc { column_datatype: LogicalType::Integer, @@ -149,9 +260,11 @@ mod test { ref_expr: None, }; let c2_col = ColumnCatalog { - id: Some(1), - name: "c2".to_string(), - table_name: Some(Arc::new("t1".to_string())), + summary: ColumnSummary { + id: Some(1), + name: "c2".to_string(), + table_name: Some(Arc::new("t1".to_string())), + }, nullable: false, desc: ColumnDesc { column_datatype: LogicalType::Integer, diff --git a/src/planner/operator/mod.rs b/src/planner/operator/mod.rs index 4be5b0af..633012db 100644 --- a/src/planner/operator/mod.rs +++ b/src/planner/operator/mod.rs @@ -67,7 +67,7 @@ impl Operator { pub fn project_input_refs(&self) -> Vec { match self { Operator::Project(op) => op - .columns + .exprs .iter() .map(ScalarExpression::unpack_alias) .filter(|expr| matches!(expr, ScalarExpression::InputRef { .. })) @@ -125,7 +125,7 @@ impl Operator { exprs } Operator::Project(op) => op - .columns + .exprs .iter() .flat_map(|expr| expr.referenced_columns()) .collect_vec(), diff --git a/src/planner/operator/project.rs b/src/planner/operator/project.rs index 9d345e5f..ca7811e9 100644 --- a/src/planner/operator/project.rs +++ b/src/planner/operator/project.rs @@ -2,5 +2,5 @@ use crate::expression::ScalarExpression; #[derive(Debug, PartialEq, Clone)] pub struct ProjectOperator { - pub columns: Vec, + pub exprs: Vec, } diff --git a/src/storage/kip.rs b/src/storage/kip.rs index b6fffcec..73b61379 100644 --- a/src/storage/kip.rs +++ b/src/storage/kip.rs @@ -450,11 +450,11 @@ impl KipTransaction { .into_iter() .filter(|col| col.desc.is_unique) { - if let Some(col_id) = col.id { + if let Some(col_id) = col.id() { let meta = IndexMeta { id: 0, column_ids: vec![col_id], - name: format!("uk_{}", col.name), + name: format!("uk_{}", col.name()), is_unique: true, }; let meta_ref = table.add_index_meta(meta); @@ -584,6 +584,7 @@ mod test { vec![ScalarExpression::InputRef { index: 0, ty: LogicalType::Integer, + ref_columns: vec![], }], )?; diff --git a/src/storage/table_codec.rs b/src/storage/table_codec.rs index d03beb01..483eab2e 100644 --- a/src/storage/table_codec.rs +++ b/src/storage/table_codec.rs @@ -211,10 +211,10 @@ impl TableCodec { /// Tips: the `0` for bound range pub fn encode_column(col: &ColumnCatalog) -> Result<(Bytes, Bytes), TypeError> { let bytes = bincode::serialize(col)?; - let mut key_prefix = Self::key_prefix(CodecType::Column, col.table_name.as_ref().unwrap()); + let mut key_prefix = Self::key_prefix(CodecType::Column, &col.table_name().unwrap()); key_prefix.push(BOUND_MIN_TAG); - key_prefix.append(&mut col.id.unwrap().to_be_bytes().to_vec()); + key_prefix.append(&mut col.id().unwrap().to_be_bytes().to_vec()); Ok((Bytes::from(key_prefix), Bytes::from(bytes))) } @@ -222,7 +222,7 @@ impl TableCodec { pub fn decode_column(bytes: &[u8]) -> Result<(TableName, ColumnCatalog), TypeError> { let column = bincode::deserialize::(bytes)?; - Ok((column.table_name.clone().unwrap(), column)) + Ok((column.table_name().unwrap(), column)) } /// Key: RootCatalog_0_TableName @@ -369,8 +369,8 @@ mod tests { None, ); - col.table_name = Some(Arc::new(table_name.to_string())); - col.id = Some(col_id as u32); + col.summary.table_name = Some(Arc::new(table_name.to_string())); + col.summary.id = Some(col_id as u32); let (key, _) = TableCodec::encode_column(&col).unwrap(); key diff --git a/src/types/tuple.rs b/src/types/tuple.rs index 0359d569..0e75aaf3 100644 --- a/src/types/tuple.rs +++ b/src/types/tuple.rs @@ -101,7 +101,7 @@ pub fn create_table(tuples: &[Tuple]) -> Table { let mut header = Vec::new(); for col in &tuples[0].columns { - header.push(Cell::new(col.name.clone())); + header.push(Cell::new(col.name().to_string())); } table.set_header(header); diff --git a/src/types/tuple_builder.rs b/src/types/tuple_builder.rs index aa6f0778..8544cd1a 100644 --- a/src/types/tuple_builder.rs +++ b/src/types/tuple_builder.rs @@ -53,7 +53,7 @@ impl TupleBuilder { let cast_data_value = data_value.cast(&self.data_types[i])?; self.data_values.push(Arc::new(cast_data_value.clone())); let col = &columns[i]; - col.id + col.id() .map(|col_id| tuple_map.insert(col_id, Arc::new(cast_data_value.clone()))); if primary_key_index.is_none() && col.desc.is_primary { primary_key_index = Some(i); @@ -61,7 +61,7 @@ impl TupleBuilder { } let primary_col_id = primary_key_index - .map(|i| columns[i].id.unwrap()) + .map(|i| columns[i].id().unwrap()) .ok_or_else(|| TypeError::PrimaryKeyNotFound)?; let tuple_id = tuple_map diff --git a/tests/slt/filter.slt b/tests/slt/filter.slt index ccd46821..14646fe2 100644 --- a/tests/slt/filter.slt +++ b/tests/slt/filter.slt @@ -68,4 +68,40 @@ select v2 from t where ((v2 >= -8 and -4 >= v1) or (v1 >= 0 and 5 > v2)) and ((v 1 statement ok -drop table t \ No newline at end of file +create table t1(id int primary key, v1 varchar) + +statement ok +insert into t1 values (0, 'KipSQL'), (1, 'KipDB'), (2, 'KipBlog'), (3, 'Cool!'); + +query II +select * from t1 where v1 like 'Kip%' +---- +0 KipSQL +1 KipDB +2 KipBlog + +query II +select * from t1 where v1 not like 'Kip%' +---- +3 Cool! + +query II +select * from t1 where v1 like 'KipD_' +---- +1 KipDB + +query II +select * from t1 where v1 like 'KipS_L' +---- +0 KipSQL + +query II +select * from t1 where v1 like 'K%L' +---- +0 KipSQL + +statement ok +drop table t + +statement ok +drop table t1 \ No newline at end of file diff --git a/tests/slt/filter_null.slt b/tests/slt/filter_null.slt index db45712f..1bfd0ad3 100644 --- a/tests/slt/filter_null.slt +++ b/tests/slt/filter_null.slt @@ -32,5 +32,17 @@ select * from t where v1 > 1 2 3 4 3 4 3 +query II +select * from t where v1 is null +---- +1 null 3 + +query II +select * from t where v1 is not null +---- +0 2 4 +2 3 4 +3 4 3 + statement ok drop table t \ No newline at end of file diff --git a/tests/sqllogictest/src/main.rs b/tests/sqllogictest/src/main.rs index 58776d61..b65bbd0d 100644 --- a/tests/sqllogictest/src/main.rs +++ b/tests/sqllogictest/src/main.rs @@ -6,16 +6,15 @@ use tempfile::TempDir; #[tokio::main] async fn main() { + const SLT_PATTERN: &str = "tests/slt/**/*.slt"; + let path = Path::new(env!("CARGO_MANIFEST_DIR")).join("..").join(".."); std::env::set_current_dir(path).unwrap(); println!("KipSQL Test Start!\n"); - const SLT_PATTERN: &str = "tests/slt/**/*.slt"; - let slt_files = glob::glob(SLT_PATTERN).expect("failed to find slt files"); - for slt_file in slt_files { + for slt_file in glob::glob(SLT_PATTERN).expect("failed to find slt files") { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let filepath = slt_file .expect("failed to read slt file") .to_str()