refactor: rewrite some UDFs to DataFusion style (part 4) (#7011)

Signed-off-by: luofucong <luofc@foxmail.com>
This commit is contained in:
LFC
2025-09-25 03:50:58 +08:00
committed by GitHub
parent a14c01a807
commit 6d0dd2540e
12 changed files with 710 additions and 507 deletions

View File

@@ -14,18 +14,21 @@
use std::net::{Ipv4Addr, Ipv6Addr};
use std::str::FromStr;
use std::sync::Arc;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datafusion_common::types;
use datafusion_expr::{Coercion, Signature, TypeSignature, TypeSignatureClass, Volatility};
use datatypes::arrow::datatypes::DataType;
use datatypes::prelude::Value;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{MutableVector, StringVectorBuilder, VectorRef};
use datafusion_common::arrow::array::{Array, AsArray, StringViewBuilder};
use datafusion_common::arrow::compute;
use datafusion_common::arrow::datatypes::{DataType, UInt8Type};
use datafusion_common::{DataFusionError, types};
use datafusion_expr::{
Coercion, ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, TypeSignatureClass,
Volatility,
};
use derive_more::Display;
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::function::Function;
/// Function that converts an IPv4 address string to CIDR notation.
///
@@ -46,7 +49,7 @@ impl Function for Ipv4ToCidr {
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Utf8)
Ok(DataType::Utf8View)
}
fn signature(&self) -> Signature {
@@ -62,19 +65,29 @@ impl Function for Ipv4ToCidr {
)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 1 || columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!("Expected 1 or 2 arguments, got {}", columns.len())
}
);
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
if args.args.len() != 1 && args.args.len() != 2 {
return Err(DataFusionError::Execution(format!(
"expecting 1 or 2 arguments, got {}",
args.args.len()
)));
}
let columns = ColumnarValue::values_to_arrays(&args.args)?;
let ip_vec = &columns[0];
let mut results = StringVectorBuilder::with_capacity(ip_vec.len());
let mut builder = StringViewBuilder::with_capacity(ip_vec.len());
let arg0 = compute::cast(ip_vec, &DataType::Utf8View)?;
let ip_vec = arg0.as_string_view();
let has_subnet_arg = columns.len() == 2;
let subnet_vec = if has_subnet_arg {
let maybe_arg1 = if columns.len() > 1 {
Some(compute::cast(&columns[1], &DataType::UInt8)?)
} else {
None
};
let subnets = if let Some(arg1) = maybe_arg1.as_ref() {
ensure!(
columns[1].len() == ip_vec.len(),
InvalidFuncArgsSnafu {
@@ -83,23 +96,19 @@ impl Function for Ipv4ToCidr {
.to_string()
}
);
Some(&columns[1])
Some(arg1.as_primitive::<UInt8Type>())
} else {
None
};
for i in 0..ip_vec.len() {
let ip_str = ip_vec.get(i);
let subnet = subnet_vec.map(|v| v.get(i));
let ip_str = ip_vec.is_valid(i).then(|| ip_vec.value(i));
let subnet = subnets.and_then(|v| v.is_valid(i).then(|| v.value(i)));
let cidr = match (ip_str, subnet) {
(Value::String(s), Some(Value::UInt8(mask))) => {
let ip_str = s.as_utf8().trim();
(Some(ip_str), Some(mask)) => {
if ip_str.is_empty() {
return InvalidFuncArgsSnafu {
err_msg: "Empty IPv4 address".to_string(),
}
.fail();
return Err(DataFusionError::Execution("empty IPv4 address".to_string()));
}
let ip_addr = complete_and_parse_ipv4(ip_str)?;
@@ -109,13 +118,9 @@ impl Function for Ipv4ToCidr {
Some(format!("{}/{}", masked_ip, mask))
}
(Value::String(s), None) => {
let ip_str = s.as_utf8().trim();
(Some(ip_str), None) => {
if ip_str.is_empty() {
return InvalidFuncArgsSnafu {
err_msg: "Empty IPv4 address".to_string(),
}
.fail();
return Err(DataFusionError::Execution("empty IPv4 address".to_string()));
}
let ip_addr = complete_and_parse_ipv4(ip_str)?;
@@ -149,10 +154,10 @@ impl Function for Ipv4ToCidr {
_ => None,
};
results.push(cidr.as_deref());
builder.append_option(cidr.as_deref());
}
Ok(results.to_vector())
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
}
@@ -175,7 +180,7 @@ impl Function for Ipv6ToCidr {
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Utf8)
Ok(DataType::Utf8View)
}
fn signature(&self) -> Signature {
@@ -188,37 +193,41 @@ impl Function for Ipv6ToCidr {
)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 1 || columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!("Expected 1 or 2 arguments, got {}", columns.len())
}
);
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
if args.args.len() != 1 && args.args.len() != 2 {
return Err(DataFusionError::Execution(format!(
"expecting 1 or 2 arguments, got {}",
args.args.len()
)));
}
let columns = ColumnarValue::values_to_arrays(&args.args)?;
let ip_vec = &columns[0];
let size = ip_vec.len();
let mut results = StringVectorBuilder::with_capacity(size);
let mut builder = StringViewBuilder::with_capacity(size);
let arg0 = compute::cast(ip_vec, &DataType::Utf8View)?;
let ip_vec = arg0.as_string_view();
let has_subnet_arg = columns.len() == 2;
let subnet_vec = if has_subnet_arg {
Some(&columns[1])
let maybe_arg1 = if columns.len() > 1 {
Some(compute::cast(&columns[1], &DataType::UInt8)?)
} else {
None
};
let subnets = maybe_arg1
.as_ref()
.map(|arg1| arg1.as_primitive::<UInt8Type>());
for i in 0..size {
let ip_str = ip_vec.get(i);
let subnet = subnet_vec.map(|v| v.get(i));
let ip_str = ip_vec.is_valid(i).then(|| ip_vec.value(i));
let subnet = subnets.and_then(|v| v.is_valid(i).then(|| v.value(i)));
let cidr = match (ip_str, subnet) {
(Value::String(s), Some(Value::UInt8(mask))) => {
let ip_str = s.as_utf8().trim();
(Some(ip_str), Some(mask)) => {
if ip_str.is_empty() {
return InvalidFuncArgsSnafu {
err_msg: "Empty IPv6 address".to_string(),
}
.fail();
return Err(DataFusionError::Execution("empty IPv6 address".to_string()));
}
let ip_addr = complete_and_parse_ipv6(ip_str)?;
@@ -228,13 +237,9 @@ impl Function for Ipv6ToCidr {
Some(format!("{}/{}", masked_ip, mask))
}
(Value::String(s), None) => {
let ip_str = s.as_utf8().trim();
(Some(ip_str), None) => {
if ip_str.is_empty() {
return InvalidFuncArgsSnafu {
err_msg: "Empty IPv6 address".to_string(),
}
.fail();
return Err(DataFusionError::Execution("empty IPv6 address".to_string()));
}
let ip_addr = complete_and_parse_ipv6(ip_str)?;
@@ -250,10 +255,10 @@ impl Function for Ipv6ToCidr {
_ => None,
};
results.push(cidr.as_deref());
builder.append_option(cidr.as_deref());
}
Ok(results.to_vector())
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
}
@@ -375,108 +380,148 @@ fn auto_detect_ipv6_subnet(addr: &Ipv6Addr) -> u8 {
#[cfg(test)]
mod tests {
use std::sync::Arc;
use datatypes::scalars::ScalarVector;
use datatypes::vectors::{StringVector, UInt8Vector};
use arrow_schema::Field;
use datafusion_common::arrow::array::{StringViewArray, UInt8Array};
use super::*;
#[test]
fn test_ipv4_to_cidr_auto() {
let func = Ipv4ToCidr;
let ctx = FunctionContext::default();
// Test data with auto subnet detection
let values = vec!["192.168.1.0", "10.0.0.0", "172.16", "192"];
let input = Arc::new(StringVector::from_slice(&values)) as VectorRef;
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values)));
let result = func.eval(&ctx, &[input]).unwrap();
let result = result.as_any().downcast_ref::<StringVector>().unwrap();
let args = ScalarFunctionArgs {
args: vec![arg0],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
config_options: Arc::new(Default::default()),
};
let result = func.invoke_with_args(args).unwrap();
let result = result.to_array(4).unwrap();
let result = result.as_string_view();
assert_eq!(result.get_data(0).unwrap(), "192.168.1.0/24");
assert_eq!(result.get_data(1).unwrap(), "10.0.0.0/8");
assert_eq!(result.get_data(2).unwrap(), "172.16.0.0/16");
assert_eq!(result.get_data(3).unwrap(), "192.0.0.0/8");
assert_eq!(result.value(0), "192.168.1.0/24");
assert_eq!(result.value(1), "10.0.0.0/8");
assert_eq!(result.value(2), "172.16.0.0/16");
assert_eq!(result.value(3), "192.0.0.0/8");
}
#[test]
fn test_ipv4_to_cidr_with_subnet() {
let func = Ipv4ToCidr;
let ctx = FunctionContext::default();
// Test data with explicit subnet
let ip_values = vec!["192.168.1.1", "10.0.0.1", "172.16.5.5"];
let subnet_values = vec![24u8, 16u8, 12u8];
let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
let subnet_input = Arc::new(UInt8Vector::from_vec(subnet_values)) as VectorRef;
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&ip_values)));
let arg1 = ColumnarValue::Array(Arc::new(UInt8Array::from(subnet_values)));
let result = func.eval(&ctx, &[ip_input, subnet_input]).unwrap();
let result = result.as_any().downcast_ref::<StringVector>().unwrap();
let args = ScalarFunctionArgs {
args: vec![arg0, arg1],
arg_fields: vec![],
number_rows: 3,
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
config_options: Arc::new(Default::default()),
};
let result = func.invoke_with_args(args).unwrap();
let result = result.to_array(3).unwrap();
let result = result.as_string_view();
assert_eq!(result.get_data(0).unwrap(), "192.168.1.0/24");
assert_eq!(result.get_data(1).unwrap(), "10.0.0.0/16");
assert_eq!(result.get_data(2).unwrap(), "172.16.0.0/12");
assert_eq!(result.value(0), "192.168.1.0/24");
assert_eq!(result.value(1), "10.0.0.0/16");
assert_eq!(result.value(2), "172.16.0.0/12");
}
#[test]
fn test_ipv6_to_cidr_auto() {
let func = Ipv6ToCidr;
let ctx = FunctionContext::default();
// Test data with auto subnet detection
let values = vec!["2001:db8::", "2001:db8", "fe80::1", "::1"];
let input = Arc::new(StringVector::from_slice(&values)) as VectorRef;
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values)));
let result = func.eval(&ctx, &[input]).unwrap();
let result = result.as_any().downcast_ref::<StringVector>().unwrap();
let args = ScalarFunctionArgs {
args: vec![arg0],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
config_options: Arc::new(Default::default()),
};
let result = func.invoke_with_args(args).unwrap();
let result = result.to_array(4).unwrap();
let result = result.as_string_view();
assert_eq!(result.get_data(0).unwrap(), "2001:db8::/32");
assert_eq!(result.get_data(1).unwrap(), "2001:db8::/32");
assert_eq!(result.get_data(2).unwrap(), "fe80::/16");
assert_eq!(result.get_data(3).unwrap(), "::1/128"); // Special case for ::1
assert_eq!(result.value(0), "2001:db8::/32");
assert_eq!(result.value(1), "2001:db8::/32");
assert_eq!(result.value(2), "fe80::/16");
assert_eq!(result.value(3), "::1/128"); // Special case for ::1
}
#[test]
fn test_ipv6_to_cidr_with_subnet() {
let func = Ipv6ToCidr;
let ctx = FunctionContext::default();
// Test data with explicit subnet
let ip_values = vec!["2001:db8::", "fe80::1", "2001:db8:1234::"];
let subnet_values = vec![48u8, 10u8, 56u8];
let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
let subnet_input = Arc::new(UInt8Vector::from_vec(subnet_values)) as VectorRef;
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&ip_values)));
let arg1 = ColumnarValue::Array(Arc::new(UInt8Array::from(subnet_values)));
let result = func.eval(&ctx, &[ip_input, subnet_input]).unwrap();
let result = result.as_any().downcast_ref::<StringVector>().unwrap();
let args = ScalarFunctionArgs {
args: vec![arg0, arg1],
arg_fields: vec![],
number_rows: 3,
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
config_options: Arc::new(Default::default()),
};
let result = func.invoke_with_args(args).unwrap();
let result = result.to_array(3).unwrap();
let result = result.as_string_view();
assert_eq!(result.get_data(0).unwrap(), "2001:db8::/48");
assert_eq!(result.get_data(1).unwrap(), "fe80::/10");
assert_eq!(result.get_data(2).unwrap(), "2001:db8:1234::/56");
assert_eq!(result.value(0), "2001:db8::/48");
assert_eq!(result.value(1), "fe80::/10");
assert_eq!(result.value(2), "2001:db8:1234::/56");
}
#[test]
fn test_invalid_inputs() {
let ipv4_func = Ipv4ToCidr;
let ipv6_func = Ipv6ToCidr;
let ctx = FunctionContext::default();
// Empty string should fail
let empty_values = vec![""];
let empty_input = Arc::new(StringVector::from_slice(&empty_values)) as VectorRef;
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&empty_values)));
let ipv4_result = ipv4_func.eval(&ctx, std::slice::from_ref(&empty_input));
let ipv6_result = ipv6_func.eval(&ctx, std::slice::from_ref(&empty_input));
let args = ScalarFunctionArgs {
args: vec![arg0],
arg_fields: vec![],
number_rows: 1,
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
config_options: Arc::new(Default::default()),
};
let ipv4_result = ipv4_func.invoke_with_args(args.clone());
let ipv6_result = ipv6_func.invoke_with_args(args);
assert!(ipv4_result.is_err());
assert!(ipv6_result.is_err());
// Invalid IP formats should fail
let invalid_values = vec!["not an ip", "192.168.1.256", "zzzz::ffff"];
let invalid_input = Arc::new(StringVector::from_slice(&invalid_values)) as VectorRef;
let arg0 =
ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&invalid_values)));
let ipv4_result = ipv4_func.eval(&ctx, std::slice::from_ref(&invalid_input));
let args = ScalarFunctionArgs {
args: vec![arg0],
arg_fields: vec![],
number_rows: 3,
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
config_options: Arc::new(Default::default()),
};
let ipv4_result = ipv4_func.invoke_with_args(args);
assert!(ipv4_result.is_err());
}

