feat: add more MySQL-compatible string functions (#7454)

* feat: add more mysql string functions

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* refactor: use datafusion aliasing mechanism, close #7415

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* chore: comment

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* fix: comment and style

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

---------

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>
This commit is contained in:
dennis zhuang
2025-12-25 11:28:57 +08:00
committed by GitHub
parent d4870ee2af
commit 3866512cf6
12 changed files with 2461 additions and 120 deletions

View File

@@ -14,13 +14,31 @@
//! String scalar functions
mod elt;
mod field;
mod format;
mod insert;
mod locate;
mod regexp_extract;
mod space;
pub(crate) use elt::EltFunction;
pub(crate) use field::FieldFunction;
pub(crate) use format::FormatFunction;
pub(crate) use insert::InsertFunction;
pub(crate) use locate::LocateFunction;
pub(crate) use regexp_extract::RegexpExtractFunction;
pub(crate) use space::SpaceFunction;
use crate::function_registry::FunctionRegistry;
/// Register all string functions
pub fn register_string_functions(registry: &FunctionRegistry) {
EltFunction::register(registry);
FieldFunction::register(registry);
FormatFunction::register(registry);
InsertFunction::register(registry);
LocateFunction::register(registry);
RegexpExtractFunction::register(registry);
SpaceFunction::register(registry);
}

View File

@@ -0,0 +1,252 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! MySQL-compatible ELT function implementation.
//!
//! ELT(N, str1, str2, str3, ...) - Returns the Nth string from the list.
//! Returns NULL if N < 1 or N > number of strings.
use std::fmt;
use std::sync::Arc;
use datafusion_common::DataFusionError;
use datafusion_common::arrow::array::{Array, ArrayRef, AsArray, LargeStringBuilder};
use datafusion_common::arrow::compute::cast;
use datafusion_common::arrow::datatypes::DataType;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
use crate::function::Function;
use crate::function_registry::FunctionRegistry;
const NAME: &str = "elt";
/// MySQL-compatible ELT function.
///
/// Syntax: ELT(N, str1, str2, str3, ...)
/// Returns the Nth string argument. N is 1-based.
/// Returns NULL if N is NULL, N < 1, or N > number of string arguments.
#[derive(Debug)]
pub struct EltFunction {
signature: Signature,
}
impl EltFunction {
pub fn register(registry: &FunctionRegistry) {
registry.register_scalar(EltFunction::default());
}
}
impl Default for EltFunction {
fn default() -> Self {
Self {
// ELT takes a variable number of arguments: (Int64, String, String, ...)
signature: Signature::variadic_any(Volatility::Immutable),
}
}
}
impl fmt::Display for EltFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", NAME.to_ascii_uppercase())
}
}
impl Function for EltFunction {
fn name(&self) -> &str {
NAME
}
fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
Ok(DataType::LargeUtf8)
}
fn signature(&self) -> &Signature {
&self.signature
}
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
if args.args.len() < 2 {
return Err(DataFusionError::Execution(
"ELT requires at least 2 arguments: ELT(N, str1, ...)".to_string(),
));
}
let arrays = ColumnarValue::values_to_arrays(&args.args)?;
let len = arrays[0].len();
let num_strings = arrays.len() - 1;
// First argument is the index (N) - try to cast to Int64
let index_array = if arrays[0].data_type() == &DataType::Null {
// All NULLs - return all NULLs
let mut builder = LargeStringBuilder::with_capacity(len, 0);
for _ in 0..len {
builder.append_null();
}
return Ok(ColumnarValue::Array(Arc::new(builder.finish())));
} else {
cast(arrays[0].as_ref(), &DataType::Int64).map_err(|e| {
DataFusionError::Execution(format!("ELT: index argument cast failed: {}", e))
})?
};
// Cast string arguments to LargeUtf8
let string_arrays: Vec<ArrayRef> = arrays[1..]
.iter()
.enumerate()
.map(|(i, arr)| {
cast(arr.as_ref(), &DataType::LargeUtf8).map_err(|e| {
DataFusionError::Execution(format!(
"ELT: string argument {} cast failed: {}",
i + 1,
e
))
})
})
.collect::<datafusion_common::Result<Vec<_>>>()?;
let mut builder = LargeStringBuilder::with_capacity(len, len * 32);
for i in 0..len {
if index_array.is_null(i) {
builder.append_null();
continue;
}
let n = index_array
.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>()
.value(i);
// N is 1-based, check bounds
if n < 1 || n as usize > num_strings {
builder.append_null();
continue;
}
let str_idx = (n - 1) as usize;
let str_array = string_arrays[str_idx].as_string::<i64>();
if str_array.is_null(i) {
builder.append_null();
} else {
builder.append_value(str_array.value(i));
}
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use datafusion_common::arrow::array::{Int64Array, StringArray};
use datafusion_common::arrow::datatypes::Field;
use datafusion_expr::ScalarFunctionArgs;
use super::*;
fn create_args(arrays: Vec<ArrayRef>) -> ScalarFunctionArgs {
let arg_fields: Vec<_> = arrays
.iter()
.enumerate()
.map(|(i, arr)| {
Arc::new(Field::new(
format!("arg_{}", i),
arr.data_type().clone(),
true,
))
})
.collect();
ScalarFunctionArgs {
args: arrays.iter().cloned().map(ColumnarValue::Array).collect(),
arg_fields,
return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
number_rows: arrays[0].len(),
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
}
}
#[test]
fn test_elt_basic() {
let function = EltFunction::default();
let n = Arc::new(Int64Array::from(vec![1, 2, 3]));
let s1 = Arc::new(StringArray::from(vec!["a", "a", "a"]));
let s2 = Arc::new(StringArray::from(vec!["b", "b", "b"]));
let s3 = Arc::new(StringArray::from(vec!["c", "c", "c"]));
let args = create_args(vec![n, s1, s2, s3]);
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let str_array = array.as_string::<i64>();
assert_eq!(str_array.value(0), "a");
assert_eq!(str_array.value(1), "b");
assert_eq!(str_array.value(2), "c");
} else {
panic!("Expected array result");
}
}
#[test]
fn test_elt_out_of_bounds() {
let function = EltFunction::default();
let n = Arc::new(Int64Array::from(vec![0, 4, -1]));
let s1 = Arc::new(StringArray::from(vec!["a", "a", "a"]));
let s2 = Arc::new(StringArray::from(vec!["b", "b", "b"]));
let s3 = Arc::new(StringArray::from(vec!["c", "c", "c"]));
let args = create_args(vec![n, s1, s2, s3]);
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let str_array = array.as_string::<i64>();
assert!(str_array.is_null(0)); // 0 is out of bounds
assert!(str_array.is_null(1)); // 4 is out of bounds
assert!(str_array.is_null(2)); // -1 is out of bounds
} else {
panic!("Expected array result");
}
}
#[test]
fn test_elt_with_nulls() {
let function = EltFunction::default();
// Row 0: n=1, select s1="a" -> "a"
// Row 1: n=NULL -> NULL
// Row 2: n=1, select s1=NULL -> NULL
let n = Arc::new(Int64Array::from(vec![Some(1), None, Some(1)]));
let s1 = Arc::new(StringArray::from(vec![Some("a"), Some("a"), None]));
let s2 = Arc::new(StringArray::from(vec![Some("b"), Some("b"), Some("b")]));
let args = create_args(vec![n, s1, s2]);
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let str_array = array.as_string::<i64>();
assert_eq!(str_array.value(0), "a");
assert!(str_array.is_null(1)); // N is NULL
assert!(str_array.is_null(2)); // Selected string is NULL
} else {
panic!("Expected array result");
}
}
}

