fix: further ease the restriction of executing SQLs in new GRPC interface (#797)

* fix: carry not recordbatch result in FlightData, to allow executing SQLs other than selection in new GRPC interface

* Update src/datanode/src/instance/flight/stream.rs

Co-authored-by: Jiachun Feng <jiachun_feng@proton.me>
This commit is contained in:
LFC
2022-12-28 16:43:21 +08:00
committed by GitHub
parent 76236646ef
commit 04df80e640
22 changed files with 257 additions and 172 deletions

View File

@@ -9,6 +9,7 @@ default = ["python"]
python = ["dep:script"]
[dependencies]
async-stream.workspace = true
async-trait.workspace = true
api = { path = "../api" }
arrow-flight.workspace = true
@@ -28,6 +29,7 @@ common-telemetry = { path = "../common/telemetry" }
common-time = { path = "../common/time" }
datafusion.workspace = true
datatypes = { path = "../datatypes" }
flatbuffers = "22"
futures = "0.3"
hyper = { version = "0.14", features = ["full"] }
log-store = { path = "../log-store" }

View File

@@ -18,14 +18,16 @@ use std::pin::Pin;
use api::v1::object_expr::Expr;
use api::v1::query_request::Query;
use api::v1::ObjectExpr;
use api::v1::{FlightDataExt, ObjectExpr};
use arrow_flight::flight_service_server::FlightService;
use arrow_flight::{
Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket,
HandshakeRequest, HandshakeResponse, IpcMessage, PutResult, SchemaResult, Ticket,
};
use async_trait::async_trait;
use common_query::Output;
use datatypes::arrow;
use flatbuffers::FlatBufferBuilder;
use futures::Stream;
use prost::Message;
use session::context::QueryContext;
@@ -33,7 +35,7 @@ use snafu::{OptionExt, ResultExt};
use tonic::{Request, Response, Streaming};
use crate::error::{self, Result};
use crate::instance::flight::stream::GetStream;
use crate::instance::flight::stream::FlightRecordBatchStream;
use crate::instance::Instance;
type TonicResult<T> = std::result::Result<T, tonic::Status>;
@@ -85,7 +87,7 @@ impl FlightService for Instance {
.query
.context(error::MissingRequiredFieldSnafu { name: "expr" })?;
let stream = self.handle_query(query).await?;
Ok(Response::new(Box::pin(stream) as TonicStream<FlightData>))
Ok(Response::new(stream))
}
// TODO(LFC): Implement Insertion Flight interface.
Expr::Insert(_) => Err(tonic::Status::unimplemented("Not yet implemented")),
@@ -130,7 +132,7 @@ impl FlightService for Instance {
}
impl Instance {
async fn handle_query(&self, query: Query) -> Result<GetStream> {
async fn handle_query(&self, query: Query) -> Result<TonicStream<FlightData>> {
let output = match query {
Query::Sql(sql) => {
let stmt = self
@@ -141,14 +143,125 @@ impl Instance {
}
Query::LogicalPlan(plan) => self.execute_logical(plan).await?,
};
let recordbatch_stream = match output {
Output::Stream(stream) => stream,
Output::RecordBatches(x) => x.as_stream(),
Output::AffectedRows(_) => {
unreachable!("SELECT should not have returned affected rows!")
Ok(match output {
Output::Stream(stream) => {
let stream = FlightRecordBatchStream::new(stream);
Box::pin(stream) as _
}
};
Ok(GetStream::new(recordbatch_stream))
Output::RecordBatches(x) => {
let stream = FlightRecordBatchStream::new(x.as_stream());
Box::pin(stream) as _
}
Output::AffectedRows(rows) => {
let stream = async_stream::stream! {
let ext_data = FlightDataExt {
affected_rows: rows as _,
}.encode_to_vec();
yield Ok(FlightData::new(None, IpcMessage(build_none_flight_msg()), vec![], ext_data))
};
Box::pin(stream) as _
}
})
}
}
fn build_none_flight_msg() -> Vec<u8> {
let mut builder = FlatBufferBuilder::new();
let mut message = arrow::ipc::MessageBuilder::new(&mut builder);
message.add_version(arrow::ipc::MetadataVersion::V5);
message.add_header_type(arrow::ipc::MessageHeader::NONE);
message.add_bodyLength(0);
let data = message.finish();
builder.finish(data, None);
builder.finished_data().to_vec()
}
#[cfg(test)]
mod test {
use api::v1::{object_result, FlightDataRaw, QueryRequest};
use common_grpc::flight;
use common_grpc::flight::FlightMessage;
use datatypes::prelude::*;
use super::*;
use crate::tests::test_util::{self, MockInstance};
#[tokio::test(flavor = "multi_thread")]
async fn test_handle_query() {
let instance = MockInstance::new("test_handle_query").await;
test_util::create_test_table(
&instance,
ConcreteDataType::timestamp_millisecond_datatype(),
)
.await
.unwrap();
let ticket = Request::new(Ticket {
ticket: ObjectExpr {
header: None,
expr: Some(Expr::Query(QueryRequest {
query: Some(Query::Sql(
"INSERT INTO demo(host, cpu, memory, ts) VALUES \
('host1', 66.6, 1024, 1672201025000),\
('host2', 88.8, 333.3, 1672201026000)"
.to_string(),
)),
})),
}
.encode_to_vec(),
});
let response = instance.inner().do_get(ticket).await.unwrap();
let result = flight::flight_data_to_object_result(response)
.await
.unwrap();
let result = result.result.unwrap();
assert!(matches!(result, object_result::Result::FlightData(_)));
let object_result::Result::FlightData(FlightDataRaw { raw_data }) = result else { unreachable!() };
let mut messages = flight::raw_flight_data_to_message(raw_data).unwrap();
assert_eq!(messages.len(), 1);
let message = messages.remove(0);
assert!(matches!(message, FlightMessage::AffectedRows(_)));
let FlightMessage::AffectedRows(affected_rows) = message else { unreachable!() };
assert_eq!(affected_rows, 2);
let ticket = Request::new(Ticket {
ticket: ObjectExpr {
header: None,
expr: Some(Expr::Query(QueryRequest {
query: Some(Query::Sql(
"SELECT ts, host, cpu, memory FROM demo".to_string(),
)),
})),
}
.encode_to_vec(),
});
let response = instance.inner().do_get(ticket).await.unwrap();
let result = flight::flight_data_to_object_result(response)
.await
.unwrap();
let result = result.result.unwrap();
assert!(matches!(result, object_result::Result::FlightData(_)));
let object_result::Result::FlightData(FlightDataRaw { raw_data }) = result else { unreachable!() };
let messages = flight::raw_flight_data_to_message(raw_data).unwrap();
assert_eq!(messages.len(), 2);
let recordbatch = flight::flight_messages_to_recordbatches(messages).unwrap();
let expected = "\
+---------------------+-------+------+--------+
| ts | host | cpu | memory |
+---------------------+-------+------+--------+
| 2022-12-28T04:17:05 | host1 | 66.6 | 1024 |
| 2022-12-28T04:17:06 | host2 | 88.8 | 333.3 |
+---------------------+-------+------+--------+";
let actual = recordbatch.pretty_print().unwrap();
assert_eq!(actual, expected);
}
}

View File

@@ -31,14 +31,14 @@ use crate::error;
use crate::instance::flight::TonicResult;
#[pin_project(PinnedDrop)]
pub(super) struct GetStream {
pub(super) struct FlightRecordBatchStream {
#[pin]
rx: mpsc::Receiver<Result<FlightData, tonic::Status>>,
join_handle: JoinHandle<()>,
done: bool,
}
impl GetStream {
impl FlightRecordBatchStream {
pub(super) fn new(recordbatches: SendableRecordBatchStream) -> Self {
let (tx, rx) = mpsc::channel::<TonicResult<FlightData>>(1);
let join_handle =
@@ -96,13 +96,13 @@ impl GetStream {
}
#[pinned_drop]
impl PinnedDrop for GetStream {
impl PinnedDrop for FlightRecordBatchStream {
fn drop(self: Pin<&mut Self>) {
self.join_handle.abort();
}
}
impl Stream for GetStream {
impl Stream for FlightRecordBatchStream {
type Item = TonicResult<FlightData>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
@@ -139,7 +139,7 @@ mod test {
use super::*;
#[tokio::test]
async fn test_get_stream() {
async fn test_flight_record_batch_stream() {
let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
"a",
ConcreteDataType::int32_datatype(),
@@ -152,13 +152,13 @@ mod test {
let recordbatches = RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()])
.unwrap()
.as_stream();
let mut get_stream = GetStream::new(recordbatches);
let mut stream = FlightRecordBatchStream::new(recordbatches);
let mut raw_data = Vec::with_capacity(2);
raw_data.push(get_stream.next().await.unwrap().unwrap());
raw_data.push(get_stream.next().await.unwrap().unwrap());
assert!(get_stream.next().await.is_none());
assert!(get_stream.done);
raw_data.push(stream.next().await.unwrap().unwrap());
raw_data.push(stream.next().await.unwrap().unwrap());
assert!(stream.next().await.is_none());
assert!(stream.done);
let decoder = &mut FlightDecoder::default();
let mut flight_messages = raw_data

View File

@@ -21,17 +21,11 @@ use datatypes::data_type::ConcreteDataType;
use datatypes::vectors::{Int64Vector, StringVector, UInt64Vector, VectorRef};
use session::context::QueryContext;
use crate::instance::Instance;
use crate::tests::test_util;
use crate::tests::test_util::{self, MockInstance};
#[tokio::test(flavor = "multi_thread")]
async fn test_create_database_and_insert_query() {
common_telemetry::init_default_ut_logging();
let (opts, _guard) =
test_util::create_tmp_dir_and_datanode_opts("create_database_and_insert_query");
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
instance.start().await.unwrap();
let instance = MockInstance::new("create_database_and_insert_query").await;
let output = execute_sql(&instance, "create database test").await;
assert!(matches!(output, Output::AffectedRows(1)));
@@ -77,12 +71,7 @@ async fn test_create_database_and_insert_query() {
}
#[tokio::test(flavor = "multi_thread")]
async fn test_issue477_same_table_name_in_different_databases() {
common_telemetry::init_default_ut_logging();
let (opts, _guard) =
test_util::create_tmp_dir_and_datanode_opts("create_database_and_insert_query");
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
instance.start().await.unwrap();
let instance = MockInstance::new("test_issue477_same_table_name_in_different_databases").await;
// Create database a and b
let output = execute_sql(&instance, "create database a").await;
@@ -149,7 +138,7 @@ async fn test_issue477_same_table_name_in_different_databases() {
.await;
}
async fn assert_query_result(instance: &Instance, sql: &str, ts: i64, host: &str) {
async fn assert_query_result(instance: &MockInstance, sql: &str, ts: i64, host: &str) {
let query_output = execute_sql(instance, sql).await;
match query_output {
Output::Stream(s) => {
@@ -169,16 +158,11 @@ async fn assert_query_result(instance: &Instance, sql: &str, ts: i64, host: &str
}
}
async fn setup_test_instance(test_name: &str) -> Instance {
common_telemetry::init_default_ut_logging();
let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts(test_name);
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
instance.start().await.unwrap();
async fn setup_test_instance(test_name: &str) -> MockInstance {
let instance = MockInstance::new(test_name).await;
test_util::create_test_table(
instance.catalog_manager(),
instance.sql_handler(),
&instance,
ConcreteDataType::timestamp_millisecond_datatype(),
)
.await
@@ -203,19 +187,11 @@ async fn test_execute_insert() {
#[tokio::test(flavor = "multi_thread")]
async fn test_execute_insert_query_with_i64_timestamp() {
common_telemetry::init_default_ut_logging();
let instance = MockInstance::new("insert_query_i64_timestamp").await;
let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("insert_query_i64_timestamp");
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
instance.start().await.unwrap();
test_util::create_test_table(
instance.catalog_manager(),
instance.sql_handler(),
ConcreteDataType::int64_datatype(),
)
.await
.unwrap();
test_util::create_test_table(&instance, ConcreteDataType::int64_datatype())
.await
.unwrap();
let output = execute_sql(
&instance,
@@ -262,9 +238,7 @@ async fn test_execute_insert_query_with_i64_timestamp() {
#[tokio::test(flavor = "multi_thread")]
async fn test_execute_query() {
let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("execute_query");
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
instance.start().await.unwrap();
let instance = MockInstance::new("execute_query").await;
let output = execute_sql(&instance, "select sum(number) from numbers limit 20").await;
match output {
@@ -284,10 +258,7 @@ async fn test_execute_query() {
#[tokio::test(flavor = "multi_thread")]
async fn test_execute_show_databases_tables() {
let (opts, _guard) =
test_util::create_tmp_dir_and_datanode_opts("execute_show_databases_tables");
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
instance.start().await.unwrap();
let instance = MockInstance::new("execute_show_databases_tables").await;
let output = execute_sql(&instance, "show databases").await;
match output {
@@ -331,8 +302,7 @@ async fn test_execute_show_databases_tables() {
// creat a table
test_util::create_test_table(
instance.catalog_manager(),
instance.sql_handler(),
&instance,
ConcreteDataType::timestamp_millisecond_datatype(),
)
.await
@@ -367,11 +337,7 @@ async fn test_execute_show_databases_tables() {
#[tokio::test(flavor = "multi_thread")]
pub async fn test_execute_create() {
common_telemetry::init_default_ut_logging();
let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("execute_create");
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
instance.start().await.unwrap();
let instance = MockInstance::new("execute_create").await;
let output = execute_sql(
&instance,
@@ -480,9 +446,7 @@ async fn test_alter_table() {
}
async fn test_insert_with_default_value_for_type(type_name: &str) {
let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("execute_create");
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
instance.start().await.unwrap();
let instance = MockInstance::new("execute_create").await;
let create_sql = format!(
r#"create table test_table(
@@ -527,17 +491,13 @@ async fn test_insert_with_default_value_for_type(type_name: &str) {
#[tokio::test(flavor = "multi_thread")]
async fn test_insert_with_default_value() {
common_telemetry::init_default_ut_logging();
test_insert_with_default_value_for_type("timestamp").await;
test_insert_with_default_value_for_type("bigint").await;
}
#[tokio::test(flavor = "multi_thread")]
async fn test_use_database() {
let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("use_database");
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
instance.start().await.unwrap();
let instance = MockInstance::new("test_use_database").await;
let output = execute_sql(&instance, "create database db1").await;
assert!(matches!(output, Output::AffectedRows(1)));
@@ -594,11 +554,11 @@ async fn test_use_database() {
check_output_stream(output, expected).await;
}
async fn execute_sql(instance: &Instance, sql: &str) -> Output {
async fn execute_sql(instance: &MockInstance, sql: &str) -> Output {
execute_sql_in_db(instance, sql, DEFAULT_SCHEMA_NAME).await
}
async fn execute_sql_in_db(instance: &Instance, sql: &str, db: &str) -> Output {
async fn execute_sql_in_db(instance: &MockInstance, sql: &str, db: &str) -> Output {
let query_ctx = Arc::new(QueryContext::with_current_schema(db.to_string()));
instance.execute_sql(sql, query_ctx).await.unwrap()
instance.inner().execute_sql(sql, query_ctx).await.unwrap()
}

View File

@@ -15,7 +15,6 @@
use std::collections::HashMap;
use std::sync::Arc;
use catalog::CatalogManagerRef;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, MIN_USER_TABLE_ID};
use datatypes::data_type::ConcreteDataType;
use datatypes::schema::{ColumnSchema, SchemaBuilder};
@@ -30,16 +29,35 @@ use tempdir::TempDir;
use crate::datanode::{DatanodeOptions, ObjectStoreConfig};
use crate::error::{CreateTableSnafu, Result};
use crate::instance::Instance;
use crate::sql::SqlHandler;
/// Create a tmp dir(will be deleted once it goes out of scope.) and a default `DatanodeOptions`,
/// Only for test.
pub struct TestGuard {
pub(crate) struct MockInstance {
instance: Instance,
_guard: TestGuard,
}
impl MockInstance {
pub(crate) async fn new(name: &str) -> Self {
let (opts, _guard) = create_tmp_dir_and_datanode_opts(name);
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
instance.start().await.unwrap();
MockInstance { instance, _guard }
}
pub(crate) fn inner(&self) -> &Instance {
&self.instance
}
}
struct TestGuard {
_wal_tmp_dir: TempDir,
_data_tmp_dir: TempDir,
}
pub fn create_tmp_dir_and_datanode_opts(name: &str) -> (DatanodeOptions, TestGuard) {
fn create_tmp_dir_and_datanode_opts(name: &str) -> (DatanodeOptions, TestGuard) {
let wal_tmp_dir = TempDir::new(&format!("gt_wal_{name}")).unwrap();
let data_tmp_dir = TempDir::new(&format!("gt_data_{name}")).unwrap();
let opts = DatanodeOptions {
@@ -59,9 +77,8 @@ pub fn create_tmp_dir_and_datanode_opts(name: &str) -> (DatanodeOptions, TestGua
)
}
pub async fn create_test_table(
catalog_manager: &CatalogManagerRef,
sql_handler: &SqlHandler,
pub(crate) async fn create_test_table(
instance: &MockInstance,
ts_type: ConcreteDataType,
) -> Result<()> {
let column_schemas = vec![
@@ -72,7 +89,7 @@ pub async fn create_test_table(
];
let table_name = "demo";
let table_engine: TableEngineRef = sql_handler.table_engine();
let table_engine: TableEngineRef = instance.inner().sql_handler().table_engine();
let table = table_engine
.create_table(
&EngineContext::default(),
@@ -97,11 +114,10 @@ pub async fn create_test_table(
.await
.context(CreateTableSnafu { table_name })?;
let schema_provider = catalog_manager
.catalog(DEFAULT_CATALOG_NAME)
.unwrap()
.unwrap()
.schema(DEFAULT_SCHEMA_NAME)
let schema_provider = instance
.inner()
.catalog_manager
.schema(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME)
.unwrap()
.unwrap();
schema_provider