View File

@@ -14,16 +14,16 @@
use std::net::Ipv4Addr;
use std::str::FromStr;
use std::sync::Arc;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datatypes::arrow::datatypes::DataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{MutableVector, StringVectorBuilder, UInt32VectorBuilder, VectorRef};
use datafusion_common::arrow::array::{Array, AsArray, StringViewBuilder, UInt32Builder};
use datafusion_common::arrow::compute;
use datafusion_common::arrow::datatypes::{DataType, UInt32Type};
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility};
use derive_more::Display;
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::function::{Function, extract_args};
/// Function that converts a UInt32 number to an IPv4 address string.
///
@@ -53,7 +53,7 @@ impl Function for Ipv4NumToString {
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Utf8)
Ok(DataType::Utf8View)
}
fn signature(&self) -> Signature {
@@ -63,22 +63,20 @@ impl Function for Ipv4NumToString {
)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 1,
InvalidFuncArgsSnafu {
err_msg: format!("Expected 1 argument, got {}", columns.len())
}
);
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let [arg0] = extract_args(self.name(), &args)?;
let uint_vec = arg0.as_primitive::<UInt32Type>();
let uint_vec = &columns[0];
let size = uint_vec.len();
let mut results = StringVectorBuilder::with_capacity(size);
let mut builder = StringViewBuilder::with_capacity(size);
for i in 0..size {
let ip_num = uint_vec.get(i);
let ip_num = uint_vec.is_valid(i).then(|| uint_vec.value(i));
let ip_str = match ip_num {
datatypes::value::Value::UInt32(num) => {
Some(num) => {
// Convert UInt32 to IPv4 string (A.B.C.D format)
let a = (num >> 24) & 0xFF;
let b = (num >> 16) & 0xFF;
@@ -89,10 +87,10 @@ impl Function for Ipv4NumToString {
_ => None,
};
results.push(ip_str.as_deref());
builder.append_option(ip_str.as_deref());
}
Ok(results.to_vector())
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
fn aliases(&self) -> &[String] {
@@ -123,23 +121,21 @@ impl Function for Ipv4StringToNum {
Signature::string(1, Volatility::Immutable)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 1,
InvalidFuncArgsSnafu {
err_msg: format!("Expected 1 argument, got {}", columns.len())
}
);
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let [arg0] = extract_args(self.name(), &args)?;
let ip_vec = &columns[0];
let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
let ip_vec = arg0.as_string_view();
let size = ip_vec.len();
let mut results = UInt32VectorBuilder::with_capacity(size);
let mut builder = UInt32Builder::with_capacity(size);
for i in 0..size {
let ip_str = ip_vec.get(i);
let ip_str = ip_vec.is_valid(i).then(|| ip_vec.value(i));
let ip_num = match ip_str {
datatypes::value::Value::String(s) => {
let ip_str = s.as_utf8();
Some(ip_str) => {
let ip_addr = Ipv4Addr::from_str(ip_str).map_err(|_| {
InvalidFuncArgsSnafu {
err_msg: format!("Invalid IPv4 address format: {}", ip_str),
@@ -151,10 +147,10 @@ impl Function for Ipv4StringToNum {
_ => None,
};
results.push(ip_num);
builder.append_option(ip_num);
}
Ok(results.to_vector())
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
}
@@ -162,66 +158,92 @@ impl Function for Ipv4StringToNum {
mod tests {
use std::sync::Arc;
use datatypes::scalars::ScalarVector;
use datatypes::vectors::{StringVector, UInt32Vector};
use arrow_schema::Field;
use datafusion_common::arrow::array::{StringViewArray, UInt32Array};
use super::*;
#[test]
fn test_ipv4_num_to_string() {
let func = Ipv4NumToString::default();
let ctx = FunctionContext::default();
// Test data
let values = vec![167772161u32, 3232235521u32, 0u32, 4294967295u32];
let input = Arc::new(UInt32Vector::from_vec(values)) as VectorRef;
let input = ColumnarValue::Array(Arc::new(UInt32Array::from(values)));
let result = func.eval(&ctx, &[input]).unwrap();
let result = result.as_any().downcast_ref::<StringVector>().unwrap();
let args = ScalarFunctionArgs {
args: vec![input],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
config_options: Arc::new(Default::default()),
};
let result = func.invoke_with_args(args).unwrap();
let result = result.to_array(4).unwrap();
let result = result.as_string_view();
assert_eq!(result.get_data(0).unwrap(), "10.0.0.1");
assert_eq!(result.get_data(1).unwrap(), "192.168.0.1");
assert_eq!(result.get_data(2).unwrap(), "0.0.0.0");
assert_eq!(result.get_data(3).unwrap(), "255.255.255.255");
assert_eq!(result.value(0), "10.0.0.1");
assert_eq!(result.value(1), "192.168.0.1");
assert_eq!(result.value(2), "0.0.0.0");
assert_eq!(result.value(3), "255.255.255.255");
}
#[test]
fn test_ipv4_string_to_num() {
let func = Ipv4StringToNum;
let ctx = FunctionContext::default();
// Test data
let values = vec!["10.0.0.1", "192.168.0.1", "0.0.0.0", "255.255.255.255"];
let input = Arc::new(StringVector::from_slice(&values)) as VectorRef;
let input = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values)));
let result = func.eval(&ctx, &[input]).unwrap();
let result = result.as_any().downcast_ref::<UInt32Vector>().unwrap();
let args = ScalarFunctionArgs {
args: vec![input],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::UInt32, false)),
config_options: Arc::new(Default::default()),
};
let result = func.invoke_with_args(args).unwrap();
let result = result.to_array(4).unwrap();
let result = result.as_primitive::<UInt32Type>();
assert_eq!(result.get_data(0).unwrap(), 167772161);
assert_eq!(result.get_data(1).unwrap(), 3232235521);
assert_eq!(result.get_data(2).unwrap(), 0);
assert_eq!(result.get_data(3).unwrap(), 4294967295);
assert_eq!(result.value(0), 167772161);
assert_eq!(result.value(1), 3232235521);
assert_eq!(result.value(2), 0);
assert_eq!(result.value(3), 4294967295);
}
#[test]
fn test_ipv4_conversions_roundtrip() {
let to_num = Ipv4StringToNum;
let to_string = Ipv4NumToString::default();
let ctx = FunctionContext::default();
// Test data for string to num to string
let values = vec!["10.0.0.1", "192.168.0.1", "0.0.0.0", "255.255.255.255"];
let input = Arc::new(StringVector::from_slice(&values)) as VectorRef;
let input = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values)));
let num_result = to_num.eval(&ctx, &[input]).unwrap();
let back_to_string = to_string.eval(&ctx, &[num_result]).unwrap();
let str_result = back_to_string
.as_any()
.downcast_ref::<StringVector>()
.unwrap();
let args = ScalarFunctionArgs {
args: vec![input],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::UInt32, false)),
config_options: Arc::new(Default::default()),
};
let result = to_num.invoke_with_args(args).unwrap();
let args = ScalarFunctionArgs {
args: vec![result],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
config_options: Arc::new(Default::default()),
};
let result = to_string.invoke_with_args(args).unwrap();
let result = result.to_array(4).unwrap();
let result = result.as_string_view();
for (i, expected) in values.iter().enumerate() {
assert_eq!(str_result.get_data(i).unwrap(), *expected);
assert_eq!(result.value(i), *expected);
}
}
}

