fix: correctly convert Value::Null to ScalarValue (#187)

* fix: correctly convert Value::Null to ScalarValue

* address PR comments

* refactor: make code robust

Co-authored-by: luofucong <luofucong@greptime.com>
This commit is contained in:
LFC
2022-08-19 10:37:30 +08:00
committed by GitHub
parent 5c9b46fbf8
commit 9a68e4ca88
7 changed files with 347 additions and 115 deletions

1
Cargo.lock generated
View File

@@ -759,6 +759,7 @@ name = "common-query"
version = "0.1.0"
dependencies = [
"arrow2",
"common-base",
"common-error",
"datafusion",
"datafusion-common",

View File

@@ -17,3 +17,4 @@ snafu = { version = "0.7", features = ["backtraces"] }
[dev-dependencies]
tokio = { version = "1.0", features = ["full"] }
common-base = {path = "../base"}

View File

@@ -34,6 +34,12 @@ pub enum InnerError {
#[snafu(display("Failed to downcast vector: {}", err_msg))]
DowncastVector { err_msg: String },
#[snafu(display("Bad accumulator implementation: {}", err_msg))]
BadAccumulatorImpl {
err_msg: String,
backtrace: Backtrace,
},
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -43,7 +49,8 @@ impl ErrorExt for InnerError {
match self {
InnerError::ExecuteFunction { .. }
| InnerError::CreateAccumulator { .. }
| InnerError::DowncastVector { .. } => StatusCode::EngineExecuteQuery,
| InnerError::DowncastVector { .. }
| InnerError::BadAccumulatorImpl { .. } => StatusCode::EngineExecuteQuery,
InnerError::IntoVector { source, .. } => source.status_code(),
InnerError::FromScalarValue { source } => source.status_code(),
}

View File

@@ -7,11 +7,12 @@ use arrow::array::ArrayRef;
use datafusion_common::Result as DfResult;
use datafusion_expr::Accumulator as DfAccumulator;
use datatypes::prelude::*;
use datatypes::value::ListValue;
use datatypes::vectors::Helper as VectorHelper;
use datatypes::vectors::VectorRef;
use snafu::ResultExt;
use crate::error::{Error, FromScalarValueSnafu, IntoVectorSnafu, Result};
use crate::error::{self, Error, FromScalarValueSnafu, IntoVectorSnafu, Result};
use crate::prelude::*;
pub type AggregateFunctionCreatorRef = Arc<dyn AggregateFunctionCreator>;
@@ -87,22 +88,49 @@ pub fn make_state_function(creator: Arc<dyn AggregateFunctionCreator>) -> StateT
Arc::new(move |_| Ok(Arc::new(creator.state_types()?)))
}
/// A wrapper newtype for our Accumulator to DataFusion's Accumulator,
/// A wrapper type for our Accumulator to DataFusion's Accumulator,
/// so to make our Accumulator able to be executed by DataFusion query engine.
#[derive(Debug)]
pub struct DfAccumulatorAdaptor(pub Box<dyn Accumulator>);
pub struct DfAccumulatorAdaptor {
accumulator: Box<dyn Accumulator>,
creator: AggregateFunctionCreatorRef,
}
impl DfAccumulatorAdaptor {
pub fn new(accumulator: Box<dyn Accumulator>, creator: AggregateFunctionCreatorRef) -> Self {
Self {
accumulator,
creator,
}
}
}
impl DfAccumulator for DfAccumulatorAdaptor {
fn state(&self) -> DfResult<Vec<ScalarValue>> {
let state = self.0.state()?;
Ok(state.into_iter().map(ScalarValue::from).collect())
let state_values = self.accumulator.state()?;
let state_types = self.creator.state_types()?;
if state_values.len() != state_types.len() {
return error::BadAccumulatorImplSnafu {
err_msg: format!("Accumulator {:?} returned state values size do not match its state types size.", self),
}
.fail()
.map_err(Error::from)?;
}
Ok(state_values
.into_iter()
.zip(state_types.iter())
.map(|(v, t)| try_into_scalar_value(v, t))
.collect::<Result<Vec<_>>>()
.map_err(Error::from)?)
}
fn update_batch(&mut self, values: &[ArrayRef]) -> DfResult<()> {
let vectors = VectorHelper::try_into_vectors(values)
.context(FromScalarValueSnafu)
.map_err(Error::from)?;
self.0.update_batch(&vectors).map_err(|e| e.into())
self.accumulator
.update_batch(&vectors)
.map_err(|e| e.into())
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> DfResult<()> {
@@ -116,10 +144,287 @@ impl DfAccumulator for DfAccumulatorAdaptor {
.map_err(Error::from)?,
);
}
self.0.merge_batch(&vectors).map_err(|e| e.into())
self.accumulator.merge_batch(&vectors).map_err(|e| e.into())
}
fn evaluate(&self) -> DfResult<ScalarValue> {
Ok(ScalarValue::from(self.0.evaluate()?))
let value = self.accumulator.evaluate()?;
let output_type = self.creator.output_type()?;
Ok(try_into_scalar_value(value, &output_type)?)
}
}
fn try_into_scalar_value(value: Value, datatype: &ConcreteDataType) -> Result<ScalarValue> {
if !matches!(value, Value::Null) && datatype != &value.data_type() {
return error::BadAccumulatorImplSnafu {
err_msg: format!(
"expect value to return datatype {:?}, actual: {:?}",
datatype,
value.data_type()
),
}
.fail()?;
}
Ok(match value {
Value::Boolean(v) => ScalarValue::Boolean(Some(v)),
Value::UInt8(v) => ScalarValue::UInt8(Some(v)),
Value::UInt16(v) => ScalarValue::UInt16(Some(v)),
Value::UInt32(v) => ScalarValue::UInt32(Some(v)),
Value::UInt64(v) => ScalarValue::UInt64(Some(v)),
Value::Int8(v) => ScalarValue::Int8(Some(v)),
Value::Int16(v) => ScalarValue::Int16(Some(v)),
Value::Int32(v) => ScalarValue::Int32(Some(v)),
Value::Int64(v) => ScalarValue::Int64(Some(v)),
Value::Float32(v) => ScalarValue::Float32(Some(v.0)),
Value::Float64(v) => ScalarValue::Float64(Some(v.0)),
Value::String(v) => ScalarValue::LargeUtf8(Some(v.as_utf8().to_string())),
Value::Binary(v) => ScalarValue::LargeBinary(Some(v.to_vec())),
Value::Date(v) => ScalarValue::Date32(Some(v)),
Value::DateTime(v) => ScalarValue::Date64(Some(v)),
Value::Null => try_convert_null_value(datatype)?,
Value::List(list) => try_convert_list_value(list)?,
})
}
fn try_convert_null_value(datatype: &ConcreteDataType) -> Result<ScalarValue> {
Ok(match datatype {
ConcreteDataType::Boolean(_) => ScalarValue::Boolean(None),
ConcreteDataType::Int8(_) => ScalarValue::Int8(None),
ConcreteDataType::Int16(_) => ScalarValue::Int16(None),
ConcreteDataType::Int32(_) => ScalarValue::Int32(None),
ConcreteDataType::Int64(_) => ScalarValue::Int64(None),
ConcreteDataType::UInt8(_) => ScalarValue::UInt8(None),
ConcreteDataType::UInt16(_) => ScalarValue::UInt16(None),
ConcreteDataType::UInt32(_) => ScalarValue::UInt32(None),
ConcreteDataType::UInt64(_) => ScalarValue::UInt64(None),
ConcreteDataType::Float32(_) => ScalarValue::Float32(None),
ConcreteDataType::Float64(_) => ScalarValue::Float64(None),
ConcreteDataType::Binary(_) => ScalarValue::LargeBinary(None),
ConcreteDataType::String(_) => ScalarValue::LargeUtf8(None),
_ => {
return error::BadAccumulatorImplSnafu {
err_msg: format!(
"undefined transition from null value to datatype {:?}",
datatype
),
}
.fail()?
}
})
}
fn try_convert_list_value(list: ListValue) -> Result<ScalarValue> {
let vs = if let Some(items) = list.items() {
Some(Box::new(
items
.iter()
.map(|v| try_into_scalar_value(v.clone(), list.datatype()))
.collect::<Result<Vec<_>>>()?,
))
} else {
None
};
Ok(ScalarValue::List(
vs,
Box::new(list.datatype().as_arrow_type()),
))
}
#[cfg(test)]
mod tests {
use arrow::datatypes::DataType;
use common_base::bytes::{Bytes, StringBytes};
use datafusion_common::ScalarValue;
use datatypes::value::{ListValue, OrderedFloat};
use super::*;
#[test]
fn test_not_null_value_to_scalar_value() {
assert_eq!(
ScalarValue::Boolean(Some(true)),
try_into_scalar_value(Value::Boolean(true), &ConcreteDataType::boolean_datatype())
.unwrap()
);
assert_eq!(
ScalarValue::Boolean(Some(false)),
try_into_scalar_value(Value::Boolean(false), &ConcreteDataType::boolean_datatype())
.unwrap()
);
assert_eq!(
ScalarValue::UInt8(Some(u8::MIN + 1)),
try_into_scalar_value(
Value::UInt8(u8::MIN + 1),
&ConcreteDataType::uint8_datatype()
)
.unwrap()
);
assert_eq!(
ScalarValue::UInt16(Some(u16::MIN + 2)),
try_into_scalar_value(
Value::UInt16(u16::MIN + 2),
&ConcreteDataType::uint16_datatype()
)
.unwrap()
);
assert_eq!(
ScalarValue::UInt32(Some(u32::MIN + 3)),
try_into_scalar_value(
Value::UInt32(u32::MIN + 3),
&ConcreteDataType::uint32_datatype()
)
.unwrap()
);
assert_eq!(
ScalarValue::UInt64(Some(u64::MIN + 4)),
try_into_scalar_value(
Value::UInt64(u64::MIN + 4),
&ConcreteDataType::uint64_datatype()
)
.unwrap()
);
assert_eq!(
ScalarValue::Int8(Some(i8::MIN + 4)),
try_into_scalar_value(Value::Int8(i8::MIN + 4), &ConcreteDataType::int8_datatype())
.unwrap()
);
assert_eq!(
ScalarValue::Int16(Some(i16::MIN + 5)),
try_into_scalar_value(
Value::Int16(i16::MIN + 5),
&ConcreteDataType::int16_datatype()
)
.unwrap()
);
assert_eq!(
ScalarValue::Int32(Some(i32::MIN + 6)),
try_into_scalar_value(
Value::Int32(i32::MIN + 6),
&ConcreteDataType::int32_datatype()
)
.unwrap()
);
assert_eq!(
ScalarValue::Int64(Some(i64::MIN + 7)),
try_into_scalar_value(
Value::Int64(i64::MIN + 7),
&ConcreteDataType::int64_datatype()
)
.unwrap()
);
assert_eq!(
ScalarValue::Float32(Some(8.0f32)),
try_into_scalar_value(
Value::Float32(OrderedFloat(8.0f32)),
&ConcreteDataType::float32_datatype()
)
.unwrap()
);
assert_eq!(
ScalarValue::Float64(Some(9.0f64)),
try_into_scalar_value(
Value::Float64(OrderedFloat(9.0f64)),
&ConcreteDataType::float64_datatype()
)
.unwrap()
);
assert_eq!(
ScalarValue::LargeUtf8(Some("hello".to_string())),
try_into_scalar_value(
Value::String(StringBytes::from("hello")),
&ConcreteDataType::string_datatype()
)
.unwrap()
);
assert_eq!(
ScalarValue::LargeBinary(Some("world".as_bytes().to_vec())),
try_into_scalar_value(
Value::Binary(Bytes::from("world".as_bytes())),
&ConcreteDataType::binary_datatype()
)
.unwrap()
);
}
#[test]
fn test_null_value_to_scalar_value() {
assert_eq!(
ScalarValue::Boolean(None),
try_into_scalar_value(Value::Null, &ConcreteDataType::boolean_datatype()).unwrap()
);
assert_eq!(
ScalarValue::UInt8(None),
try_into_scalar_value(Value::Null, &ConcreteDataType::uint8_datatype()).unwrap()
);
assert_eq!(
ScalarValue::UInt16(None),
try_into_scalar_value(Value::Null, &ConcreteDataType::uint16_datatype()).unwrap()
);
assert_eq!(
ScalarValue::UInt32(None),
try_into_scalar_value(Value::Null, &ConcreteDataType::uint32_datatype()).unwrap()
);
assert_eq!(
ScalarValue::UInt64(None),
try_into_scalar_value(Value::Null, &ConcreteDataType::uint64_datatype()).unwrap()
);
assert_eq!(
ScalarValue::Int8(None),
try_into_scalar_value(Value::Null, &ConcreteDataType::int8_datatype()).unwrap()
);
assert_eq!(
ScalarValue::Int16(None),
try_into_scalar_value(Value::Null, &ConcreteDataType::int16_datatype()).unwrap()
);
assert_eq!(
ScalarValue::Int32(None),
try_into_scalar_value(Value::Null, &ConcreteDataType::int32_datatype()).unwrap()
);
assert_eq!(
ScalarValue::Int64(None),
try_into_scalar_value(Value::Null, &ConcreteDataType::int64_datatype()).unwrap()
);
assert_eq!(
ScalarValue::Float32(None),
try_into_scalar_value(Value::Null, &ConcreteDataType::float32_datatype()).unwrap()
);
assert_eq!(
ScalarValue::Float64(None),
try_into_scalar_value(Value::Null, &ConcreteDataType::float64_datatype()).unwrap()
);
assert_eq!(
ScalarValue::LargeUtf8(None),
try_into_scalar_value(Value::Null, &ConcreteDataType::string_datatype()).unwrap()
);
assert_eq!(
ScalarValue::LargeBinary(None),
try_into_scalar_value(Value::Null, &ConcreteDataType::binary_datatype()).unwrap()
);
}
#[test]
fn test_list_value_to_scalar_value() {
let items = Some(Box::new(vec![Value::Int32(-1), Value::Null]));
let list = Value::List(ListValue::new(items, ConcreteDataType::int32_datatype()));
let df_list = try_into_scalar_value(
list,
&ConcreteDataType::list_datatype(ConcreteDataType::int32_datatype()),
)
.unwrap();
assert!(matches!(df_list, ScalarValue::List(_, _)));
match df_list {
ScalarValue::List(vs, datatype) => {
assert_eq!(*datatype, DataType::Int32);
assert!(vs.is_some());
let vs = *vs.unwrap();
assert_eq!(
vs,
vec![ScalarValue::Int32(Some(-1)), ScalarValue::Int32(None)]
);
}
_ => unreachable!(),
}
}
}

View File

@@ -49,6 +49,7 @@ pub fn create_aggregate_function(
return_type,
accumulator,
state_type,
creator,
)
}

View File

@@ -15,6 +15,7 @@ use crate::function::{
to_df_return_type, AccumulatorFunctionImpl, ReturnTypeFunction, StateTypeFunction,
};
use crate::logical_plan::accumulator::DfAccumulatorAdaptor;
use crate::logical_plan::AggregateFunctionCreatorRef;
use crate::signature::Signature;
/// Logical representation of a user-defined aggregate function (UDAF)
@@ -31,6 +32,8 @@ pub struct AggregateFunction {
pub accumulator: AccumulatorFunctionImpl,
/// the accumulator's state's description as a function of the return type
pub state_type: StateTypeFunction,
/// the creator that creates aggregate functions
creator: AggregateFunctionCreatorRef,
}
impl Debug for AggregateFunction {
@@ -57,6 +60,7 @@ impl AggregateFunction {
return_type: ReturnTypeFunction,
accumulator: AccumulatorFunctionImpl,
state_type: StateTypeFunction,
creator: AggregateFunctionCreatorRef,
) -> Self {
Self {
name,
@@ -64,6 +68,7 @@ impl AggregateFunction {
return_type,
accumulator,
state_type,
creator,
}
}
}
@@ -74,16 +79,20 @@ impl From<AggregateFunction> for DfAggregateUdf {
&udaf.name,
&udaf.signature.into(),
&to_df_return_type(udaf.return_type),
&to_df_accumulator_func(udaf.accumulator),
&to_df_accumulator_func(udaf.accumulator, udaf.creator.clone()),
&to_df_state_type(udaf.state_type),
)
}
}
fn to_df_accumulator_func(func: AccumulatorFunctionImpl) -> DfAccumulatorFunctionImplementation {
fn to_df_accumulator_func(
accumulator: AccumulatorFunctionImpl,
creator: AggregateFunctionCreatorRef,
) -> DfAccumulatorFunctionImplementation {
Arc::new(move || {
let acc = func()?;
Ok(Box::new(DfAccumulatorAdaptor(acc)))
let accumulator = accumulator()?;
let creator = creator.clone();
Ok(Box::new(DfAccumulatorAdaptor::new(accumulator, creator)))
})
}

View File

@@ -1,7 +1,6 @@
use std::cmp::Ordering;
use common_base::bytes::{Bytes, StringBytes};
use datafusion_common::ScalarValue;
pub use ordered_float::OrderedFloat;
use serde::{Deserialize, Serialize};
@@ -64,7 +63,8 @@ impl Value {
Value::Float64(_) => ConcreteDataType::float64_datatype(),
Value::String(_) => ConcreteDataType::string_datatype(),
Value::Binary(_) => ConcreteDataType::binary_datatype(),
Value::Date(_) | Value::DateTime(_) | Value::List(_) => {
Value::List(list) => ConcreteDataType::list_datatype(list.datatype().clone()),
Value::Date(_) | Value::DateTime(_) => {
unimplemented!("Unsupported data type of value {:?}", self)
}
}
@@ -160,34 +160,6 @@ impl TryFrom<Value> for serde_json::Value {
}
}
impl From<Value> for ScalarValue {
fn from(value: Value) -> Self {
match value {
Value::Boolean(v) => ScalarValue::Boolean(Some(v)),
Value::UInt8(v) => ScalarValue::UInt8(Some(v)),
Value::UInt16(v) => ScalarValue::UInt16(Some(v)),
Value::UInt32(v) => ScalarValue::UInt32(Some(v)),
Value::UInt64(v) => ScalarValue::UInt64(Some(v)),
Value::Int8(v) => ScalarValue::Int8(Some(v)),
Value::Int16(v) => ScalarValue::Int16(Some(v)),
Value::Int32(v) => ScalarValue::Int32(Some(v)),
Value::Int64(v) => ScalarValue::Int64(Some(v)),
Value::Float32(v) => ScalarValue::Float32(Some(v.0)),
Value::Float64(v) => ScalarValue::Float64(Some(v.0)),
Value::String(v) => ScalarValue::LargeUtf8(Some(v.as_utf8().to_string())),
Value::Binary(v) => ScalarValue::LargeBinary(Some(v.to_vec())),
Value::Date(v) => ScalarValue::Date32(Some(v)),
Value::DateTime(v) => ScalarValue::Date64(Some(v)),
Value::Null => ScalarValue::Boolean(None),
Value::List(v) => ScalarValue::List(
v.items
.map(|vs| Box::new(vs.into_iter().map(ScalarValue::from).collect())),
Box::new(v.datatype.as_arrow_type()),
),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ListValue {
/// List of nested Values (boxed to reduce size_of(Value))
@@ -204,6 +176,14 @@ impl ListValue {
pub fn new(items: Option<Box<Vec<Value>>>, datatype: ConcreteDataType) -> Self {
Self { items, datatype }
}
pub fn items(&self) -> &Option<Box<Vec<Value>>> {
&self.items
}
pub fn datatype(&self) -> &ConcreteDataType {
&self.datatype
}
}
impl PartialOrd for ListValue {
@@ -367,78 +347,6 @@ mod tests {
assert_eq!(Value::Binary(Bytes::from(world)), Value::from(world));
}
#[test]
fn test_value_into_scalar_value() {
assert_eq!(
ScalarValue::Boolean(Some(true)),
Value::Boolean(true).into()
);
assert_eq!(
ScalarValue::Boolean(Some(true)),
Value::Boolean(true).into()
);
assert_eq!(
ScalarValue::UInt8(Some(u8::MIN + 1)),
Value::UInt8(u8::MIN + 1).into()
);
assert_eq!(
ScalarValue::UInt16(Some(u16::MIN + 2)),
Value::UInt16(u16::MIN + 2).into()
);
assert_eq!(
ScalarValue::UInt32(Some(u32::MIN + 3)),
Value::UInt32(u32::MIN + 3).into()
);
assert_eq!(
ScalarValue::UInt64(Some(u64::MIN + 4)),
Value::UInt64(u64::MIN + 4).into()
);
assert_eq!(
ScalarValue::Int8(Some(i8::MIN + 4)),
Value::Int8(i8::MIN + 4).into()
);
assert_eq!(
ScalarValue::Int16(Some(i16::MIN + 5)),
Value::Int16(i16::MIN + 5).into()
);
assert_eq!(
ScalarValue::Int32(Some(i32::MIN + 6)),
Value::Int32(i32::MIN + 6).into()
);
assert_eq!(
ScalarValue::Int64(Some(i64::MIN + 7)),
Value::Int64(i64::MIN + 7).into()
);
assert_eq!(
ScalarValue::Float32(Some(8.0f32)),
Value::Float32(OrderedFloat(8.0f32)).into()
);
assert_eq!(
ScalarValue::Float64(Some(9.0f64)),
Value::Float64(OrderedFloat(9.0f64)).into()
);
assert_eq!(
ScalarValue::LargeUtf8(Some("hello".to_string())),
Value::String(StringBytes::from("hello")).into()
);
assert_eq!(
ScalarValue::LargeBinary(Some("world".as_bytes().to_vec())),
Value::Binary(Bytes::from("world".as_bytes())).into()
);
assert_eq!(ScalarValue::Date32(Some(10i32)), Value::Date(10i32).into());
assert_eq!(
ScalarValue::Date64(Some(20i64)),
Value::DateTime(20i64).into()
);
assert_eq!(ScalarValue::Boolean(None), Value::Null.into());
}
fn to_json(value: Value) -> serde_json::Value {
value.try_into().unwrap()
}