From b77b561bc8aa793a0bfaceebe290de4af680248c Mon Sep 17 00:00:00 2001 From: LFC Date: Thu, 23 Mar 2023 10:38:26 +0800 Subject: [PATCH] refactor: execute insert with select in query engine (#1181) * refactor: execute insert with select in query engine * fix: resolve PR comments --- Cargo.lock | 1 + src/datanode/Cargo.toml | 1 + src/datanode/src/error.rs | 23 -- src/datanode/src/instance.rs | 1 - src/datanode/src/instance/grpc.rs | 10 +- src/datanode/src/instance/sql.rs | 49 ++-- src/datanode/src/sql.rs | 243 +----------------- src/datanode/src/sql/insert.rs | 184 +------------ src/datanode/src/tests/instance_test.rs | 36 ++- src/datanode/src/tests/test_util.rs | 17 +- src/frontend/src/instance.rs | 80 +++--- src/frontend/src/instance/distributed.rs | 15 +- src/frontend/src/instance/grpc.rs | 2 +- src/frontend/src/lib.rs | 1 - src/frontend/src/sql.rs | 130 ---------- src/query/src/datafusion.rs | 111 +++++++- src/query/src/error.rs | 12 +- src/query/src/plan.rs | 4 +- src/query/src/query_engine.rs | 2 +- src/query/src/tests.rs | 2 +- src/query/src/tests/query_engine_test.rs | 2 +- src/script/src/python/engine.rs | 2 +- src/script/src/python/ffi_types/copr.rs | 3 +- src/script/src/table.rs | 2 +- src/servers/tests/mod.rs | 4 +- src/sql/src/statements/statement.rs | 1 + src/table/src/error.rs | 18 +- src/table/src/requests.rs | 13 +- src/table/src/requests/insert.rs | 79 ++++++ .../common/insert/insert_select.result | 45 ++++ .../common/insert/insert_select.sql | 19 ++ 31 files changed, 399 insertions(+), 713 deletions(-) delete mode 100644 src/frontend/src/sql.rs create mode 100644 src/table/src/requests/insert.rs create mode 100644 tests/cases/standalone/common/insert/insert_select.result create mode 100644 tests/cases/standalone/common/insert/insert_select.sql diff --git a/Cargo.lock b/Cargo.lock index fd3e6268f3..efb1f200cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2294,6 +2294,7 @@ dependencies = [ "common-catalog", "common-datasource", "common-error", + "common-function", "common-grpc", "common-grpc-expr", "common-procedure", diff --git a/src/datanode/Cargo.toml b/src/datanode/Cargo.toml index 0cf7355aba..461de8e8af 100644 --- a/src/datanode/Cargo.toml +++ b/src/datanode/Cargo.toml @@ -21,6 +21,7 @@ common-base = { path = "../common/base" } common-catalog = { path = "../common/catalog" } common-error = { path = "../common/error" } common-datasource = { path = "../common/datasource" } +common-function = { path = "../common/function" } common-grpc = { path = "../common/grpc" } common-grpc-expr = { path = "../common/grpc-expr" } common-procedure = { path = "../common/procedure" } diff --git a/src/datanode/src/error.rs b/src/datanode/src/error.rs index 4b03ee937d..e7a85a58b7 100644 --- a/src/datanode/src/error.rs +++ b/src/datanode/src/error.rs @@ -17,9 +17,7 @@ use std::any::Any; use common_datasource::error::Error as DataSourceError; use common_error::prelude::*; use common_procedure::ProcedureId; -use common_recordbatch::error::Error as RecordBatchError; use datafusion::parquet; -use datatypes::prelude::ConcreteDataType; use storage::error::Error as StorageError; use table::error::Error as TableError; use url::ParseError; @@ -125,24 +123,6 @@ pub enum Error { ))] ColumnValuesNumberMismatch { columns: usize, values: usize }, - #[snafu(display( - "Column type mismatch, column: {}, expected type: {:?}, actual: {:?}", - column, - expected, - actual, - ))] - ColumnTypeMismatch { - column: String, - expected: ConcreteDataType, - actual: ConcreteDataType, - }, - - #[snafu(display("Failed to collect record batch, source: {}", source))] - CollectRecords { - #[snafu(backtrace)] - source: RecordBatchError, - }, - #[snafu(display("Failed to parse sql value, source: {}", source))] ParseSqlValue { #[snafu(backtrace)] @@ -556,8 +536,6 @@ impl ErrorExt for Error { Insert { source, .. } => source.status_code(), Delete { source, .. } => source.status_code(), - CollectRecords { source, .. } => source.status_code(), - TableNotFound { .. } => StatusCode::TableNotFound, ColumnNotFound { .. } => StatusCode::TableColumnNotFound, @@ -570,7 +548,6 @@ impl ErrorExt for Error { ConvertSchema { source, .. } | VectorComputation { source } => source.status_code(), ColumnValuesNumberMismatch { .. } - | ColumnTypeMismatch { .. } | InvalidSql { .. } | InvalidUrl { .. } | InvalidPath { .. } diff --git a/src/datanode/src/instance.rs b/src/datanode/src/instance.rs index 62ffc9edee..99d1b7c998 100644 --- a/src/datanode/src/instance.rs +++ b/src/datanode/src/instance.rs @@ -204,7 +204,6 @@ impl Instance { sql_handler: SqlHandler::new( table_engine.clone(), catalog_manager.clone(), - query_engine.clone(), table_engine, procedure_manager, ), diff --git a/src/datanode/src/instance/grpc.rs b/src/datanode/src/instance/grpc.rs index 05592d6104..0397848d6a 100644 --- a/src/datanode/src/instance/grpc.rs +++ b/src/datanode/src/instance/grpc.rs @@ -21,7 +21,7 @@ use common_query::Output; use query::parser::{PromQuery, QueryLanguageParser, QueryStatement}; use query::plan::LogicalPlan; use servers::query_handler::grpc::GrpcQueryHandler; -use session::context::QueryContextRef; +use session::context::{QueryContext, QueryContextRef}; use snafu::prelude::*; use sql::statements::statement::Statement; use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan}; @@ -53,7 +53,7 @@ impl Instance { .context(DecodeLogicalPlanSnafu)?; self.query_engine - .execute(&LogicalPlan::DfPlan(logical_plan)) + .execute(LogicalPlan::DfPlan(logical_plan), QueryContext::arc()) .await .context(ExecuteLogicalPlanSnafu) } @@ -69,11 +69,11 @@ impl Instance { let plan = self .query_engine .planner() - .plan(stmt, ctx) + .plan(stmt, ctx.clone()) .await .context(PlanStatementSnafu)?; self.query_engine - .execute(&plan) + .execute(plan, ctx) .await .context(ExecuteLogicalPlanSnafu) } @@ -175,7 +175,7 @@ mod test { .plan(stmt, QueryContext::arc()) .await .unwrap(); - engine.execute(&plan).await.unwrap() + engine.execute(plan, QueryContext::arc()).await.unwrap() } #[tokio::test(flavor = "multi_thread")] diff --git a/src/datanode/src/instance/sql.rs b/src/datanode/src/instance/sql.rs index ee4d78366e..e757a3ff42 100644 --- a/src/datanode/src/instance/sql.rs +++ b/src/datanode/src/instance/sql.rs @@ -19,7 +19,6 @@ use common_error::prelude::BoxedError; use common_query::Output; use common_telemetry::logging::info; use common_telemetry::timer; -use futures::StreamExt; use query::error::QueryExecutionSnafu; use query::parser::{PromQuery, QueryLanguageParser, QueryStatement}; use query::query_engine::StatementHandler; @@ -39,8 +38,7 @@ use crate::error::{ }; use crate::instance::Instance; use crate::metric; -use crate::sql::insert::InsertRequests; -use crate::sql::SqlRequest; +use crate::sql::{SqlHandler, SqlRequest}; impl Instance { pub async fn execute_stmt( @@ -50,33 +48,10 @@ impl Instance { ) -> Result { match stmt { QueryStatement::Sql(Statement::Insert(insert)) => { - let requests = self - .sql_handler - .insert_to_requests(self.catalog_manager.clone(), *insert, query_ctx.clone()) - .await?; - - match requests { - InsertRequests::Request(request) => { - self.sql_handler.execute(request, query_ctx.clone()).await - } - - InsertRequests::Stream(mut s) => { - let mut rows = 0; - while let Some(request) = s.next().await { - match self - .sql_handler - .execute(request?, query_ctx.clone()) - .await? - { - Output::AffectedRows(n) => { - rows += n; - } - _ => unreachable!(), - } - } - Ok(Output::AffectedRows(rows)) - } - } + let request = + SqlHandler::insert_to_request(self.catalog_manager.clone(), *insert, query_ctx) + .await?; + self.sql_handler.insert(request).await } QueryStatement::Sql(Statement::Delete(delete)) => { let request = SqlRequest::Delete(*delete); @@ -226,10 +201,13 @@ impl Instance { let engine = self.query_engine(); let plan = engine .planner() - .plan(stmt, query_ctx) + .plan(stmt, query_ctx.clone()) .await .context(PlanStatementSnafu)?; - engine.execute(&plan).await.context(ExecuteStatementSnafu) + engine + .execute(plan, query_ctx) + .await + .context(ExecuteStatementSnafu) } // TODO(ruihang): merge this and `execute_promql` after #951 landed @@ -262,10 +240,13 @@ impl Instance { let engine = self.query_engine(); let plan = engine .planner() - .plan(stmt, query_ctx) + .plan(stmt, query_ctx.clone()) .await .context(PlanStatementSnafu)?; - engine.execute(&plan).await.context(ExecuteStatementSnafu) + engine + .execute(plan, query_ctx) + .await + .context(ExecuteStatementSnafu) } } diff --git a/src/datanode/src/sql.rs b/src/datanode/src/sql.rs index 551ba0fca6..a1c4cea4c5 100644 --- a/src/datanode/src/sql.rs +++ b/src/datanode/src/sql.rs @@ -17,7 +17,6 @@ use common_error::prelude::BoxedError; use common_procedure::ProcedureManagerRef; use common_query::Output; use common_telemetry::error; -use query::query_engine::QueryEngineRef; use query::sql::{describe_table, show_databases, show_tables}; use session::context::QueryContextRef; use snafu::{OptionExt, ResultExt}; @@ -44,7 +43,6 @@ pub(crate) mod insert; #[derive(Debug)] pub enum SqlRequest { - Insert(InsertRequest), CreateTable(CreateTableRequest), CreateDatabase(CreateDatabaseRequest), Alter(AlterTableRequest), @@ -58,10 +56,10 @@ pub enum SqlRequest { } // Handler to execute SQL except query +#[derive(Clone)] pub struct SqlHandler { table_engine: TableEngineRef, catalog_manager: CatalogManagerRef, - query_engine: QueryEngineRef, engine_procedure: TableEngineProcedureRef, procedure_manager: Option, } @@ -70,14 +68,12 @@ impl SqlHandler { pub fn new( table_engine: TableEngineRef, catalog_manager: CatalogManagerRef, - query_engine: QueryEngineRef, engine_procedure: TableEngineProcedureRef, procedure_manager: Option, ) -> Self { Self { table_engine, catalog_manager, - query_engine, engine_procedure, procedure_manager, } @@ -89,7 +85,6 @@ impl SqlHandler { // there, instead of executing here in a "static" fashion. pub async fn execute(&self, request: SqlRequest, query_ctx: QueryContextRef) -> Result { let result = match request { - SqlRequest::Insert(req) => self.insert(req).await, SqlRequest::CreateTable(req) => self.create_table(req).await, SqlRequest::CreateDatabase(req) => self.create_database(req, query_ctx.clone()).await, SqlRequest::Alter(req) => self.alter(req).await, @@ -150,239 +145,3 @@ impl SqlHandler { .context(CloseTableEngineSnafu) } } - -#[cfg(test)] -mod tests { - use std::any::Any; - use std::sync::Arc; - - use catalog::{CatalogManager, RegisterTableRequest}; - use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; - use common_query::logical_plan::Expr; - use common_query::physical_plan::PhysicalPlanRef; - use common_test_util::temp_dir::create_temp_dir; - use common_time::timestamp::Timestamp; - use datatypes::prelude::ConcreteDataType; - use datatypes::schema::{ColumnSchema, SchemaBuilder, SchemaRef}; - use datatypes::value::Value; - use futures::StreamExt; - use log_store::NoopLogStore; - use mito::config::EngineConfig as TableEngineConfig; - use mito::engine::MitoEngine; - use object_store::services::Fs as Builder; - use object_store::{ObjectStore, ObjectStoreBuilder}; - use query::parser::{QueryLanguageParser, QueryStatement}; - use query::QueryEngineFactory; - use session::context::QueryContext; - use sql::statements::statement::Statement; - use storage::compaction::noop::NoopCompactionScheduler; - use storage::config::EngineConfig as StorageEngineConfig; - use storage::EngineImpl; - use table::error::Result as TableResult; - use table::metadata::TableInfoRef; - use table::Table; - - use super::*; - use crate::error::Error; - use crate::sql::insert::InsertRequests; - - struct DemoTable; - - #[async_trait::async_trait] - impl Table for DemoTable { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - let column_schemas = vec![ - ColumnSchema::new("host", ConcreteDataType::string_datatype(), false), - ColumnSchema::new("cpu", ConcreteDataType::float64_datatype(), true), - ColumnSchema::new("memory", ConcreteDataType::float64_datatype(), true), - ColumnSchema::new( - "ts", - ConcreteDataType::timestamp_millisecond_datatype(), - true, - ) - .with_time_index(true), - ]; - - Arc::new( - SchemaBuilder::try_from(column_schemas) - .unwrap() - .build() - .unwrap(), - ) - } - - fn table_info(&self) -> TableInfoRef { - unimplemented!() - } - - async fn scan( - &self, - _projection: Option<&Vec>, - _filters: &[Expr], - _limit: Option, - ) -> TableResult { - unimplemented!(); - } - } - - #[tokio::test] - async fn test_statement_to_request() { - let dir = create_temp_dir("setup_test_engine_and_table"); - let store_dir = dir.path().to_string_lossy(); - let accessor = Builder::default().root(&store_dir).build().unwrap(); - let object_store = ObjectStore::new(accessor).finish(); - let compaction_scheduler = Arc::new(NoopCompactionScheduler::default()); - let sql = r#"insert into demo(host, cpu, memory, ts) values - ('host1', 66.6, 1024, 1655276557000), - ('host2', 88.8, 333.3, 1655276558000) - "#; - - let table_engine = Arc::new(MitoEngine::>::new( - TableEngineConfig::default(), - EngineImpl::new( - StorageEngineConfig::default(), - Arc::new(NoopLogStore::default()), - object_store.clone(), - compaction_scheduler, - ), - object_store, - )); - - let catalog_list = Arc::new( - catalog::local::LocalCatalogManager::try_new(table_engine.clone()) - .await - .unwrap(), - ); - catalog_list.start().await.unwrap(); - assert!(catalog_list - .register_table(RegisterTableRequest { - catalog: DEFAULT_CATALOG_NAME.to_string(), - schema: DEFAULT_SCHEMA_NAME.to_string(), - table_name: "demo".to_string(), - table_id: 1, - table: Arc::new(DemoTable), - }) - .await - .unwrap()); - - let factory = QueryEngineFactory::new(catalog_list.clone()); - let query_engine = factory.query_engine(); - let sql_handler = SqlHandler::new( - table_engine.clone(), - catalog_list.clone(), - query_engine.clone(), - table_engine, - None, - ); - - let stmt = match QueryLanguageParser::parse_sql(sql).unwrap() { - QueryStatement::Sql(Statement::Insert(i)) => i, - _ => { - unreachable!() - } - }; - let request = sql_handler - .insert_to_requests(catalog_list.clone(), *stmt, QueryContext::arc()) - .await - .unwrap(); - - match request { - InsertRequests::Request(SqlRequest::Insert(req)) => { - assert_eq!(req.table_name, "demo"); - let columns_values = req.columns_values; - assert_eq!(4, columns_values.len()); - - let hosts = &columns_values["host"]; - assert_eq!(2, hosts.len()); - assert_eq!(Value::from("host1"), hosts.get(0)); - assert_eq!(Value::from("host2"), hosts.get(1)); - - let cpus = &columns_values["cpu"]; - assert_eq!(2, cpus.len()); - assert_eq!(Value::from(66.6f64), cpus.get(0)); - assert_eq!(Value::from(88.8f64), cpus.get(1)); - - let memories = &columns_values["memory"]; - assert_eq!(2, memories.len()); - assert_eq!(Value::from(1024f64), memories.get(0)); - assert_eq!(Value::from(333.3f64), memories.get(1)); - - let ts = &columns_values["ts"]; - assert_eq!(2, ts.len()); - assert_eq!( - Value::from(Timestamp::new_millisecond(1655276557000i64)), - ts.get(0) - ); - assert_eq!( - Value::from(Timestamp::new_millisecond(1655276558000i64)), - ts.get(1) - ); - } - _ => { - panic!("Not supposed to reach here") - } - } - - // test inert into select - - // type mismatch - let sql = "insert into demo(ts) select number from numbers limit 3"; - - let stmt = match QueryLanguageParser::parse_sql(sql).unwrap() { - QueryStatement::Sql(Statement::Insert(i)) => i, - _ => { - unreachable!() - } - }; - let request = sql_handler - .insert_to_requests(catalog_list.clone(), *stmt, QueryContext::arc()) - .await - .unwrap(); - - match request { - InsertRequests::Stream(mut stream) => { - assert!(matches!( - stream.next().await.unwrap().unwrap_err(), - Error::ColumnTypeMismatch { .. } - )); - } - _ => unreachable!(), - } - - let sql = "insert into demo(cpu) select cast(number as double) from numbers limit 3"; - let stmt = match QueryLanguageParser::parse_sql(sql).unwrap() { - QueryStatement::Sql(Statement::Insert(i)) => i, - _ => { - unreachable!() - } - }; - let request = sql_handler - .insert_to_requests(catalog_list.clone(), *stmt, QueryContext::arc()) - .await - .unwrap(); - - match request { - InsertRequests::Stream(mut stream) => { - let mut times = 0; - while let Some(Ok(SqlRequest::Insert(req))) = stream.next().await { - times += 1; - assert_eq!(req.table_name, "demo"); - let columns_values = req.columns_values; - assert_eq!(1, columns_values.len()); - - let memories = &columns_values["cpu"]; - assert_eq!(3, memories.len()); - assert_eq!(Value::from(0.0f64), memories.get(0)); - assert_eq!(Value::from(1.0f64), memories.get(1)); - assert_eq!(Value::from(2.0f64), memories.get(2)); - } - assert_eq!(1, times); - } - _ => unreachable!(), - } - } -} diff --git a/src/datanode/src/sql/insert.rs b/src/datanode/src/sql/insert.rs index a60100cabc..508cb230db 100644 --- a/src/datanode/src/sql/insert.rs +++ b/src/datanode/src/sql/insert.rs @@ -11,49 +11,31 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; -use std::pin::Pin; use catalog::CatalogManagerRef; use common_catalog::format_full_table_name; use common_query::Output; -use common_recordbatch::RecordBatch; -use datafusion_expr::type_coercion::binary::coerce_types; -use datafusion_expr::Operator; use datatypes::data_type::DataType; use datatypes::schema::ColumnSchema; use datatypes::vectors::MutableVector; -use futures::stream::{self, StreamExt}; -use futures::Stream; -use query::parser::QueryStatement; use session::context::QueryContextRef; use snafu::{ensure, OptionExt, ResultExt}; use sql::ast::Value as SqlValue; use sql::statements::insert::Insert; -use sql::statements::statement::Statement; use sql::statements::{self}; use table::engine::TableReference; use table::requests::*; use table::TableRef; use crate::error::{ - CatalogSnafu, CollectRecordsSnafu, ColumnDefaultValueSnafu, ColumnNoneDefaultValueSnafu, - ColumnNotFoundSnafu, ColumnTypeMismatchSnafu, ColumnValuesNumberMismatchSnafu, Error, - ExecuteLogicalPlanSnafu, InsertSnafu, MissingInsertBodySnafu, ParseSqlSnafu, - ParseSqlValueSnafu, PlanStatementSnafu, Result, TableNotFoundSnafu, + CatalogSnafu, ColumnDefaultValueSnafu, ColumnNoneDefaultValueSnafu, ColumnNotFoundSnafu, + ColumnValuesNumberMismatchSnafu, InsertSnafu, MissingInsertBodySnafu, ParseSqlSnafu, + ParseSqlValueSnafu, Result, TableNotFoundSnafu, }; -use crate::sql::{table_idents_to_full_name, SqlHandler, SqlRequest}; +use crate::sql::{table_idents_to_full_name, SqlHandler}; const DEFAULT_PLACEHOLDER_VALUE: &str = "default"; -type InsertRequestStream = Pin> + Send>>; -pub(crate) enum InsertRequests { - // Single request - Request(SqlRequest), - // Streaming requests - Stream(InsertRequestStream), -} - impl SqlHandler { pub(crate) async fn insert(&self, req: InsertRequest) -> Result { // FIXME(dennis): table_ref is used in InsertSnafu and the req is consumed @@ -77,7 +59,7 @@ impl SqlHandler { table_ref: TableReference, table: &TableRef, stmt: Insert, - ) -> Result { + ) -> Result { let values = stmt .values_body() .context(ParseSqlValueSnafu)? @@ -129,7 +111,7 @@ impl SqlHandler { } } - Ok(SqlRequest::Insert(InsertRequest { + Ok(InsertRequest { catalog_name: table_ref.catalog.to_string(), schema_name: table_ref.schema.to_string(), table_name: table_ref.table.to_string(), @@ -138,150 +120,14 @@ impl SqlHandler { .map(|(cs, mut b)| (cs.name.to_string(), b.to_vector())) .collect(), region_number: 0, - })) + }) } - fn build_request_from_batch( - stmt: Insert, - table: TableRef, - batch: RecordBatch, - query_ctx: QueryContextRef, - ) -> Result { - let (catalog_name, schema_name, table_name) = - table_idents_to_full_name(stmt.table_name(), query_ctx)?; - - let schema = table.schema(); - let columns: Vec<_> = if stmt.columns().is_empty() { - schema - .column_schemas() - .iter() - .map(|c| c.name.to_string()) - .collect() - } else { - stmt.columns().iter().map(|c| (*c).clone()).collect() - }; - let columns_num = columns.len(); - - ensure!( - batch.num_columns() == columns_num, - ColumnValuesNumberMismatchSnafu { - columns: columns_num, - values: batch.num_columns(), - } - ); - - let batch_schema = &batch.schema; - let batch_columns = batch_schema.column_schemas(); - assert_eq!(batch_columns.len(), columns_num); - let mut columns_values = HashMap::with_capacity(columns_num); - - for (i, column_name) in columns.into_iter().enumerate() { - let column_schema = schema - .column_schema_by_name(&column_name) - .with_context(|| ColumnNotFoundSnafu { - table_name: &table_name, - column_name: &column_name, - })?; - let expect_datatype = column_schema.data_type.as_arrow_type(); - // It's safe to retrieve the column schema by index, we already - // check columns number is the same above. - let batch_datatype = batch_columns[i].data_type.as_arrow_type(); - let coerced_type = coerce_types(&expect_datatype, &Operator::Eq, &batch_datatype) - .map_err(|_| Error::ColumnTypeMismatch { - column: column_name.clone(), - expected: column_schema.data_type.clone(), - actual: batch_columns[i].data_type.clone(), - })?; - - ensure!( - expect_datatype == coerced_type, - ColumnTypeMismatchSnafu { - column: column_name, - expected: column_schema.data_type.clone(), - actual: batch_columns[i].data_type.clone(), - } - ); - let vector = batch - .column(i) - .cast(&column_schema.data_type) - .map_err(|_| Error::ColumnTypeMismatch { - column: column_name.clone(), - expected: column_schema.data_type.clone(), - actual: batch_columns[i].data_type.clone(), - })?; - - columns_values.insert(column_name, vector); - } - - Ok(SqlRequest::Insert(InsertRequest { - catalog_name, - schema_name, - table_name, - columns_values, - region_number: 0, - })) - } - - // FIXME(dennis): move it to frontend when refactor is done. - async fn build_stream_from_query( - &self, - table: TableRef, - stmt: Insert, - query_ctx: QueryContextRef, - ) -> Result { - let query = stmt - .query_body() - .context(ParseSqlValueSnafu)? - .context(MissingInsertBodySnafu)?; - - let logical_plan = self - .query_engine - .planner() - .plan( - QueryStatement::Sql(Statement::Query(Box::new(query))), - query_ctx.clone(), - ) - .await - .context(PlanStatementSnafu)?; - - let output = self - .query_engine - .execute(&logical_plan) - .await - .context(ExecuteLogicalPlanSnafu)?; - - let stream: InsertRequestStream = match output { - Output::RecordBatches(batches) => { - Box::pin(stream::iter(batches.take()).map(move |batch| { - Self::build_request_from_batch( - stmt.clone(), - table.clone(), - batch, - query_ctx.clone(), - ) - })) - } - - Output::Stream(stream) => Box::pin(stream.map(move |batch| { - Self::build_request_from_batch( - stmt.clone(), - table.clone(), - batch.context(CollectRecordsSnafu)?, - query_ctx.clone(), - ) - })), - _ => unreachable!(), - }; - - Ok(stream) - } - - pub(crate) async fn insert_to_requests( - &self, + pub async fn insert_to_request( catalog_manager: CatalogManagerRef, stmt: Insert, query_ctx: QueryContextRef, - ) -> Result { + ) -> Result { let (catalog_name, schema_name, table_name) = table_idents_to_full_name(stmt.table_name(), query_ctx.clone())?; @@ -293,16 +139,8 @@ impl SqlHandler { table_name: format_full_table_name(&catalog_name, &schema_name, &table_name), })?; - if stmt.is_insert_select() { - Ok(InsertRequests::Stream( - self.build_stream_from_query(table, stmt, query_ctx).await?, - )) - } else { - let table_ref = TableReference::full(&catalog_name, &schema_name, &table_name); - Ok(InsertRequests::Request(Self::build_request_from_values( - table_ref, &table, stmt, - )?)) - } + let table_ref = TableReference::full(&catalog_name, &schema_name, &table_name); + Self::build_request_from_values(table_ref, &table, stmt) } } diff --git a/src/datanode/src/tests/instance_test.rs b/src/datanode/src/tests/instance_test.rs index 63f87e283d..d0627c1646 100644 --- a/src/datanode/src/tests/instance_test.rs +++ b/src/datanode/src/tests/instance_test.rs @@ -22,7 +22,7 @@ use common_telemetry::logging; use datatypes::data_type::ConcreteDataType; use datatypes::vectors::{Int64Vector, StringVector, UInt64Vector, VectorRef}; use query::parser::{QueryLanguageParser, QueryStatement}; -use session::context::QueryContext; +use session::context::{QueryContext, QueryContextRef}; use snafu::ResultExt; use sql::statements::statement::Statement; @@ -217,20 +217,20 @@ async fn test_execute_insert_by_select() { try_execute_sql(&instance, "insert into demo2(host) select * from demo1") .await .unwrap_err(), - Error::ColumnValuesNumberMismatch { .. } + Error::PlanStatement { .. } )); assert!(matches!( try_execute_sql(&instance, "insert into demo2 select cpu,memory from demo1") .await .unwrap_err(), - Error::ColumnValuesNumberMismatch { .. } + Error::PlanStatement { .. } )); assert!(matches!( try_execute_sql(&instance, "insert into demo2(ts) select memory from demo1") .await .unwrap_err(), - Error::ColumnTypeMismatch { .. } + Error::PlanStatement { .. } )); let output = execute_sql(&instance, "insert into demo2 select * from demo1").await; @@ -962,16 +962,28 @@ async fn try_execute_sql_in_db( ) -> Result { let query_ctx = Arc::new(QueryContext::with(DEFAULT_CATALOG_NAME, db)); + async fn plan_exec( + instance: &MockInstance, + stmt: QueryStatement, + query_ctx: QueryContextRef, + ) -> Result { + let engine = instance.inner().query_engine(); + let plan = engine + .planner() + .plan(stmt, query_ctx.clone()) + .await + .context(PlanStatementSnafu)?; + engine + .execute(plan, query_ctx) + .await + .context(ExecuteLogicalPlanSnafu) + } + let stmt = QueryLanguageParser::parse_sql(sql).unwrap(); match stmt { - QueryStatement::Sql(Statement::Query(_)) => { - let engine = instance.inner().query_engine(); - let plan = engine - .planner() - .plan(stmt, query_ctx) - .await - .context(PlanStatementSnafu)?; - engine.execute(&plan).await.context(ExecuteLogicalPlanSnafu) + QueryStatement::Sql(Statement::Query(_)) => plan_exec(instance, stmt, query_ctx).await, + QueryStatement::Sql(Statement::Insert(ref insert)) if insert.is_insert_select() => { + plan_exec(instance, stmt, query_ctx).await } _ => instance.inner().execute_stmt(stmt, query_ctx).await, } diff --git a/src/datanode/src/tests/test_util.rs b/src/datanode/src/tests/test_util.rs index 2e9d7576e5..774cea550a 100644 --- a/src/datanode/src/tests/test_util.rs +++ b/src/datanode/src/tests/test_util.rs @@ -24,7 +24,6 @@ use datatypes::schema::{ColumnSchema, RawSchema}; use mito::config::EngineConfig; use mito::table::test_util::{new_test_object_store, MockEngine, MockMitoEngine}; use query::parser::{PromQuery, QueryLanguageParser, QueryStatement}; -use query::QueryEngineFactory; use servers::Mode; use session::context::QueryContext; use snafu::ResultExt; @@ -87,7 +86,7 @@ impl MockInstance { match stmt { QueryStatement::Sql(Statement::Query(_)) => { let plan = planner.plan(stmt, QueryContext::arc()).await.unwrap(); - engine.execute(&plan).await.unwrap() + engine.execute(plan, QueryContext::arc()).await.unwrap() } QueryStatement::Sql(Statement::Tql(tql)) => { let plan = match tql { @@ -103,7 +102,7 @@ impl MockInstance { } Tql::Explain(_) => unimplemented!(), }; - engine.execute(&plan).await.unwrap() + engine.execute(plan, QueryContext::arc()).await.unwrap() } _ => self .inner() @@ -201,17 +200,7 @@ pub async fn create_mock_sql_handler() -> SqlHandler { .await .unwrap(), ); - - let catalog_list = catalog::local::new_memory_catalog_list().unwrap(); - let factory = QueryEngineFactory::new(catalog_list); - - SqlHandler::new( - mock_engine.clone(), - catalog_manager, - factory.query_engine(), - mock_engine, - None, - ) + SqlHandler::new(mock_engine.clone(), catalog_manager, mock_engine, None) } pub(crate) async fn setup_test_instance(test_name: &str) -> MockInstance { diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index 2d1b2264a1..f07396a5d9 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -428,45 +428,55 @@ fn parse_stmt(sql: &str) -> Result> { } impl Instance { + async fn plan_exec(&self, stmt: Statement, query_ctx: QueryContextRef) -> Result { + let planner = self.query_engine.planner(); + let plan = planner + .plan(QueryStatement::Sql(stmt), query_ctx.clone()) + .await + .context(PlanStatementSnafu)?; + self.query_engine + .execute(plan, query_ctx) + .await + .context(ExecLogicalPlanSnafu) + } + + async fn execute_tql(&self, tql: Tql, query_ctx: QueryContextRef) -> Result { + let plan = match tql { + Tql::Eval(eval) => { + let promql = PromQuery { + start: eval.start, + end: eval.end, + step: eval.step, + query: eval.query, + }; + let stmt = QueryLanguageParser::parse_promql(&promql).context(ParseQuerySnafu)?; + self.query_engine + .planner() + .plan(stmt, query_ctx.clone()) + .await + .context(PlanStatementSnafu)? + } + Tql::Explain(_) => unimplemented!(), + }; + self.query_engine + .execute(plan, query_ctx) + .await + .context(ExecLogicalPlanSnafu) + } + async fn query_statement(&self, stmt: Statement, query_ctx: QueryContextRef) -> Result { check_permission(self.plugins.clone(), &stmt, &query_ctx)?; - let planner = self.query_engine.planner(); - match stmt { - Statement::Query(_) | Statement::Explain(_) => { - let plan = planner - .plan(QueryStatement::Sql(stmt), query_ctx) - .await - .context(PlanStatementSnafu)?; - self.query_engine - .execute(&plan) - .await - .context(ExecLogicalPlanSnafu) - } - Statement::Tql(tql) => { - let plan = match tql { - Tql::Eval(eval) => { - let promql = PromQuery { - start: eval.start, - end: eval.end, - step: eval.step, - query: eval.query, - }; - let stmt = - QueryLanguageParser::parse_promql(&promql).context(ParseQuerySnafu)?; - planner - .plan(stmt, query_ctx) - .await - .context(PlanStatementSnafu)? - } - Tql::Explain(_) => unimplemented!(), - }; - self.query_engine - .execute(&plan) - .await - .context(ExecLogicalPlanSnafu) + Statement::Query(_) | Statement::Explain(_) => self.plan_exec(stmt, query_ctx).await, + + // For performance consideration, only "insert with select" is executed by query engine. + // Plain insert ("insert with values") is still executed directly in statement. + Statement::Insert(ref insert) if insert.is_insert_select() => { + self.plan_exec(stmt, query_ctx).await } + + Statement::Tql(tql) => self.execute_tql(tql, query_ctx).await, Statement::CreateDatabase(_) | Statement::ShowDatabases(_) | Statement::CreateTable(_) @@ -1086,7 +1096,7 @@ mod tests { .plan(stmt.clone(), QueryContext::arc()) .await .unwrap(); - let output = engine.execute(&plan).await.unwrap(); + let output = engine.execute(plan, QueryContext::arc()).await.unwrap(); let Output::Stream(stream) = output else { unreachable!() }; let recordbatches = RecordBatches::try_collect(stream).await.unwrap(); let actual = recordbatches.pretty_print().unwrap(); diff --git a/src/frontend/src/instance/distributed.rs b/src/frontend/src/instance/distributed.rs index 472b05245b..a82c3021be 100644 --- a/src/frontend/src/instance/distributed.rs +++ b/src/frontend/src/instance/distributed.rs @@ -33,6 +33,7 @@ use common_error::prelude::BoxedError; use common_query::Output; use common_telemetry::{debug, info}; use datanode::instance::sql::table_idents_to_full_name; +use datanode::sql::SqlHandler; use datatypes::prelude::ConcreteDataType; use datatypes::schema::RawSchema; use meta_client::client::MetaClient; @@ -60,13 +61,12 @@ use crate::catalog::FrontendCatalogManager; use crate::datanode::DatanodeClients; use crate::error::{ self, AlterExprToRequestSnafu, CatalogEntrySerdeSnafu, CatalogSnafu, ColumnDataTypeSnafu, - DeserializePartitionSnafu, NotSupportedSnafu, ParseSqlSnafu, PrimaryKeyNotFoundSnafu, - RequestDatanodeSnafu, RequestMetaSnafu, Result, SchemaExistsSnafu, StartMetaClientSnafu, - TableAlreadyExistSnafu, TableNotFoundSnafu, TableSnafu, ToTableInsertRequestSnafu, - UnrecognizedTableOptionSnafu, + DeserializePartitionSnafu, InvokeDatanodeSnafu, NotSupportedSnafu, ParseSqlSnafu, + PrimaryKeyNotFoundSnafu, RequestDatanodeSnafu, RequestMetaSnafu, Result, SchemaExistsSnafu, + StartMetaClientSnafu, TableAlreadyExistSnafu, TableNotFoundSnafu, TableSnafu, + ToTableInsertRequestSnafu, UnrecognizedTableOptionSnafu, }; use crate::expr_factory; -use crate::sql::insert_to_request; use crate::table::DistTable; #[derive(Clone)] @@ -374,7 +374,10 @@ impl DistInstance { .context(CatalogSnafu)? .context(TableNotFoundSnafu { table_name: table })?; - let insert_request = insert_to_request(&table, *insert, query_ctx)?; + let insert_request = + SqlHandler::insert_to_request(self.catalog_manager.clone(), *insert, query_ctx) + .await + .context(InvokeDatanodeSnafu)?; return Ok(Output::AffectedRows( table.insert(insert_request).await.context(TableSnafu)?, diff --git a/src/frontend/src/instance/grpc.rs b/src/frontend/src/instance/grpc.rs index ef12c71da3..6c65ecfc8f 100644 --- a/src/frontend/src/instance/grpc.rs +++ b/src/frontend/src/instance/grpc.rs @@ -590,7 +590,7 @@ CREATE TABLE {table_name} ( .plan(stmt, QueryContext::arc()) .await .unwrap(); - let output = engine.execute(&plan).await.unwrap(); + let output = engine.execute(plan, QueryContext::arc()).await.unwrap(); let Output::Stream(stream) = output else { unreachable!() }; let recordbatches = RecordBatches::try_collect(stream).await.unwrap(); let actual = recordbatches.pretty_print().unwrap(); diff --git a/src/frontend/src/lib.rs b/src/frontend/src/lib.rs index b5d82f93b7..71cb0ecf8f 100644 --- a/src/frontend/src/lib.rs +++ b/src/frontend/src/lib.rs @@ -28,7 +28,6 @@ pub mod postgres; pub mod prom; pub mod prometheus; mod server; -mod sql; mod table; #[cfg(test)] mod tests; diff --git a/src/frontend/src/sql.rs b/src/frontend/src/sql.rs deleted file mode 100644 index f2a766f663..0000000000 --- a/src/frontend/src/sql.rs +++ /dev/null @@ -1,130 +0,0 @@ -// Copyright 2023 Greptime Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use common_error::ext::BoxedError; -use common_error::snafu::ensure; -use datanode::instance::sql::table_idents_to_full_name; -use datatypes::data_type::DataType; -use datatypes::prelude::MutableVector; -use datatypes::schema::ColumnSchema; -use session::context::QueryContextRef; -use snafu::{OptionExt, ResultExt}; -use sql::ast::Value as SqlValue; -use sql::statements; -use sql::statements::insert::Insert; -use table::requests::InsertRequest; -use table::TableRef; - -use crate::error::{self, ExternalSnafu, Result}; - -const DEFAULT_PLACEHOLDER_VALUE: &str = "default"; - -// TODO(fys): Extract the common logic in datanode and frontend in the future. -// This function convert insert statement to an `InsertRequest` to region 0. -pub(crate) fn insert_to_request( - table: &TableRef, - stmt: Insert, - query_ctx: QueryContextRef, -) -> Result { - let columns = stmt.columns(); - let values = stmt - .values_body() - .context(error::ParseSqlSnafu)? - .context(error::MissingInsertValuesSnafu)?; - - let (catalog_name, schema_name, table_name) = - table_idents_to_full_name(stmt.table_name(), query_ctx) - .map_err(BoxedError::new) - .context(ExternalSnafu)?; - - let schema = table.schema(); - let columns_num = if columns.is_empty() { - schema.column_schemas().len() - } else { - columns.len() - }; - let rows_num = values.len(); - - let mut columns_builders: Vec<(&ColumnSchema, Box)> = - Vec::with_capacity(columns_num); - - if columns.is_empty() { - for column_schema in schema.column_schemas() { - let data_type = &column_schema.data_type; - columns_builders.push((column_schema, data_type.create_mutable_vector(rows_num))); - } - } else { - for column_name in columns { - let column_schema = schema.column_schema_by_name(column_name).with_context(|| { - error::ColumnNotFoundSnafu { - table_name: &table_name, - column_name: column_name.to_string(), - } - })?; - let data_type = &column_schema.data_type; - columns_builders.push((column_schema, data_type.create_mutable_vector(rows_num))); - } - } - - for row in values { - ensure!( - row.len() == columns_num, - error::ColumnValuesNumberMismatchSnafu { - columns: columns_num, - values: row.len(), - } - ); - - for (sql_val, (column_schema, builder)) in row.iter().zip(columns_builders.iter_mut()) { - add_row_to_vector(column_schema, sql_val, builder)?; - } - } - - Ok(InsertRequest { - catalog_name, - schema_name, - table_name, - columns_values: columns_builders - .into_iter() - .map(|(cs, mut b)| (cs.name.to_string(), b.to_vector())) - .collect(), - region_number: 0, - }) -} - -fn add_row_to_vector( - column_schema: &ColumnSchema, - sql_val: &SqlValue, - builder: &mut Box, -) -> Result<()> { - let value = if replace_default(sql_val) { - column_schema - .create_default() - .context(error::ColumnDefaultValueSnafu { - column: column_schema.name.to_string(), - })? - .context(error::ColumnNoneDefaultValueSnafu { - column: column_schema.name.to_string(), - })? - } else { - statements::sql_value_to_value(&column_schema.name, &column_schema.data_type, sql_val) - .context(error::ParseSqlSnafu)? - }; - builder.push_value_ref(value.as_value_ref()); - Ok(()) -} - -fn replace_default(sql_val: &SqlValue) -> bool { - matches!(sql_val, SqlValue::Placeholder(s) if s.to_lowercase() == DEFAULT_PLACEHOLDER_VALUE) -} diff --git a/src/query/src/datafusion.rs b/src/query/src/datafusion.rs index 057d3a14a5..498a104959 100644 --- a/src/query/src/datafusion.rs +++ b/src/query/src/datafusion.rs @@ -33,12 +33,21 @@ use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream}; use common_telemetry::timer; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::ExecutionPlan; +use datafusion_common::ResolvedTableReference; +use datafusion_expr::{DmlStatement, LogicalPlan as DfLogicalPlan, WriteOp}; use datatypes::schema::Schema; +use futures_util::StreamExt; +use session::context::QueryContextRef; use snafu::{OptionExt, ResultExt}; +use table::requests::InsertRequest; +use table::TableRef; pub use crate::datafusion::catalog_adapter::DfCatalogListAdapter; pub use crate::datafusion::planner::DfContextProviderAdapter; -use crate::error::{DataFusionSnafu, QueryExecutionSnafu, Result}; +use crate::error::{ + CatalogNotFoundSnafu, CatalogSnafu, CreateRecordBatchSnafu, DataFusionSnafu, + QueryExecutionSnafu, Result, SchemaNotFoundSnafu, TableNotFoundSnafu, UnsupportedExprSnafu, +}; use crate::executor::QueryExecutor; use crate::logical_optimizer::LogicalOptimizer; use crate::physical_optimizer::PhysicalOptimizer; @@ -56,6 +65,83 @@ impl DatafusionQueryEngine { pub fn new(state: Arc) -> Self { Self { state } } + + async fn exec_query_plan(&self, plan: LogicalPlan) -> Result { + let mut ctx = QueryEngineContext::new(self.state.session_state()); + + // `create_physical_plan` will optimize logical plan internally + let physical_plan = self.create_physical_plan(&mut ctx, &plan).await?; + let physical_plan = self.optimize_physical_plan(&mut ctx, physical_plan)?; + + Ok(Output::Stream(self.execute_stream(&ctx, &physical_plan)?)) + } + + async fn exec_insert_plan( + &self, + dml: &DmlStatement, + query_ctx: QueryContextRef, + ) -> Result { + let default_catalog = query_ctx.current_catalog(); + let default_schema = query_ctx.current_schema(); + let table_name = dml + .table_name + .as_table_reference() + .resolve(&default_catalog, &default_schema); + let table = self.find_table(&table_name).await?; + + let output = self + .exec_query_plan(LogicalPlan::DfPlan((*dml.input).clone())) + .await?; + let mut stream = match output { + Output::RecordBatches(batches) => batches.as_stream(), + Output::Stream(stream) => stream, + _ => unreachable!(), + }; + + let mut affected_rows = 0; + while let Some(batch) = stream.next().await { + let batch = batch.context(CreateRecordBatchSnafu)?; + let request = InsertRequest::try_from_recordbatch(&table_name, table.schema(), batch) + .map_err(BoxedError::new) + .context(QueryExecutionSnafu)?; + + let rows = table + .insert(request) + .await + .map_err(BoxedError::new) + .context(QueryExecutionSnafu)?; + affected_rows += rows; + } + Ok(Output::AffectedRows(affected_rows)) + } + + async fn find_table(&self, table_name: &ResolvedTableReference<'_>) -> Result { + let catalog_name = table_name.catalog.as_ref(); + let schema_name = table_name.schema.as_ref(); + let table_name = table_name.table.as_ref(); + + let catalog = self + .state + .catalog_list() + .catalog(catalog_name) + .context(CatalogSnafu)? + .context(CatalogNotFoundSnafu { + catalog: catalog_name, + })?; + let schema = + catalog + .schema(schema_name) + .context(CatalogSnafu)? + .context(SchemaNotFoundSnafu { + schema: schema_name, + })?; + let table = schema + .table(table_name) + .await + .context(CatalogSnafu)? + .context(TableNotFoundSnafu { table: table_name })?; + Ok(table) + } } #[async_trait] @@ -75,14 +161,17 @@ impl QueryEngine for DatafusionQueryEngine { optimised_plan.schema() } - async fn execute(&self, plan: &LogicalPlan) -> Result { - let logical_plan = self.optimize(plan)?; - - let mut ctx = QueryEngineContext::new(self.state.session_state()); - let physical_plan = self.create_physical_plan(&mut ctx, &logical_plan).await?; - let physical_plan = self.optimize_physical_plan(&mut ctx, physical_plan)?; - - Ok(Output::Stream(self.execute_stream(&ctx, &physical_plan)?)) + async fn execute(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result { + match plan { + LogicalPlan::DfPlan(DfLogicalPlan::Dml(dml)) => match dml.op { + WriteOp::Insert => self.exec_insert_plan(&dml, query_ctx).await, + _ => UnsupportedExprSnafu { + name: format!("DML op {}", dml.op), + } + .fail(), + }, + _ => self.exec_query_plan(plan).await, + } } fn register_udf(&self, udf: ScalarUdf) { @@ -292,11 +381,11 @@ mod tests { let stmt = QueryLanguageParser::parse_sql(sql).unwrap(); let plan = engine .planner() - .plan(stmt, Arc::new(QueryContext::new())) + .plan(stmt, QueryContext::arc()) .await .unwrap(); - let output = engine.execute(&plan).await.unwrap(); + let output = engine.execute(plan, QueryContext::arc()).await.unwrap(); match output { Output::Stream(recordbatch) => { diff --git a/src/query/src/error.rs b/src/query/src/error.rs index 70c408deec..b373a5d34d 100644 --- a/src/query/src/error.rs +++ b/src/query/src/error.rs @@ -72,8 +72,11 @@ pub enum Error { #[snafu(display("The SQL string has multiple statements, query: {}", query))] MultipleStatements { query: String, backtrace: Backtrace }, - #[snafu(display("Failed to convert datatype: {}", source))] - Datatype { source: datatypes::error::Error }, + #[snafu(display("Failed to convert Datafusion schema: {}", source))] + ConvertDatafusionSchema { + #[snafu(backtrace)] + source: datatypes::error::Error, + }, #[snafu(display("Failed to parse timestamp `{}`: {}", raw, source))] ParseTimestamp { @@ -123,9 +126,10 @@ impl ErrorExt for Error { | ParseFloat { .. } => StatusCode::InvalidArguments, QueryAccessDenied { .. } => StatusCode::AccessDenied, Catalog { source } => source.status_code(), - VectorComputation { source } => source.status_code(), + VectorComputation { source } | ConvertDatafusionSchema { source } => { + source.status_code() + } CreateRecordBatch { source } => source.status_code(), - Datatype { source } => source.status_code(), QueryExecution { source } | QueryPlan { source } => source.status_code(), DataFusion { .. } => StatusCode::Internal, Sql { source } => source.status_code(), diff --git a/src/query/src/plan.rs b/src/query/src/plan.rs index 7e73596222..0f406bea17 100644 --- a/src/query/src/plan.rs +++ b/src/query/src/plan.rs @@ -18,7 +18,7 @@ use datafusion_expr::LogicalPlan as DfLogicalPlan; use datatypes::schema::Schema; use snafu::ResultExt; -use crate::error::Result; +use crate::error::{ConvertDatafusionSchemaSnafu, Result}; /// A LogicalPlan represents the different types of relational /// operators (such as Projection, Filter, etc) and can be created by @@ -42,7 +42,7 @@ impl LogicalPlan { df_schema .clone() .try_into() - .context(crate::error::DatatypeSnafu) + .context(ConvertDatafusionSchemaSnafu) } } } diff --git a/src/query/src/query_engine.rs b/src/query/src/query_engine.rs index bf4b5766cd..3e212a816f 100644 --- a/src/query/src/query_engine.rs +++ b/src/query/src/query_engine.rs @@ -56,7 +56,7 @@ pub trait QueryEngine: Send + Sync { async fn describe(&self, plan: LogicalPlan) -> Result; - async fn execute(&self, plan: &LogicalPlan) -> Result; + async fn execute(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result; fn register_udf(&self, udf: ScalarUdf); diff --git a/src/query/src/tests.rs b/src/query/src/tests.rs index 44844fc962..d633ac9d4c 100644 --- a/src/query/src/tests.rs +++ b/src/query/src/tests.rs @@ -41,7 +41,7 @@ async fn exec_selection(engine: QueryEngineRef, sql: &str) -> Vec { .await .unwrap(); let Output::Stream(stream) = engine - .execute(&plan) + .execute(plan, QueryContext::arc()) .await .unwrap() else { unreachable!() }; util::collect(stream).await.unwrap() diff --git a/src/query/src/tests/query_engine_test.rs b/src/query/src/tests/query_engine_test.rs index f0fb8ee6f6..e3688f6e4f 100644 --- a/src/query/src/tests/query_engine_test.rs +++ b/src/query/src/tests/query_engine_test.rs @@ -77,7 +77,7 @@ async fn test_datafusion_query_engine() -> Result<()> { .unwrap(), ); - let output = engine.execute(&plan).await?; + let output = engine.execute(plan, QueryContext::arc()).await?; let recordbatch = match output { Output::Stream(recordbatch) => recordbatch, diff --git a/src/script/src/python/engine.rs b/src/script/src/python/engine.rs index b6f6519929..60e3f02e7d 100644 --- a/src/script/src/python/engine.rs +++ b/src/script/src/python/engine.rs @@ -282,7 +282,7 @@ impl Script for PyScript { .planner() .plan(stmt, QueryContext::arc()) .await?; - let res = self.query_engine.execute(&plan).await?; + let res = self.query_engine.execute(plan, QueryContext::arc()).await?; let copr = self.copr.clone(); match res { Output::Stream(stream) => Ok(Output::Stream(Box::pin(CoprStream::try_new( diff --git a/src/script/src/python/ffi_types/copr.rs b/src/script/src/python/ffi_types/copr.rs index c18303027b..6b95d75e20 100644 --- a/src/script/src/python/ffi_types/copr.rs +++ b/src/script/src/python/ffi_types/copr.rs @@ -35,6 +35,7 @@ use rustpython_compiler_core::CodeObject; use rustpython_vm as vm; #[cfg(test)] use serde::Deserialize; +use session::context::QueryContext; use snafu::{OptionExt, ResultExt}; use vm::builtins::{PyList, PyListRef}; use vm::convert::ToPyObject; @@ -373,7 +374,7 @@ impl PyQueryEngine { .map_err(|e| e.to_string())?; let res = engine .clone() - .execute(&plan) + .execute(plan, QueryContext::arc()) .await .map_err(|e| e.to_string()); match res { diff --git a/src/script/src/table.rs b/src/script/src/table.rs index 82f67fbbd4..fcd9c72473 100644 --- a/src/script/src/table.rs +++ b/src/script/src/table.rs @@ -167,7 +167,7 @@ impl ScriptsTable { let stream = match self .query_engine - .execute(&plan) + .execute(plan, QueryContext::arc()) .await .context(FindScriptSnafu { name })? { diff --git a/src/servers/tests/mod.rs b/src/servers/tests/mod.rs index 785614d570..338cc04879 100644 --- a/src/servers/tests/mod.rs +++ b/src/servers/tests/mod.rs @@ -72,10 +72,10 @@ impl SqlQueryHandler for DummyInstance { let plan = self .query_engine .planner() - .plan(stmt, query_ctx) + .plan(stmt, query_ctx.clone()) .await .unwrap(); - let output = self.query_engine.execute(&plan).await.unwrap(); + let output = self.query_engine.execute(plan, query_ctx).await.unwrap(); vec![Ok(output)] } diff --git a/src/sql/src/statements/statement.rs b/src/sql/src/statements/statement.rs index 3fe2f1bad5..9f09d9169e 100644 --- a/src/sql/src/statements/statement.rs +++ b/src/sql/src/statements/statement.rs @@ -79,6 +79,7 @@ impl TryFrom<&Statement> for DfStatement { let s = match s { Statement::Query(query) => SpStatement::Query(Box::new(query.inner.clone())), Statement::Explain(explain) => explain.inner.clone(), + Statement::Insert(insert) => insert.inner.clone(), _ => { return ConvertToDfStatementSnafu { statement: format!("{s:?}"), diff --git a/src/table/src/error.rs b/src/table/src/error.rs index 0376db64cc..5561c5613a 100644 --- a/src/table/src/error.rs +++ b/src/table/src/error.rs @@ -18,6 +18,7 @@ use common_error::prelude::*; use common_recordbatch::error::Error as RecordBatchError; use datafusion::error::DataFusionError; use datatypes::arrow::error::ArrowError; +use datatypes::prelude::ConcreteDataType; pub type Result = std::result::Result; @@ -114,6 +115,19 @@ pub enum Error { value: String, backtrace: Backtrace, }, + + #[snafu(display( + "Failed to cast vector of type '{:?}' to type '{:?}', source: {}", + from_type, + to_type, + source + ))] + CastVector { + from_type: ConcreteDataType, + to_type: ConcreteDataType, + #[snafu(backtrace)] + source: datatypes::error::Error, + }, } impl ErrorExt for Error { @@ -128,7 +142,9 @@ impl ErrorExt for Error { } Error::TablesRecordBatch { .. } => StatusCode::Unexpected, Error::ColumnExists { .. } => StatusCode::TableColumnExists, - Error::SchemaBuild { source, .. } => source.status_code(), + Error::SchemaBuild { source, .. } | Error::CastVector { source, .. } => { + source.status_code() + } Error::TableOperation { source } => source.status_code(), Error::ColumnNotExists { .. } => StatusCode::TableColumnNotFound, Error::RegionSchemaMismatch { .. } => StatusCode::StorageUnavailable, diff --git a/src/table/src/requests.rs b/src/table/src/requests.rs index f5cead9221..957a62eaf0 100644 --- a/src/table/src/requests.rs +++ b/src/table/src/requests.rs @@ -13,6 +13,8 @@ // limitations under the License. //! Table and TableEngine requests +mod insert; + use std::collections::HashMap; use std::str::FromStr; use std::time::Duration; @@ -20,6 +22,7 @@ use std::time::Duration; use common_base::readable_size::ReadableSize; use datatypes::prelude::VectorRef; use datatypes::schema::{ColumnSchema, RawSchema}; +pub use insert::InsertRequest; use serde::{Deserialize, Serialize}; use store_api::storage::RegionNumber; @@ -27,16 +30,6 @@ use crate::error; use crate::error::ParseTableOptionSnafu; use crate::metadata::TableId; -/// Insert request -#[derive(Debug)] -pub struct InsertRequest { - pub catalog_name: String, - pub schema_name: String, - pub table_name: String, - pub columns_values: HashMap, - pub region_number: RegionNumber, -} - #[derive(Debug, Clone)] pub struct CreateDatabaseRequest { pub db_name: String, diff --git a/src/table/src/requests/insert.rs b/src/table/src/requests/insert.rs new file mode 100644 index 0000000000..ddf6f8c271 --- /dev/null +++ b/src/table/src/requests/insert.rs @@ -0,0 +1,79 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; + +use common_recordbatch::RecordBatch; +use datafusion_common::ResolvedTableReference; +use datatypes::prelude::VectorRef; +use datatypes::schema::SchemaRef; +use snafu::{OptionExt, ResultExt}; +use store_api::storage::RegionNumber; + +use crate::error::{CastVectorSnafu, ColumnNotExistsSnafu, Result}; + +#[derive(Debug)] +pub struct InsertRequest { + pub catalog_name: String, + pub schema_name: String, + pub table_name: String, + pub columns_values: HashMap, + pub region_number: RegionNumber, +} + +impl InsertRequest { + pub fn try_from_recordbatch( + table_name: &ResolvedTableReference, + table_schema: SchemaRef, + recordbatch: RecordBatch, + ) -> Result { + let mut columns_values = HashMap::with_capacity(recordbatch.num_columns()); + + // column schemas in recordbatch must match its vectors, otherwise it's corrupted + for (vector_schema, vector) in recordbatch + .schema + .column_schemas() + .iter() + .zip(recordbatch.columns().iter()) + { + let column_name = &vector_schema.name; + let column_schema = table_schema + .column_schema_by_name(column_name) + .with_context(|| ColumnNotExistsSnafu { + table_name: table_name.table.to_string(), + column_name, + })?; + let vector = if vector_schema.data_type != column_schema.data_type { + vector + .cast(&column_schema.data_type) + .with_context(|_| CastVectorSnafu { + from_type: vector.data_type(), + to_type: column_schema.data_type.clone(), + })? + } else { + vector.clone() + }; + + columns_values.insert(column_name.clone(), vector); + } + + Ok(InsertRequest { + catalog_name: table_name.catalog.to_string(), + schema_name: table_name.schema.to_string(), + table_name: table_name.table.to_string(), + columns_values, + region_number: 0, + }) + } +} diff --git a/tests/cases/standalone/common/insert/insert_select.result b/tests/cases/standalone/common/insert/insert_select.result new file mode 100644 index 0000000000..209502d9b0 --- /dev/null +++ b/tests/cases/standalone/common/insert/insert_select.result @@ -0,0 +1,45 @@ +create table demo1(host string, cpu double, memory double, ts timestamp time index); + +Affected Rows: 0 + +create table demo2(host string, cpu double, memory double, ts timestamp time index); + +Affected Rows: 0 + +insert into demo1(host, cpu, memory, ts) values ('host1', 66.6, 1024, 1655276557000), ('host2', 88.8, 333.3, 1655276558000); + +Affected Rows: 2 + +insert into demo2(host) select * from demo1; + +Error: 3000(PlanQuery), Error during planning: Column count doesn't match insert query! + +insert into demo2 select cpu,memory from demo1; + +Error: 3000(PlanQuery), Error during planning: Column count doesn't match insert query! + +insert into demo2(ts) select memory from demo1; + +Error: 3000(PlanQuery), Error during planning: Cannot automatically convert Float64 to Timestamp(Millisecond, None) + +insert into demo2 select * from demo1; + +Affected Rows: 2 + +select * from demo2 order by ts; + ++-------+------+--------+---------------------+ +| host | cpu | memory | ts | ++-------+------+--------+---------------------+ +| host1 | 66.6 | 1024.0 | 2022-06-15T07:02:37 | +| host2 | 88.8 | 333.3 | 2022-06-15T07:02:38 | ++-------+------+--------+---------------------+ + +drop table demo1; + +Affected Rows: 1 + +drop table demo2; + +Affected Rows: 1 + diff --git a/tests/cases/standalone/common/insert/insert_select.sql b/tests/cases/standalone/common/insert/insert_select.sql new file mode 100644 index 0000000000..27fd9373ee --- /dev/null +++ b/tests/cases/standalone/common/insert/insert_select.sql @@ -0,0 +1,19 @@ +create table demo1(host string, cpu double, memory double, ts timestamp time index); + +create table demo2(host string, cpu double, memory double, ts timestamp time index); + +insert into demo1(host, cpu, memory, ts) values ('host1', 66.6, 1024, 1655276557000), ('host2', 88.8, 333.3, 1655276558000); + +insert into demo2(host) select * from demo1; + +insert into demo2 select cpu,memory from demo1; + +insert into demo2(ts) select memory from demo1; + +insert into demo2 select * from demo1; + +select * from demo2 order by ts; + +drop table demo1; + +drop table demo2;