View File

@@ -0,0 +1,224 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! MySQL-compatible FIELD function implementation.
//!
//! FIELD(str, str1, str2, str3, ...) - Returns the 1-based index of str in the list.
//! Returns 0 if str is not found or is NULL.
use std::fmt;
use std::sync::Arc;
use datafusion_common::DataFusionError;
use datafusion_common::arrow::array::{Array, ArrayRef, AsArray, Int64Builder};
use datafusion_common::arrow::compute::cast;
use datafusion_common::arrow::datatypes::DataType;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
use crate::function::Function;
use crate::function_registry::FunctionRegistry;
const NAME: &str = "field";
/// MySQL-compatible FIELD function.
///
/// Syntax: FIELD(str, str1, str2, str3, ...)
/// Returns the 1-based index of str in the argument list (str1, str2, str3, ...).
/// Returns 0 if str is not found or is NULL.
#[derive(Debug)]
pub struct FieldFunction {
signature: Signature,
}
impl FieldFunction {
pub fn register(registry: &FunctionRegistry) {
registry.register_scalar(FieldFunction::default());
}
}
impl Default for FieldFunction {
fn default() -> Self {
Self {
// FIELD takes a variable number of arguments: (String, String, String, ...)
signature: Signature::variadic_any(Volatility::Immutable),
}
}
}
impl fmt::Display for FieldFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", NAME.to_ascii_uppercase())
}
}
impl Function for FieldFunction {
fn name(&self) -> &str {
NAME
}
fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
Ok(DataType::Int64)
}
fn signature(&self) -> &Signature {
&self.signature
}
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
if args.args.len() < 2 {
return Err(DataFusionError::Execution(
"FIELD requires at least 2 arguments: FIELD(str, str1, ...)".to_string(),
));
}
let arrays = ColumnarValue::values_to_arrays(&args.args)?;
let len = arrays[0].len();
// Cast all arguments to LargeUtf8
let string_arrays: Vec<ArrayRef> = arrays
.iter()
.enumerate()
.map(|(i, arr)| {
cast(arr.as_ref(), &DataType::LargeUtf8).map_err(|e| {
DataFusionError::Execution(format!("FIELD: argument {} cast failed: {}", i, e))
})
})
.collect::<datafusion_common::Result<Vec<_>>>()?;
let search_str = string_arrays[0].as_string::<i64>();
let mut builder = Int64Builder::with_capacity(len);
for i in 0..len {
// If search string is NULL, return 0
if search_str.is_null(i) {
builder.append_value(0);
continue;
}
let needle = search_str.value(i);
let mut found_idx = 0i64;
// Search through the list (starting from index 1 in string_arrays)
for (j, str_arr) in string_arrays[1..].iter().enumerate() {
let str_array = str_arr.as_string::<i64>();
if !str_array.is_null(i) && str_array.value(i) == needle {
found_idx = (j + 1) as i64; // 1-based index
break;
}
}
builder.append_value(found_idx);
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use datafusion_common::arrow::array::StringArray;
use datafusion_common::arrow::datatypes::Field;
use datafusion_expr::ScalarFunctionArgs;
use super::*;
fn create_args(arrays: Vec<ArrayRef>) -> ScalarFunctionArgs {
let arg_fields: Vec<_> = arrays
.iter()
.enumerate()
.map(|(i, arr)| {
Arc::new(Field::new(
format!("arg_{}", i),
arr.data_type().clone(),
true,
))
})
.collect();
ScalarFunctionArgs {
args: arrays.iter().cloned().map(ColumnarValue::Array).collect(),
arg_fields,
return_field: Arc::new(Field::new("result", DataType::Int64, true)),
number_rows: arrays[0].len(),
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
}
}
#[test]
fn test_field_basic() {
let function = FieldFunction::default();
let search = Arc::new(StringArray::from(vec!["b", "d", "a"]));
let s1 = Arc::new(StringArray::from(vec!["a", "a", "a"]));
let s2 = Arc::new(StringArray::from(vec!["b", "b", "b"]));
let s3 = Arc::new(StringArray::from(vec!["c", "c", "c"]));
let args = create_args(vec![search, s1, s2, s3]);
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
assert_eq!(int_array.value(0), 2); // "b" is at index 2
assert_eq!(int_array.value(1), 0); // "d" not found
assert_eq!(int_array.value(2), 1); // "a" is at index 1
} else {
panic!("Expected array result");
}
}
#[test]
fn test_field_with_null_search() {
let function = FieldFunction::default();
let search = Arc::new(StringArray::from(vec![Some("a"), None]));
let s1 = Arc::new(StringArray::from(vec!["a", "a"]));
let s2 = Arc::new(StringArray::from(vec!["b", "b"]));
let args = create_args(vec![search, s1, s2]);
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
assert_eq!(int_array.value(0), 1); // "a" found at index 1
assert_eq!(int_array.value(1), 0); // NULL search returns 0
} else {
panic!("Expected array result");
}
}
#[test]
fn test_field_case_sensitive() {
let function = FieldFunction::default();
let search = Arc::new(StringArray::from(vec!["A", "a"]));
let s1 = Arc::new(StringArray::from(vec!["a", "a"]));
let s2 = Arc::new(StringArray::from(vec!["A", "A"]));
let args = create_args(vec![search, s1, s2]);
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
assert_eq!(int_array.value(0), 2); // "A" matches at index 2
assert_eq!(int_array.value(1), 1); // "a" matches at index 1
} else {
panic!("Expected array result");
}
}
}

