From 2189631efd54fa0ead07aaf9a4865b9ef371ba80 Mon Sep 17 00:00:00 2001 From: Zhenchi Date: Tue, 15 Apr 2025 14:45:56 +0800 Subject: [PATCH] feat: optimize `matches_term` with constant term pre-compilation (#5886) * feat: precompile finder for `matches_term` Signed-off-by: Zhenchi * fix sqlness Signed-off-by: Zhenchi --------- Signed-off-by: Zhenchi --- src/query/src/optimizer.rs | 1 + src/query/src/optimizer/constant_term.rs | 454 ++++++++++++++++++ src/query/src/query_engine/state.rs | 4 + .../common/tql-explain-analyze/explain.result | 1 + 4 files changed, 460 insertions(+) create mode 100644 src/query/src/optimizer/constant_term.rs diff --git a/src/query/src/optimizer.rs b/src/query/src/optimizer.rs index c98ad0c634..e6596e923a 100644 --- a/src/query/src/optimizer.rs +++ b/src/query/src/optimizer.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +pub mod constant_term; pub mod count_wildcard; pub mod parallelize_scan; pub mod pass_distribution; diff --git a/src/query/src/optimizer/constant_term.rs b/src/query/src/optimizer/constant_term.rs new file mode 100644 index 0000000000..60e5b76d9d --- /dev/null +++ b/src/query/src/optimizer/constant_term.rs @@ -0,0 +1,454 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use arrow::array::{AsArray, BooleanArray}; +use common_function::scalars::matches_term::MatchesTermFinder; +use datafusion::config::ConfigOptions; +use datafusion::error::Result as DfResult; +use datafusion::physical_optimizer::PhysicalOptimizerRule; +use datafusion::physical_plan::filter::FilterExec; +use datafusion::physical_plan::ExecutionPlan; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::ScalarValue; +use datafusion_expr::ColumnarValue; +use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_expr::{PhysicalExpr, ScalarFunctionExpr}; + +/// A physical expression that uses a pre-compiled term finder for the `matches_term` function. +/// +/// This expression optimizes the `matches_term` function by pre-compiling the term +/// when the term is a constant value. This avoids recompiling the term for each row +/// during execution. +#[derive(Debug)] +pub struct PreCompiledMatchesTermExpr { + /// The text column expression to search in + text: Arc, + /// The constant term to search for + term: String, + /// The pre-compiled term finder + finder: MatchesTermFinder, +} + +impl fmt::Display for PreCompiledMatchesTermExpr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "MatchesConstTerm({}, \"{}\")", self.text, self.term) + } +} + +impl Hash for PreCompiledMatchesTermExpr { + fn hash(&self, state: &mut H) { + self.text.hash(state); + self.term.hash(state); + } +} + +impl PartialEq for PreCompiledMatchesTermExpr { + fn eq(&self, other: &Self) -> bool { + self.text.eq(&other.text) && self.term.eq(&other.term) + } +} + +impl Eq for PreCompiledMatchesTermExpr {} + +impl PhysicalExpr for PreCompiledMatchesTermExpr { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn data_type( + &self, + _input_schema: &arrow_schema::Schema, + ) -> datafusion::error::Result { + Ok(arrow_schema::DataType::Boolean) + } + + fn nullable(&self, input_schema: &arrow_schema::Schema) -> datafusion::error::Result { + self.text.nullable(input_schema) + } + + fn evaluate( + &self, + batch: &common_recordbatch::DfRecordBatch, + ) -> datafusion::error::Result { + let num_rows = batch.num_rows(); + + let text_value = self.text.evaluate(batch)?; + let array = text_value.into_array(num_rows)?; + let str_array = array.as_string::(); + + let mut result = BooleanArray::builder(num_rows); + for text in str_array { + match text { + Some(text) => { + result.append_value(self.finder.find(text)); + } + None => { + result.append_null(); + } + } + } + + Ok(ColumnarValue::Array(Arc::new(result.finish()))) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.text] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion::error::Result> { + Ok(Arc::new(PreCompiledMatchesTermExpr { + text: children[0].clone(), + term: self.term.clone(), + finder: self.finder.clone(), + })) + } +} + +/// Optimizer rule that pre-compiles constant term in `matches_term` function. +/// +/// This optimizer looks for `matches_term` function calls where the second argument +/// (the term to match) is a constant value. When found, it replaces the function +/// call with a specialized `PreCompiledMatchesTermExpr` that uses a pre-compiled +/// term finder. +/// +/// Example: +/// ```sql +/// -- Before optimization: +/// matches_term(text_column, 'constant_term') +/// +/// -- After optimization: +/// PreCompiledMatchesTermExpr(text_column, 'constant_term') +/// ``` +/// +/// This optimization improves performance by: +/// 1. Pre-compiling the term once instead of for each row +/// 2. Using a specialized expression that avoids function call overhead +#[derive(Debug)] +pub struct MatchesConstantTermOptimizer; + +impl PhysicalOptimizerRule for MatchesConstantTermOptimizer { + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> DfResult> { + let res = plan + .transform_down(&|plan: Arc| { + if let Some(filter) = plan.as_any().downcast_ref::() { + let pred = filter.predicate().clone(); + let new_pred = pred.transform_down(&|expr: Arc| { + if let Some(func) = expr.as_any().downcast_ref::() { + if !func.name().eq_ignore_ascii_case("matches_term") { + return Ok(Transformed::no(expr)); + } + let args = func.args(); + if args.len() != 2 { + return Ok(Transformed::no(expr)); + } + + if let Some(lit) = args[1].as_any().downcast_ref::() { + if let ScalarValue::Utf8(Some(term)) = lit.value() { + let finder = MatchesTermFinder::new(term); + let expr = PreCompiledMatchesTermExpr { + text: args[0].clone(), + term: term.to_string(), + finder, + }; + + return Ok(Transformed::yes(Arc::new(expr))); + } + } + } + + Ok(Transformed::no(expr)) + })?; + + if new_pred.transformed { + let exec = FilterExec::try_new(new_pred.data, filter.input().clone())? + .with_default_selectivity(filter.default_selectivity())? + .with_projection(filter.projection().cloned())?; + return Ok(Transformed::yes(Arc::new(exec) as _)); + } + } + + Ok(Transformed::no(plan)) + })? + .data; + + Ok(res) + } + + fn name(&self) -> &str { + "MatchesConstantTerm" + } + + fn schema_check(&self) -> bool { + false + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::array::{ArrayRef, StringArray}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use catalog::memory::MemoryCatalogManager; + use catalog::RegisterTableRequest; + use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; + use common_function::scalars::matches_term::MatchesTermFunction; + use common_function::scalars::udf::create_udf; + use common_function::state::FunctionState; + use datafusion::physical_optimizer::PhysicalOptimizerRule; + use datafusion::physical_plan::filter::FilterExec; + use datafusion::physical_plan::get_plan_string; + use datafusion::physical_plan::memory::MemoryExec; + use datafusion_common::{Column, DFSchema, ScalarValue}; + use datafusion_expr::expr::ScalarFunction; + use datafusion_expr::{Expr, ScalarUDF}; + use datafusion_physical_expr::{create_physical_expr, ScalarFunctionExpr}; + use datatypes::prelude::ConcreteDataType; + use datatypes::schema::ColumnSchema; + use session::context::QueryContext; + use table::metadata::{TableInfoBuilder, TableMetaBuilder}; + use table::test_util::EmptyTable; + + use super::*; + use crate::parser::QueryLanguageParser; + use crate::{QueryEngineFactory, QueryEngineRef}; + + fn create_test_batch() -> RecordBatch { + let schema = Schema::new(vec![Field::new("text", DataType::Utf8, true)]); + + let text_array = StringArray::from(vec![ + Some("hello world"), + Some("greeting"), + Some("hello there"), + None, + ]); + + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(text_array) as ArrayRef]).unwrap() + } + + fn create_test_engine() -> QueryEngineRef { + let table_name = "test".to_string(); + let columns = vec![ + ColumnSchema::new( + "text".to_string(), + ConcreteDataType::string_datatype(), + false, + ), + ColumnSchema::new( + "timestamp".to_string(), + ConcreteDataType::timestamp_millisecond_datatype(), + false, + ) + .with_time_index(true), + ]; + + let schema = Arc::new(datatypes::schema::Schema::new(columns)); + let table_meta = TableMetaBuilder::empty() + .schema(schema) + .primary_key_indices(vec![]) + .value_indices(vec![0]) + .next_column_id(2) + .build() + .unwrap(); + let table_info = TableInfoBuilder::default() + .name(&table_name) + .meta(table_meta) + .build() + .unwrap(); + let table = EmptyTable::from_table_info(&table_info); + let catalog_list = MemoryCatalogManager::with_default_setup(); + assert!(catalog_list + .register_table_sync(RegisterTableRequest { + catalog: DEFAULT_CATALOG_NAME.to_string(), + schema: DEFAULT_SCHEMA_NAME.to_string(), + table_name, + table_id: 1024, + table, + }) + .is_ok()); + QueryEngineFactory::new( + catalog_list, + None, + None, + None, + None, + false, + Default::default(), + ) + .query_engine() + } + + fn matches_term_udf() -> Arc { + Arc::new(create_udf( + Arc::new(MatchesTermFunction), + QueryContext::arc(), + Arc::new(FunctionState::default()), + )) + } + + #[test] + fn test_matches_term_optimization() { + let batch = create_test_batch(); + + // Create a predicate with a constant pattern + let predicate = create_physical_expr( + &Expr::ScalarFunction(ScalarFunction::new_udf( + matches_term_udf(), + vec![ + Expr::Column(Column::from_name("text")), + Expr::Literal(ScalarValue::Utf8(Some("hello".to_string()))), + ], + )), + &DFSchema::try_from(batch.schema().clone()).unwrap(), + &Default::default(), + ) + .unwrap(); + + let input = + Arc::new(MemoryExec::try_new(&[vec![batch.clone()]], batch.schema(), None).unwrap()); + let filter = FilterExec::try_new(predicate, input).unwrap(); + + // Apply the optimizer + let optimizer = MatchesConstantTermOptimizer; + let optimized_plan = optimizer + .optimize(Arc::new(filter), &Default::default()) + .unwrap(); + + let optimized_filter = optimized_plan + .as_any() + .downcast_ref::() + .unwrap(); + let predicate = optimized_filter.predicate(); + + // The predicate should be a PreCompiledMatchesTermExpr + assert!( + std::any::TypeId::of::() == predicate.as_any().type_id() + ); + } + + #[test] + fn test_matches_term_no_optimization() { + let batch = create_test_batch(); + + // Create a predicate with a non-constant pattern + let predicate = create_physical_expr( + &Expr::ScalarFunction(ScalarFunction::new_udf( + matches_term_udf(), + vec![ + Expr::Column(Column::from_name("text")), + Expr::Column(Column::from_name("text")), + ], + )), + &DFSchema::try_from(batch.schema().clone()).unwrap(), + &Default::default(), + ) + .unwrap(); + + let input = + Arc::new(MemoryExec::try_new(&[vec![batch.clone()]], batch.schema(), None).unwrap()); + let filter = FilterExec::try_new(predicate, input).unwrap(); + + let optimizer = MatchesConstantTermOptimizer; + let optimized_plan = optimizer + .optimize(Arc::new(filter), &Default::default()) + .unwrap(); + + let optimized_filter = optimized_plan + .as_any() + .downcast_ref::() + .unwrap(); + let predicate = optimized_filter.predicate(); + + // The predicate should still be a ScalarFunctionExpr + assert!(std::any::TypeId::of::() == predicate.as_any().type_id()); + } + + #[tokio::test] + async fn test_matches_term_optimization_from_sql() { + let sql = "WITH base AS ( + SELECT text, timestamp FROM test + WHERE MATCHES_TERM(text, 'hello') + AND timestamp > '2025-01-01 00:00:00' + ), + subquery1 AS ( + SELECT * FROM base + WHERE MATCHES_TERM(text, 'world') + ), + subquery2 AS ( + SELECT * FROM test + WHERE MATCHES_TERM(text, 'greeting') + AND timestamp < '2025-01-02 00:00:00' + ), + union_result AS ( + SELECT * FROM subquery1 + UNION ALL + SELECT * FROM subquery2 + ), + joined_data AS ( + SELECT a.text, a.timestamp, b.text as other_text + FROM union_result a + JOIN test b ON a.timestamp = b.timestamp + WHERE MATCHES_TERM(a.text, 'there') + ) + SELECT text, other_text + FROM joined_data + WHERE MATCHES_TERM(text, '42') + AND MATCHES_TERM(other_text, 'foo')"; + + let query_ctx = QueryContext::arc(); + + let stmt = QueryLanguageParser::parse_sql(sql, &query_ctx).unwrap(); + let engine = create_test_engine(); + let logical_plan = engine + .planner() + .plan(&stmt, query_ctx.clone()) + .await + .unwrap(); + + let engine_ctx = engine.engine_context(query_ctx); + let state = engine_ctx.state(); + + let analyzed_plan = state + .analyzer() + .execute_and_check(logical_plan.clone(), state.config_options(), |_, _| {}) + .unwrap(); + + let optimized_plan = state + .optimizer() + .optimize(analyzed_plan, state, |_, _| {}) + .unwrap(); + + let physical_plan = state + .query_planner() + .create_physical_plan(&optimized_plan, state) + .await + .unwrap(); + + let plan_str = get_plan_string(&physical_plan).join("\n"); + assert!(plan_str.contains("MatchesConstTerm")); + assert!(!plan_str.contains("matches_term")) + } +} diff --git a/src/query/src/query_engine/state.rs b/src/query/src/query_engine/state.rs index 75e1ed84a7..03f3a2a13d 100644 --- a/src/query/src/query_engine/state.rs +++ b/src/query/src/query_engine/state.rs @@ -45,6 +45,7 @@ use table::table::adapter::DfTableProviderAdapter; use table::TableRef; use crate::dist_plan::{DistExtensionPlanner, DistPlannerAnalyzer, MergeSortExtensionPlanner}; +use crate::optimizer::constant_term::MatchesConstantTermOptimizer; use crate::optimizer::count_wildcard::CountWildcardToTimeIndexRule; use crate::optimizer::parallelize_scan::ParallelizeScan; use crate::optimizer::pass_distribution::PassDistribution; @@ -143,6 +144,9 @@ impl QueryEngineState { physical_optimizer .rules .push(Arc::new(WindowedSortPhysicalRule)); + physical_optimizer + .rules + .push(Arc::new(MatchesConstantTermOptimizer)); // Add rule to remove duplicate nodes generated by other rules. Run this in the last. physical_optimizer.rules.push(Arc::new(RemoveDuplicate)); // Place SanityCheckPlan at the end of the list to ensure that it runs after all other rules. diff --git a/tests/cases/standalone/common/tql-explain-analyze/explain.result b/tests/cases/standalone/common/tql-explain-analyze/explain.result index e1bbaa89e3..8b4952ed3d 100644 --- a/tests/cases/standalone/common/tql-explain-analyze/explain.result +++ b/tests/cases/standalone/common/tql-explain-analyze/explain.result @@ -167,6 +167,7 @@ TQL EXPLAIN VERBOSE (0, 10, '5s') test; | physical_plan after ProjectionPushdown_| SAME TEXT AS ABOVE_| | physical_plan after LimitPushdown_| SAME TEXT AS ABOVE_| | physical_plan after WindowedSortRule_| SAME TEXT AS ABOVE_| +| physical_plan after MatchesConstantTerm_| SAME TEXT AS ABOVE_| | physical_plan after RemoveDuplicateRule_| SAME TEXT AS ABOVE_| | physical_plan after SanityCheckPlan_| SAME TEXT AS ABOVE_| | physical_plan_| PromInstantManipulateExec: range=[0..0], lookback=[300000], interval=[300000], time index=[j]_|