diff --git a/src/frontend/src/server.rs b/src/frontend/src/server.rs index d70aa3dd49..bc9399fe60 100644 --- a/src/frontend/src/server.rs +++ b/src/frontend/src/server.rs @@ -22,6 +22,7 @@ use common_telemetry::info; use meta_client::MetaClientOptions; use servers::error::Error as ServerError; use servers::grpc::builder::GrpcServerBuilder; +use servers::grpc::flight::FlightCraftRef; use servers::grpc::frontend_grpc_handler::FrontendGrpcHandler; use servers::grpc::greptime_handler::GreptimeRequestHandler; use servers::grpc::{GrpcOptions, GrpcServer}; @@ -52,6 +53,7 @@ where grpc_server_builder: Option, http_server_builder: Option, plugins: Plugins, + flight_handler: Option, } impl Services @@ -65,6 +67,7 @@ where grpc_server_builder: None, http_server_builder: None, plugins, + flight_handler: None, } } @@ -139,6 +142,13 @@ where } } + pub fn with_flight_handler(self, flight_handler: FlightCraftRef) -> Self { + Self { + flight_handler: Some(flight_handler), + ..self + } + } + fn build_grpc_server( &mut self, grpc: &GrpcOptions, @@ -173,6 +183,12 @@ where grpc.flight_compression, ); + // Use custom flight handler if provided, otherwise use the default GreptimeRequestHandler + let flight_handler = self + .flight_handler + .clone() + .unwrap_or_else(|| Arc::new(greptime_request_handler.clone()) as FlightCraftRef); + let grpc_server = builder .name(name) .database_handler(greptime_request_handler.clone()) @@ -181,7 +197,7 @@ where self.instance.clone(), user_provider.clone(), )) - .flight_handler(Arc::new(greptime_request_handler)); + .flight_handler(flight_handler); let grpc_server = if !external { let frontend_grpc_handler = diff --git a/src/servers/src/grpc/flight.rs b/src/servers/src/grpc/flight.rs index 44b307fe71..8cabcb7fec 100644 --- a/src/servers/src/grpc/flight.rs +++ b/src/servers/src/grpc/flight.rs @@ -249,11 +249,11 @@ impl FlightCraft for GreptimeRequestHandler { } } -pub(crate) struct PutRecordBatchRequest { - pub(crate) table_name: TableName, - pub(crate) request_id: i64, - pub(crate) data: FlightData, - pub(crate) _guard: Option, +pub struct PutRecordBatchRequest { + pub table_name: TableName, + pub request_id: i64, + pub data: FlightData, + pub _guard: Option, } impl PutRecordBatchRequest { @@ -297,13 +297,13 @@ impl PutRecordBatchRequest { } } -pub(crate) struct PutRecordBatchRequestStream { - flight_data_stream: Streaming, - state: PutRecordBatchRequestStreamState, - limiter: Option, +pub struct PutRecordBatchRequestStream { + pub flight_data_stream: Streaming, + pub state: PutRecordBatchRequestStreamState, + pub limiter: Option, } -enum PutRecordBatchRequestStreamState { +pub enum PutRecordBatchRequestStreamState { Init(String, String), Started(TableName), }