View File

@@ -0,0 +1,512 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! MySQL-compatible FORMAT function implementation.
//!
//! FORMAT(X, D) - Formats the number X with D decimal places using thousand separators.
use std::fmt;
use std::sync::Arc;
use datafusion_common::DataFusionError;
use datafusion_common::arrow::array::{Array, AsArray, LargeStringBuilder};
use datafusion_common::arrow::datatypes as arrow_types;
use datafusion_common::arrow::datatypes::DataType;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility};
use crate::function::Function;
use crate::function_registry::FunctionRegistry;
const NAME: &str = "format";
/// MySQL-compatible FORMAT function.
///
/// Syntax: FORMAT(X, D)
/// Formats the number X to a format like '#,###,###.##', rounded to D decimal places.
/// D can be 0 to 30.
///
/// Note: This implementation uses the en_US locale (comma as thousand separator,
/// period as decimal separator).
#[derive(Debug)]
pub struct FormatFunction {
signature: Signature,
}
impl FormatFunction {
pub fn register(registry: &FunctionRegistry) {
registry.register_scalar(FormatFunction::default());
}
}
impl Default for FormatFunction {
fn default() -> Self {
let mut signatures = Vec::new();
// Support various numeric types for X
let numeric_types = [
DataType::Float64,
DataType::Float32,
DataType::Int64,
DataType::Int32,
DataType::Int16,
DataType::Int8,
DataType::UInt64,
DataType::UInt32,
DataType::UInt16,
DataType::UInt8,
];
// D can be various integer types
let int_types = [
DataType::Int64,
DataType::Int32,
DataType::Int16,
DataType::Int8,
DataType::UInt64,
DataType::UInt32,
DataType::UInt16,
DataType::UInt8,
];
for x_type in &numeric_types {
for d_type in &int_types {
signatures.push(TypeSignature::Exact(vec![x_type.clone(), d_type.clone()]));
}
}
Self {
signature: Signature::one_of(signatures, Volatility::Immutable),
}
}
}
impl fmt::Display for FormatFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", NAME.to_ascii_uppercase())
}
}
impl Function for FormatFunction {
fn name(&self) -> &str {
NAME
}
fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
Ok(DataType::LargeUtf8)
}
fn signature(&self) -> &Signature {
&self.signature
}
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
if args.args.len() != 2 {
return Err(DataFusionError::Execution(
"FORMAT requires exactly 2 arguments: FORMAT(X, D)".to_string(),
));
}
let arrays = ColumnarValue::values_to_arrays(&args.args)?;
let len = arrays[0].len();
let x_array = &arrays[0];
let d_array = &arrays[1];
let mut builder = LargeStringBuilder::with_capacity(len, len * 20);
for i in 0..len {
if x_array.is_null(i) || d_array.is_null(i) {
builder.append_null();
continue;
}
let decimal_places = get_decimal_places(d_array, i)?.clamp(0, 30) as usize;
let formatted = match x_array.data_type() {
DataType::Float64 | DataType::Float32 => {
format_number_float(get_float_value(x_array, i)?, decimal_places)
}
DataType::Int64
| DataType::Int32
| DataType::Int16
| DataType::Int8
| DataType::UInt64
| DataType::UInt32
| DataType::UInt16
| DataType::UInt8 => format_number_integer(x_array, i, decimal_places)?,
_ => {
return Err(DataFusionError::Execution(format!(
"FORMAT: unsupported type {:?}",
x_array.data_type()
)));
}
};
builder.append_value(&formatted);
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
}
/// Get float value from various numeric types.
fn get_float_value(
array: &datafusion_common::arrow::array::ArrayRef,
index: usize,
) -> datafusion_common::Result<f64> {
match array.data_type() {
DataType::Float64 => Ok(array
.as_primitive::<arrow_types::Float64Type>()
.value(index)),
DataType::Float32 => Ok(array
.as_primitive::<arrow_types::Float32Type>()
.value(index) as f64),
_ => Err(DataFusionError::Execution(format!(
"FORMAT: unsupported type {:?}",
array.data_type()
))),
}
}
/// Get decimal places from various integer types.
///
/// MySQL clamps decimal places to `0..=30`. This function returns an `i64` so the caller can clamp.
fn get_decimal_places(
array: &datafusion_common::arrow::array::ArrayRef,
index: usize,
) -> datafusion_common::Result<i64> {
match array.data_type() {
DataType::Int64 => Ok(array.as_primitive::<arrow_types::Int64Type>().value(index)),
DataType::Int32 => Ok(array.as_primitive::<arrow_types::Int32Type>().value(index) as i64),
DataType::Int16 => Ok(array.as_primitive::<arrow_types::Int16Type>().value(index) as i64),
DataType::Int8 => Ok(array.as_primitive::<arrow_types::Int8Type>().value(index) as i64),
DataType::UInt64 => {
let v = array.as_primitive::<arrow_types::UInt64Type>().value(index);
Ok(if v > i64::MAX as u64 {
i64::MAX
} else {
v as i64
})
}
DataType::UInt32 => Ok(array.as_primitive::<arrow_types::UInt32Type>().value(index) as i64),
DataType::UInt16 => Ok(array.as_primitive::<arrow_types::UInt16Type>().value(index) as i64),
DataType::UInt8 => Ok(array.as_primitive::<arrow_types::UInt8Type>().value(index) as i64),
_ => Err(DataFusionError::Execution(format!(
"FORMAT: unsupported type {:?}",
array.data_type()
))),
}
}
fn format_number_integer(
array: &datafusion_common::arrow::array::ArrayRef,
index: usize,
decimal_places: usize,
) -> datafusion_common::Result<String> {
let (is_negative, abs_digits) = match array.data_type() {
DataType::Int64 => {
let v = array.as_primitive::<arrow_types::Int64Type>().value(index) as i128;
(v.is_negative(), v.unsigned_abs().to_string())
}
DataType::Int32 => {
let v = array.as_primitive::<arrow_types::Int32Type>().value(index) as i128;
(v.is_negative(), v.unsigned_abs().to_string())
}
DataType::Int16 => {
let v = array.as_primitive::<arrow_types::Int16Type>().value(index) as i128;
(v.is_negative(), v.unsigned_abs().to_string())
}
DataType::Int8 => {
let v = array.as_primitive::<arrow_types::Int8Type>().value(index) as i128;
(v.is_negative(), v.unsigned_abs().to_string())
}
DataType::UInt64 => {
let v = array.as_primitive::<arrow_types::UInt64Type>().value(index) as u128;
(false, v.to_string())
}
DataType::UInt32 => {
let v = array.as_primitive::<arrow_types::UInt32Type>().value(index) as u128;
(false, v.to_string())
}
DataType::UInt16 => {
let v = array.as_primitive::<arrow_types::UInt16Type>().value(index) as u128;
(false, v.to_string())
}
DataType::UInt8 => {
let v = array.as_primitive::<arrow_types::UInt8Type>().value(index) as u128;
(false, v.to_string())
}
_ => {
return Err(DataFusionError::Execution(format!(
"FORMAT: unsupported type {:?}",
array.data_type()
)));
}
};
let mut result = String::new();
if is_negative {
result.push('-');
}
result.push_str(&add_thousand_separators(&abs_digits));
if decimal_places > 0 {
result.push('.');
result.push_str(&"0".repeat(decimal_places));
}
Ok(result)
}
/// Format a float with thousand separators and `decimal_places` digits after decimal point.
fn format_number_float(x: f64, decimal_places: usize) -> String {
// Handle special cases
if x.is_nan() {
return "NaN".to_string();
}
if x.is_infinite() {
return if x.is_sign_positive() {
"Infinity".to_string()
} else {
"-Infinity".to_string()
};
}
// Round to decimal_places
let multiplier = 10f64.powi(decimal_places as i32);
let rounded = (x * multiplier).round() / multiplier;
// Split into integer and fractional parts
let is_negative = rounded < 0.0;
let abs_value = rounded.abs();
// Format with the specified decimal places
let formatted = if decimal_places == 0 {
format!("{:.0}", abs_value)
} else {
format!("{:.prec$}", abs_value, prec = decimal_places)
};
// Split at decimal point
let parts: Vec<&str> = formatted.split('.').collect();
let int_part = parts[0];
let dec_part = parts.get(1).copied();
// Add thousand separators to integer part
let int_with_sep = add_thousand_separators(int_part);
// Build result
let mut result = String::new();
if is_negative {
result.push('-');
}
result.push_str(&int_with_sep);
if let Some(dec) = dec_part {
result.push('.');
result.push_str(dec);
}
result
}
/// Add thousand separators (commas) to an integer string.
fn add_thousand_separators(s: &str) -> String {
let chars: Vec<char> = s.chars().collect();
let len = chars.len();
if len <= 3 {
return s.to_string();
}
let mut result = String::with_capacity(len + len / 3);
let first_group_len = len % 3;
let first_group_len = if first_group_len == 0 {
3
} else {
first_group_len
};
for (i, ch) in chars.iter().enumerate() {
if i > 0 && i >= first_group_len && (i - first_group_len) % 3 == 0 {
result.push(',');
}
result.push(*ch);
}
result
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use datafusion_common::arrow::array::{Float64Array, Int64Array};
use datafusion_common::arrow::datatypes::Field;
use datafusion_expr::ScalarFunctionArgs;
use super::*;
fn create_args(arrays: Vec<datafusion_common::arrow::array::ArrayRef>) -> ScalarFunctionArgs {
let arg_fields: Vec<_> = arrays
.iter()
.enumerate()
.map(|(i, arr)| {
Arc::new(Field::new(
format!("arg_{}", i),
arr.data_type().clone(),
true,
))
})
.collect();
ScalarFunctionArgs {
args: arrays.iter().cloned().map(ColumnarValue::Array).collect(),
arg_fields,
return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
number_rows: arrays[0].len(),
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
}
}
#[test]
fn test_format_basic() {
let function = FormatFunction::default();
let x = Arc::new(Float64Array::from(vec![1234567.891, 1234.5, 1234567.0]));
let d = Arc::new(Int64Array::from(vec![2, 0, 3]));
let args = create_args(vec![x, d]);
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let str_array = array.as_string::<i64>();
assert_eq!(str_array.value(0), "1,234,567.89");
assert_eq!(str_array.value(1), "1,235"); // rounded
assert_eq!(str_array.value(2), "1,234,567.000");
} else {
panic!("Expected array result");
}
}
#[test]
fn test_format_negative() {
let function = FormatFunction::default();
let x = Arc::new(Float64Array::from(vec![-1234567.891]));
let d = Arc::new(Int64Array::from(vec![2]));
let args = create_args(vec![x, d]);
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let str_array = array.as_string::<i64>();
assert_eq!(str_array.value(0), "-1,234,567.89");
} else {
panic!("Expected array result");
}
}
#[test]
fn test_format_small_numbers() {
let function = FormatFunction::default();
let x = Arc::new(Float64Array::from(vec![0.5, 12.345, 123.0]));
let d = Arc::new(Int64Array::from(vec![2, 2, 0]));
let args = create_args(vec![x, d]);
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let str_array = array.as_string::<i64>();
assert_eq!(str_array.value(0), "0.50");
assert_eq!(str_array.value(1), "12.35"); // rounded
assert_eq!(str_array.value(2), "123");
} else {
panic!("Expected array result");
}
}
#[test]
fn test_format_with_nulls() {
let function = FormatFunction::default();
let x = Arc::new(Float64Array::from(vec![Some(1234.5), None]));
let d = Arc::new(Int64Array::from(vec![2, 2]));
let args = create_args(vec![x, d]);
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let str_array = array.as_string::<i64>();
assert_eq!(str_array.value(0), "1,234.50");
assert!(str_array.is_null(1));
} else {
panic!("Expected array result");
}
}
#[test]
fn test_add_thousand_separators() {
assert_eq!(add_thousand_separators("1"), "1");
assert_eq!(add_thousand_separators("12"), "12");
assert_eq!(add_thousand_separators("123"), "123");
assert_eq!(add_thousand_separators("1234"), "1,234");
assert_eq!(add_thousand_separators("12345"), "12,345");
assert_eq!(add_thousand_separators("123456"), "123,456");
assert_eq!(add_thousand_separators("1234567"), "1,234,567");
assert_eq!(add_thousand_separators("12345678"), "12,345,678");
assert_eq!(add_thousand_separators("123456789"), "123,456,789");
}
#[test]
fn test_format_large_int_no_float_precision_loss() {
let function = FormatFunction::default();
// 2^53 + 1 cannot be represented exactly as f64.
let x = Arc::new(Int64Array::from(vec![9_007_199_254_740_993i64]));
let d = Arc::new(Int64Array::from(vec![0]));
let args = create_args(vec![x, d]);
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let str_array = array.as_string::<i64>();
assert_eq!(str_array.value(0), "9,007,199,254,740,993");
} else {
panic!("Expected array result");
}
}
#[test]
fn test_format_decimal_places_u64_overflow_clamps() {
use datafusion_common::arrow::array::UInt64Array;
let function = FormatFunction::default();
let x = Arc::new(Int64Array::from(vec![1]));
let d = Arc::new(UInt64Array::from(vec![u64::MAX]));
let args = create_args(vec![x, d]);
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let str_array = array.as_string::<i64>();
assert_eq!(str_array.value(0), format!("1.{}", "0".repeat(30)));
} else {
panic!("Expected array result");
}
}
}

