diff --git a/src/promql/src/extension_plan/series_divide.rs b/src/promql/src/extension_plan/series_divide.rs index 4085f0d44f..3fdc1ddf6e 100644 --- a/src/promql/src/extension_plan/series_divide.rs +++ b/src/promql/src/extension_plan/series_divide.rs @@ -17,8 +17,10 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use datafusion::arrow::array::{Array, StringArray}; -use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::array::{ + Array, LargeStringArray, StringArray, StringViewArray, UInt64Array, +}; +use datafusion::arrow::datatypes::{DataType, SchemaRef}; use datafusion::arrow::record_batch::RecordBatch; use datafusion::common::{DFSchema, DFSchemaRef}; use datafusion::error::Result as DataFusionResult; @@ -45,6 +47,149 @@ use crate::error::{DeserializeSnafu, Result}; use crate::extension_plan::{METRIC_NUM_SERIES, resolve_column_name, serialize_column_index}; use crate::metrics::PROMQL_SERIES_COUNT; +enum TagIdentifier<'a> { + /// A group of raw string tag columns. + Raw(Vec>), + /// A single UInt64 identifier (tsid). + Id(&'a UInt64Array), +} + +impl<'a> TagIdentifier<'a> { + fn try_new(batch: &'a RecordBatch, tag_indices: &[usize]) -> DataFusionResult { + match tag_indices { + [] => Ok(Self::Raw(Vec::new())), + [index] => { + let array = batch.column(*index); + if array.data_type() == &DataType::UInt64 { + let array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion::error::DataFusionError::Internal( + "Failed to downcast tag column to UInt64Array".to_string(), + ) + })?; + Ok(Self::Id(array)) + } else { + Ok(Self::Raw(vec![RawTagColumn::try_new(array.as_ref())?])) + } + } + indices => Ok(Self::Raw( + indices + .iter() + .map(|index| RawTagColumn::try_new(batch.column(*index).as_ref())) + .collect::>>()?, + )), + } + } + + fn equal_at(&self, left_row: usize, other: &Self, right_row: usize) -> DataFusionResult { + match (self, other) { + (Self::Id(left), Self::Id(right)) => { + if left.is_null(left_row) || right.is_null(right_row) { + return Ok(left.is_null(left_row) && right.is_null(right_row)); + } + Ok(left.value(left_row) == right.value(right_row)) + } + (Self::Raw(left), Self::Raw(right)) => { + if left.len() != right.len() { + return Err(datafusion::error::DataFusionError::Internal(format!( + "Mismatched tag column count: left={}, right={}", + left.len(), + right.len() + ))); + } + + for (left_column, right_column) in left.iter().zip(right.iter()) { + if !left_column.equal_at(left_row, right_column, right_row)? { + return Ok(false); + } + } + Ok(true) + } + _ => Err(datafusion::error::DataFusionError::Internal(format!( + "Mismatched tag identifier types: left={:?}, right={:?}", + self.data_type(), + other.data_type() + ))), + } + } + + fn data_type(&self) -> &'static str { + match self { + Self::Raw(_) => "Raw", + Self::Id(_) => "Id", + } + } +} + +enum RawTagColumn<'a> { + Utf8(&'a StringArray), + LargeUtf8(&'a LargeStringArray), + Utf8View(&'a StringViewArray), +} + +impl<'a> RawTagColumn<'a> { + fn try_new(array: &'a dyn Array) -> DataFusionResult { + match array.data_type() { + DataType::Utf8 => array + .as_any() + .downcast_ref::() + .map(Self::Utf8) + .ok_or_else(|| { + datafusion::error::DataFusionError::Internal( + "Failed to downcast tag column to StringArray".to_string(), + ) + }), + DataType::LargeUtf8 => array + .as_any() + .downcast_ref::() + .map(Self::LargeUtf8) + .ok_or_else(|| { + datafusion::error::DataFusionError::Internal( + "Failed to downcast tag column to LargeStringArray".to_string(), + ) + }), + DataType::Utf8View => array + .as_any() + .downcast_ref::() + .map(Self::Utf8View) + .ok_or_else(|| { + datafusion::error::DataFusionError::Internal( + "Failed to downcast tag column to StringViewArray".to_string(), + ) + }), + other => Err(datafusion::error::DataFusionError::Internal(format!( + "Unsupported tag column type: {other:?}" + ))), + } + } + + fn is_null(&self, row: usize) -> bool { + match self { + Self::Utf8(array) => array.is_null(row), + Self::LargeUtf8(array) => array.is_null(row), + Self::Utf8View(array) => array.is_null(row), + } + } + + fn value(&self, row: usize) -> &str { + match self { + Self::Utf8(array) => array.value(row), + Self::LargeUtf8(array) => array.value(row), + Self::Utf8View(array) => array.value(row), + } + } + + fn equal_at(&self, left_row: usize, other: &Self, right_row: usize) -> DataFusionResult { + if self.is_null(left_row) || other.is_null(right_row) { + return Ok(self.is_null(left_row) && other.is_null(right_row)); + } + + Ok(self.value(left_row) == other.value(right_row)) + } +} + #[derive(Debug, PartialEq, Eq, Hash, PartialOrd)] pub struct SeriesDivide { tag_columns: Vec, @@ -481,90 +626,37 @@ impl SeriesDivideStream { for batch in &self.buffer[resumed_batch_index..] { let num_rows = batch.num_rows(); - let mut result_index = num_rows; + let tags = TagIdentifier::try_new(batch, &self.tag_indices)?; // check if the first row is the same with last batch's last row if resumed_batch_index > self.inspect_start.checked_sub(1).unwrap_or_default() { let last_batch = &self.buffer[resumed_batch_index - 1]; let last_row = last_batch.num_rows() - 1; - for index in &self.tag_indices { - let current_array = batch.column(*index); - let last_array = last_batch.column(*index); - let current_string_array = current_array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - datafusion::error::DataFusionError::Internal( - "Failed to downcast tag column to StringArray".to_string(), - ) - })?; - let last_string_array = last_array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - datafusion::error::DataFusionError::Internal( - "Failed to downcast tag column to StringArray".to_string(), - ) - })?; - let current_value = current_string_array.value(0); - let last_value = last_string_array.value(last_row); - if current_value != last_value { - return Ok(Some((resumed_batch_index - 1, last_batch.num_rows() - 1))); - } + let last_tags = TagIdentifier::try_new(last_batch, &self.tag_indices)?; + if !tags.equal_at(0, &last_tags, last_row)? { + return Ok(Some((resumed_batch_index - 1, last_row))); } } // quick check if all rows are the same by comparing the first and last row in this batch - let mut all_same = true; - for index in &self.tag_indices { - let array = batch.column(*index); - let string_array = - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - datafusion::error::DataFusionError::Internal( - "Failed to downcast tag column to StringArray".to_string(), - ) - })?; - if string_array.value(0) != string_array.value(num_rows - 1) { - all_same = false; - break; - } - } - if all_same { + if tags.equal_at(0, &tags, num_rows - 1)? { resumed_batch_index += 1; continue; } - // check column by column - for index in &self.tag_indices { - let array = batch.column(*index); - let string_array = - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - datafusion::error::DataFusionError::Internal( - "Failed to downcast tag column to StringArray".to_string(), - ) - })?; - // the first row number that not equal to the next row. - let mut same_until = 0; - while same_until < num_rows - 1 { - if string_array.value(same_until) != string_array.value(same_until + 1) { - break; - } - same_until += 1; + let mut same_until = 0; + while same_until < num_rows - 1 { + if !tags.equal_at(same_until, &tags, same_until + 1)? { + break; } - result_index = result_index.min(same_until); + same_until += 1; } - if result_index + 1 >= num_rows { + if same_until + 1 >= num_rows { // all rows are the same, inspect next batch resumed_batch_index += 1; } else { - return Ok(Some((resumed_batch_index, result_index))); + return Ok(Some((resumed_batch_index, same_until))); } } @@ -1030,4 +1122,142 @@ mod test { // No more batches should be produced assert!(divide_stream.next().await.is_none()); } + + #[tokio::test] + async fn test_string_tag_column_types() { + let schema = Arc::new(Schema::new(vec![ + Field::new("tag_large", DataType::LargeUtf8, false), + Field::new("tag_view", DataType::Utf8View, false), + Field::new( + "time_index", + DataType::Timestamp(datafusion::arrow::datatypes::TimeUnit::Millisecond, None), + false, + ), + ])); + + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(LargeStringArray::from(vec!["a", "a", "a", "a"])), + Arc::new(StringViewArray::from(vec!["x", "x", "y", "y"])), + Arc::new(datafusion::arrow::array::TimestampMillisecondArray::from( + vec![1000, 2000, 1000, 2000], + )), + ], + ) + .unwrap(); + + let batch2 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(LargeStringArray::from(vec!["b", "b"])), + Arc::new(StringViewArray::from(vec!["x", "x"])), + Arc::new(datafusion::arrow::array::TimestampMillisecondArray::from( + vec![1000, 2000], + )), + ], + ) + .unwrap(); + + let memory_exec: Arc = Arc::new(DataSourceExec::new(Arc::new( + MemorySourceConfig::try_new(&[vec![batch1, batch2]], schema.clone(), None).unwrap(), + ))); + + let divide_exec = Arc::new(SeriesDivideExec { + tag_columns: vec!["tag_large".to_string(), "tag_view".to_string()], + time_index_column: "time_index".to_string(), + input: memory_exec, + metric: ExecutionPlanMetricsSet::new(), + }); + + let session_context = SessionContext::default(); + let result = datafusion::physical_plan::collect(divide_exec, session_context.task_ctx()) + .await + .unwrap(); + + assert_eq!(result.len(), 3); + for ((expected_large, expected_view), batch) in [("a", "x"), ("a", "y"), ("b", "x")] + .into_iter() + .zip(result.iter()) + { + assert_eq!(batch.num_rows(), 2); + + let tag_large_array = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let tag_view_array = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + for row in 0..batch.num_rows() { + assert_eq!(tag_large_array.value(row), expected_large); + assert_eq!(tag_view_array.value(row), expected_view); + } + } + } + + #[tokio::test] + async fn test_u64_tag_column() { + let schema = Arc::new(Schema::new(vec![ + Field::new("tsid", DataType::UInt64, false), + Field::new( + "time_index", + DataType::Timestamp(datafusion::arrow::datatypes::TimeUnit::Millisecond, None), + false, + ), + ])); + + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt64Array::from(vec![1, 1, 2, 2])), + Arc::new(datafusion::arrow::array::TimestampMillisecondArray::from( + vec![1000, 2000, 1000, 2000], + )), + ], + ) + .unwrap(); + + let batch2 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt64Array::from(vec![3, 3])), + Arc::new(datafusion::arrow::array::TimestampMillisecondArray::from( + vec![1000, 2000], + )), + ], + ) + .unwrap(); + + let memory_exec: Arc = Arc::new(DataSourceExec::new(Arc::new( + MemorySourceConfig::try_new(&[vec![batch1, batch2]], schema.clone(), None).unwrap(), + ))); + + let divide_exec = Arc::new(SeriesDivideExec { + tag_columns: vec!["tsid".to_string()], + time_index_column: "time_index".to_string(), + input: memory_exec, + metric: ExecutionPlanMetricsSet::new(), + }); + + let session_context = SessionContext::default(); + let result = datafusion::physical_plan::collect(divide_exec, session_context.task_ctx()) + .await + .unwrap(); + + assert_eq!(result.len(), 3); + for (expected_tsid, batch) in [1u64, 2u64, 3u64].into_iter().zip(result.iter()) { + assert_eq!(batch.num_rows(), 2); + let tsid_array = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(tsid_array.iter().all(|v| v == Some(expected_tsid))); + } + } }