feat: add Arrow IPC output format for http rest api (#3177)

* feat: add arrow format output for sql api

* refactor: remove unwraps

* test: add test for arrow format

* chore: update cargo toml format

* fix: resolve lint warrnings

* fix: ensure outputs size is one
This commit is contained in:
Ning Sun
2024-01-24 14:10:05 +08:00
committed by GitHub
parent f81e37f508
commit 1711ad4631
7 changed files with 191 additions and 1 deletions

3
Cargo.lock generated
View File

@@ -8495,7 +8495,10 @@ version = "0.6.0"
dependencies = [
"aide",
"api",
"arrow",
"arrow-flight",
"arrow-ipc",
"arrow-schema",
"async-trait",
"auth",
"axum",

View File

@@ -70,6 +70,7 @@ aquamarine = "0.3"
arrow = { version = "47.0" }
arrow-array = "47.0"
arrow-flight = "47.0"
arrow-ipc = "47.0"
arrow-schema = { version = "47.0", features = ["serde"] }
async-stream = "0.3"
async-trait = "0.1"

View File

@@ -14,6 +14,9 @@ testing = []
aide = { version = "0.9", features = ["axum"] }
api.workspace = true
arrow-flight.workspace = true
arrow-ipc.workspace = true
arrow-schema.workspace = true
arrow.workspace = true
async-trait = "0.1"
auth.workspace = true
axum-macros = "0.3.8"

View File

@@ -35,6 +35,12 @@ use tonic::Code;
#[snafu(visibility(pub))]
#[stack_trace_debug]
pub enum Error {
#[snafu(display("Arrow error"))]
Arrow {
#[snafu(source)]
error: arrow_schema::ArrowError,
},
#[snafu(display("Internal error: {}", err_msg))]
Internal { err_msg: String },
@@ -455,7 +461,8 @@ impl ErrorExt for Error {
| TcpIncoming { .. }
| CatalogError { .. }
| GrpcReflectionService { .. }
| BuildHttpResponse { .. } => StatusCode::Internal,
| BuildHttpResponse { .. }
| Arrow { .. } => StatusCode::Internal,
UnsupportedDataType { .. } => StatusCode::Unsupported,

View File

@@ -50,6 +50,7 @@ use tower_http::trace::TraceLayer;
use self::authorize::AuthState;
use crate::configurator::ConfiguratorRef;
use crate::error::{AlreadyStartedSnafu, Error, Result, StartHttpSnafu, ToJsonSnafu};
use crate::http::arrow_result::ArrowResponse;
use crate::http::csv_result::CsvResponse;
use crate::http::error_result::ErrorResponse;
use crate::http::greptime_result_v1::GreptimedbV1Response;
@@ -82,6 +83,7 @@ pub mod prom_store;
pub mod prometheus;
pub mod script;
pub mod arrow_result;
pub mod csv_result;
#[cfg(feature = "dashboard")]
mod dashboard;
@@ -247,6 +249,7 @@ pub enum GreptimeQueryOutput {
/// It allows the results of SQL queries to be presented in different formats.
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
pub enum ResponseFormat {
Arrow,
Csv,
#[default]
GreptimedbV1,
@@ -256,6 +259,7 @@ pub enum ResponseFormat {
impl ResponseFormat {
pub fn parse(s: &str) -> Option<Self> {
match s {
"arrow" => Some(ResponseFormat::Arrow),
"csv" => Some(ResponseFormat::Csv),
"greptimedb_v1" => Some(ResponseFormat::GreptimedbV1),
"influxdb_v1" => Some(ResponseFormat::InfluxdbV1),
@@ -265,6 +269,7 @@ impl ResponseFormat {
pub fn as_str(&self) -> &'static str {
match self {
ResponseFormat::Arrow => "arrow",
ResponseFormat::Csv => "csv",
ResponseFormat::GreptimedbV1 => "greptimedb_v1",
ResponseFormat::InfluxdbV1 => "influxdb_v1",
@@ -318,6 +323,7 @@ impl Display for Epoch {
#[derive(Serialize, Deserialize, Debug, JsonSchema)]
pub enum HttpResponse {
Arrow(ArrowResponse),
Csv(CsvResponse),
Error(ErrorResponse),
GreptimedbV1(GreptimedbV1Response),
@@ -327,6 +333,7 @@ pub enum HttpResponse {
impl HttpResponse {
pub fn with_execution_time(self, execution_time: u64) -> Self {
match self {
HttpResponse::Arrow(resp) => resp.with_execution_time(execution_time).into(),
HttpResponse::Csv(resp) => resp.with_execution_time(execution_time).into(),
HttpResponse::GreptimedbV1(resp) => resp.with_execution_time(execution_time).into(),
HttpResponse::InfluxdbV1(resp) => resp.with_execution_time(execution_time).into(),
@@ -338,6 +345,7 @@ impl HttpResponse {
impl IntoResponse for HttpResponse {
fn into_response(self) -> Response {
match self {
HttpResponse::Arrow(resp) => resp.into_response(),
HttpResponse::Csv(resp) => resp.into_response(),
HttpResponse::GreptimedbV1(resp) => resp.into_response(),
HttpResponse::InfluxdbV1(resp) => resp.into_response(),
@@ -350,6 +358,12 @@ impl OperationOutput for HttpResponse {
type Inner = Response;
}
impl From<ArrowResponse> for HttpResponse {
fn from(value: ArrowResponse) -> Self {
HttpResponse::Arrow(value)
}
}
impl From<CsvResponse> for HttpResponse {
fn from(value: CsvResponse) -> Self {
HttpResponse::Csv(value)
@@ -801,9 +815,12 @@ async fn handle_error(err: BoxError) -> Json<HttpResponse> {
#[cfg(test)]
mod test {
use std::future::pending;
use std::io::Cursor;
use std::sync::Arc;
use api::v1::greptime_request::Request;
use arrow_ipc::reader::FileReader;
use arrow_schema::DataType;
use axum::handler::Handler;
use axum::http::StatusCode;
use axum::routing::get;
@@ -942,11 +959,13 @@ mod test {
ResponseFormat::GreptimedbV1,
ResponseFormat::InfluxdbV1,
ResponseFormat::Csv,
ResponseFormat::Arrow,
] {
let recordbatches =
RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()]).unwrap();
let outputs = vec![Ok(Output::RecordBatches(recordbatches))];
let json_resp = match format {
ResponseFormat::Arrow => ArrowResponse::from_output(outputs).await,
ResponseFormat::Csv => CsvResponse::from_output(outputs).await,
ResponseFormat::GreptimedbV1 => GreptimedbV1Response::from_output(outputs).await,
ResponseFormat::InfluxdbV1 => InfluxdbV1Response::from_output(outputs, None).await,
@@ -992,6 +1011,20 @@ mod test {
panic!("invalid output type");
}
}
HttpResponse::Arrow(resp) => {
let output = resp.data;
let mut reader =
FileReader::try_new(Cursor::new(output), None).expect("Arrow reader error");
let schema = reader.schema();
assert_eq!(schema.fields[0].name(), "numbers");
assert_eq!(schema.fields[0].data_type(), &DataType::UInt32);
assert_eq!(schema.fields[1].name(), "strings");
assert_eq!(schema.fields[1].data_type(), &DataType::Utf8);
let rb = reader.next().unwrap().expect("read record batch failed");
assert_eq!(rb.num_columns(), 2);
assert_eq!(rb.num_rows(), 4);
}
HttpResponse::Error(err) => unreachable!("{err:?}"),
}
}

View File

@@ -0,0 +1,141 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::pin::Pin;
use std::sync::Arc;
use arrow::datatypes::Schema;
use arrow_ipc::writer::FileWriter;
use axum::http::{header, HeaderName, HeaderValue};
use axum::response::{IntoResponse, Response};
use common_error::status_code::StatusCode;
use common_query::Output;
use common_recordbatch::RecordBatchStream;
use futures::StreamExt;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use snafu::ResultExt;
use crate::error::{self, Error};
use crate::http::error_result::ErrorResponse;
use crate::http::header::{GREPTIME_DB_HEADER_EXECUTION_TIME, GREPTIME_DB_HEADER_FORMAT};
use crate::http::{HttpResponse, ResponseFormat};
#[derive(Serialize, Deserialize, Debug, JsonSchema)]
pub struct ArrowResponse {
pub(crate) data: Vec<u8>,
pub(crate) execution_time_ms: u64,
}
async fn write_arrow_bytes(
mut recordbatches: Pin<Box<dyn RecordBatchStream + Send>>,
schema: &Arc<Schema>,
) -> Result<Vec<u8>, Error> {
let mut bytes = Vec::new();
{
let mut writer = FileWriter::try_new(&mut bytes, schema).context(error::ArrowSnafu)?;
while let Some(rb) = recordbatches.next().await {
let rb = rb.context(error::CollectRecordbatchSnafu)?;
writer
.write(&rb.into_df_record_batch())
.context(error::ArrowSnafu)?;
}
writer.finish().context(error::ArrowSnafu)?;
}
Ok(bytes)
}
impl ArrowResponse {
pub async fn from_output(mut outputs: Vec<crate::error::Result<Output>>) -> HttpResponse {
if outputs.len() != 1 {
return HttpResponse::Error(ErrorResponse::from_error_message(
ResponseFormat::Arrow,
StatusCode::InvalidArguments,
"Multi-statements and empty query are not allowed".to_string(),
));
}
match outputs.remove(0) {
Ok(output) => match output {
Output::AffectedRows(_rows) => HttpResponse::Arrow(ArrowResponse {
data: vec![],
execution_time_ms: 0,
}),
Output::RecordBatches(recordbatches) => {
let schema = recordbatches.schema();
match write_arrow_bytes(recordbatches.as_stream(), schema.arrow_schema()).await
{
Ok(payload) => HttpResponse::Arrow(ArrowResponse {
data: payload,
execution_time_ms: 0,
}),
Err(e) => {
HttpResponse::Error(ErrorResponse::from_error(ResponseFormat::Arrow, e))
}
}
}
Output::Stream(recordbatches) => {
let schema = recordbatches.schema();
match write_arrow_bytes(recordbatches, schema.arrow_schema()).await {
Ok(payload) => HttpResponse::Arrow(ArrowResponse {
data: payload,
execution_time_ms: 0,
}),
Err(e) => {
HttpResponse::Error(ErrorResponse::from_error(ResponseFormat::Arrow, e))
}
}
}
},
Err(e) => HttpResponse::Error(ErrorResponse::from_error(ResponseFormat::Arrow, e)),
}
}
pub fn with_execution_time(mut self, execution_time: u64) -> Self {
self.execution_time_ms = execution_time;
self
}
pub fn execution_time_ms(&self) -> u64 {
self.execution_time_ms
}
}
impl IntoResponse for ArrowResponse {
fn into_response(self) -> Response {
let execution_time = self.execution_time_ms;
(
[
(
header::CONTENT_TYPE,
HeaderValue::from_static("application/arrow"),
),
(
HeaderName::from_static(GREPTIME_DB_HEADER_FORMAT),
HeaderValue::from_static("ARROW"),
),
(
HeaderName::from_static(GREPTIME_DB_HEADER_EXECUTION_TIME),
HeaderValue::from(execution_time),
),
],
self.data,
)
.into_response()
}
}

View File

@@ -29,6 +29,7 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use session::context::QueryContextRef;
use crate::http::arrow_result::ArrowResponse;
use crate::http::csv_result::CsvResponse;
use crate::http::error_result::ErrorResponse;
use crate::http::greptime_result_v1::GreptimedbV1Response;
@@ -111,6 +112,7 @@ pub async fn sql(
};
let resp = match format {
ResponseFormat::Arrow => ArrowResponse::from_output(outputs).await,
ResponseFormat::Csv => CsvResponse::from_output(outputs).await,
ResponseFormat::GreptimedbV1 => GreptimedbV1Response::from_output(outputs).await,
ResponseFormat::InfluxdbV1 => InfluxdbV1Response::from_output(outputs, epoch).await,