From c8ed1bbfae3d66298cdb911f2808c47e120d5346 Mon Sep 17 00:00:00 2001 From: Weny Xu Date: Mon, 10 Jul 2023 17:53:38 +0900 Subject: [PATCH] fix: cast orc data against output schema (#1922) fix: cast data against output schema --- src/common/datasource/src/file_format/orc.rs | 29 +++++++++++++++++-- src/frontend/src/statement/copy_table_from.rs | 2 +- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/src/common/datasource/src/file_format/orc.rs b/src/common/datasource/src/file_format/orc.rs index fb228ee1db..7b9858661a 100644 --- a/src/common/datasource/src/file_format/orc.rs +++ b/src/common/datasource/src/file_format/orc.rs @@ -15,6 +15,7 @@ use std::pin::Pin; use std::task::{Context, Poll}; +use arrow::compute::cast; use arrow_schema::{Schema, SchemaRef}; use async_trait::async_trait; use datafusion::arrow::record_batch::RecordBatch as DfRecordBatch; @@ -60,12 +61,16 @@ pub async fn infer_orc_schema } pub struct OrcArrowStreamReaderAdapter { + output_schema: SchemaRef, stream: ArrowStreamReader, } impl OrcArrowStreamReaderAdapter { - pub fn new(stream: ArrowStreamReader) -> Self { - Self { stream } + pub fn new(output_schema: SchemaRef, stream: ArrowStreamReader) -> Self { + Self { + stream, + output_schema, + } } } @@ -73,7 +78,7 @@ impl RecordBatchStream for OrcArrowStreamReaderAdapter { fn schema(&self) -> SchemaRef { - self.stream.schema() + self.output_schema.clone() } } @@ -83,6 +88,24 @@ impl Stream for OrcArrowStrea fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let batch = futures::ready!(Pin::new(&mut self.stream).poll_next(cx)) .map(|r| r.map_err(|e| DataFusionError::External(Box::new(e)))); + + let batch = batch.map(|b| { + b.and_then(|b| { + let mut columns = Vec::with_capacity(b.num_columns()); + for (idx, column) in b.columns().iter().enumerate() { + if column.data_type() != self.output_schema.field(idx).data_type() { + let output = cast(&column, self.output_schema.field(idx).data_type())?; + columns.push(output) + } else { + columns.push(column.clone()) + } + } + let record_batch = DfRecordBatch::try_new(self.output_schema.clone(), columns)?; + + Ok(record_batch) + }) + }); + Poll::Ready(batch) } } diff --git a/src/frontend/src/statement/copy_table_from.rs b/src/frontend/src/statement/copy_table_from.rs index b79bc1d45d..45807a08f6 100644 --- a/src/frontend/src/statement/copy_table_from.rs +++ b/src/frontend/src/statement/copy_table_from.rs @@ -224,7 +224,7 @@ impl StatementExecutor { let stream = new_orc_stream_reader(reader) .await .context(error::ReadOrcSnafu)?; - let stream = OrcArrowStreamReaderAdapter::new(stream); + let stream = OrcArrowStreamReaderAdapter::new(schema, stream); Ok(Box::pin(stream)) }