diff --git a/Cargo.lock b/Cargo.lock index fdad3a09be..3176657706 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -443,6 +443,7 @@ dependencies = [ "arrow-schema", "flatbuffers", "lz4_flex 0.11.3", + "zstd 0.13.1", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 046bf82478..c1eea12a53 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -90,7 +90,7 @@ aquamarine = "0.3" arrow = { version = "51.0.0", features = ["prettyprint"] } arrow-array = { version = "51.0.0", default-features = false, features = ["chrono-tz"] } arrow-flight = "51.0" -arrow-ipc = { version = "51.0.0", default-features = false, features = ["lz4"] } +arrow-ipc = { version = "51.0.0", default-features = false, features = ["lz4", "zstd"] } arrow-schema = { version = "51.0", features = ["serde"] } async-stream = "0.3" async-trait = "0.1" diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index 2313d19bbe..956a650fcc 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -1136,7 +1136,7 @@ mod test { RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()]).unwrap(); let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))]; let json_resp = match format { - ResponseFormat::Arrow => ArrowResponse::from_output(outputs).await, + ResponseFormat::Arrow => ArrowResponse::from_output(outputs, None).await, ResponseFormat::Csv => CsvResponse::from_output(outputs).await, ResponseFormat::Table => TableResponse::from_output(outputs).await, ResponseFormat::GreptimedbV1 => GreptimedbV1Response::from_output(outputs).await, diff --git a/src/servers/src/http/arrow_result.rs b/src/servers/src/http/arrow_result.rs index e6d2441ee2..6a739fee04 100644 --- a/src/servers/src/http/arrow_result.rs +++ b/src/servers/src/http/arrow_result.rs @@ -16,7 +16,8 @@ use std::pin::Pin; use std::sync::Arc; use arrow::datatypes::Schema; -use arrow_ipc::writer::FileWriter; +use arrow_ipc::writer::{FileWriter, IpcWriteOptions}; +use arrow_ipc::CompressionType; use axum::http::{header, HeaderValue}; use axum::response::{IntoResponse, Response}; use common_error::status_code::StatusCode; @@ -41,10 +42,15 @@ pub struct ArrowResponse { async fn write_arrow_bytes( mut recordbatches: Pin>, schema: &Arc, + compression: Option, ) -> Result, Error> { let mut bytes = Vec::new(); { - let mut writer = FileWriter::try_new(&mut bytes, schema).context(error::ArrowSnafu)?; + let options = IpcWriteOptions::default() + .try_with_compression(compression) + .context(error::ArrowSnafu)?; + let mut writer = FileWriter::try_new_with_options(&mut bytes, schema, options) + .context(error::ArrowSnafu)?; while let Some(rb) = recordbatches.next().await { let rb = rb.context(error::CollectRecordbatchSnafu)?; @@ -59,8 +65,22 @@ async fn write_arrow_bytes( Ok(bytes) } +fn compression_type(compression: Option) -> Option { + match compression + .map(|compression| compression.to_lowercase()) + .as_deref() + { + Some("zstd") => Some(CompressionType::ZSTD), + Some("lz4") => Some(CompressionType::LZ4_FRAME), + _ => None, + } +} + impl ArrowResponse { - pub async fn from_output(mut outputs: Vec>) -> HttpResponse { + pub async fn from_output( + mut outputs: Vec>, + compression: Option, + ) -> HttpResponse { if outputs.len() > 1 { return HttpResponse::Error(ErrorResponse::from_error_message( StatusCode::InvalidArguments, @@ -68,6 +88,8 @@ impl ArrowResponse { )); } + let compression = compression_type(compression); + match outputs.pop() { None => HttpResponse::Arrow(ArrowResponse { data: vec![], @@ -80,7 +102,9 @@ impl ArrowResponse { }), OutputData::RecordBatches(batches) => { let schema = batches.schema(); - match write_arrow_bytes(batches.as_stream(), schema.arrow_schema()).await { + match write_arrow_bytes(batches.as_stream(), schema.arrow_schema(), compression) + .await + { Ok(payload) => HttpResponse::Arrow(ArrowResponse { data: payload, execution_time_ms: 0, @@ -90,7 +114,7 @@ impl ArrowResponse { } OutputData::Stream(batches) => { let schema = batches.schema(); - match write_arrow_bytes(batches, schema.arrow_schema()).await { + match write_arrow_bytes(batches, schema.arrow_schema(), compression).await { Ok(payload) => HttpResponse::Arrow(ArrowResponse { data: payload, execution_time_ms: 0, @@ -136,3 +160,64 @@ impl IntoResponse for ArrowResponse { .into_response() } } + +#[cfg(test)] +mod test { + use std::io::Cursor; + + use arrow_ipc::reader::FileReader; + use arrow_schema::DataType; + use common_recordbatch::{RecordBatch, RecordBatches}; + use datatypes::prelude::*; + use datatypes::schema::{ColumnSchema, Schema}; + use datatypes::vectors::{StringVector, UInt32Vector}; + + use super::*; + + #[tokio::test] + async fn test_arrow_output() { + let column_schemas = vec![ + ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false), + ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true), + ]; + let schema = Arc::new(Schema::new(column_schemas)); + let columns: Vec = vec![ + Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])), + Arc::new(StringVector::from(vec![ + None, + Some("hello"), + Some("greptime"), + None, + ])), + ]; + + for compression in [None, Some("zstd".to_string()), Some("lz4".to_string())].into_iter() { + let recordbatch = RecordBatch::new(schema.clone(), columns.clone()).unwrap(); + let recordbatches = + RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()]).unwrap(); + let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))]; + + let http_resp = ArrowResponse::from_output(outputs, compression).await; + match http_resp { + HttpResponse::Arrow(resp) => { + let output = resp.data; + let mut reader = + FileReader::try_new(Cursor::new(output), None).expect("Arrow reader error"); + let schema = reader.schema(); + assert_eq!(schema.fields[0].name(), "numbers"); + assert_eq!(schema.fields[0].data_type(), &DataType::UInt32); + assert_eq!(schema.fields[1].name(), "strings"); + assert_eq!(schema.fields[1].data_type(), &DataType::Utf8); + + let rb = reader.next().unwrap().expect("read record batch failed"); + assert_eq!(rb.num_columns(), 2); + assert_eq!(rb.num_rows(), 4); + } + HttpResponse::Error(e) => { + panic!("unexpected {:?}", e); + } + _ => unreachable!(), + } + } + } +} diff --git a/src/servers/src/http/handler.rs b/src/servers/src/http/handler.rs index 4d5ca58461..1befc22240 100644 --- a/src/servers/src/http/handler.rs +++ b/src/servers/src/http/handler.rs @@ -51,7 +51,8 @@ use crate::query_handler::sql::ServerSqlQueryHandlerRef; pub struct SqlQuery { pub db: Option, pub sql: Option, - // (Optional) result format: [`greptimedb_v1`, `influxdb_v1`, `csv`], + // (Optional) result format: [`greptimedb_v1`, `influxdb_v1`, `csv`, + // `arrow`], // the default value is `greptimedb_v1` pub format: Option, // Returns epoch timestamps with the specified precision. @@ -64,6 +65,8 @@ pub struct SqlQuery { // param too. pub epoch: Option, pub limit: Option, + // For arrow output + pub compression: Option, } /// Handler to execute sql @@ -128,7 +131,9 @@ pub async fn sql( }; let mut resp = match format { - ResponseFormat::Arrow => ArrowResponse::from_output(outputs).await, + ResponseFormat::Arrow => { + ArrowResponse::from_output(outputs, query_params.compression).await + } ResponseFormat::Csv => CsvResponse::from_output(outputs).await, ResponseFormat::Table => TableResponse::from_output(outputs).await, ResponseFormat::GreptimedbV1 => GreptimedbV1Response::from_output(outputs).await,