From 21044c7339d1821b92d04834b9ccba62f5ad003b Mon Sep 17 00:00:00 2001 From: discord9 <55937128+discord9@users.noreply.github.com> Date: Sun, 27 Apr 2025 20:43:21 +0800 Subject: [PATCH] feat: uddsketch_merge udaf (#5992) --- src/common/function/src/aggr.rs | 2 +- .../function/src/aggr/uddsketch_state.rs | 96 +++++++++++++++---- src/query/src/datafusion/planner.rs | 6 +- .../src/query_engine/default_serializer.rs | 3 +- .../common/aggregate/uddsketch.result | 34 +++++++ .../standalone/common/aggregate/uddsketch.sql | 17 ++++ .../common/flow/flow_step_aggr.result | 51 ++++++++++ .../standalone/common/flow/flow_step_aggr.sql | 28 ++++++ 8 files changed, 214 insertions(+), 23 deletions(-) diff --git a/src/common/function/src/aggr.rs b/src/common/function/src/aggr.rs index 24bcb86618..8b4486906d 100644 --- a/src/common/function/src/aggr.rs +++ b/src/common/function/src/aggr.rs @@ -19,4 +19,4 @@ mod uddsketch_state; pub use geo_path::{GeoPathAccumulator, GEO_PATH_NAME}; pub(crate) use hll::HllStateType; pub use hll::{HllState, HLL_MERGE_NAME, HLL_NAME}; -pub use uddsketch_state::{UddSketchState, UDDSKETCH_STATE_NAME}; +pub use uddsketch_state::{UddSketchState, UDDSKETCH_MERGE_NAME, UDDSKETCH_STATE_NAME}; diff --git a/src/common/function/src/aggr/uddsketch_state.rs b/src/common/function/src/aggr/uddsketch_state.rs index 3ac138736d..1c5741c4a9 100644 --- a/src/common/function/src/aggr/uddsketch_state.rs +++ b/src/common/function/src/aggr/uddsketch_state.rs @@ -31,23 +31,28 @@ use datafusion::physical_plan::expressions::Literal; use datafusion::prelude::create_udaf; use datatypes::arrow::array::ArrayRef; use datatypes::arrow::datatypes::{DataType, Float64Type}; +use serde::{Deserialize, Serialize}; use uddsketch::{SketchHashKey, UDDSketch}; pub const UDDSKETCH_STATE_NAME: &str = "uddsketch_state"; -#[derive(Debug)] +pub const UDDSKETCH_MERGE_NAME: &str = "uddsketch_merge"; + +#[derive(Debug, Serialize, Deserialize)] pub struct UddSketchState { uddsketch: UDDSketch, + error_rate: f64, } impl UddSketchState { pub fn new(bucket_size: u64, error_rate: f64) -> Self { Self { uddsketch: UDDSketch::new(bucket_size, error_rate), + error_rate, } } - pub fn udf_impl() -> AggregateUDF { + pub fn state_udf_impl() -> AggregateUDF { create_udaf( UDDSKETCH_STATE_NAME, vec![DataType::Int64, DataType::Float64, DataType::Float64], @@ -61,18 +66,55 @@ impl UddSketchState { ) } + /// Create a UDF for the `uddsketch_merge` function. + /// + /// `uddsketch_merge` accepts bucket size, error rate, and a binary column of states generated by `uddsketch_state` + /// and merges them into a single state. + /// + /// The bucket size and error rate must be the same as the original state. + pub fn merge_udf_impl() -> AggregateUDF { + create_udaf( + UDDSKETCH_MERGE_NAME, + vec![DataType::Int64, DataType::Float64, DataType::Binary], + Arc::new(DataType::Binary), + Volatility::Immutable, + Arc::new(|args| { + let (bucket_size, error_rate) = downcast_accumulator_args(args)?; + Ok(Box::new(UddSketchState::new(bucket_size, error_rate))) + }), + Arc::new(vec![DataType::Binary]), + ) + } + fn update(&mut self, value: f64) { self.uddsketch.add_value(value); } - fn merge(&mut self, raw: &[u8]) { - if let Ok(uddsketch) = bincode::deserialize::(raw) { - if uddsketch.count() != 0 { - self.uddsketch.merge_sketch(&uddsketch); + fn merge(&mut self, raw: &[u8]) -> DfResult<()> { + if let Ok(uddsketch) = bincode::deserialize::(raw) { + if uddsketch.uddsketch.count() != 0 { + if self.uddsketch.max_allowed_buckets() != uddsketch.uddsketch.max_allowed_buckets() + || (self.error_rate - uddsketch.error_rate).abs() >= 1e-9 + { + return Err(DataFusionError::Plan(format!( + "Merging UDDSketch with different parameters: arguments={:?} vs actual input={:?}", + ( + self.uddsketch.max_allowed_buckets(), + self.error_rate + ), + (uddsketch.uddsketch.max_allowed_buckets(), uddsketch.error_rate) + ))); + } + self.uddsketch.merge_sketch(&uddsketch.uddsketch); } } else { trace!("Warning: Failed to deserialize UDDSketch from {:?}", raw); + return Err(DataFusionError::Plan( + "Failed to deserialize UDDSketch from binary".to_string(), + )); } + + Ok(()) } } @@ -113,9 +155,21 @@ fn downcast_accumulator_args(args: AccumulatorArgs) -> DfResult<(u64, f64)> { impl DfAccumulator for UddSketchState { fn update_batch(&mut self, values: &[ArrayRef]) -> DfResult<()> { let array = &values[2]; // the third column is data value - let f64_array = as_primitive_array::(array)?; - for v in f64_array.iter().flatten() { - self.update(v); + match array.data_type() { + DataType::Float64 => { + let f64_array = as_primitive_array::(array)?; + for v in f64_array.iter().flatten() { + self.update(v); + } + } + // meaning instantiate as `uddsketch_merge` + DataType::Binary => self.merge_batch(&[array.clone()])?, + _ => { + return not_impl_err!( + "UDDSketch functions do not support data type: {}", + array.data_type() + ) + } } Ok(()) @@ -123,7 +177,7 @@ impl DfAccumulator for UddSketchState { fn evaluate(&mut self) -> DfResult { Ok(ScalarValue::Binary(Some( - bincode::serialize(&self.uddsketch).map_err(|e| { + bincode::serialize(&self).map_err(|e| { DataFusionError::Internal(format!("Failed to serialize UDDSketch: {}", e)) })?, ))) @@ -150,7 +204,7 @@ impl DfAccumulator for UddSketchState { fn state(&mut self) -> DfResult> { Ok(vec![ScalarValue::Binary(Some( - bincode::serialize(&self.uddsketch).map_err(|e| { + bincode::serialize(&self).map_err(|e| { DataFusionError::Internal(format!("Failed to serialize UDDSketch: {}", e)) })?, ))]) @@ -160,7 +214,7 @@ impl DfAccumulator for UddSketchState { let array = &states[0]; let binary_array = as_binary_array(array)?; for v in binary_array.iter().flatten() { - self.merge(v); + self.merge(v)?; } Ok(()) @@ -182,8 +236,8 @@ mod tests { let result = state.evaluate().unwrap(); if let ScalarValue::Binary(Some(bytes)) = result { - let deserialized: UDDSketch = bincode::deserialize(&bytes).unwrap(); - assert_eq!(deserialized.count(), 3); + let deserialized: UddSketchState = bincode::deserialize(&bytes).unwrap(); + assert_eq!(deserialized.uddsketch.count(), 3); } else { panic!("Expected binary scalar value"); } @@ -201,13 +255,15 @@ mod tests { // Create new state and merge the serialized data let mut new_state = UddSketchState::new(10, 0.01); if let ScalarValue::Binary(Some(bytes)) = &serialized { - new_state.merge(bytes); + new_state.merge(bytes).unwrap(); // Verify the merged state matches original by comparing deserialized values - let original_sketch: UDDSketch = bincode::deserialize(bytes).unwrap(); + let original_sketch: UddSketchState = bincode::deserialize(bytes).unwrap(); + let original_sketch = original_sketch.uddsketch; let new_result = new_state.evaluate().unwrap(); if let ScalarValue::Binary(Some(new_bytes)) = new_result { - let new_sketch: UDDSketch = bincode::deserialize(&new_bytes).unwrap(); + let new_sketch: UddSketchState = bincode::deserialize(&new_bytes).unwrap(); + let new_sketch = new_sketch.uddsketch; assert_eq!(original_sketch.count(), new_sketch.count()); assert_eq!(original_sketch.sum(), new_sketch.sum()); assert_eq!(original_sketch.mean(), new_sketch.mean()); @@ -244,7 +300,8 @@ mod tests { let result = state.evaluate().unwrap(); if let ScalarValue::Binary(Some(bytes)) = result { - let deserialized: UDDSketch = bincode::deserialize(&bytes).unwrap(); + let deserialized: UddSketchState = bincode::deserialize(&bytes).unwrap(); + let deserialized = deserialized.uddsketch; assert_eq!(deserialized.count(), 3); } else { panic!("Expected binary scalar value"); @@ -273,7 +330,8 @@ mod tests { let result = merged_state.evaluate().unwrap(); if let ScalarValue::Binary(Some(bytes)) = result { - let deserialized: UDDSketch = bincode::deserialize(&bytes).unwrap(); + let deserialized: UddSketchState = bincode::deserialize(&bytes).unwrap(); + let deserialized = deserialized.uddsketch; assert_eq!(deserialized.count(), 2); } else { panic!("Expected binary scalar value"); diff --git a/src/query/src/datafusion/planner.rs b/src/query/src/datafusion/planner.rs index 0ad531541f..6d0d99e296 100644 --- a/src/query/src/datafusion/planner.rs +++ b/src/query/src/datafusion/planner.rs @@ -20,7 +20,7 @@ use arrow_schema::DataType; use catalog::table_source::DfTableSourceProvider; use common_function::aggr::{ GeoPathAccumulator, HllState, UddSketchState, GEO_PATH_NAME, HLL_MERGE_NAME, HLL_NAME, - UDDSKETCH_STATE_NAME, + UDDSKETCH_MERGE_NAME, UDDSKETCH_STATE_NAME, }; use common_function::scalars::udf::create_udf; use common_query::logical_plan::create_aggregate_function; @@ -165,7 +165,9 @@ impl ContextProvider for DfContextProviderAdapter { fn get_aggregate_meta(&self, name: &str) -> Option> { if name == UDDSKETCH_STATE_NAME { - return Some(Arc::new(UddSketchState::udf_impl())); + return Some(Arc::new(UddSketchState::state_udf_impl())); + } else if name == UDDSKETCH_MERGE_NAME { + return Some(Arc::new(UddSketchState::merge_udf_impl())); } else if name == HLL_NAME { return Some(Arc::new(HllState::state_udf_impl())); } else if name == HLL_MERGE_NAME { diff --git a/src/query/src/query_engine/default_serializer.rs b/src/query/src/query_engine/default_serializer.rs index c3feed1d55..a7e5ff2ca1 100644 --- a/src/query/src/query_engine/default_serializer.rs +++ b/src/query/src/query_engine/default_serializer.rs @@ -128,7 +128,8 @@ impl SubstraitPlanDecoder for DefaultPlanDecoder { session_state .register_udf(udf) .context(RegisterUdfSnafu { name: func.name() })?; - let _ = session_state.register_udaf(Arc::new(UddSketchState::udf_impl())); + let _ = session_state.register_udaf(Arc::new(UddSketchState::state_udf_impl())); + let _ = session_state.register_udaf(Arc::new(UddSketchState::merge_udf_impl())); let _ = session_state.register_udaf(Arc::new(HllState::state_udf_impl())); let _ = session_state.register_udaf(Arc::new(HllState::merge_udf_impl())); let _ = session_state.register_udaf(Arc::new(GeoPathAccumulator::udf_impl())); diff --git a/tests/cases/standalone/common/aggregate/uddsketch.result b/tests/cases/standalone/common/aggregate/uddsketch.result index a1cd1bbac7..6304ea9362 100644 --- a/tests/cases/standalone/common/aggregate/uddsketch.result +++ b/tests/cases/standalone/common/aggregate/uddsketch.result @@ -52,7 +52,41 @@ select uddsketch_calc(0.95, uddsketch_state(128, 0.01, `value`)) from test_uddsk | 100.49456770856492 | +----------------------------------------------------------------------------------------------+ +CREATE TABLE grouped_uddsketch ( + `state` BINARY, + id_group INT PRIMARY KEY, + `ts` timestamp time index default now() +); + +Affected Rows: 0 + +INSERT INTO grouped_uddsketch (`state`, id_group) SELECT uddsketch_state(128, 0.01, `value`), `id`/5*5 as id_group FROM test_uddsketch GROUP BY id_group; + +Affected Rows: 3 + +SELECT uddsketch_calc(0.1, uddsketch_merge(128, 0.01, `state`)) FROM grouped_uddsketch; + ++------------------------------------------------------------------------------------------------+ +| uddsketch_calc(Float64(0.1),uddsketch_merge(Int64(128),Float64(0.01),grouped_uddsketch.state)) | ++------------------------------------------------------------------------------------------------+ +| 19.886670240866184 | ++------------------------------------------------------------------------------------------------+ + +-- should fail +SELECT uddsketch_calc(0.1, uddsketch_merge(128, 0.1, `state`)) FROM grouped_uddsketch; + +Error: 3001(EngineExecuteQuery), Error during planning: Merging UDDSketch with different parameters: arguments=(128, 0.1) vs actual input=(128, 0.01) + +-- should fail +SELECT uddsketch_calc(0.1, uddsketch_merge(64, 0.01, `state`)) FROM grouped_uddsketch; + +Error: 3001(EngineExecuteQuery), Error during planning: Merging UDDSketch with different parameters: arguments=(64, 0.01) vs actual input=(128, 0.01) + drop table test_uddsketch; Affected Rows: 0 +drop table grouped_uddsketch; + +Affected Rows: 0 + diff --git a/tests/cases/standalone/common/aggregate/uddsketch.sql b/tests/cases/standalone/common/aggregate/uddsketch.sql index 40931dbbc9..56ce2ccf97 100644 --- a/tests/cases/standalone/common/aggregate/uddsketch.sql +++ b/tests/cases/standalone/common/aggregate/uddsketch.sql @@ -24,4 +24,21 @@ select uddsketch_calc(0.75, uddsketch_state(128, 0.01, `value`)) from test_uddsk select uddsketch_calc(0.95, uddsketch_state(128, 0.01, `value`)) from test_uddsketch; +CREATE TABLE grouped_uddsketch ( + `state` BINARY, + id_group INT PRIMARY KEY, + `ts` timestamp time index default now() +); + +INSERT INTO grouped_uddsketch (`state`, id_group) SELECT uddsketch_state(128, 0.01, `value`), `id`/5*5 as id_group FROM test_uddsketch GROUP BY id_group; + +SELECT uddsketch_calc(0.1, uddsketch_merge(128, 0.01, `state`)) FROM grouped_uddsketch; + +-- should fail +SELECT uddsketch_calc(0.1, uddsketch_merge(128, 0.1, `state`)) FROM grouped_uddsketch; + +-- should fail +SELECT uddsketch_calc(0.1, uddsketch_merge(64, 0.01, `state`)) FROM grouped_uddsketch; + drop table test_uddsketch; +drop table grouped_uddsketch; diff --git a/tests/cases/standalone/common/flow/flow_step_aggr.result b/tests/cases/standalone/common/flow/flow_step_aggr.result index ab76a67617..74113ccccb 100644 --- a/tests/cases/standalone/common/flow/flow_step_aggr.result +++ b/tests/cases/standalone/common/flow/flow_step_aggr.result @@ -201,6 +201,13 @@ CREATE TABLE percentile_5s ( Affected Rows: 0 +CREATE TABLE percentile_10s ( + "percentile_state" BINARY, + time_window timestamp(0) time index +); + +Affected Rows: 0 + CREATE FLOW calc_percentile_5s SINK TO percentile_5s AS SELECT @@ -213,6 +220,18 @@ GROUP BY Affected Rows: 0 +CREATE FLOW calc_percentile_10s SINK TO percentile_10s +AS +SELECT + uddsketch_merge(128, 0.01, percentile_state), + date_bin('10 seconds'::INTERVAL, time_window) AS time_window +FROM + percentile_5s +GROUP BY + date_bin('10 seconds'::INTERVAL, time_window); + +Affected Rows: 0 + INSERT INTO percentile_base ("id", "value", ts) VALUES (1, 10.0, 1), (2, 20.0, 2), @@ -236,6 +255,15 @@ ADMIN FLUSH_FLOW('calc_percentile_5s'); | FLOW_FLUSHED | +----------------------------------------+ +-- SQLNESS REPLACE (ADMIN\sFLUSH_FLOW\('\w+'\)\s+\|\n\+-+\+\n\|\s+)[0-9]+\s+\| $1 FLOW_FLUSHED | +ADMIN FLUSH_FLOW('calc_percentile_10s'); + ++-----------------------------------------+ +| ADMIN FLUSH_FLOW('calc_percentile_10s') | ++-----------------------------------------+ +| FLOW_FLUSHED | ++-----------------------------------------+ + SELECT time_window, uddsketch_calc(0.99, percentile_state) AS p99 @@ -252,14 +280,37 @@ ORDER BY | 1970-01-01T00:00:10 | | +---------------------+--------------------+ +SELECT + time_window, + uddsketch_calc(0.99, percentile_state) AS p99 +FROM + percentile_10s +ORDER BY + time_window; + ++---------------------+--------------------+ +| time_window | p99 | ++---------------------+--------------------+ +| 1970-01-01T00:00:00 | 59.745049810145126 | +| 1970-01-01T00:00:10 | | ++---------------------+--------------------+ + DROP FLOW calc_percentile_5s; Affected Rows: 0 +DROP FLOW calc_percentile_10s; + +Affected Rows: 0 + DROP TABLE percentile_5s; Affected Rows: 0 +DROP TABLE percentile_10s; + +Affected Rows: 0 + DROP TABLE percentile_base; Affected Rows: 0 diff --git a/tests/cases/standalone/common/flow/flow_step_aggr.sql b/tests/cases/standalone/common/flow/flow_step_aggr.sql index 44dde88912..92698d8de6 100644 --- a/tests/cases/standalone/common/flow/flow_step_aggr.sql +++ b/tests/cases/standalone/common/flow/flow_step_aggr.sql @@ -123,6 +123,11 @@ CREATE TABLE percentile_5s ( time_window timestamp(0) time index ); +CREATE TABLE percentile_10s ( + "percentile_state" BINARY, + time_window timestamp(0) time index +); + CREATE FLOW calc_percentile_5s SINK TO percentile_5s AS SELECT @@ -133,6 +138,16 @@ FROM GROUP BY time_window; +CREATE FLOW calc_percentile_10s SINK TO percentile_10s +AS +SELECT + uddsketch_merge(128, 0.01, percentile_state), + date_bin('10 seconds'::INTERVAL, time_window) AS time_window +FROM + percentile_5s +GROUP BY + date_bin('10 seconds'::INTERVAL, time_window); + INSERT INTO percentile_base ("id", "value", ts) VALUES (1, 10.0, 1), (2, 20.0, 2), @@ -148,6 +163,9 @@ INSERT INTO percentile_base ("id", "value", ts) VALUES -- SQLNESS REPLACE (ADMIN\sFLUSH_FLOW\('\w+'\)\s+\|\n\+-+\+\n\|\s+)[0-9]+\s+\| $1 FLOW_FLUSHED | ADMIN FLUSH_FLOW('calc_percentile_5s'); +-- SQLNESS REPLACE (ADMIN\sFLUSH_FLOW\('\w+'\)\s+\|\n\+-+\+\n\|\s+)[0-9]+\s+\| $1 FLOW_FLUSHED | +ADMIN FLUSH_FLOW('calc_percentile_10s'); + SELECT time_window, uddsketch_calc(0.99, percentile_state) AS p99 @@ -156,6 +174,16 @@ FROM ORDER BY time_window; +SELECT + time_window, + uddsketch_calc(0.99, percentile_state) AS p99 +FROM + percentile_10s +ORDER BY + time_window; + DROP FLOW calc_percentile_5s; +DROP FLOW calc_percentile_10s; DROP TABLE percentile_5s; +DROP TABLE percentile_10s; DROP TABLE percentile_base;