fix multi-field join case

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
This commit is contained in:
Ruihang Xia
2026-04-01 06:50:17 +08:00
parent 6035a9bf7b
commit dcd5ab82d2
2 changed files with 284 additions and 57 deletions

View File

@@ -21,7 +21,6 @@ use arrow_schema::{Schema as ArrowSchema, SchemaRef as ArrowSchemaRef, SortOptio
use async_stream::stream;
use common_catalog::parse_catalog_and_schema_from_db_string;
use common_plugins::GREPTIME_EXEC_READ_COST;
use common_query::prelude::greptime_timestamp;
use common_query::request::QueryRequest;
use common_recordbatch::adapter::RecordBatchMetrics;
use common_telemetry::tracing_context::TracingContext;
@@ -45,6 +44,9 @@ use futures_util::StreamExt;
use greptime_proto::v1::region::RegionRequestHeader;
use meter_core::data::ReadItem;
use meter_macros::read_meter;
use promql::extension_plan::{
InstantManipulate, RangeManipulate, ScalarCalculate, SeriesDivide, SeriesNormalize,
};
use session::context::QueryContextRef;
use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME;
use store_api::storage::RegionId;
@@ -194,73 +196,92 @@ impl MergeScanExec {
plan: &LogicalPlan,
partition_cols: &AliasMapping,
) -> Option<Vec<Arc<dyn datafusion_physical_expr::PhysicalExpr>>> {
if !Self::is_promql_tsid_ordered_plan(plan) {
return None;
}
Self::promql_tsid_ordered_time_index(plan)?;
let tsid_aliases = partition_cols.get(DATA_SCHEMA_TSID_COLUMN_NAME)?;
let tsid_expr = Self::partition_expr_for_alias(session_state, plan, tsid_aliases.first())?;
Some(vec![tsid_expr])
}
fn is_promql_tsid_ordered_plan(plan: &LogicalPlan) -> bool {
match plan {
fn promql_tsid_ordered_time_index(plan: &LogicalPlan) -> Option<String> {
let time_index_column = match plan {
LogicalPlan::Sort(sort) => {
if sort.expr.len() != 2 {
return false;
return None;
}
let [tsid_sort, time_sort] = sort.expr.as_slice() else {
return false;
return None;
};
Self::is_ascending_nulls_first_sort(tsid_sort, DATA_SCHEMA_TSID_COLUMN_NAME)
&& Self::is_ascending_nulls_first_sort(time_sort, greptime_timestamp())
let tsid_column = Self::ascending_nulls_first_sort_column(tsid_sort)?;
let time_column = Self::ascending_nulls_first_sort_column(time_sort)?;
(tsid_column == DATA_SCHEMA_TSID_COLUMN_NAME).then_some(time_column)
}
LogicalPlan::Projection(projection) => {
Self::is_promql_tsid_ordered_plan(projection.input.as_ref())
Self::promql_tsid_ordered_time_index(projection.input.as_ref())
}
LogicalPlan::Filter(filter) => {
Self::promql_tsid_ordered_time_index(filter.input.as_ref())
}
LogicalPlan::Filter(filter) => Self::is_promql_tsid_ordered_plan(filter.input.as_ref()),
LogicalPlan::SubqueryAlias(alias) => {
Self::is_promql_tsid_ordered_plan(alias.input.as_ref())
Self::promql_tsid_ordered_time_index(alias.input.as_ref())
}
LogicalPlan::Extension(extension)
if matches!(
extension.node.name(),
"PromInstantManipulate"
| "PromSeriesDivide"
| "PromNormalize"
| "PromScalarCalculate"
| "PromRangeManipulate"
) =>
{
LogicalPlan::Extension(extension) if Self::is_promql_passthrough_node(extension) => {
extension
.node
.inputs()
.first()
.is_some_and(|input| Self::is_promql_tsid_ordered_plan(input))
.and_then(|input| Self::promql_tsid_ordered_time_index(input))
}
_ => false,
}
_ => None,
}?;
let schema = plan.schema();
let has_tsid = schema
.index_of_column_by_name(None, DATA_SCHEMA_TSID_COLUMN_NAME)
.is_some();
let has_time_index = schema
.index_of_column_by_name(None, &time_index_column)
.is_some();
(has_tsid && has_time_index).then_some(time_index_column)
}
fn is_ascending_nulls_first_sort(
fn is_promql_passthrough_node(extension: &Extension) -> bool {
let node = extension.node.as_any();
node.is::<InstantManipulate>()
|| node.is::<SeriesDivide>()
|| node.is::<SeriesNormalize>()
|| node.is::<ScalarCalculate>()
|| node.is::<RangeManipulate>()
}
fn ascending_nulls_first_sort_column(
sort_expr: &datafusion_expr::expr::Sort,
column: &str,
) -> bool {
sort_expr.asc
&& sort_expr.nulls_first
&& matches!(
sort_expr.expr.try_as_col(),
Some(col) if col.name == column
)
) -> Option<String> {
(sort_expr.asc && sort_expr.nulls_first)
.then(|| sort_expr.expr.try_as_col().map(|col| col.name.clone()))
.flatten()
}
fn schema_exposes_column(plan: &LogicalPlan, column_name: &str) -> bool {
plan.schema()
.index_of_column_by_name(None, column_name)
.is_some()
}
pub(crate) fn logical_sort_ordering(
session_state: &SessionState,
plan: &LogicalPlan,
) -> Result<Option<LexOrdering>> {
if Self::is_promql_tsid_ordered_plan(plan) {
if let Some(time_index_column) = Self::promql_tsid_ordered_time_index(plan) {
if !Self::schema_exposes_column(plan, DATA_SCHEMA_TSID_COLUMN_NAME)
|| !Self::schema_exposes_column(plan, &time_index_column)
{
return Ok(None);
}
let tsid_expr = session_state.create_physical_expr(
Expr::Column(ColumnExpr::new_unqualified(
DATA_SCHEMA_TSID_COLUMN_NAME.to_string(),
@@ -268,9 +289,7 @@ impl MergeScanExec {
plan.schema(),
)?;
let time_expr = session_state.create_physical_expr(
Expr::Column(ColumnExpr::new_unqualified(
greptime_timestamp().to_string(),
)),
Expr::Column(ColumnExpr::new_unqualified(time_index_column)),
plan.schema(),
)?;
return Ok(LexOrdering::new(vec![
@@ -660,8 +679,9 @@ mod tests {
use datafusion::execution::SessionStateBuilder;
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
use datafusion_common::ToDFSchema;
use datafusion_expr::{EmptyRelation, LogicalPlan, LogicalPlanBuilder, col};
use datafusion_expr::{EmptyRelation, Extension, LogicalPlan, LogicalPlanBuilder, col};
use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr};
use promql::extension_plan::{InstantManipulate, SeriesDivide};
use session::ReadPreference;
use session::context::QueryContext;
@@ -714,7 +734,7 @@ mod tests {
};
assert_eq!(partition_count, 32);
assert_eq!(
column_names(&exprs),
column_names(exprs),
vec![DATA_SCHEMA_TSID_COLUMN_NAME, "greptime_timestamp"]
);
}
@@ -741,7 +761,7 @@ mod tests {
panic!("expected hash partitioning");
};
assert_eq!(partition_count, 32);
assert_eq!(column_names(&exprs), vec!["host"]);
assert_eq!(column_names(exprs), vec!["host"]);
}
#[test]
@@ -751,7 +771,7 @@ mod tests {
Field::new("host", DataType::Utf8, true),
Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false),
Field::new(
greptime_timestamp(),
"ts",
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
),
@@ -767,7 +787,11 @@ mod tests {
BTreeSet::from([ColumnExpr::from_name(DATA_SCHEMA_TSID_COLUMN_NAME)]),
),
]);
let plan = promql_tsid_sorted_plan(schema.clone());
let plan = promql_tsid_sorted_plan(schema.clone(), "ts");
let ordering = MergeScanExec::logical_sort_ordering(&session_state, &plan)
.unwrap()
.unwrap();
let merge_scan = MergeScanExec::new(
&session_state,
@@ -789,6 +813,45 @@ mod tests {
};
assert_eq!(partition_count, 32);
assert_eq!(column_names(exprs), vec![DATA_SCHEMA_TSID_COLUMN_NAME]);
assert_eq!(
ordering_column_names(&ordering),
vec![DATA_SCHEMA_TSID_COLUMN_NAME, "ts"]
);
}
#[test]
fn logical_sort_ordering_ignores_projected_away_tsid_columns() {
let session_state = SessionStateBuilder::new().with_default_features().build();
let schema = Arc::new(Schema::new(vec![
Field::new("host", DataType::Utf8, true),
Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false),
Field::new(
"ts",
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
),
Field::new("greptime_value", DataType::Float64, true),
]));
let projected = LogicalPlanBuilder::from(promql_tsid_sorted_plan(schema, "ts"))
.project(vec![col("host"), col("ts"), col("greptime_value")])
.unwrap()
.build()
.unwrap();
let plan = LogicalPlan::Extension(Extension {
node: Arc::new(InstantManipulate::new(
0,
10,
1,
1,
"ts".to_string(),
Some("greptime_value".to_string()),
projected,
)),
});
let ordering = MergeScanExec::logical_sort_ordering(&session_state, &plan).unwrap();
assert!(ordering.is_none());
}
fn test_merge_scan_exec(
@@ -839,20 +902,41 @@ mod tests {
.collect()
}
fn promql_tsid_sorted_plan(schema: Arc<Schema>) -> LogicalPlan {
fn ordering_column_names(ordering: &LexOrdering) -> Vec<&str> {
ordering
.iter()
.map(|sort_expr| {
sort_expr
.expr
.as_any()
.downcast_ref::<Column>()
.unwrap()
.name()
})
.collect()
}
fn promql_tsid_sorted_plan(schema: Arc<Schema>, time_index: &str) -> LogicalPlan {
let input = LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: schema.to_dfschema_ref().unwrap(),
});
LogicalPlanBuilder::from(input)
let sorted = LogicalPlanBuilder::from(input)
.sort(vec![
col(DATA_SCHEMA_TSID_COLUMN_NAME).sort(true, true),
col(greptime_timestamp()).sort(true, true),
col(time_index).sort(true, true),
])
.unwrap()
.build()
.unwrap()
.unwrap();
LogicalPlan::Extension(Extension {
node: Arc::new(SeriesDivide::new(
vec!["host".to_string()],
time_index.to_string(),
sorted,
)),
})
}
}

View File

@@ -727,7 +727,18 @@ impl PromPlanner {
self.ctx.table_name = Some("rhs".to_string());
}
}
let mut field_columns = left_field_columns.iter().zip(right_field_columns.iter());
let field_columns = left_field_columns
.iter()
.zip(right_field_columns.iter())
.collect::<Vec<_>>();
// PromQL binary arithmetic only combines the shared prefix of value columns.
// Keep the output field count aligned with that zipped prefix so planning
// remains stable even when the two sides have uneven multi-field schemas.
self.ctx.field_columns = field_columns
.iter()
.map(|(left_col_name, _)| (*left_col_name).clone())
.collect();
let mut field_columns = field_columns.into_iter();
let join_plan = self.join_on_non_field_columns(
left_input,
@@ -849,9 +860,17 @@ impl PromPlanner {
}
}
let mut field_columns = repeated_field_columns
let field_columns = repeated_field_columns
.iter()
.zip(other_field_columns.iter());
.zip(other_field_columns.iter())
.collect::<Vec<_>>();
// The collapsed fast path must preserve the same zipped-field semantics as the
// original two-step plan: only the shared prefix of value columns participates.
self.ctx.field_columns = field_columns
.iter()
.map(|(repeated_col_name, _)| (*repeated_col_name).clone())
.collect();
let mut field_columns = field_columns.into_iter();
let join_plan = self.join_on_non_field_columns(
repeated_input,
@@ -4452,11 +4471,27 @@ mod test {
table_name_tuples: &[(String, String)],
num_tag: usize,
num_field: usize,
) -> DfTableSourceProvider {
let table_specs = table_name_tuples
.iter()
.map(|(schema_name, table_name)| ((schema_name.clone(), table_name.clone()), num_field))
.collect::<Vec<_>>();
build_test_table_provider_with_tsid_fields(&table_specs, num_tag).await
}
async fn build_test_table_provider_with_tsid_fields(
table_specs: &[((String, String), usize)],
num_tag: usize,
) -> DfTableSourceProvider {
let catalog_list = MemoryCatalogManager::with_default_setup();
let physical_table_name = "phy";
let physical_table_id = 999u32;
let physical_num_field = table_specs
.iter()
.map(|(_, num_field)| *num_field)
.max()
.unwrap_or(0);
// Register a metric engine physical table with internal columns.
{
@@ -4487,7 +4522,7 @@ mod test {
)
.with_time_index(true),
);
for i in 0..num_field {
for i in 0..physical_num_field {
columns.push(ColumnSchema::new(
format!("field_{i}"),
ConcreteDataType::float64_datatype(),
@@ -4500,7 +4535,7 @@ mod test {
let table_meta = TableMetaBuilder::empty()
.schema(schema)
.primary_key_indices(primary_key_indices)
.value_indices((2 + num_tag..2 + num_tag + 1 + num_field).collect())
.value_indices((2 + num_tag..2 + num_tag + 1 + physical_num_field).collect())
.engine(METRIC_ENGINE_NAME.to_string())
.next_column_id(1024)
.build()
@@ -4527,7 +4562,7 @@ mod test {
}
// Register metric engine logical tables without `__tsid`, referencing the physical table.
for (idx, (schema_name, table_name)) in table_name_tuples.iter().enumerate() {
for (idx, ((schema_name, table_name), num_field)) in table_specs.iter().enumerate() {
let mut columns = vec![];
for i in 0..num_tag {
columns.push(ColumnSchema::new(
@@ -4544,7 +4579,7 @@ mod test {
)
.with_time_index(true),
);
for i in 0..num_field {
for i in 0..*num_field {
columns.push(ColumnSchema::new(
format!("field_{i}"),
ConcreteDataType::float64_datatype(),
@@ -4562,7 +4597,7 @@ mod test {
let table_meta = TableMetaBuilder::empty()
.schema(schema)
.primary_key_indices((0..num_tag).collect())
.value_indices((num_tag + 1..num_tag + 1 + num_field).collect())
.value_indices((num_tag + 1..num_tag + 1 + *num_field).collect())
.engine(METRIC_ENGINE_NAME.to_string())
.options(options)
.next_column_id(1024)
@@ -5110,6 +5145,114 @@ mod test {
assert!(!plan_str.contains("tag_0 ="), "{plan_str}");
}
#[tokio::test]
async fn repeated_tsid_binary_operand_uses_shorter_field_side() {
let prom_expr =
parser::parse("((two_field_metric - one_field_metric) / one_field_metric) * 100")
.unwrap();
let eval_stmt = EvalStmt {
expr: prom_expr,
start: UNIX_EPOCH,
end: UNIX_EPOCH
.checked_add(Duration::from_secs(100_000))
.unwrap(),
interval: Duration::from_secs(5),
lookback_delta: Duration::from_secs(1),
};
let table_provider = build_test_table_provider_with_tsid_fields(
&[
(
(
DEFAULT_SCHEMA_NAME.to_string(),
"two_field_metric".to_string(),
),
2,
),
(
(
DEFAULT_SCHEMA_NAME.to_string(),
"one_field_metric".to_string(),
),
1,
),
],
1,
)
.await;
let plan =
PromPlanner::stmt_to_plan(table_provider, &eval_stmt, &build_query_engine_state())
.await
.unwrap();
let field_names = plan
.schema()
.fields()
.iter()
.map(|field| field.name().clone())
.collect::<Vec<_>>();
let value_columns = field_names
.iter()
.filter(|name| {
*name != "tag_0" && *name != "timestamp" && *name != DATA_SCHEMA_TSID_COLUMN_NAME
})
.count();
assert_eq!(value_columns, 1, "{field_names:?}");
}
#[tokio::test]
async fn tsid_binary_join_uses_shorter_field_side() {
let prom_expr = parser::parse("one_field_metric / two_field_metric").unwrap();
let eval_stmt = EvalStmt {
expr: prom_expr,
start: UNIX_EPOCH,
end: UNIX_EPOCH
.checked_add(Duration::from_secs(100_000))
.unwrap(),
interval: Duration::from_secs(5),
lookback_delta: Duration::from_secs(1),
};
let table_provider = build_test_table_provider_with_tsid_fields(
&[
(
(
DEFAULT_SCHEMA_NAME.to_string(),
"one_field_metric".to_string(),
),
1,
),
(
(
DEFAULT_SCHEMA_NAME.to_string(),
"two_field_metric".to_string(),
),
2,
),
],
1,
)
.await;
let plan =
PromPlanner::stmt_to_plan(table_provider, &eval_stmt, &build_query_engine_state())
.await
.unwrap();
let field_names = plan
.schema()
.fields()
.iter()
.map(|field| field.name().clone())
.collect::<Vec<_>>();
let value_columns = field_names
.iter()
.filter(|name| {
*name != "tag_0" && *name != "timestamp" && *name != DATA_SCHEMA_TSID_COLUMN_NAME
})
.count();
assert_eq!(value_columns, 1, "{field_names:?}");
}
#[tokio::test]
async fn label_matching_modifier_disables_tsid_binary_join() {
let prom_expr = parser::parse("some_metric / ignoring(tag_0) some_alt_metric").unwrap();