mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-19 14:30:43 +00:00
feat: add more MySQL-compatible string functions (#7454)
* feat: add more mysql string functions Signed-off-by: Dennis Zhuang <killme2008@gmail.com> * refactor: use datafusion aliasing mechanism, close #7415 Signed-off-by: Dennis Zhuang <killme2008@gmail.com> * chore: comment Signed-off-by: Dennis Zhuang <killme2008@gmail.com> * fix: comment and style Signed-off-by: Dennis Zhuang <killme2008@gmail.com> --------- Signed-off-by: Dennis Zhuang <killme2008@gmail.com>
This commit is contained in:
@@ -14,13 +14,31 @@
|
||||
|
||||
//! String scalar functions
|
||||
|
||||
mod elt;
|
||||
mod field;
|
||||
mod format;
|
||||
mod insert;
|
||||
mod locate;
|
||||
mod regexp_extract;
|
||||
mod space;
|
||||
|
||||
pub(crate) use elt::EltFunction;
|
||||
pub(crate) use field::FieldFunction;
|
||||
pub(crate) use format::FormatFunction;
|
||||
pub(crate) use insert::InsertFunction;
|
||||
pub(crate) use locate::LocateFunction;
|
||||
pub(crate) use regexp_extract::RegexpExtractFunction;
|
||||
pub(crate) use space::SpaceFunction;
|
||||
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
|
||||
/// Register all string functions
|
||||
pub fn register_string_functions(registry: &FunctionRegistry) {
|
||||
EltFunction::register(registry);
|
||||
FieldFunction::register(registry);
|
||||
FormatFunction::register(registry);
|
||||
InsertFunction::register(registry);
|
||||
LocateFunction::register(registry);
|
||||
RegexpExtractFunction::register(registry);
|
||||
SpaceFunction::register(registry);
|
||||
}
|
||||
|
||||
252
src/common/function/src/scalars/string/elt.rs
Normal file
252
src/common/function/src/scalars/string/elt.rs
Normal file
@@ -0,0 +1,252 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! MySQL-compatible ELT function implementation.
|
||||
//!
|
||||
//! ELT(N, str1, str2, str3, ...) - Returns the Nth string from the list.
|
||||
//! Returns NULL if N < 1 or N > number of strings.
|
||||
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_common::arrow::array::{Array, ArrayRef, AsArray, LargeStringBuilder};
|
||||
use datafusion_common::arrow::compute::cast;
|
||||
use datafusion_common::arrow::datatypes::DataType;
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
|
||||
|
||||
use crate::function::Function;
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
|
||||
const NAME: &str = "elt";
|
||||
|
||||
/// MySQL-compatible ELT function.
|
||||
///
|
||||
/// Syntax: ELT(N, str1, str2, str3, ...)
|
||||
/// Returns the Nth string argument. N is 1-based.
|
||||
/// Returns NULL if N is NULL, N < 1, or N > number of string arguments.
|
||||
#[derive(Debug)]
|
||||
pub struct EltFunction {
|
||||
signature: Signature,
|
||||
}
|
||||
|
||||
impl EltFunction {
|
||||
pub fn register(registry: &FunctionRegistry) {
|
||||
registry.register_scalar(EltFunction::default());
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for EltFunction {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
// ELT takes a variable number of arguments: (Int64, String, String, ...)
|
||||
signature: Signature::variadic_any(Volatility::Immutable),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for EltFunction {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", NAME.to_ascii_uppercase())
|
||||
}
|
||||
}
|
||||
|
||||
impl Function for EltFunction {
|
||||
fn name(&self) -> &str {
|
||||
NAME
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
|
||||
Ok(DataType::LargeUtf8)
|
||||
}
|
||||
|
||||
fn signature(&self) -> &Signature {
|
||||
&self.signature
|
||||
}
|
||||
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
if args.args.len() < 2 {
|
||||
return Err(DataFusionError::Execution(
|
||||
"ELT requires at least 2 arguments: ELT(N, str1, ...)".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let arrays = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let len = arrays[0].len();
|
||||
let num_strings = arrays.len() - 1;
|
||||
|
||||
// First argument is the index (N) - try to cast to Int64
|
||||
let index_array = if arrays[0].data_type() == &DataType::Null {
|
||||
// All NULLs - return all NULLs
|
||||
let mut builder = LargeStringBuilder::with_capacity(len, 0);
|
||||
for _ in 0..len {
|
||||
builder.append_null();
|
||||
}
|
||||
return Ok(ColumnarValue::Array(Arc::new(builder.finish())));
|
||||
} else {
|
||||
cast(arrays[0].as_ref(), &DataType::Int64).map_err(|e| {
|
||||
DataFusionError::Execution(format!("ELT: index argument cast failed: {}", e))
|
||||
})?
|
||||
};
|
||||
|
||||
// Cast string arguments to LargeUtf8
|
||||
let string_arrays: Vec<ArrayRef> = arrays[1..]
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, arr)| {
|
||||
cast(arr.as_ref(), &DataType::LargeUtf8).map_err(|e| {
|
||||
DataFusionError::Execution(format!(
|
||||
"ELT: string argument {} cast failed: {}",
|
||||
i + 1,
|
||||
e
|
||||
))
|
||||
})
|
||||
})
|
||||
.collect::<datafusion_common::Result<Vec<_>>>()?;
|
||||
|
||||
let mut builder = LargeStringBuilder::with_capacity(len, len * 32);
|
||||
|
||||
for i in 0..len {
|
||||
if index_array.is_null(i) {
|
||||
builder.append_null();
|
||||
continue;
|
||||
}
|
||||
|
||||
let n = index_array
|
||||
.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>()
|
||||
.value(i);
|
||||
|
||||
// N is 1-based, check bounds
|
||||
if n < 1 || n as usize > num_strings {
|
||||
builder.append_null();
|
||||
continue;
|
||||
}
|
||||
|
||||
let str_idx = (n - 1) as usize;
|
||||
let str_array = string_arrays[str_idx].as_string::<i64>();
|
||||
|
||||
if str_array.is_null(i) {
|
||||
builder.append_null();
|
||||
} else {
|
||||
builder.append_value(str_array.value(i));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_common::arrow::array::{Int64Array, StringArray};
|
||||
use datafusion_common::arrow::datatypes::Field;
|
||||
use datafusion_expr::ScalarFunctionArgs;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn create_args(arrays: Vec<ArrayRef>) -> ScalarFunctionArgs {
|
||||
let arg_fields: Vec<_> = arrays
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, arr)| {
|
||||
Arc::new(Field::new(
|
||||
format!("arg_{}", i),
|
||||
arr.data_type().clone(),
|
||||
true,
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
|
||||
ScalarFunctionArgs {
|
||||
args: arrays.iter().cloned().map(ColumnarValue::Array).collect(),
|
||||
arg_fields,
|
||||
return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
|
||||
number_rows: arrays[0].len(),
|
||||
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_elt_basic() {
|
||||
let function = EltFunction::default();
|
||||
|
||||
let n = Arc::new(Int64Array::from(vec![1, 2, 3]));
|
||||
let s1 = Arc::new(StringArray::from(vec!["a", "a", "a"]));
|
||||
let s2 = Arc::new(StringArray::from(vec!["b", "b", "b"]));
|
||||
let s3 = Arc::new(StringArray::from(vec!["c", "c", "c"]));
|
||||
|
||||
let args = create_args(vec![n, s1, s2, s3]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), "a");
|
||||
assert_eq!(str_array.value(1), "b");
|
||||
assert_eq!(str_array.value(2), "c");
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_elt_out_of_bounds() {
|
||||
let function = EltFunction::default();
|
||||
|
||||
let n = Arc::new(Int64Array::from(vec![0, 4, -1]));
|
||||
let s1 = Arc::new(StringArray::from(vec!["a", "a", "a"]));
|
||||
let s2 = Arc::new(StringArray::from(vec!["b", "b", "b"]));
|
||||
let s3 = Arc::new(StringArray::from(vec!["c", "c", "c"]));
|
||||
|
||||
let args = create_args(vec![n, s1, s2, s3]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert!(str_array.is_null(0)); // 0 is out of bounds
|
||||
assert!(str_array.is_null(1)); // 4 is out of bounds
|
||||
assert!(str_array.is_null(2)); // -1 is out of bounds
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_elt_with_nulls() {
|
||||
let function = EltFunction::default();
|
||||
|
||||
// Row 0: n=1, select s1="a" -> "a"
|
||||
// Row 1: n=NULL -> NULL
|
||||
// Row 2: n=1, select s1=NULL -> NULL
|
||||
let n = Arc::new(Int64Array::from(vec![Some(1), None, Some(1)]));
|
||||
let s1 = Arc::new(StringArray::from(vec![Some("a"), Some("a"), None]));
|
||||
let s2 = Arc::new(StringArray::from(vec![Some("b"), Some("b"), Some("b")]));
|
||||
|
||||
let args = create_args(vec![n, s1, s2]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), "a");
|
||||
assert!(str_array.is_null(1)); // N is NULL
|
||||
assert!(str_array.is_null(2)); // Selected string is NULL
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
}
|
||||
224
src/common/function/src/scalars/string/field.rs
Normal file
224
src/common/function/src/scalars/string/field.rs
Normal file
@@ -0,0 +1,224 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! MySQL-compatible FIELD function implementation.
|
||||
//!
|
||||
//! FIELD(str, str1, str2, str3, ...) - Returns the 1-based index of str in the list.
|
||||
//! Returns 0 if str is not found or is NULL.
|
||||
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_common::arrow::array::{Array, ArrayRef, AsArray, Int64Builder};
|
||||
use datafusion_common::arrow::compute::cast;
|
||||
use datafusion_common::arrow::datatypes::DataType;
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
|
||||
|
||||
use crate::function::Function;
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
|
||||
const NAME: &str = "field";
|
||||
|
||||
/// MySQL-compatible FIELD function.
|
||||
///
|
||||
/// Syntax: FIELD(str, str1, str2, str3, ...)
|
||||
/// Returns the 1-based index of str in the argument list (str1, str2, str3, ...).
|
||||
/// Returns 0 if str is not found or is NULL.
|
||||
#[derive(Debug)]
|
||||
pub struct FieldFunction {
|
||||
signature: Signature,
|
||||
}
|
||||
|
||||
impl FieldFunction {
|
||||
pub fn register(registry: &FunctionRegistry) {
|
||||
registry.register_scalar(FieldFunction::default());
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FieldFunction {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
// FIELD takes a variable number of arguments: (String, String, String, ...)
|
||||
signature: Signature::variadic_any(Volatility::Immutable),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for FieldFunction {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", NAME.to_ascii_uppercase())
|
||||
}
|
||||
}
|
||||
|
||||
impl Function for FieldFunction {
|
||||
fn name(&self) -> &str {
|
||||
NAME
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
|
||||
Ok(DataType::Int64)
|
||||
}
|
||||
|
||||
fn signature(&self) -> &Signature {
|
||||
&self.signature
|
||||
}
|
||||
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
if args.args.len() < 2 {
|
||||
return Err(DataFusionError::Execution(
|
||||
"FIELD requires at least 2 arguments: FIELD(str, str1, ...)".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let arrays = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let len = arrays[0].len();
|
||||
|
||||
// Cast all arguments to LargeUtf8
|
||||
let string_arrays: Vec<ArrayRef> = arrays
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, arr)| {
|
||||
cast(arr.as_ref(), &DataType::LargeUtf8).map_err(|e| {
|
||||
DataFusionError::Execution(format!("FIELD: argument {} cast failed: {}", i, e))
|
||||
})
|
||||
})
|
||||
.collect::<datafusion_common::Result<Vec<_>>>()?;
|
||||
|
||||
let search_str = string_arrays[0].as_string::<i64>();
|
||||
let mut builder = Int64Builder::with_capacity(len);
|
||||
|
||||
for i in 0..len {
|
||||
// If search string is NULL, return 0
|
||||
if search_str.is_null(i) {
|
||||
builder.append_value(0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let needle = search_str.value(i);
|
||||
let mut found_idx = 0i64;
|
||||
|
||||
// Search through the list (starting from index 1 in string_arrays)
|
||||
for (j, str_arr) in string_arrays[1..].iter().enumerate() {
|
||||
let str_array = str_arr.as_string::<i64>();
|
||||
if !str_array.is_null(i) && str_array.value(i) == needle {
|
||||
found_idx = (j + 1) as i64; // 1-based index
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
builder.append_value(found_idx);
|
||||
}
|
||||
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_common::arrow::array::StringArray;
|
||||
use datafusion_common::arrow::datatypes::Field;
|
||||
use datafusion_expr::ScalarFunctionArgs;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn create_args(arrays: Vec<ArrayRef>) -> ScalarFunctionArgs {
|
||||
let arg_fields: Vec<_> = arrays
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, arr)| {
|
||||
Arc::new(Field::new(
|
||||
format!("arg_{}", i),
|
||||
arr.data_type().clone(),
|
||||
true,
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
|
||||
ScalarFunctionArgs {
|
||||
args: arrays.iter().cloned().map(ColumnarValue::Array).collect(),
|
||||
arg_fields,
|
||||
return_field: Arc::new(Field::new("result", DataType::Int64, true)),
|
||||
number_rows: arrays[0].len(),
|
||||
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_field_basic() {
|
||||
let function = FieldFunction::default();
|
||||
|
||||
let search = Arc::new(StringArray::from(vec!["b", "d", "a"]));
|
||||
let s1 = Arc::new(StringArray::from(vec!["a", "a", "a"]));
|
||||
let s2 = Arc::new(StringArray::from(vec!["b", "b", "b"]));
|
||||
let s3 = Arc::new(StringArray::from(vec!["c", "c", "c"]));
|
||||
|
||||
let args = create_args(vec![search, s1, s2, s3]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
|
||||
assert_eq!(int_array.value(0), 2); // "b" is at index 2
|
||||
assert_eq!(int_array.value(1), 0); // "d" not found
|
||||
assert_eq!(int_array.value(2), 1); // "a" is at index 1
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_field_with_null_search() {
|
||||
let function = FieldFunction::default();
|
||||
|
||||
let search = Arc::new(StringArray::from(vec![Some("a"), None]));
|
||||
let s1 = Arc::new(StringArray::from(vec!["a", "a"]));
|
||||
let s2 = Arc::new(StringArray::from(vec!["b", "b"]));
|
||||
|
||||
let args = create_args(vec![search, s1, s2]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
|
||||
assert_eq!(int_array.value(0), 1); // "a" found at index 1
|
||||
assert_eq!(int_array.value(1), 0); // NULL search returns 0
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_field_case_sensitive() {
|
||||
let function = FieldFunction::default();
|
||||
|
||||
let search = Arc::new(StringArray::from(vec!["A", "a"]));
|
||||
let s1 = Arc::new(StringArray::from(vec!["a", "a"]));
|
||||
let s2 = Arc::new(StringArray::from(vec!["A", "A"]));
|
||||
|
||||
let args = create_args(vec![search, s1, s2]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
|
||||
assert_eq!(int_array.value(0), 2); // "A" matches at index 2
|
||||
assert_eq!(int_array.value(1), 1); // "a" matches at index 1
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
}
|
||||
512
src/common/function/src/scalars/string/format.rs
Normal file
512
src/common/function/src/scalars/string/format.rs
Normal file
@@ -0,0 +1,512 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! MySQL-compatible FORMAT function implementation.
|
||||
//!
|
||||
//! FORMAT(X, D) - Formats the number X with D decimal places using thousand separators.
|
||||
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_common::arrow::array::{Array, AsArray, LargeStringBuilder};
|
||||
use datafusion_common::arrow::datatypes as arrow_types;
|
||||
use datafusion_common::arrow::datatypes::DataType;
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility};
|
||||
|
||||
use crate::function::Function;
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
|
||||
const NAME: &str = "format";
|
||||
|
||||
/// MySQL-compatible FORMAT function.
|
||||
///
|
||||
/// Syntax: FORMAT(X, D)
|
||||
/// Formats the number X to a format like '#,###,###.##', rounded to D decimal places.
|
||||
/// D can be 0 to 30.
|
||||
///
|
||||
/// Note: This implementation uses the en_US locale (comma as thousand separator,
|
||||
/// period as decimal separator).
|
||||
#[derive(Debug)]
|
||||
pub struct FormatFunction {
|
||||
signature: Signature,
|
||||
}
|
||||
|
||||
impl FormatFunction {
|
||||
pub fn register(registry: &FunctionRegistry) {
|
||||
registry.register_scalar(FormatFunction::default());
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FormatFunction {
|
||||
fn default() -> Self {
|
||||
let mut signatures = Vec::new();
|
||||
|
||||
// Support various numeric types for X
|
||||
let numeric_types = [
|
||||
DataType::Float64,
|
||||
DataType::Float32,
|
||||
DataType::Int64,
|
||||
DataType::Int32,
|
||||
DataType::Int16,
|
||||
DataType::Int8,
|
||||
DataType::UInt64,
|
||||
DataType::UInt32,
|
||||
DataType::UInt16,
|
||||
DataType::UInt8,
|
||||
];
|
||||
|
||||
// D can be various integer types
|
||||
let int_types = [
|
||||
DataType::Int64,
|
||||
DataType::Int32,
|
||||
DataType::Int16,
|
||||
DataType::Int8,
|
||||
DataType::UInt64,
|
||||
DataType::UInt32,
|
||||
DataType::UInt16,
|
||||
DataType::UInt8,
|
||||
];
|
||||
|
||||
for x_type in &numeric_types {
|
||||
for d_type in &int_types {
|
||||
signatures.push(TypeSignature::Exact(vec![x_type.clone(), d_type.clone()]));
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
signature: Signature::one_of(signatures, Volatility::Immutable),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for FormatFunction {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", NAME.to_ascii_uppercase())
|
||||
}
|
||||
}
|
||||
|
||||
impl Function for FormatFunction {
|
||||
fn name(&self) -> &str {
|
||||
NAME
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
|
||||
Ok(DataType::LargeUtf8)
|
||||
}
|
||||
|
||||
fn signature(&self) -> &Signature {
|
||||
&self.signature
|
||||
}
|
||||
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
if args.args.len() != 2 {
|
||||
return Err(DataFusionError::Execution(
|
||||
"FORMAT requires exactly 2 arguments: FORMAT(X, D)".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let arrays = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let len = arrays[0].len();
|
||||
|
||||
let x_array = &arrays[0];
|
||||
let d_array = &arrays[1];
|
||||
|
||||
let mut builder = LargeStringBuilder::with_capacity(len, len * 20);
|
||||
|
||||
for i in 0..len {
|
||||
if x_array.is_null(i) || d_array.is_null(i) {
|
||||
builder.append_null();
|
||||
continue;
|
||||
}
|
||||
|
||||
let decimal_places = get_decimal_places(d_array, i)?.clamp(0, 30) as usize;
|
||||
|
||||
let formatted = match x_array.data_type() {
|
||||
DataType::Float64 | DataType::Float32 => {
|
||||
format_number_float(get_float_value(x_array, i)?, decimal_places)
|
||||
}
|
||||
DataType::Int64
|
||||
| DataType::Int32
|
||||
| DataType::Int16
|
||||
| DataType::Int8
|
||||
| DataType::UInt64
|
||||
| DataType::UInt32
|
||||
| DataType::UInt16
|
||||
| DataType::UInt8 => format_number_integer(x_array, i, decimal_places)?,
|
||||
_ => {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"FORMAT: unsupported type {:?}",
|
||||
x_array.data_type()
|
||||
)));
|
||||
}
|
||||
};
|
||||
builder.append_value(&formatted);
|
||||
}
|
||||
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
/// Get float value from various numeric types.
|
||||
fn get_float_value(
|
||||
array: &datafusion_common::arrow::array::ArrayRef,
|
||||
index: usize,
|
||||
) -> datafusion_common::Result<f64> {
|
||||
match array.data_type() {
|
||||
DataType::Float64 => Ok(array
|
||||
.as_primitive::<arrow_types::Float64Type>()
|
||||
.value(index)),
|
||||
DataType::Float32 => Ok(array
|
||||
.as_primitive::<arrow_types::Float32Type>()
|
||||
.value(index) as f64),
|
||||
_ => Err(DataFusionError::Execution(format!(
|
||||
"FORMAT: unsupported type {:?}",
|
||||
array.data_type()
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get decimal places from various integer types.
|
||||
///
|
||||
/// MySQL clamps decimal places to `0..=30`. This function returns an `i64` so the caller can clamp.
|
||||
fn get_decimal_places(
|
||||
array: &datafusion_common::arrow::array::ArrayRef,
|
||||
index: usize,
|
||||
) -> datafusion_common::Result<i64> {
|
||||
match array.data_type() {
|
||||
DataType::Int64 => Ok(array.as_primitive::<arrow_types::Int64Type>().value(index)),
|
||||
DataType::Int32 => Ok(array.as_primitive::<arrow_types::Int32Type>().value(index) as i64),
|
||||
DataType::Int16 => Ok(array.as_primitive::<arrow_types::Int16Type>().value(index) as i64),
|
||||
DataType::Int8 => Ok(array.as_primitive::<arrow_types::Int8Type>().value(index) as i64),
|
||||
DataType::UInt64 => {
|
||||
let v = array.as_primitive::<arrow_types::UInt64Type>().value(index);
|
||||
Ok(if v > i64::MAX as u64 {
|
||||
i64::MAX
|
||||
} else {
|
||||
v as i64
|
||||
})
|
||||
}
|
||||
DataType::UInt32 => Ok(array.as_primitive::<arrow_types::UInt32Type>().value(index) as i64),
|
||||
DataType::UInt16 => Ok(array.as_primitive::<arrow_types::UInt16Type>().value(index) as i64),
|
||||
DataType::UInt8 => Ok(array.as_primitive::<arrow_types::UInt8Type>().value(index) as i64),
|
||||
_ => Err(DataFusionError::Execution(format!(
|
||||
"FORMAT: unsupported type {:?}",
|
||||
array.data_type()
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn format_number_integer(
|
||||
array: &datafusion_common::arrow::array::ArrayRef,
|
||||
index: usize,
|
||||
decimal_places: usize,
|
||||
) -> datafusion_common::Result<String> {
|
||||
let (is_negative, abs_digits) = match array.data_type() {
|
||||
DataType::Int64 => {
|
||||
let v = array.as_primitive::<arrow_types::Int64Type>().value(index) as i128;
|
||||
(v.is_negative(), v.unsigned_abs().to_string())
|
||||
}
|
||||
DataType::Int32 => {
|
||||
let v = array.as_primitive::<arrow_types::Int32Type>().value(index) as i128;
|
||||
(v.is_negative(), v.unsigned_abs().to_string())
|
||||
}
|
||||
DataType::Int16 => {
|
||||
let v = array.as_primitive::<arrow_types::Int16Type>().value(index) as i128;
|
||||
(v.is_negative(), v.unsigned_abs().to_string())
|
||||
}
|
||||
DataType::Int8 => {
|
||||
let v = array.as_primitive::<arrow_types::Int8Type>().value(index) as i128;
|
||||
(v.is_negative(), v.unsigned_abs().to_string())
|
||||
}
|
||||
DataType::UInt64 => {
|
||||
let v = array.as_primitive::<arrow_types::UInt64Type>().value(index) as u128;
|
||||
(false, v.to_string())
|
||||
}
|
||||
DataType::UInt32 => {
|
||||
let v = array.as_primitive::<arrow_types::UInt32Type>().value(index) as u128;
|
||||
(false, v.to_string())
|
||||
}
|
||||
DataType::UInt16 => {
|
||||
let v = array.as_primitive::<arrow_types::UInt16Type>().value(index) as u128;
|
||||
(false, v.to_string())
|
||||
}
|
||||
DataType::UInt8 => {
|
||||
let v = array.as_primitive::<arrow_types::UInt8Type>().value(index) as u128;
|
||||
(false, v.to_string())
|
||||
}
|
||||
_ => {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"FORMAT: unsupported type {:?}",
|
||||
array.data_type()
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
let mut result = String::new();
|
||||
if is_negative {
|
||||
result.push('-');
|
||||
}
|
||||
result.push_str(&add_thousand_separators(&abs_digits));
|
||||
|
||||
if decimal_places > 0 {
|
||||
result.push('.');
|
||||
result.push_str(&"0".repeat(decimal_places));
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Format a float with thousand separators and `decimal_places` digits after decimal point.
|
||||
fn format_number_float(x: f64, decimal_places: usize) -> String {
|
||||
// Handle special cases
|
||||
if x.is_nan() {
|
||||
return "NaN".to_string();
|
||||
}
|
||||
if x.is_infinite() {
|
||||
return if x.is_sign_positive() {
|
||||
"Infinity".to_string()
|
||||
} else {
|
||||
"-Infinity".to_string()
|
||||
};
|
||||
}
|
||||
|
||||
// Round to decimal_places
|
||||
let multiplier = 10f64.powi(decimal_places as i32);
|
||||
let rounded = (x * multiplier).round() / multiplier;
|
||||
|
||||
// Split into integer and fractional parts
|
||||
let is_negative = rounded < 0.0;
|
||||
let abs_value = rounded.abs();
|
||||
|
||||
// Format with the specified decimal places
|
||||
let formatted = if decimal_places == 0 {
|
||||
format!("{:.0}", abs_value)
|
||||
} else {
|
||||
format!("{:.prec$}", abs_value, prec = decimal_places)
|
||||
};
|
||||
|
||||
// Split at decimal point
|
||||
let parts: Vec<&str> = formatted.split('.').collect();
|
||||
let int_part = parts[0];
|
||||
let dec_part = parts.get(1).copied();
|
||||
|
||||
// Add thousand separators to integer part
|
||||
let int_with_sep = add_thousand_separators(int_part);
|
||||
|
||||
// Build result
|
||||
let mut result = String::new();
|
||||
if is_negative {
|
||||
result.push('-');
|
||||
}
|
||||
result.push_str(&int_with_sep);
|
||||
if let Some(dec) = dec_part {
|
||||
result.push('.');
|
||||
result.push_str(dec);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Add thousand separators (commas) to an integer string.
|
||||
fn add_thousand_separators(s: &str) -> String {
|
||||
let chars: Vec<char> = s.chars().collect();
|
||||
let len = chars.len();
|
||||
|
||||
if len <= 3 {
|
||||
return s.to_string();
|
||||
}
|
||||
|
||||
let mut result = String::with_capacity(len + len / 3);
|
||||
let first_group_len = len % 3;
|
||||
let first_group_len = if first_group_len == 0 {
|
||||
3
|
||||
} else {
|
||||
first_group_len
|
||||
};
|
||||
|
||||
for (i, ch) in chars.iter().enumerate() {
|
||||
if i > 0 && i >= first_group_len && (i - first_group_len) % 3 == 0 {
|
||||
result.push(',');
|
||||
}
|
||||
result.push(*ch);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_common::arrow::array::{Float64Array, Int64Array};
|
||||
use datafusion_common::arrow::datatypes::Field;
|
||||
use datafusion_expr::ScalarFunctionArgs;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn create_args(arrays: Vec<datafusion_common::arrow::array::ArrayRef>) -> ScalarFunctionArgs {
|
||||
let arg_fields: Vec<_> = arrays
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, arr)| {
|
||||
Arc::new(Field::new(
|
||||
format!("arg_{}", i),
|
||||
arr.data_type().clone(),
|
||||
true,
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
|
||||
ScalarFunctionArgs {
|
||||
args: arrays.iter().cloned().map(ColumnarValue::Array).collect(),
|
||||
arg_fields,
|
||||
return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
|
||||
number_rows: arrays[0].len(),
|
||||
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_basic() {
|
||||
let function = FormatFunction::default();
|
||||
|
||||
let x = Arc::new(Float64Array::from(vec![1234567.891, 1234.5, 1234567.0]));
|
||||
let d = Arc::new(Int64Array::from(vec![2, 0, 3]));
|
||||
|
||||
let args = create_args(vec![x, d]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), "1,234,567.89");
|
||||
assert_eq!(str_array.value(1), "1,235"); // rounded
|
||||
assert_eq!(str_array.value(2), "1,234,567.000");
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_negative() {
|
||||
let function = FormatFunction::default();
|
||||
|
||||
let x = Arc::new(Float64Array::from(vec![-1234567.891]));
|
||||
let d = Arc::new(Int64Array::from(vec![2]));
|
||||
|
||||
let args = create_args(vec![x, d]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), "-1,234,567.89");
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_small_numbers() {
|
||||
let function = FormatFunction::default();
|
||||
|
||||
let x = Arc::new(Float64Array::from(vec![0.5, 12.345, 123.0]));
|
||||
let d = Arc::new(Int64Array::from(vec![2, 2, 0]));
|
||||
|
||||
let args = create_args(vec![x, d]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), "0.50");
|
||||
assert_eq!(str_array.value(1), "12.35"); // rounded
|
||||
assert_eq!(str_array.value(2), "123");
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_with_nulls() {
|
||||
let function = FormatFunction::default();
|
||||
|
||||
let x = Arc::new(Float64Array::from(vec![Some(1234.5), None]));
|
||||
let d = Arc::new(Int64Array::from(vec![2, 2]));
|
||||
|
||||
let args = create_args(vec![x, d]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), "1,234.50");
|
||||
assert!(str_array.is_null(1));
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_thousand_separators() {
|
||||
assert_eq!(add_thousand_separators("1"), "1");
|
||||
assert_eq!(add_thousand_separators("12"), "12");
|
||||
assert_eq!(add_thousand_separators("123"), "123");
|
||||
assert_eq!(add_thousand_separators("1234"), "1,234");
|
||||
assert_eq!(add_thousand_separators("12345"), "12,345");
|
||||
assert_eq!(add_thousand_separators("123456"), "123,456");
|
||||
assert_eq!(add_thousand_separators("1234567"), "1,234,567");
|
||||
assert_eq!(add_thousand_separators("12345678"), "12,345,678");
|
||||
assert_eq!(add_thousand_separators("123456789"), "123,456,789");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_large_int_no_float_precision_loss() {
|
||||
let function = FormatFunction::default();
|
||||
|
||||
// 2^53 + 1 cannot be represented exactly as f64.
|
||||
let x = Arc::new(Int64Array::from(vec![9_007_199_254_740_993i64]));
|
||||
let d = Arc::new(Int64Array::from(vec![0]));
|
||||
|
||||
let args = create_args(vec![x, d]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), "9,007,199,254,740,993");
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_decimal_places_u64_overflow_clamps() {
|
||||
use datafusion_common::arrow::array::UInt64Array;
|
||||
|
||||
let function = FormatFunction::default();
|
||||
|
||||
let x = Arc::new(Int64Array::from(vec![1]));
|
||||
let d = Arc::new(UInt64Array::from(vec![u64::MAX]));
|
||||
|
||||
let args = create_args(vec![x, d]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), format!("1.{}", "0".repeat(30)));
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
}
|
||||
345
src/common/function/src/scalars/string/insert.rs
Normal file
345
src/common/function/src/scalars/string/insert.rs
Normal file
@@ -0,0 +1,345 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! MySQL-compatible INSERT function implementation.
|
||||
//!
|
||||
//! INSERT(str, pos, len, newstr) - Inserts newstr into str at position pos,
|
||||
//! replacing len characters.
|
||||
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_common::arrow::array::{Array, ArrayRef, AsArray, LargeStringBuilder};
|
||||
use datafusion_common::arrow::compute::cast;
|
||||
use datafusion_common::arrow::datatypes::DataType;
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility};
|
||||
|
||||
use crate::function::Function;
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
|
||||
const NAME: &str = "insert";
|
||||
|
||||
/// MySQL-compatible INSERT function.
|
||||
///
|
||||
/// Syntax: INSERT(str, pos, len, newstr)
|
||||
/// Returns str with the substring beginning at position pos and len characters long
|
||||
/// replaced by newstr.
|
||||
///
|
||||
/// - pos is 1-based
|
||||
/// - If pos is out of range, returns the original string
|
||||
/// - If len is out of range, replaces from pos to end of string
|
||||
#[derive(Debug)]
|
||||
pub struct InsertFunction {
|
||||
signature: Signature,
|
||||
}
|
||||
|
||||
impl InsertFunction {
|
||||
pub fn register(registry: &FunctionRegistry) {
|
||||
registry.register_scalar(InsertFunction::default());
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for InsertFunction {
|
||||
fn default() -> Self {
|
||||
let mut signatures = Vec::new();
|
||||
let string_types = [DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View];
|
||||
let int_types = [
|
||||
DataType::Int64,
|
||||
DataType::Int32,
|
||||
DataType::Int16,
|
||||
DataType::Int8,
|
||||
DataType::UInt64,
|
||||
DataType::UInt32,
|
||||
DataType::UInt16,
|
||||
DataType::UInt8,
|
||||
];
|
||||
|
||||
for str_type in &string_types {
|
||||
for newstr_type in &string_types {
|
||||
for pos_type in &int_types {
|
||||
for len_type in &int_types {
|
||||
signatures.push(TypeSignature::Exact(vec![
|
||||
str_type.clone(),
|
||||
pos_type.clone(),
|
||||
len_type.clone(),
|
||||
newstr_type.clone(),
|
||||
]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
signature: Signature::one_of(signatures, Volatility::Immutable),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for InsertFunction {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", NAME.to_ascii_uppercase())
|
||||
}
|
||||
}
|
||||
|
||||
impl Function for InsertFunction {
|
||||
fn name(&self) -> &str {
|
||||
NAME
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
|
||||
Ok(DataType::LargeUtf8)
|
||||
}
|
||||
|
||||
fn signature(&self) -> &Signature {
|
||||
&self.signature
|
||||
}
|
||||
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
if args.args.len() != 4 {
|
||||
return Err(DataFusionError::Execution(
|
||||
"INSERT requires exactly 4 arguments: INSERT(str, pos, len, newstr)".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let arrays = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let len = arrays[0].len();
|
||||
|
||||
// Cast string arguments to LargeUtf8
|
||||
let str_array = cast_to_large_utf8(&arrays[0], "str")?;
|
||||
let newstr_array = cast_to_large_utf8(&arrays[3], "newstr")?;
|
||||
let pos_array = cast_to_int64(&arrays[1], "pos")?;
|
||||
let replace_len_array = cast_to_int64(&arrays[2], "len")?;
|
||||
|
||||
let str_arr = str_array.as_string::<i64>();
|
||||
let pos_arr = pos_array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
|
||||
let len_arr =
|
||||
replace_len_array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
|
||||
let newstr_arr = newstr_array.as_string::<i64>();
|
||||
|
||||
let mut builder = LargeStringBuilder::with_capacity(len, len * 32);
|
||||
|
||||
for i in 0..len {
|
||||
// Check for NULLs
|
||||
if str_arr.is_null(i)
|
||||
|| pos_array.is_null(i)
|
||||
|| replace_len_array.is_null(i)
|
||||
|| newstr_arr.is_null(i)
|
||||
{
|
||||
builder.append_null();
|
||||
continue;
|
||||
}
|
||||
|
||||
let original = str_arr.value(i);
|
||||
let pos = pos_arr.value(i);
|
||||
let replace_len = len_arr.value(i);
|
||||
let new_str = newstr_arr.value(i);
|
||||
|
||||
let result = insert_string(original, pos, replace_len, new_str);
|
||||
builder.append_value(&result);
|
||||
}
|
||||
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
/// Cast array to LargeUtf8 for uniform string access.
|
||||
fn cast_to_large_utf8(array: &ArrayRef, name: &str) -> datafusion_common::Result<ArrayRef> {
|
||||
cast(array.as_ref(), &DataType::LargeUtf8)
|
||||
.map_err(|e| DataFusionError::Execution(format!("INSERT: {} cast failed: {}", name, e)))
|
||||
}
|
||||
|
||||
fn cast_to_int64(array: &ArrayRef, name: &str) -> datafusion_common::Result<ArrayRef> {
|
||||
cast(array.as_ref(), &DataType::Int64)
|
||||
.map_err(|e| DataFusionError::Execution(format!("INSERT: {} cast failed: {}", name, e)))
|
||||
}
|
||||
|
||||
/// Perform the INSERT string operation.
|
||||
/// pos is 1-based. If pos < 1 or pos > len(str) + 1, returns original string.
|
||||
fn insert_string(original: &str, pos: i64, replace_len: i64, new_str: &str) -> String {
|
||||
let char_count = original.chars().count();
|
||||
|
||||
// MySQL behavior: if pos < 1 or pos > string length + 1, return original
|
||||
if pos < 1 || pos as usize > char_count + 1 {
|
||||
return original.to_string();
|
||||
}
|
||||
|
||||
let start_idx = (pos - 1) as usize; // Convert to 0-based
|
||||
|
||||
// Calculate end index for replacement
|
||||
let replace_len = if replace_len < 0 {
|
||||
0
|
||||
} else {
|
||||
replace_len as usize
|
||||
};
|
||||
let end_idx = (start_idx + replace_len).min(char_count);
|
||||
|
||||
let start_byte = char_to_byte_idx(original, start_idx);
|
||||
let end_byte = char_to_byte_idx(original, end_idx);
|
||||
|
||||
let mut result = String::with_capacity(original.len() + new_str.len());
|
||||
result.push_str(&original[..start_byte]);
|
||||
result.push_str(new_str);
|
||||
result.push_str(&original[end_byte..]);
|
||||
result
|
||||
}
|
||||
|
||||
fn char_to_byte_idx(s: &str, char_idx: usize) -> usize {
|
||||
s.char_indices()
|
||||
.nth(char_idx)
|
||||
.map(|(idx, _)| idx)
|
||||
.unwrap_or(s.len())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_common::arrow::array::{Int64Array, StringArray};
|
||||
use datafusion_common::arrow::datatypes::Field;
|
||||
use datafusion_expr::ScalarFunctionArgs;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn create_args(arrays: Vec<ArrayRef>) -> ScalarFunctionArgs {
|
||||
let arg_fields: Vec<_> = arrays
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, arr)| {
|
||||
Arc::new(Field::new(
|
||||
format!("arg_{}", i),
|
||||
arr.data_type().clone(),
|
||||
true,
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
|
||||
ScalarFunctionArgs {
|
||||
args: arrays.iter().cloned().map(ColumnarValue::Array).collect(),
|
||||
arg_fields,
|
||||
return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
|
||||
number_rows: arrays[0].len(),
|
||||
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insert_basic() {
|
||||
let function = InsertFunction::default();
|
||||
|
||||
// INSERT('Quadratic', 3, 4, 'What') => 'QuWhattic'
|
||||
let str_arr = Arc::new(StringArray::from(vec!["Quadratic"]));
|
||||
let pos = Arc::new(Int64Array::from(vec![3]));
|
||||
let len = Arc::new(Int64Array::from(vec![4]));
|
||||
let newstr = Arc::new(StringArray::from(vec!["What"]));
|
||||
|
||||
let args = create_args(vec![str_arr, pos, len, newstr]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), "QuWhattic");
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insert_out_of_range_pos() {
|
||||
let function = InsertFunction::default();
|
||||
|
||||
// INSERT('Quadratic', 0, 4, 'What') => 'Quadratic' (pos < 1)
|
||||
let str_arr = Arc::new(StringArray::from(vec!["Quadratic", "Quadratic"]));
|
||||
let pos = Arc::new(Int64Array::from(vec![0, 100]));
|
||||
let len = Arc::new(Int64Array::from(vec![4, 4]));
|
||||
let newstr = Arc::new(StringArray::from(vec!["What", "What"]));
|
||||
|
||||
let args = create_args(vec![str_arr, pos, len, newstr]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), "Quadratic"); // pos < 1
|
||||
assert_eq!(str_array.value(1), "Quadratic"); // pos > length
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insert_replace_to_end() {
|
||||
let function = InsertFunction::default();
|
||||
|
||||
// INSERT('Quadratic', 3, 100, 'What') => 'QuWhat' (len exceeds remaining)
|
||||
let str_arr = Arc::new(StringArray::from(vec!["Quadratic"]));
|
||||
let pos = Arc::new(Int64Array::from(vec![3]));
|
||||
let len = Arc::new(Int64Array::from(vec![100]));
|
||||
let newstr = Arc::new(StringArray::from(vec!["What"]));
|
||||
|
||||
let args = create_args(vec![str_arr, pos, len, newstr]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), "QuWhat");
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insert_unicode() {
|
||||
let function = InsertFunction::default();
|
||||
|
||||
// INSERT('hello世界', 6, 1, 'の') => 'helloの界'
|
||||
let str_arr = Arc::new(StringArray::from(vec!["hello世界"]));
|
||||
let pos = Arc::new(Int64Array::from(vec![6]));
|
||||
let len = Arc::new(Int64Array::from(vec![1]));
|
||||
let newstr = Arc::new(StringArray::from(vec!["の"]));
|
||||
|
||||
let args = create_args(vec![str_arr, pos, len, newstr]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), "helloの界");
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insert_with_nulls() {
|
||||
let function = InsertFunction::default();
|
||||
|
||||
let str_arr = Arc::new(StringArray::from(vec![Some("hello"), None]));
|
||||
let pos = Arc::new(Int64Array::from(vec![1, 1]));
|
||||
let len = Arc::new(Int64Array::from(vec![1, 1]));
|
||||
let newstr = Arc::new(StringArray::from(vec!["X", "X"]));
|
||||
|
||||
let args = create_args(vec![str_arr, pos, len, newstr]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), "Xello");
|
||||
assert!(str_array.is_null(1));
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
}
|
||||
373
src/common/function/src/scalars/string/locate.rs
Normal file
373
src/common/function/src/scalars/string/locate.rs
Normal file
@@ -0,0 +1,373 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! MySQL-compatible LOCATE function implementation.
|
||||
//!
|
||||
//! LOCATE(substr, str) - Returns the position of the first occurrence of substr in str (1-based).
|
||||
//! LOCATE(substr, str, pos) - Returns the position of the first occurrence of substr in str,
|
||||
//! starting from position pos.
|
||||
//! Returns 0 if substr is not found.
|
||||
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_common::arrow::array::{Array, ArrayRef, AsArray, Int64Builder};
|
||||
use datafusion_common::arrow::compute::cast;
|
||||
use datafusion_common::arrow::datatypes::DataType;
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility};
|
||||
|
||||
use crate::function::Function;
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
|
||||
const NAME: &str = "locate";
|
||||
|
||||
/// MySQL-compatible LOCATE function.
|
||||
///
|
||||
/// Syntax:
|
||||
/// - LOCATE(substr, str) - Returns 1-based position of substr in str, or 0 if not found.
|
||||
/// - LOCATE(substr, str, pos) - Same, but starts searching from position pos.
|
||||
#[derive(Debug)]
|
||||
pub struct LocateFunction {
|
||||
signature: Signature,
|
||||
}
|
||||
|
||||
impl LocateFunction {
|
||||
pub fn register(registry: &FunctionRegistry) {
|
||||
registry.register_scalar(LocateFunction::default());
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LocateFunction {
|
||||
fn default() -> Self {
|
||||
// Support 2 or 3 arguments with various string types
|
||||
let mut signatures = Vec::new();
|
||||
let string_types = [DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View];
|
||||
let int_types = [
|
||||
DataType::Int64,
|
||||
DataType::Int32,
|
||||
DataType::Int16,
|
||||
DataType::Int8,
|
||||
DataType::UInt64,
|
||||
DataType::UInt32,
|
||||
DataType::UInt16,
|
||||
DataType::UInt8,
|
||||
];
|
||||
|
||||
// 2-argument form: LOCATE(substr, str)
|
||||
for substr_type in &string_types {
|
||||
for str_type in &string_types {
|
||||
signatures.push(TypeSignature::Exact(vec![
|
||||
substr_type.clone(),
|
||||
str_type.clone(),
|
||||
]));
|
||||
}
|
||||
}
|
||||
|
||||
// 3-argument form: LOCATE(substr, str, pos)
|
||||
for substr_type in &string_types {
|
||||
for str_type in &string_types {
|
||||
for pos_type in &int_types {
|
||||
signatures.push(TypeSignature::Exact(vec![
|
||||
substr_type.clone(),
|
||||
str_type.clone(),
|
||||
pos_type.clone(),
|
||||
]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
signature: Signature::one_of(signatures, Volatility::Immutable),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for LocateFunction {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", NAME.to_ascii_uppercase())
|
||||
}
|
||||
}
|
||||
|
||||
impl Function for LocateFunction {
|
||||
fn name(&self) -> &str {
|
||||
NAME
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
|
||||
Ok(DataType::Int64)
|
||||
}
|
||||
|
||||
fn signature(&self) -> &Signature {
|
||||
&self.signature
|
||||
}
|
||||
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let arg_count = args.args.len();
|
||||
if !(2..=3).contains(&arg_count) {
|
||||
return Err(DataFusionError::Execution(
|
||||
"LOCATE requires 2 or 3 arguments: LOCATE(substr, str) or LOCATE(substr, str, pos)"
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let arrays = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
|
||||
// Cast string arguments to LargeUtf8 for uniform access
|
||||
let substr_array = cast_to_large_utf8(&arrays[0], "substr")?;
|
||||
let str_array = cast_to_large_utf8(&arrays[1], "str")?;
|
||||
|
||||
let substr = substr_array.as_string::<i64>();
|
||||
let str_arr = str_array.as_string::<i64>();
|
||||
let len = substr.len();
|
||||
|
||||
// Handle optional pos argument
|
||||
let pos_array: Option<ArrayRef> = if arg_count == 3 {
|
||||
Some(cast_to_int64(&arrays[2], "pos")?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut builder = Int64Builder::with_capacity(len);
|
||||
|
||||
for i in 0..len {
|
||||
if substr.is_null(i) || str_arr.is_null(i) {
|
||||
builder.append_null();
|
||||
continue;
|
||||
}
|
||||
|
||||
let needle = substr.value(i);
|
||||
let haystack = str_arr.value(i);
|
||||
|
||||
// Get starting position (1-based in MySQL, convert to 0-based)
|
||||
let start_pos = if let Some(ref pos_arr) = pos_array {
|
||||
if pos_arr.is_null(i) {
|
||||
builder.append_null();
|
||||
continue;
|
||||
}
|
||||
let pos = pos_arr
|
||||
.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>()
|
||||
.value(i);
|
||||
if pos < 1 {
|
||||
// MySQL returns 0 for pos < 1
|
||||
builder.append_value(0);
|
||||
continue;
|
||||
}
|
||||
(pos - 1) as usize
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
// Find position using character-based indexing (for Unicode support)
|
||||
let result = locate_substr(haystack, needle, start_pos);
|
||||
builder.append_value(result);
|
||||
}
|
||||
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
/// Cast array to LargeUtf8 for uniform string access.
|
||||
fn cast_to_large_utf8(array: &ArrayRef, name: &str) -> datafusion_common::Result<ArrayRef> {
|
||||
cast(array.as_ref(), &DataType::LargeUtf8)
|
||||
.map_err(|e| DataFusionError::Execution(format!("LOCATE: {} cast failed: {}", name, e)))
|
||||
}
|
||||
|
||||
fn cast_to_int64(array: &ArrayRef, name: &str) -> datafusion_common::Result<ArrayRef> {
|
||||
cast(array.as_ref(), &DataType::Int64)
|
||||
.map_err(|e| DataFusionError::Execution(format!("LOCATE: {} cast failed: {}", name, e)))
|
||||
}
|
||||
|
||||
/// Find the 1-based position of needle in haystack, starting from start_pos (0-based character index).
|
||||
/// Returns 0 if not found.
|
||||
fn locate_substr(haystack: &str, needle: &str, start_pos: usize) -> i64 {
|
||||
// Handle empty needle - MySQL returns start_pos + 1
|
||||
if needle.is_empty() {
|
||||
let char_count = haystack.chars().count();
|
||||
return if start_pos <= char_count {
|
||||
(start_pos + 1) as i64
|
||||
} else {
|
||||
0
|
||||
};
|
||||
}
|
||||
|
||||
// Convert start_pos (character index) to byte index
|
||||
let byte_start = haystack
|
||||
.char_indices()
|
||||
.nth(start_pos)
|
||||
.map(|(idx, _)| idx)
|
||||
.unwrap_or(haystack.len());
|
||||
|
||||
if byte_start >= haystack.len() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Search in the substring
|
||||
let search_str = &haystack[byte_start..];
|
||||
if let Some(byte_pos) = search_str.find(needle) {
|
||||
// Convert byte position back to character position
|
||||
let char_pos = search_str[..byte_pos].chars().count();
|
||||
// Return 1-based position relative to original string
|
||||
(start_pos + char_pos + 1) as i64
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_common::arrow::array::StringArray;
|
||||
use datafusion_common::arrow::datatypes::Field;
|
||||
use datafusion_expr::ScalarFunctionArgs;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn create_args(arrays: Vec<ArrayRef>) -> ScalarFunctionArgs {
|
||||
let arg_fields: Vec<_> = arrays
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, arr)| {
|
||||
Arc::new(Field::new(
|
||||
format!("arg_{}", i),
|
||||
arr.data_type().clone(),
|
||||
true,
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
|
||||
ScalarFunctionArgs {
|
||||
args: arrays.iter().cloned().map(ColumnarValue::Array).collect(),
|
||||
arg_fields,
|
||||
return_field: Arc::new(Field::new("result", DataType::Int64, true)),
|
||||
number_rows: arrays[0].len(),
|
||||
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_locate_basic() {
|
||||
let function = LocateFunction::default();
|
||||
|
||||
let substr = Arc::new(StringArray::from(vec!["world", "xyz", "hello"]));
|
||||
let str_arr = Arc::new(StringArray::from(vec![
|
||||
"hello world",
|
||||
"hello world",
|
||||
"hello world",
|
||||
]));
|
||||
|
||||
let args = create_args(vec![substr, str_arr]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
|
||||
assert_eq!(int_array.value(0), 7); // "world" at position 7
|
||||
assert_eq!(int_array.value(1), 0); // "xyz" not found
|
||||
assert_eq!(int_array.value(2), 1); // "hello" at position 1
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_locate_with_position() {
|
||||
let function = LocateFunction::default();
|
||||
|
||||
let substr = Arc::new(StringArray::from(vec!["o", "o", "o"]));
|
||||
let str_arr = Arc::new(StringArray::from(vec![
|
||||
"hello world",
|
||||
"hello world",
|
||||
"hello world",
|
||||
]));
|
||||
let pos = Arc::new(datafusion_common::arrow::array::Int64Array::from(vec![
|
||||
1, 5, 8,
|
||||
]));
|
||||
|
||||
let args = create_args(vec![substr, str_arr, pos]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
|
||||
assert_eq!(int_array.value(0), 5); // first 'o' at position 5
|
||||
assert_eq!(int_array.value(1), 5); // 'o' at position 5 (start from 5)
|
||||
assert_eq!(int_array.value(2), 8); // 'o' in "world" at position 8
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_locate_unicode() {
|
||||
let function = LocateFunction::default();
|
||||
|
||||
let substr = Arc::new(StringArray::from(vec!["世", "界"]));
|
||||
let str_arr = Arc::new(StringArray::from(vec!["hello世界", "hello世界"]));
|
||||
|
||||
let args = create_args(vec![substr, str_arr]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
|
||||
assert_eq!(int_array.value(0), 6); // "世" at position 6
|
||||
assert_eq!(int_array.value(1), 7); // "界" at position 7
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_locate_empty_needle() {
|
||||
let function = LocateFunction::default();
|
||||
|
||||
let substr = Arc::new(StringArray::from(vec!["", ""]));
|
||||
let str_arr = Arc::new(StringArray::from(vec!["hello", "hello"]));
|
||||
let pos = Arc::new(datafusion_common::arrow::array::Int64Array::from(vec![
|
||||
1, 3,
|
||||
]));
|
||||
|
||||
let args = create_args(vec![substr, str_arr, pos]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
|
||||
assert_eq!(int_array.value(0), 1); // empty string at pos 1
|
||||
assert_eq!(int_array.value(1), 3); // empty string at pos 3
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_locate_with_nulls() {
|
||||
let function = LocateFunction::default();
|
||||
|
||||
let substr = Arc::new(StringArray::from(vec![Some("o"), None]));
|
||||
let str_arr = Arc::new(StringArray::from(vec![Some("hello"), Some("hello")]));
|
||||
|
||||
let args = create_args(vec![substr, str_arr]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
|
||||
assert_eq!(int_array.value(0), 5);
|
||||
assert!(int_array.is_null(1));
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
}
|
||||
252
src/common/function/src/scalars/string/space.rs
Normal file
252
src/common/function/src/scalars/string/space.rs
Normal file
@@ -0,0 +1,252 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! MySQL-compatible SPACE function implementation.
|
||||
//!
|
||||
//! SPACE(N) - Returns a string consisting of N space characters.
|
||||
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_common::arrow::array::{Array, AsArray, LargeStringBuilder};
|
||||
use datafusion_common::arrow::datatypes::DataType;
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility};
|
||||
|
||||
use crate::function::Function;
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
|
||||
const NAME: &str = "space";
|
||||
|
||||
// Safety limit for maximum number of spaces
|
||||
const MAX_SPACE_COUNT: i64 = 1024 * 1024; // 1MB of spaces
|
||||
|
||||
/// MySQL-compatible SPACE function.
|
||||
///
|
||||
/// Syntax: SPACE(N)
|
||||
/// Returns a string consisting of N space characters.
|
||||
/// Returns NULL if N is NULL.
|
||||
/// Returns empty string if N < 0.
|
||||
#[derive(Debug)]
|
||||
pub struct SpaceFunction {
|
||||
signature: Signature,
|
||||
}
|
||||
|
||||
impl SpaceFunction {
|
||||
pub fn register(registry: &FunctionRegistry) {
|
||||
registry.register_scalar(SpaceFunction::default());
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SpaceFunction {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
signature: Signature::one_of(
|
||||
vec![
|
||||
TypeSignature::Exact(vec![DataType::Int64]),
|
||||
TypeSignature::Exact(vec![DataType::Int32]),
|
||||
TypeSignature::Exact(vec![DataType::Int16]),
|
||||
TypeSignature::Exact(vec![DataType::Int8]),
|
||||
TypeSignature::Exact(vec![DataType::UInt64]),
|
||||
TypeSignature::Exact(vec![DataType::UInt32]),
|
||||
TypeSignature::Exact(vec![DataType::UInt16]),
|
||||
TypeSignature::Exact(vec![DataType::UInt8]),
|
||||
],
|
||||
Volatility::Immutable,
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for SpaceFunction {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", NAME.to_ascii_uppercase())
|
||||
}
|
||||
}
|
||||
|
||||
impl Function for SpaceFunction {
|
||||
fn name(&self) -> &str {
|
||||
NAME
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
|
||||
Ok(DataType::LargeUtf8)
|
||||
}
|
||||
|
||||
fn signature(&self) -> &Signature {
|
||||
&self.signature
|
||||
}
|
||||
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
if args.args.len() != 1 {
|
||||
return Err(DataFusionError::Execution(
|
||||
"SPACE requires exactly 1 argument: SPACE(N)".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let arrays = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let len = arrays[0].len();
|
||||
let n_array = &arrays[0];
|
||||
|
||||
let mut builder = LargeStringBuilder::with_capacity(len, len * 10);
|
||||
|
||||
for i in 0..len {
|
||||
if n_array.is_null(i) {
|
||||
builder.append_null();
|
||||
continue;
|
||||
}
|
||||
|
||||
let n = get_int_value(n_array, i)?;
|
||||
|
||||
if n < 0 {
|
||||
// MySQL returns empty string for negative values
|
||||
builder.append_value("");
|
||||
} else if n > MAX_SPACE_COUNT {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"SPACE: requested {} spaces exceeds maximum allowed ({})",
|
||||
n, MAX_SPACE_COUNT
|
||||
)));
|
||||
} else {
|
||||
let spaces = " ".repeat(n as usize);
|
||||
builder.append_value(&spaces);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract integer value from various integer types.
|
||||
fn get_int_value(
|
||||
array: &datafusion_common::arrow::array::ArrayRef,
|
||||
index: usize,
|
||||
) -> datafusion_common::Result<i64> {
|
||||
use datafusion_common::arrow::datatypes as arrow_types;
|
||||
|
||||
match array.data_type() {
|
||||
DataType::Int64 => Ok(array.as_primitive::<arrow_types::Int64Type>().value(index)),
|
||||
DataType::Int32 => Ok(array.as_primitive::<arrow_types::Int32Type>().value(index) as i64),
|
||||
DataType::Int16 => Ok(array.as_primitive::<arrow_types::Int16Type>().value(index) as i64),
|
||||
DataType::Int8 => Ok(array.as_primitive::<arrow_types::Int8Type>().value(index) as i64),
|
||||
DataType::UInt64 => {
|
||||
let v = array.as_primitive::<arrow_types::UInt64Type>().value(index);
|
||||
if v > i64::MAX as u64 {
|
||||
Err(DataFusionError::Execution(format!(
|
||||
"SPACE: value {} exceeds maximum",
|
||||
v
|
||||
)))
|
||||
} else {
|
||||
Ok(v as i64)
|
||||
}
|
||||
}
|
||||
DataType::UInt32 => Ok(array.as_primitive::<arrow_types::UInt32Type>().value(index) as i64),
|
||||
DataType::UInt16 => Ok(array.as_primitive::<arrow_types::UInt16Type>().value(index) as i64),
|
||||
DataType::UInt8 => Ok(array.as_primitive::<arrow_types::UInt8Type>().value(index) as i64),
|
||||
_ => Err(DataFusionError::Execution(format!(
|
||||
"SPACE: unsupported type {:?}",
|
||||
array.data_type()
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_common::arrow::array::Int64Array;
|
||||
use datafusion_common::arrow::datatypes::Field;
|
||||
use datafusion_expr::ScalarFunctionArgs;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn create_args(arrays: Vec<datafusion_common::arrow::array::ArrayRef>) -> ScalarFunctionArgs {
|
||||
let arg_fields: Vec<_> = arrays
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, arr)| {
|
||||
Arc::new(Field::new(
|
||||
format!("arg_{}", i),
|
||||
arr.data_type().clone(),
|
||||
true,
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
|
||||
ScalarFunctionArgs {
|
||||
args: arrays.iter().cloned().map(ColumnarValue::Array).collect(),
|
||||
arg_fields,
|
||||
return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
|
||||
number_rows: arrays[0].len(),
|
||||
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_space_basic() {
|
||||
let function = SpaceFunction::default();
|
||||
|
||||
let n = Arc::new(Int64Array::from(vec![0, 1, 5]));
|
||||
|
||||
let args = create_args(vec![n]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), "");
|
||||
assert_eq!(str_array.value(1), " ");
|
||||
assert_eq!(str_array.value(2), " ");
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_space_negative() {
|
||||
let function = SpaceFunction::default();
|
||||
|
||||
let n = Arc::new(Int64Array::from(vec![-1, -100]));
|
||||
|
||||
let args = create_args(vec![n]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), "");
|
||||
assert_eq!(str_array.value(1), "");
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_space_with_nulls() {
|
||||
let function = SpaceFunction::default();
|
||||
|
||||
let n = Arc::new(Int64Array::from(vec![Some(3), None]));
|
||||
|
||||
let args = create_args(vec![n]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), " ");
|
||||
assert!(str_array.is_null(1));
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user