View File

@@ -14,17 +14,17 @@
use std::net::{Ipv4Addr, Ipv6Addr};
use std::str::FromStr;
use std::sync::Arc;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datafusion::arrow::datatypes::DataType;
use datafusion_expr::{Signature, Volatility};
use datatypes::prelude::Value;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, StringVectorBuilder, VectorRef};
use datafusion_common::DataFusionError;
use datafusion_common::arrow::array::{Array, AsArray, BinaryViewBuilder, StringViewBuilder};
use datafusion_common::arrow::compute;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
use derive_more::Display;
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::function::{Function, extract_args};
/// Function that converts a hex string representation of an IPv6 address to a formatted string.
///
@@ -41,30 +41,29 @@ impl Function for Ipv6NumToString {
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Utf8)
Ok(DataType::Utf8View)
}
fn signature(&self) -> Signature {
Signature::string(1, Volatility::Immutable)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 1,
InvalidFuncArgsSnafu {
err_msg: format!("Expected 1 argument, got {}", columns.len())
}
);
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let [arg0] = extract_args(self.name(), &args)?;
let hex_vec = &columns[0];
let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
let hex_vec = arg0.as_string_view();
let size = hex_vec.len();
let mut results = StringVectorBuilder::with_capacity(size);
let mut builder = StringViewBuilder::with_capacity(size);
for i in 0..size {
let hex_str = hex_vec.get(i);
let hex_str = hex_vec.is_valid(i).then(|| hex_vec.value(i));
let ip_str = match hex_str {
Value::String(s) => {
let hex_str = s.as_utf8().to_lowercase();
Some(s) => {
let hex_str = s.to_lowercase();
// Validate and convert hex string to bytes
let bytes = if hex_str.len() == 32 {
@@ -80,10 +79,10 @@ impl Function for Ipv6NumToString {
}
bytes
} else {
return InvalidFuncArgsSnafu {
err_msg: format!("Expected 32 hex characters, got {}", hex_str.len()),
}
.fail();
return Err(DataFusionError::Execution(format!(
"expecting 32 hex characters, got {}",
hex_str.len()
)));
};
// Convert bytes to IPv6 address
@@ -106,10 +105,10 @@ impl Function for Ipv6NumToString {
_ => None,
};
results.push(ip_str.as_deref());
builder.append_option(ip_str.as_deref());
}
Ok(results.to_vector())
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
}
@@ -130,31 +129,28 @@ impl Function for Ipv6StringToNum {
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Binary)
Ok(DataType::BinaryView)
}
fn signature(&self) -> Signature {
Signature::string(1, Volatility::Immutable)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 1,
InvalidFuncArgsSnafu {
err_msg: format!("Expected 1 argument, got {}", columns.len())
}
);
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let [arg0] = extract_args(self.name(), &args)?;
let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
let ip_vec = arg0.as_string_view();
let ip_vec = &columns[0];
let size = ip_vec.len();
let mut results = BinaryVectorBuilder::with_capacity(size);
let mut builder = BinaryViewBuilder::with_capacity(size);
for i in 0..size {
let ip_str = ip_vec.get(i);
let ip_str = ip_vec.is_valid(i).then(|| ip_vec.value(i));
let ip_binary = match ip_str {
Value::String(s) => {
let addr_str = s.as_utf8();
Some(addr_str) => {
let addr = if let Ok(ipv6) = Ipv6Addr::from_str(addr_str) {
// Direct IPv6 address
ipv6
@@ -163,10 +159,10 @@ impl Function for Ipv6StringToNum {
ipv4.to_ipv6_mapped()
} else {
// Invalid format
return InvalidFuncArgsSnafu {
err_msg: format!("Invalid IPv6 address format: {}", addr_str),
}
.fail();
return Err(DataFusionError::Execution(format!(
"Invalid IPv6 address format: {}",
addr_str
)));
};
// Convert IPv6 address to binary (16 bytes)
@@ -176,10 +172,10 @@ impl Function for Ipv6StringToNum {
_ => None,
};
results.push(ip_binary.as_deref());
builder.append_option(ip_binary.as_deref());
}
Ok(results.to_vector())
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
}
@@ -188,15 +184,14 @@ mod tests {
use std::fmt::Write;
use std::sync::Arc;
use datatypes::scalars::ScalarVector;
use datatypes::vectors::{BinaryVector, StringVector, Vector};
use arrow_schema::Field;
use datafusion_common::arrow::array::StringViewArray;
use super::*;
#[test]
fn test_ipv6_num_to_string() {
let func = Ipv6NumToString;
let ctx = FunctionContext::default();
// Hex string for "2001:db8::1"
let hex_str1 = "20010db8000000000000000000000001";
@@ -205,62 +200,93 @@ mod tests {
let hex_str2 = "00000000000000000000ffffc0a80001";
let values = vec![hex_str1, hex_str2];
let input = Arc::new(StringVector::from_slice(&values)) as VectorRef;
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values)));
let result = func.eval(&ctx, &[input]).unwrap();
let result = result.as_any().downcast_ref::<StringVector>().unwrap();
let args = ScalarFunctionArgs {
args: vec![arg0],
arg_fields: vec![],
number_rows: 2,
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
config_options: Arc::new(Default::default()),
};
let result = func.invoke_with_args(args).unwrap();
let result = result.to_array(2).unwrap();
let result = result.as_string_view();
assert_eq!(result.get_data(0).unwrap(), "2001:db8::1");
assert_eq!(result.get_data(1).unwrap(), "::ffff:192.168.0.1");
assert_eq!(result.value(0), "2001:db8::1");
assert_eq!(result.value(1), "::ffff:192.168.0.1");
}
#[test]
fn test_ipv6_num_to_string_uppercase() {
let func = Ipv6NumToString;
let ctx = FunctionContext::default();
// Uppercase hex string for "2001:db8::1"
let hex_str = "20010DB8000000000000000000000001";
let values = vec![hex_str];
let input = Arc::new(StringVector::from_slice(&values)) as VectorRef;
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values)));
let result = func.eval(&ctx, &[input]).unwrap();
let result = result.as_any().downcast_ref::<StringVector>().unwrap();
let args = ScalarFunctionArgs {
args: vec![arg0],
arg_fields: vec![],
number_rows: 1,
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
config_options: Arc::new(Default::default()),
};
let result = func.invoke_with_args(args).unwrap();
let result = result.to_array(1).unwrap();
let result = result.as_string_view();
assert_eq!(result.get_data(0).unwrap(), "2001:db8::1");
assert_eq!(result.value(0), "2001:db8::1");
}
#[test]
fn test_ipv6_num_to_string_error() {
let func = Ipv6NumToString;
let ctx = FunctionContext::default();
// Invalid hex string - wrong length
let hex_str = "20010db8";
let values = vec![hex_str];
let input = Arc::new(StringVector::from_slice(&values)) as VectorRef;
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values)));
// Should return an error
let result = func.eval(&ctx, &[input]);
let args = ScalarFunctionArgs {
args: vec![arg0],
arg_fields: vec![],
number_rows: 2,
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
config_options: Arc::new(Default::default()),
};
let result = func.invoke_with_args(args);
assert!(result.is_err());
// Check that the error message contains expected text
let error_msg = result.unwrap_err().to_string();
assert!(error_msg.contains("Expected 32 hex characters"));
assert_eq!(
error_msg,
"Execution error: expecting 32 hex characters, got 8"
);
}
#[test]
fn test_ipv6_string_to_num() {
let func = Ipv6StringToNum;
let ctx = FunctionContext::default();
let values = vec!["2001:db8::1", "::ffff:192.168.0.1", "192.168.0.1"];
let input = Arc::new(StringVector::from_slice(&values)) as VectorRef;
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values)));
let result = func.eval(&ctx, &[input]).unwrap();
let result = result.as_any().downcast_ref::<BinaryVector>().unwrap();
let args = ScalarFunctionArgs {
args: vec![arg0],
arg_fields: vec![],
number_rows: 3,
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
config_options: Arc::new(Default::default()),
};
let result = func.invoke_with_args(args).unwrap();
let result = result.to_array(3).unwrap();
let result = result.as_binary_view();
// Expected binary for "2001:db8::1"
let expected_1 = [
@@ -272,33 +298,37 @@ mod tests {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xC0, 0xA8, 0, 0x01,
];
assert_eq!(result.get_data(0).unwrap(), &expected_1);
assert_eq!(result.get_data(1).unwrap(), &expected_2);
assert_eq!(result.get_data(2).unwrap(), &expected_2);
assert_eq!(result.value(0), &expected_1);
assert_eq!(result.value(1), &expected_2);
assert_eq!(result.value(2), &expected_2);
}
#[test]
fn test_ipv6_conversions_roundtrip() {
let to_num = Ipv6StringToNum;
let to_string = Ipv6NumToString;
let ctx = FunctionContext::default();
// Test data
let values = vec!["2001:db8::1", "::ffff:192.168.0.1"];
let input = Arc::new(StringVector::from_slice(&values)) as VectorRef;
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values)));
// Convert IPv6 addresses to binary
let binary_result = to_num.eval(&ctx, std::slice::from_ref(&input)).unwrap();
let args = ScalarFunctionArgs {
args: vec![arg0],
arg_fields: vec![],
number_rows: 2,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(Default::default()),
};
let result = to_num.invoke_with_args(args).unwrap();
// Convert binary to hex string representation (for ipv6_num_to_string)
let mut hex_strings = Vec::new();
let binary_vector = binary_result
.as_any()
.downcast_ref::<BinaryVector>()
.unwrap();
let result = result.to_array(2).unwrap();
let binary_vector = result.as_binary_view();
for i in 0..binary_vector.len() {
let bytes = binary_vector.get_data(i).unwrap();
let bytes = binary_vector.value(i);
let hex = bytes.iter().fold(String::new(), |mut acc, b| {
write!(&mut acc, "{:02x}", b).unwrap();
acc
@@ -307,18 +337,23 @@ mod tests {
}
let hex_str_refs: Vec<&str> = hex_strings.iter().map(|s| s.as_str()).collect();
let hex_input = Arc::new(StringVector::from_slice(&hex_str_refs)) as VectorRef;
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&hex_str_refs)));
// Now convert hex to formatted string
let string_result = to_string.eval(&ctx, &[hex_input]).unwrap();
let str_result = string_result
.as_any()
.downcast_ref::<StringVector>()
.unwrap();
let args = ScalarFunctionArgs {
args: vec![arg0],
arg_fields: vec![],
number_rows: 2,
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
config_options: Arc::new(Default::default()),
};
let result = to_string.invoke_with_args(args).unwrap();
let result = result.to_array(2).unwrap();
let result = result.as_string_view();
// Compare with original input
assert_eq!(str_result.get_data(0).unwrap(), values[0]);
assert_eq!(str_result.get_data(1).unwrap(), values[1]);
assert_eq!(result.value(0), values[0]);
assert_eq!(result.value(1), values[1]);
}
#[test]
@@ -327,24 +362,35 @@ mod tests {
// can be converted back using ipv6_string_to_num
let to_string = Ipv6NumToString;
let to_binary = Ipv6StringToNum;
let ctx = FunctionContext::default();
// Hex representation of IPv6 addresses
let hex_values = vec![
"20010db8000000000000000000000001",
"00000000000000000000ffffc0a80001",
];
let hex_input = Arc::new(StringVector::from_slice(&hex_values)) as VectorRef;
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&hex_values)));
// Convert hex to string representation
let string_result = to_string.eval(&ctx, &[hex_input]).unwrap();
let args = ScalarFunctionArgs {
args: vec![arg0],
arg_fields: vec![],
number_rows: 2,
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
config_options: Arc::new(Default::default()),
};
let result = to_string.invoke_with_args(args).unwrap();
// Then convert string representation back to binary
let binary_result = to_binary.eval(&ctx, &[string_result]).unwrap();
let bin_result = binary_result
.as_any()
.downcast_ref::<BinaryVector>()
.unwrap();
let args = ScalarFunctionArgs {
args: vec![result],
arg_fields: vec![],
number_rows: 2,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(Default::default()),
};
let result = to_binary.invoke_with_args(args).unwrap();
let result = result.to_array(2).unwrap();
let result = result.as_binary_view();
// Expected binary values
let expected_bin1 = [
@@ -354,7 +400,7 @@ mod tests {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xC0, 0xA8, 0, 0x01,
];
assert_eq!(bin_result.get_data(0).unwrap(), &expected_bin1);
assert_eq!(bin_result.get_data(1).unwrap(), &expected_bin2);
assert_eq!(result.value(0), &expected_bin1);
assert_eq!(result.value(1), &expected_bin2);
}
}

