diff --git a/src/common/function/src/scalars/ip/cidr.rs b/src/common/function/src/scalars/ip/cidr.rs index cbd5fbd922..73db7e91a7 100644 --- a/src/common/function/src/scalars/ip/cidr.rs +++ b/src/common/function/src/scalars/ip/cidr.rs @@ -14,18 +14,21 @@ use std::net::{Ipv4Addr, Ipv6Addr}; use std::str::FromStr; +use std::sync::Arc; use common_query::error::{InvalidFuncArgsSnafu, Result}; -use datafusion_common::types; -use datafusion_expr::{Coercion, Signature, TypeSignature, TypeSignatureClass, Volatility}; -use datatypes::arrow::datatypes::DataType; -use datatypes::prelude::Value; -use datatypes::scalars::ScalarVectorBuilder; -use datatypes::vectors::{MutableVector, StringVectorBuilder, VectorRef}; +use datafusion_common::arrow::array::{Array, AsArray, StringViewBuilder}; +use datafusion_common::arrow::compute; +use datafusion_common::arrow::datatypes::{DataType, UInt8Type}; +use datafusion_common::{DataFusionError, types}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, TypeSignatureClass, + Volatility, +}; use derive_more::Display; use snafu::ensure; -use crate::function::{Function, FunctionContext}; +use crate::function::Function; /// Function that converts an IPv4 address string to CIDR notation. /// @@ -46,7 +49,7 @@ impl Function for Ipv4ToCidr { } fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Utf8) + Ok(DataType::Utf8View) } fn signature(&self) -> Signature { @@ -62,19 +65,29 @@ impl Function for Ipv4ToCidr { ) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure!( - columns.len() == 1 || columns.len() == 2, - InvalidFuncArgsSnafu { - err_msg: format!("Expected 1 or 2 arguments, got {}", columns.len()) - } - ); + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + if args.args.len() != 1 && args.args.len() != 2 { + return Err(DataFusionError::Execution(format!( + "expecting 1 or 2 arguments, got {}", + args.args.len() + ))); + } + let columns = ColumnarValue::values_to_arrays(&args.args)?; let ip_vec = &columns[0]; - let mut results = StringVectorBuilder::with_capacity(ip_vec.len()); + let mut builder = StringViewBuilder::with_capacity(ip_vec.len()); + let arg0 = compute::cast(ip_vec, &DataType::Utf8View)?; + let ip_vec = arg0.as_string_view(); - let has_subnet_arg = columns.len() == 2; - let subnet_vec = if has_subnet_arg { + let maybe_arg1 = if columns.len() > 1 { + Some(compute::cast(&columns[1], &DataType::UInt8)?) + } else { + None + }; + let subnets = if let Some(arg1) = maybe_arg1.as_ref() { ensure!( columns[1].len() == ip_vec.len(), InvalidFuncArgsSnafu { @@ -83,23 +96,19 @@ impl Function for Ipv4ToCidr { .to_string() } ); - Some(&columns[1]) + Some(arg1.as_primitive::()) } else { None }; for i in 0..ip_vec.len() { - let ip_str = ip_vec.get(i); - let subnet = subnet_vec.map(|v| v.get(i)); + let ip_str = ip_vec.is_valid(i).then(|| ip_vec.value(i)); + let subnet = subnets.and_then(|v| v.is_valid(i).then(|| v.value(i))); let cidr = match (ip_str, subnet) { - (Value::String(s), Some(Value::UInt8(mask))) => { - let ip_str = s.as_utf8().trim(); + (Some(ip_str), Some(mask)) => { if ip_str.is_empty() { - return InvalidFuncArgsSnafu { - err_msg: "Empty IPv4 address".to_string(), - } - .fail(); + return Err(DataFusionError::Execution("empty IPv4 address".to_string())); } let ip_addr = complete_and_parse_ipv4(ip_str)?; @@ -109,13 +118,9 @@ impl Function for Ipv4ToCidr { Some(format!("{}/{}", masked_ip, mask)) } - (Value::String(s), None) => { - let ip_str = s.as_utf8().trim(); + (Some(ip_str), None) => { if ip_str.is_empty() { - return InvalidFuncArgsSnafu { - err_msg: "Empty IPv4 address".to_string(), - } - .fail(); + return Err(DataFusionError::Execution("empty IPv4 address".to_string())); } let ip_addr = complete_and_parse_ipv4(ip_str)?; @@ -149,10 +154,10 @@ impl Function for Ipv4ToCidr { _ => None, }; - results.push(cidr.as_deref()); + builder.append_option(cidr.as_deref()); } - Ok(results.to_vector()) + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } } @@ -175,7 +180,7 @@ impl Function for Ipv6ToCidr { } fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Utf8) + Ok(DataType::Utf8View) } fn signature(&self) -> Signature { @@ -188,37 +193,41 @@ impl Function for Ipv6ToCidr { ) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure!( - columns.len() == 1 || columns.len() == 2, - InvalidFuncArgsSnafu { - err_msg: format!("Expected 1 or 2 arguments, got {}", columns.len()) - } - ); + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + if args.args.len() != 1 && args.args.len() != 2 { + return Err(DataFusionError::Execution(format!( + "expecting 1 or 2 arguments, got {}", + args.args.len() + ))); + } + let columns = ColumnarValue::values_to_arrays(&args.args)?; let ip_vec = &columns[0]; let size = ip_vec.len(); - let mut results = StringVectorBuilder::with_capacity(size); + let mut builder = StringViewBuilder::with_capacity(size); + let arg0 = compute::cast(ip_vec, &DataType::Utf8View)?; + let ip_vec = arg0.as_string_view(); - let has_subnet_arg = columns.len() == 2; - let subnet_vec = if has_subnet_arg { - Some(&columns[1]) + let maybe_arg1 = if columns.len() > 1 { + Some(compute::cast(&columns[1], &DataType::UInt8)?) } else { None }; + let subnets = maybe_arg1 + .as_ref() + .map(|arg1| arg1.as_primitive::()); for i in 0..size { - let ip_str = ip_vec.get(i); - let subnet = subnet_vec.map(|v| v.get(i)); + let ip_str = ip_vec.is_valid(i).then(|| ip_vec.value(i)); + let subnet = subnets.and_then(|v| v.is_valid(i).then(|| v.value(i))); let cidr = match (ip_str, subnet) { - (Value::String(s), Some(Value::UInt8(mask))) => { - let ip_str = s.as_utf8().trim(); + (Some(ip_str), Some(mask)) => { if ip_str.is_empty() { - return InvalidFuncArgsSnafu { - err_msg: "Empty IPv6 address".to_string(), - } - .fail(); + return Err(DataFusionError::Execution("empty IPv6 address".to_string())); } let ip_addr = complete_and_parse_ipv6(ip_str)?; @@ -228,13 +237,9 @@ impl Function for Ipv6ToCidr { Some(format!("{}/{}", masked_ip, mask)) } - (Value::String(s), None) => { - let ip_str = s.as_utf8().trim(); + (Some(ip_str), None) => { if ip_str.is_empty() { - return InvalidFuncArgsSnafu { - err_msg: "Empty IPv6 address".to_string(), - } - .fail(); + return Err(DataFusionError::Execution("empty IPv6 address".to_string())); } let ip_addr = complete_and_parse_ipv6(ip_str)?; @@ -250,10 +255,10 @@ impl Function for Ipv6ToCidr { _ => None, }; - results.push(cidr.as_deref()); + builder.append_option(cidr.as_deref()); } - Ok(results.to_vector()) + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } } @@ -375,108 +380,148 @@ fn auto_detect_ipv6_subnet(addr: &Ipv6Addr) -> u8 { #[cfg(test)] mod tests { - use std::sync::Arc; - - use datatypes::scalars::ScalarVector; - use datatypes::vectors::{StringVector, UInt8Vector}; + use arrow_schema::Field; + use datafusion_common::arrow::array::{StringViewArray, UInt8Array}; use super::*; #[test] fn test_ipv4_to_cidr_auto() { let func = Ipv4ToCidr; - let ctx = FunctionContext::default(); // Test data with auto subnet detection let values = vec!["192.168.1.0", "10.0.0.0", "172.16", "192"]; - let input = Arc::new(StringVector::from_slice(&values)) as VectorRef; + let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values))); - let result = func.eval(&ctx, &[input]).unwrap(); - let result = result.as_any().downcast_ref::().unwrap(); + let args = ScalarFunctionArgs { + args: vec![arg0], + arg_fields: vec![], + number_rows: 4, + return_field: Arc::new(Field::new("x", DataType::Utf8View, false)), + config_options: Arc::new(Default::default()), + }; + let result = func.invoke_with_args(args).unwrap(); + let result = result.to_array(4).unwrap(); + let result = result.as_string_view(); - assert_eq!(result.get_data(0).unwrap(), "192.168.1.0/24"); - assert_eq!(result.get_data(1).unwrap(), "10.0.0.0/8"); - assert_eq!(result.get_data(2).unwrap(), "172.16.0.0/16"); - assert_eq!(result.get_data(3).unwrap(), "192.0.0.0/8"); + assert_eq!(result.value(0), "192.168.1.0/24"); + assert_eq!(result.value(1), "10.0.0.0/8"); + assert_eq!(result.value(2), "172.16.0.0/16"); + assert_eq!(result.value(3), "192.0.0.0/8"); } #[test] fn test_ipv4_to_cidr_with_subnet() { let func = Ipv4ToCidr; - let ctx = FunctionContext::default(); // Test data with explicit subnet let ip_values = vec!["192.168.1.1", "10.0.0.1", "172.16.5.5"]; let subnet_values = vec![24u8, 16u8, 12u8]; - let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef; - let subnet_input = Arc::new(UInt8Vector::from_vec(subnet_values)) as VectorRef; + let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&ip_values))); + let arg1 = ColumnarValue::Array(Arc::new(UInt8Array::from(subnet_values))); - let result = func.eval(&ctx, &[ip_input, subnet_input]).unwrap(); - let result = result.as_any().downcast_ref::().unwrap(); + let args = ScalarFunctionArgs { + args: vec![arg0, arg1], + arg_fields: vec![], + number_rows: 3, + return_field: Arc::new(Field::new("x", DataType::Utf8View, false)), + config_options: Arc::new(Default::default()), + }; + let result = func.invoke_with_args(args).unwrap(); + let result = result.to_array(3).unwrap(); + let result = result.as_string_view(); - assert_eq!(result.get_data(0).unwrap(), "192.168.1.0/24"); - assert_eq!(result.get_data(1).unwrap(), "10.0.0.0/16"); - assert_eq!(result.get_data(2).unwrap(), "172.16.0.0/12"); + assert_eq!(result.value(0), "192.168.1.0/24"); + assert_eq!(result.value(1), "10.0.0.0/16"); + assert_eq!(result.value(2), "172.16.0.0/12"); } #[test] fn test_ipv6_to_cidr_auto() { let func = Ipv6ToCidr; - let ctx = FunctionContext::default(); // Test data with auto subnet detection let values = vec!["2001:db8::", "2001:db8", "fe80::1", "::1"]; - let input = Arc::new(StringVector::from_slice(&values)) as VectorRef; + let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values))); - let result = func.eval(&ctx, &[input]).unwrap(); - let result = result.as_any().downcast_ref::().unwrap(); + let args = ScalarFunctionArgs { + args: vec![arg0], + arg_fields: vec![], + number_rows: 4, + return_field: Arc::new(Field::new("x", DataType::Utf8View, false)), + config_options: Arc::new(Default::default()), + }; + let result = func.invoke_with_args(args).unwrap(); + let result = result.to_array(4).unwrap(); + let result = result.as_string_view(); - assert_eq!(result.get_data(0).unwrap(), "2001:db8::/32"); - assert_eq!(result.get_data(1).unwrap(), "2001:db8::/32"); - assert_eq!(result.get_data(2).unwrap(), "fe80::/16"); - assert_eq!(result.get_data(3).unwrap(), "::1/128"); // Special case for ::1 + assert_eq!(result.value(0), "2001:db8::/32"); + assert_eq!(result.value(1), "2001:db8::/32"); + assert_eq!(result.value(2), "fe80::/16"); + assert_eq!(result.value(3), "::1/128"); // Special case for ::1 } #[test] fn test_ipv6_to_cidr_with_subnet() { let func = Ipv6ToCidr; - let ctx = FunctionContext::default(); // Test data with explicit subnet let ip_values = vec!["2001:db8::", "fe80::1", "2001:db8:1234::"]; let subnet_values = vec![48u8, 10u8, 56u8]; - let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef; - let subnet_input = Arc::new(UInt8Vector::from_vec(subnet_values)) as VectorRef; + let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&ip_values))); + let arg1 = ColumnarValue::Array(Arc::new(UInt8Array::from(subnet_values))); - let result = func.eval(&ctx, &[ip_input, subnet_input]).unwrap(); - let result = result.as_any().downcast_ref::().unwrap(); + let args = ScalarFunctionArgs { + args: vec![arg0, arg1], + arg_fields: vec![], + number_rows: 3, + return_field: Arc::new(Field::new("x", DataType::Utf8View, false)), + config_options: Arc::new(Default::default()), + }; + let result = func.invoke_with_args(args).unwrap(); + let result = result.to_array(3).unwrap(); + let result = result.as_string_view(); - assert_eq!(result.get_data(0).unwrap(), "2001:db8::/48"); - assert_eq!(result.get_data(1).unwrap(), "fe80::/10"); - assert_eq!(result.get_data(2).unwrap(), "2001:db8:1234::/56"); + assert_eq!(result.value(0), "2001:db8::/48"); + assert_eq!(result.value(1), "fe80::/10"); + assert_eq!(result.value(2), "2001:db8:1234::/56"); } #[test] fn test_invalid_inputs() { let ipv4_func = Ipv4ToCidr; let ipv6_func = Ipv6ToCidr; - let ctx = FunctionContext::default(); // Empty string should fail let empty_values = vec![""]; - let empty_input = Arc::new(StringVector::from_slice(&empty_values)) as VectorRef; + let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&empty_values))); - let ipv4_result = ipv4_func.eval(&ctx, std::slice::from_ref(&empty_input)); - let ipv6_result = ipv6_func.eval(&ctx, std::slice::from_ref(&empty_input)); + let args = ScalarFunctionArgs { + args: vec![arg0], + arg_fields: vec![], + number_rows: 1, + return_field: Arc::new(Field::new("x", DataType::Utf8View, false)), + config_options: Arc::new(Default::default()), + }; + let ipv4_result = ipv4_func.invoke_with_args(args.clone()); + let ipv6_result = ipv6_func.invoke_with_args(args); assert!(ipv4_result.is_err()); assert!(ipv6_result.is_err()); // Invalid IP formats should fail let invalid_values = vec!["not an ip", "192.168.1.256", "zzzz::ffff"]; - let invalid_input = Arc::new(StringVector::from_slice(&invalid_values)) as VectorRef; + let arg0 = + ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&invalid_values))); - let ipv4_result = ipv4_func.eval(&ctx, std::slice::from_ref(&invalid_input)); + let args = ScalarFunctionArgs { + args: vec![arg0], + arg_fields: vec![], + number_rows: 3, + return_field: Arc::new(Field::new("x", DataType::Utf8View, false)), + config_options: Arc::new(Default::default()), + }; + let ipv4_result = ipv4_func.invoke_with_args(args); assert!(ipv4_result.is_err()); } diff --git a/src/common/function/src/scalars/ip/ipv4.rs b/src/common/function/src/scalars/ip/ipv4.rs index 56a1409456..bd1770cf10 100644 --- a/src/common/function/src/scalars/ip/ipv4.rs +++ b/src/common/function/src/scalars/ip/ipv4.rs @@ -14,16 +14,16 @@ use std::net::Ipv4Addr; use std::str::FromStr; +use std::sync::Arc; use common_query::error::{InvalidFuncArgsSnafu, Result}; -use datafusion_expr::{Signature, TypeSignature, Volatility}; -use datatypes::arrow::datatypes::DataType; -use datatypes::scalars::ScalarVectorBuilder; -use datatypes::vectors::{MutableVector, StringVectorBuilder, UInt32VectorBuilder, VectorRef}; +use datafusion_common::arrow::array::{Array, AsArray, StringViewBuilder, UInt32Builder}; +use datafusion_common::arrow::compute; +use datafusion_common::arrow::datatypes::{DataType, UInt32Type}; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility}; use derive_more::Display; -use snafu::ensure; -use crate::function::{Function, FunctionContext}; +use crate::function::{Function, extract_args}; /// Function that converts a UInt32 number to an IPv4 address string. /// @@ -53,7 +53,7 @@ impl Function for Ipv4NumToString { } fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Utf8) + Ok(DataType::Utf8View) } fn signature(&self) -> Signature { @@ -63,22 +63,20 @@ impl Function for Ipv4NumToString { ) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure!( - columns.len() == 1, - InvalidFuncArgsSnafu { - err_msg: format!("Expected 1 argument, got {}", columns.len()) - } - ); + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let [arg0] = extract_args(self.name(), &args)?; + let uint_vec = arg0.as_primitive::(); - let uint_vec = &columns[0]; let size = uint_vec.len(); - let mut results = StringVectorBuilder::with_capacity(size); + let mut builder = StringViewBuilder::with_capacity(size); for i in 0..size { - let ip_num = uint_vec.get(i); + let ip_num = uint_vec.is_valid(i).then(|| uint_vec.value(i)); let ip_str = match ip_num { - datatypes::value::Value::UInt32(num) => { + Some(num) => { // Convert UInt32 to IPv4 string (A.B.C.D format) let a = (num >> 24) & 0xFF; let b = (num >> 16) & 0xFF; @@ -89,10 +87,10 @@ impl Function for Ipv4NumToString { _ => None, }; - results.push(ip_str.as_deref()); + builder.append_option(ip_str.as_deref()); } - Ok(results.to_vector()) + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } fn aliases(&self) -> &[String] { @@ -123,23 +121,21 @@ impl Function for Ipv4StringToNum { Signature::string(1, Volatility::Immutable) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure!( - columns.len() == 1, - InvalidFuncArgsSnafu { - err_msg: format!("Expected 1 argument, got {}", columns.len()) - } - ); + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let [arg0] = extract_args(self.name(), &args)?; - let ip_vec = &columns[0]; + let arg0 = compute::cast(&arg0, &DataType::Utf8View)?; + let ip_vec = arg0.as_string_view(); let size = ip_vec.len(); - let mut results = UInt32VectorBuilder::with_capacity(size); + let mut builder = UInt32Builder::with_capacity(size); for i in 0..size { - let ip_str = ip_vec.get(i); + let ip_str = ip_vec.is_valid(i).then(|| ip_vec.value(i)); let ip_num = match ip_str { - datatypes::value::Value::String(s) => { - let ip_str = s.as_utf8(); + Some(ip_str) => { let ip_addr = Ipv4Addr::from_str(ip_str).map_err(|_| { InvalidFuncArgsSnafu { err_msg: format!("Invalid IPv4 address format: {}", ip_str), @@ -151,10 +147,10 @@ impl Function for Ipv4StringToNum { _ => None, }; - results.push(ip_num); + builder.append_option(ip_num); } - Ok(results.to_vector()) + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } } @@ -162,66 +158,92 @@ impl Function for Ipv4StringToNum { mod tests { use std::sync::Arc; - use datatypes::scalars::ScalarVector; - use datatypes::vectors::{StringVector, UInt32Vector}; + use arrow_schema::Field; + use datafusion_common::arrow::array::{StringViewArray, UInt32Array}; use super::*; #[test] fn test_ipv4_num_to_string() { let func = Ipv4NumToString::default(); - let ctx = FunctionContext::default(); // Test data let values = vec![167772161u32, 3232235521u32, 0u32, 4294967295u32]; - let input = Arc::new(UInt32Vector::from_vec(values)) as VectorRef; + let input = ColumnarValue::Array(Arc::new(UInt32Array::from(values))); - let result = func.eval(&ctx, &[input]).unwrap(); - let result = result.as_any().downcast_ref::().unwrap(); + let args = ScalarFunctionArgs { + args: vec![input], + arg_fields: vec![], + number_rows: 4, + return_field: Arc::new(Field::new("x", DataType::Utf8View, false)), + config_options: Arc::new(Default::default()), + }; + let result = func.invoke_with_args(args).unwrap(); + let result = result.to_array(4).unwrap(); + let result = result.as_string_view(); - assert_eq!(result.get_data(0).unwrap(), "10.0.0.1"); - assert_eq!(result.get_data(1).unwrap(), "192.168.0.1"); - assert_eq!(result.get_data(2).unwrap(), "0.0.0.0"); - assert_eq!(result.get_data(3).unwrap(), "255.255.255.255"); + assert_eq!(result.value(0), "10.0.0.1"); + assert_eq!(result.value(1), "192.168.0.1"); + assert_eq!(result.value(2), "0.0.0.0"); + assert_eq!(result.value(3), "255.255.255.255"); } #[test] fn test_ipv4_string_to_num() { let func = Ipv4StringToNum; - let ctx = FunctionContext::default(); // Test data let values = vec!["10.0.0.1", "192.168.0.1", "0.0.0.0", "255.255.255.255"]; - let input = Arc::new(StringVector::from_slice(&values)) as VectorRef; + let input = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values))); - let result = func.eval(&ctx, &[input]).unwrap(); - let result = result.as_any().downcast_ref::().unwrap(); + let args = ScalarFunctionArgs { + args: vec![input], + arg_fields: vec![], + number_rows: 4, + return_field: Arc::new(Field::new("x", DataType::UInt32, false)), + config_options: Arc::new(Default::default()), + }; + let result = func.invoke_with_args(args).unwrap(); + let result = result.to_array(4).unwrap(); + let result = result.as_primitive::(); - assert_eq!(result.get_data(0).unwrap(), 167772161); - assert_eq!(result.get_data(1).unwrap(), 3232235521); - assert_eq!(result.get_data(2).unwrap(), 0); - assert_eq!(result.get_data(3).unwrap(), 4294967295); + assert_eq!(result.value(0), 167772161); + assert_eq!(result.value(1), 3232235521); + assert_eq!(result.value(2), 0); + assert_eq!(result.value(3), 4294967295); } #[test] fn test_ipv4_conversions_roundtrip() { let to_num = Ipv4StringToNum; let to_string = Ipv4NumToString::default(); - let ctx = FunctionContext::default(); // Test data for string to num to string let values = vec!["10.0.0.1", "192.168.0.1", "0.0.0.0", "255.255.255.255"]; - let input = Arc::new(StringVector::from_slice(&values)) as VectorRef; + let input = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values))); - let num_result = to_num.eval(&ctx, &[input]).unwrap(); - let back_to_string = to_string.eval(&ctx, &[num_result]).unwrap(); - let str_result = back_to_string - .as_any() - .downcast_ref::() - .unwrap(); + let args = ScalarFunctionArgs { + args: vec![input], + arg_fields: vec![], + number_rows: 4, + return_field: Arc::new(Field::new("x", DataType::UInt32, false)), + config_options: Arc::new(Default::default()), + }; + let result = to_num.invoke_with_args(args).unwrap(); + + let args = ScalarFunctionArgs { + args: vec![result], + arg_fields: vec![], + number_rows: 4, + return_field: Arc::new(Field::new("x", DataType::Utf8View, false)), + config_options: Arc::new(Default::default()), + }; + let result = to_string.invoke_with_args(args).unwrap(); + let result = result.to_array(4).unwrap(); + let result = result.as_string_view(); for (i, expected) in values.iter().enumerate() { - assert_eq!(str_result.get_data(i).unwrap(), *expected); + assert_eq!(result.value(i), *expected); } } } diff --git a/src/common/function/src/scalars/ip/ipv6.rs b/src/common/function/src/scalars/ip/ipv6.rs index 57bf2c2082..d90d91103d 100644 --- a/src/common/function/src/scalars/ip/ipv6.rs +++ b/src/common/function/src/scalars/ip/ipv6.rs @@ -14,17 +14,17 @@ use std::net::{Ipv4Addr, Ipv6Addr}; use std::str::FromStr; +use std::sync::Arc; use common_query::error::{InvalidFuncArgsSnafu, Result}; use datafusion::arrow::datatypes::DataType; -use datafusion_expr::{Signature, Volatility}; -use datatypes::prelude::Value; -use datatypes::scalars::ScalarVectorBuilder; -use datatypes::vectors::{BinaryVectorBuilder, MutableVector, StringVectorBuilder, VectorRef}; +use datafusion_common::DataFusionError; +use datafusion_common::arrow::array::{Array, AsArray, BinaryViewBuilder, StringViewBuilder}; +use datafusion_common::arrow::compute; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility}; use derive_more::Display; -use snafu::ensure; -use crate::function::{Function, FunctionContext}; +use crate::function::{Function, extract_args}; /// Function that converts a hex string representation of an IPv6 address to a formatted string. /// @@ -41,30 +41,29 @@ impl Function for Ipv6NumToString { } fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Utf8) + Ok(DataType::Utf8View) } fn signature(&self) -> Signature { Signature::string(1, Volatility::Immutable) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure!( - columns.len() == 1, - InvalidFuncArgsSnafu { - err_msg: format!("Expected 1 argument, got {}", columns.len()) - } - ); + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let [arg0] = extract_args(self.name(), &args)?; - let hex_vec = &columns[0]; + let arg0 = compute::cast(&arg0, &DataType::Utf8View)?; + let hex_vec = arg0.as_string_view(); let size = hex_vec.len(); - let mut results = StringVectorBuilder::with_capacity(size); + let mut builder = StringViewBuilder::with_capacity(size); for i in 0..size { - let hex_str = hex_vec.get(i); + let hex_str = hex_vec.is_valid(i).then(|| hex_vec.value(i)); let ip_str = match hex_str { - Value::String(s) => { - let hex_str = s.as_utf8().to_lowercase(); + Some(s) => { + let hex_str = s.to_lowercase(); // Validate and convert hex string to bytes let bytes = if hex_str.len() == 32 { @@ -80,10 +79,10 @@ impl Function for Ipv6NumToString { } bytes } else { - return InvalidFuncArgsSnafu { - err_msg: format!("Expected 32 hex characters, got {}", hex_str.len()), - } - .fail(); + return Err(DataFusionError::Execution(format!( + "expecting 32 hex characters, got {}", + hex_str.len() + ))); }; // Convert bytes to IPv6 address @@ -106,10 +105,10 @@ impl Function for Ipv6NumToString { _ => None, }; - results.push(ip_str.as_deref()); + builder.append_option(ip_str.as_deref()); } - Ok(results.to_vector()) + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } } @@ -130,31 +129,28 @@ impl Function for Ipv6StringToNum { } fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Binary) + Ok(DataType::BinaryView) } fn signature(&self) -> Signature { Signature::string(1, Volatility::Immutable) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure!( - columns.len() == 1, - InvalidFuncArgsSnafu { - err_msg: format!("Expected 1 argument, got {}", columns.len()) - } - ); + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let [arg0] = extract_args(self.name(), &args)?; + let arg0 = compute::cast(&arg0, &DataType::Utf8View)?; + let ip_vec = arg0.as_string_view(); - let ip_vec = &columns[0]; let size = ip_vec.len(); - let mut results = BinaryVectorBuilder::with_capacity(size); + let mut builder = BinaryViewBuilder::with_capacity(size); for i in 0..size { - let ip_str = ip_vec.get(i); + let ip_str = ip_vec.is_valid(i).then(|| ip_vec.value(i)); let ip_binary = match ip_str { - Value::String(s) => { - let addr_str = s.as_utf8(); - + Some(addr_str) => { let addr = if let Ok(ipv6) = Ipv6Addr::from_str(addr_str) { // Direct IPv6 address ipv6 @@ -163,10 +159,10 @@ impl Function for Ipv6StringToNum { ipv4.to_ipv6_mapped() } else { // Invalid format - return InvalidFuncArgsSnafu { - err_msg: format!("Invalid IPv6 address format: {}", addr_str), - } - .fail(); + return Err(DataFusionError::Execution(format!( + "Invalid IPv6 address format: {}", + addr_str + ))); }; // Convert IPv6 address to binary (16 bytes) @@ -176,10 +172,10 @@ impl Function for Ipv6StringToNum { _ => None, }; - results.push(ip_binary.as_deref()); + builder.append_option(ip_binary.as_deref()); } - Ok(results.to_vector()) + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } } @@ -188,15 +184,14 @@ mod tests { use std::fmt::Write; use std::sync::Arc; - use datatypes::scalars::ScalarVector; - use datatypes::vectors::{BinaryVector, StringVector, Vector}; + use arrow_schema::Field; + use datafusion_common::arrow::array::StringViewArray; use super::*; #[test] fn test_ipv6_num_to_string() { let func = Ipv6NumToString; - let ctx = FunctionContext::default(); // Hex string for "2001:db8::1" let hex_str1 = "20010db8000000000000000000000001"; @@ -205,62 +200,93 @@ mod tests { let hex_str2 = "00000000000000000000ffffc0a80001"; let values = vec![hex_str1, hex_str2]; - let input = Arc::new(StringVector::from_slice(&values)) as VectorRef; + let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values))); - let result = func.eval(&ctx, &[input]).unwrap(); - let result = result.as_any().downcast_ref::().unwrap(); + let args = ScalarFunctionArgs { + args: vec![arg0], + arg_fields: vec![], + number_rows: 2, + return_field: Arc::new(Field::new("x", DataType::Utf8View, false)), + config_options: Arc::new(Default::default()), + }; + let result = func.invoke_with_args(args).unwrap(); + let result = result.to_array(2).unwrap(); + let result = result.as_string_view(); - assert_eq!(result.get_data(0).unwrap(), "2001:db8::1"); - assert_eq!(result.get_data(1).unwrap(), "::ffff:192.168.0.1"); + assert_eq!(result.value(0), "2001:db8::1"); + assert_eq!(result.value(1), "::ffff:192.168.0.1"); } #[test] fn test_ipv6_num_to_string_uppercase() { let func = Ipv6NumToString; - let ctx = FunctionContext::default(); // Uppercase hex string for "2001:db8::1" let hex_str = "20010DB8000000000000000000000001"; let values = vec![hex_str]; - let input = Arc::new(StringVector::from_slice(&values)) as VectorRef; + let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values))); - let result = func.eval(&ctx, &[input]).unwrap(); - let result = result.as_any().downcast_ref::().unwrap(); + let args = ScalarFunctionArgs { + args: vec![arg0], + arg_fields: vec![], + number_rows: 1, + return_field: Arc::new(Field::new("x", DataType::Utf8View, false)), + config_options: Arc::new(Default::default()), + }; + let result = func.invoke_with_args(args).unwrap(); + let result = result.to_array(1).unwrap(); + let result = result.as_string_view(); - assert_eq!(result.get_data(0).unwrap(), "2001:db8::1"); + assert_eq!(result.value(0), "2001:db8::1"); } #[test] fn test_ipv6_num_to_string_error() { let func = Ipv6NumToString; - let ctx = FunctionContext::default(); // Invalid hex string - wrong length let hex_str = "20010db8"; let values = vec![hex_str]; - let input = Arc::new(StringVector::from_slice(&values)) as VectorRef; + let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values))); // Should return an error - let result = func.eval(&ctx, &[input]); + let args = ScalarFunctionArgs { + args: vec![arg0], + arg_fields: vec![], + number_rows: 2, + return_field: Arc::new(Field::new("x", DataType::Utf8View, false)), + config_options: Arc::new(Default::default()), + }; + let result = func.invoke_with_args(args); assert!(result.is_err()); // Check that the error message contains expected text let error_msg = result.unwrap_err().to_string(); - assert!(error_msg.contains("Expected 32 hex characters")); + assert_eq!( + error_msg, + "Execution error: expecting 32 hex characters, got 8" + ); } #[test] fn test_ipv6_string_to_num() { let func = Ipv6StringToNum; - let ctx = FunctionContext::default(); let values = vec!["2001:db8::1", "::ffff:192.168.0.1", "192.168.0.1"]; - let input = Arc::new(StringVector::from_slice(&values)) as VectorRef; + let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values))); - let result = func.eval(&ctx, &[input]).unwrap(); - let result = result.as_any().downcast_ref::().unwrap(); + let args = ScalarFunctionArgs { + args: vec![arg0], + arg_fields: vec![], + number_rows: 3, + return_field: Arc::new(Field::new("x", DataType::Utf8View, false)), + config_options: Arc::new(Default::default()), + }; + let result = func.invoke_with_args(args).unwrap(); + let result = result.to_array(3).unwrap(); + let result = result.as_binary_view(); // Expected binary for "2001:db8::1" let expected_1 = [ @@ -272,33 +298,37 @@ mod tests { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xC0, 0xA8, 0, 0x01, ]; - assert_eq!(result.get_data(0).unwrap(), &expected_1); - assert_eq!(result.get_data(1).unwrap(), &expected_2); - assert_eq!(result.get_data(2).unwrap(), &expected_2); + assert_eq!(result.value(0), &expected_1); + assert_eq!(result.value(1), &expected_2); + assert_eq!(result.value(2), &expected_2); } #[test] fn test_ipv6_conversions_roundtrip() { let to_num = Ipv6StringToNum; let to_string = Ipv6NumToString; - let ctx = FunctionContext::default(); // Test data let values = vec!["2001:db8::1", "::ffff:192.168.0.1"]; - let input = Arc::new(StringVector::from_slice(&values)) as VectorRef; + let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values))); // Convert IPv6 addresses to binary - let binary_result = to_num.eval(&ctx, std::slice::from_ref(&input)).unwrap(); + let args = ScalarFunctionArgs { + args: vec![arg0], + arg_fields: vec![], + number_rows: 2, + return_field: Arc::new(Field::new("x", DataType::BinaryView, false)), + config_options: Arc::new(Default::default()), + }; + let result = to_num.invoke_with_args(args).unwrap(); // Convert binary to hex string representation (for ipv6_num_to_string) let mut hex_strings = Vec::new(); - let binary_vector = binary_result - .as_any() - .downcast_ref::() - .unwrap(); + let result = result.to_array(2).unwrap(); + let binary_vector = result.as_binary_view(); for i in 0..binary_vector.len() { - let bytes = binary_vector.get_data(i).unwrap(); + let bytes = binary_vector.value(i); let hex = bytes.iter().fold(String::new(), |mut acc, b| { write!(&mut acc, "{:02x}", b).unwrap(); acc @@ -307,18 +337,23 @@ mod tests { } let hex_str_refs: Vec<&str> = hex_strings.iter().map(|s| s.as_str()).collect(); - let hex_input = Arc::new(StringVector::from_slice(&hex_str_refs)) as VectorRef; + let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&hex_str_refs))); // Now convert hex to formatted string - let string_result = to_string.eval(&ctx, &[hex_input]).unwrap(); - let str_result = string_result - .as_any() - .downcast_ref::() - .unwrap(); + let args = ScalarFunctionArgs { + args: vec![arg0], + arg_fields: vec![], + number_rows: 2, + return_field: Arc::new(Field::new("x", DataType::Utf8View, false)), + config_options: Arc::new(Default::default()), + }; + let result = to_string.invoke_with_args(args).unwrap(); + let result = result.to_array(2).unwrap(); + let result = result.as_string_view(); // Compare with original input - assert_eq!(str_result.get_data(0).unwrap(), values[0]); - assert_eq!(str_result.get_data(1).unwrap(), values[1]); + assert_eq!(result.value(0), values[0]); + assert_eq!(result.value(1), values[1]); } #[test] @@ -327,24 +362,35 @@ mod tests { // can be converted back using ipv6_string_to_num let to_string = Ipv6NumToString; let to_binary = Ipv6StringToNum; - let ctx = FunctionContext::default(); // Hex representation of IPv6 addresses let hex_values = vec![ "20010db8000000000000000000000001", "00000000000000000000ffffc0a80001", ]; - let hex_input = Arc::new(StringVector::from_slice(&hex_values)) as VectorRef; + let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&hex_values))); // Convert hex to string representation - let string_result = to_string.eval(&ctx, &[hex_input]).unwrap(); + let args = ScalarFunctionArgs { + args: vec![arg0], + arg_fields: vec![], + number_rows: 2, + return_field: Arc::new(Field::new("x", DataType::Utf8View, false)), + config_options: Arc::new(Default::default()), + }; + let result = to_string.invoke_with_args(args).unwrap(); // Then convert string representation back to binary - let binary_result = to_binary.eval(&ctx, &[string_result]).unwrap(); - let bin_result = binary_result - .as_any() - .downcast_ref::() - .unwrap(); + let args = ScalarFunctionArgs { + args: vec![result], + arg_fields: vec![], + number_rows: 2, + return_field: Arc::new(Field::new("x", DataType::BinaryView, false)), + config_options: Arc::new(Default::default()), + }; + let result = to_binary.invoke_with_args(args).unwrap(); + let result = result.to_array(2).unwrap(); + let result = result.as_binary_view(); // Expected binary values let expected_bin1 = [ @@ -354,7 +400,7 @@ mod tests { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xC0, 0xA8, 0, 0x01, ]; - assert_eq!(bin_result.get_data(0).unwrap(), &expected_bin1); - assert_eq!(bin_result.get_data(1).unwrap(), &expected_bin2); + assert_eq!(result.value(0), &expected_bin1); + assert_eq!(result.value(1), &expected_bin2); } } diff --git a/src/common/function/src/scalars/ip/range.rs b/src/common/function/src/scalars/ip/range.rs index ebda747e9b..373fb0832f 100644 --- a/src/common/function/src/scalars/ip/range.rs +++ b/src/common/function/src/scalars/ip/range.rs @@ -14,17 +14,18 @@ use std::net::{Ipv4Addr, Ipv6Addr}; use std::str::FromStr; +use std::sync::Arc; use common_query::error::{InvalidFuncArgsSnafu, Result}; -use datafusion::arrow::datatypes::DataType; -use datafusion_expr::{Signature, Volatility}; -use datatypes::prelude::Value; -use datatypes::scalars::ScalarVectorBuilder; -use datatypes::vectors::{BooleanVectorBuilder, MutableVector, VectorRef}; +use datafusion_common::DataFusionError; +use datafusion_common::arrow::array::{Array, AsArray, BooleanBuilder}; +use datafusion_common::arrow::compute; +use datafusion_common::arrow::datatypes::DataType; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility}; use derive_more::Display; use snafu::ensure; -use crate::function::{Function, FunctionContext}; +use crate::function::{Function, extract_args}; /// Function that checks if an IPv4 address is within a specified CIDR range. /// @@ -52,42 +53,31 @@ impl Function for Ipv4InRange { Signature::string(2, Volatility::Immutable) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure!( - columns.len() == 2, - InvalidFuncArgsSnafu { - err_msg: format!("Expected 2 arguments, got {}", columns.len()) - } - ); + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let [arg0, arg1] = extract_args(self.name(), &args)?; + + let arg0 = compute::cast(&arg0, &DataType::Utf8View)?; + let ip_vec = arg0.as_string_view(); + let arg1 = compute::cast(&arg1, &DataType::Utf8View)?; + let ranges = arg1.as_string_view(); - let ip_vec = &columns[0]; - let range_vec = &columns[1]; let size = ip_vec.len(); - ensure!( - range_vec.len() == size, - InvalidFuncArgsSnafu { - err_msg: "IP addresses and CIDR ranges must have the same number of rows" - .to_string() - } - ); - - let mut results = BooleanVectorBuilder::with_capacity(size); + let mut builder = BooleanBuilder::with_capacity(size); for i in 0..size { - let ip = ip_vec.get(i); - let range = range_vec.get(i); + let ip = ip_vec.is_valid(i).then(|| ip_vec.value(i)); + let range = ranges.is_valid(i).then(|| ranges.value(i)); let in_range = match (ip, range) { - (Value::String(ip_str), Value::String(range_str)) => { - let ip_str = ip_str.as_utf8().trim(); - let range_str = range_str.as_utf8().trim(); - + (Some(ip_str), Some(range_str)) => { if ip_str.is_empty() || range_str.is_empty() { - return InvalidFuncArgsSnafu { - err_msg: "IP address and CIDR range cannot be empty".to_string(), - } - .fail(); + return Err(DataFusionError::Execution( + "IP address or CIDR range cannot be empty".to_string(), + )); } // Parse the IP address @@ -107,10 +97,10 @@ impl Function for Ipv4InRange { _ => None, }; - results.push(in_range); + builder.append_option(in_range); } - Ok(results.to_vector()) + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } } @@ -141,42 +131,29 @@ impl Function for Ipv6InRange { Signature::string(2, Volatility::Immutable) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure!( - columns.len() == 2, - InvalidFuncArgsSnafu { - err_msg: format!("Expected 2 arguments, got {}", columns.len()) - } - ); + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let [arg0, arg1] = extract_args(self.name(), &args)?; - let ip_vec = &columns[0]; - let range_vec = &columns[1]; + let arg0 = compute::cast(&arg0, &DataType::Utf8View)?; + let ip_vec = arg0.as_string_view(); + let arg1 = compute::cast(&arg1, &DataType::Utf8View)?; + let ranges = arg1.as_string_view(); let size = ip_vec.len(); - - ensure!( - range_vec.len() == size, - InvalidFuncArgsSnafu { - err_msg: "IP addresses and CIDR ranges must have the same number of rows" - .to_string() - } - ); - - let mut results = BooleanVectorBuilder::with_capacity(size); + let mut builder = BooleanBuilder::with_capacity(size); for i in 0..size { - let ip = ip_vec.get(i); - let range = range_vec.get(i); + let ip = ip_vec.is_valid(i).then(|| ip_vec.value(i)); + let range = ranges.is_valid(i).then(|| ranges.value(i)); let in_range = match (ip, range) { - (Value::String(ip_str), Value::String(range_str)) => { - let ip_str = ip_str.as_utf8().trim(); - let range_str = range_str.as_utf8().trim(); - + (Some(ip_str), Some(range_str)) => { if ip_str.is_empty() || range_str.is_empty() { - return InvalidFuncArgsSnafu { - err_msg: "IP address and CIDR range cannot be empty".to_string(), - } - .fail(); + return Err(DataFusionError::Execution( + "IP address or CIDR range cannot be empty".to_string(), + )); } // Parse the IP address @@ -196,10 +173,10 @@ impl Function for Ipv6InRange { _ => None, }; - results.push(in_range); + builder.append_option(in_range); } - Ok(results.to_vector()) + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } } @@ -329,15 +306,14 @@ fn is_ipv6_in_range(ip: &Ipv6Addr, cidr_base: &Ipv6Addr, prefix_len: u8) -> Opti mod tests { use std::sync::Arc; - use datatypes::scalars::ScalarVector; - use datatypes::vectors::{BooleanVector, StringVector}; + use arrow_schema::Field; + use datafusion_common::arrow::array::StringViewArray; use super::*; #[test] fn test_ipv4_in_range() { let func = Ipv4InRange; - let ctx = FunctionContext::default(); // Test IPs let ip_values = vec![ @@ -357,24 +333,31 @@ mod tests { "172.16.0.0/16", ]; - let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef; - let cidr_input = Arc::new(StringVector::from_slice(&cidr_values)) as VectorRef; + let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&ip_values))); + let arg1 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&cidr_values))); - let result = func.eval(&ctx, &[ip_input, cidr_input]).unwrap(); - let result = result.as_any().downcast_ref::().unwrap(); + let args = ScalarFunctionArgs { + args: vec![arg0, arg1], + arg_fields: vec![], + number_rows: 5, + return_field: Arc::new(Field::new("x", DataType::Boolean, false)), + config_options: Arc::new(Default::default()), + }; + let result = func.invoke_with_args(args).unwrap(); + let result = result.to_array(5).unwrap(); + let result = result.as_boolean(); // Expected results - assert!(result.get_data(0).unwrap()); // 192.168.1.5 is in 192.168.1.0/24 - assert!(!result.get_data(1).unwrap()); // 192.168.2.1 is not in 192.168.1.0/24 - assert!(result.get_data(2).unwrap()); // 10.0.0.1 is in 10.0.0.0/8 - assert!(result.get_data(3).unwrap()); // 10.1.0.1 is in 10.0.0.0/8 - assert!(result.get_data(4).unwrap()); // 172.16.0.1 is in 172.16.0.0/16 + assert!(result.value(0)); // 192.168.1.5 is in 192.168.1.0/24 + assert!(!result.value(1)); // 192.168.2.1 is not in 192.168.1.0/24 + assert!(result.value(2)); // 10.0.0.1 is in 10.0.0.0/8 + assert!(result.value(3)); // 10.1.0.1 is in 10.0.0.0/8 + assert!(result.value(4)); // 172.16.0.1 is in 172.16.0.0/16 } #[test] fn test_ipv6_in_range() { let func = Ipv6InRange; - let ctx = FunctionContext::default(); // Test IPs let ip_values = vec![ @@ -394,46 +377,70 @@ mod tests { "fe80::/16", ]; - let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef; - let cidr_input = Arc::new(StringVector::from_slice(&cidr_values)) as VectorRef; + let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&ip_values))); + let arg1 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&cidr_values))); - let result = func.eval(&ctx, &[ip_input, cidr_input]).unwrap(); - let result = result.as_any().downcast_ref::().unwrap(); + let args = ScalarFunctionArgs { + args: vec![arg0, arg1], + arg_fields: vec![], + number_rows: 5, + return_field: Arc::new(Field::new("x", DataType::Boolean, false)), + config_options: Arc::new(Default::default()), + }; + let result = func.invoke_with_args(args).unwrap(); + let result = result.to_array(5).unwrap(); + let result = result.as_boolean(); // Expected results - assert!(result.get_data(0).unwrap()); // 2001:db8::1 is in 2001:db8::/32 - assert!(result.get_data(1).unwrap()); // 2001:db8:1:: is in 2001:db8::/32 - assert!(!result.get_data(2).unwrap()); // 2001:db9::1 is not in 2001:db8::/32 - assert!(result.get_data(3).unwrap()); // ::1 is in ::1/128 - assert!(result.get_data(4).unwrap()); // fe80::1 is in fe80::/16 + assert!(result.value(0)); // 2001:db8::1 is in 2001:db8::/32 + assert!(result.value(1)); // 2001:db8:1:: is in 2001:db8::/32 + assert!(!result.value(2)); // 2001:db9::1 is not in 2001:db8::/32 + assert!(result.value(3)); // ::1 is in ::1/128 + assert!(result.value(4)); // fe80::1 is in fe80::/16 } #[test] fn test_invalid_inputs() { let ipv4_func = Ipv4InRange; let ipv6_func = Ipv6InRange; - let ctx = FunctionContext::default(); // Invalid IPv4 address let invalid_ip_values = vec!["not-an-ip", "192.168.1.300"]; let cidr_values = vec!["192.168.1.0/24", "192.168.1.0/24"]; - let invalid_ip_input = Arc::new(StringVector::from_slice(&invalid_ip_values)) as VectorRef; - let cidr_input = Arc::new(StringVector::from_slice(&cidr_values)) as VectorRef; + let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values( + &invalid_ip_values, + ))); + let arg1 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&cidr_values))); - let result = ipv4_func.eval(&ctx, &[invalid_ip_input, cidr_input]); + let args = ScalarFunctionArgs { + args: vec![arg0, arg1], + arg_fields: vec![], + number_rows: 2, + return_field: Arc::new(Field::new("x", DataType::Boolean, false)), + config_options: Arc::new(Default::default()), + }; + let result = ipv4_func.invoke_with_args(args); assert!(result.is_err()); // Invalid CIDR notation let ip_values = vec!["192.168.1.1", "2001:db8::1"]; let invalid_cidr_values = vec!["192.168.1.0", "2001:db8::/129"]; - let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef; - let invalid_cidr_input = - Arc::new(StringVector::from_slice(&invalid_cidr_values)) as VectorRef; + let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&ip_values))); + let arg1 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values( + &invalid_cidr_values, + ))); - let ipv4_result = ipv4_func.eval(&ctx, &[ip_input.clone(), invalid_cidr_input.clone()]); - let ipv6_result = ipv6_func.eval(&ctx, &[ip_input, invalid_cidr_input]); + let args = ScalarFunctionArgs { + args: vec![arg0, arg1], + arg_fields: vec![], + number_rows: 2, + return_field: Arc::new(Field::new("x", DataType::Boolean, false)), + config_options: Arc::new(Default::default()), + }; + let ipv4_result = ipv4_func.invoke_with_args(args.clone()); + let ipv6_result = ipv6_func.invoke_with_args(args); assert!(ipv4_result.is_err()); assert!(ipv6_result.is_err()); @@ -442,20 +449,27 @@ mod tests { #[test] fn test_edge_cases() { let ipv4_func = Ipv4InRange; - let ctx = FunctionContext::default(); // Edge cases like prefix length 0 (matches everything) and 32 (exact match) let ip_values = vec!["8.8.8.8", "192.168.1.1", "192.168.1.1"]; let cidr_values = vec!["0.0.0.0/0", "192.168.1.1/32", "192.168.1.0/32"]; - let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef; - let cidr_input = Arc::new(StringVector::from_slice(&cidr_values)) as VectorRef; + let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&ip_values))); + let arg1 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&cidr_values))); - let result = ipv4_func.eval(&ctx, &[ip_input, cidr_input]).unwrap(); - let result = result.as_any().downcast_ref::().unwrap(); + let args = ScalarFunctionArgs { + args: vec![arg0, arg1], + arg_fields: vec![], + number_rows: 3, + return_field: Arc::new(Field::new("x", DataType::Boolean, false)), + config_options: Arc::new(Default::default()), + }; + let result = ipv4_func.invoke_with_args(args).unwrap(); + let result = result.to_array(3).unwrap(); + let result = result.as_boolean(); - assert!(result.get_data(0).unwrap()); // 8.8.8.8 is in 0.0.0.0/0 (matches everything) - assert!(result.get_data(1).unwrap()); // 192.168.1.1 is in 192.168.1.1/32 (exact match) - assert!(!result.get_data(2).unwrap()); // 192.168.1.1 is not in 192.168.1.0/32 (no match) + assert!(result.value(0)); // 8.8.8.8 is in 0.0.0.0/0 (matches everything) + assert!(result.value(1)); // 192.168.1.1 is in 192.168.1.1/32 (exact match) + assert!(!result.value(2)); // 192.168.1.1 is not in 192.168.1.0/32 (no match) } } diff --git a/src/common/function/src/scalars/math/modulo.rs b/src/common/function/src/scalars/math/modulo.rs index 89df6e99bb..5d79ee25cf 100644 --- a/src/common/function/src/scalars/math/modulo.rs +++ b/src/common/function/src/scalars/math/modulo.rs @@ -15,19 +15,14 @@ use std::fmt; use std::fmt::Display; -use common_query::error; -use common_query::error::{ArrowComputeSnafu, InvalidFuncArgsSnafu, Result}; -use datafusion::arrow::datatypes::DataType; +use common_query::error::Result; +use datafusion_common::arrow::compute; +use datafusion_common::arrow::compute::kernels::numeric; +use datafusion_common::arrow::datatypes::DataType; use datafusion_expr::type_coercion::aggregates::NUMERICS; -use datafusion_expr::{Signature, Volatility}; -use datatypes::arrow::compute; -use datatypes::arrow::compute::kernels::numeric; -use datatypes::arrow::datatypes::DataType as ArrowDataType; -use datatypes::prelude::ConcreteDataType; -use datatypes::vectors::{Helper, VectorRef}; -use snafu::{ResultExt, ensure}; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility}; -use crate::function::{Function, FunctionContext}; +use crate::function::{Function, extract_args}; const NAME: &str = "mod"; @@ -60,38 +55,24 @@ impl Function for ModuloFunction { Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure!( - columns.len() == 2, - InvalidFuncArgsSnafu { - err_msg: format!( - "The length of the args is not correct, expect exactly two, have: {}", - columns.len() - ), - } - ); - let nums = &columns[0]; - let divs = &columns[1]; - let nums_arrow_array = &nums.to_arrow_array(); - let divs_arrow_array = &divs.to_arrow_array(); - let array = numeric::rem(nums_arrow_array, divs_arrow_array).context(ArrowComputeSnafu)?; + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let [nums, divs] = extract_args(self.name(), &args)?; + let array = numeric::rem(&nums, &divs)?; let result = match nums.data_type() { - ConcreteDataType::Int8(_) - | ConcreteDataType::Int16(_) - | ConcreteDataType::Int32(_) - | ConcreteDataType::Int64(_) => compute::cast(&array, &ArrowDataType::Int64), - ConcreteDataType::UInt8(_) - | ConcreteDataType::UInt16(_) - | ConcreteDataType::UInt32(_) - | ConcreteDataType::UInt64(_) => compute::cast(&array, &ArrowDataType::UInt64), - ConcreteDataType::Float32(_) | ConcreteDataType::Float64(_) => { - compute::cast(&array, &ArrowDataType::Float64) + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + compute::cast(&array, &DataType::Int64) } + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { + compute::cast(&array, &DataType::UInt64) + } + DataType::Float32 | DataType::Float64 => compute::cast(&array, &DataType::Float64), _ => unreachable!("unexpected datatype: {:?}", nums.data_type()), - } - .context(ArrowComputeSnafu)?; - Helper::try_into_vector(&result).context(error::FromArrowArraySnafu) + }?; + Ok(ColumnarValue::Array(result)) } } @@ -99,9 +80,11 @@ impl Function for ModuloFunction { mod tests { use std::sync::Arc; - use common_error::ext::ErrorExt; - use datatypes::value::Value; - use datatypes::vectors::{Float64Vector, Int32Vector, StringVector, UInt32Vector}; + use arrow_schema::Field; + use datafusion_common::arrow::array::{ + AsArray, Float64Array, Int32Array, StringViewArray, UInt32Array, + }; + use datafusion_common::arrow::datatypes::{Float64Type, Int64Type, UInt64Type}; use super::*; #[test] @@ -120,15 +103,23 @@ mod tests { let nums = vec![18, -17, 5, -6]; let divs = vec![4, 8, -5, -5]; - let args: Vec = vec![ - Arc::new(Int32Vector::from_vec(nums.clone())), - Arc::new(Int32Vector::from_vec(divs.clone())), - ]; - let result = function.eval(&FunctionContext::default(), &args).unwrap(); + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(Int32Array::from(nums.clone()))), + ColumnarValue::Array(Arc::new(Int32Array::from(divs.clone()))), + ], + arg_fields: vec![], + number_rows: 4, + return_field: Arc::new(Field::new("x", DataType::Int64, false)), + config_options: Arc::new(Default::default()), + }; + let result = function.invoke_with_args(args).unwrap(); + let result = result.to_array(4).unwrap(); + let result = result.as_primitive::(); assert_eq!(result.len(), 4); for i in 0..4 { let p: i64 = (nums[i] % divs[i]) as i64; - assert!(matches!(result.get(i), Value::Int64(v) if v == p)); + assert_eq!(result.value(i), p); } } @@ -148,15 +139,23 @@ mod tests { let nums: Vec = vec![18, 17, 5, 6]; let divs: Vec = vec![4, 8, 5, 5]; - let args: Vec = vec![ - Arc::new(UInt32Vector::from_vec(nums.clone())), - Arc::new(UInt32Vector::from_vec(divs.clone())), - ]; - let result = function.eval(&FunctionContext::default(), &args).unwrap(); + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(UInt32Array::from(nums.clone()))), + ColumnarValue::Array(Arc::new(UInt32Array::from(divs.clone()))), + ], + arg_fields: vec![], + number_rows: 4, + return_field: Arc::new(Field::new("x", DataType::UInt64, false)), + config_options: Arc::new(Default::default()), + }; + let result = function.invoke_with_args(args).unwrap(); + let result = result.to_array(4).unwrap(); + let result = result.as_primitive::(); assert_eq!(result.len(), 4); for i in 0..4 { let p: u64 = (nums[i] % divs[i]) as u64; - assert!(matches!(result.get(i), Value::UInt64(v) if v == p)); + assert_eq!(result.value(i), p); } } @@ -176,15 +175,23 @@ mod tests { let nums = vec![18.0, 17.0, 5.0, 6.0]; let divs = vec![4.0, 8.0, 5.0, 5.0]; - let args: Vec = vec![ - Arc::new(Float64Vector::from_vec(nums.clone())), - Arc::new(Float64Vector::from_vec(divs.clone())), - ]; - let result = function.eval(&FunctionContext::default(), &args).unwrap(); + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(Float64Array::from(nums.clone()))), + ColumnarValue::Array(Arc::new(Float64Array::from(divs.clone()))), + ], + arg_fields: vec![], + number_rows: 4, + return_field: Arc::new(Field::new("x", DataType::Float64, false)), + config_options: Arc::new(Default::default()), + }; + let result = function.invoke_with_args(args).unwrap(); + let result = result.to_array(4).unwrap(); + let result = result.as_primitive::(); assert_eq!(result.len(), 4); for i in 0..4 { let p: f64 = nums[i] % divs[i]; - assert!(matches!(result.get(i), Value::Float64(v) if v == p)); + assert_eq!(result.value(i), p); } } @@ -195,37 +202,53 @@ mod tests { let nums = vec![27]; let divs = vec![0]; - let args: Vec = vec![ - Arc::new(Int32Vector::from_vec(nums.clone())), - Arc::new(Int32Vector::from_vec(divs.clone())), - ]; - let result = function.eval(&FunctionContext::default(), &args); + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(Int32Array::from(nums))), + ColumnarValue::Array(Arc::new(Int32Array::from(divs))), + ], + arg_fields: vec![], + number_rows: 1, + return_field: Arc::new(Field::new("x", DataType::Int64, false)), + config_options: Arc::new(Default::default()), + }; + let result = function.invoke_with_args(args); assert!(result.is_err()); - let err_msg = result.unwrap_err().output_msg(); - assert_eq!( - err_msg, - "Failed to perform compute operation on arrow arrays: Divide by zero error" - ); + let err_msg = result.unwrap_err().to_string(); + assert_eq!(err_msg, "Arrow error: Divide by zero error"); let nums = vec![27]; - let args: Vec = vec![Arc::new(Int32Vector::from_vec(nums.clone()))]; - let result = function.eval(&FunctionContext::default(), &args); + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::new(Int32Array::from(nums)))], + arg_fields: vec![], + number_rows: 1, + return_field: Arc::new(Field::new("x", DataType::Int64, false)), + config_options: Arc::new(Default::default()), + }; + let result = function.invoke_with_args(args); assert!(result.is_err()); - let err_msg = result.unwrap_err().output_msg(); - assert!( - err_msg.contains("The length of the args is not correct, expect exactly two, have: 1") + let err_msg = result.unwrap_err().to_string(); + assert_eq!( + err_msg, + "Execution error: mod function requires 2 arguments, got 1" ); let nums = vec!["27"]; let divs = vec!["4"]; - let args: Vec = vec![ - Arc::new(StringVector::from(nums.clone())), - Arc::new(StringVector::from(divs.clone())), - ]; - let result = function.eval(&FunctionContext::default(), &args); + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(StringViewArray::from(nums))), + ColumnarValue::Array(Arc::new(StringViewArray::from(divs))), + ], + arg_fields: vec![], + number_rows: 1, + return_field: Arc::new(Field::new("x", DataType::Int64, false)), + config_options: Arc::new(Default::default()), + }; + let result = function.invoke_with_args(args); assert!(result.is_err()); - let err_msg = result.unwrap_err().output_msg(); + let err_msg = result.unwrap_err().to_string(); assert!(err_msg.contains("Invalid arithmetic operation")); } } diff --git a/src/common/function/src/scalars/test.rs b/src/common/function/src/scalars/test.rs index 0623060350..e85ac9062a 100644 --- a/src/common/function/src/scalars/test.rs +++ b/src/common/function/src/scalars/test.rs @@ -13,14 +13,14 @@ // limitations under the License. use std::fmt; -use std::sync::Arc; use common_query::error::Result; -use datafusion_expr::{Signature, Volatility}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_expr::{ScalarFunctionArgs, Signature, Volatility}; use datatypes::arrow::datatypes::DataType; -use datatypes::prelude::VectorRef; +use datatypes::vectors::{Helper, Vector}; -use crate::function::{Function, FunctionContext}; +use crate::function::{Function, extract_args}; use crate::scalars::expression::{EvalContext, scalar_binary_op}; #[derive(Clone, Default)] @@ -42,14 +42,19 @@ impl Function for TestAndFunction { ) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let [arg0, arg1] = extract_args(self.name(), &args)?; + let columns = Helper::try_into_vectors(&[arg0, arg1]).unwrap(); let col = scalar_binary_op::( &columns[0], &columns[1], scalar_and, &mut EvalContext::default(), )?; - Ok(Arc::new(col)) + Ok(ColumnarValue::Array(col.to_arrow_array())) } } diff --git a/src/common/function/src/scalars/udf.rs b/src/common/function/src/scalars/udf.rs index b0aa19d9be..25efcc4cde 100644 --- a/src/common/function/src/scalars/udf.rs +++ b/src/common/function/src/scalars/udf.rs @@ -107,12 +107,11 @@ mod tests { use common_query::prelude::ScalarValue; use datafusion::arrow::array::BooleanArray; + use datafusion_common::arrow::array::AsArray; + use datafusion_common::arrow::datatypes::DataType as ArrowDataType; use datafusion_common::config::ConfigOptions; use datatypes::arrow::datatypes::Field; use datatypes::data_type::{ConcreteDataType, DataType}; - use datatypes::prelude::VectorRef; - use datatypes::value::Value; - use datatypes::vectors::{BooleanVector, ConstantVector}; use session::context::QueryContextBuilder; use super::*; @@ -124,20 +123,31 @@ mod tests { let f = Arc::new(TestAndFunction); let query_ctx = QueryContextBuilder::default().build().into(); - let args: Vec = vec![ - Arc::new(ConstantVector::new( - Arc::new(BooleanVector::from(vec![true])), - 3, - )), - Arc::new(BooleanVector::from(vec![true, false, true])), - ]; + let args = ScalarFunctionArgs { + args: vec![ + datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![ + true, true, true, + ]))), + datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![ + true, false, true, + ]))), + ], + arg_fields: vec![], + number_rows: 3, + return_field: Arc::new(Field::new("x", ArrowDataType::Boolean, true)), + config_options: Arc::new(Default::default()), + }; - let vector = f.eval(&FunctionContext::default(), &args).unwrap(); + let result = f + .invoke_with_args(args) + .and_then(|x| x.to_array(3)) + .unwrap(); + let vector = result.as_boolean(); assert_eq!(3, vector.len()); - for i in 0..3 { - assert!(matches!(vector.get(i), Value::Boolean(b) if b == (i == 0 || i == 2))); - } + assert!(vector.value(0)); + assert!(!vector.value(1)); + assert!(vector.value(2)); // create a udf and test it again let udf = create_udf(f.clone(), query_ctx, Arc::new(FunctionState::default())); diff --git a/src/common/function/src/system/timezone.rs b/src/common/function/src/system/timezone.rs index 22bc395453..6fd6f9c359 100644 --- a/src/common/function/src/system/timezone.rs +++ b/src/common/function/src/system/timezone.rs @@ -13,15 +13,13 @@ // limitations under the License. use std::fmt::{self}; -use std::sync::Arc; use common_query::error::Result; use datafusion::arrow::datatypes::DataType; -use datafusion_expr::{Signature, Volatility}; -use datatypes::prelude::ScalarVector; -use datatypes::vectors::{StringVector, VectorRef}; +use datafusion_common::ScalarValue; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility}; -use crate::function::{Function, FunctionContext}; +use crate::function::{Function, find_function_context}; /// A function to return current session timezone. #[derive(Clone, Debug, Default)] @@ -35,17 +33,21 @@ impl Function for TimezoneFunction { } fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Utf8) + Ok(DataType::Utf8View) } fn signature(&self) -> Signature { Signature::nullary(Volatility::Immutable) } - fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result { + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let func_ctx = find_function_context(&args)?; let tz = func_ctx.query_ctx.timezone().to_string(); - Ok(Arc::new(StringVector::from_slice(&[&tz])) as _) + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(tz)))) } } @@ -59,14 +61,18 @@ impl fmt::Display for TimezoneFunction { mod tests { use std::sync::Arc; + use arrow_schema::Field; + use datafusion_common::config::ConfigOptions; use session::context::QueryContextBuilder; use super::*; + use crate::function::FunctionContext; + #[test] fn test_build_function() { let build = TimezoneFunction; assert_eq!("timezone", build.name()); - assert_eq!(DataType::Utf8, build.return_type(&[]).unwrap()); + assert_eq!(DataType::Utf8View, build.return_type(&[]).unwrap()); assert_eq!(build.signature(), Signature::nullary(Volatility::Immutable)); let query_ctx = QueryContextBuilder::default().build().into(); @@ -75,8 +81,21 @@ mod tests { query_ctx, ..Default::default() }; - let vector = build.eval(&func_ctx, &[]).unwrap(); - let expect: VectorRef = Arc::new(StringVector::from(vec!["UTC"])); - assert_eq!(expect, vector); + let mut config_options = ConfigOptions::default(); + config_options.extensions.insert(func_ctx); + let config_options = Arc::new(config_options); + + let args = ScalarFunctionArgs { + args: vec![], + arg_fields: vec![], + number_rows: 0, + return_field: Arc::new(Field::new("x", DataType::Utf8View, false)), + config_options, + }; + let result = build.invoke_with_args(args).unwrap(); + let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s))) = result else { + unreachable!() + }; + assert_eq!(s, "UTC"); } } diff --git a/src/common/function/src/system/version.rs b/src/common/function/src/system/version.rs index d1ef84baea..fe9382c886 100644 --- a/src/common/function/src/system/version.rs +++ b/src/common/function/src/system/version.rs @@ -13,15 +13,14 @@ // limitations under the License. use std::fmt; -use std::sync::Arc; use common_query::error::Result; use datafusion::arrow::datatypes::DataType; -use datafusion_expr::{Signature, Volatility}; -use datatypes::vectors::{StringVector, VectorRef}; +use datafusion_common::ScalarValue; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility}; use session::context::Channel; -use crate::function::{Function, FunctionContext}; +use crate::function::{Function, find_function_context}; #[derive(Clone, Debug, Default)] pub(crate) struct VersionFunction; @@ -38,14 +37,18 @@ impl Function for VersionFunction { } fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Utf8) + Ok(DataType::Utf8View) } fn signature(&self) -> Signature { Signature::nullary(Volatility::Immutable) } - fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result { + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let func_ctx = find_function_context(&args)?; let version = match func_ctx.query_ctx.channel() { Channel::Mysql => { format!( @@ -60,7 +63,6 @@ impl Function for VersionFunction { } _ => common_version::version().to_string(), }; - let result = StringVector::from(vec![version]); - Ok(Arc::new(result)) + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(version)))) } } diff --git a/src/common/recordbatch/src/recordbatch.rs b/src/common/recordbatch/src/recordbatch.rs index f83f51e584..3cc30ce1ba 100644 --- a/src/common/recordbatch/src/recordbatch.rs +++ b/src/common/recordbatch/src/recordbatch.rs @@ -17,6 +17,9 @@ use std::slice; use std::sync::Arc; use datafusion::arrow::util::pretty::pretty_format_batches; +use datafusion_common::arrow::array::ArrayRef; +use datafusion_common::arrow::compute; +use datafusion_common::arrow::datatypes::{DataType as ArrowDataType, SchemaRef as ArrowSchemaRef}; use datatypes::arrow::array::RecordBatchOptions; use datatypes::prelude::DataType; use datatypes::schema::SchemaRef; @@ -28,8 +31,8 @@ use snafu::{OptionExt, ResultExt, ensure}; use crate::DfRecordBatch; use crate::error::{ - self, CastVectorSnafu, ColumnNotExistsSnafu, DataTypesSnafu, ProjectArrowRecordBatchSnafu, - Result, + self, ArrowComputeSnafu, CastVectorSnafu, ColumnNotExistsSnafu, DataTypesSnafu, + ProjectArrowRecordBatchSnafu, Result, }; /// A two-dimensional batch of column-oriented data with a defined schema. @@ -49,6 +52,15 @@ impl RecordBatch { let columns: Vec<_> = columns.into_iter().collect(); let arrow_arrays = columns.iter().map(|v| v.to_arrow_array()).collect(); + // Casting the arrays here to match the schema, is a temporary solution to support Arrow's + // view array types (`StringViewArray` and `BinaryViewArray`). + // As to "support": the arrays here are created from vectors, which do not have types + // corresponding to view arrays. What we can do is to only cast them. + // As to "temporary": we are planing to use Arrow's RecordBatch directly in the read path. + // the casting here will be removed in the end. + // TODO(LFC): Remove the casting here once `Batch` is no longer used. + let arrow_arrays = Self::cast_view_arrays(schema.arrow_schema(), arrow_arrays)?; + let df_record_batch = DfRecordBatch::try_new(schema.arrow_schema().clone(), arrow_arrays) .context(error::NewDfRecordBatchSnafu)?; @@ -59,6 +71,24 @@ impl RecordBatch { }) } + fn cast_view_arrays( + schema: &ArrowSchemaRef, + mut arrays: Vec, + ) -> Result> { + for (f, a) in schema.fields().iter().zip(arrays.iter_mut()) { + let expected = f.data_type(); + let actual = a.data_type(); + if matches!( + (expected, actual), + (ArrowDataType::Utf8View, ArrowDataType::Utf8) + | (ArrowDataType::BinaryView, ArrowDataType::Binary) + ) { + *a = compute::cast(a, expected).context(ArrowComputeSnafu)?; + } + } + Ok(arrays) + } + /// Create an empty [`RecordBatch`] from `schema`. pub fn new_empty(schema: SchemaRef) -> RecordBatch { let df_record_batch = DfRecordBatch::new_empty(schema.arrow_schema().clone()); diff --git a/src/flow/src/transform.rs b/src/flow/src/transform.rs index 8d2c1ec8a4..c17fc9f1b0 100644 --- a/src/flow/src/transform.rs +++ b/src/flow/src/transform.rs @@ -16,14 +16,12 @@ use std::collections::BTreeMap; use std::sync::Arc; -use common_error::ext::BoxedError; -use common_function::function::{FunctionContext, FunctionRef}; +use common_function::function::FunctionRef; use datafusion::arrow::datatypes::{DataType, TimeUnit}; use datafusion_expr::{Signature, Volatility}; use datafusion_substrait::extensions::Extensions; use query::QueryEngine; use serde::{Deserialize, Serialize}; -use snafu::ResultExt; /// note here we are using the `substrait_proto_df` crate from the `substrait` module and /// rename it to `substrait_proto` use substrait::substrait_proto_df as substrait_proto; @@ -31,7 +29,7 @@ use substrait_proto::proto::extensions::SimpleExtensionDeclaration; use substrait_proto::proto::extensions::simple_extension_declaration::MappingType; use crate::adapter::FlownodeContext; -use crate::error::{Error, NotImplementedSnafu, UnexpectedSnafu}; +use crate::error::{Error, NotImplementedSnafu}; use crate::expr::{TUMBLE_END, TUMBLE_START}; /// a simple macro to generate a not implemented error macro_rules! not_impl_err { @@ -149,19 +147,6 @@ impl common_function::function::Function for TumbleFunction { fn signature(&self) -> Signature { Signature::variadic_any(Volatility::Immutable) } - - fn eval( - &self, - _func_ctx: &FunctionContext, - _columns: &[datatypes::prelude::VectorRef], - ) -> common_query::error::Result { - UnexpectedSnafu { - reason: "Tumbler function is not implemented for datafusion executor", - } - .fail() - .map_err(BoxedError::new) - .context(common_query::error::ExecuteSnafu) - } } #[cfg(test)] diff --git a/src/query/src/dist_plan/analyzer.rs b/src/query/src/dist_plan/analyzer.rs index c7d73e70ae..e7353a8e4b 100644 --- a/src/query/src/dist_plan/analyzer.rs +++ b/src/query/src/dist_plan/analyzer.rs @@ -114,6 +114,8 @@ impl AnalyzerRule for DistPlannerAnalyzer { // seems to lost track on it: the `ConfigOptions` is recreated with its default values again. // So we create a custom `OptimizerConfig` with the desired `ConfigOptions` // to walk around the issue. + // TODO(LFC): Maybe use DataFusion's `OptimizerContext` again + // once https://github.com/apache/datafusion/pull/17742 is merged. struct OptimizerContext { inner: datafusion_optimizer::OptimizerContext, config: Arc,