diff --git a/src/common/function/src/aggrs/vector.rs b/src/common/function/src/aggrs/vector.rs index 5af064d002..03489a51d4 100644 --- a/src/common/function/src/aggrs/vector.rs +++ b/src/common/function/src/aggrs/vector.rs @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +use crate::aggrs::vector::avg::VectorAvg; use crate::aggrs::vector::product::VectorProduct; use crate::aggrs::vector::sum::VectorSum; use crate::function_registry::FunctionRegistry; +mod avg; mod product; mod sum; @@ -25,5 +27,6 @@ impl VectorFunction { pub fn register(registry: &FunctionRegistry) { registry.register_aggr(VectorSum::uadf_impl()); registry.register_aggr(VectorProduct::uadf_impl()); + registry.register_aggr(VectorAvg::uadf_impl()); } } diff --git a/src/common/function/src/aggrs/vector/avg.rs b/src/common/function/src/aggrs/vector/avg.rs new file mode 100644 index 0000000000..ddf1823d28 --- /dev/null +++ b/src/common/function/src/aggrs/vector/avg.rs @@ -0,0 +1,270 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::borrow::Cow; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, AsArray, BinaryArray, LargeStringArray, StringArray}; +use arrow::compute::sum; +use arrow::datatypes::UInt64Type; +use arrow_schema::{DataType, Field}; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::{ + Accumulator, AggregateUDF, Signature, SimpleAggregateUDF, TypeSignature, Volatility, +}; +use datafusion_functions_aggregate_common::accumulator::AccumulatorArgs; +use nalgebra::{Const, DVector, DVectorView, Dyn, OVector}; + +use crate::scalars::vector::impl_conv::{ + binlit_as_veclit, parse_veclit_from_strlit, veclit_to_binlit, +}; + +/// The accumulator for the `vec_avg` aggregate function. +#[derive(Debug, Default)] +pub struct VectorAvg { + sum: Option>, + count: u64, +} + +impl VectorAvg { + /// Create a new `AggregateUDF` for the `vec_avg` aggregate function. + pub fn uadf_impl() -> AggregateUDF { + let signature = Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Utf8]), + TypeSignature::Exact(vec![DataType::LargeUtf8]), + TypeSignature::Exact(vec![DataType::Binary]), + ], + Volatility::Immutable, + ); + let udaf = SimpleAggregateUDF::new_with_signature( + "vec_avg", + signature, + DataType::Binary, + Arc::new(Self::accumulator), + vec![ + Arc::new(Field::new("sum", DataType::Binary, true)), + Arc::new(Field::new("count", DataType::UInt64, true)), + ], + ); + AggregateUDF::from(udaf) + } + + fn accumulator(args: AccumulatorArgs) -> Result> { + if args.schema.fields().len() != 1 { + return Err(datafusion_common::DataFusionError::Internal(format!( + "expect creating `VEC_AVG` with only one input field, actual {}", + args.schema.fields().len() + ))); + } + + let t = args.schema.field(0).data_type(); + if !matches!(t, DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary) { + return Err(datafusion_common::DataFusionError::Internal(format!( + "unexpected input datatype {t} when creating `VEC_AVG`" + ))); + } + + Ok(Box::new(VectorAvg::default())) + } + + fn inner(&mut self, len: usize) -> &mut OVector { + self.sum + .get_or_insert_with(|| OVector::zeros_generic(Dyn(len), Const::<1>)) + } + + fn update(&mut self, values: &[ArrayRef], is_update: bool) -> Result<()> { + if values.is_empty() { + return Ok(()); + }; + + let vectors = match values[0].data_type() { + DataType::Utf8 => { + let arr: &StringArray = values[0].as_string(); + arr.iter() + .filter_map(|x| x.map(|s| parse_veclit_from_strlit(s).map_err(Into::into))) + .map(|x| x.map(Cow::Owned)) + .collect::>>()? + } + DataType::LargeUtf8 => { + let arr: &LargeStringArray = values[0].as_string(); + arr.iter() + .filter_map(|x| x.map(|s| parse_veclit_from_strlit(s).map_err(Into::into))) + .map(|x: Result>| x.map(Cow::Owned)) + .collect::>>()? + } + DataType::Binary => { + let arr: &BinaryArray = values[0].as_binary(); + arr.iter() + .filter_map(|x| x.map(|b| binlit_as_veclit(b).map_err(Into::into))) + .collect::>>()? + } + _ => { + return Err(datafusion_common::DataFusionError::NotImplemented(format!( + "unsupported data type {} for `VEC_AVG`", + values[0].data_type() + ))); + } + }; + + if vectors.is_empty() { + return Ok(()); + } + + let len = if is_update { + vectors.len() as u64 + } else { + sum(values[1].as_primitive::()).unwrap_or_default() + }; + + let dims = vectors[0].len(); + let mut sum = DVector::zeros(dims); + for v in vectors { + if v.len() != dims { + return Err(datafusion_common::DataFusionError::Execution( + "vectors length not match: VEC_AVG".to_string(), + )); + } + let v_view = DVectorView::from_slice(&v, dims); + sum += &v_view; + } + + *self.inner(dims) += sum; + self.count += len; + + Ok(()) + } +} + +impl Accumulator for VectorAvg { + fn state(&mut self) -> Result> { + let vector = match &self.sum { + None => ScalarValue::Binary(None), + Some(sum) => ScalarValue::Binary(Some(veclit_to_binlit(sum.as_slice()))), + }; + Ok(vec![vector, ScalarValue::from(self.count)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.update(values, true) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update(states, false) + } + + fn evaluate(&mut self) -> Result { + match &self.sum { + None => Ok(ScalarValue::Binary(None)), + Some(sum) => Ok(ScalarValue::Binary(Some(veclit_to_binlit( + (sum / self.count as f32).as_slice(), + )))), + } + } + + fn size(&self) -> usize { + size_of_val(self) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::array::StringArray; + use datatypes::scalars::ScalarVector; + use datatypes::vectors::{ConstantVector, StringVector, Vector}; + + use super::*; + + #[test] + fn test_update_batch() { + // test update empty batch, expect not updating anything + let mut vec_avg = VectorAvg::default(); + vec_avg.update_batch(&[]).unwrap(); + assert!(vec_avg.sum.is_none()); + assert_eq!(ScalarValue::Binary(None), vec_avg.evaluate().unwrap()); + + // test update one not-null value + let mut vec_avg = VectorAvg::default(); + let v: Vec = vec![Arc::new(StringArray::from(vec![ + Some("[1.0,2.0,3.0]".to_string()), + Some("[4.0,5.0,6.0]".to_string()), + ]))]; + vec_avg.update_batch(&v).unwrap(); + assert_eq!( + ScalarValue::Binary(Some(veclit_to_binlit(&[2.5, 3.5, 4.5]))), + vec_avg.evaluate().unwrap() + ); + + // test update one null value + let mut vec_avg = VectorAvg::default(); + let v: Vec = vec![Arc::new(StringArray::from(vec![Option::::None]))]; + vec_avg.update_batch(&v).unwrap(); + assert_eq!(ScalarValue::Binary(None), vec_avg.evaluate().unwrap()); + + // test update no null-value batch + let mut vec_avg = VectorAvg::default(); + let v: Vec = vec![Arc::new(StringArray::from(vec![ + Some("[1.0,2.0,3.0]".to_string()), + Some("[4.0,5.0,6.0]".to_string()), + Some("[7.0,8.0,9.0]".to_string()), + ]))]; + vec_avg.update_batch(&v).unwrap(); + assert_eq!( + ScalarValue::Binary(Some(veclit_to_binlit(&[4.0, 5.0, 6.0]))), + vec_avg.evaluate().unwrap() + ); + + // test update null-value batch + let mut vec_avg = VectorAvg::default(); + let v: Vec = vec![Arc::new(StringArray::from(vec![ + Some("[1.0,2.0,3.0]".to_string()), + None, + Some("[7.0,8.0,9.0]".to_string()), + ]))]; + vec_avg.update_batch(&v).unwrap(); + assert_eq!( + ScalarValue::Binary(Some(veclit_to_binlit(&[4.0, 5.0, 6.0]))), + vec_avg.evaluate().unwrap() + ); + + let mut vec_avg = VectorAvg::default(); + let v: Vec = vec![Arc::new(StringArray::from(vec![ + None, + Some("[4.0,5.0,6.0]".to_string()), + Some("[7.0,8.0,9.0]".to_string()), + ]))]; + vec_avg.update_batch(&v).unwrap(); + assert_eq!( + ScalarValue::Binary(Some(veclit_to_binlit(&[5.5, 6.5, 7.5]))), + vec_avg.evaluate().unwrap() + ); + + // test update with constant vector + let mut vec_avg = VectorAvg::default(); + let v: Vec = vec![ + Arc::new(ConstantVector::new( + Arc::new(StringVector::from_vec(vec!["[1.0,2.0,3.0]".to_string()])), + 4, + )) + .to_arrow_array(), + ]; + vec_avg.update_batch(&v).unwrap(); + assert_eq!( + ScalarValue::Binary(Some(veclit_to_binlit(&[1.0, 2.0, 3.0]))), + vec_avg.evaluate().unwrap() + ); + } +} diff --git a/src/common/function/src/scalars/vector.rs b/src/common/function/src/scalars/vector.rs index 75d66f03c5..f265cfe53a 100644 --- a/src/common/function/src/scalars/vector.rs +++ b/src/common/function/src/scalars/vector.rs @@ -14,6 +14,7 @@ mod convert; mod distance; +mod elem_avg; mod elem_product; mod elem_sum; pub mod impl_conv; @@ -64,6 +65,7 @@ impl VectorFunction { registry.register_scalar(vector_subvector::VectorSubvectorFunction::default()); registry.register_scalar(elem_sum::ElemSumFunction::default()); registry.register_scalar(elem_product::ElemProductFunction::default()); + registry.register_scalar(elem_avg::ElemAvgFunction::default()); } } diff --git a/src/common/function/src/scalars/vector/elem_avg.rs b/src/common/function/src/scalars/vector/elem_avg.rs new file mode 100644 index 0000000000..7ebee3ad41 --- /dev/null +++ b/src/common/function/src/scalars/vector/elem_avg.rs @@ -0,0 +1,128 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt::Display; + +use datafusion::arrow::datatypes::DataType; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::ScalarValue; +use datafusion_expr::type_coercion::aggregates::{BINARYS, STRINGS}; +use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility}; +use nalgebra::DVectorView; + +use crate::function::Function; +use crate::scalars::vector::{VectorCalculator, impl_conv}; + +const NAME: &str = "vec_elem_avg"; + +#[derive(Debug, Clone)] +pub(crate) struct ElemAvgFunction { + signature: Signature, +} + +impl Default for ElemAvgFunction { + fn default() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Uniform(1, STRINGS.to_vec()), + TypeSignature::Uniform(1, BINARYS.to_vec()), + TypeSignature::Uniform(1, vec![DataType::BinaryView]), + ], + Volatility::Immutable, + ), + } + } +} + +impl Function for ElemAvgFunction { + fn name(&self) -> &str { + NAME + } + + fn return_type(&self, _: &[DataType]) -> datafusion_common::Result { + Ok(DataType::Float32) + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let body = |v0: &ScalarValue| -> datafusion_common::Result { + let v0 = + impl_conv::as_veclit(v0)?.map(|v0| DVectorView::from_slice(&v0, v0.len()).mean()); + Ok(ScalarValue::Float32(v0)) + }; + + let calculator = VectorCalculator { + name: self.name(), + func: body, + }; + calculator.invoke_with_single_argument(args) + } +} + +impl Display for ElemAvgFunction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", NAME.to_ascii_uppercase()) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::array::StringViewArray; + use arrow_schema::Field; + use datafusion::arrow::array::{Array, AsArray}; + use datafusion::arrow::datatypes::Float32Type; + use datafusion_common::config::ConfigOptions; + + use super::*; + + #[test] + fn test_elem_avg() { + let func = ElemAvgFunction::default(); + + let input = Arc::new(StringViewArray::from(vec![ + Some("[1.0,2.0,3.0]".to_string()), + Some("[4.0,5.0,6.0]".to_string()), + Some("[7.0,8.0,9.0]".to_string()), + None, + ])); + + let result = func + .invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Array(input.clone())], + arg_fields: vec![], + number_rows: input.len(), + return_field: Arc::new(Field::new("x", DataType::Float32, true)), + config_options: Arc::new(ConfigOptions::new()), + }) + .and_then(|v| ColumnarValue::values_to_arrays(&[v])) + .map(|mut a| a.remove(0)) + .unwrap(); + let result = result.as_primitive::(); + + assert_eq!(result.len(), 4); + assert_eq!(result.value(0), 2.0); + assert_eq!(result.value(1), 5.0); + assert_eq!(result.value(2), 8.0); + assert!(result.is_null(3)); + } +} diff --git a/src/query/src/tests.rs b/src/query/src/tests.rs index c70381d32f..4b12464b73 100644 --- a/src/query/src/tests.rs +++ b/src/query/src/tests.rs @@ -26,6 +26,7 @@ mod query_engine_test; mod time_range_filter_test; mod function; +mod vec_avg_test; mod vec_product_test; mod vec_sum_test; diff --git a/src/query/src/tests/vec_avg_test.rs b/src/query/src/tests/vec_avg_test.rs new file mode 100644 index 0000000000..46bb3528a9 --- /dev/null +++ b/src/query/src/tests/vec_avg_test.rs @@ -0,0 +1,60 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::ops::AddAssign; + +use common_function::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit}; +use datafusion_common::ScalarValue; +use datatypes::prelude::Value; +use nalgebra::{Const, DVectorView, Dyn, OVector}; + +use crate::tests::{exec_selection, function}; + +#[tokio::test] +async fn test_vec_avg_aggregator() -> Result<(), common_query::error::Error> { + common_telemetry::init_default_ut_logging(); + let engine = function::create_query_engine_for_vector10x3(); + let sql = "select VEC_AVG(vector) as vec_avg from vectors"; + let result = exec_selection(engine.clone(), sql).await; + let value = function::get_value_from_batches("vec_avg", result); + + let mut expected_value = None; + + let sql = "SELECT vector FROM vectors"; + let vectors = exec_selection(engine, sql).await; + + let column = vectors[0].column(0).to_arrow_array(); + let len = column.len(); + for i in 0..column.len() { + let v = ScalarValue::try_from_array(&column, i)?; + let vector = as_veclit(&v)?; + let Some(vector) = vector else { + expected_value = None; + break; + }; + expected_value + .get_or_insert_with(|| OVector::zeros_generic(Dyn(3), Const::<1>)) + .add_assign(&DVectorView::from_slice(&vector, vector.len())); + } + let expected_value = match expected_value.map(|mut v| { + v /= len as f32; + veclit_to_binlit(v.as_slice()) + }) { + None => Value::Null, + Some(bytes) => Value::from(bytes), + }; + assert_eq!(value, expected_value); + + Ok(()) +} diff --git a/tests/cases/standalone/common/function/vector/vector.result b/tests/cases/standalone/common/function/vector/vector.result index 57a37d638d..c546f9ad25 100644 --- a/tests/cases/standalone/common/function/vector/vector.result +++ b/tests/cases/standalone/common/function/vector/vector.result @@ -150,6 +150,38 @@ SELECT vec_elem_sum(parse_vec('[-1.0, -2.0, -3.0]')); | -6.0 | +-----------------------------------------------------+ +SELECT vec_elem_avg('[1.0, 2.0, 3.0]'); + ++---------------------------------------+ +| vec_elem_avg(Utf8("[1.0, 2.0, 3.0]")) | ++---------------------------------------+ +| 2.0 | ++---------------------------------------+ + +SELECT vec_elem_avg('[-1.0, -2.0, -3.0]'); + ++------------------------------------------+ +| vec_elem_avg(Utf8("[-1.0, -2.0, -3.0]")) | ++------------------------------------------+ +| -2.0 | ++------------------------------------------+ + +SELECT vec_elem_avg(parse_vec('[1.0, 2.0, 3.0]')); + ++--------------------------------------------------+ +| vec_elem_avg(parse_vec(Utf8("[1.0, 2.0, 3.0]"))) | ++--------------------------------------------------+ +| 2.0 | ++--------------------------------------------------+ + +SELECT vec_elem_avg(parse_vec('[-1.0, -2.0, -3.0]')); + ++-----------------------------------------------------+ +| vec_elem_avg(parse_vec(Utf8("[-1.0, -2.0, -3.0]"))) | ++-----------------------------------------------------+ +| -2.0 | ++-----------------------------------------------------+ + SELECT vec_to_string(vec_div('[1.0, 2.0]', '[3.0, 4.0]')); +---------------------------------------------------------------+ @@ -269,6 +301,21 @@ FROM ( | [4,5,6] | +---------------------------+ +SELECT vec_to_string(vec_avg(v)) +FROM ( + SELECT '[1.0, 2.0, 3.0]' AS v + UNION ALL + SELECT '[10.0, 11.0, 12.0]' AS v + UNION ALL + SELECT '[4.0, 5.0, 6.0]' AS v +); + ++---------------------------+ +| vec_to_string(vec_avg(v)) | ++---------------------------+ +| [5,6,7] | ++---------------------------+ + SELECT vec_to_string(vec_product(v)) FROM ( SELECT '[1.0, 2.0, 3.0]' AS v diff --git a/tests/cases/standalone/common/function/vector/vector.sql b/tests/cases/standalone/common/function/vector/vector.sql index c441fc1480..9bbf1583f5 100644 --- a/tests/cases/standalone/common/function/vector/vector.sql +++ b/tests/cases/standalone/common/function/vector/vector.sql @@ -36,6 +36,14 @@ SELECT vec_elem_sum(parse_vec('[1.0, 2.0, 3.0]')); SELECT vec_elem_sum(parse_vec('[-1.0, -2.0, -3.0]')); +SELECT vec_elem_avg('[1.0, 2.0, 3.0]'); + +SELECT vec_elem_avg('[-1.0, -2.0, -3.0]'); + +SELECT vec_elem_avg(parse_vec('[1.0, 2.0, 3.0]')); + +SELECT vec_elem_avg(parse_vec('[-1.0, -2.0, -3.0]')); + SELECT vec_to_string(vec_div('[1.0, 2.0]', '[3.0, 4.0]')); SELECT vec_to_string(vec_div(parse_vec('[1.0, 2.0]'), '[3.0, 4.0]')); @@ -71,6 +79,15 @@ FROM ( SELECT '[4.0, 5.0, 6.0]' AS v ); +SELECT vec_to_string(vec_avg(v)) +FROM ( + SELECT '[1.0, 2.0, 3.0]' AS v + UNION ALL + SELECT '[10.0, 11.0, 12.0]' AS v + UNION ALL + SELECT '[4.0, 5.0, 6.0]' AS v +); + SELECT vec_to_string(vec_product(v)) FROM ( SELECT '[1.0, 2.0, 3.0]' AS v