feat(expr): support avg functions on vector (#7146)

* feat(expr): support vec_elem_avg function

Signed-off-by: Alan Tang <jmtangcs@gmail.com>

* feat: support vec_avg function

Signed-off-by: Alan Tang <jmtangcs@gmail.com>

* test: add more query test for avg aggregator

Signed-off-by: Alan Tang <jmtangcs@gmail.com>

* fix: fix the merge batch mode

Signed-off-by: Alan Tang <jmtangcs@gmail.com>

* refactor: use sum and count as state for avg function

Signed-off-by: Alan Tang <jmtangcs@gmail.com>

* refactor: refactor merge batch mode for avg function

Signed-off-by: Alan Tang <jmtangcs@gmail.com>

* feat: add additional vector restrictions for validation

Signed-off-by: Alan Tang <jmtangcs@gmail.com>

---------

Signed-off-by: Alan Tang <jmtangcs@gmail.com>
Co-authored-by: Yingwen <realevenyag@gmail.com>
This commit is contained in:
Alan Tang
2025-11-07 21:42:14 +08:00
committed by GitHub
parent af6bbacc8c
commit 910a383420
8 changed files with 528 additions and 0 deletions

View File

@@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::aggrs::vector::avg::VectorAvg;
use crate::aggrs::vector::product::VectorProduct;
use crate::aggrs::vector::sum::VectorSum;
use crate::function_registry::FunctionRegistry;
mod avg;
mod product;
mod sum;
@@ -25,5 +27,6 @@ impl VectorFunction {
pub fn register(registry: &FunctionRegistry) {
registry.register_aggr(VectorSum::uadf_impl());
registry.register_aggr(VectorProduct::uadf_impl());
registry.register_aggr(VectorAvg::uadf_impl());
}
}

View File

@@ -0,0 +1,270 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Cow;
use std::sync::Arc;
use arrow::array::{Array, ArrayRef, AsArray, BinaryArray, LargeStringArray, StringArray};
use arrow::compute::sum;
use arrow::datatypes::UInt64Type;
use arrow_schema::{DataType, Field};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{
Accumulator, AggregateUDF, Signature, SimpleAggregateUDF, TypeSignature, Volatility,
};
use datafusion_functions_aggregate_common::accumulator::AccumulatorArgs;
use nalgebra::{Const, DVector, DVectorView, Dyn, OVector};
use crate::scalars::vector::impl_conv::{
binlit_as_veclit, parse_veclit_from_strlit, veclit_to_binlit,
};
/// The accumulator for the `vec_avg` aggregate function.
#[derive(Debug, Default)]
pub struct VectorAvg {
sum: Option<OVector<f32, Dyn>>,
count: u64,
}
impl VectorAvg {
/// Create a new `AggregateUDF` for the `vec_avg` aggregate function.
pub fn uadf_impl() -> AggregateUDF {
let signature = Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Utf8]),
TypeSignature::Exact(vec![DataType::LargeUtf8]),
TypeSignature::Exact(vec![DataType::Binary]),
],
Volatility::Immutable,
);
let udaf = SimpleAggregateUDF::new_with_signature(
"vec_avg",
signature,
DataType::Binary,
Arc::new(Self::accumulator),
vec![
Arc::new(Field::new("sum", DataType::Binary, true)),
Arc::new(Field::new("count", DataType::UInt64, true)),
],
);
AggregateUDF::from(udaf)
}
fn accumulator(args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
if args.schema.fields().len() != 1 {
return Err(datafusion_common::DataFusionError::Internal(format!(
"expect creating `VEC_AVG` with only one input field, actual {}",
args.schema.fields().len()
)));
}
let t = args.schema.field(0).data_type();
if !matches!(t, DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary) {
return Err(datafusion_common::DataFusionError::Internal(format!(
"unexpected input datatype {t} when creating `VEC_AVG`"
)));
}
Ok(Box::new(VectorAvg::default()))
}
fn inner(&mut self, len: usize) -> &mut OVector<f32, Dyn> {
self.sum
.get_or_insert_with(|| OVector::zeros_generic(Dyn(len), Const::<1>))
}
fn update(&mut self, values: &[ArrayRef], is_update: bool) -> Result<()> {
if values.is_empty() {
return Ok(());
};
let vectors = match values[0].data_type() {
DataType::Utf8 => {
let arr: &StringArray = values[0].as_string();
arr.iter()
.filter_map(|x| x.map(|s| parse_veclit_from_strlit(s).map_err(Into::into)))
.map(|x| x.map(Cow::Owned))
.collect::<Result<Vec<_>>>()?
}
DataType::LargeUtf8 => {
let arr: &LargeStringArray = values[0].as_string();
arr.iter()
.filter_map(|x| x.map(|s| parse_veclit_from_strlit(s).map_err(Into::into)))
.map(|x: Result<Vec<f32>>| x.map(Cow::Owned))
.collect::<Result<Vec<_>>>()?
}
DataType::Binary => {
let arr: &BinaryArray = values[0].as_binary();
arr.iter()
.filter_map(|x| x.map(|b| binlit_as_veclit(b).map_err(Into::into)))
.collect::<Result<Vec<_>>>()?
}
_ => {
return Err(datafusion_common::DataFusionError::NotImplemented(format!(
"unsupported data type {} for `VEC_AVG`",
values[0].data_type()
)));
}
};
if vectors.is_empty() {
return Ok(());
}
let len = if is_update {
vectors.len() as u64
} else {
sum(values[1].as_primitive::<UInt64Type>()).unwrap_or_default()
};
let dims = vectors[0].len();
let mut sum = DVector::zeros(dims);
for v in vectors {
if v.len() != dims {
return Err(datafusion_common::DataFusionError::Execution(
"vectors length not match: VEC_AVG".to_string(),
));
}
let v_view = DVectorView::from_slice(&v, dims);
sum += &v_view;
}
*self.inner(dims) += sum;
self.count += len;
Ok(())
}
}
impl Accumulator for VectorAvg {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let vector = match &self.sum {
None => ScalarValue::Binary(None),
Some(sum) => ScalarValue::Binary(Some(veclit_to_binlit(sum.as_slice()))),
};
Ok(vec![vector, ScalarValue::from(self.count)])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
self.update(values, true)
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
self.update(states, false)
}
fn evaluate(&mut self) -> Result<ScalarValue> {
match &self.sum {
None => Ok(ScalarValue::Binary(None)),
Some(sum) => Ok(ScalarValue::Binary(Some(veclit_to_binlit(
(sum / self.count as f32).as_slice(),
)))),
}
}
fn size(&self) -> usize {
size_of_val(self)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow::array::StringArray;
use datatypes::scalars::ScalarVector;
use datatypes::vectors::{ConstantVector, StringVector, Vector};
use super::*;
#[test]
fn test_update_batch() {
// test update empty batch, expect not updating anything
let mut vec_avg = VectorAvg::default();
vec_avg.update_batch(&[]).unwrap();
assert!(vec_avg.sum.is_none());
assert_eq!(ScalarValue::Binary(None), vec_avg.evaluate().unwrap());
// test update one not-null value
let mut vec_avg = VectorAvg::default();
let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
]))];
vec_avg.update_batch(&v).unwrap();
assert_eq!(
ScalarValue::Binary(Some(veclit_to_binlit(&[2.5, 3.5, 4.5]))),
vec_avg.evaluate().unwrap()
);
// test update one null value
let mut vec_avg = VectorAvg::default();
let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![Option::<String>::None]))];
vec_avg.update_batch(&v).unwrap();
assert_eq!(ScalarValue::Binary(None), vec_avg.evaluate().unwrap());
// test update no null-value batch
let mut vec_avg = VectorAvg::default();
let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
Some("[7.0,8.0,9.0]".to_string()),
]))];
vec_avg.update_batch(&v).unwrap();
assert_eq!(
ScalarValue::Binary(Some(veclit_to_binlit(&[4.0, 5.0, 6.0]))),
vec_avg.evaluate().unwrap()
);
// test update null-value batch
let mut vec_avg = VectorAvg::default();
let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
None,
Some("[7.0,8.0,9.0]".to_string()),
]))];
vec_avg.update_batch(&v).unwrap();
assert_eq!(
ScalarValue::Binary(Some(veclit_to_binlit(&[4.0, 5.0, 6.0]))),
vec_avg.evaluate().unwrap()
);
let mut vec_avg = VectorAvg::default();
let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
None,
Some("[4.0,5.0,6.0]".to_string()),
Some("[7.0,8.0,9.0]".to_string()),
]))];
vec_avg.update_batch(&v).unwrap();
assert_eq!(
ScalarValue::Binary(Some(veclit_to_binlit(&[5.5, 6.5, 7.5]))),
vec_avg.evaluate().unwrap()
);
// test update with constant vector
let mut vec_avg = VectorAvg::default();
let v: Vec<ArrayRef> = vec![
Arc::new(ConstantVector::new(
Arc::new(StringVector::from_vec(vec!["[1.0,2.0,3.0]".to_string()])),
4,
))
.to_arrow_array(),
];
vec_avg.update_batch(&v).unwrap();
assert_eq!(
ScalarValue::Binary(Some(veclit_to_binlit(&[1.0, 2.0, 3.0]))),
vec_avg.evaluate().unwrap()
);
}
}