View File

@@ -14,17 +14,18 @@
use std::net::{Ipv4Addr, Ipv6Addr};
use std::str::FromStr;
use std::sync::Arc;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datafusion::arrow::datatypes::DataType;
use datafusion_expr::{Signature, Volatility};
use datatypes::prelude::Value;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{BooleanVectorBuilder, MutableVector, VectorRef};
use datafusion_common::DataFusionError;
use datafusion_common::arrow::array::{Array, AsArray, BooleanBuilder};
use datafusion_common::arrow::compute;
use datafusion_common::arrow::datatypes::DataType;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
use derive_more::Display;
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::function::{Function, extract_args};
/// Function that checks if an IPv4 address is within a specified CIDR range.
///
@@ -52,42 +53,31 @@ impl Function for Ipv4InRange {
Signature::string(2, Volatility::Immutable)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!("Expected 2 arguments, got {}", columns.len())
}
);
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let [arg0, arg1] = extract_args(self.name(), &args)?;
let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
let ip_vec = arg0.as_string_view();
let arg1 = compute::cast(&arg1, &DataType::Utf8View)?;
let ranges = arg1.as_string_view();
let ip_vec = &columns[0];
let range_vec = &columns[1];
let size = ip_vec.len();
ensure!(
range_vec.len() == size,
InvalidFuncArgsSnafu {
err_msg: "IP addresses and CIDR ranges must have the same number of rows"
.to_string()
}
);
let mut results = BooleanVectorBuilder::with_capacity(size);
let mut builder = BooleanBuilder::with_capacity(size);
for i in 0..size {
let ip = ip_vec.get(i);
let range = range_vec.get(i);
let ip = ip_vec.is_valid(i).then(|| ip_vec.value(i));
let range = ranges.is_valid(i).then(|| ranges.value(i));
let in_range = match (ip, range) {
(Value::String(ip_str), Value::String(range_str)) => {
let ip_str = ip_str.as_utf8().trim();
let range_str = range_str.as_utf8().trim();
(Some(ip_str), Some(range_str)) => {
if ip_str.is_empty() || range_str.is_empty() {
return InvalidFuncArgsSnafu {
err_msg: "IP address and CIDR range cannot be empty".to_string(),
}
.fail();
return Err(DataFusionError::Execution(
"IP address or CIDR range cannot be empty".to_string(),
));
}
// Parse the IP address
@@ -107,10 +97,10 @@ impl Function for Ipv4InRange {
_ => None,
};
results.push(in_range);
builder.append_option(in_range);
}
Ok(results.to_vector())
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
}
@@ -141,42 +131,29 @@ impl Function for Ipv6InRange {
Signature::string(2, Volatility::Immutable)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!("Expected 2 arguments, got {}", columns.len())
}
);
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let [arg0, arg1] = extract_args(self.name(), &args)?;
let ip_vec = &columns[0];
let range_vec = &columns[1];
let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
let ip_vec = arg0.as_string_view();
let arg1 = compute::cast(&arg1, &DataType::Utf8View)?;
let ranges = arg1.as_string_view();
let size = ip_vec.len();
ensure!(
range_vec.len() == size,
InvalidFuncArgsSnafu {
err_msg: "IP addresses and CIDR ranges must have the same number of rows"
.to_string()
}
);
let mut results = BooleanVectorBuilder::with_capacity(size);
let mut builder = BooleanBuilder::with_capacity(size);
for i in 0..size {
let ip = ip_vec.get(i);
let range = range_vec.get(i);
let ip = ip_vec.is_valid(i).then(|| ip_vec.value(i));
let range = ranges.is_valid(i).then(|| ranges.value(i));
let in_range = match (ip, range) {
(Value::String(ip_str), Value::String(range_str)) => {
let ip_str = ip_str.as_utf8().trim();
let range_str = range_str.as_utf8().trim();
(Some(ip_str), Some(range_str)) => {
if ip_str.is_empty() || range_str.is_empty() {
return InvalidFuncArgsSnafu {
err_msg: "IP address and CIDR range cannot be empty".to_string(),
}
.fail();
return Err(DataFusionError::Execution(
"IP address or CIDR range cannot be empty".to_string(),
));
}
// Parse the IP address
@@ -196,10 +173,10 @@ impl Function for Ipv6InRange {
_ => None,
};
results.push(in_range);
builder.append_option(in_range);
}
Ok(results.to_vector())
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
}
@@ -329,15 +306,14 @@ fn is_ipv6_in_range(ip: &Ipv6Addr, cidr_base: &Ipv6Addr, prefix_len: u8) -> Opti
mod tests {
use std::sync::Arc;
use datatypes::scalars::ScalarVector;
use datatypes::vectors::{BooleanVector, StringVector};
use arrow_schema::Field;
use datafusion_common::arrow::array::StringViewArray;
use super::*;
#[test]
fn test_ipv4_in_range() {
let func = Ipv4InRange;
let ctx = FunctionContext::default();
// Test IPs
let ip_values = vec![
@@ -357,24 +333,31 @@ mod tests {
"172.16.0.0/16",
];
let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
let cidr_input = Arc::new(StringVector::from_slice(&cidr_values)) as VectorRef;
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&ip_values)));
let arg1 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&cidr_values)));
let result = func.eval(&ctx, &[ip_input, cidr_input]).unwrap();
let result = result.as_any().downcast_ref::<BooleanVector>().unwrap();
let args = ScalarFunctionArgs {
args: vec![arg0, arg1],
arg_fields: vec![],
number_rows: 5,
return_field: Arc::new(Field::new("x", DataType::Boolean, false)),
config_options: Arc::new(Default::default()),
};
let result = func.invoke_with_args(args).unwrap();
let result = result.to_array(5).unwrap();
let result = result.as_boolean();
// Expected results
assert!(result.get_data(0).unwrap()); // 192.168.1.5 is in 192.168.1.0/24
assert!(!result.get_data(1).unwrap()); // 192.168.2.1 is not in 192.168.1.0/24
assert!(result.get_data(2).unwrap()); // 10.0.0.1 is in 10.0.0.0/8
assert!(result.get_data(3).unwrap()); // 10.1.0.1 is in 10.0.0.0/8
assert!(result.get_data(4).unwrap()); // 172.16.0.1 is in 172.16.0.0/16
assert!(result.value(0)); // 192.168.1.5 is in 192.168.1.0/24
assert!(!result.value(1)); // 192.168.2.1 is not in 192.168.1.0/24
assert!(result.value(2)); // 10.0.0.1 is in 10.0.0.0/8
assert!(result.value(3)); // 10.1.0.1 is in 10.0.0.0/8
assert!(result.value(4)); // 172.16.0.1 is in 172.16.0.0/16
}
#[test]
fn test_ipv6_in_range() {
let func = Ipv6InRange;
let ctx = FunctionContext::default();
// Test IPs
let ip_values = vec![
@@ -394,46 +377,70 @@ mod tests {
"fe80::/16",
];
let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
let cidr_input = Arc::new(StringVector::from_slice(&cidr_values)) as VectorRef;
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&ip_values)));
let arg1 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&cidr_values)));
let result = func.eval(&ctx, &[ip_input, cidr_input]).unwrap();
let result = result.as_any().downcast_ref::<BooleanVector>().unwrap();
let args = ScalarFunctionArgs {
args: vec![arg0, arg1],
arg_fields: vec![],
number_rows: 5,
return_field: Arc::new(Field::new("x", DataType::Boolean, false)),
config_options: Arc::new(Default::default()),
};
let result = func.invoke_with_args(args).unwrap();
let result = result.to_array(5).unwrap();
let result = result.as_boolean();
// Expected results
assert!(result.get_data(0).unwrap()); // 2001:db8::1 is in 2001:db8::/32
assert!(result.get_data(1).unwrap()); // 2001:db8:1:: is in 2001:db8::/32
assert!(!result.get_data(2).unwrap()); // 2001:db9::1 is not in 2001:db8::/32
assert!(result.get_data(3).unwrap()); // ::1 is in ::1/128
assert!(result.get_data(4).unwrap()); // fe80::1 is in fe80::/16
assert!(result.value(0)); // 2001:db8::1 is in 2001:db8::/32
assert!(result.value(1)); // 2001:db8:1:: is in 2001:db8::/32
assert!(!result.value(2)); // 2001:db9::1 is not in 2001:db8::/32
assert!(result.value(3)); // ::1 is in ::1/128
assert!(result.value(4)); // fe80::1 is in fe80::/16
}
#[test]
fn test_invalid_inputs() {
let ipv4_func = Ipv4InRange;
let ipv6_func = Ipv6InRange;
let ctx = FunctionContext::default();
// Invalid IPv4 address
let invalid_ip_values = vec!["not-an-ip", "192.168.1.300"];
let cidr_values = vec!["192.168.1.0/24", "192.168.1.0/24"];
let invalid_ip_input = Arc::new(StringVector::from_slice(&invalid_ip_values)) as VectorRef;
let cidr_input = Arc::new(StringVector::from_slice(&cidr_values)) as VectorRef;
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(
&invalid_ip_values,
)));
let arg1 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&cidr_values)));
let result = ipv4_func.eval(&ctx, &[invalid_ip_input, cidr_input]);
let args = ScalarFunctionArgs {
args: vec![arg0, arg1],
arg_fields: vec![],
number_rows: 2,
return_field: Arc::new(Field::new("x", DataType::Boolean, false)),
config_options: Arc::new(Default::default()),
};
let result = ipv4_func.invoke_with_args(args);
assert!(result.is_err());
// Invalid CIDR notation
let ip_values = vec!["192.168.1.1", "2001:db8::1"];
let invalid_cidr_values = vec!["192.168.1.0", "2001:db8::/129"];
let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
let invalid_cidr_input =
Arc::new(StringVector::from_slice(&invalid_cidr_values)) as VectorRef;
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&ip_values)));
let arg1 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(
&invalid_cidr_values,
)));
let ipv4_result = ipv4_func.eval(&ctx, &[ip_input.clone(), invalid_cidr_input.clone()]);
let ipv6_result = ipv6_func.eval(&ctx, &[ip_input, invalid_cidr_input]);
let args = ScalarFunctionArgs {
args: vec![arg0, arg1],
arg_fields: vec![],
number_rows: 2,
return_field: Arc::new(Field::new("x", DataType::Boolean, false)),
config_options: Arc::new(Default::default()),
};
let ipv4_result = ipv4_func.invoke_with_args(args.clone());
let ipv6_result = ipv6_func.invoke_with_args(args);
assert!(ipv4_result.is_err());
assert!(ipv6_result.is_err());
@@ -442,20 +449,27 @@ mod tests {
#[test]
fn test_edge_cases() {
let ipv4_func = Ipv4InRange;
let ctx = FunctionContext::default();
// Edge cases like prefix length 0 (matches everything) and 32 (exact match)
let ip_values = vec!["8.8.8.8", "192.168.1.1", "192.168.1.1"];
let cidr_values = vec!["0.0.0.0/0", "192.168.1.1/32", "192.168.1.0/32"];
let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
let cidr_input = Arc::new(StringVector::from_slice(&cidr_values)) as VectorRef;
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&ip_values)));
let arg1 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&cidr_values)));
let result = ipv4_func.eval(&ctx, &[ip_input, cidr_input]).unwrap();
let result = result.as_any().downcast_ref::<BooleanVector>().unwrap();
let args = ScalarFunctionArgs {
args: vec![arg0, arg1],
arg_fields: vec![],
number_rows: 3,
return_field: Arc::new(Field::new("x", DataType::Boolean, false)),
config_options: Arc::new(Default::default()),
};
let result = ipv4_func.invoke_with_args(args).unwrap();
let result = result.to_array(3).unwrap();
let result = result.as_boolean();
assert!(result.get_data(0).unwrap()); // 8.8.8.8 is in 0.0.0.0/0 (matches everything)
assert!(result.get_data(1).unwrap()); // 192.168.1.1 is in 192.168.1.1/32 (exact match)
assert!(!result.get_data(2).unwrap()); // 192.168.1.1 is not in 192.168.1.0/32 (no match)
assert!(result.value(0)); // 8.8.8.8 is in 0.0.0.0/0 (matches everything)
assert!(result.value(1)); // 192.168.1.1 is in 192.168.1.1/32 (exact match)
assert!(!result.value(2)); // 192.168.1.1 is not in 192.168.1.0/32 (no match)
}
}