View File

@@ -0,0 +1,345 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! MySQL-compatible INSERT function implementation.
//!
//! INSERT(str, pos, len, newstr) - Inserts newstr into str at position pos,
//! replacing len characters.
use std::fmt;
use std::sync::Arc;
use datafusion_common::DataFusionError;
use datafusion_common::arrow::array::{Array, ArrayRef, AsArray, LargeStringBuilder};
use datafusion_common::arrow::compute::cast;
use datafusion_common::arrow::datatypes::DataType;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility};
use crate::function::Function;
use crate::function_registry::FunctionRegistry;
const NAME: &str = "insert";
/// MySQL-compatible INSERT function.
///
/// Syntax: INSERT(str, pos, len, newstr)
/// Returns str with the substring beginning at position pos and len characters long
/// replaced by newstr.
///
/// - pos is 1-based
/// - If pos is out of range, returns the original string
/// - If len is out of range, replaces from pos to end of string
#[derive(Debug)]
pub struct InsertFunction {
signature: Signature,
}
impl InsertFunction {
pub fn register(registry: &FunctionRegistry) {
registry.register_scalar(InsertFunction::default());
}
}
impl Default for InsertFunction {
fn default() -> Self {
let mut signatures = Vec::new();
let string_types = [DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View];
let int_types = [
DataType::Int64,
DataType::Int32,
DataType::Int16,
DataType::Int8,
DataType::UInt64,
DataType::UInt32,
DataType::UInt16,
DataType::UInt8,
];
for str_type in &string_types {
for newstr_type in &string_types {
for pos_type in &int_types {
for len_type in &int_types {
signatures.push(TypeSignature::Exact(vec![
str_type.clone(),
pos_type.clone(),
len_type.clone(),
newstr_type.clone(),
]));
}
}
}
}
Self {
signature: Signature::one_of(signatures, Volatility::Immutable),
}
}
}
impl fmt::Display for InsertFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", NAME.to_ascii_uppercase())
}
}
impl Function for InsertFunction {
fn name(&self) -> &str {
NAME
}
fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
Ok(DataType::LargeUtf8)
}
fn signature(&self) -> &Signature {
&self.signature
}
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
if args.args.len() != 4 {
return Err(DataFusionError::Execution(
"INSERT requires exactly 4 arguments: INSERT(str, pos, len, newstr)".to_string(),
));
}
let arrays = ColumnarValue::values_to_arrays(&args.args)?;
let len = arrays[0].len();
// Cast string arguments to LargeUtf8
let str_array = cast_to_large_utf8(&arrays[0], "str")?;
let newstr_array = cast_to_large_utf8(&arrays[3], "newstr")?;
let pos_array = cast_to_int64(&arrays[1], "pos")?;
let replace_len_array = cast_to_int64(&arrays[2], "len")?;
let str_arr = str_array.as_string::<i64>();
let pos_arr = pos_array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
let len_arr =
replace_len_array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
let newstr_arr = newstr_array.as_string::<i64>();
let mut builder = LargeStringBuilder::with_capacity(len, len * 32);
for i in 0..len {
// Check for NULLs
if str_arr.is_null(i)
|| pos_array.is_null(i)
|| replace_len_array.is_null(i)
|| newstr_arr.is_null(i)
{
builder.append_null();
continue;
}
let original = str_arr.value(i);
let pos = pos_arr.value(i);
let replace_len = len_arr.value(i);
let new_str = newstr_arr.value(i);
let result = insert_string(original, pos, replace_len, new_str);
builder.append_value(&result);
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
}
/// Cast array to LargeUtf8 for uniform string access.
fn cast_to_large_utf8(array: &ArrayRef, name: &str) -> datafusion_common::Result<ArrayRef> {
cast(array.as_ref(), &DataType::LargeUtf8)
.map_err(|e| DataFusionError::Execution(format!("INSERT: {} cast failed: {}", name, e)))
}
fn cast_to_int64(array: &ArrayRef, name: &str) -> datafusion_common::Result<ArrayRef> {
cast(array.as_ref(), &DataType::Int64)
.map_err(|e| DataFusionError::Execution(format!("INSERT: {} cast failed: {}", name, e)))
}
/// Perform the INSERT string operation.
/// pos is 1-based. If pos < 1 or pos > len(str) + 1, returns original string.
fn insert_string(original: &str, pos: i64, replace_len: i64, new_str: &str) -> String {
let char_count = original.chars().count();
// MySQL behavior: if pos < 1 or pos > string length + 1, return original
if pos < 1 || pos as usize > char_count + 1 {
return original.to_string();
}
let start_idx = (pos - 1) as usize; // Convert to 0-based
// Calculate end index for replacement
let replace_len = if replace_len < 0 {
0
} else {
replace_len as usize
};
let end_idx = (start_idx + replace_len).min(char_count);
let start_byte = char_to_byte_idx(original, start_idx);
let end_byte = char_to_byte_idx(original, end_idx);
let mut result = String::with_capacity(original.len() + new_str.len());
result.push_str(&original[..start_byte]);
result.push_str(new_str);
result.push_str(&original[end_byte..]);
result
}
fn char_to_byte_idx(s: &str, char_idx: usize) -> usize {
s.char_indices()
.nth(char_idx)
.map(|(idx, _)| idx)
.unwrap_or(s.len())
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use datafusion_common::arrow::array::{Int64Array, StringArray};
use datafusion_common::arrow::datatypes::Field;
use datafusion_expr::ScalarFunctionArgs;
use super::*;
fn create_args(arrays: Vec<ArrayRef>) -> ScalarFunctionArgs {
let arg_fields: Vec<_> = arrays
.iter()
.enumerate()
.map(|(i, arr)| {
Arc::new(Field::new(
format!("arg_{}", i),
arr.data_type().clone(),
true,
))
})
.collect();
ScalarFunctionArgs {
args: arrays.iter().cloned().map(ColumnarValue::Array).collect(),
arg_fields,
return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
number_rows: arrays[0].len(),
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
}
}
#[test]
fn test_insert_basic() {
let function = InsertFunction::default();
// INSERT('Quadratic', 3, 4, 'What') => 'QuWhattic'
let str_arr = Arc::new(StringArray::from(vec!["Quadratic"]));
let pos = Arc::new(Int64Array::from(vec![3]));
let len = Arc::new(Int64Array::from(vec![4]));
let newstr = Arc::new(StringArray::from(vec!["What"]));
let args = create_args(vec![str_arr, pos, len, newstr]);
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let str_array = array.as_string::<i64>();
assert_eq!(str_array.value(0), "QuWhattic");
} else {
panic!("Expected array result");
}
}
#[test]
fn test_insert_out_of_range_pos() {
let function = InsertFunction::default();
// INSERT('Quadratic', 0, 4, 'What') => 'Quadratic' (pos < 1)
let str_arr = Arc::new(StringArray::from(vec!["Quadratic", "Quadratic"]));
let pos = Arc::new(Int64Array::from(vec![0, 100]));
let len = Arc::new(Int64Array::from(vec![4, 4]));
let newstr = Arc::new(StringArray::from(vec!["What", "What"]));
let args = create_args(vec![str_arr, pos, len, newstr]);
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let str_array = array.as_string::<i64>();
assert_eq!(str_array.value(0), "Quadratic"); // pos < 1
assert_eq!(str_array.value(1), "Quadratic"); // pos > length
} else {
panic!("Expected array result");
}
}
#[test]
fn test_insert_replace_to_end() {
let function = InsertFunction::default();
// INSERT('Quadratic', 3, 100, 'What') => 'QuWhat' (len exceeds remaining)
let str_arr = Arc::new(StringArray::from(vec!["Quadratic"]));
let pos = Arc::new(Int64Array::from(vec![3]));
let len = Arc::new(Int64Array::from(vec![100]));
let newstr = Arc::new(StringArray::from(vec!["What"]));
let args = create_args(vec![str_arr, pos, len, newstr]);
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let str_array = array.as_string::<i64>();
assert_eq!(str_array.value(0), "QuWhat");
} else {
panic!("Expected array result");
}
}
#[test]
fn test_insert_unicode() {
let function = InsertFunction::default();
// INSERT('hello世界', 6, 1, 'の') => 'helloの界'
let str_arr = Arc::new(StringArray::from(vec!["hello世界"]));
let pos = Arc::new(Int64Array::from(vec![6]));
let len = Arc::new(Int64Array::from(vec![1]));
let newstr = Arc::new(StringArray::from(vec![""]));
let args = create_args(vec![str_arr, pos, len, newstr]);
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let str_array = array.as_string::<i64>();
assert_eq!(str_array.value(0), "helloの界");
} else {
panic!("Expected array result");
}
}
#[test]
fn test_insert_with_nulls() {
let function = InsertFunction::default();
let str_arr = Arc::new(StringArray::from(vec![Some("hello"), None]));
let pos = Arc::new(Int64Array::from(vec![1, 1]));
let len = Arc::new(Int64Array::from(vec![1, 1]));
let newstr = Arc::new(StringArray::from(vec!["X", "X"]));
let args = create_args(vec![str_arr, pos, len, newstr]);
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let str_array = array.as_string::<i64>();
assert_eq!(str_array.value(0), "Xello");
assert!(str_array.is_null(1));
} else {
panic!("Expected array result");
}
}
}

