diff --git a/Cargo.lock b/Cargo.lock index 51669a1ea3..8daa63e3ce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6524,6 +6524,7 @@ dependencies = [ name = "query" version = "0.2.0" dependencies = [ + "ahash 0.8.3", "approx_eq", "arc-swap", "arrow-schema", diff --git a/src/cmd/src/cli/repl.rs b/src/cmd/src/cli/repl.rs index aa77df8e22..1ae27a8476 100644 --- a/src/cmd/src/cli/repl.rs +++ b/src/cmd/src/cli/repl.rs @@ -267,7 +267,11 @@ async fn create_query_engine(meta_addr: &str) -> Result { partition_manager, datanode_clients, )); - let state = Arc::new(QueryEngineState::new(catalog_list, Default::default())); + let state = Arc::new(QueryEngineState::new( + catalog_list, + false, + Default::default(), + )); Ok(DatafusionQueryEngine::new(state)) } diff --git a/src/datanode/src/instance.rs b/src/datanode/src/instance.rs index 862ebffdd3..2ba9f02eae 100644 --- a/src/datanode/src/instance.rs +++ b/src/datanode/src/instance.rs @@ -200,7 +200,7 @@ impl Instance { } }; - let factory = QueryEngineFactory::new(catalog_manager.clone()); + let factory = QueryEngineFactory::new(catalog_manager.clone(), false); let query_engine = factory.query_engine(); let handlder_executor = diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index e3e5efd42b..e230a5fcd5 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -151,7 +151,7 @@ impl Instance { let catalog_manager = Arc::new(catalog_manager); let query_engine = - QueryEngineFactory::new_with_plugins(catalog_manager.clone(), plugins.clone()) + QueryEngineFactory::new_with_plugins(catalog_manager.clone(), false, plugins.clone()) .query_engine(); let script_executor = @@ -249,7 +249,7 @@ impl Instance { catalog_manager: CatalogManagerRef, dist_instance: Arc, ) -> Self { - let query_engine = QueryEngineFactory::new(catalog_manager.clone()).query_engine(); + let query_engine = QueryEngineFactory::new(catalog_manager.clone(), false).query_engine(); let script_executor = Arc::new( ScriptExecutor::new(catalog_manager.clone(), query_engine.clone()) .await diff --git a/src/frontend/src/table/scan.rs b/src/frontend/src/table/scan.rs index 136565b53c..b8f2918066 100644 --- a/src/frontend/src/table/scan.rs +++ b/src/frontend/src/table/scan.rs @@ -67,8 +67,8 @@ impl DatanodeInstance { .logical_plan(substrait_plan.to_vec()) .await .context(error::RequestDatanodeSnafu)?; - let Output::RecordBatches(recordbatches) = result else { unreachable!() }; - Ok(recordbatches) + let Output::RecordBatches(record_batches) = result else { unreachable!() }; + Ok(record_batches) } fn build_logical_plan(&self, table_scan: &TableScanPlan) -> Result { diff --git a/src/promql/src/extension_plan/planner.rs b/src/promql/src/extension_plan/planner.rs index 3beedf3d2c..d0c5929fd9 100644 --- a/src/promql/src/extension_plan/planner.rs +++ b/src/promql/src/extension_plan/planner.rs @@ -25,7 +25,7 @@ use crate::extension_plan::{ EmptyMetric, InstantManipulate, RangeManipulate, SeriesDivide, SeriesNormalize, }; -pub struct PromExtensionPlanner {} +pub struct PromExtensionPlanner; #[async_trait] impl ExtensionPlanner for PromExtensionPlanner { diff --git a/src/query/Cargo.toml b/src/query/Cargo.toml index ea7cf3998c..1a9195d16c 100644 --- a/src/query/Cargo.toml +++ b/src/query/Cargo.toml @@ -5,6 +5,7 @@ edition.workspace = true license.workspace = true [dependencies] +ahash = { version = "0.8", features = ["compile-time-rng"] } arc-swap = "1.0" arrow-schema.workspace = true async-trait = "0.1" diff --git a/src/query/src/datafusion.rs b/src/query/src/datafusion.rs index 8dfbf05faf..19e7cf915a 100644 --- a/src/query/src/datafusion.rs +++ b/src/query/src/datafusion.rs @@ -407,7 +407,7 @@ mod tests { .register_catalog_sync(DEFAULT_CATALOG_NAME.to_string(), default_catalog) .unwrap(); - QueryEngineFactory::new(catalog_list).query_engine() + QueryEngineFactory::new(catalog_list, false).query_engine() } #[tokio::test] diff --git a/src/query/src/dist_plan.rs b/src/query/src/dist_plan.rs new file mode 100644 index 0000000000..0aa086eb43 --- /dev/null +++ b/src/query/src/dist_plan.rs @@ -0,0 +1,22 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod analyzer; +mod commutativity; +mod merge_scan; +mod planner; +mod utils; + +pub use analyzer::DistPlannerAnalyzer; +pub use planner::DistExtensionPlanner; diff --git a/src/query/src/dist_plan/analyzer.rs b/src/query/src/dist_plan/analyzer.rs new file mode 100644 index 0000000000..eb8c46cc80 --- /dev/null +++ b/src/query/src/dist_plan/analyzer.rs @@ -0,0 +1,224 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::{Arc, Mutex}; + +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeVisitor, VisitRecursion}; +use datafusion_expr::{Extension, LogicalPlan}; +use datafusion_optimizer::analyzer::AnalyzerRule; + +use crate::dist_plan::commutativity::{ + partial_commutative_transformer, Categorizer, Commutativity, +}; +use crate::dist_plan::merge_scan::MergeScanLogicalPlan; +use crate::dist_plan::utils; + +pub struct DistPlannerAnalyzer; + +impl AnalyzerRule for DistPlannerAnalyzer { + fn name(&self) -> &str { + "DistPlannerAnalyzer" + } + + fn analyze( + &self, + plan: LogicalPlan, + _config: &ConfigOptions, + ) -> datafusion_common::Result { + // (1) add merge scan + let plan = plan.transform(&Self::add_merge_scan)?; + + // (2) transform up merge scan + let mut visitor = CommutativeVisitor::new(); + plan.visit(&mut visitor)?; + let state = ExpandState::new(); + let plan = plan.transform_down(&|plan| Self::expand(plan, &visitor, &state))?; + + Ok(plan) + } +} + +impl DistPlannerAnalyzer { + /// Add [MergeScanLogicalPlan] before the table scan + fn add_merge_scan(plan: LogicalPlan) -> datafusion_common::Result> { + Ok(match plan { + LogicalPlan::TableScan(table_scan) => { + let ext_plan = LogicalPlan::Extension(Extension { + node: Arc::new(MergeScanLogicalPlan::new( + LogicalPlan::TableScan(table_scan), + true, + )), + }); + Transformed::Yes(ext_plan) + } + _ => Transformed::No(plan), + }) + } + + /// Expand stages on the stop node + fn expand( + mut plan: LogicalPlan, + visitor: &CommutativeVisitor, + state: &ExpandState, + ) -> datafusion_common::Result> { + if state.is_transformed() { + // only transform once + return Ok(Transformed::No(plan)); + } + if let Some(stop_node) = visitor.stop_node && utils::hash_plan(&plan) != stop_node { + // only act with the stop node or the root (the first node seen by this closure) if no stop node + return Ok(Transformed::No(plan)); + } + + // add merge scan + plan = MergeScanLogicalPlan::new(plan, false).into_logical_plan(); + + // add stages + for new_stage in &visitor.next_stage { + plan = new_stage.with_new_inputs(&[plan])? + } + + state.set_transformed(); + Ok(Transformed::Yes(plan)) + } +} + +struct ExpandState { + transformed: Mutex, +} + +impl ExpandState { + pub fn new() -> Self { + Self { + transformed: Mutex::new(false), + } + } + + pub fn is_transformed(&self) -> bool { + *self.transformed.lock().unwrap() + } + + /// Set the state to transformed + pub fn set_transformed(&self) { + *self.transformed.lock().unwrap() = true; + } +} + +struct CommutativeVisitor { + next_stage: Vec, + // hash of the stop node + stop_node: Option, +} + +impl TreeNodeVisitor for CommutativeVisitor { + type N = LogicalPlan; + + fn pre_visit(&mut self, plan: &LogicalPlan) -> datafusion_common::Result { + // find the first merge scan and stop traversing down + // todo: check if it works for join + Ok(match plan { + LogicalPlan::Extension(ext) => { + if ext.node.name() == MergeScanLogicalPlan::name() { + VisitRecursion::Skip + } else { + VisitRecursion::Continue + } + } + _ => VisitRecursion::Continue, + }) + } + + fn post_visit(&mut self, plan: &LogicalPlan) -> datafusion_common::Result { + match Categorizer::check_plan(plan) { + Commutativity::Commutative => {} + Commutativity::PartialCommutative => { + if let Some(plan) = partial_commutative_transformer(plan) { + self.next_stage.push(plan) + } + } + Commutativity::ConditionalCommutative(transformer) => { + if let Some(transformer) = transformer + && let Some(plan) = transformer(plan) { + self.next_stage.push(plan) + } + }, + Commutativity::TransformedCommutative(transformer) => { + if let Some(transformer) = transformer + && let Some(plan) = transformer(plan) { + self.next_stage.push(plan) + } + }, + Commutativity::NonCommutative + | Commutativity::Unimplemented + | Commutativity::Unsupported => { + self.stop_node = Some(utils::hash_plan(plan)); + return Ok(VisitRecursion::Stop); + } + } + + Ok(VisitRecursion::Continue) + } +} + +impl CommutativeVisitor { + pub fn new() -> Self { + Self { + next_stage: vec![], + stop_node: None, + } + } +} + +#[cfg(test)] +mod test { + use datafusion::datasource::DefaultTableSource; + use datafusion_expr::{col, lit, LogicalPlanBuilder}; + use table::table::adapter::DfTableProviderAdapter; + use table::table::numbers::NumbersTable; + + use super::*; + + #[test] + fn see_how_analyzer_works() { + let numbers_table = Arc::new(NumbersTable::new(0)) as _; + let table_source = Arc::new(DefaultTableSource::new(Arc::new( + DfTableProviderAdapter::new(numbers_table), + ))); + + let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![]) + .unwrap() + .filter(col("number").lt(lit(10))) + .unwrap() + .project(vec![col("number")]) + .unwrap() + .distinct() + .unwrap() + .build() + .unwrap(); + + let config = ConfigOptions::default(); + let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap(); + let expected = String::from( + "Distinct:\ + \n MergeScan [is_placeholder=false]\ + \n Distinct:\ + \n Projection: t.number\ + \n Filter: t.number < Int32(10)\ + \n MergeScan [is_placeholder=true]\ + \n TableScan: t", + ); + assert_eq!(expected, format!("{:?}", result)); + } +} diff --git a/src/query/src/dist_plan/commutativity.rs b/src/query/src/dist_plan/commutativity.rs new file mode 100644 index 0000000000..28723f3350 --- /dev/null +++ b/src/query/src/dist_plan/commutativity.rs @@ -0,0 +1,87 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use datafusion_expr::{LogicalPlan, UserDefinedLogicalNode}; + +#[allow(dead_code)] +pub enum Commutativity { + Commutative, + PartialCommutative, + ConditionalCommutative(Option), + TransformedCommutative(Option), + NonCommutative, + Unimplemented, + /// For unrelated plans like DDL + Unsupported, +} + +pub struct Categorizer {} + +impl Categorizer { + pub fn check_plan(plan: &LogicalPlan) -> Commutativity { + match plan { + LogicalPlan::Projection(_) => Commutativity::Commutative, + LogicalPlan::Filter(_) => Commutativity::Commutative, + LogicalPlan::Window(_) => Commutativity::Unimplemented, + LogicalPlan::Aggregate(_) => { + // check all children exprs and uses the strictest level + Commutativity::Unimplemented + } + LogicalPlan::Sort(_) => Commutativity::NonCommutative, + LogicalPlan::Join(_) => Commutativity::NonCommutative, + LogicalPlan::CrossJoin(_) => Commutativity::NonCommutative, + LogicalPlan::Repartition(_) => { + // unsupported? or non-commutative + Commutativity::Unimplemented + } + LogicalPlan::Union(_) => Commutativity::Unimplemented, + LogicalPlan::TableScan(_) => Commutativity::NonCommutative, + LogicalPlan::EmptyRelation(_) => Commutativity::NonCommutative, + LogicalPlan::Subquery(_) => Commutativity::Unimplemented, + LogicalPlan::SubqueryAlias(_) => Commutativity::Unimplemented, + LogicalPlan::Limit(_) => Commutativity::PartialCommutative, + LogicalPlan::Extension(extension) => { + Self::check_extension_plan(extension.node.as_ref() as _) + } + LogicalPlan::Distinct(_) => Commutativity::PartialCommutative, + LogicalPlan::Unnest(_) => Commutativity::Commutative, + LogicalPlan::Statement(_) => Commutativity::Unsupported, + LogicalPlan::CreateExternalTable(_) => Commutativity::Unsupported, + LogicalPlan::CreateMemoryTable(_) => Commutativity::Unsupported, + LogicalPlan::CreateView(_) => Commutativity::Unsupported, + LogicalPlan::CreateCatalogSchema(_) => Commutativity::Unsupported, + LogicalPlan::CreateCatalog(_) => Commutativity::Unsupported, + LogicalPlan::DropTable(_) => Commutativity::Unsupported, + LogicalPlan::DropView(_) => Commutativity::Unsupported, + LogicalPlan::Values(_) => Commutativity::Unsupported, + LogicalPlan::Explain(_) => Commutativity::Unsupported, + LogicalPlan::Analyze(_) => Commutativity::Unsupported, + LogicalPlan::Prepare(_) => Commutativity::Unsupported, + LogicalPlan::DescribeTable(_) => Commutativity::Unsupported, + LogicalPlan::Dml(_) => Commutativity::Unsupported, + } + } + + pub fn check_extension_plan(_plan: &dyn UserDefinedLogicalNode) -> Commutativity { + todo!("enumerate all the extension plans here") + } +} + +pub type Transformer = Arc Option>; + +pub fn partial_commutative_transformer(plan: &LogicalPlan) -> Option { + Some(plan.clone()) +} diff --git a/src/query/src/dist_plan/merge_scan.rs b/src/query/src/dist_plan/merge_scan.rs new file mode 100644 index 0000000000..46c9ec30bb --- /dev/null +++ b/src/query/src/dist_plan/merge_scan.rs @@ -0,0 +1,84 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use datafusion_expr::{Extension, LogicalPlan, UserDefinedLogicalNodeCore}; + +#[derive(Debug, Hash, PartialEq, Eq)] +pub struct MergeScanLogicalPlan { + /// In logical plan phase it only contains one input + input: LogicalPlan, + /// If this plan is a placeholder + is_placeholder: bool, +} + +impl UserDefinedLogicalNodeCore for MergeScanLogicalPlan { + fn name(&self) -> &str { + Self::name() + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &datafusion_common::DFSchemaRef { + self.input.schema() + } + + fn expressions(&self) -> Vec { + self.input.expressions() + } + + fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "MergeScan [is_placeholder={}]", self.is_placeholder) + } + + // todo: maybe contains exprs will be useful + // todo: add check for inputs' length + fn from_template(&self, _exprs: &[datafusion_expr::Expr], inputs: &[LogicalPlan]) -> Self { + Self { + input: inputs[0].clone(), + is_placeholder: self.is_placeholder, + } + } +} + +impl MergeScanLogicalPlan { + pub fn new(input: LogicalPlan, is_placeholder: bool) -> Self { + Self { + input, + is_placeholder, + } + } + + pub fn name() -> &'static str { + "MergeScan" + } + + /// Create a [LogicalPlan::Extension] node from this merge scan plan + pub fn into_logical_plan(self) -> LogicalPlan { + LogicalPlan::Extension(Extension { + node: Arc::new(self), + }) + } + + pub fn is_placeholder(&self) -> bool { + self.is_placeholder + } + + pub fn input(&self) -> &LogicalPlan { + &self.input + } +} diff --git a/src/query/src/dist_plan/planner.rs b/src/query/src/dist_plan/planner.rs new file mode 100644 index 0000000000..3005016abb --- /dev/null +++ b/src/query/src/dist_plan/planner.rs @@ -0,0 +1,56 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! [ExtensionPlanner] implementation for distributed planner + +use std::sync::Arc; + +use async_trait::async_trait; +use datafusion::common::Result; +use datafusion::execution::context::SessionState; +use datafusion::physical_plan::planner::ExtensionPlanner; +use datafusion::physical_plan::{ExecutionPlan, PhysicalPlanner}; +use datafusion_common::DataFusionError; +use datafusion_expr::{LogicalPlan, UserDefinedLogicalNode}; + +use crate::dist_plan::merge_scan::MergeScanLogicalPlan; + +pub struct DistExtensionPlanner; + +#[async_trait] +impl ExtensionPlanner for DistExtensionPlanner { + async fn plan_extension( + &self, + planner: &dyn PhysicalPlanner, + node: &dyn UserDefinedLogicalNode, + _logical_inputs: &[&LogicalPlan], + _physical_inputs: &[Arc], + session_state: &SessionState, + ) -> Result>> { + let maybe_merge_scan = { node.as_any().downcast_ref::() }; + if let Some(merge_scan) = maybe_merge_scan { + if merge_scan.is_placeholder() { + let input = merge_scan.input().clone(); + planner + .create_physical_plan(&input, session_state) + .await + .map(Some) + } else { + Err(DataFusionError::NotImplemented("MergeScan".to_string())) + } + } else { + Ok(None) + } + } +} diff --git a/src/query/src/dist_plan/utils.rs b/src/query/src/dist_plan/utils.rs new file mode 100644 index 0000000000..78d2a85b1e --- /dev/null +++ b/src/query/src/dist_plan/utils.rs @@ -0,0 +1,45 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::hash::{Hash, Hasher}; + +use ahash::AHasher; +use datafusion_expr::LogicalPlan; + +/// Calculate u64 hash for a [LogicalPlan]. +pub fn hash_plan(plan: &LogicalPlan) -> u64 { + let mut hasher = AHasher::default(); + plan.hash(&mut hasher); + hasher.finish() +} + +#[cfg(test)] +mod test { + use datafusion_expr::LogicalPlanBuilder; + + use super::*; + + #[test] + fn hash_two_plan() { + let plan1 = LogicalPlanBuilder::empty(false).build().unwrap(); + let plan2 = LogicalPlanBuilder::empty(false) + .explain(false, false) + .unwrap() + .build() + .unwrap(); + + assert_eq!(hash_plan(&plan1), hash_plan(&plan1)); + assert_ne!(hash_plan(&plan1), hash_plan(&plan2)); + } +} diff --git a/src/query/src/lib.rs b/src/query/src/lib.rs index ca150b5c39..7689a29726 100644 --- a/src/query/src/lib.rs +++ b/src/query/src/lib.rs @@ -12,7 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#![feature(let_chains)] + pub mod datafusion; +pub mod dist_plan; pub mod error; pub mod executor; pub mod logical_optimizer; diff --git a/src/query/src/query_engine.rs b/src/query/src/query_engine.rs index 0843ec5321..52f5ac5218 100644 --- a/src/query/src/query_engine.rs +++ b/src/query/src/query_engine.rs @@ -65,12 +65,20 @@ pub struct QueryEngineFactory { } impl QueryEngineFactory { - pub fn new(catalog_manager: CatalogManagerRef) -> Self { - Self::new_with_plugins(catalog_manager, Default::default()) + pub fn new(catalog_manager: CatalogManagerRef, with_dist_planner: bool) -> Self { + Self::new_with_plugins(catalog_manager, with_dist_planner, Default::default()) } - pub fn new_with_plugins(catalog_manager: CatalogManagerRef, plugins: Arc) -> Self { - let state = Arc::new(QueryEngineState::new(catalog_manager, plugins)); + pub fn new_with_plugins( + catalog_manager: CatalogManagerRef, + with_dist_planner: bool, + plugins: Arc, + ) -> Self { + let state = Arc::new(QueryEngineState::new( + catalog_manager, + with_dist_planner, + plugins, + )); let query_engine = Arc::new(DatafusionQueryEngine::new(state)); register_functions(&query_engine); Self { query_engine } @@ -100,7 +108,7 @@ mod tests { #[test] fn test_query_engine_factory() { let catalog_list = catalog::local::new_memory_catalog_list().unwrap(); - let factory = QueryEngineFactory::new(catalog_list); + let factory = QueryEngineFactory::new(catalog_list, false); let engine = factory.query_engine(); diff --git a/src/query/src/query_engine/state.rs b/src/query/src/query_engine/state.rs index 4348543c9d..33ee3a602e 100644 --- a/src/query/src/query_engine/state.rs +++ b/src/query/src/query_engine/state.rs @@ -32,6 +32,7 @@ use datafusion_expr::LogicalPlan as DfLogicalPlan; use datafusion_optimizer::analyzer::Analyzer; use promql::extension_plan::PromExtensionPlanner; +use crate::dist_plan::{DistExtensionPlanner, DistPlannerAnalyzer}; use crate::optimizer::TypeConversionRule; use crate::query_engine::options::QueryOptions; @@ -55,11 +56,18 @@ impl fmt::Debug for QueryEngineState { } impl QueryEngineState { - pub fn new(catalog_list: CatalogManagerRef, plugins: Arc) -> Self { + pub fn new( + catalog_list: CatalogManagerRef, + with_dist_planner: bool, + plugins: Arc, + ) -> Self { let runtime_env = Arc::new(RuntimeEnv::default()); let session_config = SessionConfig::new().with_create_default_catalog_and_schema(false); // Apply the type conversion rule first. let mut analyzer = Analyzer::new(); + if with_dist_planner { + analyzer.rules.insert(0, Arc::new(DistPlannerAnalyzer)); + } analyzer.rules.insert(0, Arc::new(TypeConversionRule)); let session_state = SessionState::with_config_rt_and_catalog_list( @@ -139,9 +147,10 @@ impl QueryPlanner for DfQueryPlanner { impl DfQueryPlanner { fn new() -> Self { Self { - physical_planner: DefaultPhysicalPlanner::with_extension_planners(vec![Arc::new( - PromExtensionPlanner {}, - )]), + physical_planner: DefaultPhysicalPlanner::with_extension_planners(vec![ + Arc::new(PromExtensionPlanner), + Arc::new(DistExtensionPlanner), + ]), } } } diff --git a/src/query/src/tests/function.rs b/src/query/src/tests/function.rs index 91759ab150..767fb5fbd4 100644 --- a/src/query/src/tests/function.rs +++ b/src/query/src/tests/function.rs @@ -66,7 +66,7 @@ pub fn create_query_engine() -> Arc { .register_catalog_sync(DEFAULT_CATALOG_NAME.to_string(), catalog_provider) .unwrap(); - QueryEngineFactory::new(catalog_list).query_engine() + QueryEngineFactory::new(catalog_list, false).query_engine() } pub async fn get_numbers_from_table<'s, T>( diff --git a/src/query/src/tests/my_sum_udaf_example.rs b/src/query/src/tests/my_sum_udaf_example.rs index 22104c5572..2ae4c6cea6 100644 --- a/src/query/src/tests/my_sum_udaf_example.rs +++ b/src/query/src/tests/my_sum_udaf_example.rs @@ -243,5 +243,5 @@ fn new_query_engine_factory(table: MemTable) -> QueryEngineFactory { .register_catalog_sync(DEFAULT_CATALOG_NAME.to_string(), catalog_provider) .unwrap(); - QueryEngineFactory::new(catalog_list) + QueryEngineFactory::new(catalog_list, false) } diff --git a/src/query/src/tests/percentile_test.rs b/src/query/src/tests/percentile_test.rs index 10fa1df574..7fa9e6e903 100644 --- a/src/query/src/tests/percentile_test.rs +++ b/src/query/src/tests/percentile_test.rs @@ -111,5 +111,5 @@ fn create_correctness_engine() -> Arc { .register_catalog_sync(DEFAULT_CATALOG_NAME.to_string(), catalog_provider) .unwrap(); - QueryEngineFactory::new(catalog_list).query_engine() + QueryEngineFactory::new(catalog_list, false).query_engine() } diff --git a/src/query/src/tests/query_engine_test.rs b/src/query/src/tests/query_engine_test.rs index 08d81e53c6..9f13bc789f 100644 --- a/src/query/src/tests/query_engine_test.rs +++ b/src/query/src/tests/query_engine_test.rs @@ -46,7 +46,7 @@ async fn test_datafusion_query_engine() -> Result<()> { let catalog_list = catalog::local::new_memory_catalog_list() .map_err(BoxedError::new) .context(QueryExecutionSnafu)?; - let factory = QueryEngineFactory::new(catalog_list); + let factory = QueryEngineFactory::new(catalog_list, false); let engine = factory.query_engine(); let column_schemas = vec![ColumnSchema::new( @@ -133,7 +133,7 @@ async fn test_query_validate() -> Result<()> { }); let plugins = Arc::new(plugins); - let factory = QueryEngineFactory::new_with_plugins(catalog_list, plugins); + let factory = QueryEngineFactory::new_with_plugins(catalog_list, false, plugins); let engine = factory.query_engine(); let stmt = QueryLanguageParser::parse_sql("select number from public.numbers").unwrap(); @@ -157,7 +157,7 @@ async fn test_udf() -> Result<()> { common_telemetry::init_default_ut_logging(); let catalog_list = catalog_list()?; - let factory = QueryEngineFactory::new(catalog_list); + let factory = QueryEngineFactory::new(catalog_list, false); let engine = factory.query_engine(); let pow = make_scalar_function(pow); diff --git a/src/query/src/tests/time_range_filter_test.rs b/src/query/src/tests/time_range_filter_test.rs index 249d32e678..e133ce4032 100644 --- a/src/query/src/tests/time_range_filter_test.rs +++ b/src/query/src/tests/time_range_filter_test.rs @@ -119,7 +119,7 @@ fn create_test_engine() -> TimeRangeTester { .register_catalog_sync("greptime".to_string(), default_catalog) .unwrap(); - let engine = QueryEngineFactory::new(catalog_list).query_engine(); + let engine = QueryEngineFactory::new(catalog_list, false).query_engine(); TimeRangeTester { engine, table } } diff --git a/src/script/benches/py_benchmark.rs b/src/script/benches/py_benchmark.rs index 44c7538207..0b5eede342 100644 --- a/src/script/benches/py_benchmark.rs +++ b/src/script/benches/py_benchmark.rs @@ -64,7 +64,7 @@ pub(crate) fn sample_script_engine() -> PyEngine { .register_catalog_sync(DEFAULT_CATALOG_NAME.to_string(), default_catalog) .unwrap(); - let factory = QueryEngineFactory::new(catalog_list); + let factory = QueryEngineFactory::new(catalog_list, false); let query_engine = factory.query_engine(); PyEngine::new(query_engine.clone()) diff --git a/src/script/src/manager.rs b/src/script/src/manager.rs index 771fd295f6..84e9ad87a2 100644 --- a/src/script/src/manager.rs +++ b/src/script/src/manager.rs @@ -171,7 +171,7 @@ mod tests { .unwrap(), ); - let factory = QueryEngineFactory::new(catalog_manager.clone()); + let factory = QueryEngineFactory::new(catalog_manager.clone(), false); let query_engine = factory.query_engine(); let mgr = ScriptManager::new(catalog_manager.clone(), query_engine) .await diff --git a/src/script/src/python/engine.rs b/src/script/src/python/engine.rs index f5e01e9ec5..152e529b8d 100644 --- a/src/script/src/python/engine.rs +++ b/src/script/src/python/engine.rs @@ -377,7 +377,7 @@ mod tests { .register_catalog_sync(DEFAULT_CATALOG_NAME.to_string(), default_catalog) .unwrap(); - let factory = QueryEngineFactory::new(catalog_list); + let factory = QueryEngineFactory::new(catalog_list, false); let query_engine = factory.query_engine(); PyEngine::new(query_engine.clone()) diff --git a/src/servers/tests/mod.rs b/src/servers/tests/mod.rs index a01436fadd..171ef03e11 100644 --- a/src/servers/tests/mod.rs +++ b/src/servers/tests/mod.rs @@ -211,7 +211,7 @@ fn create_testing_instance(table: MemTable) -> DummyInstance { .register_catalog_sync(DEFAULT_CATALOG_NAME.to_string(), catalog_provider) .unwrap(); - let factory = QueryEngineFactory::new(catalog_list); + let factory = QueryEngineFactory::new(catalog_list, false); let query_engine = factory.query_engine(); DummyInstance::new(query_engine) }