View File

@@ -14,6 +14,7 @@
mod convert;
mod distance;
mod elem_avg;
mod elem_product;
mod elem_sum;
pub mod impl_conv;
@@ -64,6 +65,7 @@ impl VectorFunction {
registry.register_scalar(vector_subvector::VectorSubvectorFunction::default());
registry.register_scalar(elem_sum::ElemSumFunction::default());
registry.register_scalar(elem_product::ElemProductFunction::default());
registry.register_scalar(elem_avg::ElemAvgFunction::default());
}
}

View File

@@ -0,0 +1,128 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::fmt::Display;
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::ScalarValue;
use datafusion_expr::type_coercion::aggregates::{BINARYS, STRINGS};
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
use nalgebra::DVectorView;
use crate::function::Function;
use crate::scalars::vector::{VectorCalculator, impl_conv};
const NAME: &str = "vec_elem_avg";
#[derive(Debug, Clone)]
pub(crate) struct ElemAvgFunction {
signature: Signature,
}
impl Default for ElemAvgFunction {
fn default() -> Self {
Self {
signature: Signature::one_of(
vec![
TypeSignature::Uniform(1, STRINGS.to_vec()),
TypeSignature::Uniform(1, BINARYS.to_vec()),
TypeSignature::Uniform(1, vec![DataType::BinaryView]),
],
Volatility::Immutable,
),
}
}
}
impl Function for ElemAvgFunction {
fn name(&self) -> &str {
NAME
}
fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
Ok(DataType::Float32)
}
fn signature(&self) -> &Signature {
&self.signature
}
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
let v0 =
impl_conv::as_veclit(v0)?.map(|v0| DVectorView::from_slice(&v0, v0.len()).mean());
Ok(ScalarValue::Float32(v0))
};
let calculator = VectorCalculator {
name: self.name(),
func: body,
};
calculator.invoke_with_single_argument(args)
}
}
impl Display for ElemAvgFunction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", NAME.to_ascii_uppercase())
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow::array::StringViewArray;
use arrow_schema::Field;
use datafusion::arrow::array::{Array, AsArray};
use datafusion::arrow::datatypes::Float32Type;
use datafusion_common::config::ConfigOptions;
use super::*;
#[test]
fn test_elem_avg() {
let func = ElemAvgFunction::default();
let input = Arc::new(StringViewArray::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
Some("[7.0,8.0,9.0]".to_string()),
None,
]));
let result = func
.invoke_with_args(ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input.clone())],
arg_fields: vec![],
number_rows: input.len(),
return_field: Arc::new(Field::new("x", DataType::Float32, true)),
config_options: Arc::new(ConfigOptions::new()),
})
.and_then(|v| ColumnarValue::values_to_arrays(&[v]))
.map(|mut a| a.remove(0))
.unwrap();
let result = result.as_primitive::<Float32Type>();
assert_eq!(result.len(), 4);
assert_eq!(result.value(0), 2.0);
assert_eq!(result.value(1), 5.0);
assert_eq!(result.value(2), 8.0);
assert!(result.is_null(3));
}
}

