From 550c494d25f536137fbb1b0adededbb4608b67d3 Mon Sep 17 00:00:00 2001 From: Weny Xu Date: Tue, 9 May 2023 18:46:16 +0900 Subject: [PATCH] fix: Copy from must follow the order of table fields issue (#1521) * fix: Copy from must follow the order of table fields issue * chore: apply suggestion from CR --- src/frontend/src/error.rs | 10 +- src/frontend/src/statement/copy_table_from.rs | 257 ++++++++++++++++-- src/sql/src/parsers/copy_parser.rs | 8 +- 3 files changed, 248 insertions(+), 27 deletions(-) diff --git a/src/frontend/src/error.rs b/src/frontend/src/error.rs index a34661ed29..c2b6c7d49c 100644 --- a/src/frontend/src/error.rs +++ b/src/frontend/src/error.rs @@ -16,6 +16,7 @@ use std::any::Any; use common_error::prelude::*; use datafusion::parquet; +use datatypes::arrow::error::ArrowError; use datatypes::value::Value; use snafu::Location; use store_api::storage::RegionId; @@ -523,6 +524,12 @@ pub enum Error { location: Location, }, + #[snafu(display("Failed to project schema: {}", source))] + ProjectSchema { + source: ArrowError, + location: Location, + }, + #[snafu(display("Failed to encode object into json, source: {}", source))] EncodeJson { source: serde_json::error::Error, @@ -556,7 +563,8 @@ impl ErrorExt for Error { | Error::BuildRegex { .. } | Error::InvalidSchema { .. } | Error::PrepareImmutableTable { .. } - | Error::BuildCsvConfig { .. } => StatusCode::InvalidArguments, + | Error::BuildCsvConfig { .. } + | Error::ProjectSchema { .. } => StatusCode::InvalidArguments, Error::NotSupported { .. } => StatusCode::Unsupported, diff --git a/src/frontend/src/statement/copy_table_from.rs b/src/frontend/src/statement/copy_table_from.rs index 7c4e442719..9ad4b7b06e 100644 --- a/src/frontend/src/statement/copy_table_from.rs +++ b/src/frontend/src/statement/copy_table_from.rs @@ -32,7 +32,7 @@ use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::parquet::arrow::ParquetRecordBatchStreamBuilder; use datafusion::physical_plan::file_format::{FileOpener, FileScanConfig, FileStream}; use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; -use datatypes::arrow::datatypes::{DataType, SchemaRef}; +use datatypes::arrow::datatypes::{DataType, Schema, SchemaRef}; use datatypes::vectors::Helper; use futures_util::StreamExt; use object_store::{Entry, EntryMode, Metakey, ObjectStore}; @@ -146,12 +146,14 @@ impl StatementExecutor { object_store: ObjectStore, path: &str, schema: SchemaRef, + projection: Vec, ) -> Result { match format { Format::Csv(format) => { let csv_conf = CsvConfigBuilder::default() .batch_size(DEFAULT_BATCH_SIZE) .file_schema(schema.clone()) + .file_projection(Some(projection)) .build() .context(error::BuildCsvConfigSnafu)?; @@ -163,10 +165,16 @@ impl StatementExecutor { .await } Format::Json(format) => { + let projected_schema = Arc::new( + schema + .project(&projection) + .context(error::ProjectSchemaSnafu)?, + ); + self.build_file_stream( JsonOpener::new( DEFAULT_BATCH_SIZE, - schema.clone(), + projected_schema, object_store, format.compression_type, ), @@ -206,17 +214,10 @@ impl StatementExecutor { let format = Format::try_from(&req.with).context(error::ParseFileFormatSnafu)?; - let fields = table - .schema() - .arrow_schema() - .fields() - .iter() - .map(|f| f.name().to_string()) - .collect::>(); - let (object_store, entries) = self.list_copy_from_entries(&req).await?; let mut files = Vec::with_capacity(entries.len()); + let table_schema = table.schema().arrow_schema().clone(); for entry in entries.iter() { let metadata = object_store @@ -230,18 +231,52 @@ impl StatementExecutor { let file_schema = self .infer_schema(&format, object_store.clone(), path) .await?; + let (file_schema_projection, table_schema_projection, compat_schema) = + generated_schema_projection_and_compatible_file_schema(&file_schema, &table_schema); - ensure_schema_matches_ignore_timezone(&file_schema, table.schema().arrow_schema())?; + let projected_file_schema = Arc::new( + file_schema + .project(&file_schema_projection) + .context(error::ProjectSchemaSnafu)?, + ); + let projected_table_schema = Arc::new( + table_schema + .project(&table_schema_projection) + .context(error::ProjectSchemaSnafu)?, + ); - files.push((file_schema, path)) + ensure_schema_matches_ignore_timezone( + &projected_file_schema, + &projected_table_schema, + true, + )?; + + files.push(( + Arc::new(compat_schema), + file_schema_projection, + projected_table_schema, + path, + )) } let mut rows_inserted = 0; - for (schema, path) in files { + for (schema, file_schema_projection, projected_table_schema, path) in files { let mut stream = self - .build_read_stream(&format, object_store.clone(), path, schema) + .build_read_stream( + &format, + object_store.clone(), + path, + schema, + file_schema_projection, + ) .await?; + let fields = projected_table_schema + .fields() + .iter() + .map(|f| f.name().to_string()) + .collect::>(); + // TODO(hl): make this configurable through options. let pending_mem_threshold = ReadableSize::mb(32).as_bytes(); let mut pending_mem_size = 0; @@ -301,14 +336,18 @@ async fn batch_insert( Ok(res) } -fn ensure_schema_matches_ignore_timezone(left: &SchemaRef, right: &SchemaRef) -> Result<()> { +fn ensure_schema_matches_ignore_timezone( + left: &SchemaRef, + right: &SchemaRef, + ts_cast: bool, +) -> Result<()> { let not_match = left .fields .iter() .zip(right.fields.iter()) .map(|(l, r)| (l.data_type(), r.data_type())) .enumerate() - .find(|(_, (l, r))| !data_type_equals_ignore_timezone(l, r)); + .find(|(_, (l, r))| !data_type_equals_ignore_timezone_with_options(l, r, ts_cast)); if let Some((index, _)) = not_match { error::InvalidSchemaSnafu { @@ -322,33 +361,78 @@ fn ensure_schema_matches_ignore_timezone(left: &SchemaRef, right: &SchemaRef) -> } } -fn data_type_equals_ignore_timezone(l: &DataType, r: &DataType) -> bool { +fn data_type_equals_ignore_timezone_with_options( + l: &DataType, + r: &DataType, + ts_cast: bool, +) -> bool { match (l, r) { (DataType::List(a), DataType::List(b)) | (DataType::LargeList(a), DataType::LargeList(b)) => { a.is_nullable() == b.is_nullable() - && data_type_equals_ignore_timezone(a.data_type(), b.data_type()) + && data_type_equals_ignore_timezone_with_options( + a.data_type(), + b.data_type(), + ts_cast, + ) } (DataType::FixedSizeList(a, a_size), DataType::FixedSizeList(b, b_size)) => { a_size == b_size && a.is_nullable() == b.is_nullable() - && data_type_equals_ignore_timezone(a.data_type(), b.data_type()) + && data_type_equals_ignore_timezone_with_options( + a.data_type(), + b.data_type(), + ts_cast, + ) } (DataType::Struct(a), DataType::Struct(b)) => { a.len() == b.len() && a.iter().zip(b).all(|(a, b)| { a.is_nullable() == b.is_nullable() - && data_type_equals_ignore_timezone(a.data_type(), b.data_type()) + && data_type_equals_ignore_timezone_with_options( + a.data_type(), + b.data_type(), + ts_cast, + ) }) } (DataType::Map(a_field, a_is_sorted), DataType::Map(b_field, b_is_sorted)) => { a_field == b_field && a_is_sorted == b_is_sorted } - (DataType::Timestamp(l_unit, _), DataType::Timestamp(r_unit, _)) => l_unit == r_unit, + (DataType::Timestamp(l_unit, _), DataType::Timestamp(r_unit, _)) => { + l_unit == r_unit || ts_cast + } + (&DataType::Utf8, DataType::Timestamp(_, _)) + | (DataType::Timestamp(_, _), &DataType::Utf8) => ts_cast, _ => l == r, } } +/// Allows the file schema is a subset of table +fn generated_schema_projection_and_compatible_file_schema( + file: &SchemaRef, + table: &SchemaRef, +) -> (Vec, Vec, Schema) { + let mut file_projection = Vec::with_capacity(file.fields.len()); + let mut table_projection = Vec::with_capacity(file.fields.len()); + let mut compatible_fields = file.fields.iter().cloned().collect::>(); + for (file_idx, file_field) in file.fields.iter().enumerate() { + if let Some((table_idx, table_field)) = table.fields.find(file_field.name()) { + file_projection.push(file_idx); + table_projection.push(table_idx); + + // Safety: the compatible_fields has same length as file schema + compatible_fields[file_idx] = table_field.clone(); + } + } + + ( + file_projection, + table_projection, + Schema::new(compatible_fields), + ) +} + #[cfg(test)] mod tests { use std::sync::Arc; @@ -358,9 +442,18 @@ mod tests { use super::*; fn test_schema_matches(l: (DataType, bool), r: (DataType, bool), matches: bool) { + test_schema_matches_with_options(l, r, false, matches) + } + + fn test_schema_matches_with_options( + l: (DataType, bool), + r: (DataType, bool), + ts_cast: bool, + matches: bool, + ) { let s1 = Arc::new(Schema::new(vec![Field::new("col", l.0, l.1)])); let s2 = Arc::new(Schema::new(vec![Field::new("col", r.0, r.1)])); - let res = ensure_schema_matches_ignore_timezone(&s1, &s2); + let res = ensure_schema_matches_ignore_timezone(&s1, &s2, ts_cast); assert_eq!(matches, res.is_ok()) } @@ -433,4 +526,124 @@ mod tests { test_schema_matches((DataType::Int8, true), (DataType::Int16, true), false); } + + #[test] + fn test_data_type_equals_ignore_timezone_with_options() { + test_schema_matches_with_options( + ( + DataType::Timestamp( + datatypes::arrow::datatypes::TimeUnit::Microsecond, + Some("UTC".into()), + ), + true, + ), + ( + DataType::Timestamp( + datatypes::arrow::datatypes::TimeUnit::Millisecond, + Some("PDT".into()), + ), + true, + ), + true, + true, + ); + + test_schema_matches_with_options( + (DataType::Utf8, true), + ( + DataType::Timestamp( + datatypes::arrow::datatypes::TimeUnit::Millisecond, + Some("PDT".into()), + ), + true, + ), + true, + true, + ); + + test_schema_matches_with_options( + ( + DataType::Timestamp( + datatypes::arrow::datatypes::TimeUnit::Millisecond, + Some("PDT".into()), + ), + true, + ), + (DataType::Utf8, true), + true, + true, + ); + } + + fn make_test_schema(v: &[Field]) -> Arc { + Arc::new(Schema::new(v.to_vec())) + } + + #[test] + fn test_compatible_file_schema() { + let file_schema0 = make_test_schema(&[ + Field::new("c1", DataType::UInt8, true), + Field::new("c2", DataType::UInt8, true), + ]); + + let table_schema = make_test_schema(&[ + Field::new("c1", DataType::Int16, true), + Field::new("c2", DataType::Int16, true), + Field::new("c3", DataType::Int16, true), + ]); + + let compat_schema = make_test_schema(&[ + Field::new("c1", DataType::Int16, true), + Field::new("c2", DataType::Int16, true), + ]); + + let (_, tp, _) = + generated_schema_projection_and_compatible_file_schema(&file_schema0, &table_schema); + + assert_eq!(table_schema.project(&tp).unwrap(), *compat_schema); + } + + #[test] + fn test_schema_projection() { + let file_schema0 = make_test_schema(&[ + Field::new("c1", DataType::UInt8, true), + Field::new("c2", DataType::UInt8, true), + Field::new("c3", DataType::UInt8, true), + ]); + + let file_schema1 = make_test_schema(&[ + Field::new("c3", DataType::UInt8, true), + Field::new("c4", DataType::UInt8, true), + ]); + + let file_schema2 = make_test_schema(&[ + Field::new("c3", DataType::UInt8, true), + Field::new("c4", DataType::UInt8, true), + Field::new("c5", DataType::UInt8, true), + ]); + + let file_schema3 = make_test_schema(&[ + Field::new("c1", DataType::UInt8, true), + Field::new("c2", DataType::UInt8, true), + ]); + + let table_schema = make_test_schema(&[ + Field::new("c3", DataType::UInt8, true), + Field::new("c4", DataType::UInt8, true), + Field::new("c5", DataType::UInt8, true), + ]); + + let tests = [ + (&file_schema0, &table_schema, true), // intersection + (&file_schema1, &table_schema, true), // subset + (&file_schema2, &table_schema, true), // full-eq + (&file_schema3, &table_schema, true), // non-intersection + ]; + + for test in tests { + let (fp, tp, _) = + generated_schema_projection_and_compatible_file_schema(test.0, test.1); + assert_eq!(test.0.project(&fp).unwrap(), test.1.project(&tp).unwrap()); + } + } } diff --git a/src/sql/src/parsers/copy_parser.rs b/src/sql/src/parsers/copy_parser.rs index da0683cb29..0405ddac03 100644 --- a/src/sql/src/parsers/copy_parser.rs +++ b/src/sql/src/parsers/copy_parser.rs @@ -68,7 +68,7 @@ impl<'a> ParserContext<'a> { let with = options .into_iter() .filter_map(|option| { - parse_option_string(option.value).map(|v| (option.name.to_string(), v)) + parse_option_string(option.value).map(|v| (option.name.value.to_uppercase(), v)) }) .collect(); @@ -80,7 +80,7 @@ impl<'a> ParserContext<'a> { let connection = connection_options .into_iter() .filter_map(|option| { - parse_option_string(option.value).map(|v| (option.name.to_string(), v)) + parse_option_string(option.value).map(|v| (option.name.value.to_uppercase(), v)) }) .collect(); Ok(CopyTableArgument { @@ -109,7 +109,7 @@ impl<'a> ParserContext<'a> { let with = options .into_iter() .filter_map(|option| { - parse_option_string(option.value).map(|v| (option.name.to_string(), v)) + parse_option_string(option.value).map(|v| (option.name.value.to_uppercase(), v)) }) .collect(); @@ -121,7 +121,7 @@ impl<'a> ParserContext<'a> { let connection = connection_options .into_iter() .filter_map(|option| { - parse_option_string(option.value).map(|v| (option.name.to_string(), v)) + parse_option_string(option.value).map(|v| (option.name.value.to_uppercase(), v)) }) .collect();