feat: flight aboard (#840)

feat: replace old GRPC interface with Arrow Flight
This commit is contained in:
LFC
2023-01-09 17:06:24 +08:00
committed by GitHub
parent 9e58311ecd
commit 72f05a3137
56 changed files with 1268 additions and 2210 deletions

View File

@@ -12,7 +12,6 @@ python = ["dep:script"]
async-stream.workspace = true
async-trait.workspace = true
api = { path = "../api" }
arrow-flight.workspace = true
axum = "0.6"
axum-macros = "0.3"
backon = "0.2"
@@ -38,7 +37,7 @@ metrics = "0.20"
mito = { path = "../mito", features = ["test"] }
object-store = { path = "../object-store" }
pin-project = "1.0"
prost = "0.11"
prost.workspace = true
query = { path = "../query" }
script = { path = "../script", features = ["python"], optional = true }
serde = "1.0"

View File

@@ -281,32 +281,8 @@ pub enum Error {
#[snafu(display("Missing node id option in distributed mode"))]
MissingMetasrvOpts { backtrace: Backtrace },
#[snafu(display("Invalid Flight ticket, source: {}", source))]
InvalidFlightTicket {
source: api::DecodeError,
backtrace: Backtrace,
},
#[snafu(display("Missing required field: {}", name))]
MissingRequiredField { name: String, backtrace: Backtrace },
#[snafu(display("Failed to poll recordbatch stream, source: {}", source))]
PollRecordbatchStream {
#[snafu(backtrace)]
source: common_recordbatch::error::Error,
},
#[snafu(display("Invalid FlightData, source: {}", source))]
InvalidFlightData {
#[snafu(backtrace)]
source: common_grpc::Error,
},
#[snafu(display("Failed to do Flight get, source: {}", source))]
FlightGet {
source: tonic::Status,
backtrace: Backtrace,
},
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -336,8 +312,6 @@ impl ErrorExt for Error {
| Error::CreateExprToRequest { source }
| Error::InsertData { source } => source.status_code(),
Error::InvalidFlightData { source } => source.status_code(),
Error::CreateSchema { source, .. }
| Error::ConvertSchema { source, .. }
| Error::VectorComputation { source } => source.status_code(),
@@ -362,8 +336,6 @@ impl ErrorExt for Error {
| Error::RegisterSchema { .. }
| Error::Catalog { .. }
| Error::MissingRequiredField { .. }
| Error::FlightGet { .. }
| Error::InvalidFlightTicket { .. }
| Error::IncorrectInternalState { .. } => StatusCode::Internal,
Error::InitBackend { .. } => StatusCode::StorageUnavailable,
@@ -376,7 +348,6 @@ impl ErrorExt for Error {
Error::BumpTableId { source, .. } => source.status_code(),
Error::MissingNodeId { .. } => StatusCode::InvalidArguments,
Error::MissingMetasrvOpts { .. } => StatusCode::InvalidArguments,
Error::PollRecordbatchStream { source } => source.status_code(),
}
}

View File

@@ -47,7 +47,6 @@ use crate::heartbeat::HeartbeatTask;
use crate::script::ScriptExecutor;
use crate::sql::SqlHandler;
pub mod flight;
mod grpc;
mod script;
mod sql;

View File

@@ -1,470 +0,0 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
mod stream;
use std::pin::Pin;
use api::v1::ddl_request::Expr as DdlExpr;
use api::v1::object_expr::Request as GrpcRequest;
use api::v1::query_request::Query;
use api::v1::{DdlRequest, InsertRequest, ObjectExpr};
use arrow_flight::flight_service_server::FlightService;
use arrow_flight::{
Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket,
};
use async_trait::async_trait;
use common_catalog::consts::DEFAULT_CATALOG_NAME;
use common_grpc::flight::{FlightEncoder, FlightMessage};
use common_query::Output;
use futures::Stream;
use prost::Message;
use query::parser::QueryLanguageParser;
use session::context::QueryContext;
use snafu::{OptionExt, ResultExt};
use tonic::{Request, Response, Streaming};
use crate::error::{
CatalogSnafu, ExecuteSqlSnafu, InsertDataSnafu, InsertSnafu, InvalidFlightTicketSnafu,
MissingRequiredFieldSnafu, Result, TableNotFoundSnafu,
};
use crate::instance::flight::stream::FlightRecordBatchStream;
use crate::instance::Instance;
type TonicResult<T> = std::result::Result<T, tonic::Status>;
type TonicStream<T> = Pin<Box<dyn Stream<Item = TonicResult<T>> + Send + Sync + 'static>>;
#[async_trait]
impl FlightService for Instance {
type HandshakeStream = TonicStream<HandshakeResponse>;
async fn handshake(
&self,
_request: Request<Streaming<HandshakeRequest>>,
) -> TonicResult<Response<Self::HandshakeStream>> {
Err(tonic::Status::unimplemented("Not yet implemented"))
}
type ListFlightsStream = TonicStream<FlightInfo>;
async fn list_flights(
&self,
_request: Request<Criteria>,
) -> TonicResult<Response<Self::ListFlightsStream>> {
Err(tonic::Status::unimplemented("Not yet implemented"))
}
async fn get_flight_info(
&self,
_request: Request<FlightDescriptor>,
) -> TonicResult<Response<FlightInfo>> {
Err(tonic::Status::unimplemented("Not yet implemented"))
}
async fn get_schema(
&self,
_request: Request<FlightDescriptor>,
) -> TonicResult<Response<SchemaResult>> {
Err(tonic::Status::unimplemented("Not yet implemented"))
}
type DoGetStream = TonicStream<FlightData>;
async fn do_get(&self, request: Request<Ticket>) -> TonicResult<Response<Self::DoGetStream>> {
let ticket = request.into_inner().ticket;
let request = ObjectExpr::decode(ticket.as_slice())
.context(InvalidFlightTicketSnafu)?
.request
.context(MissingRequiredFieldSnafu { name: "request" })?;
let output = match request {
GrpcRequest::Insert(request) => self.handle_insert(request).await?,
GrpcRequest::Query(query_request) => {
let query = query_request
.query
.context(MissingRequiredFieldSnafu { name: "query" })?;
self.handle_query(query).await?
}
GrpcRequest::Ddl(request) => self.handle_ddl(request).await?,
};
let stream = to_flight_data_stream(output);
Ok(Response::new(stream))
}
type DoPutStream = TonicStream<PutResult>;
async fn do_put(
&self,
_request: Request<Streaming<FlightData>>,
) -> TonicResult<Response<Self::DoPutStream>> {
Err(tonic::Status::unimplemented("Not yet implemented"))
}
type DoExchangeStream = TonicStream<FlightData>;
async fn do_exchange(
&self,
_request: Request<Streaming<FlightData>>,
) -> TonicResult<Response<Self::DoExchangeStream>> {
Err(tonic::Status::unimplemented("Not yet implemented"))
}
type DoActionStream = TonicStream<arrow_flight::Result>;
async fn do_action(
&self,
_request: Request<Action>,
) -> TonicResult<Response<Self::DoActionStream>> {
Err(tonic::Status::unimplemented("Not yet implemented"))
}
type ListActionsStream = TonicStream<ActionType>;
async fn list_actions(
&self,
_request: Request<Empty>,
) -> TonicResult<Response<Self::ListActionsStream>> {
Err(tonic::Status::unimplemented("Not yet implemented"))
}
}
impl Instance {
async fn handle_query(&self, query: Query) -> Result<Output> {
Ok(match query {
Query::Sql(sql) => {
let stmt = QueryLanguageParser::parse_sql(&sql).context(ExecuteSqlSnafu)?;
self.execute_stmt(stmt, QueryContext::arc()).await?
}
Query::LogicalPlan(plan) => self.execute_logical(plan).await?,
})
}
pub async fn handle_insert(&self, request: InsertRequest) -> Result<Output> {
let table_name = &request.table_name.clone();
// TODO(LFC): InsertRequest should carry catalog name, too.
let table = self
.catalog_manager
.table(DEFAULT_CATALOG_NAME, &request.schema_name, table_name)
.context(CatalogSnafu)?
.context(TableNotFoundSnafu { table_name })?;
let request =
common_grpc_expr::insert::to_table_insert_request(request).context(InsertDataSnafu)?;
let affected_rows = table
.insert(request)
.await
.context(InsertSnafu { table_name })?;
Ok(Output::AffectedRows(affected_rows))
}
async fn handle_ddl(&self, request: DdlRequest) -> Result<Output> {
let expr = request
.expr
.context(MissingRequiredFieldSnafu { name: "expr" })?;
match expr {
DdlExpr::CreateTable(expr) => self.handle_create(expr).await,
DdlExpr::Alter(expr) => self.handle_alter(expr).await,
DdlExpr::CreateDatabase(expr) => self.handle_create_database(expr).await,
DdlExpr::DropTable(expr) => self.handle_drop_table(expr).await,
}
}
}
pub fn to_flight_data_stream(output: Output) -> TonicStream<FlightData> {
match output {
Output::Stream(stream) => {
let stream = FlightRecordBatchStream::new(stream);
Box::pin(stream) as _
}
Output::RecordBatches(x) => {
let stream = FlightRecordBatchStream::new(x.as_stream());
Box::pin(stream) as _
}
Output::AffectedRows(rows) => {
let stream = tokio_stream::once(Ok(
FlightEncoder::default().encode(FlightMessage::AffectedRows(rows))
));
Box::pin(stream) as _
}
}
}
#[cfg(test)]
mod test {
use api::v1::column::{SemanticType, Values};
use api::v1::{
alter_expr, AddColumn, AddColumns, AlterExpr, Column, ColumnDataType, ColumnDef,
CreateDatabaseExpr, CreateTableExpr, QueryRequest,
};
use client::RpcOutput;
use common_grpc::flight;
use common_recordbatch::RecordBatches;
use datatypes::prelude::*;
use super::*;
use crate::tests::test_util::{self, MockInstance};
async fn boarding(instance: &MockInstance, ticket: Request<Ticket>) -> RpcOutput {
let response = instance.inner().do_get(ticket).await.unwrap();
let result = flight::flight_data_to_object_result(response)
.await
.unwrap();
result.try_into().unwrap()
}
#[tokio::test(flavor = "multi_thread")]
async fn test_handle_ddl() {
let instance = MockInstance::new("test_handle_ddl").await;
let ticket = Request::new(Ticket {
ticket: ObjectExpr {
request: Some(GrpcRequest::Ddl(DdlRequest {
expr: Some(DdlExpr::CreateDatabase(CreateDatabaseExpr {
database_name: "my_database".to_string(),
})),
})),
}
.encode_to_vec(),
});
let output = boarding(&instance, ticket).await;
assert!(matches!(output, RpcOutput::AffectedRows(1)));
let ticket = Request::new(Ticket {
ticket: ObjectExpr {
request: Some(GrpcRequest::Ddl(DdlRequest {
expr: Some(DdlExpr::CreateTable(CreateTableExpr {
catalog_name: "greptime".to_string(),
schema_name: "my_database".to_string(),
table_name: "my_table".to_string(),
desc: "blabla".to_string(),
column_defs: vec![
ColumnDef {
name: "a".to_string(),
datatype: ColumnDataType::String as i32,
is_nullable: true,
default_constraint: vec![],
},
ColumnDef {
name: "ts".to_string(),
datatype: ColumnDataType::TimestampMillisecond as i32,
is_nullable: false,
default_constraint: vec![],
},
],
time_index: "ts".to_string(),
..Default::default()
})),
})),
}
.encode_to_vec(),
});
let output = boarding(&instance, ticket).await;
assert!(matches!(output, RpcOutput::AffectedRows(0)));
let ticket = Request::new(Ticket {
ticket: ObjectExpr {
request: Some(GrpcRequest::Ddl(DdlRequest {
expr: Some(DdlExpr::Alter(AlterExpr {
catalog_name: "greptime".to_string(),
schema_name: "my_database".to_string(),
table_name: "my_table".to_string(),
kind: Some(alter_expr::Kind::AddColumns(AddColumns {
add_columns: vec![AddColumn {
column_def: Some(ColumnDef {
name: "b".to_string(),
datatype: ColumnDataType::Int32 as i32,
is_nullable: true,
default_constraint: vec![],
}),
is_key: true,
}],
})),
})),
})),
}
.encode_to_vec(),
});
let output = boarding(&instance, ticket).await;
assert!(matches!(output, RpcOutput::AffectedRows(0)));
let output = instance
.inner()
.execute_sql(
"INSERT INTO my_database.my_table (a, b, ts) VALUES ('s', 1, 1672384140000)",
QueryContext::arc(),
)
.await
.unwrap();
assert!(matches!(output, Output::AffectedRows(1)));
let output = instance
.inner()
.execute_sql(
"SELECT ts, a, b FROM my_database.my_table",
QueryContext::arc(),
)
.await
.unwrap();
let Output::Stream(stream) = output else { unreachable!() };
let recordbatches = RecordBatches::try_collect(stream).await.unwrap();
let expected = "\
+---------------------+---+---+
| ts | a | b |
+---------------------+---+---+
| 2022-12-30T07:09:00 | s | 1 |
+---------------------+---+---+";
assert_eq!(recordbatches.pretty_print().unwrap(), expected);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_handle_insert() {
let instance = MockInstance::new("test_handle_insert").await;
test_util::create_test_table(
&instance,
ConcreteDataType::timestamp_millisecond_datatype(),
)
.await
.unwrap();
let insert = InsertRequest {
schema_name: "public".to_string(),
table_name: "demo".to_string(),
columns: vec![
Column {
column_name: "host".to_string(),
values: Some(Values {
string_values: vec![
"host1".to_string(),
"host2".to_string(),
"host3".to_string(),
],
..Default::default()
}),
semantic_type: SemanticType::Tag as i32,
datatype: ColumnDataType::String as i32,
..Default::default()
},
Column {
column_name: "cpu".to_string(),
values: Some(Values {
f64_values: vec![1.0, 3.0],
..Default::default()
}),
null_mask: vec![2],
semantic_type: SemanticType::Field as i32,
datatype: ColumnDataType::Float64 as i32,
},
Column {
column_name: "ts".to_string(),
values: Some(Values {
ts_millisecond_values: vec![1672384140000, 1672384141000, 1672384142000],
..Default::default()
}),
semantic_type: SemanticType::Timestamp as i32,
datatype: ColumnDataType::TimestampMillisecond as i32,
..Default::default()
},
],
row_count: 3,
..Default::default()
};
let ticket = Request::new(Ticket {
ticket: ObjectExpr {
request: Some(GrpcRequest::Insert(insert)),
}
.encode_to_vec(),
});
let output = boarding(&instance, ticket).await;
assert!(matches!(output, RpcOutput::AffectedRows(3)));
let output = instance
.inner()
.execute_sql("SELECT ts, host, cpu FROM demo", QueryContext::arc())
.await
.unwrap();
let Output::Stream(stream) = output else { unreachable!() };
let recordbatches = RecordBatches::try_collect(stream).await.unwrap();
let expected = "\
+---------------------+-------+-----+
| ts | host | cpu |
+---------------------+-------+-----+
| 2022-12-30T07:09:00 | host1 | 1 |
| 2022-12-30T07:09:01 | host2 | |
| 2022-12-30T07:09:02 | host3 | 3 |
+---------------------+-------+-----+";
assert_eq!(recordbatches.pretty_print().unwrap(), expected);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_handle_query() {
let instance = MockInstance::new("test_handle_query").await;
test_util::create_test_table(
&instance,
ConcreteDataType::timestamp_millisecond_datatype(),
)
.await
.unwrap();
let ticket = Request::new(Ticket {
ticket: ObjectExpr {
request: Some(GrpcRequest::Query(QueryRequest {
query: Some(Query::Sql(
"INSERT INTO demo(host, cpu, memory, ts) VALUES \
('host1', 66.6, 1024, 1672201025000),\
('host2', 88.8, 333.3, 1672201026000)"
.to_string(),
)),
})),
}
.encode_to_vec(),
});
let output = boarding(&instance, ticket).await;
assert!(matches!(output, RpcOutput::AffectedRows(2)));
let ticket = Request::new(Ticket {
ticket: ObjectExpr {
request: Some(GrpcRequest::Query(QueryRequest {
query: Some(Query::Sql(
"SELECT ts, host, cpu, memory FROM demo".to_string(),
)),
})),
}
.encode_to_vec(),
});
let response = instance.inner().do_get(ticket).await.unwrap();
let result = flight::flight_data_to_object_result(response)
.await
.unwrap();
let raw_data = result.flight_data;
let messages = flight::raw_flight_data_to_message(raw_data).unwrap();
assert_eq!(messages.len(), 2);
let recordbatch = flight::flight_messages_to_recordbatches(messages).unwrap();
let expected = "\
+---------------------+-------+------+--------+
| ts | host | cpu | memory |
+---------------------+-------+------+--------+
| 2022-12-28T04:17:05 | host1 | 66.6 | 1024 |
| 2022-12-28T04:17:06 | host2 | 88.8 | 333.3 |
+---------------------+-------+------+--------+";
let actual = recordbatch.pretty_print().unwrap();
assert_eq!(actual, expected);
}
}

View File

@@ -1,178 +0,0 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::pin::Pin;
use std::task::{Context, Poll};
use arrow_flight::FlightData;
use common_grpc::flight::{FlightEncoder, FlightMessage};
use common_recordbatch::SendableRecordBatchStream;
use common_telemetry::warn;
use futures::channel::mpsc;
use futures::channel::mpsc::Sender;
use futures::{SinkExt, Stream, StreamExt};
use pin_project::{pin_project, pinned_drop};
use snafu::ResultExt;
use tokio::task::JoinHandle;
use crate::error;
use crate::instance::flight::TonicResult;
#[pin_project(PinnedDrop)]
pub(super) struct FlightRecordBatchStream {
#[pin]
rx: mpsc::Receiver<Result<FlightMessage, tonic::Status>>,
join_handle: JoinHandle<()>,
done: bool,
encoder: FlightEncoder,
}
impl FlightRecordBatchStream {
pub(super) fn new(recordbatches: SendableRecordBatchStream) -> Self {
let (tx, rx) = mpsc::channel::<TonicResult<FlightMessage>>(1);
let join_handle =
common_runtime::spawn_read(
async move { Self::flight_data_stream(recordbatches, tx).await },
);
Self {
rx,
join_handle,
done: false,
encoder: FlightEncoder::default(),
}
}
async fn flight_data_stream(
mut recordbatches: SendableRecordBatchStream,
mut tx: Sender<TonicResult<FlightMessage>>,
) {
let schema = recordbatches.schema();
if let Err(e) = tx.send(Ok(FlightMessage::Schema(schema))).await {
warn!("stop sending Flight data, err: {e}");
return;
}
while let Some(batch_or_err) = recordbatches.next().await {
match batch_or_err {
Ok(recordbatch) => {
if let Err(e) = tx.send(Ok(FlightMessage::Recordbatch(recordbatch))).await {
warn!("stop sending Flight data, err: {e}");
return;
}
}
Err(e) => {
let e = Err(e).context(error::PollRecordbatchStreamSnafu);
if let Err(e) = tx.send(e.map_err(|x| x.into())).await {
warn!("stop sending Flight data, err: {e}");
}
return;
}
}
}
}
}
#[pinned_drop]
impl PinnedDrop for FlightRecordBatchStream {
fn drop(self: Pin<&mut Self>) {
self.join_handle.abort();
}
}
impl Stream for FlightRecordBatchStream {
type Item = TonicResult<FlightData>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
if *this.done {
Poll::Ready(None)
} else {
match this.rx.poll_next(cx) {
Poll::Ready(None) => {
*this.done = true;
Poll::Ready(None)
}
Poll::Ready(Some(result)) => match result {
Ok(flight_message) => {
let flight_data = this.encoder.encode(flight_message);
Poll::Ready(Some(Ok(flight_data)))
}
Err(e) => {
*this.done = true;
Poll::Ready(Some(Err(e)))
}
},
Poll::Pending => Poll::Pending,
}
}
}
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use common_grpc::flight::{FlightDecoder, FlightMessage};
use common_recordbatch::{RecordBatch, RecordBatches};
use datatypes::prelude::*;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::vectors::Int32Vector;
use futures::StreamExt;
use super::*;
#[tokio::test]
async fn test_flight_record_batch_stream() {
let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
"a",
ConcreteDataType::int32_datatype(),
false,
)]));
let v: VectorRef = Arc::new(Int32Vector::from_slice(&[1, 2]));
let recordbatch = RecordBatch::new(schema.clone(), vec![v]).unwrap();
let recordbatches = RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()])
.unwrap()
.as_stream();
let mut stream = FlightRecordBatchStream::new(recordbatches);
let mut raw_data = Vec::with_capacity(2);
raw_data.push(stream.next().await.unwrap().unwrap());
raw_data.push(stream.next().await.unwrap().unwrap());
assert!(stream.next().await.is_none());
assert!(stream.done);
let decoder = &mut FlightDecoder::default();
let mut flight_messages = raw_data
.into_iter()
.map(|x| decoder.try_decode(x).unwrap())
.collect::<Vec<FlightMessage>>();
assert_eq!(flight_messages.len(), 2);
match flight_messages.remove(0) {
FlightMessage::Schema(actual_schema) => {
assert_eq!(actual_schema, schema);
}
_ => unreachable!(),
}
match flight_messages.remove(0) {
FlightMessage::Recordbatch(actual_recordbatch) => {
assert_eq!(actual_recordbatch, recordbatch);
}
_ => unreachable!(),
}
}
}

