diff --git a/Cargo.lock b/Cargo.lock index dbd4ab135f..827db7ade6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2770,6 +2770,7 @@ dependencies = [ "common-base", "common-catalog", "common-error", + "common-function", "common-grpc", "common-grpc-expr", "common-query", diff --git a/config/datanode.example.toml b/config/datanode.example.toml index fbdc55307c..f7e7193d87 100644 --- a/config/datanode.example.toml +++ b/config/datanode.example.toml @@ -10,10 +10,6 @@ rpc_addr = "127.0.0.1:3001" rpc_hostname = "127.0.0.1" # The number of gRPC server worker threads, 8 by default. rpc_runtime_size = 8 -# MySQL server address, "127.0.0.1:4406" by default. -mysql_addr = "127.0.0.1:4406" -# The number of MySQL server worker threads, 2 by default. -mysql_runtime_size = 2 # Metasrv client options. [meta_client_options] diff --git a/src/cmd/src/cli/repl.rs b/src/cmd/src/cli/repl.rs index 79f3c38864..144f3ac410 100644 --- a/src/cmd/src/cli/repl.rs +++ b/src/cmd/src/cli/repl.rs @@ -32,6 +32,7 @@ use query::datafusion::DatafusionQueryEngine; use query::logical_optimizer::LogicalOptimizer; use query::parser::QueryLanguageParser; use query::plan::LogicalPlan; +use query::query_engine::QueryEngineState; use query::QueryEngine; use rustyline::error::ReadlineError; use rustyline::Editor; @@ -166,12 +167,16 @@ impl Repl { self.database.catalog(), self.database.schema(), )); - let LogicalPlan::DfPlan(plan) = query_engine - .statement_to_plan(stmt, query_ctx) + + let plan = query_engine + .planner() + .plan(stmt, query_ctx) .await - .and_then(|x| query_engine.optimize(&x)) .context(PlanStatementSnafu)?; + let LogicalPlan::DfPlan(plan) = + query_engine.optimize(&plan).context(PlanStatementSnafu)?; + let plan = DFLogicalSubstraitConvertor {} .encode(plan) .context(SubstraitEncodeLogicalPlanSnafu)?; @@ -262,6 +267,7 @@ async fn create_query_engine(meta_addr: &str) -> Result { partition_manager, datanode_clients, )); + let state = Arc::new(QueryEngineState::new(catalog_list, Default::default())); - Ok(DatafusionQueryEngine::new(catalog_list, Default::default())) + Ok(DatafusionQueryEngine::new(state)) } diff --git a/src/cmd/tests/cli.rs b/src/cmd/tests/cli.rs index 2b35c1c209..84b86a23a4 100644 --- a/src/cmd/tests/cli.rs +++ b/src/cmd/tests/cli.rs @@ -46,6 +46,9 @@ mod tests { } } + // TODO(LFC): Un-ignore this REPL test. + // Ignore this REPL test because some logical plans like create database are not supported yet in Datanode. + #[ignore] #[test] fn test_repl() { let data_dir = create_temp_dir("data"); diff --git a/src/datanode/src/error.rs b/src/datanode/src/error.rs index 6f6ab2b24a..138deb1679 100644 --- a/src/datanode/src/error.rs +++ b/src/datanode/src/error.rs @@ -35,6 +35,24 @@ pub enum Error { source: query::error::Error, }, + #[snafu(display("Failed to plan statement, source: {}", source))] + PlanStatement { + #[snafu(backtrace)] + source: query::error::Error, + }, + + #[snafu(display("Failed to execute statement, source: {}", source))] + ExecuteStatement { + #[snafu(backtrace)] + source: query::error::Error, + }, + + #[snafu(display("Failed to execute logical plan, source: {}", source))] + ExecuteLogicalPlan { + #[snafu(backtrace)] + source: query::error::Error, + }, + #[snafu(display("Failed to decode logical plan, source: {}", source))] DecodeLogicalPlan { #[snafu(backtrace)] @@ -508,7 +526,12 @@ impl ErrorExt for Error { fn status_code(&self) -> StatusCode { use Error::*; match self { - ExecuteSql { source } | DescribeStatement { source } => source.status_code(), + ExecuteSql { source } + | PlanStatement { source } + | ExecuteStatement { source } + | ExecuteLogicalPlan { source } + | DescribeStatement { source } => source.status_code(), + DecodeLogicalPlan { source } => source.status_code(), NewCatalog { source } | RegisterSchema { source } => source.status_code(), FindTable { source, .. } => source.status_code(), diff --git a/src/datanode/src/instance.rs b/src/datanode/src/instance.rs index 66d1db095c..ed2faa2f39 100644 --- a/src/datanode/src/instance.rs +++ b/src/datanode/src/instance.rs @@ -78,9 +78,6 @@ pub type InstanceRef = Arc; impl Instance { pub async fn new(opts: &DatanodeOptions) -> Result { - let object_store = new_object_store(&opts.storage).await?; - let logstore = Arc::new(create_log_store(&opts.wal).await?); - let meta_client = match opts.mode { Mode::Standalone => None, Mode::Distributed => { @@ -97,11 +94,22 @@ impl Instance { let compaction_scheduler = create_compaction_scheduler(opts); + Self::new_with(opts, meta_client, compaction_scheduler).await + } + + pub(crate) async fn new_with( + opts: &DatanodeOptions, + meta_client: Option>, + compaction_scheduler: CompactionSchedulerRef, + ) -> Result { + let object_store = new_object_store(&opts.storage).await?; + let log_store = Arc::new(create_log_store(&opts.wal).await?); + let table_engine = Arc::new(DefaultEngine::new( TableEngineConfig::default(), EngineImpl::new( StorageEngineConfig::from(opts), - logstore.clone(), + log_store.clone(), object_store.clone(), compaction_scheduler, ), @@ -109,7 +117,7 @@ impl Instance { )); // create remote catalog manager - let (catalog_manager, factory, table_id_provider) = match opts.mode { + let (catalog_manager, table_id_provider) = match opts.mode { Mode::Standalone => { if opts.enable_memory_catalog { let catalog = Arc::new(catalog::local::MemoryCatalogManager::default()); @@ -126,11 +134,8 @@ impl Instance { .await .expect("Failed to register numbers"); - let factory = QueryEngineFactory::new(catalog.clone()); - ( catalog.clone() as CatalogManagerRef, - factory, Some(catalog as TableIdProviderRef), ) } else { @@ -139,11 +144,9 @@ impl Instance { .await .context(CatalogSnafu)?, ); - let factory = QueryEngineFactory::new(catalog.clone()); ( catalog.clone() as CatalogManagerRef, - factory, Some(catalog as TableIdProviderRef), ) } @@ -157,11 +160,11 @@ impl Instance { client: meta_client.as_ref().unwrap().clone(), }), )); - let factory = QueryEngineFactory::new(catalog.clone()); - (catalog as CatalogManagerRef, factory, None) + (catalog as CatalogManagerRef, None) } }; + let factory = QueryEngineFactory::new(catalog_manager.clone()); let query_engine = factory.query_engine(); let script_executor = ScriptExecutor::new(catalog_manager.clone(), query_engine.clone()).await?; @@ -244,6 +247,10 @@ impl Instance { pub fn catalog_manager(&self) -> &CatalogManagerRef { &self.catalog_manager } + + pub fn query_engine(&self) -> QueryEngineRef { + self.query_engine.clone() + } } fn create_compaction_scheduler(opts: &DatanodeOptions) -> CompactionSchedulerRef { diff --git a/src/datanode/src/instance/grpc.rs b/src/datanode/src/instance/grpc.rs index 8b010f2ddf..1664a53e9a 100644 --- a/src/datanode/src/instance/grpc.rs +++ b/src/datanode/src/instance/grpc.rs @@ -18,15 +18,19 @@ use api::v1::query_request::Query; use api::v1::{CreateDatabaseExpr, DdlRequest, InsertRequest}; use async_trait::async_trait; use common_query::Output; -use query::parser::{PromQuery, QueryLanguageParser}; +use query::parser::{PromQuery, QueryLanguageParser, QueryStatement}; use query::plan::LogicalPlan; use servers::query_handler::grpc::GrpcQueryHandler; use session::context::QueryContextRef; use snafu::prelude::*; +use sql::statements::statement::Statement; use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan}; use table::requests::CreateDatabaseRequest; -use crate::error::{self, DecodeLogicalPlanSnafu, ExecuteSqlSnafu, Result}; +use crate::error::{ + self, DecodeLogicalPlanSnafu, ExecuteLogicalPlanSnafu, ExecuteSqlSnafu, PlanStatementSnafu, + Result, +}; use crate::instance::Instance; impl Instance { @@ -51,16 +55,32 @@ impl Instance { self.query_engine .execute(&LogicalPlan::DfPlan(logical_plan)) .await - .context(ExecuteSqlSnafu) + .context(ExecuteLogicalPlanSnafu) } async fn handle_query(&self, query: Query, ctx: QueryContextRef) -> Result { - Ok(match query { + match query { Query::Sql(sql) => { let stmt = QueryLanguageParser::parse_sql(&sql).context(ExecuteSqlSnafu)?; - self.execute_stmt(stmt, ctx).await? + match stmt { + // TODO(LFC): Remove SQL execution branch here. + // Keep this because substrait can't handle much of SQLs now. + QueryStatement::Sql(Statement::Query(_)) => { + let plan = self + .query_engine + .planner() + .plan(stmt, ctx) + .await + .context(PlanStatementSnafu)?; + self.query_engine + .execute(&plan) + .await + .context(ExecuteLogicalPlanSnafu) + } + _ => self.execute_stmt(stmt, ctx).await, + } } - Query::LogicalPlan(plan) => self.execute_logical(plan).await?, + Query::LogicalPlan(plan) => self.execute_logical(plan).await, Query::PromRangeQuery(promql) => { let prom_query = PromQuery { query: promql.query, @@ -68,9 +88,9 @@ impl Instance { end: promql.end, step: promql.step, }; - self.execute_promql(&prom_query, ctx).await? + self.execute_promql(&prom_query, ctx).await } - }) + } } pub async fn handle_insert( @@ -141,11 +161,23 @@ mod test { }; use common_recordbatch::RecordBatches; use datatypes::prelude::*; + use query::parser::QueryLanguageParser; use session::context::QueryContext; use super::*; use crate::tests::test_util::{self, MockInstance}; + async fn exec_selection(instance: &Instance, sql: &str) -> Output { + let stmt = QueryLanguageParser::parse_sql(sql).unwrap(); + let engine = instance.query_engine(); + let plan = engine + .planner() + .plan(stmt, QueryContext::arc()) + .await + .unwrap(); + engine.execute(&plan).await.unwrap() + } + #[tokio::test(flavor = "multi_thread")] async fn test_handle_ddl() { let instance = MockInstance::new("test_handle_ddl").await; @@ -208,22 +240,17 @@ mod test { let output = instance.do_query(query, QueryContext::arc()).await.unwrap(); assert!(matches!(output, Output::AffectedRows(0))); + let stmt = QueryLanguageParser::parse_sql( + "INSERT INTO my_database.my_table (a, b, ts) VALUES ('s', 1, 1672384140000)", + ) + .unwrap(); let output = instance - .execute_sql( - "INSERT INTO my_database.my_table (a, b, ts) VALUES ('s', 1, 1672384140000)", - QueryContext::arc(), - ) + .execute_stmt(stmt, QueryContext::arc()) .await .unwrap(); assert!(matches!(output, Output::AffectedRows(1))); - let output = instance - .execute_sql( - "SELECT ts, a, b FROM my_database.my_table", - QueryContext::arc(), - ) - .await - .unwrap(); + let output = exec_selection(instance, "SELECT ts, a, b FROM my_database.my_table").await; let Output::Stream(stream) = output else { unreachable!() }; let recordbatches = RecordBatches::try_collect(stream).await.unwrap(); let expected = "\ @@ -289,10 +316,7 @@ mod test { let output = instance.do_query(query, QueryContext::arc()).await.unwrap(); assert!(matches!(output, Output::AffectedRows(3))); - let output = instance - .execute_sql("SELECT ts, host, cpu FROM demo", QueryContext::arc()) - .await - .unwrap(); + let output = exec_selection(instance, "SELECT ts, host, cpu FROM demo").await; let Output::Stream(stream) = output else { unreachable!() }; let recordbatches = RecordBatches::try_collect(stream).await.unwrap(); let expected = "\ diff --git a/src/datanode/src/instance/sql.rs b/src/datanode/src/instance/sql.rs index 45019d2675..becbdbdea3 100644 --- a/src/datanode/src/instance/sql.rs +++ b/src/datanode/src/instance/sql.rs @@ -17,27 +17,28 @@ use std::time::{Duration, SystemTime}; use async_trait::async_trait; use common_error::prelude::BoxedError; use common_query::Output; -use common_recordbatch::RecordBatches; use common_telemetry::logging::info; use common_telemetry::timer; -use datatypes::schema::Schema; use futures::StreamExt; +use query::error::QueryExecutionSnafu; use query::parser::{PromQuery, QueryLanguageParser, QueryStatement}; +use query::query_engine::StatementHandler; use servers::error as server_error; use servers::prom::PromHandler; -use servers::query_handler::sql::SqlQueryHandler; use session::context::{QueryContext, QueryContextRef}; use snafu::prelude::*; use sql::ast::ObjectName; use sql::statements::copy::CopyTable; use sql::statements::statement::Statement; -use sql::statements::tql::Tql; use table::engine::TableReference; use table::requests::{ CopyTableFromRequest, CopyTableRequest, CreateDatabaseRequest, DropTableRequest, }; -use crate::error::{self, BumpTableIdSnafu, ExecuteSqlSnafu, Result, TableIdProviderNotFoundSnafu}; +use crate::error::{ + self, BumpTableIdSnafu, ExecuteSqlSnafu, ExecuteStatementSnafu, PlanStatementSnafu, Result, + TableIdProviderNotFoundSnafu, +}; use crate::instance::Instance; use crate::metric; use crate::sql::insert::InsertRequests; @@ -50,18 +51,6 @@ impl Instance { query_ctx: QueryContextRef, ) -> Result { match stmt { - QueryStatement::Sql(Statement::Query(_)) | QueryStatement::Promql(_) => { - let logical_plan = self - .query_engine - .statement_to_plan(stmt, query_ctx) - .await - .context(ExecuteSqlSnafu)?; - - self.query_engine - .execute(&logical_plan) - .await - .context(ExecuteSqlSnafu) - } QueryStatement::Sql(Statement::Insert(insert)) => { let requests = self .sql_handler @@ -163,11 +152,6 @@ impl Instance { .execute(SqlRequest::ShowTables(show_tables), query_ctx) .await } - QueryStatement::Sql(Statement::Explain(explain)) => { - self.sql_handler - .execute(SqlRequest::Explain(Box::new(explain)), query_ctx) - .await - } QueryStatement::Sql(Statement::DescribeTable(describe_table)) => { self.sql_handler .execute(SqlRequest::DescribeTable(describe_table), query_ctx) @@ -176,17 +160,6 @@ impl Instance { QueryStatement::Sql(Statement::ShowCreateTable(_show_create_table)) => { unimplemented!("SHOW CREATE TABLE is unimplemented yet"); } - QueryStatement::Sql(Statement::Use(ref schema)) => { - let catalog = &query_ctx.current_catalog(); - ensure!( - self.is_valid_schema(catalog, schema)?, - error::DatabaseNotFoundSnafu { catalog, schema } - ); - - query_ctx.set_current_schema(schema); - - Ok(Output::RecordBatches(RecordBatches::empty())) - } QueryStatement::Sql(Statement::Copy(copy_table)) => match copy_table { CopyTable::To(copy_table) => { let (catalog_name, schema_name, table_name) = @@ -220,49 +193,30 @@ impl Instance { .await } }, - QueryStatement::Sql(Statement::Tql(tql)) => self.execute_tql(tql, query_ctx).await, + QueryStatement::Sql(Statement::Query(_)) + | QueryStatement::Sql(Statement::Explain(_)) + | QueryStatement::Sql(Statement::Use(_)) + | QueryStatement::Sql(Statement::Tql(_)) + | QueryStatement::Promql(_) => unreachable!(), } } - pub(crate) async fn execute_tql(&self, tql: Tql, query_ctx: QueryContextRef) -> Result { - match tql { - Tql::Eval(eval) => { - let promql = PromQuery { - start: eval.start, - end: eval.end, - step: eval.step, - query: eval.query, - }; - let stmt = QueryLanguageParser::parse_promql(&promql).context(ExecuteSqlSnafu)?; - let logical_plan = self - .query_engine - .statement_to_plan(stmt, query_ctx) - .await - .context(ExecuteSqlSnafu)?; - - self.query_engine - .execute(&logical_plan) - .await - .context(ExecuteSqlSnafu) - } - Tql::Explain(_explain) => { - todo!("waiting for promql-parser ast adding a explain node") - } - } - } - - pub async fn execute_sql(&self, sql: &str, query_ctx: QueryContextRef) -> Result { - let stmt = QueryLanguageParser::parse_sql(sql).context(ExecuteSqlSnafu)?; - self.execute_stmt(stmt, query_ctx).await - } - pub async fn execute_promql( &self, promql: &PromQuery, query_ctx: QueryContextRef, ) -> Result { + let _timer = timer!(metric::METRIC_HANDLE_PROMQL_ELAPSED); + let stmt = QueryLanguageParser::parse_promql(promql).context(ExecuteSqlSnafu)?; - self.execute_stmt(stmt, query_ctx).await + + let engine = self.query_engine(); + let plan = engine + .planner() + .plan(stmt, query_ctx) + .await + .context(PlanStatementSnafu)?; + engine.execute(&plan).await.context(ExecuteStatementSnafu) } // TODO(ruihang): merge this and `execute_promql` after #951 landed @@ -291,7 +245,14 @@ impl Instance { eval_stmt.lookback_delta = lookback } } - self.execute_stmt(stmt, query_ctx).await + + let engine = self.query_engine(); + let plan = engine + .planner() + .plan(stmt, query_ctx) + .await + .context(PlanStatementSnafu)?; + engine.execute(&plan).await.context(ExecuteStatementSnafu) } } @@ -327,57 +288,16 @@ pub fn table_idents_to_full_name( } #[async_trait] -impl SqlQueryHandler for Instance { - type Error = error::Error; - - async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec> { - let _timer = timer!(metric::METRIC_HANDLE_SQL_ELAPSED); - // we assume sql string has only 1 statement in datanode - let result = self.execute_sql(query, query_ctx).await; - vec![result] - } - - async fn do_promql_query( +impl StatementHandler for Instance { + async fn handle_statement( &self, - query: &PromQuery, + stmt: QueryStatement, query_ctx: QueryContextRef, - ) -> Vec> { - let _timer = timer!(metric::METRIC_HANDLE_PROMQL_ELAPSED); - let result = self.execute_promql(query, query_ctx).await; - vec![result] - } - - async fn do_statement_query( - &self, - stmt: Statement, - query_ctx: QueryContextRef, - ) -> Result { - let _timer = timer!(metric::METRIC_HANDLE_SQL_ELAPSED); - self.execute_stmt(QueryStatement::Sql(stmt), query_ctx) + ) -> query::error::Result { + self.execute_stmt(stmt, query_ctx) .await - } - - async fn do_describe( - &self, - stmt: Statement, - query_ctx: QueryContextRef, - ) -> Result> { - if let Statement::Query(_) = stmt { - self.query_engine - .describe(QueryStatement::Sql(stmt), query_ctx) - .await - .map(Some) - .context(error::DescribeStatementSnafu) - } else { - Ok(None) - } - } - - fn is_valid_schema(&self, catalog: &str, schema: &str) -> Result { - self.catalog_manager - .schema(catalog, schema) - .map(|s| s.is_some()) - .context(error::CatalogSnafu) + .map_err(BoxedError::new) + .context(QueryExecutionSnafu) } } diff --git a/src/datanode/src/lib.rs b/src/datanode/src/lib.rs index b88b240325..6acd1fd94a 100644 --- a/src/datanode/src/lib.rs +++ b/src/datanode/src/lib.rs @@ -13,12 +13,13 @@ // limitations under the License. #![feature(assert_matches)] +#![feature(trait_upcasting)] pub mod datanode; pub mod error; mod heartbeat; pub mod instance; -mod metric; +pub mod metric; mod mock; mod script; pub mod server; diff --git a/src/datanode/src/mock.rs b/src/datanode/src/mock.rs index 0a46d2dbcc..179fda6ac2 100644 --- a/src/datanode/src/mock.rs +++ b/src/datanode/src/mock.rs @@ -12,32 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; -use catalog::remote::MetaKvBackend; -use catalog::CatalogManagerRef; -use common_catalog::consts::MIN_USER_TABLE_ID; use meta_client::client::{MetaClient, MetaClientBuilder}; use meta_srv::mocks::MockInfo; -use mito::config::EngineConfig as TableEngineConfig; -use query::QueryEngineFactory; -use servers::Mode; -use snafu::ResultExt; use storage::compaction::noop::NoopCompactionScheduler; -use storage::config::EngineConfig as StorageEngineConfig; -use storage::EngineImpl; -use table::metadata::TableId; -use table::table::TableIdProvider; use crate::datanode::DatanodeOptions; -use crate::error::{CatalogSnafu, RecoverProcedureSnafu, Result}; -use crate::heartbeat::HeartbeatTask; -use crate::instance::{ - create_log_store, create_procedure_manager, new_object_store, DefaultEngine, Instance, -}; -use crate::script::ScriptExecutor; -use crate::sql::SqlHandler; +use crate::error::Result; +use crate::instance::Instance; impl Instance { pub async fn with_mock_meta_client(opts: &DatanodeOptions) -> Result { @@ -46,98 +29,9 @@ impl Instance { } pub async fn with_mock_meta_server(opts: &DatanodeOptions, meta_srv: MockInfo) -> Result { - let object_store = new_object_store(&opts.storage).await?; - let logstore = Arc::new(create_log_store(&opts.wal).await?); let meta_client = Arc::new(mock_meta_client(meta_srv, opts.node_id.unwrap_or(42)).await); let compaction_scheduler = Arc::new(NoopCompactionScheduler::default()); - let table_engine = Arc::new(DefaultEngine::new( - TableEngineConfig::default(), - EngineImpl::new( - StorageEngineConfig::default(), - logstore.clone(), - object_store.clone(), - compaction_scheduler, - ), - object_store, - )); - - // By default, catalog manager and factory are created in standalone mode - let (catalog_manager, factory, heartbeat_task) = match opts.mode { - Mode::Standalone => { - let catalog = Arc::new( - catalog::local::LocalCatalogManager::try_new(table_engine.clone()) - .await - .context(CatalogSnafu)?, - ); - let factory = QueryEngineFactory::new(catalog.clone()); - (catalog as CatalogManagerRef, factory, None) - } - Mode::Distributed => { - let catalog = Arc::new(catalog::remote::RemoteCatalogManager::new( - table_engine.clone(), - opts.node_id.unwrap_or(42), - Arc::new(MetaKvBackend { - client: meta_client.clone(), - }), - )); - let factory = QueryEngineFactory::new(catalog.clone()); - let heartbeat_task = HeartbeatTask::new( - opts.node_id.unwrap_or(42), - opts.rpc_addr.clone(), - None, - meta_client.clone(), - catalog.clone(), - ); - (catalog as CatalogManagerRef, factory, Some(heartbeat_task)) - } - }; - let query_engine = factory.query_engine(); - let script_executor = - ScriptExecutor::new(catalog_manager.clone(), query_engine.clone()).await?; - - let procedure_manager = create_procedure_manager(&opts.procedure).await?; - if let Some(procedure_manager) = &procedure_manager { - table_engine.register_procedure_loaders(&**procedure_manager); - // Recover procedures. - procedure_manager - .recover() - .await - .context(RecoverProcedureSnafu)?; - } - - Ok(Self { - query_engine: query_engine.clone(), - sql_handler: SqlHandler::new( - table_engine.clone(), - catalog_manager.clone(), - query_engine.clone(), - table_engine, - procedure_manager, - ), - catalog_manager, - script_executor, - table_id_provider: Some(Arc::new(LocalTableIdProvider::default())), - heartbeat_task, - }) - } -} - -struct LocalTableIdProvider { - inner: Arc, -} - -impl Default for LocalTableIdProvider { - fn default() -> Self { - Self { - inner: Arc::new(AtomicU32::new(MIN_USER_TABLE_ID)), - } - } -} - -#[async_trait::async_trait] -impl TableIdProvider for LocalTableIdProvider { - async fn next_table_id(&self) -> table::Result { - Ok(self.inner.fetch_add(1, Ordering::Relaxed)) + Instance::new_with(opts, Some(meta_client), compaction_scheduler).await } } diff --git a/src/datanode/src/server.rs b/src/datanode/src/server.rs index b62e20c17d..2417625b0e 100644 --- a/src/datanode/src/server.rs +++ b/src/datanode/src/server.rs @@ -17,19 +17,12 @@ use std::net::SocketAddr; use std::sync::Arc; use common_runtime::Builder as RuntimeBuilder; -use common_telemetry::tracing::log::info; -use servers::error::Error::InternalIo; use servers::grpc::GrpcServer; -use servers::mysql::server::{MysqlServer, MysqlSpawnConfig, MysqlSpawnRef}; use servers::query_handler::grpc::ServerGrpcQueryHandlerAdaptor; -use servers::query_handler::sql::ServerSqlQueryHandlerAdaptor; use servers::server::Server; -use servers::tls::TlsOption; -use servers::Mode; use snafu::ResultExt; use crate::datanode::DatanodeOptions; -use crate::error::Error::StartServer; use crate::error::{ ParseAddrSnafu, Result, RuntimeResourceSnafu, ShutdownServerSnafu, StartServerSnafu, }; @@ -40,7 +33,6 @@ pub mod grpc; /// All rpc services. pub struct Services { grpc_server: GrpcServer, - mysql_server: Option>, } impl Services { @@ -53,48 +45,12 @@ impl Services { .context(RuntimeResourceSnafu)?, ); - let mysql_server = match opts.mode { - Mode::Standalone => { - info!("Disable MySQL server on datanode when running in standalone mode"); - None - } - Mode::Distributed => { - let mysql_io_runtime = Arc::new( - RuntimeBuilder::default() - .worker_threads(opts.mysql_runtime_size) - .thread_name("mysql-io-handlers") - .build() - .context(RuntimeResourceSnafu)?, - ); - let tls = TlsOption::default(); - // default tls config returns None - // but try to think a better way to do this - Some(MysqlServer::create_server( - mysql_io_runtime, - Arc::new(MysqlSpawnRef::new( - ServerSqlQueryHandlerAdaptor::arc(instance.clone()), - None, - )), - Arc::new(MysqlSpawnConfig::new( - tls.should_force_tls(), - tls.setup() - .map_err(|e| StartServer { - source: InternalIo { source: e }, - })? - .map(Arc::new), - false, - )), - )) - } - }; - Ok(Self { grpc_server: GrpcServer::new( ServerGrpcQueryHandlerAdaptor::arc(instance), None, grpc_runtime, ), - mysql_server, }) } @@ -102,32 +58,17 @@ impl Services { let grpc_addr: SocketAddr = opts.rpc_addr.parse().context(ParseAddrSnafu { addr: &opts.rpc_addr, })?; - - let mut res = vec![self.grpc_server.start(grpc_addr)]; - if let Some(mysql_server) = &self.mysql_server { - let mysql_addr = &opts.mysql_addr; - let mysql_addr: SocketAddr = mysql_addr - .parse() - .context(ParseAddrSnafu { addr: mysql_addr })?; - res.push(mysql_server.start(mysql_addr)); - }; - - futures::future::try_join_all(res) + self.grpc_server + .start(grpc_addr) .await .context(StartServerSnafu)?; Ok(()) } pub async fn shutdown(&self) -> Result<()> { - let mut res = vec![self.grpc_server.shutdown()]; - if let Some(mysql_server) = &self.mysql_server { - res.push(mysql_server.shutdown()); - } - - futures::future::try_join_all(res) + self.grpc_server + .shutdown() .await - .context(ShutdownServerSnafu)?; - - Ok(()) + .context(ShutdownServerSnafu) } } diff --git a/src/datanode/src/sql.rs b/src/datanode/src/sql.rs index e3e10f5c26..cfa804df47 100644 --- a/src/datanode/src/sql.rs +++ b/src/datanode/src/sql.rs @@ -18,12 +18,11 @@ use common_procedure::ProcedureManagerRef; use common_query::Output; use common_telemetry::error; use query::query_engine::QueryEngineRef; -use query::sql::{describe_table, explain, show_databases, show_tables}; +use query::sql::{describe_table, show_databases, show_tables}; use session::context::QueryContextRef; use snafu::{OptionExt, ResultExt}; use sql::statements::delete::Delete; use sql::statements::describe::DescribeTable; -use sql::statements::explain::Explain; use sql::statements::show::{ShowDatabases, ShowTables}; use table::engine::{EngineContext, TableEngineProcedureRef, TableEngineRef, TableReference}; use table::requests::*; @@ -52,7 +51,6 @@ pub enum SqlRequest { ShowDatabases(ShowDatabases), ShowTables(ShowTables), DescribeTable(DescribeTable), - Explain(Box), Delete(Delete), CopyTable(CopyTableRequest), CopyTableFrom(CopyTableFromRequest), @@ -118,9 +116,6 @@ impl SqlHandler { })?; describe_table(table).context(ExecuteSqlSnafu) } - SqlRequest::Explain(req) => explain(req, self.query_engine.clone(), query_ctx.clone()) - .await - .context(ExecuteSqlSnafu), }; if let Err(e) = &result { error!(e; "{query_ctx}"); diff --git a/src/datanode/src/sql/insert.rs b/src/datanode/src/sql/insert.rs index 2b6c7a044f..a60100cabc 100644 --- a/src/datanode/src/sql/insert.rs +++ b/src/datanode/src/sql/insert.rs @@ -39,8 +39,8 @@ use table::TableRef; use crate::error::{ CatalogSnafu, CollectRecordsSnafu, ColumnDefaultValueSnafu, ColumnNoneDefaultValueSnafu, ColumnNotFoundSnafu, ColumnTypeMismatchSnafu, ColumnValuesNumberMismatchSnafu, Error, - ExecuteSqlSnafu, InsertSnafu, MissingInsertBodySnafu, ParseSqlSnafu, ParseSqlValueSnafu, - Result, TableNotFoundSnafu, + ExecuteLogicalPlanSnafu, InsertSnafu, MissingInsertBodySnafu, ParseSqlSnafu, + ParseSqlValueSnafu, PlanStatementSnafu, Result, TableNotFoundSnafu, }; use crate::sql::{table_idents_to_full_name, SqlHandler, SqlRequest}; @@ -236,18 +236,19 @@ impl SqlHandler { let logical_plan = self .query_engine - .statement_to_plan( + .planner() + .plan( QueryStatement::Sql(Statement::Query(Box::new(query))), query_ctx.clone(), ) .await - .context(ExecuteSqlSnafu)?; + .context(PlanStatementSnafu)?; let output = self .query_engine .execute(&logical_plan) .await - .context(ExecuteSqlSnafu)?; + .context(ExecuteLogicalPlanSnafu)?; let stream: InsertRequestStream = match output { Output::RecordBatches(batches) => { diff --git a/src/datanode/src/tests.rs b/src/datanode/src/tests.rs index 1841b00baa..eb3b33b0fc 100644 --- a/src/datanode/src/tests.rs +++ b/src/datanode/src/tests.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// TODO(LFC): These tests should be moved to frontend crate. They are actually standalone instance tests. mod instance_test; mod promql_test; pub(crate) mod test_util; diff --git a/src/datanode/src/tests/instance_test.rs b/src/datanode/src/tests/instance_test.rs index e311b45a6a..9ec4682b89 100644 --- a/src/datanode/src/tests/instance_test.rs +++ b/src/datanode/src/tests/instance_test.rs @@ -19,9 +19,12 @@ use common_query::Output; use common_recordbatch::util; use datatypes::data_type::ConcreteDataType; use datatypes::vectors::{Int64Vector, StringVector, UInt64Vector, VectorRef}; +use query::parser::{QueryLanguageParser, QueryStatement}; use session::context::QueryContext; +use snafu::ResultExt; +use sql::statements::statement::Statement; -use crate::error::Error; +use crate::error::{Error, ExecuteLogicalPlanSnafu, PlanStatementSnafu}; use crate::tests::test_util::{self, check_output_stream, setup_test_instance, MockInstance}; #[tokio::test(flavor = "multi_thread")] @@ -414,7 +417,6 @@ pub async fn test_execute_create() { #[tokio::test] async fn test_rename_table() { - common_telemetry::init_default_ut_logging(); let instance = MockInstance::new("test_rename_table_local").await; let output = execute_sql(&instance, "create database db").await; @@ -933,7 +935,20 @@ async fn try_execute_sql_in_db( db: &str, ) -> Result { let query_ctx = Arc::new(QueryContext::with(DEFAULT_CATALOG_NAME, db)); - instance.inner().execute_sql(sql, query_ctx).await + + let stmt = QueryLanguageParser::parse_sql(sql).unwrap(); + match stmt { + QueryStatement::Sql(Statement::Query(_)) => { + let engine = instance.inner().query_engine(); + let plan = engine + .planner() + .plan(stmt, query_ctx) + .await + .context(PlanStatementSnafu)?; + engine.execute(&plan).await.context(ExecuteLogicalPlanSnafu) + } + _ => instance.inner().execute_stmt(stmt, query_ctx).await, + } } async fn execute_sql_in_db(instance: &MockInstance, sql: &str, db: &str) -> Output { diff --git a/src/datanode/src/tests/promql_test.rs b/src/datanode/src/tests/promql_test.rs index e7030e9de9..a14bb16ea2 100644 --- a/src/datanode/src/tests/promql_test.rs +++ b/src/datanode/src/tests/promql_test.rs @@ -31,22 +31,14 @@ async fn create_insert_query_assert( expected: &str, ) { let instance = setup_test_instance("test_execute_insert").await; - let query_ctx = QueryContext::arc(); - instance - .inner() - .execute_sql(create, query_ctx.clone()) - .await - .unwrap(); - instance - .inner() - .execute_sql(insert, query_ctx.clone()) - .await - .unwrap(); + instance.execute_sql(create).await; + + instance.execute_sql(insert).await; let query_output = instance .inner() - .execute_promql_statement(promql, start, end, interval, lookback, query_ctx) + .execute_promql_statement(promql, start, end, interval, lookback, QueryContext::arc()) .await .unwrap(); let expected = String::from(expected); @@ -56,24 +48,12 @@ async fn create_insert_query_assert( #[allow(clippy::too_many_arguments)] async fn create_insert_tql_assert(create: &str, insert: &str, tql: &str, expected: &str) { let instance = setup_test_instance("test_execute_insert").await; - let query_ctx = QueryContext::arc(); - instance - .inner() - .execute_sql(create, query_ctx.clone()) - .await - .unwrap(); - instance - .inner() - .execute_sql(insert, query_ctx.clone()) - .await - .unwrap(); + instance.execute_sql(create).await; - let query_output = instance - .inner() - .execute_sql(tql, query_ctx.clone()) - .await - .unwrap(); + instance.execute_sql(insert).await; + + let query_output = instance.execute_sql(tql).await; let expected = String::from(expected); check_unordered_output_stream(query_output, expected).await; } diff --git a/src/datanode/src/tests/test_util.rs b/src/datanode/src/tests/test_util.rs index 28591feaf1..da5f0aea9f 100644 --- a/src/datanode/src/tests/test_util.rs +++ b/src/datanode/src/tests/test_util.rs @@ -22,9 +22,13 @@ use datatypes::data_type::ConcreteDataType; use datatypes::schema::{ColumnSchema, RawSchema}; use mito::config::EngineConfig; use mito::table::test_util::{new_test_object_store, MockEngine, MockMitoEngine}; +use query::parser::{PromQuery, QueryLanguageParser, QueryStatement}; use query::QueryEngineFactory; use servers::Mode; +use session::context::QueryContext; use snafu::ResultExt; +use sql::statements::statement::Statement; +use sql::statements::tql::Tql; use table::engine::{EngineContext, TableEngineRef}; use table::requests::{CreateTableRequest, TableOptions}; @@ -72,6 +76,40 @@ impl MockInstance { } } + pub(crate) async fn execute_sql(&self, sql: &str) -> Output { + let engine = self.inner().query_engine(); + let planner = engine.planner(); + + let stmt = QueryLanguageParser::parse_sql(sql).unwrap(); + match stmt { + QueryStatement::Sql(Statement::Query(_)) => { + let plan = planner.plan(stmt, QueryContext::arc()).await.unwrap(); + engine.execute(&plan).await.unwrap() + } + QueryStatement::Sql(Statement::Tql(tql)) => { + let plan = match tql { + Tql::Eval(eval) => { + let promql = PromQuery { + start: eval.start, + end: eval.end, + step: eval.step, + query: eval.query, + }; + let stmt = QueryLanguageParser::parse_promql(&promql).unwrap(); + planner.plan(stmt, QueryContext::arc()).await.unwrap() + } + Tql::Explain(_) => unimplemented!(), + }; + engine.execute(&plan).await.unwrap() + } + _ => self + .inner() + .execute_stmt(stmt, QueryContext::arc()) + .await + .unwrap(), + } + } + pub(crate) fn inner(&self) -> &Instance { &self.instance } diff --git a/src/frontend/Cargo.toml b/src/frontend/Cargo.toml index 401b84c9c2..7a57004f85 100644 --- a/src/frontend/Cargo.toml +++ b/src/frontend/Cargo.toml @@ -14,6 +14,7 @@ client = { path = "../client" } common-base = { path = "../common/base" } common-catalog = { path = "../common/catalog" } common-error = { path = "../common/error" } +common-function = { path = "../common/function" } common-grpc = { path = "../common/grpc" } common-grpc-expr = { path = "../common/grpc-expr" } common-query = { path = "../common/query" } diff --git a/src/frontend/src/error.rs b/src/frontend/src/error.rs index 99a6b1be10..15eaccd965 100644 --- a/src/frontend/src/error.rs +++ b/src/frontend/src/error.rs @@ -247,6 +247,24 @@ pub enum Error { source: query::error::Error, }, + #[snafu(display("Failed to plan statement, source: {}", source))] + PlanStatement { + #[snafu(backtrace)] + source: query::error::Error, + }, + + #[snafu(display("Failed to parse query, source: {}", source))] + ParseQuery { + #[snafu(backtrace)] + source: query::error::Error, + }, + + #[snafu(display("Failed to execute logical plan, source: {}", source))] + ExecLogicalPlan { + #[snafu(backtrace)] + source: query::error::Error, + }, + #[snafu(display("Failed to build DataFusion logical plan, source: {}", source))] BuildDfLogicalPlan { source: datafusion_common::DataFusionError, @@ -426,9 +444,12 @@ impl ErrorExt for Error { | Error::ToTableInsertRequest { source } | Error::FindNewColumnsOnInsertion { source } => source.status_code(), - Error::ExecuteStatement { source, .. } | Error::DescribeStatement { source } => { - source.status_code() - } + Error::ExecuteStatement { source, .. } + | Error::PlanStatement { source } + | Error::ParseQuery { source } + | Error::ExecLogicalPlan { source } + | Error::DescribeStatement { source } => source.status_code(), + Error::AlterExprToRequest { source, .. } => source.status_code(), Error::LeaderNotFound { .. } => StatusCode::StorageUnavailable, Error::TableAlreadyExist { .. } => StatusCode::TableAlreadyExists, diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index 6dd95414ce..2ccb4fa48a 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -36,22 +36,26 @@ use common_grpc::channel_manager::{ChannelConfig, ChannelManager}; use common_query::Output; use common_recordbatch::RecordBatches; use common_telemetry::logging::{debug, info}; +use common_telemetry::timer; use datafusion::sql::sqlparser::ast::ObjectName; use datanode::instance::sql::table_idents_to_full_name; use datanode::instance::InstanceRef as DnInstanceRef; +use datanode::metric; use datatypes::schema::Schema; use distributed::DistInstance; use meta_client::client::{MetaClient, MetaClientBuilder}; use meta_client::MetaClientOptions; use partition::manager::PartitionRuleManager; use partition::route::TableRoutes; -use query::parser::PromQuery; +use query::parser::{PromQuery, QueryLanguageParser, QueryStatement}; use query::query_engine::options::{validate_catalog_and_schema, QueryOptions}; +use query::query_engine::StatementHandlerRef; +use query::{QueryEngineFactory, QueryEngineRef}; use servers::error as server_error; use servers::interceptor::{SqlQueryInterceptor, SqlQueryInterceptorRef}; use servers::prom::{PromHandler, PromHandlerRef}; use servers::query_handler::grpc::{GrpcQueryHandler, GrpcQueryHandlerRef}; -use servers::query_handler::sql::{SqlQueryHandler, SqlQueryHandlerRef}; +use servers::query_handler::sql::SqlQueryHandler; use servers::query_handler::{ InfluxdbLineProtocolHandler, OpentsdbProtocolHandler, PrometheusProtocolHandler, ScriptHandler, ScriptHandlerRef, @@ -62,16 +66,18 @@ use sql::dialect::GenericDialect; use sql::parser::ParserContext; use sql::statements::copy::CopyTable; use sql::statements::statement::Statement; +use sql::statements::tql::Tql; use crate::catalog::FrontendCatalogManager; use crate::datanode::DatanodeClients; use crate::error::{ - self, Error, ExecutePromqlSnafu, ExternalSnafu, InvalidInsertRequestSnafu, - MissingMetasrvOptsSnafu, NotSupportedSnafu, ParseSqlSnafu, Result, SqlExecInterceptedSnafu, + self, Error, ExecLogicalPlanSnafu, ExecutePromqlSnafu, ExecuteStatementSnafu, ExternalSnafu, + InvalidInsertRequestSnafu, MissingMetasrvOptsSnafu, NotSupportedSnafu, ParseQuerySnafu, + ParseSqlSnafu, PlanStatementSnafu, Result, SqlExecInterceptedSnafu, }; use crate::expr_factory::{CreateExprFactoryRef, DefaultCreateExprFactory}; use crate::frontend::FrontendOptions; -use crate::instance::standalone::{StandaloneGrpcQueryHandler, StandaloneSqlQueryHandler}; +use crate::instance::standalone::StandaloneGrpcQueryHandler; use crate::server::{start_server, ServerHandlers, Services}; #[async_trait] @@ -98,7 +104,8 @@ pub struct Instance { /// Script handler is None in distributed mode, only works on standalone mode. script_handler: Option, - sql_handler: SqlQueryHandlerRef, + statement_handler: StatementHandlerRef, + query_engine: QueryEngineRef, grpc_query_handler: GrpcQueryHandlerRef, promql_handler: Option, @@ -131,19 +138,20 @@ impl Instance { datanode_clients.clone(), )); - let dist_instance = DistInstance::new( - meta_client, - catalog_manager.clone(), - datanode_clients, - plugins.clone(), - ); + let dist_instance = + DistInstance::new(meta_client, catalog_manager.clone(), datanode_clients); let dist_instance = Arc::new(dist_instance); + let query_engine = + QueryEngineFactory::new_with_plugins(catalog_manager.clone(), plugins.clone()) + .query_engine(); + Ok(Instance { catalog_manager, script_handler: None, create_expr_factory: Arc::new(DefaultCreateExprFactory), - sql_handler: dist_instance.clone(), + statement_handler: dist_instance.clone(), + query_engine, grpc_query_handler: dist_instance, promql_handler: None, plugins: plugins.clone(), @@ -186,7 +194,8 @@ impl Instance { catalog_manager: dn_instance.catalog_manager().clone(), script_handler: None, create_expr_factory: Arc::new(DefaultCreateExprFactory), - sql_handler: StandaloneSqlQueryHandler::arc(dn_instance.clone()), + statement_handler: dn_instance.clone(), + query_engine: dn_instance.query_engine(), grpc_query_handler: StandaloneGrpcQueryHandler::arc(dn_instance.clone()), promql_handler: Some(dn_instance.clone()), plugins: Default::default(), @@ -207,11 +216,14 @@ impl Instance { #[cfg(test)] pub(crate) fn new_distributed(dist_instance: Arc) -> Self { + let catalog_manager = dist_instance.catalog_manager(); + let query_engine = QueryEngineFactory::new(catalog_manager.clone()).query_engine(); Instance { - catalog_manager: dist_instance.catalog_manager(), + catalog_manager, script_handler: None, + statement_handler: dist_instance.clone(), + query_engine, create_expr_factory: Arc::new(DefaultCreateExprFactory), - sql_handler: dist_instance.clone(), grpc_query_handler: dist_instance, promql_handler: None, plugins: Default::default(), @@ -418,20 +430,57 @@ fn parse_stmt(sql: &str) -> Result> { impl Instance { async fn query_statement(&self, stmt: Statement, query_ctx: QueryContextRef) -> Result { check_permission(self.plugins.clone(), &stmt, &query_ctx)?; + + let planner = self.query_engine.planner(); + match stmt { + Statement::Query(_) | Statement::Explain(_) => { + let plan = planner + .plan(QueryStatement::Sql(stmt), query_ctx) + .await + .context(PlanStatementSnafu)?; + self.query_engine + .execute(&plan) + .await + .context(ExecLogicalPlanSnafu) + } + Statement::Tql(tql) => { + let plan = match tql { + Tql::Eval(eval) => { + let promql = PromQuery { + start: eval.start, + end: eval.end, + step: eval.step, + query: eval.query, + }; + let stmt = + QueryLanguageParser::parse_promql(&promql).context(ParseQuerySnafu)?; + planner + .plan(stmt, query_ctx) + .await + .context(PlanStatementSnafu)? + } + Tql::Explain(_) => unimplemented!(), + }; + self.query_engine + .execute(&plan) + .await + .context(ExecLogicalPlanSnafu) + } Statement::CreateDatabase(_) | Statement::ShowDatabases(_) | Statement::CreateTable(_) | Statement::ShowTables(_) | Statement::DescribeTable(_) - | Statement::Explain(_) - | Statement::Query(_) | Statement::Insert(_) | Statement::Delete(_) | Statement::Alter(_) | Statement::DropTable(_) - | Statement::Tql(_) - | Statement::Copy(_) => self.sql_handler.do_statement_query(stmt, query_ctx).await, + | Statement::Copy(_) => self + .statement_handler + .handle_statement(QueryStatement::Sql(stmt), query_ctx) + .await + .context(ExecuteStatementSnafu), Statement::Use(db) => self.handle_use(db, query_ctx), Statement::ShowCreateTable(_) => NotSupportedSnafu { feat: format!("{stmt:?}"), @@ -446,6 +495,8 @@ impl SqlQueryHandler for Instance { type Error = Error; async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec> { + let _timer = timer!(metric::METRIC_HANDLE_SQL_ELAPSED); + let query_interceptor = self.plugins.get::>(); let query = match query_interceptor.pre_parsing(query, query_ctx.clone()) { Ok(q) => q, @@ -502,28 +553,26 @@ impl SqlQueryHandler for Instance { } } - async fn do_statement_query( - &self, - stmt: Statement, - query_ctx: QueryContextRef, - ) -> Result { - let query_interceptor = self.plugins.get::>(); - - // TODO(sunng87): figure out at which stage we can call - // this hook after ArrowFlight adoption. We need to provide - // LogicalPlan as to this hook. - query_interceptor.pre_execute(&stmt, None, query_ctx.clone())?; - self.query_statement(stmt, query_ctx.clone()) - .await - .and_then(|output| query_interceptor.post_execute(output, query_ctx.clone())) - } - async fn do_describe( &self, stmt: Statement, query_ctx: QueryContextRef, ) -> Result> { - self.sql_handler.do_describe(stmt, query_ctx).await + if let Statement::Query(_) = stmt { + let plan = self + .query_engine + .planner() + .plan(QueryStatement::Sql(stmt), query_ctx) + .await + .context(PlanStatementSnafu)?; + self.query_engine + .describe(plan) + .await + .map(Some) + .context(error::DescribeStatementSnafu) + } else { + Ok(None) + } } fn is_valid_schema(&self, catalog: &str, schema: &str) -> Result { @@ -1028,12 +1077,16 @@ mod tests { .collect::>(); assert_eq!(region_to_dn_map.len(), expected_distribution.len()); + let stmt = QueryLanguageParser::parse_sql("SELECT ts, host FROM demo ORDER BY ts").unwrap(); for (region, dn) in region_to_dn_map.iter() { let dn = instance.datanodes.get(dn).unwrap(); - let output = dn - .execute_sql("SELECT ts, host FROM demo ORDER BY ts", QueryContext::arc()) + let engine = dn.query_engine(); + let plan = engine + .planner() + .plan(stmt.clone(), QueryContext::arc()) .await .unwrap(); + let output = engine.execute(&plan).await.unwrap(); let Output::Stream(stream) = output else { unreachable!() }; let recordbatches = RecordBatches::try_collect(stream).await.unwrap(); let actual = recordbatches.pretty_print().unwrap(); diff --git a/src/frontend/src/instance/distributed.rs b/src/frontend/src/instance/distributed.rs index 2907ba0fef..66f98d28cd 100644 --- a/src/frontend/src/instance/distributed.rs +++ b/src/frontend/src/instance/distributed.rs @@ -27,7 +27,6 @@ use catalog::helper::{SchemaKey, SchemaValue}; use catalog::{CatalogManager, DeregisterTableRequest, RegisterTableRequest}; use chrono::DateTime; use client::Database; -use common_base::Plugins; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_catalog::format_full_table_name; use common_error::prelude::BoxedError; @@ -35,7 +34,7 @@ use common_query::Output; use common_telemetry::{debug, info}; use datanode::instance::sql::table_idents_to_full_name; use datatypes::prelude::ConcreteDataType; -use datatypes::schema::{RawSchema, Schema}; +use datatypes::schema::RawSchema; use meta_client::client::MetaClient; use meta_client::rpc::router::DeleteRequest as MetaDeleteRequest; use meta_client::rpc::{ @@ -43,10 +42,10 @@ use meta_client::rpc::{ RouteResponse, TableName, }; use partition::partition::{PartitionBound, PartitionDef}; -use query::parser::{PromQuery, QueryStatement}; -use query::sql::{describe_table, explain, show_databases, show_tables}; -use query::{QueryEngineFactory, QueryEngineRef}; -use servers::query_handler::sql::SqlQueryHandler; +use query::error::QueryExecutionSnafu; +use query::parser::QueryStatement; +use query::query_engine::StatementHandler; +use query::sql::{describe_table, show_databases, show_tables}; use session::context::QueryContextRef; use snafu::{ensure, OptionExt, ResultExt}; use sql::ast::Value as SqlValue; @@ -61,12 +60,12 @@ use crate::catalog::FrontendCatalogManager; use crate::datanode::DatanodeClients; use crate::error::{ self, AlterExprToRequestSnafu, CatalogEntrySerdeSnafu, CatalogSnafu, ColumnDataTypeSnafu, - DeserializePartitionSnafu, ParseSqlSnafu, PrimaryKeyNotFoundSnafu, RequestDatanodeSnafu, - RequestMetaSnafu, Result, SchemaExistsSnafu, StartMetaClientSnafu, TableAlreadyExistSnafu, - TableNotFoundSnafu, TableSnafu, ToTableInsertRequestSnafu, UnrecognizedTableOptionSnafu, + DeserializePartitionSnafu, NotSupportedSnafu, ParseSqlSnafu, PrimaryKeyNotFoundSnafu, + RequestDatanodeSnafu, RequestMetaSnafu, Result, SchemaExistsSnafu, StartMetaClientSnafu, + TableAlreadyExistSnafu, TableNotFoundSnafu, TableSnafu, ToTableInsertRequestSnafu, + UnrecognizedTableOptionSnafu, }; use crate::expr_factory; -use crate::instance::parse_stmt; use crate::sql::insert_to_request; use crate::table::DistTable; @@ -75,7 +74,6 @@ pub(crate) struct DistInstance { meta_client: Arc, catalog_manager: Arc, datanode_clients: Arc, - query_engine: QueryEngineRef, } impl DistInstance { @@ -83,16 +81,11 @@ impl DistInstance { meta_client: Arc, catalog_manager: Arc, datanode_clients: Arc, - plugins: Arc, ) -> Self { - let query_engine = - QueryEngineFactory::new_with_plugins(catalog_manager.clone(), plugins.clone()) - .query_engine(); Self { meta_client, catalog_manager, datanode_clients, - query_engine, } } @@ -272,14 +265,6 @@ impl DistInstance { query_ctx: QueryContextRef, ) -> Result { match stmt { - Statement::Query(_) => { - let plan = self - .query_engine - .statement_to_plan(QueryStatement::Sql(stmt), query_ctx) - .await - .context(error::ExecuteStatementSnafu {})?; - self.query_engine.execute(&plan).await - } Statement::CreateDatabase(stmt) => { let expr = CreateDatabaseExpr { database_name: stmt.name.to_string(), @@ -321,9 +306,6 @@ impl DistInstance { })?; describe_table(table) } - Statement::Explain(stmt) => { - explain(Box::new(stmt), self.query_engine.clone(), query_ctx).await - } Statement::Insert(insert) => { let (catalog, schema, table) = table_idents_to_full_name(insert.table_name(), query_ctx.clone()) @@ -353,29 +335,6 @@ impl DistInstance { .context(error::ExecuteStatementSnafu) } - async fn handle_sql(&self, sql: &str, query_ctx: QueryContextRef) -> Vec> { - let stmts = parse_stmt(sql); - match stmts { - Ok(stmts) => { - let mut results = Vec::with_capacity(stmts.len()); - - for stmt in stmts { - let result = self.handle_statement(stmt, query_ctx.clone()).await; - let is_err = result.is_err(); - - results.push(result); - - if is_err { - break; - } - } - - results - } - Err(e) => vec![Err(e)], - } - } - /// Handles distributed database creation async fn handle_create_database( &self, @@ -519,50 +478,21 @@ impl DistInstance { } #[async_trait] -impl SqlQueryHandler for DistInstance { - type Error = error::Error; - - async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec> { - self.handle_sql(query, query_ctx).await - } - - async fn do_promql_query( +impl StatementHandler for DistInstance { + async fn handle_statement( &self, - _: &PromQuery, - _: QueryContextRef, - ) -> Vec> { - unimplemented!() - } - - async fn do_statement_query( - &self, - stmt: Statement, + stmt: QueryStatement, query_ctx: QueryContextRef, - ) -> Result { - self.handle_statement(stmt, query_ctx).await - } - - async fn do_describe( - &self, - stmt: Statement, - query_ctx: QueryContextRef, - ) -> Result> { - if let Statement::Query(_) = stmt { - self.query_engine - .describe(QueryStatement::Sql(stmt), query_ctx) - .await - .map(Some) - .context(error::DescribeStatementSnafu) - } else { - Ok(None) + ) -> query::error::Result { + match stmt { + QueryStatement::Sql(stmt) => self.handle_statement(stmt, query_ctx).await, + QueryStatement::Promql(_) => NotSupportedSnafu { + feat: "distributed execute promql".to_string(), + } + .fail(), } - } - - fn is_valid_schema(&self, catalog: &str, schema: &str) -> Result { - self.catalog_manager - .schema(catalog, schema) - .map(|s| s.is_some()) - .context(CatalogSnafu) + .map_err(BoxedError::new) + .context(QueryExecutionSnafu) } } @@ -721,14 +651,15 @@ fn find_partition_columns( #[cfg(test)] mod test { use itertools::Itertools; - use servers::query_handler::sql::SqlQueryHandlerRef; + use query::parser::QueryLanguageParser; + use query::query_engine::StatementHandlerRef; use session::context::QueryContext; use sql::dialect::GenericDialect; use sql::parser::ParserContext; use sql::statements::statement::Statement; use super::*; - use crate::instance::standalone::StandaloneSqlQueryHandler; + use crate::instance::parse_stmt; #[tokio::test] async fn test_parse_partitions() { @@ -771,28 +702,28 @@ ENGINE=mito", } } + async fn handle_sql(instance: &Arc, sql: &str) -> Output { + let stmt = parse_stmt(sql).unwrap().remove(0); + instance + .handle_statement(stmt, QueryContext::arc()) + .await + .unwrap() + } + #[tokio::test(flavor = "multi_thread")] async fn test_show_databases() { let instance = crate::tests::create_distributed_instance("test_show_databases").await; let dist_instance = &instance.dist_instance; let sql = "create database test_show_databases"; - let output = dist_instance - .handle_sql(sql, QueryContext::arc()) - .await - .remove(0) - .unwrap(); + let output = handle_sql(dist_instance, sql).await; match output { Output::AffectedRows(rows) => assert_eq!(rows, 1), _ => unreachable!(), } let sql = "show databases"; - let output = dist_instance - .handle_sql(sql, QueryContext::arc()) - .await - .remove(0) - .unwrap(); + let output = handle_sql(dist_instance, sql).await; match output { Output::RecordBatches(r) => { let expected1 = vec![ @@ -829,11 +760,7 @@ ENGINE=mito", let datanode_instances = instance.datanodes; let sql = "create database test_show_tables"; - dist_instance - .handle_sql(sql, QueryContext::arc()) - .await - .remove(0) - .unwrap(); + handle_sql(dist_instance, sql).await; let sql = " CREATE TABLE greptime.test_show_tables.dist_numbers ( @@ -848,18 +775,14 @@ ENGINE=mito", PARTITION r3 VALUES LESS THAN (MAXVALUE), ) ENGINE=mito"; - dist_instance - .handle_sql(sql, QueryContext::arc()) - .await - .remove(0) - .unwrap(); + handle_sql(dist_instance, sql).await; - async fn assert_show_tables(instance: SqlQueryHandlerRef) { + async fn assert_show_tables(handler: StatementHandlerRef) { let sql = "show tables in test_show_tables"; - let output = instance - .do_query(sql, QueryContext::arc()) + let stmt = QueryLanguageParser::parse_sql(sql).unwrap(); + let output = handler + .handle_statement(stmt, QueryContext::arc()) .await - .remove(0) .unwrap(); match output { Output::RecordBatches(r) => { @@ -878,7 +801,7 @@ ENGINE=mito", // Asserts that new table is created in Datanode as well. for x in datanode_instances.values() { - assert_show_tables(StandaloneSqlQueryHandler::arc(x.clone())).await + assert_show_tables(x.clone()).await } } } diff --git a/src/frontend/src/instance/grpc.rs b/src/frontend/src/instance/grpc.rs index e4c6eee442..27d5236e6a 100644 --- a/src/frontend/src/instance/grpc.rs +++ b/src/frontend/src/instance/grpc.rs @@ -97,6 +97,7 @@ mod test { use catalog::helper::{TableGlobalKey, TableGlobalValue}; use common_query::Output; use common_recordbatch::RecordBatches; + use query::parser::QueryLanguageParser; use session::context::QueryContext; use super::*; @@ -455,14 +456,18 @@ CREATE TABLE {table_name} ( assert_eq!(region_to_dn_map.len(), expected_distribution.len()); for (region, dn) in region_to_dn_map.iter() { + let stmt = QueryLanguageParser::parse_sql(&format!( + "SELECT ts, a FROM {table_name} ORDER BY ts" + )) + .unwrap(); let dn = instance.datanodes.get(dn).unwrap(); - let output = dn - .execute_sql( - &format!("SELECT ts, a FROM {table_name} ORDER BY ts"), - QueryContext::arc(), - ) + let engine = dn.query_engine(); + let plan = engine + .planner() + .plan(stmt, QueryContext::arc()) .await .unwrap(); + let output = engine.execute(&plan).await.unwrap(); let Output::Stream(stream) = output else { unreachable!() }; let recordbatches = RecordBatches::try_collect(stream).await.unwrap(); let actual = recordbatches.pretty_print().unwrap(); diff --git a/src/frontend/src/instance/standalone.rs b/src/frontend/src/instance/standalone.rs index 042519dff6..4e54d073f9 100644 --- a/src/frontend/src/instance/standalone.rs +++ b/src/frontend/src/instance/standalone.rs @@ -18,74 +18,12 @@ use api::v1::greptime_request::Request as GreptimeRequest; use async_trait::async_trait; use common_query::Output; use datanode::error::Error as DatanodeError; -use datatypes::schema::Schema; -use query::parser::PromQuery; use servers::query_handler::grpc::{GrpcQueryHandler, GrpcQueryHandlerRef}; -use servers::query_handler::sql::{SqlQueryHandler, SqlQueryHandlerRef}; use session::context::QueryContextRef; use snafu::ResultExt; -use sql::statements::statement::Statement; use crate::error::{self, Result}; -pub(crate) struct StandaloneSqlQueryHandler(SqlQueryHandlerRef); - -impl StandaloneSqlQueryHandler { - pub(crate) fn arc(handler: SqlQueryHandlerRef) -> Arc { - Arc::new(Self(handler)) - } -} - -#[async_trait] -impl SqlQueryHandler for StandaloneSqlQueryHandler { - type Error = error::Error; - - async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec> { - self.0 - .do_query(query, query_ctx) - .await - .into_iter() - .map(|x| x.context(error::InvokeDatanodeSnafu)) - .collect() - } - - async fn do_promql_query( - &self, - _: &PromQuery, - _: QueryContextRef, - ) -> Vec> { - unimplemented!() - } - - async fn do_statement_query( - &self, - stmt: Statement, - query_ctx: QueryContextRef, - ) -> Result { - self.0 - .do_statement_query(stmt, query_ctx) - .await - .context(error::InvokeDatanodeSnafu) - } - - async fn do_describe( - &self, - stmt: Statement, - query_ctx: QueryContextRef, - ) -> Result> { - self.0 - .do_describe(stmt, query_ctx) - .await - .context(error::InvokeDatanodeSnafu) - } - - fn is_valid_schema(&self, catalog: &str, schema: &str) -> Result { - self.0 - .is_valid_schema(catalog, schema) - .context(error::InvokeDatanodeSnafu) - } -} - pub(crate) struct StandaloneGrpcQueryHandler(GrpcQueryHandlerRef); impl StandaloneGrpcQueryHandler { diff --git a/src/frontend/src/tests.rs b/src/frontend/src/tests.rs index d9e1a8e9b5..631d04ce06 100644 --- a/src/frontend/src/tests.rs +++ b/src/frontend/src/tests.rs @@ -258,7 +258,6 @@ pub(crate) async fn create_distributed_instance(test_name: &str) -> MockDistribu meta_client.clone(), catalog_manager, datanode_clients.clone(), - Default::default(), ); let dist_instance = Arc::new(dist_instance); let frontend = Instance::new_distributed(dist_instance.clone()); diff --git a/src/query/src/datafusion.rs b/src/query/src/datafusion.rs index ab5ce5ea16..057d3a14a5 100644 --- a/src/query/src/datafusion.rs +++ b/src/query/src/datafusion.rs @@ -21,9 +21,6 @@ mod planner; use std::sync::Arc; use async_trait::async_trait; -use catalog::table_source::DfTableSourceProvider; -use catalog::CatalogListRef; -use common_base::Plugins; use common_error::prelude::BoxedError; use common_function::scalars::aggregate::AggregateFunctionMetaRef; use common_function::scalars::udf::create_udf; @@ -36,115 +33,44 @@ use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream}; use common_telemetry::timer; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::ExecutionPlan; -use datafusion_sql::planner::{ParserOptions, SqlToRel}; use datatypes::schema::Schema; -use promql::planner::PromPlanner; -use promql_parser::parser::EvalStmt; -use session::context::QueryContextRef; use snafu::{OptionExt, ResultExt}; -use sql::statements::statement::Statement; pub use crate::datafusion::catalog_adapter::DfCatalogListAdapter; pub use crate::datafusion::planner::DfContextProviderAdapter; -use crate::error::{ - DataFusionSnafu, PlanSqlSnafu, QueryExecutionSnafu, QueryPlanSnafu, Result, SqlSnafu, -}; +use crate::error::{DataFusionSnafu, QueryExecutionSnafu, Result}; use crate::executor::QueryExecutor; use crate::logical_optimizer::LogicalOptimizer; -use crate::parser::QueryStatement; use crate::physical_optimizer::PhysicalOptimizer; use crate::physical_planner::PhysicalPlanner; use crate::plan::LogicalPlan; +use crate::planner::{DfLogicalPlanner, LogicalPlanner}; use crate::query_engine::{QueryEngineContext, QueryEngineState}; use crate::{metric, QueryEngine}; pub struct DatafusionQueryEngine { - state: QueryEngineState, + state: Arc, } impl DatafusionQueryEngine { - pub fn new(catalog_list: CatalogListRef, plugins: Arc) -> Self { - Self { - state: QueryEngineState::new(catalog_list.clone(), plugins), - } - } - - async fn plan_sql_stmt( - &self, - stmt: Statement, - query_ctx: QueryContextRef, - ) -> Result { - let session_state = self.state.session_state(); - - let df_stmt = (&stmt).try_into().context(SqlSnafu)?; - - let config_options = session_state.config().config_options(); - let parser_options = ParserOptions { - enable_ident_normalization: config_options.sql_parser.enable_ident_normalization, - parse_float_as_decimal: config_options.sql_parser.parse_float_as_decimal, - }; - - let context_provider = DfContextProviderAdapter::try_new( - self.state.clone(), - session_state, - &df_stmt, - query_ctx, - ) - .await?; - let sql_to_rel = SqlToRel::new_with_options(&context_provider, parser_options); - - let result = sql_to_rel.statement_to_plan(df_stmt).with_context(|_| { - let sql = if let Statement::Query(query) = stmt { - query.inner.to_string() - } else { - format!("{stmt:?}") - }; - PlanSqlSnafu { sql } - })?; - Ok(LogicalPlan::DfPlan(result)) - } - - // TODO(ruihang): test this method once parser is ready. - async fn plan_promql_stmt( - &self, - stmt: EvalStmt, - query_ctx: QueryContextRef, - ) -> Result { - let table_provider = DfTableSourceProvider::new( - self.state.catalog_list().clone(), - self.state.disallow_cross_schema_query(), - query_ctx.as_ref(), - ); - PromPlanner::stmt_to_plan(table_provider, stmt) - .await - .map(LogicalPlan::DfPlan) - .map_err(BoxedError::new) - .context(QueryPlanSnafu) + pub fn new(state: Arc) -> Self { + Self { state } } } -// TODO(LFC): Refactor consideration: extract a "Planner" that stores query context and execute queries inside. #[async_trait] impl QueryEngine for DatafusionQueryEngine { + fn planner(&self) -> Arc { + Arc::new(DfLogicalPlanner::new(self.state.clone())) + } + fn name(&self) -> &str { "datafusion" } - async fn statement_to_plan( - &self, - stmt: QueryStatement, - query_ctx: QueryContextRef, - ) -> Result { - match stmt { - QueryStatement::Sql(stmt) => self.plan_sql_stmt(stmt, query_ctx).await, - QueryStatement::Promql(stmt) => self.plan_promql_stmt(stmt, query_ctx).await, - } - } - - async fn describe(&self, stmt: QueryStatement, query_ctx: QueryContextRef) -> Result { + async fn describe(&self, plan: LogicalPlan) -> Result { // TODO(sunng87): consider cache optmised logical plan between describe // and execute - let plan = self.statement_to_plan(stmt, query_ctx).await?; let optimised_plan = self.optimize(&plan)?; optimised_plan.schema() } @@ -159,11 +85,6 @@ impl QueryEngine for DatafusionQueryEngine { Ok(Output::Stream(self.execute_stream(&ctx, &physical_plan)?)) } - async fn execute_physical(&self, plan: &Arc) -> Result { - let ctx = QueryEngineContext::new(self.state.session_state()); - Ok(Output::Stream(self.execute_stream(&ctx, plan)?)) - } - fn register_udf(&self, udf: ScalarUdf) { self.state.register_udf(udf); } @@ -348,7 +269,8 @@ mod tests { let stmt = QueryLanguageParser::parse_sql(sql).unwrap(); let plan = engine - .statement_to_plan(stmt, Arc::new(QueryContext::new())) + .planner() + .plan(stmt, QueryContext::arc()) .await .unwrap(); @@ -369,7 +291,8 @@ mod tests { let stmt = QueryLanguageParser::parse_sql(sql).unwrap(); let plan = engine - .statement_to_plan(stmt, Arc::new(QueryContext::new())) + .planner() + .plan(stmt, Arc::new(QueryContext::new())) .await .unwrap(); @@ -406,11 +329,14 @@ mod tests { let stmt = QueryLanguageParser::parse_sql(sql).unwrap(); - let schema = engine - .describe(stmt, Arc::new(QueryContext::new())) + let plan = engine + .planner() + .plan(stmt, QueryContext::arc()) .await .unwrap(); + let schema = engine.describe(plan).await.unwrap(); + assert_eq!( schema.column_schemas()[0], ColumnSchema::new( diff --git a/src/query/src/datafusion/planner.rs b/src/query/src/datafusion/planner.rs index b3bcda49dc..c0a4f6a635 100644 --- a/src/query/src/datafusion/planner.rs +++ b/src/query/src/datafusion/planner.rs @@ -37,7 +37,7 @@ use crate::error::{CatalogSnafu, DataFusionSnafu, Result}; use crate::query_engine::QueryEngineState; pub struct DfContextProviderAdapter { - engine_state: QueryEngineState, + engine_state: Arc, session_state: SessionState, tables: HashMap>, table_provider: DfTableSourceProvider, @@ -45,7 +45,7 @@ pub struct DfContextProviderAdapter { impl DfContextProviderAdapter { pub(crate) async fn try_new( - engine_state: QueryEngineState, + engine_state: Arc, session_state: SessionState, df_stmt: &DfStatement, query_ctx: QueryContextRef, diff --git a/src/query/src/planner.rs b/src/query/src/planner.rs index 8f59912aab..e1938f858d 100644 --- a/src/query/src/planner.rs +++ b/src/query/src/planner.rs @@ -12,12 +12,94 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + +use async_trait::async_trait; +use catalog::table_source::DfTableSourceProvider; +use common_error::prelude::BoxedError; +use datafusion::execution::context::SessionState; +use datafusion_sql::planner::{ParserOptions, SqlToRel}; +use promql::planner::PromPlanner; +use promql_parser::parser::EvalStmt; +use session::context::QueryContextRef; +use snafu::ResultExt; use sql::statements::statement::Statement; -use crate::error::Result; +use crate::error::{PlanSqlSnafu, QueryPlanSnafu, Result, SqlSnafu}; +use crate::parser::QueryStatement; use crate::plan::LogicalPlan; +use crate::query_engine::QueryEngineState; +use crate::DfContextProviderAdapter; -/// SQL logical planner. -pub trait Planner: Send + Sync { - fn statement_to_plan(&self, statement: Statement) -> Result; +#[async_trait] +pub trait LogicalPlanner: Send + Sync { + async fn plan(&self, stmt: QueryStatement, query_ctx: QueryContextRef) -> Result; +} + +pub struct DfLogicalPlanner { + engine_state: Arc, + session_state: SessionState, +} + +impl DfLogicalPlanner { + pub fn new(engine_state: Arc) -> Self { + let session_state = engine_state.session_state(); + Self { + engine_state, + session_state, + } + } + + async fn plan_sql(&self, stmt: Statement, query_ctx: QueryContextRef) -> Result { + let df_stmt = (&stmt).try_into().context(SqlSnafu)?; + + let context_provider = DfContextProviderAdapter::try_new( + self.engine_state.clone(), + self.session_state.clone(), + &df_stmt, + query_ctx, + ) + .await?; + + let config_options = self.session_state.config().config_options(); + let parser_options = ParserOptions { + enable_ident_normalization: config_options.sql_parser.enable_ident_normalization, + parse_float_as_decimal: config_options.sql_parser.parse_float_as_decimal, + }; + + let sql_to_rel = SqlToRel::new_with_options(&context_provider, parser_options); + + let result = sql_to_rel.statement_to_plan(df_stmt).with_context(|_| { + let sql = if let Statement::Query(query) = stmt { + query.inner.to_string() + } else { + format!("{stmt:?}") + }; + PlanSqlSnafu { sql } + })?; + Ok(LogicalPlan::DfPlan(result)) + } + + async fn plan_pql(&self, stmt: EvalStmt, query_ctx: QueryContextRef) -> Result { + let table_provider = DfTableSourceProvider::new( + self.engine_state.catalog_list().clone(), + self.engine_state.disallow_cross_schema_query(), + query_ctx.as_ref(), + ); + PromPlanner::stmt_to_plan(table_provider, stmt) + .await + .map(LogicalPlan::DfPlan) + .map_err(BoxedError::new) + .context(QueryPlanSnafu) + } +} + +#[async_trait] +impl LogicalPlanner for DfLogicalPlanner { + async fn plan(&self, stmt: QueryStatement, query_ctx: QueryContextRef) -> Result { + match stmt { + QueryStatement::Sql(stmt) => self.plan_sql(stmt, query_ctx).await, + QueryStatement::Promql(stmt) => self.plan_pql(stmt, query_ctx).await, + } + } } diff --git a/src/query/src/query_engine.rs b/src/query/src/query_engine.rs index 83e95ab775..bf4b5766cd 100644 --- a/src/query/src/query_engine.rs +++ b/src/query/src/query_engine.rs @@ -23,7 +23,6 @@ use catalog::CatalogListRef; use common_base::Plugins; use common_function::scalars::aggregate::AggregateFunctionMetaRef; use common_function::scalars::{FunctionRef, FUNCTION_REGISTRY}; -use common_query::physical_plan::PhysicalPlan; use common_query::prelude::ScalarUdf; use common_query::Output; use datatypes::schema::Schema; @@ -33,25 +32,32 @@ use crate::datafusion::DatafusionQueryEngine; use crate::error::Result; use crate::parser::QueryStatement; use crate::plan::LogicalPlan; +use crate::planner::LogicalPlanner; pub use crate::query_engine::context::QueryEngineContext; pub use crate::query_engine::state::QueryEngineState; -#[async_trait] -pub trait QueryEngine: Send + Sync { - fn name(&self) -> &str; +pub type StatementHandlerRef = Arc; - async fn statement_to_plan( +// TODO(LFC): Gradually make more statements executed in the form of logical plan, and remove this trait. Tracked in #1010. +#[async_trait] +pub trait StatementHandler: Send + Sync { + async fn handle_statement( &self, stmt: QueryStatement, query_ctx: QueryContextRef, - ) -> Result; + ) -> Result; +} - async fn describe(&self, stmt: QueryStatement, query_ctx: QueryContextRef) -> Result; +#[async_trait] +pub trait QueryEngine: Send + Sync { + fn planner(&self) -> Arc; + + fn name(&self) -> &str; + + async fn describe(&self, plan: LogicalPlan) -> Result; async fn execute(&self, plan: &LogicalPlan) -> Result; - async fn execute_physical(&self, plan: &Arc) -> Result; - fn register_udf(&self, udf: ScalarUdf); fn register_aggregate_function(&self, func: AggregateFunctionMetaRef); @@ -65,13 +71,12 @@ pub struct QueryEngineFactory { impl QueryEngineFactory { pub fn new(catalog_list: CatalogListRef) -> Self { - let query_engine = Arc::new(DatafusionQueryEngine::new(catalog_list, Default::default())); - register_functions(&query_engine); - Self { query_engine } + Self::new_with_plugins(catalog_list, Default::default()) } pub fn new_with_plugins(catalog_list: CatalogListRef, plugins: Arc) -> Self { - let query_engine = Arc::new(DatafusionQueryEngine::new(catalog_list, plugins)); + let state = Arc::new(QueryEngineState::new(catalog_list, plugins)); + let query_engine = Arc::new(DatafusionQueryEngine::new(state)); register_functions(&query_engine); Self { query_engine } } diff --git a/src/query/src/sql.rs b/src/query/src/sql.rs index fcf21f6e4a..f83c03a664 100644 --- a/src/query/src/sql.rs +++ b/src/query/src/sql.rs @@ -24,14 +24,10 @@ use datatypes::vectors::{Helper, StringVector}; use once_cell::sync::Lazy; use session::context::QueryContextRef; use snafu::{ensure, OptionExt, ResultExt}; -use sql::statements::explain::Explain; use sql::statements::show::{ShowDatabases, ShowKind, ShowTables}; -use sql::statements::statement::Statement; use table::TableRef; use crate::error::{self, Result}; -use crate::parser::QueryStatement; -use crate::QueryEngineRef; const SCHEMAS_COLUMN: &str = "Schemas"; const TABLES_COLUMN: &str = "Tables"; @@ -156,17 +152,6 @@ pub fn show_tables( Ok(Output::RecordBatches(records)) } -pub async fn explain( - stmt: Box, - query_engine: QueryEngineRef, - query_ctx: QueryContextRef, -) -> Result { - let plan = query_engine - .statement_to_plan(QueryStatement::Sql(Statement::Explain(*stmt)), query_ctx) - .await?; - query_engine.execute(&plan).await -} - pub fn describe_table(table: TableRef) -> Result { let table_info = table.table_info(); let columns_schemas = table_info.meta.schema.column_schemas(); diff --git a/src/query/src/tests.rs b/src/query/src/tests.rs index a95678a5fe..44844fc962 100644 --- a/src/query/src/tests.rs +++ b/src/query/src/tests.rs @@ -12,6 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +use common_query::Output; +use common_recordbatch::{util, RecordBatch}; +use session::context::QueryContext; + +use crate::parser::QueryLanguageParser; +use crate::QueryEngineRef; + mod argmax_test; mod argmin_test; mod mean_test; @@ -25,3 +32,17 @@ mod time_range_filter_test; mod function; mod pow; + +async fn exec_selection(engine: QueryEngineRef, sql: &str) -> Vec { + let stmt = QueryLanguageParser::parse_sql(sql).unwrap(); + let plan = engine + .planner() + .plan(stmt, QueryContext::arc()) + .await + .unwrap(); + let Output::Stream(stream) = engine + .execute(&plan) + .await + .unwrap() else { unreachable!() }; + util::collect(stream).await.unwrap() +} diff --git a/src/query/src/tests/argmax_test.rs b/src/query/src/tests/argmax_test.rs index c65f972d66..9f9c86e8e7 100644 --- a/src/query/src/tests/argmax_test.rs +++ b/src/query/src/tests/argmax_test.rs @@ -14,17 +14,12 @@ use std::sync::Arc; -use common_query::Output; -use common_recordbatch::error::Result as RecordResult; -use common_recordbatch::{util, RecordBatch}; use datatypes::for_all_primitive_types; use datatypes::prelude::*; use datatypes::types::WrapperType; -use session::context::QueryContext; use crate::error::Result; -use crate::parser::QueryLanguageParser; -use crate::tests::function; +use crate::tests::{exec_selection, function}; use crate::QueryEngine; #[tokio::test] @@ -52,9 +47,8 @@ async fn test_argmax_success( where T: WrapperType + PartialOrd, { - let result = execute_argmax(column_name, table_name, engine.clone()) - .await - .unwrap(); + let sql = format!("select ARGMAX({column_name}) as argmax from {table_name}"); + let result = exec_selection(engine.clone(), &sql).await; let value = function::get_value_from_batches("argmax", result); let numbers = @@ -77,23 +71,3 @@ where assert_eq!(value, expected_value); Ok(()) } - -async fn execute_argmax<'a>( - column_name: &'a str, - table_name: &'a str, - engine: Arc, -) -> RecordResult> { - let sql = format!("select ARGMAX({column_name}) as argmax from {table_name}"); - let stmt = QueryLanguageParser::parse_sql(&sql).unwrap(); - let plan = engine - .statement_to_plan(stmt, Arc::new(QueryContext::new())) - .await - .unwrap(); - - let output = engine.execute(&plan).await.unwrap(); - let recordbatch_stream = match output { - Output::Stream(batch) => batch, - _ => unreachable!(), - }; - util::collect(recordbatch_stream).await -} diff --git a/src/query/src/tests/argmin_test.rs b/src/query/src/tests/argmin_test.rs index 171c387d31..5baa532cc6 100644 --- a/src/query/src/tests/argmin_test.rs +++ b/src/query/src/tests/argmin_test.rs @@ -14,17 +14,12 @@ use std::sync::Arc; -use common_query::Output; -use common_recordbatch::error::Result as RecordResult; -use common_recordbatch::{util, RecordBatch}; use datatypes::for_all_primitive_types; use datatypes::prelude::*; use datatypes::types::WrapperType; -use session::context::QueryContext; use crate::error::Result; -use crate::parser::QueryLanguageParser; -use crate::tests::function; +use crate::tests::{exec_selection, function}; use crate::QueryEngine; #[tokio::test] @@ -52,9 +47,8 @@ async fn test_argmin_success( where T: WrapperType + PartialOrd, { - let result = execute_argmin(column_name, table_name, engine.clone()) - .await - .unwrap(); + let sql = format!("select argmin({column_name}) as argmin from {table_name}"); + let result = exec_selection(engine.clone(), &sql).await; let value = function::get_value_from_batches("argmin", result); let numbers = @@ -77,23 +71,3 @@ where assert_eq!(value, expected_value); Ok(()) } - -async fn execute_argmin<'a>( - column_name: &'a str, - table_name: &'a str, - engine: Arc, -) -> RecordResult> { - let sql = format!("select argmin({column_name}) as argmin from {table_name}"); - let stmt = QueryLanguageParser::parse_sql(&sql).unwrap(); - let plan = engine - .statement_to_plan(stmt, Arc::new(QueryContext::new())) - .await - .unwrap(); - - let output = engine.execute(&plan).await.unwrap(); - let recordbatch_stream = match output { - Output::Stream(batch) => batch, - _ => unreachable!(), - }; - util::collect(recordbatch_stream).await -} diff --git a/src/query/src/tests/function.rs b/src/query/src/tests/function.rs index 7560b038ef..a301bc11ad 100644 --- a/src/query/src/tests/function.rs +++ b/src/query/src/tests/function.rs @@ -17,18 +17,16 @@ use std::sync::Arc; use catalog::local::{MemoryCatalogManager, MemoryCatalogProvider, MemorySchemaProvider}; use catalog::{CatalogList, CatalogProvider, SchemaProvider}; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; -use common_query::Output; -use common_recordbatch::{util, RecordBatch}; +use common_recordbatch::RecordBatch; use datatypes::for_all_primitive_types; use datatypes::prelude::*; use datatypes::schema::{ColumnSchema, Schema}; use datatypes::types::WrapperType; use datatypes::vectors::Helper; use rand::Rng; -use session::context::QueryContext; use table::test_util::MemTable; -use crate::parser::QueryLanguageParser; +use crate::tests::exec_selection; use crate::{QueryEngine, QueryEngineFactory}; pub fn create_query_engine() -> Arc { @@ -81,18 +79,7 @@ where T: WrapperType, { let sql = format!("SELECT {column_name} FROM {table_name}"); - let stmt = QueryLanguageParser::parse_sql(&sql).unwrap(); - let plan = engine - .statement_to_plan(stmt, Arc::new(QueryContext::new())) - .await - .unwrap(); - - let output = engine.execute(&plan).await.unwrap(); - let recordbatch_stream = match output { - Output::Stream(batch) => batch, - _ => unreachable!(), - }; - let numbers = util::collect(recordbatch_stream).await.unwrap(); + let numbers = exec_selection(engine, &sql).await; let column = numbers[0].column(0); let column: &::VectorType = unsafe { Helper::static_cast(column) }; diff --git a/src/query/src/tests/mean_test.rs b/src/query/src/tests/mean_test.rs index 2e044e6d01..5ae9a0a605 100644 --- a/src/query/src/tests/mean_test.rs +++ b/src/query/src/tests/mean_test.rs @@ -14,20 +14,15 @@ use std::sync::Arc; -use common_query::Output; -use common_recordbatch::error::Result as RecordResult; -use common_recordbatch::{util, RecordBatch}; use datatypes::for_all_primitive_types; use datatypes::prelude::*; use datatypes::types::WrapperType; use datatypes::value::OrderedFloat; use format_num::NumberFormat; use num_traits::AsPrimitive; -use session::context::QueryContext; use crate::error::Result; -use crate::parser::QueryLanguageParser; -use crate::tests::function; +use crate::tests::{exec_selection, function}; use crate::QueryEngine; #[tokio::test] @@ -55,9 +50,8 @@ async fn test_mean_success( where T: WrapperType + AsPrimitive, { - let result = execute_mean(column_name, table_name, engine.clone()) - .await - .unwrap(); + let sql = format!("select MEAN({column_name}) as mean from {table_name}"); + let result = exec_selection(engine.clone(), &sql).await; let value = function::get_value_from_batches("mean", result); let numbers = @@ -73,23 +67,3 @@ where } Ok(()) } - -async fn execute_mean<'a>( - column_name: &'a str, - table_name: &'a str, - engine: Arc, -) -> RecordResult> { - let sql = format!("select MEAN({column_name}) as mean from {table_name}"); - let stmt = QueryLanguageParser::parse_sql(&sql).unwrap(); - let plan = engine - .statement_to_plan(stmt, Arc::new(QueryContext::new())) - .await - .unwrap(); - - let output = engine.execute(&plan).await.unwrap(); - let recordbatch_stream = match output { - Output::Stream(batch) => batch, - _ => unreachable!(), - }; - util::collect(recordbatch_stream).await -} diff --git a/src/query/src/tests/my_sum_udaf_example.rs b/src/query/src/tests/my_sum_udaf_example.rs index 1975a12b0e..27e4981cc0 100644 --- a/src/query/src/tests/my_sum_udaf_example.rs +++ b/src/query/src/tests/my_sum_udaf_example.rs @@ -24,19 +24,17 @@ use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore}; use common_query::error::{CreateAccumulatorSnafu, Result as QueryResult}; use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; use common_query::prelude::*; -use common_query::Output; -use common_recordbatch::{util, RecordBatch}; +use common_recordbatch::{RecordBatch, RecordBatches}; use datatypes::prelude::*; use datatypes::schema::{ColumnSchema, Schema}; use datatypes::types::{LogicalPrimitiveType, WrapperType}; use datatypes::vectors::Helper; use datatypes::with_match_primitive_type_id; use num_traits::AsPrimitive; -use session::context::QueryContext; use table::test_util::MemTable; use crate::error::Result; -use crate::parser::QueryLanguageParser; +use crate::tests::exec_selection; use crate::QueryEngineFactory; #[derive(Debug, Default)] @@ -220,18 +218,8 @@ where ))); let sql = format!("select MY_SUM({column_name}) as my_sum from {table_name}"); - let stmt = QueryLanguageParser::parse_sql(&sql).unwrap(); - let plan = engine - .statement_to_plan(stmt, Arc::new(QueryContext::new())) - .await - .unwrap(); - - let output = engine.execute(&plan).await?; - let recordbatch_stream = match output { - Output::Stream(batch) => batch, - _ => unreachable!(), - }; - let batches = util::collect_batches(recordbatch_stream).await.unwrap(); + let batches = exec_selection(engine, &sql).await; + let batches = RecordBatches::try_new(batches.first().unwrap().schema.clone(), batches).unwrap(); let pretty_print = batches.pretty_print().unwrap(); assert_eq!(expected, pretty_print); diff --git a/src/query/src/tests/percentile_test.rs b/src/query/src/tests/percentile_test.rs index b0aecc3e8c..eefb825d75 100644 --- a/src/query/src/tests/percentile_test.rs +++ b/src/query/src/tests/percentile_test.rs @@ -17,21 +17,17 @@ use std::sync::Arc; use catalog::local::{MemoryCatalogManager, MemoryCatalogProvider, MemorySchemaProvider}; use catalog::{CatalogList, CatalogProvider, SchemaProvider}; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; -use common_query::Output; -use common_recordbatch::error::Result as RecordResult; -use common_recordbatch::{util, RecordBatch}; +use common_recordbatch::RecordBatch; use datatypes::for_all_primitive_types; use datatypes::prelude::*; use datatypes::schema::{ColumnSchema, Schema}; use datatypes::vectors::Int32Vector; use function::{create_query_engine, get_numbers_from_table}; use num_traits::AsPrimitive; -use session::context::QueryContext; use table::test_util::MemTable; use crate::error::Result; -use crate::parser::QueryLanguageParser; -use crate::tests::function; +use crate::tests::{exec_selection, function}; use crate::{QueryEngine, QueryEngineFactory}; #[tokio::test] @@ -55,18 +51,7 @@ async fn test_percentile_aggregator() -> Result<()> { async fn test_percentile_correctness() -> Result<()> { let engine = create_correctness_engine(); let sql = String::from("select PERCENTILE(corr_number,88.0) as percentile from corr_numbers"); - let stmt = QueryLanguageParser::parse_sql(&sql).unwrap(); - let plan = engine - .statement_to_plan(stmt, Arc::new(QueryContext::new())) - .await - .unwrap(); - - let output = engine.execute(&plan).await.unwrap(); - let recordbatch_stream = match output { - Output::Stream(batch) => batch, - _ => unreachable!(), - }; - let record_batch = util::collect(recordbatch_stream).await.unwrap(); + let record_batch = exec_selection(engine, &sql).await; let column = record_batch[0].column(0); let value = column.get(0); assert_eq!(value, Value::from(9.280_000_000_000_001_f64)); @@ -81,9 +66,8 @@ async fn test_percentile_success( where T: WrapperType + AsPrimitive, { - let result = execute_percentile(column_name, table_name, engine.clone()) - .await - .unwrap(); + let sql = format!("select PERCENTILE({column_name},50.0) as percentile from {table_name}"); + let result = exec_selection(engine.clone(), &sql).await; let value = function::get_value_from_batches("percentile", result); let numbers = get_numbers_from_table::(column_name, table_name, engine.clone()).await; @@ -95,26 +79,6 @@ where Ok(()) } -async fn execute_percentile<'a>( - column_name: &'a str, - table_name: &'a str, - engine: Arc, -) -> RecordResult> { - let sql = format!("select PERCENTILE({column_name},50.0) as percentile from {table_name}"); - let stmt = QueryLanguageParser::parse_sql(&sql).unwrap(); - let plan = engine - .statement_to_plan(stmt, Arc::new(QueryContext::new())) - .await - .unwrap(); - - let output = engine.execute(&plan).await.unwrap(); - let recordbatch_stream = match output { - Output::Stream(batch) => batch, - _ => unreachable!(), - }; - util::collect(recordbatch_stream).await -} - fn create_correctness_engine() -> Arc { // create engine let schema_provider = Arc::new(MemorySchemaProvider::new()); diff --git a/src/query/src/tests/polyval_test.rs b/src/query/src/tests/polyval_test.rs index f2f4834edd..5e0f44d559 100644 --- a/src/query/src/tests/polyval_test.rs +++ b/src/query/src/tests/polyval_test.rs @@ -14,18 +14,13 @@ use std::sync::Arc; -use common_query::Output; -use common_recordbatch::error::Result as RecordResult; -use common_recordbatch::{util, RecordBatch}; use datatypes::for_all_primitive_types; use datatypes::prelude::*; use datatypes::types::WrapperType; use num_traits::AsPrimitive; -use session::context::QueryContext; use crate::error::Result; -use crate::parser::QueryLanguageParser; -use crate::tests::function; +use crate::tests::{exec_selection, function}; use crate::QueryEngine; #[tokio::test] @@ -57,9 +52,8 @@ where PolyT::Native: std::ops::Mul + std::iter::Sum, i64: AsPrimitive, { - let result = execute_polyval(column_name, table_name, engine.clone()) - .await - .unwrap(); + let sql = format!("select POLYVAL({column_name}, 0) as polyval from {table_name}"); + let result = exec_selection(engine.clone(), &sql).await; let value = function::get_value_from_batches("polyval", result); let numbers = @@ -74,23 +68,3 @@ where assert_eq!(value, PolyT::from_native(expected_native).into()); Ok(()) } - -async fn execute_polyval<'a>( - column_name: &'a str, - table_name: &'a str, - engine: Arc, -) -> RecordResult> { - let sql = format!("select POLYVAL({column_name}, 0) as polyval from {table_name}"); - let stmt = QueryLanguageParser::parse_sql(&sql).unwrap(); - let plan = engine - .statement_to_plan(stmt, Arc::new(QueryContext::new())) - .await - .unwrap(); - - let output = engine.execute(&plan).await.unwrap(); - let recordbatch_stream = match output { - Output::Stream(batch) => batch, - _ => unreachable!(), - }; - util::collect(recordbatch_stream).await -} diff --git a/src/query/src/tests/query_engine_test.rs b/src/query/src/tests/query_engine_test.rs index f4285f4264..f0fb8ee6f6 100644 --- a/src/query/src/tests/query_engine_test.rs +++ b/src/query/src/tests/query_engine_test.rs @@ -38,6 +38,7 @@ use crate::parser::QueryLanguageParser; use crate::plan::LogicalPlan; use crate::query_engine::options::QueryOptions; use crate::query_engine::QueryEngineFactory; +use crate::tests::exec_selection; use crate::tests::pow::pow; #[tokio::test] @@ -138,13 +139,15 @@ async fn test_query_validate() -> Result<()> { let stmt = QueryLanguageParser::parse_sql("select number from public.numbers").unwrap(); assert!(engine - .statement_to_plan(stmt, QueryContext::arc()) + .planner() + .plan(stmt, QueryContext::arc()) .await .is_ok()); let stmt = QueryLanguageParser::parse_sql("select number from wrongschema.numbers").unwrap(); assert!(engine - .statement_to_plan(stmt, QueryContext::arc()) + .planner() + .plan(stmt, QueryContext::arc()) .await .is_err()); Ok(()) @@ -174,21 +177,8 @@ async fn test_udf() -> Result<()> { engine.register_udf(udf); - let stmt = - QueryLanguageParser::parse_sql("select my_pow(number, number) as p from numbers limit 10") - .unwrap(); - let plan = engine - .statement_to_plan(stmt, Arc::new(QueryContext::new())) - .await - .unwrap(); - - let output = engine.execute(&plan).await?; - let recordbatch = match output { - Output::Stream(recordbatch) => recordbatch, - _ => unreachable!(), - }; - - let numbers = util::collect(recordbatch).await.unwrap(); + let sql = "select my_pow(number, number) as p from numbers limit 10"; + let numbers = exec_selection(engine, sql).await; assert_eq!(1, numbers.len()); assert_eq!(numbers[0].num_columns(), 1); assert_eq!(1, numbers[0].schema.num_columns()); diff --git a/src/query/src/tests/scipy_stats_norm_cdf_test.rs b/src/query/src/tests/scipy_stats_norm_cdf_test.rs index 21e3cdf96e..de4015c0b7 100644 --- a/src/query/src/tests/scipy_stats_norm_cdf_test.rs +++ b/src/query/src/tests/scipy_stats_norm_cdf_test.rs @@ -14,19 +14,14 @@ use std::sync::Arc; -use common_query::Output; -use common_recordbatch::error::Result as RecordResult; -use common_recordbatch::{util, RecordBatch}; use datatypes::for_all_primitive_types; use datatypes::types::WrapperType; use num_traits::AsPrimitive; -use session::context::QueryContext; use statrs::distribution::{ContinuousCDF, Normal}; use statrs::statistics::Statistics; use crate::error::Result; -use crate::parser::QueryLanguageParser; -use crate::tests::function; +use crate::tests::{exec_selection, function}; use crate::QueryEngine; #[tokio::test] @@ -54,9 +49,10 @@ async fn test_scipy_stats_norm_cdf_success( where T: WrapperType + AsPrimitive, { - let result = execute_scipy_stats_norm_cdf(column_name, table_name, engine.clone()) - .await - .unwrap(); + let sql = format!( + "select SCIPYSTATSNORMCDF({column_name},2.0) as scipy_stats_norm_cdf from {table_name}", + ); + let result = exec_selection(engine.clone(), &sql).await; let value = function::get_value_from_batches("scipy_stats_norm_cdf", result); let numbers = @@ -71,25 +67,3 @@ where assert_eq!(value, expected_value.into()); Ok(()) } - -async fn execute_scipy_stats_norm_cdf<'a>( - column_name: &'a str, - table_name: &'a str, - engine: Arc, -) -> RecordResult> { - let sql = format!( - "select SCIPYSTATSNORMCDF({column_name},2.0) as scipy_stats_norm_cdf from {table_name}", - ); - let stmt = QueryLanguageParser::parse_sql(&sql).unwrap(); - let plan = engine - .statement_to_plan(stmt, Arc::new(QueryContext::new())) - .await - .unwrap(); - - let output = engine.execute(&plan).await.unwrap(); - let recordbatch_stream = match output { - Output::Stream(batch) => batch, - _ => unreachable!(), - }; - util::collect(recordbatch_stream).await -} diff --git a/src/query/src/tests/scipy_stats_norm_pdf.rs b/src/query/src/tests/scipy_stats_norm_pdf.rs index 21b2b04798..85e0cd7771 100644 --- a/src/query/src/tests/scipy_stats_norm_pdf.rs +++ b/src/query/src/tests/scipy_stats_norm_pdf.rs @@ -14,19 +14,14 @@ use std::sync::Arc; -use common_query::Output; -use common_recordbatch::error::Result as RecordResult; -use common_recordbatch::{util, RecordBatch}; use datatypes::for_all_primitive_types; use datatypes::types::WrapperType; use num_traits::AsPrimitive; -use session::context::QueryContext; use statrs::distribution::{Continuous, Normal}; use statrs::statistics::Statistics; use crate::error::Result; -use crate::parser::QueryLanguageParser; -use crate::tests::function; +use crate::tests::{exec_selection, function}; use crate::QueryEngine; #[tokio::test] @@ -54,9 +49,10 @@ async fn test_scipy_stats_norm_pdf_success( where T: WrapperType + AsPrimitive, { - let result = execute_scipy_stats_norm_pdf(column_name, table_name, engine.clone()) - .await - .unwrap(); + let sql = format!( + "select SCIPYSTATSNORMPDF({column_name},2.0) as scipy_stats_norm_pdf from {table_name}" + ); + let result = exec_selection(engine.clone(), &sql).await; let value = function::get_value_from_batches("scipy_stats_norm_pdf", result); let numbers = @@ -71,25 +67,3 @@ where assert_eq!(value, expected_value.into()); Ok(()) } - -async fn execute_scipy_stats_norm_pdf<'a>( - column_name: &'a str, - table_name: &'a str, - engine: Arc, -) -> RecordResult> { - let sql = format!( - "select SCIPYSTATSNORMPDF({column_name},2.0) as scipy_stats_norm_pdf from {table_name}" - ); - let stmt = QueryLanguageParser::parse_sql(&sql).unwrap(); - let plan = engine - .statement_to_plan(stmt, Arc::new(QueryContext::new())) - .await - .unwrap(); - - let output = engine.execute(&plan).await.unwrap(); - let recordbatch_stream = match output { - Output::Stream(batch) => batch, - _ => unreachable!(), - }; - util::collect(recordbatch_stream).await -} diff --git a/src/query/src/tests/time_range_filter_test.rs b/src/query/src/tests/time_range_filter_test.rs index 31bbb3ce2e..3a06e0f525 100644 --- a/src/query/src/tests/time_range_filter_test.rs +++ b/src/query/src/tests/time_range_filter_test.rs @@ -26,14 +26,13 @@ use common_time::Timestamp; use datatypes::data_type::ConcreteDataType; use datatypes::schema::{ColumnSchema, Schema, SchemaRef}; use datatypes::vectors::{Int64Vector, TimestampMillisecondVector}; -use session::context::QueryContext; use table::metadata::{FilterPushDownType, TableInfoRef}; use table::predicate::TimeRangePredicateBuilder; use table::test_util::MemTable; use table::Table; use tokio::sync::RwLock; -use crate::parser::QueryLanguageParser; +use crate::tests::exec_selection; use crate::{QueryEngineFactory, QueryEngineRef}; struct MemTableWrapper { @@ -128,18 +127,7 @@ struct TimeRangeTester { impl TimeRangeTester { async fn check(&self, sql: &str, expect: TimestampRange) { - let stmt = QueryLanguageParser::parse_sql(sql).unwrap(); - let _ = self - .engine - .execute( - &self - .engine - .statement_to_plan(stmt, Arc::new(QueryContext::new())) - .await - .unwrap(), - ) - .await - .unwrap(); + let _ = exec_selection(self.engine.clone(), sql).await; let filters = self.table.get_filters().await; let range = TimeRangePredicateBuilder::new("ts", &filters).build(); diff --git a/src/script/src/python/engine.rs b/src/script/src/python/engine.rs index 0f3a01be6b..8634647c7a 100644 --- a/src/script/src/python/engine.rs +++ b/src/script/src/python/engine.rs @@ -279,7 +279,8 @@ impl Script for PyScript { ); let plan = self .query_engine - .statement_to_plan(stmt, Arc::new(QueryContext::new())) + .planner() + .plan(stmt, QueryContext::arc()) .await?; let res = self.query_engine.execute(&plan).await?; let copr = self.copr.clone(); diff --git a/src/script/src/python/ffi_types/copr.rs b/src/script/src/python/ffi_types/copr.rs index 2ec5b4d906..55824cbb78 100644 --- a/src/script/src/python/ffi_types/copr.rs +++ b/src/script/src/python/ffi_types/copr.rs @@ -367,7 +367,8 @@ impl PyQueryEngine { let handle = rt.handle().clone(); let res = handle.block_on(async { let plan = engine - .statement_to_plan(stmt, Default::default()) + .planner() + .plan(stmt, Default::default()) .await .map_err(|e| e.to_string())?; let res = engine diff --git a/src/script/src/table.rs b/src/script/src/table.rs index ea4b2470ae..82f67fbbd4 100644 --- a/src/script/src/table.rs +++ b/src/script/src/table.rs @@ -160,7 +160,8 @@ impl ScriptsTable { let plan = self .query_engine - .statement_to_plan(stmt, Arc::new(QueryContext::new())) + .planner() + .plan(stmt, QueryContext::arc()) .await .unwrap(); diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index 219d6733a0..4bdd945dc7 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -614,14 +614,6 @@ mod test { unimplemented!() } - async fn do_statement_query( - &self, - _stmt: sql::statements::statement::Statement, - _query_ctx: QueryContextRef, - ) -> Result { - unimplemented!() - } - async fn do_describe( &self, _stmt: sql::statements::statement::Statement, diff --git a/src/servers/src/query_handler/sql.rs b/src/servers/src/query_handler/sql.rs index 59b7920f34..76259bf157 100644 --- a/src/servers/src/query_handler/sql.rs +++ b/src/servers/src/query_handler/sql.rs @@ -43,12 +43,6 @@ pub trait SqlQueryHandler { query_ctx: QueryContextRef, ) -> Vec>; - async fn do_statement_query( - &self, - stmt: Statement, - query_ctx: QueryContextRef, - ) -> std::result::Result; - // TODO(LFC): revisit this for mysql prepared statement async fn do_describe( &self, @@ -110,18 +104,6 @@ where .collect() } - async fn do_statement_query( - &self, - stmt: Statement, - query_ctx: QueryContextRef, - ) -> Result { - self.0 - .do_statement_query(stmt, query_ctx) - .await - .map_err(BoxedError::new) - .context(error::ExecuteStatementSnafu) - } - async fn do_describe( &self, stmt: Statement, diff --git a/src/servers/tests/http/influxdb_test.rs b/src/servers/tests/http/influxdb_test.rs index d56b6c5f3a..123d9e0d02 100644 --- a/src/servers/tests/http/influxdb_test.rs +++ b/src/servers/tests/http/influxdb_test.rs @@ -64,14 +64,6 @@ impl SqlQueryHandler for DummyInstance { unimplemented!() } - async fn do_statement_query( - &self, - _stmt: sql::statements::statement::Statement, - _query_ctx: QueryContextRef, - ) -> Result { - unimplemented!() - } - async fn do_describe( &self, _stmt: sql::statements::statement::Statement, diff --git a/src/servers/tests/http/opentsdb_test.rs b/src/servers/tests/http/opentsdb_test.rs index e9894804ec..96e40d8fd0 100644 --- a/src/servers/tests/http/opentsdb_test.rs +++ b/src/servers/tests/http/opentsdb_test.rs @@ -62,14 +62,6 @@ impl SqlQueryHandler for DummyInstance { unimplemented!() } - async fn do_statement_query( - &self, - _stmt: sql::statements::statement::Statement, - _query_ctx: QueryContextRef, - ) -> Result { - unimplemented!() - } - async fn do_describe( &self, _stmt: sql::statements::statement::Statement, diff --git a/src/servers/tests/http/prometheus_test.rs b/src/servers/tests/http/prometheus_test.rs index 7ca0913d90..2a40af43b0 100644 --- a/src/servers/tests/http/prometheus_test.rs +++ b/src/servers/tests/http/prometheus_test.rs @@ -87,14 +87,6 @@ impl SqlQueryHandler for DummyInstance { unimplemented!() } - async fn do_statement_query( - &self, - _stmt: sql::statements::statement::Statement, - _query_ctx: QueryContextRef, - ) -> Result { - unimplemented!() - } - async fn do_describe( &self, _stmt: sql::statements::statement::Statement, diff --git a/src/servers/tests/mod.rs b/src/servers/tests/mod.rs index 6d5d41996f..789b952c5b 100644 --- a/src/servers/tests/mod.rs +++ b/src/servers/tests/mod.rs @@ -71,7 +71,8 @@ impl SqlQueryHandler for DummyInstance { let stmt = QueryLanguageParser::parse_sql(query).unwrap(); let plan = self .query_engine - .statement_to_plan(stmt, query_ctx) + .planner() + .plan(stmt, query_ctx) .await .unwrap(); let output = self.query_engine.execute(&plan).await.unwrap(); @@ -86,25 +87,19 @@ impl SqlQueryHandler for DummyInstance { unimplemented!() } - async fn do_statement_query( - &self, - _stmt: Statement, - _query_ctx: QueryContextRef, - ) -> Result { - unimplemented!() - } - async fn do_describe( &self, stmt: Statement, query_ctx: QueryContextRef, ) -> Result> { if let Statement::Query(_) = stmt { - let schema = self + let plan = self .query_engine - .describe(QueryStatement::Sql(stmt), query_ctx) + .planner() + .plan(QueryStatement::Sql(stmt), query_ctx) .await .unwrap(); + let schema = self.query_engine.describe(plan).await.unwrap(); Ok(Some(schema)) } else { Ok(None) diff --git a/tests-integration/src/test_util.rs b/tests-integration/src/test_util.rs index ff2b804b00..6013010d77 100644 --- a/tests-integration/src/test_util.rs +++ b/tests-integration/src/test_util.rs @@ -257,7 +257,7 @@ pub async fn create_test_table( Ok(()) } -async fn build_frontend_instance(datanode_instance: InstanceRef) -> FeInstance { +fn build_frontend_instance(datanode_instance: InstanceRef) -> FeInstance { let mut frontend_instance = FeInstance::new_standalone(datanode_instance.clone()); frontend_instance.set_script_handler(datanode_instance); frontend_instance @@ -275,7 +275,7 @@ pub async fn setup_test_http_app(store_type: StorageType, name: &str) -> (Router .await .unwrap(); let http_server = HttpServer::new( - ServerSqlQueryHandlerAdaptor::arc(instance), + ServerSqlQueryHandlerAdaptor::arc(Arc::new(build_frontend_instance(instance))), HttpOptions::default(), ); (http_server.make_app(), guard) @@ -287,7 +287,7 @@ pub async fn setup_test_http_app_with_frontend( ) -> (Router, TestGuard) { let (opts, guard) = create_tmp_dir_and_datanode_opts(store_type, name); let instance = Arc::new(Instance::with_mock_meta_client(&opts).await.unwrap()); - let frontend = build_frontend_instance(instance.clone()).await; + let frontend = build_frontend_instance(instance.clone()); instance.start().await.unwrap(); create_test_table( frontend.catalog_manager(), @@ -311,7 +311,7 @@ pub async fn setup_test_prom_app_with_frontend( ) -> (Router, TestGuard) { let (opts, guard) = create_tmp_dir_and_datanode_opts(store_type, name); let instance = Arc::new(Instance::with_mock_meta_client(&opts).await.unwrap()); - let frontend = build_frontend_instance(instance.clone()).await; + let frontend = build_frontend_instance(instance.clone()); instance.start().await.unwrap(); create_test_table( frontend.catalog_manager(),