diff --git a/Cargo.lock b/Cargo.lock index 158275b73d..d6f13ecf65 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -455,6 +455,28 @@ dependencies = [ "winapi", ] +[[package]] +name = "chrono-tz" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "58549f1842da3080ce63002102d5bc954c7bc843d4f47818e642abdc36253552" +dependencies = [ + "chrono", + "chrono-tz-build", + "phf", +] + +[[package]] +name = "chrono-tz-build" +version = "0.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db058d493fb2f65f41861bfed7e3fe6335264a9f0f92710cab5bdf01fef09069" +dependencies = [ + "parse-zoneinfo", + "phf", + "phf_codegen", +] + [[package]] name = "clap" version = "3.1.17" @@ -518,6 +540,9 @@ dependencies = [ [[package]] name = "common-base" version = "0.1.0" +dependencies = [ + "serde", +] [[package]] name = "common-error" @@ -526,6 +551,21 @@ dependencies = [ "snafu", ] +[[package]] +name = "common-function" +version = "0.1.0" +dependencies = [ + "chrono-tz", + "common-error", + "common-query", + "datatypes", + "num", + "num-traits", + "once_cell", + "paste", + "snafu", +] + [[package]] name = "common-query" version = "0.1.0" @@ -876,6 +916,7 @@ dependencies = [ "arrow2", "common-base", "common-error", + "datafusion-common", "enum_dispatch", "paste", "serde", @@ -1813,6 +1854,20 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "num" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43db66d1170d347f9a065114077f7dccb00c1b9478c89384490a3425279a4606" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + [[package]] name = "num-bigint" version = "0.4.3" @@ -1824,6 +1879,15 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-complex" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fbc387afefefd5e9e39493299f3069e14a140dd34dc19b4c1c1a8fddb6a790" +dependencies = [ + "num-traits", +] + [[package]] name = "num-integer" version = "0.1.45" @@ -1834,6 +1898,29 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-iter" +version = "0.1.43" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d41702bd167c2df5520b384281bc111a4b5efcf7fbc4c9c222c815b07e0a6a6a" +dependencies = [ + "autocfg", + "num-bigint", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.15" @@ -2127,6 +2214,15 @@ dependencies = [ "zstd", ] +[[package]] +name = "parse-zoneinfo" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c705f256449c60da65e11ff6626e0c16a0a0b96aaa348de61376b249bc340f41" +dependencies = [ + "regex", +] + [[package]] name = "paste" version = "1.0.7" @@ -2158,6 +2254,45 @@ dependencies = [ "indexmap", ] +[[package]] +name = "phf" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fabbf1ead8a5bcbc20f5f8b939ee3f5b0f6f281b6ad3468b84656b658b455259" +dependencies = [ + "phf_shared", +] + +[[package]] +name = "phf_codegen" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fb1c3a8bc4dd4e5cfce29b44ffc14bedd2ee294559a294e2a4d4c9e9a6a13cd" +dependencies = [ + "phf_generator", + "phf_shared", +] + +[[package]] +name = "phf_generator" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d5285893bb5eb82e6aaf5d59ee909a06a16737a8970984dd7746ba9283498d6" +dependencies = [ + "phf_shared", + "rand 0.8.5", +] + +[[package]] +name = "phf_shared" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6796ad771acdc0123d2a88dc428b5e38ef24456743ddb1744ed628f9815c096" +dependencies = [ + "siphasher", + "uncased", +] + [[package]] name = "pin-project" version = "1.0.10" @@ -2300,6 +2435,7 @@ dependencies = [ "arrow2", "async-trait", "common-error", + "common-function", "common-query", "common-recordbatch", "common-telemetry", @@ -2730,6 +2866,12 @@ dependencies = [ "time 0.3.9", ] +[[package]] +name = "siphasher" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bd3e3206899af3f8b12af284fafc038cc1dc2b41d1b89dd17297221c5d225de" + [[package]] name = "sketches-ddsketch" version = "0.1.2" @@ -3375,6 +3517,15 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987" +[[package]] +name = "uncased" +version = "0.9.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09b01702b0fd0b3fadcf98e098780badda8742d4f4a7676615cad90e8ac73622" +dependencies = [ + "version_check", +] + [[package]] name = "unicase" version = "2.6.0" diff --git a/Cargo.toml b/Cargo.toml index c031f69862..67408e9658 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ members = [ "src/common/base", "src/common/error", + "src/common/function", "src/common/telemetry", "src/common/query", "src/common/recordbatch", diff --git a/src/common/base/Cargo.toml b/src/common/base/Cargo.toml index afcd0e6983..087988229e 100644 --- a/src/common/base/Cargo.toml +++ b/src/common/base/Cargo.toml @@ -3,6 +3,5 @@ name = "common-base" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] +serde = { version = "1.0", features = ["derive"] } diff --git a/src/common/base/src/bytes.rs b/src/common/base/src/bytes.rs index c0d299023d..5a934fad64 100644 --- a/src/common/base/src/bytes.rs +++ b/src/common/base/src/bytes.rs @@ -1,7 +1,10 @@ +use serde::Serialize; /// Bytes buffer. -#[derive(Debug, Default, Clone, PartialEq)] -pub struct Bytes(Vec); +#[derive(Debug, Default, Clone, PartialEq, Serialize)] +//TODO: impl From and Deref to remove pub declaration +pub struct Bytes(pub Vec); /// String buffer with arbitrary encoding. -#[derive(Debug, Default, Clone, PartialEq)] -pub struct StringBytes(Vec); +#[derive(Debug, Default, Clone, PartialEq, Serialize)] +//TODO: impl From and Deref to remove pub declaration +pub struct StringBytes(pub Vec); diff --git a/src/common/function/Cargo.toml b/src/common/function/Cargo.toml new file mode 100644 index 0000000000..1f83deed1f --- /dev/null +++ b/src/common/function/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "common-function" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +chrono-tz = "0.6" +common-error = { path = "../error" } +common-query = { path = "../query" } +datatypes = { path = "../../datatypes" } +num = "0.4.0" +num-traits = "0.2.14" +once_cell = "1.10" +paste = "1.0" +snafu = { version = "0.7", features = ["backtraces"] } \ No newline at end of file diff --git a/src/common/function/src/error.rs b/src/common/function/src/error.rs new file mode 100644 index 0000000000..854526e750 --- /dev/null +++ b/src/common/function/src/error.rs @@ -0,0 +1,69 @@ +use std::any::Any; + +use common_error::prelude::*; +use common_query::error::Error as QueryError; +use datatypes::error::Error as DataTypeError; +use snafu::GenerateImplicitData; + +common_error::define_opaque_error!(Error); + +pub type Result = std::result::Result; + +#[derive(Debug, Snafu)] +#[snafu(visibility(pub))] +pub enum InnerError { + #[snafu(display("Fail to get scalar vector, {}", source))] + GetScalarVector { + source: DataTypeError, + backtrace: Backtrace, + }, +} + +impl ErrorExt for InnerError { + fn backtrace_opt(&self) -> Option<&Backtrace> { + ErrorCompat::backtrace(self) + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +impl From for Error { + fn from(err: InnerError) -> Self { + Self::new(err) + } +} + +impl From for QueryError { + fn from(err: Error) -> Self { + QueryError::External { + msg: err.to_string(), + backtrace: Backtrace::generate(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn raise_datatype_error() -> std::result::Result<(), DataTypeError> { + Err(DataTypeError::Conversion { + from: "test".to_string(), + backtrace: Backtrace::generate(), + }) + } + + #[test] + fn test_get_scalar_vector_error() { + let err = raise_datatype_error() + .context(GetScalarVectorSnafu) + .err() + .unwrap(); + assert!(err.backtrace_opt().is_some()); + + let query_error = QueryError::from(Error::from(err)); + assert!(matches!(query_error, QueryError::External { .. })); + } +} diff --git a/src/common/function/src/lib.rs b/src/common/function/src/lib.rs new file mode 100644 index 0000000000..b7bfa04e41 --- /dev/null +++ b/src/common/function/src/lib.rs @@ -0,0 +1,2 @@ +pub mod error; +pub mod scalars; diff --git a/src/common/function/src/scalars.rs b/src/common/function/src/scalars.rs new file mode 100644 index 0000000000..e171af88ca --- /dev/null +++ b/src/common/function/src/scalars.rs @@ -0,0 +1,11 @@ +pub mod expression; +pub mod function; +pub mod function_registry; +pub mod math; +pub mod numpy; +#[cfg(test)] +pub(crate) mod test; +pub mod udf; + +pub use function::{Function, FunctionRef}; +pub use function_registry::{FunctionRegistry, FUNCTION_REGISTRY}; diff --git a/src/common/function/src/scalars/expression/binary.rs b/src/common/function/src/scalars/expression/binary.rs new file mode 100644 index 0000000000..f0eea0ec2d --- /dev/null +++ b/src/common/function/src/scalars/expression/binary.rs @@ -0,0 +1,80 @@ +use std::iter; + +use datatypes::prelude::*; +use datatypes::vectors::ConstantVector; + +use crate::error::Result; +use crate::scalars::expression::ctx::EvalContext; + +pub fn scalar_binary_op( + l: &VectorRef, + r: &VectorRef, + f: F, + ctx: &mut EvalContext, +) -> Result<::VectorType> +where + F: Fn(Option>, Option>, &mut EvalContext) -> Option, +{ + debug_assert!( + l.len() == r.len(), + "Size of vectors must match to apply binary expression" + ); + + let result = match (l.is_const(), r.is_const()) { + (false, true) => { + let left: &::VectorType = unsafe { VectorHelper::static_cast(l) }; + let right: &ConstantVector = unsafe { VectorHelper::static_cast(r) }; + let right: &::VectorType = + unsafe { VectorHelper::static_cast(right.inner()) }; + let b = right.get_data(0); + + let it = left.iter_data().map(|a| f(a, b, ctx)); + ::VectorType::from_owned_iterator(it) + } + + (false, false) => { + let left: &::VectorType = unsafe { VectorHelper::static_cast(l) }; + let right: &::VectorType = unsafe { VectorHelper::static_cast(r) }; + + let it = left + .iter_data() + .zip(right.iter_data()) + .map(|(a, b)| f(a, b, ctx)); + ::VectorType::from_owned_iterator(it) + } + + (true, false) => { + let left: &ConstantVector = unsafe { VectorHelper::static_cast(l) }; + let left: &::VectorType = + unsafe { VectorHelper::static_cast(left.inner()) }; + let a = left.get_data(0); + + let right: &::VectorType = unsafe { VectorHelper::static_cast(r) }; + let it = right.iter_data().map(|b| f(a, b, ctx)); + ::VectorType::from_owned_iterator(it) + } + + (true, true) => { + let left: &ConstantVector = unsafe { VectorHelper::static_cast(l) }; + let left: &::VectorType = + unsafe { VectorHelper::static_cast(left.inner()) }; + let a = left.get_data(0); + + let right: &ConstantVector = unsafe { VectorHelper::static_cast(r) }; + let right: &::VectorType = + unsafe { VectorHelper::static_cast(right.inner()) }; + let b = right.get_data(0); + + let it = iter::repeat(a) + .zip(iter::repeat(b)) + .map(|(a, b)| f(a, b, ctx)) + .take(left.len()); + ::VectorType::from_owned_iterator(it) + } + }; + + if let Some(error) = ctx.error.take() { + return Err(error); + } + Ok(result) +} diff --git a/src/common/function/src/scalars/expression/ctx.rs b/src/common/function/src/scalars/expression/ctx.rs new file mode 100644 index 0000000000..87f7e5afea --- /dev/null +++ b/src/common/function/src/scalars/expression/ctx.rs @@ -0,0 +1,26 @@ +use chrono_tz::Tz; + +use crate::error::Error; + +pub struct EvalContext { + _tz: Tz, + pub error: Option, +} + +impl Default for EvalContext { + fn default() -> Self { + let tz = "UTC".parse::().unwrap(); + Self { + error: None, + _tz: tz, + } + } +} + +impl EvalContext { + pub fn set_error(&mut self, e: Error) { + if self.error.is_none() { + self.error = Some(e); + } + } +} diff --git a/src/common/function/src/scalars/expression/mod.rs b/src/common/function/src/scalars/expression/mod.rs new file mode 100644 index 0000000000..6fe84c189c --- /dev/null +++ b/src/common/function/src/scalars/expression/mod.rs @@ -0,0 +1,7 @@ +mod binary; +mod ctx; +mod unary; + +pub use binary::scalar_binary_op; +pub use ctx::EvalContext; +pub use unary::scalar_unary_op; diff --git a/src/common/function/src/scalars/expression/unary.rs b/src/common/function/src/scalars/expression/unary.rs new file mode 100644 index 0000000000..848b4edcd1 --- /dev/null +++ b/src/common/function/src/scalars/expression/unary.rs @@ -0,0 +1,26 @@ +use datatypes::prelude::*; +use snafu::ResultExt; + +use crate::error::{GetScalarVectorSnafu, Result}; +use crate::scalars::expression::ctx::EvalContext; + +/// TODO: remove the allow_unused when it's used. +#[allow(unused)] +pub fn scalar_unary_op( + l: &VectorRef, + f: F, + ctx: &mut EvalContext, +) -> Result<::VectorType> +where + F: Fn(Option>, &mut EvalContext) -> Option, +{ + let left = VectorHelper::check_get_scalar::(l).context(GetScalarVectorSnafu)?; + let it = left.iter_data().map(|a| f(a, ctx)); + let result = ::VectorType::from_owned_iterator(it); + + if let Some(error) = ctx.error.take() { + return Err(error); + } + + Ok(result) +} diff --git a/src/common/function/src/scalars/function.rs b/src/common/function/src/scalars/function.rs new file mode 100644 index 0000000000..751803eff3 --- /dev/null +++ b/src/common/function/src/scalars/function.rs @@ -0,0 +1,38 @@ +use std::fmt; +use std::sync::Arc; + +use chrono_tz::Tz; +use common_query::prelude::Signature; +use datatypes::data_type::ConcreteDataType; +use datatypes::vectors::VectorRef; + +use crate::error::Result; + +#[derive(Clone)] +pub struct FunctionContext { + pub tz: Tz, +} + +impl Default for FunctionContext { + fn default() -> Self { + Self { + tz: "UTC".parse::().unwrap(), + } + } +} + +/// Scalar function trait, modified from databend to adapt datafusion +/// TODO(dennis): optimize function by it's features such as monotonicity etc. +pub trait Function: fmt::Display + Sync + Send { + /// Returns the name of the function, should be unique. + fn name(&self) -> &str; + + fn return_type(&self, input_types: &[ConcreteDataType]) -> Result; + + fn signature(&self) -> Signature; + + /// Evaluate the function, e.g. run/execute the function. + fn eval(&self, _func_ctx: FunctionContext, _columns: &[VectorRef]) -> Result; +} + +pub type FunctionRef = Arc; diff --git a/src/common/function/src/scalars/function_registry.rs b/src/common/function/src/scalars/function_registry.rs new file mode 100644 index 0000000000..7e9db97609 --- /dev/null +++ b/src/common/function/src/scalars/function_registry.rs @@ -0,0 +1,59 @@ +//! functions registry +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::RwLock; + +use once_cell::sync::Lazy; + +use crate::scalars::function::FunctionRef; +use crate::scalars::math::MathFunction; +use crate::scalars::numpy::NumpyFunction; + +#[derive(Default)] +pub struct FunctionRegistry { + functions: RwLock>, +} + +impl FunctionRegistry { + pub fn register(&self, func: FunctionRef) { + self.functions + .write() + .unwrap() + .insert(func.name().to_string(), func); + } + + pub fn get_function(&self, name: &str) -> Option { + self.functions.read().unwrap().get(name).cloned() + } + + pub fn functions(&self) -> Vec { + self.functions.read().unwrap().values().cloned().collect() + } +} + +pub static FUNCTION_REGISTRY: Lazy> = Lazy::new(|| { + let function_registry = FunctionRegistry::default(); + + MathFunction::register(&function_registry); + NumpyFunction::register(&function_registry); + + Arc::new(function_registry) +}); + +#[cfg(test)] +mod tests { + use super::*; + use crate::scalars::test::TestAndFunction; + + #[test] + fn test_function_registry() { + let registry = FunctionRegistry::default(); + let func = Arc::new(TestAndFunction::default()); + + assert!(registry.get_function("test_and").is_none()); + assert!(registry.functions().is_empty()); + registry.register(func); + assert!(registry.get_function("test_and").is_some()); + assert_eq!(1, registry.functions().len()); + } +} diff --git a/src/common/function/src/scalars/math/mod.rs b/src/common/function/src/scalars/math/mod.rs new file mode 100644 index 0000000000..d454b15ee6 --- /dev/null +++ b/src/common/function/src/scalars/math/mod.rs @@ -0,0 +1,15 @@ +mod pow; + +use std::sync::Arc; + +use pow::PowFunction; + +use crate::scalars::function_registry::FunctionRegistry; + +pub(crate) struct MathFunction; + +impl MathFunction { + pub fn register(registry: &FunctionRegistry) { + registry.register(Arc::new(PowFunction::default())); + } +} diff --git a/src/common/function/src/scalars/math/pow.rs b/src/common/function/src/scalars/math/pow.rs new file mode 100644 index 0000000000..05cf157c00 --- /dev/null +++ b/src/common/function/src/scalars/math/pow.rs @@ -0,0 +1,105 @@ +use std::fmt; +use std::sync::Arc; + +use common_query::prelude::{Signature, Volatility}; +use datatypes::data_type::DataType; +use datatypes::prelude::ConcreteDataType; +use datatypes::type_id::LogicalTypeId; +use datatypes::vectors::VectorRef; +use datatypes::with_match_primitive_type_id; +use num::traits::Pow; +use num_traits::AsPrimitive; + +use crate::error::Result; +use crate::scalars::expression::{scalar_binary_op, EvalContext}; +use crate::scalars::function::{Function, FunctionContext}; + +#[derive(Clone, Debug, Default)] +pub struct PowFunction; + +impl Function for PowFunction { + fn name(&self) -> &str { + "pow" + } + + fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result { + Ok(ConcreteDataType::float64_datatype()) + } + + fn signature(&self) -> Signature { + Signature::uniform(2, ConcreteDataType::numerics(), Volatility::Immutable) + } + + fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| { + with_match_primitive_type_id!(columns[1].data_type().logical_type_id(), |$T| { + let col = scalar_binary_op::<$S, $T, f64, _>(&columns[0], &columns[1], scalar_pow, &mut EvalContext::default())?; + Ok(Arc::new(col)) + },{ + unreachable!() + }) + },{ + unreachable!() + }) + } +} + +#[inline] +fn scalar_pow(value: Option, base: Option, _ctx: &mut EvalContext) -> Option +where + S: AsPrimitive, + T: AsPrimitive, +{ + match (value, base) { + (Some(value), Some(base)) => Some(value.as_().pow(base.as_())), + _ => None, + } +} + +impl fmt::Display for PowFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "POW") + } +} + +#[cfg(test)] +mod tests { + use common_query::prelude::TypeSignature; + use datatypes::value::Value; + use datatypes::vectors::{Float32Vector, Int8Vector}; + + use super::*; + #[test] + fn test_pow_function() { + let pow = PowFunction::default(); + + assert_eq!("pow", pow.name()); + assert_eq!( + ConcreteDataType::float64_datatype(), + pow.return_type(&[]).unwrap() + ); + + assert!(matches!(pow.signature(), + Signature { + type_signature: TypeSignature::Uniform(2, valid_types), + volatility: Volatility::Immutable + } if valid_types == ConcreteDataType::numerics() + )); + + let values = vec![1.0, 2.0, 3.0]; + let bases = vec![0i8, -1i8, 3i8]; + + let args: Vec = vec![ + Arc::new(Float32Vector::from_vec(values.clone())), + Arc::new(Int8Vector::from_vec(bases.clone())), + ]; + + let vector = pow.eval(FunctionContext::default(), &args).unwrap(); + assert_eq!(3, vector.len()); + + for i in 0..3 { + let p: f64 = (values[i] as f64).pow(bases[i] as f64); + assert!(matches!(vector.get_unchecked(i), Value::Float64(v) if v == p)); + } + } +} diff --git a/src/common/function/src/scalars/numpy/clip.rs b/src/common/function/src/scalars/numpy/clip.rs new file mode 100644 index 0000000000..c1218e1fc8 --- /dev/null +++ b/src/common/function/src/scalars/numpy/clip.rs @@ -0,0 +1,263 @@ +use std::fmt; +use std::sync::Arc; + +use common_query::prelude::{Signature, Volatility}; +use datatypes::data_type::ConcreteDataType; +use datatypes::data_type::DataType; +use datatypes::prelude::{Scalar, VectorRef}; +use datatypes::type_id::LogicalTypeId; +use datatypes::with_match_primitive_type_id; +use num_traits::AsPrimitive; +use paste::paste; + +use crate::error::Result; +use crate::scalars::expression::{scalar_binary_op, EvalContext}; +use crate::scalars::function::{Function, FunctionContext}; + +/// numpy.clip function, +#[derive(Clone, Debug, Default)] +pub struct ClipFunction; + +macro_rules! define_eval { + ($O: ident) => { + paste! { + fn [](columns: &[VectorRef]) -> Result { + with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| { + with_match_primitive_type_id!(columns[1].data_type().logical_type_id(), |$T| { + with_match_primitive_type_id!(columns[2].data_type().logical_type_id(), |$R| { + // clip(a, min, max) is equals to min(max(a, min), max) + let col: VectorRef = Arc::new(scalar_binary_op::<$S, $T, $O, _>(&columns[0], &columns[1], scalar_max, &mut EvalContext::default())?); + let col = scalar_binary_op::<$O, $R, $O, _>(&col, &columns[2], scalar_min, &mut EvalContext::default())?; + Ok(Arc::new(col)) + }, { + unreachable!() + }) + }, { + unreachable!() + }) + }, { + unreachable!() + }) + } + } + } +} + +define_eval!(i64); +define_eval!(u64); +define_eval!(f64); + +impl Function for ClipFunction { + fn name(&self) -> &str { + "clip" + } + + fn return_type(&self, input_types: &[ConcreteDataType]) -> Result { + if input_types.iter().all(ConcreteDataType::is_signed) { + Ok(ConcreteDataType::int64_datatype()) + } else if input_types.iter().all(ConcreteDataType::is_unsigned) { + Ok(ConcreteDataType::uint64_datatype()) + } else { + Ok(ConcreteDataType::float64_datatype()) + } + } + + fn signature(&self) -> Signature { + Signature::uniform(3, ConcreteDataType::numerics(), Volatility::Immutable) + } + + fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + if columns.iter().all(|v| v.data_type().is_signed()) { + eval_i64(columns) + } else if columns.iter().all(|v| v.data_type().is_unsigned()) { + eval_u64(columns) + } else { + eval_f64(columns) + } + } +} + +#[inline] +pub fn min(input: T, min: T) -> T { + if input < min { + input + } else { + min + } +} + +#[inline] +pub fn max(input: T, max: T) -> T { + if input > max { + input + } else { + max + } +} + +#[inline] +fn scalar_min(left: Option, right: Option, _ctx: &mut EvalContext) -> Option +where + S: AsPrimitive, + T: AsPrimitive, + O: Scalar + Copy + PartialOrd, +{ + match (left, right) { + (Some(left), Some(right)) => Some(min(left.as_(), right.as_())), + _ => None, + } +} + +#[inline] +fn scalar_max(left: Option, right: Option, _ctx: &mut EvalContext) -> Option +where + S: AsPrimitive, + T: AsPrimitive, + O: Scalar + Copy + PartialOrd, +{ + match (left, right) { + (Some(left), Some(right)) => Some(max(left.as_(), right.as_())), + _ => None, + } +} + +impl fmt::Display for ClipFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "CLIP") + } +} + +#[cfg(test)] +mod tests { + use common_query::prelude::TypeSignature; + use datatypes::value::Value; + use datatypes::vectors::{ConstantVector, Float32Vector, Int32Vector, UInt32Vector}; + + use super::*; + #[test] + fn test_clip_function() { + let clip = ClipFunction::default(); + + assert_eq!("clip", clip.name()); + assert_eq!( + ConcreteDataType::int64_datatype(), + clip.return_type(&[]).unwrap() + ); + + assert_eq!( + ConcreteDataType::int64_datatype(), + clip.return_type(&[ + ConcreteDataType::int16_datatype(), + ConcreteDataType::int64_datatype(), + ConcreteDataType::int8_datatype() + ]) + .unwrap() + ); + assert_eq!( + ConcreteDataType::uint64_datatype(), + clip.return_type(&[ + ConcreteDataType::uint16_datatype(), + ConcreteDataType::uint64_datatype(), + ConcreteDataType::uint8_datatype() + ]) + .unwrap() + ); + assert_eq!( + ConcreteDataType::float64_datatype(), + clip.return_type(&[ + ConcreteDataType::uint16_datatype(), + ConcreteDataType::int64_datatype(), + ConcreteDataType::uint8_datatype() + ]) + .unwrap() + ); + + assert!(matches!(clip.signature(), + Signature { + type_signature: TypeSignature::Uniform(3, valid_types), + volatility: Volatility::Immutable + } if valid_types == ConcreteDataType::numerics() + )); + + // eval with signed integers + let args: Vec = vec![ + Arc::new(Int32Vector::from_values(0..10)), + Arc::new(ConstantVector::new( + Arc::new(Int32Vector::from_vec(vec![3])), + 10, + )), + Arc::new(ConstantVector::new( + Arc::new(Int32Vector::from_vec(vec![6])), + 10, + )), + ]; + + let vector = clip.eval(FunctionContext::default(), &args).unwrap(); + assert_eq!(10, vector.len()); + + // clip([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 3, 6) = [3, 3, 3, 3, 4, 5, 6, 6, 6, 6] + for i in 0..10 { + if i <= 3 { + assert!(matches!(vector.get_unchecked(i), Value::Int64(v) if v == 3)); + } else if i <= 6 { + assert!(matches!(vector.get_unchecked(i), Value::Int64(v) if v == (i as i64))); + } else { + assert!(matches!(vector.get_unchecked(i), Value::Int64(v) if v == 6)); + } + } + + // eval with unsigned integers + let args: Vec = vec![ + Arc::new(UInt32Vector::from_values(0..10)), + Arc::new(ConstantVector::new( + Arc::new(UInt32Vector::from_vec(vec![3])), + 10, + )), + Arc::new(ConstantVector::new( + Arc::new(UInt32Vector::from_vec(vec![6])), + 10, + )), + ]; + + let vector = clip.eval(FunctionContext::default(), &args).unwrap(); + assert_eq!(10, vector.len()); + + // clip([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 3, 6) = [3, 3, 3, 3, 4, 5, 6, 6, 6, 6] + for i in 0..10 { + if i <= 3 { + assert!(matches!(vector.get_unchecked(i), Value::UInt64(v) if v == 3)); + } else if i <= 6 { + assert!(matches!(vector.get_unchecked(i), Value::UInt64(v) if v == (i as u64))); + } else { + assert!(matches!(vector.get_unchecked(i), Value::UInt64(v) if v == 6)); + } + } + + // eval with floats + let args: Vec = vec![ + Arc::new(Int32Vector::from_values(0..10)), + Arc::new(ConstantVector::new( + Arc::new(Int32Vector::from_vec(vec![3])), + 10, + )), + Arc::new(ConstantVector::new( + Arc::new(Float32Vector::from_vec(vec![6f32])), + 10, + )), + ]; + + let vector = clip.eval(FunctionContext::default(), &args).unwrap(); + assert_eq!(10, vector.len()); + + // clip([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 3, 6) = [3, 3, 3, 3, 4, 5, 6, 6, 6, 6] + for i in 0..10 { + if i <= 3 { + assert!(matches!(vector.get_unchecked(i), Value::Float64(v) if v == 3.0)); + } else if i <= 6 { + assert!(matches!(vector.get_unchecked(i), Value::Float64(v) if v == (i as f64))); + } else { + assert!(matches!(vector.get_unchecked(i), Value::Float64(v) if v == 6.0)); + } + } + } +} diff --git a/src/common/function/src/scalars/numpy/mod.rs b/src/common/function/src/scalars/numpy/mod.rs new file mode 100644 index 0000000000..edc754c838 --- /dev/null +++ b/src/common/function/src/scalars/numpy/mod.rs @@ -0,0 +1,15 @@ +mod clip; + +use std::sync::Arc; + +use clip::ClipFunction; + +use crate::scalars::function_registry::FunctionRegistry; + +pub(crate) struct NumpyFunction; + +impl NumpyFunction { + pub fn register(registry: &FunctionRegistry) { + registry.register(Arc::new(ClipFunction::default())); + } +} diff --git a/src/common/function/src/scalars/test.rs b/src/common/function/src/scalars/test.rs new file mode 100644 index 0000000000..fb0dfeb519 --- /dev/null +++ b/src/common/function/src/scalars/test.rs @@ -0,0 +1,57 @@ +use std::fmt; +use std::sync::Arc; + +use common_query::prelude::{Signature, Volatility}; +use datatypes::data_type::ConcreteDataType; +use datatypes::prelude::VectorRef; + +use crate::error::Result; +use crate::scalars::expression::{scalar_binary_op, EvalContext}; +use crate::scalars::function::{Function, FunctionContext}; + +#[derive(Clone, Default)] +pub(crate) struct TestAndFunction; + +impl Function for TestAndFunction { + fn name(&self) -> &str { + "test_and" + } + + fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result { + Ok(ConcreteDataType::boolean_datatype()) + } + + fn signature(&self) -> Signature { + Signature::exact( + vec![ + ConcreteDataType::boolean_datatype(), + ConcreteDataType::boolean_datatype(), + ], + Volatility::Immutable, + ) + } + + fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + let col = scalar_binary_op::( + &columns[0], + &columns[1], + scalar_and, + &mut EvalContext::default(), + )?; + Ok(Arc::new(col)) + } +} + +impl fmt::Display for TestAndFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "TEST_AND") + } +} + +#[inline] +fn scalar_and(left: Option, right: Option, _ctx: &mut EvalContext) -> Option { + match (left, right) { + (Some(left), Some(right)) => Some(left && right), + _ => None, + } +} diff --git a/src/common/function/src/scalars/udf.rs b/src/common/function/src/scalars/udf.rs new file mode 100644 index 0000000000..a4d8801ed9 --- /dev/null +++ b/src/common/function/src/scalars/udf.rs @@ -0,0 +1,127 @@ +use std::sync::Arc; + +use common_query::error::{ExecuteFunctionSnafu, FromScalarValueSnafu}; +use common_query::prelude::ScalarValue; +use common_query::prelude::{ + ColumnarValue, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUdf, +}; +use datatypes::error::Error as DataTypeError; +use datatypes::prelude::{ConcreteDataType, VectorHelper}; +use snafu::ResultExt; + +use crate::scalars::function::{FunctionContext, FunctionRef}; + +/// Create a ScalarUdf from function. +pub fn create_udf(func: FunctionRef) -> ScalarUdf { + let func_cloned = func.clone(); + let return_type: ReturnTypeFunction = Arc::new(move |input_types: &[ConcreteDataType]| { + Ok(Arc::new(func_cloned.return_type(input_types)?)) + }); + + let func_cloned = func.clone(); + let fun: ScalarFunctionImplementation = Arc::new(move |args: &[ColumnarValue]| { + let func_ctx = FunctionContext::default(); + + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Vector(v) => Some(v.len()), + }); + + let rows = len.unwrap_or(1); + + let args: Result, DataTypeError> = args + .iter() + .map(|arg| match arg { + ColumnarValue::Scalar(v) => VectorHelper::try_from_scalar_value(v.clone(), rows), + ColumnarValue::Vector(v) => Ok(v.clone()), + }) + .collect(); + + let result = func_cloned.eval(func_ctx, &args.context(FromScalarValueSnafu)?); + + if len.is_some() { + result.map(ColumnarValue::Vector).map_err(|e| e.into()) + } else { + ScalarValue::try_from_array(&result?.to_arrow_array(), 0) + .map(ColumnarValue::Scalar) + .context(ExecuteFunctionSnafu) + } + }); + + ScalarUdf::new(func.name(), &func.signature(), &return_type, &fun) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use common_query::prelude::{ColumnarValue, ScalarValue}; + use datatypes::data_type::ConcreteDataType; + use datatypes::prelude::{ScalarVector, Vector, VectorRef}; + use datatypes::value::Value; + use datatypes::vectors::{BooleanVector, ConstantVector}; + + use super::*; + use crate::scalars::function::Function; + use crate::scalars::test::TestAndFunction; + + #[test] + fn test_create_udf() { + let f = Arc::new(TestAndFunction::default()); + + let args: Vec = vec![ + Arc::new(ConstantVector::new( + Arc::new(BooleanVector::from(vec![true])), + 3, + )), + Arc::new(BooleanVector::from(vec![true, false, true])), + ]; + + let vector = f.eval(FunctionContext::default(), &args).unwrap(); + assert_eq!(3, vector.len()); + + for i in 0..3 { + assert!( + matches!(vector.get_unchecked(i), Value::Boolean(b) if b == (i == 0 || i == 2)) + ); + } + + // create a udf and test it again + let udf = create_udf(f.clone()); + + assert_eq!("test_and", udf.name); + assert_eq!(f.signature(), udf.signature); + assert_eq!( + Arc::new(ConcreteDataType::boolean_datatype()), + ((udf.return_type)(&[])).unwrap() + ); + + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))), + ColumnarValue::Vector(Arc::new(BooleanVector::from(vec![ + true, false, false, true, + ]))), + ]; + + let vec = (udf.fun)(&args).unwrap(); + + match vec { + ColumnarValue::Vector(vec) => { + let vec = vec.as_any().downcast_ref::().unwrap(); + + assert_eq!(4, vec.len()); + for i in 0..4 { + assert_eq!( + i == 0 || i == 3, + vec.get_data(i).unwrap(), + "failed at {}", + i + ) + } + } + _ => unreachable!(), + } + } +} diff --git a/src/common/query/src/columnar_value.rs b/src/common/query/src/columnar_value.rs index 27b238cff7..8f82b62bef 100644 --- a/src/common/query/src/columnar_value.rs +++ b/src/common/query/src/columnar_value.rs @@ -1,6 +1,6 @@ use datafusion_expr::ColumnarValue as DfColumnarValue; use datatypes::prelude::ConcreteDataType; -use datatypes::vectors; +use datatypes::vectors::Helper; use datatypes::vectors::VectorRef; use snafu::ResultExt; @@ -32,7 +32,7 @@ impl ColumnarValue { ColumnarValue::Scalar(s) => { let v = s.to_array_of_size(num_rows); let data_type = v.data_type().clone(); - vectors::try_into_vector(v).context(IntoVectorSnafu { data_type })? + Helper::try_into_vector(v).context(IntoVectorSnafu { data_type })? } }) } @@ -44,7 +44,7 @@ impl TryFrom<&DfColumnarValue> for ColumnarValue { Ok(match value { DfColumnarValue::Scalar(v) => ColumnarValue::Scalar(v.clone()), DfColumnarValue::Array(v) => { - ColumnarValue::Vector(vectors::try_into_vector(v.clone()).with_context(|_| { + ColumnarValue::Vector(Helper::try_into_vector(v.clone()).with_context(|_| { IntoVectorSnafu { data_type: v.data_type().clone(), } diff --git a/src/common/query/src/error.rs b/src/common/query/src/error.rs index d51f743caf..b34c6d7f48 100644 --- a/src/common/query/src/error.rs +++ b/src/common/query/src/error.rs @@ -13,12 +13,19 @@ pub enum Error { source: DataFusionError, backtrace: Backtrace, }, + #[snafu(display("Fail to cast scalar value into vector: {}", source))] + FromScalarValue { + #[snafu(backtrace)] + source: DataTypeError, + }, #[snafu(display("Fail to cast arrow array into vector: {:?}, {}", data_type, source))] IntoVector { #[snafu(backtrace)] source: DataTypeError, data_type: ArrowDatatype, }, + #[snafu(display("External error: {}, {}", msg, backtrace))] + External { msg: String, backtrace: Backtrace }, } pub type Result = std::result::Result; @@ -28,6 +35,8 @@ impl ErrorExt for Error { match self { Error::ExecuteFunction { .. } => StatusCode::EngineExecuteQuery, Error::IntoVector { source, .. } => source.status_code(), + Error::FromScalarValue { source } => source.status_code(), + Error::External { .. } => StatusCode::Internal, } } diff --git a/src/common/query/src/prelude.rs b/src/common/query/src/prelude.rs index 383159daf0..c537b7cee4 100644 --- a/src/common/query/src/prelude.rs +++ b/src/common/query/src/prelude.rs @@ -5,4 +5,4 @@ pub use crate::function::*; pub use crate::logical_plan::create_udf; pub use crate::logical_plan::Expr; pub use crate::logical_plan::ScalarUdf; -pub use crate::signature::Volatility; +pub use crate::signature::{Signature, TypeSignature, Volatility}; diff --git a/src/common/query/src/signature.rs b/src/common/query/src/signature.rs index f6201cd410..4527e09c4d 100644 --- a/src/common/query/src/signature.rs +++ b/src/common/query/src/signature.rs @@ -8,7 +8,7 @@ use datatypes::data_type::DataType; use datatypes::prelude::ConcreteDataType; /// A function's type signature, which defines the function's supported argument types. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub enum TypeSignature { /// arbitrary number of arguments of an common type out of a list of valid types // A function such as `concat` is `Variadic(vec![ConcreteDataType::String, ConcreteDataType::String])` @@ -30,7 +30,7 @@ pub enum TypeSignature { } ///The Signature of a function defines its supported input types as well as its volatility. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct Signature { /// type_signature - The types that the function accepts. See [TypeSignature] for more information. pub type_signature: TypeSignature, diff --git a/src/common/recordbatch/src/recordbatch.rs b/src/common/recordbatch/src/recordbatch.rs index a2ceea677f..233c775eca 100644 --- a/src/common/recordbatch/src/recordbatch.rs +++ b/src/common/recordbatch/src/recordbatch.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use datafusion_common::record_batch::RecordBatch as DfRecordBatch; use datatypes::schema::Schema; -use datatypes::vectors; +use datatypes::vectors::Helper; use serde::ser::{Error, SerializeStruct}; use serde::{Serialize, Serializer}; @@ -24,7 +24,7 @@ impl Serialize for RecordBatch { let vec = df_columns .iter() - .map(|c| vectors::try_into_vector(c.clone())?.serialize_to_json()) + .map(|c| Helper::try_into_vector(c.clone())?.serialize_to_json()) .collect::, _>>() .map_err(S::Error::custom)?; diff --git a/src/datatypes/Cargo.toml b/src/datatypes/Cargo.toml index 2ad7b8eb1a..b10d34296b 100644 --- a/src/datatypes/Cargo.toml +++ b/src/datatypes/Cargo.toml @@ -11,6 +11,7 @@ features = ["io_csv", "io_json", "io_parquet", "io_parquet_compression", "io_ipc [dependencies] common-base = { path = "../common/base" } common-error = { path = "../common/error" } +datafusion-common = { git = "https://github.com/apache/arrow-datafusion.git" , branch = "arrow2" } enum_dispatch = "0.3" paste = "1.0" serde = { version = "1.0.136", features = ["derive"] } diff --git a/src/datatypes/src/data_type.rs b/src/datatypes/src/data_type.rs index 09f37488c3..54197fca8e 100644 --- a/src/datatypes/src/data_type.rs +++ b/src/datatypes/src/data_type.rs @@ -34,6 +34,48 @@ pub enum ConcreteDataType { } impl ConcreteDataType { + pub fn is_float(&self) -> bool { + matches!( + self, + ConcreteDataType::Float64(_) | ConcreteDataType::Float32(_) + ) + } + + pub fn is_signed(&self) -> bool { + matches!( + self, + ConcreteDataType::Int8(_) + | ConcreteDataType::Int16(_) + | ConcreteDataType::Int32(_) + | ConcreteDataType::Int64(_) + ) + } + + pub fn is_unsigned(&self) -> bool { + matches!( + self, + ConcreteDataType::UInt8(_) + | ConcreteDataType::UInt16(_) + | ConcreteDataType::UInt32(_) + | ConcreteDataType::UInt64(_) + ) + } + + pub fn numerics() -> Vec { + vec![ + ConcreteDataType::int8_datatype(), + ConcreteDataType::int16_datatype(), + ConcreteDataType::int32_datatype(), + ConcreteDataType::int64_datatype(), + ConcreteDataType::uint8_datatype(), + ConcreteDataType::uint16_datatype(), + ConcreteDataType::uint32_datatype(), + ConcreteDataType::uint64_datatype(), + ConcreteDataType::float32_datatype(), + ConcreteDataType::float64_datatype(), + ] + } + /// Convert arrow data type to [ConcreteDataType]. /// /// # Panics diff --git a/src/datatypes/src/error.rs b/src/datatypes/src/error.rs index 3671775340..287126dfff 100644 --- a/src/datatypes/src/error.rs +++ b/src/datatypes/src/error.rs @@ -11,9 +11,16 @@ pub enum Error { source: serde_json::Error, backtrace: Backtrace, }, - #[snafu(display("Failed to convert datafusion type: {}", from))] Conversion { from: String, backtrace: Backtrace }, + #[snafu(display("Bad array access, Index out of bounds: {}, size: {}", index, size))] + BadArrayAccess { + index: usize, + size: usize, + backtrace: Backtrace, + }, + #[snafu(display("Unknown vector, {}", msg))] + UnknownVector { msg: String, backtrace: Backtrace }, } impl ErrorExt for Error { diff --git a/src/datatypes/src/lib.rs b/src/datatypes/src/lib.rs index b4cb6a34da..d7bd7ed6ec 100644 --- a/src/datatypes/src/lib.rs +++ b/src/datatypes/src/lib.rs @@ -4,6 +4,7 @@ pub mod arrow_array; pub mod data_type; pub mod deserialize; pub mod error; +pub mod macros; pub mod prelude; mod scalars; pub mod schema; diff --git a/src/datatypes/src/macros.rs b/src/datatypes/src/macros.rs new file mode 100644 index 0000000000..bc8392628e --- /dev/null +++ b/src/datatypes/src/macros.rs @@ -0,0 +1,65 @@ +///! Some helper macros for datatypes, copied from databend. +#[macro_export] +macro_rules! for_all_scalar_types { + ($macro:tt $(, $x:tt)*) => { + $macro! { + [$($x),*], + { i8 }, + { i16 }, + { i32 }, + { i64 }, + { u8 }, + { u16 }, + { u32 }, + { u64 }, + { f32 }, + { f64 }, + { bool }, + } + }; +} + +#[macro_export] +macro_rules! for_all_primitive_types{ + ($macro:tt $(, $x:tt)*) => { + $macro! { + [$($x),*], + { i8 }, + { i16 }, + { i32 }, + { i64 }, + { u8 }, + { u16 }, + { u32 }, + { u64 }, + { f32 }, + { f64 } + } + }; +} + +#[macro_export] +macro_rules! with_match_primitive_type_id { + ($key_type:expr, | $_:tt $T:ident | $body:tt, $nbody:tt) => {{ + macro_rules! __with_ty__ { + ( $_ $T:ident ) => { + $body + }; + } + + match $key_type { + LogicalTypeId::Int8 => __with_ty__! { i8 }, + LogicalTypeId::Int16 => __with_ty__! { i16 }, + LogicalTypeId::Int32 => __with_ty__! { i32 }, + LogicalTypeId::Int64 => __with_ty__! { i64 }, + LogicalTypeId::UInt8 => __with_ty__! { u8 }, + LogicalTypeId::UInt16 => __with_ty__! { u16 }, + LogicalTypeId::UInt32 => __with_ty__! { u32 }, + LogicalTypeId::UInt64 => __with_ty__! { u64 }, + LogicalTypeId::Float32 => __with_ty__! { f32 }, + LogicalTypeId::Float64 => __with_ty__! { f64 }, + + _ => $nbody, + } + }}; +} diff --git a/src/datatypes/src/prelude.rs b/src/datatypes/src/prelude.rs index 6d1a2d5bfb..fc460cd473 100644 --- a/src/datatypes/src/prelude.rs +++ b/src/datatypes/src/prelude.rs @@ -1,5 +1,6 @@ pub use crate::data_type::{ConcreteDataType, DataType, DataTypeRef}; -pub use crate::scalars::{ScalarVector, ScalarVectorBuilder}; +pub use crate::macros::*; +pub use crate::scalars::{Scalar, ScalarRef, ScalarVector, ScalarVectorBuilder}; pub use crate::type_id::LogicalTypeId; pub use crate::value::Value; -pub use crate::vectors::{Validity, Vector, VectorRef}; +pub use crate::vectors::{Helper as VectorHelper, MutableVector, Validity, Vector, VectorRef}; diff --git a/src/datatypes/src/scalars.rs b/src/datatypes/src/scalars.rs index fb9910921b..b27097bf10 100644 --- a/src/datatypes/src/scalars.rs +++ b/src/datatypes/src/scalars.rs @@ -1,11 +1,53 @@ -use crate::vectors::Vector; +use std::any::Any; + +pub mod common; +use crate::prelude::*; +use crate::vectors::*; + +fn get_iter_capacity>(iter: &I) -> usize { + match iter.size_hint() { + (_lower, Some(upper)) => upper, + (0, None) => 1024, + (lower, None) => lower, + } +} + +/// Owned scalar value +/// primitive types, bool, Vec ... +pub trait Scalar: 'static + Sized + Default + Any +where + for<'a> Self::VectorType: ScalarVector = Self::RefType<'a>>, +{ + type VectorType: ScalarVector; + type RefType<'a>: ScalarRef<'a, ScalarType = Self, VectorType = Self::VectorType> + where + Self: 'a; + /// Get a reference of the current value. + fn as_scalar_ref(&self) -> Self::RefType<'_>; + + /// Upcast GAT type's lifetime. + fn upcast_gat<'short, 'long: 'short>(long: Self::RefType<'long>) -> Self::RefType<'short>; +} + +pub trait ScalarRef<'a>: std::fmt::Debug + Clone + Copy + Send + 'a { + type VectorType: ScalarVector = Self>; + /// The corresponding [`Scalar`] type. + type ScalarType: Scalar = Self>; + + /// Convert the reference into an owned value. + fn to_owned_scalar(&self) -> Self::ScalarType; +} /// A sub trait of Vector to add scalar operation support. // This implementation refers to Datebend's [ScalarColumn](https://github.com/datafuselabs/databend/blob/main/common/datavalues/src/scalars/type_.rs) // and skyzh's [type-exercise-in-rust](https://github.com/skyzh/type-exercise-in-rust). -pub trait ScalarVector: Vector { +pub trait ScalarVector: Vector + Send + Sync + Sized + 'static +where + for<'a> Self::OwnedItem: Scalar = Self::RefItem<'a>>, +{ + type OwnedItem: Scalar; /// The reference item of this vector. - type RefItem<'a>: Copy + type RefItem<'a>: ScalarRef<'a, ScalarType = Self::OwnedItem, VectorType = Self> where Self: 'a; @@ -27,10 +69,46 @@ pub trait ScalarVector: Vector { /// Returns iterator of current vector. fn iter_data(&self) -> Self::Iter<'_>; + + fn from_slice(data: &[Self::RefItem<'_>]) -> Self { + let mut builder = Self::Builder::with_capacity(data.len()); + for item in data { + builder.push(Some(*item)); + } + builder.finish() + } + + fn from_iterator<'a>(it: impl Iterator>) -> Self { + let mut builder = Self::Builder::with_capacity(get_iter_capacity(&it)); + for item in it { + builder.push(Some(item)); + } + builder.finish() + } + + fn from_owned_iterator(it: impl Iterator>) -> Self { + let mut builder = Self::Builder::with_capacity(get_iter_capacity(&it)); + for item in it { + match item { + Some(item) => builder.push(Some(item.as_scalar_ref())), + None => builder.push(None), + } + } + builder.finish() + } + + fn from_vecs(values: Vec) -> Self { + let it = values.iter(); + let mut builder = Self::Builder::with_capacity(get_iter_capacity(&it)); + for item in it { + builder.push(Some(item.as_scalar_ref())); + } + builder.finish() + } } /// A trait over all vector builders. -pub trait ScalarVectorBuilder { +pub trait ScalarVectorBuilder: MutableVector { type VectorType: ScalarVector; /// Create a new builder with initial `capacity`. @@ -40,7 +118,125 @@ pub trait ScalarVectorBuilder { fn push(&mut self, value: Option<::RefItem<'_>>); /// Finish build and return a new vector. - fn finish(self) -> Self::VectorType; + fn finish(&mut self) -> Self::VectorType; +} + +macro_rules! impl_primitive_scalar_type { + ($native:ident) => { + impl Scalar for $native { + type VectorType = PrimitiveVector<$native>; + type RefType<'a> = $native; + + #[inline] + fn as_scalar_ref(&self) -> $native { + *self + } + + #[allow(clippy::needless_lifetimes)] + #[inline] + fn upcast_gat<'short, 'long: 'short>(long: $native) -> $native { + long + } + } + + /// Implement [`ScalarRef`] for primitive types. Note that primitive types are both [`Scalar`] and [`ScalarRef`]. + impl<'a> ScalarRef<'a> for $native { + type VectorType = PrimitiveVector<$native>; + type ScalarType = $native; + + #[inline] + fn to_owned_scalar(&self) -> $native { + *self + } + } + }; +} + +impl_primitive_scalar_type!(u8); +impl_primitive_scalar_type!(u16); +impl_primitive_scalar_type!(u32); +impl_primitive_scalar_type!(u64); +impl_primitive_scalar_type!(i8); +impl_primitive_scalar_type!(i16); +impl_primitive_scalar_type!(i32); +impl_primitive_scalar_type!(i64); +impl_primitive_scalar_type!(f32); +impl_primitive_scalar_type!(f64); + +impl Scalar for bool { + type VectorType = BooleanVector; + type RefType<'a> = bool; + + #[inline] + fn as_scalar_ref(&self) -> bool { + *self + } + + #[allow(clippy::needless_lifetimes)] + #[inline] + fn upcast_gat<'short, 'long: 'short>(long: bool) -> bool { + long + } +} + +impl<'a> ScalarRef<'a> for bool { + type VectorType = BooleanVector; + type ScalarType = bool; + + #[inline] + fn to_owned_scalar(&self) -> bool { + *self + } +} + +impl Scalar for String { + type VectorType = StringVector; + type RefType<'a> = &'a str; + + #[inline] + fn as_scalar_ref(&self) -> &str { + self + } + + #[inline] + fn upcast_gat<'short, 'long: 'short>(long: &'long str) -> &'short str { + long + } +} + +impl<'a> ScalarRef<'a> for &'a str { + type VectorType = StringVector; + type ScalarType = String; + + #[inline] + fn to_owned_scalar(&self) -> String { + self.to_string() + } +} + +impl Scalar for Vec { + type VectorType = BinaryVector; + type RefType<'a> = &'a [u8]; + + #[inline] + fn as_scalar_ref(&self) -> &[u8] { + self + } + + #[inline] + fn upcast_gat<'short, 'long: 'short>(long: &'long [u8]) -> &'short [u8] { + long + } +} + +impl<'a> ScalarRef<'a> for &'a [u8] { + type VectorType = BinaryVector; + type ScalarType = Vec; + + #[inline] + fn to_owned_scalar(&self) -> Vec { + self.to_vec() + } } #[cfg(test)] diff --git a/src/datatypes/src/scalars/common.rs b/src/datatypes/src/scalars/common.rs new file mode 100644 index 0000000000..29e6b90517 --- /dev/null +++ b/src/datatypes/src/scalars/common.rs @@ -0,0 +1,23 @@ +use crate::prelude::*; + +pub fn replicate_scalar_vector(c: &C, offsets: &[usize]) -> VectorRef { + debug_assert!( + offsets.len() == c.len(), + "Size of offsets must match size of vector" + ); + + if offsets.is_empty() { + return c.slice(0, 0); + } + let mut builder = <::Builder>::with_capacity(c.len()); + + let mut previous_offset = 0; + for (i, offset) in offsets.iter().enumerate() { + let data = c.get_data(i); + for _ in previous_offset..*offset { + builder.push(data); + } + previous_offset = *offset; + } + builder.to_vector() +} diff --git a/src/datatypes/src/types/primitive_type.rs b/src/datatypes/src/types/primitive_type.rs index 45f9823a1e..6c1e5b5bfe 100644 --- a/src/datatypes/src/types/primitive_type.rs +++ b/src/datatypes/src/types/primitive_type.rs @@ -16,6 +16,7 @@ pub struct PrimitiveType { /// Create a new [ConcreteDataType] from a primitive type. pub trait DataTypeBuilder { fn build_data_type() -> ConcreteDataType; + fn type_name() -> String; } macro_rules! impl_build_data_type { @@ -25,6 +26,9 @@ macro_rules! impl_build_data_type { fn build_data_type() -> ConcreteDataType { ConcreteDataType::$TypeId(PrimitiveType::<$Type>::default()) } + fn type_name() -> String { + stringify!($TypeId).to_string() + } } } }; diff --git a/src/datatypes/src/value.rs b/src/datatypes/src/value.rs index a0d1504d0c..3a236c5699 100644 --- a/src/datatypes/src/value.rs +++ b/src/datatypes/src/value.rs @@ -1,7 +1,8 @@ use common_base::bytes::{Bytes, StringBytes}; +use serde::{Serialize, Serializer}; /// Value holds a single arbitrary value of any [DataType](crate::data_type::DataType). -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Clone)] pub enum Value { Null, @@ -59,3 +60,41 @@ impl_from!(Float32, f32); impl_from!(Float64, f64); impl_from!(String, StringBytes); impl_from!(Binary, Bytes); + +impl From<&[u8]> for Value { + fn from(s: &[u8]) -> Self { + Value::Binary(Bytes(s.to_vec())) + } +} + +impl From<&str> for Value { + fn from(s: &str) -> Self { + Value::String(StringBytes(s.to_string().into_bytes())) + } +} + +impl Serialize for Value { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + match self { + Value::Null => serde_json::Value::Null.serialize(serializer), + Value::Boolean(v) => v.serialize(serializer), + Value::UInt8(v) => v.serialize(serializer), + Value::UInt16(v) => v.serialize(serializer), + Value::UInt32(v) => v.serialize(serializer), + Value::UInt64(v) => v.serialize(serializer), + Value::Int8(v) => v.serialize(serializer), + Value::Int16(v) => v.serialize(serializer), + Value::Int32(v) => v.serialize(serializer), + Value::Int64(v) => v.serialize(serializer), + Value::Float32(v) => v.serialize(serializer), + Value::Float64(v) => v.serialize(serializer), + Value::String(bytes) => bytes.serialize(serializer), + Value::Binary(bytes) => bytes.serialize(serializer), + Value::Date(v) => v.serialize(serializer), + Value::DateTime(v) => v.serialize(serializer), + } + } +} diff --git a/src/datatypes/src/vectors.rs b/src/datatypes/src/vectors.rs index 969d96b8f4..2d7c82c741 100644 --- a/src/datatypes/src/vectors.rs +++ b/src/datatypes/src/vectors.rs @@ -1,5 +1,8 @@ pub mod binary; pub mod boolean; +pub mod constant; +mod helper; +pub mod mutable; pub mod null; pub mod primitive; mod string; @@ -9,16 +12,20 @@ use std::sync::Arc; use arrow::array::ArrayRef; use arrow::bitmap::Bitmap; -use arrow::datatypes::DataType as ArrowDataType; pub use binary::*; pub use boolean::*; +pub use constant::*; +pub use helper::Helper; +pub use mutable::MutableVector; pub use null::*; pub use primitive::*; +use snafu::ensure; pub use string::*; use crate::data_type::ConcreteDataType; -use crate::error::Result; +use crate::error::{BadArrayAccessSnafu, Result}; use crate::serialize::Serializable; +use crate::value::Value; pub use crate::vectors::{ BinaryVector, BooleanVector, Float32Vector, Float64Vector, Int16Vector, Int32Vector, Int64Vector, Int8Vector, NullVector, StringVector, UInt16Vector, UInt32Vector, UInt64Vector, @@ -51,6 +58,8 @@ pub trait Vector: Send + Sync + Serializable { /// This may require heap allocation. fn data_type(&self) -> ConcreteDataType; + fn vector_type_name(&self) -> String; + /// Returns the vector as [Any](std::any::Any) so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; @@ -79,38 +88,44 @@ pub trait Vector: Send + Sync + Serializable { Validity::AllNull => self.len(), } } + + /// Returns true when it's a ConstantColumn + fn is_const(&self) -> bool { + false + } + + /// Returns whether row is null. + fn is_null(&self, row: usize) -> bool; + + /// If the only value vector can contain is NULL. + fn only_null(&self) -> bool { + self.null_count() == self.len() + } + + fn slice(&self, offset: usize, length: usize) -> VectorRef; + + /// # Safety + /// Assumes that the `index` is smaller than size. + fn get_unchecked(&self, index: usize) -> Value; + + fn get(&self, index: usize) -> Result { + ensure!( + index < self.len(), + BadArrayAccessSnafu { + index, + size: self.len() + } + ); + Ok(self.get_unchecked(index)) + } + + // Copies each element according offsets parameter. + // (i-th element should be copied offsets[i] - offsets[i - 1] times.) + fn replicate(&self, offsets: &[usize]) -> VectorRef; } pub type VectorRef = Arc; -/// Try to cast an arrow array into vector -/// -/// # Panics -/// Panic if given arrow data type is not supported. -pub fn try_into_vector(array: ArrayRef) -> Result { - Ok(match array.data_type() { - ArrowDataType::Null => Arc::new(NullVector::try_from_arrow_array(array)?), - ArrowDataType::Boolean => Arc::new(BooleanVector::try_from_arrow_array(array)?), - ArrowDataType::Binary | ArrowDataType::LargeBinary => { - Arc::new(BinaryVector::try_from_arrow_array(array)?) - } - ArrowDataType::Int8 => Arc::new(Int8Vector::try_from_arrow_array(array)?), - ArrowDataType::Int16 => Arc::new(Int16Vector::try_from_arrow_array(array)?), - ArrowDataType::Int32 => Arc::new(Int32Vector::try_from_arrow_array(array)?), - ArrowDataType::Int64 => Arc::new(Int64Vector::try_from_arrow_array(array)?), - ArrowDataType::UInt8 => Arc::new(UInt8Vector::try_from_arrow_array(array)?), - ArrowDataType::UInt16 => Arc::new(UInt16Vector::try_from_arrow_array(array)?), - ArrowDataType::UInt32 => Arc::new(UInt32Vector::try_from_arrow_array(array)?), - ArrowDataType::UInt64 => Arc::new(UInt64Vector::try_from_arrow_array(array)?), - ArrowDataType::Float32 => Arc::new(Float32Vector::try_from_arrow_array(array)?), - ArrowDataType::Float64 => Arc::new(Float64Vector::try_from_arrow_array(array)?), - ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 => { - Arc::new(StringVector::try_from_arrow_array(array)?) - } - _ => unimplemented!(), - }) -} - /// Helper to define `try_from_arrow_array(array: arrow::array::ArrayRef)` function. macro_rules! impl_try_from_arrow_array_for_vector { ($Array: ident, $Vector: ident) => { @@ -139,6 +154,7 @@ pub mod tests { use arrow::array::{Array, PrimitiveArray}; use serde_json; + use super::helper::Helper; use super::*; use crate::data_type::DataType; use crate::types::DataTypeBuilder; @@ -146,7 +162,7 @@ pub mod tests { #[test] fn test_df_columns_to_vector() { let df_column: Arc = Arc::new(PrimitiveArray::from_slice(vec![1, 2, 3])); - let vector = try_into_vector(df_column).unwrap(); + let vector = Helper::try_into_vector(df_column).unwrap(); assert_eq!( i32::build_data_type().as_arrow_type(), vector.data_type().as_arrow_type() @@ -156,7 +172,7 @@ pub mod tests { #[test] fn test_serialize_i32_vector() { let df_column: Arc = Arc::new(PrimitiveArray::::from_slice(vec![1, 2, 3])); - let json_value = try_into_vector(df_column) + let json_value = Helper::try_into_vector(df_column) .unwrap() .serialize_to_json() .unwrap(); @@ -166,7 +182,7 @@ pub mod tests { #[test] fn test_serialize_i8_vector() { let df_column: Arc = Arc::new(PrimitiveArray::from_slice(vec![1u8, 2u8, 3u8])); - let json_value = try_into_vector(df_column) + let json_value = Helper::try_into_vector(df_column) .unwrap() .serialize_to_json() .unwrap(); diff --git a/src/datatypes/src/vectors/binary.rs b/src/datatypes/src/vectors/binary.rs index 4cb0b09357..9d7695eee2 100644 --- a/src/datatypes/src/vectors/binary.rs +++ b/src/datatypes/src/vectors/binary.rs @@ -1,8 +1,8 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::BinaryValueIter; use arrow::array::{Array, ArrayRef, BinaryArray}; +use arrow::array::{BinaryValueIter, MutableArray}; use arrow::bitmap::utils::ZipValidity; use snafu::OptionExt; use snafu::ResultExt; @@ -11,11 +11,11 @@ use crate::arrow_array::{LargeBinaryArray, MutableLargeBinaryArray}; use crate::data_type::ConcreteDataType; use crate::error::Result; use crate::error::SerializeSnafu; -use crate::scalars::{ScalarVector, ScalarVectorBuilder}; +use crate::scalars::{common, ScalarVector, ScalarVectorBuilder}; use crate::serialize::Serializable; -use crate::types::BinaryType; +use crate::value::Value; use crate::vectors::impl_try_from_arrow_array_for_vector; -use crate::vectors::{Validity, Vector}; +use crate::vectors::{MutableVector, Validity, Vector, VectorRef}; /// Vector of binary strings. #[derive(Debug)] @@ -29,9 +29,21 @@ impl From> for BinaryVector { } } +impl From>>> for BinaryVector { + fn from(data: Vec>>) -> Self { + Self { + array: LargeBinaryArray::from(data), + } + } +} + impl Vector for BinaryVector { fn data_type(&self) -> ConcreteDataType { - ConcreteDataType::Binary(BinaryType::default()) + ConcreteDataType::binary_datatype() + } + + fn vector_type_name(&self) -> String { + "BinaryVector".to_string() } fn as_any(&self) -> &dyn Any { @@ -52,9 +64,26 @@ impl Vector for BinaryVector { None => Validity::AllValid, } } + + fn is_null(&self, row: usize) -> bool { + self.array.is_null(row) + } + + fn slice(&self, offset: usize, length: usize) -> VectorRef { + Arc::new(Self::from(self.array.slice(offset, length))) + } + + fn get_unchecked(&self, index: usize) -> Value { + self.array.value(index).into() + } + + fn replicate(&self, offsets: &[usize]) -> VectorRef { + common::replicate_scalar_vector(self, offsets) + } } impl ScalarVector for BinaryVector { + type OwnedItem = Vec; type RefItem<'a> = &'a [u8]; type Iter<'a> = ZipValidity<'a, &'a [u8], BinaryValueIter<'a, i64>>; type Builder = BinaryVectorBuilder; @@ -76,6 +105,28 @@ pub struct BinaryVectorBuilder { mutable_array: MutableLargeBinaryArray, } +impl MutableVector for BinaryVectorBuilder { + fn data_type(&self) -> ConcreteDataType { + ConcreteDataType::binary_datatype() + } + + fn len(&self) -> usize { + self.mutable_array.len() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn Any { + self + } + + fn to_vector(&mut self) -> VectorRef { + Arc::new(self.finish()) + } +} + impl ScalarVectorBuilder for BinaryVectorBuilder { type VectorType = BinaryVector; @@ -89,9 +140,9 @@ impl ScalarVectorBuilder for BinaryVectorBuilder { self.mutable_array.push(value); } - fn finish(self) -> Self::VectorType { + fn finish(&mut self) -> Self::VectorType { BinaryVector { - array: self.mutable_array.into(), + array: std::mem::take(&mut self.mutable_array).into(), } } } @@ -112,12 +163,37 @@ impl_try_from_arrow_array_for_vector!(LargeBinaryArray, BinaryVector); #[cfg(test)] mod tests { + use arrow::datatypes::DataType as ArrowDataType; + use common_base::bytes::Bytes; use serde_json; use super::*; use crate::arrow_array::LargeBinaryArray; use crate::serialize::Serializable; + #[test] + fn test_binary_vector_misc() { + let v = BinaryVector::from(LargeBinaryArray::from_slice(&vec![ + vec![1, 2, 3], + vec![1, 2, 3], + ])); + + assert_eq!(2, v.len()); + assert_eq!("BinaryVector", v.vector_type_name()); + assert!(!v.is_const()); + assert_eq!(Validity::AllValid, v.validity()); + assert!(!v.only_null()); + + for i in 0..2 { + assert!(!v.is_null(i)); + assert_eq!(Value::Binary(Bytes(vec![1, 2, 3])), v.get_unchecked(i)); + } + + let arrow_arr = v.to_arrow_array(); + assert_eq!(2, arrow_arr.len()); + assert_eq!(&ArrowDataType::LargeBinary, arrow_arr.data_type()); + } + #[test] fn test_serialize_binary_vector_to_json() { let vector = BinaryVector::from(LargeBinaryArray::from_slice(&vec![ diff --git a/src/datatypes/src/vectors/boolean.rs b/src/datatypes/src/vectors/boolean.rs index 0f383087a8..a98240e579 100644 --- a/src/datatypes/src/vectors/boolean.rs +++ b/src/datatypes/src/vectors/boolean.rs @@ -2,18 +2,19 @@ use std::any::Any; use std::borrow::Borrow; use std::sync::Arc; -use arrow::array::{Array, ArrayRef, BooleanArray, MutableBooleanArray}; +use arrow::array::{Array, ArrayRef, BooleanArray, MutableArray, MutableBooleanArray}; use arrow::bitmap::utils::{BitmapIter, ZipValidity}; use snafu::OptionExt; use snafu::ResultExt; use crate::data_type::ConcreteDataType; use crate::error::Result; +use crate::scalars::common::replicate_scalar_vector; use crate::scalars::{ScalarVector, ScalarVectorBuilder}; use crate::serialize::Serializable; -use crate::types::BooleanType; +use crate::value::Value; use crate::vectors::impl_try_from_arrow_array_for_vector; -use crate::vectors::{Validity, Vector}; +use crate::vectors::{MutableVector, Validity, Vector, VectorRef}; /// Vector of boolean. #[derive(Debug)] @@ -53,7 +54,11 @@ impl>> FromIterator for BooleanVector { impl Vector for BooleanVector { fn data_type(&self) -> ConcreteDataType { - ConcreteDataType::Boolean(BooleanType::default()) + ConcreteDataType::boolean_datatype() + } + + fn vector_type_name(&self) -> String { + "BooleanVector".to_string() } fn as_any(&self) -> &dyn Any { @@ -74,9 +79,26 @@ impl Vector for BooleanVector { None => Validity::AllValid, } } + + fn is_null(&self, row: usize) -> bool { + self.array.is_null(row) + } + + fn slice(&self, offset: usize, length: usize) -> VectorRef { + Arc::new(Self::from(self.array.slice(offset, length))) + } + + fn get_unchecked(&self, index: usize) -> Value { + self.array.value(index).into() + } + + fn replicate(&self, offsets: &[usize]) -> VectorRef { + replicate_scalar_vector(self, offsets) + } } impl ScalarVector for BooleanVector { + type OwnedItem = bool; type RefItem<'a> = bool; type Iter<'a> = ZipValidity<'a, bool, BitmapIter<'a>>; type Builder = BooleanVectorBuilder; @@ -98,6 +120,28 @@ pub struct BooleanVectorBuilder { mutable_array: MutableBooleanArray, } +impl MutableVector for BooleanVectorBuilder { + fn data_type(&self) -> ConcreteDataType { + ConcreteDataType::boolean_datatype() + } + + fn len(&self) -> usize { + self.mutable_array.len() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn Any { + self + } + + fn to_vector(&mut self) -> VectorRef { + Arc::new(self.finish()) + } +} + impl ScalarVectorBuilder for BooleanVectorBuilder { type VectorType = BooleanVector; @@ -111,9 +155,9 @@ impl ScalarVectorBuilder for BooleanVectorBuilder { self.mutable_array.push(value); } - fn finish(self) -> Self::VectorType { + fn finish(&mut self) -> Self::VectorType { BooleanVector { - array: self.mutable_array.into(), + array: std::mem::take(&mut self.mutable_array).into(), } } } @@ -131,11 +175,32 @@ impl_try_from_arrow_array_for_vector!(BooleanArray, BooleanVector); #[cfg(test)] mod tests { + use arrow::datatypes::DataType as ArrowDataType; use serde_json; use super::*; use crate::serialize::Serializable; + #[test] + fn test_boolean_vector_misc() { + let bools = vec![true, false, true, true, false, false]; + let v = BooleanVector::from(bools.clone()); + assert_eq!(6, v.len()); + assert_eq!("BooleanVector", v.vector_type_name()); + assert!(!v.is_const()); + assert_eq!(Validity::AllValid, v.validity()); + assert!(!v.only_null()); + + for (i, b) in bools.iter().enumerate() { + assert!(!v.is_null(i)); + assert_eq!(Value::Boolean(*b), v.get_unchecked(i)); + } + + let arrow_arr = v.to_arrow_array(); + assert_eq!(6, arrow_arr.len()); + assert_eq!(&ArrowDataType::Boolean, arrow_arr.data_type()); + } + #[test] fn test_serialize_boolean_vector_to_json() { let vector = BooleanVector::from(vec![true, false, true, true, false, false]); diff --git a/src/datatypes/src/vectors/constant.rs b/src/datatypes/src/vectors/constant.rs new file mode 100644 index 0000000000..b696fc1824 --- /dev/null +++ b/src/datatypes/src/vectors/constant.rs @@ -0,0 +1,164 @@ +use std::any::Any; +use std::fmt; +use std::sync::Arc; + +use arrow::array::ArrayRef; +use snafu::ResultExt; + +use crate::data_type::ConcreteDataType; +use crate::error::{Result, SerializeSnafu}; +use crate::serialize::Serializable; +use crate::value::Value; +use crate::vectors::Helper; +use crate::vectors::{Validity, Vector, VectorRef}; + +#[derive(Clone)] +pub struct ConstantVector { + length: usize, + vector: VectorRef, +} + +impl ConstantVector { + pub fn new(vector: VectorRef, length: usize) -> Self { + // Avoid const recursion. + if vector.is_const() { + let vec: &ConstantVector = unsafe { Helper::static_cast(&vector) }; + return Self::new(vec.inner().clone(), length); + } + Self { vector, length } + } + pub fn inner(&self) -> &VectorRef { + &self.vector + } +} + +impl Vector for ConstantVector { + fn data_type(&self) -> ConcreteDataType { + self.vector.data_type() + } + + fn vector_type_name(&self) -> String { + "ConstantVector".to_string() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn len(&self) -> usize { + self.length + } + + fn to_arrow_array(&self) -> ArrayRef { + let v = self.vector.replicate(&[self.length]); + v.to_arrow_array() + } + + fn is_const(&self) -> bool { + true + } + + fn validity(&self) -> Validity { + if self.vector.is_null(0) { + Validity::AllNull + } else { + Validity::AllValid + } + } + + fn is_null(&self, _row: usize) -> bool { + self.vector.is_null(0) + } + + fn only_null(&self) -> bool { + self.vector.is_null(0) + } + + fn slice(&self, _offset: usize, length: usize) -> VectorRef { + Arc::new(Self { + vector: self.vector.clone(), + length, + }) + } + + fn get_unchecked(&self, _index: usize) -> Value { + self.vector.get_unchecked(0) + } + + fn replicate(&self, offsets: &[usize]) -> VectorRef { + debug_assert!( + offsets.len() == self.len(), + "Size of offsets must match size of column" + ); + + Arc::new(Self::new(self.vector.clone(), *offsets.last().unwrap())) + } +} + +impl fmt::Debug for ConstantVector { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "ConstantVector([{:?}; {}])", + self.get(0).unwrap_or(Value::Null), + self.len() + ) + } +} + +impl Serializable for ConstantVector { + fn serialize_to_json(&self) -> Result> { + std::iter::repeat(self.get(0)?) + .take(self.len()) + .map(serde_json::to_value) + .collect::>() + .context(SerializeSnafu) + } +} + +#[cfg(test)] +mod tests { + use arrow::datatypes::DataType as ArrowDataType; + + use super::*; + use crate::vectors::Int32Vector; + + #[test] + fn test_constant_vector_misc() { + let a = Int32Vector::from_slice(vec![1]); + let c = ConstantVector::new(Arc::new(a), 10); + + assert_eq!("ConstantVector", c.vector_type_name()); + assert!(c.is_const()); + assert_eq!(10, c.len()); + assert_eq!(Validity::AllValid, c.validity()); + assert!(!c.only_null()); + + for i in 0..10 { + assert!(!c.is_null(i)); + assert_eq!(Value::Int32(1), c.get_unchecked(i)); + } + + let arrow_arr = c.to_arrow_array(); + assert_eq!(10, arrow_arr.len()); + assert_eq!(&ArrowDataType::Int32, arrow_arr.data_type()); + } + + #[test] + fn test_debug_null_array() { + let a = Int32Vector::from_slice(vec![1]); + let c = ConstantVector::new(Arc::new(a), 10); + + let s = format!("{:?}", c); + assert_eq!(s, "ConstantVector([Int32(1); 10])"); + } + + #[test] + fn test_serialize_json() { + let a = Int32Vector::from_slice(vec![1]); + let c = ConstantVector::new(Arc::new(a), 10); + + let s = serde_json::to_string(&c.serialize_to_json().unwrap()).unwrap(); + assert_eq!(s, "[1,1,1,1,1,1,1,1,1,1]"); + } +} diff --git a/src/datatypes/src/vectors/helper.rs b/src/datatypes/src/vectors/helper.rs new file mode 100644 index 0000000000..67ba77747a --- /dev/null +++ b/src/datatypes/src/vectors/helper.rs @@ -0,0 +1,177 @@ +//! Vector helper functions, inspired by databend Series mod + +use std::any::Any; +use std::sync::Arc; + +use arrow::datatypes::DataType as ArrowDataType; +use datafusion_common::ScalarValue; +use snafu::OptionExt; + +use crate::error::{ConversionSnafu, Result, UnknownVectorSnafu}; +use crate::scalars::*; +use crate::vectors::*; + +pub struct Helper; + +impl Helper { + /// Get a pointer to the underlying data of this vectors. + /// Can be useful for fast comparisons. + /// # Safety + /// Assumes that the `vector` is T. + pub unsafe fn static_cast(vector: &VectorRef) -> &T { + let object = vector.as_ref(); + debug_assert!(object.as_any().is::()); + &*(object as *const dyn Vector as *const T) + } + + pub fn check_get_scalar(vector: &VectorRef) -> Result<&::VectorType> { + let arr = vector + .as_any() + .downcast_ref::<::VectorType>() + .with_context(|| UnknownVectorSnafu { + msg: format!( + "downcast vector error, vector type: {:?}, expected vector: {:?}", + vector.vector_type_name(), + std::any::type_name::(), + ), + }); + arr + } + + pub fn check_get(vector: &VectorRef) -> Result<&T> { + let arr = vector + .as_any() + .downcast_ref::() + .with_context(|| UnknownVectorSnafu { + msg: format!( + "downcast vector error, vector type: {:?}, expected vector: {:?}", + vector.vector_type_name(), + std::any::type_name::(), + ), + }); + arr + } + + pub fn check_get_mutable_vector( + vector: &mut dyn MutableVector, + ) -> Result<&mut T> { + let ty = vector.data_type(); + let arr = vector + .as_mut_any() + .downcast_mut() + .with_context(|| UnknownVectorSnafu { + msg: format!( + "downcast vector error, vector type: {:?}, expected vector: {:?}", + ty, + std::any::type_name::(), + ), + }); + arr + } + + pub fn check_get_scalar_vector( + vector: &VectorRef, + ) -> Result<&::VectorType> { + let arr = vector + .as_any() + .downcast_ref::<::VectorType>() + .with_context(|| UnknownVectorSnafu { + msg: format!( + "downcast vector error, vector type: {:?}, expected vector: {:?}", + vector.vector_type_name(), + std::any::type_name::(), + ), + }); + arr + } + + /// Try to cast an arrow scalar value into vector + /// + /// # Panics + /// Panic if given scalar value is not supported. + pub fn try_from_scalar_value(value: ScalarValue, length: usize) -> Result { + let vector = match value { + ScalarValue::Boolean(v) => { + ConstantVector::new(Arc::new(BooleanVector::from(vec![v])), length) + } + ScalarValue::Float32(v) => { + ConstantVector::new(Arc::new(Float32Vector::from(vec![v])), length) + } + ScalarValue::Float64(v) => { + ConstantVector::new(Arc::new(Float64Vector::from(vec![v])), length) + } + ScalarValue::Int8(v) => { + ConstantVector::new(Arc::new(Int8Vector::from(vec![v])), length) + } + ScalarValue::Int16(v) => { + ConstantVector::new(Arc::new(Int16Vector::from(vec![v])), length) + } + ScalarValue::Int32(v) => { + ConstantVector::new(Arc::new(Int32Vector::from(vec![v])), length) + } + ScalarValue::Int64(v) => { + ConstantVector::new(Arc::new(Int64Vector::from(vec![v])), length) + } + ScalarValue::UInt8(v) => { + ConstantVector::new(Arc::new(UInt8Vector::from(vec![v])), length) + } + ScalarValue::UInt16(v) => { + ConstantVector::new(Arc::new(UInt16Vector::from(vec![v])), length) + } + ScalarValue::UInt32(v) => { + ConstantVector::new(Arc::new(UInt32Vector::from(vec![v])), length) + } + ScalarValue::UInt64(v) => { + ConstantVector::new(Arc::new(UInt64Vector::from(vec![v])), length) + } + ScalarValue::Utf8(v) => { + ConstantVector::new(Arc::new(StringVector::from(vec![v])), length) + } + ScalarValue::LargeUtf8(v) => { + ConstantVector::new(Arc::new(StringVector::from(vec![v])), length) + } + ScalarValue::Binary(v) => { + ConstantVector::new(Arc::new(BinaryVector::from(vec![v])), length) + } + ScalarValue::LargeBinary(v) => { + ConstantVector::new(Arc::new(BinaryVector::from(vec![v])), length) + } + _ => { + return ConversionSnafu { + from: format!("Unsupported scalar value: {}", value), + } + .fail() + } + }; + + Ok(Arc::new(vector)) + } + + /// Try to cast an arrow array into vector + /// + /// # Panics + /// Panic if given arrow data type is not supported. + pub fn try_into_vector(array: ArrayRef) -> Result { + Ok(match array.data_type() { + ArrowDataType::Null => Arc::new(NullVector::try_from_arrow_array(array)?), + ArrowDataType::Boolean => Arc::new(BooleanVector::try_from_arrow_array(array)?), + ArrowDataType::Binary | ArrowDataType::LargeBinary => { + Arc::new(BinaryVector::try_from_arrow_array(array)?) + } + ArrowDataType::Int8 => Arc::new(Int8Vector::try_from_arrow_array(array)?), + ArrowDataType::Int16 => Arc::new(Int16Vector::try_from_arrow_array(array)?), + ArrowDataType::Int32 => Arc::new(Int32Vector::try_from_arrow_array(array)?), + ArrowDataType::Int64 => Arc::new(Int64Vector::try_from_arrow_array(array)?), + ArrowDataType::UInt8 => Arc::new(UInt8Vector::try_from_arrow_array(array)?), + ArrowDataType::UInt16 => Arc::new(UInt16Vector::try_from_arrow_array(array)?), + ArrowDataType::UInt32 => Arc::new(UInt32Vector::try_from_arrow_array(array)?), + ArrowDataType::UInt64 => Arc::new(UInt64Vector::try_from_arrow_array(array)?), + ArrowDataType::Float32 => Arc::new(Float32Vector::try_from_arrow_array(array)?), + ArrowDataType::Float64 => Arc::new(Float64Vector::try_from_arrow_array(array)?), + ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 => { + Arc::new(StringVector::try_from_arrow_array(array)?) + } + _ => unimplemented!("Arrow array datatype: {:?}", array.data_type()), + }) + } +} diff --git a/src/datatypes/src/vectors/mutable.rs b/src/datatypes/src/vectors/mutable.rs new file mode 100644 index 0000000000..ff9ed18a38 --- /dev/null +++ b/src/datatypes/src/vectors/mutable.rs @@ -0,0 +1,19 @@ +use std::any::Any; + +use crate::prelude::*; + +pub trait MutableVector: Send + Sync { + fn data_type(&self) -> ConcreteDataType; + + fn len(&self) -> usize; + + fn is_empty(&self) -> bool { + self.len() == 0 + } + + fn as_any(&self) -> &dyn Any; + + fn as_mut_any(&mut self) -> &mut dyn Any; + + fn to_vector(&mut self) -> VectorRef; +} diff --git a/src/datatypes/src/vectors/null.rs b/src/datatypes/src/vectors/null.rs index 16e7865b9f..9ab28890f6 100644 --- a/src/datatypes/src/vectors/null.rs +++ b/src/datatypes/src/vectors/null.rs @@ -11,8 +11,9 @@ use crate::data_type::ConcreteDataType; use crate::error::Result; use crate::serialize::Serializable; use crate::types::NullType; +use crate::value::Value; use crate::vectors::impl_try_from_arrow_array_for_vector; -use crate::vectors::{Validity, Vector}; +use crate::vectors::{Validity, Vector, VectorRef}; pub struct NullVector { array: NullArray, @@ -37,6 +38,10 @@ impl Vector for NullVector { ConcreteDataType::Null(NullType::default()) } + fn vector_type_name(&self) -> String { + "NullVector".to_string() + } + fn as_any(&self) -> &dyn Any { self } @@ -52,6 +57,33 @@ impl Vector for NullVector { fn validity(&self) -> Validity { Validity::AllNull } + + fn is_null(&self, _row: usize) -> bool { + true + } + + fn get_unchecked(&self, _index: usize) -> Value { + Value::Null + } + + fn only_null(&self) -> bool { + true + } + + fn slice(&self, _offset: usize, length: usize) -> VectorRef { + Arc::new(Self::new(length)) + } + + fn replicate(&self, offsets: &[usize]) -> VectorRef { + debug_assert!( + offsets.len() == self.len(), + "Size of offsets must match size of column" + ); + + Arc::new(Self { + array: NullArray::new(ArrowDataType::Null, *offsets.last().unwrap() as usize), + }) + } } impl fmt::Debug for NullVector { @@ -77,16 +109,26 @@ mod tests { use super::*; #[test] - fn test_null_vector() { - let vector = NullVector::new(32); + fn test_null_vector_misc() { + let v = NullVector::new(32); - assert_eq!(vector.len(), 32); - let arrow_arr = vector.to_arrow_array(); + assert_eq!(v.len(), 32); + let arrow_arr = v.to_arrow_array(); assert_eq!(arrow_arr.null_count(), 32); let array2 = arrow_arr.slice(8, 16); assert_eq!(array2.len(), 16); assert_eq!(array2.null_count(), 16); + + assert_eq!("NullVector", v.vector_type_name()); + assert!(!v.is_const()); + assert_eq!(Validity::AllNull, v.validity()); + assert!(v.only_null()); + + for i in 0..32 { + assert!(v.is_null(i)); + assert_eq!(Value::Null, v.get_unchecked(i)); + } } #[test] diff --git a/src/datatypes/src/vectors/primitive.rs b/src/datatypes/src/vectors/primitive.rs index 24fc44634b..72cdad8fb1 100644 --- a/src/datatypes/src/vectors/primitive.rs +++ b/src/datatypes/src/vectors/primitive.rs @@ -3,7 +3,7 @@ use std::iter::FromIterator; use std::slice::Iter; use std::sync::Arc; -use arrow::array::{Array, ArrayRef, MutablePrimitiveArray, PrimitiveArray}; +use arrow::array::{Array, ArrayRef, MutableArray, MutablePrimitiveArray, PrimitiveArray}; use arrow::bitmap::utils::ZipValidity; use serde_json::Value as JsonValue; use snafu::{OptionExt, ResultExt}; @@ -11,10 +11,12 @@ use snafu::{OptionExt, ResultExt}; use crate::data_type::ConcreteDataType; use crate::error::ConversionSnafu; use crate::error::{Result, SerializeSnafu}; +use crate::scalars::{Scalar, ScalarRef}; use crate::scalars::{ScalarVector, ScalarVectorBuilder}; use crate::serialize::Serializable; use crate::types::{DataTypeBuilder, Primitive}; -use crate::vectors::{Validity, Vector}; +use crate::value::Value; +use crate::vectors::{MutableVector, Validity, Vector, VectorRef}; /// Vector for primitive data types. #[derive(Debug)] @@ -62,6 +64,10 @@ impl Vector for PrimitiveVector { T::build_data_type() } + fn vector_type_name(&self) -> String { + format!("{}Vector", T::type_name()) + } + fn as_any(&self) -> &dyn Any { self } @@ -80,6 +86,45 @@ impl Vector for PrimitiveVector { None => Validity::AllValid, } } + + fn is_null(&self, row: usize) -> bool { + self.array.is_null(row) + } + + fn slice(&self, offset: usize, length: usize) -> VectorRef { + Arc::new(Self::from(self.array.slice(offset, length))) + } + + fn get_unchecked(&self, index: usize) -> Value { + self.array.value(index).into() + } + + fn replicate(&self, offsets: &[usize]) -> VectorRef { + debug_assert!( + offsets.len() == self.len(), + "Size of offsets must match size of column" + ); + + if offsets.is_empty() { + return self.slice(0, 0); + } + + let mut builder = + PrimitiveVectorBuilder::::with_capacity(*offsets.last().unwrap() as usize); + + let mut previous_offset = 0; + + for (i, offset) in offsets.iter().enumerate() { + let data = unsafe { self.array.value_unchecked(i) }; + builder.mutable_array.extend( + std::iter::repeat(data) + .take(*offset - previous_offset) + .map(Option::Some), + ); + previous_offset = *offset; + } + builder.to_vector() + } } impl From> for PrimitiveVector { @@ -88,6 +133,14 @@ impl From> for PrimitiveVector { } } +impl From>> for PrimitiveVector { + fn from(v: Vec>) -> Self { + Self { + array: PrimitiveArray::::from(v), + } + } +} + impl>> FromIterator for PrimitiveVector { fn from_iter>(iter: I) -> Self { Self { @@ -96,7 +149,13 @@ impl>> FromIterator for Pr } } -impl ScalarVector for PrimitiveVector { +impl ScalarVector for PrimitiveVector +where + T: Scalar + Primitive + DataTypeBuilder, + for<'a> T: ScalarRef<'a, ScalarType = T, VectorType = Self>, + for<'a> T: Scalar = T>, +{ + type OwnedItem = T; type RefItem<'a> = T; type Iter<'a> = PrimitiveIter<'a, T>; type Builder = PrimitiveVectorBuilder; @@ -141,11 +200,48 @@ impl<'a, T: Copy> Iterator for PrimitiveIter<'a, T> { } } -pub struct PrimitiveVectorBuilder { +pub struct PrimitiveVectorBuilder { mutable_array: MutablePrimitiveArray, } -impl ScalarVectorBuilder for PrimitiveVectorBuilder { +impl PrimitiveVectorBuilder { + fn with_capacity(capacity: usize) -> Self { + Self { + mutable_array: MutablePrimitiveArray::with_capacity(capacity), + } + } +} + +impl MutableVector for PrimitiveVectorBuilder { + fn data_type(&self) -> ConcreteDataType { + T::build_data_type() + } + + fn len(&self) -> usize { + self.mutable_array.len() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn Any { + self + } + + fn to_vector(&mut self) -> VectorRef { + Arc::new(PrimitiveVector:: { + array: std::mem::take(&mut self.mutable_array).into(), + }) + } +} + +impl ScalarVectorBuilder for PrimitiveVectorBuilder +where + T: Scalar> + Primitive + DataTypeBuilder, + for<'a> T: ScalarRef<'a, ScalarType = T, VectorType = PrimitiveVector>, + for<'a> T: Scalar = T>, +{ type VectorType = PrimitiveVector; fn with_capacity(capacity: usize) -> Self { @@ -158,16 +254,17 @@ impl ScalarVectorBuilder for PrimitiveVectorBuil self.mutable_array.push(value); } - fn finish(self) -> Self::VectorType { + fn finish(&mut self) -> Self::VectorType { PrimitiveVector { - array: self.mutable_array.into(), + array: std::mem::take(&mut self.mutable_array).into(), } } } impl Serializable for PrimitiveVector { fn serialize_to_json(&self) -> Result> { - self.iter_data() + self.array + .iter() .map(serde_json::to_value) .collect::>() .context(SerializeSnafu) @@ -176,14 +273,30 @@ impl Serializable for PrimitiveVector { #[cfg(test)] mod tests { + use arrow::datatypes::DataType as ArrowDataType; use serde_json; use super::*; use crate::serialize::Serializable; fn check_vec(v: PrimitiveVector) { + assert_eq!(4, v.len()); + assert_eq!("Int32Vector", v.vector_type_name()); + assert!(!v.is_const()); + assert_eq!(Validity::AllValid, v.validity()); + assert!(!v.only_null()); + + for i in 0..4 { + assert!(!v.is_null(i)); + assert_eq!(Value::Int32(i as i32 + 1), v.get_unchecked(i)); + } + let json_value = v.serialize_to_json().unwrap(); assert_eq!("[1,2,3,4]", serde_json::to_string(&json_value).unwrap(),); + + let arrow_arr = v.to_arrow_array(); + assert_eq!(4, arrow_arr.len()); + assert_eq!(&ArrowDataType::Int32, arrow_arr.data_type()); } #[test] @@ -264,4 +377,18 @@ mod tests { assert_eq!(0, vector.null_count()); assert_eq!(Validity::AllValid, vector.validity()); } + + #[test] + fn test_replicate() { + let v = PrimitiveVector::::from_slice((0..5).collect::>()); + + let offsets = [0usize, 1usize, 2usize, 3usize, 4usize]; + + let v = v.replicate(&offsets); + assert_eq!(4, v.len()); + + for i in 0..4 { + assert_eq!(Value::Int32(i as i32 + 1), v.get_unchecked(i)); + } + } } diff --git a/src/datatypes/src/vectors/string.rs b/src/datatypes/src/vectors/string.rs index d445693e52..be387d7481 100644 --- a/src/datatypes/src/vectors/string.rs +++ b/src/datatypes/src/vectors/string.rs @@ -1,19 +1,20 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{Array, ArrayRef, Utf8ValuesIter}; +use arrow::array::{Array, ArrayRef, MutableArray, Utf8ValuesIter}; use arrow::bitmap::utils::ZipValidity; -use serde_json::Value; +use serde_json::Value as JsonValue; use snafu::OptionExt; use snafu::ResultExt; use crate::arrow_array::{MutableStringArray, StringArray}; use crate::data_type::ConcreteDataType; use crate::error::SerializeSnafu; -use crate::prelude::{ScalarVectorBuilder, Validity, Vector}; -use crate::scalars::ScalarVector; +use crate::prelude::{MutableVector, ScalarVectorBuilder, Validity, Vector, VectorRef}; +use crate::scalars::{common, ScalarVector}; use crate::serialize::Serializable; use crate::types::StringType; +use crate::value::Value; use crate::vectors::impl_try_from_arrow_array_for_vector; /// String array wrapper @@ -28,11 +29,55 @@ impl From for StringVector { } } +impl From>> for StringVector { + fn from(data: Vec>) -> Self { + Self { + array: StringArray::from(data), + } + } +} + +impl From> for StringVector { + fn from(data: Vec) -> Self { + Self { + array: StringArray::from( + data.into_iter() + .map(Option::Some) + .collect::>>(), + ), + } + } +} + +impl From>> for StringVector { + fn from(data: Vec>) -> Self { + Self { + array: StringArray::from(data), + } + } +} + +impl From> for StringVector { + fn from(data: Vec<&str>) -> Self { + Self { + array: StringArray::from( + data.into_iter() + .map(Option::Some) + .collect::>>(), + ), + } + } +} + impl Vector for StringVector { fn data_type(&self) -> ConcreteDataType { ConcreteDataType::String(StringType::default()) } + fn vector_type_name(&self) -> String { + "StringVector".to_string() + } + fn as_any(&self) -> &dyn Any { self } @@ -51,9 +96,26 @@ impl Vector for StringVector { None => Validity::AllValid, } } + + fn is_null(&self, row: usize) -> bool { + self.array.is_null(row) + } + + fn slice(&self, offset: usize, length: usize) -> VectorRef { + Arc::new(Self::from(self.array.slice(offset, length))) + } + + fn get_unchecked(&self, index: usize) -> Value { + self.array.value(index).into() + } + + fn replicate(&self, offsets: &[usize]) -> VectorRef { + common::replicate_scalar_vector(self, offsets) + } } impl ScalarVector for StringVector { + type OwnedItem = String; type RefItem<'a> = &'a str; type Iter<'a> = ZipValidity<'a, &'a str, Utf8ValuesIter<'a, i32>>; type Builder = StringVectorBuilder; @@ -75,6 +137,28 @@ pub struct StringVectorBuilder { buffer: MutableStringArray, } +impl MutableVector for StringVectorBuilder { + fn data_type(&self) -> ConcreteDataType { + ConcreteDataType::string_datatype() + } + + fn len(&self) -> usize { + self.buffer.len() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn Any { + self + } + + fn to_vector(&mut self) -> VectorRef { + Arc::new(self.finish()) + } +} + impl ScalarVectorBuilder for StringVectorBuilder { type VectorType = StringVector; @@ -88,15 +172,15 @@ impl ScalarVectorBuilder for StringVectorBuilder { self.buffer.push(value) } - fn finish(self) -> Self::VectorType { + fn finish(&mut self) -> Self::VectorType { Self::VectorType { - array: self.buffer.into(), + array: std::mem::take(&mut self.buffer).into(), } } } impl Serializable for StringVector { - fn serialize_to_json(&self) -> crate::error::Result> { + fn serialize_to_json(&self) -> crate::error::Result> { self.iter_data() .map(|v| match v { None => Ok(serde_json::Value::Null), @@ -111,10 +195,31 @@ impl_try_from_arrow_array_for_vector!(StringArray, StringVector); #[cfg(test)] mod tests { + use arrow::datatypes::DataType as ArrowDataType; use serde_json; use super::*; + #[test] + fn test_string_vector_misc() { + let strs = vec!["hello", "greptime", "rust"]; + let v = StringVector::from(strs.clone()); + assert_eq!(3, v.len()); + assert_eq!("StringVector", v.vector_type_name()); + assert!(!v.is_const()); + assert_eq!(Validity::AllValid, v.validity()); + assert!(!v.only_null()); + + for (i, s) in strs.iter().enumerate() { + assert_eq!(Value::from(*s), v.get_unchecked(i)); + assert_eq!(Value::from(*s), v.get(i).unwrap()); + } + + let arrow_arr = v.to_arrow_array(); + assert_eq!(3, arrow_arr.len()); + assert_eq!(&ArrowDataType::Utf8, arrow_arr.data_type()); + } + #[test] fn test_serialize_string_vector() { let mut builder = StringVectorBuilder::with_capacity(3); diff --git a/src/query/Cargo.toml b/src/query/Cargo.toml index 95718212c8..c9d2d5eb6a 100644 --- a/src/query/Cargo.toml +++ b/src/query/Cargo.toml @@ -11,6 +11,7 @@ features = ["io_csv", "io_json", "io_parquet", "io_parquet_compression", "io_ipc [dependencies] async-trait = "0.1" common-error = { path = "../common/error" } +common-function = { path = "../common/function" } common-query = { path = "../common/query" } common-recordbatch = {path = "../common/recordbatch" } common-telemetry = { path = "../common/telemetry" } diff --git a/src/query/src/datafusion.rs b/src/query/src/datafusion.rs index 4c558eec5e..5ef481733d 100644 --- a/src/query/src/datafusion.rs +++ b/src/query/src/datafusion.rs @@ -7,6 +7,8 @@ mod planner; use std::sync::Arc; +use common_function::scalars::udf::create_udf; +use common_function::scalars::FunctionRef; use common_query::prelude::ScalarUdf; use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream}; use common_telemetry::timer; @@ -73,6 +75,10 @@ impl QueryEngine for DatafusionQueryEngine { fn register_udf(&self, udf: ScalarUdf) { self.state.register_udf(udf); } + + fn register_function(&self, func: FunctionRef) { + self.state.register_udf(create_udf(func)); + } } impl LogicalOptimizer for DatafusionQueryEngine { diff --git a/src/query/src/query_engine.rs b/src/query/src/query_engine.rs index 8d14f74e92..e79a76618c 100644 --- a/src/query/src/query_engine.rs +++ b/src/query/src/query_engine.rs @@ -3,6 +3,7 @@ mod state; use std::sync::Arc; +use common_function::scalars::{FunctionRef, FUNCTION_REGISTRY}; use common_query::prelude::ScalarUdf; use common_recordbatch::SendableRecordBatchStream; @@ -28,6 +29,8 @@ pub trait QueryEngine: Send + Sync { async fn execute(&self, plan: &LogicalPlan) -> Result; fn register_udf(&self, udf: ScalarUdf); + + fn register_function(&self, func: FunctionRef); } pub struct QueryEngineFactory { @@ -36,9 +39,13 @@ pub struct QueryEngineFactory { impl QueryEngineFactory { pub fn new(catalog_list: Arc) -> Self { - Self { - query_engine: Arc::new(DatafusionQueryEngine::new(catalog_list)), + let query_engine = Arc::new(DatafusionQueryEngine::new(catalog_list)); + + for func in FUNCTION_REGISTRY.functions() { + query_engine.register_function(func); } + + Self { query_engine } } }