mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-07 22:02:56 +00:00
feat: implement uddsketch function to calculate percentile (#5574)
* basic impl Signed-off-by: Ruihang Xia <waynestxia@gmail.com> * more tests Signed-off-by: Ruihang Xia <waynestxia@gmail.com> * sqlness test Signed-off-by: Ruihang Xia <waynestxia@gmail.com> * fix clippy Signed-off-by: Ruihang Xia <waynestxia@gmail.com> * update with more test and logs Signed-off-by: Ruihang Xia <waynestxia@gmail.com> --------- Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
This commit is contained in:
@@ -15,6 +15,7 @@ geo = ["geohash", "h3o", "s2", "wkt", "geo-types", "dep:geo"]
|
||||
api.workspace = true
|
||||
arc-swap = "1.0"
|
||||
async-trait.workspace = true
|
||||
bincode = "1.3"
|
||||
common-base.workspace = true
|
||||
common-catalog.workspace = true
|
||||
common-error.workspace = true
|
||||
@@ -47,6 +48,7 @@ sql.workspace = true
|
||||
statrs = "0.16"
|
||||
store-api.workspace = true
|
||||
table.workspace = true
|
||||
uddsketch = { git = "https://github.com/GreptimeTeam/timescaledb-toolkit.git", rev = "84828fe8fb494a6a61412a3da96517fc80f7bb20" }
|
||||
wkt = { version = "0.11", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
17
src/common/function/src/aggr.rs
Normal file
17
src/common/function/src/aggr.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
mod uddsketch_state;
|
||||
|
||||
pub use uddsketch_state::{UddSketchState, UDDSKETCH_STATE_NAME};
|
||||
307
src/common/function/src/aggr/uddsketch_state.rs
Normal file
307
src/common/function/src/aggr/uddsketch_state.rs
Normal file
@@ -0,0 +1,307 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::prelude::*;
|
||||
use common_telemetry::trace;
|
||||
use datafusion::common::cast::{as_binary_array, as_primitive_array};
|
||||
use datafusion::common::not_impl_err;
|
||||
use datafusion::error::{DataFusionError, Result as DfResult};
|
||||
use datafusion::logical_expr::function::AccumulatorArgs;
|
||||
use datafusion::logical_expr::{Accumulator as DfAccumulator, AggregateUDF};
|
||||
use datafusion::physical_plan::expressions::Literal;
|
||||
use datafusion::prelude::create_udaf;
|
||||
use datatypes::arrow::array::ArrayRef;
|
||||
use datatypes::arrow::datatypes::{DataType, Float64Type};
|
||||
use uddsketch::{SketchHashKey, UDDSketch};
|
||||
|
||||
pub const UDDSKETCH_STATE_NAME: &str = "uddsketch_state";
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UddSketchState {
|
||||
uddsketch: UDDSketch,
|
||||
}
|
||||
|
||||
impl UddSketchState {
|
||||
pub fn new(bucket_size: u64, error_rate: f64) -> Self {
|
||||
Self {
|
||||
uddsketch: UDDSketch::new(bucket_size, error_rate),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn udf_impl() -> AggregateUDF {
|
||||
create_udaf(
|
||||
UDDSKETCH_STATE_NAME,
|
||||
vec![DataType::Int64, DataType::Float64, DataType::Float64],
|
||||
Arc::new(DataType::Binary),
|
||||
Volatility::Immutable,
|
||||
Arc::new(|args| {
|
||||
let (bucket_size, error_rate) = downcast_accumulator_args(args)?;
|
||||
Ok(Box::new(UddSketchState::new(bucket_size, error_rate)))
|
||||
}),
|
||||
Arc::new(vec![DataType::Binary]),
|
||||
)
|
||||
}
|
||||
|
||||
fn update(&mut self, value: f64) {
|
||||
self.uddsketch.add_value(value);
|
||||
}
|
||||
|
||||
fn merge(&mut self, raw: &[u8]) {
|
||||
if let Ok(uddsketch) = bincode::deserialize::<UDDSketch>(raw) {
|
||||
if uddsketch.count() != 0 {
|
||||
self.uddsketch.merge_sketch(&uddsketch);
|
||||
}
|
||||
} else {
|
||||
trace!("Warning: Failed to deserialize UDDSketch from {:?}", raw);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn downcast_accumulator_args(args: AccumulatorArgs) -> DfResult<(u64, f64)> {
|
||||
let bucket_size = match args.exprs[0]
|
||||
.as_any()
|
||||
.downcast_ref::<Literal>()
|
||||
.map(|lit| lit.value())
|
||||
{
|
||||
Some(ScalarValue::Int64(Some(value))) => *value as u64,
|
||||
_ => {
|
||||
return not_impl_err!(
|
||||
"{} not supported for bucket size: {}",
|
||||
UDDSKETCH_STATE_NAME,
|
||||
&args.exprs[0]
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let error_rate = match args.exprs[1]
|
||||
.as_any()
|
||||
.downcast_ref::<Literal>()
|
||||
.map(|lit| lit.value())
|
||||
{
|
||||
Some(ScalarValue::Float64(Some(value))) => *value,
|
||||
_ => {
|
||||
return not_impl_err!(
|
||||
"{} not supported for error rate: {}",
|
||||
UDDSKETCH_STATE_NAME,
|
||||
&args.exprs[1]
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
Ok((bucket_size, error_rate))
|
||||
}
|
||||
|
||||
impl DfAccumulator for UddSketchState {
|
||||
fn update_batch(&mut self, values: &[ArrayRef]) -> DfResult<()> {
|
||||
let array = &values[2]; // the third column is data value
|
||||
let f64_array = as_primitive_array::<Float64Type>(array)?;
|
||||
for v in f64_array.iter().flatten() {
|
||||
self.update(v);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn evaluate(&mut self) -> DfResult<ScalarValue> {
|
||||
Ok(ScalarValue::Binary(Some(
|
||||
bincode::serialize(&self.uddsketch).map_err(|e| {
|
||||
DataFusionError::Internal(format!("Failed to serialize UDDSketch: {}", e))
|
||||
})?,
|
||||
)))
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
// Base size of UDDSketch struct fields
|
||||
let mut total_size = std::mem::size_of::<f64>() * 3 + // alpha, gamma, values_sum
|
||||
std::mem::size_of::<u32>() + // compactions
|
||||
std::mem::size_of::<u64>() * 2; // max_buckets, num_values
|
||||
|
||||
// Size of buckets (SketchHashMap)
|
||||
// Each bucket entry contains:
|
||||
// - SketchHashKey (enum with i64/Zero/Invalid variants)
|
||||
// - SketchHashEntry (count: u64, next: SketchHashKey)
|
||||
let bucket_entry_size = std::mem::size_of::<SketchHashKey>() + // key
|
||||
std::mem::size_of::<u64>() + // count
|
||||
std::mem::size_of::<SketchHashKey>(); // next
|
||||
|
||||
total_size += self.uddsketch.current_buckets_count() * bucket_entry_size;
|
||||
|
||||
total_size
|
||||
}
|
||||
|
||||
fn state(&mut self) -> DfResult<Vec<ScalarValue>> {
|
||||
Ok(vec![ScalarValue::Binary(Some(
|
||||
bincode::serialize(&self.uddsketch).map_err(|e| {
|
||||
DataFusionError::Internal(format!("Failed to serialize UDDSketch: {}", e))
|
||||
})?,
|
||||
))])
|
||||
}
|
||||
|
||||
fn merge_batch(&mut self, states: &[ArrayRef]) -> DfResult<()> {
|
||||
let array = &states[0];
|
||||
let binary_array = as_binary_array(array)?;
|
||||
for v in binary_array.iter().flatten() {
|
||||
self.merge(v);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use datafusion::arrow::array::{BinaryArray, Float64Array};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_uddsketch_state_basic() {
|
||||
let mut state = UddSketchState::new(10, 0.01);
|
||||
state.update(1.0);
|
||||
state.update(2.0);
|
||||
state.update(3.0);
|
||||
|
||||
let result = state.evaluate().unwrap();
|
||||
if let ScalarValue::Binary(Some(bytes)) = result {
|
||||
let deserialized: UDDSketch = bincode::deserialize(&bytes).unwrap();
|
||||
assert_eq!(deserialized.count(), 3);
|
||||
} else {
|
||||
panic!("Expected binary scalar value");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_uddsketch_state_roundtrip() {
|
||||
let mut state = UddSketchState::new(10, 0.01);
|
||||
state.update(1.0);
|
||||
state.update(2.0);
|
||||
|
||||
// Serialize
|
||||
let serialized = state.evaluate().unwrap();
|
||||
|
||||
// Create new state and merge the serialized data
|
||||
let mut new_state = UddSketchState::new(10, 0.01);
|
||||
if let ScalarValue::Binary(Some(bytes)) = &serialized {
|
||||
new_state.merge(bytes);
|
||||
|
||||
// Verify the merged state matches original by comparing deserialized values
|
||||
let original_sketch: UDDSketch = bincode::deserialize(bytes).unwrap();
|
||||
let new_result = new_state.evaluate().unwrap();
|
||||
if let ScalarValue::Binary(Some(new_bytes)) = new_result {
|
||||
let new_sketch: UDDSketch = bincode::deserialize(&new_bytes).unwrap();
|
||||
assert_eq!(original_sketch.count(), new_sketch.count());
|
||||
assert_eq!(original_sketch.sum(), new_sketch.sum());
|
||||
assert_eq!(original_sketch.mean(), new_sketch.mean());
|
||||
assert_eq!(original_sketch.max_error(), new_sketch.max_error());
|
||||
// Compare a few quantiles to ensure statistical equivalence
|
||||
for q in [0.1, 0.5, 0.9].iter() {
|
||||
assert!(
|
||||
(original_sketch.estimate_quantile(*q) - new_sketch.estimate_quantile(*q))
|
||||
.abs()
|
||||
< 1e-10,
|
||||
"Quantile {} mismatch: original={}, new={}",
|
||||
q,
|
||||
original_sketch.estimate_quantile(*q),
|
||||
new_sketch.estimate_quantile(*q)
|
||||
);
|
||||
}
|
||||
} else {
|
||||
panic!("Expected binary scalar value");
|
||||
}
|
||||
} else {
|
||||
panic!("Expected binary scalar value");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_uddsketch_state_batch_update() {
|
||||
let mut state = UddSketchState::new(10, 0.01);
|
||||
let values = vec![1.0f64, 2.0, 3.0];
|
||||
let array = Arc::new(Float64Array::from(values)) as ArrayRef;
|
||||
|
||||
state
|
||||
.update_batch(&[array.clone(), array.clone(), array])
|
||||
.unwrap();
|
||||
|
||||
let result = state.evaluate().unwrap();
|
||||
if let ScalarValue::Binary(Some(bytes)) = result {
|
||||
let deserialized: UDDSketch = bincode::deserialize(&bytes).unwrap();
|
||||
assert_eq!(deserialized.count(), 3);
|
||||
} else {
|
||||
panic!("Expected binary scalar value");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_uddsketch_state_merge_batch() {
|
||||
let mut state1 = UddSketchState::new(10, 0.01);
|
||||
state1.update(1.0);
|
||||
let state1_binary = state1.evaluate().unwrap();
|
||||
|
||||
let mut state2 = UddSketchState::new(10, 0.01);
|
||||
state2.update(2.0);
|
||||
let state2_binary = state2.evaluate().unwrap();
|
||||
|
||||
let mut merged_state = UddSketchState::new(10, 0.01);
|
||||
if let (ScalarValue::Binary(Some(bytes1)), ScalarValue::Binary(Some(bytes2))) =
|
||||
(&state1_binary, &state2_binary)
|
||||
{
|
||||
let binary_array = Arc::new(BinaryArray::from(vec![
|
||||
bytes1.as_slice(),
|
||||
bytes2.as_slice(),
|
||||
])) as ArrayRef;
|
||||
merged_state.merge_batch(&[binary_array]).unwrap();
|
||||
|
||||
let result = merged_state.evaluate().unwrap();
|
||||
if let ScalarValue::Binary(Some(bytes)) = result {
|
||||
let deserialized: UDDSketch = bincode::deserialize(&bytes).unwrap();
|
||||
assert_eq!(deserialized.count(), 2);
|
||||
} else {
|
||||
panic!("Expected binary scalar value");
|
||||
}
|
||||
} else {
|
||||
panic!("Expected binary scalar values");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_uddsketch_state_size() {
|
||||
let mut state = UddSketchState::new(10, 0.01);
|
||||
let initial_size = state.size();
|
||||
|
||||
// Add some values to create buckets
|
||||
state.update(1.0);
|
||||
state.update(2.0);
|
||||
state.update(3.0);
|
||||
|
||||
let size_with_values = state.size();
|
||||
assert!(
|
||||
size_with_values > initial_size,
|
||||
"Size should increase after adding values: initial={}, with_values={}",
|
||||
initial_size,
|
||||
size_with_values
|
||||
);
|
||||
|
||||
// Verify size increases with more buckets
|
||||
state.update(10.0); // This should create a new bucket
|
||||
assert!(
|
||||
state.size() > size_with_values,
|
||||
"Size should increase after adding new bucket: prev={}, new={}",
|
||||
size_with_values,
|
||||
state.size()
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -26,6 +26,7 @@ use crate::scalars::json::JsonFunction;
|
||||
use crate::scalars::matches::MatchesFunction;
|
||||
use crate::scalars::math::MathFunction;
|
||||
use crate::scalars::timestamp::TimestampFunction;
|
||||
use crate::scalars::uddsketch_calc::UddSketchCalcFunction;
|
||||
use crate::scalars::vector::VectorFunction;
|
||||
use crate::system::SystemFunction;
|
||||
use crate::table::TableFunction;
|
||||
@@ -105,6 +106,7 @@ pub static FUNCTION_REGISTRY: Lazy<Arc<FunctionRegistry>> = Lazy::new(|| {
|
||||
TimestampFunction::register(&function_registry);
|
||||
DateFunction::register(&function_registry);
|
||||
ExpressionFunction::register(&function_registry);
|
||||
UddSketchCalcFunction::register(&function_registry);
|
||||
|
||||
// Aggregate functions
|
||||
AggregateFunctions::register(&function_registry);
|
||||
|
||||
@@ -21,6 +21,7 @@ pub mod scalars;
|
||||
mod system;
|
||||
mod table;
|
||||
|
||||
pub mod aggr;
|
||||
pub mod function;
|
||||
pub mod function_registry;
|
||||
pub mod handlers;
|
||||
|
||||
@@ -25,4 +25,5 @@ pub mod vector;
|
||||
#[cfg(test)]
|
||||
pub(crate) mod test;
|
||||
pub(crate) mod timestamp;
|
||||
pub(crate) mod uddsketch_calc;
|
||||
pub mod udf;
|
||||
|
||||
211
src/common/function/src/scalars/uddsketch_calc.rs
Normal file
211
src/common/function/src/scalars/uddsketch_calc.rs
Normal file
@@ -0,0 +1,211 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Implementation of the scalar function `uddsketch_calc`.
|
||||
|
||||
use std::fmt;
|
||||
use std::fmt::Display;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::{DowncastVectorSnafu, InvalidFuncArgsSnafu, Result};
|
||||
use common_query::prelude::{Signature, Volatility};
|
||||
use datatypes::data_type::ConcreteDataType;
|
||||
use datatypes::prelude::Vector;
|
||||
use datatypes::scalars::{ScalarVector, ScalarVectorBuilder};
|
||||
use datatypes::vectors::{BinaryVector, Float64VectorBuilder, MutableVector, VectorRef};
|
||||
use snafu::OptionExt;
|
||||
use uddsketch::UDDSketch;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
|
||||
const NAME: &str = "uddsketch_calc";
|
||||
|
||||
/// UddSketchCalcFunction implements the scalar function `uddsketch_calc`.
|
||||
///
|
||||
/// It accepts two arguments:
|
||||
/// 1. A percentile (as f64) for which to compute the estimated quantile (e.g. 0.95 for p95).
|
||||
/// 2. The serialized UDDSketch state, as produced by the aggregator (binary).
|
||||
///
|
||||
/// For each row, it deserializes the sketch and returns the computed quantile value.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct UddSketchCalcFunction;
|
||||
|
||||
impl UddSketchCalcFunction {
|
||||
pub fn register(registry: &FunctionRegistry) {
|
||||
registry.register(Arc::new(UddSketchCalcFunction));
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for UddSketchCalcFunction {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", NAME.to_ascii_uppercase())
|
||||
}
|
||||
}
|
||||
|
||||
impl Function for UddSketchCalcFunction {
|
||||
fn name(&self) -> &str {
|
||||
NAME
|
||||
}
|
||||
|
||||
fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
|
||||
Ok(ConcreteDataType::float64_datatype())
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
// First argument: percentile (float64)
|
||||
// Second argument: UDDSketch state (binary)
|
||||
Signature::exact(
|
||||
vec![
|
||||
ConcreteDataType::float64_datatype(),
|
||||
ConcreteDataType::binary_datatype(),
|
||||
],
|
||||
Volatility::Immutable,
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
if columns.len() != 2 {
|
||||
return InvalidFuncArgsSnafu {
|
||||
err_msg: format!("uddsketch_calc expects 2 arguments, got {}", columns.len()),
|
||||
}
|
||||
.fail();
|
||||
}
|
||||
|
||||
let perc_vec = &columns[0];
|
||||
let sketch_vec = columns[1]
|
||||
.as_any()
|
||||
.downcast_ref::<BinaryVector>()
|
||||
.with_context(|| DowncastVectorSnafu {
|
||||
err_msg: format!("expect BinaryVector, got {}", columns[1].vector_type_name()),
|
||||
})?;
|
||||
let len = sketch_vec.len();
|
||||
let mut builder = Float64VectorBuilder::with_capacity(len);
|
||||
|
||||
for i in 0..len {
|
||||
let perc_opt = perc_vec.get(i).as_f64_lossy();
|
||||
let sketch_opt = sketch_vec.get_data(i);
|
||||
|
||||
if sketch_opt.is_none() || perc_opt.is_none() {
|
||||
builder.push_null();
|
||||
continue;
|
||||
}
|
||||
|
||||
let sketch_bytes = sketch_opt.unwrap();
|
||||
let perc = perc_opt.unwrap();
|
||||
|
||||
// Deserialize the UDDSketch from its bincode representation
|
||||
let sketch: UDDSketch = match bincode::deserialize(sketch_bytes) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
common_telemetry::trace!("Failed to deserialize UDDSketch: {}", e);
|
||||
builder.push_null();
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Compute the estimated quantile from the sketch
|
||||
let result = sketch.estimate_quantile(perc);
|
||||
builder.push(Some(result));
|
||||
}
|
||||
|
||||
Ok(builder.to_vector())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::vectors::{BinaryVector, Float64Vector};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_uddsketch_calc_function() {
|
||||
let function = UddSketchCalcFunction;
|
||||
assert_eq!("uddsketch_calc", function.name());
|
||||
assert_eq!(
|
||||
ConcreteDataType::float64_datatype(),
|
||||
function
|
||||
.return_type(&[ConcreteDataType::float64_datatype()])
|
||||
.unwrap()
|
||||
);
|
||||
|
||||
// Create a test sketch
|
||||
let mut sketch = UDDSketch::new(128, 0.01);
|
||||
sketch.add_value(10.0);
|
||||
sketch.add_value(20.0);
|
||||
sketch.add_value(30.0);
|
||||
sketch.add_value(40.0);
|
||||
sketch.add_value(50.0);
|
||||
sketch.add_value(60.0);
|
||||
sketch.add_value(70.0);
|
||||
sketch.add_value(80.0);
|
||||
sketch.add_value(90.0);
|
||||
sketch.add_value(100.0);
|
||||
|
||||
// Get expected values directly from the sketch
|
||||
let expected_p50 = sketch.estimate_quantile(0.5);
|
||||
let expected_p90 = sketch.estimate_quantile(0.9);
|
||||
let expected_p95 = sketch.estimate_quantile(0.95);
|
||||
|
||||
let serialized = bincode::serialize(&sketch).unwrap();
|
||||
let percentiles = vec![0.5, 0.9, 0.95];
|
||||
|
||||
let args: Vec<VectorRef> = vec![
|
||||
Arc::new(Float64Vector::from_vec(percentiles.clone())),
|
||||
Arc::new(BinaryVector::from(vec![Some(serialized.clone()); 3])),
|
||||
];
|
||||
|
||||
let result = function.eval(FunctionContext::default(), &args).unwrap();
|
||||
assert_eq!(result.len(), 3);
|
||||
|
||||
// Test median (p50)
|
||||
assert!(
|
||||
matches!(result.get(0), datatypes::value::Value::Float64(v) if (v - expected_p50).abs() < 1e-10)
|
||||
);
|
||||
// Test p90
|
||||
assert!(
|
||||
matches!(result.get(1), datatypes::value::Value::Float64(v) if (v - expected_p90).abs() < 1e-10)
|
||||
);
|
||||
// Test p95
|
||||
assert!(
|
||||
matches!(result.get(2), datatypes::value::Value::Float64(v) if (v - expected_p95).abs() < 1e-10)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_uddsketch_calc_function_errors() {
|
||||
let function = UddSketchCalcFunction;
|
||||
|
||||
// Test with invalid number of arguments
|
||||
let args: Vec<VectorRef> = vec![Arc::new(Float64Vector::from_vec(vec![0.95]))];
|
||||
let result = function.eval(FunctionContext::default(), &args);
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("uddsketch_calc expects 2 arguments"));
|
||||
|
||||
// Test with invalid binary data
|
||||
let args: Vec<VectorRef> = vec![
|
||||
Arc::new(Float64Vector::from_vec(vec![0.95])),
|
||||
Arc::new(BinaryVector::from(vec![Some(vec![1, 2, 3])])), // Invalid binary data
|
||||
];
|
||||
let result = function.eval(FunctionContext::default(), &args).unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert!(matches!(result.get(0), datatypes::value::Value::Null));
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,7 @@ use std::sync::Arc;
|
||||
|
||||
use arrow_schema::DataType;
|
||||
use catalog::table_source::DfTableSourceProvider;
|
||||
use common_function::aggr::{UddSketchState, UDDSKETCH_STATE_NAME};
|
||||
use common_function::scalars::udf::create_udf;
|
||||
use common_query::logical_plan::create_aggregate_function;
|
||||
use datafusion::common::TableReference;
|
||||
@@ -165,6 +166,10 @@ impl ContextProvider for DfContextProviderAdapter {
|
||||
}
|
||||
|
||||
fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
|
||||
if name == UDDSKETCH_STATE_NAME {
|
||||
return Some(Arc::new(UddSketchState::udf_impl()));
|
||||
}
|
||||
|
||||
self.engine_state.aggregate_function(name).map_or_else(
|
||||
|| self.session_state.aggregate_functions().get(name).cloned(),
|
||||
|func| {
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_error::ext::BoxedError;
|
||||
use common_function::aggr::UddSketchState;
|
||||
use common_function::function_registry::FUNCTION_REGISTRY;
|
||||
use common_function::scalars::udf::create_udf;
|
||||
use common_query::error::RegisterUdfSnafu;
|
||||
@@ -125,6 +126,7 @@ impl SubstraitPlanDecoder for DefaultPlanDecoder {
|
||||
session_state
|
||||
.register_udf(udf)
|
||||
.context(RegisterUdfSnafu { name: func.name() })?;
|
||||
let _ = session_state.register_udaf(Arc::new(UddSketchState::udf_impl()));
|
||||
}
|
||||
let logical_plan = DFLogicalSubstraitConvertor
|
||||
.decode(message, session_state)
|
||||
|
||||
Reference in New Issue
Block a user