From 04df80e6402b339bdd32c88ec92d59c851a39413 Mon Sep 17 00:00:00 2001 From: LFC Date: Wed, 28 Dec 2022 16:43:21 +0800 Subject: [PATCH] fix: further ease the restriction of executing SQLs in new GRPC interface (#797) * fix: carry not recordbatch result in FlightData, to allow executing SQLs other than selection in new GRPC interface * Update src/datanode/src/instance/flight/stream.rs Co-authored-by: Jiachun Feng --- Cargo.lock | 2 + Cargo.toml | 1 + benchmarks/src/bin/nyc-taxi.rs | 4 +- src/api/greptime/v1/database.proto | 4 + src/catalog/Cargo.toml | 2 +- src/client/Cargo.toml | 2 +- src/client/examples/select.rs | 34 ----- src/client/src/database.rs | 14 +-- src/client/src/lib.rs | 2 +- src/common/grpc/src/flight.rs | 24 ++-- src/datanode/Cargo.toml | 2 + src/datanode/src/instance/flight.rs | 139 +++++++++++++++++++-- src/datanode/src/instance/flight/stream.rs | 20 +-- src/datanode/src/tests/instance_test.rs | 80 +++--------- src/datanode/src/tests/test_util.rs | 44 ++++--- src/frontend/Cargo.toml | 2 +- src/log-store/Cargo.toml | 2 +- src/mito/Cargo.toml | 2 +- src/storage/Cargo.toml | 2 +- src/store-api/Cargo.toml | 2 +- tests-integration/tests/grpc.rs | 20 ++- tests/runner/src/env.rs | 25 ++-- 22 files changed, 257 insertions(+), 172 deletions(-) delete mode 100644 src/client/examples/select.rs diff --git a/Cargo.lock b/Cargo.lock index 6dc7e947de..f314919393 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2137,6 +2137,7 @@ version = "0.1.0" dependencies = [ "api", "arrow-flight", + "async-stream", "async-trait", "axum", "axum-macros", @@ -2157,6 +2158,7 @@ dependencies = [ "datafusion", "datafusion-common", "datatypes", + "flatbuffers", "futures", "hyper", "log-store", diff --git a/Cargo.toml b/Cargo.toml index f88be085c8..99df1a56f0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,7 @@ license = "Apache-2.0" arrow = "29.0" arrow-flight = "29.0" arrow-schema = { version = "29.0", features = ["serde"] } +async-stream = "0.3" async-trait = "0.1" # TODO(LFC): Use released Datafusion when it officially dpendent on Arrow 29.0 datafusion = { git = "https://github.com/apache/arrow-datafusion.git", rev = "4917235a398ae20145c87d20984e6367dc1a0c1e" } diff --git a/benchmarks/src/bin/nyc-taxi.rs b/benchmarks/src/bin/nyc-taxi.rs index 885f1827e1..bbc1a4462a 100644 --- a/benchmarks/src/bin/nyc-taxi.rs +++ b/benchmarks/src/bin/nyc-taxi.rs @@ -28,7 +28,7 @@ use clap::Parser; use client::admin::Admin; use client::api::v1::column::Values; use client::api::v1::{Column, ColumnDataType, ColumnDef, CreateTableExpr, InsertExpr, TableId}; -use client::{Client, Database, Select}; +use client::{Client, Database}; use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; use tokio::task::JoinSet; @@ -405,7 +405,7 @@ async fn do_query(num_iter: usize, db: &Database) { println!("Running query: {query}"); for i in 0..num_iter { let now = Instant::now(); - let _res = db.select(Select::Sql(query.clone())).await.unwrap(); + let _res = db.sql(&query).await.unwrap(); let elapsed = now.elapsed(); println!( "query {}, iteration {}: {}ms", diff --git a/src/api/greptime/v1/database.proto b/src/api/greptime/v1/database.proto index 9bfd4f4bcf..0db3aa7abe 100644 --- a/src/api/greptime/v1/database.proto +++ b/src/api/greptime/v1/database.proto @@ -56,3 +56,7 @@ message ObjectResult { message FlightDataRaw { repeated bytes raw_data = 1; } + +message FlightDataExt { + uint32 affected_rows = 1; +} diff --git a/src/catalog/Cargo.toml b/src/catalog/Cargo.toml index 59b5c56e65..73bb2990d9 100644 --- a/src/catalog/Cargo.toml +++ b/src/catalog/Cargo.toml @@ -7,7 +7,7 @@ license.workspace = true [dependencies] api = { path = "../api" } arc-swap = "1.0" -async-stream = "0.3" +async-stream.workspace = true async-trait = "0.1" backoff = { version = "0.4", features = ["tokio"] } common-catalog = { path = "../common/catalog" } diff --git a/src/client/Cargo.toml b/src/client/Cargo.toml index 645e22a9bb..abf9b7c108 100644 --- a/src/client/Cargo.toml +++ b/src/client/Cargo.toml @@ -6,7 +6,7 @@ license.workspace = true [dependencies] api = { path = "../api" } -async-stream = "0.3" +async-stream.workspace = true common-base = { path = "../common/base" } common-error = { path = "../common/error" } common-grpc = { path = "../common/grpc" } diff --git a/src/client/examples/select.rs b/src/client/examples/select.rs deleted file mode 100644 index 01516e732a..0000000000 --- a/src/client/examples/select.rs +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2022 Greptime Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use client::{Client, Database, Select}; -use tracing::{event, Level}; - -fn main() { - tracing::subscriber::set_global_default(tracing_subscriber::FmtSubscriber::builder().finish()) - .unwrap(); - - run(); -} - -#[tokio::main] -async fn run() { - let client = Client::with_urls(vec!["127.0.0.1:3001"]); - let db = Database::new("greptime", client); - - let sql = Select::Sql("select * from demo".to_string()); - let result = db.select(sql).await.unwrap(); - - event!(Level::INFO, "result: {:#?}", result); -} diff --git a/src/client/src/database.rs b/src/client/src/database.rs index 8595dcd852..29ecc07b56 100644 --- a/src/client/src/database.rs +++ b/src/client/src/database.rs @@ -73,13 +73,11 @@ impl Database { .collect() } - pub async fn select(&self, expr: Select) -> Result { - let select_expr = match expr { - Select::Sql(sql) => QueryRequest { - query: Some(query_request::Query::Sql(sql)), - }, + pub async fn sql(&self, sql: &str) -> Result { + let query = QueryRequest { + query: Some(query_request::Query::Sql(sql.to_string())), }; - self.do_select(select_expr).await + self.do_select(query).await } pub async fn logical_plan(&self, logical_plan: Vec) -> Result { @@ -166,10 +164,6 @@ impl TryFrom for ObjectResult { } } -pub enum Select { - Sql(String), -} - impl TryFrom for Output { type Error = error::Error; diff --git a/src/client/src/lib.rs b/src/client/src/lib.rs index 931f45c802..3a64c6b962 100644 --- a/src/client/src/lib.rs +++ b/src/client/src/lib.rs @@ -21,5 +21,5 @@ pub mod load_balance; pub use api; pub use self::client::Client; -pub use self::database::{Database, ObjectResult, Select}; +pub use self::database::{Database, ObjectResult}; pub use self::error::{Error, Result}; diff --git a/src/common/grpc/src/flight.rs b/src/common/grpc/src/flight.rs index 1ed8030d70..d08e53ba9e 100644 --- a/src/common/grpc/src/flight.rs +++ b/src/common/grpc/src/flight.rs @@ -17,7 +17,7 @@ use std::pin::Pin; use std::sync::Arc; use api::result::ObjectResultBuilder; -use api::v1::ObjectResult; +use api::v1::{FlightDataExt, ObjectResult}; use arrow_flight::utils::flight_data_to_arrow_batch; use arrow_flight::FlightData; use common_error::prelude::StatusCode; @@ -31,7 +31,10 @@ use snafu::{OptionExt, ResultExt}; use tonic::codegen::futures_core::Stream; use tonic::Response; -use crate::error::{self, InvalidFlightDataSnafu, Result}; +use crate::error::{ + ConvertArrowSchemaSnafu, CreateRecordBatchSnafu, DecodeFlightDataSnafu, InvalidFlightDataSnafu, + Result, +}; type TonicResult = std::result::Result; type TonicStream = Pin> + Send + Sync + 'static>>; @@ -40,6 +43,7 @@ type TonicStream = Pin> + Send + Sync + pub enum FlightMessage { Schema(SchemaRef), Recordbatch(RecordBatch), + AffectedRows(usize), } #[derive(Default)] @@ -56,6 +60,11 @@ impl FlightDecoder { .build() })?; match message.header_type() { + MessageHeader::NONE => { + let ext_data = FlightDataExt::decode(flight_data.data_body.as_slice()) + .context(DecodeFlightDataSnafu)?; + Ok(FlightMessage::AffectedRows(ext_data.affected_rows as _)) + } MessageHeader::Schema => { let arrow_schema = ArrowSchema::try_from(&flight_data).map_err(|e| { InvalidFlightDataSnafu { @@ -63,9 +72,8 @@ impl FlightDecoder { } .build() })?; - let schema = Arc::new( - Schema::try_from(arrow_schema).context(error::ConvertArrowSchemaSnafu)?, - ); + let schema = + Arc::new(Schema::try_from(arrow_schema).context(ConvertArrowSchemaSnafu)?); self.schema = Some(schema.clone()); @@ -86,7 +94,7 @@ impl FlightDecoder { .build() })?; let recordbatch = RecordBatch::try_from_df_record_batch(schema, arrow_batch) - .context(error::CreateRecordBatchSnafu)?; + .context(CreateRecordBatchSnafu)?; Ok(FlightMessage::Recordbatch(recordbatch)) } other => { @@ -127,7 +135,7 @@ pub async fn flight_data_to_object_result( pub fn raw_flight_data_to_message(raw_data: Vec>) -> Result> { let flight_data = raw_data .into_iter() - .map(|x| FlightData::decode(x.as_slice()).context(error::DecodeFlightDataSnafu)) + .map(|x| FlightData::decode(x.as_slice()).context(DecodeFlightDataSnafu)) .collect::>>()?; let decoder = &mut FlightDecoder::default(); @@ -165,7 +173,7 @@ pub fn flight_messages_to_recordbatches(messages: Vec) -> Result< } } - RecordBatches::try_new(schema, recordbatches).context(error::CreateRecordBatchSnafu) + RecordBatches::try_new(schema, recordbatches).context(CreateRecordBatchSnafu) } } diff --git a/src/datanode/Cargo.toml b/src/datanode/Cargo.toml index da8cc205b0..9d4ae86113 100644 --- a/src/datanode/Cargo.toml +++ b/src/datanode/Cargo.toml @@ -9,6 +9,7 @@ default = ["python"] python = ["dep:script"] [dependencies] +async-stream.workspace = true async-trait.workspace = true api = { path = "../api" } arrow-flight.workspace = true @@ -28,6 +29,7 @@ common-telemetry = { path = "../common/telemetry" } common-time = { path = "../common/time" } datafusion.workspace = true datatypes = { path = "../datatypes" } +flatbuffers = "22" futures = "0.3" hyper = { version = "0.14", features = ["full"] } log-store = { path = "../log-store" } diff --git a/src/datanode/src/instance/flight.rs b/src/datanode/src/instance/flight.rs index da22ffd2ac..27eb13d7d9 100644 --- a/src/datanode/src/instance/flight.rs +++ b/src/datanode/src/instance/flight.rs @@ -18,14 +18,16 @@ use std::pin::Pin; use api::v1::object_expr::Expr; use api::v1::query_request::Query; -use api::v1::ObjectExpr; +use api::v1::{FlightDataExt, ObjectExpr}; use arrow_flight::flight_service_server::FlightService; use arrow_flight::{ Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, - HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket, + HandshakeRequest, HandshakeResponse, IpcMessage, PutResult, SchemaResult, Ticket, }; use async_trait::async_trait; use common_query::Output; +use datatypes::arrow; +use flatbuffers::FlatBufferBuilder; use futures::Stream; use prost::Message; use session::context::QueryContext; @@ -33,7 +35,7 @@ use snafu::{OptionExt, ResultExt}; use tonic::{Request, Response, Streaming}; use crate::error::{self, Result}; -use crate::instance::flight::stream::GetStream; +use crate::instance::flight::stream::FlightRecordBatchStream; use crate::instance::Instance; type TonicResult = std::result::Result; @@ -85,7 +87,7 @@ impl FlightService for Instance { .query .context(error::MissingRequiredFieldSnafu { name: "expr" })?; let stream = self.handle_query(query).await?; - Ok(Response::new(Box::pin(stream) as TonicStream)) + Ok(Response::new(stream)) } // TODO(LFC): Implement Insertion Flight interface. Expr::Insert(_) => Err(tonic::Status::unimplemented("Not yet implemented")), @@ -130,7 +132,7 @@ impl FlightService for Instance { } impl Instance { - async fn handle_query(&self, query: Query) -> Result { + async fn handle_query(&self, query: Query) -> Result> { let output = match query { Query::Sql(sql) => { let stmt = self @@ -141,14 +143,125 @@ impl Instance { } Query::LogicalPlan(plan) => self.execute_logical(plan).await?, }; - - let recordbatch_stream = match output { - Output::Stream(stream) => stream, - Output::RecordBatches(x) => x.as_stream(), - Output::AffectedRows(_) => { - unreachable!("SELECT should not have returned affected rows!") + Ok(match output { + Output::Stream(stream) => { + let stream = FlightRecordBatchStream::new(stream); + Box::pin(stream) as _ } - }; - Ok(GetStream::new(recordbatch_stream)) + Output::RecordBatches(x) => { + let stream = FlightRecordBatchStream::new(x.as_stream()); + Box::pin(stream) as _ + } + Output::AffectedRows(rows) => { + let stream = async_stream::stream! { + let ext_data = FlightDataExt { + affected_rows: rows as _, + }.encode_to_vec(); + yield Ok(FlightData::new(None, IpcMessage(build_none_flight_msg()), vec![], ext_data)) + }; + Box::pin(stream) as _ + } + }) + } +} + +fn build_none_flight_msg() -> Vec { + let mut builder = FlatBufferBuilder::new(); + + let mut message = arrow::ipc::MessageBuilder::new(&mut builder); + message.add_version(arrow::ipc::MetadataVersion::V5); + message.add_header_type(arrow::ipc::MessageHeader::NONE); + message.add_bodyLength(0); + + let data = message.finish(); + builder.finish(data, None); + + builder.finished_data().to_vec() +} + +#[cfg(test)] +mod test { + use api::v1::{object_result, FlightDataRaw, QueryRequest}; + use common_grpc::flight; + use common_grpc::flight::FlightMessage; + use datatypes::prelude::*; + + use super::*; + use crate::tests::test_util::{self, MockInstance}; + + #[tokio::test(flavor = "multi_thread")] + async fn test_handle_query() { + let instance = MockInstance::new("test_handle_query").await; + test_util::create_test_table( + &instance, + ConcreteDataType::timestamp_millisecond_datatype(), + ) + .await + .unwrap(); + + let ticket = Request::new(Ticket { + ticket: ObjectExpr { + header: None, + expr: Some(Expr::Query(QueryRequest { + query: Some(Query::Sql( + "INSERT INTO demo(host, cpu, memory, ts) VALUES \ + ('host1', 66.6, 1024, 1672201025000),\ + ('host2', 88.8, 333.3, 1672201026000)" + .to_string(), + )), + })), + } + .encode_to_vec(), + }); + + let response = instance.inner().do_get(ticket).await.unwrap(); + let result = flight::flight_data_to_object_result(response) + .await + .unwrap(); + let result = result.result.unwrap(); + assert!(matches!(result, object_result::Result::FlightData(_))); + + let object_result::Result::FlightData(FlightDataRaw { raw_data }) = result else { unreachable!() }; + let mut messages = flight::raw_flight_data_to_message(raw_data).unwrap(); + assert_eq!(messages.len(), 1); + + let message = messages.remove(0); + assert!(matches!(message, FlightMessage::AffectedRows(_))); + let FlightMessage::AffectedRows(affected_rows) = message else { unreachable!() }; + assert_eq!(affected_rows, 2); + + let ticket = Request::new(Ticket { + ticket: ObjectExpr { + header: None, + expr: Some(Expr::Query(QueryRequest { + query: Some(Query::Sql( + "SELECT ts, host, cpu, memory FROM demo".to_string(), + )), + })), + } + .encode_to_vec(), + }); + + let response = instance.inner().do_get(ticket).await.unwrap(); + let result = flight::flight_data_to_object_result(response) + .await + .unwrap(); + let result = result.result.unwrap(); + assert!(matches!(result, object_result::Result::FlightData(_))); + + let object_result::Result::FlightData(FlightDataRaw { raw_data }) = result else { unreachable!() }; + let messages = flight::raw_flight_data_to_message(raw_data).unwrap(); + assert_eq!(messages.len(), 2); + + let recordbatch = flight::flight_messages_to_recordbatches(messages).unwrap(); + let expected = "\ ++---------------------+-------+------+--------+ +| ts | host | cpu | memory | ++---------------------+-------+------+--------+ +| 2022-12-28T04:17:05 | host1 | 66.6 | 1024 | +| 2022-12-28T04:17:06 | host2 | 88.8 | 333.3 | ++---------------------+-------+------+--------+"; + let actual = recordbatch.pretty_print().unwrap(); + assert_eq!(actual, expected); } } diff --git a/src/datanode/src/instance/flight/stream.rs b/src/datanode/src/instance/flight/stream.rs index 1f851c76d3..0cb3b35da4 100644 --- a/src/datanode/src/instance/flight/stream.rs +++ b/src/datanode/src/instance/flight/stream.rs @@ -31,14 +31,14 @@ use crate::error; use crate::instance::flight::TonicResult; #[pin_project(PinnedDrop)] -pub(super) struct GetStream { +pub(super) struct FlightRecordBatchStream { #[pin] rx: mpsc::Receiver>, join_handle: JoinHandle<()>, done: bool, } -impl GetStream { +impl FlightRecordBatchStream { pub(super) fn new(recordbatches: SendableRecordBatchStream) -> Self { let (tx, rx) = mpsc::channel::>(1); let join_handle = @@ -96,13 +96,13 @@ impl GetStream { } #[pinned_drop] -impl PinnedDrop for GetStream { +impl PinnedDrop for FlightRecordBatchStream { fn drop(self: Pin<&mut Self>) { self.join_handle.abort(); } } -impl Stream for GetStream { +impl Stream for FlightRecordBatchStream { type Item = TonicResult; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -139,7 +139,7 @@ mod test { use super::*; #[tokio::test] - async fn test_get_stream() { + async fn test_flight_record_batch_stream() { let schema = Arc::new(Schema::new(vec![ColumnSchema::new( "a", ConcreteDataType::int32_datatype(), @@ -152,13 +152,13 @@ mod test { let recordbatches = RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()]) .unwrap() .as_stream(); - let mut get_stream = GetStream::new(recordbatches); + let mut stream = FlightRecordBatchStream::new(recordbatches); let mut raw_data = Vec::with_capacity(2); - raw_data.push(get_stream.next().await.unwrap().unwrap()); - raw_data.push(get_stream.next().await.unwrap().unwrap()); - assert!(get_stream.next().await.is_none()); - assert!(get_stream.done); + raw_data.push(stream.next().await.unwrap().unwrap()); + raw_data.push(stream.next().await.unwrap().unwrap()); + assert!(stream.next().await.is_none()); + assert!(stream.done); let decoder = &mut FlightDecoder::default(); let mut flight_messages = raw_data diff --git a/src/datanode/src/tests/instance_test.rs b/src/datanode/src/tests/instance_test.rs index 09771f5eda..daf21216d9 100644 --- a/src/datanode/src/tests/instance_test.rs +++ b/src/datanode/src/tests/instance_test.rs @@ -21,17 +21,11 @@ use datatypes::data_type::ConcreteDataType; use datatypes::vectors::{Int64Vector, StringVector, UInt64Vector, VectorRef}; use session::context::QueryContext; -use crate::instance::Instance; -use crate::tests::test_util; +use crate::tests::test_util::{self, MockInstance}; #[tokio::test(flavor = "multi_thread")] async fn test_create_database_and_insert_query() { - common_telemetry::init_default_ut_logging(); - - let (opts, _guard) = - test_util::create_tmp_dir_and_datanode_opts("create_database_and_insert_query"); - let instance = Instance::with_mock_meta_client(&opts).await.unwrap(); - instance.start().await.unwrap(); + let instance = MockInstance::new("create_database_and_insert_query").await; let output = execute_sql(&instance, "create database test").await; assert!(matches!(output, Output::AffectedRows(1))); @@ -77,12 +71,7 @@ async fn test_create_database_and_insert_query() { } #[tokio::test(flavor = "multi_thread")] async fn test_issue477_same_table_name_in_different_databases() { - common_telemetry::init_default_ut_logging(); - - let (opts, _guard) = - test_util::create_tmp_dir_and_datanode_opts("create_database_and_insert_query"); - let instance = Instance::with_mock_meta_client(&opts).await.unwrap(); - instance.start().await.unwrap(); + let instance = MockInstance::new("test_issue477_same_table_name_in_different_databases").await; // Create database a and b let output = execute_sql(&instance, "create database a").await; @@ -149,7 +138,7 @@ async fn test_issue477_same_table_name_in_different_databases() { .await; } -async fn assert_query_result(instance: &Instance, sql: &str, ts: i64, host: &str) { +async fn assert_query_result(instance: &MockInstance, sql: &str, ts: i64, host: &str) { let query_output = execute_sql(instance, sql).await; match query_output { Output::Stream(s) => { @@ -169,16 +158,11 @@ async fn assert_query_result(instance: &Instance, sql: &str, ts: i64, host: &str } } -async fn setup_test_instance(test_name: &str) -> Instance { - common_telemetry::init_default_ut_logging(); - - let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts(test_name); - let instance = Instance::with_mock_meta_client(&opts).await.unwrap(); - instance.start().await.unwrap(); +async fn setup_test_instance(test_name: &str) -> MockInstance { + let instance = MockInstance::new(test_name).await; test_util::create_test_table( - instance.catalog_manager(), - instance.sql_handler(), + &instance, ConcreteDataType::timestamp_millisecond_datatype(), ) .await @@ -203,19 +187,11 @@ async fn test_execute_insert() { #[tokio::test(flavor = "multi_thread")] async fn test_execute_insert_query_with_i64_timestamp() { - common_telemetry::init_default_ut_logging(); + let instance = MockInstance::new("insert_query_i64_timestamp").await; - let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("insert_query_i64_timestamp"); - let instance = Instance::with_mock_meta_client(&opts).await.unwrap(); - instance.start().await.unwrap(); - - test_util::create_test_table( - instance.catalog_manager(), - instance.sql_handler(), - ConcreteDataType::int64_datatype(), - ) - .await - .unwrap(); + test_util::create_test_table(&instance, ConcreteDataType::int64_datatype()) + .await + .unwrap(); let output = execute_sql( &instance, @@ -262,9 +238,7 @@ async fn test_execute_insert_query_with_i64_timestamp() { #[tokio::test(flavor = "multi_thread")] async fn test_execute_query() { - let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("execute_query"); - let instance = Instance::with_mock_meta_client(&opts).await.unwrap(); - instance.start().await.unwrap(); + let instance = MockInstance::new("execute_query").await; let output = execute_sql(&instance, "select sum(number) from numbers limit 20").await; match output { @@ -284,10 +258,7 @@ async fn test_execute_query() { #[tokio::test(flavor = "multi_thread")] async fn test_execute_show_databases_tables() { - let (opts, _guard) = - test_util::create_tmp_dir_and_datanode_opts("execute_show_databases_tables"); - let instance = Instance::with_mock_meta_client(&opts).await.unwrap(); - instance.start().await.unwrap(); + let instance = MockInstance::new("execute_show_databases_tables").await; let output = execute_sql(&instance, "show databases").await; match output { @@ -331,8 +302,7 @@ async fn test_execute_show_databases_tables() { // creat a table test_util::create_test_table( - instance.catalog_manager(), - instance.sql_handler(), + &instance, ConcreteDataType::timestamp_millisecond_datatype(), ) .await @@ -367,11 +337,7 @@ async fn test_execute_show_databases_tables() { #[tokio::test(flavor = "multi_thread")] pub async fn test_execute_create() { - common_telemetry::init_default_ut_logging(); - - let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("execute_create"); - let instance = Instance::with_mock_meta_client(&opts).await.unwrap(); - instance.start().await.unwrap(); + let instance = MockInstance::new("execute_create").await; let output = execute_sql( &instance, @@ -480,9 +446,7 @@ async fn test_alter_table() { } async fn test_insert_with_default_value_for_type(type_name: &str) { - let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("execute_create"); - let instance = Instance::with_mock_meta_client(&opts).await.unwrap(); - instance.start().await.unwrap(); + let instance = MockInstance::new("execute_create").await; let create_sql = format!( r#"create table test_table( @@ -527,17 +491,13 @@ async fn test_insert_with_default_value_for_type(type_name: &str) { #[tokio::test(flavor = "multi_thread")] async fn test_insert_with_default_value() { - common_telemetry::init_default_ut_logging(); - test_insert_with_default_value_for_type("timestamp").await; test_insert_with_default_value_for_type("bigint").await; } #[tokio::test(flavor = "multi_thread")] async fn test_use_database() { - let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("use_database"); - let instance = Instance::with_mock_meta_client(&opts).await.unwrap(); - instance.start().await.unwrap(); + let instance = MockInstance::new("test_use_database").await; let output = execute_sql(&instance, "create database db1").await; assert!(matches!(output, Output::AffectedRows(1))); @@ -594,11 +554,11 @@ async fn test_use_database() { check_output_stream(output, expected).await; } -async fn execute_sql(instance: &Instance, sql: &str) -> Output { +async fn execute_sql(instance: &MockInstance, sql: &str) -> Output { execute_sql_in_db(instance, sql, DEFAULT_SCHEMA_NAME).await } -async fn execute_sql_in_db(instance: &Instance, sql: &str, db: &str) -> Output { +async fn execute_sql_in_db(instance: &MockInstance, sql: &str, db: &str) -> Output { let query_ctx = Arc::new(QueryContext::with_current_schema(db.to_string())); - instance.execute_sql(sql, query_ctx).await.unwrap() + instance.inner().execute_sql(sql, query_ctx).await.unwrap() } diff --git a/src/datanode/src/tests/test_util.rs b/src/datanode/src/tests/test_util.rs index 9fba0ff5f1..292402cf1f 100644 --- a/src/datanode/src/tests/test_util.rs +++ b/src/datanode/src/tests/test_util.rs @@ -15,7 +15,6 @@ use std::collections::HashMap; use std::sync::Arc; -use catalog::CatalogManagerRef; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, MIN_USER_TABLE_ID}; use datatypes::data_type::ConcreteDataType; use datatypes::schema::{ColumnSchema, SchemaBuilder}; @@ -30,16 +29,35 @@ use tempdir::TempDir; use crate::datanode::{DatanodeOptions, ObjectStoreConfig}; use crate::error::{CreateTableSnafu, Result}; +use crate::instance::Instance; use crate::sql::SqlHandler; -/// Create a tmp dir(will be deleted once it goes out of scope.) and a default `DatanodeOptions`, -/// Only for test. -pub struct TestGuard { +pub(crate) struct MockInstance { + instance: Instance, + _guard: TestGuard, +} + +impl MockInstance { + pub(crate) async fn new(name: &str) -> Self { + let (opts, _guard) = create_tmp_dir_and_datanode_opts(name); + + let instance = Instance::with_mock_meta_client(&opts).await.unwrap(); + instance.start().await.unwrap(); + + MockInstance { instance, _guard } + } + + pub(crate) fn inner(&self) -> &Instance { + &self.instance + } +} + +struct TestGuard { _wal_tmp_dir: TempDir, _data_tmp_dir: TempDir, } -pub fn create_tmp_dir_and_datanode_opts(name: &str) -> (DatanodeOptions, TestGuard) { +fn create_tmp_dir_and_datanode_opts(name: &str) -> (DatanodeOptions, TestGuard) { let wal_tmp_dir = TempDir::new(&format!("gt_wal_{name}")).unwrap(); let data_tmp_dir = TempDir::new(&format!("gt_data_{name}")).unwrap(); let opts = DatanodeOptions { @@ -59,9 +77,8 @@ pub fn create_tmp_dir_and_datanode_opts(name: &str) -> (DatanodeOptions, TestGua ) } -pub async fn create_test_table( - catalog_manager: &CatalogManagerRef, - sql_handler: &SqlHandler, +pub(crate) async fn create_test_table( + instance: &MockInstance, ts_type: ConcreteDataType, ) -> Result<()> { let column_schemas = vec![ @@ -72,7 +89,7 @@ pub async fn create_test_table( ]; let table_name = "demo"; - let table_engine: TableEngineRef = sql_handler.table_engine(); + let table_engine: TableEngineRef = instance.inner().sql_handler().table_engine(); let table = table_engine .create_table( &EngineContext::default(), @@ -97,11 +114,10 @@ pub async fn create_test_table( .await .context(CreateTableSnafu { table_name })?; - let schema_provider = catalog_manager - .catalog(DEFAULT_CATALOG_NAME) - .unwrap() - .unwrap() - .schema(DEFAULT_SCHEMA_NAME) + let schema_provider = instance + .inner() + .catalog_manager + .schema(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME) .unwrap() .unwrap(); schema_provider diff --git a/src/frontend/Cargo.toml b/src/frontend/Cargo.toml index dc6b503cdc..00433eaf1e 100644 --- a/src/frontend/Cargo.toml +++ b/src/frontend/Cargo.toml @@ -7,7 +7,7 @@ license.workspace = true [dependencies] anymap = "1.0.0-beta.2" api = { path = "../api" } -async-stream = "0.3" +async-stream.workspace = true async-trait = "0.1" catalog = { path = "../catalog" } chrono = "0.4" diff --git a/src/log-store/Cargo.toml b/src/log-store/Cargo.toml index 829f769c86..b675b9549d 100644 --- a/src/log-store/Cargo.toml +++ b/src/log-store/Cargo.toml @@ -7,7 +7,7 @@ license.workspace = true [dependencies] async-trait.workspace = true arc-swap = "1.5" -async-stream = "0.3" +async-stream.workspace = true base64 = "0.13" byteorder = "1.4" bytes = "1.1" diff --git a/src/mito/Cargo.toml b/src/mito/Cargo.toml index 7ba19c6aae..af99129a05 100644 --- a/src/mito/Cargo.toml +++ b/src/mito/Cargo.toml @@ -10,7 +10,7 @@ test = ["tempdir"] [dependencies] arc-swap = "1.0" -async-stream = "0.3" +async-stream.workspace = true async-trait = "0.1" chrono = { version = "0.4", features = ["serde"] } common-catalog = { path = "../common/catalog" } diff --git a/src/storage/Cargo.toml b/src/storage/Cargo.toml index 794807a881..243eab5724 100644 --- a/src/storage/Cargo.toml +++ b/src/storage/Cargo.toml @@ -7,7 +7,7 @@ license.workspace = true [dependencies] arc-swap = "1.0" async-compat = "0.2" -async-stream = "0.3" +async-stream.workspace = true async-trait = "0.1" bytes = "1.1" common-base = { path = "../common/base" } diff --git a/src/store-api/Cargo.toml b/src/store-api/Cargo.toml index af44ba75e0..92da1d5c0d 100644 --- a/src/store-api/Cargo.toml +++ b/src/store-api/Cargo.toml @@ -18,6 +18,6 @@ serde.workspace = true snafu.workspace = true [dev-dependencies] -async-stream = "0.3" +async-stream.workspace = true serde_json = "1.0" tokio = { version = "1.0", features = ["full"] } diff --git a/tests-integration/tests/grpc.rs b/tests-integration/tests/grpc.rs index 0893e38703..7982706395 100644 --- a/tests-integration/tests/grpc.rs +++ b/tests-integration/tests/grpc.rs @@ -20,7 +20,7 @@ use api::v1::{ use client::admin::Admin; use client::{Client, Database, ObjectResult}; use common_catalog::consts::MIN_USER_TABLE_ID; -use common_grpc::flight::flight_messages_to_recordbatches; +use common_grpc::flight::{flight_messages_to_recordbatches, FlightMessage}; use servers::server::Server; use tests_integration::test_util::{setup_grpc_server, StorageType}; @@ -167,7 +167,7 @@ pub async fn test_insert_and_select(store_type: StorageType) { kind: Some(kind), }; let result = admin.alter(expr).await.unwrap(); - assert_eq!(result.result, None); + assert!(result.result.is_none()); // insert insert_and_assert(&db).await; @@ -195,11 +195,21 @@ async fn insert_and_assert(db: &Database) { let result = db.insert(expr).await; result.unwrap(); - // select let result = db - .select(client::Select::Sql("select * from demo".to_string())) + .sql( + "INSERT INTO demo(host, cpu, memory, ts) VALUES \ + ('host5', 66.6, 1024, 1672201027000),\ + ('host6', 88.8, 333.3, 1672201028000)", + ) .await .unwrap(); + assert!(matches!(result, ObjectResult::FlightData(_))); + let ObjectResult::FlightData(mut messages) = result else { unreachable!() }; + assert_eq!(messages.len(), 1); + assert!(matches!(messages.remove(0), FlightMessage::AffectedRows(2))); + + // select + let result = db.sql("SELECT * FROM demo").await.unwrap(); match result { ObjectResult::FlightData(flight_messages) => { let recordbatches = flight_messages_to_recordbatches(flight_messages).unwrap(); @@ -212,6 +222,8 @@ async fn insert_and_assert(db: &Database) { | host2 | | 0.2 | 1970-01-01T00:00:00.101 | | host3 | 0.41 | | 1970-01-01T00:00:00.102 | | host4 | 0.2 | 0.3 | 1970-01-01T00:00:00.103 | +| host5 | 66.6 | 1024 | 2022-12-28T04:17:07 | +| host6 | 88.8 | 333.3 | 2022-12-28T04:17:08 | +-------+------+--------+-------------------------+\ "; assert_eq!(pretty, expected); diff --git a/tests/runner/src/env.rs b/tests/runner/src/env.rs index 67eb0fd822..5fb0073a02 100644 --- a/tests/runner/src/env.rs +++ b/tests/runner/src/env.rs @@ -18,8 +18,9 @@ use std::process::Stdio; use std::time::Duration; use async_trait::async_trait; -use client::{Client, Database as DB, Error as ClientError, ObjectResult, Select}; +use client::{Client, Database as DB, Error as ClientError, ObjectResult}; use common_grpc::flight; +use common_grpc::flight::FlightMessage; use sqlness::{Database, Environment}; use tokio::process::{Child, Command}; @@ -112,8 +113,7 @@ pub struct GreptimeDB { #[async_trait] impl Database for GreptimeDB { async fn query(&self, query: String) -> Box { - let sql = Select::Sql(query); - let result = self.db.select(sql).await; + let result = self.db.sql(&query).await; Box::new(ResultDisplayer { result }) as _ } } @@ -130,12 +130,19 @@ impl Display for ResultDisplayer { write!(f, "{mutate_result:?}") } ObjectResult::FlightData(messages) => { - let pretty = flight::flight_messages_to_recordbatches(messages.clone()) - .map_err(|e| e.to_string()) - .and_then(|x| x.pretty_print().map_err(|e| e.to_string())); - match pretty { - Ok(s) => write!(f, "{s}"), - Err(e) => write!(f, "format result error: {e}"), + if let Some(FlightMessage::AffectedRows(rows)) = messages.get(0) { + write!(f, "Affected Rows: {rows}") + } else { + let pretty = flight::flight_messages_to_recordbatches(messages.clone()) + .map_err(|e| e.to_string()) + .and_then(|x| x.pretty_print().map_err(|e| e.to_string())); + match pretty { + Ok(s) => write!(f, "{s}"), + Err(e) => write!( + f, + "Failed to convert Flight messages {messages:?} to Recordbatches, error: {e}" + ), + } } } },