diff --git a/Cargo.toml b/Cargo.toml
index fe3ea020..5351475a 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -30,7 +30,7 @@ integer-encoding = "3.0.4"
strum_macros = "0.24"
ordered-float = "3.0"
petgraph = "0.6.3"
-futures-async-stream = "0.2.6"
+futures-async-stream = "0.2.9"
futures = "0.3.25"
ahash = "0.8.3"
lazy_static = "1.4.0"
@@ -39,6 +39,7 @@ bytes = "1.5.0"
kip_db = "0.1.2-alpha.17"
rust_decimal = "1"
csv = "1"
+regex = "1.10.2"
[dev-dependencies]
tokio-test = "0.4.2"
diff --git a/README.md b/README.md
index 62638a13..501a2494 100755
--- a/README.md
+++ b/README.md
@@ -1,12 +1,27 @@
-# KipSQL
+
+Built by @KipData
+
+██╗ ██╗██╗██████╗ ███████╗ ██████╗ ██╗
+██║ ██╔╝██║██╔══██╗██╔════╝██╔═══██╗██║
+█████╔╝ ██║██████╔╝███████╗██║ ██║██║
+██╔═██╗ ██║██╔═══╝ ╚════██║██║▄▄ ██║██║
+ ██║ ██╗██║██║ ███████║╚██████╔╝███████╗
+ ╚═╝ ╚═╝╚═╝╚═╝ ╚══════╝ ╚══▀▀═╝ ╚══════╝
+-----------------------------------
+Embedded SQL DBMS
+
+
+
+### Architecture
+Welcome to our WebSite, Power By KipSQL:
+**http://www.kipdata.site/**
> Lightweight SQL calculation engine, as the SQL layer of KipDB, implemented with TalentPlan's TinySQL as the reference standard
-### Architecture
+
![architecture](./static/images/architecture.png)
### Get Started
-#### 组件引入
``` toml
kip-sql = "0.0.1-alpha.0"
```
@@ -79,6 +94,12 @@ implement_from_tuple!(Post, (
- not null
- null
- unique
+ - primary key
+- SQL where options
+ - is null
+ - is not null
+ - like
+ - not like
- Supports index type
- Unique Index
- Supports multiple primary key types
diff --git a/rust-toolchain b/rust-toolchain
index c34ab853..07ade694 100644
--- a/rust-toolchain
+++ b/rust-toolchain
@@ -1 +1 @@
-nightly-2023-10-13
\ No newline at end of file
+nightly
\ No newline at end of file
diff --git a/src/binder/aggregate.rs b/src/binder/aggregate.rs
index 9e65e0ea..b8523ac7 100644
--- a/src/binder/aggregate.rs
+++ b/src/binder/aggregate.rs
@@ -92,6 +92,8 @@ impl<'a, T: Transaction> Binder<'a, T> {
expr: &mut ScalarExpression,
is_select: bool,
) -> Result<(), BindError> {
+ let ref_columns = expr.referenced_columns();
+
match expr {
ScalarExpression::AggCall {
ty: return_type, ..
@@ -99,7 +101,11 @@ impl<'a, T: Transaction> Binder<'a, T> {
let ty = return_type.clone();
if is_select {
let index = self.context.input_ref_index(InputRefType::AggCall);
- let input_ref = ScalarExpression::InputRef { index, ty };
+ let input_ref = ScalarExpression::InputRef {
+ index,
+ ty,
+ ref_columns,
+ };
match std::mem::replace(expr, input_ref) {
ScalarExpression::AggCall {
kind,
@@ -124,14 +130,21 @@ impl<'a, T: Transaction> Binder<'a, T> {
.find_position(|agg_expr| agg_expr == &expr)
.ok_or_else(|| BindError::AggMiss(format!("{:?}", expr)))?;
- let _ = std::mem::replace(expr, ScalarExpression::InputRef { index, ty });
+ let _ = std::mem::replace(
+ expr,
+ ScalarExpression::InputRef {
+ index,
+ ty,
+ ref_columns,
+ },
+ );
}
}
ScalarExpression::TypeCast { expr, .. } => {
self.visit_column_agg_expr(expr, is_select)?
}
- ScalarExpression::IsNull { expr } => self.visit_column_agg_expr(expr, is_select)?,
+ ScalarExpression::IsNull { expr, .. } => self.visit_column_agg_expr(expr, is_select)?,
ScalarExpression::Unary { expr, .. } => self.visit_column_agg_expr(expr, is_select)?,
ScalarExpression::Alias { expr, .. } => self.visit_column_agg_expr(expr, is_select)?,
ScalarExpression::Binary {
@@ -228,6 +241,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
}) {
let index = self.context.input_ref_index(InputRefType::GroupBy);
let mut select_item = &mut select_list[i];
+ let ref_columns = select_item.referenced_columns();
let return_type = select_item.return_type();
self.context.group_by_exprs.push(std::mem::replace(
@@ -235,6 +249,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
ScalarExpression::InputRef {
index,
ty: return_type,
+ ref_columns,
},
));
return;
@@ -243,6 +258,8 @@ impl<'a, T: Transaction> Binder<'a, T> {
if let Some(i) = select_list.iter().position(|column| column == expr) {
let expr = &mut select_list[i];
+ let ref_columns = expr.referenced_columns();
+
match expr {
ScalarExpression::Constant(_) | ScalarExpression::ColumnRef { .. } => {
self.context.group_by_exprs.push(expr.clone())
@@ -255,6 +272,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
ScalarExpression::InputRef {
index,
ty: expr.return_type(),
+ ref_columns,
},
))
}
@@ -300,7 +318,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
}
ScalarExpression::TypeCast { expr, .. } => self.validate_having_orderby(expr),
- ScalarExpression::IsNull { expr } => self.validate_having_orderby(expr),
+ ScalarExpression::IsNull { expr, .. } => self.validate_having_orderby(expr),
ScalarExpression::Unary { expr, .. } => self.validate_having_orderby(expr),
ScalarExpression::Binary {
left_expr,
diff --git a/src/binder/create_table.rs b/src/binder/create_table.rs
index bba6497f..65aeb8f3 100644
--- a/src/binder/create_table.rs
+++ b/src/binder/create_table.rs
@@ -80,13 +80,13 @@ mod tests {
match plan1.operator {
Operator::CreateTable(op) => {
assert_eq!(op.table_name, Arc::new("t1".to_string()));
- assert_eq!(op.columns[0].name, "id".to_string());
+ assert_eq!(op.columns[0].name(), "id");
assert_eq!(op.columns[0].nullable, false);
assert_eq!(
op.columns[0].desc,
ColumnDesc::new(LogicalType::Integer, true, false)
);
- assert_eq!(op.columns[1].name, "name".to_string());
+ assert_eq!(op.columns[1].name(), "name");
assert_eq!(op.columns[1].nullable, true);
assert_eq!(
op.columns[1].desc,
diff --git a/src/binder/expr.rs b/src/binder/expr.rs
index 628a1db6..01bd21c2 100644
--- a/src/binder/expr.rs
+++ b/src/binder/expr.rs
@@ -1,4 +1,5 @@
use crate::binder::BindError;
+use crate::expression;
use crate::expression::agg::AggKind;
use itertools::Itertools;
use sqlparser::ast::{
@@ -25,13 +26,41 @@ impl<'a, T: Transaction> Binder<'a, T> {
Expr::Function(func) => self.bind_agg_call(func),
Expr::Nested(expr) => self.bind_expr(expr),
Expr::UnaryOp { expr, op } => self.bind_unary_op_internal(expr, op),
- Expr::IsNull(expr) => self.bind_is_null(expr),
+ Expr::Like {
+ negated,
+ expr,
+ pattern,
+ ..
+ } => self.bind_like(*negated, expr, pattern),
+ Expr::IsNull(expr) => self.bind_is_null(expr, false),
+ Expr::IsNotNull(expr) => self.bind_is_null(expr, true),
_ => {
todo!()
}
}
}
+ pub fn bind_like(
+ &mut self,
+ negated: bool,
+ expr: &Expr,
+ pattern: &Expr,
+ ) -> Result {
+ let left_expr = Box::new(self.bind_expr(expr)?);
+ let right_expr = Box::new(self.bind_expr(pattern)?);
+ let op = if negated {
+ expression::BinaryOperator::NotLike
+ } else {
+ expression::BinaryOperator::Like
+ };
+ Ok(ScalarExpression::Binary {
+ op,
+ left_expr,
+ right_expr,
+ ty: LogicalType::Boolean,
+ })
+ }
+
pub fn bind_column_ref_from_identifiers(
&mut self,
idents: &[Ident],
@@ -199,8 +228,9 @@ impl<'a, T: Transaction> Binder<'a, T> {
})
}
- fn bind_is_null(&mut self, expr: &Expr) -> Result {
+ fn bind_is_null(&mut self, expr: &Expr, negated: bool) -> Result {
Ok(ScalarExpression::IsNull {
+ negated,
expr: Box::new(self.bind_expr(expr)?),
})
}
diff --git a/src/binder/select.rs b/src/binder/select.rs
index 2001bd48..cb8868d8 100644
--- a/src/binder/select.rs
+++ b/src/binder/select.rs
@@ -343,9 +343,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
select_list: Vec,
) -> LogicalPlan {
LogicalPlan {
- operator: Operator::Project(ProjectOperator {
- columns: select_list,
- }),
+ operator: Operator::Project(ProjectOperator { exprs: select_list }),
childrens: vec![children],
}
}
@@ -431,7 +429,8 @@ impl<'a, T: Transaction> Binder<'a, T> {
for column in select_items {
if let ScalarExpression::ColumnRef(col) = column {
- if let Some(nullable) = table_force_nullable.get(col.table_name.as_ref().unwrap()) {
+ if let Some(nullable) = table_force_nullable.get(col.table_name().as_ref().unwrap())
+ {
let mut new_col = ColumnCatalog::clone(col);
new_col.nullable = *nullable;
@@ -504,12 +503,12 @@ impl<'a, T: Transaction> Binder<'a, T> {
// example: foo = bar
(ScalarExpression::ColumnRef(l), ScalarExpression::ColumnRef(r)) => {
// reorder left and right joins keys to pattern: (left, right)
- if left_schema.contains_column(&l.name)
- && right_schema.contains_column(&r.name)
+ if left_schema.contains_column(l.name())
+ && right_schema.contains_column(r.name())
{
accum.push((left, right));
- } else if left_schema.contains_column(&r.name)
- && right_schema.contains_column(&l.name)
+ } else if left_schema.contains_column(r.name())
+ && right_schema.contains_column(l.name())
{
accum.push((right, left));
} else {
diff --git a/src/catalog/column.rs b/src/catalog/column.rs
index c3206160..4e73de6c 100644
--- a/src/catalog/column.rs
+++ b/src/catalog/column.rs
@@ -2,6 +2,7 @@ use crate::catalog::TableName;
use crate::expression::ScalarExpression;
use serde::{Deserialize, Serialize};
use sqlparser::ast::{ColumnDef, ColumnOption};
+use std::hash::Hash;
use std::sync::Arc;
use crate::types::{ColumnId, LogicalType};
@@ -10,14 +11,19 @@ pub type ColumnRef = Arc;
#[derive(Debug, Clone, Serialize, Deserialize, Hash, Eq, PartialEq)]
pub struct ColumnCatalog {
- pub id: Option,
- pub name: String,
- pub table_name: Option,
+ pub summary: ColumnSummary,
pub nullable: bool,
pub desc: ColumnDesc,
pub ref_expr: Option,
}
+#[derive(Debug, Clone, Serialize, Deserialize, Hash, Eq, PartialEq)]
+pub struct ColumnSummary {
+ pub id: Option,
+ pub name: String,
+ pub table_name: Option,
+}
+
impl ColumnCatalog {
pub(crate) fn new(
column_name: String,
@@ -26,9 +32,11 @@ impl ColumnCatalog {
ref_expr: Option,
) -> ColumnCatalog {
ColumnCatalog {
- id: None,
- name: column_name,
- table_name: None,
+ summary: ColumnSummary {
+ id: None,
+ name: column_name,
+ table_name: None,
+ },
nullable,
desc: column_desc,
ref_expr,
@@ -37,20 +45,39 @@ impl ColumnCatalog {
pub(crate) fn new_dummy(column_name: String) -> ColumnCatalog {
ColumnCatalog {
- id: Some(0),
- name: column_name,
- table_name: None,
+ summary: ColumnSummary {
+ id: Some(0),
+ name: column_name,
+ table_name: None,
+ },
nullable: false,
desc: ColumnDesc::new(LogicalType::Varchar(None), false, false),
ref_expr: None,
}
}
+ pub(crate) fn summary(&self) -> &ColumnSummary {
+ &self.summary
+ }
+
+ pub(crate) fn id(&self) -> Option {
+ self.summary.id
+ }
+
+ pub(crate) fn table_name(&self) -> Option {
+ self.summary.table_name.clone()
+ }
+
+ pub(crate) fn name(&self) -> &str {
+ &self.summary.name
+ }
+
pub(crate) fn datatype(&self) -> &LogicalType {
&self.desc.column_datatype
}
- pub fn desc(&self) -> &ColumnDesc {
+ #[allow(dead_code)]
+ pub(crate) fn desc(&self) -> &ColumnDesc {
&self.desc
}
}
diff --git a/src/catalog/table.rs b/src/catalog/table.rs
index 215e2438..92268f61 100644
--- a/src/catalog/table.rs
+++ b/src/catalog/table.rs
@@ -38,7 +38,7 @@ impl TableCatalog {
self.columns.get(id)
}
- pub(crate) fn contains_column(&self, name: &String) -> bool {
+ pub(crate) fn contains_column(&self, name: &str) -> bool {
self.column_idxs.contains_key(name)
}
@@ -55,15 +55,15 @@ impl TableCatalog {
/// Add a column to the table catalog.
pub(crate) fn add_column(&mut self, mut col: ColumnCatalog) -> Result {
- if self.column_idxs.contains_key(&col.name) {
- return Err(CatalogError::Duplicated("column", col.name.clone()));
+ if self.column_idxs.contains_key(col.name()) {
+ return Err(CatalogError::Duplicated("column", col.name().to_string()));
}
let col_id = self.columns.len() as u32;
- col.id = Some(col_id);
- col.table_name = Some(self.name.clone());
- self.column_idxs.insert(col.name.clone(), col_id);
+ col.summary.id = Some(col_id);
+ col.summary.table_name = Some(self.name.clone());
+ self.column_idxs.insert(col.name().to_string(), col_id);
self.columns.insert(col_id, Arc::new(col));
Ok(col_id)
@@ -148,11 +148,11 @@ mod tests {
assert!(col_a_id < col_b_id);
let column_catalog = table_catalog.get_column_by_id(&col_a_id).unwrap();
- assert_eq!(column_catalog.name, "a");
+ assert_eq!(column_catalog.name(), "a");
assert_eq!(*column_catalog.datatype(), LogicalType::Integer,);
let column_catalog = table_catalog.get_column_by_id(&col_b_id).unwrap();
- assert_eq!(column_catalog.name, "b");
+ assert_eq!(column_catalog.name(), "b");
assert_eq!(*column_catalog.datatype(), LogicalType::Boolean,);
}
}
diff --git a/src/db.rs b/src/db.rs
index 813168bf..1b94a885 100644
--- a/src/db.rs
+++ b/src/db.rs
@@ -44,7 +44,6 @@ impl Database {
if stmts.is_empty() {
return Ok(vec![]);
}
-
let binder = Binder::new(BinderContext::new(&transaction));
/// Build a logical plan.
@@ -71,10 +70,15 @@ impl Database {
fn default_optimizer(source_plan: LogicalPlan) -> HepOptimizer {
HepOptimizer::new(source_plan)
+ .batch(
+ "Column Pruning".to_string(),
+ HepBatchStrategy::once_topdown(),
+ vec![RuleImpl::ColumnPruning],
+ )
.batch(
"Simplify Filter".to_string(),
HepBatchStrategy::fix_point_topdown(10),
- vec![RuleImpl::SimplifyFilter],
+ vec![RuleImpl::SimplifyFilter, RuleImpl::ConstantCalculation],
)
.batch(
"Predicate Pushdown".to_string(),
@@ -89,14 +93,6 @@ impl Database {
HepBatchStrategy::fix_point_topdown(10),
vec![RuleImpl::CollapseProject, RuleImpl::CombineFilter],
)
- .batch(
- "Column Pruning".to_string(),
- HepBatchStrategy::fix_point_topdown(10),
- vec![
- RuleImpl::PushProjectThroughChild,
- RuleImpl::PushProjectIntoScan,
- ],
- )
.batch(
"Limit Pushdown".to_string(),
HepBatchStrategy::fix_point_topdown(10),
@@ -199,7 +195,7 @@ mod test {
let _ = kipsql
.run("create table t2 (c int primary key, d int unsigned null, e datetime)")
.await?;
- let _ = kipsql.run("insert into t1 (a, b, k, z) values (-99, 1, 1, 'k'), (-1, 2, 2, 'i'), (5, 3, 2, 'p')").await?;
+ let _ = kipsql.run("insert into t1 (a, b, k, z) values (-99, 1, 1, 'k'), (-1, 2, 2, 'i'), (5, 3, 2, 'p'), (29, 4, 2, 'db')").await?;
let _ = kipsql.run("insert into t2 (d, c, e) values (2, 1, '2021-05-20 21:00:00'), (3, 4, '2023-09-10 00:00:00')").await?;
let _ = kipsql
.run("create table t3 (a int primary key, b decimal(4,2))")
@@ -231,6 +227,18 @@ mod test {
let tuples_projection_and_sort = kipsql.run("select * from t1 order by a, b").await?;
println!("{}", create_table(&tuples_projection_and_sort));
+ println!("like t1 1:");
+ let tuples_like_1_t1 = kipsql.run("select * from t1 where z like '%k'").await?;
+ println!("{}", create_table(&tuples_like_1_t1));
+
+ println!("like t1 2:");
+ let tuples_like_2_t1 = kipsql.run("select * from t1 where z like '_b'").await?;
+ println!("{}", create_table(&tuples_like_2_t1));
+
+ println!("not like t1:");
+ let tuples_not_like_t1 = kipsql.run("select * from t1 where z not like '_b'").await?;
+ println!("{}", create_table(&tuples_not_like_t1));
+
println!("limit:");
let tuples_limit = kipsql.run("select * from t1 limit 1 offset 1").await?;
println!("{}", create_table(&tuples_limit));
diff --git a/src/execution/executor/dml/copy_from_file.rs b/src/execution/executor/dml/copy_from_file.rs
index 341fb974..10e68ac5 100644
--- a/src/execution/executor/dml/copy_from_file.rs
+++ b/src/execution/executor/dml/copy_from_file.rs
@@ -111,7 +111,7 @@ fn return_result(size: usize, tx: Sender) -> Result<(), ExecutorError> {
#[cfg(test)]
mod tests {
- use crate::catalog::{ColumnCatalog, ColumnDesc};
+ use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnSummary};
use crate::db::{Database, DatabaseError};
use futures::StreamExt;
use std::io::Write;
@@ -132,25 +132,31 @@ mod tests {
let columns = vec![
Arc::new(ColumnCatalog {
- id: Some(0),
- name: "a".to_string(),
- table_name: None,
+ summary: ColumnSummary {
+ id: Some(0),
+ name: "a".to_string(),
+ table_name: None,
+ },
nullable: false,
desc: ColumnDesc::new(LogicalType::Integer, true, false),
ref_expr: None,
}),
Arc::new(ColumnCatalog {
- id: Some(1),
- name: "b".to_string(),
- table_name: None,
+ summary: ColumnSummary {
+ id: Some(1),
+ name: "b".to_string(),
+ table_name: None,
+ },
nullable: false,
desc: ColumnDesc::new(LogicalType::Float, false, false),
ref_expr: None,
}),
Arc::new(ColumnCatalog {
- id: Some(1),
- name: "c".to_string(),
- table_name: None,
+ summary: ColumnSummary {
+ id: Some(1),
+ name: "c".to_string(),
+ table_name: None,
+ },
nullable: false,
desc: ColumnDesc::new(LogicalType::Varchar(Some(10)), false, false),
ref_expr: None,
diff --git a/src/execution/executor/dml/delete.rs b/src/execution/executor/dml/delete.rs
index d85c2549..a6b60540 100644
--- a/src/execution/executor/dml/delete.rs
+++ b/src/execution/executor/dml/delete.rs
@@ -39,7 +39,7 @@ impl Delete {
col.desc
.is_unique
.then(|| {
- col.id.and_then(|col_id| {
+ col.id().and_then(|col_id| {
table_catalog
.get_unique_index(&col_id)
.map(|index_meta| (i, index_meta.clone()))
diff --git a/src/execution/executor/dml/insert.rs b/src/execution/executor/dml/insert.rs
index a4f5e1a0..34746825 100644
--- a/src/execution/executor/dml/insert.rs
+++ b/src/execution/executor/dml/insert.rs
@@ -62,7 +62,7 @@ impl Insert {
for (i, value) in values.into_iter().enumerate() {
let col = &columns[i];
- if let Some(col_id) = col.id {
+ if let Some(col_id) = col.id() {
tuple_map.insert(col_id, value);
}
}
@@ -70,7 +70,7 @@ impl Insert {
columns
.iter()
.find(|col| col.desc.is_primary)
- .map(|col| col.id.unwrap())
+ .map(|col| col.id().unwrap())
.unwrap()
});
let all_columns = table_catalog.all_columns_with_id();
@@ -87,7 +87,7 @@ impl Insert {
if col.desc.is_unique && !value.is_null() {
unique_values
- .entry(col.id)
+ .entry(col.id())
.or_insert_with(|| vec![])
.push((tuple_id.clone(), value.clone()))
}
diff --git a/src/execution/executor/dml/update.rs b/src/execution/executor/dml/update.rs
index ee48a87a..411753c6 100644
--- a/src/execution/executor/dml/update.rs
+++ b/src/execution/executor/dml/update.rs
@@ -56,7 +56,7 @@ impl Update {
columns, values, ..
} = tuple?;
for i in 0..columns.len() {
- value_map.insert(columns[i].id, values[i].clone());
+ value_map.insert(columns[i].id(), values[i].clone());
}
}
#[for_await]
@@ -65,7 +65,7 @@ impl Update {
let mut is_overwrite = true;
for (i, column) in tuple.columns.iter().enumerate() {
- if let Some(value) = value_map.get(&column.id) {
+ if let Some(value) = value_map.get(&column.id()) {
if column.desc.is_primary {
let old_key = tuple.id.replace(value.clone()).unwrap();
@@ -74,7 +74,7 @@ impl Update {
}
if column.desc.is_unique && value != &tuple.values[i] {
if let Some(index_meta) =
- table_catalog.get_unique_index(&column.id.unwrap())
+ table_catalog.get_unique_index(&column.id().unwrap())
{
let mut index = Index {
id: index_meta.id,
diff --git a/src/execution/executor/dql/projection.rs b/src/execution/executor/dql/projection.rs
index 8285896a..d6c21b92 100644
--- a/src/execution/executor/dql/projection.rs
+++ b/src/execution/executor/dql/projection.rs
@@ -13,11 +13,8 @@ pub struct Projection {
}
impl From<(ProjectOperator, BoxedExecutor)> for Projection {
- fn from((ProjectOperator { columns }, input): (ProjectOperator, BoxedExecutor)) -> Self {
- Projection {
- exprs: columns,
- input,
- }
+ fn from((ProjectOperator { exprs }, input): (ProjectOperator, BoxedExecutor)) -> Self {
+ Projection { exprs, input }
}
}
diff --git a/src/expression/evaluator.rs b/src/expression/evaluator.rs
index c233ed23..9288ec43 100644
--- a/src/expression/evaluator.rs
+++ b/src/expression/evaluator.rs
@@ -16,7 +16,7 @@ impl ScalarExpression {
match &self {
ScalarExpression::Constant(val) => Ok(val.clone()),
ScalarExpression::ColumnRef(col) => {
- let value = Self::eval_with_name(&tuple, &col.name)
+ let value = Self::eval_with_name(&tuple, col.name())
.unwrap_or(&NULL_VALUE)
.clone();
@@ -46,10 +46,13 @@ impl ScalarExpression {
Ok(Arc::new(binary_op(&left, &right, op)?))
}
- ScalarExpression::IsNull { expr } => {
- let value = expr.eval_column(tuple)?;
+ ScalarExpression::IsNull { expr, negated } => {
+ let mut value = expr.eval_column(tuple)?.is_null();
- Ok(Arc::new(DataValue::Boolean(Some(value.is_null()))))
+ if *negated {
+ value = !value;
+ }
+ Ok(Arc::new(DataValue::Boolean(Some(value))))
}
ScalarExpression::Unary { expr, op, .. } => {
let value = expr.eval_column(tuple)?;
@@ -60,11 +63,11 @@ impl ScalarExpression {
}
}
- fn eval_with_name<'a>(tuple: &'a Tuple, name: &String) -> Option<&'a ValueRef> {
+ fn eval_with_name<'a>(tuple: &'a Tuple, name: &str) -> Option<&'a ValueRef> {
tuple
.columns
.iter()
- .find_position(|tul_col| &tul_col.name == name)
+ .find_position(|tul_col| tul_col.name() == name)
.map(|(i, _)| &tuple.values[i])
}
}
diff --git a/src/expression/mod.rs b/src/expression/mod.rs
index 6c617fd8..282649af 100644
--- a/src/expression/mod.rs
+++ b/src/expression/mod.rs
@@ -30,6 +30,7 @@ pub enum ScalarExpression {
InputRef {
index: usize,
ty: LogicalType,
+ ref_columns: Vec,
},
Alias {
expr: Box,
@@ -40,6 +41,7 @@ pub enum ScalarExpression {
ty: LogicalType,
},
IsNull {
+ negated: bool,
expr: Box,
},
Unary {
@@ -70,6 +72,23 @@ impl ScalarExpression {
}
}
+ pub fn has_count_star(&self) -> bool {
+ match self {
+ ScalarExpression::InputRef { ref_columns, .. } => ref_columns.is_empty(),
+ ScalarExpression::Alias { expr, .. } => expr.has_count_star(),
+ ScalarExpression::TypeCast { expr, .. } => expr.has_count_star(),
+ ScalarExpression::IsNull { expr, .. } => expr.has_count_star(),
+ ScalarExpression::Unary { expr, .. } => expr.has_count_star(),
+ ScalarExpression::Binary {
+ left_expr,
+ right_expr,
+ ..
+ } => left_expr.has_count_star() || right_expr.has_count_star(),
+ ScalarExpression::AggCall { args, .. } => args.iter().any(Self::has_count_star),
+ _ => false,
+ }
+ }
+
pub fn nullable(&self) -> bool {
match self {
ScalarExpression::Constant(_) => false,
@@ -77,7 +96,7 @@ impl ScalarExpression {
ScalarExpression::InputRef { .. } => unreachable!(),
ScalarExpression::Alias { expr, .. } => expr.nullable(),
ScalarExpression::TypeCast { expr, .. } => expr.nullable(),
- ScalarExpression::IsNull { expr } => expr.nullable(),
+ ScalarExpression::IsNull { expr, .. } => expr.nullable(),
ScalarExpression::Unary { expr, .. } => expr.nullable(),
ScalarExpression::Binary {
left_expr,
@@ -135,6 +154,9 @@ impl ScalarExpression {
columns_collect(expr, vec)
}
}
+ ScalarExpression::InputRef { ref_columns, .. } => {
+ vec.extend_from_slice(ref_columns);
+ }
_ => (),
}
}
@@ -187,7 +209,7 @@ impl ScalarExpression {
} => {
let args_str = args
.iter()
- .map(|expr| expr.output_columns(tuple).name.clone())
+ .map(|expr| expr.output_columns(tuple).name().to_string())
.join(", ");
let op = |allow_distinct, distinct| {
if allow_distinct && distinct {
@@ -219,9 +241,9 @@ impl ScalarExpression {
} => {
let column_name = format!(
"({} {} {})",
- left_expr.output_columns(tuple).name,
+ left_expr.output_columns(tuple).name(),
op,
- right_expr.output_columns(tuple).name,
+ right_expr.output_columns(tuple).name(),
);
Arc::new(ColumnCatalog::new(
@@ -232,7 +254,7 @@ impl ScalarExpression {
))
}
ScalarExpression::Unary { expr, op, ty } => {
- let column_name = format!("{} {}", op, expr.output_columns(tuple).name,);
+ let column_name = format!("{} {}", op, expr.output_columns(tuple).name());
Arc::new(ColumnCatalog::new(
column_name,
true,
@@ -280,6 +302,8 @@ pub enum BinaryOperator {
Spaceship,
Eq,
NotEq,
+ Like,
+ NotLike,
And,
Or,
@@ -305,6 +329,8 @@ impl fmt::Display for BinaryOperator {
BinaryOperator::And => write!(f, "&&"),
BinaryOperator::Or => write!(f, "||"),
BinaryOperator::Xor => write!(f, "^"),
+ BinaryOperator::Like => write!(f, "like"),
+ BinaryOperator::NotLike => write!(f, "not like"),
}
}
}
diff --git a/src/expression/simplify.rs b/src/expression/simplify.rs
index 25a7be94..53537f01 100644
--- a/src/expression/simplify.rs
+++ b/src/expression/simplify.rs
@@ -266,10 +266,10 @@ struct ReplaceUnary {
impl ScalarExpression {
pub fn exist_column(&self, col_id: &ColumnId) -> bool {
match self {
- ScalarExpression::ColumnRef(col) => col.id == Some(*col_id),
+ ScalarExpression::ColumnRef(col) => col.id() == Some(*col_id),
ScalarExpression::Alias { expr, .. } => expr.exist_column(col_id),
ScalarExpression::TypeCast { expr, .. } => expr.exist_column(col_id),
- ScalarExpression::IsNull { expr } => expr.exist_column(col_id),
+ ScalarExpression::IsNull { expr, .. } => expr.exist_column(col_id),
ScalarExpression::Unary { expr, .. } => expr.exist_column(col_id),
ScalarExpression::Binary {
left_expr,
@@ -287,7 +287,7 @@ impl ScalarExpression {
ScalarExpression::TypeCast { expr, ty, .. } => expr
.unpack_val()
.and_then(|val| DataValue::clone(&val).cast(ty).ok().map(Arc::new)),
- ScalarExpression::IsNull { expr } => {
+ ScalarExpression::IsNull { expr, .. } => {
let is_null = expr.unpack_val().map(|val| val.is_null());
Some(Arc::new(DataValue::Boolean(is_null)))
@@ -338,6 +338,48 @@ impl ScalarExpression {
self._simplify(&mut Vec::new())
}
+ pub fn constant_calculation(&mut self) -> Result<(), TypeError> {
+ match self {
+ ScalarExpression::Unary { expr, op, .. } => {
+ expr.constant_calculation()?;
+
+ if let ScalarExpression::Constant(unary_val) = expr.as_ref() {
+ let value = unary_op(unary_val, op)?;
+ let _ = mem::replace(self, ScalarExpression::Constant(Arc::new(value)));
+ }
+ }
+ ScalarExpression::Binary {
+ left_expr,
+ right_expr,
+ op,
+ ..
+ } => {
+ left_expr.constant_calculation()?;
+ right_expr.constant_calculation()?;
+
+ if let (
+ ScalarExpression::Constant(left_val),
+ ScalarExpression::Constant(right_val),
+ ) = (left_expr.as_ref(), right_expr.as_ref())
+ {
+ let value = binary_op(left_val, right_val, op)?;
+ let _ = mem::replace(self, ScalarExpression::Constant(Arc::new(value)));
+ }
+ }
+ ScalarExpression::Alias { expr, .. } => expr.constant_calculation()?,
+ ScalarExpression::TypeCast { expr, .. } => expr.constant_calculation()?,
+ ScalarExpression::IsNull { expr, .. } => expr.constant_calculation()?,
+ ScalarExpression::AggCall { args, .. } => {
+ for expr in args {
+ expr.constant_calculation()?;
+ }
+ }
+ _ => (),
+ }
+
+ Ok(())
+ }
+
// Tips: Indirect expressions like `ScalarExpression::Alias` will be lost
fn _simplify(&mut self, replaces: &mut Vec) -> Result<(), TypeError> {
match self {
@@ -458,7 +500,6 @@ impl ScalarExpression {
if Self::is_arithmetic(op) {
return Ok(());
}
-
while let Some(replace) = replaces.pop() {
match replace {
Replace::Binary(binary) => Self::fix_binary(binary, left_expr, right_expr, op),
@@ -638,7 +679,7 @@ impl ScalarExpression {
}
ScalarExpression::Alias { expr, .. } => expr.convert_binary(col_id),
ScalarExpression::TypeCast { expr, .. } => expr.convert_binary(col_id),
- ScalarExpression::IsNull { expr } => expr.convert_binary(col_id),
+ ScalarExpression::IsNull { expr, .. } => expr.convert_binary(col_id),
ScalarExpression::Unary { expr, .. } => expr.convert_binary(col_id),
_ => Ok(None),
}
@@ -666,7 +707,7 @@ impl ScalarExpression {
val: ValueRef,
is_flip: bool,
) -> Option {
- if col.id.unwrap() != *col_id {
+ if col.id() != Some(*col_id) {
return None;
}
@@ -706,7 +747,7 @@ impl ScalarExpression {
#[cfg(test)]
mod test {
- use crate::catalog::{ColumnCatalog, ColumnDesc};
+ use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnSummary};
use crate::expression::simplify::ConstantBinary;
use crate::expression::{BinaryOperator, ScalarExpression};
use crate::types::errors::TypeError;
@@ -718,9 +759,11 @@ mod test {
#[test]
fn test_convert_binary_simple() -> Result<(), TypeError> {
let col_1 = Arc::new(ColumnCatalog {
- id: Some(0),
- name: "c1".to_string(),
- table_name: None,
+ summary: ColumnSummary {
+ id: Some(0),
+ name: "c1".to_string(),
+ table_name: None,
+ },
nullable: false,
desc: ColumnDesc {
column_datatype: LogicalType::Integer,
diff --git a/src/expression/value_compute.rs b/src/expression/value_compute.rs
index fe4efdc1..a47d62ec 100644
--- a/src/expression/value_compute.rs
+++ b/src/expression/value_compute.rs
@@ -2,6 +2,7 @@ use crate::expression::{BinaryOperator, UnaryOperator};
use crate::types::errors::TypeError;
use crate::types::value::DataValue;
use crate::types::LogicalType;
+use regex::Regex;
fn unpack_i32(value: DataValue) -> Option {
match value {
@@ -114,6 +115,22 @@ pub fn binary_op(
right: &DataValue,
op: &BinaryOperator,
) -> Result {
+ if matches!(op, BinaryOperator::Like | BinaryOperator::NotLike) {
+ let value_option = unpack_utf8(left.clone().cast(&LogicalType::Varchar(None))?);
+ let pattern_option = unpack_utf8(right.clone().cast(&LogicalType::Varchar(None))?);
+
+ let mut is_match = if let (Some(value), Some(pattern)) = (value_option, pattern_option) {
+ let regex_pattern = pattern.replace("%", ".*").replace("_", ".");
+
+ Regex::new(®ex_pattern).unwrap().is_match(&value)
+ } else {
+ unreachable!("The left and right values calculated by Like cannot be Null values.")
+ };
+ if op == &BinaryOperator::NotLike {
+ is_match = !is_match;
+ }
+ return Ok(DataValue::Boolean(Some(is_match)));
+ }
let unified_type = LogicalType::max_logical_type(&left.logical_type(), &right.logical_type())?;
let value = match &unified_type {
diff --git a/src/lib.rs b/src/lib.rs
index e7d9c23b..a1478ad2 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,11 +1,10 @@
#![feature(error_generic_member_access)]
#![allow(unused_doc_comments)]
#![feature(result_flattening)]
-#![feature(generators)]
+#![feature(coroutines)]
#![feature(iterator_try_collect)]
#![feature(slice_pattern)]
#![feature(bound_map)]
-#![feature(async_fn_in_trait)]
extern crate core;
pub mod binder;
pub mod catalog;
diff --git a/src/marco/mod.rs b/src/marco/mod.rs
index f9ca70e9..1f6ab659 100644
--- a/src/marco/mod.rs
+++ b/src/marco/mod.rs
@@ -31,7 +31,7 @@ macro_rules! implement_from_tuple {
let (idx, _) = tuple.columns
.iter()
.enumerate()
- .find(|(_, col)| &col.name == field_name)?;
+ .find(|(_, col)| col.name() == field_name)?;
DataValue::clone(&tuple.values[idx])
.cast(&ty)
diff --git a/src/optimizer/core/opt_expr.rs b/src/optimizer/core/opt_expr.rs
index 0b7136fc..c29e7b8e 100644
--- a/src/optimizer/core/opt_expr.rs
+++ b/src/optimizer/core/opt_expr.rs
@@ -3,36 +3,17 @@ use crate::planner::LogicalPlan;
use std::fmt::Debug;
pub type OptExprNodeId = usize;
-
-#[derive(Clone, PartialEq)]
-pub enum OptExprNode {
- /// Raw plan node with dummy children.
- OperatorRef(Operator),
- #[allow(dead_code)]
- /// Existing OptExprNode in graph.
- OptExpr(OptExprNodeId),
-}
-
-impl Debug for OptExprNode {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- match self {
- Self::OperatorRef(op) => write!(f, "LogicOperator({:?})", op),
- Self::OptExpr(id) => write!(f, "OptExpr({})", id),
- }
- }
-}
-
#[derive(Clone, Debug)]
pub struct OptExpr {
/// The root of the tree.
- pub root: OptExprNode,
+ pub root: Operator,
/// The root's children expressions.
pub childrens: Vec,
}
impl OptExpr {
#[allow(dead_code)]
- pub fn new(root: OptExprNode, childrens: Vec) -> Self {
+ pub fn new(root: Operator, childrens: Vec) -> Self {
Self { root, childrens }
}
@@ -43,7 +24,7 @@ impl OptExpr {
#[allow(dead_code)]
fn build_opt_expr_internal(input: &LogicalPlan) -> OptExpr {
- let root = OptExprNode::OperatorRef(input.operator.clone());
+ let root = input.operator.clone();
let childrens = input
.childrens
.iter()
@@ -54,22 +35,14 @@ impl OptExpr {
#[allow(dead_code)]
pub fn to_plan_ref(&self) -> LogicalPlan {
- match &self.root {
- OptExprNode::OperatorRef(op) => {
- let childrens = self
- .childrens
- .iter()
- .map(|c| c.to_plan_ref())
- .collect::>();
- LogicalPlan {
- operator: op.clone(),
- childrens,
- }
- }
- OptExprNode::OptExpr(_) => LogicalPlan {
- operator: Operator::Dummy,
- childrens: vec![],
- },
+ let childrens = self
+ .childrens
+ .iter()
+ .map(|c| c.to_plan_ref())
+ .collect::>();
+ LogicalPlan {
+ operator: self.root.clone(),
+ childrens,
}
}
}
diff --git a/src/optimizer/heuristic/batch.rs b/src/optimizer/heuristic/batch.rs
index 81f6ed78..2161bec7 100644
--- a/src/optimizer/heuristic/batch.rs
+++ b/src/optimizer/heuristic/batch.rs
@@ -30,7 +30,6 @@ pub struct HepBatchStrategy {
}
impl HepBatchStrategy {
- #[allow(dead_code)]
pub fn once_topdown() -> Self {
HepBatchStrategy {
max_iteration: 1,
diff --git a/src/optimizer/heuristic/graph.rs b/src/optimizer/heuristic/graph.rs
index 4f4b9e97..16ca4ef4 100644
--- a/src/optimizer/heuristic/graph.rs
+++ b/src/optimizer/heuristic/graph.rs
@@ -1,4 +1,4 @@
-use crate::optimizer::core::opt_expr::{OptExprNode, OptExprNodeId};
+use crate::optimizer::core::opt_expr::OptExprNodeId;
use crate::optimizer::heuristic::batch::HepMatchOrder;
use crate::planner::operator::Operator;
use crate::planner::LogicalPlan;
@@ -12,7 +12,7 @@ pub type HepNodeId = NodeIndex;
#[derive(Debug)]
pub struct HepGraph {
- graph: StableDiGraph,
+ graph: StableDiGraph,
root_index: HepNodeId,
pub version: usize,
}
@@ -20,13 +20,13 @@ pub struct HepGraph {
impl HepGraph {
pub fn new(root: LogicalPlan) -> Self {
fn graph_filling(
- graph: &mut StableDiGraph,
+ graph: &mut StableDiGraph,
LogicalPlan {
operator,
childrens,
}: LogicalPlan,
) -> HepNodeId {
- let index = graph.add_node(OptExprNode::OperatorRef(operator));
+ let index = graph.add_node(operator);
for (order, child) in childrens.into_iter().enumerate() {
let child_index = graph_filling(graph, child);
@@ -36,7 +36,7 @@ impl HepGraph {
index
}
- let mut graph = StableDiGraph::::default();
+ let mut graph = StableDiGraph::::default();
let root_index = graph_filling(&mut graph, root);
@@ -54,7 +54,7 @@ impl HepGraph {
}
#[allow(dead_code)]
- pub fn add_root(&mut self, new_node: OptExprNode) {
+ pub fn add_root(&mut self, new_node: Operator) {
let old_root_id = mem::replace(&mut self.root_index, self.graph.add_node(new_node));
self.graph.add_edge(self.root_index, old_root_id, 0);
@@ -65,7 +65,7 @@ impl HepGraph {
&mut self,
source_id: HepNodeId,
children_option: Option,
- new_node: OptExprNode,
+ new_node: Operator,
) {
let new_index = self.graph.add_node(new_node);
@@ -85,7 +85,7 @@ impl HepGraph {
self.version += 1;
}
- pub fn replace_node(&mut self, source_id: HepNodeId, new_node: OptExprNode) {
+ pub fn replace_node(&mut self, source_id: HepNodeId, new_node: Operator) {
self.graph[source_id] = new_node;
self.version += 1;
}
@@ -97,11 +97,7 @@ impl HepGraph {
self.version += 1;
}
- pub fn remove_node(
- &mut self,
- source_id: HepNodeId,
- with_childrens: bool,
- ) -> Option {
+ pub fn remove_node(&mut self, source_id: HepNodeId, with_childrens: bool) -> Option {
if !with_childrens {
let children_ids = self
.graph
@@ -152,15 +148,16 @@ impl HepGraph {
}
#[allow(dead_code)]
- pub fn node(&self, node_id: HepNodeId) -> Option<&OptExprNode> {
+ pub fn node(&self, node_id: HepNodeId) -> Option<&Operator> {
self.graph.node_weight(node_id)
}
pub fn operator(&self, node_id: HepNodeId) -> &Operator {
- match &self.graph[node_id] {
- OptExprNode::OperatorRef(op) => op,
- OptExprNode::OptExpr(node_id) => self.operator(HepNodeId::new(*node_id)),
- }
+ &self.graph[node_id]
+ }
+
+ pub fn operator_mut(&mut self, node_id: HepNodeId) -> &mut Operator {
+ &mut self.graph[node_id]
}
pub fn to_plan(&self) -> LogicalPlan {
@@ -204,7 +201,6 @@ impl HepGraph {
mod tests {
use crate::binder::test::select_sql_run;
use crate::execution::ExecutorError;
- use crate::optimizer::core::opt_expr::OptExprNode;
use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId};
use crate::planner::operator::Operator;
use petgraph::stable_graph::{EdgeIndex, NodeIndex};
@@ -236,23 +232,11 @@ mod tests {
let plan = select_sql_run("select * from t1 left join t2 on c1 = c3").await?;
let mut graph = HepGraph::new(plan);
- graph.add_node(
- HepNodeId::new(1),
- None,
- OptExprNode::OperatorRef(Operator::Dummy),
- );
+ graph.add_node(HepNodeId::new(1), None, Operator::Dummy);
- graph.add_node(
- HepNodeId::new(1),
- Some(HepNodeId::new(4)),
- OptExprNode::OperatorRef(Operator::Dummy),
- );
+ graph.add_node(HepNodeId::new(1), Some(HepNodeId::new(4)), Operator::Dummy);
- graph.add_node(
- HepNodeId::new(5),
- None,
- OptExprNode::OperatorRef(Operator::Dummy),
- );
+ graph.add_node(HepNodeId::new(5), None, Operator::Dummy);
assert!(graph
.graph
@@ -276,7 +260,7 @@ mod tests {
let plan = select_sql_run("select * from t1 left join t2 on c1 = c3").await?;
let mut graph = HepGraph::new(plan);
- graph.replace_node(HepNodeId::new(1), OptExprNode::OperatorRef(Operator::Dummy));
+ graph.replace_node(HepNodeId::new(1), Operator::Dummy);
assert!(matches!(graph.operator(HepNodeId::new(1)), Operator::Dummy));
@@ -338,7 +322,7 @@ mod tests {
let plan = select_sql_run("select * from t1 left join t2 on c1 = c3").await?;
let mut graph = HepGraph::new(plan);
- graph.add_root(OptExprNode::OperatorRef(Operator::Dummy));
+ graph.add_root(Operator::Dummy);
assert_eq!(graph.graph.edge_count(), 4);
assert!(graph
diff --git a/src/optimizer/rule/column_pruning.rs b/src/optimizer/rule/column_pruning.rs
index f9a657db..0ae347dc 100644
--- a/src/optimizer/rule/column_pruning.rs
+++ b/src/optimizer/rule/column_pruning.rs
@@ -1,191 +1,153 @@
-use crate::catalog::ColumnRef;
+use crate::catalog::{ColumnRef, ColumnSummary};
+use crate::expression::agg::AggKind;
use crate::expression::ScalarExpression;
-use crate::optimizer::core::opt_expr::OptExprNode;
use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate};
use crate::optimizer::core::rule::Rule;
use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId};
use crate::optimizer::OptimizerError;
-use crate::planner::operator::aggregate::AggregateOperator;
-use crate::planner::operator::project::ProjectOperator;
use crate::planner::operator::Operator;
-use itertools::Itertools;
+use crate::types::value::DataValue;
+use crate::types::LogicalType;
use lazy_static::lazy_static;
+use std::collections::HashSet;
+use std::sync::Arc;
lazy_static! {
- static ref PUSH_PROJECT_INTO_SCAN_RULE: Pattern = {
+ static ref COLUMN_PRUNING_RULE: Pattern = {
Pattern {
- predicate: |op| matches!(op, Operator::Project(_)),
- children: PatternChildrenPredicate::Predicate(vec![Pattern {
- predicate: |op| matches!(op, Operator::Scan(_)),
- children: PatternChildrenPredicate::None,
- }]),
- }
- };
- static ref PUSH_PROJECT_THROUGH_CHILD_RULE: Pattern = {
- Pattern {
- predicate: |op| matches!(op, Operator::Project(_)),
- children: PatternChildrenPredicate::Predicate(vec![Pattern {
- predicate: |op| !matches!(op, Operator::Scan(_) | Operator::Project(_)),
- children: PatternChildrenPredicate::Predicate(vec![Pattern {
- predicate: |op| !matches!(op, Operator::Project(_)),
- children: PatternChildrenPredicate::None,
- }]),
- }]),
+ predicate: |_| true,
+ children: PatternChildrenPredicate::None,
}
};
}
-#[derive(Copy, Clone)]
-pub struct PushProjectIntoScan;
-
-impl Rule for PushProjectIntoScan {
- fn pattern(&self) -> &Pattern {
- &PUSH_PROJECT_INTO_SCAN_RULE
- }
-
- fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> {
- if let Operator::Project(project_op) = graph.operator(node_id) {
- let child_index = graph.children_at(node_id)[0];
- if let Operator::Scan(scan_op) = graph.operator(child_index) {
- let mut new_scan_op = scan_op.clone();
-
- new_scan_op.columns = project_op
- .columns
- .iter()
- .map(ScalarExpression::unpack_alias)
- .cloned()
- .collect_vec();
-
- graph.remove_node(node_id, false);
- graph.replace_node(
- child_index,
- OptExprNode::OperatorRef(Operator::Scan(new_scan_op)),
- );
- }
- }
-
- Ok(())
- }
-}
-
#[derive(Clone)]
-pub struct PushProjectThroughChild;
+pub struct ColumnPruning;
-impl Rule for PushProjectThroughChild {
- fn pattern(&self) -> &Pattern {
- &PUSH_PROJECT_THROUGH_CHILD_RULE
+impl ColumnPruning {
+ fn clear_exprs(
+ column_references: &mut HashSet,
+ exprs: &mut Vec,
+ ) {
+ exprs.retain(|expr| {
+ expr.referenced_columns()
+ .iter()
+ .any(|column| column_references.contains(column.summary()))
+ })
}
- fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> {
- let node_operator = graph.operator(node_id);
-
- if let Operator::Project(_) = node_operator {
- let child_index = graph.children_at(node_id)[0];
- let node_referenced_columns = node_operator.referenced_columns();
- let child_operator = graph.operator(child_index);
- let child_referenced_columns = child_operator.referenced_columns();
- let op = |col: &ColumnRef| format!("{:?}.{:?}", col.table_name, col.id);
-
- match child_operator {
- // When the aggregate function is a child node,
- // the pushdown will lose the corresponding ColumnRef due to `InputRef`.
- // Therefore, it is necessary to map the InputRef to the corresponding ColumnRef
- // and push it down.
- Operator::Aggregate(AggregateOperator { agg_calls, .. }) => {
- let grandson_id = graph.children_at(child_index)[0];
- let columns = node_operator
- .project_input_refs()
- .iter()
- .filter_map(|expr| {
- if agg_calls.is_empty() {
- return None;
- }
-
- if let ScalarExpression::InputRef { index, .. } = expr {
- agg_calls.get(*index).cloned()
- } else {
- None
- }
+ fn _apply(
+ column_references: &mut HashSet,
+ all_referenced: bool,
+ node_id: HepNodeId,
+ graph: &mut HepGraph,
+ ) {
+ let operator = graph.operator_mut(node_id);
+
+ match operator {
+ Operator::Aggregate(op) => {
+ if !all_referenced {
+ Self::clear_exprs(column_references, &mut op.agg_calls);
+
+ if op.agg_calls.is_empty() && op.groupby_exprs.is_empty() {
+ let value = Arc::new(DataValue::Utf8(Some("*".to_string())));
+ // only single COUNT(*) is not depend on any column
+ // removed all expressions from the aggregate: push a COUNT(*)
+ op.agg_calls.push(ScalarExpression::AggCall {
+ distinct: false,
+ kind: AggKind::Count,
+ args: vec![ScalarExpression::Constant(value)],
+ ty: LogicalType::Integer,
})
- .map(|expr| expr.referenced_columns())
- .flatten()
- .chain(node_referenced_columns.into_iter())
- .chain(child_referenced_columns.into_iter())
- .unique_by(op)
- .map(|col| ScalarExpression::ColumnRef(col))
- .collect_vec();
-
- Self::add_project_node(graph, child_index, columns, grandson_id);
- }
- Operator::Join(_) => {
- let parent_referenced_columns = node_referenced_columns
- .into_iter()
- .chain(child_referenced_columns.into_iter())
- .unique_by(op)
- .collect_vec();
-
- for grandson_id in graph.children_at(child_index) {
- let grandson_referenced_column =
- graph.operator(grandson_id).referenced_columns();
-
- // for PushLimitThroughJoin
- if grandson_referenced_column.is_empty() {
- return Ok(());
- }
- let grandson_table_name = grandson_referenced_column[0].table_name.clone();
- let columns = parent_referenced_columns
- .iter()
- .filter(|col| col.table_name == grandson_table_name)
- .cloned()
- .map(|col| ScalarExpression::ColumnRef(col))
- .collect_vec();
-
- Self::add_project_node(graph, child_index, columns, grandson_id);
}
}
- _ => {
- let grandson_ids = graph.children_at(child_index);
+ let op_ref_columns = operator.referenced_columns();
- if grandson_ids.is_empty() {
- return Ok(());
+ Self::recollect_apply(op_ref_columns, false, node_id, graph);
+ }
+ Operator::Project(op) => {
+ let has_count_star = op.exprs.iter().any(ScalarExpression::has_count_star);
+ if !has_count_star {
+ if !all_referenced {
+ Self::clear_exprs(column_references, &mut op.exprs);
}
- let grandson_id = grandson_ids[0];
- let mut columns = node_operator.project_input_refs();
- let mut referenced_columns = node_referenced_columns
- .into_iter()
- .chain(child_referenced_columns.into_iter())
- .unique_by(op)
- .map(|col| ScalarExpression::ColumnRef(col))
- .collect_vec();
-
- columns.append(&mut referenced_columns);
+ let op_ref_columns = operator.referenced_columns();
- Self::add_project_node(graph, child_index, columns, grandson_id);
+ Self::recollect_apply(op_ref_columns, false, node_id, graph);
}
}
- }
+ Operator::Sort(_op) => {
+ if !all_referenced {
+ // Todo: Order Project
+ // https://github.com/duckdb/duckdb/blob/main/src/optimizer/remove_unused_columns.cpp#L174
+ }
+ for child_id in graph.children_at(node_id) {
+ Self::_apply(column_references, true, child_id, graph);
+ }
+ }
+ Operator::Scan(op) => {
+ if !all_referenced {
+ Self::clear_exprs(column_references, &mut op.columns);
+ }
+ }
+ Operator::Limit(_) | Operator::Join(_) | Operator::Filter(_) => {
+ for column in operator.referenced_columns() {
+ column_references.insert(column.summary().clone());
+ }
+ for child_id in graph.children_at(node_id) {
+ Self::_apply(column_references, all_referenced, child_id, graph);
+ }
+ }
+ // Last Operator
+ Operator::Dummy | Operator::Values(_) => (),
+ // DDL Based on Other Plan
+ Operator::Insert(_) | Operator::Update(_) | Operator::Delete(_) => {
+ let op_ref_columns = operator.referenced_columns();
- Ok(())
+ Self::recollect_apply(op_ref_columns, true, graph.children_at(node_id)[0], graph);
+ }
+ // DDL Single Plan
+ Operator::CreateTable(_)
+ | Operator::DropTable(_)
+ | Operator::Truncate(_)
+ | Operator::Show(_)
+ | Operator::CopyFromFile(_)
+ | Operator::CopyToFile(_) => (),
+ }
}
-}
-impl PushProjectThroughChild {
- fn add_project_node(
+ fn recollect_apply(
+ referenced_columns: Vec,
+ all_referenced: bool,
+ node_id: HepNodeId,
graph: &mut HepGraph,
- child_index: HepNodeId,
- columns: Vec,
- grandson_id: HepNodeId,
) {
- if !columns.is_empty() {
- graph.add_node(
- child_index,
- Some(grandson_id),
- OptExprNode::OperatorRef(Operator::Project(ProjectOperator { columns })),
- );
+ for child_id in graph.children_at(node_id) {
+ let mut new_references: HashSet = referenced_columns
+ .iter()
+ .map(|column| column.summary())
+ .cloned()
+ .collect();
+
+ Self::_apply(&mut new_references, all_referenced, child_id, graph);
}
}
}
+impl Rule for ColumnPruning {
+ fn pattern(&self) -> &Pattern {
+ &COLUMN_PRUNING_RULE
+ }
+
+ fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> {
+ Self::_apply(&mut HashSet::new(), true, node_id, graph);
+ // mark changed to skip this rule batch
+ graph.version += 1;
+
+ Ok(())
+ }
+}
+
#[cfg(test)]
mod tests {
use crate::binder::test::select_sql_run;
@@ -197,47 +159,21 @@ mod tests {
use crate::planner::operator::Operator;
#[tokio::test]
- async fn test_project_into_table_scan() -> Result<(), DatabaseError> {
- let plan = select_sql_run("select * from t1").await?;
-
- let best_plan = HepOptimizer::new(plan.clone())
- .batch(
- "test_project_into_table_scan".to_string(),
- HepBatchStrategy::once_topdown(),
- vec![RuleImpl::PushProjectIntoScan],
- )
- .find_best()?;
-
- assert_eq!(best_plan.childrens.len(), 0);
- match best_plan.operator {
- Operator::Scan(op) => {
- assert_eq!(op.columns.len(), 2);
- }
- _ => unreachable!("Should be a scan operator"),
- }
-
- Ok(())
- }
-
- #[tokio::test]
- async fn test_project_through_child_on_join() -> Result<(), DatabaseError> {
+ async fn test_column_pruning() -> Result<(), DatabaseError> {
let plan = select_sql_run("select c1, c3 from t1 left join t2 on c1 = c3").await?;
let best_plan = HepOptimizer::new(plan.clone())
.batch(
- "test_project_through_child_on_join".to_string(),
- HepBatchStrategy::fix_point_topdown(10),
- vec![
- RuleImpl::PushProjectThroughChild,
- RuleImpl::PushProjectIntoScan,
- ],
+ "test_column_pruning".to_string(),
+ HepBatchStrategy::once_topdown(),
+ vec![RuleImpl::ColumnPruning],
)
.find_best()?;
assert_eq!(best_plan.childrens.len(), 1);
match best_plan.operator {
Operator::Project(op) => {
- assert_eq!(op.columns.len(), 2);
+ assert_eq!(op.exprs.len(), 2);
}
_ => unreachable!("Should be a project operator"),
}
diff --git a/src/optimizer/rule/combine_operators.rs b/src/optimizer/rule/combine_operators.rs
index 7105dada..a4daa1f1 100644
--- a/src/optimizer/rule/combine_operators.rs
+++ b/src/optimizer/rule/combine_operators.rs
@@ -1,5 +1,4 @@
use crate::expression::{BinaryOperator, ScalarExpression};
-use crate::optimizer::core::opt_expr::OptExprNode;
use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate};
use crate::optimizer::core::rule::Rule;
use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId};
@@ -43,7 +42,7 @@ impl Rule for CollapseProject {
if let Operator::Project(op) = graph.operator(node_id) {
let child_id = graph.children_at(node_id)[0];
if let Operator::Project(child_op) = graph.operator(child_id) {
- if is_subset_exprs(&op.columns, &child_op.columns) {
+ if is_subset_exprs(&op.exprs, &child_op.exprs) {
graph.remove_node(child_id, false);
} else {
graph.remove_node(node_id, false);
@@ -76,10 +75,7 @@ impl Rule for CombineFilter {
},
having: op.having || child_op.having,
};
- graph.replace_node(
- node_id,
- OptExprNode::OperatorRef(Operator::Filter(new_filter_op)),
- );
+ graph.replace_node(node_id, Operator::Filter(new_filter_op));
graph.remove_node(child_id, false);
}
}
@@ -94,7 +90,6 @@ mod tests {
use crate::db::DatabaseError;
use crate::expression::ScalarExpression::Constant;
use crate::expression::{BinaryOperator, ScalarExpression};
- use crate::optimizer::core::opt_expr::OptExprNode;
use crate::optimizer::heuristic::batch::HepBatchStrategy;
use crate::optimizer::heuristic::graph::HepNodeId;
use crate::optimizer::heuristic::optimizer::HepOptimizer;
@@ -117,19 +112,17 @@ mod tests {
let mut new_project_op = optimizer.graph.operator(HepNodeId::new(0)).clone();
if let Operator::Project(op) = &mut new_project_op {
- op.columns.remove(0);
+ op.exprs.remove(0);
} else {
unreachable!("Should be a project operator")
}
- optimizer
- .graph
- .add_root(OptExprNode::OperatorRef(new_project_op));
+ optimizer.graph.add_root(new_project_op);
let best_plan = optimizer.find_best()?;
if let Operator::Project(op) = &best_plan.operator {
- assert_eq!(op.columns.len(), 1);
+ assert_eq!(op.exprs.len(), 1);
} else {
unreachable!("Should be a project operator")
}
@@ -166,11 +159,9 @@ mod tests {
unreachable!("Should be a filter operator")
}
- optimizer.graph.add_node(
- HepNodeId::new(0),
- Some(HepNodeId::new(1)),
- OptExprNode::OperatorRef(new_filter_op),
- );
+ optimizer
+ .graph
+ .add_node(HepNodeId::new(0), Some(HepNodeId::new(1)), new_filter_op);
let best_plan = optimizer.find_best()?;
diff --git a/src/optimizer/rule/mod.rs b/src/optimizer/rule/mod.rs
index 846aa27c..1c9bbbed 100644
--- a/src/optimizer/rule/mod.rs
+++ b/src/optimizer/rule/mod.rs
@@ -2,13 +2,14 @@ use crate::expression::ScalarExpression;
use crate::optimizer::core::pattern::Pattern;
use crate::optimizer::core::rule::Rule;
use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId};
-use crate::optimizer::rule::column_pruning::{PushProjectIntoScan, PushProjectThroughChild};
+use crate::optimizer::rule::column_pruning::ColumnPruning;
use crate::optimizer::rule::combine_operators::{CollapseProject, CombineFilter};
use crate::optimizer::rule::pushdown_limit::{
EliminateLimits, LimitProjectTranspose, PushLimitIntoScan, PushLimitThroughJoin,
};
use crate::optimizer::rule::pushdown_predicates::PushPredicateIntoScan;
use crate::optimizer::rule::pushdown_predicates::PushPredicateThroughJoin;
+use crate::optimizer::rule::simplification::ConstantCalculation;
use crate::optimizer::rule::simplification::SimplifyFilter;
use crate::optimizer::OptimizerError;
@@ -20,9 +21,7 @@ mod simplification;
#[derive(Debug, Copy, Clone)]
pub enum RuleImpl {
- // Column pruning
- PushProjectIntoScan,
- PushProjectThroughChild,
+ ColumnPruning,
// Combine operators
CollapseProject,
CombineFilter,
@@ -37,38 +36,39 @@ pub enum RuleImpl {
PushPredicateIntoScan,
// Simplification
SimplifyFilter,
+ ConstantCalculation,
}
impl Rule for RuleImpl {
fn pattern(&self) -> &Pattern {
match self {
- RuleImpl::PushProjectIntoScan => PushProjectIntoScan {}.pattern(),
- RuleImpl::PushProjectThroughChild => PushProjectThroughChild {}.pattern(),
- RuleImpl::CollapseProject => CollapseProject {}.pattern(),
- RuleImpl::CombineFilter => CombineFilter {}.pattern(),
- RuleImpl::LimitProjectTranspose => LimitProjectTranspose {}.pattern(),
- RuleImpl::EliminateLimits => EliminateLimits {}.pattern(),
- RuleImpl::PushLimitThroughJoin => PushLimitThroughJoin {}.pattern(),
- RuleImpl::PushLimitIntoTableScan => PushLimitIntoScan {}.pattern(),
- RuleImpl::PushPredicateThroughJoin => PushPredicateThroughJoin {}.pattern(),
- RuleImpl::PushPredicateIntoScan => PushPredicateIntoScan {}.pattern(),
- RuleImpl::SimplifyFilter => SimplifyFilter {}.pattern(),
+ RuleImpl::ColumnPruning => ColumnPruning.pattern(),
+ RuleImpl::CollapseProject => CollapseProject.pattern(),
+ RuleImpl::CombineFilter => CombineFilter.pattern(),
+ RuleImpl::LimitProjectTranspose => LimitProjectTranspose.pattern(),
+ RuleImpl::EliminateLimits => EliminateLimits.pattern(),
+ RuleImpl::PushLimitThroughJoin => PushLimitThroughJoin.pattern(),
+ RuleImpl::PushLimitIntoTableScan => PushLimitIntoScan.pattern(),
+ RuleImpl::PushPredicateThroughJoin => PushPredicateThroughJoin.pattern(),
+ RuleImpl::PushPredicateIntoScan => PushPredicateIntoScan.pattern(),
+ RuleImpl::SimplifyFilter => SimplifyFilter.pattern(),
+ RuleImpl::ConstantCalculation => ConstantCalculation.pattern(),
}
}
fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> {
match self {
- RuleImpl::PushProjectIntoScan => PushProjectIntoScan {}.apply(node_id, graph),
- RuleImpl::PushProjectThroughChild => PushProjectThroughChild {}.apply(node_id, graph),
- RuleImpl::CollapseProject => CollapseProject {}.apply(node_id, graph),
- RuleImpl::CombineFilter => CombineFilter {}.apply(node_id, graph),
- RuleImpl::LimitProjectTranspose => LimitProjectTranspose {}.apply(node_id, graph),
- RuleImpl::EliminateLimits => EliminateLimits {}.apply(node_id, graph),
- RuleImpl::PushLimitThroughJoin => PushLimitThroughJoin {}.apply(node_id, graph),
- RuleImpl::PushLimitIntoTableScan => PushLimitIntoScan {}.apply(node_id, graph),
- RuleImpl::PushPredicateThroughJoin => PushPredicateThroughJoin {}.apply(node_id, graph),
- RuleImpl::SimplifyFilter => SimplifyFilter {}.apply(node_id, graph),
- RuleImpl::PushPredicateIntoScan => PushPredicateIntoScan {}.apply(node_id, graph),
+ RuleImpl::ColumnPruning => ColumnPruning.apply(node_id, graph),
+ RuleImpl::CollapseProject => CollapseProject.apply(node_id, graph),
+ RuleImpl::CombineFilter => CombineFilter.apply(node_id, graph),
+ RuleImpl::LimitProjectTranspose => LimitProjectTranspose.apply(node_id, graph),
+ RuleImpl::EliminateLimits => EliminateLimits.apply(node_id, graph),
+ RuleImpl::PushLimitThroughJoin => PushLimitThroughJoin.apply(node_id, graph),
+ RuleImpl::PushLimitIntoTableScan => PushLimitIntoScan.apply(node_id, graph),
+ RuleImpl::PushPredicateThroughJoin => PushPredicateThroughJoin.apply(node_id, graph),
+ RuleImpl::SimplifyFilter => SimplifyFilter.apply(node_id, graph),
+ RuleImpl::PushPredicateIntoScan => PushPredicateIntoScan.apply(node_id, graph),
+ RuleImpl::ConstantCalculation => ConstantCalculation.apply(node_id, graph),
}
}
}
diff --git a/src/optimizer/rule/pushdown_limit.rs b/src/optimizer/rule/pushdown_limit.rs
index d2504f3c..2884b84f 100644
--- a/src/optimizer/rule/pushdown_limit.rs
+++ b/src/optimizer/rule/pushdown_limit.rs
@@ -1,4 +1,3 @@
-use crate::optimizer::core::opt_expr::OptExprNode;
use crate::optimizer::core::pattern::Pattern;
use crate::optimizer::core::pattern::PatternChildrenPredicate;
use crate::optimizer::core::rule::Rule;
@@ -81,10 +80,7 @@ impl Rule for EliminateLimits {
let new_limit_op = LimitOperator { offset, limit };
graph.remove_node(child_id, false);
- graph.replace_node(
- node_id,
- OptExprNode::OperatorRef(Operator::Limit(new_limit_op)),
- );
+ graph.replace_node(node_id, Operator::Limit(new_limit_op));
}
}
@@ -135,11 +131,7 @@ impl Rule for PushLimitThroughJoin {
JoinType::Right => Some(graph.children_at(child_id)[1]),
_ => None,
} {
- graph.add_node(
- child_id,
- Some(grandson_id),
- OptExprNode::OperatorRef(Operator::Limit(op.clone())),
- );
+ graph.add_node(child_id, Some(grandson_id), Operator::Limit(op.clone()));
}
}
}
@@ -165,10 +157,7 @@ impl Rule for PushLimitIntoScan {
new_scan_op.limit = (limit_op.offset, limit_op.limit);
graph.remove_node(node_id, false);
- graph.replace_node(
- child_index,
- OptExprNode::OperatorRef(Operator::Scan(new_scan_op)),
- );
+ graph.replace_node(child_index, Operator::Scan(new_scan_op));
}
}
@@ -180,7 +169,6 @@ impl Rule for PushLimitIntoScan {
mod tests {
use crate::binder::test::select_sql_run;
use crate::db::DatabaseError;
- use crate::optimizer::core::opt_expr::OptExprNode;
use crate::optimizer::heuristic::batch::HepBatchStrategy;
use crate::optimizer::heuristic::optimizer::HepOptimizer;
use crate::optimizer::rule::RuleImpl;
@@ -227,9 +215,7 @@ mod tests {
limit: Some(1),
};
- optimizer
- .graph
- .add_root(OptExprNode::OperatorRef(Operator::Limit(new_limit_op)));
+ optimizer.graph.add_root(Operator::Limit(new_limit_op));
let best_plan = optimizer.find_best()?;
diff --git a/src/optimizer/rule/pushdown_predicates.rs b/src/optimizer/rule/pushdown_predicates.rs
index c3268eeb..d5132148 100644
--- a/src/optimizer/rule/pushdown_predicates.rs
+++ b/src/optimizer/rule/pushdown_predicates.rs
@@ -1,6 +1,5 @@
use crate::catalog::ColumnRef;
use crate::expression::{BinaryOperator, ScalarExpression};
-use crate::optimizer::core::opt_expr::OptExprNode;
use crate::optimizer::core::pattern::Pattern;
use crate::optimizer::core::pattern::PatternChildrenPredicate;
use crate::optimizer::core::rule::Rule;
@@ -177,23 +176,15 @@ impl Rule for PushPredicateThroughJoin {
}
if let Some(left_op) = new_ops.0 {
- graph.add_node(
- child_id,
- Some(join_childs[0]),
- OptExprNode::OperatorRef(left_op),
- );
+ graph.add_node(child_id, Some(join_childs[0]), left_op);
}
if let Some(right_op) = new_ops.1 {
- graph.add_node(
- child_id,
- Some(join_childs[1]),
- OptExprNode::OperatorRef(right_op),
- );
+ graph.add_node(child_id, Some(join_childs[1]), right_op);
}
if let Some(common_op) = new_ops.2 {
- graph.replace_node(node_id, OptExprNode::OperatorRef(common_op));
+ graph.replace_node(node_id, common_op);
} else {
graph.remove_node(node_id, false);
}
@@ -203,7 +194,7 @@ impl Rule for PushPredicateThroughJoin {
}
}
-pub struct PushPredicateIntoScan {}
+pub struct PushPredicateIntoScan;
impl Rule for PushPredicateIntoScan {
fn pattern(&self) -> &Pattern {
@@ -232,10 +223,7 @@ impl Rule for PushPredicateIntoScan {
// The constant expression extracted in prewhere is used to
// reduce the data scanning range and cannot replace the role of Filter.
- graph.replace_node(
- child_id,
- OptExprNode::OperatorRef(Operator::Scan(scan_by_index)),
- );
+ graph.replace_node(child_id, Operator::Scan(scan_by_index));
return Ok(());
}
diff --git a/src/optimizer/rule/simplification.rs b/src/optimizer/rule/simplification.rs
index fc5e324d..3f004451 100644
--- a/src/optimizer/rule/simplification.rs
+++ b/src/optimizer/rule/simplification.rs
@@ -1,11 +1,17 @@
-use crate::optimizer::core::opt_expr::OptExprNode;
use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate};
use crate::optimizer::core::rule::Rule;
use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId};
use crate::optimizer::OptimizerError;
+use crate::planner::operator::join::JoinCondition;
use crate::planner::operator::Operator;
use lazy_static::lazy_static;
lazy_static! {
+ static ref CONSTANT_CALCULATION_RULE: Pattern = {
+ Pattern {
+ predicate: |_| true,
+ children: PatternChildrenPredicate::None,
+ }
+ };
static ref SIMPLIFY_FILTER_RULE: Pattern = {
Pattern {
predicate: |op| matches!(op, Operator::Filter(_)),
@@ -17,6 +23,72 @@ lazy_static! {
};
}
+#[derive(Copy, Clone)]
+pub struct ConstantCalculation;
+
+impl ConstantCalculation {
+ fn _apply(node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> {
+ let operator = graph.operator_mut(node_id);
+
+ match operator {
+ Operator::Aggregate(op) => {
+ for expr in op.agg_calls.iter_mut().chain(op.groupby_exprs.iter_mut()) {
+ expr.constant_calculation()?;
+ }
+ }
+ Operator::Filter(op) => {
+ op.predicate.constant_calculation()?;
+ }
+ Operator::Join(op) => {
+ if let JoinCondition::On { on, filter } = &mut op.on {
+ for (left_expr, right_expr) in on {
+ left_expr.constant_calculation()?;
+ right_expr.constant_calculation()?;
+ }
+ if let Some(expr) = filter {
+ expr.constant_calculation()?;
+ }
+ }
+ }
+ Operator::Project(op) => {
+ for expr in &mut op.exprs {
+ expr.constant_calculation()?;
+ }
+ }
+ Operator::Scan(op) => {
+ for expr in &mut op.columns {
+ expr.constant_calculation()?;
+ }
+ }
+ Operator::Sort(op) => {
+ for field in &mut op.sort_fields {
+ field.expr.constant_calculation()?;
+ }
+ }
+ _ => (),
+ }
+ for child_id in graph.children_at(node_id) {
+ Self::_apply(child_id, graph)?;
+ }
+
+ Ok(())
+ }
+}
+
+impl Rule for ConstantCalculation {
+ fn pattern(&self) -> &Pattern {
+ &CONSTANT_CALCULATION_RULE
+ }
+
+ fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> {
+ Self::_apply(node_id, graph)?;
+ // mark changed to skip this rule batch
+ graph.version += 1;
+
+ Ok(())
+ }
+}
+
#[derive(Copy, Clone)]
pub struct SimplifyFilter;
@@ -28,11 +100,9 @@ impl Rule for SimplifyFilter {
fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), OptimizerError> {
if let Operator::Filter(mut filter_op) = graph.operator(node_id).clone() {
filter_op.predicate.simplify()?;
+ filter_op.predicate.constant_calculation()?;
- graph.replace_node(
- node_id,
- OptExprNode::OperatorRef(Operator::Filter(filter_op)),
- )
+ graph.replace_node(node_id, Operator::Filter(filter_op))
}
Ok(())
@@ -42,7 +112,7 @@ impl Rule for SimplifyFilter {
#[cfg(test)]
mod test {
use crate::binder::test::select_sql_run;
- use crate::catalog::{ColumnCatalog, ColumnDesc};
+ use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnSummary};
use crate::db::DatabaseError;
use crate::expression::simplify::ConstantBinary;
use crate::expression::{BinaryOperator, ScalarExpression, UnaryOperator};
@@ -57,6 +127,45 @@ mod test {
use std::collections::Bound;
use std::sync::Arc;
+ #[tokio::test]
+ async fn test_constant_calculation_omitted() -> Result<(), DatabaseError> {
+ // (2 + (-1)) < -(c1 + 1)
+ let plan =
+ select_sql_run("select c1 + (2 + 1), 2 + 1 from t1 where (2 + (-1)) < -(c1 + 1)")
+ .await?;
+
+ let best_plan = HepOptimizer::new(plan)
+ .batch(
+ "test_simplification".to_string(),
+ HepBatchStrategy::once_topdown(),
+ vec![RuleImpl::SimplifyFilter, RuleImpl::ConstantCalculation],
+ )
+ .find_best()?;
+ if let Operator::Project(project_op) = best_plan.clone().operator {
+ let constant_expr = ScalarExpression::Constant(Arc::new(DataValue::Int32(Some(3))));
+ if let ScalarExpression::Binary { right_expr, .. } = &project_op.exprs[0] {
+ assert_eq!(right_expr.as_ref(), &constant_expr);
+ } else {
+ unreachable!();
+ }
+ assert_eq!(&project_op.exprs[1], &constant_expr);
+ } else {
+ unreachable!();
+ }
+ if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator {
+ let column_binary = filter_op.predicate.convert_binary(&0).unwrap();
+ let final_binary = ConstantBinary::Scope {
+ min: Bound::Unbounded,
+ max: Bound::Excluded(Arc::new(DataValue::Int32(Some(-2)))),
+ };
+ assert_eq!(column_binary, Some(final_binary));
+ } else {
+ unreachable!();
+ }
+
+ Ok(())
+ }
+
#[tokio::test]
async fn test_simplify_filter_single_column() -> Result<(), DatabaseError> {
// c1 + 1 < -1 => c1 < -2
@@ -137,9 +246,11 @@ mod test {
.find_best()?;
if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator {
let c1_col = ColumnCatalog {
- id: Some(0),
- name: "c1".to_string(),
- table_name: Some(Arc::new("t1".to_string())),
+ summary: ColumnSummary {
+ id: Some(0),
+ name: "c1".to_string(),
+ table_name: Some(Arc::new("t1".to_string())),
+ },
nullable: false,
desc: ColumnDesc {
column_datatype: LogicalType::Integer,
@@ -149,9 +260,11 @@ mod test {
ref_expr: None,
};
let c2_col = ColumnCatalog {
- id: Some(1),
- name: "c2".to_string(),
- table_name: Some(Arc::new("t1".to_string())),
+ summary: ColumnSummary {
+ id: Some(1),
+ name: "c2".to_string(),
+ table_name: Some(Arc::new("t1".to_string())),
+ },
nullable: false,
desc: ColumnDesc {
column_datatype: LogicalType::Integer,
diff --git a/src/planner/operator/mod.rs b/src/planner/operator/mod.rs
index 4be5b0af..633012db 100644
--- a/src/planner/operator/mod.rs
+++ b/src/planner/operator/mod.rs
@@ -67,7 +67,7 @@ impl Operator {
pub fn project_input_refs(&self) -> Vec {
match self {
Operator::Project(op) => op
- .columns
+ .exprs
.iter()
.map(ScalarExpression::unpack_alias)
.filter(|expr| matches!(expr, ScalarExpression::InputRef { .. }))
@@ -125,7 +125,7 @@ impl Operator {
exprs
}
Operator::Project(op) => op
- .columns
+ .exprs
.iter()
.flat_map(|expr| expr.referenced_columns())
.collect_vec(),
diff --git a/src/planner/operator/project.rs b/src/planner/operator/project.rs
index 9d345e5f..ca7811e9 100644
--- a/src/planner/operator/project.rs
+++ b/src/planner/operator/project.rs
@@ -2,5 +2,5 @@ use crate::expression::ScalarExpression;
#[derive(Debug, PartialEq, Clone)]
pub struct ProjectOperator {
- pub columns: Vec,
+ pub exprs: Vec,
}
diff --git a/src/storage/kip.rs b/src/storage/kip.rs
index b6fffcec..73b61379 100644
--- a/src/storage/kip.rs
+++ b/src/storage/kip.rs
@@ -450,11 +450,11 @@ impl KipTransaction {
.into_iter()
.filter(|col| col.desc.is_unique)
{
- if let Some(col_id) = col.id {
+ if let Some(col_id) = col.id() {
let meta = IndexMeta {
id: 0,
column_ids: vec![col_id],
- name: format!("uk_{}", col.name),
+ name: format!("uk_{}", col.name()),
is_unique: true,
};
let meta_ref = table.add_index_meta(meta);
@@ -584,6 +584,7 @@ mod test {
vec![ScalarExpression::InputRef {
index: 0,
ty: LogicalType::Integer,
+ ref_columns: vec![],
}],
)?;
diff --git a/src/storage/table_codec.rs b/src/storage/table_codec.rs
index d03beb01..483eab2e 100644
--- a/src/storage/table_codec.rs
+++ b/src/storage/table_codec.rs
@@ -211,10 +211,10 @@ impl TableCodec {
/// Tips: the `0` for bound range
pub fn encode_column(col: &ColumnCatalog) -> Result<(Bytes, Bytes), TypeError> {
let bytes = bincode::serialize(col)?;
- let mut key_prefix = Self::key_prefix(CodecType::Column, col.table_name.as_ref().unwrap());
+ let mut key_prefix = Self::key_prefix(CodecType::Column, &col.table_name().unwrap());
key_prefix.push(BOUND_MIN_TAG);
- key_prefix.append(&mut col.id.unwrap().to_be_bytes().to_vec());
+ key_prefix.append(&mut col.id().unwrap().to_be_bytes().to_vec());
Ok((Bytes::from(key_prefix), Bytes::from(bytes)))
}
@@ -222,7 +222,7 @@ impl TableCodec {
pub fn decode_column(bytes: &[u8]) -> Result<(TableName, ColumnCatalog), TypeError> {
let column = bincode::deserialize::(bytes)?;
- Ok((column.table_name.clone().unwrap(), column))
+ Ok((column.table_name().unwrap(), column))
}
/// Key: RootCatalog_0_TableName
@@ -369,8 +369,8 @@ mod tests {
None,
);
- col.table_name = Some(Arc::new(table_name.to_string()));
- col.id = Some(col_id as u32);
+ col.summary.table_name = Some(Arc::new(table_name.to_string()));
+ col.summary.id = Some(col_id as u32);
let (key, _) = TableCodec::encode_column(&col).unwrap();
key
diff --git a/src/types/tuple.rs b/src/types/tuple.rs
index 0359d569..0e75aaf3 100644
--- a/src/types/tuple.rs
+++ b/src/types/tuple.rs
@@ -101,7 +101,7 @@ pub fn create_table(tuples: &[Tuple]) -> Table {
let mut header = Vec::new();
for col in &tuples[0].columns {
- header.push(Cell::new(col.name.clone()));
+ header.push(Cell::new(col.name().to_string()));
}
table.set_header(header);
diff --git a/src/types/tuple_builder.rs b/src/types/tuple_builder.rs
index aa6f0778..8544cd1a 100644
--- a/src/types/tuple_builder.rs
+++ b/src/types/tuple_builder.rs
@@ -53,7 +53,7 @@ impl TupleBuilder {
let cast_data_value = data_value.cast(&self.data_types[i])?;
self.data_values.push(Arc::new(cast_data_value.clone()));
let col = &columns[i];
- col.id
+ col.id()
.map(|col_id| tuple_map.insert(col_id, Arc::new(cast_data_value.clone())));
if primary_key_index.is_none() && col.desc.is_primary {
primary_key_index = Some(i);
@@ -61,7 +61,7 @@ impl TupleBuilder {
}
let primary_col_id = primary_key_index
- .map(|i| columns[i].id.unwrap())
+ .map(|i| columns[i].id().unwrap())
.ok_or_else(|| TypeError::PrimaryKeyNotFound)?;
let tuple_id = tuple_map
diff --git a/tests/slt/filter.slt b/tests/slt/filter.slt
index ccd46821..14646fe2 100644
--- a/tests/slt/filter.slt
+++ b/tests/slt/filter.slt
@@ -68,4 +68,40 @@ select v2 from t where ((v2 >= -8 and -4 >= v1) or (v1 >= 0 and 5 > v2)) and ((v
1
statement ok
-drop table t
\ No newline at end of file
+create table t1(id int primary key, v1 varchar)
+
+statement ok
+insert into t1 values (0, 'KipSQL'), (1, 'KipDB'), (2, 'KipBlog'), (3, 'Cool!');
+
+query II
+select * from t1 where v1 like 'Kip%'
+----
+0 KipSQL
+1 KipDB
+2 KipBlog
+
+query II
+select * from t1 where v1 not like 'Kip%'
+----
+3 Cool!
+
+query II
+select * from t1 where v1 like 'KipD_'
+----
+1 KipDB
+
+query II
+select * from t1 where v1 like 'KipS_L'
+----
+0 KipSQL
+
+query II
+select * from t1 where v1 like 'K%L'
+----
+0 KipSQL
+
+statement ok
+drop table t
+
+statement ok
+drop table t1
\ No newline at end of file
diff --git a/tests/slt/filter_null.slt b/tests/slt/filter_null.slt
index db45712f..1bfd0ad3 100644
--- a/tests/slt/filter_null.slt
+++ b/tests/slt/filter_null.slt
@@ -32,5 +32,17 @@ select * from t where v1 > 1
2 3 4
3 4 3
+query II
+select * from t where v1 is null
+----
+1 null 3
+
+query II
+select * from t where v1 is not null
+----
+0 2 4
+2 3 4
+3 4 3
+
statement ok
drop table t
\ No newline at end of file
diff --git a/tests/sqllogictest/src/main.rs b/tests/sqllogictest/src/main.rs
index 58776d61..b65bbd0d 100644
--- a/tests/sqllogictest/src/main.rs
+++ b/tests/sqllogictest/src/main.rs
@@ -6,16 +6,15 @@ use tempfile::TempDir;
#[tokio::main]
async fn main() {
+ const SLT_PATTERN: &str = "tests/slt/**/*.slt";
+
let path = Path::new(env!("CARGO_MANIFEST_DIR")).join("..").join("..");
std::env::set_current_dir(path).unwrap();
println!("KipSQL Test Start!\n");
- const SLT_PATTERN: &str = "tests/slt/**/*.slt";
- let slt_files = glob::glob(SLT_PATTERN).expect("failed to find slt files");
- for slt_file in slt_files {
+ for slt_file in glob::glob(SLT_PATTERN).expect("failed to find slt files") {
let temp_dir = TempDir::new().expect("unable to create temporary working directory");
-
let filepath = slt_file
.expect("failed to read slt file")
.to_str()