diff --git a/src/promql/src/extension_plan/histogram_fold.rs b/src/promql/src/extension_plan/histogram_fold.rs index f43184a0a7..b8c04e597a 100644 --- a/src/promql/src/extension_plan/histogram_fold.rs +++ b/src/promql/src/extension_plan/histogram_fold.rs @@ -106,11 +106,15 @@ impl UserDefinedLogicalNodeCore for HistogramFold { } fn necessary_children_exprs(&self, output_columns: &[usize]) -> Option>> { - let le_column_index = self - .input - .schema() - .index_of_column_by_name(None, &self.le_column)?; - let necessary_indices = output_columns + let input_schema = self.input.schema(); + let le_column_index = input_schema.index_of_column_by_name(None, &self.le_column)?; + + if output_columns.is_empty() { + let indices = (0..input_schema.fields().len()).collect::>(); + return Some(vec![indices]); + } + + let mut necessary_indices = output_columns .iter() .map(|&output_column| { if output_column < le_column_index { @@ -119,7 +123,10 @@ impl UserDefinedLogicalNodeCore for HistogramFold { output_column + 1 } }) - .collect(); + .collect::>(); + necessary_indices.push(le_column_index); + necessary_indices.sort_unstable(); + necessary_indices.dedup(); Some(vec![necessary_indices]) } @@ -1030,11 +1037,26 @@ mod test { use datafusion::common::ToDFSchema; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::source::DataSourceExec; + use datafusion::logical_expr::EmptyRelation; use datafusion::prelude::SessionContext; use datatypes::arrow_array::StringArray; + use futures::FutureExt; use super::*; + fn project_batch(batch: &RecordBatch, indices: &[usize]) -> RecordBatch { + let fields = indices + .iter() + .map(|&idx| batch.schema().field(idx).clone()) + .collect::>(); + let columns = indices + .iter() + .map(|&idx| batch.column(idx).clone()) + .collect::>(); + let schema = Arc::new(Schema::new(fields)); + RecordBatch::try_new(schema, columns).unwrap() + } + fn prepare_test_data() -> DataSourceExec { let schema = Arc::new(Schema::new(vec![ Field::new("host", DataType::Utf8, true), @@ -1222,6 +1244,100 @@ mod test { assert_eq!(result_literal, expected); } + #[tokio::test] + async fn pruning_should_keep_le_column_for_exec() { + let schema = Arc::new(Schema::new(vec![ + Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true), + Field::new("le", DataType::Utf8, true), + Field::new("val", DataType::Float64, true), + ])); + let df_schema = schema.clone().to_dfschema_ref().unwrap(); + let input = LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: df_schema, + }); + let plan = HistogramFold::new( + "le".to_string(), + "val".to_string(), + "ts".to_string(), + 0.5, + input, + ) + .unwrap(); + + let output_columns = [0usize, 1usize]; + let required = plan.necessary_children_exprs(&output_columns).unwrap(); + let required = &required[0]; + assert_eq!(required.as_slice(), &[0, 1, 2]); + + let input_batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(TimestampMillisecondArray::from(vec![0, 0])), + Arc::new(StringArray::from(vec!["0.1", "+Inf"])), + Arc::new(Float64Array::from(vec![1.0, 2.0])), + ], + ) + .unwrap(); + let projected = project_batch(&input_batch, required); + let projected_schema = projected.schema(); + let memory_exec = Arc::new(DataSourceExec::new(Arc::new( + MemorySourceConfig::try_new(&[vec![projected]], projected_schema, None).unwrap(), + ))); + + let fold_exec = plan.to_execution_plan(memory_exec); + let session_context = SessionContext::default(); + let output_batches = + datafusion::physical_plan::collect(fold_exec, session_context.task_ctx()) + .await + .unwrap(); + assert_eq!(output_batches.len(), 1); + + let output_batch = &output_batches[0]; + assert_eq!(output_batch.num_rows(), 1); + + let ts = output_batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(ts.values(), &[0i64]); + + let values = output_batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert!((values.value(0) - 0.1).abs() < 1e-12); + + // Simulate the pre-fix pruning behavior: omit the `le` column from the child input. + let le_index = 1usize; + let broken_required = output_columns + .iter() + .map(|&output_column| { + if output_column < le_index { + output_column + } else { + output_column + 1 + } + }) + .collect::>(); + + let broken = project_batch(&input_batch, &broken_required); + let broken_schema = broken.schema(); + let broken_exec = Arc::new(DataSourceExec::new(Arc::new( + MemorySourceConfig::try_new(&[vec![broken]], broken_schema, None).unwrap(), + ))); + let broken_fold_exec = plan.to_execution_plan(broken_exec); + let session_context = SessionContext::default(); + let broken_result = std::panic::AssertUnwindSafe(async { + datafusion::physical_plan::collect(broken_fold_exec, session_context.task_ctx()).await + }) + .catch_unwind() + .await; + assert!(broken_result.is_err()); + } + #[test] fn confirm_schema() { let input_schema = Schema::new(vec![ diff --git a/src/promql/src/extension_plan/range_manipulate.rs b/src/promql/src/extension_plan/range_manipulate.rs index 15f0c9d4e4..dcbcaafd45 100644 --- a/src/promql/src/extension_plan/range_manipulate.rs +++ b/src/promql/src/extension_plan/range_manipulate.rs @@ -18,7 +18,7 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use common_telemetry::debug; +use common_telemetry::{debug, warn}; use datafusion::arrow::array::{Array, ArrayRef, Int64Array, TimestampMillisecondArray}; use datafusion::arrow::compute; use datafusion::arrow::datatypes::{Field, SchemaRef}; @@ -316,6 +316,10 @@ impl UserDefinedLogicalNodeCore for RangeManipulate { // Derived timestamp range column. required.push(time_index_idx); } else { + warn!( + "Output column index {} is out of bounds for input schema with length {}", + idx, input_len + ); return None; } } diff --git a/src/promql/src/extension_plan/scalar_calculate.rs b/src/promql/src/extension_plan/scalar_calculate.rs index dd64c47245..6a5f0fb0d8 100644 --- a/src/promql/src/extension_plan/scalar_calculate.rs +++ b/src/promql/src/extension_plan/scalar_calculate.rs @@ -279,7 +279,7 @@ impl UserDefinedLogicalNodeCore for ScalarCalculate { .collect() } - fn necessary_children_exprs(&self, output_columns: &[usize]) -> Option>> { + fn necessary_children_exprs(&self, _output_columns: &[usize]) -> Option>> { if self.unfix.is_some() { return None; } @@ -288,15 +288,10 @@ impl UserDefinedLogicalNodeCore for ScalarCalculate { let time_index_idx = input_schema.index_of_column_by_name(None, &self.time_index)?; let field_column_idx = input_schema.index_of_column_by_name(None, &self.field_column)?; - let mut required = Vec::with_capacity(2); - if output_columns.contains(&0) { - required.push(time_index_idx); - } - if output_columns.contains(&1) { - required.push(field_column_idx); - } - if required.is_empty() { - required.extend([time_index_idx, field_column_idx]); + let mut required = Vec::with_capacity(2 + self.tag_columns.len()); + required.extend([time_index_idx, field_column_idx]); + for tag in &self.tag_columns { + required.push(input_schema.index_of_column_by_name(None, tag)?); } required.sort_unstable(); @@ -653,6 +648,109 @@ mod test { use super::*; + fn project_batch(batch: &RecordBatch, indices: &[usize]) -> RecordBatch { + let fields = indices + .iter() + .map(|&idx| batch.schema().field(idx).clone()) + .collect::>(); + let columns = indices + .iter() + .map(|&idx| batch.column(idx).clone()) + .collect::>(); + let schema = Arc::new(Schema::new(fields)); + RecordBatch::try_new(schema, columns).unwrap() + } + + #[test] + fn necessary_children_exprs_preserve_tag_columns() { + let schema = Arc::new(Schema::new(vec![ + Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true), + Field::new("tag1", DataType::Utf8, true), + Field::new("tag2", DataType::Utf8, true), + Field::new("val", DataType::Float64, true), + Field::new("extra", DataType::Utf8, true), + ])); + let schema = Arc::new(DFSchema::try_from(schema).unwrap()); + let input = LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema, + }); + let tag_columns = vec!["tag1".to_string(), "tag2".to_string()]; + let plan = ScalarCalculate::new(0, 1, 1, input, "ts", &tag_columns, "val", None).unwrap(); + + let required = plan.necessary_children_exprs(&[0, 1]).unwrap(); + assert_eq!(required, vec![vec![0, 1, 2, 3]]); + } + + #[tokio::test] + async fn pruning_should_keep_tag_columns_for_exec() { + let schema = Arc::new(Schema::new(vec![ + Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true), + Field::new("tag1", DataType::Utf8, true), + Field::new("tag2", DataType::Utf8, true), + Field::new("val", DataType::Float64, true), + Field::new("extra", DataType::Utf8, true), + ])); + let df_schema = Arc::new(DFSchema::try_from(schema.clone()).unwrap()); + let input = LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: df_schema, + }); + let tag_columns = vec!["tag1".to_string(), "tag2".to_string()]; + let plan = + ScalarCalculate::new(0, 15_000, 5000, input, "ts", &tag_columns, "val", None).unwrap(); + + let required = plan.necessary_children_exprs(&[0, 1]).unwrap(); + let required = &required[0]; + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(TimestampMillisecondArray::from(vec![ + 0, 5_000, 10_000, 15_000, + ])), + Arc::new(StringArray::from(vec!["foo", "foo", "foo", "foo"])), + Arc::new(StringArray::from(vec!["bar", "bar", "bar", "bar"])), + Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])), + Arc::new(StringArray::from(vec!["x", "x", "x", "x"])), + ], + ) + .unwrap(); + + let projected_batch = project_batch(&batch, required); + let projected_schema = projected_batch.schema(); + let memory_exec = Arc::new(DataSourceExec::new(Arc::new( + MemorySourceConfig::try_new(&[vec![projected_batch]], projected_schema, None).unwrap(), + ))); + let scalar_exec = plan.to_execution_plan(memory_exec).unwrap(); + + let session_context = SessionContext::default(); + let result = datafusion::physical_plan::collect(scalar_exec, session_context.task_ctx()) + .await + .unwrap(); + + assert_eq!(result.len(), 1); + let batch = &result[0]; + assert_eq!(batch.num_columns(), 2); + assert_eq!(batch.num_rows(), 4); + assert_eq!(batch.schema().field(0).name(), "ts"); + assert_eq!(batch.schema().field(1).name(), "scalar(val)"); + + let ts = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(ts.values(), &[0i64, 5_000, 10_000, 15_000]); + + let values = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(values.values(), &[1.0f64, 2.0, 3.0, 4.0]); + } + fn prepare_test_data(series: Vec) -> DataSourceExec { let schema = Arc::new(Schema::new(vec![ Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),