fix: further ease the restriction of executing SQLs in new GRPC interface (#797)

* fix: carry not recordbatch result in FlightData, to allow executing SQLs other than selection in new GRPC interface

* Update src/datanode/src/instance/flight/stream.rs

Co-authored-by: Jiachun Feng <jiachun_feng@proton.me>
This commit is contained in:
LFC
2022-12-28 16:43:21 +08:00
committed by GitHub
parent 76236646ef
commit 04df80e640
22 changed files with 257 additions and 172 deletions

2
Cargo.lock generated
View File

@@ -2137,6 +2137,7 @@ version = "0.1.0"
dependencies = [
"api",
"arrow-flight",
"async-stream",
"async-trait",
"axum",
"axum-macros",
@@ -2157,6 +2158,7 @@ dependencies = [
"datafusion",
"datafusion-common",
"datatypes",
"flatbuffers",
"futures",
"hyper",
"log-store",

View File

@@ -48,6 +48,7 @@ license = "Apache-2.0"
arrow = "29.0"
arrow-flight = "29.0"
arrow-schema = { version = "29.0", features = ["serde"] }
async-stream = "0.3"
async-trait = "0.1"
# TODO(LFC): Use released Datafusion when it officially dpendent on Arrow 29.0
datafusion = { git = "https://github.com/apache/arrow-datafusion.git", rev = "4917235a398ae20145c87d20984e6367dc1a0c1e" }

View File

@@ -28,7 +28,7 @@ use clap::Parser;
use client::admin::Admin;
use client::api::v1::column::Values;
use client::api::v1::{Column, ColumnDataType, ColumnDef, CreateTableExpr, InsertExpr, TableId};
use client::{Client, Database, Select};
use client::{Client, Database};
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
use tokio::task::JoinSet;
@@ -405,7 +405,7 @@ async fn do_query(num_iter: usize, db: &Database) {
println!("Running query: {query}");
for i in 0..num_iter {
let now = Instant::now();
let _res = db.select(Select::Sql(query.clone())).await.unwrap();
let _res = db.sql(&query).await.unwrap();
let elapsed = now.elapsed();
println!(
"query {}, iteration {}: {}ms",

View File

@@ -56,3 +56,7 @@ message ObjectResult {
message FlightDataRaw {
repeated bytes raw_data = 1;
}
message FlightDataExt {
uint32 affected_rows = 1;
}

View File

@@ -7,7 +7,7 @@ license.workspace = true
[dependencies]
api = { path = "../api" }
arc-swap = "1.0"
async-stream = "0.3"
async-stream.workspace = true
async-trait = "0.1"
backoff = { version = "0.4", features = ["tokio"] }
common-catalog = { path = "../common/catalog" }

View File

@@ -6,7 +6,7 @@ license.workspace = true
[dependencies]
api = { path = "../api" }
async-stream = "0.3"
async-stream.workspace = true
common-base = { path = "../common/base" }
common-error = { path = "../common/error" }
common-grpc = { path = "../common/grpc" }

View File

@@ -1,34 +0,0 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use client::{Client, Database, Select};
use tracing::{event, Level};
fn main() {
tracing::subscriber::set_global_default(tracing_subscriber::FmtSubscriber::builder().finish())
.unwrap();
run();
}
#[tokio::main]
async fn run() {
let client = Client::with_urls(vec!["127.0.0.1:3001"]);
let db = Database::new("greptime", client);
let sql = Select::Sql("select * from demo".to_string());
let result = db.select(sql).await.unwrap();
event!(Level::INFO, "result: {:#?}", result);
}

View File

@@ -73,13 +73,11 @@ impl Database {
.collect()
}
pub async fn select(&self, expr: Select) -> Result<ObjectResult> {
let select_expr = match expr {
Select::Sql(sql) => QueryRequest {
query: Some(query_request::Query::Sql(sql)),
},
pub async fn sql(&self, sql: &str) -> Result<ObjectResult> {
let query = QueryRequest {
query: Some(query_request::Query::Sql(sql.to_string())),
};
self.do_select(select_expr).await
self.do_select(query).await
}
pub async fn logical_plan(&self, logical_plan: Vec<u8>) -> Result<ObjectResult> {
@@ -166,10 +164,6 @@ impl TryFrom<api::v1::ObjectResult> for ObjectResult {
}
}
pub enum Select {
Sql(String),
}
impl TryFrom<ObjectResult> for Output {
type Error = error::Error;

View File

@@ -21,5 +21,5 @@ pub mod load_balance;
pub use api;
pub use self::client::Client;
pub use self::database::{Database, ObjectResult, Select};
pub use self::database::{Database, ObjectResult};
pub use self::error::{Error, Result};

View File

@@ -17,7 +17,7 @@ use std::pin::Pin;
use std::sync::Arc;
use api::result::ObjectResultBuilder;
use api::v1::ObjectResult;
use api::v1::{FlightDataExt, ObjectResult};
use arrow_flight::utils::flight_data_to_arrow_batch;
use arrow_flight::FlightData;
use common_error::prelude::StatusCode;
@@ -31,7 +31,10 @@ use snafu::{OptionExt, ResultExt};
use tonic::codegen::futures_core::Stream;
use tonic::Response;
use crate::error::{self, InvalidFlightDataSnafu, Result};
use crate::error::{
ConvertArrowSchemaSnafu, CreateRecordBatchSnafu, DecodeFlightDataSnafu, InvalidFlightDataSnafu,
Result,
};
type TonicResult<T> = std::result::Result<T, tonic::Status>;
type TonicStream<T> = Pin<Box<dyn Stream<Item = TonicResult<T>> + Send + Sync + 'static>>;
@@ -40,6 +43,7 @@ type TonicStream<T> = Pin<Box<dyn Stream<Item = TonicResult<T>> + Send + Sync +
pub enum FlightMessage {
Schema(SchemaRef),
Recordbatch(RecordBatch),
AffectedRows(usize),
}
#[derive(Default)]
@@ -56,6 +60,11 @@ impl FlightDecoder {
.build()
})?;
match message.header_type() {
MessageHeader::NONE => {
let ext_data = FlightDataExt::decode(flight_data.data_body.as_slice())
.context(DecodeFlightDataSnafu)?;
Ok(FlightMessage::AffectedRows(ext_data.affected_rows as _))
}
MessageHeader::Schema => {
let arrow_schema = ArrowSchema::try_from(&flight_data).map_err(|e| {
InvalidFlightDataSnafu {
@@ -63,9 +72,8 @@ impl FlightDecoder {
}
.build()
})?;
let schema = Arc::new(
Schema::try_from(arrow_schema).context(error::ConvertArrowSchemaSnafu)?,
);
let schema =
Arc::new(Schema::try_from(arrow_schema).context(ConvertArrowSchemaSnafu)?);
self.schema = Some(schema.clone());
@@ -86,7 +94,7 @@ impl FlightDecoder {
.build()
})?;
let recordbatch = RecordBatch::try_from_df_record_batch(schema, arrow_batch)
.context(error::CreateRecordBatchSnafu)?;
.context(CreateRecordBatchSnafu)?;
Ok(FlightMessage::Recordbatch(recordbatch))
}
other => {
@@ -127,7 +135,7 @@ pub async fn flight_data_to_object_result(
pub fn raw_flight_data_to_message(raw_data: Vec<Vec<u8>>) -> Result<Vec<FlightMessage>> {
let flight_data = raw_data
.into_iter()
.map(|x| FlightData::decode(x.as_slice()).context(error::DecodeFlightDataSnafu))
.map(|x| FlightData::decode(x.as_slice()).context(DecodeFlightDataSnafu))
.collect::<Result<Vec<FlightData>>>()?;
let decoder = &mut FlightDecoder::default();
@@ -165,7 +173,7 @@ pub fn flight_messages_to_recordbatches(messages: Vec<FlightMessage>) -> Result<
}
}
RecordBatches::try_new(schema, recordbatches).context(error::CreateRecordBatchSnafu)
RecordBatches::try_new(schema, recordbatches).context(CreateRecordBatchSnafu)
}
}

View File

@@ -9,6 +9,7 @@ default = ["python"]
python = ["dep:script"]
[dependencies]
async-stream.workspace = true
async-trait.workspace = true
api = { path = "../api" }
arrow-flight.workspace = true
@@ -28,6 +29,7 @@ common-telemetry = { path = "../common/telemetry" }
common-time = { path = "../common/time" }
datafusion.workspace = true
datatypes = { path = "../datatypes" }
flatbuffers = "22"
futures = "0.3"
hyper = { version = "0.14", features = ["full"] }
log-store = { path = "../log-store" }

View File

@@ -18,14 +18,16 @@ use std::pin::Pin;
use api::v1::object_expr::Expr;
use api::v1::query_request::Query;
use api::v1::ObjectExpr;
use api::v1::{FlightDataExt, ObjectExpr};
use arrow_flight::flight_service_server::FlightService;
use arrow_flight::{
Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket,
HandshakeRequest, HandshakeResponse, IpcMessage, PutResult, SchemaResult, Ticket,
};
use async_trait::async_trait;
use common_query::Output;
use datatypes::arrow;
use flatbuffers::FlatBufferBuilder;
use futures::Stream;
use prost::Message;
use session::context::QueryContext;
@@ -33,7 +35,7 @@ use snafu::{OptionExt, ResultExt};
use tonic::{Request, Response, Streaming};
use crate::error::{self, Result};
use crate::instance::flight::stream::GetStream;
use crate::instance::flight::stream::FlightRecordBatchStream;
use crate::instance::Instance;
type TonicResult<T> = std::result::Result<T, tonic::Status>;
@@ -85,7 +87,7 @@ impl FlightService for Instance {
.query
.context(error::MissingRequiredFieldSnafu { name: "expr" })?;
let stream = self.handle_query(query).await?;
Ok(Response::new(Box::pin(stream) as TonicStream<FlightData>))
Ok(Response::new(stream))
}
// TODO(LFC): Implement Insertion Flight interface.
Expr::Insert(_) => Err(tonic::Status::unimplemented("Not yet implemented")),
@@ -130,7 +132,7 @@ impl FlightService for Instance {
}
impl Instance {
async fn handle_query(&self, query: Query) -> Result<GetStream> {
async fn handle_query(&self, query: Query) -> Result<TonicStream<FlightData>> {
let output = match query {
Query::Sql(sql) => {
let stmt = self
@@ -141,14 +143,125 @@ impl Instance {
}
Query::LogicalPlan(plan) => self.execute_logical(plan).await?,
};
let recordbatch_stream = match output {
Output::Stream(stream) => stream,
Output::RecordBatches(x) => x.as_stream(),
Output::AffectedRows(_) => {
unreachable!("SELECT should not have returned affected rows!")
Ok(match output {
Output::Stream(stream) => {
let stream = FlightRecordBatchStream::new(stream);
Box::pin(stream) as _
}
};
Ok(GetStream::new(recordbatch_stream))
Output::RecordBatches(x) => {
let stream = FlightRecordBatchStream::new(x.as_stream());
Box::pin(stream) as _
}
Output::AffectedRows(rows) => {
let stream = async_stream::stream! {
let ext_data = FlightDataExt {
affected_rows: rows as _,
}.encode_to_vec();
yield Ok(FlightData::new(None, IpcMessage(build_none_flight_msg()), vec![], ext_data))
};
Box::pin(stream) as _
}
})
}
}
fn build_none_flight_msg() -> Vec<u8> {
let mut builder = FlatBufferBuilder::new();
let mut message = arrow::ipc::MessageBuilder::new(&mut builder);
message.add_version(arrow::ipc::MetadataVersion::V5);
message.add_header_type(arrow::ipc::MessageHeader::NONE);
message.add_bodyLength(0);
let data = message.finish();
builder.finish(data, None);
builder.finished_data().to_vec()
}
#[cfg(test)]
mod test {
use api::v1::{object_result, FlightDataRaw, QueryRequest};
use common_grpc::flight;
use common_grpc::flight::FlightMessage;
use datatypes::prelude::*;
use super::*;
use crate::tests::test_util::{self, MockInstance};
#[tokio::test(flavor = "multi_thread")]
async fn test_handle_query() {
let instance = MockInstance::new("test_handle_query").await;
test_util::create_test_table(
&instance,
ConcreteDataType::timestamp_millisecond_datatype(),
)
.await
.unwrap();
let ticket = Request::new(Ticket {
ticket: ObjectExpr {
header: None,
expr: Some(Expr::Query(QueryRequest {
query: Some(Query::Sql(
"INSERT INTO demo(host, cpu, memory, ts) VALUES \
('host1', 66.6, 1024, 1672201025000),\
('host2', 88.8, 333.3, 1672201026000)"
.to_string(),
)),
})),
}
.encode_to_vec(),
});
let response = instance.inner().do_get(ticket).await.unwrap();
let result = flight::flight_data_to_object_result(response)
.await
.unwrap();
let result = result.result.unwrap();
assert!(matches!(result, object_result::Result::FlightData(_)));
let object_result::Result::FlightData(FlightDataRaw { raw_data }) = result else { unreachable!() };
let mut messages = flight::raw_flight_data_to_message(raw_data).unwrap();
assert_eq!(messages.len(), 1);
let message = messages.remove(0);
assert!(matches!(message, FlightMessage::AffectedRows(_)));
let FlightMessage::AffectedRows(affected_rows) = message else { unreachable!() };
assert_eq!(affected_rows, 2);
let ticket = Request::new(Ticket {
ticket: ObjectExpr {
header: None,
expr: Some(Expr::Query(QueryRequest {
query: Some(Query::Sql(
"SELECT ts, host, cpu, memory FROM demo".to_string(),
)),
})),
}
.encode_to_vec(),
});
let response = instance.inner().do_get(ticket).await.unwrap();
let result = flight::flight_data_to_object_result(response)
.await
.unwrap();
let result = result.result.unwrap();
assert!(matches!(result, object_result::Result::FlightData(_)));
let object_result::Result::FlightData(FlightDataRaw { raw_data }) = result else { unreachable!() };
let messages = flight::raw_flight_data_to_message(raw_data).unwrap();
assert_eq!(messages.len(), 2);
let recordbatch = flight::flight_messages_to_recordbatches(messages).unwrap();
let expected = "\
+---------------------+-------+------+--------+
| ts | host | cpu | memory |
+---------------------+-------+------+--------+
| 2022-12-28T04:17:05 | host1 | 66.6 | 1024 |
| 2022-12-28T04:17:06 | host2 | 88.8 | 333.3 |
+---------------------+-------+------+--------+";
let actual = recordbatch.pretty_print().unwrap();
assert_eq!(actual, expected);
}
}

View File

@@ -31,14 +31,14 @@ use crate::error;
use crate::instance::flight::TonicResult;
#[pin_project(PinnedDrop)]
pub(super) struct GetStream {
pub(super) struct FlightRecordBatchStream {
#[pin]
rx: mpsc::Receiver<Result<FlightData, tonic::Status>>,
join_handle: JoinHandle<()>,
done: bool,
}
impl GetStream {
impl FlightRecordBatchStream {
pub(super) fn new(recordbatches: SendableRecordBatchStream) -> Self {
let (tx, rx) = mpsc::channel::<TonicResult<FlightData>>(1);
let join_handle =
@@ -96,13 +96,13 @@ impl GetStream {
}
#[pinned_drop]
impl PinnedDrop for GetStream {
impl PinnedDrop for FlightRecordBatchStream {
fn drop(self: Pin<&mut Self>) {
self.join_handle.abort();
}
}
impl Stream for GetStream {
impl Stream for FlightRecordBatchStream {
type Item = TonicResult<FlightData>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
@@ -139,7 +139,7 @@ mod test {
use super::*;
#[tokio::test]
async fn test_get_stream() {
async fn test_flight_record_batch_stream() {
let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
"a",
ConcreteDataType::int32_datatype(),
@@ -152,13 +152,13 @@ mod test {
let recordbatches = RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()])
.unwrap()
.as_stream();
let mut get_stream = GetStream::new(recordbatches);
let mut stream = FlightRecordBatchStream::new(recordbatches);
let mut raw_data = Vec::with_capacity(2);
raw_data.push(get_stream.next().await.unwrap().unwrap());
raw_data.push(get_stream.next().await.unwrap().unwrap());
assert!(get_stream.next().await.is_none());
assert!(get_stream.done);
raw_data.push(stream.next().await.unwrap().unwrap());
raw_data.push(stream.next().await.unwrap().unwrap());
assert!(stream.next().await.is_none());
assert!(stream.done);
let decoder = &mut FlightDecoder::default();
let mut flight_messages = raw_data

View File

@@ -21,17 +21,11 @@ use datatypes::data_type::ConcreteDataType;
use datatypes::vectors::{Int64Vector, StringVector, UInt64Vector, VectorRef};
use session::context::QueryContext;
use crate::instance::Instance;
use crate::tests::test_util;
use crate::tests::test_util::{self, MockInstance};
#[tokio::test(flavor = "multi_thread")]
async fn test_create_database_and_insert_query() {
common_telemetry::init_default_ut_logging();
let (opts, _guard) =
test_util::create_tmp_dir_and_datanode_opts("create_database_and_insert_query");
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
instance.start().await.unwrap();
let instance = MockInstance::new("create_database_and_insert_query").await;
let output = execute_sql(&instance, "create database test").await;
assert!(matches!(output, Output::AffectedRows(1)));
@@ -77,12 +71,7 @@ async fn test_create_database_and_insert_query() {
}
#[tokio::test(flavor = "multi_thread")]
async fn test_issue477_same_table_name_in_different_databases() {
common_telemetry::init_default_ut_logging();
let (opts, _guard) =
test_util::create_tmp_dir_and_datanode_opts("create_database_and_insert_query");
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
instance.start().await.unwrap();
let instance = MockInstance::new("test_issue477_same_table_name_in_different_databases").await;
// Create database a and b
let output = execute_sql(&instance, "create database a").await;
@@ -149,7 +138,7 @@ async fn test_issue477_same_table_name_in_different_databases() {
.await;
}
async fn assert_query_result(instance: &Instance, sql: &str, ts: i64, host: &str) {
async fn assert_query_result(instance: &MockInstance, sql: &str, ts: i64, host: &str) {
let query_output = execute_sql(instance, sql).await;
match query_output {
Output::Stream(s) => {
@@ -169,16 +158,11 @@ async fn assert_query_result(instance: &Instance, sql: &str, ts: i64, host: &str
}
}
async fn setup_test_instance(test_name: &str) -> Instance {
common_telemetry::init_default_ut_logging();
let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts(test_name);
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
instance.start().await.unwrap();
async fn setup_test_instance(test_name: &str) -> MockInstance {
let instance = MockInstance::new(test_name).await;
test_util::create_test_table(
instance.catalog_manager(),
instance.sql_handler(),
&instance,
ConcreteDataType::timestamp_millisecond_datatype(),
)
.await
@@ -203,19 +187,11 @@ async fn test_execute_insert() {
#[tokio::test(flavor = "multi_thread")]
async fn test_execute_insert_query_with_i64_timestamp() {
common_telemetry::init_default_ut_logging();
let instance = MockInstance::new("insert_query_i64_timestamp").await;
let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("insert_query_i64_timestamp");
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
instance.start().await.unwrap();
test_util::create_test_table(
instance.catalog_manager(),
instance.sql_handler(),
ConcreteDataType::int64_datatype(),
)
.await
.unwrap();
test_util::create_test_table(&instance, ConcreteDataType::int64_datatype())
.await
.unwrap();
let output = execute_sql(
&instance,
@@ -262,9 +238,7 @@ async fn test_execute_insert_query_with_i64_timestamp() {
#[tokio::test(flavor = "multi_thread")]
async fn test_execute_query() {
let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("execute_query");
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
instance.start().await.unwrap();
let instance = MockInstance::new("execute_query").await;
let output = execute_sql(&instance, "select sum(number) from numbers limit 20").await;
match output {
@@ -284,10 +258,7 @@ async fn test_execute_query() {
#[tokio::test(flavor = "multi_thread")]
async fn test_execute_show_databases_tables() {
let (opts, _guard) =
test_util::create_tmp_dir_and_datanode_opts("execute_show_databases_tables");
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
instance.start().await.unwrap();
let instance = MockInstance::new("execute_show_databases_tables").await;
let output = execute_sql(&instance, "show databases").await;
match output {
@@ -331,8 +302,7 @@ async fn test_execute_show_databases_tables() {
// creat a table
test_util::create_test_table(
instance.catalog_manager(),
instance.sql_handler(),
&instance,
ConcreteDataType::timestamp_millisecond_datatype(),
)
.await
@@ -367,11 +337,7 @@ async fn test_execute_show_databases_tables() {
#[tokio::test(flavor = "multi_thread")]
pub async fn test_execute_create() {
common_telemetry::init_default_ut_logging();
let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("execute_create");
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
instance.start().await.unwrap();
let instance = MockInstance::new("execute_create").await;
let output = execute_sql(
&instance,
@@ -480,9 +446,7 @@ async fn test_alter_table() {
}
async fn test_insert_with_default_value_for_type(type_name: &str) {
let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("execute_create");
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
instance.start().await.unwrap();
let instance = MockInstance::new("execute_create").await;
let create_sql = format!(
r#"create table test_table(
@@ -527,17 +491,13 @@ async fn test_insert_with_default_value_for_type(type_name: &str) {
#[tokio::test(flavor = "multi_thread")]
async fn test_insert_with_default_value() {
common_telemetry::init_default_ut_logging();
test_insert_with_default_value_for_type("timestamp").await;
test_insert_with_default_value_for_type("bigint").await;
}
#[tokio::test(flavor = "multi_thread")]
async fn test_use_database() {
let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("use_database");
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
instance.start().await.unwrap();
let instance = MockInstance::new("test_use_database").await;
let output = execute_sql(&instance, "create database db1").await;
assert!(matches!(output, Output::AffectedRows(1)));
@@ -594,11 +554,11 @@ async fn test_use_database() {
check_output_stream(output, expected).await;
}
async fn execute_sql(instance: &Instance, sql: &str) -> Output {
async fn execute_sql(instance: &MockInstance, sql: &str) -> Output {
execute_sql_in_db(instance, sql, DEFAULT_SCHEMA_NAME).await
}
async fn execute_sql_in_db(instance: &Instance, sql: &str, db: &str) -> Output {
async fn execute_sql_in_db(instance: &MockInstance, sql: &str, db: &str) -> Output {
let query_ctx = Arc::new(QueryContext::with_current_schema(db.to_string()));
instance.execute_sql(sql, query_ctx).await.unwrap()
instance.inner().execute_sql(sql, query_ctx).await.unwrap()
}

View File

@@ -15,7 +15,6 @@
use std::collections::HashMap;
use std::sync::Arc;
use catalog::CatalogManagerRef;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, MIN_USER_TABLE_ID};
use datatypes::data_type::ConcreteDataType;
use datatypes::schema::{ColumnSchema, SchemaBuilder};
@@ -30,16 +29,35 @@ use tempdir::TempDir;
use crate::datanode::{DatanodeOptions, ObjectStoreConfig};
use crate::error::{CreateTableSnafu, Result};
use crate::instance::Instance;
use crate::sql::SqlHandler;
/// Create a tmp dir(will be deleted once it goes out of scope.) and a default `DatanodeOptions`,
/// Only for test.
pub struct TestGuard {
pub(crate) struct MockInstance {
instance: Instance,
_guard: TestGuard,
}
impl MockInstance {
pub(crate) async fn new(name: &str) -> Self {
let (opts, _guard) = create_tmp_dir_and_datanode_opts(name);
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
instance.start().await.unwrap();
MockInstance { instance, _guard }
}
pub(crate) fn inner(&self) -> &Instance {
&self.instance
}
}
struct TestGuard {
_wal_tmp_dir: TempDir,
_data_tmp_dir: TempDir,
}
pub fn create_tmp_dir_and_datanode_opts(name: &str) -> (DatanodeOptions, TestGuard) {
fn create_tmp_dir_and_datanode_opts(name: &str) -> (DatanodeOptions, TestGuard) {
let wal_tmp_dir = TempDir::new(&format!("gt_wal_{name}")).unwrap();
let data_tmp_dir = TempDir::new(&format!("gt_data_{name}")).unwrap();
let opts = DatanodeOptions {
@@ -59,9 +77,8 @@ pub fn create_tmp_dir_and_datanode_opts(name: &str) -> (DatanodeOptions, TestGua
)
}
pub async fn create_test_table(
catalog_manager: &CatalogManagerRef,
sql_handler: &SqlHandler,
pub(crate) async fn create_test_table(
instance: &MockInstance,
ts_type: ConcreteDataType,
) -> Result<()> {
let column_schemas = vec![
@@ -72,7 +89,7 @@ pub async fn create_test_table(
];
let table_name = "demo";
let table_engine: TableEngineRef = sql_handler.table_engine();
let table_engine: TableEngineRef = instance.inner().sql_handler().table_engine();
let table = table_engine
.create_table(
&EngineContext::default(),
@@ -97,11 +114,10 @@ pub async fn create_test_table(
.await
.context(CreateTableSnafu { table_name })?;
let schema_provider = catalog_manager
.catalog(DEFAULT_CATALOG_NAME)
.unwrap()
.unwrap()
.schema(DEFAULT_SCHEMA_NAME)
let schema_provider = instance
.inner()
.catalog_manager
.schema(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME)
.unwrap()
.unwrap();
schema_provider

View File

@@ -7,7 +7,7 @@ license.workspace = true
[dependencies]
anymap = "1.0.0-beta.2"
api = { path = "../api" }
async-stream = "0.3"
async-stream.workspace = true
async-trait = "0.1"
catalog = { path = "../catalog" }
chrono = "0.4"

View File

@@ -7,7 +7,7 @@ license.workspace = true
[dependencies]
async-trait.workspace = true
arc-swap = "1.5"
async-stream = "0.3"
async-stream.workspace = true
base64 = "0.13"
byteorder = "1.4"
bytes = "1.1"

View File

@@ -10,7 +10,7 @@ test = ["tempdir"]
[dependencies]
arc-swap = "1.0"
async-stream = "0.3"
async-stream.workspace = true
async-trait = "0.1"
chrono = { version = "0.4", features = ["serde"] }
common-catalog = { path = "../common/catalog" }

View File

@@ -7,7 +7,7 @@ license.workspace = true
[dependencies]
arc-swap = "1.0"
async-compat = "0.2"
async-stream = "0.3"
async-stream.workspace = true
async-trait = "0.1"
bytes = "1.1"
common-base = { path = "../common/base" }

View File

@@ -18,6 +18,6 @@ serde.workspace = true
snafu.workspace = true
[dev-dependencies]
async-stream = "0.3"
async-stream.workspace = true
serde_json = "1.0"
tokio = { version = "1.0", features = ["full"] }

View File

@@ -20,7 +20,7 @@ use api::v1::{
use client::admin::Admin;
use client::{Client, Database, ObjectResult};
use common_catalog::consts::MIN_USER_TABLE_ID;
use common_grpc::flight::flight_messages_to_recordbatches;
use common_grpc::flight::{flight_messages_to_recordbatches, FlightMessage};
use servers::server::Server;
use tests_integration::test_util::{setup_grpc_server, StorageType};
@@ -167,7 +167,7 @@ pub async fn test_insert_and_select(store_type: StorageType) {
kind: Some(kind),
};
let result = admin.alter(expr).await.unwrap();
assert_eq!(result.result, None);
assert!(result.result.is_none());
// insert
insert_and_assert(&db).await;
@@ -195,11 +195,21 @@ async fn insert_and_assert(db: &Database) {
let result = db.insert(expr).await;
result.unwrap();
// select
let result = db
.select(client::Select::Sql("select * from demo".to_string()))
.sql(
"INSERT INTO demo(host, cpu, memory, ts) VALUES \
('host5', 66.6, 1024, 1672201027000),\
('host6', 88.8, 333.3, 1672201028000)",
)
.await
.unwrap();
assert!(matches!(result, ObjectResult::FlightData(_)));
let ObjectResult::FlightData(mut messages) = result else { unreachable!() };
assert_eq!(messages.len(), 1);
assert!(matches!(messages.remove(0), FlightMessage::AffectedRows(2)));
// select
let result = db.sql("SELECT * FROM demo").await.unwrap();
match result {
ObjectResult::FlightData(flight_messages) => {
let recordbatches = flight_messages_to_recordbatches(flight_messages).unwrap();
@@ -212,6 +222,8 @@ async fn insert_and_assert(db: &Database) {
| host2 | | 0.2 | 1970-01-01T00:00:00.101 |
| host3 | 0.41 | | 1970-01-01T00:00:00.102 |
| host4 | 0.2 | 0.3 | 1970-01-01T00:00:00.103 |
| host5 | 66.6 | 1024 | 2022-12-28T04:17:07 |
| host6 | 88.8 | 333.3 | 2022-12-28T04:17:08 |
+-------+------+--------+-------------------------+\
";
assert_eq!(pretty, expected);

View File

@@ -18,8 +18,9 @@ use std::process::Stdio;
use std::time::Duration;
use async_trait::async_trait;
use client::{Client, Database as DB, Error as ClientError, ObjectResult, Select};
use client::{Client, Database as DB, Error as ClientError, ObjectResult};
use common_grpc::flight;
use common_grpc::flight::FlightMessage;
use sqlness::{Database, Environment};
use tokio::process::{Child, Command};
@@ -112,8 +113,7 @@ pub struct GreptimeDB {
#[async_trait]
impl Database for GreptimeDB {
async fn query(&self, query: String) -> Box<dyn Display> {
let sql = Select::Sql(query);
let result = self.db.select(sql).await;
let result = self.db.sql(&query).await;
Box::new(ResultDisplayer { result }) as _
}
}
@@ -130,12 +130,19 @@ impl Display for ResultDisplayer {
write!(f, "{mutate_result:?}")
}
ObjectResult::FlightData(messages) => {
let pretty = flight::flight_messages_to_recordbatches(messages.clone())
.map_err(|e| e.to_string())
.and_then(|x| x.pretty_print().map_err(|e| e.to_string()));
match pretty {
Ok(s) => write!(f, "{s}"),
Err(e) => write!(f, "format result error: {e}"),
if let Some(FlightMessage::AffectedRows(rows)) = messages.get(0) {
write!(f, "Affected Rows: {rows}")
} else {
let pretty = flight::flight_messages_to_recordbatches(messages.clone())
.map_err(|e| e.to_string())
.and_then(|x| x.pretty_print().map_err(|e| e.to_string()));
match pretty {
Ok(s) => write!(f, "{s}"),
Err(e) => write!(
f,
"Failed to convert Flight messages {messages:?} to Recordbatches, error: {e}"
),
}
}
}
},