From 7bd108e2be5e94f0a4a4154b83b4de59e19bae6b Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Mon, 24 Feb 2025 11:07:37 -0800 Subject: [PATCH] feat: impl `hll_state`, `hll_merge` and `hll_calc` for incremental distinct counting (#5579) * basic impl Signed-off-by: Ruihang Xia * more tests Signed-off-by: Ruihang Xia * sqlness test Signed-off-by: Ruihang Xia * fix clippy Signed-off-by: Ruihang Xia * update with more test and logs Signed-off-by: Ruihang Xia * impl Signed-off-by: Ruihang Xia * impl merge fn Signed-off-by: Ruihang Xia * rename function names Signed-off-by: Ruihang Xia --------- Signed-off-by: Ruihang Xia --- Cargo.lock | 11 + src/common/function/Cargo.toml | 2 + src/common/function/src/aggr.rs | 3 + src/common/function/src/aggr/hll.rs | 319 ++++++++++++++++++ src/common/function/src/function_registry.rs | 2 + src/common/function/src/scalars.rs | 1 + src/common/function/src/scalars/hll_count.rs | 175 ++++++++++ src/common/function/src/utils.rs | 70 ++++ src/query/src/datafusion/planner.rs | 10 +- .../src/query_engine/default_serializer.rs | 4 +- .../standalone/common/aggregate/hll.result | 84 +++++ .../cases/standalone/common/aggregate/hll.sql | 49 +++ 12 files changed, 728 insertions(+), 2 deletions(-) create mode 100644 src/common/function/src/aggr/hll.rs create mode 100644 src/common/function/src/scalars/hll_count.rs create mode 100644 tests/cases/standalone/common/aggregate/hll.result create mode 100644 tests/cases/standalone/common/aggregate/hll.sql diff --git a/Cargo.lock b/Cargo.lock index 83832853b8..8f6b4a17f5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2009,6 +2009,7 @@ dependencies = [ name = "common-function" version = "0.12.0" dependencies = [ + "ahash 0.8.11", "api", "approx 0.5.1", "arc-swap", @@ -2031,6 +2032,7 @@ dependencies = [ "geo-types", "geohash", "h3o", + "hyperloglogplus", "jsonb", "nalgebra 0.33.2", "num", @@ -5289,6 +5291,15 @@ dependencies = [ "tracing", ] +[[package]] +name = "hyperloglogplus" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "621debdf94dcac33e50475fdd76d34d5ea9c0362a834b9db08c3024696c1fbe3" +dependencies = [ + "serde", +] + [[package]] name = "i_float" version = "1.3.1" diff --git a/src/common/function/Cargo.toml b/src/common/function/Cargo.toml index 851703da26..d2aa4a86c3 100644 --- a/src/common/function/Cargo.toml +++ b/src/common/function/Cargo.toml @@ -12,6 +12,7 @@ default = ["geo"] geo = ["geohash", "h3o", "s2", "wkt", "geo-types", "dep:geo"] [dependencies] +ahash = "0.8" api.workspace = true arc-swap = "1.0" async-trait.workspace = true @@ -33,6 +34,7 @@ geo = { version = "0.29", optional = true } geo-types = { version = "0.7", optional = true } geohash = { version = "0.13", optional = true } h3o = { version = "0.6", optional = true } +hyperloglogplus = "0.4" jsonb.workspace = true nalgebra.workspace = true num = "0.4" diff --git a/src/common/function/src/aggr.rs b/src/common/function/src/aggr.rs index ab9281fbb7..be271d4d20 100644 --- a/src/common/function/src/aggr.rs +++ b/src/common/function/src/aggr.rs @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod hll; mod uddsketch_state; +pub(crate) use hll::HllStateType; +pub use hll::{HllState, HLL_MERGE_NAME, HLL_NAME}; pub use uddsketch_state::{UddSketchState, UDDSKETCH_STATE_NAME}; diff --git a/src/common/function/src/aggr/hll.rs b/src/common/function/src/aggr/hll.rs new file mode 100644 index 0000000000..b4df0d77f8 --- /dev/null +++ b/src/common/function/src/aggr/hll.rs @@ -0,0 +1,319 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use common_query::prelude::*; +use common_telemetry::trace; +use datafusion::arrow::array::ArrayRef; +use datafusion::common::cast::{as_binary_array, as_string_array}; +use datafusion::common::not_impl_err; +use datafusion::error::{DataFusionError, Result as DfResult}; +use datafusion::logical_expr::function::AccumulatorArgs; +use datafusion::logical_expr::{Accumulator as DfAccumulator, AggregateUDF}; +use datafusion::prelude::create_udaf; +use datatypes::arrow::datatypes::DataType; +use hyperloglogplus::{HyperLogLog, HyperLogLogPlus}; + +use crate::utils::FixedRandomState; + +pub const HLL_NAME: &str = "hll"; +pub const HLL_MERGE_NAME: &str = "hll_merge"; + +const DEFAULT_PRECISION: u8 = 14; + +pub(crate) type HllStateType = HyperLogLogPlus; + +pub struct HllState { + hll: HllStateType, +} + +impl std::fmt::Debug for HllState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "HllState") + } +} + +impl Default for HllState { + fn default() -> Self { + Self::new() + } +} + +impl HllState { + pub fn new() -> Self { + Self { + // Safety: the DEFAULT_PRECISION is fixed and valid + hll: HllStateType::new(DEFAULT_PRECISION, FixedRandomState::new()).unwrap(), + } + } + + /// Create a UDF for the `hll` function. + /// + /// `hll` accepts a string column and aggregates the + /// values into a HyperLogLog state. + pub fn state_udf_impl() -> AggregateUDF { + create_udaf( + HLL_NAME, + vec![DataType::Utf8], + Arc::new(DataType::Binary), + Volatility::Immutable, + Arc::new(Self::create_accumulator), + Arc::new(vec![DataType::Binary]), + ) + } + + /// Create a UDF for the `hll_merge` function. + /// + /// `hll_merge` accepts a binary column of states generated by `hll` + /// and merges them into a single state. + pub fn merge_udf_impl() -> AggregateUDF { + create_udaf( + HLL_MERGE_NAME, + vec![DataType::Binary], + Arc::new(DataType::Binary), + Volatility::Immutable, + Arc::new(Self::create_merge_accumulator), + Arc::new(vec![DataType::Binary]), + ) + } + + fn update(&mut self, value: &str) { + self.hll.insert(value); + } + + fn merge(&mut self, raw: &[u8]) { + if let Ok(serialized) = bincode::deserialize::(raw) { + if let Ok(()) = self.hll.merge(&serialized) { + return; + } + } + trace!("Warning: Failed to merge HyperLogLog from {:?}", raw); + } + + fn create_accumulator(acc_args: AccumulatorArgs) -> DfResult> { + let data_type = acc_args.exprs[0].data_type(acc_args.schema)?; + + match data_type { + DataType::Utf8 => Ok(Box::new(HllState::new())), + other => not_impl_err!("{HLL_NAME} does not support data type: {other}"), + } + } + + fn create_merge_accumulator(acc_args: AccumulatorArgs) -> DfResult> { + let data_type = acc_args.exprs[0].data_type(acc_args.schema)?; + + match data_type { + DataType::Binary => Ok(Box::new(HllState::new())), + other => not_impl_err!("{HLL_MERGE_NAME} does not support data type: {other}"), + } + } +} + +impl DfAccumulator for HllState { + fn update_batch(&mut self, values: &[ArrayRef]) -> DfResult<()> { + let array = &values[0]; + + match array.data_type() { + DataType::Utf8 => { + let string_array = as_string_array(array)?; + for value in string_array.iter().flatten() { + self.update(value); + } + } + DataType::Binary => { + let binary_array = as_binary_array(array)?; + for v in binary_array.iter().flatten() { + self.merge(v); + } + } + _ => { + return not_impl_err!( + "HLL functions do not support data type: {}", + array.data_type() + ) + } + } + + Ok(()) + } + + fn evaluate(&mut self) -> DfResult { + Ok(ScalarValue::Binary(Some( + bincode::serialize(&self.hll).map_err(|e| { + DataFusionError::Internal(format!("Failed to serialize HyperLogLog: {}", e)) + })?, + ))) + } + + fn size(&self) -> usize { + std::mem::size_of_val(&self.hll) + } + + fn state(&mut self) -> DfResult> { + Ok(vec![ScalarValue::Binary(Some( + bincode::serialize(&self.hll).map_err(|e| { + DataFusionError::Internal(format!("Failed to serialize HyperLogLog: {}", e)) + })?, + ))]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> DfResult<()> { + let array = &states[0]; + let binary_array = as_binary_array(array)?; + for v in binary_array.iter().flatten() { + self.merge(v); + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use datafusion::arrow::array::{BinaryArray, StringArray}; + + use super::*; + + #[test] + fn test_hll_basic() { + let mut state = HllState::new(); + state.update("1"); + state.update("2"); + state.update("3"); + + let result = state.evaluate().unwrap(); + if let ScalarValue::Binary(Some(bytes)) = result { + let mut hll: HllStateType = bincode::deserialize(&bytes).unwrap(); + assert_eq!(hll.count().trunc() as u32, 3); + } else { + panic!("Expected binary scalar value"); + } + } + + #[test] + fn test_hll_roundtrip() { + let mut state = HllState::new(); + state.update("1"); + state.update("2"); + + // Serialize + let serialized = state.evaluate().unwrap(); + + // Create new state and merge the serialized data + let mut new_state = HllState::new(); + if let ScalarValue::Binary(Some(bytes)) = &serialized { + new_state.merge(bytes); + + // Verify the merged state matches original + let result = new_state.evaluate().unwrap(); + if let ScalarValue::Binary(Some(new_bytes)) = result { + let mut original: HllStateType = bincode::deserialize(bytes).unwrap(); + let mut merged: HllStateType = bincode::deserialize(&new_bytes).unwrap(); + assert_eq!(original.count(), merged.count()); + } else { + panic!("Expected binary scalar value"); + } + } else { + panic!("Expected binary scalar value"); + } + } + + #[test] + fn test_hll_batch_update() { + let mut state = HllState::new(); + + // Test string values + let str_values = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i"]; + let str_array = Arc::new(StringArray::from(str_values)) as ArrayRef; + state.update_batch(&[str_array]).unwrap(); + + let result = state.evaluate().unwrap(); + if let ScalarValue::Binary(Some(bytes)) = result { + let mut hll: HllStateType = bincode::deserialize(&bytes).unwrap(); + assert_eq!(hll.count().trunc() as u32, 9); + } else { + panic!("Expected binary scalar value"); + } + } + + #[test] + fn test_hll_merge_batch() { + let mut state1 = HllState::new(); + state1.update("1"); + let state1_binary = state1.evaluate().unwrap(); + + let mut state2 = HllState::new(); + state2.update("2"); + let state2_binary = state2.evaluate().unwrap(); + + let mut merged_state = HllState::new(); + if let (ScalarValue::Binary(Some(bytes1)), ScalarValue::Binary(Some(bytes2))) = + (&state1_binary, &state2_binary) + { + let binary_array = Arc::new(BinaryArray::from(vec![ + bytes1.as_slice(), + bytes2.as_slice(), + ])) as ArrayRef; + merged_state.merge_batch(&[binary_array]).unwrap(); + + let result = merged_state.evaluate().unwrap(); + if let ScalarValue::Binary(Some(bytes)) = result { + let mut hll: HllStateType = bincode::deserialize(&bytes).unwrap(); + assert_eq!(hll.count().trunc() as u32, 2); + } else { + panic!("Expected binary scalar value"); + } + } else { + panic!("Expected binary scalar values"); + } + } + + #[test] + fn test_hll_merge_function() { + // Create two HLL states with different values + let mut state1 = HllState::new(); + state1.update("1"); + state1.update("2"); + let state1_binary = state1.evaluate().unwrap(); + + let mut state2 = HllState::new(); + state2.update("2"); + state2.update("3"); + let state2_binary = state2.evaluate().unwrap(); + + // Create a merge state and merge both states + let mut merge_state = HllState::new(); + if let (ScalarValue::Binary(Some(bytes1)), ScalarValue::Binary(Some(bytes2))) = + (&state1_binary, &state2_binary) + { + let binary_array = Arc::new(BinaryArray::from(vec![ + bytes1.as_slice(), + bytes2.as_slice(), + ])) as ArrayRef; + merge_state.update_batch(&[binary_array]).unwrap(); + + let result = merge_state.evaluate().unwrap(); + if let ScalarValue::Binary(Some(bytes)) = result { + let mut hll: HllStateType = bincode::deserialize(&bytes).unwrap(); + // Should have 3 unique values: "1", "2", "3" + assert_eq!(hll.count().trunc() as u32, 3); + } else { + panic!("Expected binary scalar value"); + } + } else { + panic!("Expected binary scalar values"); + } + } +} diff --git a/src/common/function/src/function_registry.rs b/src/common/function/src/function_registry.rs index e4a3f66b2b..1761f6ef50 100644 --- a/src/common/function/src/function_registry.rs +++ b/src/common/function/src/function_registry.rs @@ -22,6 +22,7 @@ use crate::function::{AsyncFunctionRef, FunctionRef}; use crate::scalars::aggregate::{AggregateFunctionMetaRef, AggregateFunctions}; use crate::scalars::date::DateFunction; use crate::scalars::expression::ExpressionFunction; +use crate::scalars::hll_count::HllCalcFunction; use crate::scalars::json::JsonFunction; use crate::scalars::matches::MatchesFunction; use crate::scalars::math::MathFunction; @@ -107,6 +108,7 @@ pub static FUNCTION_REGISTRY: Lazy> = Lazy::new(|| { DateFunction::register(&function_registry); ExpressionFunction::register(&function_registry); UddSketchCalcFunction::register(&function_registry); + HllCalcFunction::register(&function_registry); // Aggregate functions AggregateFunctions::register(&function_registry); diff --git a/src/common/function/src/scalars.rs b/src/common/function/src/scalars.rs index c6b9d5dc9d..cd39880b90 100644 --- a/src/common/function/src/scalars.rs +++ b/src/common/function/src/scalars.rs @@ -22,6 +22,7 @@ pub mod matches; pub mod math; pub mod vector; +pub(crate) mod hll_count; #[cfg(test)] pub(crate) mod test; pub(crate) mod timestamp; diff --git a/src/common/function/src/scalars/hll_count.rs b/src/common/function/src/scalars/hll_count.rs new file mode 100644 index 0000000000..e2a00d9d49 --- /dev/null +++ b/src/common/function/src/scalars/hll_count.rs @@ -0,0 +1,175 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Implementation of the scalar function `hll_count`. + +use std::fmt; +use std::fmt::Display; +use std::sync::Arc; + +use common_query::error::{DowncastVectorSnafu, InvalidFuncArgsSnafu, Result}; +use common_query::prelude::{Signature, Volatility}; +use datatypes::data_type::ConcreteDataType; +use datatypes::prelude::Vector; +use datatypes::scalars::{ScalarVector, ScalarVectorBuilder}; +use datatypes::vectors::{BinaryVector, MutableVector, UInt64VectorBuilder, VectorRef}; +use hyperloglogplus::HyperLogLog; +use snafu::OptionExt; + +use crate::aggr::HllStateType; +use crate::function::{Function, FunctionContext}; +use crate::function_registry::FunctionRegistry; + +const NAME: &str = "hll_count"; + +/// HllCalcFunction implements the scalar function `hll_count`. +/// +/// It accepts one argument: +/// 1. The serialized HyperLogLogPlus state, as produced by the aggregator (binary). +/// +/// For each row, it deserializes the sketch and returns the estimated cardinality. +#[derive(Debug, Default)] +pub struct HllCalcFunction; + +impl HllCalcFunction { + pub fn register(registry: &FunctionRegistry) { + registry.register(Arc::new(HllCalcFunction)); + } +} + +impl Display for HllCalcFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", NAME.to_ascii_uppercase()) + } +} + +impl Function for HllCalcFunction { + fn name(&self) -> &str { + NAME + } + + fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result { + Ok(ConcreteDataType::uint64_datatype()) + } + + fn signature(&self) -> Signature { + // Only argument: HyperLogLogPlus state (binary) + Signature::exact( + vec![ConcreteDataType::binary_datatype()], + Volatility::Immutable, + ) + } + + fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + if columns.len() != 1 { + return InvalidFuncArgsSnafu { + err_msg: format!("hll_count expects 1 argument, got {}", columns.len()), + } + .fail(); + } + + let hll_vec = columns[0] + .as_any() + .downcast_ref::() + .with_context(|| DowncastVectorSnafu { + err_msg: format!("expect BinaryVector, got {}", columns[0].vector_type_name()), + })?; + let len = hll_vec.len(); + let mut builder = UInt64VectorBuilder::with_capacity(len); + + for i in 0..len { + let hll_opt = hll_vec.get_data(i); + + if hll_opt.is_none() { + builder.push_null(); + continue; + } + + let hll_bytes = hll_opt.unwrap(); + + // Deserialize the HyperLogLogPlus from its bincode representation + let mut hll: HllStateType = match bincode::deserialize(hll_bytes) { + Ok(h) => h, + Err(e) => { + common_telemetry::trace!("Failed to deserialize HyperLogLogPlus: {}", e); + builder.push_null(); + continue; + } + }; + + builder.push(Some(hll.count().round() as u64)); + } + + Ok(builder.to_vector()) + } +} + +#[cfg(test)] +mod tests { + use datatypes::vectors::BinaryVector; + + use super::*; + use crate::utils::FixedRandomState; + + #[test] + fn test_hll_count_function() { + let function = HllCalcFunction; + assert_eq!("hll_count", function.name()); + assert_eq!( + ConcreteDataType::uint64_datatype(), + function + .return_type(&[ConcreteDataType::uint64_datatype()]) + .unwrap() + ); + + // Create a test HLL + let mut hll = HllStateType::new(14, FixedRandomState::new()).unwrap(); + for i in 1..=10 { + hll.insert(&i.to_string()); + } + + let serialized_bytes = bincode::serialize(&hll).unwrap(); + let args: Vec = vec![Arc::new(BinaryVector::from(vec![Some(serialized_bytes)]))]; + + let result = function.eval(FunctionContext::default(), &args).unwrap(); + assert_eq!(result.len(), 1); + + // Test cardinality estimate + if let datatypes::value::Value::UInt64(v) = result.get(0) { + assert_eq!(v, 10); + } else { + panic!("Expected uint64 value"); + } + } + + #[test] + fn test_hll_count_function_errors() { + let function = HllCalcFunction; + + // Test with invalid number of arguments + let args: Vec = vec![]; + let result = function.eval(FunctionContext::default(), &args); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("hll_count expects 1 argument")); + + // Test with invalid binary data + let args: Vec = vec![Arc::new(BinaryVector::from(vec![Some(vec![1, 2, 3])]))]; // Invalid binary data + let result = function.eval(FunctionContext::default(), &args).unwrap(); + assert_eq!(result.len(), 1); + assert!(matches!(result.get(0), datatypes::value::Value::Null)); + } +} diff --git a/src/common/function/src/utils.rs b/src/common/function/src/utils.rs index f2c18d5f6c..b2daac35b8 100644 --- a/src/common/function/src/utils.rs +++ b/src/common/function/src/utils.rs @@ -12,6 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::hash::BuildHasher; + +use ahash::RandomState; +use serde::{Deserialize, Serialize}; + /// Escapes special characters in the provided pattern string for `LIKE`. /// /// Specifically, it prefixes the backslash (`\`), percent (`%`), and underscore (`_`) @@ -32,6 +37,71 @@ pub fn escape_like_pattern(pattern: &str) -> String { }) .collect::() } + +/// A random state with fixed seeds. +/// +/// This is used to ensure that the hash values are consistent across +/// different processes, and easy to serialize and deserialize. +#[derive(Debug)] +pub struct FixedRandomState { + state: RandomState, +} + +impl FixedRandomState { + // some random seeds + const RANDOM_SEED_0: u64 = 0x517cc1b727220a95; + const RANDOM_SEED_1: u64 = 0x428a2f98d728ae22; + const RANDOM_SEED_2: u64 = 0x7137449123ef65cd; + const RANDOM_SEED_3: u64 = 0xb5c0fbcfec4d3b2f; + + pub fn new() -> Self { + Self { + state: ahash::RandomState::with_seeds( + Self::RANDOM_SEED_0, + Self::RANDOM_SEED_1, + Self::RANDOM_SEED_2, + Self::RANDOM_SEED_3, + ), + } + } +} + +impl Default for FixedRandomState { + fn default() -> Self { + Self::new() + } +} + +impl BuildHasher for FixedRandomState { + type Hasher = ahash::AHasher; + + fn build_hasher(&self) -> Self::Hasher { + self.state.build_hasher() + } + + fn hash_one(&self, x: T) -> u64 { + self.state.hash_one(x) + } +} + +impl Serialize for FixedRandomState { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_unit() + } +} + +impl<'de> Deserialize<'de> for FixedRandomState { + fn deserialize(_deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + Ok(Self::new()) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/query/src/datafusion/planner.rs b/src/query/src/datafusion/planner.rs index 909aa5460e..25f1015735 100644 --- a/src/query/src/datafusion/planner.rs +++ b/src/query/src/datafusion/planner.rs @@ -18,7 +18,9 @@ use std::sync::Arc; use arrow_schema::DataType; use catalog::table_source::DfTableSourceProvider; -use common_function::aggr::{UddSketchState, UDDSKETCH_STATE_NAME}; +use common_function::aggr::{ + HllState, UddSketchState, HLL_MERGE_NAME, HLL_NAME, UDDSKETCH_STATE_NAME, +}; use common_function::scalars::udf::create_udf; use common_query::logical_plan::create_aggregate_function; use datafusion::common::TableReference; @@ -169,6 +171,12 @@ impl ContextProvider for DfContextProviderAdapter { if name == UDDSKETCH_STATE_NAME { return Some(Arc::new(UddSketchState::udf_impl())); } + if name == HLL_NAME { + return Some(Arc::new(HllState::state_udf_impl())); + } + if name == HLL_MERGE_NAME { + return Some(Arc::new(HllState::merge_udf_impl())); + } self.engine_state.aggregate_function(name).map_or_else( || self.session_state.aggregate_functions().get(name).cloned(), diff --git a/src/query/src/query_engine/default_serializer.rs b/src/query/src/query_engine/default_serializer.rs index 60ca46e1fd..d35feeb1a2 100644 --- a/src/query/src/query_engine/default_serializer.rs +++ b/src/query/src/query_engine/default_serializer.rs @@ -15,7 +15,7 @@ use std::sync::Arc; use common_error::ext::BoxedError; -use common_function::aggr::UddSketchState; +use common_function::aggr::{HllState, UddSketchState}; use common_function::function_registry::FUNCTION_REGISTRY; use common_function::scalars::udf::create_udf; use common_query::error::RegisterUdfSnafu; @@ -127,6 +127,8 @@ impl SubstraitPlanDecoder for DefaultPlanDecoder { .register_udf(udf) .context(RegisterUdfSnafu { name: func.name() })?; let _ = session_state.register_udaf(Arc::new(UddSketchState::udf_impl())); + let _ = session_state.register_udaf(Arc::new(HllState::state_udf_impl())); + let _ = session_state.register_udaf(Arc::new(HllState::merge_udf_impl())); } let logical_plan = DFLogicalSubstraitConvertor .decode(message, session_state) diff --git a/tests/cases/standalone/common/aggregate/hll.result b/tests/cases/standalone/common/aggregate/hll.result new file mode 100644 index 0000000000..092fe069bd --- /dev/null +++ b/tests/cases/standalone/common/aggregate/hll.result @@ -0,0 +1,84 @@ +CREATE TABLE test_hll ( + `id` INT PRIMARY KEY, + `value` STRING, + `ts` timestamp time index default now() +); + +Affected Rows: 0 + +INSERT INTO test_hll (`id`, `value`) VALUES + (1, "a"), + (2, "b"), + (5, "e"), + (6, "f"), + (7, "g"), + (8, "h"), + (9, "i"), + (10, "j"), + (11, "i"), + (12, "j"), + (13, "i"), + (14, "n"), + (15, "o"); + +Affected Rows: 13 + +select hll_count(hll(`value`)) from test_hll; + ++--------------------------------+ +| hll_count(hll(test_hll.value)) | ++--------------------------------+ +| 10 | ++--------------------------------+ + +INSERT INTO test_hll (`id`, `value`) VALUES + (16, "b"), + (17, "i"), + (18, "j"), + (19, "s"), + (20, "t"); + +Affected Rows: 5 + +select hll_count(hll(`value`)) from test_hll; + ++--------------------------------+ +| hll_count(hll(test_hll.value)) | ++--------------------------------+ +| 12 | ++--------------------------------+ + +create table test_hll_merge ( + `id` INT PRIMARY KEY, + `state` BINARY, + `ts` timestamp time index default now() +); + +Affected Rows: 0 + +insert into test_hll_merge (`id`, `state`) +select 1, hll(`value`) from test_hll; + +Affected Rows: 1 + +insert into test_hll_merge (`id`, `state`) +select 2, hll(`value`) from test_hll; + +Affected Rows: 1 + +select hll_count(hll_merge(`state`)) from test_hll_merge; + ++--------------------------------------------+ +| hll_count(hll_merge(test_hll_merge.state)) | ++--------------------------------------------+ +| 12 | ++--------------------------------------------+ + +drop table test_hll; + +Affected Rows: 0 + +drop table test_hll_merge; + +Affected Rows: 0 + diff --git a/tests/cases/standalone/common/aggregate/hll.sql b/tests/cases/standalone/common/aggregate/hll.sql new file mode 100644 index 0000000000..7aa029bcc7 --- /dev/null +++ b/tests/cases/standalone/common/aggregate/hll.sql @@ -0,0 +1,49 @@ +CREATE TABLE test_hll ( + `id` INT PRIMARY KEY, + `value` STRING, + `ts` timestamp time index default now() +); + +INSERT INTO test_hll (`id`, `value`) VALUES + (1, "a"), + (2, "b"), + (5, "e"), + (6, "f"), + (7, "g"), + (8, "h"), + (9, "i"), + (10, "j"), + (11, "i"), + (12, "j"), + (13, "i"), + (14, "n"), + (15, "o"); + +select hll_count(hll(`value`)) from test_hll; + +INSERT INTO test_hll (`id`, `value`) VALUES + (16, "b"), + (17, "i"), + (18, "j"), + (19, "s"), + (20, "t"); + +select hll_count(hll(`value`)) from test_hll; + +create table test_hll_merge ( + `id` INT PRIMARY KEY, + `state` BINARY, + `ts` timestamp time index default now() +); + +insert into test_hll_merge (`id`, `state`) +select 1, hll(`value`) from test_hll; + +insert into test_hll_merge (`id`, `state`) +select 2, hll(`value`) from test_hll; + +select hll_count(hll_merge(`state`)) from test_hll_merge; + +drop table test_hll; + +drop table test_hll_merge;