mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-09 06:42:57 +00:00
refactor: rewrite some UDFs to DataFusion style (part 4) (#7011)
Signed-off-by: luofucong <luofc@foxmail.com>
This commit is contained in:
@@ -14,18 +14,21 @@
|
||||
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion_common::types;
|
||||
use datafusion_expr::{Coercion, Signature, TypeSignature, TypeSignatureClass, Volatility};
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::prelude::Value;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{MutableVector, StringVectorBuilder, VectorRef};
|
||||
use datafusion_common::arrow::array::{Array, AsArray, StringViewBuilder};
|
||||
use datafusion_common::arrow::compute;
|
||||
use datafusion_common::arrow::datatypes::{DataType, UInt8Type};
|
||||
use datafusion_common::{DataFusionError, types};
|
||||
use datafusion_expr::{
|
||||
Coercion, ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, TypeSignatureClass,
|
||||
Volatility,
|
||||
};
|
||||
use derive_more::Display;
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
|
||||
/// Function that converts an IPv4 address string to CIDR notation.
|
||||
///
|
||||
@@ -46,7 +49,7 @@ impl Function for Ipv4ToCidr {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Utf8)
|
||||
Ok(DataType::Utf8View)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
@@ -62,19 +65,29 @@ impl Function for Ipv4ToCidr {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 1 || columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!("Expected 1 or 2 arguments, got {}", columns.len())
|
||||
}
|
||||
);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
if args.args.len() != 1 && args.args.len() != 2 {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"expecting 1 or 2 arguments, got {}",
|
||||
args.args.len()
|
||||
)));
|
||||
}
|
||||
let columns = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
|
||||
let ip_vec = &columns[0];
|
||||
let mut results = StringVectorBuilder::with_capacity(ip_vec.len());
|
||||
let mut builder = StringViewBuilder::with_capacity(ip_vec.len());
|
||||
let arg0 = compute::cast(ip_vec, &DataType::Utf8View)?;
|
||||
let ip_vec = arg0.as_string_view();
|
||||
|
||||
let has_subnet_arg = columns.len() == 2;
|
||||
let subnet_vec = if has_subnet_arg {
|
||||
let maybe_arg1 = if columns.len() > 1 {
|
||||
Some(compute::cast(&columns[1], &DataType::UInt8)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let subnets = if let Some(arg1) = maybe_arg1.as_ref() {
|
||||
ensure!(
|
||||
columns[1].len() == ip_vec.len(),
|
||||
InvalidFuncArgsSnafu {
|
||||
@@ -83,23 +96,19 @@ impl Function for Ipv4ToCidr {
|
||||
.to_string()
|
||||
}
|
||||
);
|
||||
Some(&columns[1])
|
||||
Some(arg1.as_primitive::<UInt8Type>())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
for i in 0..ip_vec.len() {
|
||||
let ip_str = ip_vec.get(i);
|
||||
let subnet = subnet_vec.map(|v| v.get(i));
|
||||
let ip_str = ip_vec.is_valid(i).then(|| ip_vec.value(i));
|
||||
let subnet = subnets.and_then(|v| v.is_valid(i).then(|| v.value(i)));
|
||||
|
||||
let cidr = match (ip_str, subnet) {
|
||||
(Value::String(s), Some(Value::UInt8(mask))) => {
|
||||
let ip_str = s.as_utf8().trim();
|
||||
(Some(ip_str), Some(mask)) => {
|
||||
if ip_str.is_empty() {
|
||||
return InvalidFuncArgsSnafu {
|
||||
err_msg: "Empty IPv4 address".to_string(),
|
||||
}
|
||||
.fail();
|
||||
return Err(DataFusionError::Execution("empty IPv4 address".to_string()));
|
||||
}
|
||||
|
||||
let ip_addr = complete_and_parse_ipv4(ip_str)?;
|
||||
@@ -109,13 +118,9 @@ impl Function for Ipv4ToCidr {
|
||||
|
||||
Some(format!("{}/{}", masked_ip, mask))
|
||||
}
|
||||
(Value::String(s), None) => {
|
||||
let ip_str = s.as_utf8().trim();
|
||||
(Some(ip_str), None) => {
|
||||
if ip_str.is_empty() {
|
||||
return InvalidFuncArgsSnafu {
|
||||
err_msg: "Empty IPv4 address".to_string(),
|
||||
}
|
||||
.fail();
|
||||
return Err(DataFusionError::Execution("empty IPv4 address".to_string()));
|
||||
}
|
||||
|
||||
let ip_addr = complete_and_parse_ipv4(ip_str)?;
|
||||
@@ -149,10 +154,10 @@ impl Function for Ipv4ToCidr {
|
||||
_ => None,
|
||||
};
|
||||
|
||||
results.push(cidr.as_deref());
|
||||
builder.append_option(cidr.as_deref());
|
||||
}
|
||||
|
||||
Ok(results.to_vector())
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,7 +180,7 @@ impl Function for Ipv6ToCidr {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Utf8)
|
||||
Ok(DataType::Utf8View)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
@@ -188,37 +193,41 @@ impl Function for Ipv6ToCidr {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 1 || columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!("Expected 1 or 2 arguments, got {}", columns.len())
|
||||
}
|
||||
);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
if args.args.len() != 1 && args.args.len() != 2 {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"expecting 1 or 2 arguments, got {}",
|
||||
args.args.len()
|
||||
)));
|
||||
}
|
||||
let columns = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
|
||||
let ip_vec = &columns[0];
|
||||
let size = ip_vec.len();
|
||||
let mut results = StringVectorBuilder::with_capacity(size);
|
||||
let mut builder = StringViewBuilder::with_capacity(size);
|
||||
let arg0 = compute::cast(ip_vec, &DataType::Utf8View)?;
|
||||
let ip_vec = arg0.as_string_view();
|
||||
|
||||
let has_subnet_arg = columns.len() == 2;
|
||||
let subnet_vec = if has_subnet_arg {
|
||||
Some(&columns[1])
|
||||
let maybe_arg1 = if columns.len() > 1 {
|
||||
Some(compute::cast(&columns[1], &DataType::UInt8)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let subnets = maybe_arg1
|
||||
.as_ref()
|
||||
.map(|arg1| arg1.as_primitive::<UInt8Type>());
|
||||
|
||||
for i in 0..size {
|
||||
let ip_str = ip_vec.get(i);
|
||||
let subnet = subnet_vec.map(|v| v.get(i));
|
||||
let ip_str = ip_vec.is_valid(i).then(|| ip_vec.value(i));
|
||||
let subnet = subnets.and_then(|v| v.is_valid(i).then(|| v.value(i)));
|
||||
|
||||
let cidr = match (ip_str, subnet) {
|
||||
(Value::String(s), Some(Value::UInt8(mask))) => {
|
||||
let ip_str = s.as_utf8().trim();
|
||||
(Some(ip_str), Some(mask)) => {
|
||||
if ip_str.is_empty() {
|
||||
return InvalidFuncArgsSnafu {
|
||||
err_msg: "Empty IPv6 address".to_string(),
|
||||
}
|
||||
.fail();
|
||||
return Err(DataFusionError::Execution("empty IPv6 address".to_string()));
|
||||
}
|
||||
|
||||
let ip_addr = complete_and_parse_ipv6(ip_str)?;
|
||||
@@ -228,13 +237,9 @@ impl Function for Ipv6ToCidr {
|
||||
|
||||
Some(format!("{}/{}", masked_ip, mask))
|
||||
}
|
||||
(Value::String(s), None) => {
|
||||
let ip_str = s.as_utf8().trim();
|
||||
(Some(ip_str), None) => {
|
||||
if ip_str.is_empty() {
|
||||
return InvalidFuncArgsSnafu {
|
||||
err_msg: "Empty IPv6 address".to_string(),
|
||||
}
|
||||
.fail();
|
||||
return Err(DataFusionError::Execution("empty IPv6 address".to_string()));
|
||||
}
|
||||
|
||||
let ip_addr = complete_and_parse_ipv6(ip_str)?;
|
||||
@@ -250,10 +255,10 @@ impl Function for Ipv6ToCidr {
|
||||
_ => None,
|
||||
};
|
||||
|
||||
results.push(cidr.as_deref());
|
||||
builder.append_option(cidr.as_deref());
|
||||
}
|
||||
|
||||
Ok(results.to_vector())
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -375,108 +380,148 @@ fn auto_detect_ipv6_subnet(addr: &Ipv6Addr) -> u8 {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::scalars::ScalarVector;
|
||||
use datatypes::vectors::{StringVector, UInt8Vector};
|
||||
use arrow_schema::Field;
|
||||
use datafusion_common::arrow::array::{StringViewArray, UInt8Array};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_ipv4_to_cidr_auto() {
|
||||
let func = Ipv4ToCidr;
|
||||
let ctx = FunctionContext::default();
|
||||
|
||||
// Test data with auto subnet detection
|
||||
let values = vec!["192.168.1.0", "10.0.0.0", "172.16", "192"];
|
||||
let input = Arc::new(StringVector::from_slice(&values)) as VectorRef;
|
||||
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values)));
|
||||
|
||||
let result = func.eval(&ctx, &[input]).unwrap();
|
||||
let result = result.as_any().downcast_ref::<StringVector>().unwrap();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![arg0],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let result = func.invoke_with_args(args).unwrap();
|
||||
let result = result.to_array(4).unwrap();
|
||||
let result = result.as_string_view();
|
||||
|
||||
assert_eq!(result.get_data(0).unwrap(), "192.168.1.0/24");
|
||||
assert_eq!(result.get_data(1).unwrap(), "10.0.0.0/8");
|
||||
assert_eq!(result.get_data(2).unwrap(), "172.16.0.0/16");
|
||||
assert_eq!(result.get_data(3).unwrap(), "192.0.0.0/8");
|
||||
assert_eq!(result.value(0), "192.168.1.0/24");
|
||||
assert_eq!(result.value(1), "10.0.0.0/8");
|
||||
assert_eq!(result.value(2), "172.16.0.0/16");
|
||||
assert_eq!(result.value(3), "192.0.0.0/8");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ipv4_to_cidr_with_subnet() {
|
||||
let func = Ipv4ToCidr;
|
||||
let ctx = FunctionContext::default();
|
||||
|
||||
// Test data with explicit subnet
|
||||
let ip_values = vec!["192.168.1.1", "10.0.0.1", "172.16.5.5"];
|
||||
let subnet_values = vec![24u8, 16u8, 12u8];
|
||||
let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
|
||||
let subnet_input = Arc::new(UInt8Vector::from_vec(subnet_values)) as VectorRef;
|
||||
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&ip_values)));
|
||||
let arg1 = ColumnarValue::Array(Arc::new(UInt8Array::from(subnet_values)));
|
||||
|
||||
let result = func.eval(&ctx, &[ip_input, subnet_input]).unwrap();
|
||||
let result = result.as_any().downcast_ref::<StringVector>().unwrap();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![arg0, arg1],
|
||||
arg_fields: vec![],
|
||||
number_rows: 3,
|
||||
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let result = func.invoke_with_args(args).unwrap();
|
||||
let result = result.to_array(3).unwrap();
|
||||
let result = result.as_string_view();
|
||||
|
||||
assert_eq!(result.get_data(0).unwrap(), "192.168.1.0/24");
|
||||
assert_eq!(result.get_data(1).unwrap(), "10.0.0.0/16");
|
||||
assert_eq!(result.get_data(2).unwrap(), "172.16.0.0/12");
|
||||
assert_eq!(result.value(0), "192.168.1.0/24");
|
||||
assert_eq!(result.value(1), "10.0.0.0/16");
|
||||
assert_eq!(result.value(2), "172.16.0.0/12");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ipv6_to_cidr_auto() {
|
||||
let func = Ipv6ToCidr;
|
||||
let ctx = FunctionContext::default();
|
||||
|
||||
// Test data with auto subnet detection
|
||||
let values = vec!["2001:db8::", "2001:db8", "fe80::1", "::1"];
|
||||
let input = Arc::new(StringVector::from_slice(&values)) as VectorRef;
|
||||
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values)));
|
||||
|
||||
let result = func.eval(&ctx, &[input]).unwrap();
|
||||
let result = result.as_any().downcast_ref::<StringVector>().unwrap();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![arg0],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let result = func.invoke_with_args(args).unwrap();
|
||||
let result = result.to_array(4).unwrap();
|
||||
let result = result.as_string_view();
|
||||
|
||||
assert_eq!(result.get_data(0).unwrap(), "2001:db8::/32");
|
||||
assert_eq!(result.get_data(1).unwrap(), "2001:db8::/32");
|
||||
assert_eq!(result.get_data(2).unwrap(), "fe80::/16");
|
||||
assert_eq!(result.get_data(3).unwrap(), "::1/128"); // Special case for ::1
|
||||
assert_eq!(result.value(0), "2001:db8::/32");
|
||||
assert_eq!(result.value(1), "2001:db8::/32");
|
||||
assert_eq!(result.value(2), "fe80::/16");
|
||||
assert_eq!(result.value(3), "::1/128"); // Special case for ::1
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ipv6_to_cidr_with_subnet() {
|
||||
let func = Ipv6ToCidr;
|
||||
let ctx = FunctionContext::default();
|
||||
|
||||
// Test data with explicit subnet
|
||||
let ip_values = vec!["2001:db8::", "fe80::1", "2001:db8:1234::"];
|
||||
let subnet_values = vec![48u8, 10u8, 56u8];
|
||||
let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
|
||||
let subnet_input = Arc::new(UInt8Vector::from_vec(subnet_values)) as VectorRef;
|
||||
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&ip_values)));
|
||||
let arg1 = ColumnarValue::Array(Arc::new(UInt8Array::from(subnet_values)));
|
||||
|
||||
let result = func.eval(&ctx, &[ip_input, subnet_input]).unwrap();
|
||||
let result = result.as_any().downcast_ref::<StringVector>().unwrap();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![arg0, arg1],
|
||||
arg_fields: vec![],
|
||||
number_rows: 3,
|
||||
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let result = func.invoke_with_args(args).unwrap();
|
||||
let result = result.to_array(3).unwrap();
|
||||
let result = result.as_string_view();
|
||||
|
||||
assert_eq!(result.get_data(0).unwrap(), "2001:db8::/48");
|
||||
assert_eq!(result.get_data(1).unwrap(), "fe80::/10");
|
||||
assert_eq!(result.get_data(2).unwrap(), "2001:db8:1234::/56");
|
||||
assert_eq!(result.value(0), "2001:db8::/48");
|
||||
assert_eq!(result.value(1), "fe80::/10");
|
||||
assert_eq!(result.value(2), "2001:db8:1234::/56");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_inputs() {
|
||||
let ipv4_func = Ipv4ToCidr;
|
||||
let ipv6_func = Ipv6ToCidr;
|
||||
let ctx = FunctionContext::default();
|
||||
|
||||
// Empty string should fail
|
||||
let empty_values = vec![""];
|
||||
let empty_input = Arc::new(StringVector::from_slice(&empty_values)) as VectorRef;
|
||||
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&empty_values)));
|
||||
|
||||
let ipv4_result = ipv4_func.eval(&ctx, std::slice::from_ref(&empty_input));
|
||||
let ipv6_result = ipv6_func.eval(&ctx, std::slice::from_ref(&empty_input));
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![arg0],
|
||||
arg_fields: vec![],
|
||||
number_rows: 1,
|
||||
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let ipv4_result = ipv4_func.invoke_with_args(args.clone());
|
||||
let ipv6_result = ipv6_func.invoke_with_args(args);
|
||||
|
||||
assert!(ipv4_result.is_err());
|
||||
assert!(ipv6_result.is_err());
|
||||
|
||||
// Invalid IP formats should fail
|
||||
let invalid_values = vec!["not an ip", "192.168.1.256", "zzzz::ffff"];
|
||||
let invalid_input = Arc::new(StringVector::from_slice(&invalid_values)) as VectorRef;
|
||||
let arg0 =
|
||||
ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&invalid_values)));
|
||||
|
||||
let ipv4_result = ipv4_func.eval(&ctx, std::slice::from_ref(&invalid_input));
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![arg0],
|
||||
arg_fields: vec![],
|
||||
number_rows: 3,
|
||||
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let ipv4_result = ipv4_func.invoke_with_args(args);
|
||||
|
||||
assert!(ipv4_result.is_err());
|
||||
}
|
||||
|
||||
@@ -14,16 +14,16 @@
|
||||
|
||||
use std::net::Ipv4Addr;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion_expr::{Signature, TypeSignature, Volatility};
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{MutableVector, StringVectorBuilder, UInt32VectorBuilder, VectorRef};
|
||||
use datafusion_common::arrow::array::{Array, AsArray, StringViewBuilder, UInt32Builder};
|
||||
use datafusion_common::arrow::compute;
|
||||
use datafusion_common::arrow::datatypes::{DataType, UInt32Type};
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility};
|
||||
use derive_more::Display;
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::{Function, extract_args};
|
||||
|
||||
/// Function that converts a UInt32 number to an IPv4 address string.
|
||||
///
|
||||
@@ -53,7 +53,7 @@ impl Function for Ipv4NumToString {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Utf8)
|
||||
Ok(DataType::Utf8View)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
@@ -63,22 +63,20 @@ impl Function for Ipv4NumToString {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 1,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!("Expected 1 argument, got {}", columns.len())
|
||||
}
|
||||
);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let [arg0] = extract_args(self.name(), &args)?;
|
||||
let uint_vec = arg0.as_primitive::<UInt32Type>();
|
||||
|
||||
let uint_vec = &columns[0];
|
||||
let size = uint_vec.len();
|
||||
let mut results = StringVectorBuilder::with_capacity(size);
|
||||
let mut builder = StringViewBuilder::with_capacity(size);
|
||||
|
||||
for i in 0..size {
|
||||
let ip_num = uint_vec.get(i);
|
||||
let ip_num = uint_vec.is_valid(i).then(|| uint_vec.value(i));
|
||||
let ip_str = match ip_num {
|
||||
datatypes::value::Value::UInt32(num) => {
|
||||
Some(num) => {
|
||||
// Convert UInt32 to IPv4 string (A.B.C.D format)
|
||||
let a = (num >> 24) & 0xFF;
|
||||
let b = (num >> 16) & 0xFF;
|
||||
@@ -89,10 +87,10 @@ impl Function for Ipv4NumToString {
|
||||
_ => None,
|
||||
};
|
||||
|
||||
results.push(ip_str.as_deref());
|
||||
builder.append_option(ip_str.as_deref());
|
||||
}
|
||||
|
||||
Ok(results.to_vector())
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
|
||||
fn aliases(&self) -> &[String] {
|
||||
@@ -123,23 +121,21 @@ impl Function for Ipv4StringToNum {
|
||||
Signature::string(1, Volatility::Immutable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 1,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!("Expected 1 argument, got {}", columns.len())
|
||||
}
|
||||
);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let [arg0] = extract_args(self.name(), &args)?;
|
||||
|
||||
let ip_vec = &columns[0];
|
||||
let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
|
||||
let ip_vec = arg0.as_string_view();
|
||||
let size = ip_vec.len();
|
||||
let mut results = UInt32VectorBuilder::with_capacity(size);
|
||||
let mut builder = UInt32Builder::with_capacity(size);
|
||||
|
||||
for i in 0..size {
|
||||
let ip_str = ip_vec.get(i);
|
||||
let ip_str = ip_vec.is_valid(i).then(|| ip_vec.value(i));
|
||||
let ip_num = match ip_str {
|
||||
datatypes::value::Value::String(s) => {
|
||||
let ip_str = s.as_utf8();
|
||||
Some(ip_str) => {
|
||||
let ip_addr = Ipv4Addr::from_str(ip_str).map_err(|_| {
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!("Invalid IPv4 address format: {}", ip_str),
|
||||
@@ -151,10 +147,10 @@ impl Function for Ipv4StringToNum {
|
||||
_ => None,
|
||||
};
|
||||
|
||||
results.push(ip_num);
|
||||
builder.append_option(ip_num);
|
||||
}
|
||||
|
||||
Ok(results.to_vector())
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -162,66 +158,92 @@ impl Function for Ipv4StringToNum {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::scalars::ScalarVector;
|
||||
use datatypes::vectors::{StringVector, UInt32Vector};
|
||||
use arrow_schema::Field;
|
||||
use datafusion_common::arrow::array::{StringViewArray, UInt32Array};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_ipv4_num_to_string() {
|
||||
let func = Ipv4NumToString::default();
|
||||
let ctx = FunctionContext::default();
|
||||
|
||||
// Test data
|
||||
let values = vec![167772161u32, 3232235521u32, 0u32, 4294967295u32];
|
||||
let input = Arc::new(UInt32Vector::from_vec(values)) as VectorRef;
|
||||
let input = ColumnarValue::Array(Arc::new(UInt32Array::from(values)));
|
||||
|
||||
let result = func.eval(&ctx, &[input]).unwrap();
|
||||
let result = result.as_any().downcast_ref::<StringVector>().unwrap();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![input],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let result = func.invoke_with_args(args).unwrap();
|
||||
let result = result.to_array(4).unwrap();
|
||||
let result = result.as_string_view();
|
||||
|
||||
assert_eq!(result.get_data(0).unwrap(), "10.0.0.1");
|
||||
assert_eq!(result.get_data(1).unwrap(), "192.168.0.1");
|
||||
assert_eq!(result.get_data(2).unwrap(), "0.0.0.0");
|
||||
assert_eq!(result.get_data(3).unwrap(), "255.255.255.255");
|
||||
assert_eq!(result.value(0), "10.0.0.1");
|
||||
assert_eq!(result.value(1), "192.168.0.1");
|
||||
assert_eq!(result.value(2), "0.0.0.0");
|
||||
assert_eq!(result.value(3), "255.255.255.255");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ipv4_string_to_num() {
|
||||
let func = Ipv4StringToNum;
|
||||
let ctx = FunctionContext::default();
|
||||
|
||||
// Test data
|
||||
let values = vec!["10.0.0.1", "192.168.0.1", "0.0.0.0", "255.255.255.255"];
|
||||
let input = Arc::new(StringVector::from_slice(&values)) as VectorRef;
|
||||
let input = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values)));
|
||||
|
||||
let result = func.eval(&ctx, &[input]).unwrap();
|
||||
let result = result.as_any().downcast_ref::<UInt32Vector>().unwrap();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![input],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::UInt32, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let result = func.invoke_with_args(args).unwrap();
|
||||
let result = result.to_array(4).unwrap();
|
||||
let result = result.as_primitive::<UInt32Type>();
|
||||
|
||||
assert_eq!(result.get_data(0).unwrap(), 167772161);
|
||||
assert_eq!(result.get_data(1).unwrap(), 3232235521);
|
||||
assert_eq!(result.get_data(2).unwrap(), 0);
|
||||
assert_eq!(result.get_data(3).unwrap(), 4294967295);
|
||||
assert_eq!(result.value(0), 167772161);
|
||||
assert_eq!(result.value(1), 3232235521);
|
||||
assert_eq!(result.value(2), 0);
|
||||
assert_eq!(result.value(3), 4294967295);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ipv4_conversions_roundtrip() {
|
||||
let to_num = Ipv4StringToNum;
|
||||
let to_string = Ipv4NumToString::default();
|
||||
let ctx = FunctionContext::default();
|
||||
|
||||
// Test data for string to num to string
|
||||
let values = vec!["10.0.0.1", "192.168.0.1", "0.0.0.0", "255.255.255.255"];
|
||||
let input = Arc::new(StringVector::from_slice(&values)) as VectorRef;
|
||||
let input = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values)));
|
||||
|
||||
let num_result = to_num.eval(&ctx, &[input]).unwrap();
|
||||
let back_to_string = to_string.eval(&ctx, &[num_result]).unwrap();
|
||||
let str_result = back_to_string
|
||||
.as_any()
|
||||
.downcast_ref::<StringVector>()
|
||||
.unwrap();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![input],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::UInt32, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let result = to_num.invoke_with_args(args).unwrap();
|
||||
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![result],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let result = to_string.invoke_with_args(args).unwrap();
|
||||
let result = result.to_array(4).unwrap();
|
||||
let result = result.as_string_view();
|
||||
|
||||
for (i, expected) in values.iter().enumerate() {
|
||||
assert_eq!(str_result.get_data(i).unwrap(), *expected);
|
||||
assert_eq!(result.value(i), *expected);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,17 +14,17 @@
|
||||
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion_expr::{Signature, Volatility};
|
||||
use datatypes::prelude::Value;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, StringVectorBuilder, VectorRef};
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_common::arrow::array::{Array, AsArray, BinaryViewBuilder, StringViewBuilder};
|
||||
use datafusion_common::arrow::compute;
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
|
||||
use derive_more::Display;
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::{Function, extract_args};
|
||||
|
||||
/// Function that converts a hex string representation of an IPv6 address to a formatted string.
|
||||
///
|
||||
@@ -41,30 +41,29 @@ impl Function for Ipv6NumToString {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Utf8)
|
||||
Ok(DataType::Utf8View)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
Signature::string(1, Volatility::Immutable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 1,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!("Expected 1 argument, got {}", columns.len())
|
||||
}
|
||||
);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let [arg0] = extract_args(self.name(), &args)?;
|
||||
|
||||
let hex_vec = &columns[0];
|
||||
let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
|
||||
let hex_vec = arg0.as_string_view();
|
||||
let size = hex_vec.len();
|
||||
let mut results = StringVectorBuilder::with_capacity(size);
|
||||
let mut builder = StringViewBuilder::with_capacity(size);
|
||||
|
||||
for i in 0..size {
|
||||
let hex_str = hex_vec.get(i);
|
||||
let hex_str = hex_vec.is_valid(i).then(|| hex_vec.value(i));
|
||||
let ip_str = match hex_str {
|
||||
Value::String(s) => {
|
||||
let hex_str = s.as_utf8().to_lowercase();
|
||||
Some(s) => {
|
||||
let hex_str = s.to_lowercase();
|
||||
|
||||
// Validate and convert hex string to bytes
|
||||
let bytes = if hex_str.len() == 32 {
|
||||
@@ -80,10 +79,10 @@ impl Function for Ipv6NumToString {
|
||||
}
|
||||
bytes
|
||||
} else {
|
||||
return InvalidFuncArgsSnafu {
|
||||
err_msg: format!("Expected 32 hex characters, got {}", hex_str.len()),
|
||||
}
|
||||
.fail();
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"expecting 32 hex characters, got {}",
|
||||
hex_str.len()
|
||||
)));
|
||||
};
|
||||
|
||||
// Convert bytes to IPv6 address
|
||||
@@ -106,10 +105,10 @@ impl Function for Ipv6NumToString {
|
||||
_ => None,
|
||||
};
|
||||
|
||||
results.push(ip_str.as_deref());
|
||||
builder.append_option(ip_str.as_deref());
|
||||
}
|
||||
|
||||
Ok(results.to_vector())
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -130,31 +129,28 @@ impl Function for Ipv6StringToNum {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Binary)
|
||||
Ok(DataType::BinaryView)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
Signature::string(1, Volatility::Immutable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 1,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!("Expected 1 argument, got {}", columns.len())
|
||||
}
|
||||
);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let [arg0] = extract_args(self.name(), &args)?;
|
||||
let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
|
||||
let ip_vec = arg0.as_string_view();
|
||||
|
||||
let ip_vec = &columns[0];
|
||||
let size = ip_vec.len();
|
||||
let mut results = BinaryVectorBuilder::with_capacity(size);
|
||||
let mut builder = BinaryViewBuilder::with_capacity(size);
|
||||
|
||||
for i in 0..size {
|
||||
let ip_str = ip_vec.get(i);
|
||||
let ip_str = ip_vec.is_valid(i).then(|| ip_vec.value(i));
|
||||
let ip_binary = match ip_str {
|
||||
Value::String(s) => {
|
||||
let addr_str = s.as_utf8();
|
||||
|
||||
Some(addr_str) => {
|
||||
let addr = if let Ok(ipv6) = Ipv6Addr::from_str(addr_str) {
|
||||
// Direct IPv6 address
|
||||
ipv6
|
||||
@@ -163,10 +159,10 @@ impl Function for Ipv6StringToNum {
|
||||
ipv4.to_ipv6_mapped()
|
||||
} else {
|
||||
// Invalid format
|
||||
return InvalidFuncArgsSnafu {
|
||||
err_msg: format!("Invalid IPv6 address format: {}", addr_str),
|
||||
}
|
||||
.fail();
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"Invalid IPv6 address format: {}",
|
||||
addr_str
|
||||
)));
|
||||
};
|
||||
|
||||
// Convert IPv6 address to binary (16 bytes)
|
||||
@@ -176,10 +172,10 @@ impl Function for Ipv6StringToNum {
|
||||
_ => None,
|
||||
};
|
||||
|
||||
results.push(ip_binary.as_deref());
|
||||
builder.append_option(ip_binary.as_deref());
|
||||
}
|
||||
|
||||
Ok(results.to_vector())
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,15 +184,14 @@ mod tests {
|
||||
use std::fmt::Write;
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::scalars::ScalarVector;
|
||||
use datatypes::vectors::{BinaryVector, StringVector, Vector};
|
||||
use arrow_schema::Field;
|
||||
use datafusion_common::arrow::array::StringViewArray;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_ipv6_num_to_string() {
|
||||
let func = Ipv6NumToString;
|
||||
let ctx = FunctionContext::default();
|
||||
|
||||
// Hex string for "2001:db8::1"
|
||||
let hex_str1 = "20010db8000000000000000000000001";
|
||||
@@ -205,62 +200,93 @@ mod tests {
|
||||
let hex_str2 = "00000000000000000000ffffc0a80001";
|
||||
|
||||
let values = vec![hex_str1, hex_str2];
|
||||
let input = Arc::new(StringVector::from_slice(&values)) as VectorRef;
|
||||
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values)));
|
||||
|
||||
let result = func.eval(&ctx, &[input]).unwrap();
|
||||
let result = result.as_any().downcast_ref::<StringVector>().unwrap();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![arg0],
|
||||
arg_fields: vec![],
|
||||
number_rows: 2,
|
||||
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let result = func.invoke_with_args(args).unwrap();
|
||||
let result = result.to_array(2).unwrap();
|
||||
let result = result.as_string_view();
|
||||
|
||||
assert_eq!(result.get_data(0).unwrap(), "2001:db8::1");
|
||||
assert_eq!(result.get_data(1).unwrap(), "::ffff:192.168.0.1");
|
||||
assert_eq!(result.value(0), "2001:db8::1");
|
||||
assert_eq!(result.value(1), "::ffff:192.168.0.1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ipv6_num_to_string_uppercase() {
|
||||
let func = Ipv6NumToString;
|
||||
let ctx = FunctionContext::default();
|
||||
|
||||
// Uppercase hex string for "2001:db8::1"
|
||||
let hex_str = "20010DB8000000000000000000000001";
|
||||
|
||||
let values = vec![hex_str];
|
||||
let input = Arc::new(StringVector::from_slice(&values)) as VectorRef;
|
||||
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values)));
|
||||
|
||||
let result = func.eval(&ctx, &[input]).unwrap();
|
||||
let result = result.as_any().downcast_ref::<StringVector>().unwrap();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![arg0],
|
||||
arg_fields: vec![],
|
||||
number_rows: 1,
|
||||
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let result = func.invoke_with_args(args).unwrap();
|
||||
let result = result.to_array(1).unwrap();
|
||||
let result = result.as_string_view();
|
||||
|
||||
assert_eq!(result.get_data(0).unwrap(), "2001:db8::1");
|
||||
assert_eq!(result.value(0), "2001:db8::1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ipv6_num_to_string_error() {
|
||||
let func = Ipv6NumToString;
|
||||
let ctx = FunctionContext::default();
|
||||
|
||||
// Invalid hex string - wrong length
|
||||
let hex_str = "20010db8";
|
||||
|
||||
let values = vec![hex_str];
|
||||
let input = Arc::new(StringVector::from_slice(&values)) as VectorRef;
|
||||
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values)));
|
||||
|
||||
// Should return an error
|
||||
let result = func.eval(&ctx, &[input]);
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![arg0],
|
||||
arg_fields: vec![],
|
||||
number_rows: 2,
|
||||
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let result = func.invoke_with_args(args);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Check that the error message contains expected text
|
||||
let error_msg = result.unwrap_err().to_string();
|
||||
assert!(error_msg.contains("Expected 32 hex characters"));
|
||||
assert_eq!(
|
||||
error_msg,
|
||||
"Execution error: expecting 32 hex characters, got 8"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ipv6_string_to_num() {
|
||||
let func = Ipv6StringToNum;
|
||||
let ctx = FunctionContext::default();
|
||||
|
||||
let values = vec!["2001:db8::1", "::ffff:192.168.0.1", "192.168.0.1"];
|
||||
let input = Arc::new(StringVector::from_slice(&values)) as VectorRef;
|
||||
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values)));
|
||||
|
||||
let result = func.eval(&ctx, &[input]).unwrap();
|
||||
let result = result.as_any().downcast_ref::<BinaryVector>().unwrap();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![arg0],
|
||||
arg_fields: vec![],
|
||||
number_rows: 3,
|
||||
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let result = func.invoke_with_args(args).unwrap();
|
||||
let result = result.to_array(3).unwrap();
|
||||
let result = result.as_binary_view();
|
||||
|
||||
// Expected binary for "2001:db8::1"
|
||||
let expected_1 = [
|
||||
@@ -272,33 +298,37 @@ mod tests {
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xC0, 0xA8, 0, 0x01,
|
||||
];
|
||||
|
||||
assert_eq!(result.get_data(0).unwrap(), &expected_1);
|
||||
assert_eq!(result.get_data(1).unwrap(), &expected_2);
|
||||
assert_eq!(result.get_data(2).unwrap(), &expected_2);
|
||||
assert_eq!(result.value(0), &expected_1);
|
||||
assert_eq!(result.value(1), &expected_2);
|
||||
assert_eq!(result.value(2), &expected_2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ipv6_conversions_roundtrip() {
|
||||
let to_num = Ipv6StringToNum;
|
||||
let to_string = Ipv6NumToString;
|
||||
let ctx = FunctionContext::default();
|
||||
|
||||
// Test data
|
||||
let values = vec!["2001:db8::1", "::ffff:192.168.0.1"];
|
||||
let input = Arc::new(StringVector::from_slice(&values)) as VectorRef;
|
||||
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values)));
|
||||
|
||||
// Convert IPv6 addresses to binary
|
||||
let binary_result = to_num.eval(&ctx, std::slice::from_ref(&input)).unwrap();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![arg0],
|
||||
arg_fields: vec![],
|
||||
number_rows: 2,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let result = to_num.invoke_with_args(args).unwrap();
|
||||
|
||||
// Convert binary to hex string representation (for ipv6_num_to_string)
|
||||
let mut hex_strings = Vec::new();
|
||||
let binary_vector = binary_result
|
||||
.as_any()
|
||||
.downcast_ref::<BinaryVector>()
|
||||
.unwrap();
|
||||
let result = result.to_array(2).unwrap();
|
||||
let binary_vector = result.as_binary_view();
|
||||
|
||||
for i in 0..binary_vector.len() {
|
||||
let bytes = binary_vector.get_data(i).unwrap();
|
||||
let bytes = binary_vector.value(i);
|
||||
let hex = bytes.iter().fold(String::new(), |mut acc, b| {
|
||||
write!(&mut acc, "{:02x}", b).unwrap();
|
||||
acc
|
||||
@@ -307,18 +337,23 @@ mod tests {
|
||||
}
|
||||
|
||||
let hex_str_refs: Vec<&str> = hex_strings.iter().map(|s| s.as_str()).collect();
|
||||
let hex_input = Arc::new(StringVector::from_slice(&hex_str_refs)) as VectorRef;
|
||||
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&hex_str_refs)));
|
||||
|
||||
// Now convert hex to formatted string
|
||||
let string_result = to_string.eval(&ctx, &[hex_input]).unwrap();
|
||||
let str_result = string_result
|
||||
.as_any()
|
||||
.downcast_ref::<StringVector>()
|
||||
.unwrap();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![arg0],
|
||||
arg_fields: vec![],
|
||||
number_rows: 2,
|
||||
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let result = to_string.invoke_with_args(args).unwrap();
|
||||
let result = result.to_array(2).unwrap();
|
||||
let result = result.as_string_view();
|
||||
|
||||
// Compare with original input
|
||||
assert_eq!(str_result.get_data(0).unwrap(), values[0]);
|
||||
assert_eq!(str_result.get_data(1).unwrap(), values[1]);
|
||||
assert_eq!(result.value(0), values[0]);
|
||||
assert_eq!(result.value(1), values[1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -327,24 +362,35 @@ mod tests {
|
||||
// can be converted back using ipv6_string_to_num
|
||||
let to_string = Ipv6NumToString;
|
||||
let to_binary = Ipv6StringToNum;
|
||||
let ctx = FunctionContext::default();
|
||||
|
||||
// Hex representation of IPv6 addresses
|
||||
let hex_values = vec![
|
||||
"20010db8000000000000000000000001",
|
||||
"00000000000000000000ffffc0a80001",
|
||||
];
|
||||
let hex_input = Arc::new(StringVector::from_slice(&hex_values)) as VectorRef;
|
||||
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&hex_values)));
|
||||
|
||||
// Convert hex to string representation
|
||||
let string_result = to_string.eval(&ctx, &[hex_input]).unwrap();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![arg0],
|
||||
arg_fields: vec![],
|
||||
number_rows: 2,
|
||||
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let result = to_string.invoke_with_args(args).unwrap();
|
||||
|
||||
// Then convert string representation back to binary
|
||||
let binary_result = to_binary.eval(&ctx, &[string_result]).unwrap();
|
||||
let bin_result = binary_result
|
||||
.as_any()
|
||||
.downcast_ref::<BinaryVector>()
|
||||
.unwrap();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![result],
|
||||
arg_fields: vec![],
|
||||
number_rows: 2,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let result = to_binary.invoke_with_args(args).unwrap();
|
||||
let result = result.to_array(2).unwrap();
|
||||
let result = result.as_binary_view();
|
||||
|
||||
// Expected binary values
|
||||
let expected_bin1 = [
|
||||
@@ -354,7 +400,7 @@ mod tests {
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xC0, 0xA8, 0, 0x01,
|
||||
];
|
||||
|
||||
assert_eq!(bin_result.get_data(0).unwrap(), &expected_bin1);
|
||||
assert_eq!(bin_result.get_data(1).unwrap(), &expected_bin2);
|
||||
assert_eq!(result.value(0), &expected_bin1);
|
||||
assert_eq!(result.value(1), &expected_bin2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,17 +14,18 @@
|
||||
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion_expr::{Signature, Volatility};
|
||||
use datatypes::prelude::Value;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{BooleanVectorBuilder, MutableVector, VectorRef};
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_common::arrow::array::{Array, AsArray, BooleanBuilder};
|
||||
use datafusion_common::arrow::compute;
|
||||
use datafusion_common::arrow::datatypes::DataType;
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
|
||||
use derive_more::Display;
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::{Function, extract_args};
|
||||
|
||||
/// Function that checks if an IPv4 address is within a specified CIDR range.
|
||||
///
|
||||
@@ -52,42 +53,31 @@ impl Function for Ipv4InRange {
|
||||
Signature::string(2, Volatility::Immutable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!("Expected 2 arguments, got {}", columns.len())
|
||||
}
|
||||
);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let [arg0, arg1] = extract_args(self.name(), &args)?;
|
||||
|
||||
let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
|
||||
let ip_vec = arg0.as_string_view();
|
||||
let arg1 = compute::cast(&arg1, &DataType::Utf8View)?;
|
||||
let ranges = arg1.as_string_view();
|
||||
|
||||
let ip_vec = &columns[0];
|
||||
let range_vec = &columns[1];
|
||||
let size = ip_vec.len();
|
||||
|
||||
ensure!(
|
||||
range_vec.len() == size,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: "IP addresses and CIDR ranges must have the same number of rows"
|
||||
.to_string()
|
||||
}
|
||||
);
|
||||
|
||||
let mut results = BooleanVectorBuilder::with_capacity(size);
|
||||
let mut builder = BooleanBuilder::with_capacity(size);
|
||||
|
||||
for i in 0..size {
|
||||
let ip = ip_vec.get(i);
|
||||
let range = range_vec.get(i);
|
||||
let ip = ip_vec.is_valid(i).then(|| ip_vec.value(i));
|
||||
let range = ranges.is_valid(i).then(|| ranges.value(i));
|
||||
|
||||
let in_range = match (ip, range) {
|
||||
(Value::String(ip_str), Value::String(range_str)) => {
|
||||
let ip_str = ip_str.as_utf8().trim();
|
||||
let range_str = range_str.as_utf8().trim();
|
||||
|
||||
(Some(ip_str), Some(range_str)) => {
|
||||
if ip_str.is_empty() || range_str.is_empty() {
|
||||
return InvalidFuncArgsSnafu {
|
||||
err_msg: "IP address and CIDR range cannot be empty".to_string(),
|
||||
}
|
||||
.fail();
|
||||
return Err(DataFusionError::Execution(
|
||||
"IP address or CIDR range cannot be empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Parse the IP address
|
||||
@@ -107,10 +97,10 @@ impl Function for Ipv4InRange {
|
||||
_ => None,
|
||||
};
|
||||
|
||||
results.push(in_range);
|
||||
builder.append_option(in_range);
|
||||
}
|
||||
|
||||
Ok(results.to_vector())
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,42 +131,29 @@ impl Function for Ipv6InRange {
|
||||
Signature::string(2, Volatility::Immutable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!("Expected 2 arguments, got {}", columns.len())
|
||||
}
|
||||
);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let [arg0, arg1] = extract_args(self.name(), &args)?;
|
||||
|
||||
let ip_vec = &columns[0];
|
||||
let range_vec = &columns[1];
|
||||
let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
|
||||
let ip_vec = arg0.as_string_view();
|
||||
let arg1 = compute::cast(&arg1, &DataType::Utf8View)?;
|
||||
let ranges = arg1.as_string_view();
|
||||
let size = ip_vec.len();
|
||||
|
||||
ensure!(
|
||||
range_vec.len() == size,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: "IP addresses and CIDR ranges must have the same number of rows"
|
||||
.to_string()
|
||||
}
|
||||
);
|
||||
|
||||
let mut results = BooleanVectorBuilder::with_capacity(size);
|
||||
let mut builder = BooleanBuilder::with_capacity(size);
|
||||
|
||||
for i in 0..size {
|
||||
let ip = ip_vec.get(i);
|
||||
let range = range_vec.get(i);
|
||||
let ip = ip_vec.is_valid(i).then(|| ip_vec.value(i));
|
||||
let range = ranges.is_valid(i).then(|| ranges.value(i));
|
||||
|
||||
let in_range = match (ip, range) {
|
||||
(Value::String(ip_str), Value::String(range_str)) => {
|
||||
let ip_str = ip_str.as_utf8().trim();
|
||||
let range_str = range_str.as_utf8().trim();
|
||||
|
||||
(Some(ip_str), Some(range_str)) => {
|
||||
if ip_str.is_empty() || range_str.is_empty() {
|
||||
return InvalidFuncArgsSnafu {
|
||||
err_msg: "IP address and CIDR range cannot be empty".to_string(),
|
||||
}
|
||||
.fail();
|
||||
return Err(DataFusionError::Execution(
|
||||
"IP address or CIDR range cannot be empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Parse the IP address
|
||||
@@ -196,10 +173,10 @@ impl Function for Ipv6InRange {
|
||||
_ => None,
|
||||
};
|
||||
|
||||
results.push(in_range);
|
||||
builder.append_option(in_range);
|
||||
}
|
||||
|
||||
Ok(results.to_vector())
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -329,15 +306,14 @@ fn is_ipv6_in_range(ip: &Ipv6Addr, cidr_base: &Ipv6Addr, prefix_len: u8) -> Opti
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::scalars::ScalarVector;
|
||||
use datatypes::vectors::{BooleanVector, StringVector};
|
||||
use arrow_schema::Field;
|
||||
use datafusion_common::arrow::array::StringViewArray;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_ipv4_in_range() {
|
||||
let func = Ipv4InRange;
|
||||
let ctx = FunctionContext::default();
|
||||
|
||||
// Test IPs
|
||||
let ip_values = vec![
|
||||
@@ -357,24 +333,31 @@ mod tests {
|
||||
"172.16.0.0/16",
|
||||
];
|
||||
|
||||
let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
|
||||
let cidr_input = Arc::new(StringVector::from_slice(&cidr_values)) as VectorRef;
|
||||
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&ip_values)));
|
||||
let arg1 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&cidr_values)));
|
||||
|
||||
let result = func.eval(&ctx, &[ip_input, cidr_input]).unwrap();
|
||||
let result = result.as_any().downcast_ref::<BooleanVector>().unwrap();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![arg0, arg1],
|
||||
arg_fields: vec![],
|
||||
number_rows: 5,
|
||||
return_field: Arc::new(Field::new("x", DataType::Boolean, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let result = func.invoke_with_args(args).unwrap();
|
||||
let result = result.to_array(5).unwrap();
|
||||
let result = result.as_boolean();
|
||||
|
||||
// Expected results
|
||||
assert!(result.get_data(0).unwrap()); // 192.168.1.5 is in 192.168.1.0/24
|
||||
assert!(!result.get_data(1).unwrap()); // 192.168.2.1 is not in 192.168.1.0/24
|
||||
assert!(result.get_data(2).unwrap()); // 10.0.0.1 is in 10.0.0.0/8
|
||||
assert!(result.get_data(3).unwrap()); // 10.1.0.1 is in 10.0.0.0/8
|
||||
assert!(result.get_data(4).unwrap()); // 172.16.0.1 is in 172.16.0.0/16
|
||||
assert!(result.value(0)); // 192.168.1.5 is in 192.168.1.0/24
|
||||
assert!(!result.value(1)); // 192.168.2.1 is not in 192.168.1.0/24
|
||||
assert!(result.value(2)); // 10.0.0.1 is in 10.0.0.0/8
|
||||
assert!(result.value(3)); // 10.1.0.1 is in 10.0.0.0/8
|
||||
assert!(result.value(4)); // 172.16.0.1 is in 172.16.0.0/16
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ipv6_in_range() {
|
||||
let func = Ipv6InRange;
|
||||
let ctx = FunctionContext::default();
|
||||
|
||||
// Test IPs
|
||||
let ip_values = vec![
|
||||
@@ -394,46 +377,70 @@ mod tests {
|
||||
"fe80::/16",
|
||||
];
|
||||
|
||||
let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
|
||||
let cidr_input = Arc::new(StringVector::from_slice(&cidr_values)) as VectorRef;
|
||||
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&ip_values)));
|
||||
let arg1 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&cidr_values)));
|
||||
|
||||
let result = func.eval(&ctx, &[ip_input, cidr_input]).unwrap();
|
||||
let result = result.as_any().downcast_ref::<BooleanVector>().unwrap();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![arg0, arg1],
|
||||
arg_fields: vec![],
|
||||
number_rows: 5,
|
||||
return_field: Arc::new(Field::new("x", DataType::Boolean, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let result = func.invoke_with_args(args).unwrap();
|
||||
let result = result.to_array(5).unwrap();
|
||||
let result = result.as_boolean();
|
||||
|
||||
// Expected results
|
||||
assert!(result.get_data(0).unwrap()); // 2001:db8::1 is in 2001:db8::/32
|
||||
assert!(result.get_data(1).unwrap()); // 2001:db8:1:: is in 2001:db8::/32
|
||||
assert!(!result.get_data(2).unwrap()); // 2001:db9::1 is not in 2001:db8::/32
|
||||
assert!(result.get_data(3).unwrap()); // ::1 is in ::1/128
|
||||
assert!(result.get_data(4).unwrap()); // fe80::1 is in fe80::/16
|
||||
assert!(result.value(0)); // 2001:db8::1 is in 2001:db8::/32
|
||||
assert!(result.value(1)); // 2001:db8:1:: is in 2001:db8::/32
|
||||
assert!(!result.value(2)); // 2001:db9::1 is not in 2001:db8::/32
|
||||
assert!(result.value(3)); // ::1 is in ::1/128
|
||||
assert!(result.value(4)); // fe80::1 is in fe80::/16
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_inputs() {
|
||||
let ipv4_func = Ipv4InRange;
|
||||
let ipv6_func = Ipv6InRange;
|
||||
let ctx = FunctionContext::default();
|
||||
|
||||
// Invalid IPv4 address
|
||||
let invalid_ip_values = vec!["not-an-ip", "192.168.1.300"];
|
||||
let cidr_values = vec!["192.168.1.0/24", "192.168.1.0/24"];
|
||||
|
||||
let invalid_ip_input = Arc::new(StringVector::from_slice(&invalid_ip_values)) as VectorRef;
|
||||
let cidr_input = Arc::new(StringVector::from_slice(&cidr_values)) as VectorRef;
|
||||
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(
|
||||
&invalid_ip_values,
|
||||
)));
|
||||
let arg1 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&cidr_values)));
|
||||
|
||||
let result = ipv4_func.eval(&ctx, &[invalid_ip_input, cidr_input]);
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![arg0, arg1],
|
||||
arg_fields: vec![],
|
||||
number_rows: 2,
|
||||
return_field: Arc::new(Field::new("x", DataType::Boolean, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let result = ipv4_func.invoke_with_args(args);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Invalid CIDR notation
|
||||
let ip_values = vec!["192.168.1.1", "2001:db8::1"];
|
||||
let invalid_cidr_values = vec!["192.168.1.0", "2001:db8::/129"];
|
||||
|
||||
let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
|
||||
let invalid_cidr_input =
|
||||
Arc::new(StringVector::from_slice(&invalid_cidr_values)) as VectorRef;
|
||||
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&ip_values)));
|
||||
let arg1 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(
|
||||
&invalid_cidr_values,
|
||||
)));
|
||||
|
||||
let ipv4_result = ipv4_func.eval(&ctx, &[ip_input.clone(), invalid_cidr_input.clone()]);
|
||||
let ipv6_result = ipv6_func.eval(&ctx, &[ip_input, invalid_cidr_input]);
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![arg0, arg1],
|
||||
arg_fields: vec![],
|
||||
number_rows: 2,
|
||||
return_field: Arc::new(Field::new("x", DataType::Boolean, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let ipv4_result = ipv4_func.invoke_with_args(args.clone());
|
||||
let ipv6_result = ipv6_func.invoke_with_args(args);
|
||||
|
||||
assert!(ipv4_result.is_err());
|
||||
assert!(ipv6_result.is_err());
|
||||
@@ -442,20 +449,27 @@ mod tests {
|
||||
#[test]
|
||||
fn test_edge_cases() {
|
||||
let ipv4_func = Ipv4InRange;
|
||||
let ctx = FunctionContext::default();
|
||||
|
||||
// Edge cases like prefix length 0 (matches everything) and 32 (exact match)
|
||||
let ip_values = vec!["8.8.8.8", "192.168.1.1", "192.168.1.1"];
|
||||
let cidr_values = vec!["0.0.0.0/0", "192.168.1.1/32", "192.168.1.0/32"];
|
||||
|
||||
let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
|
||||
let cidr_input = Arc::new(StringVector::from_slice(&cidr_values)) as VectorRef;
|
||||
let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&ip_values)));
|
||||
let arg1 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&cidr_values)));
|
||||
|
||||
let result = ipv4_func.eval(&ctx, &[ip_input, cidr_input]).unwrap();
|
||||
let result = result.as_any().downcast_ref::<BooleanVector>().unwrap();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![arg0, arg1],
|
||||
arg_fields: vec![],
|
||||
number_rows: 3,
|
||||
return_field: Arc::new(Field::new("x", DataType::Boolean, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let result = ipv4_func.invoke_with_args(args).unwrap();
|
||||
let result = result.to_array(3).unwrap();
|
||||
let result = result.as_boolean();
|
||||
|
||||
assert!(result.get_data(0).unwrap()); // 8.8.8.8 is in 0.0.0.0/0 (matches everything)
|
||||
assert!(result.get_data(1).unwrap()); // 192.168.1.1 is in 192.168.1.1/32 (exact match)
|
||||
assert!(!result.get_data(2).unwrap()); // 192.168.1.1 is not in 192.168.1.0/32 (no match)
|
||||
assert!(result.value(0)); // 8.8.8.8 is in 0.0.0.0/0 (matches everything)
|
||||
assert!(result.value(1)); // 192.168.1.1 is in 192.168.1.1/32 (exact match)
|
||||
assert!(!result.value(2)); // 192.168.1.1 is not in 192.168.1.0/32 (no match)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,19 +15,14 @@
|
||||
use std::fmt;
|
||||
use std::fmt::Display;
|
||||
|
||||
use common_query::error;
|
||||
use common_query::error::{ArrowComputeSnafu, InvalidFuncArgsSnafu, Result};
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use common_query::error::Result;
|
||||
use datafusion_common::arrow::compute;
|
||||
use datafusion_common::arrow::compute::kernels::numeric;
|
||||
use datafusion_common::arrow::datatypes::DataType;
|
||||
use datafusion_expr::type_coercion::aggregates::NUMERICS;
|
||||
use datafusion_expr::{Signature, Volatility};
|
||||
use datatypes::arrow::compute;
|
||||
use datatypes::arrow::compute::kernels::numeric;
|
||||
use datatypes::arrow::datatypes::DataType as ArrowDataType;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::vectors::{Helper, VectorRef};
|
||||
use snafu::{ResultExt, ensure};
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::{Function, extract_args};
|
||||
|
||||
const NAME: &str = "mod";
|
||||
|
||||
@@ -60,38 +55,24 @@ impl Function for ModuloFunction {
|
||||
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly two, have: {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
let nums = &columns[0];
|
||||
let divs = &columns[1];
|
||||
let nums_arrow_array = &nums.to_arrow_array();
|
||||
let divs_arrow_array = &divs.to_arrow_array();
|
||||
let array = numeric::rem(nums_arrow_array, divs_arrow_array).context(ArrowComputeSnafu)?;
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let [nums, divs] = extract_args(self.name(), &args)?;
|
||||
let array = numeric::rem(&nums, &divs)?;
|
||||
|
||||
let result = match nums.data_type() {
|
||||
ConcreteDataType::Int8(_)
|
||||
| ConcreteDataType::Int16(_)
|
||||
| ConcreteDataType::Int32(_)
|
||||
| ConcreteDataType::Int64(_) => compute::cast(&array, &ArrowDataType::Int64),
|
||||
ConcreteDataType::UInt8(_)
|
||||
| ConcreteDataType::UInt16(_)
|
||||
| ConcreteDataType::UInt32(_)
|
||||
| ConcreteDataType::UInt64(_) => compute::cast(&array, &ArrowDataType::UInt64),
|
||||
ConcreteDataType::Float32(_) | ConcreteDataType::Float64(_) => {
|
||||
compute::cast(&array, &ArrowDataType::Float64)
|
||||
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
|
||||
compute::cast(&array, &DataType::Int64)
|
||||
}
|
||||
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
|
||||
compute::cast(&array, &DataType::UInt64)
|
||||
}
|
||||
DataType::Float32 | DataType::Float64 => compute::cast(&array, &DataType::Float64),
|
||||
_ => unreachable!("unexpected datatype: {:?}", nums.data_type()),
|
||||
}
|
||||
.context(ArrowComputeSnafu)?;
|
||||
Helper::try_into_vector(&result).context(error::FromArrowArraySnafu)
|
||||
}?;
|
||||
Ok(ColumnarValue::Array(result))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,9 +80,11 @@ impl Function for ModuloFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_error::ext::ErrorExt;
|
||||
use datatypes::value::Value;
|
||||
use datatypes::vectors::{Float64Vector, Int32Vector, StringVector, UInt32Vector};
|
||||
use arrow_schema::Field;
|
||||
use datafusion_common::arrow::array::{
|
||||
AsArray, Float64Array, Int32Array, StringViewArray, UInt32Array,
|
||||
};
|
||||
use datafusion_common::arrow::datatypes::{Float64Type, Int64Type, UInt64Type};
|
||||
|
||||
use super::*;
|
||||
#[test]
|
||||
@@ -120,15 +103,23 @@ mod tests {
|
||||
let nums = vec![18, -17, 5, -6];
|
||||
let divs = vec![4, 8, -5, -5];
|
||||
|
||||
let args: Vec<VectorRef> = vec![
|
||||
Arc::new(Int32Vector::from_vec(nums.clone())),
|
||||
Arc::new(Int32Vector::from_vec(divs.clone())),
|
||||
];
|
||||
let result = function.eval(&FunctionContext::default(), &args).unwrap();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![
|
||||
ColumnarValue::Array(Arc::new(Int32Array::from(nums.clone()))),
|
||||
ColumnarValue::Array(Arc::new(Int32Array::from(divs.clone()))),
|
||||
],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::Int64, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
let result = result.to_array(4).unwrap();
|
||||
let result = result.as_primitive::<Int64Type>();
|
||||
assert_eq!(result.len(), 4);
|
||||
for i in 0..4 {
|
||||
let p: i64 = (nums[i] % divs[i]) as i64;
|
||||
assert!(matches!(result.get(i), Value::Int64(v) if v == p));
|
||||
assert_eq!(result.value(i), p);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -148,15 +139,23 @@ mod tests {
|
||||
let nums: Vec<u32> = vec![18, 17, 5, 6];
|
||||
let divs: Vec<u32> = vec![4, 8, 5, 5];
|
||||
|
||||
let args: Vec<VectorRef> = vec![
|
||||
Arc::new(UInt32Vector::from_vec(nums.clone())),
|
||||
Arc::new(UInt32Vector::from_vec(divs.clone())),
|
||||
];
|
||||
let result = function.eval(&FunctionContext::default(), &args).unwrap();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![
|
||||
ColumnarValue::Array(Arc::new(UInt32Array::from(nums.clone()))),
|
||||
ColumnarValue::Array(Arc::new(UInt32Array::from(divs.clone()))),
|
||||
],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
let result = result.to_array(4).unwrap();
|
||||
let result = result.as_primitive::<UInt64Type>();
|
||||
assert_eq!(result.len(), 4);
|
||||
for i in 0..4 {
|
||||
let p: u64 = (nums[i] % divs[i]) as u64;
|
||||
assert!(matches!(result.get(i), Value::UInt64(v) if v == p));
|
||||
assert_eq!(result.value(i), p);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -176,15 +175,23 @@ mod tests {
|
||||
let nums = vec![18.0, 17.0, 5.0, 6.0];
|
||||
let divs = vec![4.0, 8.0, 5.0, 5.0];
|
||||
|
||||
let args: Vec<VectorRef> = vec![
|
||||
Arc::new(Float64Vector::from_vec(nums.clone())),
|
||||
Arc::new(Float64Vector::from_vec(divs.clone())),
|
||||
];
|
||||
let result = function.eval(&FunctionContext::default(), &args).unwrap();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![
|
||||
ColumnarValue::Array(Arc::new(Float64Array::from(nums.clone()))),
|
||||
ColumnarValue::Array(Arc::new(Float64Array::from(divs.clone()))),
|
||||
],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::Float64, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
let result = result.to_array(4).unwrap();
|
||||
let result = result.as_primitive::<Float64Type>();
|
||||
assert_eq!(result.len(), 4);
|
||||
for i in 0..4 {
|
||||
let p: f64 = nums[i] % divs[i];
|
||||
assert!(matches!(result.get(i), Value::Float64(v) if v == p));
|
||||
assert_eq!(result.value(i), p);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -195,37 +202,53 @@ mod tests {
|
||||
let nums = vec![27];
|
||||
let divs = vec![0];
|
||||
|
||||
let args: Vec<VectorRef> = vec![
|
||||
Arc::new(Int32Vector::from_vec(nums.clone())),
|
||||
Arc::new(Int32Vector::from_vec(divs.clone())),
|
||||
];
|
||||
let result = function.eval(&FunctionContext::default(), &args);
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![
|
||||
ColumnarValue::Array(Arc::new(Int32Array::from(nums))),
|
||||
ColumnarValue::Array(Arc::new(Int32Array::from(divs))),
|
||||
],
|
||||
arg_fields: vec![],
|
||||
number_rows: 1,
|
||||
return_field: Arc::new(Field::new("x", DataType::Int64, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let result = function.invoke_with_args(args);
|
||||
assert!(result.is_err());
|
||||
let err_msg = result.unwrap_err().output_msg();
|
||||
assert_eq!(
|
||||
err_msg,
|
||||
"Failed to perform compute operation on arrow arrays: Divide by zero error"
|
||||
);
|
||||
let err_msg = result.unwrap_err().to_string();
|
||||
assert_eq!(err_msg, "Arrow error: Divide by zero error");
|
||||
|
||||
let nums = vec![27];
|
||||
|
||||
let args: Vec<VectorRef> = vec![Arc::new(Int32Vector::from_vec(nums.clone()))];
|
||||
let result = function.eval(&FunctionContext::default(), &args);
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(Arc::new(Int32Array::from(nums)))],
|
||||
arg_fields: vec![],
|
||||
number_rows: 1,
|
||||
return_field: Arc::new(Field::new("x", DataType::Int64, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let result = function.invoke_with_args(args);
|
||||
assert!(result.is_err());
|
||||
let err_msg = result.unwrap_err().output_msg();
|
||||
assert!(
|
||||
err_msg.contains("The length of the args is not correct, expect exactly two, have: 1")
|
||||
let err_msg = result.unwrap_err().to_string();
|
||||
assert_eq!(
|
||||
err_msg,
|
||||
"Execution error: mod function requires 2 arguments, got 1"
|
||||
);
|
||||
|
||||
let nums = vec!["27"];
|
||||
let divs = vec!["4"];
|
||||
let args: Vec<VectorRef> = vec![
|
||||
Arc::new(StringVector::from(nums.clone())),
|
||||
Arc::new(StringVector::from(divs.clone())),
|
||||
];
|
||||
let result = function.eval(&FunctionContext::default(), &args);
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![
|
||||
ColumnarValue::Array(Arc::new(StringViewArray::from(nums))),
|
||||
ColumnarValue::Array(Arc::new(StringViewArray::from(divs))),
|
||||
],
|
||||
arg_fields: vec![],
|
||||
number_rows: 1,
|
||||
return_field: Arc::new(Field::new("x", DataType::Int64, false)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
let result = function.invoke_with_args(args);
|
||||
assert!(result.is_err());
|
||||
let err_msg = result.unwrap_err().output_msg();
|
||||
let err_msg = result.unwrap_err().to_string();
|
||||
assert!(err_msg.contains("Invalid arithmetic operation"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,14 +13,14 @@
|
||||
// limitations under the License.
|
||||
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::Result;
|
||||
use datafusion_expr::{Signature, Volatility};
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature, Volatility};
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::prelude::VectorRef;
|
||||
use datatypes::vectors::{Helper, Vector};
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::{Function, extract_args};
|
||||
use crate::scalars::expression::{EvalContext, scalar_binary_op};
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
@@ -42,14 +42,19 @@ impl Function for TestAndFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let [arg0, arg1] = extract_args(self.name(), &args)?;
|
||||
let columns = Helper::try_into_vectors(&[arg0, arg1]).unwrap();
|
||||
let col = scalar_binary_op::<bool, bool, bool, _>(
|
||||
&columns[0],
|
||||
&columns[1],
|
||||
scalar_and,
|
||||
&mut EvalContext::default(),
|
||||
)?;
|
||||
Ok(Arc::new(col))
|
||||
Ok(ColumnarValue::Array(col.to_arrow_array()))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -107,12 +107,11 @@ mod tests {
|
||||
|
||||
use common_query::prelude::ScalarValue;
|
||||
use datafusion::arrow::array::BooleanArray;
|
||||
use datafusion_common::arrow::array::AsArray;
|
||||
use datafusion_common::arrow::datatypes::DataType as ArrowDataType;
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
use datatypes::arrow::datatypes::Field;
|
||||
use datatypes::data_type::{ConcreteDataType, DataType};
|
||||
use datatypes::prelude::VectorRef;
|
||||
use datatypes::value::Value;
|
||||
use datatypes::vectors::{BooleanVector, ConstantVector};
|
||||
use session::context::QueryContextBuilder;
|
||||
|
||||
use super::*;
|
||||
@@ -124,20 +123,31 @@ mod tests {
|
||||
let f = Arc::new(TestAndFunction);
|
||||
let query_ctx = QueryContextBuilder::default().build().into();
|
||||
|
||||
let args: Vec<VectorRef> = vec![
|
||||
Arc::new(ConstantVector::new(
|
||||
Arc::new(BooleanVector::from(vec![true])),
|
||||
3,
|
||||
)),
|
||||
Arc::new(BooleanVector::from(vec![true, false, true])),
|
||||
];
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![
|
||||
datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
|
||||
true, true, true,
|
||||
]))),
|
||||
datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
|
||||
true, false, true,
|
||||
]))),
|
||||
],
|
||||
arg_fields: vec![],
|
||||
number_rows: 3,
|
||||
return_field: Arc::new(Field::new("x", ArrowDataType::Boolean, true)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
};
|
||||
|
||||
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
|
||||
let result = f
|
||||
.invoke_with_args(args)
|
||||
.and_then(|x| x.to_array(3))
|
||||
.unwrap();
|
||||
let vector = result.as_boolean();
|
||||
assert_eq!(3, vector.len());
|
||||
|
||||
for i in 0..3 {
|
||||
assert!(matches!(vector.get(i), Value::Boolean(b) if b == (i == 0 || i == 2)));
|
||||
}
|
||||
assert!(vector.value(0));
|
||||
assert!(!vector.value(1));
|
||||
assert!(vector.value(2));
|
||||
|
||||
// create a udf and test it again
|
||||
let udf = create_udf(f.clone(), query_ctx, Arc::new(FunctionState::default()));
|
||||
|
||||
@@ -13,15 +13,13 @@
|
||||
// limitations under the License.
|
||||
|
||||
use std::fmt::{self};
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion_expr::{Signature, Volatility};
|
||||
use datatypes::prelude::ScalarVector;
|
||||
use datatypes::vectors::{StringVector, VectorRef};
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::{Function, find_function_context};
|
||||
|
||||
/// A function to return current session timezone.
|
||||
#[derive(Clone, Debug, Default)]
|
||||
@@ -35,17 +33,21 @@ impl Function for TimezoneFunction {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Utf8)
|
||||
Ok(DataType::Utf8View)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
Signature::nullary(Volatility::Immutable)
|
||||
}
|
||||
|
||||
fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let func_ctx = find_function_context(&args)?;
|
||||
let tz = func_ctx.query_ctx.timezone().to_string();
|
||||
|
||||
Ok(Arc::new(StringVector::from_slice(&[&tz])) as _)
|
||||
Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(tz))))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,14 +61,18 @@ impl fmt::Display for TimezoneFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_schema::Field;
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
use session::context::QueryContextBuilder;
|
||||
|
||||
use super::*;
|
||||
use crate::function::FunctionContext;
|
||||
|
||||
#[test]
|
||||
fn test_build_function() {
|
||||
let build = TimezoneFunction;
|
||||
assert_eq!("timezone", build.name());
|
||||
assert_eq!(DataType::Utf8, build.return_type(&[]).unwrap());
|
||||
assert_eq!(DataType::Utf8View, build.return_type(&[]).unwrap());
|
||||
assert_eq!(build.signature(), Signature::nullary(Volatility::Immutable));
|
||||
|
||||
let query_ctx = QueryContextBuilder::default().build().into();
|
||||
@@ -75,8 +81,21 @@ mod tests {
|
||||
query_ctx,
|
||||
..Default::default()
|
||||
};
|
||||
let vector = build.eval(&func_ctx, &[]).unwrap();
|
||||
let expect: VectorRef = Arc::new(StringVector::from(vec!["UTC"]));
|
||||
assert_eq!(expect, vector);
|
||||
let mut config_options = ConfigOptions::default();
|
||||
config_options.extensions.insert(func_ctx);
|
||||
let config_options = Arc::new(config_options);
|
||||
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![],
|
||||
arg_fields: vec![],
|
||||
number_rows: 0,
|
||||
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
|
||||
config_options,
|
||||
};
|
||||
let result = build.invoke_with_args(args).unwrap();
|
||||
let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s))) = result else {
|
||||
unreachable!()
|
||||
};
|
||||
assert_eq!(s, "UTC");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,15 +13,14 @@
|
||||
// limitations under the License.
|
||||
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion_expr::{Signature, Volatility};
|
||||
use datatypes::vectors::{StringVector, VectorRef};
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
|
||||
use session::context::Channel;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::{Function, find_function_context};
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub(crate) struct VersionFunction;
|
||||
@@ -38,14 +37,18 @@ impl Function for VersionFunction {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Utf8)
|
||||
Ok(DataType::Utf8View)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
Signature::nullary(Volatility::Immutable)
|
||||
}
|
||||
|
||||
fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let func_ctx = find_function_context(&args)?;
|
||||
let version = match func_ctx.query_ctx.channel() {
|
||||
Channel::Mysql => {
|
||||
format!(
|
||||
@@ -60,7 +63,6 @@ impl Function for VersionFunction {
|
||||
}
|
||||
_ => common_version::version().to_string(),
|
||||
};
|
||||
let result = StringVector::from(vec![version]);
|
||||
Ok(Arc::new(result))
|
||||
Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(version))))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,9 @@ use std::slice;
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion::arrow::util::pretty::pretty_format_batches;
|
||||
use datafusion_common::arrow::array::ArrayRef;
|
||||
use datafusion_common::arrow::compute;
|
||||
use datafusion_common::arrow::datatypes::{DataType as ArrowDataType, SchemaRef as ArrowSchemaRef};
|
||||
use datatypes::arrow::array::RecordBatchOptions;
|
||||
use datatypes::prelude::DataType;
|
||||
use datatypes::schema::SchemaRef;
|
||||
@@ -28,8 +31,8 @@ use snafu::{OptionExt, ResultExt, ensure};
|
||||
|
||||
use crate::DfRecordBatch;
|
||||
use crate::error::{
|
||||
self, CastVectorSnafu, ColumnNotExistsSnafu, DataTypesSnafu, ProjectArrowRecordBatchSnafu,
|
||||
Result,
|
||||
self, ArrowComputeSnafu, CastVectorSnafu, ColumnNotExistsSnafu, DataTypesSnafu,
|
||||
ProjectArrowRecordBatchSnafu, Result,
|
||||
};
|
||||
|
||||
/// A two-dimensional batch of column-oriented data with a defined schema.
|
||||
@@ -49,6 +52,15 @@ impl RecordBatch {
|
||||
let columns: Vec<_> = columns.into_iter().collect();
|
||||
let arrow_arrays = columns.iter().map(|v| v.to_arrow_array()).collect();
|
||||
|
||||
// Casting the arrays here to match the schema, is a temporary solution to support Arrow's
|
||||
// view array types (`StringViewArray` and `BinaryViewArray`).
|
||||
// As to "support": the arrays here are created from vectors, which do not have types
|
||||
// corresponding to view arrays. What we can do is to only cast them.
|
||||
// As to "temporary": we are planing to use Arrow's RecordBatch directly in the read path.
|
||||
// the casting here will be removed in the end.
|
||||
// TODO(LFC): Remove the casting here once `Batch` is no longer used.
|
||||
let arrow_arrays = Self::cast_view_arrays(schema.arrow_schema(), arrow_arrays)?;
|
||||
|
||||
let df_record_batch = DfRecordBatch::try_new(schema.arrow_schema().clone(), arrow_arrays)
|
||||
.context(error::NewDfRecordBatchSnafu)?;
|
||||
|
||||
@@ -59,6 +71,24 @@ impl RecordBatch {
|
||||
})
|
||||
}
|
||||
|
||||
fn cast_view_arrays(
|
||||
schema: &ArrowSchemaRef,
|
||||
mut arrays: Vec<ArrayRef>,
|
||||
) -> Result<Vec<ArrayRef>> {
|
||||
for (f, a) in schema.fields().iter().zip(arrays.iter_mut()) {
|
||||
let expected = f.data_type();
|
||||
let actual = a.data_type();
|
||||
if matches!(
|
||||
(expected, actual),
|
||||
(ArrowDataType::Utf8View, ArrowDataType::Utf8)
|
||||
| (ArrowDataType::BinaryView, ArrowDataType::Binary)
|
||||
) {
|
||||
*a = compute::cast(a, expected).context(ArrowComputeSnafu)?;
|
||||
}
|
||||
}
|
||||
Ok(arrays)
|
||||
}
|
||||
|
||||
/// Create an empty [`RecordBatch`] from `schema`.
|
||||
pub fn new_empty(schema: SchemaRef) -> RecordBatch {
|
||||
let df_record_batch = DfRecordBatch::new_empty(schema.arrow_schema().clone());
|
||||
|
||||
@@ -16,14 +16,12 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_error::ext::BoxedError;
|
||||
use common_function::function::{FunctionContext, FunctionRef};
|
||||
use common_function::function::FunctionRef;
|
||||
use datafusion::arrow::datatypes::{DataType, TimeUnit};
|
||||
use datafusion_expr::{Signature, Volatility};
|
||||
use datafusion_substrait::extensions::Extensions;
|
||||
use query::QueryEngine;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use snafu::ResultExt;
|
||||
/// note here we are using the `substrait_proto_df` crate from the `substrait` module and
|
||||
/// rename it to `substrait_proto`
|
||||
use substrait::substrait_proto_df as substrait_proto;
|
||||
@@ -31,7 +29,7 @@ use substrait_proto::proto::extensions::SimpleExtensionDeclaration;
|
||||
use substrait_proto::proto::extensions::simple_extension_declaration::MappingType;
|
||||
|
||||
use crate::adapter::FlownodeContext;
|
||||
use crate::error::{Error, NotImplementedSnafu, UnexpectedSnafu};
|
||||
use crate::error::{Error, NotImplementedSnafu};
|
||||
use crate::expr::{TUMBLE_END, TUMBLE_START};
|
||||
/// a simple macro to generate a not implemented error
|
||||
macro_rules! not_impl_err {
|
||||
@@ -149,19 +147,6 @@ impl common_function::function::Function for TumbleFunction {
|
||||
fn signature(&self) -> Signature {
|
||||
Signature::variadic_any(Volatility::Immutable)
|
||||
}
|
||||
|
||||
fn eval(
|
||||
&self,
|
||||
_func_ctx: &FunctionContext,
|
||||
_columns: &[datatypes::prelude::VectorRef],
|
||||
) -> common_query::error::Result<datatypes::prelude::VectorRef> {
|
||||
UnexpectedSnafu {
|
||||
reason: "Tumbler function is not implemented for datafusion executor",
|
||||
}
|
||||
.fail()
|
||||
.map_err(BoxedError::new)
|
||||
.context(common_query::error::ExecuteSnafu)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -114,6 +114,8 @@ impl AnalyzerRule for DistPlannerAnalyzer {
|
||||
// seems to lost track on it: the `ConfigOptions` is recreated with its default values again.
|
||||
// So we create a custom `OptimizerConfig` with the desired `ConfigOptions`
|
||||
// to walk around the issue.
|
||||
// TODO(LFC): Maybe use DataFusion's `OptimizerContext` again
|
||||
// once https://github.com/apache/datafusion/pull/17742 is merged.
|
||||
struct OptimizerContext {
|
||||
inner: datafusion_optimizer::OptimizerContext,
|
||||
config: Arc<ConfigOptions>,
|
||||
|
||||
Reference in New Issue
Block a user