diff --git a/src/common/recordbatch/src/adapter.rs b/src/common/recordbatch/src/adapter.rs index 2b8436ec4e..14b8fba0dd 100644 --- a/src/common/recordbatch/src/adapter.rs +++ b/src/common/recordbatch/src/adapter.rs @@ -121,7 +121,8 @@ impl Stream for RecordBatchStreamAdapter { enum AsyncRecordBatchStreamAdapterState { Uninit(FutureStream), - Inited(std::result::Result), + Ready(DfSendableRecordBatchStream), + Failed, } pub struct AsyncRecordBatchStreamAdapter { @@ -151,28 +152,26 @@ impl Stream for AsyncRecordBatchStreamAdapter { loop { match &mut self.state { AsyncRecordBatchStreamAdapterState::Uninit(stream_future) => { - self.state = AsyncRecordBatchStreamAdapterState::Inited(ready!(Pin::new( - stream_future - ) - .poll(cx))); - continue; + match ready!(Pin::new(stream_future).poll(cx)) { + Ok(stream) => { + self.state = AsyncRecordBatchStreamAdapterState::Ready(stream); + continue; + } + Err(e) => { + self.state = AsyncRecordBatchStreamAdapterState::Failed; + return Poll::Ready(Some( + Err(e).context(error::InitRecordbatchStreamSnafu), + )); + } + }; } - AsyncRecordBatchStreamAdapterState::Inited(stream) => match stream { - Ok(stream) => { - return Poll::Ready(ready!(Pin::new(stream).poll_next(cx)).map(|df| { - let df_record_batch = df.context(error::PollStreamSnafu)?; - RecordBatch::try_from_df_record_batch(self.schema(), df_record_batch) - })); - } - Err(e) => { - return Poll::Ready(Some( - error::CreateRecordBatchesSnafu { - reason: format!("Read error {:?} from stream", e), - } - .fail(), - )) - } - }, + AsyncRecordBatchStreamAdapterState::Ready(stream) => { + return Poll::Ready(ready!(Pin::new(stream).poll_next(cx)).map(|x| { + let df_record_batch = x.context(error::PollStreamSnafu)?; + RecordBatch::try_from_df_record_batch(self.schema(), df_record_batch) + })) + } + AsyncRecordBatchStreamAdapterState::Failed => return Poll::Ready(None), } } } @@ -183,3 +182,104 @@ impl Stream for AsyncRecordBatchStreamAdapter { (0, None) } } + +#[cfg(test)] +mod test { + use common_error::mock::MockError; + use common_error::prelude::{BoxedError, StatusCode}; + use datatypes::prelude::ConcreteDataType; + use datatypes::schema::ColumnSchema; + use datatypes::vectors::Int32Vector; + + use super::*; + use crate::RecordBatches; + + #[tokio::test] + async fn test_async_recordbatch_stream_adaptor() { + struct MaybeErrorRecordBatchStream { + items: Vec>, + } + + impl RecordBatchStream for MaybeErrorRecordBatchStream { + fn schema(&self) -> SchemaRef { + unimplemented!() + } + } + + impl Stream for MaybeErrorRecordBatchStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll> { + if let Some(batch) = self.items.pop() { + Poll::Ready(Some(Ok(batch?))) + } else { + Poll::Ready(None) + } + } + } + + fn new_future_stream( + maybe_recordbatches: Result>>, + ) -> FutureStream { + Box::pin(async move { + maybe_recordbatches + .map(|items| { + Box::pin(DfRecordBatchStreamAdapter::new(Box::pin( + MaybeErrorRecordBatchStream { items }, + ))) as _ + }) + .map_err(|e| DataFusionError::External(Box::new(e))) + }) + } + + let schema = Arc::new(Schema::new(vec![ColumnSchema::new( + "a", + ConcreteDataType::int32_datatype(), + false, + )])); + let batch1 = RecordBatch::new( + schema.clone(), + vec![Arc::new(Int32Vector::from_slice(&[1])) as _], + ) + .unwrap(); + let batch2 = RecordBatch::new( + schema.clone(), + vec![Arc::new(Int32Vector::from_slice(&[2])) as _], + ) + .unwrap(); + + let success_stream = new_future_stream(Ok(vec![Ok(batch1.clone()), Ok(batch2.clone())])); + let adapter = AsyncRecordBatchStreamAdapter::new(schema.clone(), success_stream); + let collected = RecordBatches::try_collect(Box::pin(adapter)).await.unwrap(); + assert_eq!( + collected, + RecordBatches::try_new(schema.clone(), vec![batch2.clone(), batch1.clone()]).unwrap() + ); + + let poll_err_stream = new_future_stream(Ok(vec![ + Ok(batch1.clone()), + Err(error::Error::External { + source: BoxedError::new(MockError::new(StatusCode::Unknown)), + }), + ])); + let adapter = AsyncRecordBatchStreamAdapter::new(schema.clone(), poll_err_stream); + let result = RecordBatches::try_collect(Box::pin(adapter)).await; + assert_eq!( + result.unwrap_err().to_string(), + "Failed to poll stream, source: External error: External error, source: Unknown" + ); + + let failed_to_init_stream = new_future_stream(Err(error::Error::External { + source: BoxedError::new(MockError::new(StatusCode::Internal)), + })); + let adapter = AsyncRecordBatchStreamAdapter::new(schema.clone(), failed_to_init_stream); + let result = RecordBatches::try_collect(Box::pin(adapter)).await; + assert_eq!( + result.unwrap_err().to_string(), + "Failed to init Recordbatch stream, source: External error: External error, source: Internal" + ); + } +} diff --git a/src/common/recordbatch/src/error.rs b/src/common/recordbatch/src/error.rs index 0937441338..c77e2f3f48 100644 --- a/src/common/recordbatch/src/error.rs +++ b/src/common/recordbatch/src/error.rs @@ -64,6 +64,12 @@ pub enum Error { source: datatypes::arrow::error::ArrowError, backtrace: Backtrace, }, + + #[snafu(display("Failed to init Recordbatch stream, source: {}", source))] + InitRecordbatchStream { + source: datafusion_common::DataFusionError, + backtrace: Backtrace, + }, } impl ErrorExt for Error { @@ -74,7 +80,8 @@ impl ErrorExt for Error { Error::DataTypes { .. } | Error::CreateRecordBatches { .. } | Error::PollStream { .. } - | Error::Format { .. } => StatusCode::Internal, + | Error::Format { .. } + | Error::InitRecordbatchStream { .. } => StatusCode::Internal, Error::External { source } => source.status_code(),