From d08743fc644fd78d2c9f5d6e971d5228730e37a4 Mon Sep 17 00:00:00 2001 From: discord9 Date: Wed, 29 Apr 2026 16:54:24 +0800 Subject: [PATCH] refactor: per review Signed-off-by: discord9 --- src/client/src/database.rs | 373 +++++++++++++++--------------- src/common/grpc/src/flight.rs | 59 ++++- src/datanode/src/region_server.rs | 143 +++++++++++- src/query/src/datafusion.rs | 34 ++- src/query/src/metrics.rs | 76 +++--- src/query/src/options.rs | 41 +++- src/servers/src/grpc/flight.rs | 17 +- 7 files changed, 497 insertions(+), 246 deletions(-) diff --git a/src/client/src/database.rs b/src/client/src/database.rs index 5114fa2a7e..a9ab7fd888 100644 --- a/src/client/src/database.rs +++ b/src/client/src/database.rs @@ -14,8 +14,8 @@ use std::pin::Pin; use std::str::FromStr; -use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, RwLock}; use std::task::{Context, Poll}; use api::v1::auth_header::AuthScheme; @@ -63,7 +63,7 @@ type DoPutResponseStream = Pin>>>; #[derive(Debug, Clone, Default)] pub struct OutputMetrics { - metrics: Arc>, + metrics: Arc>>, ready: Arc, } @@ -73,7 +73,7 @@ impl OutputMetrics { } pub fn update(&self, metrics: Option) { - self.metrics.swap(metrics.map(Arc::new)); + *self.metrics.write().expect("metrics lock poisoned") = metrics; } pub fn mark_ready(&self) { @@ -85,7 +85,7 @@ impl OutputMetrics { } pub fn get(&self) -> Option { - self.metrics.load().as_ref().map(|m| m.as_ref().clone()) + self.metrics.read().expect("metrics lock poisoned").clone() } /// Returns proved per-region watermarks. @@ -227,6 +227,123 @@ fn attach_terminal_metrics(output: Output, terminal_metrics: &OutputMetrics) -> Output::new(data, meta) } +async fn output_from_flight_message_stream( + mut flight_message_stream: S, +) -> Result +where + S: Stream> + Send + Unpin + 'static, +{ + let Some(first_flight_message) = flight_message_stream.next().await else { + return IllegalFlightMessagesSnafu { + reason: "Expect the response not to be empty", + } + .fail(); + }; + + let first_flight_message = first_flight_message?; + + match first_flight_message { + FlightMessage::AffectedRows { rows, metrics } => { + let terminal_metrics = OutputMetrics::new(); + if let Some(metrics) = metrics { + terminal_metrics.update(Some(parse_terminal_metrics(&metrics)?)); + } + let next_message = flight_message_stream.next().await.transpose()?; + match next_message { + None => terminal_metrics.mark_ready(), + Some(FlightMessage::Metrics(s)) if terminal_metrics.get().is_none() => { + terminal_metrics.update(Some(parse_terminal_metrics(&s)?)); + terminal_metrics.mark_ready(); + ensure!( + flight_message_stream.next().await.is_none(), + IllegalFlightMessagesSnafu { + reason: "Expect 'AffectedRows' Flight messages to be followed by at most one Metrics message" + } + ); + } + Some(FlightMessage::Metrics(_)) => { + return IllegalFlightMessagesSnafu { + reason: "'AffectedRows' Flight metadata already carries Metrics and cannot be followed by another Metrics message".to_string(), + } + .fail(); + } + Some(other) => { + return IllegalFlightMessagesSnafu { + reason: format!( + "'AffectedRows' Flight message can only be followed by a Metrics message, got {other:?}" + ), + } + .fail(); + } + } + Ok(OutputWithMetrics { + output: Output::new_with_affected_rows(rows), + metrics: terminal_metrics, + }) + } + FlightMessage::RecordBatch(_) | FlightMessage::Metrics(_) => IllegalFlightMessagesSnafu { + reason: "The first flight message cannot be a RecordBatch or Metrics message", + } + .fail(), + FlightMessage::Schema(schema) => { + let metrics = Arc::new(ArcSwapOption::from(None)); + let metrics_ref = metrics.clone(); + let schema = Arc::new( + datatypes::schema::Schema::try_from(schema).context(error::ConvertSchemaSnafu)?, + ); + let schema_cloned = schema.clone(); + let stream = Box::pin(stream!({ + while let Some(flight_message_item) = flight_message_stream.next().await { + let flight_message = match flight_message_item { + Ok(message) => message, + Err(e) => { + yield Err(BoxedError::new(e)).context(ExternalSnafu); + break; + } + }; + + match flight_message { + FlightMessage::RecordBatch(arrow_batch) => { + yield Ok(RecordBatch::from_df_record_batch( + schema_cloned.clone(), + arrow_batch, + )) + } + FlightMessage::Metrics(s) => { + match parse_terminal_metrics(&s) { + Ok(m) => { + metrics_ref.swap(Some(Arc::new(m))); + } + Err(e) => { + yield Err(BoxedError::new(e)).context(ExternalSnafu); + } + }; + break; + } + FlightMessage::AffectedRows { .. } | FlightMessage::Schema(_) => { + yield IllegalFlightMessagesSnafu {reason: format!("A Schema message must be succeeded exclusively by a set of RecordBatch messages, flight_message: {:?}", flight_message)} + .fail() + .map_err(BoxedError::new) + .context(ExternalSnafu); + break; + } + } + } + })); + let record_batch_stream = RecordBatchStreamWrapper { + schema, + stream, + output_ordering: None, + metrics, + span: Span::current(), + }; + Ok(OutputWithMetrics::from_output(Output::new_with_stream( + Box::pin(record_batch_stream), + ))) + } + } +} + #[derive(Clone, Debug, Default)] pub struct Database { // The "catalog" and "schema" to be used in processing the requests at the server side. @@ -597,7 +714,7 @@ impl Database { let flight_data_stream = response.into_inner(); let mut decoder = FlightDecoder::default(); - let mut flight_message_stream = flight_data_stream.map(move |flight_data| { + let flight_message_stream = flight_data_stream.map(move |flight_data| { flight_data .map_err(Error::from) .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu))? @@ -606,156 +723,7 @@ impl Database { }) }); - let Some(first_flight_message) = flight_message_stream.next().await else { - return IllegalFlightMessagesSnafu { - reason: "Expect the response not to be empty", - } - .fail(); - }; - - let first_flight_message = first_flight_message?; - - match first_flight_message { - FlightMessage::AffectedRows(rows) => { - let terminal_metrics = OutputMetrics::new(); - let next_message = flight_message_stream.next().await.transpose()?; - match next_message { - None => terminal_metrics.mark_ready(), - Some(FlightMessage::Metrics(s)) => { - terminal_metrics.update(Some(parse_terminal_metrics(&s)?)); - terminal_metrics.mark_ready(); - ensure!( - flight_message_stream.next().await.is_none(), - IllegalFlightMessagesSnafu { - reason: "Expect 'AffectedRows' Flight messages to be followed by at most one Metrics message" - } - ); - } - Some(other) => { - return IllegalFlightMessagesSnafu { - reason: format!( - "'AffectedRows' Flight message can only be followed by a Metrics message, got {other:?}" - ), - } - .fail(); - } - } - Ok(OutputWithMetrics { - output: Output::new_with_affected_rows(rows), - metrics: terminal_metrics, - }) - } - FlightMessage::RecordBatch(_) | FlightMessage::Metrics(_) => { - IllegalFlightMessagesSnafu { - reason: "The first flight message cannot be a RecordBatch or Metrics message", - } - .fail() - } - FlightMessage::Schema(schema) => { - let metrics = Arc::new(ArcSwapOption::from(None)); - let metrics_ref = metrics.clone(); - let schema = Arc::new( - datatypes::schema::Schema::try_from(schema) - .context(error::ConvertSchemaSnafu)?, - ); - let schema_cloned = schema.clone(); - let stream = Box::pin(stream!({ - let mut buffered_message: Option = None; - let mut stream_ended = false; - - while !stream_ended { - let flight_message_item = if let Some(msg) = buffered_message.take() { - Some(Ok(msg)) - } else { - flight_message_stream.next().await - }; - - let flight_message = match flight_message_item { - Some(Ok(message)) => message, - Some(Err(e)) => { - yield Err(BoxedError::new(e)).context(ExternalSnafu); - break; - } - None => break, - }; - - match flight_message { - FlightMessage::RecordBatch(arrow_batch) => { - let result_to_yield = RecordBatch::from_df_record_batch( - schema_cloned.clone(), - arrow_batch, - ); - - if let Some(next_flight_message_result) = - flight_message_stream.next().await - { - match next_flight_message_result { - Ok(FlightMessage::Metrics(s)) => { - match parse_terminal_metrics(&s) { - Ok(m) => { - metrics_ref.swap(Some(Arc::new(m))); - } - Err(e) => { - yield Err(BoxedError::new(e)) - .context(ExternalSnafu); - break; - } - }; - } - Ok(FlightMessage::RecordBatch(rb)) => { - buffered_message = Some(FlightMessage::RecordBatch(rb)); - } - Ok(_) => { - yield IllegalFlightMessagesSnafu {reason: "A RecordBatch message can only be succeeded by a Metrics message or another RecordBatch message"} - .fail() - .map_err(BoxedError::new) - .context(ExternalSnafu); - break; - } - Err(e) => { - yield Err(BoxedError::new(e)).context(ExternalSnafu); - break; - } - } - } else { - stream_ended = true; - } - - yield Ok(result_to_yield) - } - FlightMessage::Metrics(s) => { - match parse_terminal_metrics(&s) { - Ok(m) => { - metrics_ref.swap(Some(Arc::new(m))); - } - Err(e) => { - yield Err(BoxedError::new(e)).context(ExternalSnafu); - } - }; - break; - } - FlightMessage::AffectedRows(_) | FlightMessage::Schema(_) => { - yield IllegalFlightMessagesSnafu {reason: format!("A Schema message must be succeeded exclusively by a set of RecordBatch messages, flight_message: {:?}", flight_message)} - .fail() - .map_err(BoxedError::new) - .context(ExternalSnafu); - break; - } - } - } - })); - let record_batch_stream = RecordBatchStreamWrapper { - schema, - stream, - output_ordering: None, - metrics, - span: Span::current(), - }; - Ok(OutputWithMetrics::from_output(Output::new_with_stream( - Box::pin(record_batch_stream), - ))) - } - } + output_from_flight_message_stream(flight_message_stream).await } /// Ingest a stream of [RecordBatch]es that belong to a table, using Arrow Flight's "`DoPut`" @@ -859,6 +827,17 @@ mod tests { } } + fn terminal_metrics_json() -> String { + serde_json::to_string(&RecordBatchMetrics { + region_watermarks: vec![common_recordbatch::adapter::RegionWatermarkEntry { + region_id: 7, + watermark: Some(42), + }], + ..Default::default() + }) + .unwrap() + } + #[test] fn test_flight_ctx() { let mut ctx = FlightContext::default(); @@ -948,7 +927,47 @@ mod tests { } #[tokio::test] - async fn test_invalid_terminal_metrics_after_record_batch_fails_before_yielding_batch() { + async fn test_affected_rows_inline_metrics_are_parsed() { + let output = output_from_flight_message_stream(futures_util::stream::iter(vec![Ok( + FlightMessage::AffectedRows { + rows: 3, + metrics: Some(terminal_metrics_json()), + }, + )] + as Vec>)) + .await + .unwrap(); + + assert!(matches!(output.output.data, OutputData::AffectedRows(3))); + assert!(output.metrics.is_ready()); + assert_eq!( + output.metrics.region_watermark_map(), + Some(std::collections::HashMap::from([(7, 42)])) + ); + } + + #[tokio::test] + async fn test_affected_rows_inline_metrics_rejects_trailing_metrics() { + let metrics_json = terminal_metrics_json(); + let err = output_from_flight_message_stream(futures_util::stream::iter(vec![ + Ok(FlightMessage::AffectedRows { + rows: 3, + metrics: Some(metrics_json.clone()), + }), + Ok(FlightMessage::Metrics(metrics_json)), + ] + as Vec>)) + .await + .unwrap_err(); + + assert!( + err.to_string().contains("already carries Metrics"), + "unexpected error: {err:?}" + ); + } + + #[tokio::test] + async fn test_invalid_terminal_metrics_after_record_batch_yields_batch_then_error() { let schema = Arc::new(Schema::new(vec![ColumnSchema::new( "v", ConcreteDataType::int32_datatype(), @@ -959,38 +978,21 @@ mod tests { vec![Arc::new(Int32Vector::from_slice([1])) as VectorRef], ) .unwrap(); - let metrics = Arc::new(ArcSwapOption::from(None)); - let metrics_ref = metrics.clone(); - let schema_cloned = schema.clone(); - let mut flight_message_stream = futures_util::stream::iter(vec![ + let output = output_from_flight_message_stream(futures_util::stream::iter(vec![ + Ok(FlightMessage::Schema(schema.arrow_schema().clone())), Ok(FlightMessage::RecordBatch(batch.into_df_record_batch())), Ok(FlightMessage::Metrics("{not-json}".to_string())), ] - as Vec>); + as Vec>)) + .await + .unwrap(); + let terminal_metrics = output.metrics.clone(); + let OutputData::Stream(mut record_batch_stream) = output.output.data else { + panic!("expected stream output"); + }; - let mut record_batch_stream = Box::pin(stream!({ - let Some(Ok(FlightMessage::RecordBatch(arrow_batch))) = - flight_message_stream.next().await - else { - return; - }; - let result_to_yield = - RecordBatch::from_df_record_batch(schema_cloned.clone(), arrow_batch); - - if let Some(Ok(FlightMessage::Metrics(s))) = flight_message_stream.next().await { - match parse_terminal_metrics(&s) { - Ok(m) => { - metrics_ref.swap(Some(Arc::new(m))); - } - Err(e) => { - yield Err(BoxedError::new(e)).context(ExternalSnafu); - return; - } - } - } - - yield Ok(result_to_yield); - })); + let batch = record_batch_stream.next().await.unwrap().unwrap(); + assert_eq!(batch.num_rows(), 1); let err = record_batch_stream.next().await.unwrap().unwrap_err(); assert_eq!("External error", err.to_string()); @@ -999,7 +1001,8 @@ mod tests { "unexpected error: {err:?}" ); assert!(record_batch_stream.next().await.is_none()); - assert!(metrics.load().is_none()); + assert!(terminal_metrics.is_ready()); + assert!(terminal_metrics.get().is_none()); } #[test] diff --git a/src/common/grpc/src/flight.rs b/src/common/grpc/src/flight.rs index 5fc115a60e..f09400d9b4 100644 --- a/src/common/grpc/src/flight.rs +++ b/src/common/grpc/src/flight.rs @@ -41,7 +41,10 @@ use crate::error::{DecodeFlightDataSnafu, InvalidFlightDataSnafu, Result}; pub enum FlightMessage { Schema(SchemaRef), RecordBatch(DfRecordBatch), - AffectedRows(usize), + AffectedRows { + rows: usize, + metrics: Option, + }, Metrics(String), } @@ -116,10 +119,12 @@ impl FlightEncoder { encoded_batch.into(), ) } - FlightMessage::AffectedRows(rows) => { + FlightMessage::AffectedRows { rows, metrics } => { let metadata = FlightMetadata { affected_rows: Some(AffectedRows { value: rows as _ }), - metrics: None, + metrics: metrics.map(|s| Metrics { + metrics: s.into_bytes(), + }), } .encode_to_vec(); vec1![FlightData { @@ -223,7 +228,12 @@ impl FlightDecoder { let metadata = FlightMetadata::decode(flight_data.app_metadata.clone()) .context(DecodeFlightDataSnafu)?; if let Some(AffectedRows { value }) = metadata.affected_rows { - return Ok(Some(FlightMessage::AffectedRows(value as _))); + return Ok(Some(FlightMessage::AffectedRows { + rows: value as _, + metrics: metadata + .metrics + .map(|m| String::from_utf8_lossy(&m.metrics).to_string()), + })); } if let Some(Metrics { metrics }) = metadata.metrics { return Ok(Some(FlightMessage::Metrics( @@ -426,6 +436,47 @@ mod test { Ok(()) } + #[test] + fn test_affected_rows_metrics_encode_decode() -> Result<()> { + let metrics = r#"{"region_watermarks":[{"region_id":42,"watermark":7}]}"#; + let mut encoder = FlightEncoder::default(); + let encoded = encoder.encode(FlightMessage::AffectedRows { + rows: 3, + metrics: Some(metrics.to_string()), + }); + + assert_eq!(encoded.len(), 1); + + let mut decoder = FlightDecoder::default(); + let decoded = decoder.try_decode(encoded.first())?.unwrap(); + let FlightMessage::AffectedRows { + rows, + metrics: decoded_metrics, + } = decoded + else { + unreachable!() + }; + assert_eq!(rows, 3); + assert_eq!(decoded_metrics.as_deref(), Some(metrics)); + + let encoded = encoder.encode(FlightMessage::AffectedRows { + rows: 5, + metrics: None, + }); + let decoded = decoder.try_decode(encoded.first())?.unwrap(); + let FlightMessage::AffectedRows { + rows, + metrics: decoded_metrics, + } = decoded + else { + unreachable!() + }; + assert_eq!(rows, 5); + assert!(decoded_metrics.is_none()); + + Ok(()) + } + #[test] fn test_flight_messages_to_recordbatches() { let schema = Arc::new(Schema::new(vec![Field::new("m", DataType::Int32, true)])); diff --git a/src/datanode/src/region_server.rs b/src/datanode/src/region_server.rs index aa3ffbfe3a..29dc0f4e03 100644 --- a/src/datanode/src/region_server.rs +++ b/src/datanode/src/region_server.rs @@ -17,8 +17,10 @@ mod catalog; use std::collections::HashMap; use std::fmt::Debug; use std::ops::Deref; +use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, RwLock}; +use std::task::{Context, Poll}; use std::time::Duration; use api::region::RegionResponse; @@ -36,7 +38,8 @@ use common_error::status_code::StatusCode; use common_meta::datanode::TopicStatsReporter; use common_query::OutputData; use common_query::request::QueryRequest; -use common_recordbatch::SendableRecordBatchStream; +use common_recordbatch::adapter::RecordBatchMetrics; +use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream, SendableRecordBatchStream}; use common_runtime::Runtime; use common_telemetry::tracing::{self, info_span}; use common_telemetry::tracing_context::{FutureExt, TracingContext}; @@ -45,6 +48,7 @@ use dashmap::DashMap; use datafusion::datasource::TableProvider; use datafusion_common::tree_node::TreeNode; use either::Either; +use futures_util::Stream; use futures_util::future::try_join_all; use metric_engine::engine::MetricEngine; use mito2::engine::{MITO_ENGINE_NAME, MitoEngine}; @@ -53,6 +57,7 @@ use query::QueryEngineRef; pub use query::dummy_catalog::{ DummyCatalogList, DummyTableProviderFactory, TableProviderFactoryRef, }; +use query::options::should_collect_region_watermark_from_extensions; use serde_json; use servers::error::{ self as servers_error, ExecuteGrpcRequestSnafu, Result as ServerResult, SuspendedSnafu, @@ -278,16 +283,31 @@ impl RegionServer { .await .context(DecodeLogicalPlanSnafu)?; - self.inner + let stream = self + .inner .handle_read( QueryRequest { header: request.header, region_id, plan, }, - query_ctx, + query_ctx.clone(), ) - .await + .await?; + + let region_latest_seq = + if should_collect_region_watermark_from_extensions(&query_ctx.extensions()) { + query_ctx.get_snapshot(region_id.as_u64()) + } else { + None + }; + + if let Some(seq) = region_latest_seq { + Ok(Box::pin(RegionWatermarkStream::new(stream, region_id, seq)) + as SendableRecordBatchStream) + } else { + Ok(stream) + } } #[tracing::instrument(skip_all)] @@ -749,6 +769,83 @@ impl RegionServer { } } +/// Wraps a region read stream so terminal metrics can carry the scan-open watermark. +struct RegionWatermarkStream { + stream: SendableRecordBatchStream, + region_id: u64, + snapshot_seq: u64, + finished: AtomicBool, +} + +impl RegionWatermarkStream { + fn new(stream: SendableRecordBatchStream, region_id: RegionId, snapshot_seq: u64) -> Self { + Self { + stream, + region_id: region_id.as_u64(), + snapshot_seq, + finished: AtomicBool::new(false), + } + } + + fn merged_metrics(&self, mut metrics: RecordBatchMetrics) -> RecordBatchMetrics { + let entry = if let Some(entry) = metrics + .region_watermarks + .iter_mut() + .find(|entry| entry.region_id == self.region_id) + { + entry + } else { + metrics + .region_watermarks + .push(common_recordbatch::adapter::RegionWatermarkEntry { + region_id: self.region_id, + watermark: None, + }); + metrics.region_watermarks.last_mut().unwrap() + }; + + entry.watermark = Some(self.snapshot_seq); + metrics + } +} + +impl RecordBatchStream for RegionWatermarkStream { + fn name(&self) -> &str { + self.stream.name() + } + + fn schema(&self) -> datatypes::schema::SchemaRef { + self.stream.schema() + } + + fn output_ordering(&self) -> Option<&[OrderOption]> { + self.stream.output_ordering() + } + + fn metrics(&self) -> Option { + let base = self.stream.metrics(); + if !self.finished.load(Ordering::Relaxed) { + return base; + } + + Some(self.merged_metrics(base.unwrap_or_default())) + } +} + +impl Stream for RegionWatermarkStream { + type Item = common_recordbatch::error::Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match Pin::new(&mut self.stream).poll_next(cx) { + Poll::Ready(None) => { + self.finished.store(true, Ordering::Relaxed); + Poll::Ready(None) + } + other => other, + } + } +} + #[async_trait] impl RegionServerHandler for RegionServer { async fn handle(&self, request: region_request::Body) -> ServerResult { @@ -1669,10 +1766,16 @@ impl RegionAttribute { mod tests { use std::assert_matches; + use std::sync::Arc; use api::v1::SemanticType; use common_error::ext::ErrorExt; - use datatypes::prelude::ConcreteDataType; + use common_recordbatch::RecordBatches; + use common_recordbatch::adapter::RegionWatermarkEntry; + use datatypes::prelude::{ConcreteDataType, VectorRef}; + use datatypes::schema::{ColumnSchema, Schema}; + use datatypes::vectors::Int32Vector; + use futures_util::StreamExt; use mito2::test_util::CreateRequestBuilder; use store_api::metadata::{ColumnMetadata, RegionMetadata, RegionMetadataBuilder}; use store_api::region_engine::RegionEngine; @@ -1685,6 +1788,36 @@ mod tests { use crate::error::Result; use crate::tests::{MockRegionEngine, mock_region_server}; + #[tokio::test] + async fn test_region_watermark_stream_only_sets_terminal_metrics() { + let schema = Arc::new(Schema::new(vec![ColumnSchema::new( + "v", + ConcreteDataType::int32_datatype(), + false, + )])); + let values: VectorRef = Arc::new(Int32Vector::from_slice([1, 2])); + let batch = RecordBatch::new(schema.clone(), vec![values]).unwrap(); + let stream = RecordBatches::try_new(schema, vec![batch]) + .unwrap() + .as_stream(); + + let region_id = RegionId::new(42, 7); + let wrapped = RegionWatermarkStream::new(stream, region_id, 99); + let mut pinned = Box::pin(wrapped); + + assert!(pinned.as_ref().get_ref().metrics().is_none()); + while pinned.next().await.is_some() {} + + let metrics = pinned.as_ref().get_ref().metrics().unwrap(); + assert_eq!( + metrics.region_watermarks, + vec![RegionWatermarkEntry { + region_id: region_id.as_u64(), + watermark: Some(99), + }] + ); + } + #[tokio::test] async fn test_region_registering() { common_telemetry::init_default_ut_logging(); diff --git a/src/query/src/datafusion.rs b/src/query/src/datafusion.rs index 6fc78c59e5..8f648ccfcb 100644 --- a/src/query/src/datafusion.rs +++ b/src/query/src/datafusion.rs @@ -59,7 +59,8 @@ use crate::error::{ TableNotFoundSnafu, TableReadOnlySnafu, UnsupportedExprSnafu, }; use crate::executor::QueryExecutor; -use crate::metrics::{OnDone, QUERY_STAGE_ELAPSED}; +use crate::metrics::{OnDone, QUERY_STAGE_ELAPSED, RegionWatermarkMetricsStream}; +use crate::options::FlowQueryExtensions; use crate::physical_wrapper::PhysicalPlanWrapperRef; use crate::planner::{DfLogicalPlanner, LogicalPlanner}; use crate::query_engine::{DescribeResult, QueryEngineContext, QueryEngineState}; @@ -100,8 +101,10 @@ impl DatafusionQueryEngine { optimized_physical_plan }; + let stream = self.execute_stream(&ctx, &physical_plan)?; + Ok(Output::new( - OutputData::Stream(self.execute_stream(&ctx, &physical_plan)?), + OutputData::Stream(stream), OutputMeta::new_with_plan(physical_plan), )) } @@ -128,10 +131,10 @@ impl DatafusionQueryEngine { let table_name = dml.table_name.resolve(default_catalog, default_schema); let table = self.find_table(&table_name, &query_ctx).await?; - let output = self + let Output { data, meta } = self .exec_query_plan((*dml.input).clone(), query_ctx.clone()) .await?; - let mut stream = match output.data { + let mut stream = match data { OutputData::RecordBatches(batches) => batches.as_stream(), OutputData::Stream(stream) => stream, _ => unreachable!(), @@ -167,7 +170,7 @@ impl DatafusionQueryEngine { } Ok(Output::new( OutputData::AffectedRows(affected_rows), - OutputMeta::new_with_cost(insert_cost), + OutputMeta::new(meta.plan, insert_cost), )) } @@ -544,6 +547,9 @@ impl QueryExecutor for DatafusionQueryEngine { plan: &Arc, ) -> Result { let explain_verbose = ctx.query_ctx().explain_verbose(); + let should_collect_region_watermark = + FlowQueryExtensions::parse_flow_extensions(&ctx.query_ctx().extensions())? + .is_some_and(|extensions| extensions.should_collect_region_watermark()); let output_partitions = plan.properties().output_partitioning().partition_count(); if explain_verbose { common_telemetry::info!("Executing query plan, output_partitions: {output_partitions}"); @@ -579,7 +585,14 @@ impl QueryExecutor for DatafusionQueryEngine { ); } }); - Ok(Box::pin(stream)) + if should_collect_region_watermark { + Ok(Box::pin(RegionWatermarkMetricsStream::new( + Box::pin(stream), + plan.clone(), + ))) + } else { + Ok(Box::pin(stream)) + } } _ => { // merge into a single partition @@ -608,7 +621,14 @@ impl QueryExecutor for DatafusionQueryEngine { ); } }); - Ok(Box::pin(stream)) + if should_collect_region_watermark { + Ok(Box::pin(RegionWatermarkMetricsStream::new( + Box::pin(stream), + plan.clone(), + ))) + } else { + Ok(Box::pin(stream)) + } } } } diff --git a/src/query/src/metrics.rs b/src/query/src/metrics.rs index 9a376d748c..02bfd43b5d 100644 --- a/src/query/src/metrics.rs +++ b/src/query/src/metrics.rs @@ -230,13 +230,52 @@ fn collect_region_watermarks(plan: Arc) -> Vec, + entries: impl IntoIterator, +) { + for entry in entries { + merged + .entry(entry.region_id) + .and_modify(|existing| match entry.watermark { + None => match existing { + MergeState::Participated | MergeState::Proved(_) => { + *existing = MergeState::Unproved; + } + MergeState::Unproved | MergeState::Conflict { .. } => {} + }, + Some(seq) => match existing { + MergeState::Participated => { + *existing = MergeState::Proved(seq); + } + MergeState::Unproved => {} + MergeState::Proved(existing_seq) if *existing_seq == seq => {} + MergeState::Proved(existing_seq) => { + let old_seq = *existing_seq; + *existing = MergeState::Conflict { + watermarks: vec![old_seq, seq], + }; + } + MergeState::Conflict { watermarks } => { + if !watermarks.contains(&seq) { + watermarks.push(seq); + } + } + }, + }) + .or_insert(match entry.watermark { + Some(seq) => MergeState::Proved(seq), + None => MergeState::Unproved, + }); + } +} + fn merge_merge_scan_region_watermarks( merged: &mut BTreeMap, regions: impl IntoIterator, @@ -247,40 +286,7 @@ fn merge_merge_scan_region_watermarks( } for metrics in sub_stage_metrics { - for entry in metrics.region_watermarks { - merged - .entry(entry.region_id) - .and_modify(|existing| match entry.watermark { - None => match existing { - MergeState::Participated | MergeState::Proved(_) => { - *existing = MergeState::Unproved; - } - MergeState::Unproved | MergeState::Conflict { .. } => {} - }, - Some(seq) => match existing { - MergeState::Participated => { - *existing = MergeState::Proved(seq); - } - MergeState::Unproved => {} - MergeState::Proved(existing_seq) if *existing_seq == seq => {} - MergeState::Proved(existing_seq) => { - let old_seq = *existing_seq; - *existing = MergeState::Conflict { - watermarks: vec![old_seq, seq], - }; - } - MergeState::Conflict { watermarks } => { - if !watermarks.contains(&seq) { - watermarks.push(seq); - } - } - }, - }) - .or_insert(match entry.watermark { - Some(seq) => MergeState::Proved(seq), - None => MergeState::Unproved, - }); - } + merge_region_watermark_entries(merged, metrics.region_watermarks); } } diff --git a/src/query/src/options.rs b/src/query/src/options.rs index 46b8f1e413..6cb74f0305 100644 --- a/src/query/src/options.rs +++ b/src/query/src/options.rs @@ -177,10 +177,31 @@ impl FlowQueryExtensions { } pub fn should_collect_region_watermark(&self) -> bool { - self.return_region_seq || self.incremental_after_seqs.is_some() + should_collect_region_watermark( + self.return_region_seq, + self.incremental_after_seqs.is_some(), + ) } } +pub fn should_collect_region_watermark_from_extensions( + extensions: &HashMap, +) -> bool { + let return_region_seq = extensions + .get(FLOW_RETURN_REGION_SEQ) + .is_some_and(|value| value.eq_ignore_ascii_case("true")); + let has_incremental_after_seqs = extensions.contains_key(FLOW_INCREMENTAL_AFTER_SEQS); + + should_collect_region_watermark(return_region_seq, has_incremental_after_seqs) +} + +fn should_collect_region_watermark( + return_region_seq: bool, + has_incremental_after_seqs: bool, +) -> bool { + return_region_seq || has_incremental_after_seqs +} + fn parse_incremental_after_seqs(value: &str) -> Result> { let raw = serde_json::from_str::>(value).map_err(|e| { invalid_query_context_extension(format!( @@ -420,6 +441,24 @@ mod flow_extension_tests { assert!(parsed.should_collect_region_watermark()); } + #[test] + fn test_should_collect_region_watermark_from_extensions() { + let exts = HashMap::from([(FLOW_RETURN_REGION_SEQ.to_string(), "true".to_string())]); + assert!(should_collect_region_watermark_from_extensions(&exts)); + + let exts = HashMap::from([( + FLOW_INCREMENTAL_AFTER_SEQS.to_string(), + r#"{"1":10}"#.to_string(), + )]); + assert!(should_collect_region_watermark_from_extensions(&exts)); + + let exts = HashMap::from([(FLOW_RETURN_REGION_SEQ.to_string(), "false".to_string())]); + assert!(!should_collect_region_watermark_from_extensions(&exts)); + assert!(!should_collect_region_watermark_from_extensions( + &HashMap::new() + )); + } + #[test] fn test_parse_flow_extensions_return_region_seq_only_returns_some() { let exts = HashMap::from([(FLOW_RETURN_REGION_SEQ.to_string(), "true".to_string())]); diff --git a/src/servers/src/grpc/flight.rs b/src/servers/src/grpc/flight.rs index 4f262c53aa..ddd0b694a8 100644 --- a/src/servers/src/grpc/flight.rs +++ b/src/servers/src/grpc/flight.rs @@ -26,6 +26,7 @@ use arrow_flight::{ }; use async_trait::async_trait; use bytes::{self, Bytes}; +use common_error::ext::ErrorExt; use common_grpc::flight::do_put::{DoPutMetadata, DoPutResponse}; use common_grpc::flight::{FlightDecoder, FlightEncoder, FlightMessage}; use common_memory_manager::MemoryGuard; @@ -201,7 +202,7 @@ impl FlightCraft for GreptimeRequestHandler { // This does not authorize or execute anything; `handle_request()` below still performs // the normal frontend handling and auth checks before query execution. FlowQueryExtensions::parse_flow_extensions(&query_ctx.extensions()) - .map_err(|e| Status::invalid_argument(e.to_string()))?; + .map_err(|e| Status::invalid_argument(e.output_msg()))?; // The Grpc protocol pass query by Flight. It needs to be wrapped under a span, in order to record stream let span = info_span!( @@ -549,7 +550,6 @@ fn to_flight_data_stream( Box::pin(stream) as _ } OutputData::AffectedRows(rows) => { - let affected_rows = FlightEncoder::default().encode(FlightMessage::AffectedRows(rows)); let should_emit_terminal_metrics = FlowQueryExtensions::parse_flow_extensions(&query_ctx.extensions()) .expect("flow extensions must be validated before Flight serialization") @@ -558,13 +558,12 @@ fn to_flight_data_stream( .then_some(output.meta.plan) .flatten() .and_then(terminal_recordbatch_metrics_from_plan) - .and_then(|metrics| serde_json::to_string(&metrics).ok()) - .map(FlightMessage::Metrics) - .map(|message| FlightEncoder::default().encode(message)) - .into_iter() - .flatten(); - let stream = - tokio_stream::iter(affected_rows.into_iter().chain(terminal_metrics).map(Ok)); + .and_then(|metrics| serde_json::to_string(&metrics).ok()); + let affected_rows = FlightEncoder::default().encode(FlightMessage::AffectedRows { + rows, + metrics: terminal_metrics, + }); + let stream = tokio_stream::iter(affected_rows.into_iter().map(Ok)); Box::pin(stream) as _ } }