From 177036475aa772c63386faed30d1853a78e28e8c Mon Sep 17 00:00:00 2001 From: Weny Xu Date: Fri, 18 Aug 2023 11:09:54 +0800 Subject: [PATCH] fix: support to copy from parquet with typecast (#2201) --- src/common/recordbatch/src/adapter.rs | 44 ++++++++++++++++++- src/frontend/src/statement/copy_table_from.rs | 6 ++- tests-integration/src/tests/instance_test.rs | 33 ++++++++++++++ 3 files changed, 80 insertions(+), 3 deletions(-) diff --git a/src/common/recordbatch/src/adapter.rs b/src/common/recordbatch/src/adapter.rs index a4ed408c67..93fe9a3a5b 100644 --- a/src/common/recordbatch/src/adapter.rs +++ b/src/common/recordbatch/src/adapter.rs @@ -17,6 +17,7 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +use datafusion::arrow::compute::cast; use datafusion::arrow::datatypes::SchemaRef as DfSchemaRef; use datafusion::error::Result as DfResult; use datafusion::parquet::arrow::async_reader::{AsyncFileReader, ParquetRecordBatchStream}; @@ -44,11 +45,27 @@ type FutureStream = Pin< /// ParquetRecordBatchStream -> DataFusion RecordBatchStream pub struct ParquetRecordBatchStreamAdapter { stream: ParquetRecordBatchStream, + output_schema: DfSchemaRef, + projection: Vec, } impl ParquetRecordBatchStreamAdapter { - pub fn new(stream: ParquetRecordBatchStream) -> Self { - Self { stream } + pub fn new( + output_schema: DfSchemaRef, + stream: ParquetRecordBatchStream, + projection: Option>, + ) -> Self { + let projection = if let Some(projection) = projection { + projection + } else { + (0..output_schema.fields().len()).collect() + }; + + Self { + stream, + output_schema, + projection, + } } } @@ -66,6 +83,29 @@ impl Stream for ParquetRecordBatchS fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let batch = futures::ready!(Pin::new(&mut self.stream).poll_next(cx)) .map(|r| r.map_err(|e| DataFusionError::External(Box::new(e)))); + + let projected_schema = self.output_schema.project(&self.projection)?; + let batch = batch.map(|b| { + b.and_then(|b| { + let mut columns = Vec::with_capacity(self.projection.len()); + for idx in self.projection.iter() { + let column = b.column(*idx); + let field = self.output_schema.field(*idx); + + if column.data_type() != field.data_type() { + let output = cast(&column, field.data_type())?; + columns.push(output) + } else { + columns.push(column.clone()) + } + } + + let record_batch = DfRecordBatch::try_new(projected_schema.into(), columns)?; + + Ok(record_batch) + }) + }); + Poll::Ready(batch) } diff --git a/src/frontend/src/statement/copy_table_from.rs b/src/frontend/src/statement/copy_table_from.rs index 148f3cf80e..17dbc28934 100644 --- a/src/frontend/src/statement/copy_table_from.rs +++ b/src/frontend/src/statement/copy_table_from.rs @@ -214,7 +214,11 @@ impl StatementExecutor { .build() .context(error::BuildParquetRecordBatchStreamSnafu)?; - Ok(Box::pin(ParquetRecordBatchStreamAdapter::new(upstream))) + Ok(Box::pin(ParquetRecordBatchStreamAdapter::new( + schema, + upstream, + Some(projection), + ))) } Format::Orc(_) => { let reader = object_store diff --git a/tests-integration/src/tests/instance_test.rs b/tests-integration/src/tests/instance_test.rs index 9345814ae4..b3cea11fc2 100644 --- a/tests-integration/src/tests/instance_test.rs +++ b/tests-integration/src/tests/instance_test.rs @@ -1305,6 +1305,39 @@ async fn test_execute_copy_from_s3(instance: Arc) { } } +#[apply(both_instances_cases)] +async fn test_execute_copy_from_orc_with_cast(instance: Arc) { + logging::init_default_ut_logging(); + let instance = instance.frontend(); + + // setups + assert!(matches!(execute_sql( + &instance, + "create table demo(bigint_direct timestamp(9), bigint_neg_direct timestamp(6), bigint_other timestamp(3), timestamp_simple timestamp(9), time index (bigint_other));", + ) + .await, Output::AffectedRows(0))); + + let filepath = find_testing_resource("/src/common/datasource/tests/orc/test.orc"); + + let output = execute_sql( + &instance, + &format!("copy demo from '{}' WITH(FORMAT='orc');", &filepath), + ) + .await; + + assert!(matches!(output, Output::AffectedRows(5))); + + let output = execute_sql(&instance, "select * from demo;").await; + let expected = r#"+-------------------------------+----------------------------+-------------------------+----------------------------+ +| bigint_direct | bigint_neg_direct | bigint_other | timestamp_simple | ++-------------------------------+----------------------------+-------------------------+----------------------------+ +| 1970-01-01T00:00:00.000000006 | 1969-12-31T23:59:59.999994 | 1969-12-31T23:59:59.995 | 2021-08-22T07:26:44.525777 | +| | | 1970-01-01T00:00:00.001 | 2023-01-01T00:00:00 | +| 1970-01-01T00:00:00.000000002 | 1969-12-31T23:59:59.999998 | 1970-01-01T00:00:00.005 | 2023-03-01T00:00:00 | ++-------------------------------+----------------------------+-------------------------+----------------------------+"#; + check_output_stream(output, expected).await; +} + #[apply(both_instances_cases)] async fn test_execute_copy_from_orc(instance: Arc) { logging::init_default_ut_logging();