View File

@@ -0,0 +1,373 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! MySQL-compatible LOCATE function implementation.
//!
//! LOCATE(substr, str) - Returns the position of the first occurrence of substr in str (1-based).
//! LOCATE(substr, str, pos) - Returns the position of the first occurrence of substr in str,
//! starting from position pos.
//! Returns 0 if substr is not found.
use std::fmt;
use std::sync::Arc;
use datafusion_common::DataFusionError;
use datafusion_common::arrow::array::{Array, ArrayRef, AsArray, Int64Builder};
use datafusion_common::arrow::compute::cast;
use datafusion_common::arrow::datatypes::DataType;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility};
use crate::function::Function;
use crate::function_registry::FunctionRegistry;
const NAME: &str = "locate";
/// MySQL-compatible LOCATE function.
///
/// Syntax:
/// - LOCATE(substr, str) - Returns 1-based position of substr in str, or 0 if not found.
/// - LOCATE(substr, str, pos) - Same, but starts searching from position pos.
#[derive(Debug)]
pub struct LocateFunction {
signature: Signature,
}
impl LocateFunction {
pub fn register(registry: &FunctionRegistry) {
registry.register_scalar(LocateFunction::default());
}
}
impl Default for LocateFunction {
fn default() -> Self {
// Support 2 or 3 arguments with various string types
let mut signatures = Vec::new();
let string_types = [DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View];
let int_types = [
DataType::Int64,
DataType::Int32,
DataType::Int16,
DataType::Int8,
DataType::UInt64,
DataType::UInt32,
DataType::UInt16,
DataType::UInt8,
];
// 2-argument form: LOCATE(substr, str)
for substr_type in &string_types {
for str_type in &string_types {
signatures.push(TypeSignature::Exact(vec![
substr_type.clone(),
str_type.clone(),
]));
}
}
// 3-argument form: LOCATE(substr, str, pos)
for substr_type in &string_types {
for str_type in &string_types {
for pos_type in &int_types {
signatures.push(TypeSignature::Exact(vec![
substr_type.clone(),
str_type.clone(),
pos_type.clone(),
]));
}
}
}
Self {
signature: Signature::one_of(signatures, Volatility::Immutable),
}
}
}
impl fmt::Display for LocateFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", NAME.to_ascii_uppercase())
}
}
impl Function for LocateFunction {
fn name(&self) -> &str {
NAME
}
fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
Ok(DataType::Int64)
}
fn signature(&self) -> &Signature {
&self.signature
}
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let arg_count = args.args.len();
if !(2..=3).contains(&arg_count) {
return Err(DataFusionError::Execution(
"LOCATE requires 2 or 3 arguments: LOCATE(substr, str) or LOCATE(substr, str, pos)"
.to_string(),
));
}
let arrays = ColumnarValue::values_to_arrays(&args.args)?;
// Cast string arguments to LargeUtf8 for uniform access
let substr_array = cast_to_large_utf8(&arrays[0], "substr")?;
let str_array = cast_to_large_utf8(&arrays[1], "str")?;
let substr = substr_array.as_string::<i64>();
let str_arr = str_array.as_string::<i64>();
let len = substr.len();
// Handle optional pos argument
let pos_array: Option<ArrayRef> = if arg_count == 3 {
Some(cast_to_int64(&arrays[2], "pos")?)
} else {
None
};
let mut builder = Int64Builder::with_capacity(len);
for i in 0..len {
if substr.is_null(i) || str_arr.is_null(i) {
builder.append_null();
continue;
}
let needle = substr.value(i);
let haystack = str_arr.value(i);
// Get starting position (1-based in MySQL, convert to 0-based)
let start_pos = if let Some(ref pos_arr) = pos_array {
if pos_arr.is_null(i) {
builder.append_null();
continue;
}
let pos = pos_arr
.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>()
.value(i);
if pos < 1 {
// MySQL returns 0 for pos < 1
builder.append_value(0);
continue;
}
(pos - 1) as usize
} else {
0
};
// Find position using character-based indexing (for Unicode support)
let result = locate_substr(haystack, needle, start_pos);
builder.append_value(result);
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
}
/// Cast array to LargeUtf8 for uniform string access.
fn cast_to_large_utf8(array: &ArrayRef, name: &str) -> datafusion_common::Result<ArrayRef> {
cast(array.as_ref(), &DataType::LargeUtf8)
.map_err(|e| DataFusionError::Execution(format!("LOCATE: {} cast failed: {}", name, e)))
}
fn cast_to_int64(array: &ArrayRef, name: &str) -> datafusion_common::Result<ArrayRef> {
cast(array.as_ref(), &DataType::Int64)
.map_err(|e| DataFusionError::Execution(format!("LOCATE: {} cast failed: {}", name, e)))
}
/// Find the 1-based position of needle in haystack, starting from start_pos (0-based character index).
/// Returns 0 if not found.
fn locate_substr(haystack: &str, needle: &str, start_pos: usize) -> i64 {
// Handle empty needle - MySQL returns start_pos + 1
if needle.is_empty() {
let char_count = haystack.chars().count();
return if start_pos <= char_count {
(start_pos + 1) as i64
} else {
0
};
}
// Convert start_pos (character index) to byte index
let byte_start = haystack
.char_indices()
.nth(start_pos)
.map(|(idx, _)| idx)
.unwrap_or(haystack.len());
if byte_start >= haystack.len() {
return 0;
}
// Search in the substring
let search_str = &haystack[byte_start..];
if let Some(byte_pos) = search_str.find(needle) {
// Convert byte position back to character position
let char_pos = search_str[..byte_pos].chars().count();
// Return 1-based position relative to original string
(start_pos + char_pos + 1) as i64
} else {
0
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use datafusion_common::arrow::array::StringArray;
use datafusion_common::arrow::datatypes::Field;
use datafusion_expr::ScalarFunctionArgs;
use super::*;
fn create_args(arrays: Vec<ArrayRef>) -> ScalarFunctionArgs {
let arg_fields: Vec<_> = arrays
.iter()
.enumerate()
.map(|(i, arr)| {
Arc::new(Field::new(
format!("arg_{}", i),
arr.data_type().clone(),
true,
))
})
.collect();
ScalarFunctionArgs {
args: arrays.iter().cloned().map(ColumnarValue::Array).collect(),
arg_fields,
return_field: Arc::new(Field::new("result", DataType::Int64, true)),
number_rows: arrays[0].len(),
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
}
}
#[test]
fn test_locate_basic() {
let function = LocateFunction::default();
let substr = Arc::new(StringArray::from(vec!["world", "xyz", "hello"]));
let str_arr = Arc::new(StringArray::from(vec![
"hello world",
"hello world",
"hello world",
]));
let args = create_args(vec![substr, str_arr]);
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
assert_eq!(int_array.value(0), 7); // "world" at position 7
assert_eq!(int_array.value(1), 0); // "xyz" not found
assert_eq!(int_array.value(2), 1); // "hello" at position 1
} else {
panic!("Expected array result");
}
}
#[test]
fn test_locate_with_position() {
let function = LocateFunction::default();
let substr = Arc::new(StringArray::from(vec!["o", "o", "o"]));
let str_arr = Arc::new(StringArray::from(vec![
"hello world",
"hello world",
"hello world",
]));
let pos = Arc::new(datafusion_common::arrow::array::Int64Array::from(vec![
1, 5, 8,
]));
let args = create_args(vec![substr, str_arr, pos]);
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
assert_eq!(int_array.value(0), 5); // first 'o' at position 5
assert_eq!(int_array.value(1), 5); // 'o' at position 5 (start from 5)
assert_eq!(int_array.value(2), 8); // 'o' in "world" at position 8
} else {
panic!("Expected array result");
}
}
#[test]
fn test_locate_unicode() {
let function = LocateFunction::default();
let substr = Arc::new(StringArray::from(vec!["", ""]));
let str_arr = Arc::new(StringArray::from(vec!["hello世界", "hello世界"]));
let args = create_args(vec![substr, str_arr]);
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
assert_eq!(int_array.value(0), 6); // "世" at position 6
assert_eq!(int_array.value(1), 7); // "界" at position 7
} else {
panic!("Expected array result");
}
}
#[test]
fn test_locate_empty_needle() {
let function = LocateFunction::default();
let substr = Arc::new(StringArray::from(vec!["", ""]));
let str_arr = Arc::new(StringArray::from(vec!["hello", "hello"]));
let pos = Arc::new(datafusion_common::arrow::array::Int64Array::from(vec![
1, 3,
]));
let args = create_args(vec![substr, str_arr, pos]);
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
assert_eq!(int_array.value(0), 1); // empty string at pos 1
assert_eq!(int_array.value(1), 3); // empty string at pos 3
} else {
panic!("Expected array result");
}
}
#[test]
fn test_locate_with_nulls() {
let function = LocateFunction::default();
let substr = Arc::new(StringArray::from(vec![Some("o"), None]));
let str_arr = Arc::new(StringArray::from(vec![Some("hello"), Some("hello")]));
let args = create_args(vec![substr, str_arr]);
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
assert_eq!(int_array.value(0), 5);
assert!(int_array.is_null(1));
} else {
panic!("Expected array result");
}
}
}

