feat: flow terminal metrics handling

Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
discord9
2026-03-17 21:48:53 +08:00
parent bc768617fb
commit 8376150c81
7 changed files with 700 additions and 37 deletions

View File

@@ -15,6 +15,8 @@
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::task::{Context, Poll};
use api::v1::auth_header::AuthScheme;
use api::v1::ddl_request::Expr as DdlExpr;
@@ -38,7 +40,7 @@ use common_grpc::flight::{FlightDecoder, FlightMessage};
use common_query::Output;
use common_recordbatch::adapter::RecordBatchMetrics;
use common_recordbatch::error::ExternalSnafu;
use common_recordbatch::{RecordBatch, RecordBatchStreamWrapper};
use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream, RecordBatchStreamWrapper};
use common_telemetry::tracing::Span;
use common_telemetry::tracing_context::W3cTrace;
use common_telemetry::{error, warn};
@@ -59,6 +61,148 @@ type FlightDataStream = Pin<Box<dyn Stream<Item = FlightData> + Send>>;
type DoPutResponseStream = Pin<Box<dyn Stream<Item = Result<DoPutResponse>>>>;
#[derive(Debug, Clone, Default)]
pub struct OutputMetrics {
metrics: Arc<ArcSwapOption<RecordBatchMetrics>>,
ready: Arc<AtomicBool>,
}
impl OutputMetrics {
fn new() -> Self {
Self::default()
}
pub fn update(&self, metrics: Option<RecordBatchMetrics>) {
self.metrics.swap(metrics.map(Arc::new));
}
pub fn mark_ready(&self) {
self.ready.store(true, Ordering::Release);
}
pub fn is_ready(&self) -> bool {
self.ready.load(Ordering::Acquire)
}
pub fn get(&self) -> Option<RecordBatchMetrics> {
self.metrics.load().as_ref().map(|m| m.as_ref().clone())
}
pub fn region_watermark_map(&self) -> Option<std::collections::HashMap<u64, u64>> {
self.get()?.region_latest_sequences.map(|sequences| {
sequences
.into_iter()
.collect::<std::collections::HashMap<_, _>>()
})
}
}
#[derive(Debug)]
pub struct OutputWithMetrics {
pub output: Output,
pub metrics: OutputMetrics,
}
impl OutputWithMetrics {
pub fn from_output(output: Output) -> Self {
let terminal_metrics = OutputMetrics::new();
let output = attach_terminal_metrics(output, &terminal_metrics);
Self {
output,
metrics: terminal_metrics,
}
}
pub fn region_watermark_map(&self) -> Option<std::collections::HashMap<u64, u64>> {
self.metrics.region_watermark_map()
}
pub fn into_output(self) -> Output {
self.output
}
}
fn parse_terminal_metrics(metrics_json: &str) -> Option<Arc<RecordBatchMetrics>> {
serde_json::from_str(metrics_json).ok().map(Arc::new)
}
struct StreamWithMetrics {
stream: common_recordbatch::SendableRecordBatchStream,
metrics: OutputMetrics,
}
impl StreamWithMetrics {
fn new(stream: common_recordbatch::SendableRecordBatchStream, metrics: OutputMetrics) -> Self {
Self { stream, metrics }
}
fn sync_terminal_metrics(&self) {
self.metrics.update(self.stream.metrics());
}
fn mark_ready_if_terminated(&self) {
self.metrics.mark_ready();
}
}
impl RecordBatchStream for StreamWithMetrics {
fn name(&self) -> &str {
self.stream.name()
}
fn schema(&self) -> datatypes::schema::SchemaRef {
self.stream.schema()
}
fn output_ordering(&self) -> Option<&[OrderOption]> {
self.stream.output_ordering()
}
fn metrics(&self) -> Option<RecordBatchMetrics> {
self.sync_terminal_metrics();
self.metrics.get()
}
}
impl Stream for StreamWithMetrics {
type Item = common_recordbatch::error::Result<RecordBatch>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let polled = Pin::new(&mut self.stream).poll_next(cx);
match &polled {
Poll::Ready(Some(_)) => self.sync_terminal_metrics(),
Poll::Ready(None) => {
self.sync_terminal_metrics();
self.mark_ready_if_terminated();
}
Poll::Pending => {}
}
polled
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.stream.size_hint()
}
}
fn attach_terminal_metrics(output: Output, terminal_metrics: &OutputMetrics) -> Output {
let Output { data, meta } = output;
let data = match data {
common_query::OutputData::Stream(stream) => {
terminal_metrics.update(stream.metrics());
common_query::OutputData::Stream(Box::pin(StreamWithMetrics::new(
stream,
terminal_metrics.clone(),
)))
}
other => {
terminal_metrics.mark_ready();
other
}
};
Output::new(data, meta)
}
#[derive(Clone, Debug, Default)]
pub struct Database {
// The "catalog" and "schema" to be used in processing the requests at the server side.
@@ -335,15 +479,46 @@ impl Database {
let request = Request::Query(QueryRequest {
query: Some(Query::Sql(sql.as_ref().to_string())),
});
self.do_get(request, hints).await
self.do_get(request, hints)
.await
.map(OutputWithMetrics::into_output)
}
pub async fn sql_with_terminal_metrics<S>(
&self,
sql: S,
hints: &[(&str, &str)],
) -> Result<OutputWithMetrics>
where
S: AsRef<str>,
{
self.query_with_terminal_metrics(
QueryRequest {
query: Some(Query::Sql(sql.as_ref().to_string())),
},
hints,
)
.await
}
/// Executes a logical plan directly without SQL parsing.
pub async fn logical_plan(&self, logical_plan: Vec<u8>) -> Result<Output> {
let request = Request::Query(QueryRequest {
query: Some(Query::LogicalPlan(logical_plan)),
});
self.do_get(request, &[]).await
self.query_with_terminal_metrics(
QueryRequest {
query: Some(Query::LogicalPlan(logical_plan)),
},
&[],
)
.await
.map(OutputWithMetrics::into_output)
}
pub async fn query_with_terminal_metrics(
&self,
request: QueryRequest,
hints: &[(&str, &str)],
) -> Result<OutputWithMetrics> {
self.do_get(Request::Query(request), hints).await
}
/// Creates a new table using the provided table expression.
@@ -351,7 +526,9 @@ impl Database {
let request = Request::Ddl(DdlRequest {
expr: Some(DdlExpr::CreateTable(expr)),
});
self.do_get(request, &[]).await
self.do_get(request, &[])
.await
.map(OutputWithMetrics::into_output)
}
/// Alters an existing table using the provided alter expression.
@@ -359,10 +536,12 @@ impl Database {
let request = Request::Ddl(DdlRequest {
expr: Some(DdlExpr::AlterTable(expr)),
});
self.do_get(request, &[]).await
self.do_get(request, &[])
.await
.map(OutputWithMetrics::into_output)
}
async fn do_get(&self, request: Request, hints: &[(&str, &str)]) -> Result<Output> {
async fn do_get(&self, request: Request, hints: &[(&str, &str)]) -> Result<OutputWithMetrics> {
let request = self.to_rpc_request(request);
let request = Ticket {
ticket: request.encode_to_vec().into(),
@@ -411,13 +590,35 @@ impl Database {
match first_flight_message {
FlightMessage::AffectedRows(rows) => {
ensure!(
flight_message_stream.next().await.is_none(),
IllegalFlightMessagesSnafu {
reason: "Expect 'AffectedRows' Flight messages to be the one and the only!"
let terminal_metrics = OutputMetrics::new();
let next_message = flight_message_stream.next().await.transpose()?;
match next_message {
None => terminal_metrics.mark_ready(),
Some(FlightMessage::Metrics(s)) => {
terminal_metrics.update(
parse_terminal_metrics(&s).map(|metrics| metrics.as_ref().clone()),
);
terminal_metrics.mark_ready();
ensure!(
flight_message_stream.next().await.is_none(),
IllegalFlightMessagesSnafu {
reason: "Expect 'AffectedRows' Flight messages to be followed by at most one Metrics message"
}
);
}
);
Ok(Output::new_with_affected_rows(rows))
Some(other) => {
return IllegalFlightMessagesSnafu {
reason: format!(
"'AffectedRows' Flight message can only be followed by a Metrics message, got {other:?}"
),
}
.fail();
}
}
Ok(OutputWithMetrics {
output: Output::new_with_affected_rows(rows),
metrics: terminal_metrics,
})
}
FlightMessage::RecordBatch(_) | FlightMessage::Metrics(_) => {
IllegalFlightMessagesSnafu {
@@ -465,8 +666,7 @@ impl Database {
{
match next_flight_message_result {
Ok(FlightMessage::Metrics(s)) => {
let m: Option<Arc<RecordBatchMetrics>> =
serde_json::from_str(&s).ok().map(Arc::new);
let m = parse_terminal_metrics(&s);
metrics_ref.swap(m);
}
Ok(FlightMessage::RecordBatch(rb)) => {
@@ -491,8 +691,7 @@ impl Database {
yield Ok(result_to_yield)
}
FlightMessage::Metrics(s) => {
let m: Option<Arc<RecordBatchMetrics>> =
serde_json::from_str(&s).ok().map(Arc::new);
let m = parse_terminal_metrics(&s);
metrics_ref.swap(m);
break;
}
@@ -513,7 +712,9 @@ impl Database {
metrics,
span: Span::current(),
};
Ok(Output::new_with_stream(Box::pin(record_batch_stream)))
Ok(OutputWithMetrics::from_output(Output::new_with_stream(
Box::pin(record_batch_stream),
)))
}
}
}
@@ -567,15 +768,59 @@ struct FlightContext {
#[cfg(test)]
mod tests {
use std::assert_matches::assert_matches;
use std::sync::Arc;
use std::task::{Context, Poll};
use api::v1::auth_header::AuthScheme;
use api::v1::{AuthHeader, Basic};
use common_error::status_code::StatusCode;
use common_query::OutputData;
use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream};
use datatypes::prelude::{ConcreteDataType, VectorRef};
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::vectors::Int32Vector;
use futures_util::StreamExt;
use tonic::{Code, Status};
use super::*;
use crate::error::TonicSnafu;
struct MockMetricsStream {
schema: datatypes::schema::SchemaRef,
batch: Option<RecordBatch>,
metrics: RecordBatchMetrics,
terminal_metrics_only: bool,
}
impl Stream for MockMetricsStream {
type Item = common_recordbatch::error::Result<RecordBatch>;
fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Poll::Ready(self.batch.take().map(Ok))
}
}
impl RecordBatchStream for MockMetricsStream {
fn name(&self) -> &str {
"MockMetricsStream"
}
fn schema(&self) -> datatypes::schema::SchemaRef {
self.schema.clone()
}
fn output_ordering(&self) -> Option<&[OrderOption]> {
None
}
fn metrics(&self) -> Option<RecordBatchMetrics> {
if self.terminal_metrics_only && self.batch.is_some() {
return None;
}
Some(self.metrics.clone())
}
}
#[test]
fn test_flight_ctx() {
let mut ctx = FlightContext::default();
@@ -612,4 +857,48 @@ mod tests {
assert_eq!(expected.to_string(), actual.to_string());
}
#[tokio::test]
async fn test_query_with_terminal_metrics_tracks_terminal_only_metrics() {
let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
"v",
ConcreteDataType::int32_datatype(),
false,
)]));
let batch = RecordBatch::new(
schema.clone(),
vec![Arc::new(Int32Vector::from_slice([1, 2])) as VectorRef],
)
.unwrap();
let output = Output::new_with_stream(Box::pin(MockMetricsStream {
schema,
batch: Some(batch),
metrics: RecordBatchMetrics {
region_latest_sequences: Some(vec![(7, 42)]),
..Default::default()
},
terminal_metrics_only: true,
}));
let result = OutputWithMetrics::from_output(output);
let terminal_metrics = result.metrics.clone();
assert!(!terminal_metrics.is_ready());
assert!(terminal_metrics.get().is_none());
let OutputData::Stream(mut stream) = result.output.data else {
panic!("expected stream output");
};
while stream.next().await.is_some() {}
assert!(terminal_metrics.is_ready());
assert_eq!(
terminal_metrics.region_watermark_map(),
Some(std::collections::HashMap::from([(7_u64, 42_u64)]))
);
}
#[test]
fn test_parse_terminal_metrics_rejects_invalid_json() {
assert!(parse_terminal_metrics("{not-json}").is_none());
}
}

View File

@@ -34,7 +34,7 @@ pub use common_recordbatch::{RecordBatches, SendableRecordBatchStream};
use snafu::OptionExt;
pub use self::client::Client;
pub use self::database::Database;
pub use self::database::{Database, OutputMetrics, OutputWithMetrics};
pub use self::error::{Error, Result};
use crate::error::{IllegalDatabaseResponseSnafu, ServerSnafu};

View File

@@ -21,17 +21,18 @@ use std::time::SystemTime;
use api::v1::greptime_request::Request;
use api::v1::query_request::Query;
use api::v1::{CreateTableExpr, QueryRequest};
use client::{Client, Database};
use client::{Client, Database, OutputWithMetrics};
use common_error::ext::BoxedError;
use common_grpc::channel_manager::{ChannelConfig, ChannelManager, load_client_tls_config};
use common_meta::cluster::{NodeInfo, NodeInfoKey, Role};
use common_meta::peer::Peer;
use common_meta::rpc::store::RangeRequest;
use common_query::Output;
use common_query::{Output, OutputData};
use common_telemetry::warn;
use meta_client::client::MetaClient;
use query::datafusion::QUERY_PARALLELISM_HINT;
use query::options::QueryOptions;
use query::metrics::terminal_recordbatch_metrics_from_plan;
use query::options::{FlowQueryExtensions, QueryOptions};
use rand::rng;
use rand::seq::SliceRandom;
use servers::query_handler::grpc::GrpcQueryHandler;
@@ -407,6 +408,79 @@ impl FrontendClient {
}
}
pub async fn query_with_terminal_metrics(
&self,
catalog: &str,
schema: &str,
request: QueryRequest,
extensions: &[(&str, &str)],
) -> Result<OutputWithMetrics, Error> {
let flow_extensions = build_flow_extensions(extensions)?;
match self {
FrontendClient::Distributed {
query, batch_opts, ..
} => {
let query_parallelism = query.parallelism.to_string();
let mut hints = vec![
(QUERY_PARALLELISM_HINT, query_parallelism.as_str()),
(READ_PREFERENCE_HINT, batch_opts.read_preference.as_ref()),
];
hints.extend_from_slice(extensions);
let db = self.get_random_active_frontend(catalog, schema).await?;
db.database
.query_with_terminal_metrics(request, &hints)
.await
.map_err(BoxedError::new)
.context(ExternalSnafu)
}
FrontendClient::Standalone {
database_client,
query,
} => {
let mut extensions_map = HashMap::from([(
QUERY_PARALLELISM_HINT.to_string(),
query.parallelism.to_string(),
)]);
for (key, value) in extensions {
extensions_map.insert((*key).to_string(), (*value).to_string());
}
let ctx = QueryContextBuilder::default()
.current_catalog(catalog.to_string())
.current_schema(schema.to_string())
.extensions(extensions_map)
.build();
let ctx = Arc::new(ctx);
let database_client = {
database_client
.handler
.lock()
.map_err(|e| {
UnexpectedSnafu {
reason: format!("Failed to lock database client: {e}"),
}
.build()
})?
.as_ref()
.context(UnexpectedSnafu {
reason: "Standalone's frontend instance is not set",
})?
.upgrade()
.context(UnexpectedSnafu {
reason: "Failed to upgrade database client",
})?
};
database_client
.do_query(Request::Query(request), ctx)
.await
.map(|output| {
wrap_standalone_output_with_terminal_metrics(output, &flow_extensions)
})
.map_err(BoxedError::new)
.context(ExternalSnafu)
}
}
}
/// Handle a request to frontend
pub(crate) async fn handle(
&self,
@@ -497,6 +571,39 @@ impl FrontendClient {
}
}
fn build_flow_extensions(extensions: &[(&str, &str)]) -> Result<FlowQueryExtensions, Error> {
let flow_extensions = HashMap::from_iter(
extensions
.iter()
.map(|(key, value)| ((*key).to_string(), (*value).to_string())),
);
FlowQueryExtensions::from_extensions(&flow_extensions)
.map_err(BoxedError::new)
.context(ExternalSnafu)
}
fn wrap_standalone_output_with_terminal_metrics(
output: Output,
flow_extensions: &FlowQueryExtensions,
) -> OutputWithMetrics {
let should_collect_region_watermark = flow_extensions.should_collect_region_watermark();
let terminal_metrics =
if should_collect_region_watermark && !matches!(&output.data, OutputData::Stream(_)) {
output
.meta
.plan
.clone()
.and_then(terminal_recordbatch_metrics_from_plan)
} else {
None
};
let result = OutputWithMetrics::from_output(output);
if let Some(metrics) = terminal_metrics {
result.metrics.update(Some(metrics));
}
result
}
/// Describe a peer of frontend
#[derive(Debug, Default)]
pub(crate) enum PeerDesc {
@@ -544,11 +651,19 @@ fn extract_u64_segment(message: &str, start: &str, end: &str) -> Option<u64> {
#[cfg(test)]
mod tests {
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use common_error::ext::PlainError;
use common_error::status_code::StatusCode;
use common_query::Output;
use common_query::{Output, OutputData};
use common_recordbatch::adapter::RecordBatchMetrics;
use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream};
use datatypes::prelude::{ConcreteDataType, VectorRef};
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::vectors::Int32Vector;
use futures::StreamExt;
use snafu::GenerateImplicitData;
use tokio::time::timeout;
@@ -557,6 +672,55 @@ mod tests {
#[derive(Debug)]
struct NoopHandler;
struct MockMetricsStream {
schema: datatypes::schema::SchemaRef,
batch: Option<RecordBatch>,
metrics: RecordBatchMetrics,
terminal_metrics_only: bool,
}
impl futures::Stream for MockMetricsStream {
type Item = common_recordbatch::error::Result<RecordBatch>;
fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Poll::Ready(self.batch.take().map(Ok))
}
fn size_hint(&self) -> (usize, Option<usize>) {
(
usize::from(self.batch.is_some()),
Some(usize::from(self.batch.is_some())),
)
}
}
impl RecordBatchStream for MockMetricsStream {
fn name(&self) -> &str {
"MockMetricsStream"
}
fn schema(&self) -> datatypes::schema::SchemaRef {
self.schema.clone()
}
fn output_ordering(&self) -> Option<&[OrderOption]> {
None
}
fn metrics(&self) -> Option<RecordBatchMetrics> {
if self.terminal_metrics_only && self.batch.is_some() {
return None;
}
Some(self.metrics.clone())
}
}
#[derive(Debug)]
struct MetricsHandler;
#[derive(Debug)]
struct ExtensionAwareHandler;
#[async_trait::async_trait]
impl GrpcQueryHandlerWithBoxedError for NoopHandler {
async fn do_query(
@@ -568,6 +732,47 @@ mod tests {
}
}
#[async_trait::async_trait]
impl GrpcQueryHandlerWithBoxedError for MetricsHandler {
async fn do_query(
&self,
_query: Request,
_ctx: QueryContextRef,
) -> std::result::Result<Output, BoxedError> {
let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
"v",
ConcreteDataType::int32_datatype(),
false,
)]));
let batch = RecordBatch::new(
schema.clone(),
vec![Arc::new(Int32Vector::from_slice([1, 2])) as VectorRef],
)
.unwrap();
Ok(Output::new_with_stream(Box::pin(MockMetricsStream {
schema,
batch: Some(batch),
metrics: RecordBatchMetrics {
region_latest_sequences: Some(vec![(42, 99)]),
..Default::default()
},
terminal_metrics_only: true,
})))
}
}
#[async_trait::async_trait]
impl GrpcQueryHandlerWithBoxedError for ExtensionAwareHandler {
async fn do_query(
&self,
_query: Request,
ctx: QueryContextRef,
) -> std::result::Result<Output, BoxedError> {
assert_eq!(ctx.extension("flow.return_region_seq"), Some("true"));
Ok(Output::new_with_affected_rows(1))
}
}
#[tokio::test]
async fn wait_initialized() {
let (client, handler_mut) =
@@ -650,4 +855,81 @@ mod tests {
assert!(!failure.is_stale_cursor());
assert_eq!(failure.stale_cursor, None);
}
#[tokio::test]
async fn test_query_with_terminal_metrics_tracks_watermark_in_standalone_mode() {
let handler: Arc<dyn GrpcQueryHandlerWithBoxedError> = Arc::new(MetricsHandler);
let client =
FrontendClient::from_grpc_handler(Arc::downgrade(&handler), QueryOptions::default());
let result = client
.query_with_terminal_metrics(
"greptime",
"public",
QueryRequest {
query: Some(Query::Sql("select 1".to_string())),
},
&[],
)
.await
.unwrap();
let terminal_metrics = result.metrics.clone();
assert!(!result.metrics.is_ready());
assert!(terminal_metrics.get().is_none());
let OutputData::Stream(mut stream) = result.output.data else {
panic!("expected stream output");
};
while stream.next().await.is_some() {}
assert!(terminal_metrics.is_ready());
assert_eq!(
terminal_metrics.region_watermark_map(),
Some(HashMap::from([(42_u64, 99_u64)]))
);
}
#[tokio::test]
async fn test_query_with_terminal_metrics_forwards_flow_extensions_in_standalone_mode() {
let handler: Arc<dyn GrpcQueryHandlerWithBoxedError> = Arc::new(ExtensionAwareHandler);
let client =
FrontendClient::from_grpc_handler(Arc::downgrade(&handler), QueryOptions::default());
let result = client
.query_with_terminal_metrics(
"greptime",
"public",
QueryRequest {
query: Some(Query::Sql("insert into t select 1".to_string())),
},
&[("flow.return_region_seq", "true")],
)
.await
.unwrap();
assert!(result.metrics.is_ready());
assert!(result.region_watermark_map().is_none());
}
#[tokio::test]
async fn test_query_with_terminal_metrics_rejects_invalid_flow_extensions() {
let handler: Arc<dyn GrpcQueryHandlerWithBoxedError> = Arc::new(NoopHandler);
let client =
FrontendClient::from_grpc_handler(Arc::downgrade(&handler), QueryOptions::default());
let err = client
.query_with_terminal_metrics(
"greptime",
"public",
QueryRequest {
query: Some(Query::Sql("select 1".to_string())),
},
&[("flow.return_region_seq", "not-a-bool")],
)
.await
.unwrap_err();
assert!(format!("{err:?}").contains("Invalid value for flow.return_region_seq"));
}
}

View File

@@ -132,7 +132,9 @@ impl DatafusionQueryEngine {
let output = self
.exec_query_plan((*dml.input).clone(), query_ctx.clone())
.await?;
let mut stream = match output.data {
let common_query::Output { data, meta } = output;
let input_plan = meta.plan;
let mut stream = match data {
OutputData::RecordBatches(batches) => batches.as_stream(),
OutputData::Stream(stream) => stream,
_ => unreachable!(),
@@ -168,7 +170,7 @@ impl DatafusionQueryEngine {
}
Ok(Output::new(
OutputData::AffectedRows(affected_rows),
OutputMeta::new_with_cost(insert_cost),
OutputMeta::new(input_plan, insert_cost),
))
}

View File

@@ -180,6 +180,20 @@ impl Stream for RegionWatermarkMetricsStream {
}
}
pub fn terminal_recordbatch_metrics_from_plan(
plan: Arc<dyn ExecutionPlan>,
) -> Option<RecordBatchMetrics> {
let region_latest_sequences = collect_region_latest_sequences(plan);
if region_latest_sequences.is_empty() {
None
} else {
Some(RecordBatchMetrics {
region_latest_sequences: Some(region_latest_sequences),
..Default::default()
})
}
}
fn collect_region_latest_sequences(plan: Arc<dyn ExecutionPlan>) -> Vec<(u64, u64)> {
let mut merged = std::collections::HashMap::new();
let mut stack = vec![plan];

View File

@@ -39,7 +39,9 @@ use datatypes::arrow::datatypes::SchemaRef;
use futures::{Stream, future, ready};
use futures_util::{StreamExt, TryStreamExt};
use prost::Message;
use session::context::{QueryContext, QueryContextRef};
use query::metrics::terminal_recordbatch_metrics_from_plan;
use query::options::FlowQueryExtensions;
use session::context::{Channel, QueryContextRef};
use snafu::{IntoError, ResultExt, ensure};
use table::table_name::TableName;
use tokio::sync::mpsc;
@@ -48,7 +50,9 @@ use tonic::{Request, Response, Status, Streaming};
use crate::error::{InvalidParameterSnafu, Result, ToJsonSnafu};
pub use crate::grpc::flight::stream::FlightRecordBatchStream;
use crate::grpc::greptime_handler::{GreptimeRequestHandler, get_request_type};
use crate::grpc::greptime_handler::{
GreptimeRequestHandler, create_query_context, get_request_type,
};
use crate::grpc::{FlightCompression, TonicResult, context_auth};
use crate::request_memory_limiter::ServerMemoryLimiter;
use crate::request_memory_metrics::RequestMemoryMetrics;
@@ -192,6 +196,10 @@ impl FlightCraft for GreptimeRequestHandler {
let ticket = request.into_inner().ticket;
let request =
GreptimeRequest::decode(ticket.as_ref()).context(error::InvalidFlightTicketSnafu)?;
let query_ctx =
create_query_context(Channel::Grpc, request.header.as_ref(), hints.clone())?;
FlowQueryExtensions::from_extensions(&query_ctx.extensions())
.map_err(|e| Status::invalid_argument(e.to_string()))?;
// The Grpc protocol pass query by Flight. It needs to be wrapped under a span, in order to record stream
let span = info_span!(
@@ -206,7 +214,7 @@ impl FlightCraft for GreptimeRequestHandler {
output,
TracingContext::from_current_span(),
flight_compression,
QueryContext::arc(),
query_ctx,
);
Ok(Response::new(stream))
}
@@ -539,12 +547,22 @@ fn to_flight_data_stream(
Box::pin(stream) as _
}
OutputData::AffectedRows(rows) => {
let stream = tokio_stream::iter(
FlightEncoder::default()
.encode(FlightMessage::AffectedRows(rows))
.into_iter()
.map(Ok),
);
let affected_rows = FlightEncoder::default().encode(FlightMessage::AffectedRows(rows));
let should_emit_terminal_metrics =
FlowQueryExtensions::from_extensions(&query_ctx.extensions())
.expect("flow extensions must be validated before Flight serialization")
.should_collect_region_watermark();
let terminal_metrics = should_emit_terminal_metrics
.then_some(output.meta.plan)
.flatten()
.and_then(terminal_recordbatch_metrics_from_plan)
.and_then(|metrics| serde_json::to_string(&metrics).ok())
.map(FlightMessage::Metrics)
.map(|message| FlightEncoder::default().encode(message))
.into_iter()
.flatten();
let stream =
tokio_stream::iter(affected_rows.into_iter().chain(terminal_metrics).map(Ok));
Box::pin(stream) as _
}
}

View File

@@ -143,6 +143,23 @@ mod test {
let metrics = stream.metrics().expect("expected terminal metrics");
assert!(metrics.region_latest_sequences.is_none());
let result = client
.sql_with_terminal_metrics(sql, &[("flow.return_region_seq", "true")])
.await
.unwrap();
let terminal_metrics = result.metrics.clone();
let OutputData::Stream(mut stream) = result.output.data else {
panic!("expected stream output");
};
while let Some(batch) = stream.next().await {
batch.unwrap();
}
assert!(
terminal_metrics
.region_watermark_map()
.is_some_and(|m| !m.is_empty())
);
let output = client
.sql_with_hint(sql, &[("flow.return_region_seq", "true")])
.await
@@ -170,6 +187,43 @@ mod test {
let previous_watermark = region_latest_sequences[0];
create_table_named(&client, "bar").await;
let result = client
.sql_with_terminal_metrics("insert into bar select ts, a, `B` from foo", &[])
.await
.unwrap();
let OutputData::AffectedRows(affected_rows) = result.output.data else {
panic!("expected affected rows output");
};
assert_eq!(affected_rows, 9);
assert!(result.metrics.is_ready());
assert!(result.region_watermark_map().is_none());
let err = client
.sql_with_terminal_metrics(
"insert into bar select ts, a, `B` from foo",
&[("flow.return_region_seq", "not-a-bool")],
)
.await
.unwrap_err();
let err_msg = format!("{err:?}");
assert!(err_msg.contains("Invalid value for flow.return_region_seq"));
client.sql("truncate table bar").await.unwrap();
let result = client
.sql_with_terminal_metrics(
"insert into bar select ts, a, `B` from foo",
&[("flow.return_region_seq", "true")],
)
.await
.unwrap();
let OutputData::AffectedRows(affected_rows) = result.output.data else {
panic!("expected affected rows output");
};
assert_eq!(affected_rows, 9);
assert!(result.region_watermark_map().is_some_and(|m| !m.is_empty()));
let incremental_batches = create_record_batches(10);
test_put_record_batches(&client, incremental_batches).await;
@@ -362,6 +416,10 @@ mod test {
}
async fn create_table(client: &Database) {
create_table_named(client, "foo").await;
}
async fn create_table_named(client: &Database, table_name: &str) {
// create table foo (
// ts timestamp time index,
// a int primary key,
@@ -370,7 +428,7 @@ mod test {
let output = client
.create(CreateTableExpr {
schema_name: "public".to_string(),
table_name: "foo".to_string(),
table_name: table_name.to_string(),
column_defs: vec![
ColumnDef {
name: "ts".to_string(),