diff --git a/Cargo.lock b/Cargo.lock index 9675130378..24c6a95375 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -759,6 +759,7 @@ name = "common-query" version = "0.1.0" dependencies = [ "arrow2", + "common-base", "common-error", "datafusion", "datafusion-common", diff --git a/src/common/query/Cargo.toml b/src/common/query/Cargo.toml index 7623c1a332..3c369923e4 100644 --- a/src/common/query/Cargo.toml +++ b/src/common/query/Cargo.toml @@ -17,3 +17,4 @@ snafu = { version = "0.7", features = ["backtraces"] } [dev-dependencies] tokio = { version = "1.0", features = ["full"] } +common-base = {path = "../base"} diff --git a/src/common/query/src/error.rs b/src/common/query/src/error.rs index fc418546ef..8c17fd4c82 100644 --- a/src/common/query/src/error.rs +++ b/src/common/query/src/error.rs @@ -34,6 +34,12 @@ pub enum InnerError { #[snafu(display("Failed to downcast vector: {}", err_msg))] DowncastVector { err_msg: String }, + + #[snafu(display("Bad accumulator implementation: {}", err_msg))] + BadAccumulatorImpl { + err_msg: String, + backtrace: Backtrace, + }, } pub type Result = std::result::Result; @@ -43,7 +49,8 @@ impl ErrorExt for InnerError { match self { InnerError::ExecuteFunction { .. } | InnerError::CreateAccumulator { .. } - | InnerError::DowncastVector { .. } => StatusCode::EngineExecuteQuery, + | InnerError::DowncastVector { .. } + | InnerError::BadAccumulatorImpl { .. } => StatusCode::EngineExecuteQuery, InnerError::IntoVector { source, .. } => source.status_code(), InnerError::FromScalarValue { source } => source.status_code(), } diff --git a/src/common/query/src/logical_plan/accumulator.rs b/src/common/query/src/logical_plan/accumulator.rs index e7a0eeb307..b1d6f077fb 100644 --- a/src/common/query/src/logical_plan/accumulator.rs +++ b/src/common/query/src/logical_plan/accumulator.rs @@ -7,11 +7,12 @@ use arrow::array::ArrayRef; use datafusion_common::Result as DfResult; use datafusion_expr::Accumulator as DfAccumulator; use datatypes::prelude::*; +use datatypes::value::ListValue; use datatypes::vectors::Helper as VectorHelper; use datatypes::vectors::VectorRef; use snafu::ResultExt; -use crate::error::{Error, FromScalarValueSnafu, IntoVectorSnafu, Result}; +use crate::error::{self, Error, FromScalarValueSnafu, IntoVectorSnafu, Result}; use crate::prelude::*; pub type AggregateFunctionCreatorRef = Arc; @@ -87,22 +88,49 @@ pub fn make_state_function(creator: Arc) -> StateT Arc::new(move |_| Ok(Arc::new(creator.state_types()?))) } -/// A wrapper newtype for our Accumulator to DataFusion's Accumulator, +/// A wrapper type for our Accumulator to DataFusion's Accumulator, /// so to make our Accumulator able to be executed by DataFusion query engine. #[derive(Debug)] -pub struct DfAccumulatorAdaptor(pub Box); +pub struct DfAccumulatorAdaptor { + accumulator: Box, + creator: AggregateFunctionCreatorRef, +} + +impl DfAccumulatorAdaptor { + pub fn new(accumulator: Box, creator: AggregateFunctionCreatorRef) -> Self { + Self { + accumulator, + creator, + } + } +} impl DfAccumulator for DfAccumulatorAdaptor { fn state(&self) -> DfResult> { - let state = self.0.state()?; - Ok(state.into_iter().map(ScalarValue::from).collect()) + let state_values = self.accumulator.state()?; + let state_types = self.creator.state_types()?; + if state_values.len() != state_types.len() { + return error::BadAccumulatorImplSnafu { + err_msg: format!("Accumulator {:?} returned state values size do not match its state types size.", self), + } + .fail() + .map_err(Error::from)?; + } + Ok(state_values + .into_iter() + .zip(state_types.iter()) + .map(|(v, t)| try_into_scalar_value(v, t)) + .collect::>>() + .map_err(Error::from)?) } fn update_batch(&mut self, values: &[ArrayRef]) -> DfResult<()> { let vectors = VectorHelper::try_into_vectors(values) .context(FromScalarValueSnafu) .map_err(Error::from)?; - self.0.update_batch(&vectors).map_err(|e| e.into()) + self.accumulator + .update_batch(&vectors) + .map_err(|e| e.into()) } fn merge_batch(&mut self, states: &[ArrayRef]) -> DfResult<()> { @@ -116,10 +144,287 @@ impl DfAccumulator for DfAccumulatorAdaptor { .map_err(Error::from)?, ); } - self.0.merge_batch(&vectors).map_err(|e| e.into()) + self.accumulator.merge_batch(&vectors).map_err(|e| e.into()) } fn evaluate(&self) -> DfResult { - Ok(ScalarValue::from(self.0.evaluate()?)) + let value = self.accumulator.evaluate()?; + let output_type = self.creator.output_type()?; + Ok(try_into_scalar_value(value, &output_type)?) + } +} + +fn try_into_scalar_value(value: Value, datatype: &ConcreteDataType) -> Result { + if !matches!(value, Value::Null) && datatype != &value.data_type() { + return error::BadAccumulatorImplSnafu { + err_msg: format!( + "expect value to return datatype {:?}, actual: {:?}", + datatype, + value.data_type() + ), + } + .fail()?; + } + + Ok(match value { + Value::Boolean(v) => ScalarValue::Boolean(Some(v)), + Value::UInt8(v) => ScalarValue::UInt8(Some(v)), + Value::UInt16(v) => ScalarValue::UInt16(Some(v)), + Value::UInt32(v) => ScalarValue::UInt32(Some(v)), + Value::UInt64(v) => ScalarValue::UInt64(Some(v)), + Value::Int8(v) => ScalarValue::Int8(Some(v)), + Value::Int16(v) => ScalarValue::Int16(Some(v)), + Value::Int32(v) => ScalarValue::Int32(Some(v)), + Value::Int64(v) => ScalarValue::Int64(Some(v)), + Value::Float32(v) => ScalarValue::Float32(Some(v.0)), + Value::Float64(v) => ScalarValue::Float64(Some(v.0)), + Value::String(v) => ScalarValue::LargeUtf8(Some(v.as_utf8().to_string())), + Value::Binary(v) => ScalarValue::LargeBinary(Some(v.to_vec())), + Value::Date(v) => ScalarValue::Date32(Some(v)), + Value::DateTime(v) => ScalarValue::Date64(Some(v)), + Value::Null => try_convert_null_value(datatype)?, + Value::List(list) => try_convert_list_value(list)?, + }) +} + +fn try_convert_null_value(datatype: &ConcreteDataType) -> Result { + Ok(match datatype { + ConcreteDataType::Boolean(_) => ScalarValue::Boolean(None), + ConcreteDataType::Int8(_) => ScalarValue::Int8(None), + ConcreteDataType::Int16(_) => ScalarValue::Int16(None), + ConcreteDataType::Int32(_) => ScalarValue::Int32(None), + ConcreteDataType::Int64(_) => ScalarValue::Int64(None), + ConcreteDataType::UInt8(_) => ScalarValue::UInt8(None), + ConcreteDataType::UInt16(_) => ScalarValue::UInt16(None), + ConcreteDataType::UInt32(_) => ScalarValue::UInt32(None), + ConcreteDataType::UInt64(_) => ScalarValue::UInt64(None), + ConcreteDataType::Float32(_) => ScalarValue::Float32(None), + ConcreteDataType::Float64(_) => ScalarValue::Float64(None), + ConcreteDataType::Binary(_) => ScalarValue::LargeBinary(None), + ConcreteDataType::String(_) => ScalarValue::LargeUtf8(None), + _ => { + return error::BadAccumulatorImplSnafu { + err_msg: format!( + "undefined transition from null value to datatype {:?}", + datatype + ), + } + .fail()? + } + }) +} + +fn try_convert_list_value(list: ListValue) -> Result { + let vs = if let Some(items) = list.items() { + Some(Box::new( + items + .iter() + .map(|v| try_into_scalar_value(v.clone(), list.datatype())) + .collect::>>()?, + )) + } else { + None + }; + Ok(ScalarValue::List( + vs, + Box::new(list.datatype().as_arrow_type()), + )) +} + +#[cfg(test)] +mod tests { + use arrow::datatypes::DataType; + use common_base::bytes::{Bytes, StringBytes}; + use datafusion_common::ScalarValue; + use datatypes::value::{ListValue, OrderedFloat}; + + use super::*; + + #[test] + fn test_not_null_value_to_scalar_value() { + assert_eq!( + ScalarValue::Boolean(Some(true)), + try_into_scalar_value(Value::Boolean(true), &ConcreteDataType::boolean_datatype()) + .unwrap() + ); + assert_eq!( + ScalarValue::Boolean(Some(false)), + try_into_scalar_value(Value::Boolean(false), &ConcreteDataType::boolean_datatype()) + .unwrap() + ); + assert_eq!( + ScalarValue::UInt8(Some(u8::MIN + 1)), + try_into_scalar_value( + Value::UInt8(u8::MIN + 1), + &ConcreteDataType::uint8_datatype() + ) + .unwrap() + ); + assert_eq!( + ScalarValue::UInt16(Some(u16::MIN + 2)), + try_into_scalar_value( + Value::UInt16(u16::MIN + 2), + &ConcreteDataType::uint16_datatype() + ) + .unwrap() + ); + assert_eq!( + ScalarValue::UInt32(Some(u32::MIN + 3)), + try_into_scalar_value( + Value::UInt32(u32::MIN + 3), + &ConcreteDataType::uint32_datatype() + ) + .unwrap() + ); + assert_eq!( + ScalarValue::UInt64(Some(u64::MIN + 4)), + try_into_scalar_value( + Value::UInt64(u64::MIN + 4), + &ConcreteDataType::uint64_datatype() + ) + .unwrap() + ); + assert_eq!( + ScalarValue::Int8(Some(i8::MIN + 4)), + try_into_scalar_value(Value::Int8(i8::MIN + 4), &ConcreteDataType::int8_datatype()) + .unwrap() + ); + assert_eq!( + ScalarValue::Int16(Some(i16::MIN + 5)), + try_into_scalar_value( + Value::Int16(i16::MIN + 5), + &ConcreteDataType::int16_datatype() + ) + .unwrap() + ); + assert_eq!( + ScalarValue::Int32(Some(i32::MIN + 6)), + try_into_scalar_value( + Value::Int32(i32::MIN + 6), + &ConcreteDataType::int32_datatype() + ) + .unwrap() + ); + assert_eq!( + ScalarValue::Int64(Some(i64::MIN + 7)), + try_into_scalar_value( + Value::Int64(i64::MIN + 7), + &ConcreteDataType::int64_datatype() + ) + .unwrap() + ); + assert_eq!( + ScalarValue::Float32(Some(8.0f32)), + try_into_scalar_value( + Value::Float32(OrderedFloat(8.0f32)), + &ConcreteDataType::float32_datatype() + ) + .unwrap() + ); + assert_eq!( + ScalarValue::Float64(Some(9.0f64)), + try_into_scalar_value( + Value::Float64(OrderedFloat(9.0f64)), + &ConcreteDataType::float64_datatype() + ) + .unwrap() + ); + assert_eq!( + ScalarValue::LargeUtf8(Some("hello".to_string())), + try_into_scalar_value( + Value::String(StringBytes::from("hello")), + &ConcreteDataType::string_datatype() + ) + .unwrap() + ); + assert_eq!( + ScalarValue::LargeBinary(Some("world".as_bytes().to_vec())), + try_into_scalar_value( + Value::Binary(Bytes::from("world".as_bytes())), + &ConcreteDataType::binary_datatype() + ) + .unwrap() + ); + } + + #[test] + fn test_null_value_to_scalar_value() { + assert_eq!( + ScalarValue::Boolean(None), + try_into_scalar_value(Value::Null, &ConcreteDataType::boolean_datatype()).unwrap() + ); + assert_eq!( + ScalarValue::UInt8(None), + try_into_scalar_value(Value::Null, &ConcreteDataType::uint8_datatype()).unwrap() + ); + assert_eq!( + ScalarValue::UInt16(None), + try_into_scalar_value(Value::Null, &ConcreteDataType::uint16_datatype()).unwrap() + ); + assert_eq!( + ScalarValue::UInt32(None), + try_into_scalar_value(Value::Null, &ConcreteDataType::uint32_datatype()).unwrap() + ); + assert_eq!( + ScalarValue::UInt64(None), + try_into_scalar_value(Value::Null, &ConcreteDataType::uint64_datatype()).unwrap() + ); + assert_eq!( + ScalarValue::Int8(None), + try_into_scalar_value(Value::Null, &ConcreteDataType::int8_datatype()).unwrap() + ); + assert_eq!( + ScalarValue::Int16(None), + try_into_scalar_value(Value::Null, &ConcreteDataType::int16_datatype()).unwrap() + ); + assert_eq!( + ScalarValue::Int32(None), + try_into_scalar_value(Value::Null, &ConcreteDataType::int32_datatype()).unwrap() + ); + assert_eq!( + ScalarValue::Int64(None), + try_into_scalar_value(Value::Null, &ConcreteDataType::int64_datatype()).unwrap() + ); + assert_eq!( + ScalarValue::Float32(None), + try_into_scalar_value(Value::Null, &ConcreteDataType::float32_datatype()).unwrap() + ); + assert_eq!( + ScalarValue::Float64(None), + try_into_scalar_value(Value::Null, &ConcreteDataType::float64_datatype()).unwrap() + ); + assert_eq!( + ScalarValue::LargeUtf8(None), + try_into_scalar_value(Value::Null, &ConcreteDataType::string_datatype()).unwrap() + ); + assert_eq!( + ScalarValue::LargeBinary(None), + try_into_scalar_value(Value::Null, &ConcreteDataType::binary_datatype()).unwrap() + ); + } + + #[test] + fn test_list_value_to_scalar_value() { + let items = Some(Box::new(vec![Value::Int32(-1), Value::Null])); + let list = Value::List(ListValue::new(items, ConcreteDataType::int32_datatype())); + let df_list = try_into_scalar_value( + list, + &ConcreteDataType::list_datatype(ConcreteDataType::int32_datatype()), + ) + .unwrap(); + assert!(matches!(df_list, ScalarValue::List(_, _))); + match df_list { + ScalarValue::List(vs, datatype) => { + assert_eq!(*datatype, DataType::Int32); + + assert!(vs.is_some()); + let vs = *vs.unwrap(); + assert_eq!( + vs, + vec![ScalarValue::Int32(Some(-1)), ScalarValue::Int32(None)] + ); + } + _ => unreachable!(), + } } } diff --git a/src/common/query/src/logical_plan/mod.rs b/src/common/query/src/logical_plan/mod.rs index ffc0f5f317..15f9f317f1 100644 --- a/src/common/query/src/logical_plan/mod.rs +++ b/src/common/query/src/logical_plan/mod.rs @@ -49,6 +49,7 @@ pub fn create_aggregate_function( return_type, accumulator, state_type, + creator, ) } diff --git a/src/common/query/src/logical_plan/udaf.rs b/src/common/query/src/logical_plan/udaf.rs index ff629a71a8..a4e8d86743 100644 --- a/src/common/query/src/logical_plan/udaf.rs +++ b/src/common/query/src/logical_plan/udaf.rs @@ -15,6 +15,7 @@ use crate::function::{ to_df_return_type, AccumulatorFunctionImpl, ReturnTypeFunction, StateTypeFunction, }; use crate::logical_plan::accumulator::DfAccumulatorAdaptor; +use crate::logical_plan::AggregateFunctionCreatorRef; use crate::signature::Signature; /// Logical representation of a user-defined aggregate function (UDAF) @@ -31,6 +32,8 @@ pub struct AggregateFunction { pub accumulator: AccumulatorFunctionImpl, /// the accumulator's state's description as a function of the return type pub state_type: StateTypeFunction, + /// the creator that creates aggregate functions + creator: AggregateFunctionCreatorRef, } impl Debug for AggregateFunction { @@ -57,6 +60,7 @@ impl AggregateFunction { return_type: ReturnTypeFunction, accumulator: AccumulatorFunctionImpl, state_type: StateTypeFunction, + creator: AggregateFunctionCreatorRef, ) -> Self { Self { name, @@ -64,6 +68,7 @@ impl AggregateFunction { return_type, accumulator, state_type, + creator, } } } @@ -74,16 +79,20 @@ impl From for DfAggregateUdf { &udaf.name, &udaf.signature.into(), &to_df_return_type(udaf.return_type), - &to_df_accumulator_func(udaf.accumulator), + &to_df_accumulator_func(udaf.accumulator, udaf.creator.clone()), &to_df_state_type(udaf.state_type), ) } } -fn to_df_accumulator_func(func: AccumulatorFunctionImpl) -> DfAccumulatorFunctionImplementation { +fn to_df_accumulator_func( + accumulator: AccumulatorFunctionImpl, + creator: AggregateFunctionCreatorRef, +) -> DfAccumulatorFunctionImplementation { Arc::new(move || { - let acc = func()?; - Ok(Box::new(DfAccumulatorAdaptor(acc))) + let accumulator = accumulator()?; + let creator = creator.clone(); + Ok(Box::new(DfAccumulatorAdaptor::new(accumulator, creator))) }) } diff --git a/src/datatypes/src/value.rs b/src/datatypes/src/value.rs index d40e6cd72b..80248980f9 100644 --- a/src/datatypes/src/value.rs +++ b/src/datatypes/src/value.rs @@ -1,7 +1,6 @@ use std::cmp::Ordering; use common_base::bytes::{Bytes, StringBytes}; -use datafusion_common::ScalarValue; pub use ordered_float::OrderedFloat; use serde::{Deserialize, Serialize}; @@ -64,7 +63,8 @@ impl Value { Value::Float64(_) => ConcreteDataType::float64_datatype(), Value::String(_) => ConcreteDataType::string_datatype(), Value::Binary(_) => ConcreteDataType::binary_datatype(), - Value::Date(_) | Value::DateTime(_) | Value::List(_) => { + Value::List(list) => ConcreteDataType::list_datatype(list.datatype().clone()), + Value::Date(_) | Value::DateTime(_) => { unimplemented!("Unsupported data type of value {:?}", self) } } @@ -160,34 +160,6 @@ impl TryFrom for serde_json::Value { } } -impl From for ScalarValue { - fn from(value: Value) -> Self { - match value { - Value::Boolean(v) => ScalarValue::Boolean(Some(v)), - Value::UInt8(v) => ScalarValue::UInt8(Some(v)), - Value::UInt16(v) => ScalarValue::UInt16(Some(v)), - Value::UInt32(v) => ScalarValue::UInt32(Some(v)), - Value::UInt64(v) => ScalarValue::UInt64(Some(v)), - Value::Int8(v) => ScalarValue::Int8(Some(v)), - Value::Int16(v) => ScalarValue::Int16(Some(v)), - Value::Int32(v) => ScalarValue::Int32(Some(v)), - Value::Int64(v) => ScalarValue::Int64(Some(v)), - Value::Float32(v) => ScalarValue::Float32(Some(v.0)), - Value::Float64(v) => ScalarValue::Float64(Some(v.0)), - Value::String(v) => ScalarValue::LargeUtf8(Some(v.as_utf8().to_string())), - Value::Binary(v) => ScalarValue::LargeBinary(Some(v.to_vec())), - Value::Date(v) => ScalarValue::Date32(Some(v)), - Value::DateTime(v) => ScalarValue::Date64(Some(v)), - Value::Null => ScalarValue::Boolean(None), - Value::List(v) => ScalarValue::List( - v.items - .map(|vs| Box::new(vs.into_iter().map(ScalarValue::from).collect())), - Box::new(v.datatype.as_arrow_type()), - ), - } - } -} - #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct ListValue { /// List of nested Values (boxed to reduce size_of(Value)) @@ -204,6 +176,14 @@ impl ListValue { pub fn new(items: Option>>, datatype: ConcreteDataType) -> Self { Self { items, datatype } } + + pub fn items(&self) -> &Option>> { + &self.items + } + + pub fn datatype(&self) -> &ConcreteDataType { + &self.datatype + } } impl PartialOrd for ListValue { @@ -367,78 +347,6 @@ mod tests { assert_eq!(Value::Binary(Bytes::from(world)), Value::from(world)); } - #[test] - fn test_value_into_scalar_value() { - assert_eq!( - ScalarValue::Boolean(Some(true)), - Value::Boolean(true).into() - ); - assert_eq!( - ScalarValue::Boolean(Some(true)), - Value::Boolean(true).into() - ); - - assert_eq!( - ScalarValue::UInt8(Some(u8::MIN + 1)), - Value::UInt8(u8::MIN + 1).into() - ); - assert_eq!( - ScalarValue::UInt16(Some(u16::MIN + 2)), - Value::UInt16(u16::MIN + 2).into() - ); - assert_eq!( - ScalarValue::UInt32(Some(u32::MIN + 3)), - Value::UInt32(u32::MIN + 3).into() - ); - assert_eq!( - ScalarValue::UInt64(Some(u64::MIN + 4)), - Value::UInt64(u64::MIN + 4).into() - ); - - assert_eq!( - ScalarValue::Int8(Some(i8::MIN + 4)), - Value::Int8(i8::MIN + 4).into() - ); - assert_eq!( - ScalarValue::Int16(Some(i16::MIN + 5)), - Value::Int16(i16::MIN + 5).into() - ); - assert_eq!( - ScalarValue::Int32(Some(i32::MIN + 6)), - Value::Int32(i32::MIN + 6).into() - ); - assert_eq!( - ScalarValue::Int64(Some(i64::MIN + 7)), - Value::Int64(i64::MIN + 7).into() - ); - - assert_eq!( - ScalarValue::Float32(Some(8.0f32)), - Value::Float32(OrderedFloat(8.0f32)).into() - ); - assert_eq!( - ScalarValue::Float64(Some(9.0f64)), - Value::Float64(OrderedFloat(9.0f64)).into() - ); - - assert_eq!( - ScalarValue::LargeUtf8(Some("hello".to_string())), - Value::String(StringBytes::from("hello")).into() - ); - assert_eq!( - ScalarValue::LargeBinary(Some("world".as_bytes().to_vec())), - Value::Binary(Bytes::from("world".as_bytes())).into() - ); - - assert_eq!(ScalarValue::Date32(Some(10i32)), Value::Date(10i32).into()); - assert_eq!( - ScalarValue::Date64(Some(20i64)), - Value::DateTime(20i64).into() - ); - - assert_eq!(ScalarValue::Boolean(None), Value::Null.into()); - } - fn to_json(value: Value) -> serde_json::Value { value.try_into().unwrap() }