diff --git a/src/client/src/database.rs b/src/client/src/database.rs index 7a07f478bf..732c9ac9bf 100644 --- a/src/client/src/database.rs +++ b/src/client/src/database.rs @@ -15,6 +15,8 @@ use std::pin::Pin; use std::str::FromStr; use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::task::{Context, Poll}; use api::v1::auth_header::AuthScheme; use api::v1::ddl_request::Expr as DdlExpr; @@ -38,7 +40,7 @@ use common_grpc::flight::{FlightDecoder, FlightMessage}; use common_query::Output; use common_recordbatch::adapter::RecordBatchMetrics; use common_recordbatch::error::ExternalSnafu; -use common_recordbatch::{RecordBatch, RecordBatchStreamWrapper}; +use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream, RecordBatchStreamWrapper}; use common_telemetry::tracing::Span; use common_telemetry::tracing_context::W3cTrace; use common_telemetry::{error, warn}; @@ -59,6 +61,148 @@ type FlightDataStream = Pin + Send>>; type DoPutResponseStream = Pin>>>; +#[derive(Debug, Clone, Default)] +pub struct OutputMetrics { + metrics: Arc>, + ready: Arc, +} + +impl OutputMetrics { + fn new() -> Self { + Self::default() + } + + pub fn update(&self, metrics: Option) { + self.metrics.swap(metrics.map(Arc::new)); + } + + pub fn mark_ready(&self) { + self.ready.store(true, Ordering::Release); + } + + pub fn is_ready(&self) -> bool { + self.ready.load(Ordering::Acquire) + } + + pub fn get(&self) -> Option { + self.metrics.load().as_ref().map(|m| m.as_ref().clone()) + } + + pub fn region_watermark_map(&self) -> Option> { + self.get()?.region_latest_sequences.map(|sequences| { + sequences + .into_iter() + .collect::>() + }) + } +} + +#[derive(Debug)] +pub struct OutputWithMetrics { + pub output: Output, + pub metrics: OutputMetrics, +} + +impl OutputWithMetrics { + pub fn from_output(output: Output) -> Self { + let terminal_metrics = OutputMetrics::new(); + let output = attach_terminal_metrics(output, &terminal_metrics); + Self { + output, + metrics: terminal_metrics, + } + } + + pub fn region_watermark_map(&self) -> Option> { + self.metrics.region_watermark_map() + } + + pub fn into_output(self) -> Output { + self.output + } +} + +fn parse_terminal_metrics(metrics_json: &str) -> Option> { + serde_json::from_str(metrics_json).ok().map(Arc::new) +} + +struct StreamWithMetrics { + stream: common_recordbatch::SendableRecordBatchStream, + metrics: OutputMetrics, +} + +impl StreamWithMetrics { + fn new(stream: common_recordbatch::SendableRecordBatchStream, metrics: OutputMetrics) -> Self { + Self { stream, metrics } + } + + fn sync_terminal_metrics(&self) { + self.metrics.update(self.stream.metrics()); + } + + fn mark_ready_if_terminated(&self) { + self.metrics.mark_ready(); + } +} + +impl RecordBatchStream for StreamWithMetrics { + fn name(&self) -> &str { + self.stream.name() + } + + fn schema(&self) -> datatypes::schema::SchemaRef { + self.stream.schema() + } + + fn output_ordering(&self) -> Option<&[OrderOption]> { + self.stream.output_ordering() + } + + fn metrics(&self) -> Option { + self.sync_terminal_metrics(); + self.metrics.get() + } +} + +impl Stream for StreamWithMetrics { + type Item = common_recordbatch::error::Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let polled = Pin::new(&mut self.stream).poll_next(cx); + match &polled { + Poll::Ready(Some(_)) => self.sync_terminal_metrics(), + Poll::Ready(None) => { + self.sync_terminal_metrics(); + self.mark_ready_if_terminated(); + } + Poll::Pending => {} + } + polled + } + + fn size_hint(&self) -> (usize, Option) { + self.stream.size_hint() + } +} + +fn attach_terminal_metrics(output: Output, terminal_metrics: &OutputMetrics) -> Output { + let Output { data, meta } = output; + let data = match data { + common_query::OutputData::Stream(stream) => { + terminal_metrics.update(stream.metrics()); + common_query::OutputData::Stream(Box::pin(StreamWithMetrics::new( + stream, + terminal_metrics.clone(), + ))) + } + other => { + terminal_metrics.mark_ready(); + other + } + }; + Output::new(data, meta) +} + #[derive(Clone, Debug, Default)] pub struct Database { // The "catalog" and "schema" to be used in processing the requests at the server side. @@ -335,15 +479,46 @@ impl Database { let request = Request::Query(QueryRequest { query: Some(Query::Sql(sql.as_ref().to_string())), }); - self.do_get(request, hints).await + self.do_get(request, hints) + .await + .map(OutputWithMetrics::into_output) + } + + pub async fn sql_with_terminal_metrics( + &self, + sql: S, + hints: &[(&str, &str)], + ) -> Result + where + S: AsRef, + { + self.query_with_terminal_metrics( + QueryRequest { + query: Some(Query::Sql(sql.as_ref().to_string())), + }, + hints, + ) + .await } /// Executes a logical plan directly without SQL parsing. pub async fn logical_plan(&self, logical_plan: Vec) -> Result { - let request = Request::Query(QueryRequest { - query: Some(Query::LogicalPlan(logical_plan)), - }); - self.do_get(request, &[]).await + self.query_with_terminal_metrics( + QueryRequest { + query: Some(Query::LogicalPlan(logical_plan)), + }, + &[], + ) + .await + .map(OutputWithMetrics::into_output) + } + + pub async fn query_with_terminal_metrics( + &self, + request: QueryRequest, + hints: &[(&str, &str)], + ) -> Result { + self.do_get(Request::Query(request), hints).await } /// Creates a new table using the provided table expression. @@ -351,7 +526,9 @@ impl Database { let request = Request::Ddl(DdlRequest { expr: Some(DdlExpr::CreateTable(expr)), }); - self.do_get(request, &[]).await + self.do_get(request, &[]) + .await + .map(OutputWithMetrics::into_output) } /// Alters an existing table using the provided alter expression. @@ -359,10 +536,12 @@ impl Database { let request = Request::Ddl(DdlRequest { expr: Some(DdlExpr::AlterTable(expr)), }); - self.do_get(request, &[]).await + self.do_get(request, &[]) + .await + .map(OutputWithMetrics::into_output) } - async fn do_get(&self, request: Request, hints: &[(&str, &str)]) -> Result { + async fn do_get(&self, request: Request, hints: &[(&str, &str)]) -> Result { let request = self.to_rpc_request(request); let request = Ticket { ticket: request.encode_to_vec().into(), @@ -411,13 +590,35 @@ impl Database { match first_flight_message { FlightMessage::AffectedRows(rows) => { - ensure!( - flight_message_stream.next().await.is_none(), - IllegalFlightMessagesSnafu { - reason: "Expect 'AffectedRows' Flight messages to be the one and the only!" + let terminal_metrics = OutputMetrics::new(); + let next_message = flight_message_stream.next().await.transpose()?; + match next_message { + None => terminal_metrics.mark_ready(), + Some(FlightMessage::Metrics(s)) => { + terminal_metrics.update( + parse_terminal_metrics(&s).map(|metrics| metrics.as_ref().clone()), + ); + terminal_metrics.mark_ready(); + ensure!( + flight_message_stream.next().await.is_none(), + IllegalFlightMessagesSnafu { + reason: "Expect 'AffectedRows' Flight messages to be followed by at most one Metrics message" + } + ); } - ); - Ok(Output::new_with_affected_rows(rows)) + Some(other) => { + return IllegalFlightMessagesSnafu { + reason: format!( + "'AffectedRows' Flight message can only be followed by a Metrics message, got {other:?}" + ), + } + .fail(); + } + } + Ok(OutputWithMetrics { + output: Output::new_with_affected_rows(rows), + metrics: terminal_metrics, + }) } FlightMessage::RecordBatch(_) | FlightMessage::Metrics(_) => { IllegalFlightMessagesSnafu { @@ -465,8 +666,7 @@ impl Database { { match next_flight_message_result { Ok(FlightMessage::Metrics(s)) => { - let m: Option> = - serde_json::from_str(&s).ok().map(Arc::new); + let m = parse_terminal_metrics(&s); metrics_ref.swap(m); } Ok(FlightMessage::RecordBatch(rb)) => { @@ -491,8 +691,7 @@ impl Database { yield Ok(result_to_yield) } FlightMessage::Metrics(s) => { - let m: Option> = - serde_json::from_str(&s).ok().map(Arc::new); + let m = parse_terminal_metrics(&s); metrics_ref.swap(m); break; } @@ -513,7 +712,9 @@ impl Database { metrics, span: Span::current(), }; - Ok(Output::new_with_stream(Box::pin(record_batch_stream))) + Ok(OutputWithMetrics::from_output(Output::new_with_stream( + Box::pin(record_batch_stream), + ))) } } } @@ -567,15 +768,59 @@ struct FlightContext { #[cfg(test)] mod tests { use std::assert_matches::assert_matches; + use std::sync::Arc; + use std::task::{Context, Poll}; use api::v1::auth_header::AuthScheme; use api::v1::{AuthHeader, Basic}; use common_error::status_code::StatusCode; + use common_query::OutputData; + use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream}; + use datatypes::prelude::{ConcreteDataType, VectorRef}; + use datatypes::schema::{ColumnSchema, Schema}; + use datatypes::vectors::Int32Vector; + use futures_util::StreamExt; use tonic::{Code, Status}; use super::*; use crate::error::TonicSnafu; + struct MockMetricsStream { + schema: datatypes::schema::SchemaRef, + batch: Option, + metrics: RecordBatchMetrics, + terminal_metrics_only: bool, + } + + impl Stream for MockMetricsStream { + type Item = common_recordbatch::error::Result; + + fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.batch.take().map(Ok)) + } + } + + impl RecordBatchStream for MockMetricsStream { + fn name(&self) -> &str { + "MockMetricsStream" + } + + fn schema(&self) -> datatypes::schema::SchemaRef { + self.schema.clone() + } + + fn output_ordering(&self) -> Option<&[OrderOption]> { + None + } + + fn metrics(&self) -> Option { + if self.terminal_metrics_only && self.batch.is_some() { + return None; + } + Some(self.metrics.clone()) + } + } + #[test] fn test_flight_ctx() { let mut ctx = FlightContext::default(); @@ -612,4 +857,48 @@ mod tests { assert_eq!(expected.to_string(), actual.to_string()); } + + #[tokio::test] + async fn test_query_with_terminal_metrics_tracks_terminal_only_metrics() { + let schema = Arc::new(Schema::new(vec![ColumnSchema::new( + "v", + ConcreteDataType::int32_datatype(), + false, + )])); + let batch = RecordBatch::new( + schema.clone(), + vec![Arc::new(Int32Vector::from_slice([1, 2])) as VectorRef], + ) + .unwrap(); + let output = Output::new_with_stream(Box::pin(MockMetricsStream { + schema, + batch: Some(batch), + metrics: RecordBatchMetrics { + region_latest_sequences: Some(vec![(7, 42)]), + ..Default::default() + }, + terminal_metrics_only: true, + })); + + let result = OutputWithMetrics::from_output(output); + let terminal_metrics = result.metrics.clone(); + assert!(!terminal_metrics.is_ready()); + assert!(terminal_metrics.get().is_none()); + + let OutputData::Stream(mut stream) = result.output.data else { + panic!("expected stream output"); + }; + while stream.next().await.is_some() {} + + assert!(terminal_metrics.is_ready()); + assert_eq!( + terminal_metrics.region_watermark_map(), + Some(std::collections::HashMap::from([(7_u64, 42_u64)])) + ); + } + + #[test] + fn test_parse_terminal_metrics_rejects_invalid_json() { + assert!(parse_terminal_metrics("{not-json}").is_none()); + } } diff --git a/src/client/src/lib.rs b/src/client/src/lib.rs index bf383acff9..3d94ceeca2 100644 --- a/src/client/src/lib.rs +++ b/src/client/src/lib.rs @@ -34,7 +34,7 @@ pub use common_recordbatch::{RecordBatches, SendableRecordBatchStream}; use snafu::OptionExt; pub use self::client::Client; -pub use self::database::Database; +pub use self::database::{Database, OutputMetrics, OutputWithMetrics}; pub use self::error::{Error, Result}; use crate::error::{IllegalDatabaseResponseSnafu, ServerSnafu}; diff --git a/src/flow/src/batching_mode/frontend_client.rs b/src/flow/src/batching_mode/frontend_client.rs index baf22e7a20..916647c245 100644 --- a/src/flow/src/batching_mode/frontend_client.rs +++ b/src/flow/src/batching_mode/frontend_client.rs @@ -21,17 +21,18 @@ use std::time::SystemTime; use api::v1::greptime_request::Request; use api::v1::query_request::Query; use api::v1::{CreateTableExpr, QueryRequest}; -use client::{Client, Database}; +use client::{Client, Database, OutputWithMetrics}; use common_error::ext::BoxedError; use common_grpc::channel_manager::{ChannelConfig, ChannelManager, load_client_tls_config}; use common_meta::cluster::{NodeInfo, NodeInfoKey, Role}; use common_meta::peer::Peer; use common_meta::rpc::store::RangeRequest; -use common_query::Output; +use common_query::{Output, OutputData}; use common_telemetry::warn; use meta_client::client::MetaClient; use query::datafusion::QUERY_PARALLELISM_HINT; -use query::options::QueryOptions; +use query::metrics::terminal_recordbatch_metrics_from_plan; +use query::options::{FlowQueryExtensions, QueryOptions}; use rand::rng; use rand::seq::SliceRandom; use servers::query_handler::grpc::GrpcQueryHandler; @@ -407,6 +408,79 @@ impl FrontendClient { } } + pub async fn query_with_terminal_metrics( + &self, + catalog: &str, + schema: &str, + request: QueryRequest, + extensions: &[(&str, &str)], + ) -> Result { + let flow_extensions = build_flow_extensions(extensions)?; + match self { + FrontendClient::Distributed { + query, batch_opts, .. + } => { + let query_parallelism = query.parallelism.to_string(); + let mut hints = vec![ + (QUERY_PARALLELISM_HINT, query_parallelism.as_str()), + (READ_PREFERENCE_HINT, batch_opts.read_preference.as_ref()), + ]; + hints.extend_from_slice(extensions); + let db = self.get_random_active_frontend(catalog, schema).await?; + db.database + .query_with_terminal_metrics(request, &hints) + .await + .map_err(BoxedError::new) + .context(ExternalSnafu) + } + FrontendClient::Standalone { + database_client, + query, + } => { + let mut extensions_map = HashMap::from([( + QUERY_PARALLELISM_HINT.to_string(), + query.parallelism.to_string(), + )]); + for (key, value) in extensions { + extensions_map.insert((*key).to_string(), (*value).to_string()); + } + let ctx = QueryContextBuilder::default() + .current_catalog(catalog.to_string()) + .current_schema(schema.to_string()) + .extensions(extensions_map) + .build(); + let ctx = Arc::new(ctx); + let database_client = { + database_client + .handler + .lock() + .map_err(|e| { + UnexpectedSnafu { + reason: format!("Failed to lock database client: {e}"), + } + .build() + })? + .as_ref() + .context(UnexpectedSnafu { + reason: "Standalone's frontend instance is not set", + })? + .upgrade() + .context(UnexpectedSnafu { + reason: "Failed to upgrade database client", + })? + }; + database_client + .do_query(Request::Query(request), ctx) + .await + .map(|output| { + wrap_standalone_output_with_terminal_metrics(output, &flow_extensions) + }) + .map_err(BoxedError::new) + .context(ExternalSnafu) + } + } + } + /// Handle a request to frontend pub(crate) async fn handle( &self, @@ -497,6 +571,39 @@ impl FrontendClient { } } +fn build_flow_extensions(extensions: &[(&str, &str)]) -> Result { + let flow_extensions = HashMap::from_iter( + extensions + .iter() + .map(|(key, value)| ((*key).to_string(), (*value).to_string())), + ); + FlowQueryExtensions::from_extensions(&flow_extensions) + .map_err(BoxedError::new) + .context(ExternalSnafu) +} + +fn wrap_standalone_output_with_terminal_metrics( + output: Output, + flow_extensions: &FlowQueryExtensions, +) -> OutputWithMetrics { + let should_collect_region_watermark = flow_extensions.should_collect_region_watermark(); + let terminal_metrics = + if should_collect_region_watermark && !matches!(&output.data, OutputData::Stream(_)) { + output + .meta + .plan + .clone() + .and_then(terminal_recordbatch_metrics_from_plan) + } else { + None + }; + let result = OutputWithMetrics::from_output(output); + if let Some(metrics) = terminal_metrics { + result.metrics.update(Some(metrics)); + } + result +} + /// Describe a peer of frontend #[derive(Debug, Default)] pub(crate) enum PeerDesc { @@ -544,11 +651,19 @@ fn extract_u64_segment(message: &str, start: &str, end: &str) -> Option { #[cfg(test)] mod tests { + use std::pin::Pin; + use std::task::{Context, Poll}; use std::time::Duration; use common_error::ext::PlainError; use common_error::status_code::StatusCode; - use common_query::Output; + use common_query::{Output, OutputData}; + use common_recordbatch::adapter::RecordBatchMetrics; + use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream}; + use datatypes::prelude::{ConcreteDataType, VectorRef}; + use datatypes::schema::{ColumnSchema, Schema}; + use datatypes::vectors::Int32Vector; + use futures::StreamExt; use snafu::GenerateImplicitData; use tokio::time::timeout; @@ -557,6 +672,55 @@ mod tests { #[derive(Debug)] struct NoopHandler; + struct MockMetricsStream { + schema: datatypes::schema::SchemaRef, + batch: Option, + metrics: RecordBatchMetrics, + terminal_metrics_only: bool, + } + + impl futures::Stream for MockMetricsStream { + type Item = common_recordbatch::error::Result; + + fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.batch.take().map(Ok)) + } + + fn size_hint(&self) -> (usize, Option) { + ( + usize::from(self.batch.is_some()), + Some(usize::from(self.batch.is_some())), + ) + } + } + + impl RecordBatchStream for MockMetricsStream { + fn name(&self) -> &str { + "MockMetricsStream" + } + + fn schema(&self) -> datatypes::schema::SchemaRef { + self.schema.clone() + } + + fn output_ordering(&self) -> Option<&[OrderOption]> { + None + } + + fn metrics(&self) -> Option { + if self.terminal_metrics_only && self.batch.is_some() { + return None; + } + Some(self.metrics.clone()) + } + } + + #[derive(Debug)] + struct MetricsHandler; + + #[derive(Debug)] + struct ExtensionAwareHandler; + #[async_trait::async_trait] impl GrpcQueryHandlerWithBoxedError for NoopHandler { async fn do_query( @@ -568,6 +732,47 @@ mod tests { } } + #[async_trait::async_trait] + impl GrpcQueryHandlerWithBoxedError for MetricsHandler { + async fn do_query( + &self, + _query: Request, + _ctx: QueryContextRef, + ) -> std::result::Result { + let schema = Arc::new(Schema::new(vec![ColumnSchema::new( + "v", + ConcreteDataType::int32_datatype(), + false, + )])); + let batch = RecordBatch::new( + schema.clone(), + vec![Arc::new(Int32Vector::from_slice([1, 2])) as VectorRef], + ) + .unwrap(); + Ok(Output::new_with_stream(Box::pin(MockMetricsStream { + schema, + batch: Some(batch), + metrics: RecordBatchMetrics { + region_latest_sequences: Some(vec![(42, 99)]), + ..Default::default() + }, + terminal_metrics_only: true, + }))) + } + } + + #[async_trait::async_trait] + impl GrpcQueryHandlerWithBoxedError for ExtensionAwareHandler { + async fn do_query( + &self, + _query: Request, + ctx: QueryContextRef, + ) -> std::result::Result { + assert_eq!(ctx.extension("flow.return_region_seq"), Some("true")); + Ok(Output::new_with_affected_rows(1)) + } + } + #[tokio::test] async fn wait_initialized() { let (client, handler_mut) = @@ -650,4 +855,81 @@ mod tests { assert!(!failure.is_stale_cursor()); assert_eq!(failure.stale_cursor, None); } + + #[tokio::test] + async fn test_query_with_terminal_metrics_tracks_watermark_in_standalone_mode() { + let handler: Arc = Arc::new(MetricsHandler); + let client = + FrontendClient::from_grpc_handler(Arc::downgrade(&handler), QueryOptions::default()); + + let result = client + .query_with_terminal_metrics( + "greptime", + "public", + QueryRequest { + query: Some(Query::Sql("select 1".to_string())), + }, + &[], + ) + .await + .unwrap(); + + let terminal_metrics = result.metrics.clone(); + assert!(!result.metrics.is_ready()); + assert!(terminal_metrics.get().is_none()); + + let OutputData::Stream(mut stream) = result.output.data else { + panic!("expected stream output"); + }; + while stream.next().await.is_some() {} + + assert!(terminal_metrics.is_ready()); + assert_eq!( + terminal_metrics.region_watermark_map(), + Some(HashMap::from([(42_u64, 99_u64)])) + ); + } + + #[tokio::test] + async fn test_query_with_terminal_metrics_forwards_flow_extensions_in_standalone_mode() { + let handler: Arc = Arc::new(ExtensionAwareHandler); + let client = + FrontendClient::from_grpc_handler(Arc::downgrade(&handler), QueryOptions::default()); + + let result = client + .query_with_terminal_metrics( + "greptime", + "public", + QueryRequest { + query: Some(Query::Sql("insert into t select 1".to_string())), + }, + &[("flow.return_region_seq", "true")], + ) + .await + .unwrap(); + + assert!(result.metrics.is_ready()); + assert!(result.region_watermark_map().is_none()); + } + + #[tokio::test] + async fn test_query_with_terminal_metrics_rejects_invalid_flow_extensions() { + let handler: Arc = Arc::new(NoopHandler); + let client = + FrontendClient::from_grpc_handler(Arc::downgrade(&handler), QueryOptions::default()); + + let err = client + .query_with_terminal_metrics( + "greptime", + "public", + QueryRequest { + query: Some(Query::Sql("select 1".to_string())), + }, + &[("flow.return_region_seq", "not-a-bool")], + ) + .await + .unwrap_err(); + + assert!(format!("{err:?}").contains("Invalid value for flow.return_region_seq")); + } } diff --git a/src/query/src/datafusion.rs b/src/query/src/datafusion.rs index 79223d590c..425b19858a 100644 --- a/src/query/src/datafusion.rs +++ b/src/query/src/datafusion.rs @@ -132,7 +132,9 @@ impl DatafusionQueryEngine { let output = self .exec_query_plan((*dml.input).clone(), query_ctx.clone()) .await?; - let mut stream = match output.data { + let common_query::Output { data, meta } = output; + let input_plan = meta.plan; + let mut stream = match data { OutputData::RecordBatches(batches) => batches.as_stream(), OutputData::Stream(stream) => stream, _ => unreachable!(), @@ -168,7 +170,7 @@ impl DatafusionQueryEngine { } Ok(Output::new( OutputData::AffectedRows(affected_rows), - OutputMeta::new_with_cost(insert_cost), + OutputMeta::new(input_plan, insert_cost), )) } diff --git a/src/query/src/metrics.rs b/src/query/src/metrics.rs index a5ed020f38..8ce8a90b28 100644 --- a/src/query/src/metrics.rs +++ b/src/query/src/metrics.rs @@ -180,6 +180,20 @@ impl Stream for RegionWatermarkMetricsStream { } } +pub fn terminal_recordbatch_metrics_from_plan( + plan: Arc, +) -> Option { + let region_latest_sequences = collect_region_latest_sequences(plan); + if region_latest_sequences.is_empty() { + None + } else { + Some(RecordBatchMetrics { + region_latest_sequences: Some(region_latest_sequences), + ..Default::default() + }) + } +} + fn collect_region_latest_sequences(plan: Arc) -> Vec<(u64, u64)> { let mut merged = std::collections::HashMap::new(); let mut stack = vec![plan]; diff --git a/src/servers/src/grpc/flight.rs b/src/servers/src/grpc/flight.rs index 02755fcfd0..a01d7fc456 100644 --- a/src/servers/src/grpc/flight.rs +++ b/src/servers/src/grpc/flight.rs @@ -39,7 +39,9 @@ use datatypes::arrow::datatypes::SchemaRef; use futures::{Stream, future, ready}; use futures_util::{StreamExt, TryStreamExt}; use prost::Message; -use session::context::{QueryContext, QueryContextRef}; +use query::metrics::terminal_recordbatch_metrics_from_plan; +use query::options::FlowQueryExtensions; +use session::context::{Channel, QueryContextRef}; use snafu::{IntoError, ResultExt, ensure}; use table::table_name::TableName; use tokio::sync::mpsc; @@ -48,7 +50,9 @@ use tonic::{Request, Response, Status, Streaming}; use crate::error::{InvalidParameterSnafu, Result, ToJsonSnafu}; pub use crate::grpc::flight::stream::FlightRecordBatchStream; -use crate::grpc::greptime_handler::{GreptimeRequestHandler, get_request_type}; +use crate::grpc::greptime_handler::{ + GreptimeRequestHandler, create_query_context, get_request_type, +}; use crate::grpc::{FlightCompression, TonicResult, context_auth}; use crate::request_memory_limiter::ServerMemoryLimiter; use crate::request_memory_metrics::RequestMemoryMetrics; @@ -192,6 +196,10 @@ impl FlightCraft for GreptimeRequestHandler { let ticket = request.into_inner().ticket; let request = GreptimeRequest::decode(ticket.as_ref()).context(error::InvalidFlightTicketSnafu)?; + let query_ctx = + create_query_context(Channel::Grpc, request.header.as_ref(), hints.clone())?; + FlowQueryExtensions::from_extensions(&query_ctx.extensions()) + .map_err(|e| Status::invalid_argument(e.to_string()))?; // The Grpc protocol pass query by Flight. It needs to be wrapped under a span, in order to record stream let span = info_span!( @@ -206,7 +214,7 @@ impl FlightCraft for GreptimeRequestHandler { output, TracingContext::from_current_span(), flight_compression, - QueryContext::arc(), + query_ctx, ); Ok(Response::new(stream)) } @@ -539,12 +547,22 @@ fn to_flight_data_stream( Box::pin(stream) as _ } OutputData::AffectedRows(rows) => { - let stream = tokio_stream::iter( - FlightEncoder::default() - .encode(FlightMessage::AffectedRows(rows)) - .into_iter() - .map(Ok), - ); + let affected_rows = FlightEncoder::default().encode(FlightMessage::AffectedRows(rows)); + let should_emit_terminal_metrics = + FlowQueryExtensions::from_extensions(&query_ctx.extensions()) + .expect("flow extensions must be validated before Flight serialization") + .should_collect_region_watermark(); + let terminal_metrics = should_emit_terminal_metrics + .then_some(output.meta.plan) + .flatten() + .and_then(terminal_recordbatch_metrics_from_plan) + .and_then(|metrics| serde_json::to_string(&metrics).ok()) + .map(FlightMessage::Metrics) + .map(|message| FlightEncoder::default().encode(message)) + .into_iter() + .flatten(); + let stream = + tokio_stream::iter(affected_rows.into_iter().chain(terminal_metrics).map(Ok)); Box::pin(stream) as _ } } diff --git a/tests-integration/src/grpc/flight.rs b/tests-integration/src/grpc/flight.rs index ec00f0e526..952ff8a5c9 100644 --- a/tests-integration/src/grpc/flight.rs +++ b/tests-integration/src/grpc/flight.rs @@ -143,6 +143,23 @@ mod test { let metrics = stream.metrics().expect("expected terminal metrics"); assert!(metrics.region_latest_sequences.is_none()); + let result = client + .sql_with_terminal_metrics(sql, &[("flow.return_region_seq", "true")]) + .await + .unwrap(); + let terminal_metrics = result.metrics.clone(); + let OutputData::Stream(mut stream) = result.output.data else { + panic!("expected stream output"); + }; + while let Some(batch) = stream.next().await { + batch.unwrap(); + } + assert!( + terminal_metrics + .region_watermark_map() + .is_some_and(|m| !m.is_empty()) + ); + let output = client .sql_with_hint(sql, &[("flow.return_region_seq", "true")]) .await @@ -170,6 +187,43 @@ mod test { let previous_watermark = region_latest_sequences[0]; + create_table_named(&client, "bar").await; + let result = client + .sql_with_terminal_metrics("insert into bar select ts, a, `B` from foo", &[]) + .await + .unwrap(); + let OutputData::AffectedRows(affected_rows) = result.output.data else { + panic!("expected affected rows output"); + }; + assert_eq!(affected_rows, 9); + assert!(result.metrics.is_ready()); + assert!(result.region_watermark_map().is_none()); + + let err = client + .sql_with_terminal_metrics( + "insert into bar select ts, a, `B` from foo", + &[("flow.return_region_seq", "not-a-bool")], + ) + .await + .unwrap_err(); + let err_msg = format!("{err:?}"); + assert!(err_msg.contains("Invalid value for flow.return_region_seq")); + + client.sql("truncate table bar").await.unwrap(); + + let result = client + .sql_with_terminal_metrics( + "insert into bar select ts, a, `B` from foo", + &[("flow.return_region_seq", "true")], + ) + .await + .unwrap(); + let OutputData::AffectedRows(affected_rows) = result.output.data else { + panic!("expected affected rows output"); + }; + assert_eq!(affected_rows, 9); + assert!(result.region_watermark_map().is_some_and(|m| !m.is_empty())); + let incremental_batches = create_record_batches(10); test_put_record_batches(&client, incremental_batches).await; @@ -362,6 +416,10 @@ mod test { } async fn create_table(client: &Database) { + create_table_named(client, "foo").await; + } + + async fn create_table_named(client: &Database, table_name: &str) { // create table foo ( // ts timestamp time index, // a int primary key, @@ -370,7 +428,7 @@ mod test { let output = client .create(CreateTableExpr { schema_name: "public".to_string(), - table_name: "foo".to_string(), + table_name: table_name.to_string(), column_defs: vec![ ColumnDef { name: "ts".to_string(),