From d374859e243eaf6c3d8bd5cf9bb656f27bbaea63 Mon Sep 17 00:00:00 2001 From: Weny Xu Date: Sun, 23 Apr 2023 17:31:54 +0900 Subject: [PATCH] refactor: replace Copy Format with datasource Format (#1435) * refactor: replace Copy Format with datasource Format * chore: apply suggestions from CR * chore: apply suggestions from CR --- src/common/datasource/src/error.rs | 46 ++++++--- src/common/datasource/src/file_format.rs | 32 ++++++- .../datasource/src/file_format/parquet.rs | 3 +- .../datasource/src/file_format/tests.rs | 39 ++++++++ src/common/datasource/src/lib.rs | 2 + src/frontend/src/statement.rs | 5 +- src/sql/src/error.rs | 6 +- src/sql/src/parsers/copy_parser.rs | 93 +++++-------------- src/sql/src/statements/copy.rs | 27 +++--- src/table/src/requests.rs | 1 + 10 files changed, 148 insertions(+), 106 deletions(-) diff --git a/src/common/datasource/src/error.rs b/src/common/datasource/src/error.rs index 519c42b4a2..663b94df67 100644 --- a/src/common/datasource/src/error.rs +++ b/src/common/datasource/src/error.rs @@ -22,19 +22,32 @@ use url::ParseError; #[snafu(visibility(pub))] pub enum Error { #[snafu(display("Unsupported compression type: {}", compression_type))] - UnsupportedCompressionType { compression_type: String }, + UnsupportedCompressionType { + compression_type: String, + location: Location, + }, #[snafu(display("Unsupported backend protocol: {}", protocol))] - UnsupportedBackendProtocol { protocol: String }, + UnsupportedBackendProtocol { + protocol: String, + location: Location, + }, + + #[snafu(display("Unsupported format protocol: {}", format))] + UnsupportedFormat { format: String, location: Location }, #[snafu(display("empty host: {}", url))] - EmptyHostPath { url: String }, + EmptyHostPath { url: String, location: Location }, #[snafu(display("Invalid path: {}", path))] - InvalidPath { path: String }, + InvalidPath { path: String, location: Location }, #[snafu(display("Invalid url: {}, error :{}", url, source))] - InvalidUrl { url: String, source: ParseError }, + InvalidUrl { + url: String, + source: ParseError, + location: Location, + }, #[snafu(display("Failed to decompression, source: {}", source))] Decompression { @@ -82,7 +95,7 @@ pub enum Error { }, #[snafu(display("Invalid connection: {}", msg))] - InvalidConnection { msg: String }, + InvalidConnection { msg: String, location: Location }, #[snafu(display("Failed to join handle: {}", source))] JoinHandle { @@ -102,6 +115,9 @@ pub enum Error { source: arrow_schema::ArrowError, location: Location, }, + + #[snafu(display("Missing required field: {}", name))] + MissingRequiredField { name: String, location: Location }, } pub type Result = std::result::Result; @@ -116,6 +132,7 @@ impl ErrorExt for Error { UnsupportedBackendProtocol { .. } | UnsupportedCompressionType { .. } + | UnsupportedFormat { .. } | InvalidConnection { .. } | InvalidUrl { .. } | EmptyHostPath { .. } @@ -124,7 +141,8 @@ impl ErrorExt for Error { | ReadParquetSnafu { .. } | ParquetToSchema { .. } | ParseFormat { .. } - | MergeSchema { .. } => StatusCode::InvalidArguments, + | MergeSchema { .. } + | MissingRequiredField { .. } => StatusCode::InvalidArguments, Decompression { .. } | JoinHandle { .. } => StatusCode::Unexpected, } @@ -147,13 +165,15 @@ impl ErrorExt for Error { JoinHandle { location, .. } => Some(*location), ParseFormat { location, .. } => Some(*location), MergeSchema { location, .. } => Some(*location), + MissingRequiredField { location, .. } => Some(*location), - UnsupportedBackendProtocol { .. } - | EmptyHostPath { .. } - | InvalidPath { .. } - | InvalidUrl { .. } - | InvalidConnection { .. } - | UnsupportedCompressionType { .. } => None, + UnsupportedBackendProtocol { location, .. } => Some(*location), + EmptyHostPath { location, .. } => Some(*location), + InvalidPath { location, .. } => Some(*location), + InvalidUrl { location, .. } => Some(*location), + InvalidConnection { location, .. } => Some(*location), + UnsupportedCompressionType { location, .. } => Some(*location), + UnsupportedFormat { location, .. } => Some(*location), } } } diff --git a/src/common/datasource/src/file_format.rs b/src/common/datasource/src/file_format.rs index 7a34d4cae2..9f53f1c210 100644 --- a/src/common/datasource/src/file_format.rs +++ b/src/common/datasource/src/file_format.rs @@ -20,6 +20,7 @@ pub mod tests; pub const DEFAULT_SCHEMA_INFER_MAX_RECORD: usize = 1000; +use std::collections::HashMap; use std::result; use std::sync::Arc; use std::task::Poll; @@ -33,13 +34,42 @@ use datafusion::physical_plan::file_format::FileOpenFuture; use futures::StreamExt; use object_store::ObjectStore; +use self::csv::CsvFormat; +use self::json::JsonFormat; +use self::parquet::ParquetFormat; use crate::compression::CompressionType; -use crate::error::Result; +use crate::error::{self, Result}; pub const FORMAT_COMPRESSION_TYPE: &str = "COMPRESSION_TYPE"; pub const FORMAT_DELIMTERL: &str = "DELIMTERL"; pub const FORMAT_SCHEMA_INFER_MAX_RECORD: &str = "SCHEMA_INFER_MAX_RECORD"; pub const FORMAT_HAS_HEADER: &str = "FORMAT_HAS_HEADER"; +pub const FORMAT_TYPE: &str = "FORMAT"; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Format { + Csv(CsvFormat), + Json(JsonFormat), + Parquet(ParquetFormat), +} + +impl TryFrom<&HashMap> for Format { + type Error = error::Error; + + fn try_from(options: &HashMap) -> Result { + let format = options + .get(FORMAT_TYPE) + .map(|format| format.to_ascii_uppercase()) + .unwrap_or_else(|| "PARQUET".to_string()); + + match format.as_str() { + "CSV" => Ok(Self::Csv(CsvFormat::try_from(options)?)), + "JSON" => Ok(Self::Json(JsonFormat::try_from(options)?)), + "PARQUET" => Ok(Self::Parquet(ParquetFormat::default())), + _ => error::UnsupportedFormatSnafu { format: &format }.fail(), + } + } +} #[async_trait] pub trait FileFormat: Send + Sync + std::fmt::Debug { diff --git a/src/common/datasource/src/file_format/parquet.rs b/src/common/datasource/src/file_format/parquet.rs index 1c79f4aca6..edb5b98431 100644 --- a/src/common/datasource/src/file_format/parquet.rs +++ b/src/common/datasource/src/file_format/parquet.rs @@ -31,7 +31,7 @@ use snafu::ResultExt; use crate::error::{self, Result}; use crate::file_format::FileFormat; -#[derive(Debug, Default)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub struct ParquetFormat {} #[async_trait] @@ -142,7 +142,6 @@ impl AsyncFileReader for LazyParquetFileReader { #[cfg(test)] mod tests { use super::*; - use crate::file_format::FileFormat; use crate::test_util::{self, format_schema, test_store}; fn test_data_root() -> String { diff --git a/src/common/datasource/src/file_format/tests.rs b/src/common/datasource/src/file_format/tests.rs index 36036c73d3..873209da2f 100644 --- a/src/common/datasource/src/file_format/tests.rs +++ b/src/common/datasource/src/file_format/tests.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::assert_matches::assert_matches; +use std::collections::HashMap; use std::sync::Arc; use std::vec; @@ -26,10 +28,13 @@ use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; use futures::StreamExt; +use super::FORMAT_TYPE; use crate::compression::CompressionType; +use crate::error; use crate::file_format::csv::{CsvConfigBuilder, CsvOpener}; use crate::file_format::json::JsonOpener; use crate::file_format::parquet::DefaultParquetFileReaderFactory; +use crate::file_format::Format; use crate::test_util::{self, test_basic_schema, test_store}; fn scan_config(file_schema: SchemaRef, limit: Option, filename: &str) -> FileScanConfig { @@ -204,3 +209,37 @@ async fn test_parquet_exec() { &result ); } + +#[test] +fn test_format() { + let value = [(FORMAT_TYPE.to_string(), "csv".to_string())] + .into_iter() + .collect::>(); + + assert_matches!(Format::try_from(&value).unwrap(), Format::Csv(_)); + + let value = [(FORMAT_TYPE.to_string(), "Parquet".to_string())] + .into_iter() + .collect::>(); + + assert_matches!(Format::try_from(&value).unwrap(), Format::Parquet(_)); + + let value = [(FORMAT_TYPE.to_string(), "JSON".to_string())] + .into_iter() + .collect::>(); + + assert_matches!(Format::try_from(&value).unwrap(), Format::Json(_)); + + let value = [(FORMAT_TYPE.to_string(), "Foobar".to_string())] + .into_iter() + .collect::>(); + + assert_matches!( + Format::try_from(&value).unwrap_err(), + error::Error::UnsupportedFormat { .. } + ); + + let value = HashMap::new(); + + assert_matches!(Format::try_from(&value).unwrap(), Format::Parquet(_)); +} diff --git a/src/common/datasource/src/lib.rs b/src/common/datasource/src/lib.rs index 458516cd42..12503d49c2 100644 --- a/src/common/datasource/src/lib.rs +++ b/src/common/datasource/src/lib.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#![feature(assert_matches)] + pub mod compression; pub mod error; pub mod file_format; diff --git a/src/frontend/src/statement.rs b/src/frontend/src/statement.rs index 45b23dbeb3..6af1756bba 100644 --- a/src/frontend/src/statement.rs +++ b/src/frontend/src/statement.rs @@ -166,7 +166,7 @@ fn to_copy_table_request(stmt: CopyTable, query_ctx: QueryContextRef) -> Result< let CopyTableArgument { location, connection, - pattern, + with, table_name, .. } = match stmt { @@ -177,11 +177,14 @@ fn to_copy_table_request(stmt: CopyTable, query_ctx: QueryContextRef) -> Result< .map_err(BoxedError::new) .context(ExternalSnafu)?; + let pattern = with.get("PATTERN").cloned(); + Ok(CopyTableRequest { catalog_name, schema_name, table_name, location, + with, connection, pattern, direction, diff --git a/src/sql/src/error.rs b/src/sql/src/error.rs index 4e5ca063af..6d1f54e015 100644 --- a/src/sql/src/error.rs +++ b/src/sql/src/error.rs @@ -137,9 +137,6 @@ pub enum Error { target_unit: TimeUnit, }, - #[snafu(display("Unsupported format option: {}", name))] - UnsupportedCopyFormatOption { name: String }, - #[snafu(display("Unable to convert statement {} to DataFusion statement", statement))] ConvertToDfStatement { statement: String, @@ -178,8 +175,7 @@ impl ErrorExt for Error { | ColumnTypeMismatch { .. } | InvalidTableName { .. } | InvalidSqlValue { .. } - | TimestampOverflow { .. } - | UnsupportedCopyFormatOption { .. } => StatusCode::InvalidArguments, + | TimestampOverflow { .. } => StatusCode::InvalidArguments, UnsupportedAlterTableStatement { .. } => StatusCode::InvalidSyntax, SerializeColumnDefaultConstraint { source, .. } => source.status_code(), diff --git a/src/sql/src/parsers/copy_parser.rs b/src/sql/src/parsers/copy_parser.rs index a619a15afd..da0683cb29 100644 --- a/src/sql/src/parsers/copy_parser.rs +++ b/src/sql/src/parsers/copy_parser.rs @@ -18,7 +18,7 @@ use sqlparser::keywords::Keyword; use crate::error::{self, Result}; use crate::parser::ParserContext; -use crate::statements::copy::{CopyTable, CopyTableArgument, Format}; +use crate::statements::copy::{CopyTable, CopyTableArgument}; use crate::statements::statement::Statement; use crate::util::parse_option_string; @@ -65,25 +65,12 @@ impl<'a> ParserContext<'a> { .parse_options(Keyword::WITH) .context(error::SyntaxSnafu { sql: self.sql })?; - // default format is parquet - let mut format = Format::Parquet; - let mut pattern = None; - for option in options { - match option.name.value.to_ascii_uppercase().as_str() { - "FORMAT" => { - if let Some(fmt_str) = parse_option_string(option.value) { - format = Format::try_from(fmt_str)?; - } - } - "PATTERN" => { - if let Some(v) = parse_option_string(option.value) { - pattern = Some(v); - } - } - //TODO: throws warnings? - _ => (), - } - } + let with = options + .into_iter() + .filter_map(|option| { + parse_option_string(option.value).map(|v| (option.name.to_string(), v)) + }) + .collect(); let connection_options = self .parser @@ -93,17 +80,12 @@ impl<'a> ParserContext<'a> { let connection = connection_options .into_iter() .filter_map(|option| { - if let Some(v) = parse_option_string(option.value) { - Some((option.name.value.to_uppercase(), v)) - } else { - None - } + parse_option_string(option.value).map(|v| (option.name.to_string(), v)) }) .collect(); Ok(CopyTableArgument { table_name, - format, - pattern, + with, connection, location, }) @@ -124,15 +106,12 @@ impl<'a> ParserContext<'a> { .parse_options(Keyword::WITH) .context(error::SyntaxSnafu { sql: self.sql })?; - // default format is parquet - let mut format = Format::Parquet; - for option in options { - if option.name.value.eq_ignore_ascii_case("FORMAT") { - if let Some(fmt_str) = parse_option_string(option.value) { - format = Format::try_from(fmt_str)?; - } - } - } + let with = options + .into_iter() + .filter_map(|option| { + parse_option_string(option.value).map(|v| (option.name.to_string(), v)) + }) + .collect(); let connection_options = self .parser @@ -142,19 +121,14 @@ impl<'a> ParserContext<'a> { let connection = connection_options .into_iter() .filter_map(|option| { - if let Some(v) = parse_option_string(option.value) { - Some((option.name.value.to_uppercase(), v)) - } else { - None - } + parse_option_string(option.value).map(|v| (option.name.to_string(), v)) }) .collect(); Ok(CopyTableArgument { table_name, - format, + with, connection, - pattern: None, location, }) } @@ -198,11 +172,11 @@ mod tests { assert_eq!("schema0", schema); assert_eq!("tbl", table); - let file_name = copy_table.location; + let file_name = ©_table.location; assert_eq!("tbl_file.parquet", file_name); - let format = copy_table.format; - assert_eq!(Format::Parquet, format); + let format = copy_table.format().unwrap(); + assert_eq!("parquet", format.to_lowercase()); } _ => unreachable!(), } @@ -241,11 +215,11 @@ mod tests { assert_eq!("schema0", schema); assert_eq!("tbl", table); - let file_name = copy_table.location; + let file_name = ©_table.location; assert_eq!("tbl_file.parquet", file_name); - let format = copy_table.format; - assert_eq!(Format::Parquet, format); + let format = copy_table.format().unwrap(); + assert_eq!("parquet", format.to_lowercase()); } _ => unreachable!(), } @@ -283,7 +257,7 @@ mod tests { match statement { Statement::Copy(CopyTable::From(copy_table)) => { if let Some(expected_pattern) = test.expected_pattern { - assert_eq!(copy_table.pattern.clone().unwrap(), expected_pattern); + assert_eq!(copy_table.pattern().unwrap(), expected_pattern); } assert_eq!(copy_table.connection.clone(), test.expected_connection); } @@ -329,23 +303,4 @@ mod tests { } } } - - #[test] - fn test_parse_copy_table_with_unsupopoted_format() { - let results = [ - "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' WITH (FORMAT = 'unknow_format')", - "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet' WITH (FORMAT = 'unknow_format')", - ] - .iter() - .map(|sql| ParserContext::create_with_dialect(sql, &GenericDialect {})) - .collect::>(); - - for result in results { - assert!(result.is_err()); - assert_matches!( - result.err().unwrap(), - error::Error::UnsupportedCopyFormatOption { .. } - ); - } - } } diff --git a/src/sql/src/statements/copy.rs b/src/sql/src/statements/copy.rs index b56d48e9cb..db7e69d5b9 100644 --- a/src/sql/src/statements/copy.rs +++ b/src/sql/src/statements/copy.rs @@ -16,8 +16,6 @@ use std::collections::HashMap; use sqlparser::ast::ObjectName; -use crate::error::{self, Result}; - #[derive(Debug, Clone, PartialEq, Eq)] pub enum CopyTable { To(CopyTableArgument), @@ -27,25 +25,24 @@ pub enum CopyTable { #[derive(Debug, Clone, PartialEq, Eq)] pub struct CopyTableArgument { pub table_name: ObjectName, - pub format: Format, + pub with: HashMap, pub connection: HashMap, - pub pattern: Option, /// Copy tbl [To|From] 'location'. pub location: String, } -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum Format { - Parquet, -} +#[cfg(test)] +impl CopyTableArgument { + const FORMAT: &str = "FORMAT"; -impl TryFrom for Format { - type Error = error::Error; + pub fn format(&self) -> Option { + self.with + .get(Self::FORMAT) + .cloned() + .or_else(|| Some("PARQUET".to_string())) + } - fn try_from(name: String) -> Result { - if name.eq_ignore_ascii_case("PARQUET") { - return Ok(Format::Parquet); - } - error::UnsupportedCopyFormatOptionSnafu { name }.fail() + pub fn pattern(&self) -> Option { + self.with.get("PATTERN").cloned() } } diff --git a/src/table/src/requests.rs b/src/table/src/requests.rs index b346f2ea65..3cebfbe006 100644 --- a/src/table/src/requests.rs +++ b/src/table/src/requests.rs @@ -259,6 +259,7 @@ pub struct CopyTableRequest { pub schema_name: String, pub table_name: String, pub location: String, + pub with: HashMap, pub connection: HashMap, pub pattern: Option, pub direction: CopyDirection,