diff --git a/Cargo.lock b/Cargo.lock index 764abd78fb5..b8611faabe6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4658,9 +4658,11 @@ dependencies = [ "spacetimedb-client-api", "spacetimedb-core", "spacetimedb-data-structures", + "spacetimedb-execution", "spacetimedb-lib", "spacetimedb-paths", "spacetimedb-primitives", + "spacetimedb-query", "spacetimedb-sats", "spacetimedb-schema", "spacetimedb-standalone", @@ -4910,13 +4912,16 @@ dependencies = [ "spacetimedb-commitlog", "spacetimedb-data-structures", "spacetimedb-durability", + "spacetimedb-execution", "spacetimedb-expr", "spacetimedb-jsonwebtoken", "spacetimedb-jwks", "spacetimedb-lib", "spacetimedb-metrics", "spacetimedb-paths", + "spacetimedb-physical-plan", "spacetimedb-primitives", + "spacetimedb-query", "spacetimedb-sats", "spacetimedb-schema", "spacetimedb-snapshot", @@ -4975,9 +4980,12 @@ dependencies = [ name = "spacetimedb-execution" version = "1.0.0-rc3" dependencies = [ + "anyhow", "spacetimedb-expr", "spacetimedb-lib", + "spacetimedb-physical-plan", "spacetimedb-primitives", + "spacetimedb-sql-parser", "spacetimedb-table", ] @@ -5091,12 +5099,15 @@ dependencies = [ name = "spacetimedb-physical-plan" version = "1.0.0-rc3" dependencies = [ + "anyhow", "derive_more", + "pretty_assertions", "spacetimedb-expr", "spacetimedb-lib", "spacetimedb-primitives", "spacetimedb-schema", "spacetimedb-sql-parser", + "spacetimedb-table", ] [[package]] @@ -5110,6 +5121,23 @@ dependencies = [ "proptest", ] +[[package]] +name = "spacetimedb-query" +version = "1.0.0-rc3" +dependencies = [ + "anyhow", + "itertools 0.12.1", + "rayon", + "spacetimedb-client-api-messages", + "spacetimedb-execution", + "spacetimedb-expr", + "spacetimedb-lib", + "spacetimedb-physical-plan", + "spacetimedb-primitives", + "spacetimedb-sql-parser", + "spacetimedb-table", +] + [[package]] name = "spacetimedb-quickstart-module" version = "0.1.0" @@ -5320,6 +5348,7 @@ dependencies = [ "log", "smallvec", "spacetimedb-data-structures", + "spacetimedb-execution", "spacetimedb-lib", "spacetimedb-primitives", "spacetimedb-sats", diff --git a/Cargo.toml b/Cargo.toml index acffd380999..77bd1ea48f1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ members = [ "crates/paths", "crates/physical-plan", "crates/primitives", + "crates/query", "crates/sats", "crates/schema", "crates/sdk", @@ -106,6 +107,7 @@ spacetimedb-metrics = { path = "crates/metrics", version = "1.0.0-rc3" } spacetimedb-paths = { path = "crates/paths", version = "1.0.0-rc3" } spacetimedb-physical-plan = { path = "crates/physical-plan", version = "1.0.0-rc3" } spacetimedb-primitives = { path = "crates/primitives", version = "1.0.0-rc3" } +spacetimedb-query = { path = "crates/query", version = "1.0.0-rc3" } spacetimedb-sats = { path = "crates/sats", version = "1.0.0-rc3" } spacetimedb-schema = { path = "crates/schema", version = "1.0.0-rc3" } spacetimedb-standalone = { path = "crates/standalone", version = "1.0.0-rc3" } diff --git a/crates/bench/Cargo.toml b/crates/bench/Cargo.toml index 53bc3f3298e..b606bef652f 100644 --- a/crates/bench/Cargo.toml +++ b/crates/bench/Cargo.toml @@ -31,9 +31,11 @@ bench = false spacetimedb-client-api = { path = "../client-api" } spacetimedb-core = { path = "../core", features = ["test"] } spacetimedb-data-structures.workspace = true +spacetimedb-execution = { path = "../execution" } spacetimedb-lib = { path = "../lib" } spacetimedb-paths.workspace = true spacetimedb-primitives = { path = "../primitives" } +spacetimedb-query = { path = "../query" } spacetimedb-sats = { path = "../sats" } spacetimedb-schema = { workspace = true, features = ["test"] } spacetimedb-standalone = { path = "../standalone" } diff --git a/crates/bench/benches/subscription.rs b/crates/bench/benches/subscription.rs index ae26a40a114..8a889963ebe 100644 --- a/crates/bench/benches/subscription.rs +++ b/crates/bench/benches/subscription.rs @@ -4,12 +4,15 @@ use spacetimedb::execution_context::Workload; use spacetimedb::host::module_host::DatabaseTableUpdate; use spacetimedb::identity::AuthCtx; use spacetimedb::messages::websocket::BsatnFormat; +use spacetimedb::sql::ast::SchemaViewer; use spacetimedb::subscription::query::compile_read_only_queryset; use spacetimedb::subscription::subscription::ExecutionSet; +use spacetimedb::subscription::tx::DeltaTx; use spacetimedb::{db::relational_db::RelationalDB, messages::websocket::Compression}; use spacetimedb_bench::database::BenchDatabase as _; use spacetimedb_bench::spacetime_raw::SpacetimeRaw; use spacetimedb_primitives::{col_list, TableId}; +use spacetimedb_query::SubscribePlan; use spacetimedb_sats::{bsatn, product, AlgebraicType, AlgebraicValue, ProductValue}; fn create_table_location(db: &RelationalDB) -> Result { @@ -107,6 +110,23 @@ fn eval(c: &mut Criterion) { let ins_rhs = insert_op(rhs, "location", new_rhs_row); let update = [&ins_lhs, &ins_rhs]; + // A benchmark runner for the new query engine + let bench_query = |c: &mut Criterion, name, sql| { + c.bench_function(name, |b| { + let tx = raw.db.begin_tx(Workload::Subscribe); + let auth = AuthCtx::for_testing(); + let schema_viewer = &SchemaViewer::new(&tx, &auth); + let plan = SubscribePlan::compile(sql, schema_viewer).unwrap(); + let tx = DeltaTx::from(&tx); + + b.iter(|| { + drop(black_box( + plan.collect_table_update::<_, BsatnFormat>(Compression::None, &tx), + )) + }) + }); + }; + let bench_eval = |c: &mut Criterion, name, sql| { c.bench_function(name, |b| { let tx = raw.db.begin_tx(Workload::Update); @@ -124,6 +144,22 @@ fn eval(c: &mut Criterion) { }); }; + // Join 1M rows on the left with 12K rows on the right. + // Note, this should use an index join so as not to read the entire footprint table. + let semijoin = format!( + r#" + select f.* + from footprint f join location l on f.entity_id = l.entity_id + where l.chunk_index = {chunk_index} + "# + ); + + let index_scan_multi = "select * from location WHERE x = 0 AND z = 10000 AND dimension = 0"; + + bench_query(c, "footprint-scan", "select * from footprint"); + bench_query(c, "footprint-semijoin", &semijoin); + bench_query(c, "index-scan-multi", index_scan_multi); + // To profile this benchmark for 30s // samply record -r 10000000 cargo bench --bench=subscription --profile=profiling -- full-scan --exact --profile-time=30 // Iterate 1M rows. @@ -132,7 +168,7 @@ fn eval(c: &mut Criterion) { // To profile this benchmark for 30s // samply record -r 10000000 cargo bench --bench=subscription --profile=profiling -- full-join --exact --profile-time=30 // Join 1M rows on the left with 12K rows on the right. - // Note, this should use an index join so as not to read the entire lhs table. + // Note, this should use an index join so as not to read the entire footprint table. let name = format!( r#" select footprint.* diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 21976f25cdd..2cf6fb59808 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -23,12 +23,15 @@ spacetimedb-durability.workspace = true spacetimedb-metrics.workspace = true spacetimedb-primitives.workspace = true spacetimedb-paths.workspace = true +spacetimedb-physical-plan.workspace = true +spacetimedb-query.workspace = true spacetimedb-sats = { workspace = true, features = ["serde"] } spacetimedb-schema.workspace = true spacetimedb-table.workspace = true spacetimedb-vm.workspace = true spacetimedb-snapshot.workspace = true spacetimedb-expr.workspace = true +spacetimedb-execution.workspace = true anyhow = { workspace = true, features = ["backtrace"] } arrayvec.workspace = true diff --git a/crates/core/src/client/client_connection.rs b/crates/core/src/client/client_connection.rs index 92621c631a0..cdf94aac1b1 100644 --- a/crates/core/src/client/client_connection.rs +++ b/crates/core/src/client/client_connection.rs @@ -13,7 +13,7 @@ use crate::worker_metrics::WORKER_METRICS; use derive_more::From; use futures::prelude::*; use spacetimedb_client_api_messages::websocket::{ - CallReducerFlags, Compression, FormatSwitch, SubscribeSingle, Unsubscribe, + BsatnFormat, CallReducerFlags, Compression, FormatSwitch, JsonFormat, SubscribeSingle, Unsubscribe, WebsocketFormat, }; use spacetimedb_lib::identity::RequestId; use tokio::sync::{mpsc, oneshot, watch}; @@ -314,26 +314,41 @@ impl ClientConnection { .unwrap() } - pub fn one_off_query(&self, query: &str, message_id: &[u8], timer: Instant) -> Result<(), anyhow::Error> { - let result = self.module.one_off_query(self.id.identity, query.to_owned()); + pub fn one_off_query_json(&self, query: &str, message_id: &[u8], timer: Instant) -> Result<(), anyhow::Error> { + let response = self.one_off_query::(query, message_id, timer); + self.send_message(response)?; + Ok(()) + } + + pub fn one_off_query_bsatn(&self, query: &str, message_id: &[u8], timer: Instant) -> Result<(), anyhow::Error> { + let response = self.one_off_query::(query, message_id, timer); + self.send_message(response)?; + Ok(()) + } + + fn one_off_query( + &self, + query: &str, + message_id: &[u8], + timer: Instant, + ) -> OneOffQueryResponseMessage { + let result = self.module.one_off_query::(self.id.identity, query.to_owned()); let message_id = message_id.to_owned(); let total_host_execution_duration = timer.elapsed().as_micros() as u64; - let response = match result { + match result { Ok(results) => OneOffQueryResponseMessage { message_id, error: None, - results, + results: vec![results], total_host_execution_duration, }, Err(err) => OneOffQueryResponseMessage { message_id, error: Some(format!("{}", err)), - results: Vec::new(), + results: vec![], total_host_execution_duration, }, - }; - self.send_message(response)?; - Ok(()) + } } pub async fn disconnect(self) { diff --git a/crates/core/src/client/message_handlers.rs b/crates/core/src/client/message_handlers.rs index c2e4312d42f..4e153d19c26 100644 --- a/crates/core/src/client/message_handlers.rs +++ b/crates/core/src/client/message_handlers.rs @@ -1,5 +1,5 @@ use super::messages::{SubscriptionUpdateMessage, SwitchedServerMessage, ToProtocol, TransactionUpdateMessage}; -use super::{ClientConnection, DataMessage}; +use super::{ClientConnection, DataMessage, Protocol}; use crate::energy::EnergyQuanta; use crate::execution_context::WorkloadType; use crate::host::module_host::{EventStatus, ModuleEvent, ModuleFunctionCall}; @@ -107,7 +107,10 @@ pub async fn handle(client: &ClientConnection, message: DataMessage, timer: Inst query_string: query, message_id, }) => { - let res = client.one_off_query(&query, &message_id, timer); + let res = match client.config.protocol { + Protocol::Binary => client.one_off_query_bsatn(&query, &message_id, timer), + Protocol::Text => client.one_off_query_json(&query, &message_id, timer), + }; WORKER_METRICS .request_round_trip .with_label_values(&WorkloadType::Sql, &address, "") diff --git a/crates/core/src/client/messages.rs b/crates/core/src/client/messages.rs index 305758d6059..c2c0b94bf95 100644 --- a/crates/core/src/client/messages.rs +++ b/crates/core/src/client/messages.rs @@ -5,15 +5,14 @@ use crate::host::ArgsTuple; use crate::messages::websocket as ws; use derive_more::From; use spacetimedb_client_api_messages::websocket::{ - BsatnFormat, Compression, FormatSwitch, JsonFormat, WebsocketFormat, SERVER_MSG_COMPRESSION_TAG_BROTLI, - SERVER_MSG_COMPRESSION_TAG_GZIP, SERVER_MSG_COMPRESSION_TAG_NONE, + BsatnFormat, Compression, FormatSwitch, JsonFormat, OneOffTable, RowListLen, WebsocketFormat, + SERVER_MSG_COMPRESSION_TAG_BROTLI, SERVER_MSG_COMPRESSION_TAG_GZIP, SERVER_MSG_COMPRESSION_TAG_NONE, }; use spacetimedb_lib::identity::RequestId; use spacetimedb_lib::ser::serde::SerializeWrapper; use spacetimedb_lib::Address; use spacetimedb_primitives::TableId; use spacetimedb_sats::bsatn; -use spacetimedb_vm::relation::MemTable; use std::sync::Arc; use std::time::Instant; @@ -64,7 +63,8 @@ pub fn serialize(msg: impl ToProtocol, config: #[derive(Debug, From)] pub enum SerializableMessage { - Query(OneOffQueryResponseMessage), + QueryBinary(OneOffQueryResponseMessage), + QueryText(OneOffQueryResponseMessage), Identity(IdentityTokenMessage), Subscribe(SubscriptionUpdateMessage), Subscription(SubscriptionMessage), @@ -74,7 +74,8 @@ pub enum SerializableMessage { impl SerializableMessage { pub fn num_rows(&self) -> Option { match self { - Self::Query(msg) => Some(msg.num_rows()), + Self::QueryBinary(msg) => Some(msg.num_rows()), + Self::QueryText(msg) => Some(msg.num_rows()), Self::Subscribe(msg) => Some(msg.num_rows()), Self::Subscription(msg) => Some(msg.num_rows()), Self::TxUpdate(msg) => Some(msg.num_rows()), @@ -84,7 +85,7 @@ impl SerializableMessage { pub fn workload(&self) -> Option { match self { - Self::Query(_) => Some(WorkloadType::Sql), + Self::QueryBinary(_) | Self::QueryText(_) => Some(WorkloadType::Sql), Self::Subscribe(_) => Some(WorkloadType::Subscribe), Self::Subscription(msg) => match &msg.result { SubscriptionResult::Subscribe(_) => Some(WorkloadType::Subscribe), @@ -101,7 +102,8 @@ impl ToProtocol for SerializableMessage { type Encoded = SwitchedServerMessage; fn to_protocol(self, protocol: Protocol) -> Self::Encoded { match self { - SerializableMessage::Query(msg) => msg.to_protocol(protocol), + SerializableMessage::QueryBinary(msg) => msg.to_protocol(protocol), + SerializableMessage::QueryText(msg) => msg.to_protocol(protocol), SerializableMessage::Identity(msg) => msg.to_protocol(protocol), SerializableMessage::Subscribe(msg) => msg.to_protocol(protocol), SerializableMessage::TxUpdate(msg) => msg.to_protocol(protocol), @@ -402,42 +404,38 @@ impl ToProtocol for SubscriptionMessage { } #[derive(Debug)] -pub struct OneOffQueryResponseMessage { +pub struct OneOffQueryResponseMessage { pub message_id: Vec, pub error: Option, - pub results: Vec, + pub results: Vec>, pub total_host_execution_duration: u64, } -impl OneOffQueryResponseMessage { +impl OneOffQueryResponseMessage { fn num_rows(&self) -> usize { - self.results.iter().map(|t| t.data.len()).sum() + self.results.iter().map(|table| table.rows.len()).sum() } } -impl ToProtocol for OneOffQueryResponseMessage { +impl ToProtocol for OneOffQueryResponseMessage { type Encoded = SwitchedServerMessage; - fn to_protocol(self, protocol: Protocol) -> Self::Encoded { - fn convert(msg: OneOffQueryResponseMessage) -> ws::ServerMessage { - let tables = msg - .results - .into_iter() - .map(|table| ws::OneOffTable { - table_name: table.head.table_name.clone(), - rows: F::encode_list(table.data.into_iter()).0, - }) - .collect(); - ws::ServerMessage::OneOffQueryResponse(ws::OneOffQueryResponse { - message_id: msg.message_id.into(), - error: msg.error.map(Into::into), - tables, - total_host_execution_duration_micros: msg.total_host_execution_duration, - }) - } + fn to_protocol(self, _: Protocol) -> Self::Encoded { + FormatSwitch::Bsatn(convert(self)) + } +} - match protocol { - Protocol::Text => FormatSwitch::Json(convert(self)), - Protocol::Binary => FormatSwitch::Bsatn(convert(self)), - } +impl ToProtocol for OneOffQueryResponseMessage { + type Encoded = SwitchedServerMessage; + fn to_protocol(self, _: Protocol) -> Self::Encoded { + FormatSwitch::Json(convert(self)) } } + +fn convert(msg: OneOffQueryResponseMessage) -> ws::ServerMessage { + ws::ServerMessage::OneOffQueryResponse(ws::OneOffQueryResponse { + message_id: msg.message_id.into(), + error: msg.error.map(Into::into), + tables: msg.results.into_boxed_slice(), + total_host_execution_duration_micros: msg.total_host_execution_duration, + }) +} diff --git a/crates/core/src/db/datastore/locking_tx_datastore/tx.rs b/crates/core/src/db/datastore/locking_tx_datastore/tx.rs index aef949dd7a9..5057a0b36a5 100644 --- a/crates/core/src/db/datastore/locking_tx_datastore/tx.rs +++ b/crates/core/src/db/datastore/locking_tx_datastore/tx.rs @@ -7,9 +7,12 @@ use super::{ }; use crate::db::datastore::locking_tx_datastore::state_view::IterTx; use crate::execution_context::ExecutionContext; +use spacetimedb_execution::Datastore; use spacetimedb_primitives::{ColList, TableId}; use spacetimedb_sats::AlgebraicValue; use spacetimedb_schema::schema::TableSchema; +use spacetimedb_table::blob_store::BlobStore; +use spacetimedb_table::table::Table; use std::num::NonZeroU64; use std::sync::Arc; use std::{ @@ -24,6 +27,16 @@ pub struct TxId { pub(crate) ctx: ExecutionContext, } +impl Datastore for TxId { + fn blob_store(&self) -> &dyn BlobStore { + &self.committed_state_shared_lock.blob_store + } + + fn table(&self, table_id: TableId) -> Option<&Table> { + self.committed_state_shared_lock.get_table(table_id) + } +} + impl StateView for TxId { type Iter<'a> = IterTx<'a>; type IterByColRange<'a, R: RangeBounds> = IterByColRangeTx<'a, R>; diff --git a/crates/core/src/db/update.rs b/crates/core/src/db/update.rs index 8b040dffca7..5617cfe17d5 100644 --- a/crates/core/src/db/update.rs +++ b/crates/core/src/db/update.rs @@ -227,7 +227,7 @@ fn auto_migrate_database( system_logger.info(&format!("Adding row-level security `{sql_rls}`")); log::info!("Adding row-level security `{sql_rls}`"); let rls = plan.new.lookup_expect(sql_rls); - let rls = RowLevelExpr::build_row_level_expr(stdb, tx, &auth_ctx, rls)?; + let rls = RowLevelExpr::build_row_level_expr(tx, &auth_ctx, rls)?; stdb.create_row_level_security(tx, rls.def)?; } diff --git a/crates/core/src/estimation.rs b/crates/core/src/estimation.rs index 9edb2e1a514..a6f475db17e 100644 --- a/crates/core/src/estimation.rs +++ b/crates/core/src/estimation.rs @@ -1,4 +1,6 @@ use crate::db::{datastore::locking_tx_datastore::state_view::StateView as _, relational_db::Tx}; +use spacetimedb_lib::query::Delta; +use spacetimedb_physical_plan::plan::{HashJoin, IxJoin, IxScan, PhysicalPlan, Sarg}; use spacetimedb_primitives::{ColList, TableId}; use spacetimedb_vm::expr::{Query, QueryExpr, SourceExpr}; @@ -7,6 +9,80 @@ pub fn num_rows(tx: &Tx, expr: &QueryExpr) -> u64 { row_est(tx, &expr.source, &expr.query) } +/// Use cardinality estimates to predict the total number of rows scanned by a query +pub fn estimate_rows_scanned(tx: &Tx, plan: &PhysicalPlan) -> u64 { + match plan { + PhysicalPlan::TableScan(..) | PhysicalPlan::IxScan(..) => row_estimate(tx, plan), + PhysicalPlan::Filter(input, _) => estimate_rows_scanned(tx, input).saturating_add(row_estimate(tx, input)), + PhysicalPlan::NLJoin(lhs, rhs) => estimate_rows_scanned(tx, lhs) + .saturating_add(estimate_rows_scanned(tx, rhs)) + .saturating_add(row_estimate(tx, lhs).saturating_mul(row_estimate(tx, rhs))), + PhysicalPlan::IxJoin(IxJoin { lhs, unique: true, .. }, _) => { + estimate_rows_scanned(tx, lhs).saturating_add(row_estimate(tx, lhs)) + } + PhysicalPlan::IxJoin( + IxJoin { + lhs, rhs, rhs_field, .. + }, + _, + ) => estimate_rows_scanned(tx, lhs).saturating_add(row_estimate(tx, lhs).saturating_mul(index_row_est( + tx, + rhs.table_id, + &ColList::from(*rhs_field), + ))), + PhysicalPlan::HashJoin( + HashJoin { + lhs, rhs, unique: true, .. + }, + _, + ) => estimate_rows_scanned(tx, lhs) + .saturating_add(estimate_rows_scanned(tx, rhs)) + .saturating_add(row_estimate(tx, lhs)), + PhysicalPlan::HashJoin(HashJoin { lhs, rhs, .. }, _) => estimate_rows_scanned(tx, lhs) + .saturating_add(estimate_rows_scanned(tx, rhs)) + .saturating_add(row_estimate(tx, lhs).saturating_mul(row_estimate(tx, rhs))), + } +} + +/// Estimate the cardinality of a physical plan +pub fn row_estimate(tx: &Tx, plan: &PhysicalPlan) -> u64 { + match plan { + // Table scans return the number of rows in the table + PhysicalPlan::TableScan(schema, _, None) => tx.table_row_count(schema.table_id).unwrap_or_default(), + PhysicalPlan::TableScan(_, _, Some(Delta::Inserts(n) | Delta::Deletes(n))) => *n as u64, + // The selectivity of a single column index scan is 1 / NDV, + // where NDV is the Number of Distinct Values of a column. + // Note, this assumes a uniform distribution of column values. + PhysicalPlan::IxScan( + ix @ IxScan { + arg: Sarg::Eq(col_id, _), + .. + }, + _, + ) if ix.prefix.is_empty() => index_row_est(tx, ix.schema.table_id, &ColList::from(*col_id)), + // For all other index scans we assume a worst-case scenario. + PhysicalPlan::IxScan(IxScan { schema, .. }, _) => tx.table_row_count(schema.table_id).unwrap_or_default(), + // Same for filters + PhysicalPlan::Filter(input, _) => row_estimate(tx, input), + // Nested loop joins are cross joins + PhysicalPlan::NLJoin(lhs, rhs) => row_estimate(tx, lhs).saturating_mul(row_estimate(tx, rhs)), + // Unique joins return a maximal estimation. + // We assume every lhs row has a matching rhs row. + PhysicalPlan::IxJoin(IxJoin { lhs, unique: true, .. }, _) + | PhysicalPlan::HashJoin(HashJoin { lhs, unique: true, .. }, _) => row_estimate(tx, lhs), + // Otherwise we estimate the rows returned from the rhs + PhysicalPlan::IxJoin( + IxJoin { + lhs, rhs, rhs_field, .. + }, + _, + ) => row_estimate(tx, lhs).saturating_mul(index_row_est(tx, rhs.table_id, &ColList::from(*rhs_field))), + PhysicalPlan::HashJoin(HashJoin { lhs, rhs, .. }, _) => { + row_estimate(tx, lhs).saturating_mul(row_estimate(tx, rhs)) + } + } +} + /// The estimated number of rows that a query sub-plan will return. fn row_est(tx: &Tx, src: &SourceExpr, ops: &[Query]) -> u64 { match ops { @@ -68,6 +144,7 @@ fn index_row_est(tx: &Tx, table_id: TableId, cols: &ColList) -> u64 { mod tests { use crate::db::relational_db::tests_utils::insert; use crate::execution_context::Workload; + use crate::sql::ast::SchemaViewer; use crate::{ db::relational_db::{tests_utils::TestDB, RelationalDB}, error::DBError, @@ -75,9 +152,12 @@ mod tests { sql::compiler::compile_sql, }; use spacetimedb_lib::{identity::AuthCtx, AlgebraicType}; + use spacetimedb_query::SubscribePlan; use spacetimedb_sats::product; use spacetimedb_vm::expr::CrudExpr; + use super::row_estimate; + fn in_mem_db() -> TestDB { TestDB::in_memory().expect("failed to make test db") } @@ -90,6 +170,15 @@ mod tests { } } + /// Using the new query plan + fn new_row_estimate(db: &RelationalDB, sql: &str) -> u64 { + let auth = AuthCtx::for_testing(); + let tx = db.begin_tx(Workload::ForTests); + let tx = SchemaViewer::new(&tx, &auth); + let plan = SubscribePlan::compile(sql, &tx).expect("failed to compile sql"); + row_estimate(&tx, &plan) + } + const NUM_T_ROWS: u64 = 10; const NDV_T: u64 = 5; const NUM_S_ROWS: u64 = 2; @@ -141,14 +230,19 @@ mod tests { fn cardinality_estimation_index_lookup() { let db = in_mem_db(); create_table_t(&db, true); - assert_eq!(NUM_T_ROWS / NDV_T, num_rows_for(&db, "select * from T where a = 0")); + let sql = "select * from T where a = 0"; + let est = NUM_T_ROWS / NDV_T; + assert_eq!(est, num_rows_for(&db, sql)); + assert_eq!(est, new_row_estimate(&db, sql)); } #[test] fn cardinality_estimation_0_ndv() { let db = in_mem_db(); create_empty_table_r(&db, true); - assert_eq!(0, num_rows_for(&db, "select * from R where a = 0")); + let sql = "select * from R where a = 0"; + assert_eq!(0, num_rows_for(&db, sql)); + assert_eq!(0, new_row_estimate(&db, sql)); } /// We estimate an index range to return all input rows. @@ -156,7 +250,9 @@ mod tests { fn cardinality_estimation_index_range() { let db = in_mem_db(); create_table_t(&db, true); - assert_eq!(NUM_T_ROWS, num_rows_for(&db, "select * from T where a > 0 and a < 2")); + let sql = "select * from T where a > 0 and a < 2"; + assert_eq!(NUM_T_ROWS, num_rows_for(&db, sql)); + assert_eq!(NUM_T_ROWS, new_row_estimate(&db, sql)); } /// We estimate a selection on a non-indexed column to return all input rows. @@ -164,7 +260,9 @@ mod tests { fn select_cardinality_estimation() { let db = in_mem_db(); create_table_t(&db, true); - assert_eq!(NUM_T_ROWS, num_rows_for(&db, "select * from T where b = 0")); + let sql = "select * from T where b = 0"; + assert_eq!(NUM_T_ROWS, num_rows_for(&db, sql)); + assert_eq!(NUM_T_ROWS, new_row_estimate(&db, sql)); } /// We estimate a projection to return all input rows. @@ -172,7 +270,8 @@ mod tests { fn project_cardinality_estimation() { let db = in_mem_db(); create_table_t(&db, true); - assert_eq!(NUM_T_ROWS, num_rows_for(&db, "select a from T")); + let sql = "select a from T"; + assert_eq!(NUM_T_ROWS, num_rows_for(&db, sql)); } /// We estimate an inner join to return the product of its input sizes. @@ -181,10 +280,10 @@ mod tests { let db = in_mem_db(); create_table_t(&db, false); create_table_s(&db, false); - assert_eq!( - NUM_T_ROWS * NUM_S_ROWS, // => 20 - num_rows_for(&db, "select T.* from T join S on T.a = S.a where S.c = 0") - ); + let sql = "select T.* from T join S on T.a = S.a where S.c = 0"; + let est = NUM_T_ROWS * NUM_S_ROWS; + assert_eq!(est, num_rows_for(&db, sql)); + assert_eq!(est, new_row_estimate(&db, sql)); } /// An index join estimates its output cardinality in the same way. @@ -194,9 +293,9 @@ mod tests { let db = in_mem_db(); create_table_t(&db, true); create_table_s(&db, true); - assert_eq!( - NUM_T_ROWS / NDV_T * NUM_S_ROWS / NDV_S, // => 2 - num_rows_for(&db, "select T.* from T join S on T.a = S.a where S.c = 0") - ); + let sql = "select T.* from T join S on T.a = S.a where S.c = 0"; + let est = NUM_T_ROWS / NDV_T * NUM_S_ROWS / NDV_S; + assert_eq!(est, num_rows_for(&db, sql)); + assert_eq!(est, new_row_estimate(&db, sql)); } } diff --git a/crates/core/src/host/module_host.rs b/crates/core/src/host/module_host.rs index 83758e86d5e..0e892877a65 100644 --- a/crates/core/src/host/module_host.rs +++ b/crates/core/src/host/module_host.rs @@ -7,14 +7,17 @@ use crate::db::datastore::system_tables::{StClientFields, StClientRow, ST_CLIENT use crate::db::datastore::traits::{IsolationLevel, Program, TxData}; use crate::energy::EnergyQuanta; use crate::error::DBError; +use crate::estimation::estimate_rows_scanned; use crate::execution_context::{ExecutionContext, ReducerContext, Workload}; use crate::hash::Hash; use crate::identity::Identity; use crate::messages::control_db::Database; use crate::replica_context::ReplicaContext; -use crate::sql; +use crate::sql::ast::SchemaViewer; use crate::subscription::module_subscription_actor::ModuleSubscriptions; +use crate::subscription::tx::DeltaTx; use crate::util::lending_pool::{Closed, LendingPool, LentResource, PoolClosed}; +use crate::vm::check_row_limit; use crate::worker_metrics::WORKER_METRICS; use anyhow::Context; use bytes::Bytes; @@ -24,17 +27,18 @@ use indexmap::IndexSet; use itertools::Itertools; use smallvec::SmallVec; use spacetimedb_client_api_messages::timestamp::Timestamp; -use spacetimedb_client_api_messages::websocket::{Compression, QueryUpdate, WebsocketFormat}; +use spacetimedb_client_api_messages::websocket::{Compression, OneOffTable, QueryUpdate, WebsocketFormat}; use spacetimedb_data_structures::error_stream::ErrorStream; use spacetimedb_data_structures::map::{HashCollectionExt as _, IntMap}; use spacetimedb_lib::identity::{AuthCtx, RequestId}; use spacetimedb_lib::Address; use spacetimedb_primitives::{col_list, TableId}; +use spacetimedb_query::SubscribePlan; use spacetimedb_sats::{algebraic_value, ProductValue}; use spacetimedb_schema::auto_migrate::AutoMigrateError; use spacetimedb_schema::def::deserialize::ReducerArgsDeserializeSeed; use spacetimedb_schema::def::{ModuleDef, ReducerDef}; -use spacetimedb_vm::relation::{MemTable, RelValue}; +use spacetimedb_vm::relation::RelValue; use std::fmt; use std::sync::{Arc, Weak}; use std::time::{Duration, Instant}; @@ -836,15 +840,25 @@ impl ModuleHost { } #[tracing::instrument(skip_all)] - pub fn one_off_query(&self, caller_identity: Identity, query: String) -> Result, anyhow::Error> { + pub fn one_off_query( + &self, + caller_identity: Identity, + query: String, + ) -> Result, anyhow::Error> { let replica_ctx = self.replica_ctx(); let db = &replica_ctx.relational_db; let auth = AuthCtx::new(replica_ctx.owner_identity, caller_identity); log::debug!("One-off query: {query}"); db.with_read_only(Workload::Sql, |tx| { - let ast = sql::compiler::compile_sql(db, &auth, tx, &query)?; - sql::execute::execute_sql_tx(db, tx, &query, ast, auth)? + let tx = SchemaViewer::new(tx, &auth); + let plan = SubscribePlan::compile(&query, &tx)?; + check_row_limit(&plan, db, &tx, |plan, tx| estimate_rows_scanned(tx, plan), &auth)?; + plan.execute::<_, F>(&DeltaTx::from(&*tx)) + .map(|(rows, _)| OneOffTable { + table_name: plan.table_name().to_owned().into_boxed_str(), + rows, + }) .context("One-off queries are not allowed to modify the database") }) } diff --git a/crates/core/src/host/wasm_common/module_host_actor.rs b/crates/core/src/host/wasm_common/module_host_actor.rs index 77f4bd7bf98..34164317904 100644 --- a/crates/core/src/host/wasm_common/module_host_actor.rs +++ b/crates/core/src/host/wasm_common/module_host_actor.rs @@ -282,7 +282,7 @@ impl ModuleInstance for WasmModuleInstance { self.system_logger() .info(&format!("Creating row level security `{}`", rls.sql)); - let rls = RowLevelExpr::build_row_level_expr(stdb, tx, &auth_ctx, rls) + let rls = RowLevelExpr::build_row_level_expr(tx, &auth_ctx, rls) .with_context(|| format!("failed to create row-level security: `{}`", rls.sql))?; let table_id = rls.def.table_id; let sql = rls.def.sql.clone(); diff --git a/crates/core/src/sql/ast.rs b/crates/core/src/sql/ast.rs index 15ab7443faa..4fd4fa4ab32 100644 --- a/crates/core/src/sql/ast.rs +++ b/crates/core/src/sql/ast.rs @@ -1,3 +1,4 @@ +use crate::db::datastore::locking_tx_datastore::state_view::StateView; use crate::db::relational_db::{MutTx, RelationalDB, Tx}; use crate::error::{DBError, PlanError}; use spacetimedb_data_structures::map::{HashCollectionExt as _, IntMap}; @@ -7,7 +8,7 @@ use spacetimedb_lib::db::auth::StAccess; use spacetimedb_lib::db::error::RelationError; use spacetimedb_lib::identity::AuthCtx; use spacetimedb_lib::relation::{ColExpr, FieldName}; -use spacetimedb_primitives::ColId; +use spacetimedb_primitives::{ColId, TableId}; use spacetimedb_sats::{AlgebraicType, AlgebraicValue}; use spacetimedb_schema::schema::{ColumnSchema, TableSchema}; use spacetimedb_vm::errors::ErrorVm; @@ -20,6 +21,7 @@ use sqlparser::ast::{ }; use sqlparser::dialect::PostgreSqlDialect; use sqlparser::parser::Parser; +use std::ops::Deref; use std::sync::Arc; /// Simplify to detect features of the syntax we don't support yet @@ -474,32 +476,43 @@ fn compile_where(table: &From, filter: Option) -> Result { - db: &'a RelationalDB, tx: &'a T, auth: &'a AuthCtx, } -impl SchemaView for SchemaViewer<'_, T> { - fn schema(&self, name: &str) -> Option> { - let name = name.to_owned().into_boxed_str(); - let schema = self - .tx - .find_table(self.db, Table { name }) - .map(Some) - // If there was an error fetching the table schema, - // we swallow it and return None. - // It will be surfaced as a name resolution error instead. - .unwrap_or_else(|_| None)?; - if schema.table_access == StAccess::Private && self.auth.caller != self.auth.owner { - return None; - } - Some(schema) +impl Deref for SchemaViewer<'_, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.tx + } +} + +impl SchemaView for SchemaViewer<'_, T> { + fn table_id(&self, name: &str) -> Option { + let AuthCtx { owner, caller } = self.auth; + // Get the schema from the in-memory state instead of fetching from the database for speed + self.tx + .table_id_from_name(name) + .ok() + .flatten() + .and_then(|table_id| self.schema_for_table(table_id)) + .filter(|schema| schema.table_access == StAccess::Public || caller == owner) + .map(|schema| schema.table_id) + } + + fn schema_for_table(&self, table_id: TableId) -> Option> { + let AuthCtx { owner, caller } = self.auth; + self.tx + .get_schema(table_id) + .filter(|schema| schema.table_access == StAccess::Public || caller == owner) + .cloned() } } impl<'a, T> SchemaViewer<'a, T> { - pub fn new(db: &'a RelationalDB, tx: &'a T, auth: &'a AuthCtx) -> Self { - Self { db, tx, auth } + pub fn new(tx: &'a T, auth: &'a AuthCtx) -> Self { + Self { tx, auth } } } @@ -534,7 +547,11 @@ impl TableSchemaView for MutTx { } /// Compiles the `FROM` clause -fn compile_from(db: &RelationalDB, tx: &T, from: &[TableWithJoins]) -> Result { +fn compile_from( + db: &RelationalDB, + tx: &T, + from: &[TableWithJoins], +) -> Result { if from.len() > 1 { return Err(PlanError::Unsupported { feature: "Multiple tables in `FROM`.".into(), @@ -655,7 +672,11 @@ fn compile_select_item(from: &From, select_item: SelectItem) -> Result(db: &RelationalDB, tx: &T, select: Select) -> Result { +fn compile_select( + db: &RelationalDB, + tx: &T, + select: Select, +) -> Result { let from = compile_from(db, tx, &select.from)?; // SELECT ... @@ -675,7 +696,7 @@ fn compile_select(db: &RelationalDB, tx: &T, select: Select) } /// Compiles any `query` clause (currently only `SELECT...`) -fn compile_query(db: &RelationalDB, tx: &T, query: Query) -> Result { +fn compile_query(db: &RelationalDB, tx: &T, query: Query) -> Result { unsupported!( "SELECT", query.order_by, @@ -728,7 +749,7 @@ fn compile_query(db: &RelationalDB, tx: &T, query: Query) -> } /// Compiles the `INSERT ...` clause -fn compile_insert( +fn compile_insert( db: &RelationalDB, tx: &T, table_name: ObjectName, @@ -767,7 +788,7 @@ fn compile_insert( } /// Compiles the `UPDATE ...` clause -fn compile_update( +fn compile_update( db: &RelationalDB, tx: &T, table: Table, @@ -795,7 +816,7 @@ fn compile_update( } /// Compiles the `DELETE ...` clause -fn compile_delete( +fn compile_delete( db: &RelationalDB, tx: &T, table: Table, @@ -856,7 +877,11 @@ fn compile_read_config(name: Vec) -> Result { } /// Compiles a `SQL` clause -fn compile_statement(db: &RelationalDB, tx: &T, statement: Statement) -> Result { +fn compile_statement( + db: &RelationalDB, + tx: &T, + statement: Statement, +) -> Result { match statement { Statement::Query(query) => Ok(compile_query(db, tx, *query)?), Statement::Insert { @@ -944,7 +969,7 @@ fn compile_statement(db: &RelationalDB, tx: &T, statement: S } /// Compiles a `sql` string into a `Vec` using a SQL parser with [PostgreSqlDialect] -pub(crate) fn compile_to_ast( +pub(crate) fn compile_to_ast( db: &RelationalDB, auth: &AuthCtx, tx: &T, @@ -952,7 +977,7 @@ pub(crate) fn compile_to_ast( ) -> Result, DBError> { // NOTE: The following ensures compliance with the 1.0 sql api. // Come 1.0, it will have replaced the current compilation stack. - compile_sql_stmt(sql_text, &SchemaViewer::new(db, tx, auth))?; + compile_sql_stmt(sql_text, &SchemaViewer::new(tx, auth))?; let dialect = PostgreSqlDialect {}; let ast = Parser::parse_sql(&dialect, sql_text).map_err(|error| DBError::SqlParser { diff --git a/crates/core/src/sql/compiler.rs b/crates/core/src/sql/compiler.rs index 00cf1e1cec0..90441eed9a6 100644 --- a/crates/core/src/sql/compiler.rs +++ b/crates/core/src/sql/compiler.rs @@ -1,5 +1,6 @@ use super::ast::{compile_to_ast, Column, From, Join, Selection, SqlAst}; use super::type_check::TypeCheck; +use crate::db::datastore::locking_tx_datastore::state_view::StateView; use crate::db::relational_db::RelationalDB; use crate::error::{DBError, PlanError}; use core::ops::Deref; @@ -20,7 +21,7 @@ use super::ast::TableSchemaView; const MAX_SQL_LENGTH: usize = 50_000; /// Compile the `SQL` expression into an `ast` -pub fn compile_sql( +pub fn compile_sql( db: &RelationalDB, auth: &AuthCtx, tx: &T, @@ -268,7 +269,11 @@ mod tests { assert!(matches!(op, Query::Select(_))); } - fn compile_sql(db: &RelationalDB, tx: &T, sql: &str) -> Result, DBError> { + fn compile_sql( + db: &RelationalDB, + tx: &T, + sql: &str, + ) -> Result, DBError> { super::compile_sql(db, &AuthCtx::for_testing(), tx, sql) } diff --git a/crates/core/src/sql/parser.rs b/crates/core/src/sql/parser.rs index b02485aef1b..8d5c096a603 100644 --- a/crates/core/src/sql/parser.rs +++ b/crates/core/src/sql/parser.rs @@ -1,26 +1,24 @@ use crate::db::datastore::locking_tx_datastore::MutTxId; -use crate::db::relational_db::RelationalDB; use crate::sql::ast::SchemaViewer; use spacetimedb_expr::check::parse_and_type_sub; use spacetimedb_expr::errors::TypingError; -use spacetimedb_expr::expr::Project; +use spacetimedb_expr::expr::ProjectName; use spacetimedb_lib::db::raw_def::v9::RawRowLevelSecurityDefV9; use spacetimedb_lib::identity::AuthCtx; use spacetimedb_schema::schema::RowLevelSecuritySchema; pub struct RowLevelExpr { - pub sql: Project, + pub sql: ProjectName, pub def: RowLevelSecuritySchema, } impl RowLevelExpr { pub fn build_row_level_expr( - stdb: &RelationalDB, tx: &mut MutTxId, auth_ctx: &AuthCtx, rls: &RawRowLevelSecurityDefV9, ) -> Result { - let sql = parse_and_type_sub(&rls.sql, &SchemaViewer::new(stdb, tx, auth_ctx))?; + let sql = parse_and_type_sub(&rls.sql, &SchemaViewer::new(tx, auth_ctx))?; Ok(Self { def: RowLevelSecuritySchema { diff --git a/crates/core/src/subscription/delta.rs b/crates/core/src/subscription/delta.rs new file mode 100644 index 00000000000..df160de544c --- /dev/null +++ b/crates/core/src/subscription/delta.rs @@ -0,0 +1,61 @@ +use std::collections::HashMap; + +use anyhow::Result; +use spacetimedb_execution::{Datastore, DeltaStore}; +use spacetimedb_query::delta::DeltaPlanEvaluator; +use spacetimedb_vm::relation::RelValue; + +use crate::host::module_host::UpdatesRelValue; + +/// This utility deduplicates an incremental update. +/// That is, if a row is both inserted and deleted, +/// this method removes it from the result set. +/// +/// Note, the 1.0 api does allow for duplicate rows. +/// Hence this may be removed at any time after 1.0. +pub fn eval_delta<'a, Tx: Datastore + DeltaStore>( + tx: &'a Tx, + delta: &'a DeltaPlanEvaluator, +) -> Result> { + if !delta.is_join() { + return Ok(UpdatesRelValue { + inserts: delta.eval_inserts(tx)?.map(RelValue::from).collect(), + deletes: delta.eval_deletes(tx)?.map(RelValue::from).collect(), + }); + } + if delta.has_inserts() && !delta.has_deletes() { + return Ok(UpdatesRelValue { + inserts: delta.eval_inserts(tx)?.map(RelValue::from).collect(), + deletes: vec![], + }); + } + if delta.has_deletes() && !delta.has_inserts() { + return Ok(UpdatesRelValue { + deletes: delta.eval_deletes(tx)?.map(RelValue::from).collect(), + inserts: vec![], + }); + } + let mut inserts = HashMap::new(); + + for row in delta.eval_inserts(tx)?.map(RelValue::from) { + inserts.entry(row).and_modify(|n| *n += 1).or_insert(1); + } + + let deletes = delta + .eval_deletes(tx)? + .map(RelValue::from) + .filter(|row| match inserts.get_mut(row) { + None => true, + Some(1) => inserts.remove(row).is_none(), + Some(n) => { + *n -= 1; + false + } + }) + .collect(); + + Ok(UpdatesRelValue { + inserts: inserts.into_keys().collect(), + deletes, + }) +} diff --git a/crates/core/src/subscription/mod.rs b/crates/core/src/subscription/mod.rs index 761fd103ec6..e9bb8e519e4 100644 --- a/crates/core/src/subscription/mod.rs +++ b/crates/core/src/subscription/mod.rs @@ -1,6 +1,8 @@ +pub mod delta; pub mod execution_unit; pub mod module_subscription_actor; pub mod module_subscription_manager; pub mod query; #[allow(clippy::module_inception)] // it's right this isn't ideal :/ pub mod subscription; +pub mod tx; diff --git a/crates/core/src/subscription/module_subscription_actor.rs b/crates/core/src/subscription/module_subscription_actor.rs index ac7796d287d..d0205e5b7cf 100644 --- a/crates/core/src/subscription/module_subscription_actor.rs +++ b/crates/core/src/subscription/module_subscription_actor.rs @@ -1,32 +1,28 @@ -use super::execution_unit::{ExecutionUnit, QueryHash}; -use super::module_subscription_manager::SubscriptionManager; -use super::query::{compile_read_only_query, compile_read_only_queryset}; -use super::subscription::ExecutionSet; +use super::execution_unit::QueryHash; +use super::module_subscription_manager::{Plan, SubscriptionManager}; +use super::query::compile_read_only_query; +use super::tx::DeltaTx; use crate::client::messages::{ SubscriptionError, SubscriptionMessage, SubscriptionResult, SubscriptionRows, SubscriptionUpdateMessage, TransactionUpdateMessage, }; use crate::client::{ClientActorId, ClientConnectionSender, Protocol}; use crate::db::datastore::locking_tx_datastore::tx::TxId; -use crate::db::datastore::system_tables::StVarTable; use crate::db::relational_db::{MutTx, RelationalDB, Tx}; use crate::error::DBError; +use crate::estimation::estimate_rows_scanned; use crate::execution_context::Workload; use crate::host::module_host::{DatabaseUpdate, EventStatus, ModuleEvent}; use crate::messages::websocket::Subscribe; -use crate::sql::ast::SchemaViewer; use crate::vm::check_row_limit; use crate::worker_metrics::WORKER_METRICS; use parking_lot::RwLock; use spacetimedb_client_api_messages::websocket::{ BsatnFormat, FormatSwitch, JsonFormat, SubscribeSingle, TableUpdate, Unsubscribe, }; -use spacetimedb_expr::check::compile_sql_sub; use spacetimedb_lib::identity::AuthCtx; use spacetimedb_lib::Identity; -use spacetimedb_vm::errors::ErrorVm; -use spacetimedb_vm::expr::AuthAccess; -use std::time::Duration; +use spacetimedb_query::{execute_plans, SubscribePlan}; use std::{sync::Arc, time::Instant}; type Subscriptions = Arc>; @@ -55,44 +51,25 @@ impl ModuleSubscriptions { fn evaluate_initial_subscription( &self, sender: Arc, - query: Arc, - auth: AuthCtx, + query: Arc, tx: &TxId, + auth: &AuthCtx, ) -> Result, TableUpdate>, DBError> { - query.check_auth(auth.owner, auth.caller).map_err(ErrorVm::Auth)?; + let comp = sender.config.compression; + let plan = SubscribePlan::from_delta_plan(&query); check_row_limit( - &query, + &plan, &self.relational_db, tx, - |query, tx| query.row_estimate(tx), - &auth, + |plan, tx| estimate_rows_scanned(tx, plan), + auth, )?; - let slow_query_threshold = StVarTable::sub_limit(&self.relational_db, tx)?.map(Duration::from_millis); + let tx = DeltaTx::from(tx); Ok(match sender.config.protocol { - Protocol::Binary => FormatSwitch::Bsatn( - query - .eval( - &self.relational_db, - tx, - &query.sql, - slow_query_threshold, - sender.config.compression, - ) - .unwrap_or(TableUpdate::empty(query.return_table(), query.return_name())), - ), - Protocol::Text => FormatSwitch::Json( - query - .eval( - &self.relational_db, - tx, - &query.sql, - slow_query_threshold, - sender.config.compression, - ) - .unwrap_or(TableUpdate::empty(query.return_table(), query.return_name())), - ), + Protocol::Binary => FormatSwitch::Bsatn(plan.collect_table_update(comp, &tx)?), + Protocol::Text => FormatSwitch::Json(plan.collect_table_update(comp, &tx)?), }) } @@ -115,17 +92,13 @@ impl ModuleSubscriptions { let query = if let Some(unit) = guard.query(&hash) { unit } else { - // NOTE: The following ensures compliance with the 1.0 sql api. - // Come 1.0, it will have replaced the current compilation stack. - compile_sql_sub(sql, &SchemaViewer::new(&self.relational_db, &*tx, &auth))?; - - let compiled = compile_read_only_query(&self.relational_db, &auth, &tx, sql)?; - Arc::new(ExecutionUnit::new(compiled, hash)?) + let compiled = compile_read_only_query(&auth, &tx, sql)?; + Arc::new(compiled) }; drop(guard); - let table_rows = self.evaluate_initial_subscription(sender.clone(), query.clone(), auth, &tx)?; + let table_rows = self.evaluate_initial_subscription(sender.clone(), query.clone(), &tx, &auth)?; // It acquires the subscription lock after `eval`, allowing `add_subscription` to run concurrently. // This also makes it possible for `broadcast_event` to get scheduled before the subsequent part here @@ -152,8 +125,8 @@ impl ModuleSubscriptions { query_id: Some(request.query_id), timer: Some(timer), result: SubscriptionResult::Subscribe(SubscriptionRows { - table_id: query.return_table(), - table_name: query.return_name(), + table_id: query.table_id(), + table_name: query.table_name(), table_rows, }), }); @@ -183,11 +156,12 @@ impl ModuleSubscriptions { return Ok(()); } }; - let auth = AuthCtx::new(self.owner_identity, sender.id.identity); + let tx = scopeguard::guard(self.relational_db.begin_tx(Workload::Unsubscribe), |tx| { self.relational_db.release_tx(tx); }); - let table_rows = self.evaluate_initial_subscription(sender.clone(), query.clone(), auth, &tx)?; + let auth = AuthCtx::new(self.owner_identity, sender.id.identity); + let table_rows = self.evaluate_initial_subscription(sender.clone(), query.clone(), &tx, &auth)?; WORKER_METRICS .subscription_queries @@ -198,8 +172,8 @@ impl ModuleSubscriptions { query_id: Some(request.query_id), timer: Some(timer), result: SubscriptionResult::Unsubscribe(SubscriptionRows { - table_id: query.return_table(), - table_name: query.return_name(), + table_id: query.table_id(), + table_name: query.table_name(), table_rows, }), }); @@ -235,11 +209,7 @@ impl ModuleSubscriptions { queries.extend( super::subscription::get_all(&self.relational_db, &tx, &auth)? .into_iter() - .map(|query| { - let hash = QueryHash::from_string(&query.sql); - ExecutionUnit::new(query, hash).map(Arc::new) - }) - .collect::, _>>()?, + .map(Arc::new), ); continue; } @@ -247,58 +217,46 @@ impl ModuleSubscriptions { if let Some(unit) = guard.query(&hash) { queries.push(unit); } else { - // NOTE: The following ensures compliance with the 1.0 sql api. - // Come 1.0, it will have replaced the current compilation stack. - compile_sql_sub(sql, &SchemaViewer::new(&self.relational_db, &*tx, &auth))?; - - let mut compiled = compile_read_only_queryset(&self.relational_db, &auth, &tx, sql)?; - // Note that no error path is needed here. - // We know this vec only has a single element, - // since `parse_and_type_sub` guarantees it. - // This check will be removed come 1.0. - if compiled.len() == 1 { - queries.push(Arc::new(ExecutionUnit::new(compiled.remove(0), hash)?)); - } + let compiled = compile_read_only_query(&auth, &tx, sql)?; + queries.push(Arc::new(compiled)); } } drop(guard); - let execution_set: ExecutionSet = queries.into(); - - execution_set - .check_auth(auth.owner, auth.caller) - .map_err(ErrorVm::Auth)?; + let comp = sender.config.compression; + let plans = queries + .iter() + .map(|plan| &***plan) + .map(SubscribePlan::from_delta_plan) + .collect::>(); + + fn rows_scanned(tx: &TxId, plans: &[SubscribePlan]) -> u64 { + plans + .iter() + .map(|plan| estimate_rows_scanned(tx, plan)) + .fold(0, |acc, n| acc.saturating_add(n)) + } check_row_limit( - &execution_set, + &plans, &self.relational_db, &tx, - |execution_set, tx| execution_set.row_estimate(tx), + |plan, tx| rows_scanned(tx, plan), &auth, )?; - let slow_query_threshold = StVarTable::sub_limit(&self.relational_db, &tx)?.map(Duration::from_millis); + let tx = DeltaTx::from(&*tx); let database_update = match sender.config.protocol { - Protocol::Text => FormatSwitch::Json(execution_set.eval( - &self.relational_db, - &tx, - slow_query_threshold, - sender.config.compression, - )), - Protocol::Binary => FormatSwitch::Bsatn(execution_set.eval( - &self.relational_db, - &tx, - slow_query_threshold, - sender.config.compression, - )), + Protocol::Text => FormatSwitch::Json(execute_plans(plans, comp, &tx)?), + Protocol::Binary => FormatSwitch::Bsatn(execute_plans(plans, comp, &tx)?), }; // It acquires the subscription lock after `eval`, allowing `add_subscription` to run concurrently. // This also makes it possible for `broadcast_event` to get scheduled before the subsequent part here // but that should not pose an issue. let mut subscriptions = self.subscriptions.write(); - subscriptions.set_legacy_subscription(sender.clone(), execution_set.into_iter()); + subscriptions.set_legacy_subscription(sender.clone(), queries.into_iter()); let num_queries = subscriptions.num_unique_queries(); WORKER_METRICS @@ -345,30 +303,32 @@ impl ModuleSubscriptions { let stdb = &self.relational_db; // Downgrade mutable tx. // Ensure tx is released/cleaned up once out of scope. - let read_tx = scopeguard::guard( - match &mut event.status { - EventStatus::Committed(db_update) => { - let Some((tx_data, read_tx)) = stdb.commit_tx_downgrade(tx, Workload::Update)? else { - return Ok(Err(WriteConflict)); - }; - *db_update = DatabaseUpdate::from_writes(&tx_data); - read_tx - } - EventStatus::Failed(_) | EventStatus::OutOfEnergy => { - stdb.rollback_mut_tx_downgrade(tx, Workload::Update) - } - }, - |tx| { - self.relational_db.release_tx(tx); - }, - ); + let (read_tx, tx_data) = match &mut event.status { + EventStatus::Committed(db_update) => { + let Some((tx_data, read_tx)) = stdb.commit_tx_downgrade(tx, Workload::Update)? else { + return Ok(Err(WriteConflict)); + }; + *db_update = DatabaseUpdate::from_writes(&tx_data); + (read_tx, Some(tx_data)) + } + EventStatus::Failed(_) | EventStatus::OutOfEnergy => { + (stdb.rollback_mut_tx_downgrade(tx, Workload::Update), None) + } + }; + + let read_tx = scopeguard::guard(read_tx, |tx| { + self.relational_db.release_tx(tx); + }); + + let read_tx = tx_data + .as_ref() + .map(|tx_data| DeltaTx::new(&read_tx, tx_data)) + .unwrap_or_else(|| DeltaTx::from(&*read_tx)); + let event = Arc::new(event); match &event.status { - EventStatus::Committed(_) => { - let slow_query_threshold = StVarTable::incr_limit(stdb, &read_tx)?.map(Duration::from_millis); - subscriptions.eval_updates(stdb, &read_tx, event.clone(), caller, slow_query_threshold) - } + EventStatus::Committed(_) => subscriptions.eval_updates(&read_tx, event.clone(), caller), EventStatus::Failed(_) => { if let Some(client) = caller { let message = TransactionUpdateMessage { @@ -398,7 +358,6 @@ mod tests { use crate::error::DBError; use crate::execution_context::Workload; use spacetimedb_client_api_messages::websocket::Subscribe; - use spacetimedb_expr::errors::{TypingError, Unresolved}; use spacetimedb_lib::db::auth::StAccess; use spacetimedb_lib::{error::ResultTest, AlgebraicType, Identity}; use spacetimedb_sats::product; @@ -499,10 +458,7 @@ mod tests { "SELECT public.* FROM public JOIN private ON public.a = private.a WHERE private.a = 1", "SELECT private.* FROM private JOIN public ON private.a = public.a WHERE public.a = 1", ] { - assert!(matches!( - subscribe(sql).unwrap_err(), - DBError::TypeError(TypingError::Unresolved(Unresolved::Table(_))) - )); + assert!(subscribe(sql).is_err(),); } Ok(()) diff --git a/crates/core/src/subscription/module_subscription_manager.rs b/crates/core/src/subscription/module_subscription_manager.rs index 03d96bfc779..369e02b9e62 100644 --- a/crates/core/src/subscription/module_subscription_manager.rs +++ b/crates/core/src/subscription/module_subscription_manager.rs @@ -1,11 +1,11 @@ -use super::execution_unit::{ExecutionUnit, QueryHash}; +use super::execution_unit::QueryHash; +use super::tx::DeltaTx; use crate::client::messages::{SubscriptionUpdateMessage, TransactionUpdateMessage}; use crate::client::{ClientConnectionSender, Protocol}; -use crate::db::relational_db::{RelationalDB, Tx}; use crate::error::DBError; use crate::host::module_host::{DatabaseTableUpdate, ModuleEvent, UpdatesRelValue}; use crate::messages::websocket::{self as ws, TableUpdate}; -use arrayvec::ArrayVec; +use crate::subscription::delta::eval_delta; use hashbrown::hash_map::OccupiedError; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use spacetimedb_client_api_messages::websocket::{ @@ -14,14 +14,15 @@ use spacetimedb_client_api_messages::websocket::{ use spacetimedb_data_structures::map::{Entry, HashCollectionExt, HashMap, HashSet, IntMap}; use spacetimedb_lib::{Address, Identity}; use spacetimedb_primitives::TableId; +use spacetimedb_query::delta::DeltaPlan; +use std::ops::Deref; use std::sync::Arc; -use std::time::Duration; /// Clients are uniquely identified by their Identity and Address. /// Identity is insufficient because different Addresses can use the same Identity. /// TODO: Determine if Address is sufficient for uniquely identifying a client. type ClientId = (Identity, Address); -type Query = Arc; +type Query = Arc; type Client = Arc; type SwitchedDbUpdate = FormatSwitch, ws::DatabaseUpdate>; @@ -30,6 +31,30 @@ type ClientQueryId = QueryId; /// SubscriptionId is a globally unique identifier for a subscription. type SubscriptionId = (ClientId, ClientQueryId); +#[derive(Debug)] +pub struct Plan { + hash: QueryHash, + plan: DeltaPlan, +} + +impl Deref for Plan { + type Target = DeltaPlan; + + fn deref(&self) -> &Self::Target { + &self.plan + } +} + +impl Plan { + pub fn new(plan: DeltaPlan, hash: QueryHash) -> Self { + Self { plan, hash } + } + + pub fn hash(&self) -> QueryHash { + self.hash + } +} + /// For each client, we hold a handle for sending messages, and we track the queries they are subscribed to. #[derive(Debug)] struct ClientInfo { @@ -198,8 +223,9 @@ impl SubscriptionManager { // If this is new, we need to update the table to query mapping. if !query_state.has_subscribers() { - self.tables.entry(query.return_table()).or_default().insert(hash); - self.tables.entry(query.filter_table()).or_default().insert(hash); + for table_id in query.table_ids() { + self.tables.entry(table_id).or_default().insert(hash); + } } query_state.subscriptions.insert(subscription_id); @@ -230,8 +256,9 @@ impl SubscriptionManager { .queries .entry(hash) .or_insert_with(|| QueryState::new(unit.clone())); - self.tables.entry(unit.return_table()).or_default().insert(hash); - self.tables.entry(unit.filter_table()).or_default().insert(hash); + for table_id in unit.table_ids() { + self.tables.entry(table_id).or_default().insert(hash); + } query_state.legacy_subscribers.insert(client_id); } } @@ -241,9 +268,8 @@ impl SubscriptionManager { // This takes a ref to the table map instead of `self` to avoid borrowing issues. fn remove_query_from_tables(tables: &mut IntMap>, query: &Query) { let hash = query.hash(); - let related_tables = [query.return_table(), query.filter_table()]; - for table_id in related_tables.iter() { - if let Entry::Occupied(mut entry) = tables.entry(*table_id) { + for table_id in query.table_ids() { + if let Entry::Occupied(mut entry) = tables.entry(table_id) { let hashes = entry.get_mut(); if hashes.remove(&hash) && hashes.is_empty() { entry.remove(); @@ -284,47 +310,29 @@ impl SubscriptionManager { /// evaluates only the necessary queries for those delta tables, /// and then sends the results to each client. #[tracing::instrument(skip_all)] - pub fn eval_updates( - &self, - db: &RelationalDB, - tx: &Tx, - event: Arc, - caller: Option<&ClientConnectionSender>, - slow_query_threshold: Option, - ) { + pub fn eval_updates(&self, tx: &DeltaTx, event: Arc, caller: Option<&ClientConnectionSender>) { use FormatSwitch::{Bsatn, Json}; let tables = &event.status.database_update().unwrap().tables; // Put the main work on a rayon compute thread. rayon::scope(|_| { - // Collect the delta tables for each query. - // For selects this is just a single table. - // For joins it's two tables. - let mut units: HashMap<_, ArrayVec<_, 2>> = HashMap::default(); - for table @ DatabaseTableUpdate { table_id, .. } in tables { - if let Some(hashes) = self.tables.get(table_id) { - for hash in hashes { - units.entry(hash).or_insert_with(ArrayVec::new).push(table); - } - } - } - let span = tracing::info_span!("eval_incr").entered(); - let tx = &tx.into(); - let mut eval = units + let mut eval = tables + .iter() + .filter(|table| !table.inserts.is_empty() || !table.deletes.is_empty()) + .map(|DatabaseTableUpdate { table_id, .. }| table_id) + .filter_map(|table_id| self.tables.get(table_id)) + .flatten() + .collect::>() .par_iter() - .filter_map(|(&hash, tables)| { - let unit = &self.queries.get(hash)?.query; - unit.eval_incr(db, tx, &unit.sql, tables.iter().copied(), slow_query_threshold) - .map(|table| (hash, table)) - }) + .filter_map(|&hash| self.queries.get(hash).map(|state| (hash, &state.query))) // If N clients are subscribed to a query, // we copy the DatabaseTableUpdate N times, // which involves cloning BSATN (binary) or product values (json). - .flat_map_iter(|(hash, delta)| { - let table_id = delta.table_id; - let table_name = delta.table_name; + .flat_map_iter(|(hash, plan)| { + let table_id = plan.table_id(); + let table_name = plan.table_name(); // Store at most one copy of the serialization to BSATN // and ditto for the "serialization" for JSON. // Each subscriber gets to pick which of these they want, @@ -344,22 +352,32 @@ impl SubscriptionManager { .clone() } - self.queries - .get(hash) - .into_iter() - .flat_map(|query| query.all_clients()) - .map(move |id| { - let client = &self.clients[id].outbound_ref; - let update = match client.config.protocol { - Protocol::Binary => { - Bsatn(memo_encode::(&delta.updates, client, &mut ops_bin)) - } - Protocol::Text => { - Json(memo_encode::(&delta.updates, client, &mut ops_json)) - } - }; - (id, table_id, table_name.clone(), update) + let evaluator = plan.evaluator(tx); + + // TODO: Handle errors instead of skipping them + eval_delta(tx, &evaluator) + .ok() + .filter(|delta_updates| delta_updates.has_updates()) + .map(|delta_updates| { + self.queries + .get(hash) + .into_iter() + .flat_map(|query| query.all_clients()) + .map(move |id| { + let client = &self.clients[id].outbound_ref; + let update = match client.config.protocol { + Protocol::Binary => { + Bsatn(memo_encode::(&delta_updates, client, &mut ops_bin)) + } + Protocol::Text => { + Json(memo_encode::(&delta_updates, client, &mut ops_json)) + } + }; + (id, table_id, table_name.clone(), update) + }) + .collect::>() }) + .unwrap_or_default() }) .collect::>() .into_iter() @@ -453,10 +471,11 @@ mod tests { use spacetimedb_client_api_messages::websocket::QueryId; use spacetimedb_lib::{error::ResultTest, identity::AuthCtx, Address, AlgebraicType, Identity}; use spacetimedb_primitives::TableId; - use spacetimedb_vm::expr::CrudExpr; + use spacetimedb_query::delta::DeltaPlan; - use super::SubscriptionManager; + use super::{Plan, SubscriptionManager}; use crate::execution_context::Workload; + use crate::sql::ast::SchemaViewer; use crate::subscription::module_subscription_manager::ClientQueryId; use crate::{ client::{ClientActorId, ClientConfig, ClientConnectionSender, ClientName}, @@ -466,28 +485,20 @@ mod tests { module_host::{DatabaseUpdate, EventStatus, ModuleEvent, ModuleFunctionCall}, ArgsTuple, }, - sql::compiler::compile_sql, - subscription::{ - execution_unit::{ExecutionUnit, QueryHash}, - subscription::SupportedQuery, - }, + subscription::execution_unit::QueryHash, }; fn create_table(db: &RelationalDB, name: &str) -> ResultTest { Ok(db.create_table_for_test(name, &[("a", AlgebraicType::U8)], &[])?) } - fn compile_plan(db: &RelationalDB, sql: &str) -> ResultTest> { + fn compile_plan(db: &RelationalDB, sql: &str) -> ResultTest> { db.with_read_only(Workload::ForTests, |tx| { - let mut exprs = compile_sql(db, &AuthCtx::for_testing(), tx, sql)?; - assert_eq!(1, exprs.len()); - assert!(matches!(exprs[0], CrudExpr::Query(_))); - let CrudExpr::Query(query) = exprs.remove(0) else { - unreachable!(); - }; - let plan = SupportedQuery::new(query, sql.to_owned())?; + let auth = AuthCtx::for_testing(); + let tx = SchemaViewer::new(&*tx, &auth); let hash = QueryHash::from_string(sql); - Ok(Arc::new(ExecutionUnit::new(plan, hash)?)) + let plan = DeltaPlan::compile(sql, &tx).unwrap(); + Ok(Arc::new(Plan::new(plan, hash))) }) } @@ -937,7 +948,7 @@ mod tests { }); db.with_read_only(Workload::Update, |tx| { - subscriptions.eval_updates(&db, tx, event, Some(&client0), None) + subscriptions.eval_updates(&(&*tx).into(), event, Some(&client0)) }); tokio::runtime::Builder::new_current_thread() diff --git a/crates/core/src/subscription/query.rs b/crates/core/src/subscription/query.rs index 422d03aa1a7..2dd0ba71da8 100644 --- a/crates/core/src/subscription/query.rs +++ b/crates/core/src/subscription/query.rs @@ -1,12 +1,17 @@ use crate::db::relational_db::{RelationalDB, Tx}; use crate::error::{DBError, SubscriptionError}; +use crate::sql::ast::SchemaViewer; use crate::sql::compiler::compile_sql; use crate::subscription::subscription::SupportedQuery; use once_cell::sync::Lazy; use regex::Regex; use spacetimedb_lib::identity::AuthCtx; +use spacetimedb_query::delta::DeltaPlan; use spacetimedb_vm::expr::{self, Crud, CrudExpr, QueryExpr}; +use super::execution_unit::QueryHash; +use super::module_subscription_manager::Plan; + pub(crate) static WHITESPACE: Lazy = Lazy::new(|| Regex::new(r"\s+").unwrap()); pub const SUBSCRIBE_TO_ALL_QUERY: &str = "SELECT * FROM *"; @@ -64,12 +69,7 @@ pub fn compile_read_only_queryset( /// Compile a string into a single read-only query. /// This returns an error if the string has multiple queries or mutations. -pub fn compile_read_only_query( - relational_db: &RelationalDB, - auth: &AuthCtx, - tx: &Tx, - input: &str, -) -> Result { +pub fn compile_read_only_query(auth: &AuthCtx, tx: &Tx, input: &str) -> Result { let input = input.trim(); if input.is_empty() { return Err(SubscriptionError::Empty.into()); @@ -78,27 +78,11 @@ pub fn compile_read_only_query( // Remove redundant whitespace, and in particular newlines, for debug info. let input = WHITESPACE.replace_all(input, " "); - let single: CrudExpr = { - let mut compiled = compile_sql(relational_db, auth, tx, &input)?; - // Return an error if this doesn't produce exactly one query. - let first_query = compiled.pop(); - let other_queries = compiled.len(); - match (first_query, other_queries) { - (None, _) => return Err(SubscriptionError::Empty.into()), - (Some(q), 0) => q, - _ => return Err(SubscriptionError::Multiple.into()), - } - }; - - Err(SubscriptionError::SideEffect(match single { - CrudExpr::Query(query) => return SupportedQuery::new(query, input.to_string()), - CrudExpr::Insert { .. } => Crud::Insert, - CrudExpr::Update { .. } => Crud::Update, - CrudExpr::Delete { .. } => Crud::Delete, - CrudExpr::SetVar { .. } => Crud::Config, - CrudExpr::ReadVar { .. } => Crud::Config, - }) - .into()) + let tx = SchemaViewer::new(tx, auth); + let plan = DeltaPlan::compile(&input, &tx)?; + let hash = QueryHash::from_string(&input); + + Ok(Plan::new(plan, hash)) } /// The kind of [`QueryExpr`] currently supported for incremental evaluation. @@ -137,7 +121,9 @@ mod tests { use crate::host::module_host::{DatabaseTableUpdate, DatabaseUpdate}; use crate::sql::execute::collect_result; use crate::sql::execute::tests::run_for_testing; - use crate::subscription::subscription::{get_all, ExecutionSet}; + use crate::subscription::delta::eval_delta; + use crate::subscription::subscription::{legacy_get_all, ExecutionSet}; + use crate::subscription::tx::DeltaTx; use crate::vm::tests::create_table_with_rows; use crate::vm::DbProgram; use itertools::Itertools; @@ -574,7 +560,7 @@ mod tests { let row_1 = product!(1u64, "health"); let row_2 = product!(2u64, "jhon doe"); let tx = db.begin_tx(Workload::Subscribe); - let s = get_all(&db, &tx, &AuthCtx::for_testing())?.into(); + let s = legacy_get_all(&db, &tx, &AuthCtx::for_testing())?.into(); check_query_eval(&db, &tx, &s, 2, &[row_1.clone(), row_2.clone()])?; let data1 = DatabaseTableUpdate { @@ -669,7 +655,8 @@ mod tests { let lhs_id = db.create_table_for_test("lhs", &[("id", I32), ("x", I32)], &[0.into()])?; db.with_auto_commit(Workload::ForTests, |tx| { for i in 0..5 { - insert(db, tx, lhs_id, &product!(i, i + 5))?; + let row = product!(i, i + 5); + insert(db, tx, lhs_id, &row)?; } Ok(lhs_id) }) @@ -681,21 +668,20 @@ mod tests { let rhs_id = db.create_table_for_test("rhs", &[("rid", I32), ("id", I32), ("y", I32)], &[1.into()])?; db.with_auto_commit(Workload::ForTests, |tx| { for i in 10..20 { - insert(db, tx, rhs_id, &product!(i, i - 10, i - 8))?; + let row = product!(i, i - 10, i - 8); + insert(db, tx, rhs_id, &row)?; } Ok(rhs_id) }) } - fn compile_query(db: &RelationalDB) -> ResultTest { + fn compile_query(db: &RelationalDB) -> ResultTest { db.with_read_only(Workload::ForTests, |tx| { + let auth = AuthCtx::for_testing(); + let tx = SchemaViewer::new(tx, &auth); // Should be answered using an index semijion let sql = "select lhs.* from lhs join rhs on lhs.id = rhs.id where rhs.y >= 2 and rhs.y <= 4"; - let mut exp = compile_sql(db, &AuthCtx::for_testing(), tx, sql)?; - let Some(CrudExpr::Query(query)) = exp.pop() else { - panic!("unexpected query {:#?}", exp[0]); - }; - singleton_execution_set(query, sql.into()) + Ok(DeltaPlan::compile(sql, &tx).unwrap()) }) } @@ -747,29 +733,48 @@ mod tests { fn eval_incr( db: &RelationalDB, - query: &ExecutionSet, - tables: Vec, + plan: &DeltaPlan, + ops: Vec<(TableId, ProductValue, bool)>, ) -> ResultTest { - let update = DatabaseUpdate { tables }; - db.with_read_only(Workload::ForTests, |tx| { - let tx = (&*tx).into(); - let update = update.tables.iter().collect::>(); - let result = query.eval_incr_for_test(db, &tx, &update, None); - let tables = result - .tables - .iter() - .map(|update| { - let convert = |rvs: &[_]| rvs.iter().cloned().map(RelValue::into_product_value).collect(); - DatabaseTableUpdate { - table_id: update.table_id, - table_name: update.table_name.clone(), - deletes: convert(&update.updates.deletes), - inserts: convert(&update.updates.inserts), - } - }) - .collect(); - Ok(DatabaseUpdate { tables }) - }) + let mut tx = db.begin_mut_tx(IsolationLevel::Serializable, Workload::ForTests); + + for (table_id, row, insert) in ops { + if insert { + insert_row(db, &mut tx, table_id, row)?; + } else { + delete_row(db, &mut tx, table_id, row); + } + } + + let (data, tx) = tx.commit_downgrade(Workload::ForTests); + let table_id = plan.table_id(); + let table_name = plan.table_name(); + let tx = DeltaTx::new(&tx, &data); + let evaluator = plan.evaluator(&tx); + let updates = eval_delta(&tx, &evaluator).unwrap(); + + let inserts = updates + .inserts + .into_iter() + .map(RelValue::into_product_value) + .collect::>(); + let deletes = updates + .deletes + .into_iter() + .map(RelValue::into_product_value) + .collect::>(); + + let tables = if inserts.is_empty() && deletes.is_empty() { + vec![] + } else { + vec![DatabaseTableUpdate { + table_id, + table_name, + inserts, + deletes, + }] + }; + Ok(DatabaseUpdate { tables }) } // Case 1: @@ -783,22 +788,10 @@ mod tests { let r1 = product!(10, 0, 2); let r2 = product!(10, 0, 3); - db.with_auto_commit(Workload::ForTests, |tx| { - delete_row(db, tx, rhs_id, r1.clone()); - insert_row(db, tx, rhs_id, r2.clone()) - })?; - - let result = eval_incr( - db, - &query, - vec![ - delete_op(rhs_id, "rhs", r1.clone()), - insert_op(rhs_id, "rhs", r2.clone()), - ], - )?; + let result = eval_incr(db, &query, vec![(rhs_id, r1, false), (rhs_id, r2, true)])?; // No updates to report - assert_eq!(result.tables.len(), 0); + assert!(result.is_empty()); Ok(()) } @@ -813,22 +806,10 @@ mod tests { let r1 = product!(13, 3, 5); let r2 = product!(13, 3, 6); - db.with_auto_commit(Workload::ForTests, |tx| { - delete_row(db, tx, rhs_id, r1.clone()); - insert_row(db, tx, rhs_id, r2.clone()) - })?; - - let result = eval_incr( - db, - &query, - vec![ - delete_op(rhs_id, "rhs", r1.clone()), - insert_op(rhs_id, "rhs", r2.clone()), - ], - )?; + let result = eval_incr(db, &query, vec![(rhs_id, r1, false), (rhs_id, r2, true)])?; // No updates to report - assert_eq!(result.tables.len(), 0); + assert!(result.is_empty()); Ok(()) } @@ -843,19 +824,7 @@ mod tests { let r1 = product!(10, 0, 2); let r2 = product!(10, 0, 5); - db.with_auto_commit(Workload::ForTests, |tx| { - delete_row(db, tx, rhs_id, r1.clone()); - insert_row(db, tx, rhs_id, r2.clone()) - })?; - - let result = eval_incr( - db, - &query, - vec![ - delete_op(rhs_id, "rhs", r1.clone()), - insert_op(rhs_id, "rhs", r2.clone()), - ], - )?; + let result = eval_incr(db, &query, vec![(rhs_id, r1, false), (rhs_id, r2, true)])?; // A single delete from lhs assert_eq!(result.tables.len(), 1); @@ -874,19 +843,7 @@ mod tests { let r1 = product!(13, 3, 5); let r2 = product!(13, 3, 4); - db.with_auto_commit(Workload::ForTests, |tx| { - delete_row(db, tx, rhs_id, r1.clone()); - insert_row(db, tx, rhs_id, r2.clone()) - })?; - - let result = eval_incr( - db, - &query, - vec![ - delete_op(rhs_id, "rhs", r1.clone()), - insert_op(rhs_id, "rhs", r2.clone()), - ], - )?; + let result = eval_incr(db, &query, vec![(rhs_id, r1, false), (rhs_id, r2, true)])?; // A single insert into lhs assert_eq!(result.tables.len(), 1); @@ -905,19 +862,7 @@ mod tests { let lhs_row = product!(5, 10); let rhs_row = product!(20, 5, 3); - db.with_auto_commit(Workload::ForTests, |tx| { - insert_row(db, tx, lhs_id, lhs_row.clone())?; - insert_row(db, tx, rhs_id, rhs_row.clone()) - })?; - - let result = eval_incr( - db, - &query, - vec![ - insert_op(lhs_id, "lhs", lhs_row.clone()), - insert_op(rhs_id, "rhs", rhs_row.clone()), - ], - )?; + let result = eval_incr(db, &query, vec![(lhs_id, lhs_row, true), (rhs_id, rhs_row, true)])?; // A single insert into lhs assert_eq!(result.tables.len(), 1); @@ -936,19 +881,7 @@ mod tests { let lhs_row = product!(5, 10); let rhs_row = product!(20, 5, 5); - db.with_auto_commit(Workload::ForTests, |tx| { - insert_row(db, tx, lhs_id, lhs_row.clone())?; - insert_row(db, tx, rhs_id, rhs_row.clone()) - })?; - - let result = eval_incr( - db, - &query, - vec![ - insert_op(lhs_id, "lhs", lhs_row.clone()), - insert_op(rhs_id, "rhs", rhs_row.clone()), - ], - )?; + let result = eval_incr(db, &query, vec![(lhs_id, lhs_row, true), (rhs_id, rhs_row, true)])?; // No updates to report assert_eq!(result.tables.len(), 0); @@ -966,20 +899,7 @@ mod tests { let lhs_row = product!(0, 5); let rhs_row = product!(10, 0, 2); - db.with_auto_commit(Workload::ForTests, |tx| -> ResultTest<_> { - delete_row(db, tx, lhs_id, lhs_row.clone()); - delete_row(db, tx, rhs_id, rhs_row.clone()); - Ok(()) - })?; - - let result = eval_incr( - db, - &query, - vec![ - delete_op(lhs_id, "lhs", lhs_row.clone()), - delete_op(rhs_id, "rhs", rhs_row.clone()), - ], - )?; + let result = eval_incr(db, &query, vec![(lhs_id, lhs_row, false), (rhs_id, rhs_row, false)])?; // A single delete from lhs assert_eq!(result.tables.len(), 1); @@ -998,20 +918,7 @@ mod tests { let lhs_row = product!(3, 8); let rhs_row = product!(13, 3, 5); - db.with_auto_commit(Workload::ForTests, |tx| -> ResultTest<_> { - delete_row(db, tx, lhs_id, lhs_row.clone()); - delete_row(db, tx, rhs_id, rhs_row.clone()); - Ok(()) - })?; - - let result = eval_incr( - db, - &query, - vec![ - delete_op(lhs_id, "lhs", lhs_row.clone()), - delete_op(rhs_id, "rhs", rhs_row.clone()), - ], - )?; + let result = eval_incr(db, &query, vec![(lhs_id, lhs_row, false), (rhs_id, rhs_row, false)])?; // No updates to report assert_eq!(result.tables.len(), 0); @@ -1031,32 +938,20 @@ mod tests { let rhs_old = product!(11, 1, 3); let rhs_new = product!(11, 1, 4); - db.with_auto_commit(Workload::ForTests, |tx| { - delete_row(db, tx, lhs_id, lhs_old.clone()); - delete_row(db, tx, rhs_id, rhs_old.clone()); - insert_row(db, tx, lhs_id, lhs_new.clone())?; - insert_row(db, tx, rhs_id, rhs_new.clone()) - })?; - let result = eval_incr( db, &query, vec![ - DatabaseTableUpdate { - table_id: lhs_id, - table_name: "lhs".into(), - deletes: [lhs_old.clone()].into(), - inserts: [lhs_new.clone()].into(), - }, - DatabaseTableUpdate { - table_id: rhs_id, - table_name: "rhs".into(), - deletes: [rhs_old.clone()].into(), - inserts: [rhs_new.clone()].into(), - }, + (lhs_id, lhs_old, false), + (rhs_id, rhs_old, false), + (lhs_id, lhs_new, true), + (rhs_id, rhs_new, true), ], )?; + let lhs_old = product!(1, 6); + let lhs_new = product!(1, 7); + // A delete and an insert into lhs assert_eq!(result.tables.len(), 1); assert_eq!( diff --git a/crates/core/src/subscription/subscription.rs b/crates/core/src/subscription/subscription.rs index 83d1f208992..188fdb9f935 100644 --- a/crates/core/src/subscription/subscription.rs +++ b/crates/core/src/subscription/subscription.rs @@ -20,13 +20,15 @@ //! materialized views are necessary. We find, however, that a particular kind //! of join query _can_ be evaluated incrementally without materialized views. -use super::execution_unit::ExecutionUnit; +use super::execution_unit::{ExecutionUnit, QueryHash}; +use super::module_subscription_manager::Plan; use super::query; use crate::db::datastore::locking_tx_datastore::tx::TxId; use crate::db::relational_db::{RelationalDB, Tx}; use crate::error::{DBError, SubscriptionError}; use crate::host::module_host::{DatabaseTableUpdate, DatabaseUpdateRelValue, UpdatesRelValue}; use crate::messages::websocket as ws; +use crate::sql::ast::SchemaViewer; use crate::vm::{build_query, TxMode}; use anyhow::Context; use itertools::Either; @@ -39,6 +41,7 @@ use spacetimedb_lib::identity::AuthCtx; use spacetimedb_lib::relation::DbTable; use spacetimedb_lib::{Identity, ProductValue}; use spacetimedb_primitives::TableId; +use spacetimedb_query::delta::DeltaPlan; use spacetimedb_vm::expr::{self, AuthAccess, IndexJoin, Query, QueryExpr, SourceExpr, SourceProvider, SourceSet}; use spacetimedb_vm::rel_ops::RelOps; use spacetimedb_vm::relation::{MemTable, RelValue}; @@ -554,6 +557,11 @@ impl ExecutionSet { .map(|unit| unit.row_estimate(tx)) .fold(0, |acc, est| acc.saturating_add(est)) } + + /// Return an iterator over the execution units + pub fn iter(&self) -> impl Iterator { + self.exec_units.iter().map(|arc| &**arc) + } } impl FromIterator for ExecutionSet { @@ -602,7 +610,31 @@ impl AuthAccess for ExecutionSet { /// Queries all the [`StTableType::User`] tables *right now* /// and turns them into [`QueryExpr`], /// the moral equivalent of `SELECT * FROM table`. -pub(crate) fn get_all(relational_db: &RelationalDB, tx: &Tx, auth: &AuthCtx) -> Result, DBError> { +pub(crate) fn get_all(relational_db: &RelationalDB, tx: &Tx, auth: &AuthCtx) -> Result, DBError> { + Ok(relational_db + .get_all_tables(tx)? + .iter() + .map(Deref::deref) + .filter(|t| { + t.table_type == StTableType::User && (auth.owner == auth.caller || t.table_access == StAccess::Public) + }) + .map(|schema| { + let sql = format!("SELECT * FROM {}", schema.table_name); + let hash = QueryHash::from_string(&sql); + DeltaPlan::compile(&sql, &SchemaViewer::new(tx, auth)).map(|plan| Plan::new(plan, hash)) + }) + .collect::>()?) +} + +/// Queries all the [`StTableType::User`] tables *right now* +/// and turns them into [`QueryExpr`], +/// the moral equivalent of `SELECT * FROM table`. +#[cfg(test)] +pub(crate) fn legacy_get_all( + relational_db: &RelationalDB, + tx: &Tx, + auth: &AuthCtx, +) -> Result, DBError> { Ok(relational_db .get_all_tables(tx)? .iter() diff --git a/crates/core/src/subscription/tx.rs b/crates/core/src/subscription/tx.rs new file mode 100644 index 00000000000..5355f3118a4 --- /dev/null +++ b/crates/core/src/subscription/tx.rs @@ -0,0 +1,78 @@ +use std::ops::Deref; + +use spacetimedb_execution::{Datastore, DeltaStore}; +use spacetimedb_lib::{query::Delta, ProductValue}; +use spacetimedb_primitives::TableId; +use spacetimedb_table::{blob_store::BlobStore, table::Table}; + +use crate::db::datastore::{locking_tx_datastore::tx::TxId, traits::TxData}; + +/// A wrapper around a read only tx delta queries +pub struct DeltaTx<'a> { + tx: &'a TxId, + data: Option<&'a TxData>, +} + +impl<'a> DeltaTx<'a> { + pub fn new(tx: &'a TxId, data: &'a TxData) -> Self { + Self { tx, data: Some(data) } + } +} + +impl<'a> Deref for DeltaTx<'a> { + type Target = TxId; + + fn deref(&self) -> &Self::Target { + self.tx + } +} + +impl<'a> From<&'a TxId> for DeltaTx<'a> { + fn from(tx: &'a TxId) -> Self { + Self { tx, data: None } + } +} + +impl Datastore for DeltaTx<'_> { + fn table(&self, table_id: TableId) -> Option<&Table> { + self.tx.table(table_id) + } + + fn blob_store(&self) -> &dyn BlobStore { + self.tx.blob_store() + } +} + +impl DeltaStore for DeltaTx<'_> { + fn has_inserts(&self, table_id: TableId) -> Option { + self.data.and_then(|data| { + data.inserts() + .find(|(id, rows)| **id == table_id && !rows.is_empty()) + .map(|(_, rows)| Delta::Inserts(rows.len())) + }) + } + + fn has_deletes(&self, table_id: TableId) -> Option { + self.data.and_then(|data| { + data.deletes() + .find(|(id, rows)| **id == table_id && !rows.is_empty()) + .map(|(_, rows)| Delta::Deletes(rows.len())) + }) + } + + fn inserts_for_table(&self, table_id: TableId) -> Option> { + self.data.and_then(|data| { + data.inserts() + .find(|(id, rows)| **id == table_id && !rows.is_empty()) + .map(|(_, rows)| rows.iter()) + }) + } + + fn deletes_for_table(&self, table_id: TableId) -> Option> { + self.data.and_then(|data| { + data.deletes() + .find(|(id, rows)| **id == table_id && !rows.is_empty()) + .map(|(_, rows)| rows.iter()) + }) + } +} diff --git a/crates/execution/Cargo.toml b/crates/execution/Cargo.toml index 09ec6d5f1dc..530ad8e4132 100644 --- a/crates/execution/Cargo.toml +++ b/crates/execution/Cargo.toml @@ -7,7 +7,10 @@ license-file = "LICENSE" description = "The SpacetimeDB query engine" [dependencies] +anyhow.workspace = true spacetimedb-expr.workspace = true spacetimedb-lib.workspace = true +spacetimedb-physical-plan.workspace = true spacetimedb-primitives.workspace = true +spacetimedb-sql-parser.workspace = true spacetimedb-table.workspace = true diff --git a/crates/execution/src/iter.rs b/crates/execution/src/iter.rs index 8db3e2c9c7c..5718b67318e 100644 --- a/crates/execution/src/iter.rs +++ b/crates/execution/src/iter.rs @@ -1,747 +1,1072 @@ -use std::ops::{Bound, RangeBounds}; +use std::collections::{HashMap, HashSet}; -use spacetimedb_lib::{AlgebraicValue, ProductValue}; -use spacetimedb_primitives::{IndexId, TableId}; +use anyhow::{anyhow, bail, Result}; +use spacetimedb_lib::{query::Delta, AlgebraicValue, ProductValue}; +use spacetimedb_physical_plan::plan::{ + HashJoin, IxJoin, IxScan, PhysicalExpr, PhysicalPlan, ProjectField, ProjectPlan, Sarg, Semi, TupleField, +}; use spacetimedb_table::{ blob_store::BlobStore, btree_index::{BTreeIndex, BTreeIndexRangeIter}, - static_assert_size, - table::{IndexScanIter, RowRef, Table, TableScanIter}, + table::{IndexScanIter, Table, TableScanIter}, }; -/// A row from a base table in the form of a pointer or product value -#[derive(Clone)] -pub enum Row<'a> { - Ptr(RowRef<'a>), - Ref(&'a ProductValue), +use crate::{Datastore, DeltaScanIter, DeltaStore, Row, Tuple}; + +/// The different iterators for evaluating query plans +pub enum PlanIter<'a> { + Table(TableScanIter<'a>), + Index(IndexScanIter<'a>), + Delta(DeltaScanIter<'a>), + RowId(RowRefIter<'a>), + Tuple(ProjectIter<'a>), } -impl Row<'_> { - /// Expect a pointer value, panic otherwise - pub fn expect_ptr(&self) -> &RowRef { - match self { - Self::Ptr(ptr) => ptr, - _ => unreachable!(), - } +impl<'a> PlanIter<'a> { + pub(crate) fn build(plan: &'a ProjectPlan, tx: &'a Tx) -> Result + where + Tx: Datastore + DeltaStore, + { + ProjectIter::build(plan, tx).map(|iter| match iter { + ProjectIter::None(Iter::Row(RowRefIter::TableScan(iter))) => Self::Table(iter), + ProjectIter::None(Iter::Row(RowRefIter::IndexScan(iter))) => Self::Index(iter), + ProjectIter::None(Iter::Row(iter)) => Self::RowId(iter), + _ => Self::Tuple(iter), + }) } +} + +/// Implements a tuple projection for a query plan +pub enum ProjectIter<'a> { + None(Iter<'a>), + Some(Iter<'a>, usize), +} - /// Expect a product value, panic otherwise - pub fn expect_ref(&self) -> &ProductValue { +impl<'a> Iterator for ProjectIter<'a> { + type Item = Row<'a>; + + fn next(&mut self) -> Option { match self { - Self::Ref(r) => r, - _ => unreachable!(), + Self::None(iter) => iter.find_map(|tuple| { + if let Tuple::Row(ptr) = tuple { + return Some(ptr); + } + None + }), + Self::Some(iter, i) => iter.find_map(|tuple| tuple.select(*i)), } } } -static_assert_size!(Row, 32); +impl<'a> ProjectIter<'a> { + pub fn build(plan: &'a ProjectPlan, tx: &'a Tx) -> Result + where + Tx: Datastore + DeltaStore, + { + match plan { + ProjectPlan::None(plan) | ProjectPlan::Name(plan, _, None) => Iter::build(plan, tx).map(Self::None), + ProjectPlan::Name(plan, _, Some(i)) => Iter::build(plan, tx).map(|iter| Self::Some(iter, *i)), + } + } +} -/// A tuple returned by a query iterator -#[derive(Clone)] -pub enum Tuple<'a> { - /// A row from a base table - Row(Row<'a>), - /// A temporary constructed by a query operator - Join(Vec>), +/// A generic tuple-at-a-time iterator for a query plan +pub enum Iter<'a> { + Row(RowRefIter<'a>), + Join(LeftDeepJoinIter<'a>), + Filter(Filter<'a, Iter<'a>>), } -static_assert_size!(Tuple, 40); +impl<'a> Iterator for Iter<'a> { + type Item = Tuple<'a>; -impl Tuple<'_> { - /// Expect a row from a base table, panic otherwise - pub fn expect_row(&self) -> &Row { + fn next(&mut self) -> Option { match self { - Self::Row(row) => row, - _ => unreachable!(), + Self::Row(iter) => iter.next().map(Tuple::Row), + Self::Join(iter) => iter.next(), + Self::Filter(iter) => iter.next(), } } +} - /// Expect a temporary tuple, panic otherwise - pub fn expect_join(&self) -> &[Row] { - match self { - Self::Join(elems) => elems.as_slice(), - _ => unreachable!(), +impl<'a> Iter<'a> { + fn build(plan: &'a PhysicalPlan, tx: &'a Tx) -> Result + where + Tx: Datastore + DeltaStore, + { + match plan { + PhysicalPlan::TableScan(..) | PhysicalPlan::IxScan(..) => RowRefIter::build(plan, tx).map(Self::Row), + PhysicalPlan::Filter(input, expr) => { + // Build a filter iterator + Iter::build(input, tx) + .map(Box::new) + .map(|input| Filter { input, expr }) + .map(Iter::Filter) + } + PhysicalPlan::NLJoin(lhs, rhs) => { + // Build a nested loop join iterator + NLJoin::build_from(lhs, rhs, tx) + .map(LeftDeepJoinIter::NLJoin) + .map(Iter::Join) + } + PhysicalPlan::IxJoin(join @ IxJoin { unique: false, .. }, Semi::Lhs) => { + // Build a left index semijoin iterator + IxJoinLhs::build_from(join, tx) + .map(SemiJoin::Lhs) + .map(LeftDeepJoinIter::IxJoin) + .map(Iter::Join) + } + PhysicalPlan::IxJoin(join @ IxJoin { unique: false, .. }, Semi::Rhs) => { + // Build a right index semijoin iterator + IxJoinRhs::build_from(join, tx) + .map(SemiJoin::Rhs) + .map(LeftDeepJoinIter::IxJoin) + .map(Iter::Join) + } + PhysicalPlan::IxJoin(join @ IxJoin { unique: false, .. }, Semi::All) => { + // Build an index join iterator + IxJoinIter::build_from(join, tx) + .map(SemiJoin::All) + .map(LeftDeepJoinIter::IxJoin) + .map(Iter::Join) + } + PhysicalPlan::IxJoin(join @ IxJoin { unique: true, .. }, Semi::Lhs) => { + // Build a unique left index semijoin iterator + UniqueIxJoinLhs::build_from(join, tx) + .map(SemiJoin::Lhs) + .map(LeftDeepJoinIter::UniqueIxJoin) + .map(Iter::Join) + } + PhysicalPlan::IxJoin(join @ IxJoin { unique: true, .. }, Semi::Rhs) => { + // Build a unique right index semijoin iterator + UniqueIxJoinRhs::build_from(join, tx) + .map(SemiJoin::Rhs) + .map(LeftDeepJoinIter::UniqueIxJoin) + .map(Iter::Join) + } + PhysicalPlan::IxJoin(join @ IxJoin { unique: true, .. }, Semi::All) => { + // Build a unique index join iterator + UniqueIxJoin::build_from(join, tx) + .map(SemiJoin::All) + .map(LeftDeepJoinIter::UniqueIxJoin) + .map(Iter::Join) + } + PhysicalPlan::HashJoin(join @ HashJoin { unique: false, .. }, Semi::Lhs) => { + // Build a left hash semijoin iterator + HashJoinLhs::build_from(join, tx) + .map(SemiJoin::Lhs) + .map(LeftDeepJoinIter::HashJoin) + .map(Iter::Join) + } + PhysicalPlan::HashJoin(join @ HashJoin { unique: false, .. }, Semi::Rhs) => { + // Build a right hash semijoin iterator + HashJoinRhs::build_from(join, tx) + .map(SemiJoin::Rhs) + .map(LeftDeepJoinIter::HashJoin) + .map(Iter::Join) + } + PhysicalPlan::HashJoin(join @ HashJoin { unique: false, .. }, Semi::All) => { + // Build a hash join iterator + HashJoinIter::build_from(join, tx) + .map(SemiJoin::All) + .map(LeftDeepJoinIter::HashJoin) + .map(Iter::Join) + } + PhysicalPlan::HashJoin(join @ HashJoin { unique: true, .. }, Semi::Lhs) => { + // Build a unique left hash semijoin iterator + UniqueHashJoinLhs::build_from(join, tx) + .map(SemiJoin::Lhs) + .map(LeftDeepJoinIter::UniqueHashJoin) + .map(Iter::Join) + } + PhysicalPlan::HashJoin(join @ HashJoin { unique: true, .. }, Semi::Rhs) => { + // Build a unique right hash semijoin iterator + UniqueHashJoinRhs::build_from(join, tx) + .map(SemiJoin::Rhs) + .map(LeftDeepJoinIter::UniqueHashJoin) + .map(Iter::Join) + } + PhysicalPlan::HashJoin(join @ HashJoin { unique: true, .. }, Semi::All) => { + // Build a unique hash join iterator + UniqueHashJoin::build_from(join, tx) + .map(SemiJoin::All) + .map(LeftDeepJoinIter::UniqueHashJoin) + .map(Iter::Join) + } } } } -/// An execution plan for a tuple-at-a-time iterator. -/// As the name suggests it is meant to be cached. -/// Building the iterator should incur minimal overhead. -pub struct CachedIterPlan { - /// The relational ops - iter_ops: Box<[IterOp]>, - /// The expression ops - expr_ops: Box<[OpCode]>, - /// The constants referenced by the plan - constants: Box<[AlgebraicValue]>, -} - -static_assert_size!(CachedIterPlan, 48); - -impl CachedIterPlan { - /// Returns an interator over the query ops - fn ops(&self) -> impl Iterator + '_ { - self.iter_ops.iter().copied() - } - - /// Lookup a constant in the plan - fn constant(&self, i: u16) -> &AlgebraicValue { - &self.constants[i as usize] - } -} - -/// An opcode for a tuple-at-a-time execution plan -#[derive(Clone, Copy)] -pub enum IterOp { - /// A table scan opcode takes 1 arg: A [TableId] - TableScan(TableId), - /// A delta scan opcode takes 1 arg: A [TableId] - DeltaScan(TableId), - /// An index scan opcode takes 2 args: - /// 1. An [IndexId] - /// 2. A ptr to an [AlgebraicValue] - IxScanEq(IndexId, u16), - /// An index range scan opcode takes 3 args: - /// 1. An [IndexId] - /// 2. A ptr to the lower bound - /// 3. A ptr to the upper bound - IxScanRange(IndexId, Bound, Bound), - /// Pops its 2 args from the stack - NLJoin, - /// An index join opcode takes 2 args: - /// 1. An [IndexId] - /// 2. An instruction ptr - /// 3. A length - IxJoin(IndexId, usize, u16), - /// An index join opcode takes 2 args: - /// 1. An [IndexId] - /// 2. An instruction ptr - /// 3. A length - UniqueIxJoin(IndexId, usize, u16), - /// A filter opcode takes 2 args: - /// 1. An instruction ptr - /// 2. A length - Filter(usize, u32), -} - -static_assert_size!(IterOp, 16); - -pub trait Datastore { - fn delta_scan_iter(&self, table_id: TableId) -> DeltaScanIter; - fn table_scan_iter(&self, table_id: TableId) -> TableScanIter; - fn index_scan_iter(&self, index_id: IndexId, range: &impl RangeBounds) -> IndexScanIter; - fn get_table_for_index(&self, index_id: &IndexId) -> &Table; - fn get_index(&self, index_id: &IndexId) -> &BTreeIndex; - fn get_blob_store(&self) -> &dyn BlobStore; -} - -/// An iterator for a delta table -pub struct DeltaScanIter<'a> { - iter: std::slice::Iter<'a, ProductValue>, -} - -impl<'a> Iterator for DeltaScanIter<'a> { - type Item = &'a ProductValue; +/// An iterator that always returns [RowRef]s +pub enum RowRefIter<'a> { + TableScan(TableScanIter<'a>), + IndexScan(IndexScanIter<'a>), + DeltaScan(DeltaScanIter<'a>), + RowFilter(Filter<'a, RowRefIter<'a>>), +} + +impl<'a> Iterator for RowRefIter<'a> { + type Item = Row<'a>; fn next(&mut self) -> Option { - self.iter.next() + match self { + Self::TableScan(iter) => iter.next().map(Row::Ptr), + Self::IndexScan(iter) => iter.next().map(Row::Ptr), + Self::DeltaScan(iter) => iter.next().map(Row::Ref), + Self::RowFilter(iter) => iter.next(), + } } } -impl CachedIterPlan { - pub fn iter<'a>(&'a self, tx: &'a impl Datastore) -> Iter<'a> { - let mut stack = vec![]; - for op in self.ops() { - match op { - IterOp::TableScan(table_id) => { - // Push table scan - stack.push(Iter::TableScan(tx.table_scan_iter(table_id))); - } - IterOp::DeltaScan(table_id) => { - // Push delta scan - stack.push(Iter::DeltaScan(tx.delta_scan_iter(table_id))); - } - IterOp::IxScanEq(index_id, ptr) => { - // Push index scan - stack.push(Iter::IndexScan(tx.index_scan_iter(index_id, &self.constant(ptr)))); - } - IterOp::IxScanRange(index_id, lower, upper) => { - // Push range scan - let lower = lower.map(|ptr| self.constant(ptr)); - let upper = upper.map(|ptr| self.constant(ptr)); - stack.push(Iter::IndexScan(tx.index_scan_iter(index_id, &(lower, upper)))); - } - IterOp::NLJoin => { - // Pop args and push nested loop join - let rhs = stack.pop().unwrap(); - let lhs = stack.pop().unwrap(); - stack.push(Iter::NLJoin(NestedLoopJoin::new(lhs, rhs))); - } - IterOp::IxJoin(index_id, i, n) => { - // Pop arg and push index join - let input = stack.pop().unwrap(); - let index = tx.get_index(&index_id); - let table = tx.get_table_for_index(&index_id); - let blob_store = tx.get_blob_store(); - let ops = &self.expr_ops[i..i + n as usize]; - let program = ExprProgram::new(ops, &self.constants); - let projection = ProgramEvaluator::from(program); - stack.push(Iter::IxJoin(LeftDeepJoin::Eq(IndexJoin::new( - input, index, table, blob_store, projection, - )))); - } - IterOp::UniqueIxJoin(index_id, i, n) => { - // Pop arg and push index join - let input = stack.pop().unwrap(); - let index = tx.get_index(&index_id); - let table = tx.get_table_for_index(&index_id); - let blob_store = tx.get_blob_store(); - let ops = &self.expr_ops[i..i + n as usize]; - let program = ExprProgram::new(ops, &self.constants); - let projection = ProgramEvaluator::from(program); - stack.push(Iter::UniqueIxJoin(LeftDeepJoin::Eq(UniqueIndexJoin::new( - input, index, table, blob_store, projection, - )))); - } - IterOp::Filter(i, n) => { - // Pop arg and push filter - let input = Box::new(stack.pop().unwrap()); - let ops = &self.expr_ops[i..i + n as usize]; - let program = ExprProgram::new(ops, &self.constants); - let program = ProgramEvaluator::from(program); - stack.push(Iter::Filter(Filter { input, program })); - } +impl<'a> RowRefIter<'a> { + /// Instantiate an iterator from a [PhysicalPlan]. + /// The compiler ensures this isn't called on a join. + fn build(plan: &'a PhysicalPlan, tx: &'a Tx) -> Result + where + Tx: Datastore + DeltaStore, + { + let concat = |prefix: &[(_, AlgebraicValue)], v| { + ProductValue::from_iter(prefix.iter().map(|(_, v)| v).chain([v]).cloned()) + }; + match plan { + PhysicalPlan::TableScan(schema, _, None) => tx.table_scan(schema.table_id).map(Self::TableScan), + PhysicalPlan::TableScan(schema, _, Some(Delta::Inserts(..))) => { + tx.delta_scan(schema.table_id, true).map(Self::DeltaScan) + } + PhysicalPlan::TableScan(schema, _, Some(Delta::Deletes(..))) => { + tx.delta_scan(schema.table_id, false).map(Self::DeltaScan) } + PhysicalPlan::IxScan( + scan @ IxScan { + arg: Sarg::Eq(_, v), .. + }, + _, + ) if scan.prefix.is_empty() => tx + .index_scan(scan.schema.table_id, scan.index_id, v) + .map(Self::IndexScan), + PhysicalPlan::IxScan( + scan @ IxScan { + arg: Sarg::Eq(_, v), .. + }, + _, + ) => tx + .index_scan( + scan.schema.table_id, + scan.index_id, + &AlgebraicValue::product(concat(&scan.prefix, v)), + ) + .map(Self::IndexScan), + PhysicalPlan::IxScan( + scan @ IxScan { + arg: Sarg::Range(_, lower, upper), + .. + }, + _, + ) if scan.prefix.is_empty() => tx + .index_scan(scan.schema.table_id, scan.index_id, &(lower.as_ref(), upper.as_ref())) + .map(Self::IndexScan), + PhysicalPlan::IxScan( + scan @ IxScan { + arg: Sarg::Range(_, lower, upper), + .. + }, + _, + ) => tx + .index_scan( + scan.schema.table_id, + scan.index_id, + &( + lower + .as_ref() + .map(|v| concat(&scan.prefix, v)) + .map(AlgebraicValue::Product), + upper + .as_ref() + .map(|v| concat(&scan.prefix, v)) + .map(AlgebraicValue::Product), + ), + ) + .map(Self::IndexScan), + PhysicalPlan::Filter(input, expr) => Self::build(input, tx) + .map(Box::new) + .map(|input| Filter { input, expr }) + .map(Self::RowFilter), + _ => bail!("Plan does not return row ids"), } - stack.pop().unwrap() } } -/// A tuple-at-a-time query iterator. -/// Notice there is no explicit projection operation. -/// This is because for applicable plans, -/// the optimizer can remove intermediate projections, -/// implementing a form of late materialization. -pub enum Iter<'a> { - /// A [RowRef] table iterator - TableScan(TableScanIter<'a>), - /// A [ProductValue] ref iterator - DeltaScan(DeltaScanIter<'a>), - /// A [RowRef] index iterator - IndexScan(IndexScanIter<'a>), - /// A nested loop join iterator - NLJoin(NestedLoopJoin<'a>), - /// A non-unique (constraint) index join iterator - IxJoin(LeftDeepJoin>), - /// A unique (constraint) index join iterator - UniqueIxJoin(LeftDeepJoin>), - /// A tuple-at-a-time filter iterator - Filter(Filter<'a>), +/// An iterator for a left deep join tree. +/// +/// ```text +/// x +/// / \ +/// x c +/// / \ +/// a b +/// ``` +pub enum LeftDeepJoinIter<'a> { + /// A nested loop join + NLJoin(NLJoin<'a>), + /// An index join + IxJoin(SemiJoin, IxJoinLhs<'a>, IxJoinRhs<'a>>), + /// An index join for a unique constraint + UniqueIxJoin(SemiJoin, UniqueIxJoinLhs<'a>, UniqueIxJoinRhs<'a>>), + /// A hash join + HashJoin(SemiJoin, HashJoinLhs<'a>, HashJoinRhs<'a>>), + /// A hash join for a unique constraint + UniqueHashJoin(SemiJoin, UniqueHashJoinLhs<'a>, UniqueHashJoinRhs<'a>>), } -impl<'a> Iterator for Iter<'a> { +impl<'a> Iterator for LeftDeepJoinIter<'a> { type Item = Tuple<'a>; fn next(&mut self) -> Option { match self { - Self::TableScan(iter) => { - // Returns row ids - iter.next().map(Row::Ptr).map(Tuple::Row) - } - Self::DeltaScan(iter) => { - // Returns product refs - iter.next().map(Row::Ref).map(Tuple::Row) - } - Self::IndexScan(iter) => { - // Returns row ids - iter.next().map(Row::Ptr).map(Tuple::Row) - } - Self::IxJoin(iter) => { - // Returns row ids for semijoins, (n+1)-tuples otherwise - iter.next() - } - Self::UniqueIxJoin(iter) => { - // Returns row ids for semijoins, (n+1)-tuples otherwise - iter.next() - } - Self::Filter(iter) => { - // Filter is a passthru - iter.next() - } - Self::NLJoin(iter) => { - iter.next().map(|t| { - match t { - // A leaf join - // x - // / \ - // a b - (Tuple::Row(u), Tuple::Row(v)) => { - // Returns a 2-tuple - Tuple::Join(vec![u, v]) - } - // A right deep join - // x - // / \ - // a x - // / \ - // b c - (Tuple::Row(r), Tuple::Join(mut rows)) => { - // Returns an (n+1)-tuple - let mut pointers = vec![r]; - pointers.append(&mut rows); - Tuple::Join(pointers) - } - // A left deep join - // x - // / \ - // x c - // / \ - // a b - (Tuple::Join(mut rows), Tuple::Row(r)) => { - // Returns an (n+1)-tuple - rows.push(r); - Tuple::Join(rows) - } - // A bushy join - // x - // / \ - // / \ - // x x - // / \ / \ - // a b c d - (Tuple::Join(mut lhs), Tuple::Join(mut rhs)) => { - // Returns an (n+m)-tuple - lhs.append(&mut rhs); - Tuple::Join(lhs) - } - } - }) - } + Self::NLJoin(iter) => iter.next().map(|(tuple, rhs)| tuple.append(rhs)), + Self::IxJoin(iter) => iter.next(), + Self::UniqueIxJoin(iter) => iter.next(), + Self::HashJoin(iter) => iter.next(), + Self::UniqueHashJoin(iter) => iter.next(), } } } -/// An iterator for a left deep join tree -pub enum LeftDeepJoin { - /// A standard join - Eq(Iter), - /// A semijoin that returns the lhs - SemiLhs(Iter), - /// A semijion that returns the rhs - SemiRhs(Iter), +/// A semijoin iterator. +/// Returns [RowRef]s if this is a right semijoin. +/// Returns [Tuple]s otherwise. +pub enum SemiJoin { + All(All), + Lhs(Lhs), + Rhs(Rhs), } -impl<'a, Iter> Iterator for LeftDeepJoin +impl<'a, All, Lhs, Rhs> Iterator for SemiJoin where - Iter: Iterator, RowRef<'a>)>, + All: Iterator, Row<'a>)>, + Lhs: Iterator>, + Rhs: Iterator>, { type Item = Tuple<'a>; fn next(&mut self) -> Option { match self { - Self::SemiLhs(iter) => { - // Return the lhs tuple - iter.next().map(|(t, _)| t) - } - Self::SemiRhs(iter) => { - // Return the rhs row - iter.next().map(|(_, ptr)| ptr).map(Row::Ptr).map(Tuple::Row) - } - Self::Eq(iter) => { - iter.next().map(|(tuple, ptr)| { - match (tuple, ptr) { - // A leaf join - // x - // / \ - // a b - (Tuple::Row(u), ptr) => { - // Returns a 2-tuple - Tuple::Join(vec![u, Row::Ptr(ptr)]) - } - // A left deep join - // x - // / \ - // x c - // / \ - // a b - (Tuple::Join(mut rows), ptr) => { - // Returns an (n+1)-tuple - rows.push(Row::Ptr(ptr)); - Tuple::Join(rows) - } - } - }) - } + Self::All(iter) => iter.next().map(|(tuple, ptr)| tuple.append(ptr)), + Self::Lhs(iter) => iter.next(), + Self::Rhs(iter) => iter.next().map(Tuple::Row), } } } -/// A unique (constraint) index join iterator -pub struct UniqueIndexJoin<'a> { +/// An index join that uses a unique constraint index +pub struct UniqueIxJoin<'a> { /// The lhs of the join - input: Box>, + lhs: Box>, /// The rhs index - index: &'a BTreeIndex, + rhs_index: &'a BTreeIndex, /// A handle to the datastore - table: &'a Table, + rhs_table: &'a Table, /// A handle to the blobstore blob_store: &'a dyn BlobStore, - /// The lhs index key projection - projection: ProgramEvaluator<'a>, -} - -impl<'a> UniqueIndexJoin<'a> { - fn new( - input: Iter<'a>, - index: &'a BTreeIndex, - table: &'a Table, - blob_store: &'a dyn BlobStore, - projection: ProgramEvaluator<'a>, - ) -> Self { - Self { - input: Box::new(input), - index, - table, - blob_store, - projection, - } + /// The lhs probe field + lhs_field: &'a TupleField, +} + +impl<'a> UniqueIxJoin<'a> { + fn build_from(join: &'a IxJoin, tx: &'a Tx) -> Result + where + Tx: Datastore + DeltaStore, + { + let lhs = Iter::build(&join.lhs, tx)?; + let rhs_table = tx.table_or_err(join.rhs.table_id)?; + let rhs_index = rhs_table + .get_index(join.rhs_index) + .ok_or_else(|| anyhow!("IndexId `{}` does not exist", join.rhs_index))?; + Ok(Self { + lhs: Box::new(lhs), + rhs_index, + rhs_table, + blob_store: tx.blob_store(), + lhs_field: &join.lhs_field, + }) } } -impl<'a> Iterator for UniqueIndexJoin<'a> { - type Item = (Tuple<'a>, RowRef<'a>); +impl<'a> Iterator for UniqueIxJoin<'a> { + type Item = (Tuple<'a>, Row<'a>); fn next(&mut self) -> Option { - self.input.find_map(|tuple| { - self.index - .seek(&self.projection.eval(&tuple)) + self.lhs.find_map(|tuple| { + self.rhs_index + .seek(&tuple.project(self.lhs_field)) .next() - .and_then(|ptr| self.table.get_row_ref(self.blob_store, ptr)) + .and_then(|ptr| self.rhs_table.get_row_ref(self.blob_store, ptr)) + .map(Row::Ptr) .map(|ptr| (tuple, ptr)) }) } } -/// A non-unique (constraint) index join iterator -pub struct IndexJoin<'a> { +/// A left semijoin that uses a unique constraint index +pub struct UniqueIxJoinLhs<'a> { + /// The lhs of the join + lhs: Box>, + /// The rhs index + rhs: &'a BTreeIndex, + /// The lhs probe field + lhs_field: &'a TupleField, +} + +impl<'a> UniqueIxJoinLhs<'a> { + fn build_from(join: &'a IxJoin, tx: &'a Tx) -> Result + where + Tx: Datastore + DeltaStore, + { + let lhs = Iter::build(&join.lhs, tx)?; + let rhs_table = tx.table_or_err(join.rhs.table_id)?; + let rhs_index = rhs_table + .get_index(join.rhs_index) + .ok_or_else(|| anyhow!("IndexId `{}` does not exist", join.rhs_index))?; + Ok(Self { + lhs: Box::new(lhs), + rhs: rhs_index, + lhs_field: &join.lhs_field, + }) + } +} + +impl<'a> Iterator for UniqueIxJoinLhs<'a> { + type Item = Tuple<'a>; + + fn next(&mut self) -> Option { + self.lhs.find(|t| self.rhs.contains_any(&t.project(self.lhs_field))) + } +} + +/// A right semijoin that uses a unique constraint index +pub struct UniqueIxJoinRhs<'a> { /// The lhs of the join - input: Box>, - /// The current tuple from the lhs - tuple: Option>, + lhs: Box>, /// The rhs index - index: &'a BTreeIndex, - /// The current cursor for the rhs index - index_cursor: Option>, + rhs_index: &'a BTreeIndex, /// A handle to the datastore - table: &'a Table, + rhs_table: &'a Table, /// A handle to the blobstore blob_store: &'a dyn BlobStore, - /// The lhs index key projection - projection: ProgramEvaluator<'a>, -} - -impl<'a> IndexJoin<'a> { - fn new( - input: Iter<'a>, - index: &'a BTreeIndex, - table: &'a Table, - blob_store: &'a dyn BlobStore, - projection: ProgramEvaluator<'a>, - ) -> Self { - Self { - input: Box::new(input), - tuple: None, - index, - index_cursor: None, - table, - blob_store, - projection, - } + /// The lhs probe field + lhs_field: &'a TupleField, +} + +impl<'a> UniqueIxJoinRhs<'a> { + fn build_from(join: &'a IxJoin, tx: &'a Tx) -> Result + where + Tx: Datastore + DeltaStore, + { + let lhs = Iter::build(&join.lhs, tx)?; + let rhs_table = tx.table_or_err(join.rhs.table_id)?; + let rhs_index = rhs_table + .get_index(join.rhs_index) + .ok_or_else(|| anyhow!("IndexId `{}` does not exist", join.rhs_index))?; + Ok(Self { + lhs: Box::new(lhs), + rhs_index, + rhs_table, + blob_store: tx.blob_store(), + lhs_field: &join.lhs_field, + }) } } -impl<'a> Iterator for IndexJoin<'a> { - type Item = (Tuple<'a>, RowRef<'a>); +impl<'a> Iterator for UniqueIxJoinRhs<'a> { + type Item = Row<'a>; fn next(&mut self) -> Option { - self.tuple + self.lhs.find_map(|tuple| { + self.rhs_index + .seek(&tuple.project(self.lhs_field)) + .next() + .and_then(|ptr| self.rhs_table.get_row_ref(self.blob_store, ptr)) + .map(Row::Ptr) + }) + } +} + +/// An index join that does not use a unique constraint index +pub struct IxJoinIter<'a> { + /// The lhs of the join + lhs: Box>, + /// The current lhs tuple + lhs_tuple: Option>, + /// The rhs index + rhs_index: &'a BTreeIndex, + /// The current rhs index cursor + rhs_index_cursor: Option>, + /// A handle to the datastore + rhs_table: &'a Table, + /// A handle to the blobstore + blob_store: &'a dyn BlobStore, + /// The lhs probe field + lhs_field: &'a TupleField, +} + +impl<'a> IxJoinIter<'a> { + fn build_from(join: &'a IxJoin, tx: &'a Tx) -> Result + where + Tx: Datastore + DeltaStore, + { + let lhs = Iter::build(&join.lhs, tx)?; + let rhs_table = tx.table_or_err(join.rhs.table_id)?; + let rhs_index = rhs_table + .get_index(join.rhs_index) + .ok_or_else(|| anyhow!("IndexId `{}` does not exist", join.rhs_index))?; + Ok(Self { + lhs: Box::new(lhs), + lhs_tuple: None, + rhs_index, + rhs_index_cursor: None, + rhs_table, + blob_store: tx.blob_store(), + lhs_field: &join.lhs_field, + }) + } +} + +impl<'a> Iterator for IxJoinIter<'a> { + type Item = (Tuple<'a>, Row<'a>); + + fn next(&mut self) -> Option { + self.lhs_tuple .as_ref() .and_then(|tuple| { - self.index_cursor.as_mut().and_then(|cursor| { + self.rhs_index_cursor.as_mut().and_then(|cursor| { cursor.next().and_then(|ptr| { - self.table + self.rhs_table .get_row_ref(self.blob_store, ptr) + .map(Row::Ptr) .map(|ptr| (tuple.clone(), ptr)) }) }) }) .or_else(|| { - self.input.find_map(|tuple| { - Some(self.index.seek(&self.projection.eval(&tuple))).and_then(|mut cursor| { - cursor.next().and_then(|ptr| { - self.table.get_row_ref(self.blob_store, ptr).map(|ptr| { - self.tuple = Some(tuple.clone()); - self.index_cursor = Some(cursor); + self.lhs.find_map(|tuple| { + let mut cursor = self.rhs_index.seek(&tuple.project(self.lhs_field)); + cursor.next().and_then(|ptr| { + self.rhs_table + .get_row_ref(self.blob_store, ptr) + .map(Row::Ptr) + .map(|ptr| { + self.lhs_tuple = Some(tuple.clone()); + self.rhs_index_cursor = Some(cursor); (tuple, ptr) }) - }) }) }) }) } } -/// A nested loop join returns the cross product of its inputs -pub struct NestedLoopJoin<'a> { - /// The lhs input +/// A left semijoin that does not use a unique constraint index +pub struct IxJoinLhs<'a> { + /// The lhs of the join lhs: Box>, - /// The rhs input - rhs: Box>, - /// The materialized rhs - build: Vec>, + /// The rhs index + rhs_index: &'a BTreeIndex, /// The current lhs tuple - lhs_row: Option>, - /// The current rhs tuple - rhs_ptr: usize, + lhs_tuple: Option>, + /// The matching rhs row count + rhs_count: usize, + /// The lhs probe field + lhs_field: &'a TupleField, } -impl<'a> NestedLoopJoin<'a> { - fn new(lhs: Iter<'a>, rhs: Iter<'a>) -> Self { - Self { +impl<'a> IxJoinLhs<'a> { + fn build_from(join: &'a IxJoin, tx: &'a Tx) -> Result + where + Tx: Datastore + DeltaStore, + { + let lhs = Iter::build(&join.lhs, tx)?; + let rhs_table = tx.table_or_err(join.rhs.table_id)?; + let rhs_index = rhs_table + .get_index(join.rhs_index) + .ok_or_else(|| anyhow!("IndexId `{}` does not exist", join.rhs_index))?; + Ok(Self { lhs: Box::new(lhs), - rhs: Box::new(rhs), - build: vec![], - lhs_row: None, - rhs_ptr: 0, - } + lhs_tuple: None, + rhs_count: 0, + rhs_index, + lhs_field: &join.lhs_field, + }) } } -impl<'a> Iterator for NestedLoopJoin<'a> { - type Item = (Tuple<'a>, Tuple<'a>); +impl<'a> Iterator for IxJoinLhs<'a> { + type Item = Tuple<'a>; fn next(&mut self) -> Option { - for t in self.rhs.as_mut() { - self.build.push(t); - } - match self.build.get(self.rhs_ptr) { - Some(v) => { - self.rhs_ptr += 1; - self.lhs_row.as_ref().map(|u| (u.clone(), v.clone())) - } - None => { - self.rhs_ptr = 1; - self.lhs_row = self.lhs.next(); - self.lhs_row - .as_ref() - .zip(self.build.first()) - .map(|(u, v)| (u.clone(), v.clone())) + match self.rhs_count { + 0 => self + .lhs + .find_map(|tuple| self.rhs_index.count(&tuple.project(self.lhs_field)).map(|n| (tuple, n))) + .map(|(tuple, n)| { + self.rhs_count = n - 1; + self.lhs_tuple = Some(tuple.clone()); + tuple + }), + _ => { + self.rhs_count -= 1; + self.lhs_tuple.clone() } } } } -/// A tuple-at-a-time filter iterator -pub struct Filter<'a> { - input: Box>, - program: ProgramEvaluator<'a>, +/// A right semijoin that does not use a unique constraint index +pub struct IxJoinRhs<'a> { + /// The lhs of the join + lhs: Box>, + /// The rhs index + rhs_index: &'a BTreeIndex, + /// The current rhs index cursor + rhs_index_cursor: Option>, + /// A handle to the datastore + rhs_table: &'a Table, + /// A handle to the blobstore + blob_store: &'a dyn BlobStore, + /// The lhs probe field + lhs_field: &'a TupleField, } -impl<'a> Iterator for Filter<'a> { - type Item = Tuple<'a>; +impl<'a> IxJoinRhs<'a> { + fn build_from(join: &'a IxJoin, tx: &'a Tx) -> Result + where + Tx: Datastore + DeltaStore, + { + let lhs = Iter::build(&join.lhs, tx)?; + let rhs_table = tx.table_or_err(join.rhs.table_id)?; + let rhs_index = rhs_table + .get_index(join.rhs_index) + .ok_or_else(|| anyhow!("IndexId `{}` does not exist", join.rhs_index))?; + Ok(Self { + lhs: Box::new(lhs), + rhs_index, + rhs_index_cursor: None, + rhs_table, + blob_store: tx.blob_store(), + lhs_field: &join.lhs_field, + }) + } +} + +impl<'a> Iterator for IxJoinRhs<'a> { + type Item = Row<'a>; fn next(&mut self) -> Option { - self.input - .find(|tuple| self.program.eval(tuple).as_bool().is_some_and(|ok| *ok)) + self.rhs_index_cursor + .as_mut() + .and_then(|cursor| { + cursor + .next() + .and_then(|ptr| self.rhs_table.get_row_ref(self.blob_store, ptr)) + .map(Row::Ptr) + }) + .or_else(|| { + self.lhs.find_map(|tuple| { + let mut cursor = self.rhs_index.seek(&tuple.project(self.lhs_field)); + cursor.next().and_then(|ptr| { + self.rhs_table + .get_row_ref(self.blob_store, ptr) + .map(Row::Ptr) + .map(|ptr| { + self.rhs_index_cursor = Some(cursor); + ptr + }) + }) + }) + }) } } -/// An opcode for a stack-based expression evaluator -#[derive(Clone, Copy)] -pub enum OpCode { - /// == - Eq, - /// <> - Ne, - /// < - Lt, - /// > - Gt, - /// <= - Lte, - /// <= - Gte, - /// AND - And, - /// OR - Or, - /// 5 - Const(u16), - /// || - Concat(u16), - /// r.0 : [Row::Ptr] - PtrProj(u16), - /// r.0 : [Row::Ref] - RefProj(u16), - /// r.0.1 : [Row::Ptr] - TupPtrProj(u16), - /// r.0.1 : [Row::Ref] - TupRefProj(u16), +/// A hash join that on each probe, +/// returns at most one row from the hash table. +pub struct UniqueHashJoin<'a> { + /// The lhs relation + lhs: Box>, + /// The rhs hash table + rhs: HashMap>, + /// The lhs probe field + lhs_field: &'a TupleField, } -static_assert_size!(OpCode, 4); - -/// A program for evaluating a scalar expression -pub struct ExprProgram<'a> { - /// The instructions or opcodes - ops: &'a [OpCode], - /// The constants in the original expression - constants: &'a [AlgebraicValue], +impl<'a> UniqueHashJoin<'a> { + /// Builds a hash table over the rhs + fn build_from(join: &'a HashJoin, tx: &'a Tx) -> Result + where + Tx: Datastore + DeltaStore, + { + let lhs = Iter::build(&join.lhs, tx)?; + let rhs = RowRefIter::build(&join.rhs, tx)?; + let rhs = rhs.map(|ptr| (ptr.project(&join.rhs_field), ptr)).collect(); + Ok(Self { + lhs: Box::new(lhs), + rhs, + lhs_field: &join.lhs_field, + }) + } } -impl<'a> ExprProgram<'a> { - fn new(ops: &'a [OpCode], constants: &'a [AlgebraicValue]) -> Self { - Self { ops, constants } +impl<'a> Iterator for UniqueHashJoin<'a> { + type Item = (Tuple<'a>, Row<'a>); + + fn next(&mut self) -> Option { + self.lhs.find_map(|tuple| { + self.rhs + .get(&tuple.project(self.lhs_field)) + .cloned() + .map(|ptr| (tuple, ptr)) + }) } +} + +/// A left hash semijoin that on each probe, +/// returns at most one row from the hash table. +pub struct UniqueHashJoinLhs<'a> { + /// The lhs relation + lhs: Box>, + /// The rhs hash table + rhs: HashSet, + /// The lhs probe field + lhs_field: &'a TupleField, +} - /// Returns an interator over the opcodes - fn ops(&self) -> impl Iterator + '_ { - self.ops.iter().copied() +impl<'a> UniqueHashJoinLhs<'a> { + /// Builds a hash set over the rhs + fn build_from(join: &'a HashJoin, tx: &'a Tx) -> Result + where + Tx: Datastore + DeltaStore, + { + let lhs = Iter::build(&join.lhs, tx)?; + let rhs = RowRefIter::build(&join.rhs, tx)?; + let rhs = rhs.map(|ptr| ptr.project(&join.rhs_field)).collect(); + Ok(Self { + lhs: Box::new(lhs), + rhs, + lhs_field: &join.lhs_field, + }) } +} + +impl<'a> Iterator for UniqueHashJoinLhs<'a> { + type Item = Tuple<'a>; - /// Lookup a constant in the plan - fn constant(&self, i: u16) -> AlgebraicValue { - self.constants[i as usize].clone() + fn next(&mut self) -> Option { + self.lhs.find(|t| self.rhs.contains(&t.project(self.lhs_field))) } } -/// An evaluator for an [ExprProgram] -pub struct ProgramEvaluator<'a> { - program: ExprProgram<'a>, - stack: Vec, +/// A right hash join that on each probe, +/// returns at most one row from the hash table. +pub struct UniqueHashJoinRhs<'a> { + /// The lhs relation + lhs: Box>, + /// The rhs hash table + rhs: HashMap>, + /// The lhs probe field + lhs_field: &'a TupleField, } -impl<'a> From> for ProgramEvaluator<'a> { - fn from(program: ExprProgram<'a>) -> Self { - Self { program, stack: vec![] } +impl<'a> UniqueHashJoinRhs<'a> { + /// Builds a hash table over the rhs + fn build_from(join: &'a HashJoin, tx: &'a Tx) -> Result + where + Tx: Datastore + DeltaStore, + { + let lhs = Iter::build(&join.lhs, tx)?; + let rhs = RowRefIter::build(&join.rhs, tx)?; + let rhs = rhs.map(|ptr| (ptr.project(&join.rhs_field), ptr)).collect(); + Ok(Self { + lhs: Box::new(lhs), + rhs, + lhs_field: &join.lhs_field, + }) } } -impl ProgramEvaluator<'_> { - pub fn eval(&mut self, tuple: &Tuple) -> AlgebraicValue { - for op in self.program.ops() { - match op { - OpCode::Const(i) => { - self.stack.push(self.program.constant(i)); - } - OpCode::Eq => { - let r = self.stack.pop().unwrap(); - let l = self.stack.pop().unwrap(); - self.stack.push(AlgebraicValue::Bool(l == r)); - } - OpCode::Ne => { - let r = self.stack.pop().unwrap(); - let l = self.stack.pop().unwrap(); - self.stack.push(AlgebraicValue::Bool(l != r)); - } - OpCode::Lt => { - let r = self.stack.pop().unwrap(); - let l = self.stack.pop().unwrap(); - self.stack.push(AlgebraicValue::Bool(l < r)); - } - OpCode::Gt => { - let r = self.stack.pop().unwrap(); - let l = self.stack.pop().unwrap(); - self.stack.push(AlgebraicValue::Bool(l > r)); - } - OpCode::Lte => { - let r = self.stack.pop().unwrap(); - let l = self.stack.pop().unwrap(); - self.stack.push(AlgebraicValue::Bool(l <= r)); - } - OpCode::Gte => { - let r = self.stack.pop().unwrap(); - let l = self.stack.pop().unwrap(); - self.stack.push(AlgebraicValue::Bool(l >= r)); - } - OpCode::And => { - let r = *self.stack.pop().unwrap().as_bool().unwrap(); - let l = *self.stack.pop().unwrap().as_bool().unwrap(); - self.stack.push(AlgebraicValue::Bool(l && r)); - } - OpCode::Or => { - let r = *self.stack.pop().unwrap().as_bool().unwrap(); - let l = *self.stack.pop().unwrap().as_bool().unwrap(); - self.stack.push(AlgebraicValue::Bool(l || r)); - } - OpCode::Concat(n) => { - let mut elems = Vec::with_capacity(n as usize); - // Pop args off stack - for _ in 0..n { - elems.push(self.stack.pop().unwrap()); - } - // Concat and push on stack - self.stack.push(AlgebraicValue::Product(ProductValue::from_iter( - elems.into_iter().rev(), - ))); - } - OpCode::PtrProj(i) => { - self.stack.push( - tuple - // Read field from row ref - .expect_row() - .expect_ptr() - .read_col(i as usize) - .unwrap(), - ); +impl<'a> Iterator for UniqueHashJoinRhs<'a> { + type Item = Row<'a>; + + fn next(&mut self) -> Option { + self.lhs.find_map(|t| self.rhs.get(&t.project(self.lhs_field)).cloned()) + } +} + +/// A hash join that on each probe, +/// may return many rows from the hash table. +pub struct HashJoinIter<'a> { + /// The lhs relation + lhs: Box>, + /// The rhs hash table + rhs: HashMap>>, + /// The current lhs tuple + lhs_tuple: Option>, + /// The current rhs row pointer + rhs_ptr: usize, + /// The lhs probe field + lhs_field: &'a TupleField, +} + +impl<'a> HashJoinIter<'a> { + /// Builds a hash table over the rhs + fn build_from(join: &'a HashJoin, tx: &'a Tx) -> Result + where + Tx: Datastore + DeltaStore, + { + let lhs = Iter::build(&join.lhs, tx)?; + let rhs_iter = RowRefIter::build(&join.rhs, tx)?; + let mut rhs = HashMap::new(); + for ptr in rhs_iter { + let val = ptr.project(&join.rhs_field); + match rhs.get_mut(&val) { + None => { + rhs.insert(val, vec![ptr]); } - OpCode::RefProj(i) => { - self.stack.push( - tuple - // Read field from product ref - .expect_row() - .expect_ref() - .elements[i as usize] - .clone(), - ); + Some(ptrs) => { + ptrs.push(ptr); } - OpCode::TupPtrProj(i) => { - let idx = *self - // Pop index off stack - .stack - .pop() - .unwrap() - .as_u16() - .unwrap(); - self.stack.push( - tuple - // Read field from row ref - .expect_join()[idx as usize] - .expect_ptr() - .read_col(i as usize) - .unwrap(), - ); + } + } + Ok(Self { + lhs: Box::new(lhs), + rhs, + lhs_tuple: None, + rhs_ptr: 0, + lhs_field: &join.lhs_field, + }) + } +} + +impl<'a> Iterator for HashJoinIter<'a> { + type Item = (Tuple<'a>, Row<'a>); + + fn next(&mut self) -> Option { + self.lhs_tuple + .as_ref() + .and_then(|tuple| { + self.rhs.get(&tuple.project(self.lhs_field)).and_then(|ptrs| { + let i = self.rhs_ptr; + self.rhs_ptr += 1; + ptrs.get(i).map(|ptr| (tuple.clone(), ptr.clone())) + }) + }) + .or_else(|| { + self.lhs.find_map(|tuple| { + self.rhs.get(&tuple.project(self.lhs_field)).and_then(|ptrs| { + self.rhs_ptr = 1; + self.lhs_tuple = Some(tuple.clone()); + ptrs.first().map(|ptr| (tuple, ptr.clone())) + }) + }) + }) + } +} + +/// A left hash semijoin that on each probe, +/// may return many rows from the hash table. +pub struct HashJoinLhs<'a> { + /// The lhs relation + lhs: Box>, + /// The rhs hash table + rhs: HashMap, + /// The current lhs tuple + lhs_tuple: Option>, + /// The matching number of rhs rows + rhs_count: usize, + /// The lhs probe field + lhs_field: &'a TupleField, +} + +impl<'a> HashJoinLhs<'a> { + /// Instantiates the iterator by building a hash table over the rhs + fn build_from(join: &'a HashJoin, tx: &'a Tx) -> Result + where + Tx: Datastore + DeltaStore, + { + let lhs = Iter::build(&join.lhs, tx)?; + let rhs_iter = RowRefIter::build(&join.rhs, tx)?; + let mut rhs = HashMap::new(); + for ptr in rhs_iter { + rhs.entry(ptr.project(&join.rhs_field)) + .and_modify(|n| *n += 1) + .or_insert_with(|| 1); + } + Ok(Self { + lhs: Box::new(lhs), + rhs, + lhs_tuple: None, + rhs_count: 0, + lhs_field: &join.lhs_field, + }) + } +} + +impl<'a> Iterator for HashJoinLhs<'a> { + type Item = Tuple<'a>; + + fn next(&mut self) -> Option { + match self.rhs_count { + 0 => self.lhs.find_map(|tuple| { + self.rhs.get(&tuple.project(self.lhs_field)).map(|n| { + self.rhs_count = *n - 1; + self.lhs_tuple = Some(tuple.clone()); + tuple + }) + }), + _ => { + self.rhs_count -= 1; + self.lhs_tuple.clone() + } + } + } +} + +/// A right hash semijoin that on each probe, +/// may return many rows from the hash table. +pub struct HashJoinRhs<'a> { + /// The lhs relation + lhs: Box>, + /// The rhs hash table + rhs: HashMap>>, + /// The current lhs tuple + lhs_value: Option, + /// The current rhs row pointer + rhs_ptr: usize, + /// The lhs probe field + lhs_field: &'a TupleField, +} + +impl<'a> HashJoinRhs<'a> { + /// Builds a hash table over the rhs + fn build_from(join: &'a HashJoin, tx: &'a Tx) -> Result + where + Tx: Datastore + DeltaStore, + { + let lhs = Iter::build(&join.lhs, tx)?; + let rhs_iter = RowRefIter::build(&join.rhs, tx)?; + let mut rhs = HashMap::new(); + for ptr in rhs_iter { + let val = ptr.project(&join.rhs_field); + match rhs.get_mut(&val) { + None => { + rhs.insert(val, vec![ptr]); } - OpCode::TupRefProj(i) => { - let idx = *self - // Pop index off stack - .stack - .pop() - .unwrap() - .as_u16() - .unwrap(); - self.stack.push( - tuple - // Read field from product ref - .expect_join()[idx as usize] - .expect_ptr() - .read_col(i as usize) - .unwrap(), - ); + Some(ptrs) => { + ptrs.push(ptr); } } } - self.stack.pop().unwrap() + Ok(Self { + lhs: Box::new(lhs), + rhs, + lhs_value: None, + rhs_ptr: 0, + lhs_field: &join.lhs_field, + }) + } +} + +impl<'a> Iterator for HashJoinRhs<'a> { + type Item = Row<'a>; + + fn next(&mut self) -> Option { + self.lhs_value + .as_ref() + .and_then(|value| { + self.rhs.get(value).and_then(|ptrs| { + let i = self.rhs_ptr; + self.rhs_ptr += 1; + ptrs.get(i).cloned() + }) + }) + .or_else(|| { + self.lhs.find_map(|tuple| { + let value = tuple.project(self.lhs_field); + self.rhs.get(&value).and_then(|ptrs| { + self.rhs_ptr = 1; + self.lhs_value = Some(value.clone()); + ptrs.first().cloned() + }) + }) + }) + } +} + +/// A nested loop join iterator +pub struct NLJoin<'a> { + /// The lhs input + lhs: Box>, + /// The materialized rhs + rhs: Vec>, + /// The current lhs tuple + lhs_tuple: Option>, + /// The current rhs row pointer + rhs_ptr: usize, +} + +impl<'a> NLJoin<'a> { + /// Instantiates the iterator by materializing the rhs + fn build_from(lhs: &'a PhysicalPlan, rhs: &'a PhysicalPlan, tx: &'a Tx) -> Result + where + Tx: Datastore + DeltaStore, + { + let lhs = Iter::build(lhs, tx)?; + let rhs = RowRefIter::build(rhs, tx)?; + Ok(Self { + lhs: Box::new(lhs), + rhs: rhs.collect(), + lhs_tuple: None, + rhs_ptr: 0, + }) + } +} + +impl<'a> Iterator for NLJoin<'a> { + type Item = (Tuple<'a>, Row<'a>); + + fn next(&mut self) -> Option { + match self.rhs.get(self.rhs_ptr) { + Some(v) => { + self.rhs_ptr += 1; + self.lhs_tuple.as_ref().map(|u| (u.clone(), v.clone())) + } + None => { + self.rhs_ptr = 1; + self.lhs_tuple = self.lhs.next(); + self.lhs_tuple + .as_ref() + .zip(self.rhs.first()) + .map(|(u, v)| (u.clone(), v.clone())) + } + } + } +} + +/// A tuple-at-a-time filter iterator +pub struct Filter<'a, I> { + input: Box, + expr: &'a PhysicalExpr, +} + +impl<'a> Iterator for Filter<'a, RowRefIter<'a>> { + type Item = Row<'a>; + + fn next(&mut self) -> Option { + self.input.find(|ptr| self.expr.eval_bool(ptr)) + } +} + +impl<'a> Iterator for Filter<'a, Iter<'a>> { + type Item = Tuple<'a>; + + fn next(&mut self) -> Option { + self.input.find(|tuple| self.expr.eval_bool(tuple)) } } diff --git a/crates/execution/src/lib.rs b/crates/execution/src/lib.rs index 9708a97a673..172d5682f8b 100644 --- a/crates/execution/src/lib.rs +++ b/crates/execution/src/lib.rs @@ -1 +1,176 @@ +use std::ops::RangeBounds; + +use anyhow::{anyhow, Result}; +use iter::PlanIter; +use spacetimedb_lib::{ + bsatn::{EncodeError, ToBsatn}, + query::Delta, + ser::Serialize, + AlgebraicValue, ProductValue, +}; +use spacetimedb_physical_plan::plan::{ProjectField, ProjectPlan, TupleField}; +use spacetimedb_primitives::{IndexId, TableId}; +use spacetimedb_table::{ + blob_store::BlobStore, + static_assert_size, + table::{IndexScanIter, RowRef, Table, TableScanIter}, +}; + pub mod iter; + +/// The datastore interface required for building an executor +pub trait Datastore { + fn table(&self, table_id: TableId) -> Option<&Table>; + fn blob_store(&self) -> &dyn BlobStore; + + fn table_or_err(&self, table_id: TableId) -> Result<&Table> { + self.table(table_id) + .ok_or_else(|| anyhow!("TableId `{table_id}` does not exist")) + } + + fn table_scan(&self, table_id: TableId) -> Result { + self.table(table_id) + .map(|table| table.scan_rows(self.blob_store())) + .ok_or_else(|| anyhow!("TableId `{table_id}` does not exist")) + } + + fn index_scan( + &self, + table_id: TableId, + index_id: IndexId, + range: &impl RangeBounds, + ) -> Result { + self.table(table_id) + .ok_or_else(|| anyhow!("TableId `{table_id}` does not exist")) + .and_then(|table| { + table + .index_seek_by_id(self.blob_store(), index_id, range) + .ok_or_else(|| anyhow!("IndexId `{index_id}` does not exist")) + }) + } +} + +pub trait DeltaStore { + fn has_inserts(&self, table_id: TableId) -> Option; + fn has_deletes(&self, table_id: TableId) -> Option; + + fn inserts_for_table(&self, table_id: TableId) -> Option>; + fn deletes_for_table(&self, table_id: TableId) -> Option>; + + fn delta_scan(&self, table_id: TableId, inserts: bool) -> Result { + match inserts { + true => self + .inserts_for_table(table_id) + .ok_or_else(|| anyhow!("TableId `{table_id}` does not exist")) + .map(|iter| DeltaScanIter { iter }), + false => self + .deletes_for_table(table_id) + .ok_or_else(|| anyhow!("TableId `{table_id}` does not exist")) + .map(|iter| DeltaScanIter { iter }), + } + } +} + +#[derive(Clone, Serialize)] +pub enum Row<'a> { + Ptr(RowRef<'a>), + Ref(&'a ProductValue), +} + +impl ToBsatn for Row<'_> { + fn static_bsatn_size(&self) -> Option { + match self { + Self::Ptr(ptr) => ptr.static_bsatn_size(), + Self::Ref(val) => val.static_bsatn_size(), + } + } + + fn to_bsatn_extend(&self, buf: &mut Vec) -> std::result::Result<(), EncodeError> { + match self { + Self::Ptr(ptr) => ptr.to_bsatn_extend(buf), + Self::Ref(val) => val.to_bsatn_extend(buf), + } + } + + fn to_bsatn_vec(&self) -> std::result::Result, EncodeError> { + match self { + Self::Ptr(ptr) => ptr.to_bsatn_vec(), + Self::Ref(val) => val.to_bsatn_vec(), + } + } +} + +impl ProjectField for Row<'_> { + fn project(&self, field: &TupleField) -> AlgebraicValue { + match self { + Self::Ptr(ptr) => ptr.read_col(field.field_pos).unwrap(), + Self::Ref(val) => val.elements.get(field.field_pos).unwrap().clone(), + } + } +} + +/// Each query operator returns a tuple of [RowRef]s +#[derive(Clone)] +pub enum Tuple<'a> { + /// A pointer to a row in a base table + Row(Row<'a>), + /// A temporary returned by a join operator + Join(Vec>), +} + +static_assert_size!(Tuple, 40); + +impl ProjectField for Tuple<'_> { + fn project(&self, field: &TupleField) -> AlgebraicValue { + match self { + Self::Row(row) => row.project(field), + Self::Join(ptrs) => field + .label_pos + .and_then(|i| ptrs.get(i)) + .map(|ptr| ptr.project(field)) + .unwrap(), + } + } +} + +impl<'a> Tuple<'a> { + /// Select the tuple element at position `i` + fn select(self, i: usize) -> Option> { + match self { + Self::Row(_) => None, + Self::Join(mut ptrs) => Some(ptrs.swap_remove(i)), + } + } + + /// Append a [Row] to a tuple + fn append(self, ptr: Row<'a>) -> Self { + match self { + Self::Row(row) => Self::Join(vec![row, ptr]), + Self::Join(mut rows) => { + rows.push(ptr); + Self::Join(rows) + } + } + } +} + +pub struct DeltaScanIter<'a> { + iter: std::slice::Iter<'a, ProductValue>, +} + +impl<'a> Iterator for DeltaScanIter<'a> { + type Item = &'a ProductValue; + + fn next(&mut self) -> Option { + self.iter.next() + } +} + +/// Execute a query plan. +/// The actual execution is driven by `f`. +pub fn execute_plan(plan: &ProjectPlan, tx: &T, f: impl Fn(PlanIter) -> R) -> Result +where + T: Datastore + DeltaStore, +{ + PlanIter::build(plan, tx).map(f) +} diff --git a/crates/expr/src/check.rs b/crates/expr/src/check.rs index 5fa2cb0d7d8..4e8aed7d478 100644 --- a/crates/expr/src/check.rs +++ b/crates/expr/src/check.rs @@ -2,9 +2,10 @@ use std::collections::HashMap; use std::ops::{Deref, DerefMut}; use std::sync::Arc; -use crate::expr::{Expr, Project}; +use crate::expr::{Expr, ProjectList, ProjectName, Relvar}; use crate::{expr::LeftDeepJoin, statement::Statement}; use spacetimedb_lib::AlgebraicType; +use spacetimedb_primitives::TableId; use spacetimedb_schema::schema::TableSchema; use spacetimedb_sql_parser::ast::BinOp; use spacetimedb_sql_parser::{ @@ -23,7 +24,12 @@ pub type TypingResult = core::result::Result; /// A view of the database schema pub trait SchemaView { - fn schema(&self, name: &str) -> Option>; + fn table_id(&self, name: &str) -> Option; + fn schema_for_table(&self, table_id: TableId) -> Option>; + + fn schema(&self, name: &str) -> Option> { + self.table_id(name).and_then(|table_id| self.schema_for_table(table_id)) + } } #[derive(Default)] @@ -46,21 +52,29 @@ pub trait TypeChecker { type Ast; type Set; - fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult; + fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult; - fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult; + fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult; fn type_from(from: SqlFrom, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult { match from { SqlFrom::Expr(SqlIdent(name), SqlIdent(alias)) => { let schema = Self::type_relvar(tx, &name)?; vars.insert(alias.clone(), schema.clone()); - Ok(RelExpr::RelVar(schema, alias)) + Ok(RelExpr::RelVar(Relvar { + schema, + alias, + delta: None, + })) } SqlFrom::Join(SqlIdent(name), SqlIdent(alias), joins) => { let schema = Self::type_relvar(tx, &name)?; vars.insert(alias.clone(), schema.clone()); - let mut join = RelExpr::RelVar(schema, alias); + let mut join = RelExpr::RelVar(Relvar { + schema, + alias, + delta: None, + }); for SqlJoin { var: SqlIdent(name), @@ -73,23 +87,26 @@ pub trait TypeChecker { return Err(DuplicateName(alias.into_string()).into()); } - let rhs = Self::type_relvar(tx, &name)?; let lhs = Box::new(join); - let var = alias; + let rhs = Relvar { + schema: Self::type_relvar(tx, &name)?, + alias, + delta: None, + }; - vars.insert(var.clone(), rhs.clone()); + vars.insert(rhs.alias.clone(), rhs.schema.clone()); if let Some(on) = on { if let Expr::BinOp(BinOp::Eq, a, b) = type_expr(vars, on, Some(&AlgebraicType::Bool))? { if let (Expr::Field(a), Expr::Field(b)) = (*a, *b) { - join = RelExpr::EqJoin(LeftDeepJoin { lhs, rhs, var }, a, b); + join = RelExpr::EqJoin(LeftDeepJoin { lhs, rhs }, a, b); continue; } } unreachable!("Unreachability guaranteed by parser") } - join = RelExpr::LeftDeepJoin(LeftDeepJoin { lhs, rhs, var }); + join = RelExpr::LeftDeepJoin(LeftDeepJoin { lhs, rhs }); } Ok(join) @@ -111,11 +128,11 @@ impl TypeChecker for SubChecker { type Ast = SqlSelect; type Set = SqlSelect; - fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult { + fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult { Self::type_set(ast, &mut Relvars::default(), tx) } - fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult { + fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult { match ast { SqlSelect { project, @@ -138,26 +155,30 @@ impl TypeChecker for SubChecker { } /// Parse and type check a subscription query -pub fn parse_and_type_sub(sql: &str, tx: &impl SchemaView) -> TypingResult { +pub fn parse_and_type_sub(sql: &str, tx: &impl SchemaView) -> TypingResult { expect_table_type(SubChecker::type_ast(parse_subscription(sql)?, tx)?) } +/// Type check a subscription query +pub fn type_subscription(ast: SqlSelect, tx: &impl SchemaView) -> TypingResult { + expect_table_type(SubChecker::type_ast(ast, tx)?) +} + /// Parse and type check a *subscription* query into a `StatementCtx` pub fn compile_sql_sub<'a>(sql: &'a str, tx: &impl SchemaView) -> TypingResult> { - let expr = parse_and_type_sub(sql, tx)?; Ok(StatementCtx { - statement: Statement::Select(expr), + statement: Statement::Select(ProjectList::Name(parse_and_type_sub(sql, tx)?)), sql, source: StatementSource::Subscription, }) } /// Returns an error if the input type is not a table type or relvar -fn expect_table_type(expr: Project) -> TypingResult { - if let Project::Fields(..) = expr { - return Err(Unsupported::ReturnType.into()); +fn expect_table_type(expr: ProjectList) -> TypingResult { + match expr { + ProjectList::Name(proj) => Ok(proj), + ProjectList::List(..) => Err(Unsupported::ReturnType.into()), } - Ok(expr) } pub mod test_utils { @@ -182,14 +203,24 @@ pub mod test_utils { pub struct SchemaViewer(pub ModuleDef); impl SchemaView for SchemaViewer { - fn schema(&self, name: &str) -> Option> { - self.0.table(name).map(|def| { - Arc::new(TableSchema::from_module_def( - &self.0, - def, - (), - TableId(if *def.name == *"t" { 0 } else { 1 }), - )) + fn table_id(&self, name: &str) -> Option { + match name { + "t" => Some(TableId(0)), + "s" => Some(TableId(1)), + _ => None, + } + } + + fn schema_for_table(&self, table_id: TableId) -> Option> { + match table_id.idx() { + 0 => Some((TableId(0), "t")), + 1 => Some((TableId(1), "s")), + _ => None, + } + .and_then(|(table_id, name)| { + self.0 + .table(name) + .map(|def| Arc::new(TableSchema::from_module_def(&self.0, def, (), table_id))) }) } } diff --git a/crates/expr/src/expr.rs b/crates/expr/src/expr.rs index 3283d6cd1c2..f8c6455e1ee 100644 --- a/crates/expr/src/expr.rs +++ b/crates/expr/src/expr.rs @@ -1,25 +1,66 @@ use std::sync::Arc; -use spacetimedb_lib::{AlgebraicType, AlgebraicValue}; +use spacetimedb_lib::{query::Delta, AlgebraicType, AlgebraicValue}; use spacetimedb_primitives::TableId; use spacetimedb_schema::schema::TableSchema; use spacetimedb_sql_parser::ast::{BinOp, LogOp}; -/// A projection is the root of any relation expression +/// A projection is the root of any relational expression. +/// This type represents a projection that returns relvars. +/// +/// For example: +/// +/// ```sql +/// select * from t +/// ``` +/// +/// and +/// +/// ```sql +/// select t.* from t join s ... +/// ``` #[derive(Debug)] -pub enum Project { +pub enum ProjectName { None(RelExpr), - Relvar(RelExpr, Box), - Fields(RelExpr, Vec<(Box, FieldProject)>), + Some(RelExpr, Box), } -impl Project { +impl ProjectName { /// What is the [TableId] for this projection? pub fn table_id(&self) -> Option { match self { - Self::Fields(..) => None, - Self::Relvar(input, var) => input.table_id(Some(var.as_ref())), Self::None(input) => input.table_id(None), + Self::Some(input, var) => input.table_id(Some(var.as_ref())), + } + } +} + +/// A projection is the root of any relational expression. +/// This type represents a projection that returns fields. +/// +/// For example: +/// +/// ```sql +/// select a, b from t +/// ``` +/// +/// and +/// +/// ```sql +/// select t.a as x from t join s ... +/// ``` +#[derive(Debug)] +pub enum ProjectList { + Name(ProjectName), + List(RelExpr, Vec<(Box, FieldProject)>), +} + +impl ProjectList { + /// What is the [TableId] for this projection? + pub fn table_id(&self) -> Option { + match self { + Self::List(..) => None, + Self::Name(proj) => proj.table_id(), } } } @@ -28,7 +69,7 @@ impl Project { #[derive(Debug)] pub enum RelExpr { /// A relvar or table reference - RelVar(Arc, Box), + RelVar(Relvar), /// A logical select for filter Select(Box, Expr), /// A left deep binary cross product @@ -37,6 +78,15 @@ pub enum RelExpr { EqJoin(LeftDeepJoin, FieldProject, FieldProject), } +/// A table reference +#[derive(Debug)] +pub struct Relvar { + pub schema: Arc, + pub alias: Box, + /// Does this relvar represent a delta table? + pub delta: Option, +} + impl RelExpr { /// The number of fields this expression returns pub fn nfields(&self) -> usize { @@ -50,9 +100,9 @@ impl RelExpr { /// Does this expression return this field? pub fn has_field(&self, field: &str) -> bool { match self { - Self::RelVar(_, name) => name.as_ref() == field, + Self::RelVar(Relvar { alias, .. }) => alias.as_ref() == field, Self::LeftDeepJoin(join) | Self::EqJoin(join, ..) => { - join.var.as_ref() == field || join.lhs.has_field(field) + join.rhs.alias.as_ref() == field || join.lhs.has_field(field) } Self::Select(input, _) => input.has_field(field), } @@ -61,14 +111,14 @@ impl RelExpr { /// What is the [TableId] for this expression or relvar? pub fn table_id(&self, var: Option<&str>) -> Option { match (self, var) { - (Self::RelVar(schema, _), None) => Some(schema.table_id), - (Self::RelVar(schema, name), Some(var)) if name.as_ref() == var => Some(schema.table_id), - (Self::RelVar(schema, _), Some(_)) => Some(schema.table_id), + (Self::RelVar(Relvar { schema, .. }), None) => Some(schema.table_id), + (Self::RelVar(Relvar { schema, alias, .. }), Some(var)) if alias.as_ref() == var => Some(schema.table_id), + (Self::RelVar(Relvar { schema, .. }), Some(_)) => Some(schema.table_id), (Self::Select(input, _), _) => input.table_id(var), (Self::LeftDeepJoin(..) | Self::EqJoin(..), None) => None, (Self::LeftDeepJoin(join) | Self::EqJoin(join, ..), Some(name)) => { - if join.var.as_ref() == name { - Some(join.rhs.table_id) + if join.rhs.alias.as_ref() == name { + Some(join.rhs.schema.table_id) } else { join.lhs.table_id(var) } @@ -83,9 +133,7 @@ pub struct LeftDeepJoin { /// The lhs is recursive pub lhs: Box, /// The rhs is a relvar - pub rhs: Arc, - /// The rhs relvar name - pub var: Box, + pub rhs: Relvar, } /// A typed scalar expression diff --git a/crates/expr/src/lib.rs b/crates/expr/src/lib.rs index a763780e28a..fa13b349bf1 100644 --- a/crates/expr/src/lib.rs +++ b/crates/expr/src/lib.rs @@ -3,7 +3,7 @@ use std::collections::HashSet; use crate::statement::Statement; use check::{Relvars, TypingResult}; use errors::{DuplicateName, InvalidLiteral, InvalidOp, InvalidWildcard, UnexpectedType, Unresolved}; -use expr::{Expr, FieldProject, Project, RelExpr}; +use expr::{Expr, FieldProject, ProjectList, ProjectName, RelExpr}; use spacetimedb_lib::{from_hex_pad, Address, AlgebraicType, AlgebraicValue, Identity}; use spacetimedb_schema::schema::ColumnSchema; use spacetimedb_sql_parser::ast::{self, BinOp, ProjectElem, SqlExpr, SqlIdent, SqlLiteral}; @@ -22,11 +22,13 @@ pub(crate) fn type_select(input: RelExpr, expr: SqlExpr, vars: &Relvars) -> Typi } /// Type check and lower a [ast::Project] -pub(crate) fn type_proj(input: RelExpr, proj: ast::Project, vars: &Relvars) -> TypingResult { +pub(crate) fn type_proj(input: RelExpr, proj: ast::Project, vars: &Relvars) -> TypingResult { match proj { ast::Project::Star(None) if input.nfields() > 1 => Err(InvalidWildcard::Join.into()), - ast::Project::Star(None) => Ok(Project::None(input)), - ast::Project::Star(Some(SqlIdent(var))) if input.has_field(&var) => Ok(Project::Relvar(input, var)), + ast::Project::Star(None) => Ok(ProjectList::Name(ProjectName::None(input))), + ast::Project::Star(Some(SqlIdent(var))) if input.has_field(&var) => { + Ok(ProjectList::Name(ProjectName::Some(input, var))) + } ast::Project::Star(Some(SqlIdent(var))) => Err(Unresolved::var(&var).into()), ast::Project::Exprs(elems) => { let mut projections = vec![]; @@ -42,7 +44,7 @@ pub(crate) fn type_proj(input: RelExpr, proj: ast::Project, vars: &Relvars) -> T } } - Ok(Project::Fields(input, projections)) + Ok(ProjectList::List(input, projections)) } } } diff --git a/crates/expr/src/statement.rs b/crates/expr/src/statement.rs index b9fe8ca8069..90f182457f6 100644 --- a/crates/expr/src/statement.rs +++ b/crates/expr/src/statement.rs @@ -12,7 +12,7 @@ use spacetimedb_sql_parser::{ }; use thiserror::Error; -use crate::{check::Relvars, expr::Project}; +use crate::{check::Relvars, expr::ProjectList}; use super::{ check::{SchemaView, TypeChecker, TypingResult}, @@ -22,7 +22,7 @@ use super::{ }; pub enum Statement { - Select(Project), + Select(ProjectList), Insert(TableInsert), Update(TableUpdate), Delete(TableDelete), @@ -240,11 +240,11 @@ impl TypeChecker for SqlChecker { type Ast = SqlSelect; type Set = SqlSelect; - fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult { + fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult { Self::type_set(ast, &mut Relvars::default(), tx) } - fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult { + fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult { match ast { SqlSelect { project, diff --git a/crates/lib/src/lib.rs b/crates/lib/src/lib.rs index 578dd6682b8..643c47cd0de 100644 --- a/crates/lib/src/lib.rs +++ b/crates/lib/src/lib.rs @@ -11,6 +11,7 @@ pub mod db; pub mod error; pub mod identity; pub mod operator; +pub mod query; pub mod relation; pub mod scheduler; pub mod version; diff --git a/crates/lib/src/query.rs b/crates/lib/src/query.rs new file mode 100644 index 00000000000..c0df5a87a58 --- /dev/null +++ b/crates/lib/src/query.rs @@ -0,0 +1,6 @@ +/// A type used by the query planner for incremental evaluation +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Delta { + Inserts(usize), + Deletes(usize), +} diff --git a/crates/physical-plan/Cargo.toml b/crates/physical-plan/Cargo.toml index 247d06a11ca..fb9f62a4a98 100644 --- a/crates/physical-plan/Cargo.toml +++ b/crates/physical-plan/Cargo.toml @@ -7,9 +7,14 @@ license-file = "LICENSE" description = "The physical query plan for the SpacetimeDB query engine" [dependencies] +anyhow.workspace = true derive_more.workspace = true spacetimedb-lib.workspace = true spacetimedb-primitives.workspace = true spacetimedb-schema.workspace = true spacetimedb-expr.workspace = true spacetimedb-sql-parser.workspace = true +spacetimedb-table.workspace = true + +[dev-dependencies] +pretty_assertions.workspace = true diff --git a/crates/physical-plan/src/compile.rs b/crates/physical-plan/src/compile.rs index 160ada0ed8c..d79ed4094c1 100644 --- a/crates/physical-plan/src/compile.rs +++ b/crates/physical-plan/src/compile.rs @@ -2,8 +2,11 @@ use std::collections::HashMap; -use crate::plan::{HashJoin, Label, PhysicalCtx, PhysicalExpr, PhysicalPlan, PhysicalProject, ProjectField, Semi}; -use spacetimedb_expr::expr::{Expr, FieldProject, LeftDeepJoin, Project, RelExpr}; +use crate::plan::{ + HashJoin, Label, PhysicalCtx, PhysicalExpr, PhysicalPlan, ProjectListPlan, ProjectPlan, Semi, TupleField, +}; + +use spacetimedb_expr::expr::{Expr, FieldProject, LeftDeepJoin, ProjectList, ProjectName, RelExpr, Relvar}; use spacetimedb_expr::statement::Statement; use spacetimedb_expr::StatementCtx; @@ -24,13 +27,12 @@ fn compile_expr(expr: Expr, var: &mut impl VarLabel) -> PhysicalExpr { } } -fn compile_project(var: &mut impl VarLabel, expr: Project) -> PhysicalProject { +fn compile_project_list(var: &mut impl VarLabel, expr: ProjectList) -> ProjectListPlan { match expr { - Project::None(input) => PhysicalProject::None(compile_rel_expr(var, input)), - Project::Relvar(input, name) => PhysicalProject::Relvar(compile_rel_expr(var, input), var.label(&name)), - Project::Fields(input, exprs) => PhysicalProject::Fields( - compile_rel_expr(var, input), - exprs + ProjectList::Name(proj) => ProjectListPlan::Name(compile_project_name(var, proj)), + ProjectList::List(proj, fields) => ProjectListPlan::List( + compile_rel_expr(var, proj), + fields .into_iter() .map(|(alias, expr)| (alias, compile_field_project(var, expr))) .collect(), @@ -38,49 +40,75 @@ fn compile_project(var: &mut impl VarLabel, expr: Project) -> PhysicalProject { } } -fn compile_field_project(var: &mut impl VarLabel, expr: FieldProject) -> ProjectField { - ProjectField { - var: var.label(&expr.table), - pos: expr.field, +fn compile_project_name(var: &mut impl VarLabel, proj: ProjectName) -> ProjectPlan { + match proj { + ProjectName::None(input) => ProjectPlan::None(compile_rel_expr(var, input)), + ProjectName::Some(input, name) => ProjectPlan::Name(compile_rel_expr(var, input), var.label(&name), None), + } +} + +fn compile_field_project(var: &mut impl VarLabel, expr: FieldProject) -> TupleField { + TupleField { + label: var.label(&expr.table), + label_pos: None, + field_pos: expr.field, } } fn compile_rel_expr(var: &mut impl VarLabel, ast: RelExpr) -> PhysicalPlan { match ast { - RelExpr::RelVar(table, name) => { - let label = var.label(name.as_ref()); - PhysicalPlan::TableScan(table, label) + RelExpr::RelVar(Relvar { schema, alias, delta }) => { + let label = var.label(alias.as_ref()); + PhysicalPlan::TableScan(schema, label, delta) } RelExpr::Select(input, expr) => { let input = compile_rel_expr(var, *input); let input = Box::new(input); PhysicalPlan::Filter(input, compile_expr(expr, var)) } - RelExpr::EqJoin(join, FieldProject { table: u, field: a, .. }, FieldProject { table: v, field: b, .. }) => { - PhysicalPlan::HashJoin( - HashJoin { - lhs: Box::new(compile_rel_expr(var, *join.lhs)), - rhs: Box::new(PhysicalPlan::TableScan(join.rhs, var.label(&join.var))), - lhs_field: ProjectField { - var: var.label(u.as_ref()), - pos: a, + RelExpr::EqJoin( + LeftDeepJoin { + lhs, + rhs: + Relvar { + schema: rhs_schema, + alias: rhs_alias, + delta, + .. }, - rhs_field: ProjectField { - var: var.label(v.as_ref()), - pos: b, - }, - unique: false, + }, + FieldProject { table: u, field: a, .. }, + FieldProject { table: v, field: b, .. }, + ) => PhysicalPlan::HashJoin( + HashJoin { + lhs: Box::new(compile_rel_expr(var, *lhs)), + rhs: Box::new(PhysicalPlan::TableScan(rhs_schema, var.label(&rhs_alias), delta)), + lhs_field: TupleField { + label: var.label(u.as_ref()), + label_pos: None, + field_pos: a, }, - Semi::All, - ) - } + rhs_field: TupleField { + label: var.label(v.as_ref()), + label_pos: None, + field_pos: b, + }, + unique: false, + }, + Semi::All, + ), RelExpr::LeftDeepJoin(LeftDeepJoin { lhs, - rhs, - var: rhs_name, + rhs: + Relvar { + schema: rhs_schema, + alias: rhs_alias, + delta, + .. + }, }) => { let lhs = compile_rel_expr(var, *lhs); - let rhs = PhysicalPlan::TableScan(rhs, var.label(rhs_name.as_ref())); + let rhs = PhysicalPlan::TableScan(rhs_schema, var.label(&rhs_alias), delta); let lhs = Box::new(lhs); let rhs = Box::new(rhs); PhysicalPlan::NLJoin(lhs, rhs) @@ -88,6 +116,31 @@ fn compile_rel_expr(var: &mut impl VarLabel, ast: RelExpr) -> PhysicalPlan { } } +/// Compile a logical subscribe expression +pub fn compile_project_plan(project: ProjectName) -> ProjectPlan { + struct Interner { + next: usize, + names: HashMap, + } + impl VarLabel for Interner { + fn label(&mut self, name: &str) -> Label { + if let Some(id) = self.names.get(name) { + return Label(*id); + } + self.next += 1; + self.names.insert(name.to_owned(), self.next); + self.next.into() + } + } + compile_project_name( + &mut Interner { + next: 0, + names: HashMap::new(), + }, + project, + ) +} + /// Compile a SQL statement into a physical plan. /// /// The input [Statement] is assumed to be valid so the lowering is not expected to fail. @@ -109,7 +162,7 @@ pub fn compile(ast: StatementCtx<'_>) -> PhysicalCtx<'_> { } } let plan = match ast.statement { - Statement::Select(expr) => compile_project( + Statement::Select(expr) => compile_project_list( &mut Interner { next: 0, names: HashMap::new(), diff --git a/crates/physical-plan/src/lib.rs b/crates/physical-plan/src/lib.rs index b79989e66e4..b1ea3c965a4 100644 --- a/crates/physical-plan/src/lib.rs +++ b/crates/physical-plan/src/lib.rs @@ -1,2 +1,3 @@ pub mod compile; pub mod plan; +pub mod rules; diff --git a/crates/physical-plan/src/plan.rs b/crates/physical-plan/src/plan.rs index 99e10b1e914..66ca4a703da 100644 --- a/crates/physical-plan/src/plan.rs +++ b/crates/physical-plan/src/plan.rs @@ -1,52 +1,155 @@ -use std::{ops::Bound, sync::Arc}; +use std::{ + borrow::Cow, + ops::{Bound, Deref, DerefMut}, + sync::Arc, +}; use derive_more::From; use spacetimedb_expr::StatementSource; -use spacetimedb_lib::AlgebraicValue; +use spacetimedb_lib::{query::Delta, AlgebraicValue}; use spacetimedb_primitives::{ColId, ColSet, IndexId}; use spacetimedb_schema::schema::{IndexSchema, TableSchema}; use spacetimedb_sql_parser::ast::{BinOp, LogOp}; +use spacetimedb_table::table::RowRef; + +use crate::rules::{ + ComputePositions, HashToIxJoin, IxScanAnd, IxScanEq, IxScanEq2Col, IxScanEq3Col, PullFilterAboveHashJoin, + PushConstAnd, PushConstEq, ReorderDeltaJoinRhs, ReorderHashJoin, RewriteRule, UniqueHashJoinRule, UniqueIxJoinRule, +}; /// Table aliases are replaced with labels in the physical plan #[derive(Debug, Clone, Copy, PartialEq, Eq, From)] pub struct Label(pub usize); -/// Physical query plans always terminate with a projection -#[derive(Debug, PartialEq, Eq)] -pub enum PhysicalProject { +/// Physical plans always terminate with a projection. +/// This type of projection returns row ids. +/// +/// It can represent: +/// +/// ```sql +/// select * from t +/// ``` +/// +/// and +/// +/// ```sql +/// select t.* from t join ... +/// ``` +/// +/// but not +/// +/// ```sql +/// select a from t +/// ``` +#[derive(Debug, Clone)] +pub enum ProjectPlan { None(PhysicalPlan), - Relvar(PhysicalPlan, Label), - Fields(PhysicalPlan, Vec<(Box, ProjectField)>), + Name(PhysicalPlan, Label, Option), +} + +impl Deref for ProjectPlan { + type Target = PhysicalPlan; + + fn deref(&self) -> &Self::Target { + match self { + Self::None(plan) | Self::Name(plan, ..) => plan, + } + } +} + +impl DerefMut for ProjectPlan { + fn deref_mut(&mut self) -> &mut Self::Target { + match self { + Self::None(plan) | Self::Name(plan, ..) => plan, + } + } } -impl PhysicalProject { +impl ProjectPlan { pub fn optimize(self) -> Self { match self { Self::None(plan) => Self::None(plan.optimize(vec![])), - Self::Relvar(plan, var) => Self::None(plan.optimize(vec![var])), - Self::Fields(plan, fields) => { - Self::Fields(plan.optimize(fields.iter().map(|(_, proj)| proj.var).collect()), fields) + Self::Name(plan, label, _) => { + let plan = plan.optimize(vec![label]); + let n = plan.nfields(); + let pos = plan.position(&label); + match n { + 1 => Self::None(plan), + _ => Self::Name(plan, label, pos), + } } } } } -#[derive(Debug, PartialEq, Eq)] -pub struct ProjectField { - pub var: Label, - pub pos: usize, +/// Physical plans always terminate with a projection. +/// This type can project fields within a table. +/// +/// That is, it can represent: +/// +/// ```sql +/// select a from t +/// ``` +/// +/// as well as +/// +/// ```sql +/// select t.a, s.b from t join s ... +/// ``` +#[derive(Debug)] +pub enum ProjectListPlan { + Name(ProjectPlan), + List(PhysicalPlan, Vec<(Box, TupleField)>), +} + +impl ProjectListPlan { + pub fn optimize(self) -> Self { + match self { + Self::Name(plan) => Self::Name(plan.optimize()), + Self::List(plan, fields) => Self::List( + plan.optimize( + fields + .iter() + .map(|(_, TupleField { label, .. })| label) + .copied() + .collect(), + ), + fields, + ), + } + } +} + +/// Query operators return tuples of rows. +/// And this type refers to a field of a row within a tuple. +/// +/// Note that from the perspective of the optimizer, +/// tuple elements have names or labels, +/// so as to preserve query semantics across rewrites. +/// +/// However from the perspective of the query engine, +/// tuple elements are entirely positional. +/// Hence the need for both `label` and `label_pos`. +/// +/// The former is consistent across rewrites. +/// The latter is only computed once after optimization. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TupleField { + pub label: Label, + pub label_pos: Option, + pub field_pos: usize, } /// A physical plan represents a concrete evaluation strategy. -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum PhysicalPlan { /// Scan a table row by row, returning row ids - TableScan(Arc, Label), + TableScan(Arc, Label, Option), /// Fetch row ids from an index IxScan(IxScan, Label), /// An index join + projection IxJoin(IxJoin, Semi), - /// An hash join + projection + /// A hash join + projection HashJoin(HashJoin, Semi), /// A nested loop join NLJoin(Box, Box), @@ -70,6 +173,21 @@ impl PhysicalPlan { } } + /// Walks the plan tree and calls `f` on every op + pub fn visit_mut(&mut self, f: &mut impl FnMut(&mut Self)) { + f(self); + match self { + Self::IxJoin(IxJoin { lhs: input, .. }, _) | Self::Filter(input, _) => { + input.visit_mut(f); + } + Self::NLJoin(lhs, rhs) | Self::HashJoin(HashJoin { lhs, rhs, .. }, _) => { + lhs.visit_mut(f); + rhs.visit_mut(f); + } + _ => {} + } + } + /// Is there any subplan where `f` returns true? pub fn any(&self, f: &impl Fn(&Self) -> bool) -> bool { let mut ok = false; @@ -188,22 +306,29 @@ impl PhysicalPlan { plan } - /// Optimize a physical plan by applying rewrite rules. + /// Optimize a plan using the following rewrites: /// - /// First we canonicalize the plan. - /// Next we push filters to the leaves. - /// Then we try to turn those filters into index scans. - /// And finally we deterimine the index joins and semijoins. + /// 1. Canonicalize the plan + /// 2. Push filters to the leaves + /// 3. Turn filters into index scans if possible + /// 4. Determine index and semijoins + /// 5. Compute positions for tuple labels pub fn optimize(self, reqs: Vec