From a9c8584c987318a41edd7fe4e2ebc382538db7af Mon Sep 17 00:00:00 2001 From: dennis zhuang Date: Fri, 17 Feb 2023 17:56:12 +0800 Subject: [PATCH] feat: impl insert data from query (#1025) * feat: refactor insertion in datanode * feat: supports inserting data by select query * feat: impl cast operation for vector * feat: streaming insert from select query results * chore: minor changes * fix: remove unwrap * test: insert_to_requsts * test: test_execute_insert_by_select * fix: cast operation for vectors * fix: test * fix: typo * chore: by CR comments * fix: test_statement_to_request --- Cargo.lock | 1 + src/datanode/Cargo.toml | 1 + src/datanode/src/error.rs | 131 +++++++----- src/datanode/src/instance/sql.rs | 38 +++- src/datanode/src/sql.rs | 70 ++++++- src/datanode/src/sql/insert.rs | 209 ++++++++++++++++-- src/datanode/src/tests/instance_test.rs | 69 ++++++ src/datanode/src/tests/test_util.rs | 2 +- src/datatypes/src/error.rs | 7 + src/datatypes/src/vectors/constant.rs | 7 + src/datatypes/src/vectors/operations.rs | 32 ++- src/datatypes/src/vectors/operations/cast.rs | 210 +++++++++++++++++++ src/frontend/src/error.rs | 20 +- src/frontend/src/sql.rs | 5 +- src/sql/src/lib.rs | 2 +- src/sql/src/statements/insert.rs | 84 ++++++-- 16 files changed, 769 insertions(+), 119 deletions(-) create mode 100644 src/datatypes/src/vectors/operations/cast.rs diff --git a/Cargo.lock b/Cargo.lock index 5f6dd175fc..8d3f5a56c4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2187,6 +2187,7 @@ dependencies = [ "common-time", "datafusion", "datafusion-common", + "datafusion-expr", "datatypes", "futures", "humantime-serde", diff --git a/src/datanode/Cargo.toml b/src/datanode/Cargo.toml index 3482f1c54e..10c6430c02 100644 --- a/src/datanode/Cargo.toml +++ b/src/datanode/Cargo.toml @@ -27,6 +27,7 @@ common-runtime = { path = "../common/runtime" } common-telemetry = { path = "../common/telemetry" } common-time = { path = "../common/time" } datafusion.workspace = true +datafusion-expr.workspace = true datatypes = { path = "../datatypes" } futures = "0.3" hyper = { version = "0.14", features = ["full"] } diff --git a/src/datanode/src/error.rs b/src/datanode/src/error.rs index 208d9eca25..052ee7075c 100644 --- a/src/datanode/src/error.rs +++ b/src/datanode/src/error.rs @@ -15,6 +15,8 @@ use std::any::Any; use common_error::prelude::*; +use common_recordbatch::error::Error as RecordBatchError; +use datatypes::prelude::ConcreteDataType; use storage::error::Error as StorageError; use table::error::Error as TableError; @@ -101,12 +103,33 @@ pub enum Error { ))] ColumnValuesNumberMismatch { columns: usize, values: usize }, + #[snafu(display( + "Column type mismatch, column: {}, expected type: {:?}, actual: {:?}", + column, + expected, + actual, + ))] + ColumnTypeMismatch { + column: String, + expected: ConcreteDataType, + actual: ConcreteDataType, + }, + + #[snafu(display("Failed to collect record batch, source: {}", source))] + CollectRecords { + #[snafu(backtrace)] + source: RecordBatchError, + }, + #[snafu(display("Failed to parse sql value, source: {}", source))] ParseSqlValue { #[snafu(backtrace)] source: sql::error::Error, }, + #[snafu(display("Missing insert body"))] + MissingInsertBody { backtrace: Backtrace }, + #[snafu(display("Failed to insert value to table: {}, source: {}", table_name, source))] Insert { table_name: String, @@ -338,73 +361,71 @@ pub type Result = std::result::Result; impl ErrorExt for Error { fn status_code(&self) -> StatusCode { + use Error::*; match self { - Error::ExecuteSql { source } | Error::DescribeStatement { source } => { + ExecuteSql { source } | DescribeStatement { source } => source.status_code(), + DecodeLogicalPlan { source } => source.status_code(), + NewCatalog { source } | RegisterSchema { source } => source.status_code(), + FindTable { source, .. } => source.status_code(), + CreateTable { source, .. } | GetTable { source, .. } | AlterTable { source, .. } => { source.status_code() } - Error::DecodeLogicalPlan { source } => source.status_code(), - Error::NewCatalog { source } | Error::RegisterSchema { source } => source.status_code(), - Error::FindTable { source, .. } => source.status_code(), - Error::CreateTable { source, .. } - | Error::GetTable { source, .. } - | Error::AlterTable { source, .. } => source.status_code(), - Error::DropTable { source, .. } => source.status_code(), + DropTable { source, .. } => source.status_code(), - Error::Insert { source, .. } => source.status_code(), - Error::Delete { source, .. } => source.status_code(), + Insert { source, .. } => source.status_code(), + Delete { source, .. } => source.status_code(), + CollectRecords { source, .. } => source.status_code(), - Error::TableNotFound { .. } => StatusCode::TableNotFound, - Error::ColumnNotFound { .. } => StatusCode::TableColumnNotFound, + TableNotFound { .. } => StatusCode::TableNotFound, + ColumnNotFound { .. } => StatusCode::TableColumnNotFound, - Error::ParseSqlValue { source, .. } | Error::ParseSql { source, .. } => { - source.status_code() - } + ParseSqlValue { source, .. } | ParseSql { source, .. } => source.status_code(), - Error::AlterExprToRequest { source, .. } - | Error::CreateExprToRequest { source } - | Error::InsertData { source } => source.status_code(), + AlterExprToRequest { source, .. } + | CreateExprToRequest { source } + | InsertData { source } => source.status_code(), - Error::ConvertSchema { source, .. } | Error::VectorComputation { source } => { - source.status_code() - } + ConvertSchema { source, .. } | VectorComputation { source } => source.status_code(), - Error::ColumnValuesNumberMismatch { .. } - | Error::InvalidSql { .. } - | Error::NotSupportSql { .. } - | Error::KeyColumnNotFound { .. } - | Error::IllegalPrimaryKeysDef { .. } - | Error::MissingTimestampColumn { .. } - | Error::CatalogNotFound { .. } - | Error::SchemaNotFound { .. } - | Error::ConstraintNotSupported { .. } - | Error::SchemaExists { .. } - | Error::ParseTimestamp { .. } - | Error::DatabaseNotFound { .. } => StatusCode::InvalidArguments, + ColumnValuesNumberMismatch { .. } + | ColumnTypeMismatch { .. } + | InvalidSql { .. } + | NotSupportSql { .. } + | KeyColumnNotFound { .. } + | IllegalPrimaryKeysDef { .. } + | MissingTimestampColumn { .. } + | CatalogNotFound { .. } + | SchemaNotFound { .. } + | ConstraintNotSupported { .. } + | SchemaExists { .. } + | ParseTimestamp { .. } + | MissingInsertBody { .. } + | DatabaseNotFound { .. } + | MissingNodeId { .. } + | MissingMetasrvOpts { .. } + | ColumnNoneDefaultValue { .. } => StatusCode::InvalidArguments, // TODO(yingwen): Further categorize http error. - Error::StartServer { .. } - | Error::ParseAddr { .. } - | Error::TcpBind { .. } - | Error::StartGrpc { .. } - | Error::CreateDir { .. } - | Error::InsertSystemCatalog { .. } - | Error::RenameTable { .. } - | Error::Catalog { .. } - | Error::MissingRequiredField { .. } - | Error::IncorrectInternalState { .. } => StatusCode::Internal, + StartServer { .. } + | ParseAddr { .. } + | TcpBind { .. } + | StartGrpc { .. } + | CreateDir { .. } + | InsertSystemCatalog { .. } + | RenameTable { .. } + | Catalog { .. } + | MissingRequiredField { .. } + | IncorrectInternalState { .. } => StatusCode::Internal, - Error::InitBackend { .. } => StatusCode::StorageUnavailable, - Error::OpenLogStore { source } => source.status_code(), - Error::StartScriptManager { source } => source.status_code(), - Error::OpenStorageEngine { source } => source.status_code(), - Error::RuntimeResource { .. } => StatusCode::RuntimeResourcesExhausted, - Error::MetaClientInit { source, .. } => source.status_code(), - Error::TableIdProviderNotFound { .. } => StatusCode::Unsupported, - Error::BumpTableId { source, .. } => source.status_code(), - Error::MissingNodeId { .. } => StatusCode::InvalidArguments, - Error::MissingMetasrvOpts { .. } => StatusCode::InvalidArguments, - Error::ColumnDefaultValue { source, .. } => source.status_code(), - Error::ColumnNoneDefaultValue { .. } => StatusCode::InvalidArguments, + InitBackend { .. } => StatusCode::StorageUnavailable, + OpenLogStore { source } => source.status_code(), + StartScriptManager { source } => source.status_code(), + OpenStorageEngine { source } => source.status_code(), + RuntimeResource { .. } => StatusCode::RuntimeResourcesExhausted, + MetaClientInit { source, .. } => source.status_code(), + TableIdProviderNotFound { .. } => StatusCode::Unsupported, + BumpTableId { source, .. } => source.status_code(), + ColumnDefaultValue { source, .. } => source.status_code(), } } diff --git a/src/datanode/src/instance/sql.rs b/src/datanode/src/instance/sql.rs index 367f2abba2..28441efa55 100644 --- a/src/datanode/src/instance/sql.rs +++ b/src/datanode/src/instance/sql.rs @@ -21,6 +21,7 @@ use common_recordbatch::RecordBatches; use common_telemetry::logging::info; use common_telemetry::timer; use datatypes::schema::Schema; +use futures::StreamExt; use query::parser::{PromQuery, QueryLanguageParser, QueryStatement}; use servers::error as server_error; use servers::promql::PromqlHandler; @@ -35,6 +36,7 @@ use table::requests::{CreateDatabaseRequest, DropTableRequest}; use crate::error::{self, BumpTableIdSnafu, ExecuteSqlSnafu, Result, TableIdProviderNotFoundSnafu}; use crate::instance::Instance; use crate::metric; +use crate::sql::insert::InsertRequests; use crate::sql::SqlRequest; impl Instance { @@ -56,15 +58,33 @@ impl Instance { .context(ExecuteSqlSnafu) } QueryStatement::Sql(Statement::Insert(i)) => { - let (catalog, schema, table) = - table_idents_to_full_name(i.table_name(), query_ctx.clone())?; - let table_ref = TableReference::full(&catalog, &schema, &table); - let request = self.sql_handler.insert_to_request( - self.catalog_manager.clone(), - *i, - table_ref, - )?; - self.sql_handler.execute(request, query_ctx).await + let requests = self + .sql_handler + .insert_to_requests(self.catalog_manager.clone(), *i, query_ctx.clone()) + .await?; + + match requests { + InsertRequests::Request(request) => { + self.sql_handler.execute(request, query_ctx.clone()).await + } + + InsertRequests::Stream(mut s) => { + let mut rows = 0; + while let Some(request) = s.next().await { + match self + .sql_handler + .execute(request?, query_ctx.clone()) + .await? + { + Output::AffectedRows(n) => { + rows += n; + } + _ => unreachable!(), + } + } + Ok(Output::AffectedRows(rows)) + } + } } QueryStatement::Sql(Statement::Delete(d)) => { let request = SqlRequest::Delete(*d); diff --git a/src/datanode/src/sql.rs b/src/datanode/src/sql.rs index bd08b873d5..beea10e72a 100644 --- a/src/datanode/src/sql.rs +++ b/src/datanode/src/sql.rs @@ -34,7 +34,7 @@ mod alter; mod create; mod delete; mod drop_table; -mod insert; +pub(crate) mod insert; #[derive(Debug)] pub enum SqlRequest { @@ -142,6 +142,7 @@ mod tests { use datatypes::prelude::ConcreteDataType; use datatypes::schema::{ColumnSchema, SchemaBuilder, SchemaRef}; use datatypes::value::Value; + use futures::StreamExt; use log_store::NoopLogStore; use mito::config::EngineConfig as TableEngineConfig; use mito::engine::MitoEngine; @@ -149,17 +150,19 @@ mod tests { use object_store::ObjectStore; use query::parser::{QueryLanguageParser, QueryStatement}; use query::QueryEngineFactory; + use session::context::QueryContext; use sql::statements::statement::Statement; use storage::compaction::noop::NoopCompactionScheduler; use storage::config::EngineConfig as StorageEngineConfig; use storage::EngineImpl; - use table::engine::TableReference; use table::error::Result as TableResult; use table::metadata::TableInfoRef; use table::Table; use tempdir::TempDir; use super::*; + use crate::error::Error; + use crate::sql::insert::InsertRequests; struct DemoTable; @@ -255,11 +258,12 @@ mod tests { } }; let request = sql_handler - .insert_to_request(catalog_list.clone(), *stmt, TableReference::bare("demo")) + .insert_to_requests(catalog_list.clone(), *stmt, QueryContext::arc()) + .await .unwrap(); match request { - SqlRequest::Insert(req) => { + InsertRequests::Request(SqlRequest::Insert(req)) => { assert_eq!(req.table_name, "demo"); let columns_values = req.columns_values; assert_eq!(4, columns_values.len()); @@ -294,5 +298,63 @@ mod tests { panic!("Not supposed to reach here") } } + + // test inert into select + + // type mismatch + let sql = "insert into demo(ts) select number from numbers limit 3"; + + let stmt = match QueryLanguageParser::parse_sql(sql).unwrap() { + QueryStatement::Sql(Statement::Insert(i)) => i, + _ => { + unreachable!() + } + }; + let request = sql_handler + .insert_to_requests(catalog_list.clone(), *stmt, QueryContext::arc()) + .await + .unwrap(); + + match request { + InsertRequests::Stream(mut stream) => { + assert!(matches!( + stream.next().await.unwrap().unwrap_err(), + Error::ColumnTypeMismatch { .. } + )); + } + _ => unreachable!(), + } + + let sql = "insert into demo(cpu) select cast(number as double) from numbers limit 3"; + let stmt = match QueryLanguageParser::parse_sql(sql).unwrap() { + QueryStatement::Sql(Statement::Insert(i)) => i, + _ => { + unreachable!() + } + }; + let request = sql_handler + .insert_to_requests(catalog_list.clone(), *stmt, QueryContext::arc()) + .await + .unwrap(); + + match request { + InsertRequests::Stream(mut stream) => { + let mut times = 0; + while let Some(Ok(SqlRequest::Insert(req))) = stream.next().await { + times += 1; + assert_eq!(req.table_name, "demo"); + let columns_values = req.columns_values; + assert_eq!(1, columns_values.len()); + + let memories = &columns_values["cpu"]; + assert_eq!(3, memories.len()); + assert_eq!(Value::from(0.0f64), memories.get(0)); + assert_eq!(Value::from(1.0f64), memories.get(1)); + assert_eq!(Value::from(2.0f64), memories.get(2)); + } + assert_eq!(1, times); + } + _ => unreachable!(), + } } } diff --git a/src/datanode/src/sql/insert.rs b/src/datanode/src/sql/insert.rs index 7b739bd260..3ee2ba5a96 100644 --- a/src/datanode/src/sql/insert.rs +++ b/src/datanode/src/sql/insert.rs @@ -11,28 +11,48 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; +use std::pin::Pin; use catalog::CatalogManagerRef; use common_query::Output; +use common_recordbatch::RecordBatch; +use datafusion_expr::type_coercion::binary::coerce_types; +use datafusion_expr::Operator; use datatypes::data_type::DataType; use datatypes::schema::ColumnSchema; use datatypes::vectors::MutableVector; +use futures::stream::{self, StreamExt}; +use futures::Stream; +use query::parser::QueryStatement; +use session::context::QueryContextRef; use snafu::{ensure, OptionExt, ResultExt}; use sql::ast::Value as SqlValue; use sql::statements::insert::Insert; +use sql::statements::statement::Statement; use sql::statements::{self}; use table::engine::TableReference; use table::requests::*; +use table::TableRef; use crate::error::{ - CatalogSnafu, ColumnDefaultValueSnafu, ColumnNoneDefaultValueSnafu, ColumnNotFoundSnafu, - ColumnValuesNumberMismatchSnafu, InsertSnafu, ParseSqlSnafu, ParseSqlValueSnafu, Result, - TableNotFoundSnafu, + CatalogSnafu, CollectRecordsSnafu, ColumnDefaultValueSnafu, ColumnNoneDefaultValueSnafu, + ColumnNotFoundSnafu, ColumnTypeMismatchSnafu, ColumnValuesNumberMismatchSnafu, Error, + ExecuteSqlSnafu, InsertSnafu, MissingInsertBodySnafu, ParseSqlSnafu, ParseSqlValueSnafu, + Result, TableNotFoundSnafu, }; -use crate::sql::{SqlHandler, SqlRequest}; +use crate::sql::{table_idents_to_full_name, SqlHandler, SqlRequest}; const DEFAULT_PLACEHOLDER_VALUE: &str = "default"; +type InsertRequestStream = Pin> + Send>>; +pub(crate) enum InsertRequests { + // Single request + Request(SqlRequest), + // Streaming requests + Stream(InsertRequestStream), +} + impl SqlHandler { pub(crate) async fn insert(&self, req: InsertRequest) -> Result { // FIXME(dennis): table_ref is used in InsertSnafu and the req is consumed @@ -52,21 +72,16 @@ impl SqlHandler { Ok(Output::AffectedRows(affected_rows)) } - pub(crate) fn insert_to_request( - &self, - catalog_manager: CatalogManagerRef, - stmt: Insert, + fn build_request_from_values( table_ref: TableReference, + table: &TableRef, + stmt: Insert, ) -> Result { + let values = stmt + .values_body() + .context(ParseSqlValueSnafu)? + .context(MissingInsertBodySnafu)?; let columns = stmt.columns(); - let values = stmt.values().context(ParseSqlValueSnafu)?; - - let table = catalog_manager - .table(table_ref.catalog, table_ref.schema, table_ref.table) - .context(CatalogSnafu)? - .context(TableNotFoundSnafu { - table_name: table_ref.table, - })?; let schema = table.schema(); let columns_num = if columns.is_empty() { schema.column_schemas().len() @@ -78,6 +93,7 @@ impl SqlHandler { let mut columns_builders: Vec<(&ColumnSchema, Box)> = Vec::with_capacity(columns_num); + // Initialize vectors if columns.is_empty() { for column_schema in schema.column_schemas() { let data_type = &column_schema.data_type; @@ -123,6 +139,167 @@ impl SqlHandler { region_number: 0, })) } + + fn build_request_from_batch( + stmt: Insert, + table: TableRef, + batch: RecordBatch, + query_ctx: QueryContextRef, + ) -> Result { + let (catalog_name, schema_name, table_name) = + table_idents_to_full_name(stmt.table_name(), query_ctx)?; + + let schema = table.schema(); + let columns: Vec<_> = if stmt.columns().is_empty() { + schema + .column_schemas() + .iter() + .map(|c| c.name.to_string()) + .collect() + } else { + stmt.columns().iter().map(|c| (*c).clone()).collect() + }; + let columns_num = columns.len(); + + ensure!( + batch.num_columns() == columns_num, + ColumnValuesNumberMismatchSnafu { + columns: columns_num, + values: batch.num_columns(), + } + ); + + let batch_schema = &batch.schema; + let batch_columns = batch_schema.column_schemas(); + assert_eq!(batch_columns.len(), columns_num); + let mut columns_values = HashMap::with_capacity(columns_num); + + for (i, column_name) in columns.into_iter().enumerate() { + let column_schema = schema + .column_schema_by_name(&column_name) + .with_context(|| ColumnNotFoundSnafu { + table_name: &table_name, + column_name: &column_name, + })?; + let expect_datatype = column_schema.data_type.as_arrow_type(); + // It's safe to retrieve the column schema by index, we already + // check columns number is the same above. + let batch_datatype = batch_columns[i].data_type.as_arrow_type(); + let coerced_type = coerce_types(&expect_datatype, &Operator::Eq, &batch_datatype) + .map_err(|_| Error::ColumnTypeMismatch { + column: column_name.clone(), + expected: column_schema.data_type.clone(), + actual: batch_columns[i].data_type.clone(), + })?; + + ensure!( + expect_datatype == coerced_type, + ColumnTypeMismatchSnafu { + column: column_name, + expected: column_schema.data_type.clone(), + actual: batch_columns[i].data_type.clone(), + } + ); + let vector = batch + .column(i) + .cast(&column_schema.data_type) + .map_err(|_| Error::ColumnTypeMismatch { + column: column_name.clone(), + expected: column_schema.data_type.clone(), + actual: batch_columns[i].data_type.clone(), + })?; + + columns_values.insert(column_name, vector); + } + + Ok(SqlRequest::Insert(InsertRequest { + catalog_name, + schema_name, + table_name, + columns_values, + region_number: 0, + })) + } + + // FIXME(dennis): move it to frontend when refactor is done. + async fn build_stream_from_query( + &self, + table: TableRef, + stmt: Insert, + query_ctx: QueryContextRef, + ) -> Result { + let query = stmt + .query_body() + .context(ParseSqlValueSnafu)? + .context(MissingInsertBodySnafu)?; + + let logical_plan = self + .query_engine + .statement_to_plan( + QueryStatement::Sql(Statement::Query(Box::new(query))), + query_ctx.clone(), + ) + .context(ExecuteSqlSnafu)?; + + let output = self + .query_engine + .execute(&logical_plan) + .await + .context(ExecuteSqlSnafu)?; + + let stream: InsertRequestStream = match output { + Output::RecordBatches(batches) => { + Box::pin(stream::iter(batches.take()).map(move |batch| { + Self::build_request_from_batch( + stmt.clone(), + table.clone(), + batch, + query_ctx.clone(), + ) + })) + } + + Output::Stream(stream) => Box::pin(stream.map(move |batch| { + Self::build_request_from_batch( + stmt.clone(), + table.clone(), + batch.context(CollectRecordsSnafu)?, + query_ctx.clone(), + ) + })), + _ => unreachable!(), + }; + + Ok(stream) + } + + pub(crate) async fn insert_to_requests( + &self, + catalog_manager: CatalogManagerRef, + stmt: Insert, + query_ctx: QueryContextRef, + ) -> Result { + let (catalog_name, schema_name, table_name) = + table_idents_to_full_name(stmt.table_name(), query_ctx.clone())?; + + let table = catalog_manager + .table(&catalog_name, &schema_name, &table_name) + .context(CatalogSnafu)? + .with_context(|| TableNotFoundSnafu { + table_name: table_name.clone(), + })?; + + if stmt.is_insert_select() { + Ok(InsertRequests::Stream( + self.build_stream_from_query(table, stmt, query_ctx).await?, + )) + } else { + let table_ref = TableReference::full(&catalog_name, &schema_name, &table_name); + Ok(InsertRequests::Request(Self::build_request_from_values( + table_ref, &table, stmt, + )?)) + } + } } fn add_row_to_vector( diff --git a/src/datanode/src/tests/instance_test.rs b/src/datanode/src/tests/instance_test.rs index 9060815811..c81e5f0aff 100644 --- a/src/datanode/src/tests/instance_test.rs +++ b/src/datanode/src/tests/instance_test.rs @@ -21,6 +21,7 @@ use datatypes::data_type::ConcreteDataType; use datatypes::vectors::{Int64Vector, StringVector, UInt64Vector, VectorRef}; use session::context::QueryContext; +use crate::error::Error; use crate::tests::test_util::{self, check_output_stream, setup_test_instance, MockInstance}; #[tokio::test(flavor = "multi_thread")] @@ -181,6 +182,67 @@ async fn test_execute_insert() { assert!(matches!(output, Output::AffectedRows(2))); } +#[tokio::test(flavor = "multi_thread")] +async fn test_execute_insert_by_select() { + let instance = setup_test_instance("test_execute_insert_by_select").await; + + // create table + execute_sql( + &instance, + "create table demo1(host string, cpu double, memory double, ts timestamp time index);", + ) + .await; + execute_sql( + &instance, + "create table demo2(host string, cpu double, memory double, ts timestamp time index);", + ) + .await; + + let output = execute_sql( + &instance, + r#"insert into demo1(host, cpu, memory, ts) values + ('host1', 66.6, 1024, 1655276557000), + ('host2', 88.8, 333.3, 1655276558000) + "#, + ) + .await; + assert!(matches!(output, Output::AffectedRows(2))); + + assert!(matches!( + try_execute_sql(&instance, "insert into demo2(host) select * from demo1") + .await + .unwrap_err(), + Error::ColumnValuesNumberMismatch { .. } + )); + assert!(matches!( + try_execute_sql(&instance, "insert into demo2 select cpu,memory from demo1") + .await + .unwrap_err(), + Error::ColumnValuesNumberMismatch { .. } + )); + + assert!(matches!( + try_execute_sql(&instance, "insert into demo2(ts) select memory from demo1") + .await + .unwrap_err(), + Error::ColumnTypeMismatch { .. } + )); + + let output = execute_sql(&instance, "insert into demo2 select * from demo1").await; + assert!(matches!(output, Output::AffectedRows(2))); + + let output = execute_sql(&instance, "select * from demo2 order by ts").await; + let expected = "\ ++-------+------+--------+---------------------+ +| host | cpu | memory | ts | ++-------+------+--------+---------------------+ +| host1 | 66.6 | 1024 | 2022-06-15T07:02:37 | +| host2 | 88.8 | 333.3 | 2022-06-15T07:02:38 | ++-------+------+--------+---------------------+" + .to_string(); + check_output_stream(output, expected).await; +} + #[tokio::test(flavor = "multi_thread")] async fn test_execute_insert_query_with_i64_timestamp() { let instance = MockInstance::new("insert_query_i64_timestamp").await; @@ -707,6 +769,13 @@ async fn execute_sql(instance: &MockInstance, sql: &str) -> Output { execute_sql_in_db(instance, sql, DEFAULT_SCHEMA_NAME).await } +async fn try_execute_sql( + instance: &MockInstance, + sql: &str, +) -> Result { + try_execute_sql_in_db(instance, sql, DEFAULT_SCHEMA_NAME).await +} + async fn try_execute_sql_in_db( instance: &MockInstance, sql: &str, diff --git a/src/datanode/src/tests/test_util.rs b/src/datanode/src/tests/test_util.rs index 5365ad103c..1084886fea 100644 --- a/src/datanode/src/tests/test_util.rs +++ b/src/datanode/src/tests/test_util.rs @@ -155,7 +155,7 @@ pub async fn check_output_stream(output: Output, expected: String) { _ => unreachable!(), }; let pretty_print = recordbatches.pretty_print().unwrap(); - assert_eq!(pretty_print, expected); + assert_eq!(pretty_print, expected, "{}", pretty_print); } pub async fn check_unordered_output_stream(output: Output, expected: String) { diff --git a/src/datatypes/src/error.rs b/src/datatypes/src/error.rs index 1b4c034c9d..53b81d7dd6 100644 --- a/src/datatypes/src/error.rs +++ b/src/datatypes/src/error.rs @@ -52,6 +52,13 @@ pub enum Error { backtrace: Backtrace, }, + #[snafu(display("Unsupported operation: {} for vector: {}", op, vector_type))] + UnsupportedOperation { + op: String, + vector_type: String, + backtrace: Backtrace, + }, + #[snafu(display("Timestamp column {} not found", name,))] TimestampNotFound { name: String, backtrace: Backtrace }, diff --git a/src/datatypes/src/vectors/constant.rs b/src/datatypes/src/vectors/constant.rs index da5ac16f24..a2e7bc76dc 100644 --- a/src/datatypes/src/vectors/constant.rs +++ b/src/datatypes/src/vectors/constant.rs @@ -76,6 +76,13 @@ impl ConstantVector { } Ok(Arc::new(ConstantVector::new(self.inner().clone(), length))) } + + pub(crate) fn cast_vector(&self, to_type: &ConcreteDataType) -> Result { + Ok(Arc::new(ConstantVector::new( + self.inner().cast(to_type)?, + self.length, + ))) + } } impl Vector for ConstantVector { diff --git a/src/datatypes/src/vectors/operations.rs b/src/datatypes/src/vectors/operations.rs index adb430c96a..11ff506bb8 100644 --- a/src/datatypes/src/vectors/operations.rs +++ b/src/datatypes/src/vectors/operations.rs @@ -12,18 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod cast; mod filter; mod find_unique; mod replicate; use common_base::BitVec; -use crate::error::Result; +use crate::error::{self, Result}; use crate::types::LogicalPrimitiveType; use crate::vectors::constant::ConstantVector; use crate::vectors::{ - BinaryVector, BooleanVector, ListVector, NullVector, PrimitiveVector, StringVector, Vector, - VectorRef, + BinaryVector, BooleanVector, ConcreteDataType, ListVector, NullVector, PrimitiveVector, + StringVector, Vector, VectorRef, }; /// Vector compute operations. @@ -57,6 +58,11 @@ pub trait VectorOp { /// /// Note that the nulls of `filter` are interpreted as `false` will lead to these elements being masked out. fn filter(&self, filter: &BooleanVector) -> Result; + + /// Cast vector to the provided data type and return a new vector with type to_type, if possible. + /// + /// TODO(dennis) describe behaviors in details. + fn cast(&self, to_type: &ConcreteDataType) -> Result; } macro_rules! impl_scalar_vector_op { @@ -74,6 +80,10 @@ macro_rules! impl_scalar_vector_op { fn filter(&self, filter: &BooleanVector) -> Result { filter::filter_non_constant!(self, $VectorType, filter) } + + fn cast(&self, to_type: &ConcreteDataType) -> Result { + cast::cast_non_constant!(self, to_type) + } } )+}; } @@ -94,6 +104,10 @@ impl VectorOp for PrimitiveVector { fn filter(&self, filter: &BooleanVector) -> Result { filter::filter_non_constant!(self, PrimitiveVector, filter) } + + fn cast(&self, to_type: &ConcreteDataType) -> Result { + cast::cast_non_constant!(self, to_type) + } } impl VectorOp for NullVector { @@ -109,6 +123,14 @@ impl VectorOp for NullVector { fn filter(&self, filter: &BooleanVector) -> Result { filter::filter_non_constant!(self, NullVector, filter) } + fn cast(&self, _to_type: &ConcreteDataType) -> Result { + // TODO(dennis): impl it when NullVector has other datatype. + error::UnsupportedOperationSnafu { + op: "cast", + vector_type: self.vector_type_name(), + } + .fail() + } } impl VectorOp for ConstantVector { @@ -124,4 +146,8 @@ impl VectorOp for ConstantVector { fn filter(&self, filter: &BooleanVector) -> Result { self.filter_vector(filter) } + + fn cast(&self, to_type: &ConcreteDataType) -> Result { + self.cast_vector(to_type) + } } diff --git a/src/datatypes/src/vectors/operations/cast.rs b/src/datatypes/src/vectors/operations/cast.rs new file mode 100644 index 0000000000..8929b54a96 --- /dev/null +++ b/src/datatypes/src/vectors/operations/cast.rs @@ -0,0 +1,210 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +macro_rules! cast_non_constant { + ($vector: expr, $to_type: expr) => {{ + use arrow::compute; + use snafu::ResultExt; + + use crate::data_type::DataType; + use crate::vectors::helper::Helper; + + let arrow_array = $vector.to_arrow_array(); + let casted = compute::cast(&arrow_array, &$to_type.as_arrow_type()) + .context(crate::error::ArrowComputeSnafu)?; + Helper::try_into_vector(casted) + }}; +} + +pub(crate) use cast_non_constant; + +/// There are already many test cases in arrow: +/// https://github.com/apache/arrow-rs/blob/59016e53e5cfa1d368009ed640d1f3dce326e7bb/arrow-cast/src/cast.rs#L3349-L7584 +/// So we don't(can't) want to copy these cases, just test some cases which are important for us. +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use common_time::date::Date; + use common_time::timestamp::{TimeUnit, Timestamp}; + + use crate::types::{LogicalPrimitiveType, *}; + use crate::vectors::{ConcreteDataType, *}; + + fn get_cast_values(vector: &VectorRef, dt: &ConcreteDataType) -> Vec + where + T: LogicalPrimitiveType, + { + let c = vector.cast(dt).unwrap(); + let a = c.as_any().downcast_ref::>().unwrap(); + let mut v: Vec = vec![]; + for i in 0..vector.len() { + if a.is_null(i) { + v.push("null".to_string()) + } else { + v.push(format!("{}", a.get(i))); + } + } + v + } + + #[test] + fn test_cast_from_f64() { + let f64_values: Vec = vec![ + i64::MIN as f64, + i32::MIN as f64, + i16::MIN as f64, + i8::MIN as f64, + 0_f64, + u8::MAX as f64, + u16::MAX as f64, + u32::MAX as f64, + u64::MAX as f64, + ]; + let f64_vector: VectorRef = Arc::new(Float64Vector::from_slice(&f64_values)); + + let f64_expected = vec![ + -9223372036854776000.0, + -2147483648.0, + -32768.0, + -128.0, + 0.0, + 255.0, + 65535.0, + 4294967295.0, + 18446744073709552000.0, + ]; + assert_eq!( + f64_expected, + get_cast_values::(&f64_vector, &ConcreteDataType::float64_datatype()) + .iter() + .map(|i| i.parse::().unwrap()) + .collect::>() + ); + + let i64_expected = vec![ + "-9223372036854775808", + "-2147483648", + "-32768", + "-128", + "0", + "255", + "65535", + "4294967295", + "null", + ]; + assert_eq!( + i64_expected, + get_cast_values::(&f64_vector, &ConcreteDataType::int64_datatype()) + ); + + let u64_expected = vec![ + "null", + "null", + "null", + "null", + "0", + "255", + "65535", + "4294967295", + "null", + ]; + assert_eq!( + u64_expected, + get_cast_values::(&f64_vector, &ConcreteDataType::uint64_datatype()) + ); + } + + #[test] + fn test_cast_from_date() { + let i32_values: Vec = vec![ + i32::MIN, + i16::MIN as i32, + i8::MIN as i32, + 0, + i8::MAX as i32, + i16::MAX as i32, + i32::MAX, + ]; + let date32_vector: VectorRef = Arc::new(DateVector::from_slice(&i32_values)); + + let i32_expected = vec![ + "-2147483648", + "-32768", + "-128", + "0", + "127", + "32767", + "2147483647", + ]; + assert_eq!( + i32_expected, + get_cast_values::(&date32_vector, &ConcreteDataType::int32_datatype()), + ); + + let i64_expected = vec![ + "-2147483648", + "-32768", + "-128", + "0", + "127", + "32767", + "2147483647", + ]; + assert_eq!( + i64_expected, + get_cast_values::(&date32_vector, &ConcreteDataType::int64_datatype()), + ); + } + + #[test] + fn test_cast_timestamp_to_date32() { + let vector = + TimestampMillisecondVector::from(vec![Some(864000000005), Some(1545696000001), None]); + let b = vector.cast(&ConcreteDataType::date_datatype()).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(Value::Date(Date::from(10000)), c.get(0)); + assert_eq!(Value::Date(Date::from(17890)), c.get(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_cast_string_to_timestamp() { + let a1 = Arc::new(StringVector::from(vec![ + Some("2020-09-08T12:00:00+00:00"), + Some("Not a valid date"), + None, + ])) as VectorRef; + let a2 = Arc::new(StringVector::from(vec![ + Some("2020-09-08T12:00:00+00:00"), + Some("Not a valid date"), + None, + ])) as VectorRef; + + for array in &[a1, a2] { + let to_type = ConcreteDataType::timestamp_nanosecond_datatype(); + let b = array.cast(&to_type).unwrap(); + let c = b + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + Value::Timestamp(Timestamp::new(1599566400000000000, TimeUnit::Nanosecond)), + c.get(0) + ); + assert!(c.is_null(1)); + assert!(c.is_null(2)); + } + } +} diff --git a/src/frontend/src/error.rs b/src/frontend/src/error.rs index cea0a91d74..d06ba38489 100644 --- a/src/frontend/src/error.rs +++ b/src/frontend/src/error.rs @@ -56,6 +56,9 @@ pub enum Error { source: sql::error::Error, }, + #[snafu(display("Missing insert values"))] + MissingInsertValues { backtrace: Backtrace }, + #[snafu(display("Column datatype error, source: {}", source))] ColumnDataType { #[snafu(backtrace)] @@ -356,7 +359,14 @@ impl ErrorExt for Error { | Error::InvalidSql { .. } | Error::InvalidInsertRequest { .. } | Error::ColumnValuesNumberMismatch { .. } - | Error::IllegalPrimaryKeysDef { .. } => StatusCode::InvalidArguments, + | Error::IllegalPrimaryKeysDef { .. } + | Error::CatalogNotFound { .. } + | Error::SchemaNotFound { .. } + | Error::SchemaExists { .. } + | Error::MissingInsertValues { .. } + | Error::PrimaryKeyNotFound { .. } + | Error::MissingMetasrvOpts { .. } + | Error::ColumnNoneDefaultValue { .. } => StatusCode::InvalidArguments, Error::NotSupported { .. } => StatusCode::Unsupported, @@ -399,26 +409,20 @@ impl ErrorExt for Error { Error::StartMetaClient { source } | Error::RequestMeta { source } => { source.status_code() } - Error::CatalogNotFound { .. } - | Error::SchemaNotFound { .. } - | Error::SchemaExists { .. } => StatusCode::InvalidArguments, - Error::BuildCreateExprOnInsertion { source } | Error::ToTableInsertRequest { source } | Error::FindNewColumnsOnInsertion { source } => source.status_code(), - Error::PrimaryKeyNotFound { .. } => StatusCode::InvalidArguments, Error::ExecuteStatement { source, .. } | Error::DescribeStatement { source } => { source.status_code() } - Error::MissingMetasrvOpts { .. } => StatusCode::InvalidArguments, Error::AlterExprToRequest { source, .. } => source.status_code(), Error::LeaderNotFound { .. } => StatusCode::StorageUnavailable, Error::TableAlreadyExist { .. } => StatusCode::TableAlreadyExists, Error::EncodeSubstraitLogicalPlan { source } => source.status_code(), Error::InvokeDatanode { source } => source.status_code(), Error::ColumnDefaultValue { source, .. } => source.status_code(), - Error::ColumnNoneDefaultValue { .. } => StatusCode::InvalidArguments, + Error::External { source } => source.status_code(), Error::DeserializePartition { source, .. } | Error::FindTableRoute { source, .. } => { source.status_code() diff --git a/src/frontend/src/sql.rs b/src/frontend/src/sql.rs index d1f74e0221..f2a766f663 100644 --- a/src/frontend/src/sql.rs +++ b/src/frontend/src/sql.rs @@ -38,7 +38,10 @@ pub(crate) fn insert_to_request( query_ctx: QueryContextRef, ) -> Result { let columns = stmt.columns(); - let values = stmt.values().context(error::ParseSqlSnafu)?; + let values = stmt + .values_body() + .context(error::ParseSqlSnafu)? + .context(error::MissingInsertValuesSnafu)?; let (catalog_name, schema_name, table_name) = table_idents_to_full_name(stmt.table_name(), query_ctx) diff --git a/src/sql/src/lib.rs b/src/sql/src/lib.rs index 2be9b3d31b..82e2dec217 100644 --- a/src/sql/src/lib.rs +++ b/src/sql/src/lib.rs @@ -11,7 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - +#![feature(box_patterns)] #![feature(assert_matches)] pub mod ast; diff --git a/src/sql/src/statements/insert.rs b/src/sql/src/statements/insert.rs index 6381c3686a..dd42db1120 100644 --- a/src/sql/src/statements/insert.rs +++ b/src/sql/src/statements/insert.rs @@ -11,12 +11,12 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - -use sqlparser::ast::{ObjectName, SetExpr, Statement, UnaryOperator, Values}; +use sqlparser::ast::{ObjectName, Query, SetExpr, Statement, UnaryOperator, Values}; use sqlparser::parser::ParserError; use crate::ast::{Expr, Value}; use crate::error::{self, Result}; +use crate::statements::query::Query as GtQuery; #[derive(Debug, Clone, PartialEq, Eq)] pub struct Insert { @@ -39,16 +39,43 @@ impl Insert { } } - pub fn values(&self) -> Result>> { + pub fn values_body(&self) -> Result>>> { let values = match &self.inner { - Statement::Insert { source, .. } => match &*source.body { - SetExpr::Values(Values { rows, .. }) => sql_exprs_to_values(rows)?, - _ => unreachable!(), - }, - _ => unreachable!(), + Statement::Insert { + source: + box Query { + body: box SetExpr::Values(Values { rows, .. }), + .. + }, + .. + } => Some(sql_exprs_to_values(rows)?), + _ => None, }; + Ok(values) } + + pub fn query_body(&self) -> Result> { + Ok(match &self.inner { + Statement::Insert { + source: box query, .. + } => Some(query.clone().try_into()?), + _ => None, + }) + } + + pub fn is_insert_select(&self) -> bool { + matches!( + self.inner, + Statement::Insert { + source: box Query { + body: box SetExpr::Select { .. }, + .. + }, + .. + } + ) + } } fn sql_exprs_to_values(exprs: &Vec>) -> Result>> { @@ -113,11 +140,10 @@ mod tests { use super::*; use crate::parser::ParserContext; + use crate::statements::statement::Statement; #[test] fn test_insert_value_with_unary_op() { - use crate::statements::statement::Statement; - // insert "-1" let sql = "INSERT INTO my_table VALUES(-1)"; let stmt = ParserContext::create_with_dialect(sql, &GenericDialect {}) @@ -125,7 +151,7 @@ mod tests { .remove(0); match stmt { Statement::Insert(insert) => { - let values = insert.values().unwrap(); + let values = insert.values_body().unwrap().unwrap(); assert_eq!(values, vec![vec![Value::Number("-1".to_string(), false)]]); } _ => unreachable!(), @@ -138,7 +164,7 @@ mod tests { .remove(0); match stmt { Statement::Insert(insert) => { - let values = insert.values().unwrap(); + let values = insert.values_body().unwrap().unwrap(); assert_eq!(values, vec![vec![Value::Number("1".to_string(), false)]]); } _ => unreachable!(), @@ -147,8 +173,6 @@ mod tests { #[test] fn test_insert_value_with_default() { - use crate::statements::statement::Statement; - // insert "default" let sql = "INSERT INTO my_table VALUES(default)"; let stmt = ParserContext::create_with_dialect(sql, &GenericDialect {}) @@ -156,7 +180,7 @@ mod tests { .remove(0); match stmt { Statement::Insert(insert) => { - let values = insert.values().unwrap(); + let values = insert.values_body().unwrap().unwrap(); assert_eq!(values, vec![vec![Value::Placeholder("default".to_owned())]]); } _ => unreachable!(), @@ -165,8 +189,6 @@ mod tests { #[test] fn test_insert_value_with_default_uppercase() { - use crate::statements::statement::Statement; - // insert "DEFAULT" let sql = "INSERT INTO my_table VALUES(DEFAULT)"; let stmt = ParserContext::create_with_dialect(sql, &GenericDialect {}) @@ -174,7 +196,7 @@ mod tests { .remove(0); match stmt { Statement::Insert(insert) => { - let values = insert.values().unwrap(); + let values = insert.values_body().unwrap().unwrap(); assert_eq!(values, vec![vec![Value::Placeholder("DEFAULT".to_owned())]]); } _ => unreachable!(), @@ -183,8 +205,6 @@ mod tests { #[test] fn test_insert_value_with_quoted_string() { - use crate::statements::statement::Statement; - // insert "'default'" let sql = "INSERT INTO my_table VALUES('default')"; let stmt = ParserContext::create_with_dialect(sql, &GenericDialect {}) @@ -192,7 +212,7 @@ mod tests { .remove(0); match stmt { Statement::Insert(insert) => { - let values = insert.values().unwrap(); + let values = insert.values_body().unwrap().unwrap(); assert_eq!( values, vec![vec![Value::SingleQuotedString("default".to_owned())]] @@ -201,4 +221,26 @@ mod tests { _ => unreachable!(), } } + + #[test] + fn test_insert_select() { + let sql = "INSERT INTO my_table select * from other_table"; + let stmt = ParserContext::create_with_dialect(sql, &GenericDialect {}) + .unwrap() + .remove(0); + match stmt { + Statement::Insert(insert) => { + assert!(insert.is_insert_select()); + let q = insert.query_body().unwrap().unwrap(); + assert!(matches!( + q.inner, + Query { + body: box SetExpr::Select { .. }, + .. + } + )); + } + _ => unreachable!(), + } + } }