View File

@@ -15,19 +15,14 @@
use std::fmt;
use std::fmt::Display;
use common_query::error;
use common_query::error::{ArrowComputeSnafu, InvalidFuncArgsSnafu, Result};
use datafusion::arrow::datatypes::DataType;
use common_query::error::Result;
use datafusion_common::arrow::compute;
use datafusion_common::arrow::compute::kernels::numeric;
use datafusion_common::arrow::datatypes::DataType;
use datafusion_expr::type_coercion::aggregates::NUMERICS;
use datafusion_expr::{Signature, Volatility};
use datatypes::arrow::compute;
use datatypes::arrow::compute::kernels::numeric;
use datatypes::arrow::datatypes::DataType as ArrowDataType;
use datatypes::prelude::ConcreteDataType;
use datatypes::vectors::{Helper, VectorRef};
use snafu::{ResultExt, ensure};
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
use crate::function::{Function, FunctionContext};
use crate::function::{Function, extract_args};
const NAME: &str = "mod";
@@ -60,38 +55,24 @@ impl Function for ModuloFunction {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly two, have: {}",
columns.len()
),
}
);
let nums = &columns[0];
let divs = &columns[1];
let nums_arrow_array = &nums.to_arrow_array();
let divs_arrow_array = &divs.to_arrow_array();
let array = numeric::rem(nums_arrow_array, divs_arrow_array).context(ArrowComputeSnafu)?;
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let [nums, divs] = extract_args(self.name(), &args)?;
let array = numeric::rem(&nums, &divs)?;
let result = match nums.data_type() {
ConcreteDataType::Int8(_)
| ConcreteDataType::Int16(_)
| ConcreteDataType::Int32(_)
| ConcreteDataType::Int64(_) => compute::cast(&array, &ArrowDataType::Int64),
ConcreteDataType::UInt8(_)
| ConcreteDataType::UInt16(_)
| ConcreteDataType::UInt32(_)
| ConcreteDataType::UInt64(_) => compute::cast(&array, &ArrowDataType::UInt64),
ConcreteDataType::Float32(_) | ConcreteDataType::Float64(_) => {
compute::cast(&array, &ArrowDataType::Float64)
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
compute::cast(&array, &DataType::Int64)
}
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
compute::cast(&array, &DataType::UInt64)
}
DataType::Float32 | DataType::Float64 => compute::cast(&array, &DataType::Float64),
_ => unreachable!("unexpected datatype: {:?}", nums.data_type()),
}
.context(ArrowComputeSnafu)?;
Helper::try_into_vector(&result).context(error::FromArrowArraySnafu)
}?;
Ok(ColumnarValue::Array(result))
}
}
@@ -99,9 +80,11 @@ impl Function for ModuloFunction {
mod tests {
use std::sync::Arc;
use common_error::ext::ErrorExt;
use datatypes::value::Value;
use datatypes::vectors::{Float64Vector, Int32Vector, StringVector, UInt32Vector};
use arrow_schema::Field;
use datafusion_common::arrow::array::{
AsArray, Float64Array, Int32Array, StringViewArray, UInt32Array,
};
use datafusion_common::arrow::datatypes::{Float64Type, Int64Type, UInt64Type};
use super::*;
#[test]
@@ -120,15 +103,23 @@ mod tests {
let nums = vec![18, -17, 5, -6];
let divs = vec![4, 8, -5, -5];
let args: Vec<VectorRef> = vec![
Arc::new(Int32Vector::from_vec(nums.clone())),
Arc::new(Int32Vector::from_vec(divs.clone())),
];
let result = function.eval(&FunctionContext::default(), &args).unwrap();
let args = ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(Int32Array::from(nums.clone()))),
ColumnarValue::Array(Arc::new(Int32Array::from(divs.clone()))),
],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::Int64, false)),
config_options: Arc::new(Default::default()),
};
let result = function.invoke_with_args(args).unwrap();
let result = result.to_array(4).unwrap();
let result = result.as_primitive::<Int64Type>();
assert_eq!(result.len(), 4);
for i in 0..4 {
let p: i64 = (nums[i] % divs[i]) as i64;
assert!(matches!(result.get(i), Value::Int64(v) if v == p));
assert_eq!(result.value(i), p);
}
}
@@ -148,15 +139,23 @@ mod tests {
let nums: Vec<u32> = vec![18, 17, 5, 6];
let divs: Vec<u32> = vec![4, 8, 5, 5];
let args: Vec<VectorRef> = vec![
Arc::new(UInt32Vector::from_vec(nums.clone())),
Arc::new(UInt32Vector::from_vec(divs.clone())),
];
let result = function.eval(&FunctionContext::default(), &args).unwrap();
let args = ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(UInt32Array::from(nums.clone()))),
ColumnarValue::Array(Arc::new(UInt32Array::from(divs.clone()))),
],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
config_options: Arc::new(Default::default()),
};
let result = function.invoke_with_args(args).unwrap();
let result = result.to_array(4).unwrap();
let result = result.as_primitive::<UInt64Type>();
assert_eq!(result.len(), 4);
for i in 0..4 {
let p: u64 = (nums[i] % divs[i]) as u64;
assert!(matches!(result.get(i), Value::UInt64(v) if v == p));
assert_eq!(result.value(i), p);
}
}
@@ -176,15 +175,23 @@ mod tests {
let nums = vec![18.0, 17.0, 5.0, 6.0];
let divs = vec![4.0, 8.0, 5.0, 5.0];
let args: Vec<VectorRef> = vec![
Arc::new(Float64Vector::from_vec(nums.clone())),
Arc::new(Float64Vector::from_vec(divs.clone())),
];
let result = function.eval(&FunctionContext::default(), &args).unwrap();
let args = ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(Float64Array::from(nums.clone()))),
ColumnarValue::Array(Arc::new(Float64Array::from(divs.clone()))),
],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::Float64, false)),
config_options: Arc::new(Default::default()),
};
let result = function.invoke_with_args(args).unwrap();
let result = result.to_array(4).unwrap();
let result = result.as_primitive::<Float64Type>();
assert_eq!(result.len(), 4);
for i in 0..4 {
let p: f64 = nums[i] % divs[i];
assert!(matches!(result.get(i), Value::Float64(v) if v == p));
assert_eq!(result.value(i), p);
}
}
@@ -195,37 +202,53 @@ mod tests {
let nums = vec![27];
let divs = vec![0];
let args: Vec<VectorRef> = vec![
Arc::new(Int32Vector::from_vec(nums.clone())),
Arc::new(Int32Vector::from_vec(divs.clone())),
];
let result = function.eval(&FunctionContext::default(), &args);
let args = ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(Int32Array::from(nums))),
ColumnarValue::Array(Arc::new(Int32Array::from(divs))),
],
arg_fields: vec![],
number_rows: 1,
return_field: Arc::new(Field::new("x", DataType::Int64, false)),
config_options: Arc::new(Default::default()),
};
let result = function.invoke_with_args(args);
assert!(result.is_err());
let err_msg = result.unwrap_err().output_msg();
assert_eq!(
err_msg,
"Failed to perform compute operation on arrow arrays: Divide by zero error"
);
let err_msg = result.unwrap_err().to_string();
assert_eq!(err_msg, "Arrow error: Divide by zero error");
let nums = vec![27];
let args: Vec<VectorRef> = vec![Arc::new(Int32Vector::from_vec(nums.clone()))];
let result = function.eval(&FunctionContext::default(), &args);
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(Int32Array::from(nums)))],
arg_fields: vec![],
number_rows: 1,
return_field: Arc::new(Field::new("x", DataType::Int64, false)),
config_options: Arc::new(Default::default()),
};
let result = function.invoke_with_args(args);
assert!(result.is_err());
let err_msg = result.unwrap_err().output_msg();
assert!(
err_msg.contains("The length of the args is not correct, expect exactly two, have: 1")
let err_msg = result.unwrap_err().to_string();
assert_eq!(
err_msg,
"Execution error: mod function requires 2 arguments, got 1"
);
let nums = vec!["27"];
let divs = vec!["4"];
let args: Vec<VectorRef> = vec![
Arc::new(StringVector::from(nums.clone())),
Arc::new(StringVector::from(divs.clone())),
];
let result = function.eval(&FunctionContext::default(), &args);
let args = ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(StringViewArray::from(nums))),
ColumnarValue::Array(Arc::new(StringViewArray::from(divs))),
],
arg_fields: vec![],
number_rows: 1,
return_field: Arc::new(Field::new("x", DataType::Int64, false)),
config_options: Arc::new(Default::default()),
};
let result = function.invoke_with_args(args);
assert!(result.is_err());
let err_msg = result.unwrap_err().output_msg();
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("Invalid arithmetic operation"));
}
}

