From 11ecb7a28a6101eda8d6bba184d96c66f62f11ae Mon Sep 17 00:00:00 2001 From: "Lei, HUANG" <6406592+v0y4g3r@users.noreply.github.com> Date: Wed, 3 Dec 2025 23:08:02 -0800 Subject: [PATCH] refactor(servers): bulk insert service (#7329) * refactor/bulk-insert-service: refactor: decode FlightData early in put_record_batch pipeline - Move FlightDecoder usage from Inserter up to PutRecordBatchRequestStream, passing decoded RecordBatch and schema bytes instead of raw FlightData. - Eliminate redundant per-request decoding/encoding in Inserter; encode once and reuse for all region requests. - Streamline GrpcQueryHandler trait and implementations to accept PutRecordBatchRequest containing pre-decoded data. Signed-off-by: Lei, HUANG * refactor/bulk-insert-service: feat: stream-based bulk insert with per-batch responses - Introduce handle_put_record_batch_stream() to process Flight DoPut streams - Resolve table & permissions once, yield (request_id, AffectedRows) per batch - Replace loop-over-request with async-stream in frontend & server - Make PutRecordBatchRequestStream public for cross-crate usage Signed-off-by: Lei, HUANG * refactor/bulk-insert-service: fix: propagate request_id with errors in bulk insert stream Changes the bulk-insert stream item type from Result<(i64, AffectedRows), E> to (i64, Result) so every emitted tuple carries the request_id even on failure, letting callers correlate errors with the originating request. Signed-off-by: Lei, HUANG * refactor/bulk-insert-service: refactor: unify DoPut response stream to return DoPutResponse Replace the tuple (i64, Result) with Result throughout the gRPC bulk-insert path so the handler, adapter and server all speak the same type. Signed-off-by: Lei, HUANG * refactor/bulk-insert-service: feat: add elapsed_secs to DoPutResponse for bulk-insert timing - DoPutResponse now carries elapsed_secs field - Frontend measures and attaches insert duration - Server observes GRPC_BULK_INSERT_ELAPSED metric from response Signed-off-by: Lei, HUANG * refactor/bulk-insert-service: refactor: unify Bytes import in flight module - Replace `bytes::Bytes` with `Bytes` alias for consistency - Remove redundant `ProstBytes` alias Signed-off-by: Lei, HUANG * refactor/bulk-insert-service: fix: terminate gRPC stream on error and optimize FlightData handling - Stop retrying on stream errors in gRPC handler - Replace Vec1 indexing with into_iter().next() for FlightData - Remove redundant clones in bulk_insert and flight modules Signed-off-by: Lei, HUANG * refactor/bulk-insert-service: Improve permission check placement in `grpc.rs` - Moved the permission check for `BulkInsert` to occur before resolving the table reference in `GrpcQueryHandler` implementation. - Ensures permission validation is performed earlier in the process, potentially avoiding unnecessary operations if permission is denied. Signed-off-by: Lei, HUANG * refactor/bulk-insert-service: **Refactor Bulk Insert Handling in gRPC** - **`grpc.rs`**: - Switched from `async_stream::stream` to `async_stream::try_stream` for error handling. - Removed `body_size` parameter and added `flight_data` to `handle_bulk_insert`. - Simplified error handling and permission checks in `GrpcQueryHandler`. - **`bulk_insert.rs`**: - Added `raw_flight_data` parameter to `handle_bulk_insert`. - Calculated `body_size` from `raw_flight_data` and removed redundant encoding logic. - **`flight.rs`**: - Replaced `body_size` with `flight_data` in `PutRecordBatchRequest`. - Updated memory usage calculation to include `flight_data` components. Signed-off-by: Lei, HUANG * refactor/bulk-insert-service: perf(bulk_insert): encode record batch once per datanode Move FlightData encoding outside the per-region loop so the same encoded bytes are reused when mask.select_all(), eliminating redundant serialisation work. Signed-off-by: Lei, HUANG --------- Signed-off-by: Lei, HUANG --- src/common/grpc/src/flight/do_put.rs | 16 +- src/frontend/src/instance/grpc.rs | 89 ++++++++- src/operator/src/bulk_insert.rs | 48 ++--- src/servers/src/grpc/flight.rs | 240 +++++++++++++++-------- src/servers/src/grpc/greptime_handler.rs | 53 ++--- src/servers/src/query_handler/grpc.rs | 40 +++- src/servers/tests/mod.rs | 18 +- 7 files changed, 330 insertions(+), 174 deletions(-) diff --git a/src/common/grpc/src/flight/do_put.rs b/src/common/grpc/src/flight/do_put.rs index 15011fc74b..7997b7ba79 100644 --- a/src/common/grpc/src/flight/do_put.rs +++ b/src/common/grpc/src/flight/do_put.rs @@ -46,13 +46,16 @@ pub struct DoPutResponse { request_id: i64, /// The successfully ingested rows number. affected_rows: AffectedRows, + /// The elapsed time in seconds for handling the bulk insert. + elapsed_secs: f64, } impl DoPutResponse { - pub fn new(request_id: i64, affected_rows: AffectedRows) -> Self { + pub fn new(request_id: i64, affected_rows: AffectedRows, elapsed_secs: f64) -> Self { Self { request_id, affected_rows, + elapsed_secs, } } @@ -63,6 +66,10 @@ impl DoPutResponse { pub fn affected_rows(&self) -> AffectedRows { self.affected_rows } + + pub fn elapsed_secs(&self) -> f64 { + self.elapsed_secs + } } impl TryFrom for DoPutResponse { @@ -86,8 +93,11 @@ mod tests { #[test] fn test_serde_do_put_response() { - let x = DoPutResponse::new(42, 88); + let x = DoPutResponse::new(42, 88, 0.123); let serialized = serde_json::to_string(&x).unwrap(); - assert_eq!(serialized, r#"{"request_id":42,"affected_rows":88}"#); + assert_eq!( + serialized, + r#"{"request_id":42,"affected_rows":88,"elapsed_secs":0.123}"# + ); } } diff --git a/src/frontend/src/instance/grpc.rs b/src/frontend/src/instance/grpc.rs index 9eeb57ce01..09736c5c7f 100644 --- a/src/frontend/src/instance/grpc.rs +++ b/src/frontend/src/instance/grpc.rs @@ -12,7 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::pin::Pin; use std::sync::Arc; +use std::time::Instant; use api::helper::from_pb_time_ranges; use api::v1::ddl_request::{Expr as DdlExpr, Expr}; @@ -22,16 +24,18 @@ use api::v1::{ DeleteRequests, DropFlowExpr, InsertIntoPlan, InsertRequests, RowDeleteRequests, RowInsertRequests, }; +use async_stream::try_stream; use async_trait::async_trait; use auth::{PermissionChecker, PermissionCheckerRef, PermissionReq}; use common_base::AffectedRows; use common_error::ext::BoxedError; -use common_grpc::FlightData; -use common_grpc::flight::FlightDecoder; +use common_grpc::flight::do_put::DoPutResponse; use common_query::Output; use common_query::logical_plan::add_insert_to_logical_plan; use common_telemetry::tracing::{self}; use datafusion::datasource::DefaultTableSource; +use futures::Stream; +use futures::stream::StreamExt; use query::parser::PromQuery; use servers::interceptor::{GrpcQueryInterceptor, GrpcQueryInterceptorRef}; use servers::query_handler::grpc::GrpcQueryHandler; @@ -240,10 +244,8 @@ impl GrpcQueryHandler for Instance { async fn put_record_batch( &self, - table_name: &TableName, + request: servers::grpc::flight::PutRecordBatchRequest, table_ref: &mut Option, - decoder: &mut FlightDecoder, - data: FlightData, ctx: QueryContextRef, ) -> Result { let table = if let Some(table) = table_ref { @@ -252,15 +254,15 @@ impl GrpcQueryHandler for Instance { let table = self .catalog_manager() .table( - &table_name.catalog_name, - &table_name.schema_name, - &table_name.table_name, + &request.table_name.catalog_name, + &request.table_name.schema_name, + &request.table_name.table_name, None, ) .await .context(CatalogSnafu)? .with_context(|| TableNotFoundSnafu { - table_name: table_name.to_string(), + table_name: request.table_name.to_string(), })?; *table_ref = Some(table.clone()); table @@ -279,10 +281,77 @@ impl GrpcQueryHandler for Instance { // do we check limit for bulk insert? self.inserter - .handle_bulk_insert(table, decoder, data) + .handle_bulk_insert( + table, + request.flight_data, + request.record_batch, + request.schema_bytes, + ) .await .context(TableOperationSnafu) } + + fn handle_put_record_batch_stream( + &self, + mut stream: servers::grpc::flight::PutRecordBatchRequestStream, + ctx: QueryContextRef, + ) -> Pin> + Send>> { + // Resolve table once for the stream + // Clone all necessary data to make it 'static + let catalog_manager = self.catalog_manager().clone(); + let plugins = self.plugins.clone(); + let inserter = self.inserter.clone(); + let table_name = stream.table_name().clone(); + let ctx = ctx.clone(); + + Box::pin(try_stream! { + plugins + .get::() + .as_ref() + .check_permission(ctx.current_user(), PermissionReq::BulkInsert) + .context(PermissionSnafu)?; + // Cache for resolved table reference - resolve once and reuse + let table_ref = catalog_manager + .table( + &table_name.catalog_name, + &table_name.schema_name, + &table_name.table_name, + None, + ) + .await + .context(CatalogSnafu)? + .with_context(|| TableNotFoundSnafu { + table_name: table_name.to_string(), + })?; + + // Check permissions once for the stream + let interceptor_ref = plugins.get::>(); + let interceptor = interceptor_ref.as_ref(); + interceptor.pre_bulk_insert(table_ref.clone(), ctx.clone())?; + + // Process each request in the stream + while let Some(request_result) = stream.next().await { + let request = request_result.map_err(|e| { + let error_msg = format!("Stream error: {:?}", e); + IncompleteGrpcRequestSnafu { err_msg: error_msg }.build() + })?; + + let request_id = request.request_id; + let start = Instant::now(); + let rows = inserter + .handle_bulk_insert( + table_ref.clone(), + request.flight_data, + request.record_batch, + request.schema_bytes, + ) + .await + .context(TableOperationSnafu)?; + let elapsed_secs = start.elapsed().as_secs_f64(); + yield DoPutResponse::new(request_id, rows, elapsed_secs); + } + }) + } } fn fill_catalog_and_schema_from_context(ddl_expr: &mut DdlExpr, ctx: &QueryContextRef) { diff --git a/src/operator/src/bulk_insert.rs b/src/operator/src/bulk_insert.rs index a06cc9503c..cfc427e19c 100644 --- a/src/operator/src/bulk_insert.rs +++ b/src/operator/src/bulk_insert.rs @@ -22,9 +22,10 @@ use api::v1::region::{ }; use arrow::array::Array; use arrow::record_batch::RecordBatch; +use bytes::Bytes; use common_base::AffectedRows; use common_grpc::FlightData; -use common_grpc::flight::{FlightDecoder, FlightEncoder, FlightMessage}; +use common_grpc::flight::{FlightEncoder, FlightMessage}; use common_telemetry::error; use common_telemetry::tracing_context::TracingContext; use snafu::{OptionExt, ResultExt, ensure}; @@ -40,33 +41,20 @@ impl Inserter { pub async fn handle_bulk_insert( &self, table: TableRef, - decoder: &mut FlightDecoder, - data: FlightData, + raw_flight_data: FlightData, + record_batch: RecordBatch, + schema_bytes: Bytes, ) -> error::Result { let table_info = table.table_info(); let table_id = table_info.table_id(); let db_name = table_info.get_db_string(); - let decode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED - .with_label_values(&["decode_request"]) - .start_timer(); - let body_size = data.data_body.len(); - // Build region server requests - let message = decoder - .try_decode(&data) - .context(error::DecodeFlightDataSnafu)? - .context(error::NotSupportedSnafu { - feat: "bulk insert RecordBatch with dictionary arrays", - })?; - let FlightMessage::RecordBatch(record_batch) = message else { - return Ok(0); - }; - decode_timer.observe_duration(); if record_batch.num_rows() == 0 { return Ok(0); } - // TODO(yingwen): Fill record batch impure default values. + let body_size = raw_flight_data.data_body.len(); + // TODO(yingwen): Fill record batch impure default values. Note that we should override `raw_flight_data` if we have to fill defaults. // notify flownode to update dirty timestamps if flow is configured. self.maybe_update_flow_dirty_window(table_info.clone(), record_batch.clone()); @@ -75,8 +63,6 @@ impl Inserter { .with_label_values(&["raw"]) .observe(record_batch.num_rows() as f64); - // safety: when reach here schema must be present. - let schema_bytes = decoder.schema_bytes().unwrap(); let partition_timer = metrics::HANDLE_BULK_INSERT_ELAPSED .with_label_values(&["partition"]) .start_timer(); @@ -106,6 +92,7 @@ impl Inserter { .find_region_leader(region_id) .await .context(error::FindRegionLeaderSnafu)?; + let request = RegionRequest { header: Some(RegionRequestHeader { tracing_context: TracingContext::from_current_span().to_w3c(), @@ -114,9 +101,9 @@ impl Inserter { body: Some(region_request::Body::BulkInsert(BulkInsertRequest { region_id: region_id.as_u64(), body: Some(bulk_insert_request::Body::ArrowIpc(ArrowIpc { - schema: schema_bytes, - data_header: data.data_header, - payload: data.data_body, + schema: schema_bytes.clone(), + data_header: raw_flight_data.data_header, + payload: raw_flight_data.data_body, })), })), }; @@ -158,8 +145,6 @@ impl Inserter { let mut handles = Vec::with_capacity(mask_per_datanode.len()); - // raw daya header and payload bytes. - let mut raw_data_bytes = None; for (peer, masks) in mask_per_datanode { for (region_id, mask) in masks { if mask.select_none() { @@ -170,13 +155,10 @@ impl Inserter { let node_manager = self.node_manager.clone(); let peer = peer.clone(); let raw_header_and_data = if mask.select_all() { - Some( - raw_data_bytes - .get_or_insert_with(|| { - (data.data_header.clone(), data.data_body.clone()) - }) - .clone(), - ) + Some(( + raw_flight_data.data_header.clone(), + raw_flight_data.data_body.clone(), + )) } else { None }; diff --git a/src/servers/src/grpc/flight.rs b/src/servers/src/grpc/flight.rs index 8cabcb7fec..0cc3ddd7f8 100644 --- a/src/servers/src/grpc/flight.rs +++ b/src/servers/src/grpc/flight.rs @@ -25,12 +25,15 @@ use arrow_flight::{ HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket, }; use async_trait::async_trait; +use bytes; use bytes::Bytes; use common_grpc::flight::do_put::{DoPutMetadata, DoPutResponse}; -use common_grpc::flight::{FlightEncoder, FlightMessage}; +use common_grpc::flight::{FlightDecoder, FlightEncoder, FlightMessage}; use common_query::{Output, OutputData}; +use common_recordbatch::DfRecordBatch; use common_telemetry::tracing::info_span; use common_telemetry::tracing_context::{FutureExt, TracingContext}; +use datatypes::arrow::datatypes::SchemaRef; use futures::{Stream, future, ready}; use futures_util::{StreamExt, TryStreamExt}; use prost::Message; @@ -41,7 +44,7 @@ use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use tonic::{Request, Response, Status, Streaming}; -use crate::error::{InvalidParameterSnafu, ParseJsonSnafu, Result, ToJsonSnafu}; +use crate::error::{InvalidParameterSnafu, Result, ToJsonSnafu}; pub use crate::grpc::flight::stream::FlightRecordBatchStream; use crate::grpc::greptime_handler::{GreptimeRequestHandler, get_request_type}; use crate::grpc::{FlightCompression, TonicResult, context_auth}; @@ -223,14 +226,15 @@ impl FlightCraft for GreptimeRequestHandler { const MAX_PENDING_RESPONSES: usize = 32; let (tx, rx) = mpsc::channel::>(MAX_PENDING_RESPONSES); - let stream = PutRecordBatchRequestStream { - flight_data_stream: stream, - state: PutRecordBatchRequestStreamState::Init( - query_ctx.current_catalog().to_string(), - query_ctx.current_schema(), - ), + let stream = PutRecordBatchRequestStream::new( + stream, + query_ctx.current_catalog().to_string(), + query_ctx.current_schema(), limiter, - }; + ) + .await?; + // Ack to the first schema message when we successfully built the stream. + let _ = tx.send(Ok(DoPutResponse::new(0, 0, 0.0))).await; self.put_record_batches(stream, tx, query_ctx).await; let response = ReceiverStream::new(rx) @@ -252,30 +256,30 @@ impl FlightCraft for GreptimeRequestHandler { pub struct PutRecordBatchRequest { pub table_name: TableName, pub request_id: i64, - pub data: FlightData, - pub _guard: Option, + pub record_batch: DfRecordBatch, + pub schema_bytes: Bytes, + pub flight_data: FlightData, + pub(crate) _guard: Option, } impl PutRecordBatchRequest { fn try_new( table_name: TableName, + record_batch: DfRecordBatch, + request_id: i64, + schema_bytes: Bytes, flight_data: FlightData, limiter: Option<&RequestMemoryLimiter>, ) -> Result { - let request_id = if !flight_data.app_metadata.is_empty() { - let metadata: DoPutMetadata = - serde_json::from_slice(&flight_data.app_metadata).context(ParseJsonSnafu)?; - metadata.request_id() - } else { - 0 - }; + let memory_usage = flight_data.data_body.len() + + flight_data.app_metadata.len() + + flight_data.data_header.len(); let _guard = limiter .filter(|limiter| limiter.is_enabled()) .map(|limiter| { - let message_size = flight_data.encoded_len(); limiter - .try_acquire(message_size) + .try_acquire(memory_usage) .map(|guard| { guard.inspect(|g| { METRIC_GRPC_MEMORY_USAGE_BYTES.set(g.current_usage() as i64); @@ -291,27 +295,32 @@ impl PutRecordBatchRequest { Ok(Self { table_name, request_id, - data: flight_data, + record_batch, + schema_bytes, + flight_data, _guard, }) } } pub struct PutRecordBatchRequestStream { - pub flight_data_stream: Streaming, - pub state: PutRecordBatchRequestStreamState, - pub limiter: Option, + flight_data_stream: Streaming, + table_name: TableName, + schema: SchemaRef, + schema_bytes: Bytes, + decoder: FlightDecoder, + limiter: Option, } -pub enum PutRecordBatchRequestStreamState { - Init(String, String), - Started(TableName), -} - -impl Stream for PutRecordBatchRequestStream { - type Item = TonicResult; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { +impl PutRecordBatchRequestStream { + /// Creates a new `PutRecordBatchRequestStream` by waiting for the first message, + /// extracting the table name from the flight descriptor, and decoding the schema. + pub async fn new( + mut flight_data_stream: Streaming, + catalog: String, + schema: String, + limiter: Option, + ) -> TonicResult { fn extract_table_name(mut descriptor: FlightDescriptor) -> Result { ensure!( descriptor.r#type == arrow_flight::flight_descriptor::DescriptorType::Path as i32, @@ -328,56 +337,131 @@ impl Stream for PutRecordBatchRequestStream { Ok(descriptor.path.remove(0)) } - let poll = ready!(self.flight_data_stream.poll_next_unpin(cx)); - let limiter = self.limiter.clone(); + // Wait for the first message which must be a Schema message + let first_message = flight_data_stream.next().await.ok_or_else(|| { + Status::failed_precondition("flight data stream ended unexpectedly") + })??; - let result = match &mut self.state { - PutRecordBatchRequestStreamState::Init(catalog, schema) => match poll { - Some(Ok(mut flight_data)) => { - let flight_descriptor = flight_data.flight_descriptor.take(); - let result = if let Some(descriptor) = flight_descriptor { - let table_name = extract_table_name(descriptor) - .map(|x| TableName::new(catalog.clone(), schema.clone(), x)); - let table_name = match table_name { - Ok(table_name) => table_name, - Err(e) => return Poll::Ready(Some(Err(e.into()))), - }; + let flight_descriptor = first_message + .flight_descriptor + .as_ref() + .ok_or_else(|| { + Status::failed_precondition("table to put is not found in flight descriptor") + })? + .clone(); - let request = PutRecordBatchRequest::try_new( - table_name.clone(), - flight_data, - limiter.as_ref(), - ); - let request = match request { - Ok(request) => request, - Err(e) => return Poll::Ready(Some(Err(e.into()))), - }; + let table_name_str = extract_table_name(flight_descriptor) + .map_err(|e| Status::invalid_argument(e.to_string()))?; + let table_name = TableName::new(catalog, schema, table_name_str); - self.state = PutRecordBatchRequestStreamState::Started(table_name); + // Decode the first message as schema + let mut decoder = FlightDecoder::default(); + let schema_message = decoder + .try_decode(&first_message) + .map_err(|e| Status::invalid_argument(format!("Failed to decode schema: {}", e)))?; - Ok(request) - } else { - Err(Status::failed_precondition( - "table to put is not found in flight descriptor", - )) - }; - Some(result) - } - Some(Err(e)) => Some(Err(e)), - None => None, - }, - PutRecordBatchRequestStreamState::Started(table_name) => poll.map(|x| { - x.and_then(|flight_data| { - PutRecordBatchRequest::try_new( - table_name.clone(), - flight_data, - limiter.as_ref(), - ) - .map_err(Into::into) - }) - }), + let (schema, schema_bytes) = match schema_message { + Some(FlightMessage::Schema(schema)) => { + let schema_bytes = decoder.schema_bytes().ok_or_else(|| { + Status::internal("decoder should have schema bytes after decoding schema") + })?; + (schema, schema_bytes) + } + _ => { + return Err(Status::failed_precondition( + "first message must be a Schema message", + )); + } }; - Poll::Ready(result) + + Ok(Self { + flight_data_stream, + table_name, + schema, + schema_bytes, + decoder, + limiter, + }) + } + + /// Returns the table name extracted from the flight descriptor. + pub fn table_name(&self) -> &TableName { + &self.table_name + } + + /// Returns the Arrow schema decoded from the first flight message. + pub fn schema(&self) -> &SchemaRef { + &self.schema + } + + /// Returns the raw schema bytes in IPC format. + pub fn schema_bytes(&self) -> &Bytes { + &self.schema_bytes + } +} + +impl Stream for PutRecordBatchRequestStream { + type Item = TonicResult; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + let poll = ready!(self.flight_data_stream.poll_next_unpin(cx)); + + match poll { + Some(Ok(flight_data)) => { + // Extract request_id and body_size from FlightData before decoding + let request_id = if !flight_data.app_metadata.is_empty() { + match serde_json::from_slice::(&flight_data.app_metadata) { + Ok(metadata) => metadata.request_id(), + Err(_) => 0, + } + } else { + 0 + }; + + // Decode FlightData to RecordBatch + match self.decoder.try_decode(&flight_data) { + Ok(Some(FlightMessage::RecordBatch(record_batch))) => { + let limiter = self.limiter.clone(); + let table_name = self.table_name.clone(); + let schema_bytes = self.schema_bytes.clone(); + return Poll::Ready(Some( + PutRecordBatchRequest::try_new( + table_name, + record_batch, + request_id, + schema_bytes, + flight_data, + limiter.as_ref(), + ) + .map_err(|e| Status::invalid_argument(e.to_string())), + )); + } + Ok(Some(_)) => { + return Poll::Ready(Some(Err(Status::invalid_argument( + "Expected RecordBatch message, got other message type", + )))); + } + Ok(None) => { + // Dictionary batch - processed internally by decoder, continue polling + continue; + } + Err(e) => { + return Poll::Ready(Some(Err(Status::invalid_argument(format!( + "Failed to decode RecordBatch: {}", + e + ))))); + } + } + } + Some(Err(e)) => { + return Poll::Ready(Some(Err(e))); + } + None => { + return Poll::Ready(None); + } + } + } } } diff --git a/src/servers/src/grpc/greptime_handler.rs b/src/servers/src/grpc/greptime_handler.rs index 095c36abb1..c1f146db6d 100644 --- a/src/servers/src/grpc/greptime_handler.rs +++ b/src/servers/src/grpc/greptime_handler.rs @@ -24,7 +24,6 @@ use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_catalog::parse_catalog_and_schema_from_db_string; use common_error::ext::ErrorExt; use common_error::status_code::StatusCode; -use common_grpc::flight::FlightDecoder; use common_grpc::flight::do_put::DoPutResponse; use common_query::Output; use common_runtime::Runtime; @@ -37,15 +36,14 @@ use futures_util::StreamExt; use session::context::{Channel, QueryContextBuilder, QueryContextRef}; use session::hints::READ_PREFERENCE_HINT; use snafu::{OptionExt, ResultExt}; -use table::TableRef; use tokio::sync::mpsc; use tokio::sync::mpsc::error::TrySendError; +use tonic::Status; use crate::error::{InvalidQuerySnafu, JoinTaskSnafu, Result, UnknownHintSnafu}; -use crate::grpc::flight::{PutRecordBatchRequest, PutRecordBatchRequestStream}; +use crate::grpc::flight::PutRecordBatchRequestStream; use crate::grpc::{FlightCompression, TonicResult, context_auth}; -use crate::metrics; -use crate::metrics::METRIC_SERVER_GRPC_DB_REQUEST_TIMER; +use crate::metrics::{self, METRIC_SERVER_GRPC_DB_REQUEST_TIMER}; use crate::query_handler::grpc::ServerGrpcQueryHandlerRef; #[derive(Clone)] @@ -134,7 +132,7 @@ impl GreptimeRequestHandler { pub(crate) async fn put_record_batches( &self, - mut stream: PutRecordBatchRequestStream, + stream: PutRecordBatchRequestStream, result_sender: mpsc::Sender>, query_ctx: QueryContextRef, ) { @@ -144,37 +142,24 @@ impl GreptimeRequestHandler { .clone() .unwrap_or_else(common_runtime::global_runtime); runtime.spawn(async move { - // Cached table ref - let mut table_ref: Option = None; + let mut result_stream = handler.handle_put_record_batch_stream(stream, query_ctx); - let mut decoder = FlightDecoder::default(); - while let Some(request) = stream.next().await { - let request = match request { - Ok(request) => request, - Err(e) => { - let _ = result_sender.try_send(Err(e)); - break; + while let Some(result) = result_stream.next().await { + match &result { + Ok(response) => { + // Record the elapsed time metric from the response + metrics::GRPC_BULK_INSERT_ELAPSED.observe(response.elapsed_secs()); } - }; - let PutRecordBatchRequest { - table_name, - request_id, - data, - _guard, - } = request; + Err(e) => { + error!(e; "Failed to handle flight record batches"); + } + } - let timer = metrics::GRPC_BULK_INSERT_ELAPSED.start_timer(); - let result = handler - .put_record_batch(&table_name, &mut table_ref, &mut decoder, data, query_ctx.clone()) - .await - .inspect_err(|e| error!(e; "Failed to handle flight record batches")); - timer.observe_duration(); - let result = result - .map(|x| DoPutResponse::new(request_id, x)) - .map_err(Into::into); - if let Err(e)= result_sender.try_send(result) - && let TrySendError::Closed(_) = e { - warn!(r#""DoPut" client with request_id {} maybe unreachable, abort handling its message"#, request_id); + if let Err(e) = + result_sender.try_send(result.map_err(|e| Status::from_error(Box::new(e)))) + && let TrySendError::Closed(_) = e + { + warn!(r#""DoPut" client maybe unreachable, abort handling its message"#); break; } } diff --git a/src/servers/src/query_handler/grpc.rs b/src/servers/src/query_handler/grpc.rs index 305fde4448..2403c82905 100644 --- a/src/servers/src/query_handler/grpc.rs +++ b/src/servers/src/query_handler/grpc.rs @@ -12,21 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::pin::Pin; use std::sync::Arc; use api::v1::greptime_request::Request; -use arrow_flight::FlightData; use async_trait::async_trait; use common_base::AffectedRows; use common_error::ext::{BoxedError, ErrorExt}; -use common_grpc::flight::FlightDecoder; +use common_grpc::flight::do_put::DoPutResponse; use common_query::Output; +use futures::Stream; use session::context::QueryContextRef; use snafu::ResultExt; use table::TableRef; -use table::table_name::TableName; use crate::error::{self, Result}; +use crate::grpc::flight::{PutRecordBatchRequest, PutRecordBatchRequestStream}; pub type GrpcQueryHandlerRef = Arc + Send + Sync>; pub type ServerGrpcQueryHandlerRef = GrpcQueryHandlerRef; @@ -45,12 +46,16 @@ pub trait GrpcQueryHandler { async fn put_record_batch( &self, - table_name: &TableName, + request: PutRecordBatchRequest, table_ref: &mut Option, - decoder: &mut FlightDecoder, - flight_data: FlightData, ctx: QueryContextRef, ) -> std::result::Result; + + fn handle_put_record_batch_stream( + &self, + stream: PutRecordBatchRequestStream, + ctx: QueryContextRef, + ) -> Pin> + Send>>; } pub struct ServerGrpcQueryHandlerAdapter(GrpcQueryHandlerRef); @@ -78,16 +83,31 @@ where async fn put_record_batch( &self, - table_name: &TableName, + request: PutRecordBatchRequest, table_ref: &mut Option, - decoder: &mut FlightDecoder, - data: FlightData, ctx: QueryContextRef, ) -> Result { self.0 - .put_record_batch(table_name, table_ref, decoder, data, ctx) + .put_record_batch(request, table_ref, ctx) .await .map_err(BoxedError::new) .context(error::ExecuteGrpcRequestSnafu) } + + fn handle_put_record_batch_stream( + &self, + stream: PutRecordBatchRequestStream, + ctx: QueryContextRef, + ) -> Pin> + Send>> { + use futures_util::StreamExt; + Box::pin( + self.0 + .handle_put_record_batch_stream(stream, ctx) + .map(|result| { + result + .map_err(|e| BoxedError::new(e)) + .context(error::ExecuteGrpcRequestSnafu) + }), + ) + } } diff --git a/src/servers/tests/mod.rs b/src/servers/tests/mod.rs index 7d6268215c..3f85b6d3ad 100644 --- a/src/servers/tests/mod.rs +++ b/src/servers/tests/mod.rs @@ -16,12 +16,11 @@ use std::sync::Arc; use api::v1::greptime_request::Request; use api::v1::query_request::Query; -use arrow_flight::FlightData; use async_trait::async_trait; use catalog::memory::MemoryCatalogManager; use common_base::AffectedRows; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; -use common_grpc::flight::FlightDecoder; +use common_grpc::flight::do_put::DoPutResponse; use common_query::Output; use datafusion_expr::LogicalPlan; use query::options::QueryOptions; @@ -35,7 +34,6 @@ use session::context::QueryContextRef; use snafu::ensure; use sql::statements::statement::Statement; use table::TableRef; -use table::table_name::TableName; mod http; mod interceptor; @@ -165,14 +163,22 @@ impl GrpcQueryHandler for DummyInstance { async fn put_record_batch( &self, - _table_name: &TableName, + _request: servers::grpc::flight::PutRecordBatchRequest, _table_ref: &mut Option, - _decoder: &mut FlightDecoder, - _data: FlightData, _ctx: QueryContextRef, ) -> std::result::Result { unimplemented!() } + + fn handle_put_record_batch_stream( + &self, + _stream: servers::grpc::flight::PutRecordBatchRequestStream, + _ctx: QueryContextRef, + ) -> std::pin::Pin< + Box> + Send>, + > { + unimplemented!() + } } fn create_testing_instance(table: TableRef) -> DummyInstance {