View File

@@ -26,6 +26,7 @@ mod query_engine_test;
mod time_range_filter_test;
mod function;
mod vec_avg_test;
mod vec_product_test;
mod vec_sum_test;

View File

@@ -0,0 +1,60 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::ops::AddAssign;
use common_function::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
use datafusion_common::ScalarValue;
use datatypes::prelude::Value;
use nalgebra::{Const, DVectorView, Dyn, OVector};
use crate::tests::{exec_selection, function};
#[tokio::test]
async fn test_vec_avg_aggregator() -> Result<(), common_query::error::Error> {
common_telemetry::init_default_ut_logging();
let engine = function::create_query_engine_for_vector10x3();
let sql = "select VEC_AVG(vector) as vec_avg from vectors";
let result = exec_selection(engine.clone(), sql).await;
let value = function::get_value_from_batches("vec_avg", result);
let mut expected_value = None;
let sql = "SELECT vector FROM vectors";
let vectors = exec_selection(engine, sql).await;
let column = vectors[0].column(0).to_arrow_array();
let len = column.len();
for i in 0..column.len() {
let v = ScalarValue::try_from_array(&column, i)?;
let vector = as_veclit(&v)?;
let Some(vector) = vector else {
expected_value = None;
break;
};
expected_value
.get_or_insert_with(|| OVector::zeros_generic(Dyn(3), Const::<1>))
.add_assign(&DVectorView::from_slice(&vector, vector.len()));
}
let expected_value = match expected_value.map(|mut v| {
v /= len as f32;
veclit_to_binlit(v.as_slice())
}) {
None => Value::Null,
Some(bytes) => Value::from(bytes),
};
assert_eq!(value, expected_value);
Ok(())
}