View File

@@ -12,60 +12,26 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::error::Error;
use api::v1::{CreateDatabaseExpr, ObjectExpr, ObjectResult, ResultHeader};
use arrow_flight::flight_service_server::FlightService;
use arrow_flight::Ticket;
use api::v1::ddl_request::Expr as DdlExpr;
use api::v1::greptime_request::Request as GrpcRequest;
use api::v1::query_request::Query;
use api::v1::{CreateDatabaseExpr, DdlRequest, GreptimeRequest, InsertRequest};
use async_trait::async_trait;
use common_error::prelude::{BoxedError, ErrorExt, StatusCode};
use common_grpc::flight;
use common_catalog::consts::DEFAULT_CATALOG_NAME;
use common_error::prelude::BoxedError;
use common_query::Output;
use prost::Message;
use query::parser::QueryLanguageParser;
use query::plan::LogicalPlan;
use servers::query_handler::GrpcQueryHandler;
use session::context::QueryContext;
use snafu::prelude::*;
use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
use table::requests::CreateDatabaseRequest;
use tonic::Request;
use crate::error::{
DecodeLogicalPlanSnafu, Error as DatanodeError, ExecuteSqlSnafu, InvalidFlightDataSnafu, Result,
};
use crate::error::{self, DecodeLogicalPlanSnafu, ExecuteSqlSnafu, Result};
use crate::instance::Instance;
impl Instance {
async fn boarding(&self, ticket: Request<Ticket>) -> Result<ObjectResult> {
let response = self.do_get(ticket).await;
let response = match response {
Ok(response) => response,
Err(e) => {
let status_code = e
.source()
.and_then(|s| s.downcast_ref::<DatanodeError>())
.map(|s| s.status_code())
.unwrap_or(StatusCode::Internal);
let err_msg = e.source().map(|s| s.to_string()).unwrap_or(e.to_string());
// TODO(LFC): Further formalize error message when Arrow Flight adoption is done,
// and don't forget to change "test runner"'s error msg accordingly.
return Ok(ObjectResult {
header: Some(ResultHeader {
version: 1,
code: status_code as _,
err_msg,
}),
flight_data: vec![],
});
}
};
flight::flight_data_to_object_result(response)
.await
.context(InvalidFlightDataSnafu)
}
pub(crate) async fn handle_create_database(&self, expr: CreateDatabaseExpr) -> Result<Output> {
let req = CreateDatabaseRequest {
db_name: expr.database_name,
@@ -83,20 +49,298 @@ impl Instance {
.await
.context(ExecuteSqlSnafu)
}
async fn handle_query(&self, query: Query) -> Result<Output> {
Ok(match query {
Query::Sql(sql) => {
let stmt = QueryLanguageParser::parse_sql(&sql).context(ExecuteSqlSnafu)?;
self.execute_stmt(stmt, QueryContext::arc()).await?
}
Query::LogicalPlan(plan) => self.execute_logical(plan).await?,
})
}
pub async fn handle_insert(&self, request: InsertRequest) -> Result<Output> {
let table_name = &request.table_name.clone();
// TODO(LFC): InsertRequest should carry catalog name, too.
let table = self
.catalog_manager
.table(DEFAULT_CATALOG_NAME, &request.schema_name, table_name)
.context(error::CatalogSnafu)?
.context(error::TableNotFoundSnafu { table_name })?;
let request = common_grpc_expr::insert::to_table_insert_request(request)
.context(error::InsertDataSnafu)?;
let affected_rows = table
.insert(request)
.await
.context(error::InsertSnafu { table_name })?;
Ok(Output::AffectedRows(affected_rows))
}
async fn handle_ddl(&self, request: DdlRequest) -> Result<Output> {
let expr = request.expr.context(error::MissingRequiredFieldSnafu {
name: "DdlRequest.expr",
})?;
match expr {
DdlExpr::CreateTable(expr) => self.handle_create(expr).await,
DdlExpr::Alter(expr) => self.handle_alter(expr).await,
DdlExpr::CreateDatabase(expr) => self.handle_create_database(expr).await,
DdlExpr::DropTable(expr) => self.handle_drop_table(expr).await,
}
}
async fn handle_grpc_query(&self, query: GreptimeRequest) -> Result<Output> {
let request = query.request.context(error::MissingRequiredFieldSnafu {
name: "GreptimeRequest.request",
})?;
let output = match request {
GrpcRequest::Insert(request) => self.handle_insert(request).await?,
GrpcRequest::Query(query_request) => {
let query = query_request
.query
.context(error::MissingRequiredFieldSnafu {
name: "QueryRequest.query",
})?;
self.handle_query(query).await?
}
GrpcRequest::Ddl(request) => self.handle_ddl(request).await?,
};
Ok(output)
}
}
#[async_trait]
impl GrpcQueryHandler for Instance {
async fn do_query(&self, query: ObjectExpr) -> servers::error::Result<ObjectResult> {
let ticket = Request::new(Ticket {
ticket: query.encode_to_vec(),
});
// TODO(LFC): Temporarily use old GRPC interface here, will get rid of them near the end of Arrow Flight adoption.
self.boarding(ticket)
async fn do_query(&self, query: GreptimeRequest) -> servers::error::Result<Output> {
self.handle_grpc_query(query)
.await
.map_err(BoxedError::new)
.with_context(|_| servers::error::ExecuteQuerySnafu {
query: format!("{query:?}"),
})
.context(servers::error::ExecuteGrpcQuerySnafu)
}
}
#[cfg(test)]
mod test {
use api::v1::column::{SemanticType, Values};
use api::v1::{
alter_expr, AddColumn, AddColumns, AlterExpr, Column, ColumnDataType, ColumnDef,
CreateDatabaseExpr, CreateTableExpr, QueryRequest,
};
use common_recordbatch::RecordBatches;
use datatypes::prelude::*;
use super::*;
use crate::tests::test_util::{self, MockInstance};
#[tokio::test(flavor = "multi_thread")]
async fn test_handle_ddl() {
let instance = MockInstance::new("test_handle_ddl").await;
let instance = instance.inner();
let query = GreptimeRequest {
request: Some(GrpcRequest::Ddl(DdlRequest {
expr: Some(DdlExpr::CreateDatabase(CreateDatabaseExpr {
database_name: "my_database".to_string(),
})),
})),
};
let output = instance.do_query(query).await.unwrap();
assert!(matches!(output, Output::AffectedRows(1)));
let query = GreptimeRequest {
request: Some(GrpcRequest::Ddl(DdlRequest {
expr: Some(DdlExpr::CreateTable(CreateTableExpr {
catalog_name: "greptime".to_string(),
schema_name: "my_database".to_string(),
table_name: "my_table".to_string(),
desc: "blabla".to_string(),
column_defs: vec![
ColumnDef {
name: "a".to_string(),
datatype: ColumnDataType::String as i32,
is_nullable: true,
default_constraint: vec![],
},
ColumnDef {
name: "ts".to_string(),
datatype: ColumnDataType::TimestampMillisecond as i32,
is_nullable: false,
default_constraint: vec![],
},
],
time_index: "ts".to_string(),
..Default::default()
})),
})),
};
let output = instance.do_query(query).await.unwrap();
assert!(matches!(output, Output::AffectedRows(0)));
let query = GreptimeRequest {
request: Some(GrpcRequest::Ddl(DdlRequest {
expr: Some(DdlExpr::Alter(AlterExpr {
catalog_name: "greptime".to_string(),
schema_name: "my_database".to_string(),
table_name: "my_table".to_string(),
kind: Some(alter_expr::Kind::AddColumns(AddColumns {
add_columns: vec![AddColumn {
column_def: Some(ColumnDef {
name: "b".to_string(),
datatype: ColumnDataType::Int32 as i32,
is_nullable: true,
default_constraint: vec![],
}),
is_key: true,
}],
})),
})),
})),
};
let output = instance.do_query(query).await.unwrap();
assert!(matches!(output, Output::AffectedRows(0)));
let output = instance
.execute_sql(
"INSERT INTO my_database.my_table (a, b, ts) VALUES ('s', 1, 1672384140000)",
QueryContext::arc(),
)
.await
.unwrap();
assert!(matches!(output, Output::AffectedRows(1)));
let output = instance
.execute_sql(
"SELECT ts, a, b FROM my_database.my_table",
QueryContext::arc(),
)
.await
.unwrap();
let Output::Stream(stream) = output else { unreachable!() };
let recordbatches = RecordBatches::try_collect(stream).await.unwrap();
let expected = "\
+---------------------+---+---+
| ts | a | b |
+---------------------+---+---+
| 2022-12-30T07:09:00 | s | 1 |
+---------------------+---+---+";
assert_eq!(recordbatches.pretty_print().unwrap(), expected);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_handle_insert() {
let instance = MockInstance::new("test_handle_insert").await;
let instance = instance.inner();
test_util::create_test_table(instance, ConcreteDataType::timestamp_millisecond_datatype())
.await
.unwrap();
let insert = InsertRequest {
schema_name: "public".to_string(),
table_name: "demo".to_string(),
columns: vec![
Column {
column_name: "host".to_string(),
values: Some(Values {
string_values: vec![
"host1".to_string(),
"host2".to_string(),
"host3".to_string(),
],
..Default::default()
}),
semantic_type: SemanticType::Tag as i32,
datatype: ColumnDataType::String as i32,
..Default::default()
},
Column {
column_name: "cpu".to_string(),
values: Some(Values {
f64_values: vec![1.0, 3.0],
..Default::default()
}),
null_mask: vec![2],
semantic_type: SemanticType::Field as i32,
datatype: ColumnDataType::Float64 as i32,
},
Column {
column_name: "ts".to_string(),
values: Some(Values {
ts_millisecond_values: vec![1672384140000, 1672384141000, 1672384142000],
..Default::default()
}),
semantic_type: SemanticType::Timestamp as i32,
datatype: ColumnDataType::TimestampMillisecond as i32,
..Default::default()
},
],
row_count: 3,
..Default::default()
};
let query = GreptimeRequest {
request: Some(GrpcRequest::Insert(insert)),
};
let output = instance.do_query(query).await.unwrap();
assert!(matches!(output, Output::AffectedRows(3)));
let output = instance
.execute_sql("SELECT ts, host, cpu FROM demo", QueryContext::arc())
.await
.unwrap();
let Output::Stream(stream) = output else { unreachable!() };
let recordbatches = RecordBatches::try_collect(stream).await.unwrap();
let expected = "\
+---------------------+-------+-----+
| ts | host | cpu |
+---------------------+-------+-----+
| 2022-12-30T07:09:00 | host1 | 1 |
| 2022-12-30T07:09:01 | host2 | |
| 2022-12-30T07:09:02 | host3 | 3 |
+---------------------+-------+-----+";
assert_eq!(recordbatches.pretty_print().unwrap(), expected);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_handle_query() {
let instance = MockInstance::new("test_handle_query").await;
let instance = instance.inner();
test_util::create_test_table(instance, ConcreteDataType::timestamp_millisecond_datatype())
.await
.unwrap();
let query = GreptimeRequest {
request: Some(GrpcRequest::Query(QueryRequest {
query: Some(Query::Sql(
"INSERT INTO demo(host, cpu, memory, ts) VALUES \
('host1', 66.6, 1024, 1672201025000),\
('host2', 88.8, 333.3, 1672201026000)"
.to_string(),
)),
})),
};
let output = instance.do_query(query).await.unwrap();
assert!(matches!(output, Output::AffectedRows(2)));
let query = GreptimeRequest {
request: Some(GrpcRequest::Query(QueryRequest {
query: Some(Query::Sql(
"SELECT ts, host, cpu, memory FROM demo".to_string(),
)),
})),
};
let output = instance.do_query(query).await.unwrap();
let Output::Stream(stream) = output else { unreachable!() };
let recordbatch = RecordBatches::try_collect(stream).await.unwrap();
let expected = "\
+---------------------+-------+------+--------+
| ts | host | cpu | memory |
+---------------------+-------+------+--------+
| 2022-12-28T04:17:05 | host1 | 66.6 | 1024 |
| 2022-12-28T04:17:06 | host2 | 88.8 | 333.3 |
+---------------------+-------+------+--------+";
let actual = recordbatch.pretty_print().unwrap();
assert_eq!(actual, expected);
}
}

View File

@@ -162,7 +162,7 @@ async fn setup_test_instance(test_name: &str) -> MockInstance {
let instance = MockInstance::new(test_name).await;
test_util::create_test_table(
&instance,
instance.inner(),
ConcreteDataType::timestamp_millisecond_datatype(),
)
.await
@@ -189,7 +189,7 @@ async fn test_execute_insert() {
async fn test_execute_insert_query_with_i64_timestamp() {
let instance = MockInstance::new("insert_query_i64_timestamp").await;
test_util::create_test_table(&instance, ConcreteDataType::int64_datatype())
test_util::create_test_table(instance.inner(), ConcreteDataType::int64_datatype())
.await
.unwrap();
@@ -302,7 +302,7 @@ async fn test_execute_show_databases_tables() {
// creat a table
test_util::create_test_table(
&instance,
instance.inner(),
ConcreteDataType::timestamp_millisecond_datatype(),
)
.await

View File

@@ -78,7 +78,7 @@ fn create_tmp_dir_and_datanode_opts(name: &str) -> (DatanodeOptions, TestGuard)
}
pub(crate) async fn create_test_table(
instance: &MockInstance,
instance: &Instance,
ts_type: ConcreteDataType,
) -> Result<()> {
let column_schemas = vec![
@@ -89,7 +89,7 @@ pub(crate) async fn create_test_table(
];
let table_name = "demo";
let table_engine: TableEngineRef = instance.inner().sql_handler().table_engine();
let table_engine: TableEngineRef = instance.sql_handler().table_engine();
let table = table_engine
.create_table(
&EngineContext::default(),
@@ -115,7 +115,6 @@ pub(crate) async fn create_test_table(
.context(CreateTableSnafu { table_name })?;
let schema_provider = instance
.inner()
.catalog_manager
.schema(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME)
.unwrap()