refactor: refactor BufferedWriter (#1439)

* feat: implement ApproximateBufWriter

* refactor: refactor BufferedWriter

* refactor: remove ApproximateBufWriter

* fix: fix losing pending writes issue

* chore: fmt

* chore: remove unused import

* chore: rename method name

* feat: return written row count

* chore: apply suggestions from CR

* fix: fix counting the bytes_written twice issue
This commit is contained in:
Weny Xu
2023-04-27 15:45:33 +09:00
committed by GitHub
parent 09f55e3cd8
commit bf35620904
17 changed files with 571 additions and 146 deletions

3
Cargo.lock generated
View File

@@ -1612,8 +1612,10 @@ dependencies = [
"async-compression",
"async-trait",
"bytes",
"common-base",
"common-error",
"common-runtime",
"common-test-util",
"datafusion",
"derive_builder 0.12.0",
"futures",
@@ -8322,6 +8324,7 @@ dependencies = [
"atomic_float",
"bytes",
"common-base",
"common-datasource",
"common-error",
"common-query",
"common-recordbatch",

View File

@@ -80,7 +80,7 @@ snafu = { version = "0.7", features = ["backtraces"] }
sqlparser = "0.33"
tempfile = "3"
tokio = { version = "1.24.2", features = ["full"] }
tokio-util = { version = "0.7", features = ["io-util"] }
tokio-util = { version = "0.7", features = ["io-util", "compat"] }
tonic = { version = "0.9", features = ["tls"] }
uuid = { version = "1", features = ["serde", "v4", "fast-rng"] }
metrics = "0.20"

View File

@@ -17,6 +17,7 @@ async-compression = { version = "0.3", features = [
] }
async-trait.workspace = true
bytes = "1.1"
common-base = { path = "../base" }
common-error = { path = "../error" }
common-runtime = { path = "../runtime" }
datafusion.workspace = true
@@ -28,3 +29,6 @@ snafu.workspace = true
tokio.workspace = true
tokio-util.workspace = true
url = "2.3"
[dev-dependencies]
common-test-util = { path = "../test-util" }

View File

@@ -0,0 +1,138 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use datafusion::parquet::format::FileMetaData;
use object_store::Writer;
use snafu::{OptionExt, ResultExt};
use tokio::io::{AsyncWrite, AsyncWriteExt};
use tokio_util::compat::Compat;
use crate::error::{self, Result};
use crate::share_buffer::SharedBuffer;
pub struct BufferedWriter<T, U> {
writer: T,
/// None stands for [`BufferedWriter`] closed.
encoder: Option<U>,
buffer: SharedBuffer,
bytes_written: u64,
flushed: bool,
threshold: usize,
}
pub trait DfRecordBatchEncoder {
fn write(&mut self, batch: &RecordBatch) -> Result<()>;
}
#[async_trait]
pub trait ArrowWriterCloser {
async fn close(mut self) -> Result<FileMetaData>;
}
pub type DefaultBufferedWriter<E> = BufferedWriter<Compat<Writer>, E>;
impl<T: AsyncWrite + Send + Unpin, U: DfRecordBatchEncoder + ArrowWriterCloser>
BufferedWriter<T, U>
{
pub async fn close_with_arrow_writer(mut self) -> Result<(FileMetaData, u64)> {
let encoder = self
.encoder
.take()
.context(error::BufferedWriterClosedSnafu)?;
let metadata = encoder.close().await?;
let written = self.try_flush(true).await?;
// It's important to shut down! flushes all pending writes
self.close().await?;
Ok((metadata, written))
}
}
impl<T: AsyncWrite + Send + Unpin, U: DfRecordBatchEncoder> BufferedWriter<T, U> {
pub async fn close(&mut self) -> Result<()> {
self.writer.shutdown().await.context(error::AsyncWriteSnafu)
}
pub fn new(threshold: usize, buffer: SharedBuffer, encoder: U, writer: T) -> Self {
Self {
threshold,
writer,
encoder: Some(encoder),
buffer,
bytes_written: 0,
flushed: false,
}
}
pub fn bytes_written(&self) -> u64 {
self.bytes_written
}
pub async fn write(&mut self, batch: &RecordBatch) -> Result<()> {
let encoder = self
.encoder
.as_mut()
.context(error::BufferedWriterClosedSnafu)?;
encoder.write(batch)?;
self.try_flush(false).await?;
Ok(())
}
pub fn flushed(&self) -> bool {
self.flushed
}
pub async fn try_flush(&mut self, all: bool) -> Result<u64> {
let mut bytes_written: u64 = 0;
// Once buffered data size reaches threshold, split the data in chunks (typically 4MB)
// and write to underlying storage.
while self.buffer.buffer.lock().unwrap().len() >= self.threshold {
let chunk = {
let mut buffer = self.buffer.buffer.lock().unwrap();
buffer.split_to(self.threshold)
};
let size = chunk.len();
self.writer
.write_all(&chunk)
.await
.context(error::AsyncWriteSnafu)?;
bytes_written += size as u64;
}
if all {
bytes_written += self.try_flush_all().await?;
}
self.flushed = bytes_written > 0;
self.bytes_written += bytes_written;
Ok(bytes_written)
}
async fn try_flush_all(&mut self) -> Result<u64> {
let remain = self.buffer.buffer.lock().unwrap().split();
let size = remain.len();
self.writer
.write_all(&remain)
.await
.context(error::AsyncWriteSnafu)?;
Ok(size as u64)
}
}

View File

@@ -14,7 +14,9 @@
use std::any::Any;
use arrow_schema::ArrowError;
use common_error::prelude::*;
use datafusion::parquet::errors::ParquetError;
use snafu::Location;
use url::ParseError;
@@ -68,6 +70,37 @@ pub enum Error {
source: object_store::Error,
},
#[snafu(display("Failed to write object to path: {}, source: {}", path, source))]
WriteObject {
path: String,
location: Location,
source: object_store::Error,
},
#[snafu(display("Failed to write: {}", source))]
AsyncWrite {
source: std::io::Error,
location: Location,
},
#[snafu(display("Failed to write record batch: {}", source))]
WriteRecordBatch {
location: Location,
source: ArrowError,
},
#[snafu(display("Failed to encode record batch: {}", source))]
EncodeRecordBatch {
location: Location,
source: ParquetError,
},
#[snafu(display("Failed to read record batch: {}", source))]
ReadRecordBatch {
location: Location,
source: datafusion::error::DataFusionError,
},
#[snafu(display("Failed to read parquet source: {}", source))]
ReadParquetSnafu {
location: Location,
@@ -118,6 +151,9 @@ pub enum Error {
#[snafu(display("Missing required field: {}", name))]
MissingRequiredField { name: String, location: Location },
#[snafu(display("Buffered writer closed"))]
BufferedWriterClosed { location: Location },
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -126,9 +162,11 @@ impl ErrorExt for Error {
fn status_code(&self) -> StatusCode {
use Error::*;
match self {
BuildBackend { .. } | ListObjects { .. } | ReadObject { .. } => {
StatusCode::StorageUnavailable
}
BuildBackend { .. }
| ListObjects { .. }
| ReadObject { .. }
| WriteObject { .. }
| AsyncWrite { .. } => StatusCode::StorageUnavailable,
UnsupportedBackendProtocol { .. }
| UnsupportedCompressionType { .. }
@@ -144,7 +182,12 @@ impl ErrorExt for Error {
| MergeSchema { .. }
| MissingRequiredField { .. } => StatusCode::InvalidArguments,
Decompression { .. } | JoinHandle { .. } => StatusCode::Unexpected,
Decompression { .. }
| JoinHandle { .. }
| ReadRecordBatch { .. }
| WriteRecordBatch { .. }
| EncodeRecordBatch { .. }
| BufferedWriterClosed { .. } => StatusCode::Unexpected,
}
}
@@ -166,6 +209,12 @@ impl ErrorExt for Error {
ParseFormat { location, .. } => Some(*location),
MergeSchema { location, .. } => Some(*location),
MissingRequiredField { location, .. } => Some(*location),
WriteObject { location, .. } => Some(*location),
ReadRecordBatch { location, .. } => Some(*location),
WriteRecordBatch { location, .. } => Some(*location),
AsyncWrite { location, .. } => Some(*location),
EncodeRecordBatch { location, .. } => Some(*location),
BufferedWriterClosed { location, .. } => Some(*location),
UnsupportedBackendProtocol { location, .. } => Some(*location),
EmptyHostPath { location, .. } => Some(*location),

View File

@@ -31,15 +31,19 @@ use async_trait::async_trait;
use bytes::{Buf, Bytes};
use datafusion::error::{DataFusionError, Result as DataFusionResult};
use datafusion::physical_plan::file_format::FileOpenFuture;
use datafusion::physical_plan::SendableRecordBatchStream;
use futures::StreamExt;
use object_store::ObjectStore;
use snafu::ResultExt;
use tokio_util::compat::FuturesAsyncWriteCompatExt;
use self::csv::CsvFormat;
use self::json::JsonFormat;
use self::parquet::ParquetFormat;
use crate::buffered_writer::{BufferedWriter, DfRecordBatchEncoder};
use crate::compression::CompressionType;
use crate::error::{self, Result};
use crate::share_buffer::SharedBuffer;
pub const FORMAT_COMPRESSION_TYPE: &str = "COMPRESSION_TYPE";
pub const FORMAT_DELIMTERL: &str = "DELIMTERL";
@@ -167,3 +171,36 @@ pub async fn infer_schemas(
}
ArrowSchema::try_merge(schemas).context(error::MergeSchemaSnafu)
}
pub async fn stream_to_file<T: DfRecordBatchEncoder, U: Fn(SharedBuffer) -> T>(
mut stream: SendableRecordBatchStream,
store: ObjectStore,
path: String,
threshold: usize,
encoder_factory: U,
) -> Result<usize> {
let writer = store
.writer(&path)
.await
.context(error::WriteObjectSnafu { path: &path })?
.compat_write();
let buffer = SharedBuffer::with_capacity(threshold);
let encoder = encoder_factory(buffer.clone());
let mut writer = BufferedWriter::new(threshold, buffer, encoder, writer);
let mut rows = 0;
while let Some(batch) = stream.next().await {
let batch = batch.context(error::ReadRecordBatchSnafu)?;
writer.write(&batch).await?;
rows += batch.num_rows();
}
// Flushes all pending writes
writer.try_flush(true).await?;
writer.close().await?;
Ok(rows)
}

View File

@@ -18,19 +18,24 @@ use std::sync::Arc;
use arrow::csv;
use arrow::csv::reader::infer_reader_schema as infer_csv_schema;
use arrow::record_batch::RecordBatch;
use arrow_schema::{Schema, SchemaRef};
use async_trait::async_trait;
use common_runtime;
use datafusion::error::Result as DataFusionResult;
use datafusion::physical_plan::file_format::{FileMeta, FileOpenFuture, FileOpener};
use datafusion::physical_plan::SendableRecordBatchStream;
use derive_builder::Builder;
use object_store::ObjectStore;
use snafu::ResultExt;
use tokio_util::io::SyncIoBridge;
use super::stream_to_file;
use crate::buffered_writer::DfRecordBatchEncoder;
use crate::compression::CompressionType;
use crate::error::{self, Result};
use crate::file_format::{self, open_with_decoder, FileFormat};
use crate::share_buffer::SharedBuffer;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct CsvFormat {
@@ -182,6 +187,24 @@ impl FileFormat for CsvFormat {
}
}
pub async fn stream_to_csv(
stream: SendableRecordBatchStream,
store: ObjectStore,
path: String,
threshold: usize,
) -> Result<usize> {
stream_to_file(stream, store, path, threshold, |buffer| {
csv::Writer::new(buffer)
})
.await
}
impl DfRecordBatchEncoder for csv::Writer<SharedBuffer> {
fn write(&mut self, batch: &RecordBatch) -> Result<()> {
self.write(batch).context(error::WriteRecordBatchSnafu)
}
}
#[cfg(test)]
mod tests {

View File

@@ -19,19 +19,25 @@ use std::sync::Arc;
use arrow::datatypes::SchemaRef;
use arrow::json::reader::{infer_json_schema_from_iterator, ValueIter};
use arrow::json::RawReaderBuilder;
use arrow::json::writer::LineDelimited;
use arrow::json::{self, RawReaderBuilder};
use arrow::record_batch::RecordBatch;
use arrow_schema::Schema;
use async_trait::async_trait;
use common_runtime;
use datafusion::error::{DataFusionError, Result as DataFusionResult};
use datafusion::physical_plan::file_format::{FileMeta, FileOpenFuture, FileOpener};
use datafusion::physical_plan::SendableRecordBatchStream;
use object_store::ObjectStore;
use snafu::ResultExt;
use tokio_util::io::SyncIoBridge;
use super::stream_to_file;
use crate::buffered_writer::DfRecordBatchEncoder;
use crate::compression::CompressionType;
use crate::error::{self, Result};
use crate::file_format::{self, open_with_decoder, FileFormat};
use crate::share_buffer::SharedBuffer;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct JsonFormat {
@@ -140,6 +146,25 @@ impl FileOpener for JsonOpener {
}
}
pub async fn stream_to_json(
stream: SendableRecordBatchStream,
store: ObjectStore,
path: String,
threshold: usize,
) -> Result<usize> {
stream_to_file(stream, store, path, threshold, |buffer| {
json::LineDelimitedWriter::new(buffer)
})
.await
}
impl DfRecordBatchEncoder for json::Writer<SharedBuffer, LineDelimited> {
fn write(&mut self, batch: &RecordBatch) -> Result<()> {
self.write(batch.clone())
.context(error::WriteRecordBatchSnafu)
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -15,21 +15,25 @@
use std::result;
use std::sync::Arc;
use arrow::record_batch::RecordBatch;
use arrow_schema::Schema;
use async_trait::async_trait;
use datafusion::error::Result as DatafusionResult;
use datafusion::parquet::arrow::async_reader::AsyncFileReader;
use datafusion::parquet::arrow::parquet_to_arrow_schema;
use datafusion::parquet::arrow::{parquet_to_arrow_schema, ArrowWriter};
use datafusion::parquet::errors::{ParquetError, Result as ParquetResult};
use datafusion::parquet::file::metadata::ParquetMetaData;
use datafusion::parquet::format::FileMetaData;
use datafusion::physical_plan::file_format::{FileMeta, ParquetFileReaderFactory};
use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet;
use futures::future::BoxFuture;
use object_store::{ObjectStore, Reader};
use snafu::ResultExt;
use crate::buffered_writer::{ArrowWriterCloser, DfRecordBatchEncoder};
use crate::error::{self, Result};
use crate::file_format::FileFormat;
use crate::share_buffer::SharedBuffer;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct ParquetFormat {}
@@ -139,6 +143,19 @@ impl AsyncFileReader for LazyParquetFileReader {
}
}
impl DfRecordBatchEncoder for ArrowWriter<SharedBuffer> {
fn write(&mut self, batch: &RecordBatch) -> Result<()> {
self.write(batch).context(error::EncodeRecordBatchSnafu)
}
}
#[async_trait]
impl ArrowWriterCloser for ArrowWriter<SharedBuffer> {
async fn close(self) -> Result<FileMetaData> {
self.close().context(error::EncodeRecordBatchSnafu)
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -17,10 +17,7 @@ use std::collections::HashMap;
use std::sync::Arc;
use std::vec;
use arrow_schema::SchemaRef;
use datafusion::assert_batches_eq;
use datafusion::datasource::listing::PartitionedFile;
use datafusion::datasource::object_store::ObjectStoreUrl;
use datafusion::execution::context::TaskContext;
use datafusion::physical_plan::file_format::{FileOpener, FileScanConfig, FileStream, ParquetExec};
use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet;
@@ -35,21 +32,7 @@ use crate::file_format::csv::{CsvConfigBuilder, CsvOpener};
use crate::file_format::json::JsonOpener;
use crate::file_format::parquet::DefaultParquetFileReaderFactory;
use crate::file_format::Format;
use crate::test_util::{self, test_basic_schema, test_store};
fn scan_config(file_schema: SchemaRef, limit: Option<usize>, filename: &str) -> FileScanConfig {
FileScanConfig {
object_store_url: ObjectStoreUrl::parse("empty://").unwrap(), // won't be used
file_schema,
file_groups: vec![vec![PartitionedFile::new(filename.to_string(), 10)]],
statistics: Default::default(),
projection: None,
limit,
table_partition_cols: vec![],
output_ordering: None,
infinite_source: false,
}
}
use crate::test_util::{self, scan_config, test_basic_schema, test_store};
struct Test<'a, T: FileOpener> {
config: FileScanConfig,

View File

@@ -14,11 +14,15 @@
#![feature(assert_matches)]
pub mod buffered_writer;
pub mod compression;
pub mod error;
pub mod file_format;
pub mod lister;
pub mod object_store;
pub mod share_buffer;
#[cfg(test)]
pub mod test_util;
#[cfg(test)]
pub mod tests;
pub mod util;

View File

@@ -0,0 +1,46 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::io::Write;
use std::sync::{Arc, Mutex};
use bytes::{BufMut, BytesMut};
#[derive(Clone, Default)]
pub struct SharedBuffer {
pub buffer: Arc<Mutex<BytesMut>>,
}
impl SharedBuffer {
pub fn with_capacity(size: usize) -> Self {
Self {
buffer: Arc::new(Mutex::new(BytesMut::with_capacity(size))),
}
}
}
impl Write for SharedBuffer {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let len = buf.len();
let mut buffer = self.buffer.lock().unwrap();
buffer.put_slice(buf);
Ok(len)
}
fn flush(&mut self) -> std::io::Result<()> {
// This flush implementation is intentionally left to blank.
// The actual flush is in `BufferedWriter::try_flush`
Ok(())
}
}

View File

@@ -16,9 +16,19 @@ use std::path::PathBuf;
use std::sync::Arc;
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use common_test_util::temp_dir::{create_temp_dir, TempDir};
use datafusion::datasource::listing::PartitionedFile;
use datafusion::datasource::object_store::ObjectStoreUrl;
use datafusion::physical_plan::file_format::{FileScanConfig, FileStream};
use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet;
use object_store::services::Fs;
use object_store::ObjectStore;
use crate::compression::CompressionType;
use crate::file_format::csv::{stream_to_csv, CsvConfigBuilder, CsvOpener};
use crate::file_format::json::{stream_to_json, JsonOpener};
use crate::test_util;
pub const TEST_BATCH_SIZE: usize = 100;
pub fn get_data_dir(path: &str) -> PathBuf {
@@ -50,6 +60,15 @@ pub fn test_store(root: &str) -> ObjectStore {
ObjectStore::new(builder).unwrap().finish()
}
pub fn test_tmp_store(root: &str) -> (ObjectStore, TempDir) {
let dir = create_temp_dir(root);
let mut builder = Fs::default();
builder.root("/");
(ObjectStore::new(builder).unwrap().finish(), dir)
}
pub fn test_basic_schema() -> SchemaRef {
let schema = Schema::new(vec![
Field::new("num", DataType::Int64, false),
@@ -57,3 +76,100 @@ pub fn test_basic_schema() -> SchemaRef {
]);
Arc::new(schema)
}
pub fn scan_config(file_schema: SchemaRef, limit: Option<usize>, filename: &str) -> FileScanConfig {
FileScanConfig {
object_store_url: ObjectStoreUrl::parse("empty://").unwrap(), // won't be used
file_schema,
file_groups: vec![vec![PartitionedFile::new(filename.to_string(), 10)]],
statistics: Default::default(),
projection: None,
limit,
table_partition_cols: vec![],
output_ordering: None,
infinite_source: false,
}
}
pub async fn setup_stream_to_json_test(origin_path: &str, threshold: impl Fn(usize) -> usize) {
let store = test_store("/");
let schema = test_basic_schema();
let json_opener = JsonOpener::new(
test_util::TEST_BATCH_SIZE,
schema.clone(),
store.clone(),
CompressionType::UNCOMPRESSED,
);
let size = store.read(origin_path).await.unwrap().len();
let config = scan_config(schema.clone(), None, origin_path);
let stream = FileStream::new(&config, 0, json_opener, &ExecutionPlanMetricsSet::new()).unwrap();
let (tmp_store, dir) = test_tmp_store("test_stream_to_json");
let output_path = format!("{}/{}", dir.path().display(), "output");
stream_to_json(
Box::pin(stream),
tmp_store.clone(),
output_path.clone(),
threshold(size),
)
.await
.unwrap();
let written = tmp_store.read(&output_path).await.unwrap();
let origin = store.read(origin_path).await.unwrap();
// ignores `\n`
assert_eq!(
String::from_utf8_lossy(&written).trim_end_matches('\n'),
String::from_utf8_lossy(&origin).trim_end_matches('\n'),
)
}
pub async fn setup_stream_to_csv_test(origin_path: &str, threshold: impl Fn(usize) -> usize) {
let store = test_store("/");
let schema = test_basic_schema();
let csv_conf = CsvConfigBuilder::default()
.batch_size(test_util::TEST_BATCH_SIZE)
.file_schema(schema.clone())
.build()
.unwrap();
let csv_opener = CsvOpener::new(csv_conf, store.clone(), CompressionType::UNCOMPRESSED);
let size = store.read(origin_path).await.unwrap().len();
let config = scan_config(schema.clone(), None, origin_path);
let stream = FileStream::new(&config, 0, csv_opener, &ExecutionPlanMetricsSet::new()).unwrap();
let (tmp_store, dir) = test_tmp_store("test_stream_to_csv");
let output_path = format!("{}/{}", dir.path().display(), "output");
stream_to_csv(
Box::pin(stream),
tmp_store.clone(),
output_path.clone(),
threshold(size),
)
.await
.unwrap();
let written = tmp_store.read(&output_path).await.unwrap();
let origin = store.read(origin_path).await.unwrap();
// ignores `\n`
assert_eq!(
String::from_utf8_lossy(&written).trim_end_matches('\n'),
String::from_utf8_lossy(&origin).trim_end_matches('\n'),
)
}

View File

@@ -0,0 +1,61 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::test_util;
#[tokio::test]
async fn test_stream_to_json() {
// A small threshold
// Triggers the flush each writes
test_util::setup_stream_to_json_test(
&test_util::get_data_dir("tests/json/basic.json")
.display()
.to_string(),
|size| size / 2,
)
.await;
// A large threshold
// Only triggers the flush at last
test_util::setup_stream_to_json_test(
&test_util::get_data_dir("tests/json/basic.json")
.display()
.to_string(),
|size| size * 2,
)
.await;
}
#[tokio::test]
async fn test_stream_to_csv() {
// A small threshold
// Triggers the flush each writes
test_util::setup_stream_to_csv_test(
&test_util::get_data_dir("tests/csv/basic.csv")
.display()
.to_string(),
|size| size / 2,
)
.await;
// A large threshold
// Only triggers the flush at last
test_util::setup_stream_to_csv_test(
&test_util::get_data_dir("tests/csv/basic.csv")
.display()
.to_string(),
|size| size * 2,
)
.await;
}

View File

@@ -13,6 +13,7 @@ arrow.workspace = true
arrow-array.workspace = true
bytes = "1.1"
common-base = { path = "../common/base" }
common-datasource = { path = "../common/datasource" }
common-error = { path = "../common/error" }
common-query = { path = "../common/query" }
common-recordbatch = { path = "../common/recordbatch" }

View File

@@ -49,6 +49,12 @@ pub enum Error {
location: Location,
},
#[snafu(display("Failed to write to buffer, source: {}", source))]
WriteBuffer {
#[snafu(backtrace)]
source: common_datasource::error::Error,
},
#[snafu(display("Failed to create RecordBatch from vectors, source: {}", source))]
NewRecordBatch {
location: Location,
@@ -514,6 +520,7 @@ impl ErrorExt for Error {
InvalidAlterRequest { source, .. } | InvalidRegionDesc { source, .. } => {
source.status_code()
}
WriteBuffer { source, .. } => source.status_code(),
PushBatch { source, .. } => source.status_code(),
CreateDefault { source, .. } => source.status_code(),
ConvertChunk { source, .. } => source.status_code(),

View File

@@ -12,65 +12,31 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::io::Write;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use arrow_array::RecordBatch;
use bytes::{BufMut, BytesMut};
use common_datasource::buffered_writer::{
BufferedWriter as DatasourceBufferedWriter, DefaultBufferedWriter,
};
use common_datasource::share_buffer::SharedBuffer;
use datatypes::schema::SchemaRef;
use object_store::{ObjectStore, Writer};
use object_store::ObjectStore;
use parquet::arrow::ArrowWriter;
use parquet::file::properties::WriterProperties;
use parquet::format::FileMetaData;
use snafu::ResultExt;
use tokio_util::compat::FuturesAsyncWriteCompatExt;
use crate::error;
use crate::error::{NewRecordBatchSnafu, WriteObjectSnafu, WriteParquetSnafu};
use crate::read::Batch;
#[derive(Clone, Default)]
struct Buffer {
// It's lightweight since writer/flusher never tries to contend this mutex.
buffer: Arc<Mutex<BytesMut>>,
}
impl Buffer {
pub fn with_capacity(size: usize) -> Self {
Self {
buffer: Arc::new(Mutex::new(BytesMut::with_capacity(size))),
}
}
}
impl Write for Buffer {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let len = buf.len();
let mut buffer = self.buffer.lock().unwrap();
buffer.put_slice(buf);
Ok(len)
}
fn flush(&mut self) -> std::io::Result<()> {
// This flush implementation is intentionally left to blank.
// The actual flush is in `BufferedWriter::try_flush`
Ok(())
}
}
/// Parquet writer that buffers row groups in memory and writes buffered data to an underlying
/// storage by chunks to reduce memory consumption.
pub struct BufferedWriter {
path: String,
arrow_writer: ArrowWriter<Buffer>,
object_writer: Writer,
buffer: Buffer,
bytes_written: AtomicU64,
flushed: bool,
threshold: usize,
inner: InnerBufferedWriter,
arrow_schema: arrow::datatypes::SchemaRef,
}
type InnerBufferedWriter = DefaultBufferedWriter<ArrowWriter<SharedBuffer>>;
impl BufferedWriter {
pub async fn try_new(
path: String,
@@ -80,7 +46,7 @@ impl BufferedWriter {
buffer_threshold: usize,
) -> error::Result<Self> {
let arrow_schema = schema.arrow_schema();
let buffer = Buffer::with_capacity(buffer_threshold);
let buffer = SharedBuffer::with_capacity(buffer_threshold);
let writer = store
.writer(&path)
.await
@@ -89,14 +55,15 @@ impl BufferedWriter {
let arrow_writer = ArrowWriter::try_new(buffer.clone(), arrow_schema.clone(), props)
.context(WriteParquetSnafu)?;
let writer = writer.compat_write();
Ok(Self {
path,
arrow_writer,
object_writer: writer,
buffer,
bytes_written: Default::default(),
flushed: false,
threshold: buffer_threshold,
inner: DatasourceBufferedWriter::new(
buffer_threshold,
buffer.clone(),
arrow_writer,
writer,
),
arrow_schema: arrow_schema.clone(),
})
}
@@ -112,19 +79,16 @@ impl BufferedWriter {
.collect::<Vec<_>>(),
)
.context(NewRecordBatchSnafu)?;
self.arrow_writer
self.inner
.write(&arrow_batch)
.context(WriteParquetSnafu)?;
let written = Self::try_flush(
&self.path,
&self.buffer,
&mut self.object_writer,
false,
&mut self.flushed,
self.threshold,
)
.await?;
self.bytes_written.fetch_add(written, Ordering::Relaxed);
.await
.context(error::WriteBufferSnafu)?;
self.inner
.try_flush(false)
.await
.context(error::WriteBufferSnafu)?;
Ok(())
}
@@ -132,67 +96,14 @@ impl BufferedWriter {
pub async fn abort(self) -> bool {
// TODO(hl): Currently we can do nothing if file's parts have been uploaded to remote storage
// on abortion, we need to find a way to abort the upload. see https://help.aliyun.com/document_detail/31996.htm?spm=a2c4g.11186623.0.0.3eb42cb7b2mwUz#reference-txp-bvx-wdb
!self.flushed
!self.inner.flushed()
}
/// Close parquet writer and ensure all buffered data are written into underlying storage.
pub async fn close(mut self) -> error::Result<(FileMetaData, u64)> {
let metadata = self.arrow_writer.close().context(WriteParquetSnafu)?;
let written = Self::try_flush(
&self.path,
&self.buffer,
&mut self.object_writer,
true,
&mut self.flushed,
self.threshold,
)
.await?;
self.bytes_written.fetch_add(written, Ordering::Relaxed);
self.object_writer
.close()
pub async fn close(self) -> error::Result<(FileMetaData, u64)> {
self.inner
.close_with_arrow_writer()
.await
.context(WriteObjectSnafu { path: &self.path })?;
Ok((metadata, self.bytes_written.load(Ordering::Relaxed)))
}
/// Try to flush buffered data to underlying storage if it's size exceeds threshold.
/// Set `all` to true if all buffered data should be flushed regardless of it's size.
async fn try_flush(
file_name: &str,
shared_buffer: &Buffer,
object_writer: &mut Writer,
all: bool,
flushed: &mut bool,
threshold: usize,
) -> error::Result<u64> {
let mut bytes_written = 0;
// Once buffered data size reaches threshold, split the data in chunks (typically 4MB)
// and write to underlying storage.
while shared_buffer.buffer.lock().unwrap().len() >= threshold {
let chunk = {
let mut buffer = shared_buffer.buffer.lock().unwrap();
buffer.split_to(threshold)
};
let size = chunk.len();
object_writer
.write(chunk)
.await
.context(WriteObjectSnafu { path: file_name })?;
*flushed = true;
bytes_written += size;
}
if all {
let remain = shared_buffer.buffer.lock().unwrap().split();
let size = remain.len();
object_writer
.write(remain)
.await
.context(WriteObjectSnafu { path: file_name })?;
*flushed = true;
bytes_written += size;
}
Ok(bytes_written as u64)
.context(error::WriteBufferSnafu)
}
}