diff --git a/Cargo.lock b/Cargo.lock index 730192bf..525ce448 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2560,6 +2560,7 @@ dependencies = [ "prost", "pyo3", "pyo3-build-config", + "rand", "ref-cast", "regex", "risinglight_proto", diff --git a/Cargo.toml b/Cargo.toml index 4e83e4e8..0dfa8e2a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,6 +57,7 @@ pin-project = "1" pretty-xmlish = "0.1" prost = "0.13" pyo3 = { version = "0.22", features = ["extension-module"], optional = true } +rand = "0.8" ref-cast = "1.0" regex = "1" risinglight_proto = "0.2" diff --git a/src/array/internal_ext.rs b/src/array/internal_ext.rs index 8d3bf963..650c1c81 100644 --- a/src/array/internal_ext.rs +++ b/src/array/internal_ext.rs @@ -12,6 +12,11 @@ use crate::for_all_variants; pub trait ArrayValidExt: Array { fn get_valid_bitmap(&self) -> &BitVec; fn get_valid_bitmap_mut(&mut self) -> &mut BitVec; + + /// Returns the number of null values in this array. + fn null_count(&self) -> usize { + self.get_valid_bitmap().count_zeros() + } } pub trait ArrayImplValidExt { diff --git a/src/array/ops.rs b/src/array/ops.rs index 581783d6..6c041bd2 100644 --- a/src/array/ops.rs +++ b/src/array/ops.rs @@ -3,6 +3,7 @@ //! Array operations. use std::borrow::Borrow; +use std::hash::{Hash, Hasher}; use num_traits::ToPrimitive; use regex::Regex; @@ -204,6 +205,26 @@ impl ArrayImpl { Ok(A::new_bool(clear_null(unary_op(a.as_ref(), |b| !b)))) } + /// Hash the array into the given hasher. + pub fn hash_to(&self, hasher: &mut [impl Hasher]) { + assert_eq!(hasher.len(), self.len()); + match self { + A::Null(a) => a.iter().zip(hasher).for_each(|(v, h)| v.hash(h)), + A::Bool(a) => a.iter().zip(hasher).for_each(|(v, h)| v.hash(h)), + A::Int16(a) => a.iter().zip(hasher).for_each(|(v, h)| v.hash(h)), + A::Int32(a) => a.iter().zip(hasher).for_each(|(v, h)| v.hash(h)), + A::Int64(a) => a.iter().zip(hasher).for_each(|(v, h)| v.hash(h)), + A::Float64(a) => a.iter().zip(hasher).for_each(|(v, h)| v.hash(h)), + A::Decimal(a) => a.iter().zip(hasher).for_each(|(v, h)| v.hash(h)), + A::String(a) => a.iter().zip(hasher).for_each(|(v, h)| v.hash(h)), + A::Date(a) => a.iter().zip(hasher).for_each(|(v, h)| v.hash(h)), + A::Timestamp(a) => a.iter().zip(hasher).for_each(|(v, h)| v.hash(h)), + A::TimestampTz(a) => a.iter().zip(hasher).for_each(|(v, h)| v.hash(h)), + A::Interval(a) => a.iter().zip(hasher).for_each(|(v, h)| v.hash(h)), + A::Blob(a) => a.iter().zip(hasher).for_each(|(v, h)| v.hash(h)), + } + } + pub fn like(&self, pattern: &str) -> Result { /// Converts a SQL LIKE pattern to a regex pattern. fn like_to_regex(pattern: &str) -> String { @@ -600,12 +621,48 @@ impl ArrayImpl { /// Returns the sum of values. pub fn sum(&self) -> DataValue { match self { - Self::Int16(a) => DataValue::Int16(a.raw_iter().sum()), - Self::Int32(a) => DataValue::Int32(a.raw_iter().sum()), - Self::Int64(a) => DataValue::Int64(a.raw_iter().sum()), - Self::Float64(a) => DataValue::Float64(a.raw_iter().sum()), - Self::Decimal(a) => DataValue::Decimal(a.raw_iter().sum()), - Self::Interval(a) => DataValue::Interval(a.raw_iter().sum()), + Self::Int16(a) => { + if a.null_count() == a.len() { + DataValue::Null + } else { + DataValue::Int16(a.raw_iter().sum()) + } + } + Self::Int32(a) => { + if a.null_count() == a.len() { + DataValue::Null + } else { + DataValue::Int32(a.raw_iter().sum()) + } + } + Self::Int64(a) => { + if a.null_count() == a.len() { + DataValue::Null + } else { + DataValue::Int64(a.raw_iter().sum()) + } + } + Self::Float64(a) => { + if a.null_count() == a.len() { + DataValue::Null + } else { + DataValue::Float64(a.raw_iter().sum()) + } + } + Self::Decimal(a) => { + if a.null_count() == a.len() { + DataValue::Null + } else { + DataValue::Decimal(a.raw_iter().sum()) + } + } + Self::Interval(a) => { + if a.null_count() == a.len() { + DataValue::Null + } else { + DataValue::Interval(a.raw_iter().sum()) + } + } _ => panic!("can not sum array"), } } diff --git a/src/binder/create_function.rs b/src/binder/create_function.rs index 71cfd2fa..2bf10ffb 100644 --- a/src/binder/create_function.rs +++ b/src/binder/create_function.rs @@ -10,7 +10,7 @@ use serde::{Deserialize, Serialize}; use super::*; #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)] -pub struct CreateFunction { +pub struct FunctionDef { pub schema_name: String, pub name: String, pub arg_types: Vec, @@ -20,14 +20,14 @@ pub struct CreateFunction { pub body: String, } -impl fmt::Display for CreateFunction { +impl fmt::Display for Box { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let explainer = Pretty::childless_record("CreateFunction", self.pretty_function()); + let explainer = Pretty::childless_record("FunctionDef", self.pretty_function()); delegate_fmt(&explainer, f, String::with_capacity(1000)) } } -impl FromStr for CreateFunction { +impl FromStr for Box { type Err = (); fn from_str(_s: &str) -> std::result::Result { @@ -35,7 +35,7 @@ impl FromStr for CreateFunction { } } -impl CreateFunction { +impl FunctionDef { pub fn pretty_function<'a>(&self) -> Vec<(&'a str, Pretty<'a>)> { vec![ ("name", Pretty::display(&self.name)), @@ -102,7 +102,7 @@ impl Binder { arg_names.push(arg.name.map_or("".to_string(), |n| n.to_string())); } - let f = self.egraph.add(Node::CreateFunction(CreateFunction { + let func_def = self.egraph.add(Node::FunctionDef(Box::new(FunctionDef { schema_name, name, arg_types, @@ -110,8 +110,8 @@ impl Binder { return_type, language, body, - })); - - Ok(f) + }))); + let id = self.egraph.add(Node::CreateFunction(func_def)); + Ok(id) } } diff --git a/src/binder/create_table.rs b/src/binder/create_table.rs index 6c86213c..2bf58fba 100644 --- a/src/binder/create_table.rs +++ b/src/binder/create_table.rs @@ -12,22 +12,22 @@ use super::*; use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnId, SchemaId}; #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)] -pub struct CreateTable { +pub struct TableDef { pub schema_id: SchemaId, pub table_name: String, pub columns: Vec, pub ordered_pk_ids: Vec, } -impl fmt::Display for CreateTable { +impl fmt::Display for TableDef { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let explainer = Pretty::childless_record("CreateTable", self.pretty_table()); + let explainer = Pretty::childless_record("TableDef", self.pretty_table()); delegate_fmt(&explainer, f, String::with_capacity(1000)) } } -impl CreateTable { - pub fn pretty_table<'a>(&self) -> Vec<(&'a str, Pretty<'a>)> { +impl TableDef { + pub fn pretty_table(&self) -> Vec<(&str, Pretty)> { let cols = Pretty::Array(self.columns.iter().map(|c| c.desc().pretty()).collect()); let ids = Pretty::Array(self.ordered_pk_ids.iter().map(Pretty::display).collect()); vec![ @@ -39,7 +39,7 @@ impl CreateTable { } } -impl FromStr for Box { +impl FromStr for Box { type Err = (); fn from_str(_s: &str) -> std::result::Result { @@ -119,13 +119,14 @@ impl Binder { columns[index as usize].set_nullable(false); } - let create = self.egraph.add(Node::CreateTable(Box::new(CreateTable { + let table_def = self.egraph.add(Node::TableDef(Box::new(TableDef { schema_id: schema.id(), table_name: table_name.into(), columns, ordered_pk_ids, }))); - Ok(create) + let id = self.egraph.add(Node::CreateTable(table_def)); + Ok(id) } /// get primary keys' id in declared order。 diff --git a/src/binder/create_view.rs b/src/binder/create_view.rs index 2b151307..30fd0278 100644 --- a/src/binder/create_view.rs +++ b/src/binder/create_view.rs @@ -49,13 +49,13 @@ impl Binder { }) .collect(); - let table = self.egraph.add(Node::CreateTable(Box::new(CreateTable { + let table_def = self.egraph.add(Node::TableDef(Box::new(TableDef { schema_id: schema.id(), table_name: table_name.into(), columns, ordered_pk_ids: vec![], }))); - let create_view = self.egraph.add(Node::CreateView([table, query])); + let create_view = self.egraph.add(Node::CreateView([table_def, query])); Ok(create_view) } } diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 82c984f1..74e56f54 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -238,7 +238,7 @@ impl Binder { fn bind_extract(&mut self, field: DateTimeField, expr: Expr) -> Result { let expr = self.bind_expr(expr)?; - let field = self.egraph.add(Node::Field(field.into())); + let field = self.egraph.add(Node::Field(Box::new(field.into()))); Ok(self.egraph.add(Node::Extract([field, expr]))) } diff --git a/src/db.rs b/src/db.rs index 181cb4f8..665aca10 100644 --- a/src/db.rs +++ b/src/db.rs @@ -27,8 +27,12 @@ pub struct Database { /// The configuration of the database. #[derive(Debug, Default)] struct Config { + /// If true, no optimization will be applied to the query. disable_optimizer: bool, mock_stat: Option, + /// The number of partitions of each operator. + /// If set to 0, it will be automatically determined by the number of worker threads. + parallelism: usize, } impl Database { @@ -93,6 +97,11 @@ impl Database { crate::planner::Config { enable_range_filter_scan: self.storage.support_range_filter_scan(), table_is_sorted_by_primary_key: self.storage.table_is_sorted_by_primary_key(), + parallelism: if self.config.lock().unwrap().parallelism > 0 { + self.config.lock().unwrap().parallelism + } else { + tokio::runtime::Handle::current().metrics().num_workers() + }, }, ); @@ -158,19 +167,13 @@ impl Database { /// Mock the row count of a table for planner test. fn handle_set(&self, stmt: &Statement) -> Result { if let Statement::Pragma { name, .. } = stmt { + let mut config = self.config.lock().unwrap(); match name.to_string().as_str() { - "enable_optimizer" => { - self.config.lock().unwrap().disable_optimizer = false; - return Ok(true); - } - "disable_optimizer" => { - self.config.lock().unwrap().disable_optimizer = true; - return Ok(true); - } - name => { - return Err(crate::binder::BindError::NoPragma(name.into()).into()); - } + "enable_optimizer" => config.disable_optimizer = false, + "disable_optimizer" => config.disable_optimizer = true, + name => return Err(crate::binder::BindError::NoPragma(name.into()).into()), } + return Ok(true); } let Statement::SetVariable { variable, value, .. @@ -178,6 +181,14 @@ impl Database { else { return Ok(false); }; + if variable.0[0].value == "parallelism" { + let mut config = self.config.lock().unwrap(); + config.parallelism = value[0] + .to_string() + .parse::() + .map_err(|_| Error::Internal("invalid parallelism".into()))?; + return Ok(true); + } let Some(table_name) = variable.0[0].value.strip_prefix("mock_rowcount_") else { return Ok(false); }; @@ -202,6 +213,11 @@ impl Database { fn pragma_options() -> &'static [&'static str] { &["enable_optimizer", "disable_optimizer"] } + + /// Return all available set variables. + fn set_variables() -> &'static [&'static str] { + &["parallelism"] + } } /// The error type of database operations. @@ -268,6 +284,19 @@ impl rustyline::completion::Completer for &Database { return Ok((pos - last_word.len(), candidates)); } + // completion for set variable + if prefix.trim().eq_ignore_ascii_case("set") { + let candidates = Database::set_variables() + .iter() + .filter(|option| option.starts_with(last_word)) + .map(|option| rustyline::completion::Pair { + display: option.to_string(), + replacement: option.to_string(), + }) + .collect(); + return Ok((pos - last_word.len(), candidates)); + } + // TODO: complete table and column names // completion for keywords diff --git a/src/executor/analyze.rs b/src/executor/analyze.rs index 1521c99b..03a21095 100644 --- a/src/executor/analyze.rs +++ b/src/executor/analyze.rs @@ -1,7 +1,5 @@ // Copyright 2024 RisingLight Project Authors. Licensed under Apache-2.0. -use std::sync::atomic::{AtomicU64, Ordering}; - use pretty_xmlish::PrettyConfig; use super::*; @@ -26,10 +24,16 @@ impl AnalyzeExecutor { // explain the plan let get_metadata = |id| { - vec![ - ("rows", self.metrics.get_rows(id).to_string()), - ("time", format!("{:?}", self.metrics.get_time(id))), - ] + let mut metadata = Vec::new(); + if let Some(rows) = self.metrics.get_rows(id) { + let total = rows.iter().sum::(); + metadata.push(("rows", format!("{total} = {rows:?}"))); + } + if let Some(time) = self.metrics.get_time(id) { + let max = time.iter().max().unwrap(); + metadata.push(("time", format!("{max:?} = {time:?}"))); + } + metadata }; let explain_obj = Explain::of(&self.plan) .with_catalog(&self.catalog) @@ -50,44 +54,38 @@ impl AnalyzeExecutor { } /// A collection of profiling information for a query. -#[derive(Default)] +#[derive(Default, Debug)] pub struct Metrics { - spans: HashMap, - rows: HashMap, + spans: HashMap>, + rows: HashMap>, } impl Metrics { - /// Register metrics for a node. - pub fn register(&mut self, id: Id, span: TimeSpan, rows: Counter) { - self.spans.insert(id, span); - self.rows.insert(id, rows); + /// Create metrics for a node. + pub fn add( + &mut self, + id: Id, + num_spans: usize, + num_counters: usize, + ) -> (Vec, Vec) { + let spans = (0..num_spans).map(|_| TimeSpan::default()).collect_vec(); + let counters = (0..num_counters).map(|_| Counter::default()).collect_vec(); + self.spans.insert(id, spans.clone()); + self.rows.insert(id, counters.clone()); + (spans, counters) } /// Get the running time for a node. - pub fn get_time(&self, id: Id) -> Duration { - self.spans.get(&id).map(|span| span.busy_time()).unwrap() + pub fn get_time(&self, id: Id) -> Option> { + self.spans + .get(&id) + .map(|spans| spans.iter().map(|span| span.busy_time()).collect()) } /// Get the number of rows produced by a node. - pub fn get_rows(&self, id: Id) -> u64 { - self.rows.get(&id).map(|rows| rows.get()).unwrap() - } -} - -/// A counter. -#[derive(Default, Clone)] -pub struct Counter { - count: Arc, -} - -impl Counter { - /// Increments the counter. - pub fn inc(&self, value: u64) { - self.count.fetch_add(value, Ordering::Relaxed); - } - - /// Gets the current value of the counter. - pub fn get(&self) -> u64 { - self.count.load(Ordering::Relaxed) + pub fn get_rows(&self, id: Id) -> Option> { + self.rows + .get(&id) + .map(|rows| rows.iter().map(|counter| counter.get()).collect()) } } diff --git a/src/executor/create_function.rs b/src/executor/create_function.rs index cde51bcf..d0d8ba58 100644 --- a/src/executor/create_function.rs +++ b/src/executor/create_function.rs @@ -1,19 +1,19 @@ // Copyright 2024 RisingLight Project Authors. Licensed under Apache-2.0. use super::*; -use crate::binder::CreateFunction; +use crate::binder::FunctionDef; use crate::catalog::RootCatalogRef; /// The executor of `create function` statement. pub struct CreateFunctionExecutor { - pub f: CreateFunction, + pub function: Box, pub catalog: RootCatalogRef, } impl CreateFunctionExecutor { #[try_stream(boxed, ok = DataChunk, error = ExecutorError)] pub async fn execute(self) { - let CreateFunction { + let FunctionDef { schema_name, name, arg_types, @@ -21,7 +21,7 @@ impl CreateFunctionExecutor { return_type, language, body, - } = self.f; + } = *self.function; self.catalog.create_function( schema_name.clone(), diff --git a/src/executor/create_table.rs b/src/executor/create_table.rs index e0848b4d..40064d67 100644 --- a/src/executor/create_table.rs +++ b/src/executor/create_table.rs @@ -3,12 +3,12 @@ use std::sync::Arc; use super::*; -use crate::binder::CreateTable; +use crate::binder::TableDef; use crate::storage::Storage; /// The executor of `create table` statement. pub struct CreateTableExecutor { - pub table: Box, + pub table: Box, pub storage: Arc, } diff --git a/src/executor/create_view.rs b/src/executor/create_view.rs index 60d5d794..9966d37b 100644 --- a/src/executor/create_view.rs +++ b/src/executor/create_view.rs @@ -1,12 +1,12 @@ // Copyright 2024 RisingLight Project Authors. Licensed under Apache-2.0. use super::*; -use crate::binder::CreateTable; +use crate::binder::TableDef; use crate::catalog::RootCatalogRef; /// The executor of `create view` statement. pub struct CreateViewExecutor { - pub table: Box, + pub table: Box, pub query: RecExpr, pub catalog: RootCatalogRef, } diff --git a/src/executor/evaluator.rs b/src/executor/evaluator.rs index f39690d0..5a278b72 100644 --- a/src/executor/evaluator.rs +++ b/src/executor/evaluator.rs @@ -305,14 +305,18 @@ trait Ext { } impl Ext for DataValue { + /// Add two values. The result is null only if both values are null. fn add(self, other: Self) -> Self { if self.is_null() { other + } else if other.is_null() { + self } else { self + other } } + /// Returns the first non-null value. fn or(self, other: Self) -> Self { if self.is_null() { other diff --git a/src/executor/exchange.rs b/src/executor/exchange.rs new file mode 100644 index 00000000..145c61db --- /dev/null +++ b/src/executor/exchange.rs @@ -0,0 +1,75 @@ +// Copyright 2024 RisingLight Project Authors. Licensed under Apache-2.0. + +use std::hash::Hasher; + +use ahash::AHasher; +use rand::{Rng, SeedableRng}; + +use super::*; + +/// Distribute the input data to multiple partitions by hash partitioning. +pub struct HashPartitionProducer { + /// The expression to extract the keys. + /// e.g. `(list #0 #1)` + pub keys: RecExpr, + /// The number of partitions to produce. + pub num_partitions: usize, +} + +impl HashPartitionProducer { + #[try_stream(boxed, ok = (DataChunk, usize), error = ExecutorError)] + pub async fn execute(self, child: BoxedExecutor) { + // preallocate buffers for reuse + let mut hashers = vec![AHasher::default(); PROCESSING_WINDOW_SIZE]; + let mut partition_indices = vec![0; PROCESSING_WINDOW_SIZE]; + let mut visibility = vec![false; PROCESSING_WINDOW_SIZE]; + + #[for_await] + for batch in child { + let batch = batch?; + + // reset buffers + hashers.clear(); + hashers.resize(batch.cardinality(), AHasher::default()); + partition_indices.resize(batch.cardinality(), 0); + visibility.resize(batch.cardinality(), false); + + // calculate the hash + let keys_chunk = Evaluator::new(&self.keys).eval_list(&batch)?; + for column in keys_chunk.arrays() { + column.hash_to(&mut hashers); + } + for (hasher, target) in hashers.iter().zip(&mut partition_indices) { + *target = hasher.finish() as usize % self.num_partitions; + } + + // send the batch to the corresponding partition + for partition in 0..self.num_partitions { + for (row, p) in partition_indices.iter().enumerate() { + visibility[row] = *p == partition; + } + let chunk = batch.filter(&visibility); + yield (chunk, partition); + } + } + } +} + +/// Randomly distribute the input data to multiple partitions. +pub struct RandomPartitionProducer { + /// The number of partitions. + pub num_partitions: usize, +} + +impl RandomPartitionProducer { + #[try_stream(boxed, ok = (DataChunk, usize), error = ExecutorError)] + pub async fn execute(self, child: BoxedExecutor) { + let mut rng = rand::rngs::SmallRng::from_seed([0; 32]); + #[for_await] + for batch in child { + let batch = batch?; + let partition = rng.gen_range(0..self.num_partitions); + yield (batch, partition); + } + } +} diff --git a/src/executor/mod.rs b/src/executor/mod.rs index c32900d6..7d538203 100644 --- a/src/executor/mod.rs +++ b/src/executor/mod.rs @@ -20,7 +20,6 @@ use egg::{Id, Language}; use futures::stream::{BoxStream, StreamExt}; use futures_async_stream::try_stream; use itertools::Itertools; -use tracing::Instrument; // use minitrace::prelude::*; use self::analyze::*; @@ -34,6 +33,7 @@ use self::drop::*; pub use self::error::Error as ExecutorError; use self::error::*; use self::evaluator::*; +use self::exchange::*; use self::explain::*; use self::filter::*; use self::hash_agg::*; @@ -58,7 +58,8 @@ use crate::catalog::{RootCatalog, RootCatalogRef, TableRefId}; use crate::planner::{Expr, ExprAnalysis, Optimizer, RecExpr, TypeSchemaAnalysis}; use crate::storage::Storage; use crate::types::{ColumnIndex, DataType}; -use crate::utils::timed::{FutureExt as _, Span as TimeSpan}; +use crate::utils::counted::{Counter, StreamExt as _}; +use crate::utils::timed::{Span as TimeSpan, StreamExt as _}; mod analyze; mod copy_from_file; @@ -69,6 +70,7 @@ mod create_view; mod delete; mod drop; mod evaluator; +mod exchange; mod explain; mod filter; mod hash_agg; @@ -99,6 +101,8 @@ const PROCESSING_WINDOW_SIZE: usize = 1024; /// It consumes one or more streams from its child executors, /// and produces a stream to its parent. pub type BoxedExecutor = BoxStream<'static, Result>; +/// A boxed dispatcher that distributes data to multiple partitions. +pub type BoxedDispatcher = BoxStream<'static, Result<(DataChunk, usize)>>; pub fn build(optimizer: Optimizer, storage: Arc, plan: &RecExpr) -> BoxedExecutor { Builder::new(optimizer, storage, plan).build() @@ -112,7 +116,7 @@ struct Builder { root: Id, /// For scans on views, we prebuild their executors and store them here. /// Multiple scans on the same view will share the same executor. - views: HashMap, + views: HashMap, metrics: Metrics, } @@ -197,21 +201,16 @@ impl Builder { /// Builds the executor. fn build(mut self) -> BoxedExecutor { - self.build_id(self.root) + self.build_id(self.root).spawn_merge() } /// Builds the executor and returns its subscriber. - fn build_subscriber(mut self) -> StreamSubscriber { - self.build_id_subscriber(self.root) + fn build_subscriber(mut self) -> PartitionedStreamSubscriber { + self.build_id(self.root).spawn() } - /// Builds the executor for the given id. - fn build_id(&mut self, id: Id) -> BoxedExecutor { - self.build_id_subscriber(id).subscribe() - } - - /// Builds the executor for the given id and returns its subscriber. - fn build_id_subscriber(&mut self, id: Id) -> StreamSubscriber { + /// Builds stream for the given plan. + fn build_id(&mut self, id: Id) -> PartitionedStream { use Expr::*; let stream = match self.node(id).clone() { Scan([table, list, filter]) => { @@ -257,7 +256,12 @@ impl Builder { .collect(); projs.add(List(lists)); - ProjectionExecutor { projs }.execute(subscriber.subscribe()) + subscriber.subscribe().map(|c| { + ProjectionExecutor { + projs: projs.clone(), + } + .execute(c) + }) } else if table_id.schema_id == RootCatalog::SYSTEM_SCHEMA_ID { SystemTableScan { catalog: self.catalog().clone(), @@ -266,6 +270,7 @@ impl Builder { columns, } .execute() + .into() } else { TableScanExecutor { table_id, @@ -274,6 +279,7 @@ impl Builder { storage: self.storage.clone(), } .execute() + .into() } } @@ -289,121 +295,155 @@ impl Builder { .collect() }, } - .execute(), - - Proj([projs, child]) => ProjectionExecutor { - projs: self.resolve_column_index(projs, child), - } - .execute(self.build_id(child)), - - Filter([cond, child]) => FilterExecutor { - condition: self.resolve_column_index(cond, child), - } - .execute(self.build_id(child)), + .execute() + .into(), - Order([order_keys, child]) => OrderExecutor { - order_keys: self.resolve_column_index(order_keys, child), - types: self.plan_types(id).to_vec(), - } - .execute(self.build_id(child)), + Proj([projs, child]) => self.build_id(child).map(|c| { + ProjectionExecutor { + projs: self.resolve_column_index(projs, child), + } + .execute(c) + }), - Limit([limit, offset, child]) => LimitExecutor { - limit: (self.node(limit).as_const().as_usize().unwrap()).unwrap_or(usize::MAX / 2), - offset: self.node(offset).as_const().as_usize().unwrap().unwrap(), - } - .execute(self.build_id(child)), + Filter([cond, child]) => self.build_id(child).map(|c| { + FilterExecutor { + condition: self.resolve_column_index(cond, child), + } + .execute(c) + }), - TopN([limit, offset, order_keys, child]) => TopNExecutor { - limit: (self.node(limit).as_const().as_usize().unwrap()).unwrap_or(usize::MAX / 2), - offset: self.node(offset).as_const().as_usize().unwrap().unwrap(), - order_keys: self.resolve_column_index(order_keys, child), - types: self.plan_types(id).to_vec(), - } - .execute(self.build_id(child)), - - Join([op, on, left, right]) => match self.node(op) { - Inner | LeftOuter | RightOuter | FullOuter => NestedLoopJoinExecutor { - op: self.node(op).clone(), - condition: self.resolve_column_index2(on, left, right), - left_types: self.plan_types(left).to_vec(), - right_types: self.plan_types(right).to_vec(), + Order([order_keys, child]) => self.build_id(child).map(|c| { + OrderExecutor { + order_keys: self.resolve_column_index(order_keys, child), + types: self.plan_types(id).to_vec(), } - .execute(self.build_id(left), self.build_id(right)), - op @ Semi | op @ Anti => NestedLoopSemiJoinExecutor { - anti: matches!(op, Anti), - condition: self.resolve_column_index2(on, left, right), - left_types: self.plan_types(left).to_vec(), + .execute(c) + }), + + Limit([limit, offset, child]) => self.build_id(child).map(|c| { + LimitExecutor { + limit: (self.node(limit).as_const().as_usize().unwrap()) + .unwrap_or(usize::MAX / 2), + offset: self.node(offset).as_const().as_usize().unwrap().unwrap(), } - .execute(self.build_id(left), self.build_id(right)), - t => panic!("invalid join type: {t:?}"), - }, - - HashJoin(args @ [op, ..]) => match self.node(op) { - Inner => self.build_hashjoin::<{ JoinType::Inner }>(args), - LeftOuter => self.build_hashjoin::<{ JoinType::LeftOuter }>(args), - RightOuter => self.build_hashjoin::<{ JoinType::RightOuter }>(args), - FullOuter => self.build_hashjoin::<{ JoinType::FullOuter }>(args), - Semi => self.build_hashsemijoin(args, false), - Anti => self.build_hashsemijoin(args, true), - t => panic!("invalid join type: {t:?}"), - }, + .execute(c) + }), + + TopN([limit, offset, order_keys, child]) => self.build_id(child).map(|c| { + TopNExecutor { + limit: (self.node(limit).as_const().as_usize().unwrap()) + .unwrap_or(usize::MAX / 2), + offset: self.node(offset).as_const().as_usize().unwrap().unwrap(), + order_keys: self.resolve_column_index(order_keys, child), + types: self.plan_types(id).to_vec(), + } + .execute(c) + }), + + Join([op, on, left, right]) => { + self.build_id(left) + .zip(self.build_id(right)) + .map(|l, r| match self.node(op) { + Inner | LeftOuter | RightOuter | FullOuter => NestedLoopJoinExecutor { + op: self.node(op).clone(), + condition: self.resolve_column_index2(on, left, right), + left_types: self.plan_types(left).to_vec(), + right_types: self.plan_types(right).to_vec(), + } + .execute(l, r), + op @ Semi | op @ Anti => NestedLoopSemiJoinExecutor { + anti: matches!(op, Anti), + condition: self.resolve_column_index2(on, left, right), + left_types: self.plan_types(left).to_vec(), + } + .execute(l, r), + t => panic!("invalid join type: {t:?}"), + }) + } - MergeJoin(args @ [op, ..]) => match self.node(op) { - Inner => self.build_mergejoin::<{ JoinType::Inner }>(args), - LeftOuter => self.build_mergejoin::<{ JoinType::LeftOuter }>(args), - RightOuter => self.build_mergejoin::<{ JoinType::RightOuter }>(args), - FullOuter => self.build_mergejoin::<{ JoinType::FullOuter }>(args), - t => panic!("invalid join type: {t:?}"), - }, + HashJoin(args @ [op, _, _, _, left, right]) => self + .build_id(left) + .zip(self.build_id(right)) + .map(|l, r| match self.node(op) { + Inner => self.build_hashjoin::<{ JoinType::Inner }>(args, l, r), + LeftOuter => self.build_hashjoin::<{ JoinType::LeftOuter }>(args, l, r), + RightOuter => self.build_hashjoin::<{ JoinType::RightOuter }>(args, l, r), + FullOuter => self.build_hashjoin::<{ JoinType::FullOuter }>(args, l, r), + Semi => self.build_hashsemijoin(args, false, l, r), + Anti => self.build_hashsemijoin(args, true, l, r), + t => panic!("invalid join type: {t:?}"), + }), + + MergeJoin(args @ [op, _, _, _, left, right]) => self + .build_id(left) + .zip(self.build_id(right)) + .map(|l, r| match self.node(op) { + Inner => self.build_mergejoin::<{ JoinType::Inner }>(args, l, r), + LeftOuter => self.build_mergejoin::<{ JoinType::LeftOuter }>(args, l, r), + RightOuter => self.build_mergejoin::<{ JoinType::RightOuter }>(args, l, r), + FullOuter => self.build_mergejoin::<{ JoinType::FullOuter }>(args, l, r), + t => panic!("invalid join type: {t:?}"), + }), Apply(_) => { panic!("Apply is not supported in executor. It should be rewritten to join by optimizer.") } - Agg([aggs, child]) => SimpleAggExecutor { - aggs: self.resolve_column_index(aggs, child), - types: self.plan_types(id).to_vec(), - } - .execute(self.build_id(child)), - - HashAgg([keys, aggs, child]) => HashAggExecutor { - keys: self.resolve_column_index(keys, child), - aggs: self.resolve_column_index(aggs, child), - types: self.plan_types(id).to_vec(), - } - .execute(self.build_id(child)), - - SortAgg([keys, aggs, child]) => SortAggExecutor { - keys: self.resolve_column_index(keys, child), - aggs: self.resolve_column_index(aggs, child), - types: self.plan_types(id).to_vec(), - } - .execute(self.build_id(child)), + Agg([aggs, child]) => self.build_id(child).map(|c| { + SimpleAggExecutor { + aggs: self.resolve_column_index(aggs, child), + types: self.plan_types(id).to_vec(), + } + .execute(c) + }), + + HashAgg([keys, aggs, child]) => self.build_id(child).map(|c| { + HashAggExecutor { + keys: self.resolve_column_index(keys, child), + aggs: self.resolve_column_index(aggs, child), + types: self.plan_types(id).to_vec(), + } + .execute(c) + }), + + SortAgg([keys, aggs, child]) => self.build_id(child).map(|c| { + SortAggExecutor { + keys: self.resolve_column_index(keys, child), + aggs: self.resolve_column_index(aggs, child), + types: self.plan_types(id).to_vec(), + } + .execute(c) + }), - Window([exprs, child]) => WindowExecutor { - exprs: self.resolve_column_index(exprs, child), - types: self.plan_types(exprs).to_vec(), - } - .execute(self.build_id(child)), + Window([exprs, child]) => self.build_id(child).map(|c| { + WindowExecutor { + exprs: self.resolve_column_index(exprs, child), + types: self.plan_types(exprs).to_vec(), + } + .execute(c) + }), CreateTable(table) => CreateTableExecutor { - table, + table: self.node(table).as_table_def(), storage: self.storage.clone(), } - .execute(), + .execute() + .into(), CreateView([table, query]) => CreateViewExecutor { - table: self.node(table).as_create_table(), + table: self.node(table).as_table_def(), query: self.recexpr(query), catalog: self.catalog().clone(), } - .execute(), + .execute() + .into(), CreateFunction(f) => CreateFunctionExecutor { - f, + function: self.node(f).as_function_def(), catalog: self.optimizer.catalog().clone(), } - .execute(), + .execute() + .into(), Drop(tables) => DropExecutor { tables: (self.node(tables).as_list().iter()) @@ -412,7 +452,8 @@ impl Builder { catalog: self.catalog().clone(), storage: self.storage.clone(), } - .execute(), + .execute() + .into(), Insert([table, cols, child]) => InsertExecutor { table_id: self.node(table).as_table(), @@ -421,69 +462,127 @@ impl Builder { .collect(), storage: self.storage.clone(), } - .execute(self.build_id(child)), + .execute(self.build_id(child).spawn_merge()) + .into(), Delete([table, child]) => DeleteExecutor { table_id: self.node(table).as_table(), storage: self.storage.clone(), } - .execute(self.build_id(child)), + .execute(self.build_id(child).spawn_merge()) + .into(), CopyFrom([src, types]) => CopyFromFileExecutor { source: self.node(src).as_ext_source(), types: self.node(types).as_type().as_struct().to_vec(), } - .execute(), + .execute() + .into(), CopyTo([src, child]) => CopyToFileExecutor { source: self.node(src).as_ext_source(), } - .execute(self.build_id(child)), + .execute(self.build_id(child).spawn_merge()) + .into(), Explain(plan) => ExplainExecutor { plan: self.recexpr(plan), optimizer: self.optimizer.clone(), } - .execute(), + .execute() + .into(), Analyze(child) => { - let stream = self.build_id(child); + let stream = self.build_id(child).spawn_merge(); AnalyzeExecutor { plan: self.recexpr(child), catalog: self.optimizer.catalog().clone(), + // note: make sure to take the metrics after building the child stream metrics: std::mem::take(&mut self.metrics), } .execute(stream) + .into() } - Empty(_) => futures::stream::empty().boxed(), + Empty(_) => futures::stream::empty().boxed().into(), + + Schema([_, child]) => self.build_id(child), // schema node is just pass-through + + Exchange([dist, child]) => match self.node(dist).clone() { + Single => self.build_id(child).spawn_merge().into(), + Broadcast => self + .build_id(child) + .spawn_broadcast(self.optimizer.config().parallelism), + Random => { + let stream = self.build_id(child); + let num_partitions = self.optimizer.config().parallelism; + let (spans, counters) = self.metrics.add(id, stream.len(), num_partitions); + return stream + .dispatch(num_partitions, |c| { + RandomPartitionProducer { num_partitions }.execute(c) + }) + .instrument(spans) + .spawn() + .subscribe() + .counted(counters); + } + Hash(keys) => { + let keys = self.resolve_column_index(keys, child); + let num_partitions = self.optimizer.config().parallelism; + let stream = self.build_id(child); + let (spans, counters) = self.metrics.add(id, stream.len(), num_partitions); + return stream + .dispatch(num_partitions, |c| { + HashPartitionProducer { + keys: keys.clone(), + num_partitions, + } + .execute(c) + }) + .instrument(spans) + .spawn() + .subscribe() + .counted(counters); + } + node => panic!("invalid exchange type: {node:?}"), + }, - node => panic!("not a plan: {node:?}"), + node => panic!("not a plan: {node:?}\n{:?}", self.egraph.dump()), }; - self.spawn(id, stream) + let (spans, counters) = self.metrics.add(id, stream.len(), stream.len()); + stream.instrument(spans, counters) } - fn build_hashjoin(&mut self, args: [Id; 6]) -> BoxedExecutor { - let [_, cond, lkeys, rkeys, left, right] = args; + fn build_hashjoin( + &self, + [_, cond, lkey, rkey, left, right]: [Id; 6], + l: BoxedExecutor, + r: BoxedExecutor, + ) -> BoxedExecutor { assert_eq!(self.node(cond), &Expr::true_()); HashJoinExecutor:: { - left_keys: self.resolve_column_index(lkeys, left), - right_keys: self.resolve_column_index(rkeys, right), + left_keys: self.resolve_column_index(lkey, left), + right_keys: self.resolve_column_index(rkey, right), left_types: self.plan_types(left).to_vec(), right_types: self.plan_types(right).to_vec(), } - .execute(self.build_id(left), self.build_id(right)) + .execute(l, r) } - fn build_hashsemijoin(&mut self, args: [Id; 6], anti: bool) -> BoxedExecutor { - let [_, cond, lkeys, rkeys, left, right] = args; + fn build_hashsemijoin( + &self, + [_, cond, lkeys, rkeys, left, right]: [Id; 6], + anti: bool, + l: BoxedExecutor, + r: BoxedExecutor, + ) -> BoxedExecutor { if self.node(cond) == &Expr::true_() { HashSemiJoinExecutor { left_keys: self.resolve_column_index(lkeys, left), right_keys: self.resolve_column_index(rkeys, right), anti, } - .execute(self.build_id(left), self.build_id(right)) + .execute(l, r) } else { HashSemiJoinExecutor2 { left_keys: self.resolve_column_index(lkeys, left), @@ -493,12 +592,16 @@ impl Builder { right_types: self.plan_types(right).to_vec(), anti, } - .execute(self.build_id(left), self.build_id(right)) + .execute(l, r) } } - fn build_mergejoin(&mut self, args: [Id; 6]) -> BoxedExecutor { - let [_, cond, lkeys, rkeys, left, right] = args; + fn build_mergejoin( + &self, + [_, cond, lkeys, rkeys, left, right]: [Id; 6], + l: BoxedExecutor, + r: BoxedExecutor, + ) -> BoxedExecutor { assert_eq!(self.node(cond), &Expr::true_()); MergeJoinExecutor:: { left_keys: self.resolve_column_index(lkeys, left), @@ -506,43 +609,245 @@ impl Builder { left_types: self.plan_types(left).to_vec(), right_types: self.plan_types(right).to_vec(), } - .execute(self.build_id(left), self.build_id(right)) + .execute(l, r) + } +} + +/// Spawn a new task to execute the given stream. +fn spawn(mut stream: BoxedExecutor) -> StreamSubscriber { + let (tx, rx) = async_broadcast::broadcast(16); + let handle = tokio::task::Builder::default() + .spawn(async move { + while let Some(item) = stream.next().await { + if tx.broadcast(item).await.is_err() { + // all receivers are dropped, stop the task. + return; + } + } + }) + .expect("failed to spawn task"); + + StreamSubscriber { + rx: rx.deactivate(), + task_handle: Arc::new(AbortOnDropHandle(handle)), + } +} + +/// A set of partitioned output streams. +struct PartitionedStream { + streams: Vec, +} + +/// Creates from a single stream. +impl From for PartitionedStream { + fn from(stream: BoxedExecutor) -> Self { + PartitionedStream { + streams: vec![stream], + } } +} - /// Spawn a new task to execute the given stream. - fn spawn(&mut self, id: Id, mut stream: BoxedExecutor) -> StreamSubscriber { - let name = self.node(id).to_string(); - let span = TimeSpan::default(); - let output_row_counter = Counter::default(); +impl PartitionedStream { + /// Returns the number of partitions. + fn len(&self) -> usize { + self.streams.len() + } - self.metrics - .register(id, span.clone(), output_row_counter.clone()); + /// Merges the partitioned streams into a single stream. + /// + /// ```text + /// A0 -++-> A + /// A1 -+| + /// A2 --+ + /// ``` + fn spawn_merge(self) -> BoxedExecutor { + futures::stream::select_all(self.spawn().subscribe().streams).boxed() + } + + /// Broadcasts each stream to `num_partitions` partitions. + /// + /// ```text + /// A0 -+-> A + /// A1 -+-> A + /// +-> A + /// ``` + fn spawn_broadcast(self, num_partitions: usize) -> PartitionedStream { + let subscriber = self.spawn(); + PartitionedStream { + streams: (0..num_partitions) + .map(|_| subscriber.subscribe_merge()) + .collect(), + } + } + + /// Maps each stream with the given function. + /// + /// ```text + /// A0 --f-> B0 + /// A1 --f-> B1 + /// A2 --f-> B2 + /// ``` + fn map(self, f: impl Fn(BoxedExecutor) -> BoxedExecutor) -> PartitionedStream { + PartitionedStream { + streams: self.streams.into_iter().map(f).collect(), + } + } + + /// Dispatches each stream to `num_partitions` partitions with the given function. + fn dispatch( + self, + num_partitions: usize, + f: impl Fn(BoxedExecutor) -> BoxedDispatcher, + ) -> PartitionedDispatcher { + PartitionedDispatcher { + streams: self.streams.into_iter().map(f).collect(), + num_partitions, + } + } - let (tx, rx) = async_broadcast::broadcast(16); - let handle = tokio::task::Builder::default() - .name(&format!("{id}.{name}")) - .spawn( - async move { + /// Zips up two sets of partitioned streams. + /// + /// ```text + /// A0 -+---> (A0,B0) + /// A1 -|+--> (A1,B1) + /// A2 -||+-> (A2,B2) + /// ||| + /// B0 -+|| + /// B1 --+| + /// B2 ---+ + /// ``` + fn zip(self, other: PartitionedStream) -> ZippedPartitionedStream { + ZippedPartitionedStream { + left: self.streams, + right: other.streams, + } + } + + /// Spawns each partitioned stream as a tokio task. + fn spawn(self) -> PartitionedStreamSubscriber { + PartitionedStreamSubscriber { + subscribers: self.streams.into_iter().map(spawn).collect(), + } + } + + /// Attaches metrics to the streams. + fn instrument(self, spans: Vec, counters: Vec) -> Self { + assert_eq!(self.streams.len(), spans.len()); + assert_eq!(self.streams.len(), counters.len()); + PartitionedStream { + streams: (self.streams.into_iter().zip(spans).zip(counters)) + .map(|((stream, span), counter)| stream.timed(span).counted(counter).boxed()) + .collect(), + } + } + + /// Attaches metrics to the streams. + fn counted(self, counters: Vec) -> Self { + assert_eq!(self.streams.len(), counters.len()); + PartitionedStream { + streams: (self.streams.into_iter().zip(counters)) + .map(|(stream, counter)| stream.counted(counter).boxed()) + .collect(), + } + } +} + +/// The return type of `PartitionedStream::dispatch`. +/// +/// This is the end of the pipeline. Call `spawn` to execute the streams and collect the results. +struct PartitionedDispatcher { + streams: Vec, + num_partitions: usize, +} + +impl PartitionedDispatcher { + /// Attaches metrics to the streams. + fn instrument(self, spans: Vec) -> Self { + assert_eq!(self.streams.len(), spans.len()); + PartitionedDispatcher { + streams: (self.streams.into_iter().zip(spans)) + .map(|(stream, span)| stream.timed(span).boxed()) + .collect(), + num_partitions: self.num_partitions, + } + } + + /// Spawn new tasks to execute the given dispatchers. + /// Dispatch the output to multiple partitions by the associated partition index. + fn spawn(self) -> PartitionedStreamSubscriber { + let (txs, rxs): (Vec<_>, Vec<_>) = (0..self.num_partitions) + .map(|_| async_broadcast::broadcast(16)) + .unzip(); + let mut handles = Vec::with_capacity(self.streams.len()); + for mut stream in self.streams { + let txs = txs.clone(); + let handle = tokio::task::Builder::default() + .spawn(async move { while let Some(item) = stream.next().await { - if let Ok(chunk) = &item { - output_row_counter.inc(chunk.cardinality() as _); - } - if tx.broadcast(item).await.is_err() { - // all receivers are dropped, stop the task. - return; + match item { + // send the chunk to the corresponding partition (ignore error) + Ok((chunk, partition)) => _ = txs[partition].broadcast(Ok(chunk)).await, + // broadcast the error to all partitions + Err(e) => { + for tx in &txs { + tx.broadcast(Err(e.clone())).await.unwrap(); + } + } } } - } - .instrument(tracing::info_span!("executor", id = usize::from(id), name)) - .timed(span), - ) - .expect("failed to spawn task"); - - StreamSubscriber { - rx: rx.deactivate(), - handle: Arc::new(AbortOnDropHandle(handle)), + }) + .expect("failed to spawn task"); + handles.push(handle); + } + let handles = Arc::new(handles); + PartitionedStreamSubscriber { + subscribers: rxs + .into_iter() + .map(|rx| StreamSubscriber { + rx: rx.deactivate(), + task_handle: handles.clone(), // all task handles are shared by all subscribers + }) + .collect(), + } + } +} + +/// The return type of `PartitionedStream::zip`. +struct ZippedPartitionedStream { + left: Vec, + right: Vec, +} + +impl ZippedPartitionedStream { + /// Maps each stream pair with the given function. + fn map(self, f: impl Fn(BoxedExecutor, BoxedExecutor) -> BoxedExecutor) -> PartitionedStream { + assert_eq!(self.left.len(), self.right.len()); + PartitionedStream { + streams: self + .left + .into_iter() + .zip(self.right) + .map(|(l, r)| f(l, r)) + .collect(), + } + } +} + +/// A set of partitioned stream subscribers. +struct PartitionedStreamSubscriber { + subscribers: Vec, +} + +impl PartitionedStreamSubscriber { + fn subscribe(&self) -> PartitionedStream { + PartitionedStream { + streams: self.subscribers.iter().map(|s| s.subscribe()).collect(), } } + + fn subscribe_merge(&self) -> BoxedExecutor { + futures::stream::select_all(self.subscribe().streams).boxed() + } } /// A subscriber of an executor's output stream. @@ -550,7 +855,7 @@ impl Builder { /// New streams can be created by calling `subscribe`. struct StreamSubscriber { rx: async_broadcast::InactiveReceiver>, - handle: Arc, + task_handle: Arc, } impl StreamSubscriber { @@ -559,15 +864,15 @@ impl StreamSubscriber { #[try_stream(boxed, ok = DataChunk, error = ExecutorError)] async fn to_stream( rx: async_broadcast::Receiver>, - handle: Arc, + task_handle: Arc, ) { #[for_await] for chunk in rx { yield chunk?; } - drop(handle); + drop(task_handle); } - to_stream(self.rx.activate_cloned(), self.handle.clone()) + to_stream(self.rx.activate_cloned(), self.task_handle.clone()) } } diff --git a/src/executor/nested_loop_join.rs b/src/executor/nested_loop_join.rs index 5f33c4da..a9ace38f 100644 --- a/src/executor/nested_loop_join.rs +++ b/src/executor/nested_loop_join.rs @@ -25,7 +25,12 @@ impl NestedLoopJoinExecutor { if !matches!(self.op, Expr::Inner | Expr::LeftOuter) { todo!("unsupported join type: {:?}", self.op); } + + // materialize left child let left_chunks = left_child.try_collect::>().await?; + if left_chunks.is_empty() { + return Ok(()); + } let left_rows = || left_chunks.iter().flat_map(|chunk| chunk.rows()); diff --git a/src/planner/cost.rs b/src/planner/cost.rs index 019f450b..facd661b 100644 --- a/src/planner/cost.rs +++ b/src/planner/cost.rs @@ -69,6 +69,7 @@ impl egg::CostFunction for CostFn<'_> { Insert([_, _, c]) | CopyTo([_, c]) => rows(c) * cols(c) + costs(c), Empty(_) => 0.0, Max1Row(c) => costs(c), + Exchange([_, c]) => costs(c), // expressions Column(_) | Ref(_) => 0.01, // column reference is almost free List(_) => enode.fold(0.01, |sum, id| sum + costs(&id)), // list is almost free diff --git a/src/planner/explain.rs b/src/planner/explain.rs index 0cb2a03f..c602d865 100644 --- a/src/planner/explain.rs +++ b/src/planner/explain.rs @@ -322,19 +322,34 @@ impl<'a> Explain<'a> { with_meta(vec![("windows", self.expr(windows).pretty())]), vec![self.child(child).pretty()], ), - CreateTable(t) => { - let fields = with_meta(t.pretty_table()); - Pretty::childless_record("CreateTable", fields) + Exchange([dist, child]) => Pretty::simple_record( + "Exchange", + with_meta(vec![("dist", self.expr(dist).pretty())]), + vec![self.child(child).pretty()], + ), + ToParallel(child) => { + Pretty::simple_record("ToParallel", vec![], vec![self.child(child).pretty()]) } + Single | Broadcast | Random => Pretty::display(enode), + Hash(keys) => { + Pretty::childless_record("Hash", vec![("keys", self.expr(keys).pretty())]) + } + + CreateTable(table) => Pretty::childless_record( + "CreateTable", + with_meta(vec![("table", self.expr(table).pretty())]), + ), CreateView([table, query]) => Pretty::simple_record( "CreateView", with_meta(vec![("table", self.expr(table).pretty())]), vec![self.expr(query).pretty()], ), - CreateFunction(f) => { - let v = f.pretty_function(); - Pretty::childless_record("CreateFunction", v) - } + TableDef(t) => Pretty::childless_record("TableDef", t.pretty_table()), + CreateFunction(f) => Pretty::childless_record( + "CreateFunction", + with_meta(vec![("function", self.expr(f).pretty())]), + ), + FunctionDef(f) => Pretty::childless_record("FunctionDef", f.pretty_function()), Drop(tables) => { let fields = with_meta(vec![("objects", self.expr(tables).pretty())]); Pretty::childless_record("Drop", fields) @@ -373,6 +388,7 @@ impl<'a> Explain<'a> { ), Empty(_) => Pretty::childless_record("Empty", with_meta(vec![])), Max1Row(child) => Pretty::fieldless_record("Max1Row", vec![self.expr(child).pretty()]), + Schema([_, child]) => self.child(child).pretty(), } } } diff --git a/src/planner/mod.rs b/src/planner/mod.rs index 8d4a07e5..0e56a5cd 100644 --- a/src/planner/mod.rs +++ b/src/planner/mod.rs @@ -3,7 +3,7 @@ use egg::{define_language, Id, Symbol}; use crate::binder::copy::ExtSource; -use crate::binder::{CreateFunction, CreateTable}; +use crate::binder::{FunctionDef, TableDef}; use crate::catalog::{ColumnRefId, TableRefId}; use crate::parser::{BinaryOperator, UnaryOperator}; use crate::types::{ColumnIndex, DataType, DataValue, DateTimeField}; @@ -65,7 +65,7 @@ define_language! { // functions "extract" = Extract([Id; 2]), // (extract field expr) - Field(DateTimeField), + Field(Box), "replace" = Replace([Id; 3]), // (replace expr pattern replacement) "substring" = Substring([Id; 3]), // (substring expr start length) @@ -118,9 +118,20 @@ define_language! { // child must be ordered by keys "window" = Window([Id; 2]), // (window [over..] child) // output = child || exprs - CreateTable(Box), - "create_view" = CreateView([Id; 2]), // (create_view create_table child) - CreateFunction(CreateFunction), + + // parallelism + "to_parallel" = ToParallel(Id), // (to_parallel child) + "exchange" = Exchange([Id; 2]), // (exchange dist child) + "single" = Single, // (single) merge all to one + "broadcast" = Broadcast, // (broadcast) broadcast to all + "random" = Random, // (random) random partition + "hash" = Hash(Id), // (hash key=[expr..]) partition by hash of key + + "create_table" = CreateTable(Id), // (create_table table_def) + "create_view" = CreateView([Id; 2]), // (create_view table_def child) + TableDef(Box), + "create_function" = CreateFunction(Id), // (create_function func_def) + FunctionDef(Box), "drop" = Drop(Id), // (drop [table..]) "insert" = Insert([Id; 3]), // (insert table [column..] child) "delete" = Delete([Id; 2]), // (delete table child) @@ -136,6 +147,9 @@ define_language! { // with the same schema as `child` "max1row" = Max1Row(Id), // (max1row child) // convert table to scalar + "schema" = Schema([Id; 2]), // (schema [expr..] child) + // reset schema of child to [expr..] + // this node is just pass-through in execution Symbol(Symbol), } @@ -189,13 +203,20 @@ impl Expr { t } - pub fn as_create_table(&self) -> Box { - let Self::CreateTable(v) = self else { + pub fn as_table_def(&self) -> Box { + let Self::TableDef(v) = self else { panic!("not a create table: {self}") }; v.clone() } + pub fn as_function_def(&self) -> Box { + let Self::FunctionDef(v) = self else { + panic!("not a function definition: {self}") + }; + v.clone() + } + pub fn as_ext_source(&self) -> ExtSource { let Self::ExtSource(v) = self else { panic!("not an external source: {self}") @@ -282,3 +303,14 @@ impl ExprExt for egg::EClass { .expect("not a column") } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_expr_size() { + // the size of Expr should be as small as possible + assert_eq!(std::mem::size_of::(), 32); + } +} diff --git a/src/planner/optimizer.rs b/src/planner/optimizer.rs index 202509bc..04bbd7d0 100644 --- a/src/planner/optimizer.rs +++ b/src/planner/optimizer.rs @@ -4,10 +4,11 @@ use std::sync::LazyLock; use egg::CostFunction; +use self::rules::partition::to_parallel_plan; use super::*; use crate::catalog::RootCatalogRef; -/// Plan optimizer. +/// Optimizer transforms the query plan into a more efficient one. #[derive(Clone)] pub struct Optimizer { analysis: ExprAnalysis, @@ -15,9 +16,13 @@ pub struct Optimizer { /// Optimizer configurations. #[derive(Debug, Clone, Default)] +#[non_exhaustive] pub struct Config { pub enable_range_filter_scan: bool, pub table_is_sorted_by_primary_key: bool, + /// The number of partitions of each operator. + /// If set to >1, exchange operators will be inserted into the plan. + pub parallelism: usize, } impl Optimizer { @@ -49,6 +54,10 @@ impl Optimizer { self.optimize_stage(&mut expr, &mut cost, rules, 4, 6); // 3. join reorder and hashjoin self.optimize_stage(&mut expr, &mut cost, STAGE3_RULES.iter(), 3, 8); + + if self.analysis.config.parallelism > 1 { + expr = to_parallel_plan(expr); + } expr } @@ -108,6 +117,11 @@ impl Optimizer { pub fn catalog(&self) -> &RootCatalogRef { &self.analysis.catalog } + + /// Returns the configurations. + pub fn config(&self) -> &Config { + &self.analysis.config + } } /// Stage1 rules in the optimizer. diff --git a/src/planner/rules/mod.rs b/src/planner/rules/mod.rs index f683df77..ede94520 100644 --- a/src/planner/rules/mod.rs +++ b/src/planner/rules/mod.rs @@ -14,7 +14,8 @@ //! | [`schema`] | column id to index | output schema of a plan | [`Schema`] | //! | [`type_`] | | data type | [`Type`] | //! | [`rows`] | | estimated rows | [`Rows`] | -//! | [`order`] | merge join | ordered keys | [`OrderKey`] | +//! | [`order`] | merge join | ordered keys | [`OrderKey`] | +//! | [`partition`] | to parallel plan | data partition | [`Partition`] | //! //! It would be best if you have a background in program analysis. //! Here is a recommended course: . @@ -27,6 +28,7 @@ //! [`Type`]: type_::Type //! [`Rows`]: rows::Rows //! [`OrderKey`]: order::OrderKey +//! [`Partition`]: partition::Partition use std::collections::HashSet; use std::hash::Hash; @@ -40,6 +42,7 @@ use crate::types::F32; pub mod agg; pub mod expr; pub mod order; +pub mod partition; pub mod plan; pub mod range; pub mod rows; diff --git a/src/planner/rules/order.rs b/src/planner/rules/order.rs index d6b4e153..7378e5c7 100644 --- a/src/planner/rules/order.rs +++ b/src/planner/rules/order.rs @@ -36,6 +36,8 @@ pub fn analyze_order(egraph: &EGraph, enode: &Expr) -> OrderKey { Proj([_, c]) | Filter([_, c]) | Window([_, c]) | Limit([_, _, c]) => x(c).clone(), MergeJoin([_, _, _, _, _, r]) => x(r).clone(), SortAgg([_, _, c]) => x(c).clone(), + Exchange([_, c]) => x(c).clone(), + Schema([_, c]) => x(c).clone(), // unordered for other plans _ => Box::new([]), } diff --git a/src/planner/rules/partition.rs b/src/planner/rules/partition.rs new file mode 100644 index 00000000..aea2e622 --- /dev/null +++ b/src/planner/rules/partition.rs @@ -0,0 +1,389 @@ +// Copyright 2024 RisingLight Project Authors. Licensed under Apache-2.0. + +//! This module converts physical plans into parallel plans. +//! +//! In a parallel query plan, each node represents one or more physical operators, with each +//! operator processing data from one partition. Each node has a [`Partition`] property that +//! describes how the data is partitioned. +//! +//! After the conversion, [`Exchange`](Expr::Exchange) nodes will be inserted when necessary to +//! redistribute data between partitions. + +use std::sync::LazyLock; + +use egg::{rewrite as rw, Analysis, EGraph, Language}; + +use super::*; +use crate::planner::RecExpr; + +/// Converts a physical plan into a parallel plan. +pub fn to_parallel_plan(mut plan: RecExpr) -> RecExpr { + // add to_parallel to the root node + let root_id = Id::from(plan.as_ref().len() - 1); + plan.add(Expr::ToParallel(root_id)); + + let runner = egg::Runner::<_, _, ()>::new(PartitionAnalysis) + .with_expr(&plan) + .run(TO_PARALLEL_RULES.iter()); + let extractor = egg::Extractor::new(&runner.egraph, NoToParallel); + let (_, expr) = extractor.find_best(runner.roots[0]); + + assert!( + expr.as_ref() + .iter() + .all(|node| !matches!(node, Expr::ToParallel(_))), + "unexpected ToParallel in the parallel plan:\n{}", + expr.pretty(60) + ); + expr +} + +struct NoToParallel; + +impl egg::CostFunction for NoToParallel { + type Cost = usize; + fn cost(&mut self, enode: &Expr, mut costs: C) -> Self::Cost + where + C: FnMut(Id) -> Self::Cost, + { + let cost = enode.fold(1usize, |sum, id| sum.saturating_add(costs(id))); + // if all candidates contain ToParallel, the one with the deepest ToParallel will be chosen. + if let Expr::ToParallel(_) = enode { + return cost * 1024; + } + cost + } +} + +type Rewrite = egg::Rewrite; + +static TO_PARALLEL_RULES: LazyLock> = LazyLock::new(|| { + vec![ + // scan is not partitioned + rw!("scan-to-parallel"; + "(to_parallel (scan ?table ?columns ?filter))" => + "(exchange random (scan ?table ?columns ?filter))" + ), + // values and empty are not partitioned + rw!("values-to-parallel"; + "(to_parallel ?child)" => + "(exchange random ?child)" + if node_is("?child", &["values", "empty"]) + ), + // projection does not change distribution + rw!("proj-to-parallel"; + "(to_parallel (proj ?projs ?child))" => + "(proj ?projs (to_parallel ?child))" + ), + // filter does not change distribution + rw!("filter-to-parallel"; + "(to_parallel (filter ?cond ?child))" => + "(filter ?cond (to_parallel ?child))" + ), + // order can not be partitioned + rw!("order-to-parallel"; + "(to_parallel (order ?key ?child))" => + "(order ?key (exchange single (to_parallel ?child)))" + // TODO: 2-phase ordering + // "(order ?key (exchange single (order ?key (to_parallel ?child))))" + // TODO: merge sort in the second phase? + ), + // limit can not be partitioned + rw!("limit-to-parallel"; + "(to_parallel (limit ?limit ?offset ?child))" => + "(limit ?limit ?offset (exchange single (to_parallel ?child)))" + ), + // topn can not be partitioned + rw!("topn-to-parallel"; + "(to_parallel (topn ?limit ?offset ?key ?child))" => + "(topn ?limit ?offset ?key (exchange single (to_parallel ?child)))" + ), + // join is partitioned by left + rw!("join-to-parallel"; + "(to_parallel (join ?type ?cond ?left ?right))" => + "(join ?type ?cond + (exchange random (to_parallel ?left)) + (exchange broadcast (to_parallel ?right)))" + if node_is("?type", &["inner", "left_outer", "semi", "anti"]) + ), + // hash join can be partitioned by join key + rw!("hashjoin-to-parallel"; + "(to_parallel (hashjoin ?type ?cond ?lkey ?rkey ?left ?right))" => + "(hashjoin ?type ?cond ?lkey ?rkey + (exchange (hash ?lkey) (to_parallel ?left)) + (exchange (hash ?rkey) (to_parallel ?right)))" + ), + // merge join can be partitioned by join key + rw!("mergejoin-to-parallel"; + "(to_parallel (mergejoin ?type ?cond ?lkey ?rkey ?left ?right))" => + "(mergejoin ?type ?cond ?lkey ?rkey + (exchange (hash ?lkey) (to_parallel ?left)) + (exchange (hash ?rkey) (to_parallel ?right)))" + ), + // 2-phase aggregation + rw!("agg-to-parallel"; + "(to_parallel (agg ?aggs ?child))" => + { apply_global_aggs(" + (schema ?aggs (agg ?global_aggs (exchange single + (agg ?aggs (exchange random (to_parallel ?child)))))) + ") } + // to keep the schema unchanged, we add a `schema` node + // FIXME: check if all aggs are supported in 2-phase aggregation + ), + // hash aggregation can be partitioned by group key + rw!("hashagg-to-parallel"; + "(to_parallel (hashagg ?keys ?aggs ?child))" => + "(hashagg ?keys ?aggs (exchange (hash ?keys) (to_parallel ?child)))" + ), + // sort aggregation can be partitioned by group key + rw!("sortagg-to-parallel"; + "(to_parallel (sortagg ?keys ?aggs ?child))" => + "(sortagg ?keys ?aggs (exchange (hash ?keys) (to_parallel ?child)))" + ), + // window function can not be partitioned for now + rw!("window-to-parallel"; + "(to_parallel (window ?exprs ?child))" => + "(window ?exprs (exchange single (to_parallel ?child)))" + ), + // insert + rw!("insert-to-parallel"; + "(to_parallel (insert ?table ?columns ?child))" => + "(insert ?table ?columns (to_parallel ?child))" + ), + // delete + rw!("delete-to-parallel"; + "(to_parallel (delete ?table ?child))" => + "(delete ?table (to_parallel ?child))" + ), + // copy_from + rw!("copy_from-to-parallel"; + "(to_parallel (copy_from ?dest ?types))" => + "(copy_from ?dest ?types)" + ), + // copy_to + rw!("copy_to-to-parallel"; + "(to_parallel (copy_to ?dest ?child))" => + "(copy_to ?dest (to_parallel ?child))" + ), + // explain + rw!("explain-to-parallel"; + "(to_parallel (explain ?child))" => + "(explain (to_parallel ?child))" + ), + // analyze + rw!("analyze-to-parallel"; + "(to_parallel (analyze ?child))" => + "(analyze (to_parallel ?child))" + ), + // no parallel for DDL + rw!("ddl-to-parallel"; + "(to_parallel ?child)" => "?child" + if node_is("?child", &["create_table", "create_view", "create_function", "drop"]) + ), + // unnecessary exchange can be removed + rw!("remove-exchange"; + "(exchange ?dist ?child)" => "?child" + if partition_is_same("?child", "?dist") + ), + ] +}); + +/// Returns true if the distribution of the used columns is the same as the produced columns. +fn partition_is_same( + a: &str, + b: &str, +) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { + let a = var(a); + let b = var(b); + move |egraph, _, subst| egraph[subst[a]].data == egraph[subst[b]].data +} + +/// Returns true if the given node is one of the candidates. +fn node_is( + a: &str, + candidates: &'static [&'static str], +) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { + let a = var(a); + move |egraph, _, subst| candidates.contains(&egraph[subst[a]].nodes[0].to_string().as_str()) +} + +/// Returns an applier that replaces `?global_aggs` with the nested `?aggs`. +/// +/// ```text +/// ?aggs = (list (sum a) (count b)) +/// ?global_aggs = (list (sum (ref (sum a))) (count (ref (count b)))) +/// ``` +fn apply_global_aggs(pattern_str: &str) -> impl Applier { + struct ApplyGlobalAggs { + pattern: Pattern, + aggs: Var, + global_aggs: Var, + } + impl Applier for ApplyGlobalAggs { + fn apply_one( + &self, + egraph: &mut EGraph, + eclass: Id, + subst: &Subst, + searcher_ast: Option<&PatternAst>, + rule_name: Symbol, + ) -> Vec { + let aggs = egraph[subst[self.aggs]].as_list().to_vec(); + let mut global_aggs = vec![]; + for agg in aggs { + use Expr::*; + let ref_id = egraph.add(Expr::Ref(agg)); + let global_agg = match &egraph[agg].nodes[0] { + Max(_) => Max(ref_id), + Min(_) => Min(ref_id), + Sum(_) => Sum(ref_id), + Avg(_) => panic!("avg is not supported in 2-phase aggregation"), + RowCount => Sum(ref_id), + Count(_) => Sum(ref_id), + CountDistinct(_) => { + panic!("count distinct is not supported in 2-phase aggregation") + } + First(_) => First(ref_id), + Last(_) => Last(ref_id), + node => panic!("invalid agg: {}", node), + }; + global_aggs.push(egraph.add(global_agg)); + } + let id = egraph.add(Expr::List(global_aggs.into())); + let mut subst = subst.clone(); + subst.insert(self.global_aggs, id); + self.pattern + .apply_one(egraph, eclass, &subst, searcher_ast, rule_name) + } + } + ApplyGlobalAggs { + pattern: pattern(pattern_str), + aggs: var("?aggs"), + global_aggs: var("?global_aggs"), + } +} + +/// Describes how data is partitioned. +#[derive(Debug, Default, Clone, PartialEq, Eq)] +pub enum Partition { + /// Distribution is unknown. + #[default] + Unknown, + /// Data is not partitioned. + Single, + /// Data is randomly partitioned. + Random, + /// Data is broadcasted to all partitions. + Broadcast, + /// Data is partitioned by hash of keys. + Hash(Box<[Id]>), +} + +struct PartitionAnalysis; + +impl Analysis for PartitionAnalysis { + type Data = Partition; + + fn make(egraph: &EGraph, enode: &Expr) -> Self::Data { + let x = |id: &Id| egraph[*id].data.clone(); + analyze_partition(enode, x) + } + + fn merge(&mut self, a: &mut Self::Data, b: Self::Data) -> egg::DidMerge { + merge_partition(a, b) + } +} + +/// Returns partition of the given plan node. +pub fn analyze_partition(enode: &Expr, x: impl Fn(&Id) -> Partition) -> Partition { + use Expr::*; + match enode { + // partition nodes + Single => Partition::Single, + Random => Partition::Random, + Broadcast => Partition::Broadcast, + Hash(list) => x(list), + List(ids) => Partition::Hash(ids.clone()), + + // exchange node changes distribution + Exchange([dist, _]) => x(dist), + + // leaf nodes + Scan(_) | Values(_) => Partition::Single, + + // equal to child or left child + Proj([_, c]) + | Filter([_, c]) + | Order([_, c]) + | Limit([_, _, c]) + | TopN([_, _, _, c]) + | Empty(c) + | Window([_, c]) + | Agg([_, c]) + | HashAgg([_, _, c]) + | SortAgg([_, _, c]) + | Join([_, _, c, _]) + | Apply([_, c, _]) + | HashJoin([_, _, _, _, c, _]) + | MergeJoin([_, _, _, _, c, _]) => x(c), + + // not a plan node + _ => Partition::Unknown, + } +} + +fn merge_partition(a: &mut Partition, b: Partition) -> egg::DidMerge { + if *a == Partition::Unknown && b != Partition::Unknown { + *a = b; + egg::DidMerge(true, false) + } else { + egg::DidMerge(false, true) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_hash_join_to_parallel() { + let input = " + (hashjoin inner true (list a) (list b) + (scan t1 (list a) true) + (scan t2 (list b) true) + ) + "; + let distributed = " + (hashjoin inner true (list a) (list b) + (exchange (hash (list a)) + (exchange random + (scan t1 (list a) true))) + (exchange (hash (list b)) + (exchange random + (scan t2 (list b) true))) + ) + "; + let output = to_parallel_plan(input.parse().unwrap()); + let expected: RecExpr = distributed.parse().unwrap(); + assert_eq!(output.to_string(), expected.to_string()); + } + + #[test] + fn test_two_phase_agg() { + let input = " + (agg (list (sum a)) + (scan t1 (list a) true)) + "; + let distributed = " + (schema (list (sum a)) + (agg (list (sum (ref (sum a)))) + (exchange single + (agg (list (sum a)) + (exchange random + (scan t1 (list a) true)))))) + "; + let output = to_parallel_plan(input.parse().unwrap()); + let expected: RecExpr = distributed.parse().unwrap(); + assert_eq!(output.to_string(), expected.to_string()); + } +} diff --git a/src/planner/rules/rows.rs b/src/planner/rules/rows.rs index f483413d..c3dc2000 100644 --- a/src/planner/rules/rows.rs +++ b/src/planner/rules/rows.rs @@ -76,6 +76,9 @@ pub fn analyze_rows(egraph: &EGraph, enode: &Expr) -> Rows { }, Empty(_) => 0.0, Max1Row(_) => 1.0, + Schema([_, c]) => x(c), + // FIXME: broadcast distribution should multiply the number of rows + Exchange([_, c]) => x(c), // for boolean expressions, the result represents selectivity Ref(a) => x(a), diff --git a/src/planner/rules/schema.rs b/src/planner/rules/schema.rs index 7806136b..4e8719ff 100644 --- a/src/planner/rules/schema.rs +++ b/src/planner/rules/schema.rs @@ -17,7 +17,8 @@ pub fn analyze_schema( let concat = |v1: Vec, v2: Vec| v1.into_iter().chain(v2).collect(); match enode { // equal to child - Filter([_, c]) | Order([_, c]) | Limit([_, _, c]) | TopN([_, _, _, c]) | Empty(c) => x(c), + Filter([_, c]) | Order([_, c]) | Limit([_, _, c]) | TopN([_, _, _, c]) | Empty(c) + | Exchange([_, c]) => x(c), // concat 2 children Join([t, _, l, r]) @@ -37,6 +38,7 @@ pub fn analyze_schema( Proj([exprs, _]) | Agg([exprs, _]) => x(exprs), Window([exprs, child]) => concat(x(child), x(exprs)), HashAgg([keys, aggs, _]) | SortAgg([keys, aggs, _]) => concat(x(keys), x(aggs)), + Schema([exprs, _]) => x(exprs), // not plan node _ => vec![], diff --git a/src/planner/rules/type_.rs b/src/planner/rules/type_.rs index 21cd4c2f..59a9f0f3 100644 --- a/src/planner/rules/type_.rs +++ b/src/planner/rules/type_.rs @@ -142,7 +142,8 @@ pub fn analyze_type( }), // equal to child - Filter([_, c]) | Order([_, c]) | Limit([_, _, c]) | TopN([_, _, _, c]) | Empty(c) => x(c), + Filter([_, c]) | Order([_, c]) | Limit([_, _, c]) | TopN([_, _, _, c]) | Empty(c) + | Exchange([_, c]) | Schema([_, c]) => x(c), // concat 2 children Join([t, _, l, r]) | HashJoin([t, _, _, _, l, r]) | MergeJoin([t, _, _, _, l, r]) => { diff --git a/src/types/date.rs b/src/types/date.rs index 96819647..152bcbb3 100644 --- a/src/types/date.rs +++ b/src/types/date.rs @@ -193,3 +193,11 @@ impl FromStr for DateTimeField { })) } } + +impl FromStr for Box { + type Err = (); + + fn from_str(s: &str) -> Result { + DateTimeField::from_str(s).map(Box::new) + } +} diff --git a/src/utils/counted.rs b/src/utils/counted.rs new file mode 100644 index 00000000..a97cd9d8 --- /dev/null +++ b/src/utils/counted.rs @@ -0,0 +1,89 @@ +use std::fmt::{self, Debug}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use std::task::Poll; + +use futures::Stream; + +use crate::array::DataChunk; + +impl StreamExt for T {} + +/// An extension trait for `Streams` that provides counting instrument adapters. +pub trait StreamExt: Stream + Sized { + /// Binds a [`Counter`] to the [`Stream`] that counts the number of rows. + #[inline] + fn counted(self, counter: Counter) -> Counted { + Counted { + inner: self, + counter, + } + } +} + +/// Adapter for [`StreamExt::counted()`](StreamExt::counted). +#[pin_project::pin_project] +pub struct Counted { + #[pin] + inner: T, + counter: Counter, +} + +impl>> Stream for Counted { + type Item = T::Item; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let this = self.project(); + let result = this.inner.poll_next(cx); + if let Poll::Ready(Some(Ok(chunk))) = &result { + this.counter.inc(chunk.as_data_chunk().cardinality() as u64); + } + result + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +pub trait AsDataChunk { + fn as_data_chunk(&self) -> &DataChunk; +} +impl AsDataChunk for DataChunk { + fn as_data_chunk(&self) -> &DataChunk { + self + } +} +impl AsDataChunk for (DataChunk, usize) { + fn as_data_chunk(&self) -> &DataChunk { + &self.0 + } +} + +/// A counter. +#[derive(Default, Clone)] +pub struct Counter { + count: Arc, +} + +impl Debug for Counter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.get()) + } +} + +impl Counter { + /// Increments the counter. + pub fn inc(&self, value: u64) { + self.count.fetch_add(value, Ordering::Relaxed); + } + + /// Gets the current value of the counter. + pub fn get(&self) -> u64 { + self.count.load(Ordering::Relaxed) + } +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 5c461a12..d327e9e0 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,4 +1,5 @@ // Copyright 2024 RisingLight Project Authors. Licensed under Apache-2.0. +pub mod counted; pub mod time; pub mod timed; diff --git a/src/utils/timed.rs b/src/utils/timed.rs index 5c234280..b6790e0b 100644 --- a/src/utils/timed.rs +++ b/src/utils/timed.rs @@ -1,15 +1,16 @@ -use std::future::Future; +use std::fmt::{self, Debug}; use std::sync::Arc; use std::task::Poll; use std::time::{Duration, Instant}; +use futures::Stream; use parking_lot::Mutex; -impl FutureExt for T {} +impl StreamExt for T {} -/// An extension trait for `Futures` that provides tracing instrument adapters. -pub trait FutureExt: Future + Sized { - /// Binds a [`Span`] to the [`Future`] that continues to record until the future is dropped. +/// An extension trait for `Streams` that provides tracing instrument adapters. +pub trait StreamExt: Stream + Sized { + /// Binds a [`Span`] to the [`Stream`] that continues to record until the Stream is dropped. #[inline] fn timed(self, span: Span) -> Timed { Timed { @@ -19,7 +20,7 @@ pub trait FutureExt: Future + Sized { } } -/// Adapter for [`FutureExt::timed()`](FutureExt::timed). +/// Adapter for [`StreamExt::timed()`](StreamExt::timed). #[pin_project::pin_project] pub struct Timed { #[pin] @@ -27,33 +28,49 @@ pub struct Timed { span: Option, } -impl Future for Timed { - type Output = T::Output; +impl Stream for Timed { + type Item = T::Item; - fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { let this = self.project(); let _guard = this.span.as_ref().map(|s| s.enter()); - match this.inner.poll(cx) { - r @ Poll::Pending => r, - other => { - drop(_guard); - this.span.take(); - other - } + let result = this.inner.poll_next(cx); + if let Poll::Ready(None) = result { + // stream is finished + drop(_guard); + this.span.take(); } + result + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() } } -#[derive(Debug, Default, Clone)] +#[derive(Default, Clone)] pub struct Span { inner: Arc>, } +impl Debug for Span { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Span") + .field("busy_time", &self.busy_time()) + .field("finish_time", &self.finish_time()) + .finish() + } +} + #[derive(Debug, Default)] struct SpanInner { busy_time: Duration, - last_poll_time: Option, + finish_time: Option, } impl Span { @@ -68,8 +85,8 @@ impl Span { self.inner.lock().busy_time } - pub fn last_poll_time(&self) -> Option { - self.inner.lock().last_poll_time + pub fn finish_time(&self) -> Option { + self.inner.lock().finish_time } } @@ -83,6 +100,6 @@ impl Drop for Guard<'_> { let now = Instant::now(); let mut span = self.span.inner.lock(); span.busy_time += now - self.start_time; - span.last_poll_time = Some(now); + span.finish_time = Some(now); } } diff --git a/tests/sql/join_left_inner.slt b/tests/sql/join_left_inner.slt index ceeab1f5..d9600014 100644 --- a/tests/sql/join_left_inner.slt +++ b/tests/sql/join_left_inner.slt @@ -71,12 +71,12 @@ select v1, v2, v3, v4 from a, b; statement ok insert into b values (1, 100), (3, 300), (4, 400); -query IIII +query IIII rowsort select v1, v2, v3, v4 from a left join b on v1 = v3; ---- 1 1 1 100 -3 3 3 300 2 2 NULL NULL +3 3 3 300 statement ok drop table a; @@ -96,19 +96,19 @@ insert into a values (1, 1), (2, 2), (3, 3); statement ok insert into b values (1, 1, 1), (2, 2, 2), (3, 3, 4), (1, 1, 5); -query IIIII +query IIIII rowsort select v1, v2, v3, v4, v5 from a join b on v1 = v3 and v2 = v4; ---- 1 1 1 1 1 +1 1 1 1 5 2 2 2 2 2 3 3 3 3 4 -1 1 1 1 5 -query IIIII +query IIIII rowsort select v1, v2, v3, v4, v5 from a join b on v1 = v3 and v2 = v4 and v1 < v5; ---- -3 3 3 3 4 1 1 1 1 5 +3 3 3 3 4 statement ok drop table a; diff --git a/tests/sql/merge_join.slt b/tests/sql/merge_join.slt index 4eb543e1..b41df195 100644 --- a/tests/sql/merge_join.slt +++ b/tests/sql/merge_join.slt @@ -10,7 +10,7 @@ insert into t1 values (1, 10), (1, 11), (2, 20); statement ok insert into t2 values (1, -10), (1, -11), (3, -30); -query IIII +query IIII rowsort select * from (select a, b from t1 order by a, b) join (select c, d from t2 order by c, d desc) on a = c; @@ -20,7 +20,7 @@ join (select c, d from t2 order by c, d desc) on a = c; 1 11 1 -10 1 11 1 -11 -query IIII +query IIII rowsort select * from (select a, b from t1 order by a) left join (select c, d from t2 order by c) on a = c; @@ -31,7 +31,7 @@ left join (select c, d from t2 order by c) on a = c; 1 11 1 -11 2 20 NULL NULL -query IIII +query IIII rowsort select * from (select a, b from t1 order by a) right join (select c, d from t2 order by c) on a = c; @@ -42,7 +42,7 @@ right join (select c, d from t2 order by c) on a = c; 1 11 1 -11 NULL NULL 3 -30 -query IIII +query IIII rowsort select * from (select a, b from t1 order by a) full join (select c, d from t2 order by c) on a = c; diff --git a/tests/sqllogictest.rs b/tests/sqllogictest.rs index 215bdd17..9295989a 100644 --- a/tests/sqllogictest.rs +++ b/tests/sqllogictest.rs @@ -79,6 +79,7 @@ async fn test(filename: impl AsRef, engine: Engine) -> Result<()> { Engine::Disk => Database::new_on_disk(SecondaryStorageOptions::default_for_test()).await, Engine::Mem => Database::new_in_memory(), }; + db.run("set parallelism = 2;").await?; // enable data partitioning let db = DatabaseWrapper(db); let mut tester = sqllogictest::Runner::new(|| async { Ok(&db) });