From b7319fe2b1f448f83b7ba03669f16fb97a41655b Mon Sep 17 00:00:00 2001 From: WU Jingdi Date: Thu, 10 Aug 2023 10:53:20 +0800 Subject: [PATCH] feat: Support RangeSelect LogicalPlan rewrite (#2058) * feat: Support RangeSelect LogicalPlan rewrite * chore: fix code advice * fix: change format of range_fn * chore: optimize project plan rewrite * chore: fix code advice --- Cargo.lock | 36 +- Cargo.toml | 4 +- src/common/function/src/scalars/math.rs | 48 ++- src/frontend/Cargo.toml | 1 + src/frontend/src/instance.rs | 2 +- src/query/Cargo.toml | 1 + src/query/src/error.rs | 11 + src/query/src/lib.rs | 1 + src/query/src/planner.rs | 13 +- src/query/src/query_engine/state.rs | 3 +- src/query/src/range_select.rs | 17 + src/query/src/range_select/plan.rs | 263 ++++++++++++ src/query/src/range_select/plan_rewrite.rs | 477 +++++++++++++++++++++ src/query/src/range_select/planner.rs | 48 +++ src/sql/src/statements/statement.rs | 2 +- 15 files changed, 914 insertions(+), 13 deletions(-) create mode 100644 src/query/src/range_select.rs create mode 100644 src/query/src/range_select/plan.rs create mode 100644 src/query/src/range_select/plan_rewrite.rs create mode 100644 src/query/src/range_select/planner.rs diff --git a/Cargo.lock b/Cargo.lock index 1d0ad5ca78..23a94b6f8d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2451,7 +2451,7 @@ dependencies = [ "pin-project-lite", "rand", "smallvec", - "sqlparser", + "sqlparser 0.35.0", "tempfile", "tokio", "tokio-util", @@ -2472,7 +2472,7 @@ dependencies = [ "num_cpus", "object_store", "parquet", - "sqlparser", + "sqlparser 0.35.0", ] [[package]] @@ -2501,7 +2501,7 @@ dependencies = [ "arrow", "datafusion-common", "lazy_static", - "sqlparser", + "sqlparser 0.35.0", "strum 0.25.0", "strum_macros 0.25.1", ] @@ -2579,7 +2579,7 @@ dependencies = [ "datafusion-common", "datafusion-expr", "log", - "sqlparser", + "sqlparser 0.35.0", ] [[package]] @@ -3291,6 +3291,7 @@ dependencies = [ "session", "snafu", "sql", + "sqlparser 0.34.0", "storage", "store-api", "strfmt", @@ -7241,6 +7242,7 @@ dependencies = [ "arc-swap", "arrow", "arrow-schema", + "async-recursion", "async-stream", "async-trait", "catalog", @@ -9142,7 +9144,7 @@ dependencies = [ "itertools 0.10.5", "once_cell", "snafu", - "sqlparser", + "sqlparser 0.34.0", ] [[package]] @@ -9188,6 +9190,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "sqlparser" +version = "0.34.0" +source = "git+https://github.com/GreptimeTeam/sqlparser-rs.git?rev=c3814f08afa19786b13d72b1731a1e8b3cac4ab9#c3814f08afa19786b13d72b1731a1e8b3cac4ab9" +dependencies = [ + "lazy_static", + "log", + "regex", + "sqlparser 0.35.0", + "sqlparser_derive 0.1.1 (git+https://github.com/GreptimeTeam/sqlparser-rs.git?rev=c3814f08afa19786b13d72b1731a1e8b3cac4ab9)", +] + [[package]] name = "sqlparser" version = "0.35.0" @@ -9195,7 +9209,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ca597d77c98894be1f965f2e4e2d2a61575d4998088e655476c73715c54b2b43" dependencies = [ "log", - "sqlparser_derive", + "sqlparser_derive 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -9209,6 +9223,16 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "sqlparser_derive" +version = "0.1.1" +source = "git+https://github.com/GreptimeTeam/sqlparser-rs.git?rev=c3814f08afa19786b13d72b1731a1e8b3cac4ab9#c3814f08afa19786b13d72b1731a1e8b3cac4ab9" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "sqlx" version = "0.6.3" diff --git a/Cargo.toml b/Cargo.toml index e5a2862f77..3ff5ee0cdb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -90,7 +90,9 @@ regex = "1.8" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" snafu = { version = "0.7", features = ["backtraces"] } -sqlparser = "0.35" +sqlparser = { git = "https://github.com/GreptimeTeam/sqlparser-rs.git", rev = "c3814f08afa19786b13d72b1731a1e8b3cac4ab9", features = [ + "visitor", +] } tempfile = "3" tokio = { version = "1.28", features = ["full"] } tokio-util = { version = "0.7", features = ["io-util", "compat"] } diff --git a/src/common/function/src/scalars/math.rs b/src/common/function/src/scalars/math.rs index 9329bc9448..cf68b8ff37 100644 --- a/src/common/function/src/scalars/math.rs +++ b/src/common/function/src/scalars/math.rs @@ -15,11 +15,21 @@ mod pow; mod rate; +use std::fmt; use std::sync::Arc; +use common_query::error::{GeneralDataFusionSnafu, Result}; +use common_query::prelude::Signature; +use datafusion::error::DataFusionError; +use datafusion::logical_expr::Volatility; +use datatypes::prelude::ConcreteDataType; +use datatypes::vectors::VectorRef; pub use pow::PowFunction; pub use rate::RateFunction; +use snafu::ResultExt; +use super::function::FunctionContext; +use super::Function; use crate::scalars::function_registry::FunctionRegistry; pub(crate) struct MathFunction; @@ -27,6 +37,42 @@ pub(crate) struct MathFunction; impl MathFunction { pub fn register(registry: &FunctionRegistry) { registry.register(Arc::new(PowFunction)); - registry.register(Arc::new(RateFunction)) + registry.register(Arc::new(RateFunction)); + registry.register(Arc::new(RangeFunction)) + } +} + +/// `RangeFunction` will never be used as a normal function, +/// just for datafusion to generate logical plan for RangeSelect +#[derive(Clone, Debug, Default)] +struct RangeFunction; + +impl fmt::Display for RangeFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "RANGE_FN") + } +} + +impl Function for RangeFunction { + fn name(&self) -> &str { + "range_fn" + } + + // range_fn will never been used, return_type could be arbitrary value, is not important + fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result { + Ok(ConcreteDataType::float64_datatype()) + } + + /// `range_fn` will never been used. As long as a legal signature is returned, the specific content of the signature does not matter. + /// In fact, the arguments loaded by `range_fn` are very complicated, and it is difficult to use `Signature` to describe + fn signature(&self) -> Signature { + Signature::any(0, Volatility::Immutable) + } + + fn eval(&self, _func_ctx: FunctionContext, _columns: &[VectorRef]) -> Result { + Err(DataFusionError::Internal( + "range_fn just a empty function used in range select, It should not be eval!".into(), + )) + .context(GeneralDataFusionSnafu) } } diff --git a/src/frontend/Cargo.toml b/src/frontend/Cargo.toml index cb2380a463..d30ebdda7c 100644 --- a/src/frontend/Cargo.toml +++ b/src/frontend/Cargo.toml @@ -61,6 +61,7 @@ servers = { workspace = true } session = { workspace = true } snafu.workspace = true sql = { workspace = true } +sqlparser = { workspace = true } storage = { workspace = true } store-api = { workspace = true } substrait = { workspace = true } diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index a467b7e2da..829f2f53e5 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -44,7 +44,6 @@ use common_meta::key::TableMetadataManager; use common_query::Output; use common_telemetry::logging::{debug, info}; use common_telemetry::timer; -use datafusion::sql::sqlparser::ast::ObjectName; use datanode::instance::sql::table_idents_to_full_name; use datanode::instance::InstanceRef as DnInstanceRef; use datatypes::schema::Schema; @@ -75,6 +74,7 @@ use sql::dialect::Dialect; use sql::parser::ParserContext; use sql::statements::copy::CopyTable; use sql::statements::statement::Statement; +use sqlparser::ast::ObjectName; use crate::catalog::FrontendCatalogManager; use crate::error::{ diff --git a/src/query/Cargo.toml b/src/query/Cargo.toml index 1667622b5d..27fd7fcb1a 100644 --- a/src/query/Cargo.toml +++ b/src/query/Cargo.toml @@ -8,6 +8,7 @@ license.workspace = true ahash = { version = "0.8", features = ["compile-time-rng"] } arc-swap = "1.0" arrow-schema.workspace = true +async-recursion = "1.0" async-stream.workspace = true async-trait = "0.1" catalog = { workspace = true } diff --git a/src/query/src/error.rs b/src/query/src/error.rs index 9e519c3cb7..2d38ba15e4 100644 --- a/src/query/src/error.rs +++ b/src/query/src/error.rs @@ -225,6 +225,15 @@ pub enum Error { source: datatypes::error::Error, location: Location, }, + #[snafu(display("Unknown table type, downcast failed, location: {}", location))] + UnknownTable { location: Location }, + + #[snafu(display( + "Cannot find time index column in table {}, location: {}", + table, + location + ))] + TimeIndexNotFound { table: String, location: Location }, } impl ErrorExt for Error { @@ -238,6 +247,8 @@ impl ErrorExt for Error { | CatalogNotFound { .. } | SchemaNotFound { .. } | TableNotFound { .. } + | UnknownTable { .. } + | TimeIndexNotFound { .. } | ParseTimestamp { .. } | ParseFloat { .. } | MissingRequiredField { .. } diff --git a/src/query/src/lib.rs b/src/query/src/lib.rs index 5b04c43623..891edbbe7d 100644 --- a/src/query/src/lib.rs +++ b/src/query/src/lib.rs @@ -29,6 +29,7 @@ pub mod physical_wrapper; pub mod plan; pub mod planner; pub mod query_engine; +mod range_select; pub mod sql; pub use crate::datafusion::DfContextProviderAdapter; diff --git a/src/query/src/planner.rs b/src/query/src/planner.rs index b64d20b9ff..df54405a0e 100644 --- a/src/query/src/planner.rs +++ b/src/query/src/planner.rs @@ -29,6 +29,7 @@ use crate::error::{PlanSqlSnafu, QueryPlanSnafu, Result, SqlSnafu}; use crate::parser::QueryStatement; use crate::plan::LogicalPlan; use crate::query_engine::QueryEngineState; +use crate::range_select::plan_rewrite::RangePlanRewriter; use crate::DfContextProviderAdapter; #[async_trait] @@ -53,6 +54,12 @@ impl DfLogicalPlanner { async fn plan_sql(&self, stmt: Statement, query_ctx: QueryContextRef) -> Result { let df_stmt = (&stmt).try_into().context(SqlSnafu)?; + let table_provider = DfTableSourceProvider::new( + self.engine_state.catalog_manager().clone(), + self.engine_state.disallow_cross_schema_query(), + query_ctx.as_ref(), + ); + let context_provider = DfContextProviderAdapter::try_new( self.engine_state.clone(), self.session_state.clone(), @@ -77,8 +84,10 @@ impl DfLogicalPlanner { }; PlanSqlSnafu { sql } })?; - - Ok(LogicalPlan::DfPlan(result)) + let plan = RangePlanRewriter::new(table_provider, context_provider) + .rewrite(result) + .await?; + Ok(LogicalPlan::DfPlan(plan)) } async fn plan_pql(&self, stmt: EvalStmt, query_ctx: QueryContextRef) -> Result { diff --git a/src/query/src/query_engine/state.rs b/src/query/src/query_engine/state.rs index ab49a7dd7b..e7343815cb 100644 --- a/src/query/src/query_engine/state.rs +++ b/src/query/src/query_engine/state.rs @@ -48,6 +48,7 @@ use crate::optimizer::order_hint::OrderHintRule; use crate::optimizer::string_normalization::StringNormalizationRule; use crate::optimizer::type_conversion::TypeConversionRule; use crate::query_engine::options::QueryOptions; +use crate::range_select::planner::RangeSelectPlanner; /// Query engine global state // TODO(yingwen): This QueryEngineState still relies on datafusion, maybe we can define a trait for it, @@ -227,7 +228,7 @@ impl DfQueryPlanner { catalog_manager: CatalogManagerRef, ) -> Self { let mut planners: Vec> = - vec![Arc::new(PromExtensionPlanner)]; + vec![Arc::new(PromExtensionPlanner), Arc::new(RangeSelectPlanner)]; if let Some(partition_manager) = partition_manager && let Some(datanode_clients) = datanode_clients { planners.push(Arc::new(DistExtensionPlanner::new(partition_manager, datanode_clients, catalog_manager))); diff --git a/src/query/src/range_select.rs b/src/query/src/range_select.rs new file mode 100644 index 0000000000..90b41f2af1 --- /dev/null +++ b/src/query/src/range_select.rs @@ -0,0 +1,17 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub mod plan; +pub mod plan_rewrite; +pub mod planner; diff --git a/src/query/src/range_select/plan.rs b/src/query/src/range_select/plan.rs new file mode 100644 index 0000000000..d3bda025d2 --- /dev/null +++ b/src/query/src/range_select/plan.rs @@ -0,0 +1,263 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt::Display; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::time::Duration; + +use arrow_schema::{Field, Schema, SchemaRef}; +use common_query::DfPhysicalPlan; +use datafusion::common::{Result as DataFusionResult, Statistics}; +use datafusion::error::Result as DfResult; +use datafusion::execution::context::SessionState; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, RecordBatchStream, SendableRecordBatchStream, +}; +use datafusion_common::{DFField, DFSchema, DFSchemaRef}; +use datafusion_expr::utils::exprlist_to_fields; +use datafusion_expr::{Expr, ExprSchemable, LogicalPlan, UserDefinedLogicalNodeCore}; +use datatypes::arrow::record_batch::RecordBatch; +use futures::{Stream, StreamExt}; +use snafu::ResultExt; + +use crate::error::{DataFusionSnafu, Result}; + +#[derive(PartialEq, Eq, Hash, Clone, Debug)] +pub struct RangeFn { + pub expr: Expr, + pub range: Duration, + pub fill: String, +} + +impl Display for RangeFn { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "RangeFn {{ expr:{} range:{}s fill:{} }}", + self.expr.display_name().unwrap_or("?".into()), + self.range.as_secs(), + self.fill, + ) + } +} + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct RangeSelect { + /// The incoming logical plan + pub input: Arc, + /// all range expressions + pub range_expr: Vec, + pub align: Duration, + pub time_index: String, + pub by: Vec, + pub schema: DFSchemaRef, +} + +impl RangeSelect { + pub fn try_new( + input: Arc, + range_expr: Vec, + align: Duration, + time_index: Expr, + by: Vec, + ) -> Result { + let mut fields = range_expr + .iter() + .map(|RangeFn { expr, .. }| { + Ok(DFField::new_unqualified( + &expr.display_name()?, + expr.get_type(input.schema())?, + expr.nullable(input.schema())?, + )) + }) + .collect::>>() + .context(DataFusionSnafu)?; + // add align_ts + let ts_field = time_index + .to_field(input.schema().as_ref()) + .context(DataFusionSnafu)?; + let time_index_name = ts_field.name().clone(); + fields.push(ts_field); + // add by + fields.extend( + exprlist_to_fields(by.iter().collect::>(), &input).context(DataFusionSnafu)?, + ); + let schema = DFSchema::new_with_metadata(fields, input.schema().metadata().clone()) + .context(DataFusionSnafu)?; + Ok(Self { + input, + range_expr, + align, + time_index: time_index_name, + schema: Arc::new(schema), + by, + }) + } +} + +impl UserDefinedLogicalNodeCore for RangeSelect { + fn name(&self) -> &str { + "RangeSelect" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + self.range_expr + .iter() + .map(|expr| expr.expr.clone()) + .collect() + } + + fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "RangeSelect: range_exprs=[{}], align={}s time_index={}", + self.range_expr + .iter() + .map(ToString::to_string) + .collect::>() + .join(", "), + self.align.as_secs(), + self.time_index + ) + } + + fn from_template(&self, _exprs: &[Expr], inputs: &[LogicalPlan]) -> Self { + assert!(!inputs.is_empty()); + + Self { + align: self.align, + range_expr: self.range_expr.clone(), + input: Arc::new(inputs[0].clone()), + time_index: self.time_index.clone(), + schema: self.schema.clone(), + by: self.by.clone(), + } + } +} + +impl RangeSelect { + pub fn to_execution_plan( + &self, + _logical_input: &LogicalPlan, + exec_input: Arc, + _session_state: &SessionState, + ) -> DfResult> { + let fields: Vec<_> = self + .schema + .fields() + .iter() + .map(|field| Field::new(field.name(), field.data_type().clone(), field.is_nullable())) + .collect(); + Ok(Arc::new(RangeSelectExec { + input: exec_input, + schema: Arc::new(Schema::new(fields)), + })) + } +} + +#[derive(Debug)] +pub struct RangeSelectExec { + input: Arc, + schema: SchemaRef, +} + +impl DisplayAs for RangeSelectExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "RangeSelectExec: ") + } +} + +impl ExecutionPlan for RangeSelectExec { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn output_partitioning(&self) -> datafusion::physical_plan::Partitioning { + self.input.output_partitioning() + } + + fn output_ordering(&self) -> Option<&[datafusion_physical_expr::PhysicalSortExpr]> { + self.input.output_ordering() + } + + fn children(&self) -> Vec> { + vec![self.input.clone()] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + assert!(!children.is_empty()); + Ok(Arc::new(Self { + input: children[0].clone(), + schema: self.schema.clone(), + })) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> datafusion_common::Result { + let input = self.input.execute(partition, context)?; + Ok(Box::pin(RangeSelectStream { + schema: self.schema.clone(), + input, + })) + } + + fn statistics(&self) -> Statistics { + self.input.statistics() + } +} + +pub struct RangeSelectStream { + schema: SchemaRef, + input: SendableRecordBatchStream, +} + +impl RecordBatchStream for RangeSelectStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +impl Stream for RangeSelectStream { + type Item = DataFusionResult; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.input.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(_batch))) => { + Poll::Ready(Some(Ok(RecordBatch::new_empty(self.schema.clone())))) + } + Poll::Ready(other) => Poll::Ready(other), + Poll::Pending => Poll::Pending, + } + } +} diff --git a/src/query/src/range_select/plan_rewrite.rs b/src/query/src/range_select/plan_rewrite.rs new file mode 100644 index 0000000000..643280cf3c --- /dev/null +++ b/src/query/src/range_select/plan_rewrite.rs @@ -0,0 +1,477 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::str::FromStr; +use std::sync::Arc; +use std::time::Duration; + +use async_recursion::async_recursion; +use catalog::table_source::DfTableSourceProvider; +use datafusion::datasource::DefaultTableSource; +use datafusion::prelude::Column; +use datafusion::scalar::ScalarValue; +use datafusion_common::tree_node::{TreeNode, TreeNodeRewriter, VisitRecursion}; +use datafusion_common::{DFSchema, DataFusionError, Result as DFResult}; +use datafusion_expr::expr::{AggregateFunction, AggregateUDF, ScalarUDF}; +use datafusion_expr::{ + AggregateFunction as AggregateFn, Expr, Extension, LogicalPlan, LogicalPlanBuilder, Projection, +}; +use datafusion_sql::planner::ContextProvider; +use datatypes::prelude::ConcreteDataType; +use promql_parser::util::parse_duration; +use snafu::{OptionExt, ResultExt}; +use table::table::adapter::DfTableProviderAdapter; + +use crate::error::{ + CatalogSnafu, DataFusionSnafu, Result, TimeIndexNotFoundSnafu, UnknownTableSnafu, +}; +use crate::range_select::plan::{RangeFn, RangeSelect}; +use crate::DfContextProviderAdapter; + +/// `RangeExprRewriter` will recursively search certain `Expr`, find all `range_fn` scalar udf contained in `Expr`, +/// and collect the information required by the RangeSelect query, +/// and finally modify the `range_fn` scalar udf to an ordinary column field. +pub struct RangeExprRewriter<'a> { + align: Duration, + by: Vec, + range_fn: Vec, + context_provider: &'a DfContextProviderAdapter, +} + +impl<'a> RangeExprRewriter<'a> { + pub fn gen_range_expr(&self, func_name: &str, args: Vec) -> DFResult { + match AggregateFn::from_str(func_name) { + Ok(agg_fn) => Ok(Expr::AggregateFunction(AggregateFunction::new( + agg_fn, args, false, None, None, + ))), + Err(_) => match self.context_provider.get_aggregate_meta(func_name) { + Some(agg_udf) => Ok(Expr::AggregateUDF(AggregateUDF::new( + agg_udf, args, None, None, + ))), + None => Err(DataFusionError::Plan(format!( + "{} is not a Aggregate function or a Aggregate UDF", + func_name + ))), + }, + } + } +} + +fn parse_str_expr(args: &[Expr], i: usize) -> DFResult<&str> { + match args.get(i) { + Some(Expr::Literal(ScalarValue::Utf8(Some(str)))) => Ok(str.as_str()), + _ => Err(DataFusionError::Plan("Illegal str expr in range_fn".into())), + } +} + +fn parse_expr_list(args: &[Expr], start: usize, len: usize) -> DFResult> { + let mut outs = Vec::with_capacity(len); + for i in start..start + len { + outs.push(match &args.get(i) { + Some(Expr::Column(_)) | Some(Expr::BinaryExpr(_)) => args[i].clone(), + _ => return Err(DataFusionError::Plan("Illegal expr in range_fn".into())), + }); + } + Ok(outs) +} + +impl<'a> TreeNodeRewriter for RangeExprRewriter<'a> { + type N = Expr; + + fn mutate(&mut self, node: Expr) -> DFResult { + if let Expr::ScalarUDF(func) = &node { + if func.fun.name == "range_fn" { + // `range_fn(func_name, argc, [argv], range, fill, byc, [byv], align)` + // `argsv` and `byv` are variadic arguments, argc/byc indicate the length of arguments + let func_name = parse_str_expr(&func.args, 0)?; + let argc = str::parse::(parse_str_expr(&func.args, 1)?) + .map_err(|e| DataFusionError::Plan(e.to_string()))?; + let byc = str::parse::(parse_str_expr(&func.args, argc + 4)?) + .map_err(|e| DataFusionError::Plan(e.to_string()))?; + let mut range_fn = RangeFn { + expr: Expr::Wildcard, + range: parse_duration(parse_str_expr(&func.args, argc + 2)?) + .map_err(DataFusionError::Plan)?, + fill: parse_str_expr(&func.args, argc + 3)?.to_string(), + }; + let args = parse_expr_list(&func.args, 2, argc)?; + let by = parse_expr_list(&func.args, argc + 5, byc)?; + let align = parse_duration(parse_str_expr(&func.args, argc + byc + 5)?) + .map_err(DataFusionError::Plan)?; + if !self.by.is_empty() && self.by != by { + return Err(DataFusionError::Plan( + "Inconsistent by given in Range Function Rewrite".into(), + )); + } else { + self.by = by; + } + if self.align != Duration::default() && self.align != align { + return Err(DataFusionError::Plan( + "Inconsistent align given in Range Function Rewrite".into(), + )); + } else { + self.align = align; + } + range_fn.expr = self.gen_range_expr(func_name, args)?; + let alias = Expr::Column(Column::from_name(range_fn.expr.display_name()?)); + self.range_fn.push(range_fn); + return Ok(alias); + } + } + Ok(node) + } +} + +/// In order to implement RangeSelect query like `avg(field_0) RANGE '5m' FILL NULL`, +/// All RangeSelect query items are converted into udf scalar function in sql parse stage, with format like `range_fn('avg', .....)`. +/// `range_fn` contains all the parameters we need to execute RangeSelect. +/// In order to correctly execute the query process of range select, we need to modify the query plan generated by datafusion. +/// We need to recursively find the entire LogicalPlan, and find all `range_fn` scalar udf contained in the project plan, +/// collecting info we need to generate RangeSelect Query LogicalPlan and rewrite th original LogicalPlan. +pub struct RangePlanRewriter { + table_provider: DfTableSourceProvider, + context_provider: DfContextProviderAdapter, +} + +impl RangePlanRewriter { + pub fn new( + table_provider: DfTableSourceProvider, + context_provider: DfContextProviderAdapter, + ) -> Self { + Self { + table_provider, + context_provider, + } + } + + pub async fn rewrite(&mut self, plan: LogicalPlan) -> Result { + match self.rewrite_logical_plan(&plan).await? { + Some(new_plan) => Ok(new_plan), + None => Ok(plan), + } + } + + #[async_recursion] + async fn rewrite_logical_plan(&mut self, plan: &LogicalPlan) -> Result> { + let inputs = plan.inputs(); + let mut new_inputs = Vec::with_capacity(inputs.len()); + for input in &inputs { + new_inputs.push(self.rewrite_logical_plan(input).await?) + } + match plan { + LogicalPlan::Projection(Projection { expr, input, .. }) + if have_range_in_exprs(expr) => + { + let input = if let Some(new_input) = new_inputs[0].take() { + Arc::new(new_input) + } else { + input.clone() + }; + let (time_index, default_by) = self.get_index_by(input.schema().clone()).await?; + let mut range_rewriter = RangeExprRewriter { + align: Duration::default(), + by: vec![], + range_fn: vec![], + context_provider: &self.context_provider, + }; + let new_expr = expr + .iter() + .map(|expr| expr.clone().rewrite(&mut range_rewriter)) + .collect::>>() + .context(DataFusionSnafu)?; + if range_rewriter.by.is_empty() { + range_rewriter.by = default_by; + } + let range_plan = LogicalPlan::Extension(Extension { + node: Arc::new(RangeSelect::try_new( + input.clone(), + range_rewriter.range_fn, + range_rewriter.align, + time_index, + range_rewriter.by, + )?), + }); + // If the result of the project plan happens to be the schema of the range plan, no project plan is required + // that need project is identical to range plan schema. + // 1. all exprs in project must belong to range schema + // 2. range schema and project exprs must have same size + let all_in_range_schema = new_expr.iter().all(|expr| { + if let Expr::Column(column) = expr { + range_plan.schema().has_column(column) + } else { + false + } + }); + let no_additional_project = + all_in_range_schema && new_expr.len() == range_plan.schema().fields().len(); + if no_additional_project { + Ok(Some(range_plan)) + } else { + let project_plan = LogicalPlanBuilder::from(range_plan) + .project(new_expr) + .context(DataFusionSnafu)? + .build() + .context(DataFusionSnafu)?; + Ok(Some(project_plan)) + } + } + _ => { + if new_inputs.iter().any(|x| x.is_some()) { + let inputs: Vec = new_inputs + .into_iter() + .zip(inputs) + .map(|(x, y)| match x { + Some(plan) => plan, + None => y.clone(), + }) + .collect(); + Ok(Some( + plan.with_new_inputs(&inputs).context(DataFusionSnafu)?, + )) + } else { + Ok(None) + } + } + } + } + + /// this function use to find the time_index column and row columns from input schema, + /// return `(time_index, [row_columns])` to the rewriter. + /// If the user does not explicitly use the `by` keyword to indicate time series, + /// `[row_columns]` will be use as default time series + async fn get_index_by(&mut self, schema: Arc) -> Result<(Expr, Vec)> { + let mut time_index_expr = Expr::Wildcard; + let mut default_by = vec![]; + for field in schema.fields() { + if let Some(table_ref) = field.qualifier() { + let table = self + .table_provider + .resolve_table(table_ref.clone()) + .await + .context(CatalogSnafu)? + .as_any() + .downcast_ref::() + .context(UnknownTableSnafu)? + .table_provider + .as_any() + .downcast_ref::() + .context(UnknownTableSnafu)? + .table(); + let schema = table.schema(); + let time_index_column = + schema + .timestamp_column() + .with_context(|| TimeIndexNotFoundSnafu { + table: table_ref.to_string(), + })?; + // assert time_index's datatype is timestamp + if let ConcreteDataType::Timestamp(datatypes::types::TimestampType::Millisecond( + _, + )) = time_index_column.data_type + { + default_by = table + .table_info() + .meta + .row_key_column_names() + .map(|key| Expr::Column(Column::new(Some(table_ref.clone()), key))) + .collect(); + time_index_expr = Expr::Column(Column::new( + Some(table_ref.clone()), + time_index_column.name.clone(), + )); + } + } + } + if time_index_expr == Expr::Wildcard { + TimeIndexNotFoundSnafu { + table: schema.to_string(), + } + .fail() + } else { + Ok((time_index_expr, default_by)) + } + } +} + +fn have_range_in_exprs(exprs: &Vec) -> bool { + let mut have = false; + for expr in exprs { + let _ = expr.apply(&mut |expr| { + if let Expr::ScalarUDF(ScalarUDF { fun, .. }) = expr { + if fun.name == "range_fn" { + have = true; + return Ok(VisitRecursion::Stop); + } + } + Ok(VisitRecursion::Continue) + }); + if have { + break; + } + } + have +} + +#[cfg(test)] +mod test { + + use catalog::local::MemoryCatalogManager; + use catalog::{CatalogManager, RegisterTableRequest}; + use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; + use datatypes::prelude::ConcreteDataType; + use datatypes::schema::{ColumnSchema, Schema}; + use session::context::QueryContext; + use table::metadata::{TableInfoBuilder, TableMetaBuilder}; + use table::test_util::EmptyTable; + + use super::*; + use crate::parser::QueryLanguageParser; + use crate::plan::LogicalPlan as GreptimeLogicalPlan; + use crate::{QueryEngineFactory, QueryEngineRef}; + + async fn create_test_engine() -> QueryEngineRef { + let table_name = "test".to_string(); + let mut columns = vec![]; + for i in 0..5 { + columns.push(ColumnSchema::new( + format!("tag_{i}"), + ConcreteDataType::string_datatype(), + false, + )); + } + columns.push( + ColumnSchema::new( + "timestamp".to_string(), + ConcreteDataType::timestamp_millisecond_datatype(), + false, + ) + .with_time_index(true), + ); + for i in 0..5 { + columns.push(ColumnSchema::new( + format!("field_{i}"), + ConcreteDataType::float64_datatype(), + true, + )); + } + let schema = Arc::new(Schema::new(columns)); + let table_meta = TableMetaBuilder::default() + .schema(schema) + .primary_key_indices((0..5).collect()) + .value_indices((6..11).collect()) + .next_column_id(1024) + .build() + .unwrap(); + let table_info = TableInfoBuilder::default() + .name(&table_name) + .meta(table_meta) + .build() + .unwrap(); + let table = Arc::new(EmptyTable::from_table_info(&table_info)); + let catalog_list = Arc::new(MemoryCatalogManager::default()); + assert!(catalog_list + .register_table(RegisterTableRequest { + catalog: DEFAULT_CATALOG_NAME.to_string(), + schema: DEFAULT_SCHEMA_NAME.to_string(), + table_name, + table_id: 1024, + table, + }) + .await + .is_ok()); + QueryEngineFactory::new(catalog_list, false).query_engine() + } + + async fn query_plan_compare(sql: &str, expected: String) { + let stmt = QueryLanguageParser::parse_sql(sql).unwrap(); + let engine = create_test_engine().await; + let GreptimeLogicalPlan::DfPlan(plan) = engine + .planner() + .plan(stmt, QueryContext::arc()) + .await + .unwrap(); + assert_eq!(plan.display_indent_schema().to_string(), expected); + } + + #[tokio::test] + async fn range_no_project() { + let query = r#"SELECT timestamp, tag_0, tag_1, avg(field_0 + field_1) RANGE '5m' FROM test ALIGN '1h' by (tag_0,tag_1);"#; + let expected = String::from( + "RangeSelect: range_exprs=[RangeFn { expr:AVG(test.field_0 + test.field_1) range:300s fill: }], align=3600s time_index=timestamp [AVG(test.field_0 + test.field_1):Float64;N, timestamp:Timestamp(Millisecond, None), tag_0:Utf8, tag_1:Utf8]\ + \n TableScan: test [tag_0:Utf8, tag_1:Utf8, tag_2:Utf8, tag_3:Utf8, tag_4:Utf8, timestamp:Timestamp(Millisecond, None), field_0:Float64;N, field_1:Float64;N, field_2:Float64;N, field_3:Float64;N, field_4:Float64;N]" + ); + query_plan_compare(query, expected).await; + } + + #[tokio::test] + async fn range_expr_calculation() { + let query = + r#"SELECT avg(field_0 + field_1)/4 RANGE '5m' FROM test ALIGN '1h' by (tag_0,tag_1);"#; + let expected = String::from( + "Projection: AVG(test.field_0 + test.field_1) / Int64(4) [AVG(test.field_0 + test.field_1) / Int64(4):Float64;N]\ + \n RangeSelect: range_exprs=[RangeFn { expr:AVG(test.field_0 + test.field_1) range:300s fill: }], align=3600s time_index=timestamp [AVG(test.field_0 + test.field_1):Float64;N, timestamp:Timestamp(Millisecond, None), tag_0:Utf8, tag_1:Utf8]\ + \n TableScan: test [tag_0:Utf8, tag_1:Utf8, tag_2:Utf8, tag_3:Utf8, tag_4:Utf8, timestamp:Timestamp(Millisecond, None), field_0:Float64;N, field_1:Float64;N, field_2:Float64;N, field_3:Float64;N, field_4:Float64;N]" + ); + query_plan_compare(query, expected).await; + } + + #[tokio::test] + async fn range_multi_args() { + let query = + r#"SELECT covar(field_0 + field_1, field_1)/4 RANGE '5m' FROM test ALIGN '1h';"#; + let expected = String::from( + "Projection: COVARIANCE(test.field_0 + test.field_1,test.field_1) / Int64(4) [COVARIANCE(test.field_0 + test.field_1,test.field_1) / Int64(4):Float64;N]\ + \n RangeSelect: range_exprs=[RangeFn { expr:COVARIANCE(test.field_0 + test.field_1,test.field_1) range:300s fill: }], align=3600s time_index=timestamp [COVARIANCE(test.field_0 + test.field_1,test.field_1):Float64;N, timestamp:Timestamp(Millisecond, None), tag_0:Utf8, tag_1:Utf8, tag_2:Utf8, tag_3:Utf8, tag_4:Utf8]\ + \n TableScan: test [tag_0:Utf8, tag_1:Utf8, tag_2:Utf8, tag_3:Utf8, tag_4:Utf8, timestamp:Timestamp(Millisecond, None), field_0:Float64;N, field_1:Float64;N, field_2:Float64;N, field_3:Float64;N, field_4:Float64;N]" + ); + query_plan_compare(query, expected).await; + } + + #[tokio::test] + async fn range_calculation() { + let query = r#"SELECT (avg(field_0)+sum(field_1))/4 RANGE '5m' FROM test ALIGN '1h' by (tag_0,tag_1) FILL NULL;"#; + let expected = String::from( + "Projection: (AVG(test.field_0) + SUM(test.field_1)) / Int64(4) [AVG(test.field_0) + SUM(test.field_1) / Int64(4):Float64;N]\ + \n RangeSelect: range_exprs=[RangeFn { expr:AVG(test.field_0) range:300s fill:NULL }, RangeFn { expr:SUM(test.field_1) range:300s fill:NULL }], align=3600s time_index=timestamp [AVG(test.field_0):Float64;N, SUM(test.field_1):Float64;N, timestamp:Timestamp(Millisecond, None), tag_0:Utf8, tag_1:Utf8]\ + \n TableScan: test [tag_0:Utf8, tag_1:Utf8, tag_2:Utf8, tag_3:Utf8, tag_4:Utf8, timestamp:Timestamp(Millisecond, None), field_0:Float64;N, field_1:Float64;N, field_2:Float64;N, field_3:Float64;N, field_4:Float64;N]" + ); + query_plan_compare(query, expected).await; + } + + #[tokio::test] + async fn range_as_sub_query() { + let query = r#"SELECT foo + 1 from (SELECT (avg(field_0)+sum(field_1))/4 RANGE '5m' as foo FROM test ALIGN '1h' by (tag_0,tag_1) FILL NULL) where foo > 1;"#; + let expected = String::from( + "Projection: foo + Int64(1) [foo + Int64(1):Float64;N]\ + \n Filter: foo > Int64(1) [foo:Float64;N]\ + \n Projection: (AVG(test.field_0) + SUM(test.field_1)) / Int64(4) AS foo [foo:Float64;N]\ + \n RangeSelect: range_exprs=[RangeFn { expr:AVG(test.field_0) range:300s fill:NULL }, RangeFn { expr:SUM(test.field_1) range:300s fill:NULL }], align=3600s time_index=timestamp [AVG(test.field_0):Float64;N, SUM(test.field_1):Float64;N, timestamp:Timestamp(Millisecond, None), tag_0:Utf8, tag_1:Utf8]\ + \n TableScan: test [tag_0:Utf8, tag_1:Utf8, tag_2:Utf8, tag_3:Utf8, tag_4:Utf8, timestamp:Timestamp(Millisecond, None), field_0:Float64;N, field_1:Float64;N, field_2:Float64;N, field_3:Float64;N, field_4:Float64;N]" + ); + query_plan_compare(query, expected).await; + } + + #[tokio::test] + async fn range_from_nest_query() { + let query = r#"SELECT (avg(a)+sum(b))/4 RANGE '5m' FROM (SELECT field_0 as a, field_1 as b, tag_0 as c, tag_1 as d, timestamp from test where field_0 > 1.0) ALIGN '1h' by (c, d) FILL NULL;"#; + let expected = String::from( + "Projection: (AVG(a) + SUM(b)) / Int64(4) [AVG(a) + SUM(b) / Int64(4):Float64;N]\ + \n RangeSelect: range_exprs=[RangeFn { expr:AVG(a) range:300s fill:NULL }, RangeFn { expr:SUM(b) range:300s fill:NULL }], align=3600s time_index=timestamp [AVG(a):Float64;N, SUM(b):Float64;N, timestamp:Timestamp(Millisecond, None), c:Utf8, d:Utf8]\ + \n Projection: test.field_0 AS a, test.field_1 AS b, test.tag_0 AS c, test.tag_1 AS d, test.timestamp [a:Float64;N, b:Float64;N, c:Utf8, d:Utf8, timestamp:Timestamp(Millisecond, None)]\ + \n Filter: test.field_0 > Float64(1) [tag_0:Utf8, tag_1:Utf8, tag_2:Utf8, tag_3:Utf8, tag_4:Utf8, timestamp:Timestamp(Millisecond, None), field_0:Float64;N, field_1:Float64;N, field_2:Float64;N, field_3:Float64;N, field_4:Float64;N]\ + \n TableScan: test [tag_0:Utf8, tag_1:Utf8, tag_2:Utf8, tag_3:Utf8, tag_4:Utf8, timestamp:Timestamp(Millisecond, None), field_0:Float64;N, field_1:Float64;N, field_2:Float64;N, field_3:Float64;N, field_4:Float64;N]" + ); + query_plan_compare(query, expected).await; + } +} diff --git a/src/query/src/range_select/planner.rs b/src/query/src/range_select/planner.rs new file mode 100644 index 0000000000..549a867e52 --- /dev/null +++ b/src/query/src/range_select/planner.rs @@ -0,0 +1,48 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use async_trait::async_trait; +use datafusion::error::Result as DfResult; +use datafusion::execution::context::SessionState; +use datafusion::logical_expr::{LogicalPlan, UserDefinedLogicalNode}; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_planner::{ExtensionPlanner, PhysicalPlanner}; + +use super::plan::RangeSelect; + +pub struct RangeSelectPlanner; + +#[async_trait] +impl ExtensionPlanner for RangeSelectPlanner { + async fn plan_extension( + &self, + _planner: &dyn PhysicalPlanner, + node: &dyn UserDefinedLogicalNode, + logical_inputs: &[&LogicalPlan], + physical_inputs: &[Arc], + session_state: &SessionState, + ) -> DfResult>> { + if let Some(node) = node.as_any().downcast_ref::() { + Ok(Some(node.to_execution_plan( + logical_inputs[0], + physical_inputs[0].clone(), + session_state, + )?)) + } else { + Ok(None) + } + } +} diff --git a/src/sql/src/statements/statement.rs b/src/sql/src/statements/statement.rs index ee8072331d..5263f5558d 100644 --- a/src/sql/src/statements/statement.rs +++ b/src/sql/src/statements/statement.rs @@ -91,6 +91,6 @@ impl TryFrom<&Statement> for DfStatement { .fail(); } }; - Ok(DfStatement::Statement(Box::new(s))) + Ok(DfStatement::Statement(Box::new(s.into()))) } }