View File

@@ -0,0 +1,252 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! MySQL-compatible SPACE function implementation.
//!
//! SPACE(N) - Returns a string consisting of N space characters.
use std::fmt;
use std::sync::Arc;
use datafusion_common::DataFusionError;
use datafusion_common::arrow::array::{Array, AsArray, LargeStringBuilder};
use datafusion_common::arrow::datatypes::DataType;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility};
use crate::function::Function;
use crate::function_registry::FunctionRegistry;
const NAME: &str = "space";
// Safety limit for maximum number of spaces
const MAX_SPACE_COUNT: i64 = 1024 * 1024; // 1MB of spaces
/// MySQL-compatible SPACE function.
///
/// Syntax: SPACE(N)
/// Returns a string consisting of N space characters.
/// Returns NULL if N is NULL.
/// Returns empty string if N < 0.
#[derive(Debug)]
pub struct SpaceFunction {
signature: Signature,
}
impl SpaceFunction {
pub fn register(registry: &FunctionRegistry) {
registry.register_scalar(SpaceFunction::default());
}
}
impl Default for SpaceFunction {
fn default() -> Self {
Self {
signature: Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Int64]),
TypeSignature::Exact(vec![DataType::Int32]),
TypeSignature::Exact(vec![DataType::Int16]),
TypeSignature::Exact(vec![DataType::Int8]),
TypeSignature::Exact(vec![DataType::UInt64]),
TypeSignature::Exact(vec![DataType::UInt32]),
TypeSignature::Exact(vec![DataType::UInt16]),
TypeSignature::Exact(vec![DataType::UInt8]),
],
Volatility::Immutable,
),
}
}
}
impl fmt::Display for SpaceFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", NAME.to_ascii_uppercase())
}
}
impl Function for SpaceFunction {
fn name(&self) -> &str {
NAME
}
fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
Ok(DataType::LargeUtf8)
}
fn signature(&self) -> &Signature {
&self.signature
}
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
if args.args.len() != 1 {
return Err(DataFusionError::Execution(
"SPACE requires exactly 1 argument: SPACE(N)".to_string(),
));
}
let arrays = ColumnarValue::values_to_arrays(&args.args)?;
let len = arrays[0].len();
let n_array = &arrays[0];
let mut builder = LargeStringBuilder::with_capacity(len, len * 10);
for i in 0..len {
if n_array.is_null(i) {
builder.append_null();
continue;
}
let n = get_int_value(n_array, i)?;
if n < 0 {
// MySQL returns empty string for negative values
builder.append_value("");
} else if n > MAX_SPACE_COUNT {
return Err(DataFusionError::Execution(format!(
"SPACE: requested {} spaces exceeds maximum allowed ({})",
n, MAX_SPACE_COUNT
)));
} else {
let spaces = " ".repeat(n as usize);
builder.append_value(&spaces);
}
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
}
/// Extract integer value from various integer types.
fn get_int_value(
array: &datafusion_common::arrow::array::ArrayRef,
index: usize,
) -> datafusion_common::Result<i64> {
use datafusion_common::arrow::datatypes as arrow_types;
match array.data_type() {
DataType::Int64 => Ok(array.as_primitive::<arrow_types::Int64Type>().value(index)),
DataType::Int32 => Ok(array.as_primitive::<arrow_types::Int32Type>().value(index) as i64),
DataType::Int16 => Ok(array.as_primitive::<arrow_types::Int16Type>().value(index) as i64),
DataType::Int8 => Ok(array.as_primitive::<arrow_types::Int8Type>().value(index) as i64),
DataType::UInt64 => {
let v = array.as_primitive::<arrow_types::UInt64Type>().value(index);
if v > i64::MAX as u64 {
Err(DataFusionError::Execution(format!(
"SPACE: value {} exceeds maximum",
v
)))
} else {
Ok(v as i64)
}
}
DataType::UInt32 => Ok(array.as_primitive::<arrow_types::UInt32Type>().value(index) as i64),
DataType::UInt16 => Ok(array.as_primitive::<arrow_types::UInt16Type>().value(index) as i64),
DataType::UInt8 => Ok(array.as_primitive::<arrow_types::UInt8Type>().value(index) as i64),
_ => Err(DataFusionError::Execution(format!(
"SPACE: unsupported type {:?}",
array.data_type()
))),
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use datafusion_common::arrow::array::Int64Array;
use datafusion_common::arrow::datatypes::Field;
use datafusion_expr::ScalarFunctionArgs;
use super::*;
fn create_args(arrays: Vec<datafusion_common::arrow::array::ArrayRef>) -> ScalarFunctionArgs {
let arg_fields: Vec<_> = arrays
.iter()
.enumerate()
.map(|(i, arr)| {
Arc::new(Field::new(
format!("arg_{}", i),
arr.data_type().clone(),
true,
))
})
.collect();
ScalarFunctionArgs {
args: arrays.iter().cloned().map(ColumnarValue::Array).collect(),
arg_fields,
return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
number_rows: arrays[0].len(),
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
}
}
#[test]
fn test_space_basic() {
let function = SpaceFunction::default();
let n = Arc::new(Int64Array::from(vec![0, 1, 5]));
let args = create_args(vec![n]);
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let str_array = array.as_string::<i64>();
assert_eq!(str_array.value(0), "");
assert_eq!(str_array.value(1), " ");
assert_eq!(str_array.value(2), " ");
} else {
panic!("Expected array result");
}
}
#[test]
fn test_space_negative() {
let function = SpaceFunction::default();
let n = Arc::new(Int64Array::from(vec![-1, -100]));
let args = create_args(vec![n]);
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let str_array = array.as_string::<i64>();
assert_eq!(str_array.value(0), "");
assert_eq!(str_array.value(1), "");
} else {
panic!("Expected array result");
}
}
#[test]
fn test_space_with_nulls() {
let function = SpaceFunction::default();
let n = Arc::new(Int64Array::from(vec![Some(3), None]));
let args = create_args(vec![n]);
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let str_array = array.as_string::<i64>();
assert_eq!(str_array.value(0), " ");
assert!(str_array.is_null(1));
} else {
panic!("Expected array result");
}
}
}