refactor: per review

Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
discord9
2026-04-29 16:54:24 +08:00
parent 1a4a79d1eb
commit d08743fc64
7 changed files with 497 additions and 246 deletions

View File

@@ -14,8 +14,8 @@
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, RwLock};
use std::task::{Context, Poll};
use api::v1::auth_header::AuthScheme;
@@ -63,7 +63,7 @@ type DoPutResponseStream = Pin<Box<dyn Stream<Item = Result<DoPutResponse>>>>;
#[derive(Debug, Clone, Default)]
pub struct OutputMetrics {
metrics: Arc<ArcSwapOption<RecordBatchMetrics>>,
metrics: Arc<RwLock<Option<RecordBatchMetrics>>>,
ready: Arc<AtomicBool>,
}
@@ -73,7 +73,7 @@ impl OutputMetrics {
}
pub fn update(&self, metrics: Option<RecordBatchMetrics>) {
self.metrics.swap(metrics.map(Arc::new));
*self.metrics.write().expect("metrics lock poisoned") = metrics;
}
pub fn mark_ready(&self) {
@@ -85,7 +85,7 @@ impl OutputMetrics {
}
pub fn get(&self) -> Option<RecordBatchMetrics> {
self.metrics.load().as_ref().map(|m| m.as_ref().clone())
self.metrics.read().expect("metrics lock poisoned").clone()
}
/// Returns proved per-region watermarks.
@@ -227,6 +227,123 @@ fn attach_terminal_metrics(output: Output, terminal_metrics: &OutputMetrics) ->
Output::new(data, meta)
}
async fn output_from_flight_message_stream<S>(
mut flight_message_stream: S,
) -> Result<OutputWithMetrics>
where
S: Stream<Item = Result<FlightMessage>> + Send + Unpin + 'static,
{
let Some(first_flight_message) = flight_message_stream.next().await else {
return IllegalFlightMessagesSnafu {
reason: "Expect the response not to be empty",
}
.fail();
};
let first_flight_message = first_flight_message?;
match first_flight_message {
FlightMessage::AffectedRows { rows, metrics } => {
let terminal_metrics = OutputMetrics::new();
if let Some(metrics) = metrics {
terminal_metrics.update(Some(parse_terminal_metrics(&metrics)?));
}
let next_message = flight_message_stream.next().await.transpose()?;
match next_message {
None => terminal_metrics.mark_ready(),
Some(FlightMessage::Metrics(s)) if terminal_metrics.get().is_none() => {
terminal_metrics.update(Some(parse_terminal_metrics(&s)?));
terminal_metrics.mark_ready();
ensure!(
flight_message_stream.next().await.is_none(),
IllegalFlightMessagesSnafu {
reason: "Expect 'AffectedRows' Flight messages to be followed by at most one Metrics message"
}
);
}
Some(FlightMessage::Metrics(_)) => {
return IllegalFlightMessagesSnafu {
reason: "'AffectedRows' Flight metadata already carries Metrics and cannot be followed by another Metrics message".to_string(),
}
.fail();
}
Some(other) => {
return IllegalFlightMessagesSnafu {
reason: format!(
"'AffectedRows' Flight message can only be followed by a Metrics message, got {other:?}"
),
}
.fail();
}
}
Ok(OutputWithMetrics {
output: Output::new_with_affected_rows(rows),
metrics: terminal_metrics,
})
}
FlightMessage::RecordBatch(_) | FlightMessage::Metrics(_) => IllegalFlightMessagesSnafu {
reason: "The first flight message cannot be a RecordBatch or Metrics message",
}
.fail(),
FlightMessage::Schema(schema) => {
let metrics = Arc::new(ArcSwapOption::from(None));
let metrics_ref = metrics.clone();
let schema = Arc::new(
datatypes::schema::Schema::try_from(schema).context(error::ConvertSchemaSnafu)?,
);
let schema_cloned = schema.clone();
let stream = Box::pin(stream!({
while let Some(flight_message_item) = flight_message_stream.next().await {
let flight_message = match flight_message_item {
Ok(message) => message,
Err(e) => {
yield Err(BoxedError::new(e)).context(ExternalSnafu);
break;
}
};
match flight_message {
FlightMessage::RecordBatch(arrow_batch) => {
yield Ok(RecordBatch::from_df_record_batch(
schema_cloned.clone(),
arrow_batch,
))
}
FlightMessage::Metrics(s) => {
match parse_terminal_metrics(&s) {
Ok(m) => {
metrics_ref.swap(Some(Arc::new(m)));
}
Err(e) => {
yield Err(BoxedError::new(e)).context(ExternalSnafu);
}
};
break;
}
FlightMessage::AffectedRows { .. } | FlightMessage::Schema(_) => {
yield IllegalFlightMessagesSnafu {reason: format!("A Schema message must be succeeded exclusively by a set of RecordBatch messages, flight_message: {:?}", flight_message)}
.fail()
.map_err(BoxedError::new)
.context(ExternalSnafu);
break;
}
}
}
}));
let record_batch_stream = RecordBatchStreamWrapper {
schema,
stream,
output_ordering: None,
metrics,
span: Span::current(),
};
Ok(OutputWithMetrics::from_output(Output::new_with_stream(
Box::pin(record_batch_stream),
)))
}
}
}
#[derive(Clone, Debug, Default)]
pub struct Database {
// The "catalog" and "schema" to be used in processing the requests at the server side.
@@ -597,7 +714,7 @@ impl Database {
let flight_data_stream = response.into_inner();
let mut decoder = FlightDecoder::default();
let mut flight_message_stream = flight_data_stream.map(move |flight_data| {
let flight_message_stream = flight_data_stream.map(move |flight_data| {
flight_data
.map_err(Error::from)
.and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu))?
@@ -606,156 +723,7 @@ impl Database {
})
});
let Some(first_flight_message) = flight_message_stream.next().await else {
return IllegalFlightMessagesSnafu {
reason: "Expect the response not to be empty",
}
.fail();
};
let first_flight_message = first_flight_message?;
match first_flight_message {
FlightMessage::AffectedRows(rows) => {
let terminal_metrics = OutputMetrics::new();
let next_message = flight_message_stream.next().await.transpose()?;
match next_message {
None => terminal_metrics.mark_ready(),
Some(FlightMessage::Metrics(s)) => {
terminal_metrics.update(Some(parse_terminal_metrics(&s)?));
terminal_metrics.mark_ready();
ensure!(
flight_message_stream.next().await.is_none(),
IllegalFlightMessagesSnafu {
reason: "Expect 'AffectedRows' Flight messages to be followed by at most one Metrics message"
}
);
}
Some(other) => {
return IllegalFlightMessagesSnafu {
reason: format!(
"'AffectedRows' Flight message can only be followed by a Metrics message, got {other:?}"
),
}
.fail();
}
}
Ok(OutputWithMetrics {
output: Output::new_with_affected_rows(rows),
metrics: terminal_metrics,
})
}
FlightMessage::RecordBatch(_) | FlightMessage::Metrics(_) => {
IllegalFlightMessagesSnafu {
reason: "The first flight message cannot be a RecordBatch or Metrics message",
}
.fail()
}
FlightMessage::Schema(schema) => {
let metrics = Arc::new(ArcSwapOption::from(None));
let metrics_ref = metrics.clone();
let schema = Arc::new(
datatypes::schema::Schema::try_from(schema)
.context(error::ConvertSchemaSnafu)?,
);
let schema_cloned = schema.clone();
let stream = Box::pin(stream!({
let mut buffered_message: Option<FlightMessage> = None;
let mut stream_ended = false;
while !stream_ended {
let flight_message_item = if let Some(msg) = buffered_message.take() {
Some(Ok(msg))
} else {
flight_message_stream.next().await
};
let flight_message = match flight_message_item {
Some(Ok(message)) => message,
Some(Err(e)) => {
yield Err(BoxedError::new(e)).context(ExternalSnafu);
break;
}
None => break,
};
match flight_message {
FlightMessage::RecordBatch(arrow_batch) => {
let result_to_yield = RecordBatch::from_df_record_batch(
schema_cloned.clone(),
arrow_batch,
);
if let Some(next_flight_message_result) =
flight_message_stream.next().await
{
match next_flight_message_result {
Ok(FlightMessage::Metrics(s)) => {
match parse_terminal_metrics(&s) {
Ok(m) => {
metrics_ref.swap(Some(Arc::new(m)));
}
Err(e) => {
yield Err(BoxedError::new(e))
.context(ExternalSnafu);
break;
}
};
}
Ok(FlightMessage::RecordBatch(rb)) => {
buffered_message = Some(FlightMessage::RecordBatch(rb));
}
Ok(_) => {
yield IllegalFlightMessagesSnafu {reason: "A RecordBatch message can only be succeeded by a Metrics message or another RecordBatch message"}
.fail()
.map_err(BoxedError::new)
.context(ExternalSnafu);
break;
}
Err(e) => {
yield Err(BoxedError::new(e)).context(ExternalSnafu);
break;
}
}
} else {
stream_ended = true;
}
yield Ok(result_to_yield)
}
FlightMessage::Metrics(s) => {
match parse_terminal_metrics(&s) {
Ok(m) => {
metrics_ref.swap(Some(Arc::new(m)));
}
Err(e) => {
yield Err(BoxedError::new(e)).context(ExternalSnafu);
}
};
break;
}
FlightMessage::AffectedRows(_) | FlightMessage::Schema(_) => {
yield IllegalFlightMessagesSnafu {reason: format!("A Schema message must be succeeded exclusively by a set of RecordBatch messages, flight_message: {:?}", flight_message)}
.fail()
.map_err(BoxedError::new)
.context(ExternalSnafu);
break;
}
}
}
}));
let record_batch_stream = RecordBatchStreamWrapper {
schema,
stream,
output_ordering: None,
metrics,
span: Span::current(),
};
Ok(OutputWithMetrics::from_output(Output::new_with_stream(
Box::pin(record_batch_stream),
)))
}
}
output_from_flight_message_stream(flight_message_stream).await
}
/// Ingest a stream of [RecordBatch]es that belong to a table, using Arrow Flight's "`DoPut`"
@@ -859,6 +827,17 @@ mod tests {
}
}
fn terminal_metrics_json() -> String {
serde_json::to_string(&RecordBatchMetrics {
region_watermarks: vec![common_recordbatch::adapter::RegionWatermarkEntry {
region_id: 7,
watermark: Some(42),
}],
..Default::default()
})
.unwrap()
}
#[test]
fn test_flight_ctx() {
let mut ctx = FlightContext::default();
@@ -948,7 +927,47 @@ mod tests {
}
#[tokio::test]
async fn test_invalid_terminal_metrics_after_record_batch_fails_before_yielding_batch() {
async fn test_affected_rows_inline_metrics_are_parsed() {
let output = output_from_flight_message_stream(futures_util::stream::iter(vec![Ok(
FlightMessage::AffectedRows {
rows: 3,
metrics: Some(terminal_metrics_json()),
},
)]
as Vec<Result<FlightMessage>>))
.await
.unwrap();
assert!(matches!(output.output.data, OutputData::AffectedRows(3)));
assert!(output.metrics.is_ready());
assert_eq!(
output.metrics.region_watermark_map(),
Some(std::collections::HashMap::from([(7, 42)]))
);
}
#[tokio::test]
async fn test_affected_rows_inline_metrics_rejects_trailing_metrics() {
let metrics_json = terminal_metrics_json();
let err = output_from_flight_message_stream(futures_util::stream::iter(vec![
Ok(FlightMessage::AffectedRows {
rows: 3,
metrics: Some(metrics_json.clone()),
}),
Ok(FlightMessage::Metrics(metrics_json)),
]
as Vec<Result<FlightMessage>>))
.await
.unwrap_err();
assert!(
err.to_string().contains("already carries Metrics"),
"unexpected error: {err:?}"
);
}
#[tokio::test]
async fn test_invalid_terminal_metrics_after_record_batch_yields_batch_then_error() {
let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
"v",
ConcreteDataType::int32_datatype(),
@@ -959,38 +978,21 @@ mod tests {
vec![Arc::new(Int32Vector::from_slice([1])) as VectorRef],
)
.unwrap();
let metrics = Arc::new(ArcSwapOption::from(None));
let metrics_ref = metrics.clone();
let schema_cloned = schema.clone();
let mut flight_message_stream = futures_util::stream::iter(vec![
let output = output_from_flight_message_stream(futures_util::stream::iter(vec![
Ok(FlightMessage::Schema(schema.arrow_schema().clone())),
Ok(FlightMessage::RecordBatch(batch.into_df_record_batch())),
Ok(FlightMessage::Metrics("{not-json}".to_string())),
]
as Vec<Result<FlightMessage>>);
as Vec<Result<FlightMessage>>))
.await
.unwrap();
let terminal_metrics = output.metrics.clone();
let OutputData::Stream(mut record_batch_stream) = output.output.data else {
panic!("expected stream output");
};
let mut record_batch_stream = Box::pin(stream!({
let Some(Ok(FlightMessage::RecordBatch(arrow_batch))) =
flight_message_stream.next().await
else {
return;
};
let result_to_yield =
RecordBatch::from_df_record_batch(schema_cloned.clone(), arrow_batch);
if let Some(Ok(FlightMessage::Metrics(s))) = flight_message_stream.next().await {
match parse_terminal_metrics(&s) {
Ok(m) => {
metrics_ref.swap(Some(Arc::new(m)));
}
Err(e) => {
yield Err(BoxedError::new(e)).context(ExternalSnafu);
return;
}
}
}
yield Ok(result_to_yield);
}));
let batch = record_batch_stream.next().await.unwrap().unwrap();
assert_eq!(batch.num_rows(), 1);
let err = record_batch_stream.next().await.unwrap().unwrap_err();
assert_eq!("External error", err.to_string());
@@ -999,7 +1001,8 @@ mod tests {
"unexpected error: {err:?}"
);
assert!(record_batch_stream.next().await.is_none());
assert!(metrics.load().is_none());
assert!(terminal_metrics.is_ready());
assert!(terminal_metrics.get().is_none());
}
#[test]

View File

@@ -41,7 +41,10 @@ use crate::error::{DecodeFlightDataSnafu, InvalidFlightDataSnafu, Result};
pub enum FlightMessage {
Schema(SchemaRef),
RecordBatch(DfRecordBatch),
AffectedRows(usize),
AffectedRows {
rows: usize,
metrics: Option<String>,
},
Metrics(String),
}
@@ -116,10 +119,12 @@ impl FlightEncoder {
encoded_batch.into(),
)
}
FlightMessage::AffectedRows(rows) => {
FlightMessage::AffectedRows { rows, metrics } => {
let metadata = FlightMetadata {
affected_rows: Some(AffectedRows { value: rows as _ }),
metrics: None,
metrics: metrics.map(|s| Metrics {
metrics: s.into_bytes(),
}),
}
.encode_to_vec();
vec1![FlightData {
@@ -223,7 +228,12 @@ impl FlightDecoder {
let metadata = FlightMetadata::decode(flight_data.app_metadata.clone())
.context(DecodeFlightDataSnafu)?;
if let Some(AffectedRows { value }) = metadata.affected_rows {
return Ok(Some(FlightMessage::AffectedRows(value as _)));
return Ok(Some(FlightMessage::AffectedRows {
rows: value as _,
metrics: metadata
.metrics
.map(|m| String::from_utf8_lossy(&m.metrics).to_string()),
}));
}
if let Some(Metrics { metrics }) = metadata.metrics {
return Ok(Some(FlightMessage::Metrics(
@@ -426,6 +436,47 @@ mod test {
Ok(())
}
#[test]
fn test_affected_rows_metrics_encode_decode() -> Result<()> {
let metrics = r#"{"region_watermarks":[{"region_id":42,"watermark":7}]}"#;
let mut encoder = FlightEncoder::default();
let encoded = encoder.encode(FlightMessage::AffectedRows {
rows: 3,
metrics: Some(metrics.to_string()),
});
assert_eq!(encoded.len(), 1);
let mut decoder = FlightDecoder::default();
let decoded = decoder.try_decode(encoded.first())?.unwrap();
let FlightMessage::AffectedRows {
rows,
metrics: decoded_metrics,
} = decoded
else {
unreachable!()
};
assert_eq!(rows, 3);
assert_eq!(decoded_metrics.as_deref(), Some(metrics));
let encoded = encoder.encode(FlightMessage::AffectedRows {
rows: 5,
metrics: None,
});
let decoded = decoder.try_decode(encoded.first())?.unwrap();
let FlightMessage::AffectedRows {
rows,
metrics: decoded_metrics,
} = decoded
else {
unreachable!()
};
assert_eq!(rows, 5);
assert!(decoded_metrics.is_none());
Ok(())
}
#[test]
fn test_flight_messages_to_recordbatches() {
let schema = Arc::new(Schema::new(vec![Field::new("m", DataType::Int32, true)]));

View File

@@ -17,8 +17,10 @@ mod catalog;
use std::collections::HashMap;
use std::fmt::Debug;
use std::ops::Deref;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, RwLock};
use std::task::{Context, Poll};
use std::time::Duration;
use api::region::RegionResponse;
@@ -36,7 +38,8 @@ use common_error::status_code::StatusCode;
use common_meta::datanode::TopicStatsReporter;
use common_query::OutputData;
use common_query::request::QueryRequest;
use common_recordbatch::SendableRecordBatchStream;
use common_recordbatch::adapter::RecordBatchMetrics;
use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream, SendableRecordBatchStream};
use common_runtime::Runtime;
use common_telemetry::tracing::{self, info_span};
use common_telemetry::tracing_context::{FutureExt, TracingContext};
@@ -45,6 +48,7 @@ use dashmap::DashMap;
use datafusion::datasource::TableProvider;
use datafusion_common::tree_node::TreeNode;
use either::Either;
use futures_util::Stream;
use futures_util::future::try_join_all;
use metric_engine::engine::MetricEngine;
use mito2::engine::{MITO_ENGINE_NAME, MitoEngine};
@@ -53,6 +57,7 @@ use query::QueryEngineRef;
pub use query::dummy_catalog::{
DummyCatalogList, DummyTableProviderFactory, TableProviderFactoryRef,
};
use query::options::should_collect_region_watermark_from_extensions;
use serde_json;
use servers::error::{
self as servers_error, ExecuteGrpcRequestSnafu, Result as ServerResult, SuspendedSnafu,
@@ -278,16 +283,31 @@ impl RegionServer {
.await
.context(DecodeLogicalPlanSnafu)?;
self.inner
let stream = self
.inner
.handle_read(
QueryRequest {
header: request.header,
region_id,
plan,
},
query_ctx,
query_ctx.clone(),
)
.await
.await?;
let region_latest_seq =
if should_collect_region_watermark_from_extensions(&query_ctx.extensions()) {
query_ctx.get_snapshot(region_id.as_u64())
} else {
None
};
if let Some(seq) = region_latest_seq {
Ok(Box::pin(RegionWatermarkStream::new(stream, region_id, seq))
as SendableRecordBatchStream)
} else {
Ok(stream)
}
}
#[tracing::instrument(skip_all)]
@@ -749,6 +769,83 @@ impl RegionServer {
}
}
/// Wraps a region read stream so terminal metrics can carry the scan-open watermark.
struct RegionWatermarkStream {
stream: SendableRecordBatchStream,
region_id: u64,
snapshot_seq: u64,
finished: AtomicBool,
}
impl RegionWatermarkStream {
fn new(stream: SendableRecordBatchStream, region_id: RegionId, snapshot_seq: u64) -> Self {
Self {
stream,
region_id: region_id.as_u64(),
snapshot_seq,
finished: AtomicBool::new(false),
}
}
fn merged_metrics(&self, mut metrics: RecordBatchMetrics) -> RecordBatchMetrics {
let entry = if let Some(entry) = metrics
.region_watermarks
.iter_mut()
.find(|entry| entry.region_id == self.region_id)
{
entry
} else {
metrics
.region_watermarks
.push(common_recordbatch::adapter::RegionWatermarkEntry {
region_id: self.region_id,
watermark: None,
});
metrics.region_watermarks.last_mut().unwrap()
};
entry.watermark = Some(self.snapshot_seq);
metrics
}
}
impl RecordBatchStream for RegionWatermarkStream {
fn name(&self) -> &str {
self.stream.name()
}
fn schema(&self) -> datatypes::schema::SchemaRef {
self.stream.schema()
}
fn output_ordering(&self) -> Option<&[OrderOption]> {
self.stream.output_ordering()
}
fn metrics(&self) -> Option<RecordBatchMetrics> {
let base = self.stream.metrics();
if !self.finished.load(Ordering::Relaxed) {
return base;
}
Some(self.merged_metrics(base.unwrap_or_default()))
}
}
impl Stream for RegionWatermarkStream {
type Item = common_recordbatch::error::Result<RecordBatch>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match Pin::new(&mut self.stream).poll_next(cx) {
Poll::Ready(None) => {
self.finished.store(true, Ordering::Relaxed);
Poll::Ready(None)
}
other => other,
}
}
}
#[async_trait]
impl RegionServerHandler for RegionServer {
async fn handle(&self, request: region_request::Body) -> ServerResult<RegionResponseV1> {
@@ -1669,10 +1766,16 @@ impl RegionAttribute {
mod tests {
use std::assert_matches;
use std::sync::Arc;
use api::v1::SemanticType;
use common_error::ext::ErrorExt;
use datatypes::prelude::ConcreteDataType;
use common_recordbatch::RecordBatches;
use common_recordbatch::adapter::RegionWatermarkEntry;
use datatypes::prelude::{ConcreteDataType, VectorRef};
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::vectors::Int32Vector;
use futures_util::StreamExt;
use mito2::test_util::CreateRequestBuilder;
use store_api::metadata::{ColumnMetadata, RegionMetadata, RegionMetadataBuilder};
use store_api::region_engine::RegionEngine;
@@ -1685,6 +1788,36 @@ mod tests {
use crate::error::Result;
use crate::tests::{MockRegionEngine, mock_region_server};
#[tokio::test]
async fn test_region_watermark_stream_only_sets_terminal_metrics() {
let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
"v",
ConcreteDataType::int32_datatype(),
false,
)]));
let values: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
let batch = RecordBatch::new(schema.clone(), vec![values]).unwrap();
let stream = RecordBatches::try_new(schema, vec![batch])
.unwrap()
.as_stream();
let region_id = RegionId::new(42, 7);
let wrapped = RegionWatermarkStream::new(stream, region_id, 99);
let mut pinned = Box::pin(wrapped);
assert!(pinned.as_ref().get_ref().metrics().is_none());
while pinned.next().await.is_some() {}
let metrics = pinned.as_ref().get_ref().metrics().unwrap();
assert_eq!(
metrics.region_watermarks,
vec![RegionWatermarkEntry {
region_id: region_id.as_u64(),
watermark: Some(99),
}]
);
}
#[tokio::test]
async fn test_region_registering() {
common_telemetry::init_default_ut_logging();

View File

@@ -59,7 +59,8 @@ use crate::error::{
TableNotFoundSnafu, TableReadOnlySnafu, UnsupportedExprSnafu,
};
use crate::executor::QueryExecutor;
use crate::metrics::{OnDone, QUERY_STAGE_ELAPSED};
use crate::metrics::{OnDone, QUERY_STAGE_ELAPSED, RegionWatermarkMetricsStream};
use crate::options::FlowQueryExtensions;
use crate::physical_wrapper::PhysicalPlanWrapperRef;
use crate::planner::{DfLogicalPlanner, LogicalPlanner};
use crate::query_engine::{DescribeResult, QueryEngineContext, QueryEngineState};
@@ -100,8 +101,10 @@ impl DatafusionQueryEngine {
optimized_physical_plan
};
let stream = self.execute_stream(&ctx, &physical_plan)?;
Ok(Output::new(
OutputData::Stream(self.execute_stream(&ctx, &physical_plan)?),
OutputData::Stream(stream),
OutputMeta::new_with_plan(physical_plan),
))
}
@@ -128,10 +131,10 @@ impl DatafusionQueryEngine {
let table_name = dml.table_name.resolve(default_catalog, default_schema);
let table = self.find_table(&table_name, &query_ctx).await?;
let output = self
let Output { data, meta } = self
.exec_query_plan((*dml.input).clone(), query_ctx.clone())
.await?;
let mut stream = match output.data {
let mut stream = match data {
OutputData::RecordBatches(batches) => batches.as_stream(),
OutputData::Stream(stream) => stream,
_ => unreachable!(),
@@ -167,7 +170,7 @@ impl DatafusionQueryEngine {
}
Ok(Output::new(
OutputData::AffectedRows(affected_rows),
OutputMeta::new_with_cost(insert_cost),
OutputMeta::new(meta.plan, insert_cost),
))
}
@@ -544,6 +547,9 @@ impl QueryExecutor for DatafusionQueryEngine {
plan: &Arc<dyn ExecutionPlan>,
) -> Result<SendableRecordBatchStream> {
let explain_verbose = ctx.query_ctx().explain_verbose();
let should_collect_region_watermark =
FlowQueryExtensions::parse_flow_extensions(&ctx.query_ctx().extensions())?
.is_some_and(|extensions| extensions.should_collect_region_watermark());
let output_partitions = plan.properties().output_partitioning().partition_count();
if explain_verbose {
common_telemetry::info!("Executing query plan, output_partitions: {output_partitions}");
@@ -579,7 +585,14 @@ impl QueryExecutor for DatafusionQueryEngine {
);
}
});
Ok(Box::pin(stream))
if should_collect_region_watermark {
Ok(Box::pin(RegionWatermarkMetricsStream::new(
Box::pin(stream),
plan.clone(),
)))
} else {
Ok(Box::pin(stream))
}
}
_ => {
// merge into a single partition
@@ -608,7 +621,14 @@ impl QueryExecutor for DatafusionQueryEngine {
);
}
});
Ok(Box::pin(stream))
if should_collect_region_watermark {
Ok(Box::pin(RegionWatermarkMetricsStream::new(
Box::pin(stream),
plan.clone(),
)))
} else {
Ok(Box::pin(stream))
}
}
}
}

View File

@@ -230,13 +230,52 @@ fn collect_region_watermarks(plan: Arc<dyn ExecutionPlan>) -> Vec<RegionWatermar
merge_scan.sub_stage_metrics(),
);
}
stack.extend(plan.children().into_iter().cloned());
}
finalize_region_watermarks(merged)
}
fn merge_region_watermark_entries(
merged: &mut BTreeMap<u64, MergeState>,
entries: impl IntoIterator<Item = RegionWatermarkEntry>,
) {
for entry in entries {
merged
.entry(entry.region_id)
.and_modify(|existing| match entry.watermark {
None => match existing {
MergeState::Participated | MergeState::Proved(_) => {
*existing = MergeState::Unproved;
}
MergeState::Unproved | MergeState::Conflict { .. } => {}
},
Some(seq) => match existing {
MergeState::Participated => {
*existing = MergeState::Proved(seq);
}
MergeState::Unproved => {}
MergeState::Proved(existing_seq) if *existing_seq == seq => {}
MergeState::Proved(existing_seq) => {
let old_seq = *existing_seq;
*existing = MergeState::Conflict {
watermarks: vec![old_seq, seq],
};
}
MergeState::Conflict { watermarks } => {
if !watermarks.contains(&seq) {
watermarks.push(seq);
}
}
},
})
.or_insert(match entry.watermark {
Some(seq) => MergeState::Proved(seq),
None => MergeState::Unproved,
});
}
}
fn merge_merge_scan_region_watermarks(
merged: &mut BTreeMap<u64, MergeState>,
regions: impl IntoIterator<Item = u64>,
@@ -247,40 +286,7 @@ fn merge_merge_scan_region_watermarks(
}
for metrics in sub_stage_metrics {
for entry in metrics.region_watermarks {
merged
.entry(entry.region_id)
.and_modify(|existing| match entry.watermark {
None => match existing {
MergeState::Participated | MergeState::Proved(_) => {
*existing = MergeState::Unproved;
}
MergeState::Unproved | MergeState::Conflict { .. } => {}
},
Some(seq) => match existing {
MergeState::Participated => {
*existing = MergeState::Proved(seq);
}
MergeState::Unproved => {}
MergeState::Proved(existing_seq) if *existing_seq == seq => {}
MergeState::Proved(existing_seq) => {
let old_seq = *existing_seq;
*existing = MergeState::Conflict {
watermarks: vec![old_seq, seq],
};
}
MergeState::Conflict { watermarks } => {
if !watermarks.contains(&seq) {
watermarks.push(seq);
}
}
},
})
.or_insert(match entry.watermark {
Some(seq) => MergeState::Proved(seq),
None => MergeState::Unproved,
});
}
merge_region_watermark_entries(merged, metrics.region_watermarks);
}
}

View File

@@ -177,10 +177,31 @@ impl FlowQueryExtensions {
}
pub fn should_collect_region_watermark(&self) -> bool {
self.return_region_seq || self.incremental_after_seqs.is_some()
should_collect_region_watermark(
self.return_region_seq,
self.incremental_after_seqs.is_some(),
)
}
}
pub fn should_collect_region_watermark_from_extensions(
extensions: &HashMap<String, String>,
) -> bool {
let return_region_seq = extensions
.get(FLOW_RETURN_REGION_SEQ)
.is_some_and(|value| value.eq_ignore_ascii_case("true"));
let has_incremental_after_seqs = extensions.contains_key(FLOW_INCREMENTAL_AFTER_SEQS);
should_collect_region_watermark(return_region_seq, has_incremental_after_seqs)
}
fn should_collect_region_watermark(
return_region_seq: bool,
has_incremental_after_seqs: bool,
) -> bool {
return_region_seq || has_incremental_after_seqs
}
fn parse_incremental_after_seqs(value: &str) -> Result<HashMap<u64, u64>> {
let raw = serde_json::from_str::<HashMap<String, serde_json::Value>>(value).map_err(|e| {
invalid_query_context_extension(format!(
@@ -420,6 +441,24 @@ mod flow_extension_tests {
assert!(parsed.should_collect_region_watermark());
}
#[test]
fn test_should_collect_region_watermark_from_extensions() {
let exts = HashMap::from([(FLOW_RETURN_REGION_SEQ.to_string(), "true".to_string())]);
assert!(should_collect_region_watermark_from_extensions(&exts));
let exts = HashMap::from([(
FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
r#"{"1":10}"#.to_string(),
)]);
assert!(should_collect_region_watermark_from_extensions(&exts));
let exts = HashMap::from([(FLOW_RETURN_REGION_SEQ.to_string(), "false".to_string())]);
assert!(!should_collect_region_watermark_from_extensions(&exts));
assert!(!should_collect_region_watermark_from_extensions(
&HashMap::new()
));
}
#[test]
fn test_parse_flow_extensions_return_region_seq_only_returns_some() {
let exts = HashMap::from([(FLOW_RETURN_REGION_SEQ.to_string(), "true".to_string())]);

View File

@@ -26,6 +26,7 @@ use arrow_flight::{
};
use async_trait::async_trait;
use bytes::{self, Bytes};
use common_error::ext::ErrorExt;
use common_grpc::flight::do_put::{DoPutMetadata, DoPutResponse};
use common_grpc::flight::{FlightDecoder, FlightEncoder, FlightMessage};
use common_memory_manager::MemoryGuard;
@@ -201,7 +202,7 @@ impl FlightCraft for GreptimeRequestHandler {
// This does not authorize or execute anything; `handle_request()` below still performs
// the normal frontend handling and auth checks before query execution.
FlowQueryExtensions::parse_flow_extensions(&query_ctx.extensions())
.map_err(|e| Status::invalid_argument(e.to_string()))?;
.map_err(|e| Status::invalid_argument(e.output_msg()))?;
// The Grpc protocol pass query by Flight. It needs to be wrapped under a span, in order to record stream
let span = info_span!(
@@ -549,7 +550,6 @@ fn to_flight_data_stream(
Box::pin(stream) as _
}
OutputData::AffectedRows(rows) => {
let affected_rows = FlightEncoder::default().encode(FlightMessage::AffectedRows(rows));
let should_emit_terminal_metrics =
FlowQueryExtensions::parse_flow_extensions(&query_ctx.extensions())
.expect("flow extensions must be validated before Flight serialization")
@@ -558,13 +558,12 @@ fn to_flight_data_stream(
.then_some(output.meta.plan)
.flatten()
.and_then(terminal_recordbatch_metrics_from_plan)
.and_then(|metrics| serde_json::to_string(&metrics).ok())
.map(FlightMessage::Metrics)
.map(|message| FlightEncoder::default().encode(message))
.into_iter()
.flatten();
let stream =
tokio_stream::iter(affected_rows.into_iter().chain(terminal_metrics).map(Ok));
.and_then(|metrics| serde_json::to_string(&metrics).ok());
let affected_rows = FlightEncoder::default().encode(FlightMessage::AffectedRows {
rows,
metrics: terminal_metrics,
});
let stream = tokio_stream::iter(affected_rows.into_iter().map(Ok));
Box::pin(stream) as _
}
}