From 6127706b5bbd9e9a07171892ff6e0fd2d08cbf46 Mon Sep 17 00:00:00 2001 From: LFC Date: Thu, 1 Dec 2022 17:05:32 +0800 Subject: [PATCH] feat: support "use" stmt part 1 (#672) * feat: a bare sketch of session; support "use" in MySQL server; modify insertion and selection related codes in Datanode --- Cargo.lock | 13 + Cargo.toml | 1 + src/common/catalog/src/helper.rs | 27 +- src/datanode/Cargo.toml | 1 + src/datanode/src/instance/grpc.rs | 7 +- src/datanode/src/instance/sql.rs | 202 +++++++++-- src/datanode/src/server/grpc/ddl.rs | 20 +- src/datanode/src/sql.rs | 29 +- src/datanode/src/sql/alter.rs | 21 +- src/datanode/src/sql/create.rs | 42 ++- src/datanode/src/sql/insert.rs | 23 +- src/datanode/src/tests/instance_test.rs | 358 +++++++++---------- src/frontend/Cargo.toml | 1 + src/frontend/src/instance.rs | 121 ++++--- src/frontend/src/instance/distributed.rs | 33 +- src/frontend/src/instance/opentsdb.rs | 5 +- src/frontend/src/instance/prometheus.rs | 8 +- src/query/Cargo.toml | 1 + src/query/src/datafusion.rs | 37 +- src/query/src/datafusion/planner.rs | 26 +- src/query/src/executor.rs | 4 +- src/query/src/lib.rs | 4 +- src/query/src/logical_optimizer.rs | 4 +- src/query/src/physical_optimizer.rs | 4 +- src/query/src/physical_planner.rs | 4 +- src/query/src/query_engine.rs | 8 +- src/query/src/query_engine/context.rs | 4 +- src/query/src/sql.rs | 25 +- src/query/tests/argmax_test.rs | 5 +- src/query/tests/argmin_test.rs | 5 +- src/query/tests/function.rs | 5 +- src/query/tests/mean_test.rs | 5 +- src/query/tests/my_sum_udaf_example.rs | 3 +- src/query/tests/percentile_test.rs | 9 +- src/query/tests/polyval_test.rs | 5 +- src/query/tests/query_engine_test.rs | 14 +- src/query/tests/scipy_stats_norm_cdf_test.rs | 5 +- src/query/tests/scipy_stats_norm_pdf.rs | 5 +- src/script/Cargo.toml | 1 + src/script/src/python/engine.rs | 5 +- src/script/src/table.rs | 3 +- src/servers/Cargo.toml | 1 + src/servers/src/http.rs | 3 +- src/servers/src/http/handler.rs | 6 +- src/servers/src/mysql/federated.rs | 25 +- src/servers/src/mysql/handler.rs | 62 +++- src/servers/src/postgres/handler.rs | 6 +- src/servers/src/query_handler.rs | 3 +- src/servers/tests/http/influxdb_test.rs | 3 +- src/servers/tests/http/opentsdb_test.rs | 3 +- src/servers/tests/http/prometheus_test.rs | 3 +- src/servers/tests/mod.rs | 6 +- src/session/Cargo.toml | 9 + src/session/src/context.rs | 56 +++ src/session/src/lib.rs | 36 ++ src/sql/src/parser.rs | 15 + src/sql/src/statements.rs | 2 + src/sql/src/statements/insert.rs | 18 +- src/sql/src/statements/statement.rs | 56 +-- src/table/src/engine.rs | 21 ++ 60 files changed, 943 insertions(+), 494 deletions(-) create mode 100644 src/session/Cargo.toml create mode 100644 src/session/src/context.rs create mode 100644 src/session/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 84c7b25a4e..d3836fbf5a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1834,6 +1834,7 @@ dependencies = [ "serde", "serde_json", "servers", + "session", "snafu", "sql", "storage", @@ -2231,6 +2232,7 @@ dependencies = [ "serde", "serde_json", "servers", + "session", "snafu", "sql", "sqlparser", @@ -4515,6 +4517,7 @@ dependencies = [ "rand 0.8.5", "serde", "serde_json", + "session", "snafu", "sql", "statrs", @@ -5366,6 +5369,7 @@ dependencies = [ "rustpython-parser", "rustpython-vm", "serde", + "session", "snafu", "sql", "storage", @@ -5535,6 +5539,7 @@ dependencies = [ "script", "serde", "serde_json", + "session", "snafu", "snap", "table", @@ -5550,6 +5555,14 @@ dependencies = [ "tower-http", ] +[[package]] +name = "session" +version = "0.1.0" +dependencies = [ + "arc-swap", + "common-telemetry", +] + [[package]] name = "sha-1" version = "0.10.0" diff --git a/Cargo.toml b/Cargo.toml index 678ef002ce..4256063365 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ members = [ "src/query", "src/script", "src/servers", + "src/session", "src/sql", "src/storage", "src/store-api", diff --git a/src/common/catalog/src/helper.rs b/src/common/catalog/src/helper.rs index 3b7f20d959..dcfa08e8a7 100644 --- a/src/common/catalog/src/helper.rs +++ b/src/common/catalog/src/helper.rs @@ -28,31 +28,42 @@ use crate::error::{ DeserializeCatalogEntryValueSnafu, Error, InvalidCatalogSnafu, SerializeCatalogEntryValueSnafu, }; +const ALPHANUMERICS_NAME_PATTERN: &str = "[a-zA-Z_][a-zA-Z0-9_]*"; + lazy_static! { - static ref CATALOG_KEY_PATTERN: Regex = - Regex::new(&format!("^{}-([a-zA-Z_]+)$", CATALOG_KEY_PREFIX)).unwrap(); + static ref CATALOG_KEY_PATTERN: Regex = Regex::new(&format!( + "^{}-({})$", + CATALOG_KEY_PREFIX, ALPHANUMERICS_NAME_PATTERN + )) + .unwrap(); } lazy_static! { static ref SCHEMA_KEY_PATTERN: Regex = Regex::new(&format!( - "^{}-([a-zA-Z_]+)-([a-zA-Z_]+)$", - SCHEMA_KEY_PREFIX + "^{}-({})-({})$", + SCHEMA_KEY_PREFIX, ALPHANUMERICS_NAME_PATTERN, ALPHANUMERICS_NAME_PATTERN )) .unwrap(); } lazy_static! { static ref TABLE_GLOBAL_KEY_PATTERN: Regex = Regex::new(&format!( - "^{}-([a-zA-Z_]+)-([a-zA-Z_]+)-([a-zA-Z0-9_]+)$", - TABLE_GLOBAL_KEY_PREFIX + "^{}-({})-({})-({})$", + TABLE_GLOBAL_KEY_PREFIX, + ALPHANUMERICS_NAME_PATTERN, + ALPHANUMERICS_NAME_PATTERN, + ALPHANUMERICS_NAME_PATTERN )) .unwrap(); } lazy_static! { static ref TABLE_REGIONAL_KEY_PATTERN: Regex = Regex::new(&format!( - "^{}-([a-zA-Z_]+)-([a-zA-Z_]+)-([a-zA-Z0-9_]+)-([0-9]+)$", - TABLE_REGIONAL_KEY_PREFIX + "^{}-({})-({})-({})-([0-9]+)$", + TABLE_REGIONAL_KEY_PREFIX, + ALPHANUMERICS_NAME_PATTERN, + ALPHANUMERICS_NAME_PATTERN, + ALPHANUMERICS_NAME_PATTERN )) .unwrap(); } diff --git a/src/datanode/Cargo.toml b/src/datanode/Cargo.toml index 5bcdca582f..c5d63273a0 100644 --- a/src/datanode/Cargo.toml +++ b/src/datanode/Cargo.toml @@ -37,6 +37,7 @@ meta-srv = { path = "../meta-srv", features = ["mock"] } metrics = "0.20" object-store = { path = "../object-store" } query = { path = "../query" } +session = { path = "../session" } script = { path = "../script", features = ["python"], optional = true } serde = "1.0" serde_json = "1.0" diff --git a/src/datanode/src/instance/grpc.rs b/src/datanode/src/instance/grpc.rs index da7d9f4136..ddc03a6436 100644 --- a/src/datanode/src/instance/grpc.rs +++ b/src/datanode/src/instance/grpc.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + use api::result::{build_err_result, AdminResultBuilder, ObjectResultBuilder}; use api::v1::{ admin_expr, object_expr, select_expr, AdminExpr, AdminResult, Column, CreateDatabaseExpr, @@ -26,6 +28,7 @@ use common_grpc_expr::insertion_expr_to_request; use common_query::Output; use query::plan::LogicalPlan; use servers::query_handler::{GrpcAdminHandler, GrpcQueryHandler}; +use session::context::QueryContext; use snafu::prelude::*; use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan}; use table::requests::CreateDatabaseRequest; @@ -110,7 +113,9 @@ impl Instance { async fn do_handle_select(&self, select_expr: SelectExpr) -> Result { let expr = select_expr.expr; match expr { - Some(select_expr::Expr::Sql(sql)) => self.execute_sql(&sql).await, + Some(select_expr::Expr::Sql(sql)) => { + self.execute_sql(&sql, Arc::new(QueryContext::new())).await + } Some(select_expr::Expr::LogicalPlan(plan)) => self.execute_logical(plan).await, Some(select_expr::Expr::PhysicalPlan(api::v1::PhysicalPlan { original_ql, plan })) => { self.physical_planner diff --git a/src/datanode/src/instance/sql.rs b/src/datanode/src/instance/sql.rs index 5737f94a60..80149dda5c 100644 --- a/src/datanode/src/instance/sql.rs +++ b/src/datanode/src/instance/sql.rs @@ -13,25 +13,27 @@ // limitations under the License. use async_trait::async_trait; +use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_error::prelude::BoxedError; use common_query::Output; +use common_recordbatch::RecordBatches; use common_telemetry::logging::{error, info}; use common_telemetry::timer; use servers::query_handler::SqlQueryHandler; +use session::context::QueryContextRef; use snafu::prelude::*; +use sql::ast::ObjectName; use sql::statements::statement::Statement; +use table::engine::TableReference; use table::requests::CreateDatabaseRequest; -use crate::error::{ - BumpTableIdSnafu, CatalogNotFoundSnafu, CatalogSnafu, ExecuteSqlSnafu, ParseSqlSnafu, Result, - SchemaNotFoundSnafu, TableIdProviderNotFoundSnafu, -}; +use crate::error::{self, BumpTableIdSnafu, ExecuteSqlSnafu, Result, TableIdProviderNotFoundSnafu}; use crate::instance::Instance; use crate::metric; use crate::sql::SqlRequest; impl Instance { - pub async fn execute_sql(&self, sql: &str) -> Result { + pub async fn execute_sql(&self, sql: &str, query_ctx: QueryContextRef) -> Result { let stmt = self .query_engine .sql_to_statement(sql) @@ -41,7 +43,7 @@ impl Instance { Statement::Query(_) => { let logical_plan = self .query_engine - .statement_to_plan(stmt) + .statement_to_plan(stmt, query_ctx) .context(ExecuteSqlSnafu)?; self.query_engine @@ -50,20 +52,15 @@ impl Instance { .context(ExecuteSqlSnafu) } Statement::Insert(i) => { - let (catalog_name, schema_name, _table_name) = - i.full_table_name().context(ParseSqlSnafu)?; - - let schema_provider = self - .catalog_manager - .catalog(&catalog_name) - .context(CatalogSnafu)? - .context(CatalogNotFoundSnafu { name: catalog_name })? - .schema(&schema_name) - .context(CatalogSnafu)? - .context(SchemaNotFoundSnafu { name: schema_name })?; - - let request = self.sql_handler.insert_to_request(schema_provider, *i)?; - self.sql_handler.execute(request).await + let (catalog, schema, table) = + table_idents_to_full_name(i.table_name(), query_ctx.clone())?; + let table_ref = TableReference::full(&catalog, &schema, &table); + let request = self.sql_handler.insert_to_request( + self.catalog_manager.clone(), + *i, + table_ref, + )?; + self.sql_handler.execute(request, query_ctx).await } Statement::CreateDatabase(c) => { @@ -74,7 +71,7 @@ impl Instance { info!("Creating a new database: {}", request.db_name); self.sql_handler - .execute(SqlRequest::CreateDatabase(request)) + .execute(SqlRequest::CreateDatabase(request), query_ctx) .await } @@ -89,58 +86,116 @@ impl Instance { let _engine_name = c.engine.clone(); // TODO(hl): Select table engine by engine_name - let request = self.sql_handler.create_to_request(table_id, c)?; - let catalog_name = &request.catalog_name; - let schema_name = &request.schema_name; - let table_name = &request.table_name; + let name = c.name.clone(); + let (catalog, schema, table) = table_idents_to_full_name(&name, query_ctx.clone())?; + let table_ref = TableReference::full(&catalog, &schema, &table); + let request = self.sql_handler.create_to_request(table_id, c, table_ref)?; let table_id = request.id; info!( "Creating table, catalog: {:?}, schema: {:?}, table name: {:?}, table id: {}", - catalog_name, schema_name, table_name, table_id + catalog, schema, table, table_id ); self.sql_handler - .execute(SqlRequest::CreateTable(request)) + .execute(SqlRequest::CreateTable(request), query_ctx) .await } Statement::Alter(alter_table) => { - let req = self.sql_handler.alter_to_request(alter_table)?; - self.sql_handler.execute(SqlRequest::Alter(req)).await + let name = alter_table.table_name().clone(); + let (catalog, schema, table) = table_idents_to_full_name(&name, query_ctx.clone())?; + let table_ref = TableReference::full(&catalog, &schema, &table); + let req = self.sql_handler.alter_to_request(alter_table, table_ref)?; + self.sql_handler + .execute(SqlRequest::Alter(req), query_ctx) + .await } Statement::DropTable(drop_table) => { let req = self.sql_handler.drop_table_to_request(drop_table); - self.sql_handler.execute(SqlRequest::DropTable(req)).await + self.sql_handler + .execute(SqlRequest::DropTable(req), query_ctx) + .await } Statement::ShowDatabases(stmt) => { self.sql_handler - .execute(SqlRequest::ShowDatabases(stmt)) + .execute(SqlRequest::ShowDatabases(stmt), query_ctx) .await } Statement::ShowTables(stmt) => { - self.sql_handler.execute(SqlRequest::ShowTables(stmt)).await + self.sql_handler + .execute(SqlRequest::ShowTables(stmt), query_ctx) + .await } Statement::Explain(stmt) => { self.sql_handler - .execute(SqlRequest::Explain(Box::new(stmt))) + .execute(SqlRequest::Explain(Box::new(stmt)), query_ctx) .await } Statement::DescribeTable(stmt) => { self.sql_handler - .execute(SqlRequest::DescribeTable(stmt)) + .execute(SqlRequest::DescribeTable(stmt), query_ctx) .await } Statement::ShowCreateTable(_stmt) => { unimplemented!("SHOW CREATE TABLE is unimplemented yet"); } + Statement::Use(db) => { + ensure!( + self.catalog_manager + .schema(DEFAULT_CATALOG_NAME, &db) + .context(error::CatalogSnafu)? + .is_some(), + error::SchemaNotFoundSnafu { name: &db } + ); + + query_ctx.set_current_schema(&db); + + Ok(Output::RecordBatches(RecordBatches::empty())) + } } } } +// TODO(LFC): Refactor consideration: move this function to some helper mod, +// could be done together or after `TableReference`'s refactoring, when issue #559 is resolved. +/// Converts maybe fully-qualified table name (`..`) to tuple. +fn table_idents_to_full_name( + obj_name: &ObjectName, + query_ctx: QueryContextRef, +) -> Result<(String, String, String)> { + match &obj_name.0[..] { + [table] => Ok(( + DEFAULT_CATALOG_NAME.to_string(), + query_ctx.current_schema().unwrap_or_else(|| DEFAULT_SCHEMA_NAME.to_string()), + table.value.clone(), + )), + [schema, table] => Ok(( + DEFAULT_CATALOG_NAME.to_string(), + schema.value.clone(), + table.value.clone(), + )), + [catalog, schema, table] => Ok(( + catalog.value.clone(), + schema.value.clone(), + table.value.clone(), + )), + _ => error::InvalidSqlSnafu { + msg: format!( + "expect table name to be ..
, .
or
, actual: {}", + obj_name + ), + }.fail(), + } +} + #[async_trait] impl SqlQueryHandler for Instance { - async fn do_query(&self, query: &str) -> servers::error::Result { + async fn do_query( + &self, + query: &str, + query_ctx: QueryContextRef, + ) -> servers::error::Result { let _timer = timer!(metric::METRIC_HANDLE_SQL_ELAPSED); - self.execute_sql(query) + self.execute_sql(query, query_ctx) .await .map_err(|e| { error!(e; "Instance failed to execute sql"); @@ -149,3 +204,78 @@ impl SqlQueryHandler for Instance { .context(servers::error::ExecuteQuerySnafu { query }) } } + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use session::context::QueryContext; + + use super::*; + + #[test] + fn test_table_idents_to_full_name() { + let my_catalog = "my_catalog"; + let my_schema = "my_schema"; + let my_table = "my_table"; + + let full = ObjectName(vec![my_catalog.into(), my_schema.into(), my_table.into()]); + let partial = ObjectName(vec![my_schema.into(), my_table.into()]); + let bare = ObjectName(vec![my_table.into()]); + + let using_schema = "foo"; + let query_ctx = Arc::new(QueryContext::with_current_schema(using_schema.to_string())); + let empty_ctx = Arc::new(QueryContext::new()); + + assert_eq!( + table_idents_to_full_name(&full, query_ctx.clone()).unwrap(), + ( + my_catalog.to_string(), + my_schema.to_string(), + my_table.to_string() + ) + ); + assert_eq!( + table_idents_to_full_name(&full, empty_ctx.clone()).unwrap(), + ( + my_catalog.to_string(), + my_schema.to_string(), + my_table.to_string() + ) + ); + + assert_eq!( + table_idents_to_full_name(&partial, query_ctx.clone()).unwrap(), + ( + DEFAULT_CATALOG_NAME.to_string(), + my_schema.to_string(), + my_table.to_string() + ) + ); + assert_eq!( + table_idents_to_full_name(&partial, empty_ctx.clone()).unwrap(), + ( + DEFAULT_CATALOG_NAME.to_string(), + my_schema.to_string(), + my_table.to_string() + ) + ); + + assert_eq!( + table_idents_to_full_name(&bare, query_ctx).unwrap(), + ( + DEFAULT_CATALOG_NAME.to_string(), + using_schema.to_string(), + my_table.to_string() + ) + ); + assert_eq!( + table_idents_to_full_name(&bare, empty_ctx).unwrap(), + ( + DEFAULT_CATALOG_NAME.to_string(), + DEFAULT_SCHEMA_NAME.to_string(), + my_table.to_string() + ) + ); + } +} diff --git a/src/datanode/src/server/grpc/ddl.rs b/src/datanode/src/server/grpc/ddl.rs index 5629e0fa9d..26108eb020 100644 --- a/src/datanode/src/server/grpc/ddl.rs +++ b/src/datanode/src/server/grpc/ddl.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + use api::result::AdminResultBuilder; use api::v1::{AdminResult, AlterExpr, CreateExpr, DropTableExpr}; use common_error::prelude::{ErrorExt, StatusCode}; @@ -19,6 +21,7 @@ use common_grpc_expr::{alter_expr_to_request, create_expr_to_request}; use common_query::Output; use common_telemetry::{error, info}; use futures::TryFutureExt; +use session::context::QueryContext; use snafu::prelude::*; use table::requests::DropTableRequest; @@ -72,7 +75,12 @@ impl Instance { let request = create_expr_to_request(table_id, expr).context(CreateExprToRequestSnafu); let result = futures::future::ready(request) - .and_then(|request| self.sql_handler().execute(SqlRequest::CreateTable(request))) + .and_then(|request| { + self.sql_handler().execute( + SqlRequest::CreateTable(request), + Arc::new(QueryContext::new()), + ) + }) .await; match result { Ok(Output::AffectedRows(rows)) => AdminResultBuilder::default() @@ -103,7 +111,10 @@ impl Instance { }; let result = futures::future::ready(request) - .and_then(|request| self.sql_handler().execute(SqlRequest::Alter(request))) + .and_then(|request| { + self.sql_handler() + .execute(SqlRequest::Alter(request), Arc::new(QueryContext::new())) + }) .await; match result { Ok(Output::AffectedRows(rows)) => AdminResultBuilder::default() @@ -124,7 +135,10 @@ impl Instance { schema_name: expr.schema_name, table_name: expr.table_name, }; - let result = self.sql_handler().execute(SqlRequest::DropTable(req)).await; + let result = self + .sql_handler() + .execute(SqlRequest::DropTable(req), Arc::new(QueryContext::new())) + .await; match result { Ok(Output::AffectedRows(rows)) => AdminResultBuilder::default() .status_code(StatusCode::Success as u32) diff --git a/src/datanode/src/sql.rs b/src/datanode/src/sql.rs index 1c097d61d4..0a3b4a999e 100644 --- a/src/datanode/src/sql.rs +++ b/src/datanode/src/sql.rs @@ -12,13 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! sql handler - use catalog::CatalogManagerRef; use common_query::Output; use common_telemetry::error; use query::query_engine::QueryEngineRef; use query::sql::{describe_table, explain, show_databases, show_tables}; +use session::context::QueryContextRef; use snafu::{OptionExt, ResultExt}; use sql::statements::describe::DescribeTable; use sql::statements::explain::Explain; @@ -67,7 +66,11 @@ impl SqlHandler { } } - pub async fn execute(&self, request: SqlRequest) -> Result { + // TODO(LFC): Refactor consideration: a context awareness "Planner". + // Now we have some query related state (like current using database in session context), maybe + // we could create a new struct called `Planner` that stores context and handle these queries + // there, instead of executing here in a "static" fashion. + pub async fn execute(&self, request: SqlRequest, query_ctx: QueryContextRef) -> Result { let result = match request { SqlRequest::Insert(req) => self.insert(req).await, SqlRequest::CreateTable(req) => self.create_table(req).await, @@ -78,12 +81,12 @@ impl SqlHandler { show_databases(stmt, self.catalog_manager.clone()).context(ExecuteSqlSnafu) } SqlRequest::ShowTables(stmt) => { - show_tables(stmt, self.catalog_manager.clone()).context(ExecuteSqlSnafu) + show_tables(stmt, self.catalog_manager.clone(), query_ctx).context(ExecuteSqlSnafu) } SqlRequest::DescribeTable(stmt) => { describe_table(stmt, self.catalog_manager.clone()).context(ExecuteSqlSnafu) } - SqlRequest::Explain(stmt) => explain(stmt, self.query_engine.clone()) + SqlRequest::Explain(stmt) => explain(stmt, self.query_engine.clone(), query_ctx) .await .context(ExecuteSqlSnafu), }; @@ -114,7 +117,8 @@ mod tests { use std::any::Any; use std::sync::Arc; - use catalog::SchemaProvider; + use catalog::{CatalogList, SchemaProvider}; + use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_query::logical_plan::Expr; use common_query::physical_plan::PhysicalPlanRef; use common_time::timestamp::Timestamp; @@ -234,9 +238,17 @@ mod tests { .await .unwrap(), ); + let catalog_provider = catalog_list.catalog(DEFAULT_CATALOG_NAME).unwrap().unwrap(); + catalog_provider + .register_schema( + DEFAULT_SCHEMA_NAME.to_string(), + Arc::new(MockSchemaProvider {}), + ) + .unwrap(); + let factory = QueryEngineFactory::new(catalog_list.clone()); let query_engine = factory.query_engine(); - let sql_handler = SqlHandler::new(table_engine, catalog_list, query_engine.clone()); + let sql_handler = SqlHandler::new(table_engine, catalog_list.clone(), query_engine.clone()); let stmt = match query_engine.sql_to_statement(sql).unwrap() { Statement::Insert(i) => i, @@ -244,9 +256,8 @@ mod tests { unreachable!() } }; - let schema_provider = Arc::new(MockSchemaProvider {}); let request = sql_handler - .insert_to_request(schema_provider, *stmt) + .insert_to_request(catalog_list.clone(), *stmt, TableReference::bare("demo")) .unwrap(); match request { diff --git a/src/datanode/src/sql/alter.rs b/src/datanode/src/sql/alter.rs index 077ebd0a9c..77fada09fd 100644 --- a/src/datanode/src/sql/alter.rs +++ b/src/datanode/src/sql/alter.rs @@ -16,7 +16,7 @@ use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_query::Output; use snafu::prelude::*; use sql::statements::alter::{AlterTable, AlterTableOperation}; -use sql::statements::{column_def_to_schema, table_idents_to_full_name}; +use sql::statements::column_def_to_schema; use table::engine::{EngineContext, TableReference}; use table::requests::{AddColumnRequest, AlterKind, AlterTableRequest}; @@ -53,10 +53,11 @@ impl SqlHandler { Ok(Output::AffectedRows(0)) } - pub(crate) fn alter_to_request(&self, alter_table: AlterTable) -> Result { - let (catalog_name, schema_name, table_name) = - table_idents_to_full_name(alter_table.table_name()).context(error::ParseSqlSnafu)?; - + pub(crate) fn alter_to_request( + &self, + alter_table: AlterTable, + table_ref: TableReference, + ) -> Result { let alter_kind = match alter_table.alter_operation() { AlterTableOperation::AddConstraint(table_constraint) => { return error::InvalidSqlSnafu { @@ -77,9 +78,9 @@ impl SqlHandler { }, }; Ok(AlterTableRequest { - catalog_name: Some(catalog_name), - schema_name: Some(schema_name), - table_name, + catalog_name: Some(table_ref.catalog.to_string()), + schema_name: Some(table_ref.schema.to_string()), + table_name: table_ref.table.to_string(), alter_kind, }) } @@ -112,7 +113,9 @@ mod tests { async fn test_alter_to_request_with_adding_column() { let handler = create_mock_sql_handler().await; let alter_table = parse_sql("ALTER TABLE my_metric_1 ADD tagk_i STRING Null;"); - let req = handler.alter_to_request(alter_table).unwrap(); + let req = handler + .alter_to_request(alter_table, TableReference::bare("my_metric_1")) + .unwrap(); assert_eq!(req.catalog_name, Some("greptime".to_string())); assert_eq!(req.schema_name, Some("public".to_string())); assert_eq!(req.table_name, "my_metric_1"); diff --git a/src/datanode/src/sql/create.rs b/src/datanode/src/sql/create.rs index c82484ec88..ba80682b62 100644 --- a/src/datanode/src/sql/create.rs +++ b/src/datanode/src/sql/create.rs @@ -23,10 +23,10 @@ use common_telemetry::tracing::log::error; use datatypes::schema::SchemaBuilder; use snafu::{ensure, OptionExt, ResultExt}; use sql::ast::TableConstraint; +use sql::statements::column_def_to_schema; use sql::statements::create::CreateTable; -use sql::statements::{column_def_to_schema, table_idents_to_full_name}; use store_api::storage::consts::TIME_INDEX_NAME; -use table::engine::EngineContext; +use table::engine::{EngineContext, TableReference}; use table::metadata::TableId; use table::requests::*; @@ -114,13 +114,11 @@ impl SqlHandler { &self, table_id: TableId, stmt: CreateTable, + table_ref: TableReference, ) -> Result { let mut ts_index = usize::MAX; let mut primary_keys = vec![]; - let (catalog_name, schema_name, table_name) = - table_idents_to_full_name(&stmt.name).context(error::ParseSqlSnafu)?; - let col_map = stmt .columns .iter() @@ -187,8 +185,8 @@ impl SqlHandler { if primary_keys.is_empty() { info!( - "Creating table: {:?}.{:?}.{} but primary key not set, use time index column: {}", - catalog_name, schema_name, table_name, ts_index + "Creating table: {} with time index column: {} upon primary keys absent", + table_ref, ts_index ); primary_keys.push(ts_index); } @@ -211,9 +209,9 @@ impl SqlHandler { let request = CreateTableRequest { id: table_id, - catalog_name, - schema_name, - table_name, + catalog_name: table_ref.catalog.to_string(), + schema_name: table_ref.schema.to_string(), + table_name: table_ref.table.to_string(), desc: None, schema, region_numbers: vec![0], @@ -261,7 +259,9 @@ mod tests { TIME INDEX (ts), PRIMARY KEY(host)) engine=mito with(regions=1);"#, ); - let c = handler.create_to_request(42, parsed_stmt).unwrap(); + let c = handler + .create_to_request(42, parsed_stmt, TableReference::bare("demo_table")) + .unwrap(); assert_eq!("demo_table", c.table_name); assert_eq!(42, c.id); assert!(!c.create_if_not_exists); @@ -282,7 +282,9 @@ mod tests { memory double, PRIMARY KEY(host)) engine=mito with(regions=1);"#, ); - let error = handler.create_to_request(42, parsed_stmt).unwrap_err(); + let error = handler + .create_to_request(42, parsed_stmt, TableReference::bare("demo_table")) + .unwrap_err(); assert_matches!(error, Error::MissingTimestampColumn { .. }); } @@ -299,7 +301,9 @@ mod tests { memory double, TIME INDEX (ts)) engine=mito with(regions=1);"#, ); - let c = handler.create_to_request(42, parsed_stmt).unwrap(); + let c = handler + .create_to_request(42, parsed_stmt, TableReference::bare("demo_table")) + .unwrap(); assert_eq!(1, c.primary_key_indices.len()); assert_eq!( c.schema.timestamp_index().unwrap(), @@ -318,7 +322,9 @@ mod tests { TIME INDEX (ts)) engine=mito with(regions=1);"#, ); - let error = handler.create_to_request(42, parsed_stmt).unwrap_err(); + let error = handler + .create_to_request(42, parsed_stmt, TableReference::bare("demo_table")) + .unwrap_err(); assert_matches!(error, Error::KeyColumnNotFound { .. }); } @@ -338,7 +344,9 @@ mod tests { let handler = create_mock_sql_handler().await; - let error = handler.create_to_request(42, create_table).unwrap_err(); + let error = handler + .create_to_request(42, create_table, TableReference::full("c", "s", "demo")) + .unwrap_err(); assert_matches!(error, Error::InvalidPrimaryKey { .. }); } @@ -358,7 +366,9 @@ mod tests { let handler = create_mock_sql_handler().await; - let request = handler.create_to_request(42, create_table).unwrap(); + let request = handler + .create_to_request(42, create_table, TableReference::full("c", "s", "demo")) + .unwrap(); assert_eq!(42, request.id); assert_eq!("c".to_string(), request.catalog_name); diff --git a/src/datanode/src/sql/insert.rs b/src/datanode/src/sql/insert.rs index 00aa59a026..8c2dae5c4a 100644 --- a/src/datanode/src/sql/insert.rs +++ b/src/datanode/src/sql/insert.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use catalog::SchemaProviderRef; +use catalog::CatalogManagerRef; use common_query::Output; use datatypes::prelude::{ConcreteDataType, VectorBuilder}; use snafu::{ensure, OptionExt, ResultExt}; @@ -23,7 +23,7 @@ use table::engine::TableReference; use table::requests::*; use crate::error::{ - CatalogSnafu, ColumnNotFoundSnafu, ColumnValuesNumberMismatchSnafu, InsertSnafu, ParseSqlSnafu, + CatalogSnafu, ColumnNotFoundSnafu, ColumnValuesNumberMismatchSnafu, InsertSnafu, ParseSqlValueSnafu, Result, TableNotFoundSnafu, }; use crate::sql::{SqlHandler, SqlRequest}; @@ -49,19 +49,18 @@ impl SqlHandler { pub(crate) fn insert_to_request( &self, - schema_provider: SchemaProviderRef, + catalog_manager: CatalogManagerRef, stmt: Insert, + table_ref: TableReference, ) -> Result { let columns = stmt.columns(); let values = stmt.values().context(ParseSqlValueSnafu)?; - let (catalog_name, schema_name, table_name) = - stmt.full_table_name().context(ParseSqlSnafu)?; - let table = schema_provider - .table(&table_name) + let table = catalog_manager + .table(table_ref.catalog, table_ref.schema, table_ref.table) .context(CatalogSnafu)? .context(TableNotFoundSnafu { - table_name: &table_name, + table_name: table_ref.table, })?; let schema = table.schema(); let columns_num = if columns.is_empty() { @@ -88,7 +87,7 @@ impl SqlHandler { let column_schema = schema.column_schema_by_name(column_name).with_context(|| { ColumnNotFoundSnafu { - table_name: &table_name, + table_name: table_ref.table, column_name: column_name.to_string(), } })?; @@ -119,9 +118,9 @@ impl SqlHandler { } Ok(SqlRequest::Insert(InsertRequest { - catalog_name, - schema_name, - table_name, + catalog_name: table_ref.catalog.to_string(), + schema_name: table_ref.schema.to_string(), + table_name: table_ref.table.to_string(), columns_values: columns_builders .into_iter() .map(|(c, _, mut b)| (c.to_owned(), b.finish())) diff --git a/src/datanode/src/tests/instance_test.rs b/src/datanode/src/tests/instance_test.rs index 6914058ffb..b93759b3c7 100644 --- a/src/datanode/src/tests/instance_test.rs +++ b/src/datanode/src/tests/instance_test.rs @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + +use common_catalog::consts::DEFAULT_SCHEMA_NAME; use common_query::Output; use common_recordbatch::util; use datafusion::arrow_print; @@ -19,6 +22,7 @@ use datafusion_common::record_batch::RecordBatch as DfRecordBatch; use datatypes::arrow::array::{Int64Array, UInt64Array, Utf8Array}; use datatypes::arrow_array::StringArray; use datatypes::prelude::ConcreteDataType; +use session::context::QueryContext; use crate::instance::Instance; use crate::tests::test_util; @@ -32,39 +36,33 @@ async fn test_create_database_and_insert_query() { let instance = Instance::with_mock_meta_client(&opts).await.unwrap(); instance.start().await.unwrap(); - let output = instance.execute_sql("create database test").await.unwrap(); + let output = execute_sql(&instance, "create database test").await; assert!(matches!(output, Output::AffectedRows(1))); - let output = instance - .execute_sql( - r#"create table greptime.test.demo( + let output = execute_sql( + &instance, + r#"create table greptime.test.demo( host STRING, cpu DOUBLE, memory DOUBLE, ts bigint, TIME INDEX(ts) )"#, - ) - .await - .unwrap(); + ) + .await; assert!(matches!(output, Output::AffectedRows(1))); - let output = instance - .execute_sql( - r#"insert into test.demo(host, cpu, memory, ts) values + let output = execute_sql( + &instance, + r#"insert into test.demo(host, cpu, memory, ts) values ('host1', 66.6, 1024, 1655276557000), ('host2', 88.8, 333.3, 1655276558000) "#, - ) - .await - .unwrap(); + ) + .await; assert!(matches!(output, Output::AffectedRows(2))); - let query_output = instance - .execute_sql("select ts from test.demo order by ts") - .await - .unwrap(); - + let query_output = execute_sql(&instance, "select ts from test.demo order by ts").await; match query_output { Output::Stream(s) => { let batches = util::collect(s).await.unwrap(); @@ -88,54 +86,50 @@ async fn test_issue477_same_table_name_in_different_databases() { instance.start().await.unwrap(); // Create database a and b - let output = instance.execute_sql("create database a").await.unwrap(); + let output = execute_sql(&instance, "create database a").await; assert!(matches!(output, Output::AffectedRows(1))); - let output = instance.execute_sql("create database b").await.unwrap(); + let output = execute_sql(&instance, "create database b").await; assert!(matches!(output, Output::AffectedRows(1))); // Create table a.demo and b.demo - let output = instance - .execute_sql( - r#"create table a.demo( + let output = execute_sql( + &instance, + r#"create table a.demo( host STRING, ts bigint, TIME INDEX(ts) )"#, - ) - .await - .unwrap(); + ) + .await; assert!(matches!(output, Output::AffectedRows(1))); - let output = instance - .execute_sql( - r#"create table b.demo( + let output = execute_sql( + &instance, + r#"create table b.demo( host STRING, ts bigint, TIME INDEX(ts) )"#, - ) - .await - .unwrap(); + ) + .await; assert!(matches!(output, Output::AffectedRows(1))); // Insert different data into a.demo and b.demo - let output = instance - .execute_sql( - r#"insert into a.demo(host, ts) values + let output = execute_sql( + &instance, + r#"insert into a.demo(host, ts) values ('host1', 1655276557000) "#, - ) - .await - .unwrap(); + ) + .await; assert!(matches!(output, Output::AffectedRows(1))); - let output = instance - .execute_sql( - r#"insert into b.demo(host, ts) values + let output = execute_sql( + &instance, + r#"insert into b.demo(host, ts) values ('host2',1655276558000) "#, - ) - .await - .unwrap(); + ) + .await; assert!(matches!(output, Output::AffectedRows(1))); // Query data and assert @@ -157,7 +151,7 @@ async fn test_issue477_same_table_name_in_different_databases() { } async fn assert_query_result(instance: &Instance, sql: &str, ts: i64, host: &str) { - let query_output = instance.execute_sql(sql).await.unwrap(); + let query_output = execute_sql(instance, sql).await; match query_output { Output::Stream(s) => { let batches = util::collect(s).await.unwrap(); @@ -200,15 +194,14 @@ async fn setup_test_instance() -> Instance { #[tokio::test(flavor = "multi_thread")] async fn test_execute_insert() { let instance = setup_test_instance().await; - let output = instance - .execute_sql( - r#"insert into demo(host, cpu, memory, ts) values + let output = execute_sql( + &instance, + r#"insert into demo(host, cpu, memory, ts) values ('host1', 66.6, 1024, 1655276557000), ('host2', 88.8, 333.3, 1655276558000) "#, - ) - .await - .unwrap(); + ) + .await; assert!(matches!(output, Output::AffectedRows(2))); } @@ -228,22 +221,17 @@ async fn test_execute_insert_query_with_i64_timestamp() { .await .unwrap(); - let output = instance - .execute_sql( - r#"insert into demo(host, cpu, memory, ts) values + let output = execute_sql( + &instance, + r#"insert into demo(host, cpu, memory, ts) values ('host1', 66.6, 1024, 1655276557000), ('host2', 88.8, 333.3, 1655276558000) "#, - ) - .await - .unwrap(); + ) + .await; assert!(matches!(output, Output::AffectedRows(2))); - let query_output = instance - .execute_sql("select ts from demo order by ts") - .await - .unwrap(); - + let query_output = execute_sql(&instance, "select ts from demo order by ts").await; match query_output { Output::Stream(s) => { let batches = util::collect(s).await.unwrap(); @@ -257,11 +245,7 @@ async fn test_execute_insert_query_with_i64_timestamp() { _ => unreachable!(), } - let query_output = instance - .execute_sql("select ts as time from demo order by ts") - .await - .unwrap(); - + let query_output = execute_sql(&instance, "select ts as time from demo order by ts").await; match query_output { Output::Stream(s) => { let batches = util::collect(s).await.unwrap(); @@ -282,10 +266,7 @@ async fn test_execute_query() { let instance = Instance::with_mock_meta_client(&opts).await.unwrap(); instance.start().await.unwrap(); - let output = instance - .execute_sql("select sum(number) from numbers limit 20") - .await - .unwrap(); + let output = execute_sql(&instance, "select sum(number) from numbers limit 20").await; match output { Output::Stream(recordbatch) => { let numbers = util::collect(recordbatch).await.unwrap(); @@ -309,7 +290,7 @@ async fn test_execute_show_databases_tables() { let instance = Instance::with_mock_meta_client(&opts).await.unwrap(); instance.start().await.unwrap(); - let output = instance.execute_sql("show databases").await.unwrap(); + let output = execute_sql(&instance, "show databases").await; match output { Output::RecordBatches(databases) => { let databases = databases.take(); @@ -325,10 +306,7 @@ async fn test_execute_show_databases_tables() { _ => unreachable!(), } - let output = instance - .execute_sql("show databases like '%bl%'") - .await - .unwrap(); + let output = execute_sql(&instance, "show databases like '%bl%'").await; match output { Output::RecordBatches(databases) => { let databases = databases.take(); @@ -344,7 +322,7 @@ async fn test_execute_show_databases_tables() { _ => unreachable!(), } - let output = instance.execute_sql("show tables").await.unwrap(); + let output = execute_sql(&instance, "show tables").await; match output { Output::RecordBatches(databases) => { let databases = databases.take(); @@ -364,7 +342,7 @@ async fn test_execute_show_databases_tables() { .await .unwrap(); - let output = instance.execute_sql("show tables").await.unwrap(); + let output = execute_sql(&instance, "show tables").await; match output { Output::RecordBatches(databases) => { let databases = databases.take(); @@ -376,10 +354,7 @@ async fn test_execute_show_databases_tables() { } // show tables like [string] - let output = instance - .execute_sql("show tables like 'de%'") - .await - .unwrap(); + let output = execute_sql(&instance, "show tables like 'de%'").await; match output { Output::RecordBatches(databases) => { let databases = databases.take(); @@ -404,9 +379,9 @@ pub async fn test_execute_create() { let instance = Instance::with_mock_meta_client(&opts).await.unwrap(); instance.start().await.unwrap(); - let output = instance - .execute_sql( - r#"create table test_table( + let output = execute_sql( + &instance, + r#"create table test_table( host string, ts timestamp, cpu double default 0, @@ -414,56 +389,24 @@ pub async fn test_execute_create() { TIME INDEX (ts), PRIMARY KEY(host) ) engine=mito with(regions=1);"#, - ) - .await - .unwrap(); + ) + .await; assert!(matches!(output, Output::AffectedRows(1))); } -#[tokio::test(flavor = "multi_thread")] -pub async fn test_create_table_illegal_timestamp_type() { - common_telemetry::init_default_ut_logging(); - - let (opts, _guard) = - test_util::create_tmp_dir_and_datanode_opts("create_table_illegal_timestamp_type"); - let instance = Instance::with_mock_meta_client(&opts).await.unwrap(); - instance.start().await.unwrap(); - - let output = instance - .execute_sql( - r#"create table test_table( - host string, - ts bigint, - cpu double default 0, - memory double, - TIME INDEX (ts), - PRIMARY KEY(host) - ) engine=mito with(regions=1);"#, - ) - .await - .unwrap(); - match output { - Output::AffectedRows(rows) => { - assert_eq!(1, rows); - } - _ => unreachable!(), - } -} - async fn check_output_stream(output: Output, expected: Vec<&str>) { - match output { - Output::Stream(stream) => { - let recordbatches = util::collect(stream).await.unwrap(); - let recordbatch = recordbatches - .into_iter() - .map(|r| r.df_recordbatch) - .collect::>(); - let pretty_print = arrow_print::write(&recordbatch); - let pretty_print = pretty_print.lines().collect::>(); - assert_eq!(pretty_print, expected); - } + let recordbatches = match output { + Output::Stream(stream) => util::collect(stream).await.unwrap(), + Output::RecordBatches(recordbatches) => recordbatches.take(), _ => unreachable!(), - } + }; + let recordbatches = recordbatches + .into_iter() + .map(|r| r.df_recordbatch) + .collect::>(); + let pretty_print = arrow_print::write(&recordbatches); + let pretty_print = pretty_print.lines().collect::>(); + assert_eq!(pretty_print, expected); } #[tokio::test] @@ -479,35 +422,30 @@ async fn test_alter_table() { .await .unwrap(); // make sure table insertion is ok before altering table - instance - .execute_sql("insert into demo(host, cpu, memory, ts) values ('host1', 1.1, 100, 1000)") - .await - .unwrap(); + execute_sql( + &instance, + "insert into demo(host, cpu, memory, ts) values ('host1', 1.1, 100, 1000)", + ) + .await; // Add column - let output = instance - .execute_sql("alter table demo add my_tag string null") - .await - .unwrap(); + let output = execute_sql(&instance, "alter table demo add my_tag string null").await; assert!(matches!(output, Output::AffectedRows(0))); - let output = instance - .execute_sql( - "insert into demo(host, cpu, memory, ts, my_tag) values ('host2', 2.2, 200, 2000, 'hello')", - ) - .await - .unwrap(); + let output = execute_sql( + &instance, + "insert into demo(host, cpu, memory, ts, my_tag) values ('host2', 2.2, 200, 2000, 'hello')", + ) + .await; assert!(matches!(output, Output::AffectedRows(1))); - let output = instance - .execute_sql("insert into demo(host, cpu, memory, ts) values ('host3', 3.3, 300, 3000)") - .await - .unwrap(); + let output = execute_sql( + &instance, + "insert into demo(host, cpu, memory, ts) values ('host3', 3.3, 300, 3000)", + ) + .await; assert!(matches!(output, Output::AffectedRows(1))); - let output = instance - .execute_sql("select * from demo order by ts") - .await - .unwrap(); + let output = execute_sql(&instance, "select * from demo order by ts").await; let expected = vec![ "+-------+-----+--------+---------------------+--------+", "| host | cpu | memory | ts | my_tag |", @@ -520,16 +458,10 @@ async fn test_alter_table() { check_output_stream(output, expected).await; // Drop a column - let output = instance - .execute_sql("alter table demo drop column memory") - .await - .unwrap(); + let output = execute_sql(&instance, "alter table demo drop column memory").await; assert!(matches!(output, Output::AffectedRows(0))); - let output = instance - .execute_sql("select * from demo order by ts") - .await - .unwrap(); + let output = execute_sql(&instance, "select * from demo order by ts").await; let expected = vec![ "+-------+-----+---------------------+--------+", "| host | cpu | ts | my_tag |", @@ -542,16 +474,14 @@ async fn test_alter_table() { check_output_stream(output, expected).await; // insert a new row - let output = instance - .execute_sql("insert into demo(host, cpu, ts, my_tag) values ('host4', 400, 4000, 'world')") - .await - .unwrap(); + let output = execute_sql( + &instance, + "insert into demo(host, cpu, ts, my_tag) values ('host4', 400, 4000, 'world')", + ) + .await; assert!(matches!(output, Output::AffectedRows(1))); - let output = instance - .execute_sql("select * from demo order by ts") - .await - .unwrap(); + let output = execute_sql(&instance, "select * from demo order by ts").await; let expected = vec![ "+-------+-----+---------------------+--------+", "| host | cpu | ts | my_tag |", @@ -580,27 +510,26 @@ async fn test_insert_with_default_value_for_type(type_name: &str) { ) engine=mito with(regions=1);"#, type_name ); - let output = instance.execute_sql(&create_sql).await.unwrap(); + let output = execute_sql(&instance, &create_sql).await; assert!(matches!(output, Output::AffectedRows(1))); // Insert with ts. - instance - .execute_sql("insert into test_table(host, cpu, ts) values ('host1', 1.1, 1000)") - .await - .unwrap(); + let output = execute_sql( + &instance, + "insert into test_table(host, cpu, ts) values ('host1', 1.1, 1000)", + ) + .await; assert!(matches!(output, Output::AffectedRows(1))); // Insert without ts, so it should be filled by default value. - let output = instance - .execute_sql("insert into test_table(host, cpu) values ('host2', 2.2)") - .await - .unwrap(); + let output = execute_sql( + &instance, + "insert into test_table(host, cpu) values ('host2', 2.2)", + ) + .await; assert!(matches!(output, Output::AffectedRows(1))); - let output = instance - .execute_sql("select host, cpu from test_table") - .await - .unwrap(); + let output = execute_sql(&instance, "select host, cpu from test_table").await; let expected = vec![ "+-------+-----+", "| host | cpu |", @@ -619,3 +548,70 @@ async fn test_insert_with_default_value() { test_insert_with_default_value_for_type("timestamp").await; test_insert_with_default_value_for_type("bigint").await; } + +#[tokio::test(flavor = "multi_thread")] +async fn test_use_database() { + let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("use_database"); + let instance = Instance::with_mock_meta_client(&opts).await.unwrap(); + instance.start().await.unwrap(); + + let output = execute_sql(&instance, "create database db1").await; + assert!(matches!(output, Output::AffectedRows(1))); + + let output = execute_sql_in_db( + &instance, + "create table tb1(col_i32 int, ts bigint, TIME INDEX(ts))", + "db1", + ) + .await; + assert!(matches!(output, Output::AffectedRows(1))); + + let output = execute_sql_in_db(&instance, "show tables", "db1").await; + let expected = vec![ + "+--------+", + "| Tables |", + "+--------+", + "| tb1 |", + "+--------+", + ]; + check_output_stream(output, expected).await; + + let output = execute_sql_in_db( + &instance, + r#"insert into tb1(col_i32, ts) values (1, 1655276557000)"#, + "db1", + ) + .await; + assert!(matches!(output, Output::AffectedRows(1))); + + let output = execute_sql_in_db(&instance, "select col_i32 from tb1", "db1").await; + let expected = vec![ + "+---------+", + "| col_i32 |", + "+---------+", + "| 1 |", + "+---------+", + ]; + check_output_stream(output, expected).await; + + // Making a particular database the default by means of the USE statement does not preclude + // accessing tables in other databases. + let output = execute_sql(&instance, "select number from public.numbers limit 1").await; + let expected = vec![ + "+--------+", + "| number |", + "+--------+", + "| 0 |", + "+--------+", + ]; + check_output_stream(output, expected).await; +} + +async fn execute_sql(instance: &Instance, sql: &str) -> Output { + execute_sql_in_db(instance, sql, DEFAULT_SCHEMA_NAME).await +} + +async fn execute_sql_in_db(instance: &Instance, sql: &str, db: &str) -> Output { + let query_ctx = Arc::new(QueryContext::with_current_schema(db.to_string())); + instance.execute_sql(sql, query_ctx).await.unwrap() +} diff --git a/src/frontend/Cargo.toml b/src/frontend/Cargo.toml index d89124c11d..9678dd2192 100644 --- a/src/frontend/Cargo.toml +++ b/src/frontend/Cargo.toml @@ -38,6 +38,7 @@ prost = "0.11" query = { path = "../query" } serde = "1.0" serde_json = "1.0" +session = { path = "../session" } sqlparser = "0.15" servers = { path = "../servers" } snafu = { version = "0.7", features = ["backtraces"] } diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index 25a8e99d2a..1a3aee1deb 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -38,6 +38,7 @@ use common_error::prelude::{BoxedError, StatusCode}; use common_grpc::channel_manager::{ChannelConfig, ChannelManager}; use common_grpc::select::to_object_result; use common_query::Output; +use common_recordbatch::RecordBatches; use common_telemetry::{debug, error, info}; use distributed::DistInstance; use meta_client::client::MetaClientBuilder; @@ -47,6 +48,7 @@ use servers::query_handler::{ PrometheusProtocolHandler, ScriptHandler, ScriptHandlerRef, SqlQueryHandler, }; use servers::{error as server_error, Mode}; +use session::context::{QueryContext, QueryContextRef}; use snafu::prelude::*; use sql::dialect::GenericDialect; use sql::parser::ParserContext; @@ -211,10 +213,15 @@ impl Instance { self.script_handler = Some(handler); } - pub async fn handle_select(&self, expr: Select, stmt: Statement) -> Result { + async fn handle_select( + &self, + expr: Select, + stmt: Statement, + query_ctx: QueryContextRef, + ) -> Result { if let Some(dist_instance) = &self.dist_instance { let Select::Sql(sql) = expr; - dist_instance.handle_sql(&sql, stmt).await + dist_instance.handle_sql(&sql, stmt, query_ctx).await } else { // TODO(LFC): Refactor consideration: Datanode should directly execute statement in standalone mode to avoid parse SQL again. // Find a better way to execute query between Frontend and Datanode in standalone mode. @@ -298,10 +305,15 @@ impl Instance { } /// Handle explain expr - pub async fn handle_explain(&self, sql: &str, explain_stmt: Explain) -> Result { + pub async fn handle_explain( + &self, + sql: &str, + explain_stmt: Explain, + query_ctx: QueryContextRef, + ) -> Result { if let Some(dist_instance) = &self.dist_instance { dist_instance - .handle_sql(sql, Statement::Explain(explain_stmt)) + .handle_sql(sql, Statement::Explain(explain_stmt), query_ctx) .await } else { Ok(Output::AffectedRows(0)) @@ -505,6 +517,26 @@ impl Instance { let insert_request = insert_to_request(&schema_provider, *insert)?; insert_request_to_insert_batch(&insert_request) } + + fn handle_use(&self, db: String, query_ctx: QueryContextRef) -> Result { + let catalog_manager = &self.catalog_manager; + if let Some(catalog_manager) = catalog_manager { + ensure!( + catalog_manager + .schema(DEFAULT_CATALOG_NAME, &db) + .context(error::CatalogSnafu)? + .is_some(), + error::SchemaNotFoundSnafu { schema_info: &db } + ); + + query_ctx.set_current_schema(&db); + + Ok(Output::RecordBatches(RecordBatches::empty())) + } else { + // TODO(LFC): Handle "use" stmt here. + unimplemented!() + } + } } #[async_trait] @@ -545,17 +577,23 @@ fn parse_stmt(sql: &str) -> Result { #[async_trait] impl SqlQueryHandler for Instance { - async fn do_query(&self, query: &str) -> server_error::Result { + async fn do_query( + &self, + query: &str, + query_ctx: QueryContextRef, + ) -> server_error::Result { let stmt = parse_stmt(query) .map_err(BoxedError::new) .context(server_error::ExecuteQuerySnafu { query })?; match stmt { - Statement::Query(_) => self - .handle_select(Select::Sql(query.to_string()), stmt) - .await - .map_err(BoxedError::new) - .context(server_error::ExecuteQuerySnafu { query }), + Statement::ShowDatabases(_) + | Statement::ShowTables(_) + | Statement::DescribeTable(_) + | Statement::Query(_) => { + self.handle_select(Select::Sql(query.to_string()), stmt, query_ctx) + .await + } Statement::Insert(insert) => match self.mode { Mode::Standalone => { let (catalog_name, schema_name, table_name) = insert @@ -578,10 +616,7 @@ impl SqlQueryHandler for Instance { columns, row_count, }; - self.handle_insert(expr) - .await - .map_err(BoxedError::new) - .context(server_error::ExecuteQuerySnafu { query }) + self.handle_insert(expr).await } Mode::Distributed => { let affected = self @@ -604,55 +639,36 @@ impl SqlQueryHandler for Instance { self.handle_create_table(create_expr, create.partitions) .await - .map_err(BoxedError::new) - .context(server_error::ExecuteQuerySnafu { query }) } - - Statement::ShowDatabases(_) - | Statement::ShowTables(_) - | Statement::DescribeTable(_) => self - .handle_select(Select::Sql(query.to_string()), stmt) - .await - .map_err(BoxedError::new) - .context(server_error::ExecuteQuerySnafu { query }), - Statement::CreateDatabase(c) => { let expr = CreateDatabaseExpr { database_name: c.name.to_string(), }; - self.handle_create_database(expr) - .await - .map_err(BoxedError::new) - .context(server_error::ExecuteQuerySnafu { query }) + self.handle_create_database(expr).await } - Statement::Alter(alter_stmt) => self - .handle_alter( + Statement::Alter(alter_stmt) => { + self.handle_alter( AlterExpr::try_from(alter_stmt) .map_err(BoxedError::new) .context(server_error::ExecuteAlterSnafu { query })?, ) .await - .map_err(BoxedError::new) - .context(server_error::ExecuteQuerySnafu { query }), + } Statement::DropTable(drop_stmt) => { let expr = DropTableExpr { catalog_name: drop_stmt.catalog_name, schema_name: drop_stmt.schema_name, table_name: drop_stmt.table_name, }; - self.handle_drop_table(expr) - .await - .map_err(BoxedError::new) - .context(server_error::ExecuteQuerySnafu { query }) + self.handle_drop_table(expr).await + } + Statement::Explain(explain_stmt) => { + self.handle_explain(query, explain_stmt, query_ctx).await } - Statement::Explain(explain_stmt) => self - .handle_explain(query, explain_stmt) - .await - .map_err(BoxedError::new) - .context(server_error::ExecuteQuerySnafu { query }), Statement::ShowCreateTable(_) => { return server_error::NotSupportedSnafu { feat: query }.fail(); } + Statement::Use(db) => self.handle_use(db, query_ctx), } .map_err(BoxedError::new) .context(server_error::ExecuteQuerySnafu { query }) @@ -716,7 +732,8 @@ impl GrpcQueryHandler for Instance { })?; match select { select_expr::Expr::Sql(sql) => { - let output = SqlQueryHandler::do_query(self, sql).await; + let query_ctx = Arc::new(QueryContext::new()); + let output = SqlQueryHandler::do_query(self, sql, query_ctx).await; Ok(to_object_result(output).await) } _ => { @@ -797,6 +814,8 @@ mod tests { #[tokio::test] async fn test_execute_sql() { + let query_ctx = Arc::new(QueryContext::new()); + let instance = tests::create_frontend_instance().await; let sql = r#"CREATE TABLE demo( @@ -808,7 +827,9 @@ mod tests { TIME INDEX (ts), PRIMARY KEY(ts, host) ) engine=mito with(regions=1);"#; - let output = SqlQueryHandler::do_query(&*instance, sql).await.unwrap(); + let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone()) + .await + .unwrap(); match output { Output::AffectedRows(rows) => assert_eq!(rows, 1), _ => unreachable!(), @@ -819,14 +840,18 @@ mod tests { ('frontend.host2', null, null, 2000), ('frontend.host3', 3.3, 300, 3000) "#; - let output = SqlQueryHandler::do_query(&*instance, sql).await.unwrap(); + let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone()) + .await + .unwrap(); match output { Output::AffectedRows(rows) => assert_eq!(rows, 3), _ => unreachable!(), } let sql = "select * from demo"; - let output = SqlQueryHandler::do_query(&*instance, sql).await.unwrap(); + let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone()) + .await + .unwrap(); match output { Output::RecordBatches(recordbatches) => { let pretty_print = recordbatches.pretty_print(); @@ -846,7 +871,9 @@ mod tests { }; let sql = "select * from demo where ts>cast(1000000000 as timestamp)"; // use nanoseconds as where condition - let output = SqlQueryHandler::do_query(&*instance, sql).await.unwrap(); + let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone()) + .await + .unwrap(); match output { Output::RecordBatches(recordbatches) => { let pretty_print = recordbatches.pretty_print(); diff --git a/src/frontend/src/instance/distributed.rs b/src/frontend/src/instance/distributed.rs index 4098f040f2..a96f817035 100644 --- a/src/frontend/src/instance/distributed.rs +++ b/src/frontend/src/instance/distributed.rs @@ -33,6 +33,7 @@ use meta_client::rpc::{ }; use query::sql::{describe_table, explain, show_databases, show_tables}; use query::{QueryEngineFactory, QueryEngineRef}; +use session::context::QueryContextRef; use snafu::{ensure, OptionExt, ResultExt}; use sql::statements::create::Partitions; use sql::statements::sql_value_to_value; @@ -128,29 +129,31 @@ impl DistInstance { Ok(Output::AffectedRows(region_routes.len())) } - pub(crate) async fn handle_sql(&self, sql: &str, stmt: Statement) -> Result { + pub(crate) async fn handle_sql( + &self, + sql: &str, + stmt: Statement, + query_ctx: QueryContextRef, + ) -> Result { match stmt { Statement::Query(_) => { let plan = self .query_engine - .statement_to_plan(stmt) + .statement_to_plan(stmt, query_ctx) .context(error::ExecuteSqlSnafu { sql })?; - self.query_engine - .execute(&plan) - .await - .context(error::ExecuteSqlSnafu { sql }) + self.query_engine.execute(&plan).await + } + Statement::ShowDatabases(stmt) => show_databases(stmt, self.catalog_manager.clone()), + Statement::ShowTables(stmt) => { + show_tables(stmt, self.catalog_manager.clone(), query_ctx) + } + Statement::DescribeTable(stmt) => describe_table(stmt, self.catalog_manager.clone()), + Statement::Explain(stmt) => { + explain(Box::new(stmt), self.query_engine.clone(), query_ctx).await } - Statement::ShowDatabases(stmt) => show_databases(stmt, self.catalog_manager.clone()) - .context(error::ExecuteSqlSnafu { sql }), - Statement::ShowTables(stmt) => show_tables(stmt, self.catalog_manager.clone()) - .context(error::ExecuteSqlSnafu { sql }), - Statement::DescribeTable(stmt) => describe_table(stmt, self.catalog_manager.clone()) - .context(error::ExecuteSqlSnafu { sql }), - Statement::Explain(stmt) => explain(Box::new(stmt), self.query_engine.clone()) - .await - .context(error::ExecuteSqlSnafu { sql }), _ => unreachable!(), } + .context(error::ExecuteSqlSnafu { sql }) } /// Handles distributed database creation diff --git a/src/frontend/src/instance/opentsdb.rs b/src/frontend/src/instance/opentsdb.rs index 4d35b1d8ea..66b04b1317 100644 --- a/src/frontend/src/instance/opentsdb.rs +++ b/src/frontend/src/instance/opentsdb.rs @@ -60,9 +60,12 @@ impl Instance { #[cfg(test)] mod tests { + use std::sync::Arc; + use common_query::Output; use datafusion::arrow_print; use servers::query_handler::SqlQueryHandler; + use session::context::QueryContext; use super::*; use crate::tests; @@ -121,7 +124,7 @@ mod tests { assert!(result.is_ok()); let output = instance - .do_query("select * from my_metric_1") + .do_query("select * from my_metric_1", Arc::new(QueryContext::new())) .await .unwrap(); match output { diff --git a/src/frontend/src/instance/prometheus.rs b/src/frontend/src/instance/prometheus.rs index ab9f0dea59..b6f322beb2 100644 --- a/src/frontend/src/instance/prometheus.rs +++ b/src/frontend/src/instance/prometheus.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + use api::prometheus::remote::read_request::ResponseType; use api::prometheus::remote::{Query, QueryResult, ReadRequest, ReadResponse, WriteRequest}; use async_trait::async_trait; @@ -25,6 +27,7 @@ use servers::error::{self, Result as ServerResult}; use servers::prometheus::{self, Metrics}; use servers::query_handler::{PrometheusProtocolHandler, PrometheusResponse}; use servers::Mode; +use session::context::QueryContext; use snafu::{OptionExt, ResultExt}; use crate::instance::{parse_stmt, Instance}; @@ -93,7 +96,10 @@ impl Instance { let object_result = if let Some(dist_instance) = &self.dist_instance { let output = futures::future::ready(parse_stmt(&sql)) - .and_then(|stmt| dist_instance.handle_sql(&sql, stmt)) + .and_then(|stmt| { + let query_ctx = Arc::new(QueryContext::with_current_schema(db.to_string())); + dist_instance.handle_sql(&sql, stmt, query_ctx) + }) .await; to_object_result(output).await.try_into() } else { diff --git a/src/query/Cargo.toml b/src/query/Cargo.toml index f4163689f1..9676a81a39 100644 --- a/src/query/Cargo.toml +++ b/src/query/Cargo.toml @@ -27,6 +27,7 @@ metrics = "0.20" once_cell = "1.10" serde = "1.0" serde_json = "1.0" +session = { path = "../session" } snafu = { version = "0.7", features = ["backtraces"] } sql = { path = "../sql" } table = { path = "../table" } diff --git a/src/query/src/datafusion.rs b/src/query/src/datafusion.rs index 7c65c758b3..8dda26a5db 100644 --- a/src/query/src/datafusion.rs +++ b/src/query/src/datafusion.rs @@ -32,6 +32,7 @@ use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream}; use common_telemetry::timer; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::ExecutionPlan; +use session::context::QueryContextRef; use snafu::{OptionExt, ResultExt}; use sql::dialect::GenericDialect; use sql::parser::ParserContext; @@ -46,7 +47,7 @@ use crate::physical_optimizer::PhysicalOptimizer; use crate::physical_planner::PhysicalPlanner; use crate::plan::LogicalPlan; use crate::planner::Planner; -use crate::query_engine::{QueryContext, QueryEngineState}; +use crate::query_engine::{QueryEngineContext, QueryEngineState}; use crate::{metric, QueryEngine}; pub(crate) struct DatafusionQueryEngine { @@ -61,6 +62,7 @@ impl DatafusionQueryEngine { } } +// TODO(LFC): Refactor consideration: extract a "Planner" that stores query context and execute queries inside. #[async_trait::async_trait] impl QueryEngine for DatafusionQueryEngine { fn name(&self) -> &str { @@ -75,21 +77,25 @@ impl QueryEngine for DatafusionQueryEngine { Ok(statement.remove(0)) } - fn statement_to_plan(&self, stmt: Statement) -> Result { - let context_provider = DfContextProviderAdapter::new(self.state.clone()); + fn statement_to_plan( + &self, + stmt: Statement, + query_ctx: QueryContextRef, + ) -> Result { + let context_provider = DfContextProviderAdapter::new(self.state.clone(), query_ctx); let planner = DfPlanner::new(&context_provider); planner.statement_to_plan(stmt) } - fn sql_to_plan(&self, sql: &str) -> Result { + fn sql_to_plan(&self, sql: &str, query_ctx: QueryContextRef) -> Result { let _timer = timer!(metric::METRIC_PARSE_SQL_ELAPSED); let stmt = self.sql_to_statement(sql)?; - self.statement_to_plan(stmt) + self.statement_to_plan(stmt, query_ctx) } async fn execute(&self, plan: &LogicalPlan) -> Result { - let mut ctx = QueryContext::new(self.state.clone()); + let mut ctx = QueryEngineContext::new(self.state.clone()); let logical_plan = self.optimize_logical_plan(&mut ctx, plan)?; let physical_plan = self.create_physical_plan(&mut ctx, &logical_plan).await?; let physical_plan = self.optimize_physical_plan(&mut ctx, physical_plan)?; @@ -100,7 +106,7 @@ impl QueryEngine for DatafusionQueryEngine { } async fn execute_physical(&self, plan: &Arc) -> Result { - let ctx = QueryContext::new(self.state.clone()); + let ctx = QueryEngineContext::new(self.state.clone()); Ok(Output::Stream(self.execute_stream(&ctx, plan).await?)) } @@ -127,7 +133,7 @@ impl QueryEngine for DatafusionQueryEngine { impl LogicalOptimizer for DatafusionQueryEngine { fn optimize_logical_plan( &self, - _ctx: &mut QueryContext, + _: &mut QueryEngineContext, plan: &LogicalPlan, ) -> Result { let _timer = timer!(metric::METRIC_OPTIMIZE_LOGICAL_ELAPSED); @@ -151,7 +157,7 @@ impl LogicalOptimizer for DatafusionQueryEngine { impl PhysicalPlanner for DatafusionQueryEngine { async fn create_physical_plan( &self, - _ctx: &mut QueryContext, + _: &mut QueryEngineContext, logical_plan: &LogicalPlan, ) -> Result> { let _timer = timer!(metric::METRIC_CREATE_PHYSICAL_ELAPSED); @@ -183,7 +189,7 @@ impl PhysicalPlanner for DatafusionQueryEngine { impl PhysicalOptimizer for DatafusionQueryEngine { fn optimize_physical_plan( &self, - _ctx: &mut QueryContext, + _: &mut QueryEngineContext, plan: Arc, ) -> Result> { let _timer = timer!(metric::METRIC_OPTIMIZE_PHYSICAL_ELAPSED); @@ -211,7 +217,7 @@ impl PhysicalOptimizer for DatafusionQueryEngine { impl QueryExecutor for DatafusionQueryEngine { async fn execute_stream( &self, - ctx: &QueryContext, + ctx: &QueryEngineContext, plan: &Arc, ) -> Result { let _timer = timer!(metric::METRIC_EXEC_PLAN_ELAPSED); @@ -250,6 +256,7 @@ mod tests { use common_recordbatch::util; use datafusion::field_util::{FieldExt, SchemaExt}; use datatypes::arrow::array::UInt64Array; + use session::context::QueryContext; use table::table::numbers::NumbersTable; use crate::query_engine::{QueryEngineFactory, QueryEngineRef}; @@ -277,7 +284,9 @@ mod tests { let engine = create_test_engine(); let sql = "select sum(number) from numbers limit 20"; - let plan = engine.sql_to_plan(sql).unwrap(); + let plan = engine + .sql_to_plan(sql, Arc::new(QueryContext::new())) + .unwrap(); assert_eq!( format!("{:?}", plan), @@ -293,7 +302,9 @@ mod tests { let engine = create_test_engine(); let sql = "select sum(number) from numbers limit 20"; - let plan = engine.sql_to_plan(sql).unwrap(); + let plan = engine + .sql_to_plan(sql, Arc::new(QueryContext::new())) + .unwrap(); let output = engine.execute(&plan).await.unwrap(); match output { diff --git a/src/query/src/datafusion/planner.rs b/src/query/src/datafusion/planner.rs index 0f3d00ea2a..6d70109e74 100644 --- a/src/query/src/datafusion/planner.rs +++ b/src/query/src/datafusion/planner.rs @@ -21,6 +21,7 @@ use datafusion::physical_plan::udaf::AggregateUDF; use datafusion::physical_plan::udf::ScalarUDF; use datafusion::sql::planner::{ContextProvider, SqlToRel}; use datatypes::arrow::datatypes::DataType; +use session::context::QueryContextRef; use snafu::ResultExt; use sql::statements::explain::Explain; use sql::statements::query::Query; @@ -85,18 +86,20 @@ where | Statement::CreateDatabase(_) | Statement::Alter(_) | Statement::Insert(_) - | Statement::DropTable(_) => unreachable!(), + | Statement::DropTable(_) + | Statement::Use(_) => unreachable!(), } } } pub(crate) struct DfContextProviderAdapter { state: QueryEngineState, + query_ctx: QueryContextRef, } impl DfContextProviderAdapter { - pub(crate) fn new(state: QueryEngineState) -> Self { - Self { state } + pub(crate) fn new(state: QueryEngineState, query_ctx: QueryContextRef) -> Self { + Self { state, query_ctx } } } @@ -104,11 +107,18 @@ impl DfContextProviderAdapter { /// manage UDFs, UDAFs, variables by ourself in future. impl ContextProvider for DfContextProviderAdapter { fn get_table_provider(&self, name: TableReference) -> Option> { - self.state - .df_context() - .state - .lock() - .get_table_provider(name) + let schema = self.query_ctx.current_schema(); + let execution_ctx = self.state.df_context().state.lock(); + match name { + TableReference::Bare { table } if schema.is_some() => { + execution_ctx.get_table_provider(TableReference::Partial { + // unwrap safety: checked in this match's arm + schema: &schema.unwrap(), + table, + }) + } + _ => execution_ctx.get_table_provider(name), + } } fn get_function_meta(&self, name: &str) -> Option> { diff --git a/src/query/src/executor.rs b/src/query/src/executor.rs index 46eeaf97fa..52664940fb 100644 --- a/src/query/src/executor.rs +++ b/src/query/src/executor.rs @@ -18,14 +18,14 @@ use common_query::physical_plan::PhysicalPlan; use common_recordbatch::SendableRecordBatchStream; use crate::error::Result; -use crate::query_engine::QueryContext; +use crate::query_engine::QueryEngineContext; /// Executor to run [ExecutionPlan]. #[async_trait::async_trait] pub trait QueryExecutor { async fn execute_stream( &self, - ctx: &QueryContext, + ctx: &QueryEngineContext, plan: &Arc, ) -> Result; } diff --git a/src/query/src/lib.rs b/src/query/src/lib.rs index 14aa4f773c..5b25707dc7 100644 --- a/src/query/src/lib.rs +++ b/src/query/src/lib.rs @@ -26,4 +26,6 @@ pub mod planner; pub mod query_engine; pub mod sql; -pub use crate::query_engine::{QueryContext, QueryEngine, QueryEngineFactory, QueryEngineRef}; +pub use crate::query_engine::{ + QueryEngine, QueryEngineContext, QueryEngineFactory, QueryEngineRef, +}; diff --git a/src/query/src/logical_optimizer.rs b/src/query/src/logical_optimizer.rs index 8c35f856a5..266a1a4233 100644 --- a/src/query/src/logical_optimizer.rs +++ b/src/query/src/logical_optimizer.rs @@ -14,12 +14,12 @@ use crate::error::Result; use crate::plan::LogicalPlan; -use crate::query_engine::QueryContext; +use crate::query_engine::QueryEngineContext; pub trait LogicalOptimizer { fn optimize_logical_plan( &self, - ctx: &mut QueryContext, + ctx: &mut QueryEngineContext, plan: &LogicalPlan, ) -> Result; } diff --git a/src/query/src/physical_optimizer.rs b/src/query/src/physical_optimizer.rs index c96d27a7f7..a75c629057 100644 --- a/src/query/src/physical_optimizer.rs +++ b/src/query/src/physical_optimizer.rs @@ -17,12 +17,12 @@ use std::sync::Arc; use common_query::physical_plan::PhysicalPlan; use crate::error::Result; -use crate::query_engine::QueryContext; +use crate::query_engine::QueryEngineContext; pub trait PhysicalOptimizer { fn optimize_physical_plan( &self, - ctx: &mut QueryContext, + ctx: &mut QueryEngineContext, plan: Arc, ) -> Result>; } diff --git a/src/query/src/physical_planner.rs b/src/query/src/physical_planner.rs index 2118f1cc82..40213a1346 100644 --- a/src/query/src/physical_planner.rs +++ b/src/query/src/physical_planner.rs @@ -18,7 +18,7 @@ use common_query::physical_plan::PhysicalPlan; use crate::error::Result; use crate::plan::LogicalPlan; -use crate::query_engine::QueryContext; +use crate::query_engine::QueryEngineContext; /// Physical query planner that converts a `LogicalPlan` to an /// `ExecutionPlan` suitable for execution. @@ -27,7 +27,7 @@ pub trait PhysicalPlanner { /// Create a physical plan from a logical plan async fn create_physical_plan( &self, - ctx: &mut QueryContext, + ctx: &mut QueryEngineContext, logical_plan: &LogicalPlan, ) -> Result>; } diff --git a/src/query/src/query_engine.rs b/src/query/src/query_engine.rs index 36d2c54d1a..110f78e6f3 100644 --- a/src/query/src/query_engine.rs +++ b/src/query/src/query_engine.rs @@ -23,12 +23,13 @@ use common_function::scalars::{FunctionRef, FUNCTION_REGISTRY}; use common_query::physical_plan::PhysicalPlan; use common_query::prelude::ScalarUdf; use common_query::Output; +use session::context::QueryContextRef; use sql::statements::statement::Statement; use crate::datafusion::DatafusionQueryEngine; use crate::error::Result; use crate::plan::LogicalPlan; -pub use crate::query_engine::context::QueryContext; +pub use crate::query_engine::context::QueryEngineContext; pub use crate::query_engine::state::QueryEngineState; #[async_trait::async_trait] @@ -37,9 +38,10 @@ pub trait QueryEngine: Send + Sync { fn sql_to_statement(&self, sql: &str) -> Result; - fn statement_to_plan(&self, stmt: Statement) -> Result; + fn statement_to_plan(&self, stmt: Statement, query_ctx: QueryContextRef) + -> Result; - fn sql_to_plan(&self, sql: &str) -> Result; + fn sql_to_plan(&self, sql: &str, query_ctx: QueryContextRef) -> Result; async fn execute(&self, plan: &LogicalPlan) -> Result; diff --git a/src/query/src/query_engine/context.rs b/src/query/src/query_engine/context.rs index c5b5d20c2d..c54cb8b595 100644 --- a/src/query/src/query_engine/context.rs +++ b/src/query/src/query_engine/context.rs @@ -16,11 +16,11 @@ use crate::query_engine::state::QueryEngineState; #[derive(Debug)] -pub struct QueryContext { +pub struct QueryEngineContext { state: QueryEngineState, } -impl QueryContext { +impl QueryEngineContext { pub fn new(state: QueryEngineState) -> Self { Self { state } } diff --git a/src/query/src/sql.rs b/src/query/src/sql.rs index cfc8a39fd1..2854fed7fc 100644 --- a/src/query/src/sql.rs +++ b/src/query/src/sql.rs @@ -22,6 +22,7 @@ use datatypes::prelude::*; use datatypes::schema::{ColumnSchema, Schema}; use datatypes::vectors::{Helper, StringVector}; use once_cell::sync::Lazy; +use session::context::QueryContextRef; use snafu::{ensure, OptionExt, ResultExt}; use sql::statements::describe::DescribeTable; use sql::statements::explain::Explain; @@ -109,7 +110,11 @@ pub fn show_databases(stmt: ShowDatabases, catalog_manager: CatalogManagerRef) - Ok(Output::RecordBatches(records)) } -pub fn show_tables(stmt: ShowTables, catalog_manager: CatalogManagerRef) -> Result { +pub fn show_tables( + stmt: ShowTables, + catalog_manager: CatalogManagerRef, + query_ctx: QueryContextRef, +) -> Result { // TODO(LFC): supports WHERE ensure!( matches!(stmt.kind, ShowKind::All | ShowKind::Like(_)), @@ -118,9 +123,15 @@ pub fn show_tables(stmt: ShowTables, catalog_manager: CatalogManagerRef) -> Resu } ); - let schema = stmt.database.as_deref().unwrap_or(DEFAULT_SCHEMA_NAME); + let schema = if let Some(database) = stmt.database { + database + } else { + query_ctx + .current_schema() + .unwrap_or_else(|| DEFAULT_SCHEMA_NAME.to_string()) + }; let schema = catalog_manager - .schema(DEFAULT_CATALOG_NAME, schema) + .schema(DEFAULT_CATALOG_NAME, &schema) .context(error::CatalogSnafu)? .context(error::SchemaNotFoundSnafu { schema })?; let tables = schema.table_names().context(error::CatalogSnafu)?; @@ -141,8 +152,12 @@ pub fn show_tables(stmt: ShowTables, catalog_manager: CatalogManagerRef) -> Resu Ok(Output::RecordBatches(records)) } -pub async fn explain(stmt: Box, query_engine: QueryEngineRef) -> Result { - let plan = query_engine.statement_to_plan(Statement::Explain(*stmt))?; +pub async fn explain( + stmt: Box, + query_engine: QueryEngineRef, + query_ctx: QueryContextRef, +) -> Result { + let plan = query_engine.statement_to_plan(Statement::Explain(*stmt), query_ctx)?; query_engine.execute(&plan).await } diff --git a/src/query/tests/argmax_test.rs b/src/query/tests/argmax_test.rs index 23ff4785ac..11f0167a09 100644 --- a/src/query/tests/argmax_test.rs +++ b/src/query/tests/argmax_test.rs @@ -24,6 +24,7 @@ use datatypes::types::PrimitiveElement; use function::{create_query_engine, get_numbers_from_table}; use query::error::Result; use query::QueryEngine; +use session::context::QueryContext; #[tokio::test] async fn test_argmax_aggregator() -> Result<()> { @@ -95,7 +96,9 @@ async fn execute_argmax<'a>( "select ARGMAX({}) as argmax from {}", column_name, table_name ); - let plan = engine.sql_to_plan(&sql).unwrap(); + let plan = engine + .sql_to_plan(&sql, Arc::new(QueryContext::new())) + .unwrap(); let output = engine.execute(&plan).await.unwrap(); let recordbatch_stream = match output { diff --git a/src/query/tests/argmin_test.rs b/src/query/tests/argmin_test.rs index 0e02f9e4a2..2a509f05fd 100644 --- a/src/query/tests/argmin_test.rs +++ b/src/query/tests/argmin_test.rs @@ -25,6 +25,7 @@ use datatypes::types::PrimitiveElement; use function::{create_query_engine, get_numbers_from_table}; use query::error::Result; use query::QueryEngine; +use session::context::QueryContext; #[tokio::test] async fn test_argmin_aggregator() -> Result<()> { @@ -96,7 +97,9 @@ async fn execute_argmin<'a>( "select argmin({}) as argmin from {}", column_name, table_name ); - let plan = engine.sql_to_plan(&sql).unwrap(); + let plan = engine + .sql_to_plan(&sql, Arc::new(QueryContext::new())) + .unwrap(); let output = engine.execute(&plan).await.unwrap(); let recordbatch_stream = match output { diff --git a/src/query/tests/function.rs b/src/query/tests/function.rs index f5ecba91ee..040dfa7a6b 100644 --- a/src/query/tests/function.rs +++ b/src/query/tests/function.rs @@ -27,6 +27,7 @@ use datatypes::vectors::PrimitiveVector; use query::query_engine::QueryEngineFactory; use query::QueryEngine; use rand::Rng; +use session::context::QueryContext; use table::test_util::MemTable; pub fn create_query_engine() -> Arc { @@ -80,7 +81,9 @@ where for<'a> T: Scalar = T>, { let sql = format!("SELECT {} FROM {}", column_name, table_name); - let plan = engine.sql_to_plan(&sql).unwrap(); + let plan = engine + .sql_to_plan(&sql, Arc::new(QueryContext::new())) + .unwrap(); let output = engine.execute(&plan).await.unwrap(); let recordbatch_stream = match output { diff --git a/src/query/tests/mean_test.rs b/src/query/tests/mean_test.rs index 1b068f2456..705dea797d 100644 --- a/src/query/tests/mean_test.rs +++ b/src/query/tests/mean_test.rs @@ -28,6 +28,7 @@ use function::{create_query_engine, get_numbers_from_table}; use num_traits::AsPrimitive; use query::error::Result; use query::QueryEngine; +use session::context::QueryContext; #[tokio::test] async fn test_mean_aggregator() -> Result<()> { @@ -89,7 +90,9 @@ async fn execute_mean<'a>( engine: Arc, ) -> RecordResult> { let sql = format!("select MEAN({}) as mean from {}", column_name, table_name); - let plan = engine.sql_to_plan(&sql).unwrap(); + let plan = engine + .sql_to_plan(&sql, Arc::new(QueryContext::new())) + .unwrap(); let output = engine.execute(&plan).await.unwrap(); let recordbatch_stream = match output { diff --git a/src/query/tests/my_sum_udaf_example.rs b/src/query/tests/my_sum_udaf_example.rs index dbd0427752..4e05183861 100644 --- a/src/query/tests/my_sum_udaf_example.rs +++ b/src/query/tests/my_sum_udaf_example.rs @@ -36,6 +36,7 @@ use datatypes::with_match_primitive_type_id; use num_traits::AsPrimitive; use query::error::Result; use query::QueryEngineFactory; +use session::context::QueryContext; use table::test_util::MemTable; #[derive(Debug, Default)] @@ -228,7 +229,7 @@ where "select MY_SUM({}) as my_sum from {}", column_name, table_name ); - let plan = engine.sql_to_plan(&sql)?; + let plan = engine.sql_to_plan(&sql, Arc::new(QueryContext::new()))?; let output = engine.execute(&plan).await?; let recordbatch_stream = match output { diff --git a/src/query/tests/percentile_test.rs b/src/query/tests/percentile_test.rs index c504da231a..6e210a0494 100644 --- a/src/query/tests/percentile_test.rs +++ b/src/query/tests/percentile_test.rs @@ -30,6 +30,7 @@ use function::{create_query_engine, get_numbers_from_table}; use num_traits::AsPrimitive; use query::error::Result; use query::{QueryEngine, QueryEngineFactory}; +use session::context::QueryContext; use table::test_util::MemTable; #[tokio::test] @@ -53,7 +54,9 @@ async fn test_percentile_aggregator() -> Result<()> { async fn test_percentile_correctness() -> Result<()> { let engine = create_correctness_engine(); let sql = String::from("select PERCENTILE(corr_number,88.0) as percentile from corr_numbers"); - let plan = engine.sql_to_plan(&sql).unwrap(); + let plan = engine + .sql_to_plan(&sql, Arc::new(QueryContext::new())) + .unwrap(); let output = engine.execute(&plan).await.unwrap(); let recordbatch_stream = match output { @@ -113,7 +116,9 @@ async fn execute_percentile<'a>( "select PERCENTILE({},50.0) as percentile from {}", column_name, table_name ); - let plan = engine.sql_to_plan(&sql).unwrap(); + let plan = engine + .sql_to_plan(&sql, Arc::new(QueryContext::new())) + .unwrap(); let output = engine.execute(&plan).await.unwrap(); let recordbatch_stream = match output { diff --git a/src/query/tests/polyval_test.rs b/src/query/tests/polyval_test.rs index 55285e20e1..f2e60c0217 100644 --- a/src/query/tests/polyval_test.rs +++ b/src/query/tests/polyval_test.rs @@ -26,6 +26,7 @@ use function::{create_query_engine, get_numbers_from_table}; use num_traits::AsPrimitive; use query::error::Result; use query::QueryEngine; +use session::context::QueryContext; #[tokio::test] async fn test_polyval_aggregator() -> Result<()> { @@ -92,7 +93,9 @@ async fn execute_polyval<'a>( "select POLYVAL({}, 0) as polyval from {}", column_name, table_name ); - let plan = engine.sql_to_plan(&sql).unwrap(); + let plan = engine + .sql_to_plan(&sql, Arc::new(QueryContext::new())) + .unwrap(); let output = engine.execute(&plan).await.unwrap(); let recordbatch_stream = match output { diff --git a/src/query/tests/query_engine_test.rs b/src/query/tests/query_engine_test.rs index 26afd8c9cc..cf640afba4 100644 --- a/src/query/tests/query_engine_test.rs +++ b/src/query/tests/query_engine_test.rs @@ -37,6 +37,7 @@ use query::plan::LogicalPlan; use query::query_engine::QueryEngineFactory; use query::QueryEngine; use rand::Rng; +use session::context::QueryContext; use table::table::adapter::DfTableProviderAdapter; use table::table::numbers::NumbersTable; use table::test_util::MemTable; @@ -134,7 +135,10 @@ async fn test_udf() -> Result<()> { engine.register_udf(udf); - let plan = engine.sql_to_plan("select pow(number, number) as p from numbers limit 10")?; + let plan = engine.sql_to_plan( + "select pow(number, number) as p from numbers limit 10", + Arc::new(QueryContext::new()), + )?; let output = engine.execute(&plan).await?; let recordbatch = match output { @@ -242,7 +246,9 @@ where for<'a> T: Scalar = T>, { let sql = format!("SELECT {} FROM {}", column_name, table_name); - let plan = engine.sql_to_plan(&sql).unwrap(); + let plan = engine + .sql_to_plan(&sql, Arc::new(QueryContext::new())) + .unwrap(); let output = engine.execute(&plan).await.unwrap(); let recordbatch_stream = match output { @@ -330,7 +336,9 @@ async fn execute_median<'a>( "select MEDIAN({}) as median from {}", column_name, table_name ); - let plan = engine.sql_to_plan(&sql).unwrap(); + let plan = engine + .sql_to_plan(&sql, Arc::new(QueryContext::new())) + .unwrap(); let output = engine.execute(&plan).await.unwrap(); let recordbatch_stream = match output { diff --git a/src/query/tests/scipy_stats_norm_cdf_test.rs b/src/query/tests/scipy_stats_norm_cdf_test.rs index 572a433683..815501a314 100644 --- a/src/query/tests/scipy_stats_norm_cdf_test.rs +++ b/src/query/tests/scipy_stats_norm_cdf_test.rs @@ -26,6 +26,7 @@ use function::{create_query_engine, get_numbers_from_table}; use num_traits::AsPrimitive; use query::error::Result; use query::QueryEngine; +use session::context::QueryContext; use statrs::distribution::{ContinuousCDF, Normal}; use statrs::statistics::Statistics; @@ -94,7 +95,9 @@ async fn execute_scipy_stats_norm_cdf<'a>( "select SCIPYSTATSNORMCDF({},2.0) as scipy_stats_norm_cdf from {}", column_name, table_name ); - let plan = engine.sql_to_plan(&sql).unwrap(); + let plan = engine + .sql_to_plan(&sql, Arc::new(QueryContext::new())) + .unwrap(); let output = engine.execute(&plan).await.unwrap(); let recordbatch_stream = match output { diff --git a/src/query/tests/scipy_stats_norm_pdf.rs b/src/query/tests/scipy_stats_norm_pdf.rs index efbf0ddec3..dd5e0fc7fc 100644 --- a/src/query/tests/scipy_stats_norm_pdf.rs +++ b/src/query/tests/scipy_stats_norm_pdf.rs @@ -26,6 +26,7 @@ use function::{create_query_engine, get_numbers_from_table}; use num_traits::AsPrimitive; use query::error::Result; use query::QueryEngine; +use session::context::QueryContext; use statrs::distribution::{Continuous, Normal}; use statrs::statistics::Statistics; @@ -94,7 +95,9 @@ async fn execute_scipy_stats_norm_pdf<'a>( "select SCIPYSTATSNORMPDF({},2.0) as scipy_stats_norm_pdf from {}", column_name, table_name ); - let plan = engine.sql_to_plan(&sql).unwrap(); + let plan = engine + .sql_to_plan(&sql, Arc::new(QueryContext::new())) + .unwrap(); let output = engine.execute(&plan).await.unwrap(); let recordbatch_stream = match output { diff --git a/src/script/Cargo.toml b/src/script/Cargo.toml index cde4f8391c..57eaeddc98 100644 --- a/src/script/Cargo.toml +++ b/src/script/Cargo.toml @@ -48,6 +48,7 @@ rustpython-vm = { git = "https://github.com/RustPython/RustPython", optional = t "default", "freeze-stdlib", ] } +session = { path = "../session" } snafu = { version = "0.7", features = ["backtraces"] } sql = { path = "../sql" } table = { path = "../table" } diff --git a/src/script/src/python/engine.rs b/src/script/src/python/engine.rs index 7a6c02c858..7ad5390f7b 100644 --- a/src/script/src/python/engine.rs +++ b/src/script/src/python/engine.rs @@ -26,6 +26,7 @@ use common_recordbatch::{RecordBatch, RecordBatchStream, SendableRecordBatchStre use datatypes::schema::SchemaRef; use futures::Stream; use query::QueryEngineRef; +use session::context::QueryContext; use snafu::{ensure, ResultExt}; use sql::statements::statement::Statement; @@ -93,7 +94,9 @@ impl Script for PyScript { matches!(stmt, Statement::Query { .. }), error::UnsupportedSqlSnafu { sql } ); - let plan = self.query_engine.statement_to_plan(stmt)?; + let plan = self + .query_engine + .statement_to_plan(stmt, Arc::new(QueryContext::new()))?; let res = self.query_engine.execute(&plan).await?; let copr = self.copr.clone(); match res { diff --git a/src/script/src/table.rs b/src/script/src/table.rs index 6eac358244..224c98da94 100644 --- a/src/script/src/table.rs +++ b/src/script/src/table.rs @@ -28,6 +28,7 @@ use datatypes::prelude::{ConcreteDataType, ScalarVector}; use datatypes::schema::{ColumnSchema, Schema, SchemaBuilder}; use datatypes::vectors::{StringVector, TimestampVector, VectorRef}; use query::QueryEngineRef; +use session::context::QueryContext; use snafu::{ensure, OptionExt, ResultExt}; use table::requests::{CreateTableRequest, InsertRequest}; @@ -151,7 +152,7 @@ impl ScriptsTable { let plan = self .query_engine - .sql_to_plan(&sql) + .sql_to_plan(&sql, Arc::new(QueryContext::new())) .context(FindScriptSnafu { name })?; let stream = match self diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index 49163bd682..1315728868 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -38,6 +38,7 @@ rand = "0.8" schemars = "0.8" serde = "1.0" serde_json = "1.0" +session = { path = "../session" } snafu = { version = "0.7", features = ["backtraces"] } snap = "1" table = { path = "../table" } diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index 0c426874d4..d7bb2a5239 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -487,6 +487,7 @@ mod test { use datatypes::prelude::*; use datatypes::schema::{ColumnSchema, Schema}; use datatypes::vectors::{StringVector, UInt32Vector}; + use session::context::QueryContextRef; use tokio::sync::mpsc; use super::*; @@ -498,7 +499,7 @@ mod test { #[async_trait] impl SqlQueryHandler for DummyInstance { - async fn do_query(&self, _query: &str) -> Result { + async fn do_query(&self, _: &str, _: QueryContextRef) -> Result { unimplemented!() } } diff --git a/src/servers/src/http/handler.rs b/src/servers/src/http/handler.rs index 69b41edd1b..77c68d38d8 100644 --- a/src/servers/src/http/handler.rs +++ b/src/servers/src/http/handler.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::collections::HashMap; +use std::sync::Arc; use std::time::Instant; use aide::transform::TransformOperation; @@ -21,6 +22,7 @@ use common_error::status_code::StatusCode; use common_telemetry::metric; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use session::context::QueryContext; use crate::http::{ApiState, JsonResponse}; @@ -39,7 +41,9 @@ pub async fn sql( let sql_handler = &state.sql_handler; let start = Instant::now(); let resp = if let Some(sql) = ¶ms.sql { - JsonResponse::from_output(sql_handler.do_query(sql).await).await + // TODO(LFC): Sessions in http server. + let query_ctx = Arc::new(QueryContext::new()); + JsonResponse::from_output(sql_handler.do_query(sql, query_ctx).await).await } else { JsonResponse::with_error( "sql parameter is required.".to_string(), diff --git a/src/servers/src/mysql/federated.rs b/src/servers/src/mysql/federated.rs index 8aa3f369fe..f2f1a8caed 100644 --- a/src/servers/src/mysql/federated.rs +++ b/src/servers/src/mysql/federated.rs @@ -26,21 +26,25 @@ use datatypes::vectors::StringVector; use once_cell::sync::Lazy; use regex::bytes::RegexSet; use regex::Regex; +use session::context::QueryContextRef; // TODO(LFC): Include GreptimeDB's version and git commit tag etc. const MYSQL_VERSION: &str = "8.0.26"; static SELECT_VAR_PATTERN: Lazy = Lazy::new(|| Regex::new("(?i)^(SELECT @@(.*))").unwrap()); static MYSQL_CONN_JAVA_PATTERN: Lazy = - Lazy::new(|| Regex::new("(?i)^(/\\* mysql-connector-java(.*))").unwrap()); + Lazy::new(|| Regex::new("(?i)^(/\\* mysql-connector-j(.*))").unwrap()); static SHOW_LOWER_CASE_PATTERN: Lazy = Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES LIKE 'lower_case_table_names'(.*))").unwrap()); static SHOW_COLLATION_PATTERN: Lazy = Lazy::new(|| Regex::new("(?i)^(show collation where(.*))").unwrap()); static SHOW_VARIABLES_PATTERN: Lazy = Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES(.*))").unwrap()); + static SELECT_VERSION_PATTERN: Lazy = Lazy::new(|| Regex::new(r"(?i)^(SELECT VERSION\(\s*\))").unwrap()); +static SELECT_DATABASE_PATTERN: Lazy = + Lazy::new(|| Regex::new(r"(?i)^(SELECT DATABASE\(\s*\))").unwrap()); // SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP()); static SELECT_TIME_DIFF_FUNC_PATTERN: Lazy = @@ -248,13 +252,18 @@ fn check_show_variables(query: &str) -> Option { } // Check for SET or others query, this is the final check of the federated query. -fn check_others(query: &str) -> Option { +fn check_others(query: &str, query_ctx: QueryContextRef) -> Option { if OTHER_NOT_SUPPORTED_STMT.is_match(query.as_bytes()) { return Some(Output::RecordBatches(RecordBatches::empty())); } let recordbatches = if SELECT_VERSION_PATTERN.is_match(query) { Some(select_function("version()", MYSQL_VERSION)) + } else if SELECT_DATABASE_PATTERN.is_match(query) { + let schema = query_ctx + .current_schema() + .unwrap_or_else(|| "NULL".to_string()); + Some(select_function("database()", &schema)) } else if SELECT_TIME_DIFF_FUNC_PATTERN.is_match(query) { Some(select_function( "TIMEDIFF(NOW(), UTC_TIMESTAMP())", @@ -268,7 +277,7 @@ fn check_others(query: &str) -> Option { // Check whether the query is a federated or driver setup command, // and return some faked results if there are any. -pub fn check(query: &str) -> Option { +pub(crate) fn check(query: &str, query_ctx: QueryContextRef) -> Option { // First to check the query is like "select @@variables". let output = check_select_variable(query); if output.is_some() { @@ -282,25 +291,27 @@ pub fn check(query: &str) -> Option { } // Last check. - check_others(query) + check_others(query, query_ctx) } #[cfg(test)] mod test { + use session::context::QueryContext; + use super::*; #[test] fn test_check() { let query = "select 1"; - let result = check(query); + let result = check(query, Arc::new(QueryContext::new())); assert!(result.is_none()); let query = "select versiona"; - let output = check(query); + let output = check(query, Arc::new(QueryContext::new())); assert!(output.is_none()); fn test(query: &str, expected: Vec<&str>) { - let output = check(query); + let output = check(query, Arc::new(QueryContext::new())); match output.unwrap() { Output::RecordBatches(r) => { assert_eq!(r.pretty_print().lines().collect::>(), expected) diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index c1614377a7..2884b3e4bf 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -16,11 +16,13 @@ use std::sync::Arc; use std::time::Instant; use async_trait::async_trait; +use common_query::Output; use common_telemetry::{debug, error}; use opensrv_mysql::{ - AsyncMysqlShim, ErrorKind, ParamParser, QueryResultWriter, StatementMetaWriter, + AsyncMysqlShim, ErrorKind, InitWriter, ParamParser, QueryResultWriter, StatementMetaWriter, }; use rand::RngCore; +use session::Session; use tokio::io::AsyncWrite; use tokio::sync::RwLock; @@ -36,7 +38,9 @@ pub struct MysqlInstanceShim { query_handler: SqlQueryHandlerRef, salt: [u8; 20], client_addr: String, + // TODO(LFC): Break `Context` struct into different fields in `Session`, each with its own purpose. ctx: Arc>>, + session: Arc, } impl MysqlInstanceShim { @@ -59,8 +63,33 @@ impl MysqlInstanceShim { salt: scramble, client_addr, ctx: Arc::new(RwLock::new(None)), + session: Arc::new(Session::new()), } } + + async fn do_query(&self, query: &str) -> Result { + debug!("Start executing query: '{}'", query); + let start = Instant::now(); + + // TODO(LFC): Find a better way to deal with these special federated queries: + // `check` uses regex to filter out unsupported statements emitted by MySQL's federated + // components, this is quick and dirty, there must be a better way to do it. + let output = + if let Some(output) = crate::mysql::federated::check(query, self.session.context()) { + Ok(output) + } else { + self.query_handler + .do_query(query, self.session.context()) + .await + }; + + debug!( + "Finished executing query: '{}', total time costs in microseconds: {}", + query, + start.elapsed().as_micros() + ); + output + } } #[async_trait] @@ -144,25 +173,20 @@ impl AsyncMysqlShim for MysqlInstanceShi query: &'a str, writer: QueryResultWriter<'a, W>, ) -> Result<()> { - debug!("Start executing query: '{}'", query); - let start = Instant::now(); - - // TODO(LFC): Find a better way: - // `check` uses regex to filter out unsupported statements emitted by MySQL's federated - // components, this is quick and dirty, there must be a better way to do it. - let output = if let Some(output) = crate::mysql::federated::check(query) { - Ok(output) - } else { - self.query_handler.do_query(query).await - }; - - debug!( - "Finished executing query: '{}', total time costs in microseconds: {}", - query, - start.elapsed().as_micros() - ); - + let output = self.do_query(query).await; let mut writer = MysqlResultWriter::new(writer); writer.write(query, output).await } + + async fn on_init<'a>(&'a mut self, database: &'a str, w: InitWriter<'a, W>) -> Result<()> { + let query = format!("USE {}", database.trim()); + let output = self.do_query(&query).await; + if let Err(e) = output { + w.error(ErrorKind::ER_UNKNOWN_ERROR, e.to_string().as_bytes()) + .await + } else { + w.ok().await + } + .map_err(|e| e.into()) + } } diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index 59c6cc2ea8..66d7099522 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::ops::Deref; +use std::sync::Arc; use async_trait::async_trait; use common_query::Output; @@ -26,6 +27,7 @@ use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{text_query_response, FieldInfo, Response, Tag, TextDataRowEncoder}; use pgwire::api::{ClientInfo, Type}; use pgwire::error::{PgWireError, PgWireResult}; +use session::context::QueryContext; use crate::error::{self, Error, Result}; use crate::query_handler::SqlQueryHandlerRef; @@ -46,9 +48,11 @@ impl SimpleQueryHandler for PostgresServerHandler { where C: ClientInfo + Unpin + Send + Sync, { + // TODO(LFC): Sessions in pg server. + let query_ctx = Arc::new(QueryContext::new()); let output = self .query_handler - .do_query(query) + .do_query(query, query_ctx) .await .map_err(|e| PgWireError::ApiError(Box::new(e)))?; diff --git a/src/servers/src/query_handler.rs b/src/servers/src/query_handler.rs index ff76bebdc5..d9a48ba30f 100644 --- a/src/servers/src/query_handler.rs +++ b/src/servers/src/query_handler.rs @@ -18,6 +18,7 @@ use api::prometheus::remote::{ReadRequest, WriteRequest}; use api::v1::{AdminExpr, AdminResult, ObjectExpr, ObjectResult}; use async_trait::async_trait; use common_query::Output; +use session::context::QueryContextRef; use crate::error::Result; use crate::influxdb::InfluxdbRequest; @@ -44,7 +45,7 @@ pub type ScriptHandlerRef = Arc; #[async_trait] pub trait SqlQueryHandler { - async fn do_query(&self, query: &str) -> Result; + async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Result; } #[async_trait] diff --git a/src/servers/tests/http/influxdb_test.rs b/src/servers/tests/http/influxdb_test.rs index 16472d3aeb..e81df37e66 100644 --- a/src/servers/tests/http/influxdb_test.rs +++ b/src/servers/tests/http/influxdb_test.rs @@ -23,6 +23,7 @@ use servers::error::Result; use servers::http::{HttpOptions, HttpServer}; use servers::influxdb::InfluxdbRequest; use servers::query_handler::{InfluxdbLineProtocolHandler, SqlQueryHandler}; +use session::context::QueryContextRef; use tokio::sync::mpsc; struct DummyInstance { @@ -44,7 +45,7 @@ impl InfluxdbLineProtocolHandler for DummyInstance { #[async_trait] impl SqlQueryHandler for DummyInstance { - async fn do_query(&self, _query: &str) -> Result { + async fn do_query(&self, _: &str, _: QueryContextRef) -> Result { unimplemented!() } } diff --git a/src/servers/tests/http/opentsdb_test.rs b/src/servers/tests/http/opentsdb_test.rs index 0ce4465e2a..3b51f66965 100644 --- a/src/servers/tests/http/opentsdb_test.rs +++ b/src/servers/tests/http/opentsdb_test.rs @@ -22,6 +22,7 @@ use servers::error::{self, Result}; use servers::http::{HttpOptions, HttpServer}; use servers::opentsdb::codec::DataPoint; use servers::query_handler::{OpentsdbProtocolHandler, SqlQueryHandler}; +use session::context::QueryContextRef; use tokio::sync::mpsc; struct DummyInstance { @@ -44,7 +45,7 @@ impl OpentsdbProtocolHandler for DummyInstance { #[async_trait] impl SqlQueryHandler for DummyInstance { - async fn do_query(&self, _query: &str) -> Result { + async fn do_query(&self, _: &str, _: QueryContextRef) -> Result { unimplemented!() } } diff --git a/src/servers/tests/http/prometheus_test.rs b/src/servers/tests/http/prometheus_test.rs index ce1cd017a5..b7df350505 100644 --- a/src/servers/tests/http/prometheus_test.rs +++ b/src/servers/tests/http/prometheus_test.rs @@ -27,6 +27,7 @@ use servers::http::{HttpOptions, HttpServer}; use servers::prometheus; use servers::prometheus::{snappy_compress, Metrics}; use servers::query_handler::{PrometheusProtocolHandler, PrometheusResponse, SqlQueryHandler}; +use session::context::QueryContextRef; use tokio::sync::mpsc; struct DummyInstance { @@ -69,7 +70,7 @@ impl PrometheusProtocolHandler for DummyInstance { #[async_trait] impl SqlQueryHandler for DummyInstance { - async fn do_query(&self, _query: &str) -> Result { + async fn do_query(&self, _: &str, _: QueryContextRef) -> Result { unimplemented!() } } diff --git a/src/servers/tests/mod.rs b/src/servers/tests/mod.rs index a5663dddd1..63c8e2ebe2 100644 --- a/src/servers/tests/mod.rs +++ b/src/servers/tests/mod.rs @@ -31,6 +31,8 @@ mod http; mod mysql; use script::engine::{CompileContext, EvalContext, Script, ScriptEngine}; use script::python::{PyEngine, PyScript}; +use session::context::QueryContextRef; + mod opentsdb; mod postgres; @@ -52,8 +54,8 @@ impl DummyInstance { #[async_trait] impl SqlQueryHandler for DummyInstance { - async fn do_query(&self, query: &str) -> Result { - let plan = self.query_engine.sql_to_plan(query).unwrap(); + async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Result { + let plan = self.query_engine.sql_to_plan(query, query_ctx).unwrap(); Ok(self.query_engine.execute(&plan).await.unwrap()) } } diff --git a/src/session/Cargo.toml b/src/session/Cargo.toml new file mode 100644 index 0000000000..0e4e0b1591 --- /dev/null +++ b/src/session/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "session" +version = "0.1.0" +edition = "2021" +license = "Apache-2.0" + +[dependencies] +arc-swap = "1.5" +common-telemetry = { path = "../common/telemetry" } diff --git a/src/session/src/context.rs b/src/session/src/context.rs new file mode 100644 index 0000000000..aec55ac941 --- /dev/null +++ b/src/session/src/context.rs @@ -0,0 +1,56 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arc_swap::ArcSwapOption; +use common_telemetry::info; + +pub type QueryContextRef = Arc; + +pub struct QueryContext { + current_schema: ArcSwapOption, +} + +impl Default for QueryContext { + fn default() -> Self { + Self::new() + } +} + +impl QueryContext { + pub fn new() -> Self { + Self { + current_schema: ArcSwapOption::new(None), + } + } + + pub fn with_current_schema(schema: String) -> Self { + Self { + current_schema: ArcSwapOption::new(Some(Arc::new(schema))), + } + } + + pub fn current_schema(&self) -> Option { + self.current_schema.load().as_deref().cloned() + } + + pub fn set_current_schema(&self, schema: &str) { + let last = self.current_schema.swap(Some(Arc::new(schema.to_string()))); + info!( + "set new session default schema: {:?}, swap old: {:?}", + schema, last + ) + } +} diff --git a/src/session/src/lib.rs b/src/session/src/lib.rs new file mode 100644 index 0000000000..57437c3057 --- /dev/null +++ b/src/session/src/lib.rs @@ -0,0 +1,36 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub mod context; + +use std::sync::Arc; + +use crate::context::{QueryContext, QueryContextRef}; + +#[derive(Default)] +pub struct Session { + query_ctx: QueryContextRef, +} + +impl Session { + pub fn new() -> Self { + Session { + query_ctx: Arc::new(QueryContext::new()), + } + } + + pub fn context(&self) -> QueryContextRef { + self.query_ctx.clone() + } +} diff --git a/src/sql/src/parser.rs b/src/sql/src/parser.rs index 4665ec6243..254982e88e 100644 --- a/src/sql/src/parser.rs +++ b/src/sql/src/parser.rs @@ -102,6 +102,21 @@ impl<'a> ParserContext<'a> { Keyword::DROP => self.parse_drop(), + // TODO(LFC): Use "Keyword::USE" when we can upgrade to newer version of crate sqlparser. + Keyword::NoKeyword if w.value.to_lowercase() == "use" => { + self.parser.next_token(); + + let database_name = + self.parser + .parse_identifier() + .context(error::UnexpectedSnafu { + sql: self.sql, + expected: "a database name", + actual: self.peek_token_as_string(), + })?; + Ok(Statement::Use(database_name.value)) + } + // todo(hl) support more statements. _ => self.unsupported(self.peek_token_as_string()), } diff --git a/src/sql/src/statements.rs b/src/sql/src/statements.rs index 89b9831b5d..bcdc099265 100644 --- a/src/sql/src/statements.rs +++ b/src/sql/src/statements.rs @@ -42,6 +42,8 @@ use crate::error::{ SerializeColumnDefaultConstraintSnafu, UnsupportedDefaultValueSnafu, }; +// TODO(LFC): Get rid of this function, use session context aware version of "table_idents_to_full_name" instead. +// Current obstacles remain in some usage in Frontend, and other SQLs like "describe", "drop" etc. /// Converts maybe fully-qualified table name (`..
` or `
` when /// catalog and schema are default) to tuple. pub fn table_idents_to_full_name(obj_name: &ObjectName) -> Result<(String, String, String)> { diff --git a/src/sql/src/statements/insert.rs b/src/sql/src/statements/insert.rs index e94e512d15..410c0d09cb 100644 --- a/src/sql/src/statements/insert.rs +++ b/src/sql/src/statements/insert.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use sqlparser::ast::{SetExpr, Statement, UnaryOperator, Values}; +use sqlparser::ast::{ObjectName, SetExpr, Statement, UnaryOperator, Values}; use sqlparser::parser::ParserError; use crate::ast::{Expr, Value}; @@ -33,6 +33,13 @@ impl Insert { } } + pub fn table_name(&self) -> &ObjectName { + match &self.inner { + Statement::Insert { table_name, .. } => table_name, + _ => unreachable!(), + } + } + pub fn columns(&self) -> Vec<&String> { match &self.inner { Statement::Insert { columns, .. } => columns.iter().map(|ident| &ident.value).collect(), @@ -110,15 +117,6 @@ mod tests { use super::*; use crate::parser::ParserContext; - #[test] - pub fn test_insert_convert() { - let sql = r"INSERT INTO tables_0 VALUES ( 'field_0', 0) "; - let mut stmts = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); - assert_eq!(1, stmts.len()); - let insert = stmts.pop().unwrap(); - let _stmt: Statement = insert.try_into().unwrap(); - } - #[test] fn test_insert_value_with_unary_op() { use crate::statements::statement::Statement; diff --git a/src/sql/src/statements/statement.rs b/src/sql/src/statements/statement.rs index a350f5e22d..e1c8d731bb 100644 --- a/src/sql/src/statements/statement.rs +++ b/src/sql/src/statements/statement.rs @@ -12,9 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use sqlparser::ast::Statement as SpStatement; -use sqlparser::parser::ParserError; - use crate::statements::alter::AlterTable; use crate::statements::create::{CreateDatabase, CreateTable}; use crate::statements::describe::DescribeTable; @@ -50,37 +47,7 @@ pub enum Statement { DescribeTable(DescribeTable), // EXPLAIN QUERY Explain(Explain), -} - -/// Converts Statement to sqlparser statement -impl TryFrom for SpStatement { - type Error = sqlparser::parser::ParserError; - - fn try_from(value: Statement) -> Result { - match value { - Statement::ShowDatabases(_) => Err(ParserError::ParserError( - "sqlparser does not support SHOW DATABASE query.".to_string(), - )), - Statement::ShowTables(_) => Err(ParserError::ParserError( - "sqlparser does not support SHOW TABLES query.".to_string(), - )), - Statement::ShowCreateTable(_) => Err(ParserError::ParserError( - "sqlparser does not support SHOW CREATE TABLE query.".to_string(), - )), - Statement::DescribeTable(_) => Err(ParserError::ParserError( - "sqlparser does not support DESCRIBE TABLE query.".to_string(), - )), - Statement::DropTable(_) => Err(ParserError::ParserError( - "sqlparser does not support DROP TABLE query.".to_string(), - )), - Statement::Query(s) => Ok(SpStatement::Query(Box::new(s.inner))), - Statement::Insert(i) => Ok(i.inner), - Statement::CreateDatabase(_) | Statement::CreateTable(_) | Statement::Alter(_) => { - unimplemented!() - } - Statement::Explain(e) => Ok(e.inner), - } - } + Use(String), } /// Comment hints from SQL. @@ -92,24 +59,3 @@ pub struct Hint { pub comment: String, pub prefix: String, } - -#[cfg(test)] -mod tests { - use std::assert_matches::assert_matches; - - use sqlparser::dialect::GenericDialect; - - use super::*; - use crate::parser::ParserContext; - - #[test] - pub fn test_statement_convert() { - let sql = "SELECT * FROM table_0"; - let mut stmts = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); - assert_eq!(1, stmts.len()); - let x = stmts.remove(0); - let statement = SpStatement::try_from(x).unwrap(); - - assert_matches!(statement, SpStatement::Query { .. }); - } -} diff --git a/src/table/src/engine.rs b/src/table/src/engine.rs index 0929e29b45..55f68c31cf 100644 --- a/src/table/src/engine.rs +++ b/src/table/src/engine.rs @@ -26,6 +26,27 @@ pub struct TableReference<'a> { pub table: &'a str, } +// TODO(LFC): Find a better place for `TableReference`, +// so that we can reuse the default catalog and schema consts. +// Could be done together with issue #559. +impl<'a> TableReference<'a> { + pub fn bare(table: &'a str) -> Self { + TableReference { + catalog: "greptime", + schema: "public", + table, + } + } + + pub fn full(catalog: &'a str, schema: &'a str, table: &'a str) -> Self { + TableReference { + catalog, + schema, + table, + } + } +} + impl<'a> Display for TableReference<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}.{}.{}", self.catalog, self.schema, self.table)