View File

@@ -13,14 +13,14 @@
// limitations under the License.
use std::fmt;
use std::sync::Arc;
use common_query::error::Result;
use datafusion_expr::{Signature, Volatility};
use datafusion::logical_expr::ColumnarValue;
use datafusion_expr::{ScalarFunctionArgs, Signature, Volatility};
use datatypes::arrow::datatypes::DataType;
use datatypes::prelude::VectorRef;
use datatypes::vectors::{Helper, Vector};
use crate::function::{Function, FunctionContext};
use crate::function::{Function, extract_args};
use crate::scalars::expression::{EvalContext, scalar_binary_op};
#[derive(Clone, Default)]
@@ -42,14 +42,19 @@ impl Function for TestAndFunction {
)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let [arg0, arg1] = extract_args(self.name(), &args)?;
let columns = Helper::try_into_vectors(&[arg0, arg1]).unwrap();
let col = scalar_binary_op::<bool, bool, bool, _>(
&columns[0],
&columns[1],
scalar_and,
&mut EvalContext::default(),
)?;
Ok(Arc::new(col))
Ok(ColumnarValue::Array(col.to_arrow_array()))
}
}

View File

@@ -107,12 +107,11 @@ mod tests {
use common_query::prelude::ScalarValue;
use datafusion::arrow::array::BooleanArray;
use datafusion_common::arrow::array::AsArray;
use datafusion_common::arrow::datatypes::DataType as ArrowDataType;
use datafusion_common::config::ConfigOptions;
use datatypes::arrow::datatypes::Field;
use datatypes::data_type::{ConcreteDataType, DataType};
use datatypes::prelude::VectorRef;
use datatypes::value::Value;
use datatypes::vectors::{BooleanVector, ConstantVector};
use session::context::QueryContextBuilder;
use super::*;
@@ -124,20 +123,31 @@ mod tests {
let f = Arc::new(TestAndFunction);
let query_ctx = QueryContextBuilder::default().build().into();
let args: Vec<VectorRef> = vec![
Arc::new(ConstantVector::new(
Arc::new(BooleanVector::from(vec![true])),
3,
)),
Arc::new(BooleanVector::from(vec![true, false, true])),
];
let args = ScalarFunctionArgs {
args: vec![
datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
true, true, true,
]))),
datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
true, false, true,
]))),
],
arg_fields: vec![],
number_rows: 3,
return_field: Arc::new(Field::new("x", ArrowDataType::Boolean, true)),
config_options: Arc::new(Default::default()),
};
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
let result = f
.invoke_with_args(args)
.and_then(|x| x.to_array(3))
.unwrap();
let vector = result.as_boolean();
assert_eq!(3, vector.len());
for i in 0..3 {
assert!(matches!(vector.get(i), Value::Boolean(b) if b == (i == 0 || i == 2)));
}
assert!(vector.value(0));
assert!(!vector.value(1));
assert!(vector.value(2));
// create a udf and test it again
let udf = create_udf(f.clone(), query_ctx, Arc::new(FunctionState::default()));

