From 18250c48034235b332a602dfd98a6fa55b42b62e Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Tue, 22 Aug 2023 08:30:09 -0500 Subject: [PATCH] feat: implement Flight and gRPC services for RegionServer (#2226) * extract FlightCraft trait Signed-off-by: Ruihang Xia * split service handler in GrpcServer Signed-off-by: Ruihang Xia * left grpc server implement Signed-off-by: Ruihang Xia * start region server if configured Signed-off-by: Ruihang Xia --------- Signed-off-by: Ruihang Xia --- Cargo.lock | 1 + src/api/src/helper.rs | 16 ++ src/datanode/Cargo.toml | 1 + src/datanode/src/region_server.rs | 82 +++++++- src/datanode/src/server.rs | 18 ++ src/frontend/src/server.rs | 2 + src/servers/Cargo.toml | 1 + src/servers/src/grpc.rs | 94 ++++++--- src/servers/src/grpc/database.rs | 8 +- src/servers/src/grpc/flight.rs | 64 ++++-- src/servers/src/grpc/flight/stream.rs | 4 +- .../grpc/{handler.rs => greptime_handler.rs} | 9 +- src/servers/src/grpc/prom_query_gateway.rs | 2 +- src/servers/src/grpc/region_server.rs | 195 ++++++++++++++++++ src/servers/tests/grpc/mod.rs | 9 +- tests-integration/Cargo.toml | 2 +- tests-integration/src/cluster.rs | 8 + tests-integration/src/test_util.rs | 8 + 18 files changed, 453 insertions(+), 71 deletions(-) rename src/servers/src/grpc/{handler.rs => greptime_handler.rs} (96%) create mode 100644 src/servers/src/grpc/region_server.rs diff --git a/Cargo.lock b/Cargo.lock index d30ce2bfc8..80c932f7e5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2623,6 +2623,7 @@ name = "datanode" version = "0.3.2" dependencies = [ "api", + "arrow-flight", "async-compat", "async-stream", "async-trait", diff --git a/src/api/src/helper.rs b/src/api/src/helper.rs index bae965c8e8..bb049ecb26 100644 --- a/src/api/src/helper.rs +++ b/src/api/src/helper.rs @@ -37,6 +37,7 @@ use greptime_proto::v1; use greptime_proto::v1::ddl_request::Expr; use greptime_proto::v1::greptime_request::Request; use greptime_proto::v1::query_request::Query; +use greptime_proto::v1::region::region_request; use greptime_proto::v1::value::ValueData; use greptime_proto::v1::{DdlRequest, IntervalMonthDayNano, QueryRequest, SemanticType}; use snafu::prelude::*; @@ -328,6 +329,21 @@ fn query_request_type(request: &QueryRequest) -> &'static str { } } +/// Returns the type name of the [RegionRequest]. +pub fn region_request_type(request: ®ion_request::Request) -> &'static str { + match request { + region_request::Request::Inserts(_) => "region.inserts", + region_request::Request::Deletes(_) => "region.deletes", + region_request::Request::Create(_) => "region.create", + region_request::Request::Drop(_) => "region.drop ", + region_request::Request::Open(_) => "region.open", + region_request::Request::Close(_) => "region.close", + region_request::Request::Alter(_) => "region.alter", + region_request::Request::Flush(_) => "region.flush", + region_request::Request::Compact(_) => "region.compact", + } +} + /// Returns the type name of the [DdlRequest]. fn ddl_request_type(request: &DdlRequest) -> &'static str { match request.expr { diff --git a/src/datanode/Cargo.toml b/src/datanode/Cargo.toml index 1f7589a901..caa71d44e6 100644 --- a/src/datanode/Cargo.toml +++ b/src/datanode/Cargo.toml @@ -9,6 +9,7 @@ testing = ["meta-srv/mock"] [dependencies] api = { workspace = true } +arrow-flight.workspace = true async-compat = "0.2" async-stream.workspace = true async-trait.workspace = true diff --git a/src/datanode/src/region_server.rs b/src/datanode/src/region_server.rs index 34c2264279..ae41a00f80 100644 --- a/src/datanode/src/region_server.rs +++ b/src/datanode/src/region_server.rs @@ -14,9 +14,11 @@ use std::any::Any; use std::collections::HashMap; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, RwLock}; -use api::v1::region::QueryRequest; +use api::v1::region::region_request::Request as RequestBody; +use api::v1::region::{QueryRequest, RegionResponse}; +use arrow_flight::{FlightData, Ticket}; use async_trait::async_trait; use bytes::Bytes; use common_query::logical_plan::Expr; @@ -33,7 +35,12 @@ use datafusion::execution::context::SessionState; use datafusion_common::DataFusionError; use datafusion_expr::{Expr as DfExpr, TableType}; use datatypes::arrow::datatypes::SchemaRef; +use prost::Message; use query::QueryEngineRef; +use servers::error as servers_error; +use servers::error::Result as ServerResult; +use servers::grpc::flight::{FlightCraft, FlightRecordBatchStream, TonicStream}; +use servers::grpc::region_server::RegionServerHandler; use session::context::QueryContext; use snafu::{OptionExt, ResultExt}; use store_api::metadata::RegionMetadataRef; @@ -42,6 +49,7 @@ use store_api::region_request::RegionRequest; use store_api::storage::{RegionId, ScanRequest}; use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan}; use table::table::scan::StreamScanAdapter; +use tonic::{Request, Response, Result as TonicResult}; use crate::error::{ DecodeLogicalPlanSnafu, ExecuteLogicalPlanSnafu, GetRegionMetadataSnafu, @@ -49,24 +57,80 @@ use crate::error::{ UnsupportedOutputSnafu, }; +#[derive(Clone)] pub struct RegionServer { - engines: HashMap, - region_map: DashMap, - query_engine: QueryEngineRef, + inner: Arc, } impl RegionServer { pub fn new(query_engine: QueryEngineRef) -> Self { Self { - engines: HashMap::new(), + inner: Arc::new(RegionServerInner::new(query_engine)), + } + } + + pub fn register_engine(&mut self, engine: RegionEngineRef) { + self.inner.register_engine(engine); + } + + pub async fn handle_request( + &self, + region_id: RegionId, + request: RegionRequest, + ) -> Result { + self.inner.handle_request(region_id, request).await + } + + pub async fn handle_read(&self, request: QueryRequest) -> Result { + self.inner.handle_read(request).await + } +} + +#[async_trait] +impl RegionServerHandler for RegionServer { + async fn handle(&self, _request: RequestBody) -> ServerResult { + todo!() + } +} + +#[async_trait] +impl FlightCraft for RegionServer { + async fn do_get( + &self, + request: Request, + ) -> TonicResult>> { + let ticket = request.into_inner().ticket; + let request = QueryRequest::decode(ticket.as_ref()) + .context(servers_error::InvalidFlightTicketSnafu)?; + + let result = self.handle_read(request).await?; + + let stream = Box::pin(FlightRecordBatchStream::new(result)); + Ok(Response::new(stream)) + } +} + +struct RegionServerInner { + engines: RwLock>, + region_map: DashMap, + query_engine: QueryEngineRef, +} + +impl RegionServerInner { + pub fn new(query_engine: QueryEngineRef) -> Self { + Self { + engines: RwLock::new(HashMap::new()), region_map: DashMap::new(), query_engine, } } - pub fn register_engine(&mut self, engine: RegionEngineRef) { + pub fn register_engine(&self, engine: RegionEngineRef) { let engine_name = engine.name(); - self.engines.insert(engine_name.to_string(), engine); + self.engines + .write() + .unwrap() + .insert(engine_name.to_string(), engine); } pub async fn handle_request( @@ -90,6 +154,8 @@ impl RegionServer { let engine = match ®ion_change { RegionChange::Register(engine_type) => self .engines + .read() + .unwrap() .get(engine_type) .with_context(|| RegionEngineNotFoundSnafu { name: engine_type })? .clone(), diff --git a/src/datanode/src/server.rs b/src/datanode/src/server.rs index 0b037f4d59..dbc2ececc9 100644 --- a/src/datanode/src/server.rs +++ b/src/datanode/src/server.rs @@ -31,6 +31,7 @@ use crate::error::{ WaitForGrpcServingSnafu, }; use crate::instance::InstanceRef; +use crate::region_server::RegionServer; pub mod grpc; @@ -42,6 +43,9 @@ pub struct Services { impl Services { pub async fn try_new(instance: InstanceRef, opts: &DatanodeOptions) -> Result { + // TODO(ruihang): remove database service once region server is ready. + let enable_region_server = option_env!("ENABLE_REGION_SERVER").is_some(); + let grpc_runtime = Arc::new( RuntimeBuilder::default() .worker_threads(opts.rpc_runtime_size) @@ -50,10 +54,24 @@ impl Services { .context(RuntimeResourceSnafu)?, ); + let region_server = RegionServer::new(instance.query_engine()); + let flight_handler = if enable_region_server { + Some(Arc::new(region_server.clone()) as _) + } else { + None + }; + let region_server_handler = if enable_region_server { + Some(Arc::new(region_server.clone()) as _) + } else { + None + }; + Ok(Self { grpc_server: GrpcServer::new( ServerGrpcQueryHandlerAdaptor::arc(instance), None, + flight_handler, + region_server_handler, None, grpc_runtime, ), diff --git a/src/frontend/src/server.rs b/src/frontend/src/server.rs index 2810436468..91d709b8d8 100644 --- a/src/frontend/src/server.rs +++ b/src/frontend/src/server.rs @@ -71,6 +71,8 @@ impl Services { let grpc_server = GrpcServer::new( ServerGrpcQueryHandlerAdaptor::arc(instance.clone()), Some(instance.clone()), + None, + None, user_provider.clone(), grpc_runtime, ); diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index f13fe7ee33..7ad63ea38a 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -8,6 +8,7 @@ license.workspace = true dashboard = [] mem-prof = ["dep:common-mem-prof"] pprof = ["dep:pprof"] +testing = [] [dependencies] aide = { version = "0.9", features = ["axum"] } diff --git a/src/servers/src/grpc.rs b/src/servers/src/grpc.rs index 337add7b76..8ead91f6b3 100644 --- a/src/servers/src/grpc.rs +++ b/src/servers/src/grpc.rs @@ -14,8 +14,9 @@ mod database; pub mod flight; -pub mod handler; +pub mod greptime_handler; pub mod prom_query_gateway; +pub mod region_server; use std::net::SocketAddr; use std::sync::Arc; @@ -23,6 +24,7 @@ use std::sync::Arc; use api::v1::greptime_database_server::{GreptimeDatabase, GreptimeDatabaseServer}; use api::v1::health_check_server::{HealthCheck, HealthCheckServer}; use api::v1::prometheus_gateway_server::{PrometheusGateway, PrometheusGatewayServer}; +use api::v1::region::region_server_server::RegionServerServer; use api::v1::{HealthCheckRequest, HealthCheckResponse}; use arrow_flight::flight_service_server::{FlightService, FlightServiceServer}; use async_trait::async_trait; @@ -37,15 +39,14 @@ use tokio::sync::oneshot::{self, Receiver, Sender}; use tokio::sync::Mutex; use tokio_stream::wrappers::TcpListenerStream; use tonic::{Request, Response, Status}; +use tonic_reflection::server::{ServerReflection, ServerReflectionServer}; +use self::flight::{FlightCraftRef, FlightCraftWrapper}; use self::prom_query_gateway::PrometheusGatewayService; -use crate::error::{ - AlreadyStartedSnafu, GrpcReflectionServiceSnafu, InternalSnafu, Result, StartGrpcSnafu, - TcpBindSnafu, -}; +use self::region_server::{RegionServerHandlerRef, RegionServerRequestHandler}; +use crate::error::{AlreadyStartedSnafu, InternalSnafu, Result, StartGrpcSnafu, TcpBindSnafu}; use crate::grpc::database::DatabaseService; -use crate::grpc::flight::FlightHandler; -use crate::grpc::handler::GreptimeRequestHandler; +use crate::grpc::greptime_handler::GreptimeRequestHandler; use crate::prometheus::PrometheusHandlerRef; use crate::query_handler::grpc::ServerGrpcQueryHandlerRef; use crate::server::Server; @@ -53,50 +54,74 @@ use crate::server::Server; type TonicResult = std::result::Result; pub struct GrpcServer { + // states shutdown_tx: Mutex>>, - request_handler: Arc, user_provider: Option, - /// Handler for Prometheus-compatible PromQL queries. Only present for frontend server. - prometheus_handler: Option, /// gRPC serving state receiver. Only present if the gRPC server is started. /// Used to wait for the server to stop, performing the old blocking fashion. serve_state: Mutex>>>, + + // handlers + /// Handler for [GreptimeDatabase] service. + database_handler: Option, + /// Handler for Prometheus-compatible PromQL queries ([PrometheusGateway]). Only present for frontend server. + prometheus_handler: Option, + /// Handler for [FlightService]. + flight_handler: Option, + /// Handler for [RegionServer]. + region_server_handler: Option, } impl GrpcServer { pub fn new( query_handler: ServerGrpcQueryHandlerRef, prometheus_handler: Option, + flight_handler: Option, + region_server_handler: Option, user_provider: Option, runtime: Arc, ) -> Self { - let request_handler = Arc::new(GreptimeRequestHandler::new( - query_handler, - user_provider.clone(), - runtime, - )); + let database_handler = + GreptimeRequestHandler::new(query_handler, user_provider.clone(), runtime.clone()); + let region_server_handler = region_server_handler.map(|handler| { + RegionServerRequestHandler::new(handler, user_provider.clone(), runtime.clone()) + }); Self { shutdown_tx: Mutex::new(None), - request_handler, user_provider, - prometheus_handler, serve_state: Mutex::new(None), + database_handler: Some(database_handler), + prometheus_handler, + flight_handler, + region_server_handler, } } + #[cfg(feature = "testing")] pub fn create_flight_service(&self) -> FlightServiceServer { - FlightServiceServer::new(FlightHandler::new(self.request_handler.clone())) + FlightServiceServer::new(FlightCraftWrapper(self.database_handler.clone().unwrap())) } + #[cfg(feature = "testing")] pub fn create_database_service(&self) -> GreptimeDatabaseServer { - GreptimeDatabaseServer::new(DatabaseService::new(self.request_handler.clone())) + GreptimeDatabaseServer::new(DatabaseService::new(self.database_handler.clone().unwrap())) } pub fn create_healthcheck_service(&self) -> HealthCheckServer { HealthCheckServer::new(HealthCheckHandler) } + pub fn create_reflection_service(&self) -> ServerReflectionServer { + tonic_reflection::server::Builder::configure() + .register_encoded_file_descriptor_set(api::v1::GREPTIME_GRPC_DESC) + .with_service_name("greptime.v1.GreptimeDatabase") + .with_service_name("greptime.v1.HealthCheck") + .with_service_name("greptime.v1.RegionServer") + .build() + .unwrap() + } + pub fn create_prom_query_gateway_service( &self, handler: PrometheusHandlerRef, @@ -172,22 +197,31 @@ impl Server for GrpcServer { (listener, addr) }; - let reflection_service = tonic_reflection::server::Builder::configure() - .register_encoded_file_descriptor_set(api::v1::GREPTIME_GRPC_DESC) - .with_service_name("greptime.v1.GreptimeDatabase") - .with_service_name("greptime.v1.HealthCheck") - .build() - .context(GrpcReflectionServiceSnafu)?; - let mut builder = tonic::transport::Server::builder() - .add_service(self.create_flight_service()) - .add_service(self.create_database_service()) - .add_service(self.create_healthcheck_service()); + .add_service(self.create_healthcheck_service()) + .add_service(self.create_reflection_service()); + if let Some(database_handler) = &self.database_handler { + builder = builder.add_service(GreptimeDatabaseServer::new(DatabaseService::new( + database_handler.clone(), + ))) + } if let Some(prometheus_handler) = &self.prometheus_handler { builder = builder .add_service(self.create_prom_query_gateway_service(prometheus_handler.clone())) } - let builder = builder.add_service(reflection_service); + if let Some(flight_handler) = &self.flight_handler { + builder = builder.add_service(FlightServiceServer::new(FlightCraftWrapper( + flight_handler.clone(), + ))) + } else { + // TODO(ruihang): this is a temporary workaround before region server is ready. + builder = builder.add_service(FlightServiceServer::new(FlightCraftWrapper( + self.database_handler.clone().unwrap(), + ))) + } + if let Some(region_server_handler) = &self.region_server_handler { + builder = builder.add_service(RegionServerServer::new(region_server_handler.clone())) + } let (serve_state_tx, serve_state_rx) = oneshot::channel(); let mut serve_state = self.serve_state.lock().await; diff --git a/src/servers/src/grpc/database.rs b/src/servers/src/grpc/database.rs index 1c832e6efe..5be5d76254 100644 --- a/src/servers/src/grpc/database.rs +++ b/src/servers/src/grpc/database.rs @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; - use api::v1::greptime_database_server::GreptimeDatabase; use api::v1::greptime_response::Response as RawResponse; use api::v1::{AffectedRows, GreptimeRequest, GreptimeResponse, ResponseHeader}; @@ -23,15 +21,15 @@ use common_query::Output; use futures::StreamExt; use tonic::{Request, Response, Status, Streaming}; -use crate::grpc::handler::GreptimeRequestHandler; +use crate::grpc::greptime_handler::GreptimeRequestHandler; use crate::grpc::TonicResult; pub(crate) struct DatabaseService { - handler: Arc, + handler: GreptimeRequestHandler, } impl DatabaseService { - pub(crate) fn new(handler: Arc) -> Self { + pub(crate) fn new(handler: GreptimeRequestHandler) -> Self { Self { handler } } } diff --git a/src/servers/src/grpc/flight.rs b/src/servers/src/grpc/flight.rs index 0b793d9855..6f4f7cbca7 100644 --- a/src/servers/src/grpc/flight.rs +++ b/src/servers/src/grpc/flight.rs @@ -32,24 +32,43 @@ use snafu::ResultExt; use tonic::{Request, Response, Status, Streaming}; use crate::error; -use crate::grpc::flight::stream::FlightRecordBatchStream; -use crate::grpc::handler::GreptimeRequestHandler; +pub use crate::grpc::flight::stream::FlightRecordBatchStream; +use crate::grpc::greptime_handler::GreptimeRequestHandler; use crate::grpc::TonicResult; -type TonicStream = Pin> + Send + Sync + 'static>>; +pub type TonicStream = Pin> + Send + Sync + 'static>>; -pub struct FlightHandler { - handler: Arc, +/// A subset of [FlightService] +#[async_trait] +pub trait FlightCraft: Send + Sync + 'static { + async fn do_get( + &self, + request: Request, + ) -> TonicResult>>; } -impl FlightHandler { - pub fn new(handler: Arc) -> Self { - Self { handler } +pub type FlightCraftRef = Arc; + +pub struct FlightCraftWrapper(pub T); + +impl From for FlightCraftWrapper { + fn from(t: T) -> Self { + Self(t) } } #[async_trait] -impl FlightService for FlightHandler { +impl FlightCraft for FlightCraftRef { + async fn do_get( + &self, + request: Request, + ) -> TonicResult>> { + (**self).do_get(request).await + } +} + +#[async_trait] +impl FlightService for FlightCraftWrapper { type HandshakeStream = TonicStream; async fn handshake( @@ -85,14 +104,7 @@ impl FlightService for FlightHandler { type DoGetStream = TonicStream; async fn do_get(&self, request: Request) -> TonicResult> { - let ticket = request.into_inner().ticket; - let request = - GreptimeRequest::decode(ticket.as_ref()).context(error::InvalidFlightTicketSnafu)?; - - let output = self.handler.handle_request(request).await?; - - let stream = to_flight_data_stream(output); - Ok(Response::new(stream)) + self.0.do_get(request).await } type DoPutStream = TonicStream; @@ -129,6 +141,24 @@ impl FlightService for FlightHandler { } } +#[async_trait] +impl FlightCraft for GreptimeRequestHandler { + async fn do_get( + &self, + request: Request, + ) -> TonicResult>> { + let ticket = request.into_inner().ticket; + let request = + GreptimeRequest::decode(ticket.as_ref()).context(error::InvalidFlightTicketSnafu)?; + + let output = self.handle_request(request).await?; + + let stream: Pin> + Send + Sync>> = + to_flight_data_stream(output); + Ok(Response::new(stream)) + } +} + fn to_flight_data_stream(output: Output) -> TonicStream { match output { Output::Stream(stream) => { diff --git a/src/servers/src/grpc/flight/stream.rs b/src/servers/src/grpc/flight/stream.rs index 0048da2ed8..5ff570608e 100644 --- a/src/servers/src/grpc/flight/stream.rs +++ b/src/servers/src/grpc/flight/stream.rs @@ -30,7 +30,7 @@ use super::TonicResult; use crate::error; #[pin_project(PinnedDrop)] -pub(super) struct FlightRecordBatchStream { +pub struct FlightRecordBatchStream { #[pin] rx: mpsc::Receiver>, join_handle: JoinHandle<()>, @@ -39,7 +39,7 @@ pub(super) struct FlightRecordBatchStream { } impl FlightRecordBatchStream { - pub(super) fn new(recordbatches: SendableRecordBatchStream) -> Self { + pub fn new(recordbatches: SendableRecordBatchStream) -> Self { let (tx, rx) = mpsc::channel::>(1); let join_handle = common_runtime::spawn_read( diff --git a/src/servers/src/grpc/handler.rs b/src/servers/src/grpc/greptime_handler.rs similarity index 96% rename from src/servers/src/grpc/handler.rs rename to src/servers/src/grpc/greptime_handler.rs index ed6e36ad2e..873a6293fb 100644 --- a/src/servers/src/grpc/handler.rs +++ b/src/servers/src/grpc/greptime_handler.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Handler for Greptime Database service. It's implemented by frontend. + use std::sync::Arc; use std::time::Instant; @@ -38,6 +40,7 @@ use crate::metrics::{ }; use crate::query_handler::grpc::ServerGrpcQueryHandlerRef; +#[derive(Clone)] pub struct GreptimeRequestHandler { handler: ServerGrpcQueryHandlerRef, user_provider: Option, @@ -174,7 +177,7 @@ pub(crate) fn create_query_context(header: Option<&RequestHeader>) -> QueryConte /// Histogram timer for handling gRPC request. /// /// The timer records the elapsed time with [StatusCode::Success] on drop. -struct RequestTimer { +pub(crate) struct RequestTimer { start: Instant, db: String, request_type: &'static str, @@ -183,7 +186,7 @@ struct RequestTimer { impl RequestTimer { /// Returns a new timer. - fn new(db: String, request_type: &'static str) -> RequestTimer { + pub fn new(db: String, request_type: &'static str) -> RequestTimer { RequestTimer { start: Instant::now(), db, @@ -193,7 +196,7 @@ impl RequestTimer { } /// Consumes the timer and record the elapsed time with specific `status_code`. - fn record(mut self, status_code: StatusCode) { + pub fn record(mut self, status_code: StatusCode) { self.status_code = status_code; } } diff --git a/src/servers/src/grpc/prom_query_gateway.rs b/src/servers/src/grpc/prom_query_gateway.rs index 1cae3b3a45..02d74839c4 100644 --- a/src/servers/src/grpc/prom_query_gateway.rs +++ b/src/servers/src/grpc/prom_query_gateway.rs @@ -33,7 +33,7 @@ use snafu::OptionExt; use tonic::{Request, Response}; use crate::error::InvalidQuerySnafu; -use crate::grpc::handler::{auth, create_query_context}; +use crate::grpc::greptime_handler::{auth, create_query_context}; use crate::grpc::TonicResult; use crate::prometheus::{ retrieve_metric_name_and_result_type, PrometheusHandlerRef, PrometheusJsonResponse, diff --git a/src/servers/src/grpc/region_server.rs b/src/servers/src/grpc/region_server.rs new file mode 100644 index 0000000000..e3a7c06673 --- /dev/null +++ b/src/servers/src/grpc/region_server.rs @@ -0,0 +1,195 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use api::helper::region_request_type; +use api::v1::auth_header::AuthScheme; +use api::v1::region::region_request::Request as RequestBody; +use api::v1::region::region_server_server::RegionServer as RegionServerService; +use api::v1::region::{RegionRequest, RegionResponse}; +use api::v1::{Basic, RequestHeader}; +use async_trait::async_trait; +use auth::{Identity, Password, UserInfoRef, UserProviderRef}; +use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; +use common_catalog::parse_catalog_and_schema_from_db_string; +use common_error::ext::ErrorExt; +use common_runtime::Runtime; +use common_telemetry::{debug, error}; +use metrics::increment_counter; +use session::context::{QueryContextBuilder, QueryContextRef}; +use snafu::{OptionExt, ResultExt}; +use tonic::{Request, Response}; + +use crate::error::{ + AuthSnafu, InvalidQuerySnafu, JoinTaskSnafu, NotFoundAuthHeaderSnafu, Result, + UnsupportedAuthSchemeSnafu, +}; +use crate::grpc::greptime_handler::RequestTimer; +use crate::grpc::TonicResult; +use crate::metrics::{METRIC_AUTH_FAILURE, METRIC_CODE_LABEL}; + +#[async_trait] +pub trait RegionServerHandler: Send + Sync { + async fn handle(&self, request: RequestBody) -> Result; +} + +pub type RegionServerHandlerRef = Arc; + +#[derive(Clone)] +pub struct RegionServerRequestHandler { + handler: Arc, + user_provider: Option, + runtime: Arc, +} + +impl RegionServerRequestHandler { + pub fn new( + handler: Arc, + user_provider: Option, + runtime: Arc, + ) -> Self { + Self { + handler, + user_provider, + runtime, + } + } + + async fn handle(&self, request: RegionRequest) -> Result { + let query = request.request.context(InvalidQuerySnafu { + reason: "Expecting non-empty GreptimeRequest.", + })?; + + let header = request.header.as_ref(); + let query_ctx = create_query_context(header); + let user_info = self.auth(header, &query_ctx).await?; + query_ctx.set_current_user(user_info); + + let handler = self.handler.clone(); + let request_type = region_request_type(&query); + let db = query_ctx.get_db_string(); + let timer = RequestTimer::new(db.clone(), request_type); + + // Executes requests in another runtime to + // 1. prevent the execution from being cancelled unexpected by Tonic runtime; + // - Refer to our blog for the rational behind it: + // https://www.greptime.com/blogs/2023-01-12-hidden-control-flow.html + // - Obtaining a `JoinHandle` to get the panic message (if there's any). + // From its docs, `JoinHandle` is cancel safe. The task keeps running even it's handle been dropped. + // 2. avoid the handler blocks the gRPC runtime incidentally. + let handle = self.runtime.spawn(async move { + handler.handle(query).await.map_err(|e| { + if e.status_code().should_log_error() { + error!(e; "Failed to handle request"); + } else { + // Currently, we still print a debug log. + debug!("Failed to handle request, err: {}", e); + } + e + }) + }); + + handle.await.context(JoinTaskSnafu).map_err(|e| { + timer.record(e.status_code()); + e + })? + } + + async fn auth( + &self, + header: Option<&RequestHeader>, + query_ctx: &QueryContextRef, + ) -> Result> { + let Some(user_provider) = self.user_provider.as_ref() else { + return Ok(None); + }; + + let auth_scheme = header + .and_then(|header| { + header + .authorization + .as_ref() + .and_then(|x| x.auth_scheme.clone()) + }) + .context(NotFoundAuthHeaderSnafu)?; + + match auth_scheme { + AuthScheme::Basic(Basic { username, password }) => user_provider + .auth( + Identity::UserId(&username, None), + Password::PlainText(password.into()), + query_ctx.current_catalog(), + query_ctx.current_schema(), + ) + .await + .context(AuthSnafu), + AuthScheme::Token(_) => UnsupportedAuthSchemeSnafu { + name: "Token AuthScheme".to_string(), + } + .fail(), + } + .map(Some) + .map_err(|e| { + increment_counter!( + METRIC_AUTH_FAILURE, + &[(METRIC_CODE_LABEL, format!("{}", e.status_code()))] + ); + e + }) + } +} + +pub(crate) fn create_query_context(header: Option<&RequestHeader>) -> QueryContextRef { + let (catalog, schema) = header + .map(|header| { + // We provide dbname field in newer versions of protos/sdks + // parse dbname from header in priority + if !header.dbname.is_empty() { + parse_catalog_and_schema_from_db_string(&header.dbname) + } else { + ( + if !header.catalog.is_empty() { + &header.catalog + } else { + DEFAULT_CATALOG_NAME + }, + if !header.schema.is_empty() { + &header.schema + } else { + DEFAULT_SCHEMA_NAME + }, + ) + } + }) + .unwrap_or((DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME)); + + QueryContextBuilder::default() + .current_catalog(catalog.to_string()) + .current_schema(schema.to_string()) + .try_trace_id(header.and_then(|h: &RequestHeader| h.trace_id)) + .build() +} + +#[async_trait] +impl RegionServerService for RegionServerRequestHandler { + async fn handle( + &self, + request: Request, + ) -> TonicResult> { + let request = request.into_inner(); + let response = self.handle(request).await?; + Ok(Response::new(response)) + } +} diff --git a/src/servers/tests/grpc/mod.rs b/src/servers/tests/grpc/mod.rs index 82a98f1938..96becf5f94 100644 --- a/src/servers/tests/grpc/mod.rs +++ b/src/servers/tests/grpc/mod.rs @@ -24,8 +24,8 @@ use auth::UserProviderRef; use client::{Client, Database, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_runtime::{Builder as RuntimeBuilder, Runtime}; use servers::error::{Result, StartGrpcSnafu, TcpBindSnafu}; -use servers::grpc::flight::FlightHandler; -use servers::grpc::handler::GreptimeRequestHandler; +use servers::grpc::flight::FlightCraftWrapper; +use servers::grpc::greptime_handler::GreptimeRequestHandler; use servers::query_handler::grpc::ServerGrpcQueryHandlerRef; use servers::server::Server; use snafu::ResultExt; @@ -55,11 +55,12 @@ impl MockGrpcServer { } fn create_service(&self) -> FlightServiceServer { - let service = FlightHandler::new(Arc::new(GreptimeRequestHandler::new( + let service: FlightCraftWrapper<_> = GreptimeRequestHandler::new( self.query_handler.clone(), self.user_provider.clone(), self.runtime.clone(), - ))); + ) + .into(); FlightServiceServer::new(service) } } diff --git a/tests-integration/Cargo.toml b/tests-integration/Cargo.toml index 4baed8119c..ea69c1af25 100644 --- a/tests-integration/Cargo.toml +++ b/tests-integration/Cargo.toml @@ -43,7 +43,7 @@ rstest_reuse = "0.5" secrecy = "0.8" serde.workspace = true serde_json = "1.0" -servers = { workspace = true } +servers = { workspace = true, features = ["testing"] } session = { workspace = true } snafu.workspace = true sql = { workspace = true } diff --git a/tests-integration/src/cluster.rs b/tests-integration/src/cluster.rs index b792e6b59c..350a6a27c1 100644 --- a/tests-integration/src/cluster.rs +++ b/tests-integration/src/cluster.rs @@ -38,6 +38,7 @@ use meta_srv::metasrv::{MetaSrv, MetaSrvOptions}; use meta_srv::mocks::MockInfo; use meta_srv::service::store::kv::{KvBackendAdapter, KvStoreRef}; use meta_srv::service::store::memory::MemStore; +use servers::grpc::greptime_handler::GreptimeRequestHandler; use servers::grpc::GrpcServer; use servers::query_handler::grpc::ServerGrpcQueryHandlerAdaptor; use servers::Mode; @@ -287,9 +288,16 @@ async fn create_datanode_client(datanode_instance: Arc) -> (St // create a mock datanode grpc service, see example here: // https://github.com/hyperium/tonic/blob/master/examples/src/mock/mock.rs + let query_handler = Arc::new(GreptimeRequestHandler::new( + ServerGrpcQueryHandlerAdaptor::arc(datanode_instance.clone()), + None, + runtime.clone(), + )); let grpc_server = GrpcServer::new( ServerGrpcQueryHandlerAdaptor::arc(datanode_instance), None, + Some(query_handler), + None, None, runtime, ); diff --git a/tests-integration/src/test_util.rs b/tests-integration/src/test_util.rs index c963d06a42..c152ec6acb 100644 --- a/tests-integration/src/test_util.rs +++ b/tests-integration/src/test_util.rs @@ -49,6 +49,7 @@ use object_store::services::{Azblob, Gcs, Oss, S3}; use object_store::test_util::TempFolder; use object_store::ObjectStore; use secrecy::ExposeSecret; +use servers::grpc::greptime_handler::GreptimeRequestHandler; use servers::grpc::GrpcServer; use servers::http::{HttpOptions, HttpServerBuilder}; use servers::metrics_handler::MetricsHandler; @@ -583,9 +584,16 @@ pub async fn setup_grpc_server_with_user_provider( heartbeat.start().await.unwrap(); } let fe_instance_ref = Arc::new(fe_instance); + let flight_handler = Arc::new(GreptimeRequestHandler::new( + ServerGrpcQueryHandlerAdaptor::arc(fe_instance_ref.clone()), + user_provider.clone(), + runtime.clone(), + )); let fe_grpc_server = Arc::new(GrpcServer::new( ServerGrpcQueryHandlerAdaptor::arc(fe_instance_ref.clone()), Some(fe_instance_ref.clone()), + Some(flight_handler), + None, user_provider, runtime, ));