refactor: simplify parquet writer (#4112)

* refactor: simplify parquet writer

* chore: fix clippy

* refactor: use AsyncArrowWriter instead of BufferedWriter

* refactor: remove BufferedWriter

* fix: add chunk parameter to avoid entity too small issue

* refactor: use AtomicUsize instead of Mutex

* fix: add chunk argument to stream_to_parquet

* chore: fmt

* wip: fail check

* fix: check

* fmt

* refactor: use impl Future instead of async_trait

* fmt

* refactor: use associate types
This commit is contained in:
Lei, HUANG
2024-06-13 15:32:47 +08:00
committed by GitHub
parent 14a2d83594
commit f8ec46493f
8 changed files with 258 additions and 275 deletions

View File

@@ -16,7 +16,7 @@ use std::result;
use std::sync::Arc;
use arrow::record_batch::RecordBatch;
use arrow_schema::{Schema, SchemaRef};
use arrow_schema::Schema;
use async_trait::async_trait;
use datafusion::datasource::physical_plan::{FileMeta, ParquetFileReaderFactory};
use datafusion::error::Result as DatafusionResult;
@@ -30,13 +30,14 @@ use datafusion::physical_plan::SendableRecordBatchStream;
use futures::future::BoxFuture;
use futures::StreamExt;
use object_store::{FuturesAsyncReader, ObjectStore};
use parquet::arrow::AsyncArrowWriter;
use parquet::basic::{Compression, ZstdLevel};
use parquet::file::properties::WriterProperties;
use snafu::ResultExt;
use tokio_util::compat::{Compat, FuturesAsyncReadCompatExt, FuturesAsyncWriteCompatExt};
use crate::buffered_writer::{ArrowWriterCloser, DfRecordBatchEncoder, LazyBufferedWriter};
use crate::error::{self, Result};
use crate::buffered_writer::{ArrowWriterCloser, DfRecordBatchEncoder};
use crate::error::{self, Result, WriteObjectSnafu, WriteParquetSnafu};
use crate::file_format::FileFormat;
use crate::share_buffer::SharedBuffer;
use crate::DEFAULT_WRITE_BUFFER_SIZE;
@@ -174,75 +175,6 @@ impl ArrowWriterCloser for ArrowWriter<SharedBuffer> {
}
}
/// Parquet writer that buffers row groups in memory and writes buffered data to an underlying
/// storage by chunks to reduce memory consumption.
pub struct BufferedWriter {
inner: InnerBufferedWriter,
}
type InnerBufferedWriter = LazyBufferedWriter<
Compat<object_store::FuturesAsyncWriter>,
ArrowWriter<SharedBuffer>,
impl Fn(String) -> BoxFuture<'static, Result<Compat<object_store::FuturesAsyncWriter>>>,
>;
impl BufferedWriter {
fn make_write_factory(
store: ObjectStore,
concurrency: usize,
) -> impl Fn(String) -> BoxFuture<'static, Result<Compat<object_store::FuturesAsyncWriter>>>
{
move |path| {
let store = store.clone();
Box::pin(async move {
store
.writer_with(&path)
.concurrent(concurrency)
.chunk(DEFAULT_WRITE_BUFFER_SIZE.as_bytes() as usize)
.await
.map(|v| v.into_futures_async_write().compat_write())
.context(error::WriteObjectSnafu { path })
})
}
}
pub async fn try_new(
path: String,
store: ObjectStore,
arrow_schema: SchemaRef,
props: Option<WriterProperties>,
buffer_threshold: usize,
concurrency: usize,
) -> error::Result<Self> {
let buffer = SharedBuffer::with_capacity(buffer_threshold);
let arrow_writer = ArrowWriter::try_new(buffer.clone(), arrow_schema.clone(), props)
.context(error::WriteParquetSnafu { path: &path })?;
Ok(Self {
inner: LazyBufferedWriter::new(
buffer_threshold,
buffer,
arrow_writer,
&path,
Self::make_write_factory(store, concurrency),
),
})
}
/// Write a record batch to stream writer.
pub async fn write(&mut self, arrow_batch: &RecordBatch) -> error::Result<()> {
self.inner.write(arrow_batch).await
}
/// Close parquet writer.
///
/// Return file metadata and bytes written.
pub async fn close(self) -> error::Result<(FileMetaData, u64)> {
self.inner.close_with_arrow_writer().await
}
}
/// Output the stream to a parquet file.
///
/// Returns number of rows written.
@@ -250,47 +182,41 @@ pub async fn stream_to_parquet(
mut stream: SendableRecordBatchStream,
store: ObjectStore,
path: &str,
threshold: usize,
concurrency: usize,
) -> Result<usize> {
let write_props = WriterProperties::builder()
.set_compression(Compression::ZSTD(ZstdLevel::default()))
.build();
let schema = stream.schema();
let mut buffered_writer = BufferedWriter::try_new(
path.to_string(),
store,
schema,
Some(write_props),
threshold,
concurrency,
)
.await?;
let inner_writer = store
.writer_with(path)
.concurrent(concurrency)
.chunk(DEFAULT_WRITE_BUFFER_SIZE.as_bytes() as usize)
.await
.map(|w| w.into_futures_async_write().compat_write())
.context(WriteObjectSnafu { path })?;
let mut writer = AsyncArrowWriter::try_new(inner_writer, schema, Some(write_props))
.context(WriteParquetSnafu { path })?;
let mut rows_written = 0;
while let Some(batch) = stream.next().await {
let batch = batch.context(error::ReadRecordBatchSnafu)?;
buffered_writer.write(&batch).await?;
writer
.write(&batch)
.await
.context(WriteParquetSnafu { path })?;
rows_written += batch.num_rows();
}
buffered_writer.close().await?;
writer.close().await.context(WriteParquetSnafu { path })?;
Ok(rows_written)
}
#[cfg(test)]
mod tests {
use std::env;
use std::sync::Arc;
use common_telemetry::warn;
use common_test_util::find_workspace_path;
use datatypes::arrow::array::{ArrayRef, Int64Array, RecordBatch};
use datatypes::arrow::datatypes::{DataType, Field, Schema};
use object_store::services::S3;
use object_store::ObjectStore;
use rand::{thread_rng, Rng};
use super::*;
use crate::file_format::parquet::BufferedWriter;
use crate::test_util::{format_schema, test_store};
fn test_data_root() -> String {
@@ -308,64 +234,4 @@ mod tests {
assert_eq!(vec!["num: Int64: NULL", "str: Utf8: NULL"], formatted);
}
#[tokio::test]
async fn test_parquet_writer() {
common_telemetry::init_default_ut_logging();
let _ = dotenv::dotenv();
let Ok(bucket) = env::var("GT_MINIO_BUCKET") else {
warn!("ignoring test parquet writer");
return;
};
let mut builder = S3::default();
let _ = builder
.root(&uuid::Uuid::new_v4().to_string())
.access_key_id(&env::var("GT_MINIO_ACCESS_KEY_ID").unwrap())
.secret_access_key(&env::var("GT_MINIO_ACCESS_KEY").unwrap())
.bucket(&bucket)
.region(&env::var("GT_MINIO_REGION").unwrap())
.endpoint(&env::var("GT_MINIO_ENDPOINT_URL").unwrap());
let object_store = ObjectStore::new(builder).unwrap().finish();
let file_path = uuid::Uuid::new_v4().to_string();
let fields = vec![
Field::new("field1", DataType::Int64, true),
Field::new("field0", DataType::Int64, true),
];
let arrow_schema = Arc::new(Schema::new(fields));
let mut buffered_writer = BufferedWriter::try_new(
file_path.clone(),
object_store.clone(),
arrow_schema.clone(),
None,
// Sets a small value.
128,
8,
)
.await
.unwrap();
let rows = 200000;
let generator = || {
let columns: Vec<ArrayRef> = vec![
Arc::new(Int64Array::from(
(0..rows)
.map(|_| thread_rng().gen::<i64>())
.collect::<Vec<_>>(),
)),
Arc::new(Int64Array::from(
(0..rows)
.map(|_| thread_rng().gen::<i64>())
.collect::<Vec<_>>(),
)),
];
RecordBatch::try_new(arrow_schema.clone(), columns).unwrap()
};
let batch = generator();
// Writes about ~30Mi
for _ in 0..10 {
buffered_writer.write(&batch).await.unwrap();
}
buffered_writer.close().await.unwrap();
}
}

View File

@@ -146,10 +146,10 @@ impl AccessLayer {
index_options: request.index_options,
}
.build();
let mut writer = ParquetWriter::new(
let mut writer = ParquetWriter::new_with_object_store(
self.object_store.clone(),
file_path,
request.metadata,
self.object_store.clone(),
indexer,
);
writer.write_all(request.source, write_opts).await?

View File

@@ -127,10 +127,10 @@ impl WriteCache {
.build();
// Write to FileCache.
let mut writer = ParquetWriter::new(
let mut writer = ParquetWriter::new_with_object_store(
self.file_cache.local_store(),
self.file_cache.cache_file_path(parquet_key),
write_request.metadata,
self.file_cache.local_store(),
indexer,
);
@@ -246,7 +246,6 @@ pub struct SstUploadRequest {
#[cfg(test)]
mod tests {
use common_test_util::temp_dir::create_temp_dir;
use super::*;

View File

@@ -151,13 +151,6 @@ pub enum Error {
error: ArrowError,
},
#[snafu(display("Failed to write to buffer"))]
WriteBuffer {
#[snafu(implicit)]
location: Location,
source: common_datasource::error::Error,
},
#[snafu(display("Failed to read parquet file, path: {}", path))]
ReadParquet {
path: String,
@@ -167,6 +160,14 @@ pub enum Error {
location: Location,
},
#[snafu(display("Failed to write parquet file"))]
WriteParquet {
#[snafu(source)]
error: parquet::errors::ParquetError,
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Region {} not found", region_id))]
RegionNotFound {
region_id: RegionId,
@@ -808,7 +809,7 @@ impl ErrorExt for Error {
| BuildEntry { .. } => StatusCode::Internal,
OpenRegion { source, .. } => source.status_code(),
WriteBuffer { source, .. } => source.status_code(),
WriteParquet { .. } => StatusCode::Internal,
WriteGroup { source, .. } => source.status_code(),
FieldTypeMismatch { source, .. } => source.status_code(),
SerializeField { .. } => StatusCode::Internal,

View File

@@ -14,6 +14,14 @@
//! SST in parquet format.
use std::sync::Arc;
use common_base::readable_size::ReadableSize;
use parquet::file::metadata::ParquetMetaData;
use crate::sst::file::FileTimeRange;
use crate::sst::DEFAULT_WRITE_BUFFER_SIZE;
pub(crate) mod file_range;
mod format;
pub(crate) mod helper;
@@ -25,14 +33,6 @@ mod row_selection;
mod stats;
pub mod writer;
use std::sync::Arc;
use common_base::readable_size::ReadableSize;
use parquet::file::metadata::ParquetMetaData;
use crate::sst::file::FileTimeRange;
use crate::sst::DEFAULT_WRITE_BUFFER_SIZE;
/// Key of metadata in parquet SST.
pub const PARQUET_METADATA_KEY: &str = "greptime:metadata";
@@ -79,17 +79,18 @@ pub struct SstInfo {
mod tests {
use std::sync::Arc;
use common_datasource::file_format::parquet::BufferedWriter;
use common_time::Timestamp;
use datafusion_common::{Column, ScalarValue};
use datafusion_expr::{BinaryExpr, Expr, Operator};
use datatypes::arrow;
use datatypes::arrow::array::RecordBatch;
use datatypes::arrow::datatypes::{DataType, Field, Schema};
use parquet::arrow::AsyncArrowWriter;
use parquet::basic::{Compression, Encoding, ZstdLevel};
use parquet::file::metadata::KeyValue;
use parquet::file::properties::WriterProperties;
use table::predicate::Predicate;
use tokio_util::compat::FuturesAsyncWriteCompatExt;
use super::*;
use crate::cache::{CacheManager, PageKey};
@@ -123,13 +124,13 @@ mod tests {
row_group_size: 50,
..Default::default()
};
let mut writer = ParquetWriter::new(
let mut writer = ParquetWriter::new_with_object_store(
object_store.clone(),
file_path,
metadata,
object_store.clone(),
Indexer::default(),
);
let info = writer
.write_all(source, &write_opts)
.await
@@ -178,12 +179,13 @@ mod tests {
..Default::default()
};
// Prepare data.
let mut writer = ParquetWriter::new(
let mut writer = ParquetWriter::new_with_object_store(
object_store.clone(),
file_path,
metadata.clone(),
object_store.clone(),
Indexer::default(),
);
writer
.write_all(source, &write_opts)
.await
@@ -252,12 +254,13 @@ mod tests {
// write the sst file and get sst info
// sst info contains the parquet metadata, which is converted from FileMetaData
let mut writer = ParquetWriter::new(
let mut writer = ParquetWriter::new_with_object_store(
object_store.clone(),
file_path,
metadata.clone(),
object_store.clone(),
Indexer::default(),
);
let sst_info = writer
.write_all(source, &write_opts)
.await
@@ -291,10 +294,10 @@ mod tests {
..Default::default()
};
// Prepare data.
let mut writer = ParquetWriter::new(
let mut writer = ParquetWriter::new_with_object_store(
object_store.clone(),
file_path,
metadata.clone(),
object_store.clone(),
Indexer::default(),
);
writer
@@ -344,10 +347,10 @@ mod tests {
..Default::default()
};
// Prepare data.
let mut writer = ParquetWriter::new(
let mut writer = ParquetWriter::new_with_object_store(
object_store.clone(),
file_path,
metadata.clone(),
object_store.clone(),
Indexer::default(),
);
writer
@@ -379,12 +382,13 @@ mod tests {
..Default::default()
};
// Prepare data.
let mut writer = ParquetWriter::new(
let mut writer = ParquetWriter::new_with_object_store(
object_store.clone(),
file_path,
metadata.clone(),
object_store.clone(),
Indexer::default(),
);
writer
.write_all(source, &write_opts)
.await
@@ -453,15 +457,16 @@ mod tests {
&DataType::LargeBinary,
arrow_schema.field_with_name("field_0").unwrap().data_type()
);
let mut buffered_writer = BufferedWriter::try_new(
file_path.clone(),
object_store.clone(),
let mut writer = AsyncArrowWriter::try_new(
object_store
.writer_with(&file_path)
.concurrent(DEFAULT_WRITE_CONCURRENCY)
.await
.map(|w| w.into_futures_async_write().compat_write())
.unwrap(),
arrow_schema.clone(),
Some(writer_props),
write_opts.write_buffer_size.as_bytes() as usize,
DEFAULT_WRITE_CONCURRENCY,
)
.await
.unwrap();
let batch = new_batch_with_binary(&["a"], 0, 60);
@@ -480,8 +485,8 @@ mod tests {
.collect();
let result = RecordBatch::try_new(arrow_schema, arrays).unwrap();
buffered_writer.write(&result).await.unwrap();
buffered_writer.close().await.unwrap();
writer.write(&result).await.unwrap();
writer.close().await.unwrap();
let builder = ParquetReaderBuilder::new(FILE_DIR.to_string(), handle.clone(), object_store);
let mut reader = builder.build().await.unwrap();

View File

@@ -77,8 +77,8 @@ impl WriteFormat {
}
/// Gets the arrow schema to store in parquet.
pub(crate) fn arrow_schema(&self) -> SchemaRef {
self.arrow_schema.clone()
pub(crate) fn arrow_schema(&self) -> &SchemaRef {
&self.arrow_schema
}
/// Convert `batch` to a arrow record batch to store in parquet.
@@ -700,7 +700,7 @@ mod tests {
fn test_to_sst_arrow_schema() {
let metadata = build_test_region_metadata();
let write_format = WriteFormat::new(metadata);
assert_eq!(build_test_arrow_schema(), write_format.arrow_schema());
assert_eq!(&build_test_arrow_schema(), write_format.arrow_schema());
}
#[test]

View File

@@ -14,13 +14,16 @@
//! Parquet writer.
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use common_datasource::file_format::parquet::BufferedWriter;
use common_telemetry::debug;
use common_time::Timestamp;
use futures::TryFutureExt;
use object_store::ObjectStore;
use datatypes::arrow::datatypes::SchemaRef;
use object_store::{FuturesAsyncWriter, ObjectStore};
use parquet::arrow::AsyncArrowWriter;
use parquet::basic::{Compression, Encoding, ZstdLevel};
use parquet::file::metadata::KeyValue;
use parquet::file::properties::{WriterProperties, WriterPropertiesBuilder};
@@ -28,38 +31,78 @@ use parquet::schema::types::ColumnPath;
use snafu::ResultExt;
use store_api::metadata::RegionMetadataRef;
use store_api::storage::consts::SEQUENCE_COLUMN_NAME;
use tokio::io::AsyncWrite;
use tokio_util::compat::{Compat, FuturesAsyncWriteCompatExt};
use crate::error::{InvalidMetadataSnafu, Result, WriteBufferSnafu};
use crate::error::{InvalidMetadataSnafu, OpenDalSnafu, Result, WriteParquetSnafu};
use crate::read::{Batch, Source};
use crate::sst::index::Indexer;
use crate::sst::parquet::format::WriteFormat;
use crate::sst::parquet::helper::parse_parquet_metadata;
use crate::sst::parquet::{SstInfo, WriteOptions, PARQUET_METADATA_KEY};
use crate::sst::DEFAULT_WRITE_CONCURRENCY;
use crate::sst::{DEFAULT_WRITE_BUFFER_SIZE, DEFAULT_WRITE_CONCURRENCY};
/// Parquet SST writer.
pub struct ParquetWriter {
/// SST output file path.
file_path: String,
pub struct ParquetWriter<F: WriterFactory> {
writer: Option<AsyncArrowWriter<SizeAwareWriter<F::Writer>>>,
writer_factory: F,
/// Region metadata of the source and the target SST.
metadata: RegionMetadataRef,
object_store: ObjectStore,
indexer: Indexer,
bytes_written: Arc<AtomicUsize>,
}
impl ParquetWriter {
/// Creates a new parquet SST writer.
pub fn new(
file_path: String,
metadata: RegionMetadataRef,
pub trait WriterFactory {
type Writer: AsyncWrite + Send + Unpin;
fn create(&mut self) -> impl Future<Output = Result<Self::Writer>>;
}
pub struct ObjectStoreWriterFactory {
path: String,
object_store: ObjectStore,
}
impl WriterFactory for ObjectStoreWriterFactory {
type Writer = Compat<FuturesAsyncWriter>;
async fn create(&mut self) -> Result<Self::Writer> {
self.object_store
.writer_with(&self.path)
.chunk(DEFAULT_WRITE_BUFFER_SIZE.as_bytes() as usize)
.concurrent(DEFAULT_WRITE_CONCURRENCY)
.await
.map(|v| v.into_futures_async_write().compat_write())
.context(OpenDalSnafu)
}
}
impl ParquetWriter<ObjectStoreWriterFactory> {
pub fn new_with_object_store(
object_store: ObjectStore,
path: String,
metadata: RegionMetadataRef,
indexer: Indexer,
) -> ParquetWriter {
ParquetWriter {
file_path,
) -> ParquetWriter<ObjectStoreWriterFactory> {
ParquetWriter::new(
ObjectStoreWriterFactory { path, object_store },
metadata,
object_store,
indexer,
)
}
}
impl<F> ParquetWriter<F>
where
F: WriterFactory,
{
/// Creates a new parquet SST writer.
pub fn new(factory: F, metadata: RegionMetadataRef, indexer: Indexer) -> ParquetWriter<F> {
ParquetWriter {
writer: None,
writer_factory: factory,
metadata,
indexer,
bytes_written: Arc::new(AtomicUsize::new(0)),
}
}
@@ -71,42 +114,24 @@ impl ParquetWriter {
mut source: Source,
opts: &WriteOptions,
) -> Result<Option<SstInfo>> {
let json = self.metadata.to_json().context(InvalidMetadataSnafu)?;
let key_value_meta = KeyValue::new(PARQUET_METADATA_KEY.to_string(), json);
// TODO(yingwen): Find and set proper column encoding for internal columns: op type and tsid.
let props_builder = WriterProperties::builder()
.set_key_value_metadata(Some(vec![key_value_meta]))
.set_compression(Compression::ZSTD(ZstdLevel::default()))
.set_encoding(Encoding::PLAIN)
.set_max_row_group_size(opts.row_group_size);
let props_builder = Self::customize_column_config(props_builder, &self.metadata);
let writer_props = props_builder.build();
let write_format = WriteFormat::new(self.metadata.clone());
let mut buffered_writer = BufferedWriter::try_new(
self.file_path.clone(),
self.object_store.clone(),
write_format.arrow_schema(),
Some(writer_props),
opts.write_buffer_size.as_bytes() as usize,
DEFAULT_WRITE_CONCURRENCY,
)
.await
.context(WriteBufferSnafu)?;
let mut stats = SourceStats::default();
while let Some(batch) = write_next_batch(&mut source, &write_format, &mut buffered_writer)
.or_else(|err| async {
// abort index creation if error occurs.
self.indexer.abort().await;
Err(err)
})
.await?
while let Some(res) = self
.write_next_batch(&mut source, &write_format, opts)
.await
.transpose()
{
stats.update(&batch);
self.indexer.update(&batch).await;
match res {
Ok(batch) => {
stats.update(&batch);
self.indexer.update(&batch).await;
}
Err(e) => {
self.indexer.abort().await;
return Err(e);
}
}
}
let index_size = self.indexer.finish().await;
@@ -114,16 +139,18 @@ impl ParquetWriter {
let index_file_size = index_size.unwrap_or(0) as u64;
if stats.num_rows == 0 {
debug!(
"No data written, try to stop the writer: {}",
self.file_path
);
buffered_writer.close().await.context(WriteBufferSnafu)?;
return Ok(None);
}
let (file_meta, file_size) = buffered_writer.close().await.context(WriteBufferSnafu)?;
let Some(mut arrow_writer) = self.writer.take() else {
// No batch actually written.
return Ok(None);
};
arrow_writer.flush().await.context(WriteParquetSnafu)?;
let file_meta = arrow_writer.close().await.context(WriteParquetSnafu)?;
let file_size = self.bytes_written.load(Ordering::Relaxed) as u64;
// Safety: num rows > 0 so we must have min/max.
let time_range = stats.time_range.unwrap();
@@ -160,24 +187,59 @@ impl ParquetWriter {
.set_column_encoding(ts_col.clone(), Encoding::DELTA_BINARY_PACKED)
.set_column_dictionary_enabled(ts_col, false)
}
}
async fn write_next_batch(
source: &mut Source,
write_format: &WriteFormat,
buffered_writer: &mut BufferedWriter,
) -> Result<Option<Batch>> {
let Some(batch) = source.next_batch().await? else {
return Ok(None);
};
async fn write_next_batch(
&mut self,
source: &mut Source,
write_format: &WriteFormat,
opts: &WriteOptions,
) -> Result<Option<Batch>> {
let Some(batch) = source.next_batch().await? else {
return Ok(None);
};
let arrow_batch = write_format.convert_batch(&batch)?;
buffered_writer
.write(&arrow_batch)
.await
.context(WriteBufferSnafu)?;
let arrow_batch = write_format.convert_batch(&batch)?;
self.maybe_init_writer(write_format.arrow_schema(), opts)
.await?
.write(&arrow_batch)
.await
.context(WriteParquetSnafu)?;
Ok(Some(batch))
}
Ok(Some(batch))
async fn maybe_init_writer(
&mut self,
schema: &SchemaRef,
opts: &WriteOptions,
) -> Result<&mut AsyncArrowWriter<SizeAwareWriter<F::Writer>>> {
if let Some(ref mut w) = self.writer {
Ok(w)
} else {
let json = self.metadata.to_json().context(InvalidMetadataSnafu)?;
let key_value_meta = KeyValue::new(PARQUET_METADATA_KEY.to_string(), json);
// TODO(yingwen): Find and set proper column encoding for internal columns: op type and tsid.
let props_builder = WriterProperties::builder()
.set_key_value_metadata(Some(vec![key_value_meta]))
.set_compression(Compression::ZSTD(ZstdLevel::default()))
.set_encoding(Encoding::PLAIN)
.set_max_row_group_size(opts.row_group_size);
let props_builder = Self::customize_column_config(props_builder, &self.metadata);
let writer_props = props_builder.build();
let writer = SizeAwareWriter::new(
self.writer_factory.create().await?,
self.bytes_written.clone(),
);
let arrow_writer =
AsyncArrowWriter::try_new(writer, schema.clone(), Some(writer_props))
.context(WriteParquetSnafu)?;
self.writer = Some(arrow_writer);
// safety: self.writer is assigned above
Ok(self.writer.as_mut().unwrap())
}
}
}
#[derive(Default)]
@@ -208,3 +270,54 @@ impl SourceStats {
}
}
}
/// Workaround for [AsyncArrowWriter] does not provide a method to
/// get total bytes written after close.
struct SizeAwareWriter<W> {
inner: W,
size: Arc<AtomicUsize>,
}
impl<W> SizeAwareWriter<W> {
fn new(inner: W, size: Arc<AtomicUsize>) -> Self {
Self {
inner,
size: size.clone(),
}
}
}
impl<W> AsyncWrite for SizeAwareWriter<W>
where
W: AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::result::Result<usize, std::io::Error>> {
let this = self.as_mut().get_mut();
match Pin::new(&mut this.inner).poll_write(cx, buf) {
Poll::Ready(Ok(bytes_written)) => {
this.size.fetch_add(bytes_written, Ordering::Relaxed);
Poll::Ready(Ok(bytes_written))
}
other => other,
}
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), std::io::Error>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), std::io::Error>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
}

View File

@@ -79,7 +79,6 @@ impl StatementExecutor {
Box::pin(DfRecordBatchStreamAdapter::new(stream)),
object_store,
path,
threshold,
WRITE_CONCURRENCY,
)
.await