View File

@@ -13,15 +13,13 @@
// limitations under the License.
use std::fmt::{self};
use std::sync::Arc;
use common_query::error::Result;
use datafusion::arrow::datatypes::DataType;
use datafusion_expr::{Signature, Volatility};
use datatypes::prelude::ScalarVector;
use datatypes::vectors::{StringVector, VectorRef};
use datafusion_common::ScalarValue;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
use crate::function::{Function, FunctionContext};
use crate::function::{Function, find_function_context};
/// A function to return current session timezone.
#[derive(Clone, Debug, Default)]
@@ -35,17 +33,21 @@ impl Function for TimezoneFunction {
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Utf8)
Ok(DataType::Utf8View)
}
fn signature(&self) -> Signature {
Signature::nullary(Volatility::Immutable)
}
fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let func_ctx = find_function_context(&args)?;
let tz = func_ctx.query_ctx.timezone().to_string();
Ok(Arc::new(StringVector::from_slice(&[&tz])) as _)
Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(tz))))
}
}
@@ -59,14 +61,18 @@ impl fmt::Display for TimezoneFunction {
mod tests {
use std::sync::Arc;
use arrow_schema::Field;
use datafusion_common::config::ConfigOptions;
use session::context::QueryContextBuilder;
use super::*;
use crate::function::FunctionContext;
#[test]
fn test_build_function() {
let build = TimezoneFunction;
assert_eq!("timezone", build.name());
assert_eq!(DataType::Utf8, build.return_type(&[]).unwrap());
assert_eq!(DataType::Utf8View, build.return_type(&[]).unwrap());
assert_eq!(build.signature(), Signature::nullary(Volatility::Immutable));
let query_ctx = QueryContextBuilder::default().build().into();
@@ -75,8 +81,21 @@ mod tests {
query_ctx,
..Default::default()
};
let vector = build.eval(&func_ctx, &[]).unwrap();
let expect: VectorRef = Arc::new(StringVector::from(vec!["UTC"]));
assert_eq!(expect, vector);
let mut config_options = ConfigOptions::default();
config_options.extensions.insert(func_ctx);
let config_options = Arc::new(config_options);
let args = ScalarFunctionArgs {
args: vec![],
arg_fields: vec![],
number_rows: 0,
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
config_options,
};
let result = build.invoke_with_args(args).unwrap();
let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s))) = result else {
unreachable!()
};
assert_eq!(s, "UTC");
}
}