View File

@@ -150,6 +150,38 @@ SELECT vec_elem_sum(parse_vec('[-1.0, -2.0, -3.0]'));
| -6.0 |
+-----------------------------------------------------+
SELECT vec_elem_avg('[1.0, 2.0, 3.0]');
+---------------------------------------+
| vec_elem_avg(Utf8("[1.0, 2.0, 3.0]")) |
+---------------------------------------+
| 2.0 |
+---------------------------------------+
SELECT vec_elem_avg('[-1.0, -2.0, -3.0]');
+------------------------------------------+
| vec_elem_avg(Utf8("[-1.0, -2.0, -3.0]")) |
+------------------------------------------+
| -2.0 |
+------------------------------------------+
SELECT vec_elem_avg(parse_vec('[1.0, 2.0, 3.0]'));
+--------------------------------------------------+
| vec_elem_avg(parse_vec(Utf8("[1.0, 2.0, 3.0]"))) |
+--------------------------------------------------+
| 2.0 |
+--------------------------------------------------+
SELECT vec_elem_avg(parse_vec('[-1.0, -2.0, -3.0]'));
+-----------------------------------------------------+
| vec_elem_avg(parse_vec(Utf8("[-1.0, -2.0, -3.0]"))) |
+-----------------------------------------------------+
| -2.0 |
+-----------------------------------------------------+
SELECT vec_to_string(vec_div('[1.0, 2.0]', '[3.0, 4.0]'));
+---------------------------------------------------------------+
@@ -269,6 +301,21 @@ FROM (
| [4,5,6] |
+---------------------------+
SELECT vec_to_string(vec_avg(v))
FROM (
SELECT '[1.0, 2.0, 3.0]' AS v
UNION ALL
SELECT '[10.0, 11.0, 12.0]' AS v
UNION ALL
SELECT '[4.0, 5.0, 6.0]' AS v
);
+---------------------------+
| vec_to_string(vec_avg(v)) |
+---------------------------+
| [5,6,7] |
+---------------------------+
SELECT vec_to_string(vec_product(v))
FROM (
SELECT '[1.0, 2.0, 3.0]' AS v

View File

@@ -36,6 +36,14 @@ SELECT vec_elem_sum(parse_vec('[1.0, 2.0, 3.0]'));
SELECT vec_elem_sum(parse_vec('[-1.0, -2.0, -3.0]'));
SELECT vec_elem_avg('[1.0, 2.0, 3.0]');
SELECT vec_elem_avg('[-1.0, -2.0, -3.0]');
SELECT vec_elem_avg(parse_vec('[1.0, 2.0, 3.0]'));
SELECT vec_elem_avg(parse_vec('[-1.0, -2.0, -3.0]'));
SELECT vec_to_string(vec_div('[1.0, 2.0]', '[3.0, 4.0]'));
SELECT vec_to_string(vec_div(parse_vec('[1.0, 2.0]'), '[3.0, 4.0]'));
@@ -71,6 +79,15 @@ FROM (
SELECT '[4.0, 5.0, 6.0]' AS v
);
SELECT vec_to_string(vec_avg(v))
FROM (
SELECT '[1.0, 2.0, 3.0]' AS v
UNION ALL
SELECT '[10.0, 11.0, 12.0]' AS v
UNION ALL
SELECT '[4.0, 5.0, 6.0]' AS v
);
SELECT vec_to_string(vec_product(v))
FROM (
SELECT '[1.0, 2.0, 3.0]' AS v