diff --git a/Cargo.lock b/Cargo.lock index 42bbadab06..fba8c8582a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8495,7 +8495,10 @@ version = "0.6.0" dependencies = [ "aide", "api", + "arrow", "arrow-flight", + "arrow-ipc", + "arrow-schema", "async-trait", "auth", "axum", diff --git a/Cargo.toml b/Cargo.toml index 617c66bb65..a63e0bedcc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,6 +70,7 @@ aquamarine = "0.3" arrow = { version = "47.0" } arrow-array = "47.0" arrow-flight = "47.0" +arrow-ipc = "47.0" arrow-schema = { version = "47.0", features = ["serde"] } async-stream = "0.3" async-trait = "0.1" diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index b7da8935f1..295b1b2811 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -14,6 +14,9 @@ testing = [] aide = { version = "0.9", features = ["axum"] } api.workspace = true arrow-flight.workspace = true +arrow-ipc.workspace = true +arrow-schema.workspace = true +arrow.workspace = true async-trait = "0.1" auth.workspace = true axum-macros = "0.3.8" diff --git a/src/servers/src/error.rs b/src/servers/src/error.rs index a19d39fe3c..905e4fe26f 100644 --- a/src/servers/src/error.rs +++ b/src/servers/src/error.rs @@ -35,6 +35,12 @@ use tonic::Code; #[snafu(visibility(pub))] #[stack_trace_debug] pub enum Error { + #[snafu(display("Arrow error"))] + Arrow { + #[snafu(source)] + error: arrow_schema::ArrowError, + }, + #[snafu(display("Internal error: {}", err_msg))] Internal { err_msg: String }, @@ -455,7 +461,8 @@ impl ErrorExt for Error { | TcpIncoming { .. } | CatalogError { .. } | GrpcReflectionService { .. } - | BuildHttpResponse { .. } => StatusCode::Internal, + | BuildHttpResponse { .. } + | Arrow { .. } => StatusCode::Internal, UnsupportedDataType { .. } => StatusCode::Unsupported, diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index d84b991fa0..d9371cab99 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -50,6 +50,7 @@ use tower_http::trace::TraceLayer; use self::authorize::AuthState; use crate::configurator::ConfiguratorRef; use crate::error::{AlreadyStartedSnafu, Error, Result, StartHttpSnafu, ToJsonSnafu}; +use crate::http::arrow_result::ArrowResponse; use crate::http::csv_result::CsvResponse; use crate::http::error_result::ErrorResponse; use crate::http::greptime_result_v1::GreptimedbV1Response; @@ -82,6 +83,7 @@ pub mod prom_store; pub mod prometheus; pub mod script; +pub mod arrow_result; pub mod csv_result; #[cfg(feature = "dashboard")] mod dashboard; @@ -247,6 +249,7 @@ pub enum GreptimeQueryOutput { /// It allows the results of SQL queries to be presented in different formats. #[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] pub enum ResponseFormat { + Arrow, Csv, #[default] GreptimedbV1, @@ -256,6 +259,7 @@ pub enum ResponseFormat { impl ResponseFormat { pub fn parse(s: &str) -> Option { match s { + "arrow" => Some(ResponseFormat::Arrow), "csv" => Some(ResponseFormat::Csv), "greptimedb_v1" => Some(ResponseFormat::GreptimedbV1), "influxdb_v1" => Some(ResponseFormat::InfluxdbV1), @@ -265,6 +269,7 @@ impl ResponseFormat { pub fn as_str(&self) -> &'static str { match self { + ResponseFormat::Arrow => "arrow", ResponseFormat::Csv => "csv", ResponseFormat::GreptimedbV1 => "greptimedb_v1", ResponseFormat::InfluxdbV1 => "influxdb_v1", @@ -318,6 +323,7 @@ impl Display for Epoch { #[derive(Serialize, Deserialize, Debug, JsonSchema)] pub enum HttpResponse { + Arrow(ArrowResponse), Csv(CsvResponse), Error(ErrorResponse), GreptimedbV1(GreptimedbV1Response), @@ -327,6 +333,7 @@ pub enum HttpResponse { impl HttpResponse { pub fn with_execution_time(self, execution_time: u64) -> Self { match self { + HttpResponse::Arrow(resp) => resp.with_execution_time(execution_time).into(), HttpResponse::Csv(resp) => resp.with_execution_time(execution_time).into(), HttpResponse::GreptimedbV1(resp) => resp.with_execution_time(execution_time).into(), HttpResponse::InfluxdbV1(resp) => resp.with_execution_time(execution_time).into(), @@ -338,6 +345,7 @@ impl HttpResponse { impl IntoResponse for HttpResponse { fn into_response(self) -> Response { match self { + HttpResponse::Arrow(resp) => resp.into_response(), HttpResponse::Csv(resp) => resp.into_response(), HttpResponse::GreptimedbV1(resp) => resp.into_response(), HttpResponse::InfluxdbV1(resp) => resp.into_response(), @@ -350,6 +358,12 @@ impl OperationOutput for HttpResponse { type Inner = Response; } +impl From for HttpResponse { + fn from(value: ArrowResponse) -> Self { + HttpResponse::Arrow(value) + } +} + impl From for HttpResponse { fn from(value: CsvResponse) -> Self { HttpResponse::Csv(value) @@ -801,9 +815,12 @@ async fn handle_error(err: BoxError) -> Json { #[cfg(test)] mod test { use std::future::pending; + use std::io::Cursor; use std::sync::Arc; use api::v1::greptime_request::Request; + use arrow_ipc::reader::FileReader; + use arrow_schema::DataType; use axum::handler::Handler; use axum::http::StatusCode; use axum::routing::get; @@ -942,11 +959,13 @@ mod test { ResponseFormat::GreptimedbV1, ResponseFormat::InfluxdbV1, ResponseFormat::Csv, + ResponseFormat::Arrow, ] { let recordbatches = RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()]).unwrap(); let outputs = vec![Ok(Output::RecordBatches(recordbatches))]; let json_resp = match format { + ResponseFormat::Arrow => ArrowResponse::from_output(outputs).await, ResponseFormat::Csv => CsvResponse::from_output(outputs).await, ResponseFormat::GreptimedbV1 => GreptimedbV1Response::from_output(outputs).await, ResponseFormat::InfluxdbV1 => InfluxdbV1Response::from_output(outputs, None).await, @@ -992,6 +1011,20 @@ mod test { panic!("invalid output type"); } } + HttpResponse::Arrow(resp) => { + let output = resp.data; + let mut reader = + FileReader::try_new(Cursor::new(output), None).expect("Arrow reader error"); + let schema = reader.schema(); + assert_eq!(schema.fields[0].name(), "numbers"); + assert_eq!(schema.fields[0].data_type(), &DataType::UInt32); + assert_eq!(schema.fields[1].name(), "strings"); + assert_eq!(schema.fields[1].data_type(), &DataType::Utf8); + + let rb = reader.next().unwrap().expect("read record batch failed"); + assert_eq!(rb.num_columns(), 2); + assert_eq!(rb.num_rows(), 4); + } HttpResponse::Error(err) => unreachable!("{err:?}"), } } diff --git a/src/servers/src/http/arrow_result.rs b/src/servers/src/http/arrow_result.rs new file mode 100644 index 0000000000..78d22b20c5 --- /dev/null +++ b/src/servers/src/http/arrow_result.rs @@ -0,0 +1,141 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::pin::Pin; +use std::sync::Arc; + +use arrow::datatypes::Schema; +use arrow_ipc::writer::FileWriter; +use axum::http::{header, HeaderName, HeaderValue}; +use axum::response::{IntoResponse, Response}; +use common_error::status_code::StatusCode; +use common_query::Output; +use common_recordbatch::RecordBatchStream; +use futures::StreamExt; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use snafu::ResultExt; + +use crate::error::{self, Error}; +use crate::http::error_result::ErrorResponse; +use crate::http::header::{GREPTIME_DB_HEADER_EXECUTION_TIME, GREPTIME_DB_HEADER_FORMAT}; +use crate::http::{HttpResponse, ResponseFormat}; + +#[derive(Serialize, Deserialize, Debug, JsonSchema)] +pub struct ArrowResponse { + pub(crate) data: Vec, + pub(crate) execution_time_ms: u64, +} + +async fn write_arrow_bytes( + mut recordbatches: Pin>, + schema: &Arc, +) -> Result, Error> { + let mut bytes = Vec::new(); + { + let mut writer = FileWriter::try_new(&mut bytes, schema).context(error::ArrowSnafu)?; + + while let Some(rb) = recordbatches.next().await { + let rb = rb.context(error::CollectRecordbatchSnafu)?; + writer + .write(&rb.into_df_record_batch()) + .context(error::ArrowSnafu)?; + } + + writer.finish().context(error::ArrowSnafu)?; + } + + Ok(bytes) +} + +impl ArrowResponse { + pub async fn from_output(mut outputs: Vec>) -> HttpResponse { + if outputs.len() != 1 { + return HttpResponse::Error(ErrorResponse::from_error_message( + ResponseFormat::Arrow, + StatusCode::InvalidArguments, + "Multi-statements and empty query are not allowed".to_string(), + )); + } + + match outputs.remove(0) { + Ok(output) => match output { + Output::AffectedRows(_rows) => HttpResponse::Arrow(ArrowResponse { + data: vec![], + execution_time_ms: 0, + }), + Output::RecordBatches(recordbatches) => { + let schema = recordbatches.schema(); + match write_arrow_bytes(recordbatches.as_stream(), schema.arrow_schema()).await + { + Ok(payload) => HttpResponse::Arrow(ArrowResponse { + data: payload, + execution_time_ms: 0, + }), + Err(e) => { + HttpResponse::Error(ErrorResponse::from_error(ResponseFormat::Arrow, e)) + } + } + } + + Output::Stream(recordbatches) => { + let schema = recordbatches.schema(); + match write_arrow_bytes(recordbatches, schema.arrow_schema()).await { + Ok(payload) => HttpResponse::Arrow(ArrowResponse { + data: payload, + execution_time_ms: 0, + }), + Err(e) => { + HttpResponse::Error(ErrorResponse::from_error(ResponseFormat::Arrow, e)) + } + } + } + }, + Err(e) => HttpResponse::Error(ErrorResponse::from_error(ResponseFormat::Arrow, e)), + } + } + + pub fn with_execution_time(mut self, execution_time: u64) -> Self { + self.execution_time_ms = execution_time; + self + } + + pub fn execution_time_ms(&self) -> u64 { + self.execution_time_ms + } +} + +impl IntoResponse for ArrowResponse { + fn into_response(self) -> Response { + let execution_time = self.execution_time_ms; + ( + [ + ( + header::CONTENT_TYPE, + HeaderValue::from_static("application/arrow"), + ), + ( + HeaderName::from_static(GREPTIME_DB_HEADER_FORMAT), + HeaderValue::from_static("ARROW"), + ), + ( + HeaderName::from_static(GREPTIME_DB_HEADER_EXECUTION_TIME), + HeaderValue::from(execution_time), + ), + ], + self.data, + ) + .into_response() + } +} diff --git a/src/servers/src/http/handler.rs b/src/servers/src/http/handler.rs index 3bcea4595d..88b3242755 100644 --- a/src/servers/src/http/handler.rs +++ b/src/servers/src/http/handler.rs @@ -29,6 +29,7 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use session::context::QueryContextRef; +use crate::http::arrow_result::ArrowResponse; use crate::http::csv_result::CsvResponse; use crate::http::error_result::ErrorResponse; use crate::http::greptime_result_v1::GreptimedbV1Response; @@ -111,6 +112,7 @@ pub async fn sql( }; let resp = match format { + ResponseFormat::Arrow => ArrowResponse::from_output(outputs).await, ResponseFormat::Csv => CsvResponse::from_output(outputs).await, ResponseFormat::GreptimedbV1 => GreptimedbV1Response::from_output(outputs).await, ResponseFormat::InfluxdbV1 => InfluxdbV1Response::from_output(outputs, epoch).await,