View File

@@ -13,15 +13,14 @@
// limitations under the License.
use std::fmt;
use std::sync::Arc;
use common_query::error::Result;
use datafusion::arrow::datatypes::DataType;
use datafusion_expr::{Signature, Volatility};
use datatypes::vectors::{StringVector, VectorRef};
use datafusion_common::ScalarValue;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
use session::context::Channel;
use crate::function::{Function, FunctionContext};
use crate::function::{Function, find_function_context};
#[derive(Clone, Debug, Default)]
pub(crate) struct VersionFunction;
@@ -38,14 +37,18 @@ impl Function for VersionFunction {
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Utf8)
Ok(DataType::Utf8View)
}
fn signature(&self) -> Signature {
Signature::nullary(Volatility::Immutable)
}
fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let func_ctx = find_function_context(&args)?;
let version = match func_ctx.query_ctx.channel() {
Channel::Mysql => {
format!(
@@ -60,7 +63,6 @@ impl Function for VersionFunction {
}
_ => common_version::version().to_string(),
};
let result = StringVector::from(vec![version]);
Ok(Arc::new(result))
Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(version))))
}
}

View File

@@ -17,6 +17,9 @@ use std::slice;
use std::sync::Arc;
use datafusion::arrow::util::pretty::pretty_format_batches;
use datafusion_common::arrow::array::ArrayRef;
use datafusion_common::arrow::compute;
use datafusion_common::arrow::datatypes::{DataType as ArrowDataType, SchemaRef as ArrowSchemaRef};
use datatypes::arrow::array::RecordBatchOptions;
use datatypes::prelude::DataType;
use datatypes::schema::SchemaRef;
@@ -28,8 +31,8 @@ use snafu::{OptionExt, ResultExt, ensure};
use crate::DfRecordBatch;
use crate::error::{
self, CastVectorSnafu, ColumnNotExistsSnafu, DataTypesSnafu, ProjectArrowRecordBatchSnafu,
Result,
self, ArrowComputeSnafu, CastVectorSnafu, ColumnNotExistsSnafu, DataTypesSnafu,
ProjectArrowRecordBatchSnafu, Result,
};
/// A two-dimensional batch of column-oriented data with a defined schema.
@@ -49,6 +52,15 @@ impl RecordBatch {
let columns: Vec<_> = columns.into_iter().collect();
let arrow_arrays = columns.iter().map(|v| v.to_arrow_array()).collect();
// Casting the arrays here to match the schema, is a temporary solution to support Arrow's
// view array types (`StringViewArray` and `BinaryViewArray`).
// As to "support": the arrays here are created from vectors, which do not have types
// corresponding to view arrays. What we can do is to only cast them.
// As to "temporary": we are planing to use Arrow's RecordBatch directly in the read path.
// the casting here will be removed in the end.
// TODO(LFC): Remove the casting here once `Batch` is no longer used.
let arrow_arrays = Self::cast_view_arrays(schema.arrow_schema(), arrow_arrays)?;
let df_record_batch = DfRecordBatch::try_new(schema.arrow_schema().clone(), arrow_arrays)
.context(error::NewDfRecordBatchSnafu)?;
@@ -59,6 +71,24 @@ impl RecordBatch {
})
}
fn cast_view_arrays(
schema: &ArrowSchemaRef,
mut arrays: Vec<ArrayRef>,
) -> Result<Vec<ArrayRef>> {
for (f, a) in schema.fields().iter().zip(arrays.iter_mut()) {
let expected = f.data_type();
let actual = a.data_type();
if matches!(
(expected, actual),
(ArrowDataType::Utf8View, ArrowDataType::Utf8)
| (ArrowDataType::BinaryView, ArrowDataType::Binary)
) {
*a = compute::cast(a, expected).context(ArrowComputeSnafu)?;
}
}
Ok(arrays)
}
/// Create an empty [`RecordBatch`] from `schema`.
pub fn new_empty(schema: SchemaRef) -> RecordBatch {
let df_record_batch = DfRecordBatch::new_empty(schema.arrow_schema().clone());

View File

@@ -16,14 +16,12 @@
use std::collections::BTreeMap;
use std::sync::Arc;
use common_error::ext::BoxedError;
use common_function::function::{FunctionContext, FunctionRef};
use common_function::function::FunctionRef;
use datafusion::arrow::datatypes::{DataType, TimeUnit};
use datafusion_expr::{Signature, Volatility};
use datafusion_substrait::extensions::Extensions;
use query::QueryEngine;
use serde::{Deserialize, Serialize};
use snafu::ResultExt;
/// note here we are using the `substrait_proto_df` crate from the `substrait` module and
/// rename it to `substrait_proto`
use substrait::substrait_proto_df as substrait_proto;
@@ -31,7 +29,7 @@ use substrait_proto::proto::extensions::SimpleExtensionDeclaration;
use substrait_proto::proto::extensions::simple_extension_declaration::MappingType;
use crate::adapter::FlownodeContext;
use crate::error::{Error, NotImplementedSnafu, UnexpectedSnafu};
use crate::error::{Error, NotImplementedSnafu};
use crate::expr::{TUMBLE_END, TUMBLE_START};
/// a simple macro to generate a not implemented error
macro_rules! not_impl_err {
@@ -149,19 +147,6 @@ impl common_function::function::Function for TumbleFunction {
fn signature(&self) -> Signature {
Signature::variadic_any(Volatility::Immutable)
}
fn eval(
&self,
_func_ctx: &FunctionContext,
_columns: &[datatypes::prelude::VectorRef],
) -> common_query::error::Result<datatypes::prelude::VectorRef> {
UnexpectedSnafu {
reason: "Tumbler function is not implemented for datafusion executor",
}
.fail()
.map_err(BoxedError::new)
.context(common_query::error::ExecuteSnafu)
}
}
#[cfg(test)]

View File

@@ -114,6 +114,8 @@ impl AnalyzerRule for DistPlannerAnalyzer {
// seems to lost track on it: the `ConfigOptions` is recreated with its default values again.
// So we create a custom `OptimizerConfig` with the desired `ConfigOptions`
// to walk around the issue.
// TODO(LFC): Maybe use DataFusion's `OptimizerContext` again
// once https://github.com/apache/datafusion/pull/17742 is merged.
struct OptimizerContext {
inner: datafusion_optimizer::OptimizerContext,
config: Arc<ConfigOptions>,