diff --git a/src/common/datasource/src/error.rs b/src/common/datasource/src/error.rs index b8f4f61831..519c42b4a2 100644 --- a/src/common/datasource/src/error.rs +++ b/src/common/datasource/src/error.rs @@ -89,6 +89,19 @@ pub enum Error { location: Location, source: tokio::task::JoinError, }, + + #[snafu(display("Failed to parse format {} with value: {}", key, value))] + ParseFormat { + key: &'static str, + value: String, + location: Location, + }, + + #[snafu(display("Failed to merge schema: {}", source))] + MergeSchema { + source: arrow_schema::ArrowError, + location: Location, + }, } pub type Result = std::result::Result; @@ -109,7 +122,9 @@ impl ErrorExt for Error { | InvalidPath { .. } | InferSchema { .. } | ReadParquetSnafu { .. } - | ParquetToSchema { .. } => StatusCode::InvalidArguments, + | ParquetToSchema { .. } + | ParseFormat { .. } + | MergeSchema { .. } => StatusCode::InvalidArguments, Decompression { .. } | JoinHandle { .. } => StatusCode::Unexpected, } @@ -130,6 +145,8 @@ impl ErrorExt for Error { ParquetToSchema { location, .. } => Some(*location), Decompression { location, .. } => Some(*location), JoinHandle { location, .. } => Some(*location), + ParseFormat { location, .. } => Some(*location), + MergeSchema { location, .. } => Some(*location), UnsupportedBackendProtocol { .. } | EmptyHostPath { .. } diff --git a/src/common/datasource/src/file_format.rs b/src/common/datasource/src/file_format.rs index 463d17ae82..7a34d4cae2 100644 --- a/src/common/datasource/src/file_format.rs +++ b/src/common/datasource/src/file_format.rs @@ -24,9 +24,8 @@ use std::result; use std::sync::Arc; use std::task::Poll; -use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use arrow_schema::ArrowError; +use arrow_schema::{ArrowError, Schema}; use async_trait::async_trait; use bytes::{Buf, Bytes}; use datafusion::error::{DataFusionError, Result as DataFusionResult}; @@ -37,9 +36,14 @@ use object_store::ObjectStore; use crate::compression::CompressionType; use crate::error::Result; +pub const FORMAT_COMPRESSION_TYPE: &str = "COMPRESSION_TYPE"; +pub const FORMAT_DELIMTERL: &str = "DELIMTERL"; +pub const FORMAT_SCHEMA_INFER_MAX_RECORD: &str = "SCHEMA_INFER_MAX_RECORD"; +pub const FORMAT_HAS_HEADER: &str = "FORMAT_HAS_HEADER"; + #[async_trait] pub trait FileFormat: Send + Sync + std::fmt::Debug { - async fn infer_schema(&self, store: &ObjectStore, path: String) -> Result; + async fn infer_schema(&self, store: &ObjectStore, path: String) -> Result; } pub trait ArrowDecoder: Send + 'static { diff --git a/src/common/datasource/src/file_format/csv.rs b/src/common/datasource/src/file_format/csv.rs index 07f5287cb4..6c8d5fcadc 100644 --- a/src/common/datasource/src/file_format/csv.rs +++ b/src/common/datasource/src/file_format/csv.rs @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; +use std::str::FromStr; use std::sync::Arc; use arrow::csv; use arrow::csv::reader::infer_reader_schema as infer_csv_schema; -use arrow_schema::SchemaRef; +use arrow_schema::{Schema, SchemaRef}; use async_trait::async_trait; use common_runtime; use datafusion::error::Result as DataFusionResult; @@ -30,7 +32,7 @@ use crate::compression::CompressionType; use crate::error::{self, Result}; use crate::file_format::{self, open_with_decoder, FileFormat}; -#[derive(Debug)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct CsvFormat { pub has_header: bool, pub delimiter: u8, @@ -38,6 +40,49 @@ pub struct CsvFormat { pub compression_type: CompressionType, } +impl TryFrom<&HashMap> for CsvFormat { + type Error = error::Error; + + fn try_from(value: &HashMap) -> Result { + let mut format = CsvFormat::default(); + if let Some(delimiter) = value.get(file_format::FORMAT_DELIMTERL) { + // TODO(weny): considers to support parse like "\t" (not only b'\t') + format.delimiter = u8::from_str(delimiter).map_err(|_| { + error::ParseFormatSnafu { + key: file_format::FORMAT_DELIMTERL, + value: delimiter, + } + .build() + })?; + }; + if let Some(compression_type) = value.get(file_format::FORMAT_COMPRESSION_TYPE) { + format.compression_type = CompressionType::from_str(compression_type)?; + }; + if let Some(schema_infer_max_record) = + value.get(file_format::FORMAT_SCHEMA_INFER_MAX_RECORD) + { + format.schema_infer_max_record = + Some(schema_infer_max_record.parse::().map_err(|_| { + error::ParseFormatSnafu { + key: file_format::FORMAT_SCHEMA_INFER_MAX_RECORD, + value: schema_infer_max_record, + } + .build() + })?); + }; + if let Some(has_header) = value.get(file_format::FORMAT_HAS_HEADER) { + format.has_header = has_header.parse().map_err(|_| { + error::ParseFormatSnafu { + key: file_format::FORMAT_HAS_HEADER, + value: has_header, + } + .build() + })?; + } + Ok(format) + } +} + impl Default for CsvFormat { fn default() -> Self { Self { @@ -112,7 +157,7 @@ impl FileOpener for CsvOpener { #[async_trait] impl FileFormat for CsvFormat { - async fn infer_schema(&self, store: &ObjectStore, path: String) -> Result { + async fn infer_schema(&self, store: &ObjectStore, path: String) -> Result { let reader = store .reader(&path) .await @@ -130,8 +175,7 @@ impl FileFormat for CsvFormat { let (schema, _records_read) = infer_csv_schema(reader, delimiter, schema_infer_max_record, has_header) .context(error::InferSchemaSnafu { path: &path })?; - - Ok(Arc::new(schema)) + Ok(schema) }) .await .context(error::JoinHandleSnafu)? @@ -142,7 +186,10 @@ impl FileFormat for CsvFormat { mod tests { use super::*; - use crate::file_format::FileFormat; + use crate::file_format::{ + FileFormat, FORMAT_COMPRESSION_TYPE, FORMAT_DELIMTERL, FORMAT_HAS_HEADER, + FORMAT_SCHEMA_INFER_MAX_RECORD, + }; use crate::test_util::{self, format_schema, test_store}; fn test_data_root() -> String { @@ -220,4 +267,33 @@ mod tests { formatted ); } + + #[test] + fn test_try_from() { + let mut map = HashMap::new(); + let format: CsvFormat = CsvFormat::try_from(&map).unwrap(); + + assert_eq!(format, CsvFormat::default()); + + map.insert( + FORMAT_SCHEMA_INFER_MAX_RECORD.to_string(), + "2000".to_string(), + ); + + map.insert(FORMAT_COMPRESSION_TYPE.to_string(), "zstd".to_string()); + map.insert(FORMAT_DELIMTERL.to_string(), b'\t'.to_string()); + map.insert(FORMAT_HAS_HEADER.to_string(), "false".to_string()); + + let format = CsvFormat::try_from(&map).unwrap(); + + assert_eq!( + format, + CsvFormat { + compression_type: CompressionType::ZSTD, + schema_infer_max_record: Some(2000), + delimiter: b'\t', + has_header: false, + } + ); + } } diff --git a/src/common/datasource/src/file_format/json.rs b/src/common/datasource/src/file_format/json.rs index 26a4b3d561..297b457fa8 100644 --- a/src/common/datasource/src/file_format/json.rs +++ b/src/common/datasource/src/file_format/json.rs @@ -12,12 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; use std::io::BufReader; +use std::str::FromStr; use std::sync::Arc; use arrow::datatypes::SchemaRef; use arrow::json::reader::{infer_json_schema_from_iterator, ValueIter}; use arrow::json::RawReaderBuilder; +use arrow_schema::Schema; use async_trait::async_trait; use common_runtime; use datafusion::error::{DataFusionError, Result as DataFusionResult}; @@ -30,12 +33,36 @@ use crate::compression::CompressionType; use crate::error::{self, Result}; use crate::file_format::{self, open_with_decoder, FileFormat}; -#[derive(Debug)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct JsonFormat { pub schema_infer_max_record: Option, pub compression_type: CompressionType, } +impl TryFrom<&HashMap> for JsonFormat { + type Error = error::Error; + + fn try_from(value: &HashMap) -> Result { + let mut format = JsonFormat::default(); + if let Some(compression_type) = value.get(file_format::FORMAT_COMPRESSION_TYPE) { + format.compression_type = CompressionType::from_str(compression_type)? + }; + if let Some(schema_infer_max_record) = + value.get(file_format::FORMAT_SCHEMA_INFER_MAX_RECORD) + { + format.schema_infer_max_record = + Some(schema_infer_max_record.parse::().map_err(|_| { + error::ParseFormatSnafu { + key: file_format::FORMAT_SCHEMA_INFER_MAX_RECORD, + value: schema_infer_max_record, + } + .build() + })?); + }; + Ok(format) + } +} + impl Default for JsonFormat { fn default() -> Self { Self { @@ -47,7 +74,7 @@ impl Default for JsonFormat { #[async_trait] impl FileFormat for JsonFormat { - async fn infer_schema(&self, store: &ObjectStore, path: String) -> Result { + async fn infer_schema(&self, store: &ObjectStore, path: String) -> Result { let reader = store .reader(&path) .await @@ -65,7 +92,7 @@ impl FileFormat for JsonFormat { let schema = infer_json_schema_from_iterator(iter) .context(error::InferSchemaSnafu { path: &path })?; - Ok(Arc::new(schema)) + Ok(schema) }) .await .context(error::JoinHandleSnafu)? @@ -116,7 +143,7 @@ impl FileOpener for JsonOpener { #[cfg(test)] mod tests { use super::*; - use crate::file_format::FileFormat; + use crate::file_format::{FileFormat, FORMAT_COMPRESSION_TYPE, FORMAT_SCHEMA_INFER_MAX_RECORD}; use crate::test_util::{self, format_schema, test_store}; fn test_data_root() -> String { @@ -162,4 +189,29 @@ mod tests { formatted ); } + + #[test] + fn test_try_from() { + let mut map = HashMap::new(); + let format = JsonFormat::try_from(&map).unwrap(); + + assert_eq!(format, JsonFormat::default()); + + map.insert( + FORMAT_SCHEMA_INFER_MAX_RECORD.to_string(), + "2000".to_string(), + ); + + map.insert(FORMAT_COMPRESSION_TYPE.to_string(), "zstd".to_string()); + + let format = JsonFormat::try_from(&map).unwrap(); + + assert_eq!( + format, + JsonFormat { + compression_type: CompressionType::ZSTD, + schema_infer_max_record: Some(2000), + } + ); + } } diff --git a/src/common/datasource/src/file_format/parquet.rs b/src/common/datasource/src/file_format/parquet.rs index c0fd8e5abf..8847018608 100644 --- a/src/common/datasource/src/file_format/parquet.rs +++ b/src/common/datasource/src/file_format/parquet.rs @@ -12,9 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; - -use arrow_schema::SchemaRef; +use arrow_schema::Schema; use async_trait::async_trait; use datafusion::parquet::arrow::async_reader::AsyncFileReader; use datafusion::parquet::arrow::parquet_to_arrow_schema; @@ -29,7 +27,7 @@ pub struct ParquetFormat {} #[async_trait] impl FileFormat for ParquetFormat { - async fn infer_schema(&self, store: &ObjectStore, path: String) -> Result { + async fn infer_schema(&self, store: &ObjectStore, path: String) -> Result { let mut reader = store .reader(&path) .await @@ -47,7 +45,7 @@ impl FileFormat for ParquetFormat { ) .context(error::ParquetToSchemaSnafu)?; - Ok(Arc::new(schema)) + Ok(schema) } } diff --git a/src/common/datasource/src/object_store.rs b/src/common/datasource/src/object_store.rs index e1f949ceb3..ec068c08bc 100644 --- a/src/common/datasource/src/object_store.rs +++ b/src/common/datasource/src/object_store.rs @@ -27,7 +27,7 @@ use crate::error::{self, Result}; pub const FS_SCHEMA: &str = "FS"; pub const S3_SCHEMA: &str = "S3"; -/// parse url returns (schema,Option,path) +/// Returns (schema, Option, path) pub fn parse_url(url: &str) -> Result<(String, Option, String)> { let parsed_url = Url::parse(url); match parsed_url { @@ -43,7 +43,7 @@ pub fn parse_url(url: &str) -> Result<(String, Option, String)> { } } -pub fn build_backend(url: &str, connection: HashMap) -> Result { +pub fn build_backend(url: &str, connection: &HashMap) -> Result { let (schema, host, _path) = parse_url(url)?; match schema.to_uppercase().as_str() { diff --git a/src/common/datasource/src/object_store/s3.rs b/src/common/datasource/src/object_store/s3.rs index 482da1bcef..7688211021 100644 --- a/src/common/datasource/src/object_store/s3.rs +++ b/src/common/datasource/src/object_store/s3.rs @@ -30,7 +30,7 @@ const ENABLE_VIRTUAL_HOST_STYLE: &str = "ENABLE_VIRTUAL_HOST_STYLE"; pub fn build_s3_backend( host: &str, path: &str, - connection: HashMap, + connection: &HashMap, ) -> Result { let mut builder = S3::default(); diff --git a/src/common/datasource/src/test_util.rs b/src/common/datasource/src/test_util.rs index fcf735d175..a71ddd876c 100644 --- a/src/common/datasource/src/test_util.rs +++ b/src/common/datasource/src/test_util.rs @@ -28,7 +28,7 @@ pub fn get_data_dir(path: &str) -> PathBuf { PathBuf::from(dir).join(path) } -pub fn format_schema(schema: SchemaRef) -> Vec { +pub fn format_schema(schema: Schema) -> Vec { schema .fields() .iter() diff --git a/src/frontend/src/statement/copy_table_from.rs b/src/frontend/src/statement/copy_table_from.rs index 99838abdf6..c19e4fd592 100644 --- a/src/frontend/src/statement/copy_table_from.rs +++ b/src/frontend/src/statement/copy_table_from.rs @@ -46,7 +46,7 @@ impl StatementExecutor { let (_schema, _host, path) = parse_url(&req.location).context(error::ParseUrlSnafu)?; let object_store = - build_backend(&req.location, req.connection).context(error::BuildBackendSnafu)?; + build_backend(&req.location, &req.connection).context(error::BuildBackendSnafu)?; let (dir, filename) = find_dir_and_filename(&path); let regex = req diff --git a/src/frontend/src/statement/copy_table_to.rs b/src/frontend/src/statement/copy_table_to.rs index 564cfc2011..7f37478511 100644 --- a/src/frontend/src/statement/copy_table_to.rs +++ b/src/frontend/src/statement/copy_table_to.rs @@ -46,7 +46,7 @@ impl StatementExecutor { let (_schema, _host, path) = parse_url(&req.location).context(error::ParseUrlSnafu)?; let object_store = - build_backend(&req.location, req.connection).context(error::BuildBackendSnafu)?; + build_backend(&req.location, &req.connection).context(error::BuildBackendSnafu)?; let writer = ParquetWriter::new(&path, Source::Stream(stream), object_store);