From bf24dc064d4f4869102837d862fe0684dc489472 Mon Sep 17 00:00:00 2001 From: discord9 Date: Tue, 28 Apr 2026 17:33:34 +0800 Subject: [PATCH] feat: OutputMetrics for inc query Signed-off-by: discord9 --- src/client/src/database.rs | 512 +++++++++++++++++- src/client/src/lib.rs | 2 +- src/flow/src/batching_mode/frontend_client.rs | 301 +++++++++- src/frontend/src/error.rs | 2 +- src/servers/src/grpc/flight.rs | 39 +- tests-integration/src/grpc/flight.rs | 106 +++- 6 files changed, 919 insertions(+), 43 deletions(-) diff --git a/src/client/src/database.rs b/src/client/src/database.rs index e12c2ec0fc..5114fa2a7e 100644 --- a/src/client/src/database.rs +++ b/src/client/src/database.rs @@ -15,6 +15,8 @@ use std::pin::Pin; use std::str::FromStr; use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::task::{Context, Poll}; use api::v1::auth_header::AuthScheme; use api::v1::ddl_request::Expr as DdlExpr; @@ -25,6 +27,7 @@ use api::v1::{ AlterTableExpr, AuthHeader, Basic, CreateTableExpr, DdlRequest, GreptimeRequest, InsertRequests, QueryRequest, RequestHeader, RowInsertRequests, }; +use arc_swap::ArcSwapOption; use arrow_flight::{FlightData, Ticket}; use async_stream::stream; use base64::Engine; @@ -35,8 +38,9 @@ use common_error::ext::BoxedError; use common_grpc::flight::do_put::DoPutResponse; use common_grpc::flight::{FlightDecoder, FlightMessage}; use common_query::Output; +use common_recordbatch::adapter::RecordBatchMetrics; use common_recordbatch::error::ExternalSnafu; -use common_recordbatch::{RecordBatch, RecordBatchStreamWrapper}; +use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream, RecordBatchStreamWrapper}; use common_telemetry::tracing::Span; use common_telemetry::tracing_context::W3cTrace; use common_telemetry::{error, warn}; @@ -57,6 +61,172 @@ type FlightDataStream = Pin + Send>>; type DoPutResponseStream = Pin>>>; +#[derive(Debug, Clone, Default)] +pub struct OutputMetrics { + metrics: Arc>, + ready: Arc, +} + +impl OutputMetrics { + fn new() -> Self { + Self::default() + } + + pub fn update(&self, metrics: Option) { + self.metrics.swap(metrics.map(Arc::new)); + } + + pub fn mark_ready(&self) { + self.ready.store(true, Ordering::Release); + } + + pub fn is_ready(&self) -> bool { + self.ready.load(Ordering::Acquire) + } + + pub fn get(&self) -> Option { + self.metrics.load().as_ref().map(|m| m.as_ref().clone()) + } + + /// Returns proved per-region watermarks. + /// + /// Entries whose watermark is `None` are intentionally omitted because they + /// represent participating regions whose terminal sequence bound was not + /// provable. + pub fn region_watermark_map(&self) -> Option> { + Some( + self.get()? + .region_watermarks + .into_iter() + .filter_map(|entry| entry.watermark.map(|seq| (entry.region_id, seq))) + .collect::>(), + ) + } + + /// Returns all regions that participated in terminal metric collection, + /// including entries whose watermark is `None`. + pub fn participating_regions(&self) -> Option> { + Some( + self.get()? + .region_watermarks + .into_iter() + .map(|entry| entry.region_id) + .collect::>(), + ) + } +} + +#[derive(Debug)] +pub struct OutputWithMetrics { + pub output: Output, + pub metrics: OutputMetrics, +} + +impl OutputWithMetrics { + pub fn from_output(output: Output) -> Self { + let terminal_metrics = OutputMetrics::new(); + let output = attach_terminal_metrics(output, &terminal_metrics); + Self { + output, + metrics: terminal_metrics, + } + } + + pub fn region_watermark_map(&self) -> Option> { + self.metrics.region_watermark_map() + } + + pub fn participating_regions(&self) -> Option> { + self.metrics.participating_regions() + } + + pub fn into_output(self) -> Output { + self.output + } +} + +fn parse_terminal_metrics(metrics_json: &str) -> Result { + serde_json::from_str(metrics_json).map_err(|e| { + IllegalFlightMessagesSnafu { + reason: format!("Invalid terminal metrics message: {e}"), + } + .build() + }) +} + +struct StreamWithMetrics { + stream: common_recordbatch::SendableRecordBatchStream, + metrics: OutputMetrics, +} + +impl StreamWithMetrics { + fn new(stream: common_recordbatch::SendableRecordBatchStream, metrics: OutputMetrics) -> Self { + Self { stream, metrics } + } + + fn sync_terminal_metrics(&self) { + self.metrics.update(self.stream.metrics()); + } +} + +impl RecordBatchStream for StreamWithMetrics { + fn name(&self) -> &str { + self.stream.name() + } + + fn schema(&self) -> datatypes::schema::SchemaRef { + self.stream.schema() + } + + fn output_ordering(&self) -> Option<&[OrderOption]> { + self.stream.output_ordering() + } + + fn metrics(&self) -> Option { + self.sync_terminal_metrics(); + self.metrics.get() + } +} + +impl Stream for StreamWithMetrics { + type Item = common_recordbatch::error::Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let polled = Pin::new(&mut self.stream).poll_next(cx); + match &polled { + Poll::Ready(Some(_)) => self.sync_terminal_metrics(), + Poll::Ready(None) => { + self.sync_terminal_metrics(); + self.metrics.mark_ready(); + } + Poll::Pending => {} + } + polled + } + + fn size_hint(&self) -> (usize, Option) { + self.stream.size_hint() + } +} + +fn attach_terminal_metrics(output: Output, terminal_metrics: &OutputMetrics) -> Output { + let Output { data, meta } = output; + let data = match data { + common_query::OutputData::Stream(stream) => { + terminal_metrics.update(stream.metrics()); + common_query::OutputData::Stream(Box::pin(StreamWithMetrics::new( + stream, + terminal_metrics.clone(), + ))) + } + other => { + terminal_metrics.mark_ready(); + other + } + }; + Output::new(data, meta) +} + #[derive(Clone, Debug, Default)] pub struct Database { // The "catalog" and "schema" to be used in processing the requests at the server side. @@ -224,6 +394,9 @@ impl Database { } fn put_hints(metadata: &mut MetadataMap, hints: &[(&str, &str)]) -> Result<()> { + // Keep this helper for simple ASCII hint values only. The wire format is the + // existing comma-separated `x-greptime-hints` metadata value and does not + // escape commas inside individual values. let Some(value) = hints .iter() .map(|(k, v)| format!("{}={}", k, v)) @@ -333,15 +506,46 @@ impl Database { let request = Request::Query(QueryRequest { query: Some(Query::Sql(sql.as_ref().to_string())), }); - self.do_get(request, hints).await + self.do_get(request, hints) + .await + .map(OutputWithMetrics::into_output) + } + + pub async fn sql_with_terminal_metrics( + &self, + sql: S, + hints: &[(&str, &str)], + ) -> Result + where + S: AsRef, + { + self.query_with_terminal_metrics( + QueryRequest { + query: Some(Query::Sql(sql.as_ref().to_string())), + }, + hints, + ) + .await } /// Executes a logical plan directly without SQL parsing. pub async fn logical_plan(&self, logical_plan: Vec) -> Result { - let request = Request::Query(QueryRequest { - query: Some(Query::LogicalPlan(logical_plan)), - }); - self.do_get(request, &[]).await + self.query_with_terminal_metrics( + QueryRequest { + query: Some(Query::LogicalPlan(logical_plan)), + }, + &[], + ) + .await + .map(OutputWithMetrics::into_output) + } + + pub async fn query_with_terminal_metrics( + &self, + request: QueryRequest, + hints: &[(&str, &str)], + ) -> Result { + self.do_get(Request::Query(request), hints).await } /// Creates a new table using the provided table expression. @@ -349,7 +553,9 @@ impl Database { let request = Request::Ddl(DdlRequest { expr: Some(DdlExpr::CreateTable(expr)), }); - self.do_get(request, &[]).await + self.do_get(request, &[]) + .await + .map(OutputWithMetrics::into_output) } /// Alters an existing table using the provided alter expression. @@ -357,10 +563,12 @@ impl Database { let request = Request::Ddl(DdlRequest { expr: Some(DdlExpr::AlterTable(expr)), }); - self.do_get(request, &[]).await + self.do_get(request, &[]) + .await + .map(OutputWithMetrics::into_output) } - async fn do_get(&self, request: Request, hints: &[(&str, &str)]) -> Result { + async fn do_get(&self, request: Request, hints: &[(&str, &str)]) -> Result { let request = self.to_rpc_request(request); let request = Ticket { ticket: request.encode_to_vec().into(), @@ -409,13 +617,33 @@ impl Database { match first_flight_message { FlightMessage::AffectedRows(rows) => { - ensure!( - flight_message_stream.next().await.is_none(), - IllegalFlightMessagesSnafu { - reason: "Expect 'AffectedRows' Flight messages to be the one and the only!" + let terminal_metrics = OutputMetrics::new(); + let next_message = flight_message_stream.next().await.transpose()?; + match next_message { + None => terminal_metrics.mark_ready(), + Some(FlightMessage::Metrics(s)) => { + terminal_metrics.update(Some(parse_terminal_metrics(&s)?)); + terminal_metrics.mark_ready(); + ensure!( + flight_message_stream.next().await.is_none(), + IllegalFlightMessagesSnafu { + reason: "Expect 'AffectedRows' Flight messages to be followed by at most one Metrics message" + } + ); } - ); - Ok(Output::new_with_affected_rows(rows)) + Some(other) => { + return IllegalFlightMessagesSnafu { + reason: format!( + "'AffectedRows' Flight message can only be followed by a Metrics message, got {other:?}" + ), + } + .fail(); + } + } + Ok(OutputWithMetrics { + output: Output::new_with_affected_rows(rows), + metrics: terminal_metrics, + }) } FlightMessage::RecordBatch(_) | FlightMessage::Metrics(_) => { IllegalFlightMessagesSnafu { @@ -424,24 +652,88 @@ impl Database { .fail() } FlightMessage::Schema(schema) => { + let metrics = Arc::new(ArcSwapOption::from(None)); + let metrics_ref = metrics.clone(); let schema = Arc::new( datatypes::schema::Schema::try_from(schema) .context(error::ConvertSchemaSnafu)?, ); let schema_cloned = schema.clone(); let stream = Box::pin(stream!({ - while let Some(flight_message) = flight_message_stream.next().await { - let flight_message = flight_message - .map_err(BoxedError::new) - .context(ExternalSnafu)?; + let mut buffered_message: Option = None; + let mut stream_ended = false; + + while !stream_ended { + let flight_message_item = if let Some(msg) = buffered_message.take() { + Some(Ok(msg)) + } else { + flight_message_stream.next().await + }; + + let flight_message = match flight_message_item { + Some(Ok(message)) => message, + Some(Err(e)) => { + yield Err(BoxedError::new(e)).context(ExternalSnafu); + break; + } + None => break, + }; + match flight_message { FlightMessage::RecordBatch(arrow_batch) => { - yield Ok(RecordBatch::from_df_record_batch( + let result_to_yield = RecordBatch::from_df_record_batch( schema_cloned.clone(), arrow_batch, - )) + ); + + if let Some(next_flight_message_result) = + flight_message_stream.next().await + { + match next_flight_message_result { + Ok(FlightMessage::Metrics(s)) => { + match parse_terminal_metrics(&s) { + Ok(m) => { + metrics_ref.swap(Some(Arc::new(m))); + } + Err(e) => { + yield Err(BoxedError::new(e)) + .context(ExternalSnafu); + break; + } + }; + } + Ok(FlightMessage::RecordBatch(rb)) => { + buffered_message = Some(FlightMessage::RecordBatch(rb)); + } + Ok(_) => { + yield IllegalFlightMessagesSnafu {reason: "A RecordBatch message can only be succeeded by a Metrics message or another RecordBatch message"} + .fail() + .map_err(BoxedError::new) + .context(ExternalSnafu); + break; + } + Err(e) => { + yield Err(BoxedError::new(e)).context(ExternalSnafu); + break; + } + } + } else { + stream_ended = true; + } + + yield Ok(result_to_yield) + } + FlightMessage::Metrics(s) => { + match parse_terminal_metrics(&s) { + Ok(m) => { + metrics_ref.swap(Some(Arc::new(m))); + } + Err(e) => { + yield Err(BoxedError::new(e)).context(ExternalSnafu); + } + }; + break; } - FlightMessage::Metrics(_) => {} FlightMessage::AffectedRows(_) | FlightMessage::Schema(_) => { yield IllegalFlightMessagesSnafu {reason: format!("A Schema message must be succeeded exclusively by a set of RecordBatch messages, flight_message: {:?}", flight_message)} .fail() @@ -456,10 +748,12 @@ impl Database { schema, stream, output_ordering: None, - metrics: Default::default(), + metrics, span: Span::current(), }; - Ok(Output::new_with_stream(Box::pin(record_batch_stream))) + Ok(OutputWithMetrics::from_output(Output::new_with_stream( + Box::pin(record_batch_stream), + ))) } } } @@ -512,16 +806,59 @@ struct FlightContext { #[cfg(test)] mod tests { - use std::assert_matches; + use std::sync::Arc; + use std::task::{Context, Poll}; use api::v1::auth_header::AuthScheme; use api::v1::{AuthHeader, Basic}; use common_error::status_code::StatusCode; + use common_query::OutputData; + use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream}; + use datatypes::prelude::{ConcreteDataType, VectorRef}; + use datatypes::schema::{ColumnSchema, Schema}; + use datatypes::vectors::Int32Vector; + use futures_util::StreamExt; use tonic::{Code, Status}; use super::*; use crate::error::TonicSnafu; + struct MockMetricsStream { + schema: datatypes::schema::SchemaRef, + batch: Option, + metrics: RecordBatchMetrics, + terminal_metrics_only: bool, + } + + impl Stream for MockMetricsStream { + type Item = common_recordbatch::error::Result; + + fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.batch.take().map(Ok)) + } + } + + impl RecordBatchStream for MockMetricsStream { + fn name(&self) -> &str { + "MockMetricsStream" + } + + fn schema(&self) -> datatypes::schema::SchemaRef { + self.schema.clone() + } + + fn output_ordering(&self) -> Option<&[OrderOption]> { + None + } + + fn metrics(&self) -> Option { + if self.terminal_metrics_only && self.batch.is_some() { + return None; + } + Some(self.metrics.clone()) + } + } + #[test] fn test_flight_ctx() { let mut ctx = FlightContext::default(); @@ -536,12 +873,12 @@ mod tests { auth_scheme: Some(basic), }); - assert_matches!( + assert!(matches!( ctx.auth_header, Some(AuthHeader { auth_scheme: Some(AuthScheme::Basic(_)), }) - ) + )); } #[test] @@ -558,4 +895,125 @@ mod tests { assert_eq!(expected.to_string(), actual.to_string()); } + + #[tokio::test] + async fn test_query_with_terminal_metrics_tracks_terminal_only_metrics() { + let schema = Arc::new(Schema::new(vec![ColumnSchema::new( + "v", + ConcreteDataType::int32_datatype(), + false, + )])); + let batch = RecordBatch::new( + schema.clone(), + vec![Arc::new(Int32Vector::from_slice([1, 2])) as VectorRef], + ) + .unwrap(); + let output = Output::new_with_stream(Box::pin(MockMetricsStream { + schema, + batch: Some(batch), + metrics: RecordBatchMetrics { + region_watermarks: vec![common_recordbatch::adapter::RegionWatermarkEntry { + region_id: 7, + watermark: Some(42), + }], + ..Default::default() + }, + terminal_metrics_only: true, + })); + + let result = OutputWithMetrics::from_output(output); + let terminal_metrics = result.metrics.clone(); + assert!(!terminal_metrics.is_ready()); + assert!(terminal_metrics.get().is_none()); + + let OutputData::Stream(mut stream) = result.output.data else { + panic!("expected stream output"); + }; + while stream.next().await.is_some() {} + + assert!(terminal_metrics.is_ready()); + assert_eq!( + terminal_metrics.participating_regions(), + Some(std::collections::BTreeSet::from([7_u64])) + ); + assert_eq!( + terminal_metrics.region_watermark_map(), + Some(std::collections::HashMap::from([(7_u64, 42_u64)])) + ); + } + + #[test] + fn test_parse_terminal_metrics_rejects_invalid_json() { + assert!(parse_terminal_metrics("{not-json}").is_err()); + } + + #[tokio::test] + async fn test_invalid_terminal_metrics_after_record_batch_fails_before_yielding_batch() { + let schema = Arc::new(Schema::new(vec![ColumnSchema::new( + "v", + ConcreteDataType::int32_datatype(), + false, + )])); + let batch = RecordBatch::new( + schema.clone(), + vec![Arc::new(Int32Vector::from_slice([1])) as VectorRef], + ) + .unwrap(); + let metrics = Arc::new(ArcSwapOption::from(None)); + let metrics_ref = metrics.clone(); + let schema_cloned = schema.clone(); + let mut flight_message_stream = futures_util::stream::iter(vec![ + Ok(FlightMessage::RecordBatch(batch.into_df_record_batch())), + Ok(FlightMessage::Metrics("{not-json}".to_string())), + ] + as Vec>); + + let mut record_batch_stream = Box::pin(stream!({ + let Some(Ok(FlightMessage::RecordBatch(arrow_batch))) = + flight_message_stream.next().await + else { + return; + }; + let result_to_yield = + RecordBatch::from_df_record_batch(schema_cloned.clone(), arrow_batch); + + if let Some(Ok(FlightMessage::Metrics(s))) = flight_message_stream.next().await { + match parse_terminal_metrics(&s) { + Ok(m) => { + metrics_ref.swap(Some(Arc::new(m))); + } + Err(e) => { + yield Err(BoxedError::new(e)).context(ExternalSnafu); + return; + } + } + } + + yield Ok(result_to_yield); + })); + + let err = record_batch_stream.next().await.unwrap().unwrap_err(); + assert_eq!("External error", err.to_string()); + assert!( + format!("{err:?}").contains("Invalid terminal metrics message"), + "unexpected error: {err:?}" + ); + assert!(record_batch_stream.next().await.is_none()); + assert!(metrics.load().is_none()); + } + + #[test] + fn test_output_metrics_distinguishes_empty_region_watermarks_from_absence() { + let metrics = OutputMetrics::default(); + metrics.update(Some(RecordBatchMetrics::default())); + + assert_eq!( + metrics.participating_regions(), + Some(std::collections::BTreeSet::new()) + ); + assert_eq!( + metrics.region_watermark_map(), + Some(std::collections::HashMap::new()) + ); + } } diff --git a/src/client/src/lib.rs b/src/client/src/lib.rs index 0c9334b7d4..147dffc145 100644 --- a/src/client/src/lib.rs +++ b/src/client/src/lib.rs @@ -32,7 +32,7 @@ pub use common_recordbatch::{RecordBatches, SendableRecordBatchStream}; use snafu::OptionExt; pub use self::client::Client; -pub use self::database::Database; +pub use self::database::{Database, OutputMetrics, OutputWithMetrics}; pub use self::error::{Error, Result}; use crate::error::{IllegalDatabaseResponseSnafu, ServerSnafu}; diff --git a/src/flow/src/batching_mode/frontend_client.rs b/src/flow/src/batching_mode/frontend_client.rs index 9875564c78..faa1988306 100644 --- a/src/flow/src/batching_mode/frontend_client.rs +++ b/src/flow/src/batching_mode/frontend_client.rs @@ -20,15 +20,16 @@ use std::sync::{Arc, Mutex, Weak}; use api::v1::greptime_request::Request; use api::v1::query_request::Query; use api::v1::{CreateTableExpr, QueryRequest}; -use client::{Client, Database}; +use client::{Client, Database, OutputWithMetrics}; use common_error::ext::BoxedError; use common_grpc::channel_manager::{ChannelConfig, ChannelManager, load_client_tls_config}; use common_meta::peer::{Peer, PeerDiscovery}; -use common_query::Output; +use common_query::{Output, OutputData}; use common_telemetry::warn; use meta_client::client::MetaClient; use query::datafusion::QUERY_PARALLELISM_HINT; -use query::options::QueryOptions; +use query::metrics::terminal_recordbatch_metrics_from_plan; +use query::options::{FlowQueryExtensions, QueryOptions}; use rand::rng; use rand::seq::SliceRandom; use servers::query_handler::grpc::GrpcQueryHandler; @@ -342,6 +343,83 @@ impl FrontendClient { } } + pub async fn query_with_terminal_metrics( + &self, + catalog: &str, + schema: &str, + request: QueryRequest, + extensions: &[(&str, &str)], + ) -> Result { + let flow_extensions = build_flow_extensions(extensions)?; + match self { + FrontendClient::Distributed { + query, batch_opts, .. + } => { + let query_parallelism = query.parallelism.to_string(); + let mut hints = vec![ + (QUERY_PARALLELISM_HINT, query_parallelism.as_str()), + (READ_PREFERENCE_HINT, batch_opts.read_preference.as_ref()), + ]; + // PR2b only sends simple flow hint values such as + // `flow.return_region_seq=true`. The distributed client forwards + // hints through `x-greptime-hints`, whose existing comma-separated + // encoding is not suitable for comma-bearing values. + hints.extend_from_slice(extensions); + let db = self.get_random_active_frontend(catalog, schema).await?; + db.database + .query_with_terminal_metrics(request, &hints) + .await + .map_err(BoxedError::new) + .context(ExternalSnafu) + } + FrontendClient::Standalone { + database_client, + query, + } => { + let mut extensions_map = HashMap::from([( + QUERY_PARALLELISM_HINT.to_string(), + query.parallelism.to_string(), + )]); + for (key, value) in extensions { + extensions_map.insert((*key).to_string(), (*value).to_string()); + } + let ctx = QueryContextBuilder::default() + .current_catalog(catalog.to_string()) + .current_schema(schema.to_string()) + .extensions(extensions_map) + .build(); + let ctx = Arc::new(ctx); + let database_client = { + database_client + .handler + .lock() + .map_err(|e| { + UnexpectedSnafu { + reason: format!("Failed to lock database client: {e}"), + } + .build() + })? + .as_ref() + .context(UnexpectedSnafu { + reason: "Standalone's frontend instance is not set", + })? + .upgrade() + .context(UnexpectedSnafu { + reason: "Failed to upgrade database client", + })? + }; + database_client + .do_query(Request::Query(request), ctx) + .await + .map(|output| { + wrap_standalone_output_with_terminal_metrics(output, &flow_extensions) + }) + .map_err(BoxedError::new) + .context(ExternalSnafu) + } + } + } + /// Handle a request to frontend pub(crate) async fn handle( &self, @@ -432,6 +510,40 @@ impl FrontendClient { } } +fn build_flow_extensions(extensions: &[(&str, &str)]) -> Result { + let flow_extensions = HashMap::from_iter( + extensions + .iter() + .map(|(key, value)| ((*key).to_string(), (*value).to_string())), + ); + FlowQueryExtensions::parse_flow_extensions(&flow_extensions) + .map_err(BoxedError::new) + .context(ExternalSnafu) + .map(|extensions| extensions.unwrap_or_default()) +} + +fn wrap_standalone_output_with_terminal_metrics( + output: Output, + flow_extensions: &FlowQueryExtensions, +) -> OutputWithMetrics { + let should_collect_region_watermark = flow_extensions.should_collect_region_watermark(); + let terminal_metrics = + if should_collect_region_watermark && !matches!(&output.data, OutputData::Stream(_)) { + output + .meta + .plan + .clone() + .and_then(terminal_recordbatch_metrics_from_plan) + } else { + None + }; + let result = OutputWithMetrics::from_output(output); + if let Some(metrics) = terminal_metrics { + result.metrics.update(Some(metrics)); + } + result +} + /// Describe a peer of frontend #[derive(Debug, Default)] pub(crate) enum PeerDesc { @@ -456,9 +568,20 @@ impl std::fmt::Display for PeerDesc { #[cfg(test)] mod tests { + use std::pin::Pin; + use std::task::{Context, Poll}; use std::time::Duration; - use common_query::Output; + use common_error::ext::PlainError; + use common_error::status_code::StatusCode; + use common_query::{Output, OutputData}; + use common_recordbatch::adapter::RecordBatchMetrics; + use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream}; + use datatypes::prelude::{ConcreteDataType, VectorRef}; + use datatypes::schema::{ColumnSchema, Schema}; + use datatypes::vectors::Int32Vector; + use futures::StreamExt; + use snafu::GenerateImplicitData; use tokio::time::timeout; use super::*; @@ -466,6 +589,55 @@ mod tests { #[derive(Debug)] struct NoopHandler; + struct MockMetricsStream { + schema: datatypes::schema::SchemaRef, + batch: Option, + metrics: RecordBatchMetrics, + terminal_metrics_only: bool, + } + + impl futures::Stream for MockMetricsStream { + type Item = common_recordbatch::error::Result; + + fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.batch.take().map(Ok)) + } + + fn size_hint(&self) -> (usize, Option) { + ( + usize::from(self.batch.is_some()), + Some(usize::from(self.batch.is_some())), + ) + } + } + + impl RecordBatchStream for MockMetricsStream { + fn name(&self) -> &str { + "MockMetricsStream" + } + + fn schema(&self) -> datatypes::schema::SchemaRef { + self.schema.clone() + } + + fn output_ordering(&self) -> Option<&[OrderOption]> { + None + } + + fn metrics(&self) -> Option { + if self.terminal_metrics_only && self.batch.is_some() { + return None; + } + Some(self.metrics.clone()) + } + } + + #[derive(Debug)] + struct MetricsHandler; + + #[derive(Debug)] + struct ExtensionAwareHandler; + #[async_trait::async_trait] impl GrpcQueryHandlerWithBoxedError for NoopHandler { async fn do_query( @@ -477,6 +649,50 @@ mod tests { } } + #[async_trait::async_trait] + impl GrpcQueryHandlerWithBoxedError for MetricsHandler { + async fn do_query( + &self, + _query: Request, + _ctx: QueryContextRef, + ) -> std::result::Result { + let schema = Arc::new(Schema::new(vec![ColumnSchema::new( + "v", + ConcreteDataType::int32_datatype(), + false, + )])); + let batch = RecordBatch::new( + schema.clone(), + vec![Arc::new(Int32Vector::from_slice([1, 2])) as VectorRef], + ) + .unwrap(); + Ok(Output::new_with_stream(Box::pin(MockMetricsStream { + schema, + batch: Some(batch), + metrics: RecordBatchMetrics { + region_watermarks: vec![common_recordbatch::adapter::RegionWatermarkEntry { + region_id: 42, + watermark: Some(99), + }], + ..Default::default() + }, + terminal_metrics_only: true, + }))) + } + } + + #[async_trait::async_trait] + impl GrpcQueryHandlerWithBoxedError for ExtensionAwareHandler { + async fn do_query( + &self, + _query: Request, + ctx: QueryContextRef, + ) -> std::result::Result { + assert_eq!(ctx.extension("flow.return_region_seq"), Some("true")); + Ok(Output::new_with_affected_rows(1)) + } + } + #[tokio::test] async fn wait_initialized() { let (client, handler_mut) = @@ -522,4 +738,81 @@ mod tests { .is_ok() ); } + + #[tokio::test] + async fn test_query_with_terminal_metrics_tracks_watermark_in_standalone_mode() { + let handler: Arc = Arc::new(MetricsHandler); + let client = + FrontendClient::from_grpc_handler(Arc::downgrade(&handler), QueryOptions::default()); + + let result = client + .query_with_terminal_metrics( + "greptime", + "public", + QueryRequest { + query: Some(Query::Sql("select 1".to_string())), + }, + &[], + ) + .await + .unwrap(); + + let terminal_metrics = result.metrics.clone(); + assert!(!result.metrics.is_ready()); + assert!(terminal_metrics.get().is_none()); + + let OutputData::Stream(mut stream) = result.output.data else { + panic!("expected stream output"); + }; + while stream.next().await.is_some() {} + + assert!(terminal_metrics.is_ready()); + assert_eq!( + terminal_metrics.region_watermark_map(), + Some(HashMap::from([(42_u64, 99_u64)])) + ); + } + + #[tokio::test] + async fn test_query_with_terminal_metrics_forwards_flow_extensions_in_standalone_mode() { + let handler: Arc = Arc::new(ExtensionAwareHandler); + let client = + FrontendClient::from_grpc_handler(Arc::downgrade(&handler), QueryOptions::default()); + + let result = client + .query_with_terminal_metrics( + "greptime", + "public", + QueryRequest { + query: Some(Query::Sql("insert into t select 1".to_string())), + }, + &[("flow.return_region_seq", "true")], + ) + .await + .unwrap(); + + assert!(result.metrics.is_ready()); + assert!(result.region_watermark_map().is_none()); + } + + #[tokio::test] + async fn test_query_with_terminal_metrics_rejects_invalid_flow_extensions() { + let handler: Arc = Arc::new(NoopHandler); + let client = + FrontendClient::from_grpc_handler(Arc::downgrade(&handler), QueryOptions::default()); + + let err = client + .query_with_terminal_metrics( + "greptime", + "public", + QueryRequest { + query: Some(Query::Sql("select 1".to_string())), + }, + &[("flow.return_region_seq", "not-a-bool")], + ) + .await + .unwrap_err(); + + assert!(format!("{err:?}").contains("Invalid value for flow.return_region_seq")); + } } diff --git a/src/frontend/src/error.rs b/src/frontend/src/error.rs index d148e6aa1b..6f78d23e14 100644 --- a/src/frontend/src/error.rs +++ b/src/frontend/src/error.rs @@ -399,7 +399,7 @@ impl ErrorExt for Error { Error::PrometheusLabelValuesQueryPlan { source, .. } => source.status_code(), - Error::CollectRecordbatch { .. } => StatusCode::EngineExecuteQuery, + Error::CollectRecordbatch { source, .. } => source.status_code(), Error::SqlExecIntercepted { source, .. } => source.status_code(), Error::StartServer { source, .. } => source.status_code(), diff --git a/src/servers/src/grpc/flight.rs b/src/servers/src/grpc/flight.rs index 364ce8ce26..4f262c53aa 100644 --- a/src/servers/src/grpc/flight.rs +++ b/src/servers/src/grpc/flight.rs @@ -38,7 +38,9 @@ use datatypes::arrow::datatypes::SchemaRef; use futures::{Stream, future, ready}; use futures_util::{StreamExt, TryStreamExt}; use prost::Message; -use session::context::{QueryContext, QueryContextRef}; +use query::metrics::terminal_recordbatch_metrics_from_plan; +use query::options::FlowQueryExtensions; +use session::context::{Channel, QueryContextRef}; use snafu::{IntoError, ResultExt, ensure}; use table::table_name::TableName; use tokio::sync::mpsc; @@ -47,7 +49,9 @@ use tonic::{Request, Response, Status, Streaming}; use crate::error::{InvalidParameterSnafu, Result, ToJsonSnafu}; pub use crate::grpc::flight::stream::FlightRecordBatchStream; -use crate::grpc::greptime_handler::{GreptimeRequestHandler, get_request_type}; +use crate::grpc::greptime_handler::{ + GreptimeRequestHandler, create_query_context, get_request_type, +}; use crate::grpc::{FlightCompression, TonicResult, context_auth}; use crate::request_memory_limiter::ServerMemoryLimiter; use crate::request_memory_metrics::RequestMemoryMetrics; @@ -191,6 +195,13 @@ impl FlightCraft for GreptimeRequestHandler { let ticket = request.into_inner().ticket; let request = GreptimeRequest::decode(ticket.as_ref()).context(error::InvalidFlightTicketSnafu)?; + let query_ctx = + create_query_context(Channel::Grpc, request.header.as_ref(), hints.clone())?; + // Validate flow hint syntax at the transport boundary before dispatching the request. + // This does not authorize or execute anything; `handle_request()` below still performs + // the normal frontend handling and auth checks before query execution. + FlowQueryExtensions::parse_flow_extensions(&query_ctx.extensions()) + .map_err(|e| Status::invalid_argument(e.to_string()))?; // The Grpc protocol pass query by Flight. It needs to be wrapped under a span, in order to record stream let span = info_span!( @@ -205,7 +216,7 @@ impl FlightCraft for GreptimeRequestHandler { output, TracingContext::from_current_span(), flight_compression, - QueryContext::arc(), + query_ctx, ); Ok(Response::new(stream)) } @@ -538,12 +549,22 @@ fn to_flight_data_stream( Box::pin(stream) as _ } OutputData::AffectedRows(rows) => { - let stream = tokio_stream::iter( - FlightEncoder::default() - .encode(FlightMessage::AffectedRows(rows)) - .into_iter() - .map(Ok), - ); + let affected_rows = FlightEncoder::default().encode(FlightMessage::AffectedRows(rows)); + let should_emit_terminal_metrics = + FlowQueryExtensions::parse_flow_extensions(&query_ctx.extensions()) + .expect("flow extensions must be validated before Flight serialization") + .is_some_and(|extensions| extensions.should_collect_region_watermark()); + let terminal_metrics = should_emit_terminal_metrics + .then_some(output.meta.plan) + .flatten() + .and_then(terminal_recordbatch_metrics_from_plan) + .and_then(|metrics| serde_json::to_string(&metrics).ok()) + .map(FlightMessage::Metrics) + .map(|message| FlightEncoder::default().encode(message)) + .into_iter() + .flatten(); + let stream = + tokio_stream::iter(affected_rows.into_iter().chain(terminal_metrics).map(Ok)); Box::pin(stream) as _ } } diff --git a/tests-integration/src/grpc/flight.rs b/tests-integration/src/grpc/flight.rs index 9ed6b8176f..6149dcadce 100644 --- a/tests-integration/src/grpc/flight.rs +++ b/tests-integration/src/grpc/flight.rs @@ -23,10 +23,12 @@ mod test { use auth::user_provider_from_option; use client::{Client, Database}; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; + use common_error::ext::ErrorExt; use common_grpc::flight::do_put::DoPutMetadata; use common_grpc::flight::{FlightEncoder, FlightMessage}; use common_query::OutputData; use common_recordbatch::RecordBatch; + use common_recordbatch::adapter::RegionWatermarkEntry; use datatypes::prelude::{ConcreteDataType, ScalarVector, VectorRef}; use datatypes::schema::{ColumnSchema, Schema}; use datatypes::vectors::{Int32Vector, StringVector, TimestampMillisecondVector}; @@ -129,6 +131,104 @@ mod test { | 1970-01-01T00:00:00.009 | -9 | s9 | +-------------------------+----+----+"; query_and_expect(db.fe_instance().as_ref(), sql, expected).await; + + let output = client.sql(sql).await.unwrap(); + let OutputData::Stream(mut stream) = output.data else { + panic!("expected stream output"); + }; + while let Some(batch) = stream.next().await { + batch.unwrap(); + } + let metrics = stream.metrics().expect("expected terminal metrics"); + assert!(metrics.region_watermarks.is_empty()); + + let result = client + .sql_with_terminal_metrics(sql, &[("flow.return_region_seq", "true")]) + .await + .unwrap(); + let terminal_metrics = result.metrics.clone(); + let OutputData::Stream(mut stream) = result.output.data else { + panic!("expected stream output"); + }; + while let Some(batch) = stream.next().await { + batch.unwrap(); + } + assert!(terminal_metrics.is_ready()); + let regions = db.list_all_regions().await; + assert_eq!(regions.len(), 1); + let (region_id, region) = regions.into_iter().next().unwrap(); + let expected_watermark = (region_id.as_u64(), region.find_committed_sequence()); + assert_eq!( + terminal_metrics.region_watermark_map(), + Some(std::collections::HashMap::from([expected_watermark])) + ); + + let output = client + .sql_with_hint(sql, &[("flow.return_region_seq", "true")]) + .await + .unwrap(); + let OutputData::Stream(mut stream) = output.data else { + panic!("expected stream output"); + }; + + let mut row_count = 0; + while let Some(batch) = stream.next().await { + let batch = batch.unwrap(); + row_count += batch.num_rows(); + } + assert_eq!(row_count, 9); + + let metrics = stream.metrics().expect("expected terminal metrics"); + let region_watermarks = metrics.region_watermarks; + assert_eq!( + region_watermarks, + vec![RegionWatermarkEntry { + region_id: expected_watermark.0, + watermark: Some(expected_watermark.1), + }] + ); + + let previous_watermark = expected_watermark; + + create_table_named(&client, "bar").await; + let result = client + .sql_with_terminal_metrics("insert into bar select ts, a, `B` from foo", &[]) + .await + .unwrap(); + let OutputData::AffectedRows(affected_rows) = result.output.data else { + panic!("expected affected rows output"); + }; + assert_eq!(affected_rows, 9); + assert!(result.metrics.is_ready()); + assert!(result.region_watermark_map().is_none()); + + let err = client + .sql_with_terminal_metrics( + "insert into bar select ts, a, `B` from foo", + &[("flow.return_region_seq", "not-a-bool")], + ) + .await + .unwrap_err(); + let err_msg = format!("{err:?}"); + assert!(err_msg.contains("Invalid value for flow.return_region_seq")); + + client.sql("truncate table bar").await.unwrap(); + + let result = client + .sql_with_terminal_metrics( + "insert into bar select ts, a, `B` from foo", + &[("flow.return_region_seq", "true")], + ) + .await + .unwrap(); + let OutputData::AffectedRows(affected_rows) = result.output.data else { + panic!("expected affected rows output"); + }; + assert_eq!(affected_rows, 9); + assert_eq!( + result.region_watermark_map(), + Some(std::collections::HashMap::from([previous_watermark])) + ); } async fn test_put_record_batches(client: &Database, record_batches: Vec) { @@ -224,6 +324,10 @@ mod test { } async fn create_table(client: &Database) { + create_table_named(client, "foo").await; + } + + async fn create_table_named(client: &Database, table_name: &str) { // create table foo ( // ts timestamp time index, // a int primary key, @@ -232,7 +336,7 @@ mod test { let output = client .create(CreateTableExpr { schema_name: "public".to_string(), - table_name: "foo".to_string(), + table_name: table_name.to_string(), column_defs: vec![ ColumnDef { name: "ts".to_string(),