From d27b9fc3a1f74450edfef7fc7ecb102a3436e7ca Mon Sep 17 00:00:00 2001 From: LFC <990479+MichaelScofield@users.noreply.github.com> Date: Thu, 17 Apr 2025 11:46:19 +0800 Subject: [PATCH] feat: implement Arrow Flight "DoPut" in Frontend (#5836) * feat: implement Arrow Flight "DoPut" in Frontend * support auth for "do_put" * set request_id in DoPut requests and responses * set "db" in request header --- Cargo.lock | 4 + src/client/Cargo.toml | 2 + src/client/src/database.rs | 76 ++++++- src/client/src/error.rs | 13 +- src/client/src/lib.rs | 2 +- src/common/grpc/Cargo.toml | 2 + src/common/grpc/src/error.rs | 11 +- src/common/grpc/src/flight.rs | 2 + src/common/grpc/src/flight/do_put.rs | 93 +++++++++ src/frontend/src/instance/grpc.rs | 36 +++- src/servers/src/grpc/flight.rs | 193 ++++++++++++++++- src/servers/src/grpc/greptime_handler.rs | 107 +++++++++- src/servers/src/http.rs | 15 -- src/servers/src/query_handler/grpc.rs | 22 ++ src/servers/tests/http/influxdb_test.rs | 15 -- src/servers/tests/http/opentsdb_test.rs | 15 -- src/servers/tests/http/prom_store_test.rs | 15 -- src/servers/tests/mod.rs | 14 +- src/table/src/error.rs | 6 +- src/table/src/table_name.rs | 38 ++++ tests-integration/src/cluster.rs | 39 ++-- tests-integration/src/grpc.rs | 56 ++--- tests-integration/src/grpc/flight.rs | 242 ++++++++++++++++++++++ tests-integration/src/test_util.rs | 36 +++- tests-integration/src/tests/test_util.rs | 9 +- tests-integration/tests/grpc.rs | 66 +++--- 26 files changed, 944 insertions(+), 185 deletions(-) create mode 100644 src/common/grpc/src/flight/do_put.rs create mode 100644 tests-integration/src/grpc/flight.rs diff --git a/Cargo.lock b/Cargo.lock index b133daa143..cb9ad31c15 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1718,6 +1718,7 @@ dependencies = [ "arrow-flight", "async-stream", "async-trait", + "base64 0.22.1", "common-catalog", "common-error", "common-grpc", @@ -1728,6 +1729,7 @@ dependencies = [ "common-recordbatch", "common-telemetry", "enum_dispatch", + "futures", "futures-util", "lazy_static", "moka", @@ -2097,6 +2099,8 @@ dependencies = [ "lazy_static", "prost 0.13.5", "rand 0.9.0", + "serde", + "serde_json", "snafu 0.8.5", "tokio", "tokio-util", diff --git a/src/client/Cargo.toml b/src/client/Cargo.toml index f8702fe6ac..99d0c97806 100644 --- a/src/client/Cargo.toml +++ b/src/client/Cargo.toml @@ -16,6 +16,7 @@ arc-swap = "1.6" arrow-flight.workspace = true async-stream.workspace = true async-trait.workspace = true +base64.workspace = true common-catalog.workspace = true common-error.workspace = true common-grpc.workspace = true @@ -25,6 +26,7 @@ common-query.workspace = true common-recordbatch.workspace = true common-telemetry.workspace = true enum_dispatch = "0.3" +futures.workspace = true futures-util.workspace = true lazy_static.workspace = true moka = { workspace = true, features = ["future"] } diff --git a/src/client/src/database.rs b/src/client/src/database.rs index 2479240562..c9dc9b08e5 100644 --- a/src/client/src/database.rs +++ b/src/client/src/database.rs @@ -12,36 +12,49 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::pin::Pin; +use std::str::FromStr; + use api::v1::auth_header::AuthScheme; use api::v1::ddl_request::Expr as DdlExpr; use api::v1::greptime_database_client::GreptimeDatabaseClient; use api::v1::greptime_request::Request; use api::v1::query_request::Query; use api::v1::{ - AlterTableExpr, AuthHeader, CreateTableExpr, DdlRequest, GreptimeRequest, InsertRequests, - QueryRequest, RequestHeader, + AlterTableExpr, AuthHeader, Basic, CreateTableExpr, DdlRequest, GreptimeRequest, + InsertRequests, QueryRequest, RequestHeader, }; -use arrow_flight::Ticket; +use arrow_flight::{FlightData, Ticket}; use async_stream::stream; +use base64::prelude::BASE64_STANDARD; +use base64::Engine; +use common_catalog::build_db_string; +use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_error::ext::{BoxedError, ErrorExt}; +use common_grpc::flight::do_put::DoPutResponse; use common_grpc::flight::{FlightDecoder, FlightMessage}; use common_query::Output; use common_recordbatch::error::ExternalSnafu; use common_recordbatch::RecordBatchStreamWrapper; use common_telemetry::error; use common_telemetry::tracing_context::W3cTrace; -use futures_util::StreamExt; +use futures::future; +use futures_util::{Stream, StreamExt, TryStreamExt}; use prost::Message; use snafu::{ensure, ResultExt}; -use tonic::metadata::AsciiMetadataKey; +use tonic::metadata::{AsciiMetadataKey, MetadataValue}; use tonic::transport::Channel; use crate::error::{ ConvertFlightDataSnafu, Error, FlightGetSnafu, IllegalFlightMessagesSnafu, InvalidAsciiSnafu, - ServerSnafu, + InvalidTonicMetadataValueSnafu, ServerSnafu, }; use crate::{from_grpc_response, Client, Result}; +type FlightDataStream = Pin + Send>>; + +type DoPutResponseStream = Pin>>>; + #[derive(Clone, Debug, Default)] pub struct Database { // The "catalog" and "schema" to be used in processing the requests at the server side. @@ -108,16 +121,24 @@ impl Database { self.catalog = catalog.into(); } - pub fn catalog(&self) -> &String { - &self.catalog + fn catalog_or_default(&self) -> &str { + if self.catalog.is_empty() { + DEFAULT_CATALOG_NAME + } else { + &self.catalog + } } pub fn set_schema(&mut self, schema: impl Into) { self.schema = schema.into(); } - pub fn schema(&self) -> &String { - &self.schema + fn schema_or_default(&self) -> &str { + if self.schema.is_empty() { + DEFAULT_SCHEMA_NAME + } else { + &self.schema + } } pub fn set_timezone(&mut self, timezone: impl Into) { @@ -310,6 +331,41 @@ impl Database { } } } + + /// Ingest a stream of [RecordBatch]es that belong to a table, using Arrow Flight's "`DoPut`" + /// method. The return value is also a stream, produces [DoPutResponse]s. + pub async fn do_put(&self, stream: FlightDataStream) -> Result { + let mut request = tonic::Request::new(stream); + + if let Some(AuthHeader { + auth_scheme: Some(AuthScheme::Basic(Basic { username, password })), + }) = &self.ctx.auth_header + { + let encoded = BASE64_STANDARD.encode(format!("{username}:{password}")); + let value = + MetadataValue::from_str(&encoded).context(InvalidTonicMetadataValueSnafu)?; + request.metadata_mut().insert("x-greptime-auth", value); + } + + let db_to_put = if !self.dbname.is_empty() { + &self.dbname + } else { + &build_db_string(self.catalog_or_default(), self.schema_or_default()) + }; + request.metadata_mut().insert( + "x-greptime-db-name", + MetadataValue::from_str(db_to_put).context(InvalidTonicMetadataValueSnafu)?, + ); + + let mut client = self.client.make_flight_client()?; + let response = client.mut_inner().do_put(request).await?; + let response = response + .into_inner() + .map_err(Into::into) + .and_then(|x| future::ready(DoPutResponse::try_from(x).context(ConvertFlightDataSnafu))) + .boxed(); + Ok(response) + } } #[derive(Default, Debug, Clone)] diff --git a/src/client/src/error.rs b/src/client/src/error.rs index ed0c1b5ccb..3f680b1427 100644 --- a/src/client/src/error.rs +++ b/src/client/src/error.rs @@ -19,6 +19,7 @@ use common_error::status_code::{convert_tonic_code_to_status_code, StatusCode}; use common_error::{GREPTIME_DB_HEADER_ERROR_CODE, GREPTIME_DB_HEADER_ERROR_MSG}; use common_macro::stack_trace_debug; use snafu::{location, Location, Snafu}; +use tonic::metadata::errors::InvalidMetadataValue; use tonic::{Code, Status}; #[derive(Snafu)] @@ -115,6 +116,14 @@ pub enum Error { #[snafu(implicit)] location: Location, }, + + #[snafu(display("Invalid Tonic metadata value"))] + InvalidTonicMetadataValue { + #[snafu(source)] + error: InvalidMetadataValue, + #[snafu(implicit)] + location: Location, + }, } pub type Result = std::result::Result; @@ -135,7 +144,9 @@ impl ErrorExt for Error { | Error::CreateTlsChannel { source, .. } => source.status_code(), Error::IllegalGrpcClientState { .. } => StatusCode::Unexpected, - Error::InvalidAscii { .. } => StatusCode::InvalidArguments, + Error::InvalidAscii { .. } | Error::InvalidTonicMetadataValue { .. } => { + StatusCode::InvalidArguments + } } } diff --git a/src/client/src/lib.rs b/src/client/src/lib.rs index 125c185a5a..7078e71795 100644 --- a/src/client/src/lib.rs +++ b/src/client/src/lib.rs @@ -16,7 +16,7 @@ mod client; pub mod client_manager; -mod database; +pub mod database; pub mod error; pub mod flow; pub mod load_balance; diff --git a/src/common/grpc/Cargo.toml b/src/common/grpc/Cargo.toml index 4dadf0571b..f15d0761d1 100644 --- a/src/common/grpc/Cargo.toml +++ b/src/common/grpc/Cargo.toml @@ -23,6 +23,8 @@ flatbuffers = "24" hyper.workspace = true lazy_static.workspace = true prost.workspace = true +serde.workspace = true +serde_json.workspace = true snafu.workspace = true tokio.workspace = true tokio-util.workspace = true diff --git a/src/common/grpc/src/error.rs b/src/common/grpc/src/error.rs index d0ca7d970c..af194f2501 100644 --- a/src/common/grpc/src/error.rs +++ b/src/common/grpc/src/error.rs @@ -97,6 +97,14 @@ pub enum Error { #[snafu(display("Not supported: {}", feat))] NotSupported { feat: String }, + + #[snafu(display("Failed to serde Json"))] + SerdeJson { + #[snafu(source)] + error: serde_json::error::Error, + #[snafu(implicit)] + location: Location, + }, } impl ErrorExt for Error { @@ -110,7 +118,8 @@ impl ErrorExt for Error { Error::CreateChannel { .. } | Error::Conversion { .. } - | Error::DecodeFlightData { .. } => StatusCode::Internal, + | Error::DecodeFlightData { .. } + | Error::SerdeJson { .. } => StatusCode::Internal, Error::CreateRecordBatch { source, .. } => source.status_code(), Error::ConvertArrowSchema { source, .. } => source.status_code(), diff --git a/src/common/grpc/src/flight.rs b/src/common/grpc/src/flight.rs index 26f3676ce1..872897ccbf 100644 --- a/src/common/grpc/src/flight.rs +++ b/src/common/grpc/src/flight.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +pub mod do_put; + use std::collections::HashMap; use std::sync::Arc; diff --git a/src/common/grpc/src/flight/do_put.rs b/src/common/grpc/src/flight/do_put.rs new file mode 100644 index 0000000000..15011fc74b --- /dev/null +++ b/src/common/grpc/src/flight/do_put.rs @@ -0,0 +1,93 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use arrow_flight::PutResult; +use common_base::AffectedRows; +use serde::{Deserialize, Serialize}; +use snafu::ResultExt; + +use crate::error::{Error, SerdeJsonSnafu}; + +/// The metadata for "DoPut" requests and responses. +/// +/// Currently, there's only a "request_id", for coordinating requests and responses in the streams. +/// Client can set a unique request id in this metadata, and the server will return the same id in +/// the corresponding response. In doing so, a client can know how to do with its pending requests. +#[derive(Serialize, Deserialize)] +pub struct DoPutMetadata { + request_id: i64, +} + +impl DoPutMetadata { + pub fn new(request_id: i64) -> Self { + Self { request_id } + } + + pub fn request_id(&self) -> i64 { + self.request_id + } +} + +/// The response in the "DoPut" returned stream. +#[derive(Serialize, Deserialize)] +pub struct DoPutResponse { + /// The same "request_id" in the request; see the [DoPutMetadata]. + request_id: i64, + /// The successfully ingested rows number. + affected_rows: AffectedRows, +} + +impl DoPutResponse { + pub fn new(request_id: i64, affected_rows: AffectedRows) -> Self { + Self { + request_id, + affected_rows, + } + } + + pub fn request_id(&self) -> i64 { + self.request_id + } + + pub fn affected_rows(&self) -> AffectedRows { + self.affected_rows + } +} + +impl TryFrom for DoPutResponse { + type Error = Error; + + fn try_from(value: PutResult) -> Result { + serde_json::from_slice(&value.app_metadata).context(SerdeJsonSnafu) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_serde_do_put_metadata() { + let serialized = r#"{"request_id":42}"#; + let metadata = serde_json::from_str::(serialized).unwrap(); + assert_eq!(metadata.request_id(), 42); + } + + #[test] + fn test_serde_do_put_response() { + let x = DoPutResponse::new(42, 88); + let serialized = serde_json::to_string(&x).unwrap(); + assert_eq!(serialized, r#"{"request_id":42,"affected_rows":88}"#); + } +} diff --git a/src/frontend/src/instance/grpc.rs b/src/frontend/src/instance/grpc.rs index decd713555..915d884d7e 100644 --- a/src/frontend/src/instance/grpc.rs +++ b/src/frontend/src/instance/grpc.rs @@ -18,12 +18,13 @@ use api::v1::query_request::Query; use api::v1::{DeleteRequests, DropFlowExpr, InsertRequests, RowDeleteRequests, RowInsertRequests}; use async_trait::async_trait; use auth::{PermissionChecker, PermissionCheckerRef, PermissionReq}; +use common_base::AffectedRows; use common_query::Output; use common_telemetry::tracing::{self}; use datafusion::execution::SessionStateBuilder; use query::parser::PromQuery; use servers::interceptor::{GrpcQueryInterceptor, GrpcQueryInterceptorRef}; -use servers::query_handler::grpc::GrpcQueryHandler; +use servers::query_handler::grpc::{GrpcQueryHandler, RawRecordBatch}; use servers::query_handler::sql::SqlQueryHandler; use session::context::QueryContextRef; use snafu::{ensure, OptionExt, ResultExt}; @@ -31,8 +32,9 @@ use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan}; use table::table_name::TableName; use crate::error::{ - Error, InFlightWriteBytesExceededSnafu, IncompleteGrpcRequestSnafu, NotSupportedSnafu, - PermissionSnafu, Result, SubstraitDecodeLogicalPlanSnafu, TableOperationSnafu, + CatalogSnafu, Error, InFlightWriteBytesExceededSnafu, IncompleteGrpcRequestSnafu, + NotSupportedSnafu, PermissionSnafu, Result, SubstraitDecodeLogicalPlanSnafu, + TableNotFoundSnafu, TableOperationSnafu, }; use crate::instance::{attach_timer, Instance}; use crate::metrics::{ @@ -203,6 +205,34 @@ impl GrpcQueryHandler for Instance { let output = interceptor.post_execute(output, ctx)?; Ok(output) } + + async fn put_record_batch( + &self, + table: &TableName, + record_batch: RawRecordBatch, + ) -> Result { + let _table = self + .catalog_manager() + .table( + &table.catalog_name, + &table.schema_name, + &table.table_name, + None, + ) + .await + .context(CatalogSnafu)? + .with_context(|| TableNotFoundSnafu { + table_name: table.to_string(), + })?; + + // TODO(LFC): Implement it. + common_telemetry::debug!( + "calling put_record_batch with table: {:?} and record_batch size: {}", + table, + record_batch.len() + ); + Ok(record_batch.len()) + } } fn fill_catalog_and_schema_from_context(ddl_expr: &mut DdlExpr, ctx: &QueryContextRef) { diff --git a/src/servers/src/grpc/flight.rs b/src/servers/src/grpc/flight.rs index 76a6cc00ce..648cfff377 100644 --- a/src/servers/src/grpc/flight.rs +++ b/src/servers/src/grpc/flight.rs @@ -16,6 +16,7 @@ mod stream; use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; use api::v1::GreptimeRequest; use arrow_flight::flight_service_server::FlightService; @@ -24,21 +25,33 @@ use arrow_flight::{ HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket, }; use async_trait::async_trait; +use bytes::Bytes; +use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; +use common_catalog::parse_catalog_and_schema_from_db_string; +use common_grpc::flight::do_put::{DoPutMetadata, DoPutResponse}; use common_grpc::flight::{FlightEncoder, FlightMessage}; use common_query::{Output, OutputData}; use common_telemetry::tracing::info_span; use common_telemetry::tracing_context::{FutureExt, TracingContext}; -use futures::Stream; +use futures::{future, ready, Stream}; +use futures_util::{StreamExt, TryStreamExt}; use prost::Message; -use snafu::ResultExt; +use snafu::{ensure, ResultExt}; +use table::table_name::TableName; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; use tonic::{Request, Response, Status, Streaming}; use crate::error; +use crate::error::{InvalidParameterSnafu, ParseJsonSnafu, Result, ToJsonSnafu}; pub use crate::grpc::flight::stream::FlightRecordBatchStream; use crate::grpc::greptime_handler::{get_request_type, GreptimeRequestHandler}; use crate::grpc::TonicResult; +use crate::http::header::constants::GREPTIME_DB_HEADER_NAME; +use crate::http::AUTHORIZATION_HEADER; +use crate::query_handler::grpc::RawRecordBatch; -pub type TonicStream = Pin> + Send + Sync + 'static>>; +pub type TonicStream = Pin> + Send + 'static>>; /// A subset of [FlightService] #[async_trait] @@ -47,6 +60,14 @@ pub trait FlightCraft: Send + Sync + 'static { &self, request: Request, ) -> TonicResult>>; + + async fn do_put( + &self, + request: Request>, + ) -> TonicResult>> { + let _ = request; + Err(Status::unimplemented("Not yet implemented")) + } } pub type FlightCraftRef = Arc; @@ -67,6 +88,13 @@ impl FlightCraft for FlightCraftRef { ) -> TonicResult>> { (**self).do_get(request).await } + + async fn do_put( + &self, + request: Request>, + ) -> TonicResult>> { + self.as_ref().do_put(request).await + } } #[async_trait] @@ -120,9 +148,9 @@ impl FlightService for FlightCraftWrapper { async fn do_put( &self, - _: Request>, + request: Request>, ) -> TonicResult> { - Err(Status::unimplemented("Not yet implemented")) + self.0.do_put(request).await } type DoExchangeStream = TonicStream; @@ -168,13 +196,164 @@ impl FlightCraft for GreptimeRequestHandler { ); async { let output = self.handle_request(request, Default::default()).await?; - let stream: Pin> + Send + Sync>> = - to_flight_data_stream(output, TracingContext::from_current_span()); + let stream = to_flight_data_stream(output, TracingContext::from_current_span()); Ok(Response::new(stream)) } .trace(span) .await } + + async fn do_put( + &self, + request: Request>, + ) -> TonicResult>> { + let (headers, _, stream) = request.into_parts(); + + let header = |key: &str| -> TonicResult> { + let Some(v) = headers.get(key) else { + return Ok(None); + }; + let Ok(v) = std::str::from_utf8(v.as_bytes()) else { + return Err(InvalidParameterSnafu { + reason: "expect valid UTF-8 value", + } + .build() + .into()); + }; + Ok(Some(v)) + }; + + let username_and_password = header(AUTHORIZATION_HEADER)?; + let db = header(GREPTIME_DB_HEADER_NAME)?; + if !self.validate_auth(username_and_password, db).await? { + return Err(Status::unauthenticated("auth failed")); + } + + const MAX_PENDING_RESPONSES: usize = 32; + let (tx, rx) = mpsc::channel::>(MAX_PENDING_RESPONSES); + + let stream = PutRecordBatchRequestStream { + flight_data_stream: stream, + state: PutRecordBatchRequestStreamState::Init(db.map(ToString::to_string)), + }; + self.put_record_batches(stream, tx).await; + + let response = ReceiverStream::new(rx) + .and_then(|response| { + future::ready({ + serde_json::to_vec(&response) + .context(ToJsonSnafu) + .map(|x| PutResult { + app_metadata: Bytes::from(x), + }) + .map_err(Into::into) + }) + }) + .boxed(); + Ok(Response::new(response)) + } +} + +pub(crate) struct PutRecordBatchRequest { + pub(crate) table_name: TableName, + pub(crate) request_id: i64, + pub(crate) record_batch: RawRecordBatch, +} + +impl PutRecordBatchRequest { + fn try_new(table_name: TableName, flight_data: FlightData) -> Result { + let metadata: DoPutMetadata = + serde_json::from_slice(&flight_data.app_metadata).context(ParseJsonSnafu)?; + Ok(Self { + table_name, + request_id: metadata.request_id(), + record_batch: flight_data.data_body, + }) + } +} + +pub(crate) struct PutRecordBatchRequestStream { + flight_data_stream: Streaming, + state: PutRecordBatchRequestStreamState, +} + +enum PutRecordBatchRequestStreamState { + Init(Option), + Started(TableName), +} + +impl Stream for PutRecordBatchRequestStream { + type Item = TonicResult; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn extract_table_name(mut descriptor: FlightDescriptor) -> Result { + ensure!( + descriptor.r#type == arrow_flight::flight_descriptor::DescriptorType::Path as i32, + InvalidParameterSnafu { + reason: "expect FlightDescriptor::type == 'Path' only", + } + ); + ensure!( + descriptor.path.len() == 1, + InvalidParameterSnafu { + reason: "expect FlightDescriptor::path has only one table name", + } + ); + Ok(descriptor.path.remove(0)) + } + + let poll = ready!(self.flight_data_stream.poll_next_unpin(cx)); + + let result = match &mut self.state { + PutRecordBatchRequestStreamState::Init(db) => match poll { + Some(Ok(mut flight_data)) => { + let flight_descriptor = flight_data.flight_descriptor.take(); + let result = if let Some(descriptor) = flight_descriptor { + let table_name = extract_table_name(descriptor).map(|x| { + let (catalog, schema) = if let Some(db) = db { + parse_catalog_and_schema_from_db_string(db) + } else { + ( + DEFAULT_CATALOG_NAME.to_string(), + DEFAULT_SCHEMA_NAME.to_string(), + ) + }; + TableName::new(catalog, schema, x) + }); + let table_name = match table_name { + Ok(table_name) => table_name, + Err(e) => return Poll::Ready(Some(Err(e.into()))), + }; + + let request = + PutRecordBatchRequest::try_new(table_name.clone(), flight_data); + let request = match request { + Ok(request) => request, + Err(e) => return Poll::Ready(Some(Err(e.into()))), + }; + + self.state = PutRecordBatchRequestStreamState::Started(table_name); + + Ok(request) + } else { + Err(Status::failed_precondition( + "table to put is not found in flight descriptor", + )) + }; + Some(result) + } + Some(Err(e)) => Some(Err(e)), + None => None, + }, + PutRecordBatchRequestStreamState::Started(table_name) => poll.map(|x| { + x.and_then(|flight_data| { + PutRecordBatchRequest::try_new(table_name.clone(), flight_data) + .map_err(Into::into) + }) + }), + }; + Poll::Ready(result) + } } fn to_flight_data_stream( diff --git a/src/servers/src/grpc/greptime_handler.rs b/src/servers/src/grpc/greptime_handler.rs index b032ffc847..73e1a768c8 100644 --- a/src/servers/src/grpc/greptime_handler.rs +++ b/src/servers/src/grpc/greptime_handler.rs @@ -18,23 +18,33 @@ use std::time::Instant; use api::helper::request_type; use api::v1::auth_header::AuthScheme; -use api::v1::{Basic, GreptimeRequest, RequestHeader}; +use api::v1::{AuthHeader, Basic, GreptimeRequest, RequestHeader}; use auth::{Identity, Password, UserInfoRef, UserProviderRef}; +use base64::prelude::BASE64_STANDARD; +use base64::Engine; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_catalog::parse_catalog_and_schema_from_db_string; use common_error::ext::ErrorExt; use common_error::status_code::StatusCode; +use common_grpc::flight::do_put::DoPutResponse; use common_query::Output; use common_runtime::runtime::RuntimeTrait; use common_runtime::Runtime; use common_telemetry::tracing_context::{FutureExt, TracingContext}; -use common_telemetry::{debug, error, tracing}; +use common_telemetry::{debug, error, tracing, warn}; use common_time::timezone::parse_timezone; -use session::context::{QueryContextBuilder, QueryContextRef}; +use futures_util::StreamExt; +use session::context::{QueryContext, QueryContextBuilder, QueryContextRef}; use snafu::{OptionExt, ResultExt}; +use tokio::sync::mpsc; use crate::error::Error::UnsupportedAuthScheme; -use crate::error::{AuthSnafu, InvalidQuerySnafu, JoinTaskSnafu, NotFoundAuthHeaderSnafu, Result}; +use crate::error::{ + AuthSnafu, InvalidAuthHeaderInvalidUtf8ValueSnafu, InvalidBase64ValueSnafu, InvalidQuerySnafu, + JoinTaskSnafu, NotFoundAuthHeaderSnafu, Result, +}; +use crate::grpc::flight::{PutRecordBatchRequest, PutRecordBatchRequestStream}; +use crate::grpc::TonicResult; use crate::metrics::{METRIC_AUTH_FAILURE, METRIC_SERVER_GRPC_DB_REQUEST_TIMER}; use crate::query_handler::grpc::ServerGrpcQueryHandlerRef; @@ -118,6 +128,95 @@ impl GreptimeRequestHandler { None => result_future.await, } } + + pub(crate) async fn put_record_batches( + &self, + mut stream: PutRecordBatchRequestStream, + result_sender: mpsc::Sender>, + ) { + let handler = self.handler.clone(); + let runtime = self + .runtime + .clone() + .unwrap_or_else(common_runtime::global_runtime); + runtime.spawn(async move { + while let Some(request) = stream.next().await { + let request = match request { + Ok(request) => request, + Err(e) => { + let _ = result_sender.try_send(Err(e)); + break; + } + }; + + let PutRecordBatchRequest { + table_name, + request_id, + record_batch, + } = request; + let result = handler.put_record_batch(&table_name, record_batch).await; + let result = result + .map(|x| DoPutResponse::new(request_id, x)) + .map_err(Into::into); + if result_sender.try_send(result).is_err() { + warn!(r#""DoPut" client maybe unreachable, abort handling its message"#); + break; + } + } + }); + } + + pub(crate) async fn validate_auth( + &self, + username_and_password: Option<&str>, + db: Option<&str>, + ) -> Result { + if self.user_provider.is_none() { + return Ok(true); + } + + let username_and_password = username_and_password.context(NotFoundAuthHeaderSnafu)?; + let username_and_password = BASE64_STANDARD + .decode(username_and_password) + .context(InvalidBase64ValueSnafu) + .and_then(|x| String::from_utf8(x).context(InvalidAuthHeaderInvalidUtf8ValueSnafu))?; + + let mut split = username_and_password.splitn(2, ':'); + let (username, password) = match (split.next(), split.next()) { + (Some(username), Some(password)) => (username, password), + (Some(username), None) => (username, ""), + (None, None) => return Ok(false), + _ => unreachable!(), // because this iterator won't yield Some after None + }; + + let (catalog, schema) = if let Some(db) = db { + parse_catalog_and_schema_from_db_string(db) + } else { + ( + DEFAULT_CATALOG_NAME.to_string(), + DEFAULT_SCHEMA_NAME.to_string(), + ) + }; + let header = RequestHeader { + authorization: Some(AuthHeader { + auth_scheme: Some(AuthScheme::Basic(Basic { + username: username.to_string(), + password: password.to_string(), + })), + }), + catalog, + schema, + ..Default::default() + }; + + Ok(auth( + self.user_provider.clone(), + Some(&header), + &QueryContext::arc(), + ) + .await + .is_ok()) + } } pub fn get_request_type(request: &GreptimeRequest) -> &'static str { diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index 07d1c7ea24..bff38d980f 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -1169,7 +1169,6 @@ mod test { use std::io::Cursor; use std::sync::Arc; - use api::v1::greptime_request::Request; use arrow_ipc::reader::FileReader; use arrow_schema::DataType; use axum::handler::Handler; @@ -1191,26 +1190,12 @@ mod test { use super::*; use crate::error::Error; use crate::http::test_helpers::TestClient; - use crate::query_handler::grpc::GrpcQueryHandler; use crate::query_handler::sql::{ServerSqlQueryHandlerAdapter, SqlQueryHandler}; struct DummyInstance { _tx: mpsc::Sender<(String, Vec)>, } - #[async_trait] - impl GrpcQueryHandler for DummyInstance { - type Error = Error; - - async fn do_query( - &self, - _query: Request, - _ctx: QueryContextRef, - ) -> std::result::Result { - unimplemented!() - } - } - #[async_trait] impl SqlQueryHandler for DummyInstance { type Error = Error; diff --git a/src/servers/src/query_handler/grpc.rs b/src/servers/src/query_handler/grpc.rs index 01464012d6..7af5c9935a 100644 --- a/src/servers/src/query_handler/grpc.rs +++ b/src/servers/src/query_handler/grpc.rs @@ -16,16 +16,20 @@ use std::sync::Arc; use api::v1::greptime_request::Request; use async_trait::async_trait; +use common_base::AffectedRows; use common_error::ext::{BoxedError, ErrorExt}; use common_query::Output; use session::context::QueryContextRef; use snafu::ResultExt; +use table::table_name::TableName; use crate::error::{self, Result}; pub type GrpcQueryHandlerRef = Arc + Send + Sync>; pub type ServerGrpcQueryHandlerRef = GrpcQueryHandlerRef; +pub type RawRecordBatch = bytes::Bytes; + #[async_trait] pub trait GrpcQueryHandler { type Error: ErrorExt; @@ -35,6 +39,12 @@ pub trait GrpcQueryHandler { query: Request, ctx: QueryContextRef, ) -> std::result::Result; + + async fn put_record_batch( + &self, + table: &TableName, + record_batch: RawRecordBatch, + ) -> std::result::Result; } pub struct ServerGrpcQueryHandlerAdapter(GrpcQueryHandlerRef); @@ -59,4 +69,16 @@ where .map_err(BoxedError::new) .context(error::ExecuteGrpcQuerySnafu) } + + async fn put_record_batch( + &self, + table: &TableName, + record_batch: RawRecordBatch, + ) -> Result { + self.0 + .put_record_batch(table, record_batch) + .await + .map_err(BoxedError::new) + .context(error::ExecuteGrpcRequestSnafu) + } } diff --git a/src/servers/tests/http/influxdb_test.rs b/src/servers/tests/http/influxdb_test.rs index 1a251763ae..93932252fb 100644 --- a/src/servers/tests/http/influxdb_test.rs +++ b/src/servers/tests/http/influxdb_test.rs @@ -14,7 +14,6 @@ use std::sync::Arc; -use api::v1::greptime_request::Request; use api::v1::RowInsertRequests; use async_trait::async_trait; use auth::tests::{DatabaseAuthInfo, MockUserProvider}; @@ -29,7 +28,6 @@ use servers::http::header::constants::GREPTIME_DB_HEADER_NAME; use servers::http::test_helpers::TestClient; use servers::http::{HttpOptions, HttpServerBuilder}; use servers::influxdb::InfluxdbRequest; -use servers::query_handler::grpc::GrpcQueryHandler; use servers::query_handler::sql::SqlQueryHandler; use servers::query_handler::InfluxdbLineProtocolHandler; use session::context::QueryContextRef; @@ -39,19 +37,6 @@ struct DummyInstance { tx: Arc>, } -#[async_trait] -impl GrpcQueryHandler for DummyInstance { - type Error = Error; - - async fn do_query( - &self, - _query: Request, - _ctx: QueryContextRef, - ) -> std::result::Result { - unimplemented!() - } -} - #[async_trait] impl InfluxdbLineProtocolHandler for DummyInstance { async fn exec(&self, request: InfluxdbRequest, ctx: QueryContextRef) -> Result { diff --git a/src/servers/tests/http/opentsdb_test.rs b/src/servers/tests/http/opentsdb_test.rs index 358af19dc8..6ac835e72d 100644 --- a/src/servers/tests/http/opentsdb_test.rs +++ b/src/servers/tests/http/opentsdb_test.rs @@ -14,7 +14,6 @@ use std::sync::Arc; -use api::v1::greptime_request::Request; use async_trait::async_trait; use axum::Router; use common_query::Output; @@ -26,7 +25,6 @@ use servers::error::{self, Result}; use servers::http::test_helpers::TestClient; use servers::http::{HttpOptions, HttpServerBuilder}; use servers::opentsdb::codec::DataPoint; -use servers::query_handler::grpc::GrpcQueryHandler; use servers::query_handler::sql::SqlQueryHandler; use servers::query_handler::OpentsdbProtocolHandler; use session::context::QueryContextRef; @@ -36,19 +34,6 @@ struct DummyInstance { tx: mpsc::Sender, } -#[async_trait] -impl GrpcQueryHandler for DummyInstance { - type Error = crate::Error; - - async fn do_query( - &self, - _query: Request, - _ctx: QueryContextRef, - ) -> std::result::Result { - unimplemented!() - } -} - #[async_trait] impl OpentsdbProtocolHandler for DummyInstance { async fn exec(&self, data_points: Vec, _ctx: QueryContextRef) -> Result { diff --git a/src/servers/tests/http/prom_store_test.rs b/src/servers/tests/http/prom_store_test.rs index 77a06db079..c8c5671b8c 100644 --- a/src/servers/tests/http/prom_store_test.rs +++ b/src/servers/tests/http/prom_store_test.rs @@ -17,7 +17,6 @@ use std::sync::Arc; use api::prom_store::remote::{ LabelMatcher, Query, QueryResult, ReadRequest, ReadResponse, WriteRequest, }; -use api::v1::greptime_request::Request; use api::v1::RowInsertRequests; use async_trait::async_trait; use axum::Router; @@ -33,7 +32,6 @@ use servers::http::test_helpers::TestClient; use servers::http::{HttpOptions, HttpServerBuilder}; use servers::prom_store; use servers::prom_store::{snappy_compress, Metrics}; -use servers::query_handler::grpc::GrpcQueryHandler; use servers::query_handler::sql::SqlQueryHandler; use servers::query_handler::{PromStoreProtocolHandler, PromStoreResponse}; use session::context::QueryContextRef; @@ -43,19 +41,6 @@ struct DummyInstance { tx: mpsc::Sender<(String, Vec)>, } -#[async_trait] -impl GrpcQueryHandler for DummyInstance { - type Error = Error; - - async fn do_query( - &self, - _query: Request, - _ctx: QueryContextRef, - ) -> std::result::Result { - unimplemented!() - } -} - #[async_trait] impl PromStoreProtocolHandler for DummyInstance { async fn write( diff --git a/src/servers/tests/mod.rs b/src/servers/tests/mod.rs index 13c78a293f..43aeb362fa 100644 --- a/src/servers/tests/mod.rs +++ b/src/servers/tests/mod.rs @@ -18,6 +18,7 @@ use api::v1::greptime_request::Request; use api::v1::query_request::Query; use async_trait::async_trait; use catalog::memory::MemoryCatalogManager; +use common_base::AffectedRows; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_query::Output; use datafusion_expr::LogicalPlan; @@ -26,11 +27,12 @@ use query::parser::{PromQuery, QueryLanguageParser, QueryStatement}; use query::query_engine::DescribeResult; use query::{QueryEngineFactory, QueryEngineRef}; use servers::error::{Error, NotSupportedSnafu, Result}; -use servers::query_handler::grpc::{GrpcQueryHandler, ServerGrpcQueryHandlerRef}; +use servers::query_handler::grpc::{GrpcQueryHandler, RawRecordBatch, ServerGrpcQueryHandlerRef}; use servers::query_handler::sql::{ServerSqlQueryHandlerRef, SqlQueryHandler}; use session::context::QueryContextRef; use snafu::ensure; use sql::statements::statement::Statement; +use table::table_name::TableName; use table::TableRef; mod grpc; @@ -155,6 +157,16 @@ impl GrpcQueryHandler for DummyInstance { }; Ok(output) } + + async fn put_record_batch( + &self, + table: &TableName, + record_batch: RawRecordBatch, + ) -> std::result::Result { + let _ = table; + let _ = record_batch; + unimplemented!() + } } fn create_testing_instance(table: TableRef) -> DummyInstance { diff --git a/src/table/src/error.rs b/src/table/src/error.rs index ef08ebc4a1..6cd79fd61c 100644 --- a/src/table/src/error.rs +++ b/src/table/src/error.rs @@ -172,6 +172,9 @@ pub enum Error { #[snafu(implicit)] location: Location, }, + + #[snafu(display("Invalid table name: '{s}'"))] + InvalidTableName { s: String }, } impl ErrorExt for Error { @@ -197,7 +200,8 @@ impl ErrorExt for Error { Error::MissingTimeIndexColumn { .. } => StatusCode::IllegalState, Error::InvalidTableOptionValue { .. } | Error::SetSkippingOptions { .. } - | Error::UnsetSkippingOptions { .. } => StatusCode::InvalidArguments, + | Error::UnsetSkippingOptions { .. } + | Error::InvalidTableName { .. } => StatusCode::InvalidArguments, } } diff --git a/src/table/src/table_name.rs b/src/table/src/table_name.rs index f999e013f2..d2d1c1e48b 100644 --- a/src/table/src/table_name.rs +++ b/src/table/src/table_name.rs @@ -15,8 +15,12 @@ use std::fmt::{Display, Formatter}; use api::v1::TableName as PbTableName; +use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use serde::{Deserialize, Serialize}; +use snafu::ensure; +use crate::error; +use crate::error::InvalidTableNameSnafu; use crate::table_reference::TableReference; #[derive(Debug, Clone, Hash, Eq, PartialEq, Deserialize, Serialize)] @@ -83,3 +87,37 @@ impl From> for TableName { Self::new(table_ref.catalog, table_ref.schema, table_ref.table) } } + +impl TryFrom> for TableName { + type Error = error::Error; + + fn try_from(v: Vec) -> Result { + ensure!( + !v.is_empty() && v.len() <= 3, + InvalidTableNameSnafu { + s: format!("{v:?}") + } + ); + let mut v = v.into_iter(); + match (v.next(), v.next(), v.next()) { + (Some(catalog_name), Some(schema_name), Some(table_name)) => Ok(Self { + catalog_name, + schema_name, + table_name, + }), + (Some(schema_name), Some(table_name), None) => Ok(Self { + catalog_name: DEFAULT_CATALOG_NAME.to_string(), + schema_name, + table_name, + }), + (Some(table_name), None, None) => Ok(Self { + catalog_name: DEFAULT_CATALOG_NAME.to_string(), + schema_name: DEFAULT_SCHEMA_NAME.to_string(), + table_name, + }), + // Unreachable because it's ensured that "v" is not empty, + // and its iterator will not yield `Some` after `None`. + _ => unreachable!(), + } + } +} diff --git a/tests-integration/src/cluster.rs b/tests-integration/src/cluster.rs index c1159e18d5..836c3a5483 100644 --- a/tests-integration/src/cluster.rs +++ b/tests-integration/src/cluster.rs @@ -67,13 +67,12 @@ use tower::service_fn; use uuid::Uuid; use crate::test_util::{ - self, create_datanode_opts, create_tmp_dir_and_datanode_opts, FileDirGuard, StorageGuard, - StorageType, PEER_PLACEHOLDER_ADDR, + self, create_datanode_opts, create_tmp_dir_and_datanode_opts, FileDirGuard, StorageType, + TestGuard, PEER_PLACEHOLDER_ADDR, }; pub struct GreptimeDbCluster { - pub storage_guards: Vec, - pub dir_guards: Vec, + pub guards: Vec, pub datanode_options: Vec, pub datanode_instances: HashMap, @@ -177,8 +176,7 @@ impl GreptimeDbClusterBuilder { pub async fn build_with( &self, datanode_options: Vec, - storage_guards: Vec, - dir_guards: Vec, + guards: Vec, ) -> GreptimeDbCluster { let datanodes = datanode_options.len(); let channel_config = ChannelConfig::new().timeout(Duration::from_secs(20)); @@ -224,8 +222,7 @@ impl GreptimeDbClusterBuilder { GreptimeDbCluster { datanode_options, - storage_guards, - dir_guards, + guards, datanode_instances, kv_backend: self.kv_backend.clone(), metasrv: metasrv.metasrv, @@ -235,19 +232,16 @@ impl GreptimeDbClusterBuilder { pub async fn build(&self) -> GreptimeDbCluster { let datanodes = self.datanodes.unwrap_or(4); - let (datanode_options, storage_guards, dir_guards) = - self.build_datanode_options_and_guards(datanodes).await; - self.build_with(datanode_options, storage_guards, dir_guards) - .await + let (datanode_options, guards) = self.build_datanode_options_and_guards(datanodes).await; + self.build_with(datanode_options, guards).await } async fn build_datanode_options_and_guards( &self, datanodes: u32, - ) -> (Vec, Vec, Vec) { + ) -> (Vec, Vec) { let mut options = Vec::with_capacity(datanodes as usize); - let mut storage_guards = Vec::with_capacity(datanodes as usize); - let mut dir_guards = Vec::with_capacity(datanodes as usize); + let mut guards = Vec::with_capacity(datanodes as usize); for i in 0..datanodes { let datanode_id = i as u64 + 1; @@ -257,7 +251,10 @@ impl GreptimeDbClusterBuilder { } else { let home_tmp_dir = create_temp_dir(&format!("gt_home_{}", &self.cluster_name)); let home_dir = home_tmp_dir.path().to_str().unwrap().to_string(); - dir_guards.push(FileDirGuard::new(home_tmp_dir)); + guards.push(TestGuard { + home_guard: FileDirGuard::new(home_tmp_dir), + storage_guards: Vec::new(), + }); home_dir }; @@ -275,9 +272,7 @@ impl GreptimeDbClusterBuilder { &format!("{}-dn-{}", self.cluster_name, datanode_id), self.datanode_wal_config.clone(), ); - - storage_guards.push(guard.storage_guards); - dir_guards.push(guard.home_guard); + guards.push(guard); opts }; @@ -285,11 +280,7 @@ impl GreptimeDbClusterBuilder { options.push(opts); } - ( - options, - storage_guards.into_iter().flatten().collect(), - dir_guards, - ) + (options, guards) } async fn build_datanodes_with_options( diff --git a/tests-integration/src/grpc.rs b/tests-integration/src/grpc.rs index 501c41d0c8..d09bbc3761 100644 --- a/tests-integration/src/grpc.rs +++ b/tests-integration/src/grpc.rs @@ -12,6 +12,33 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod flight; + +use api::v1::greptime_request::Request; +use api::v1::query_request::Query; +use api::v1::QueryRequest; +use common_query::OutputData; +use common_recordbatch::RecordBatches; +use frontend::instance::Instance; +use servers::query_handler::grpc::GrpcQueryHandler; +use session::context::QueryContext; + +#[allow(unused)] +async fn query_and_expect(instance: &Instance, sql: &str, expected: &str) { + let request = Request::Query(QueryRequest { + query: Some(Query::Sql(sql.to_string())), + }); + let output = GrpcQueryHandler::do_query(instance, request, QueryContext::arc()) + .await + .unwrap(); + let OutputData::Stream(stream) = output.data else { + unreachable!() + }; + let recordbatches = RecordBatches::try_collect(stream).await.unwrap(); + let actual = recordbatches.pretty_print().unwrap(); + assert_eq!(actual, expected, "actual: {}", actual); +} + #[cfg(test)] mod test { use std::collections::HashMap; @@ -41,6 +68,7 @@ mod test { use store_api::storage::RegionId; use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan}; + use super::*; use crate::standalone::GreptimeDbStandaloneBuilder; use crate::tests; use crate::tests::MockDistributedInstance; @@ -219,24 +247,14 @@ mod test { let output = query(instance, request).await; assert!(matches!(output.data, OutputData::AffectedRows(1))); - let request = Request::Query(QueryRequest { - query: Some(Query::Sql( - "SELECT ts, a, b FROM database_created_through_grpc.table_created_through_grpc" - .to_string(), - )), - }); - let output = query(instance, request).await; - let OutputData::Stream(stream) = output.data else { - unreachable!() - }; - let recordbatches = RecordBatches::try_collect(stream).await.unwrap(); + let sql = "SELECT ts, a, b FROM database_created_through_grpc.table_created_through_grpc"; let expected = "\ +---------------------+---+---+ | ts | a | b | +---------------------+---+---+ | 2023-01-04T07:14:26 | s | 1 | +---------------------+---+---+"; - assert_eq!(recordbatches.pretty_print().unwrap(), expected); + query_and_expect(instance, sql, expected).await; let request = Request::Ddl(DdlRequest { expr: Some(DdlExpr::DropTable(DropTableExpr { @@ -323,24 +341,14 @@ mod test { let output = query(instance, request).await; assert!(matches!(output.data, OutputData::AffectedRows(1))); - let request = Request::Query(QueryRequest { - query: Some(Query::Sql( - "SELECT ts, a, b FROM database_created_through_grpc.table_created_through_grpc" - .to_string(), - )), - }); - let output = query(instance, request).await; - let OutputData::Stream(stream) = output.data else { - unreachable!() - }; - let recordbatches = RecordBatches::try_collect(stream).await.unwrap(); + let sql = "SELECT ts, a, b FROM database_created_through_grpc.table_created_through_grpc"; let expected = "\ +---------------------+---+---+ | ts | a | b | +---------------------+---+---+ | 2023-01-04T07:14:26 | s | 1 | +---------------------+---+---+"; - assert_eq!(recordbatches.pretty_print().unwrap(), expected); + query_and_expect(instance, sql, expected).await; let request = Request::Ddl(DdlRequest { expr: Some(DdlExpr::DropTable(DropTableExpr { diff --git a/tests-integration/src/grpc/flight.rs b/tests-integration/src/grpc/flight.rs new file mode 100644 index 0000000000..e97165f16c --- /dev/null +++ b/tests-integration/src/grpc/flight.rs @@ -0,0 +1,242 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[cfg(test)] +mod test { + use std::net::SocketAddr; + use std::sync::Arc; + use std::time::Duration; + + use api::v1::auth_header::AuthScheme; + use api::v1::{Basic, ColumnDataType, ColumnDef, CreateTableExpr, SemanticType}; + use arrow_flight::FlightDescriptor; + use auth::user_provider_from_option; + use client::{Client, Database}; + use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; + use common_grpc::flight::do_put::DoPutMetadata; + use common_grpc::flight::{FlightEncoder, FlightMessage}; + use common_query::OutputData; + use common_recordbatch::RecordBatch; + use datatypes::prelude::{ConcreteDataType, ScalarVector, VectorRef}; + use datatypes::schema::{ColumnSchema, Schema}; + use datatypes::vectors::{Int32Vector, StringVector, TimestampMillisecondVector}; + use futures_util::StreamExt; + use itertools::Itertools; + use servers::grpc::builder::GrpcServerBuilder; + use servers::grpc::greptime_handler::GreptimeRequestHandler; + use servers::grpc::GrpcServerConfig; + use servers::query_handler::grpc::ServerGrpcQueryHandlerAdapter; + use servers::server::Server; + + use crate::cluster::GreptimeDbClusterBuilder; + use crate::grpc::query_and_expect; + use crate::test_util::{setup_grpc_server, StorageType}; + use crate::tests::test_util::MockInstance; + + #[tokio::test(flavor = "multi_thread")] + async fn test_standalone_flight_do_put() { + common_telemetry::init_default_ut_logging(); + + let (addr, db, _server) = + setup_grpc_server(StorageType::File, "test_standalone_flight_do_put").await; + + let client = Client::with_urls(vec![addr]); + let client = Database::new_with_dbname("greptime-public", client); + + create_table(&client).await; + + let record_batches = create_record_batches(1); + test_put_record_batches(&client, record_batches).await; + + let sql = "select ts, a, b from foo order by ts"; + let expected = "\ +++ +++"; + query_and_expect(db.frontend().as_ref(), sql, expected).await; + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_distributed_flight_do_put() { + common_telemetry::init_default_ut_logging(); + + let db = GreptimeDbClusterBuilder::new("test_distributed_flight_do_put") + .await + .build() + .await; + + let runtime = common_runtime::global_runtime().clone(); + let greptime_request_handler = GreptimeRequestHandler::new( + ServerGrpcQueryHandlerAdapter::arc(db.frontend.instance.clone()), + user_provider_from_option( + &"static_user_provider:cmd:greptime_user=greptime_pwd".to_string(), + ) + .ok(), + Some(runtime.clone()), + ); + let grpc_server = GrpcServerBuilder::new(GrpcServerConfig::default(), runtime) + .flight_handler(Arc::new(greptime_request_handler)) + .build(); + let addr = grpc_server + .start("127.0.0.1:0".parse::().unwrap()) + .await + .unwrap() + .to_string(); + + // wait for GRPC server to start + tokio::time::sleep(Duration::from_secs(1)).await; + + let client = Client::with_urls(vec![addr]); + let mut client = Database::new(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, client); + client.set_auth(AuthScheme::Basic(Basic { + username: "greptime_user".to_string(), + password: "greptime_pwd".to_string(), + })); + + create_table(&client).await; + + let record_batches = create_record_batches(1); + test_put_record_batches(&client, record_batches).await; + + let sql = "select ts, a, b from foo order by ts"; + let expected = "\ +++ +++"; + query_and_expect(db.fe_instance().as_ref(), sql, expected).await; + } + + async fn test_put_record_batches(client: &Database, record_batches: Vec) { + let requests_count = record_batches.len(); + + let stream = tokio_stream::iter(record_batches) + .enumerate() + .map(|(i, x)| { + let mut encoder = FlightEncoder::default(); + let message = FlightMessage::Recordbatch(x); + let mut data = encoder.encode(message); + + let metadata = DoPutMetadata::new(i as i64); + data.app_metadata = serde_json::to_vec(&metadata).unwrap().into(); + + // first message in "DoPut" stream should carry table name in flight descriptor + if i == 0 { + data.flight_descriptor = Some(FlightDescriptor { + r#type: arrow_flight::flight_descriptor::DescriptorType::Path as i32, + path: vec!["foo".to_string()], + ..Default::default() + }); + } + data + }) + .boxed(); + + let response_stream = client.do_put(stream).await.unwrap(); + + let responses = response_stream.collect::>().await; + let responses_count = responses.len(); + for (i, response) in responses.into_iter().enumerate() { + assert!(response.is_ok(), "{}", response.err().unwrap()); + let response = response.unwrap(); + assert_eq!(response.request_id(), i as i64); + assert_eq!(response.affected_rows(), 448); + } + assert_eq!(requests_count, responses_count); + } + + fn create_record_batches(start: i64) -> Vec { + let schema = Arc::new(Schema::new(vec![ + ColumnSchema::new( + "ts", + ConcreteDataType::timestamp_millisecond_datatype(), + false, + ) + .with_time_index(true), + ColumnSchema::new("a", ConcreteDataType::int32_datatype(), false), + ColumnSchema::new("b", ConcreteDataType::string_datatype(), true), + ])); + + let mut record_batches = Vec::with_capacity(3); + for chunk in &(start..start + 9).chunks(3) { + let vs = chunk.collect_vec(); + let x1 = vs[0]; + let x2 = vs[1]; + let x3 = vs[2]; + + record_batches.push( + RecordBatch::new( + schema.clone(), + vec![ + Arc::new(TimestampMillisecondVector::from_vec(vec![x1, x2, x3])) + as VectorRef, + Arc::new(Int32Vector::from_vec(vec![ + -x1 as i32, -x2 as i32, -x3 as i32, + ])), + Arc::new(StringVector::from_vec(vec![ + format!("s{x1}"), + format!("s{x2}"), + format!("s{x3}"), + ])), + ], + ) + .unwrap(), + ); + } + record_batches + } + + async fn create_table(client: &Database) { + // create table foo ( + // ts timestamp time index, + // a int primary key, + // b string, + // ) + let output = client + .create(CreateTableExpr { + schema_name: "public".to_string(), + table_name: "foo".to_string(), + column_defs: vec![ + ColumnDef { + name: "ts".to_string(), + data_type: ColumnDataType::TimestampMillisecond as i32, + semantic_type: SemanticType::Timestamp as i32, + is_nullable: false, + ..Default::default() + }, + ColumnDef { + name: "a".to_string(), + data_type: ColumnDataType::Int32 as i32, + semantic_type: SemanticType::Tag as i32, + is_nullable: false, + ..Default::default() + }, + ColumnDef { + name: "b".to_string(), + data_type: ColumnDataType::String as i32, + semantic_type: SemanticType::Field as i32, + is_nullable: true, + ..Default::default() + }, + ], + time_index: "ts".to_string(), + primary_keys: vec!["a".to_string()], + engine: "mito".to_string(), + ..Default::default() + }) + .await + .unwrap(); + let OutputData::AffectedRows(affected_rows) = output.data else { + unreachable!() + }; + assert_eq!(affected_rows, 0); + } +} diff --git a/tests-integration/src/test_util.rs b/tests-integration/src/test_util.rs index 4b31b941b6..b842c9412c 100644 --- a/tests-integration/src/test_util.rs +++ b/tests-integration/src/test_util.rs @@ -299,6 +299,34 @@ impl TestGuard { } } +impl Drop for TestGuard { + fn drop(&mut self) { + let (tx, rx) = std::sync::mpsc::channel(); + + let guards = std::mem::take(&mut self.storage_guards); + common_runtime::spawn_global(async move { + let mut errors = vec![]; + for guard in guards { + if let TempDirGuard::S3(guard) + | TempDirGuard::Oss(guard) + | TempDirGuard::Azblob(guard) + | TempDirGuard::Gcs(guard) = guard.0 + { + if let Err(e) = guard.remove_all().await { + errors.push(e); + } + } + } + if errors.is_empty() { + tx.send(Ok(())).unwrap(); + } else { + tx.send(Err(errors)).unwrap(); + } + }); + rx.recv().unwrap().unwrap_or_else(|e| panic!("{:?}", e)); + } +} + pub fn create_tmp_dir_and_datanode_opts( default_store_type: StorageType, store_provider_types: Vec, @@ -504,7 +532,7 @@ pub async fn setup_test_prom_app_with_frontend( pub async fn setup_grpc_server( store_type: StorageType, name: &str, -) -> (String, TestGuard, Arc) { +) -> (String, GreptimeDbStandalone, Arc) { setup_grpc_server_with(store_type, name, None, None).await } @@ -512,7 +540,7 @@ pub async fn setup_grpc_server_with_user_provider( store_type: StorageType, name: &str, user_provider: Option, -) -> (String, TestGuard, Arc) { +) -> (String, GreptimeDbStandalone, Arc) { setup_grpc_server_with(store_type, name, user_provider, None).await } @@ -521,7 +549,7 @@ pub async fn setup_grpc_server_with( name: &str, user_provider: Option, grpc_config: Option, -) -> (String, TestGuard, Arc) { +) -> (String, GreptimeDbStandalone, Arc) { let instance = setup_standalone_instance(name, store_type).await; let runtime: Runtime = RuntimeBuilder::default() @@ -560,7 +588,7 @@ pub async fn setup_grpc_server_with( // wait for GRPC server to start tokio::time::sleep(Duration::from_secs(1)).await; - (fe_grpc_addr, instance.guard, fe_grpc_server) + (fe_grpc_addr, instance, fe_grpc_server) } pub async fn setup_mysql_server( diff --git a/tests-integration/src/tests/test_util.rs b/tests-integration/src/tests/test_util.rs index 605ed2c178..d8df68afaa 100644 --- a/tests-integration/src/tests/test_util.rs +++ b/tests-integration/src/tests/test_util.rs @@ -126,17 +126,12 @@ impl MockInstanceBuilder { unreachable!() }; let GreptimeDbCluster { - storage_guards, - dir_guards, + guards, datanode_options, .. } = instance; - MockInstanceImpl::Distributed( - builder - .build_with(datanode_options, storage_guards, dir_guards) - .await, - ) + MockInstanceImpl::Distributed(builder.build_with(datanode_options, guards).await) } } } diff --git a/tests-integration/tests/grpc.rs b/tests-integration/tests/grpc.rs index 11db34acb8..0a7fffa82d 100644 --- a/tests-integration/tests/grpc.rs +++ b/tests-integration/tests/grpc.rs @@ -90,8 +90,7 @@ macro_rules! grpc_tests { } pub async fn test_invalid_dbname(store_type: StorageType) { - let (addr, mut guard, fe_grpc_server) = - setup_grpc_server(store_type, "auto_create_table").await; + let (addr, _db, fe_grpc_server) = setup_grpc_server(store_type, "test_invalid_dbname").await; let grpc_client = Client::with_urls(vec![addr]); let db = Database::new_with_dbname("tom", grpc_client); @@ -115,12 +114,10 @@ pub async fn test_invalid_dbname(store_type: StorageType) { assert!(result.is_err()); let _ = fe_grpc_server.shutdown().await; - guard.remove_all().await; } pub async fn test_dbname(store_type: StorageType) { - let (addr, mut guard, fe_grpc_server) = - setup_grpc_server(store_type, "auto_create_table").await; + let (addr, _db, fe_grpc_server) = setup_grpc_server(store_type, "test_dbname").await; let grpc_client = Client::with_urls(vec![addr]); let db = Database::new_with_dbname( @@ -129,7 +126,6 @@ pub async fn test_dbname(store_type: StorageType) { ); insert_and_assert(&db).await; let _ = fe_grpc_server.shutdown().await; - guard.remove_all().await; } pub async fn test_grpc_message_size_ok(store_type: StorageType) { @@ -138,8 +134,8 @@ pub async fn test_grpc_message_size_ok(store_type: StorageType) { max_send_message_size: 1024, ..Default::default() }; - let (addr, mut guard, fe_grpc_server) = - setup_grpc_server_with(store_type, "auto_create_table", None, Some(config)).await; + let (addr, _db, fe_grpc_server) = + setup_grpc_server_with(store_type, "test_grpc_message_size_ok", None, Some(config)).await; let grpc_client = Client::with_urls(vec![addr]); let db = Database::new_with_dbname( @@ -148,7 +144,6 @@ pub async fn test_grpc_message_size_ok(store_type: StorageType) { ); db.sql("show tables;").await.unwrap(); let _ = fe_grpc_server.shutdown().await; - guard.remove_all().await; } pub async fn test_grpc_zstd_compression(store_type: StorageType) { @@ -158,8 +153,8 @@ pub async fn test_grpc_zstd_compression(store_type: StorageType) { max_send_message_size: 1024, ..Default::default() }; - let (addr, mut guard, fe_grpc_server) = - setup_grpc_server_with(store_type, "auto_create_table", None, Some(config)).await; + let (addr, _db, fe_grpc_server) = + setup_grpc_server_with(store_type, "test_grpc_zstd_compression", None, Some(config)).await; let grpc_client = Client::with_urls(vec![addr]); let db = Database::new_with_dbname( @@ -168,7 +163,6 @@ pub async fn test_grpc_zstd_compression(store_type: StorageType) { ); db.sql("show tables;").await.unwrap(); let _ = fe_grpc_server.shutdown().await; - guard.remove_all().await; } pub async fn test_grpc_message_size_limit_send(store_type: StorageType) { @@ -177,8 +171,13 @@ pub async fn test_grpc_message_size_limit_send(store_type: StorageType) { max_send_message_size: 50, ..Default::default() }; - let (addr, mut guard, fe_grpc_server) = - setup_grpc_server_with(store_type, "auto_create_table", None, Some(config)).await; + let (addr, _db, fe_grpc_server) = setup_grpc_server_with( + store_type, + "test_grpc_message_size_limit_send", + None, + Some(config), + ) + .await; let grpc_client = Client::with_urls(vec![addr]); let db = Database::new_with_dbname( @@ -188,7 +187,6 @@ pub async fn test_grpc_message_size_limit_send(store_type: StorageType) { let err_msg = db.sql("show tables;").await.unwrap_err().to_string(); assert!(err_msg.contains("message length too large"), "{}", err_msg); let _ = fe_grpc_server.shutdown().await; - guard.remove_all().await; } pub async fn test_grpc_message_size_limit_recv(store_type: StorageType) { @@ -197,8 +195,13 @@ pub async fn test_grpc_message_size_limit_recv(store_type: StorageType) { max_send_message_size: 1024, ..Default::default() }; - let (addr, mut guard, fe_grpc_server) = - setup_grpc_server_with(store_type, "auto_create_table", None, Some(config)).await; + let (addr, _db, fe_grpc_server) = setup_grpc_server_with( + store_type, + "test_grpc_message_size_limit_recv", + None, + Some(config), + ) + .await; let grpc_client = Client::with_urls(vec![addr]); let db = Database::new_with_dbname( @@ -212,7 +215,6 @@ pub async fn test_grpc_message_size_limit_recv(store_type: StorageType) { err_msg ); let _ = fe_grpc_server.shutdown().await; - guard.remove_all().await; } pub async fn test_grpc_auth(store_type: StorageType) { @@ -220,7 +222,7 @@ pub async fn test_grpc_auth(store_type: StorageType) { &"static_user_provider:cmd:greptime_user=greptime_pwd".to_string(), ) .unwrap(); - let (addr, mut guard, fe_grpc_server) = + let (addr, _db, fe_grpc_server) = setup_grpc_server_with_user_provider(store_type, "auto_create_table", Some(user_provider)) .await; @@ -265,29 +267,25 @@ pub async fn test_grpc_auth(store_type: StorageType) { assert!(re.is_ok()); let _ = fe_grpc_server.shutdown().await; - guard.remove_all().await; } pub async fn test_auto_create_table(store_type: StorageType) { - let (addr, mut guard, fe_grpc_server) = - setup_grpc_server(store_type, "auto_create_table").await; + let (addr, _db, fe_grpc_server) = setup_grpc_server(store_type, "test_auto_create_table").await; let grpc_client = Client::with_urls(vec![addr]); let db = Database::new(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, grpc_client); insert_and_assert(&db).await; let _ = fe_grpc_server.shutdown().await; - guard.remove_all().await; } pub async fn test_auto_create_table_with_hints(store_type: StorageType) { - let (addr, mut guard, fe_grpc_server) = + let (addr, _db, fe_grpc_server) = setup_grpc_server(store_type, "auto_create_table_with_hints").await; let grpc_client = Client::with_urls(vec![addr]); let db = Database::new(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, grpc_client); insert_with_hints_and_assert(&db).await; let _ = fe_grpc_server.shutdown().await; - guard.remove_all().await; } fn expect_data() -> (Column, Column, Column, Column) { @@ -348,8 +346,7 @@ fn expect_data() -> (Column, Column, Column, Column) { pub async fn test_insert_and_select(store_type: StorageType) { common_telemetry::init_default_ut_logging(); - let (addr, mut guard, fe_grpc_server) = - setup_grpc_server(store_type, "insert_and_select").await; + let (addr, _db, fe_grpc_server) = setup_grpc_server(store_type, "test_insert_and_select").await; let grpc_client = Client::with_urls(vec![addr]); let db = Database::new(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, grpc_client); @@ -388,7 +385,6 @@ pub async fn test_insert_and_select(store_type: StorageType) { insert_and_assert(&db).await; let _ = fe_grpc_server.shutdown().await; - guard.remove_all().await; } async fn insert_with_hints_and_assert(db: &Database) { @@ -591,21 +587,20 @@ fn testing_create_expr() -> CreateTableExpr { } pub async fn test_health_check(store_type: StorageType) { - let (addr, mut guard, fe_grpc_server) = - setup_grpc_server(store_type, "auto_create_table").await; + let (addr, _db, fe_grpc_server) = setup_grpc_server(store_type, "test_health_check").await; let grpc_client = Client::with_urls(vec![addr]); grpc_client.health_check().await.unwrap(); let _ = fe_grpc_server.shutdown().await; - guard.remove_all().await; } pub async fn test_prom_gateway_query(store_type: StorageType) { common_telemetry::init_default_ut_logging(); // prepare connection - let (addr, mut guard, fe_grpc_server) = setup_grpc_server(store_type, "prom_gateway").await; + let (addr, _db, fe_grpc_server) = + setup_grpc_server(store_type, "test_prom_gateway_query").await; let grpc_client = Client::with_urls(vec![addr]); let db = Database::new( DEFAULT_CATALOG_NAME, @@ -772,7 +767,6 @@ pub async fn test_prom_gateway_query(store_type: StorageType) { // clean up let _ = fe_grpc_server.shutdown().await; - guard.remove_all().await; } pub async fn test_grpc_timezone(store_type: StorageType) { @@ -781,7 +775,7 @@ pub async fn test_grpc_timezone(store_type: StorageType) { max_send_message_size: 1024, ..Default::default() }; - let (addr, mut guard, fe_grpc_server) = + let (addr, _db, fe_grpc_server) = setup_grpc_server_with(store_type, "auto_create_table", None, Some(config)).await; let grpc_client = Client::with_urls(vec![addr]); @@ -824,7 +818,6 @@ pub async fn test_grpc_timezone(store_type: StorageType) { +-----------+" ); let _ = fe_grpc_server.shutdown().await; - guard.remove_all().await; } async fn to_batch(output: Output) -> String { @@ -856,7 +849,7 @@ pub async fn test_grpc_tls_config(store_type: StorageType) { max_send_message_size: 1024, tls, }; - let (addr, mut guard, fe_grpc_server) = + let (addr, _db, fe_grpc_server) = setup_grpc_server_with(store_type, "tls_create_table", None, Some(config)).await; let mut client_tls = ClientTlsOption { @@ -902,5 +895,4 @@ pub async fn test_grpc_tls_config(store_type: StorageType) { } let _ = fe_grpc_server.shutdown().await; - guard.remove_all().await; }