From c7273efb36e926bbcd2eb76e09e844088bca8a42 Mon Sep 17 00:00:00 2001 From: jeremyhi Date: Fri, 3 Jul 2026 13:30:36 +0800 Subject: [PATCH] fix: global limit for distributed inspect streams Signed-off-by: jeremyhi --- src/catalog/src/information_extension.rs | 11 +- src/common/recordbatch/src/util.rs | 146 ++++++++ src/query/src/optimizer.rs | 1 + src/query/src/optimizer/global_limit.rs | 312 ++++++++++++++++++ src/query/src/query_engine/state.rs | 4 + .../information_schema/ssts_limit.result | 69 ++++ .../information_schema/ssts_limit.sql | 38 +++ .../common/tql-explain-analyze/explain.result | 3 + 8 files changed, 582 insertions(+), 2 deletions(-) create mode 100644 src/query/src/optimizer/global_limit.rs create mode 100644 tests/cases/distributed/information_schema/ssts_limit.result create mode 100644 tests/cases/distributed/information_schema/ssts_limit.sql diff --git a/src/catalog/src/information_extension.rs b/src/catalog/src/information_extension.rs index bcb5056bad..08d81ac1f0 100644 --- a/src/catalog/src/information_extension.rs +++ b/src/catalog/src/information_extension.rs @@ -23,7 +23,7 @@ use common_meta::rpc::procedure; use common_procedure::{ProcedureInfo, ProcedureState}; use common_query::request::QueryRequest; use common_recordbatch::SendableRecordBatchStream; -use common_recordbatch::util::ChainedRecordBatchStream; +use common_recordbatch::util::{ChainedRecordBatchStream, LimitedRecordBatchStream}; use meta_client::MetaClientRef; use snafu::ResultExt; use store_api::storage::RegionId; @@ -120,6 +120,7 @@ impl InformationExtension for DistributedInformationExtension { .map_err(BoxedError::new) .context(crate::error::ListNodesSnafu)?; + let limit = request.scan.limit; let plan = request .build_plan() .context(crate::error::DatafusionSnafu)?; @@ -140,6 +141,12 @@ impl InformationExtension for DistributedInformationExtension { let chained = ChainedRecordBatchStream::new(streams).context(crate::error::CreateRecordBatchSnafu)?; - Ok(Box::pin(chained)) + match limit { + Some(limit) => Ok(Box::pin(LimitedRecordBatchStream::new( + Box::pin(chained), + limit, + ))), + None => Ok(Box::pin(chained)), + } } } diff --git a/src/common/recordbatch/src/util.rs b/src/common/recordbatch/src/util.rs index 0a587e303d..abe8403546 100644 --- a/src/common/recordbatch/src/util.rs +++ b/src/common/recordbatch/src/util.rs @@ -98,6 +98,83 @@ impl ChainedRecordBatchStream { } } +/// A stream that stops after yielding at most `remaining` rows. +pub struct LimitedRecordBatchStream { + input: Option, + remaining: usize, + schema: SchemaRef, + output_ordering: Option>, +} + +impl LimitedRecordBatchStream { + pub fn new(input: SendableRecordBatchStream, limit: usize) -> Self { + let schema = input.schema(); + let output_ordering = input.output_ordering().map(|o| o.to_vec()); + Self { + input: Some(input), + remaining: limit, + schema, + output_ordering, + } + } +} + +impl RecordBatchStream for LimitedRecordBatchStream { + fn name(&self) -> &str { + "LimitedRecordBatchStream" + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn output_ordering(&self) -> Option<&[OrderOption]> { + self.output_ordering.as_deref() + } + + fn metrics(&self) -> Option { + self.input.as_ref().and_then(|input| input.metrics()) + } +} + +impl Stream for LimitedRecordBatchStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { + if self.remaining == 0 { + self.input.take(); + return Poll::Ready(None); + } + + let Some(input) = self.input.as_mut() else { + return Poll::Ready(None); + }; + + match input.poll_next_unpin(ctx) { + Poll::Ready(Some(Ok(batch))) => { + let num_rows = batch.num_rows(); + if num_rows > self.remaining { + let remaining = self.remaining; + self.remaining = 0; + self.input.take(); + Poll::Ready(Some(batch.slice(0, remaining))) + } else { + self.remaining -= num_rows; + if self.remaining == 0 { + self.input.take(); + } + Poll::Ready(Some(Ok(batch))) + } + } + Poll::Ready(None) => { + self.input.take(); + Poll::Ready(None) + } + other => other, + } + } +} + impl RecordBatchStream for ChainedRecordBatchStream { fn name(&self) -> &str { "ChainedRecordBatchStream" @@ -172,6 +249,75 @@ mod tests { } } + #[tokio::test] + async fn test_limited_chained_stream() { + let column_schemas = vec![ColumnSchema::new( + "number", + ConcreteDataType::uint32_datatype(), + false, + )]; + + let schema = Arc::new(Schema::try_new(column_schemas).unwrap()); + let first = RecordBatch::new( + schema.clone(), + [Arc::new(UInt32Vector::from_vec(vec![0, 1, 2])) as _], + ) + .unwrap(); + let second = RecordBatch::new( + schema.clone(), + [Arc::new(UInt32Vector::from_vec(vec![3, 4, 5])) as _], + ) + .unwrap(); + let chained = ChainedRecordBatchStream::new(vec![ + Box::pin(MockRecordBatchStream { + schema: schema.clone(), + batch: Some(first), + }), + Box::pin(MockRecordBatchStream { + schema: schema.clone(), + batch: Some(second), + }), + ]) + .unwrap(); + + let batches = collect(Box::pin(LimitedRecordBatchStream::new( + Box::pin(chained), + 4, + ))) + .await + .unwrap(); + + assert_eq!(2, batches.len()); + assert_eq!(3, batches[0].num_rows()); + assert_eq!(1, batches[1].num_rows()); + } + + #[tokio::test] + async fn test_limited_stream_with_zero_limit() { + let column_schemas = vec![ColumnSchema::new( + "number", + ConcreteDataType::uint32_datatype(), + false, + )]; + + let schema = Arc::new(Schema::try_new(column_schemas).unwrap()); + let batch = RecordBatch::new( + schema.clone(), + [Arc::new(UInt32Vector::from_vec(vec![0])) as _], + ) + .unwrap(); + let stream = MockRecordBatchStream { + schema, + batch: Some(batch), + }; + + let batches = collect(Box::pin(LimitedRecordBatchStream::new(Box::pin(stream), 0))) + .await + .unwrap(); + + assert!(batches.is_empty()); + } + #[tokio::test] async fn test_collect() { let column_schemas = vec![ColumnSchema::new( diff --git a/src/query/src/optimizer.rs b/src/query/src/optimizer.rs index ffbfff5ee2..8827d48ed6 100644 --- a/src/query/src/optimizer.rs +++ b/src/query/src/optimizer.rs @@ -16,6 +16,7 @@ pub mod const_normalization; pub mod constant_term; pub mod count_nest_aggr; pub mod count_wildcard; +pub mod global_limit; pub(crate) mod json_type_concretize; pub mod parallelize_scan; pub mod pass_distribution; diff --git a/src/query/src/optimizer/global_limit.rs b/src/query/src/optimizer/global_limit.rs new file mode 100644 index 0000000000..cba2939b79 --- /dev/null +++ b/src/query/src/optimizer/global_limit.rs @@ -0,0 +1,312 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use datafusion::config::ConfigOptions; +use datafusion::physical_optimizer::PhysicalOptimizerRule; +use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion::physical_plan::filter::FilterExec; +use datafusion::physical_plan::limit::GlobalLimitExec; +use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; +use datafusion_common::Result as DfResult; +use datafusion_physical_expr::OrderingRequirements; + +#[derive(Debug)] +pub struct EnsureGlobalLimitForFetch; + +impl PhysicalOptimizerRule for EnsureGlobalLimitForFetch { + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> DfResult> { + Self::optimize_plan(plan, ParentContext::default()) + } + + fn name(&self) -> &str { + "EnsureGlobalLimitForFetch" + } + + fn schema_check(&self) -> bool { + true + } +} + +impl EnsureGlobalLimitForFetch { + fn optimize_plan( + plan: Arc, + parent: ParentContext, + ) -> DfResult> { + let children = plan.children(); + let plan = if children.is_empty() { + plan + } else { + let required_input_ordering = plan.required_input_ordering(); + let maintains_input_order = plan.maintains_input_order(); + let child_parent = ParentContext { + has_global_fetch: provides_global_fetch(&plan), + required_ordering: None, + }; + let children = children + .into_iter() + .enumerate() + .map(|(idx, child)| { + let required_ordering = required_input_ordering + .get(idx) + .cloned() + .unwrap_or(None) + .or_else(|| { + maintains_input_order + .get(idx) + .copied() + .unwrap_or(false) + .then(|| parent.required_ordering.clone()) + .flatten() + }); + let parent = ParentContext { + required_ordering, + ..child_parent.clone() + }; + Self::optimize_plan(Arc::clone(child), parent) + }) + .collect::>>()?; + plan.with_new_children(children)? + }; + + let Some(fetch) = plan.fetch() else { + return Ok(plan); + }; + + if parent.has_global_fetch + || !plan.as_any().is::() + || plan.output_partitioning().partition_count() <= 1 + { + return Ok(plan); + } + + Ok(add_global_fetch(plan, fetch, parent.required_ordering)) + } +} + +#[derive(Clone, Default)] +struct ParentContext { + has_global_fetch: bool, + required_ordering: Option, +} + +fn provides_global_fetch(plan: &Arc) -> bool { + if plan.fetch().is_none() { + return false; + } + + plan.as_any().is::() + || plan.as_any().is::() + || plan.as_any().is::() +} + +fn add_global_fetch( + plan: Arc, + fetch: usize, + required_ordering: Option, +) -> Arc { + if required_ordering.is_some() + && let Some(ordering) = plan.output_ordering().cloned() + { + Arc::new(SortPreservingMergeExec::new(ordering, plan).with_fetch(Some(fetch))) + } else { + Arc::new(CoalescePartitionsExec::new(plan).with_fetch(Some(fetch))) + } +} + +#[cfg(test)] +mod tests { + use datafusion::arrow::array::Int32Array; + use datafusion::arrow::compute::SortOptions; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::arrow::record_batch::RecordBatch; + use datafusion::physical_expr::expressions::{col, lit}; + use datafusion::physical_plan::filter::FilterExecBuilder; + use datafusion::physical_plan::limit::GlobalLimitExec; + use datafusion::physical_plan::projection::ProjectionExec; + use datafusion::physical_plan::test::TestMemoryExec; + use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; + + use super::*; + + #[test] + fn adds_global_limit_for_multi_partition_filter_fetch() { + let filter = filter_fetch(unordered_input(), 1); + + let optimized = + EnsureGlobalLimitForFetch::optimize_plan(filter, ParentContext::default()).unwrap(); + + assert!(optimized.as_any().is::()); + assert_eq!(optimized.fetch(), Some(1)); + assert_eq!(optimized.output_partitioning().partition_count(), 1); + } + + #[test] + fn still_visits_subtree_under_global_limit() { + let filter = filter_fetch(unordered_input(), 5); + let projection = Arc::new( + ProjectionExec::try_new( + vec![(col("a", filter.schema().as_ref()).unwrap(), "a".to_string())], + filter, + ) + .unwrap(), + ); + let limit = + Arc::new(GlobalLimitExec::new(projection, 0, Some(10))) as Arc; + + let optimized = + EnsureGlobalLimitForFetch::optimize_plan(limit, ParentContext::default()).unwrap(); + let projection = optimized.children()[0]; + let coalesce = projection.children()[0]; + + assert!(coalesce.as_any().is::()); + assert_eq!(coalesce.fetch(), Some(5)); + } + + #[test] + fn keeps_filter_under_parent_global_fetch() { + let (input, ordering) = ordered_input(); + let filter = filter_fetch(input, 1); + let merge = Arc::new(SortPreservingMergeExec::new(ordering, filter).with_fetch(Some(1))) + as Arc; + + let optimized = + EnsureGlobalLimitForFetch::optimize_plan(merge, ParentContext::default()).unwrap(); + let child = optimized.children()[0]; + + assert!(optimized.as_any().is::()); + assert!(child.as_any().is::()); + } + + #[test] + fn preserves_parent_ordering_requirement() { + let (input, ordering) = ordered_input(); + let filter = filter_fetch(input, 1); + let merge = + Arc::new(SortPreservingMergeExec::new(ordering, filter)) as Arc; + + let optimized = + EnsureGlobalLimitForFetch::optimize_plan(merge, ParentContext::default()).unwrap(); + let child = optimized.children()[0]; + + assert!(optimized.as_any().is::()); + assert!(child.as_any().is::()); + assert_eq!(child.fetch(), Some(1)); + } + + #[test] + fn uses_child_output_ordering_for_merge() { + let schema = schema(); + let required_ordering = ordering(schema.as_ref(), false); + let actual_ordering = ordering(schema.as_ref(), true); + let batch = batch(schema.clone()); + let partitions = vec![vec![batch.clone()], vec![batch.clone()], vec![batch]]; + let input = TestMemoryExec::try_new(&partitions, schema, None) + .unwrap() + .try_with_sort_information(vec![actual_ordering.clone()]) + .unwrap(); + let filter = filter_fetch(Arc::new(input), 1); + + let optimized = add_global_fetch( + filter, + 1, + Some(OrderingRequirements::from(required_ordering)), + ); + let merge = optimized + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(merge.expr(), &actual_ordering); + } + + #[test] + fn preserves_inherited_ordering_requirement_through_projection() { + let (input, ordering) = ordered_input(); + let filter = filter_fetch(input, 1); + let projection = Arc::new( + ProjectionExec::try_new( + vec![(col("a", filter.schema().as_ref()).unwrap(), "a".to_string())], + filter, + ) + .unwrap(), + ); + let merge = + Arc::new(SortPreservingMergeExec::new(ordering, projection)) as Arc; + + let optimized = + EnsureGlobalLimitForFetch::optimize_plan(merge, ParentContext::default()).unwrap(); + let projection = optimized.children()[0]; + let child = projection.children()[0]; + + assert!(optimized.as_any().is::()); + assert!(projection.as_any().is::()); + assert!(child.as_any().is::()); + assert_eq!(child.fetch(), Some(1)); + } + + fn unordered_input() -> Arc { + let schema = schema(); + let batch = batch(schema.clone()); + let partitions = vec![vec![batch.clone()], vec![batch.clone()], vec![batch]]; + Arc::new(TestMemoryExec::try_new(&partitions, schema, None).unwrap()) + } + + fn ordered_input() -> (Arc, LexOrdering) { + let schema = schema(); + let ordering = ordering(schema.as_ref(), false); + let batch = batch(schema.clone()); + let partitions = vec![vec![batch.clone()], vec![batch.clone()], vec![batch]]; + let input = TestMemoryExec::try_new(&partitions, schema, None) + .unwrap() + .try_with_sort_information(vec![ordering.clone()]) + .unwrap(); + + (Arc::new(input), ordering) + } + + fn filter_fetch(input: Arc, fetch: usize) -> Arc { + Arc::new( + FilterExecBuilder::new(lit(true), input) + .with_fetch(Some(fetch)) + .build() + .unwrap(), + ) + } + + fn schema() -> Arc { + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])) + } + + fn batch(schema: Arc) -> RecordBatch { + RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1, 2, 3]))]).unwrap() + } + + fn ordering(schema: &Schema, descending: bool) -> LexOrdering { + LexOrdering::new([PhysicalSortExpr::new( + col("a", schema).unwrap(), + SortOptions { + descending, + nulls_first: descending, + }, + )]) + .unwrap() + } +} diff --git a/src/query/src/query_engine/state.rs b/src/query/src/query_engine/state.rs index 50785c4a13..e86c5916b3 100644 --- a/src/query/src/query_engine/state.rs +++ b/src/query/src/query_engine/state.rs @@ -66,6 +66,7 @@ use crate::optimizer::const_normalization::ConstNormalizationRule; use crate::optimizer::constant_term::MatchesConstantTermOptimizer; use crate::optimizer::count_nest_aggr::CountNestAggrRule; use crate::optimizer::count_wildcard::CountWildcardToTimeIndexRule; +use crate::optimizer::global_limit::EnsureGlobalLimitForFetch; use crate::optimizer::json_type_concretize::JsonTypeConcretizeRule; use crate::optimizer::parallelize_scan::ParallelizeScan; use crate::optimizer::pass_distribution::PassDistribution; @@ -214,6 +215,9 @@ impl QueryEngineState { physical_optimizer .rules .push(Arc::new(MatchesConstantTermOptimizer)); + physical_optimizer + .rules + .push(Arc::new(EnsureGlobalLimitForFetch)); // Add rule to remove duplicate nodes generated by other rules. Run this in the last. physical_optimizer.rules.push(Arc::new(RemoveDuplicate)); // Place SanityCheckPlan at the end of the list to ensure that it runs after all other rules. diff --git a/tests/cases/distributed/information_schema/ssts_limit.result b/tests/cases/distributed/information_schema/ssts_limit.result new file mode 100644 index 0000000000..98a1984299 --- /dev/null +++ b/tests/cases/distributed/information_schema/ssts_limit.result @@ -0,0 +1,69 @@ +CREATE TABLE ssts_limit_case ( + a INT PRIMARY KEY INVERTED INDEX, + b STRING SKIPPING INDEX, + c STRING FULLTEXT INDEX, + ts TIMESTAMP TIME INDEX, +) +PARTITION ON COLUMNS (a) ( + a < 1000, + a >= 1000 AND a < 2000, + a >= 2000 +); + +Affected Rows: 0 + +INSERT INTO ssts_limit_case VALUES + (500, 'a', 'a', 1), + (1500, 'b', 'b', 2), + (2500, 'c', 'c', 3); + +Affected Rows: 3 + +ADMIN FLUSH_TABLE('ssts_limit_case'); + ++--------------------------------------+ +| ADMIN FLUSH_TABLE('ssts_limit_case') | ++--------------------------------------+ +| 0 | ++--------------------------------------+ + +SELECT COUNT(DISTINCT node_id) > 1 AS has_multi_datanodes +FROM information_schema.ssts_manifest; + ++---------------------+ +| has_multi_datanodes | ++---------------------+ +| true | ++---------------------+ + +SELECT COUNT(*) AS limited_rows +FROM ( + SELECT region_id + FROM information_schema.ssts_manifest + LIMIT 1 +); + ++--------------+ +| limited_rows | ++--------------+ +| 1 | ++--------------+ + +SELECT COUNT(*) AS filtered_limited_rows +FROM ( + SELECT region_id + FROM information_schema.ssts_manifest + WHERE region_id > 0 + LIMIT 1 +); + ++-----------------------+ +| filtered_limited_rows | ++-----------------------+ +| 1 | ++-----------------------+ + +DROP TABLE ssts_limit_case; + +Affected Rows: 0 + diff --git a/tests/cases/distributed/information_schema/ssts_limit.sql b/tests/cases/distributed/information_schema/ssts_limit.sql new file mode 100644 index 0000000000..c02f31d63d --- /dev/null +++ b/tests/cases/distributed/information_schema/ssts_limit.sql @@ -0,0 +1,38 @@ +CREATE TABLE ssts_limit_case ( + a INT PRIMARY KEY INVERTED INDEX, + b STRING SKIPPING INDEX, + c STRING FULLTEXT INDEX, + ts TIMESTAMP TIME INDEX, +) +PARTITION ON COLUMNS (a) ( + a < 1000, + a >= 1000 AND a < 2000, + a >= 2000 +); + +INSERT INTO ssts_limit_case VALUES + (500, 'a', 'a', 1), + (1500, 'b', 'b', 2), + (2500, 'c', 'c', 3); + +ADMIN FLUSH_TABLE('ssts_limit_case'); + +SELECT COUNT(DISTINCT node_id) > 1 AS has_multi_datanodes +FROM information_schema.ssts_manifest; + +SELECT COUNT(*) AS limited_rows +FROM ( + SELECT region_id + FROM information_schema.ssts_manifest + LIMIT 1 +); + +SELECT COUNT(*) AS filtered_limited_rows +FROM ( + SELECT region_id + FROM information_schema.ssts_manifest + WHERE region_id > 0 + LIMIT 1 +); + +DROP TABLE ssts_limit_case; diff --git a/tests/cases/standalone/common/tql-explain-analyze/explain.result b/tests/cases/standalone/common/tql-explain-analyze/explain.result index 810dd27644..ff455c8cd7 100644 --- a/tests/cases/standalone/common/tql-explain-analyze/explain.result +++ b/tests/cases/standalone/common/tql-explain-analyze/explain.result @@ -202,6 +202,7 @@ TQL EXPLAIN VERBOSE (0, 10, '5s') test; | physical_plan after FilterPushdown(Post)_| SAME TEXT AS ABOVE_| | physical_plan after WindowedSortRule_| SAME TEXT AS ABOVE_| | physical_plan after MatchesConstantTerm_| SAME TEXT AS ABOVE_| +| physical_plan after EnsureGlobalLimitForFetch_| SAME TEXT AS ABOVE_| | physical_plan after RemoveDuplicateRule_| SAME TEXT AS ABOVE_| | physical_plan after SanityCheckPlan_| SAME TEXT AS ABOVE_| | physical_plan_| CooperativeExec_| @@ -353,6 +354,7 @@ TQL EXPLAIN VERBOSE (0, 10, '5s') test AS series; | physical_plan after FilterPushdown(Post)_| SAME TEXT AS ABOVE_| | physical_plan after WindowedSortRule_| SAME TEXT AS ABOVE_| | physical_plan after MatchesConstantTerm_| SAME TEXT AS ABOVE_| +| physical_plan after EnsureGlobalLimitForFetch_| SAME TEXT AS ABOVE_| | physical_plan after RemoveDuplicateRule_| SAME TEXT AS ABOVE_| | physical_plan after SanityCheckPlan_| SAME TEXT AS ABOVE_| | physical_plan_| CooperativeExec_| @@ -578,6 +580,7 @@ TQL EXPLAIN VERBOSE (0, 10, '5s') test_nano; | physical_plan after FilterPushdown(Post)_| SAME TEXT AS ABOVE_| | physical_plan after WindowedSortRule_| SAME TEXT AS ABOVE_| | physical_plan after MatchesConstantTerm_| SAME TEXT AS ABOVE_| +| physical_plan after EnsureGlobalLimitForFetch_| SAME TEXT AS ABOVE_| | physical_plan after RemoveDuplicateRule_| SAME TEXT AS ABOVE_| | physical_plan after SanityCheckPlan_| SAME TEXT AS ABOVE_| | physical_plan_| PromInstantManipulateExec: range=[0..10000], lookback=[300000], interval=[5000], time index=[j]_|