From 89d530674009e63469b68cb9bbe35a2860a97f7c Mon Sep 17 00:00:00 2001 From: LFC Date: Thu, 5 Jan 2023 14:17:57 +0800 Subject: [PATCH] feat: Impl Query and DDL functionality of Arrow Flight service for Frontend Instance (#827) * feat: Implement Query and DDL functionality of Arrow Flight service for Frontend Instance --- src/catalog/src/helper.rs | 1 - src/frontend/src/error.rs | 7 +- src/frontend/src/instance.rs | 154 +++----- src/frontend/src/instance/distributed.rs | 113 +++--- .../src/instance/distributed/flight.rs | 143 +++++++ src/frontend/src/instance/flight.rs | 368 ++++++++++++++++-- src/frontend/src/instance/grpc.rs | 37 +- src/frontend/src/instance/opentsdb.rs | 7 +- src/frontend/src/instance/prometheus.rs | 4 +- src/frontend/src/sql.rs | 13 +- src/frontend/src/table.rs | 7 +- src/frontend/src/tests.rs | 24 +- 12 files changed, 624 insertions(+), 254 deletions(-) create mode 100644 src/frontend/src/instance/distributed/flight.rs diff --git a/src/catalog/src/helper.rs b/src/catalog/src/helper.rs index ab3eb854ac..5d2d39ed71 100644 --- a/src/catalog/src/helper.rs +++ b/src/catalog/src/helper.rs @@ -132,7 +132,6 @@ impl TableGlobalKey { pub struct TableGlobalValue { /// Id of datanode that created the global table info kv. only for debugging. pub node_id: u64, - // TODO(LFC): Maybe remove it? /// Allocation of region ids across all datanodes. pub regions_id_map: HashMap>, pub table_info: RawTableInfo, diff --git a/src/frontend/src/error.rs b/src/frontend/src/error.rs index 386f5d6e1e..6223f64a2e 100644 --- a/src/frontend/src/error.rs +++ b/src/frontend/src/error.rs @@ -23,9 +23,8 @@ use store_api::storage::RegionId; #[derive(Debug, Snafu)] #[snafu(visibility(pub))] pub enum Error { - #[snafu(display("Failed to connect Datanode at {}, source: {}", addr, source))] - ConnectDatanode { - addr: String, + #[snafu(display("Invalid ObjectResult, source: {}", source))] + InvalidObjectResult { #[snafu(backtrace)] source: client::Error, }, @@ -488,7 +487,7 @@ impl ErrorExt for Error { | Error::VectorComputation { source } | Error::ConvertArrowSchema { source } => source.status_code(), - Error::ConnectDatanode { source, .. } | Error::RequestDatanode { source } => { + Error::InvalidObjectResult { source, .. } | Error::RequestDatanode { source } => { source.status_code() } diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index 286d9f9605..ff04e61093 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -26,8 +26,7 @@ use api::v1::alter_expr::Kind; use api::v1::ddl_request::Expr as DdlExpr; use api::v1::object_expr::Request; use api::v1::{ - AddColumns, AlterExpr, Column, CreateTableExpr, DdlRequest, DropTableExpr, InsertRequest, - ObjectExpr, + AddColumns, AlterExpr, Column, DdlRequest, DropTableExpr, InsertRequest, ObjectExpr, }; use async_trait::async_trait; use catalog::remote::MetaKvBackend; @@ -54,20 +53,16 @@ use session::context::QueryContextRef; use snafu::prelude::*; use sql::dialect::GenericDialect; use sql::parser::ParserContext; -use sql::statements::create::Partitions; -use sql::statements::insert::Insert; use sql::statements::statement::Statement; -use table::TableRef; use crate::catalog::FrontendCatalogManager; use crate::datanode::DatanodeClients; use crate::error::{ - self, CatalogSnafu, FindNewColumnsOnInsertionSnafu, InsertSnafu, MissingMetasrvOptsSnafu, - RequestDatanodeSnafu, Result, + self, CatalogSnafu, FindNewColumnsOnInsertionSnafu, InsertSnafu, InvalidObjectResultSnafu, + InvokeGrpcServerSnafu, MissingMetasrvOptsSnafu, Result, }; use crate::expr_factory::{CreateExprFactoryRef, DefaultCreateExprFactory}; use crate::frontend::FrontendOptions; -use crate::sql::insert_to_request; use crate::table::route::TableRoutes; use crate::Plugins; @@ -92,21 +87,17 @@ pub type FrontendInstanceRef = Arc; pub struct Instance { catalog_manager: CatalogManagerRef, - // TODO(LFC): Revisit script_handler here, maybe merge with sql_handler? /// Script handler is None in distributed mode, only works on standalone mode. script_handler: Option, + sql_handler: SqlQueryHandlerRef, + grpc_query_handler: GrpcQueryHandlerRef, + create_expr_factory: CreateExprFactoryRef, // TODO(fys): it should be a trait that corresponds to two implementations: // Standalone and Distributed, then the code behind it doesn't need to use so // many match statements. mode: Mode, - // TODO(LFC): Remove `dist_instance` together with Arrow Flight adoption refactor. - pub(crate) dist_instance: Option, - - sql_handler: SqlQueryHandlerRef, - grpc_query_handler: GrpcQueryHandlerRef, - /// plugins: this map holds extensions to customize query or auth /// behaviours. plugins: Arc, @@ -129,16 +120,15 @@ impl Instance { let dist_instance = DistInstance::new(meta_client, catalog_manager.clone(), datanode_clients); - let dist_instance_ref = Arc::new(dist_instance.clone()); + let dist_instance = Arc::new(dist_instance); Ok(Instance { catalog_manager, script_handler: None, create_expr_factory: Arc::new(DefaultCreateExprFactory), mode: Mode::Distributed, - dist_instance: Some(dist_instance), - sql_handler: dist_instance_ref.clone(), - grpc_query_handler: dist_instance_ref, + sql_handler: dist_instance.clone(), + grpc_query_handler: dist_instance, plugins: Default::default(), }) } @@ -179,7 +169,6 @@ impl Instance { script_handler: None, create_expr_factory: Arc::new(DefaultCreateExprFactory), mode: Mode::Standalone, - dist_instance: None, sql_handler: dn_instance.clone(), grpc_query_handler: dn_instance.clone(), plugins: Default::default(), @@ -187,16 +176,14 @@ impl Instance { } #[cfg(test)] - pub(crate) fn new_distributed(dist_instance: DistInstance) -> Self { - let dist_instance_ref = Arc::new(dist_instance.clone()); + pub(crate) fn new_distributed(dist_instance: Arc) -> Self { Instance { catalog_manager: dist_instance.catalog_manager(), script_handler: None, create_expr_factory: Arc::new(DefaultCreateExprFactory), mode: Mode::Distributed, - dist_instance: Some(dist_instance), - sql_handler: dist_instance_ref.clone(), - grpc_query_handler: dist_instance_ref, + sql_handler: dist_instance.clone(), + grpc_query_handler: dist_instance, plugins: Default::default(), } } @@ -213,29 +200,6 @@ impl Instance { self.script_handler = Some(handler); } - /// Handle create expr. - pub async fn handle_create_table( - &self, - mut expr: CreateTableExpr, - partitions: Option, - ) -> Result { - if let Some(v) = &self.dist_instance { - v.create_table(&mut expr, partitions).await - } else { - let result = self - .grpc_query_handler - .do_query(ObjectExpr { - request: Some(Request::Ddl(DdlRequest { - expr: Some(DdlExpr::CreateTable(expr)), - })), - }) - .await - .context(error::InvokeGrpcServerSnafu)?; - let output: RpcOutput = result.try_into().context(RequestDatanodeSnafu)?; - Ok(output.into()) - } - } - /// Handle batch inserts pub async fn handle_inserts(&self, requests: Vec) -> Result { let mut success = 0; @@ -263,7 +227,7 @@ impl Instance { }; let result = GrpcQueryHandler::do_query(&*self.grpc_query_handler, query) .await - .context(error::InvokeGrpcServerSnafu)?; + .context(InvokeGrpcServerSnafu)?; let result: RpcOutput = result.try_into().context(InsertSnafu)?; Ok(result.into()) } @@ -278,7 +242,11 @@ impl Instance { table_name: &str, columns: &[Column], ) -> Result<()> { - match self.find_table(catalog_name, schema_name, table_name)? { + let table = self + .catalog_manager + .table(catalog_name, schema_name, table_name) + .context(CatalogSnafu)?; + match table { None => { info!( "Table {}.{}.{} does not exist, try create table", @@ -336,8 +304,18 @@ impl Instance { "Try to create table: {} automatically with request: {:?}", table_name, create_expr, ); - // Create-on-insert does support partition by other columns now - self.handle_create_table(create_expr, None).await + + let result = self + .grpc_query_handler + .do_query(ObjectExpr { + request: Some(Request::Ddl(DdlRequest { + expr: Some(DdlExpr::CreateTable(create_expr)), + })), + }) + .await + .context(InvokeGrpcServerSnafu)?; + let output: RpcOutput = result.try_into().context(InvalidObjectResultSnafu)?; + Ok(output.into()) } async fn add_new_columns_to_table( @@ -366,8 +344,8 @@ impl Instance { })), }) .await - .context(error::InvokeGrpcServerSnafu)?; - let output: RpcOutput = result.try_into().context(RequestDatanodeSnafu)?; + .context(InvokeGrpcServerSnafu)?; + let output: RpcOutput = result.try_into().context(InvalidObjectResultSnafu)?; Ok(output.into()) } @@ -387,37 +365,6 @@ impl Instance { }) } - fn find_table(&self, catalog: &str, schema: &str, table: &str) -> Result> { - self.catalog_manager - .table(catalog, schema, table) - .context(CatalogSnafu) - } - - async fn sql_dist_insert(&self, insert: Box) -> Result { - let (catalog, schema, table) = insert.full_table_name().context(error::ParseSqlSnafu)?; - - let catalog_provider = self.get_catalog(&catalog)?; - let schema_provider = Self::get_schema(catalog_provider, &schema)?; - - let insert_request = insert_to_request(&schema_provider, *insert)?; - - let (columns, _row_count) = - crate::table::insert::insert_request_to_insert_batch(&insert_request)?; - - self.create_or_alter_table_on_demand(&catalog, &schema, &table, &columns) - .await?; - - let table = schema_provider - .table(&table) - .context(error::CatalogSnafu)? - .context(error::TableNotFoundSnafu { table_name: &table })?; - - table - .insert(insert_request) - .await - .context(error::TableSnafu) - } - fn handle_use(&self, db: String, query_ctx: QueryContextRef) -> Result { ensure!( self.catalog_manager @@ -468,24 +415,10 @@ impl Instance { | Statement::ShowTables(_) | Statement::DescribeTable(_) | Statement::Explain(_) - | Statement::Query(_) => { + | Statement::Query(_) + | Statement::Insert(_) => { return self.sql_handler.do_statement_query(stmt, query_ctx).await; } - Statement::Insert(insert) => match self.mode { - Mode::Standalone => { - return self.sql_handler.do_statement_query(stmt, query_ctx).await - } - Mode::Distributed => { - let affected = self - .sql_dist_insert(insert) - .await - .map_err(BoxedError::new) - .context(server_error::ExecuteInsertSnafu { - msg: "execute insert failed", - })?; - Ok(Output::AffectedRows(affected)) - } - }, Statement::Alter(alter_stmt) => { let expr = AlterExpr::try_from(alter_stmt) .map_err(BoxedError::new) @@ -639,7 +572,8 @@ mod tests { use api::v1::column::SemanticType; use api::v1::{ - column, query_request, Column, ColumnDataType, ColumnDef as GrpcColumnDef, QueryRequest, + column, query_request, Column, ColumnDataType, ColumnDef as GrpcColumnDef, CreateTableExpr, + QueryRequest, }; use common_grpc::flight::{raw_flight_data_to_message, FlightMessage}; use common_recordbatch::RecordBatch; @@ -656,7 +590,8 @@ mod tests { async fn test_execute_sql() { let query_ctx = Arc::new(QueryContext::new()); - let (instance, _guard) = tests::create_standalone_instance("test_execute_sql").await; + let standalone = tests::create_standalone_instance("test_execute_sql").await; + let instance = standalone.instance; let sql = r#"CREATE TABLE demo( host STRING, @@ -749,7 +684,8 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_execute_grpc() { - let (instance, _guard) = tests::create_standalone_instance("test_execute_grpc").await; + let standalone = tests::create_standalone_instance("test_execute_grpc").await; + let instance = standalone.instance; // testing data: let expected_host_col = Column { @@ -1015,8 +951,8 @@ mod tests { } } - let query_ctx = Arc::new(QueryContext::new()); - let (mut instance, _guard) = tests::create_standalone_instance("test_hook").await; + let standalone = tests::create_standalone_instance("test_hook").await; + let mut instance = standalone.instance; let mut plugins = Plugins::new(); let counter_hook = Arc::new(AssertionHook::default()); @@ -1032,7 +968,7 @@ mod tests { TIME INDEX (ts), PRIMARY KEY(host) ) engine=mito with(regions=1);"#; - let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone()) + let output = SqlQueryHandler::do_query(&*instance, sql, QueryContext::arc()) .await .remove(0) .unwrap(); @@ -1072,7 +1008,9 @@ mod tests { } let query_ctx = Arc::new(QueryContext::new()); - let (mut instance, _guard) = tests::create_standalone_instance("test_db_hook").await; + + let standalone = tests::create_standalone_instance("test_db_hook").await; + let mut instance = standalone.instance; let mut plugins = Plugins::new(); let hook = Arc::new(DisableDBOpHook::default()); diff --git a/src/frontend/src/instance/distributed.rs b/src/frontend/src/instance/distributed.rs index a503138886..c801d36c15 100644 --- a/src/frontend/src/instance/distributed.rs +++ b/src/frontend/src/instance/distributed.rs @@ -12,17 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod flight; + use std::collections::HashMap; use std::sync::Arc; use api::helper::ColumnDataTypeWrapper; -use api::result::ObjectResultBuilder; -use api::v1::ddl_request::Expr as DdlExpr; -use api::v1::object_expr::Request as GrpcRequest; use api::v1::{ AlterExpr, CreateDatabaseExpr, CreateTableExpr, InsertRequest, ObjectExpr, ObjectResult, TableId, }; +use arrow_flight::flight_service_server::FlightService; +use arrow_flight::Ticket; use async_trait::async_trait; use catalog::helper::{SchemaKey, SchemaValue, TableGlobalKey, TableGlobalValue}; use catalog::{CatalogList, CatalogManager}; @@ -30,7 +31,7 @@ use chrono::DateTime; use client::Database; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_error::prelude::BoxedError; -use common_grpc::flight::{FlightEncoder, FlightMessage}; +use common_grpc::flight::flight_data_to_object_result; use common_query::Output; use common_telemetry::{debug, error, info}; use datatypes::prelude::ConcreteDataType; @@ -40,6 +41,7 @@ use meta_client::rpc::{ CreateRequest as MetaCreateRequest, Partition as MetaPartition, PutRequest, RouteResponse, TableName, TableRoute, }; +use prost::Message; use query::sql::{describe_table, explain, show_databases, show_tables}; use query::{QueryEngineFactory, QueryEngineRef}; use servers::error as server_error; @@ -52,18 +54,20 @@ use sql::statements::sql_value_to_value; use sql::statements::statement::Statement; use table::metadata::{RawTableInfo, RawTableMeta, TableIdent, TableType}; use table::table::AlterContext; +use tonic::Request; use crate::catalog::FrontendCatalogManager; use crate::datanode::DatanodeClients; use crate::error::{ self, AlterExprToRequestSnafu, CatalogEntrySerdeSnafu, CatalogNotFoundSnafu, CatalogSnafu, - ColumnDataTypeSnafu, PrimaryKeyNotFoundSnafu, RequestDatanodeSnafu, RequestMetaSnafu, Result, - SchemaNotFoundSnafu, StartMetaClientSnafu, TableNotFoundSnafu, TableSnafu, - ToTableInsertRequestSnafu, + ColumnDataTypeSnafu, FlightGetSnafu, InvalidFlightDataSnafu, ParseSqlSnafu, + PrimaryKeyNotFoundSnafu, RequestDatanodeSnafu, RequestMetaSnafu, Result, SchemaNotFoundSnafu, + StartMetaClientSnafu, TableNotFoundSnafu, TableSnafu, ToTableInsertRequestSnafu, }; use crate::expr_factory::{CreateExprFactory, DefaultCreateExprFactory}; use crate::instance::parse_stmt; use crate::partitioning::{PartitionBound, PartitionDef}; +use crate::sql::insert_to_request; #[derive(Clone)] pub(crate) struct DistInstance { @@ -162,8 +166,7 @@ impl DistInstance { let expr = CreateDatabaseExpr { database_name: stmt.name.to_string(), }; - self.handle_create_database(expr).await?; - Ok(Output::AffectedRows(1)) + Ok(self.handle_create_database(expr).await?) } Statement::CreateTable(stmt) => { let create_expr = &mut DefaultCreateExprFactory.create_expr_by_stmt(&stmt).await?; @@ -177,6 +180,21 @@ impl DistInstance { Statement::Explain(stmt) => { explain(Box::new(stmt), self.query_engine.clone(), query_ctx).await } + Statement::Insert(insert) => { + let (catalog, schema, table) = insert.full_table_name().context(ParseSqlSnafu)?; + + let table = self + .catalog_manager + .table(&catalog, &schema, &table) + .context(CatalogSnafu)? + .context(TableNotFoundSnafu { table_name: table })?; + + let insert_request = insert_to_request(&table, *insert)?; + + return Ok(Output::AffectedRows( + table.insert(insert_request).await.context(TableSnafu)?, + )); + } _ => unreachable!(), } .context(error::ExecuteStatementSnafu) @@ -206,7 +224,7 @@ impl DistInstance { } /// Handles distributed database creation - async fn handle_create_database(&self, expr: CreateDatabaseExpr) -> Result<()> { + async fn handle_create_database(&self, expr: CreateDatabaseExpr) -> Result { let key = SchemaKey { catalog_name: DEFAULT_CATALOG_NAME.to_string(), schema_name: expr.database_name, @@ -221,10 +239,11 @@ impl DistInstance { .with_key(key.to_string()) .with_value(value.as_bytes().context(CatalogEntrySerdeSnafu)?); client.put(request.into()).await.context(RequestMetaSnafu)?; - Ok(()) + + Ok(Output::AffectedRows(1)) } - async fn handle_alter_table(&self, expr: AlterExpr) -> Result<()> { + async fn handle_alter_table(&self, expr: AlterExpr) -> Result { let catalog_name = if expr.catalog_name.is_empty() { DEFAULT_CATALOG_NAME } else { @@ -258,7 +277,9 @@ impl DistInstance { let mut context = AlterContext::with_capacity(1); context.insert(expr); - table.alter(context, request).await.context(TableSnafu) + table.alter(context, request).await.context(TableSnafu)?; + + Ok(Output::AffectedRows(0)) } async fn create_table_in_meta( @@ -332,7 +353,7 @@ impl DistInstance { // GRPC InsertRequest to Table InsertRequest, than split Table InsertRequest, than assemble each GRPC InsertRequest, is rather inefficient, // should operate on GRPC InsertRequest directly. // Also remember to check the "region_number" carried in InsertRequest, too. - async fn handle_dist_insert(&self, request: InsertRequest) -> Result { + async fn handle_dist_insert(&self, request: InsertRequest) -> Result { let table_name = &request.table_name; // TODO(LFC): InsertRequest should carry catalog name, too. let table = self @@ -344,7 +365,15 @@ impl DistInstance { let request = common_grpc_expr::insert::to_table_insert_request(request) .context(ToTableInsertRequestSnafu)?; - table.insert(request).await.context(TableSnafu) + let affected_rows = table.insert(request).await.context(TableSnafu)?; + Ok(Output::AffectedRows(affected_rows)) + } + + async fn boarding(&self, ticket: Request) -> Result { + let response = self.do_get(ticket).await.context(FlightGetSnafu)?; + flight_data_to_object_result(response) + .await + .context(InvalidFlightDataSnafu) } #[cfg(test)] @@ -391,43 +420,17 @@ impl SqlQueryHandler for DistInstance { #[async_trait] impl GrpcQueryHandler for DistInstance { - async fn do_query(&self, expr: ObjectExpr) -> server_error::Result { - let request = expr - .clone() - .request - .context(server_error::InvalidQuerySnafu { - reason: "empty expr", - })?; - let flight_messages = match request { - GrpcRequest::Ddl(request) => { - let expr = request.expr.context(server_error::InvalidQuerySnafu { - reason: "empty DDL expr", - })?; - let result = match expr { - DdlExpr::CreateDatabase(expr) => self.handle_create_database(expr).await, - DdlExpr::Alter(expr) => self.handle_alter_table(expr).await, - DdlExpr::CreateTable(_) | DdlExpr::DropTable(_) => unimplemented!(), - }; - result.map(|_| vec![FlightMessage::AffectedRows(1)]) - } - GrpcRequest::Insert(request) => self - .handle_dist_insert(request) - .await - .map(|x| vec![FlightMessage::AffectedRows(x)]), - // TODO(LFC): Implement Flight for DistInstance. - GrpcRequest::Query(_) => unimplemented!(), - } - .map_err(BoxedError::new) - .with_context(|_| server_error::ExecuteQuerySnafu { - query: format!("{expr:?}"), - })?; - - let encoder = FlightEncoder::default(); - let flight_data = flight_messages - .into_iter() - .map(|x| encoder.encode(x)) - .collect(); - Ok(ObjectResultBuilder::new().flight_data(flight_data).build()) + async fn do_query(&self, query: ObjectExpr) -> server_error::Result { + let ticket = Request::new(Ticket { + ticket: query.encode_to_vec(), + }); + // TODO(LFC): Temporarily use old GRPC interface here, will get rid of them near the end of Arrow Flight adoption. + self.boarding(ticket) + .await + .map_err(BoxedError::new) + .with_context(|_| servers::error::ExecuteQuerySnafu { + query: format!("{query:?}"), + }) } } @@ -677,7 +680,7 @@ ENGINE=mito", #[tokio::test(flavor = "multi_thread")] async fn test_show_databases() { let instance = crate::tests::create_distributed_instance("test_show_databases").await; - let dist_instance = instance.frontend.dist_instance.as_ref().unwrap(); + let dist_instance = &instance.dist_instance; let sql = "create database test_show_databases"; let output = dist_instance @@ -728,7 +731,7 @@ ENGINE=mito", #[tokio::test(flavor = "multi_thread")] async fn test_show_tables() { let instance = crate::tests::create_distributed_instance("test_show_tables").await; - let dist_instance = instance.frontend.dist_instance.as_ref().unwrap(); + let dist_instance = &instance.dist_instance; let datanode_instances = instance.datanodes; let sql = "create database test_show_tables"; @@ -777,7 +780,7 @@ ENGINE=mito", } } - assert_show_tables(Arc::new(dist_instance.clone())).await; + assert_show_tables(dist_instance.clone()).await; // Asserts that new table is created in Datanode as well. for x in datanode_instances.values() { diff --git a/src/frontend/src/instance/distributed/flight.rs b/src/frontend/src/instance/distributed/flight.rs new file mode 100644 index 0000000000..f1bc2e7f7a --- /dev/null +++ b/src/frontend/src/instance/distributed/flight.rs @@ -0,0 +1,143 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::pin::Pin; + +use api::v1::ddl_request::Expr as DdlExpr; +use api::v1::object_expr::Request as GrpcRequest; +use api::v1::ObjectExpr; +use arrow_flight::flight_service_server::FlightService; +use arrow_flight::{ + Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, + HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket, +}; +use async_trait::async_trait; +use datanode::instance::flight::to_flight_data_stream; +use futures::Stream; +use prost::Message; +use snafu::{OptionExt, ResultExt}; +use tonic::{Request, Response, Status, Streaming}; + +use crate::error::{IncompleteGrpcResultSnafu, InvalidFlightTicketSnafu}; +use crate::instance::distributed::DistInstance; + +type TonicResult = Result; +type TonicStream = Pin> + Send + Sync + 'static>>; + +#[async_trait] +impl FlightService for DistInstance { + type HandshakeStream = TonicStream; + + async fn handshake( + &self, + _: Request>, + ) -> TonicResult> { + Err(Status::unimplemented("Not yet implemented")) + } + + type ListFlightsStream = TonicStream; + + async fn list_flights( + &self, + _: Request, + ) -> TonicResult> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn get_flight_info( + &self, + _: Request, + ) -> TonicResult> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn get_schema( + &self, + _: Request, + ) -> TonicResult> { + Err(Status::unimplemented("Not yet implemented")) + } + + type DoGetStream = TonicStream; + + async fn do_get(&self, request: Request) -> TonicResult> { + let ticket = request.into_inner().ticket; + let request = ObjectExpr::decode(ticket.as_slice()) + .context(InvalidFlightTicketSnafu)? + .request + .context(IncompleteGrpcResultSnafu { + err_msg: "Missing 'request' in ObjectExpr", + })?; + let output = match request { + GrpcRequest::Insert(request) => self.handle_dist_insert(request).await?, + GrpcRequest::Query(_) => { + unreachable!("Query should have been handled directly in Frontend Instance!") + } + GrpcRequest::Ddl(request) => { + let expr = request.expr.context(IncompleteGrpcResultSnafu { + err_msg: "Missing 'expr' in DDL request", + })?; + match expr { + DdlExpr::CreateDatabase(expr) => self.handle_create_database(expr).await?, + DdlExpr::CreateTable(mut expr) => { + // TODO(LFC): Support creating distributed table through GRPC interface. + // Currently only SQL supports it; how to design the fields in CreateTableExpr? + self.create_table(&mut expr, None).await? + } + DdlExpr::Alter(expr) => self.handle_alter_table(expr).await?, + DdlExpr::DropTable(_) => { + // TODO(LFC): Implement distributed drop table. + // Seems the whole "drop table through GRPC interface" feature is not implemented? + return Err(Status::unimplemented("Not yet implemented")); + } + } + } + }; + let stream = to_flight_data_stream(output); + Ok(Response::new(stream)) + } + + type DoPutStream = TonicStream; + + async fn do_put( + &self, + _: Request>, + ) -> TonicResult> { + Err(Status::unimplemented("Not yet implemented")) + } + + type DoExchangeStream = TonicStream; + + async fn do_exchange( + &self, + _: Request>, + ) -> TonicResult> { + Err(Status::unimplemented("Not yet implemented")) + } + + type DoActionStream = TonicStream; + + async fn do_action(&self, _: Request) -> TonicResult> { + Err(Status::unimplemented("Not yet implemented")) + } + + type ListActionsStream = TonicStream; + + async fn list_actions( + &self, + _: Request, + ) -> TonicResult> { + Err(Status::unimplemented("Not yet implemented")) + } +} diff --git a/src/frontend/src/instance/flight.rs b/src/frontend/src/instance/flight.rs index 38b310d3b6..69e2ea7ab5 100644 --- a/src/frontend/src/instance/flight.rs +++ b/src/frontend/src/instance/flight.rs @@ -23,14 +23,19 @@ use arrow_flight::{ HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket, }; use async_trait::async_trait; +use client::RpcOutput; use datanode::instance::flight::to_flight_data_stream; use futures::Stream; use prost::Message; +use servers::query_handler::GrpcQueryHandler; use session::context::QueryContext; use snafu::{ensure, OptionExt, ResultExt}; use tonic::{Request, Response, Status, Streaming}; -use crate::error::{IncompleteGrpcResultSnafu, InvalidFlightTicketSnafu, InvalidSqlSnafu}; +use crate::error::{ + IncompleteGrpcResultSnafu, InvalidFlightTicketSnafu, InvalidObjectResultSnafu, InvalidSqlSnafu, + InvokeGrpcServerSnafu, +}; use crate::instance::{parse_stmt, Instance}; type TonicResult = Result; @@ -104,9 +109,15 @@ impl FlightService for Instance { } } } - GrpcRequest::Ddl(_request) => { - // TODO(LFC): Implement it. - unimplemented!() + GrpcRequest::Ddl(request) => { + let query = ObjectExpr { + request: Some(GrpcRequest::Ddl(request)), + }; + let result = GrpcQueryHandler::do_query(&*self.grpc_query_handler, query) + .await + .context(InvokeGrpcServerSnafu)?; + let result: RpcOutput = result.try_into().context(InvalidObjectResultSnafu)?; + result.into() } }; let stream = to_flight_data_stream(output); @@ -149,43 +160,82 @@ impl FlightService for Instance { #[cfg(test)] mod test { + use std::collections::HashMap; use std::sync::Arc; use api::v1::column::{SemanticType, Values}; - use api::v1::{Column, ColumnDataType, InsertRequest, QueryRequest}; + use api::v1::ddl_request::Expr as DdlExpr; + use api::v1::{ + alter_expr, AddColumn, AddColumns, AlterExpr, Column, ColumnDataType, ColumnDef, + CreateDatabaseExpr, CreateTableExpr, DdlRequest, InsertRequest, QueryRequest, + }; + use catalog::helper::{TableGlobalKey, TableGlobalValue}; use client::RpcOutput; use common_grpc::flight; + use common_query::Output; + use common_recordbatch::RecordBatches; use super::*; + use crate::table::DistTable; use crate::tests; + use crate::tests::MockDistributedInstance; #[tokio::test(flavor = "multi_thread")] - async fn test_distributed_insert_and_query() { - common_telemetry::init_default_ut_logging(); - + async fn test_distributed_handle_ddl_request() { let instance = - tests::create_distributed_instance("test_distributed_insert_and_query").await; + tests::create_distributed_instance("test_distributed_handle_ddl_request").await; + let frontend = &instance.frontend; - test_insert_and_query(&instance.frontend).await + test_handle_ddl_request(frontend).await } #[tokio::test(flavor = "multi_thread")] - async fn test_standalone_insert_and_query() { - common_telemetry::init_default_ut_logging(); + async fn test_standalone_handle_ddl_request() { + let standalone = + tests::create_standalone_instance("test_standalone_handle_ddl_request").await; + let instance = &standalone.instance; - let (instance, _) = - tests::create_standalone_instance("test_standalone_insert_and_query").await; - - test_insert_and_query(&instance).await + test_handle_ddl_request(instance).await } - async fn test_insert_and_query(instance: &Arc) { + async fn test_handle_ddl_request(instance: &Arc) { let ticket = Request::new(Ticket { ticket: ObjectExpr { - request: Some(GrpcRequest::Query(QueryRequest { - query: Some(Query::Sql( - "CREATE TABLE my_table (a INT, ts TIMESTAMP, TIME INDEX (ts))".to_string(), - )), + request: Some(GrpcRequest::Ddl(DdlRequest { + expr: Some(DdlExpr::CreateDatabase(CreateDatabaseExpr { + database_name: "database_created_through_grpc".to_string(), + })), + })), + } + .encode_to_vec(), + }); + let output = boarding(instance, ticket).await; + assert!(matches!(output, RpcOutput::AffectedRows(1))); + + let ticket = Request::new(Ticket { + ticket: ObjectExpr { + request: Some(GrpcRequest::Ddl(DdlRequest { + expr: Some(DdlExpr::CreateTable(CreateTableExpr { + catalog_name: "greptime".to_string(), + schema_name: "database_created_through_grpc".to_string(), + table_name: "table_created_through_grpc".to_string(), + column_defs: vec![ + ColumnDef { + name: "a".to_string(), + datatype: ColumnDataType::String as _, + is_nullable: true, + default_constraint: vec![], + }, + ColumnDef { + name: "ts".to_string(), + datatype: ColumnDataType::TimestampMillisecond as _, + is_nullable: false, + default_constraint: vec![], + }, + ], + time_index: "ts".to_string(), + ..Default::default() + })), })), } .encode_to_vec(), @@ -193,24 +243,216 @@ mod test { let output = boarding(instance, ticket).await; assert!(matches!(output, RpcOutput::AffectedRows(0))); + let ticket = Request::new(Ticket { + ticket: ObjectExpr { + request: Some(GrpcRequest::Ddl(DdlRequest { + expr: Some(DdlExpr::Alter(AlterExpr { + catalog_name: "greptime".to_string(), + schema_name: "database_created_through_grpc".to_string(), + table_name: "table_created_through_grpc".to_string(), + kind: Some(alter_expr::Kind::AddColumns(AddColumns { + add_columns: vec![AddColumn { + column_def: Some(ColumnDef { + name: "b".to_string(), + datatype: ColumnDataType::Int32 as _, + is_nullable: true, + default_constraint: vec![], + }), + is_key: false, + }], + })), + })), + })), + } + .encode_to_vec(), + }); + let output = boarding(instance, ticket).await; + assert!(matches!(output, RpcOutput::AffectedRows(0))); + + let ticket = Request::new(Ticket { + ticket: ObjectExpr { + request: Some(GrpcRequest::Query(QueryRequest { + query: Some(Query::Sql("INSERT INTO database_created_through_grpc.table_created_through_grpc (a, b, ts) VALUES ('s', 1, 1672816466000)".to_string())) + })) + }.encode_to_vec() + }); + let output = boarding(instance, ticket).await; + assert!(matches!(output, RpcOutput::AffectedRows(1))); + + let ticket = Request::new(Ticket { + ticket: ObjectExpr { + request: Some(GrpcRequest::Query(QueryRequest { + query: Some(Query::Sql("SELECT ts, a, b FROM database_created_through_grpc.table_created_through_grpc".to_string())) + })) + }.encode_to_vec() + }); + let output = boarding(instance, ticket).await; + let RpcOutput::RecordBatches(recordbatches) = output else { unreachable!() }; + let expected = "\ ++---------------------+---+---+ +| ts | a | b | ++---------------------+---+---+ +| 2023-01-04T07:14:26 | s | 1 | ++---------------------+---+---+"; + assert_eq!(recordbatches.pretty_print().unwrap(), expected); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_distributed_insert_and_query() { + common_telemetry::init_default_ut_logging(); + + let instance = + tests::create_distributed_instance("test_distributed_insert_and_query").await; + let frontend = &instance.frontend; + + let table_name = "my_dist_table"; + let sql = format!( + r" +CREATE TABLE {table_name} ( + a INT, + ts TIMESTAMP, + TIME INDEX (ts) +) PARTITION BY RANGE COLUMNS(a) ( + PARTITION r0 VALUES LESS THAN (10), + PARTITION r1 VALUES LESS THAN (20), + PARTITION r2 VALUES LESS THAN (50), + PARTITION r3 VALUES LESS THAN (MAXVALUE), +)" + ); + create_table(frontend, sql).await; + + test_insert_and_query_on_existing_table(frontend, table_name).await; + + verify_data_distribution( + &instance, + table_name, + HashMap::from([ + ( + 0u32, + "\ ++---------------------+---+ +| ts | a | ++---------------------+---+ +| 2023-01-01T07:26:12 | 1 | +| 2023-01-01T07:26:14 | | ++---------------------+---+", + ), + ( + 1u32, + "\ ++---------------------+----+ +| ts | a | ++---------------------+----+ +| 2023-01-01T07:26:13 | 11 | ++---------------------+----+", + ), + ( + 2u32, + "\ ++---------------------+----+ +| ts | a | ++---------------------+----+ +| 2023-01-01T07:26:15 | 20 | +| 2023-01-01T07:26:16 | 22 | ++---------------------+----+", + ), + ( + 3u32, + "\ ++---------------------+----+ +| ts | a | ++---------------------+----+ +| 2023-01-01T07:26:17 | 50 | +| 2023-01-01T07:26:18 | 55 | +| 2023-01-01T07:26:19 | 99 | ++---------------------+----+", + ), + ]), + ) + .await; + + test_insert_and_query_on_auto_created_table(frontend).await; + + // Auto created table has only one region. + verify_data_distribution( + &instance, + "auto_created_table", + HashMap::from([( + 0u32, + "\ ++---------------------+---+ +| ts | a | ++---------------------+---+ +| 2023-01-01T07:26:15 | 4 | +| 2023-01-01T07:26:16 | | +| 2023-01-01T07:26:17 | 6 | +| 2023-01-01T07:26:18 | | +| 2023-01-01T07:26:19 | | +| 2023-01-01T07:26:20 | | ++---------------------+---+", + )]), + ) + .await; + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_standalone_insert_and_query() { + common_telemetry::init_default_ut_logging(); + + let standalone = + tests::create_standalone_instance("test_standalone_insert_and_query").await; + let instance = &standalone.instance; + + let table_name = "my_table"; + let sql = format!("CREATE TABLE {table_name} (a INT, ts TIMESTAMP, TIME INDEX (ts))"); + create_table(instance, sql).await; + + test_insert_and_query_on_existing_table(instance, table_name).await; + + test_insert_and_query_on_auto_created_table(instance).await + } + + async fn create_table(frontend: &Arc, sql: String) { + let ticket = Request::new(Ticket { + ticket: ObjectExpr { + request: Some(GrpcRequest::Query(QueryRequest { + query: Some(Query::Sql(sql)), + })), + } + .encode_to_vec(), + }); + let output = boarding(frontend, ticket).await; + assert!(matches!(output, RpcOutput::AffectedRows(0))); + } + + async fn test_insert_and_query_on_existing_table(instance: &Arc, table_name: &str) { let insert = InsertRequest { schema_name: "public".to_string(), - table_name: "my_table".to_string(), + table_name: table_name.to_string(), columns: vec![ Column { column_name: "a".to_string(), values: Some(Values { - i32_values: vec![1, 3], + i32_values: vec![1, 11, 20, 22, 50, 55, 99], ..Default::default() }), - null_mask: vec![2], + null_mask: vec![4], semantic_type: SemanticType::Field as i32, datatype: ColumnDataType::Int32 as i32, }, Column { column_name: "ts".to_string(), values: Some(Values { - ts_millisecond_values: vec![1672557972000, 1672557973000, 1672557974000], + ts_millisecond_values: vec![ + 1672557972000, + 1672557973000, + 1672557974000, + 1672557975000, + 1672557976000, + 1672557977000, + 1672557978000, + 1672557979000, + ], ..Default::default() }), semantic_type: SemanticType::Timestamp as i32, @@ -218,7 +460,7 @@ mod test { ..Default::default() }, ], - row_count: 3, + row_count: 8, ..Default::default() }; @@ -229,14 +471,15 @@ mod test { .encode_to_vec(), }); - // Test inserting to exist table. let output = boarding(instance, ticket).await; - assert!(matches!(output, RpcOutput::AffectedRows(3))); + assert!(matches!(output, RpcOutput::AffectedRows(8))); let ticket = Request::new(Ticket { ticket: ObjectExpr { request: Some(GrpcRequest::Query(QueryRequest { - query: Some(Query::Sql("SELECT ts, a FROM my_table".to_string())), + query: Some(Query::Sql(format!( + "SELECT ts, a FROM {table_name} ORDER BY ts" + ))), })), } .encode_to_vec(), @@ -245,15 +488,68 @@ mod test { let output = boarding(instance, ticket).await; let RpcOutput::RecordBatches(recordbatches) = output else { unreachable!() }; let expected = "\ -+---------------------+---+ -| ts | a | -+---------------------+---+ -| 2023-01-01T07:26:12 | 1 | -| 2023-01-01T07:26:13 | | -| 2023-01-01T07:26:14 | 3 | -+---------------------+---+"; ++---------------------+----+ +| ts | a | ++---------------------+----+ +| 2023-01-01T07:26:12 | 1 | +| 2023-01-01T07:26:13 | 11 | +| 2023-01-01T07:26:14 | | +| 2023-01-01T07:26:15 | 20 | +| 2023-01-01T07:26:16 | 22 | +| 2023-01-01T07:26:17 | 50 | +| 2023-01-01T07:26:18 | 55 | +| 2023-01-01T07:26:19 | 99 | ++---------------------+----+"; assert_eq!(recordbatches.pretty_print().unwrap(), expected); + } + async fn verify_data_distribution( + instance: &MockDistributedInstance, + table_name: &str, + expected_distribution: HashMap, + ) { + let table = instance + .frontend + .catalog_manager() + .table("greptime", "public", table_name) + .unwrap() + .unwrap(); + let table = table.as_any().downcast_ref::().unwrap(); + + let TableGlobalValue { regions_id_map, .. } = table + .table_global_value(&TableGlobalKey { + catalog_name: "greptime".to_string(), + schema_name: "public".to_string(), + table_name: table_name.to_string(), + }) + .await + .unwrap() + .unwrap(); + let region_to_dn_map = regions_id_map + .iter() + .map(|(k, v)| (v[0], *k)) + .collect::>(); + assert_eq!(region_to_dn_map.len(), expected_distribution.len()); + + for (region, dn) in region_to_dn_map.iter() { + let dn = instance.datanodes.get(dn).unwrap(); + let output = dn + .execute_sql( + &format!("SELECT ts, a FROM {table_name} ORDER BY ts"), + QueryContext::arc(), + ) + .await + .unwrap(); + let Output::Stream(stream) = output else { unreachable!() }; + let recordbatches = RecordBatches::try_collect(stream).await.unwrap(); + let actual = recordbatches.pretty_print().unwrap(); + + let expected = expected_distribution.get(region).unwrap(); + assert_eq!(&actual, expected); + } + } + + async fn test_insert_and_query_on_auto_created_table(instance: &Arc) { let insert = InsertRequest { schema_name: "public".to_string(), table_name: "auto_created_table".to_string(), diff --git a/src/frontend/src/instance/grpc.rs b/src/frontend/src/instance/grpc.rs index ecac1e5058..40169b173d 100644 --- a/src/frontend/src/instance/grpc.rs +++ b/src/frontend/src/instance/grpc.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use api::v1::object_expr::Request as GrpcRequest; use api::v1::{ObjectExpr, ObjectResult}; use arrow_flight::flight_service_server::FlightService; use arrow_flight::Ticket; @@ -22,7 +21,7 @@ use common_grpc::flight; use prost::Message; use servers::error as server_error; use servers::query_handler::GrpcQueryHandler; -use snafu::{OptionExt, ResultExt}; +use snafu::ResultExt; use tonic::Request; use crate::error::{FlightGetSnafu, InvalidFlightDataSnafu, Result}; @@ -40,29 +39,15 @@ impl Instance { #[async_trait] impl GrpcQueryHandler for Instance { async fn do_query(&self, query: ObjectExpr) -> server_error::Result { - let request = query - .clone() - .request - .context(server_error::InvalidQuerySnafu { - reason: "empty expr", - })?; - match request { - // TODO(LFC): Unify to "boarding" when do_get supports DDL requests. - GrpcRequest::Ddl(_) => { - GrpcQueryHandler::do_query(&*self.grpc_query_handler, query).await - } - _ => { - let ticket = Request::new(Ticket { - ticket: query.encode_to_vec(), - }); - // TODO(LFC): Temporarily use old GRPC interface here, will get rid of them near the end of Arrow Flight adoption. - self.boarding(ticket) - .await - .map_err(BoxedError::new) - .with_context(|_| servers::error::ExecuteQuerySnafu { - query: format!("{query:?}"), - }) - } - } + let ticket = Request::new(Ticket { + ticket: query.encode_to_vec(), + }); + // TODO(LFC): Temporarily use old GRPC interface here, will get rid of them near the end of Arrow Flight adoption. + self.boarding(ticket) + .await + .map_err(BoxedError::new) + .with_context(|_| servers::error::ExecuteQuerySnafu { + query: format!("{query:?}"), + }) } } diff --git a/src/frontend/src/instance/opentsdb.rs b/src/frontend/src/instance/opentsdb.rs index e8d965b5c9..2545b2c4e6 100644 --- a/src/frontend/src/instance/opentsdb.rs +++ b/src/frontend/src/instance/opentsdb.rs @@ -72,7 +72,8 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_exec() { - let (instance, _guard) = tests::create_standalone_instance("test_exec").await; + let standalone = tests::create_standalone_instance("test_exec").await; + let instance = standalone.instance; instance .exec( &DataPoint::try_create( @@ -90,8 +91,8 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_insert_opentsdb_metric() { - let (instance, _guard) = - tests::create_standalone_instance("test_insert_opentsdb_metric").await; + let standalone = tests::create_standalone_instance("test_insert_opentsdb_metric").await; + let instance = standalone.instance; let data_point1 = DataPoint::new( "my_metric_1".to_string(), diff --git a/src/frontend/src/instance/prometheus.rs b/src/frontend/src/instance/prometheus.rs index 58b071d311..dc4cc71e62 100644 --- a/src/frontend/src/instance/prometheus.rs +++ b/src/frontend/src/instance/prometheus.rs @@ -177,9 +177,9 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_prometheus_remote_write_and_read() { - common_telemetry::init_default_ut_logging(); - let (instance, _guard) = + let standalone = tests::create_standalone_instance("test_prometheus_remote_write_and_read").await; + let instance = standalone.instance; let write_request = WriteRequest { timeseries: prometheus::mock_timeseries(), diff --git a/src/frontend/src/sql.rs b/src/frontend/src/sql.rs index 0e85d8f0a7..8051c49a41 100644 --- a/src/frontend/src/sql.rs +++ b/src/frontend/src/sql.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use catalog::SchemaProviderRef; use common_error::snafu::ensure; use datatypes::data_type::DataType; use datatypes::prelude::{ConcreteDataType, MutableVector}; @@ -21,26 +20,18 @@ use sql::ast::Value as SqlValue; use sql::statements; use sql::statements::insert::Insert; use table::requests::InsertRequest; +use table::TableRef; use crate::error::{self, BuildVectorSnafu, Result}; // TODO(fys): Extract the common logic in datanode and frontend in the future. #[allow(dead_code)] -pub(crate) fn insert_to_request( - schema_provider: &SchemaProviderRef, - stmt: Insert, -) -> Result { +pub(crate) fn insert_to_request(table: &TableRef, stmt: Insert) -> Result { let columns = stmt.columns(); let values = stmt.values().context(error::ParseSqlSnafu)?; let (catalog_name, schema_name, table_name) = stmt.full_table_name().context(error::ParseSqlSnafu)?; - let table = schema_provider - .table(&table_name) - .context(error::CatalogSnafu)? - .context(error::TableNotFoundSnafu { - table_name: &table_name, - })?; let schema = table.schema(); let columns_num = if columns.is_empty() { schema.column_schemas().len() diff --git a/src/frontend/src/table.rs b/src/frontend/src/table.rs index b3bc609c4b..9cb5773a2c 100644 --- a/src/frontend/src/table.rs +++ b/src/frontend/src/table.rs @@ -385,7 +385,10 @@ impl DistTable { Ok(partition_rule) } - async fn table_global_value(&self, key: &TableGlobalKey) -> Result> { + pub(crate) async fn table_global_value( + &self, + key: &TableGlobalKey, + ) -> Result> { let raw = self .backend .get(key.to_string().as_bytes()) @@ -1027,7 +1030,7 @@ mod test { let schema = Arc::new(Schema::new(column_schemas.clone())); let instance = crate::tests::create_distributed_instance(test_name).await; - let dist_instance = instance.frontend.dist_instance.as_ref().unwrap(); + let dist_instance = &instance.dist_instance; let datanode_instances = instance.datanodes; let catalog_manager = dist_instance.catalog_manager(); diff --git a/src/frontend/src/tests.rs b/src/frontend/src/tests.rs index 2274839e81..6a39671034 100644 --- a/src/frontend/src/tests.rs +++ b/src/frontend/src/tests.rs @@ -47,19 +47,29 @@ pub struct TestGuard { _data_tmp_dir: TempDir, } -pub(crate) struct MockDistributedInstances { +pub(crate) struct MockDistributedInstance { pub(crate) frontend: Arc, + pub(crate) dist_instance: Arc, pub(crate) datanodes: HashMap>, _guards: Vec, } -pub(crate) async fn create_standalone_instance(test_name: &str) -> (Arc, TestGuard) { +pub(crate) struct MockStandaloneInstance { + pub(crate) instance: Arc, + _guard: TestGuard, +} + +pub(crate) async fn create_standalone_instance(test_name: &str) -> MockStandaloneInstance { let (opts, guard) = create_tmp_dir_and_datanode_opts(test_name); let datanode_instance = DatanodeInstance::new(&opts).await.unwrap(); datanode_instance.start().await.unwrap(); let frontend_instance = Instance::new_standalone(Arc::new(datanode_instance)); - (Arc::new(frontend_instance), guard) + + MockStandaloneInstance { + instance: Arc::new(frontend_instance), + _guard: guard, + } } fn create_tmp_dir_and_datanode_opts(name: &str) -> (DatanodeOptions, TestGuard) { @@ -182,7 +192,7 @@ async fn wait_datanodes_alive(kv_store: KvStoreRef) { panic!() } -pub(crate) async fn create_distributed_instance(test_name: &str) -> MockDistributedInstances { +pub(crate) async fn create_distributed_instance(test_name: &str) -> MockDistributedInstance { let kv_store: KvStoreRef = Arc::new(MemStore::default()) as _; let meta_srv = meta_srv::mocks::mock(MetaSrvOptions::default(), kv_store.clone(), None).await; @@ -233,10 +243,12 @@ pub(crate) async fn create_distributed_instance(test_name: &str) -> MockDistribu catalog_manager, datanode_clients.clone(), ); - let frontend = Instance::new_distributed(dist_instance); + let dist_instance = Arc::new(dist_instance); + let frontend = Instance::new_distributed(dist_instance.clone()); - MockDistributedInstances { + MockDistributedInstance { frontend: Arc::new(frontend), + dist_instance, datanodes: datanode_instances, _guards: test_guards, }