From 3866512cf6a1a63d41f48ab80b73b9c2ec8ba6cf Mon Sep 17 00:00:00 2001 From: dennis zhuang Date: Thu, 25 Dec 2025 11:28:57 +0800 Subject: [PATCH] feat: add more MySQL-compatible string functions (#7454) * feat: add more mysql string functions Signed-off-by: Dennis Zhuang * refactor: use datafusion aliasing mechanism, close #7415 Signed-off-by: Dennis Zhuang * chore: comment Signed-off-by: Dennis Zhuang * fix: comment and style Signed-off-by: Dennis Zhuang --------- Signed-off-by: Dennis Zhuang --- src/common/function/src/scalars/string.rs | 18 + src/common/function/src/scalars/string/elt.rs | 252 +++++++++ .../function/src/scalars/string/field.rs | 224 ++++++++ .../function/src/scalars/string/format.rs | 512 ++++++++++++++++++ .../function/src/scalars/string/insert.rs | 345 ++++++++++++ .../function/src/scalars/string/locate.rs | 373 +++++++++++++ .../function/src/scalars/string/space.rs | 252 +++++++++ src/query/src/datafusion/planner.rs | 39 +- .../src/datafusion/planner/function_alias.rs | 86 --- src/query/src/query_engine/state.rs | 36 ++ .../function/string/mysql_compat.result | 347 ++++++++++++ .../common/function/string/mysql_compat.sql | 97 ++++ 12 files changed, 2461 insertions(+), 120 deletions(-) create mode 100644 src/common/function/src/scalars/string/elt.rs create mode 100644 src/common/function/src/scalars/string/field.rs create mode 100644 src/common/function/src/scalars/string/format.rs create mode 100644 src/common/function/src/scalars/string/insert.rs create mode 100644 src/common/function/src/scalars/string/locate.rs create mode 100644 src/common/function/src/scalars/string/space.rs delete mode 100644 src/query/src/datafusion/planner/function_alias.rs create mode 100644 tests/cases/standalone/common/function/string/mysql_compat.result create mode 100644 tests/cases/standalone/common/function/string/mysql_compat.sql diff --git a/src/common/function/src/scalars/string.rs b/src/common/function/src/scalars/string.rs index 95c6201ee2..e71dca660b 100644 --- a/src/common/function/src/scalars/string.rs +++ b/src/common/function/src/scalars/string.rs @@ -14,13 +14,31 @@ //! String scalar functions +mod elt; +mod field; +mod format; +mod insert; +mod locate; mod regexp_extract; +mod space; +pub(crate) use elt::EltFunction; +pub(crate) use field::FieldFunction; +pub(crate) use format::FormatFunction; +pub(crate) use insert::InsertFunction; +pub(crate) use locate::LocateFunction; pub(crate) use regexp_extract::RegexpExtractFunction; +pub(crate) use space::SpaceFunction; use crate::function_registry::FunctionRegistry; /// Register all string functions pub fn register_string_functions(registry: &FunctionRegistry) { + EltFunction::register(registry); + FieldFunction::register(registry); + FormatFunction::register(registry); + InsertFunction::register(registry); + LocateFunction::register(registry); RegexpExtractFunction::register(registry); + SpaceFunction::register(registry); } diff --git a/src/common/function/src/scalars/string/elt.rs b/src/common/function/src/scalars/string/elt.rs new file mode 100644 index 0000000000..8061febd37 --- /dev/null +++ b/src/common/function/src/scalars/string/elt.rs @@ -0,0 +1,252 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! MySQL-compatible ELT function implementation. +//! +//! ELT(N, str1, str2, str3, ...) - Returns the Nth string from the list. +//! Returns NULL if N < 1 or N > number of strings. + +use std::fmt; +use std::sync::Arc; + +use datafusion_common::DataFusionError; +use datafusion_common::arrow::array::{Array, ArrayRef, AsArray, LargeStringBuilder}; +use datafusion_common::arrow::compute::cast; +use datafusion_common::arrow::datatypes::DataType; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility}; + +use crate::function::Function; +use crate::function_registry::FunctionRegistry; + +const NAME: &str = "elt"; + +/// MySQL-compatible ELT function. +/// +/// Syntax: ELT(N, str1, str2, str3, ...) +/// Returns the Nth string argument. N is 1-based. +/// Returns NULL if N is NULL, N < 1, or N > number of string arguments. +#[derive(Debug)] +pub struct EltFunction { + signature: Signature, +} + +impl EltFunction { + pub fn register(registry: &FunctionRegistry) { + registry.register_scalar(EltFunction::default()); + } +} + +impl Default for EltFunction { + fn default() -> Self { + Self { + // ELT takes a variable number of arguments: (Int64, String, String, ...) + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl fmt::Display for EltFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", NAME.to_ascii_uppercase()) + } +} + +impl Function for EltFunction { + fn name(&self) -> &str { + NAME + } + + fn return_type(&self, _: &[DataType]) -> datafusion_common::Result { + Ok(DataType::LargeUtf8) + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + if args.args.len() < 2 { + return Err(DataFusionError::Execution( + "ELT requires at least 2 arguments: ELT(N, str1, ...)".to_string(), + )); + } + + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + let len = arrays[0].len(); + let num_strings = arrays.len() - 1; + + // First argument is the index (N) - try to cast to Int64 + let index_array = if arrays[0].data_type() == &DataType::Null { + // All NULLs - return all NULLs + let mut builder = LargeStringBuilder::with_capacity(len, 0); + for _ in 0..len { + builder.append_null(); + } + return Ok(ColumnarValue::Array(Arc::new(builder.finish()))); + } else { + cast(arrays[0].as_ref(), &DataType::Int64).map_err(|e| { + DataFusionError::Execution(format!("ELT: index argument cast failed: {}", e)) + })? + }; + + // Cast string arguments to LargeUtf8 + let string_arrays: Vec = arrays[1..] + .iter() + .enumerate() + .map(|(i, arr)| { + cast(arr.as_ref(), &DataType::LargeUtf8).map_err(|e| { + DataFusionError::Execution(format!( + "ELT: string argument {} cast failed: {}", + i + 1, + e + )) + }) + }) + .collect::>>()?; + + let mut builder = LargeStringBuilder::with_capacity(len, len * 32); + + for i in 0..len { + if index_array.is_null(i) { + builder.append_null(); + continue; + } + + let n = index_array + .as_primitive::() + .value(i); + + // N is 1-based, check bounds + if n < 1 || n as usize > num_strings { + builder.append_null(); + continue; + } + + let str_idx = (n - 1) as usize; + let str_array = string_arrays[str_idx].as_string::(); + + if str_array.is_null(i) { + builder.append_null(); + } else { + builder.append_value(str_array.value(i)); + } + } + + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use datafusion_common::arrow::array::{Int64Array, StringArray}; + use datafusion_common::arrow::datatypes::Field; + use datafusion_expr::ScalarFunctionArgs; + + use super::*; + + fn create_args(arrays: Vec) -> ScalarFunctionArgs { + let arg_fields: Vec<_> = arrays + .iter() + .enumerate() + .map(|(i, arr)| { + Arc::new(Field::new( + format!("arg_{}", i), + arr.data_type().clone(), + true, + )) + }) + .collect(); + + ScalarFunctionArgs { + args: arrays.iter().cloned().map(ColumnarValue::Array).collect(), + arg_fields, + return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)), + number_rows: arrays[0].len(), + config_options: Arc::new(datafusion_common::config::ConfigOptions::default()), + } + } + + #[test] + fn test_elt_basic() { + let function = EltFunction::default(); + + let n = Arc::new(Int64Array::from(vec![1, 2, 3])); + let s1 = Arc::new(StringArray::from(vec!["a", "a", "a"])); + let s2 = Arc::new(StringArray::from(vec!["b", "b", "b"])); + let s3 = Arc::new(StringArray::from(vec!["c", "c", "c"])); + + let args = create_args(vec![n, s1, s2, s3]); + let result = function.invoke_with_args(args).unwrap(); + + if let ColumnarValue::Array(array) = result { + let str_array = array.as_string::(); + assert_eq!(str_array.value(0), "a"); + assert_eq!(str_array.value(1), "b"); + assert_eq!(str_array.value(2), "c"); + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_elt_out_of_bounds() { + let function = EltFunction::default(); + + let n = Arc::new(Int64Array::from(vec![0, 4, -1])); + let s1 = Arc::new(StringArray::from(vec!["a", "a", "a"])); + let s2 = Arc::new(StringArray::from(vec!["b", "b", "b"])); + let s3 = Arc::new(StringArray::from(vec!["c", "c", "c"])); + + let args = create_args(vec![n, s1, s2, s3]); + let result = function.invoke_with_args(args).unwrap(); + + if let ColumnarValue::Array(array) = result { + let str_array = array.as_string::(); + assert!(str_array.is_null(0)); // 0 is out of bounds + assert!(str_array.is_null(1)); // 4 is out of bounds + assert!(str_array.is_null(2)); // -1 is out of bounds + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_elt_with_nulls() { + let function = EltFunction::default(); + + // Row 0: n=1, select s1="a" -> "a" + // Row 1: n=NULL -> NULL + // Row 2: n=1, select s1=NULL -> NULL + let n = Arc::new(Int64Array::from(vec![Some(1), None, Some(1)])); + let s1 = Arc::new(StringArray::from(vec![Some("a"), Some("a"), None])); + let s2 = Arc::new(StringArray::from(vec![Some("b"), Some("b"), Some("b")])); + + let args = create_args(vec![n, s1, s2]); + let result = function.invoke_with_args(args).unwrap(); + + if let ColumnarValue::Array(array) = result { + let str_array = array.as_string::(); + assert_eq!(str_array.value(0), "a"); + assert!(str_array.is_null(1)); // N is NULL + assert!(str_array.is_null(2)); // Selected string is NULL + } else { + panic!("Expected array result"); + } + } +} diff --git a/src/common/function/src/scalars/string/field.rs b/src/common/function/src/scalars/string/field.rs new file mode 100644 index 0000000000..39321bf6a8 --- /dev/null +++ b/src/common/function/src/scalars/string/field.rs @@ -0,0 +1,224 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! MySQL-compatible FIELD function implementation. +//! +//! FIELD(str, str1, str2, str3, ...) - Returns the 1-based index of str in the list. +//! Returns 0 if str is not found or is NULL. + +use std::fmt; +use std::sync::Arc; + +use datafusion_common::DataFusionError; +use datafusion_common::arrow::array::{Array, ArrayRef, AsArray, Int64Builder}; +use datafusion_common::arrow::compute::cast; +use datafusion_common::arrow::datatypes::DataType; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility}; + +use crate::function::Function; +use crate::function_registry::FunctionRegistry; + +const NAME: &str = "field"; + +/// MySQL-compatible FIELD function. +/// +/// Syntax: FIELD(str, str1, str2, str3, ...) +/// Returns the 1-based index of str in the argument list (str1, str2, str3, ...). +/// Returns 0 if str is not found or is NULL. +#[derive(Debug)] +pub struct FieldFunction { + signature: Signature, +} + +impl FieldFunction { + pub fn register(registry: &FunctionRegistry) { + registry.register_scalar(FieldFunction::default()); + } +} + +impl Default for FieldFunction { + fn default() -> Self { + Self { + // FIELD takes a variable number of arguments: (String, String, String, ...) + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl fmt::Display for FieldFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", NAME.to_ascii_uppercase()) + } +} + +impl Function for FieldFunction { + fn name(&self) -> &str { + NAME + } + + fn return_type(&self, _: &[DataType]) -> datafusion_common::Result { + Ok(DataType::Int64) + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + if args.args.len() < 2 { + return Err(DataFusionError::Execution( + "FIELD requires at least 2 arguments: FIELD(str, str1, ...)".to_string(), + )); + } + + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + let len = arrays[0].len(); + + // Cast all arguments to LargeUtf8 + let string_arrays: Vec = arrays + .iter() + .enumerate() + .map(|(i, arr)| { + cast(arr.as_ref(), &DataType::LargeUtf8).map_err(|e| { + DataFusionError::Execution(format!("FIELD: argument {} cast failed: {}", i, e)) + }) + }) + .collect::>>()?; + + let search_str = string_arrays[0].as_string::(); + let mut builder = Int64Builder::with_capacity(len); + + for i in 0..len { + // If search string is NULL, return 0 + if search_str.is_null(i) { + builder.append_value(0); + continue; + } + + let needle = search_str.value(i); + let mut found_idx = 0i64; + + // Search through the list (starting from index 1 in string_arrays) + for (j, str_arr) in string_arrays[1..].iter().enumerate() { + let str_array = str_arr.as_string::(); + if !str_array.is_null(i) && str_array.value(i) == needle { + found_idx = (j + 1) as i64; // 1-based index + break; + } + } + + builder.append_value(found_idx); + } + + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use datafusion_common::arrow::array::StringArray; + use datafusion_common::arrow::datatypes::Field; + use datafusion_expr::ScalarFunctionArgs; + + use super::*; + + fn create_args(arrays: Vec) -> ScalarFunctionArgs { + let arg_fields: Vec<_> = arrays + .iter() + .enumerate() + .map(|(i, arr)| { + Arc::new(Field::new( + format!("arg_{}", i), + arr.data_type().clone(), + true, + )) + }) + .collect(); + + ScalarFunctionArgs { + args: arrays.iter().cloned().map(ColumnarValue::Array).collect(), + arg_fields, + return_field: Arc::new(Field::new("result", DataType::Int64, true)), + number_rows: arrays[0].len(), + config_options: Arc::new(datafusion_common::config::ConfigOptions::default()), + } + } + + #[test] + fn test_field_basic() { + let function = FieldFunction::default(); + + let search = Arc::new(StringArray::from(vec!["b", "d", "a"])); + let s1 = Arc::new(StringArray::from(vec!["a", "a", "a"])); + let s2 = Arc::new(StringArray::from(vec!["b", "b", "b"])); + let s3 = Arc::new(StringArray::from(vec!["c", "c", "c"])); + + let args = create_args(vec![search, s1, s2, s3]); + let result = function.invoke_with_args(args).unwrap(); + + if let ColumnarValue::Array(array) = result { + let int_array = array.as_primitive::(); + assert_eq!(int_array.value(0), 2); // "b" is at index 2 + assert_eq!(int_array.value(1), 0); // "d" not found + assert_eq!(int_array.value(2), 1); // "a" is at index 1 + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_field_with_null_search() { + let function = FieldFunction::default(); + + let search = Arc::new(StringArray::from(vec![Some("a"), None])); + let s1 = Arc::new(StringArray::from(vec!["a", "a"])); + let s2 = Arc::new(StringArray::from(vec!["b", "b"])); + + let args = create_args(vec![search, s1, s2]); + let result = function.invoke_with_args(args).unwrap(); + + if let ColumnarValue::Array(array) = result { + let int_array = array.as_primitive::(); + assert_eq!(int_array.value(0), 1); // "a" found at index 1 + assert_eq!(int_array.value(1), 0); // NULL search returns 0 + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_field_case_sensitive() { + let function = FieldFunction::default(); + + let search = Arc::new(StringArray::from(vec!["A", "a"])); + let s1 = Arc::new(StringArray::from(vec!["a", "a"])); + let s2 = Arc::new(StringArray::from(vec!["A", "A"])); + + let args = create_args(vec![search, s1, s2]); + let result = function.invoke_with_args(args).unwrap(); + + if let ColumnarValue::Array(array) = result { + let int_array = array.as_primitive::(); + assert_eq!(int_array.value(0), 2); // "A" matches at index 2 + assert_eq!(int_array.value(1), 1); // "a" matches at index 1 + } else { + panic!("Expected array result"); + } + } +} diff --git a/src/common/function/src/scalars/string/format.rs b/src/common/function/src/scalars/string/format.rs new file mode 100644 index 0000000000..e6a6044cb1 --- /dev/null +++ b/src/common/function/src/scalars/string/format.rs @@ -0,0 +1,512 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! MySQL-compatible FORMAT function implementation. +//! +//! FORMAT(X, D) - Formats the number X with D decimal places using thousand separators. + +use std::fmt; +use std::sync::Arc; + +use datafusion_common::DataFusionError; +use datafusion_common::arrow::array::{Array, AsArray, LargeStringBuilder}; +use datafusion_common::arrow::datatypes as arrow_types; +use datafusion_common::arrow::datatypes::DataType; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility}; + +use crate::function::Function; +use crate::function_registry::FunctionRegistry; + +const NAME: &str = "format"; + +/// MySQL-compatible FORMAT function. +/// +/// Syntax: FORMAT(X, D) +/// Formats the number X to a format like '#,###,###.##', rounded to D decimal places. +/// D can be 0 to 30. +/// +/// Note: This implementation uses the en_US locale (comma as thousand separator, +/// period as decimal separator). +#[derive(Debug)] +pub struct FormatFunction { + signature: Signature, +} + +impl FormatFunction { + pub fn register(registry: &FunctionRegistry) { + registry.register_scalar(FormatFunction::default()); + } +} + +impl Default for FormatFunction { + fn default() -> Self { + let mut signatures = Vec::new(); + + // Support various numeric types for X + let numeric_types = [ + DataType::Float64, + DataType::Float32, + DataType::Int64, + DataType::Int32, + DataType::Int16, + DataType::Int8, + DataType::UInt64, + DataType::UInt32, + DataType::UInt16, + DataType::UInt8, + ]; + + // D can be various integer types + let int_types = [ + DataType::Int64, + DataType::Int32, + DataType::Int16, + DataType::Int8, + DataType::UInt64, + DataType::UInt32, + DataType::UInt16, + DataType::UInt8, + ]; + + for x_type in &numeric_types { + for d_type in &int_types { + signatures.push(TypeSignature::Exact(vec![x_type.clone(), d_type.clone()])); + } + } + + Self { + signature: Signature::one_of(signatures, Volatility::Immutable), + } + } +} + +impl fmt::Display for FormatFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", NAME.to_ascii_uppercase()) + } +} + +impl Function for FormatFunction { + fn name(&self) -> &str { + NAME + } + + fn return_type(&self, _: &[DataType]) -> datafusion_common::Result { + Ok(DataType::LargeUtf8) + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + if args.args.len() != 2 { + return Err(DataFusionError::Execution( + "FORMAT requires exactly 2 arguments: FORMAT(X, D)".to_string(), + )); + } + + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + let len = arrays[0].len(); + + let x_array = &arrays[0]; + let d_array = &arrays[1]; + + let mut builder = LargeStringBuilder::with_capacity(len, len * 20); + + for i in 0..len { + if x_array.is_null(i) || d_array.is_null(i) { + builder.append_null(); + continue; + } + + let decimal_places = get_decimal_places(d_array, i)?.clamp(0, 30) as usize; + + let formatted = match x_array.data_type() { + DataType::Float64 | DataType::Float32 => { + format_number_float(get_float_value(x_array, i)?, decimal_places) + } + DataType::Int64 + | DataType::Int32 + | DataType::Int16 + | DataType::Int8 + | DataType::UInt64 + | DataType::UInt32 + | DataType::UInt16 + | DataType::UInt8 => format_number_integer(x_array, i, decimal_places)?, + _ => { + return Err(DataFusionError::Execution(format!( + "FORMAT: unsupported type {:?}", + x_array.data_type() + ))); + } + }; + builder.append_value(&formatted); + } + + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + } +} + +/// Get float value from various numeric types. +fn get_float_value( + array: &datafusion_common::arrow::array::ArrayRef, + index: usize, +) -> datafusion_common::Result { + match array.data_type() { + DataType::Float64 => Ok(array + .as_primitive::() + .value(index)), + DataType::Float32 => Ok(array + .as_primitive::() + .value(index) as f64), + _ => Err(DataFusionError::Execution(format!( + "FORMAT: unsupported type {:?}", + array.data_type() + ))), + } +} + +/// Get decimal places from various integer types. +/// +/// MySQL clamps decimal places to `0..=30`. This function returns an `i64` so the caller can clamp. +fn get_decimal_places( + array: &datafusion_common::arrow::array::ArrayRef, + index: usize, +) -> datafusion_common::Result { + match array.data_type() { + DataType::Int64 => Ok(array.as_primitive::().value(index)), + DataType::Int32 => Ok(array.as_primitive::().value(index) as i64), + DataType::Int16 => Ok(array.as_primitive::().value(index) as i64), + DataType::Int8 => Ok(array.as_primitive::().value(index) as i64), + DataType::UInt64 => { + let v = array.as_primitive::().value(index); + Ok(if v > i64::MAX as u64 { + i64::MAX + } else { + v as i64 + }) + } + DataType::UInt32 => Ok(array.as_primitive::().value(index) as i64), + DataType::UInt16 => Ok(array.as_primitive::().value(index) as i64), + DataType::UInt8 => Ok(array.as_primitive::().value(index) as i64), + _ => Err(DataFusionError::Execution(format!( + "FORMAT: unsupported type {:?}", + array.data_type() + ))), + } +} + +fn format_number_integer( + array: &datafusion_common::arrow::array::ArrayRef, + index: usize, + decimal_places: usize, +) -> datafusion_common::Result { + let (is_negative, abs_digits) = match array.data_type() { + DataType::Int64 => { + let v = array.as_primitive::().value(index) as i128; + (v.is_negative(), v.unsigned_abs().to_string()) + } + DataType::Int32 => { + let v = array.as_primitive::().value(index) as i128; + (v.is_negative(), v.unsigned_abs().to_string()) + } + DataType::Int16 => { + let v = array.as_primitive::().value(index) as i128; + (v.is_negative(), v.unsigned_abs().to_string()) + } + DataType::Int8 => { + let v = array.as_primitive::().value(index) as i128; + (v.is_negative(), v.unsigned_abs().to_string()) + } + DataType::UInt64 => { + let v = array.as_primitive::().value(index) as u128; + (false, v.to_string()) + } + DataType::UInt32 => { + let v = array.as_primitive::().value(index) as u128; + (false, v.to_string()) + } + DataType::UInt16 => { + let v = array.as_primitive::().value(index) as u128; + (false, v.to_string()) + } + DataType::UInt8 => { + let v = array.as_primitive::().value(index) as u128; + (false, v.to_string()) + } + _ => { + return Err(DataFusionError::Execution(format!( + "FORMAT: unsupported type {:?}", + array.data_type() + ))); + } + }; + + let mut result = String::new(); + if is_negative { + result.push('-'); + } + result.push_str(&add_thousand_separators(&abs_digits)); + + if decimal_places > 0 { + result.push('.'); + result.push_str(&"0".repeat(decimal_places)); + } + + Ok(result) +} + +/// Format a float with thousand separators and `decimal_places` digits after decimal point. +fn format_number_float(x: f64, decimal_places: usize) -> String { + // Handle special cases + if x.is_nan() { + return "NaN".to_string(); + } + if x.is_infinite() { + return if x.is_sign_positive() { + "Infinity".to_string() + } else { + "-Infinity".to_string() + }; + } + + // Round to decimal_places + let multiplier = 10f64.powi(decimal_places as i32); + let rounded = (x * multiplier).round() / multiplier; + + // Split into integer and fractional parts + let is_negative = rounded < 0.0; + let abs_value = rounded.abs(); + + // Format with the specified decimal places + let formatted = if decimal_places == 0 { + format!("{:.0}", abs_value) + } else { + format!("{:.prec$}", abs_value, prec = decimal_places) + }; + + // Split at decimal point + let parts: Vec<&str> = formatted.split('.').collect(); + let int_part = parts[0]; + let dec_part = parts.get(1).copied(); + + // Add thousand separators to integer part + let int_with_sep = add_thousand_separators(int_part); + + // Build result + let mut result = String::new(); + if is_negative { + result.push('-'); + } + result.push_str(&int_with_sep); + if let Some(dec) = dec_part { + result.push('.'); + result.push_str(dec); + } + + result +} + +/// Add thousand separators (commas) to an integer string. +fn add_thousand_separators(s: &str) -> String { + let chars: Vec = s.chars().collect(); + let len = chars.len(); + + if len <= 3 { + return s.to_string(); + } + + let mut result = String::with_capacity(len + len / 3); + let first_group_len = len % 3; + let first_group_len = if first_group_len == 0 { + 3 + } else { + first_group_len + }; + + for (i, ch) in chars.iter().enumerate() { + if i > 0 && i >= first_group_len && (i - first_group_len) % 3 == 0 { + result.push(','); + } + result.push(*ch); + } + + result +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use datafusion_common::arrow::array::{Float64Array, Int64Array}; + use datafusion_common::arrow::datatypes::Field; + use datafusion_expr::ScalarFunctionArgs; + + use super::*; + + fn create_args(arrays: Vec) -> ScalarFunctionArgs { + let arg_fields: Vec<_> = arrays + .iter() + .enumerate() + .map(|(i, arr)| { + Arc::new(Field::new( + format!("arg_{}", i), + arr.data_type().clone(), + true, + )) + }) + .collect(); + + ScalarFunctionArgs { + args: arrays.iter().cloned().map(ColumnarValue::Array).collect(), + arg_fields, + return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)), + number_rows: arrays[0].len(), + config_options: Arc::new(datafusion_common::config::ConfigOptions::default()), + } + } + + #[test] + fn test_format_basic() { + let function = FormatFunction::default(); + + let x = Arc::new(Float64Array::from(vec![1234567.891, 1234.5, 1234567.0])); + let d = Arc::new(Int64Array::from(vec![2, 0, 3])); + + let args = create_args(vec![x, d]); + let result = function.invoke_with_args(args).unwrap(); + + if let ColumnarValue::Array(array) = result { + let str_array = array.as_string::(); + assert_eq!(str_array.value(0), "1,234,567.89"); + assert_eq!(str_array.value(1), "1,235"); // rounded + assert_eq!(str_array.value(2), "1,234,567.000"); + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_format_negative() { + let function = FormatFunction::default(); + + let x = Arc::new(Float64Array::from(vec![-1234567.891])); + let d = Arc::new(Int64Array::from(vec![2])); + + let args = create_args(vec![x, d]); + let result = function.invoke_with_args(args).unwrap(); + + if let ColumnarValue::Array(array) = result { + let str_array = array.as_string::(); + assert_eq!(str_array.value(0), "-1,234,567.89"); + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_format_small_numbers() { + let function = FormatFunction::default(); + + let x = Arc::new(Float64Array::from(vec![0.5, 12.345, 123.0])); + let d = Arc::new(Int64Array::from(vec![2, 2, 0])); + + let args = create_args(vec![x, d]); + let result = function.invoke_with_args(args).unwrap(); + + if let ColumnarValue::Array(array) = result { + let str_array = array.as_string::(); + assert_eq!(str_array.value(0), "0.50"); + assert_eq!(str_array.value(1), "12.35"); // rounded + assert_eq!(str_array.value(2), "123"); + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_format_with_nulls() { + let function = FormatFunction::default(); + + let x = Arc::new(Float64Array::from(vec![Some(1234.5), None])); + let d = Arc::new(Int64Array::from(vec![2, 2])); + + let args = create_args(vec![x, d]); + let result = function.invoke_with_args(args).unwrap(); + + if let ColumnarValue::Array(array) = result { + let str_array = array.as_string::(); + assert_eq!(str_array.value(0), "1,234.50"); + assert!(str_array.is_null(1)); + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_add_thousand_separators() { + assert_eq!(add_thousand_separators("1"), "1"); + assert_eq!(add_thousand_separators("12"), "12"); + assert_eq!(add_thousand_separators("123"), "123"); + assert_eq!(add_thousand_separators("1234"), "1,234"); + assert_eq!(add_thousand_separators("12345"), "12,345"); + assert_eq!(add_thousand_separators("123456"), "123,456"); + assert_eq!(add_thousand_separators("1234567"), "1,234,567"); + assert_eq!(add_thousand_separators("12345678"), "12,345,678"); + assert_eq!(add_thousand_separators("123456789"), "123,456,789"); + } + + #[test] + fn test_format_large_int_no_float_precision_loss() { + let function = FormatFunction::default(); + + // 2^53 + 1 cannot be represented exactly as f64. + let x = Arc::new(Int64Array::from(vec![9_007_199_254_740_993i64])); + let d = Arc::new(Int64Array::from(vec![0])); + + let args = create_args(vec![x, d]); + let result = function.invoke_with_args(args).unwrap(); + + if let ColumnarValue::Array(array) = result { + let str_array = array.as_string::(); + assert_eq!(str_array.value(0), "9,007,199,254,740,993"); + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_format_decimal_places_u64_overflow_clamps() { + use datafusion_common::arrow::array::UInt64Array; + + let function = FormatFunction::default(); + + let x = Arc::new(Int64Array::from(vec![1])); + let d = Arc::new(UInt64Array::from(vec![u64::MAX])); + + let args = create_args(vec![x, d]); + let result = function.invoke_with_args(args).unwrap(); + + if let ColumnarValue::Array(array) = result { + let str_array = array.as_string::(); + assert_eq!(str_array.value(0), format!("1.{}", "0".repeat(30))); + } else { + panic!("Expected array result"); + } + } +} diff --git a/src/common/function/src/scalars/string/insert.rs b/src/common/function/src/scalars/string/insert.rs new file mode 100644 index 0000000000..4816ac0df4 --- /dev/null +++ b/src/common/function/src/scalars/string/insert.rs @@ -0,0 +1,345 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! MySQL-compatible INSERT function implementation. +//! +//! INSERT(str, pos, len, newstr) - Inserts newstr into str at position pos, +//! replacing len characters. + +use std::fmt; +use std::sync::Arc; + +use datafusion_common::DataFusionError; +use datafusion_common::arrow::array::{Array, ArrayRef, AsArray, LargeStringBuilder}; +use datafusion_common::arrow::compute::cast; +use datafusion_common::arrow::datatypes::DataType; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility}; + +use crate::function::Function; +use crate::function_registry::FunctionRegistry; + +const NAME: &str = "insert"; + +/// MySQL-compatible INSERT function. +/// +/// Syntax: INSERT(str, pos, len, newstr) +/// Returns str with the substring beginning at position pos and len characters long +/// replaced by newstr. +/// +/// - pos is 1-based +/// - If pos is out of range, returns the original string +/// - If len is out of range, replaces from pos to end of string +#[derive(Debug)] +pub struct InsertFunction { + signature: Signature, +} + +impl InsertFunction { + pub fn register(registry: &FunctionRegistry) { + registry.register_scalar(InsertFunction::default()); + } +} + +impl Default for InsertFunction { + fn default() -> Self { + let mut signatures = Vec::new(); + let string_types = [DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View]; + let int_types = [ + DataType::Int64, + DataType::Int32, + DataType::Int16, + DataType::Int8, + DataType::UInt64, + DataType::UInt32, + DataType::UInt16, + DataType::UInt8, + ]; + + for str_type in &string_types { + for newstr_type in &string_types { + for pos_type in &int_types { + for len_type in &int_types { + signatures.push(TypeSignature::Exact(vec![ + str_type.clone(), + pos_type.clone(), + len_type.clone(), + newstr_type.clone(), + ])); + } + } + } + } + + Self { + signature: Signature::one_of(signatures, Volatility::Immutable), + } + } +} + +impl fmt::Display for InsertFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", NAME.to_ascii_uppercase()) + } +} + +impl Function for InsertFunction { + fn name(&self) -> &str { + NAME + } + + fn return_type(&self, _: &[DataType]) -> datafusion_common::Result { + Ok(DataType::LargeUtf8) + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + if args.args.len() != 4 { + return Err(DataFusionError::Execution( + "INSERT requires exactly 4 arguments: INSERT(str, pos, len, newstr)".to_string(), + )); + } + + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + let len = arrays[0].len(); + + // Cast string arguments to LargeUtf8 + let str_array = cast_to_large_utf8(&arrays[0], "str")?; + let newstr_array = cast_to_large_utf8(&arrays[3], "newstr")?; + let pos_array = cast_to_int64(&arrays[1], "pos")?; + let replace_len_array = cast_to_int64(&arrays[2], "len")?; + + let str_arr = str_array.as_string::(); + let pos_arr = pos_array.as_primitive::(); + let len_arr = + replace_len_array.as_primitive::(); + let newstr_arr = newstr_array.as_string::(); + + let mut builder = LargeStringBuilder::with_capacity(len, len * 32); + + for i in 0..len { + // Check for NULLs + if str_arr.is_null(i) + || pos_array.is_null(i) + || replace_len_array.is_null(i) + || newstr_arr.is_null(i) + { + builder.append_null(); + continue; + } + + let original = str_arr.value(i); + let pos = pos_arr.value(i); + let replace_len = len_arr.value(i); + let new_str = newstr_arr.value(i); + + let result = insert_string(original, pos, replace_len, new_str); + builder.append_value(&result); + } + + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + } +} + +/// Cast array to LargeUtf8 for uniform string access. +fn cast_to_large_utf8(array: &ArrayRef, name: &str) -> datafusion_common::Result { + cast(array.as_ref(), &DataType::LargeUtf8) + .map_err(|e| DataFusionError::Execution(format!("INSERT: {} cast failed: {}", name, e))) +} + +fn cast_to_int64(array: &ArrayRef, name: &str) -> datafusion_common::Result { + cast(array.as_ref(), &DataType::Int64) + .map_err(|e| DataFusionError::Execution(format!("INSERT: {} cast failed: {}", name, e))) +} + +/// Perform the INSERT string operation. +/// pos is 1-based. If pos < 1 or pos > len(str) + 1, returns original string. +fn insert_string(original: &str, pos: i64, replace_len: i64, new_str: &str) -> String { + let char_count = original.chars().count(); + + // MySQL behavior: if pos < 1 or pos > string length + 1, return original + if pos < 1 || pos as usize > char_count + 1 { + return original.to_string(); + } + + let start_idx = (pos - 1) as usize; // Convert to 0-based + + // Calculate end index for replacement + let replace_len = if replace_len < 0 { + 0 + } else { + replace_len as usize + }; + let end_idx = (start_idx + replace_len).min(char_count); + + let start_byte = char_to_byte_idx(original, start_idx); + let end_byte = char_to_byte_idx(original, end_idx); + + let mut result = String::with_capacity(original.len() + new_str.len()); + result.push_str(&original[..start_byte]); + result.push_str(new_str); + result.push_str(&original[end_byte..]); + result +} + +fn char_to_byte_idx(s: &str, char_idx: usize) -> usize { + s.char_indices() + .nth(char_idx) + .map(|(idx, _)| idx) + .unwrap_or(s.len()) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use datafusion_common::arrow::array::{Int64Array, StringArray}; + use datafusion_common::arrow::datatypes::Field; + use datafusion_expr::ScalarFunctionArgs; + + use super::*; + + fn create_args(arrays: Vec) -> ScalarFunctionArgs { + let arg_fields: Vec<_> = arrays + .iter() + .enumerate() + .map(|(i, arr)| { + Arc::new(Field::new( + format!("arg_{}", i), + arr.data_type().clone(), + true, + )) + }) + .collect(); + + ScalarFunctionArgs { + args: arrays.iter().cloned().map(ColumnarValue::Array).collect(), + arg_fields, + return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)), + number_rows: arrays[0].len(), + config_options: Arc::new(datafusion_common::config::ConfigOptions::default()), + } + } + + #[test] + fn test_insert_basic() { + let function = InsertFunction::default(); + + // INSERT('Quadratic', 3, 4, 'What') => 'QuWhattic' + let str_arr = Arc::new(StringArray::from(vec!["Quadratic"])); + let pos = Arc::new(Int64Array::from(vec![3])); + let len = Arc::new(Int64Array::from(vec![4])); + let newstr = Arc::new(StringArray::from(vec!["What"])); + + let args = create_args(vec![str_arr, pos, len, newstr]); + let result = function.invoke_with_args(args).unwrap(); + + if let ColumnarValue::Array(array) = result { + let str_array = array.as_string::(); + assert_eq!(str_array.value(0), "QuWhattic"); + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_insert_out_of_range_pos() { + let function = InsertFunction::default(); + + // INSERT('Quadratic', 0, 4, 'What') => 'Quadratic' (pos < 1) + let str_arr = Arc::new(StringArray::from(vec!["Quadratic", "Quadratic"])); + let pos = Arc::new(Int64Array::from(vec![0, 100])); + let len = Arc::new(Int64Array::from(vec![4, 4])); + let newstr = Arc::new(StringArray::from(vec!["What", "What"])); + + let args = create_args(vec![str_arr, pos, len, newstr]); + let result = function.invoke_with_args(args).unwrap(); + + if let ColumnarValue::Array(array) = result { + let str_array = array.as_string::(); + assert_eq!(str_array.value(0), "Quadratic"); // pos < 1 + assert_eq!(str_array.value(1), "Quadratic"); // pos > length + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_insert_replace_to_end() { + let function = InsertFunction::default(); + + // INSERT('Quadratic', 3, 100, 'What') => 'QuWhat' (len exceeds remaining) + let str_arr = Arc::new(StringArray::from(vec!["Quadratic"])); + let pos = Arc::new(Int64Array::from(vec![3])); + let len = Arc::new(Int64Array::from(vec![100])); + let newstr = Arc::new(StringArray::from(vec!["What"])); + + let args = create_args(vec![str_arr, pos, len, newstr]); + let result = function.invoke_with_args(args).unwrap(); + + if let ColumnarValue::Array(array) = result { + let str_array = array.as_string::(); + assert_eq!(str_array.value(0), "QuWhat"); + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_insert_unicode() { + let function = InsertFunction::default(); + + // INSERT('hello世界', 6, 1, 'の') => 'helloの界' + let str_arr = Arc::new(StringArray::from(vec!["hello世界"])); + let pos = Arc::new(Int64Array::from(vec![6])); + let len = Arc::new(Int64Array::from(vec![1])); + let newstr = Arc::new(StringArray::from(vec!["の"])); + + let args = create_args(vec![str_arr, pos, len, newstr]); + let result = function.invoke_with_args(args).unwrap(); + + if let ColumnarValue::Array(array) = result { + let str_array = array.as_string::(); + assert_eq!(str_array.value(0), "helloの界"); + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_insert_with_nulls() { + let function = InsertFunction::default(); + + let str_arr = Arc::new(StringArray::from(vec![Some("hello"), None])); + let pos = Arc::new(Int64Array::from(vec![1, 1])); + let len = Arc::new(Int64Array::from(vec![1, 1])); + let newstr = Arc::new(StringArray::from(vec!["X", "X"])); + + let args = create_args(vec![str_arr, pos, len, newstr]); + let result = function.invoke_with_args(args).unwrap(); + + if let ColumnarValue::Array(array) = result { + let str_array = array.as_string::(); + assert_eq!(str_array.value(0), "Xello"); + assert!(str_array.is_null(1)); + } else { + panic!("Expected array result"); + } + } +} diff --git a/src/common/function/src/scalars/string/locate.rs b/src/common/function/src/scalars/string/locate.rs new file mode 100644 index 0000000000..7aa421bc64 --- /dev/null +++ b/src/common/function/src/scalars/string/locate.rs @@ -0,0 +1,373 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! MySQL-compatible LOCATE function implementation. +//! +//! LOCATE(substr, str) - Returns the position of the first occurrence of substr in str (1-based). +//! LOCATE(substr, str, pos) - Returns the position of the first occurrence of substr in str, +//! starting from position pos. +//! Returns 0 if substr is not found. + +use std::fmt; +use std::sync::Arc; + +use datafusion_common::DataFusionError; +use datafusion_common::arrow::array::{Array, ArrayRef, AsArray, Int64Builder}; +use datafusion_common::arrow::compute::cast; +use datafusion_common::arrow::datatypes::DataType; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility}; + +use crate::function::Function; +use crate::function_registry::FunctionRegistry; + +const NAME: &str = "locate"; + +/// MySQL-compatible LOCATE function. +/// +/// Syntax: +/// - LOCATE(substr, str) - Returns 1-based position of substr in str, or 0 if not found. +/// - LOCATE(substr, str, pos) - Same, but starts searching from position pos. +#[derive(Debug)] +pub struct LocateFunction { + signature: Signature, +} + +impl LocateFunction { + pub fn register(registry: &FunctionRegistry) { + registry.register_scalar(LocateFunction::default()); + } +} + +impl Default for LocateFunction { + fn default() -> Self { + // Support 2 or 3 arguments with various string types + let mut signatures = Vec::new(); + let string_types = [DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View]; + let int_types = [ + DataType::Int64, + DataType::Int32, + DataType::Int16, + DataType::Int8, + DataType::UInt64, + DataType::UInt32, + DataType::UInt16, + DataType::UInt8, + ]; + + // 2-argument form: LOCATE(substr, str) + for substr_type in &string_types { + for str_type in &string_types { + signatures.push(TypeSignature::Exact(vec![ + substr_type.clone(), + str_type.clone(), + ])); + } + } + + // 3-argument form: LOCATE(substr, str, pos) + for substr_type in &string_types { + for str_type in &string_types { + for pos_type in &int_types { + signatures.push(TypeSignature::Exact(vec![ + substr_type.clone(), + str_type.clone(), + pos_type.clone(), + ])); + } + } + } + + Self { + signature: Signature::one_of(signatures, Volatility::Immutable), + } + } +} + +impl fmt::Display for LocateFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", NAME.to_ascii_uppercase()) + } +} + +impl Function for LocateFunction { + fn name(&self) -> &str { + NAME + } + + fn return_type(&self, _: &[DataType]) -> datafusion_common::Result { + Ok(DataType::Int64) + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let arg_count = args.args.len(); + if !(2..=3).contains(&arg_count) { + return Err(DataFusionError::Execution( + "LOCATE requires 2 or 3 arguments: LOCATE(substr, str) or LOCATE(substr, str, pos)" + .to_string(), + )); + } + + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + + // Cast string arguments to LargeUtf8 for uniform access + let substr_array = cast_to_large_utf8(&arrays[0], "substr")?; + let str_array = cast_to_large_utf8(&arrays[1], "str")?; + + let substr = substr_array.as_string::(); + let str_arr = str_array.as_string::(); + let len = substr.len(); + + // Handle optional pos argument + let pos_array: Option = if arg_count == 3 { + Some(cast_to_int64(&arrays[2], "pos")?) + } else { + None + }; + + let mut builder = Int64Builder::with_capacity(len); + + for i in 0..len { + if substr.is_null(i) || str_arr.is_null(i) { + builder.append_null(); + continue; + } + + let needle = substr.value(i); + let haystack = str_arr.value(i); + + // Get starting position (1-based in MySQL, convert to 0-based) + let start_pos = if let Some(ref pos_arr) = pos_array { + if pos_arr.is_null(i) { + builder.append_null(); + continue; + } + let pos = pos_arr + .as_primitive::() + .value(i); + if pos < 1 { + // MySQL returns 0 for pos < 1 + builder.append_value(0); + continue; + } + (pos - 1) as usize + } else { + 0 + }; + + // Find position using character-based indexing (for Unicode support) + let result = locate_substr(haystack, needle, start_pos); + builder.append_value(result); + } + + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + } +} + +/// Cast array to LargeUtf8 for uniform string access. +fn cast_to_large_utf8(array: &ArrayRef, name: &str) -> datafusion_common::Result { + cast(array.as_ref(), &DataType::LargeUtf8) + .map_err(|e| DataFusionError::Execution(format!("LOCATE: {} cast failed: {}", name, e))) +} + +fn cast_to_int64(array: &ArrayRef, name: &str) -> datafusion_common::Result { + cast(array.as_ref(), &DataType::Int64) + .map_err(|e| DataFusionError::Execution(format!("LOCATE: {} cast failed: {}", name, e))) +} + +/// Find the 1-based position of needle in haystack, starting from start_pos (0-based character index). +/// Returns 0 if not found. +fn locate_substr(haystack: &str, needle: &str, start_pos: usize) -> i64 { + // Handle empty needle - MySQL returns start_pos + 1 + if needle.is_empty() { + let char_count = haystack.chars().count(); + return if start_pos <= char_count { + (start_pos + 1) as i64 + } else { + 0 + }; + } + + // Convert start_pos (character index) to byte index + let byte_start = haystack + .char_indices() + .nth(start_pos) + .map(|(idx, _)| idx) + .unwrap_or(haystack.len()); + + if byte_start >= haystack.len() { + return 0; + } + + // Search in the substring + let search_str = &haystack[byte_start..]; + if let Some(byte_pos) = search_str.find(needle) { + // Convert byte position back to character position + let char_pos = search_str[..byte_pos].chars().count(); + // Return 1-based position relative to original string + (start_pos + char_pos + 1) as i64 + } else { + 0 + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use datafusion_common::arrow::array::StringArray; + use datafusion_common::arrow::datatypes::Field; + use datafusion_expr::ScalarFunctionArgs; + + use super::*; + + fn create_args(arrays: Vec) -> ScalarFunctionArgs { + let arg_fields: Vec<_> = arrays + .iter() + .enumerate() + .map(|(i, arr)| { + Arc::new(Field::new( + format!("arg_{}", i), + arr.data_type().clone(), + true, + )) + }) + .collect(); + + ScalarFunctionArgs { + args: arrays.iter().cloned().map(ColumnarValue::Array).collect(), + arg_fields, + return_field: Arc::new(Field::new("result", DataType::Int64, true)), + number_rows: arrays[0].len(), + config_options: Arc::new(datafusion_common::config::ConfigOptions::default()), + } + } + + #[test] + fn test_locate_basic() { + let function = LocateFunction::default(); + + let substr = Arc::new(StringArray::from(vec!["world", "xyz", "hello"])); + let str_arr = Arc::new(StringArray::from(vec![ + "hello world", + "hello world", + "hello world", + ])); + + let args = create_args(vec![substr, str_arr]); + let result = function.invoke_with_args(args).unwrap(); + + if let ColumnarValue::Array(array) = result { + let int_array = array.as_primitive::(); + assert_eq!(int_array.value(0), 7); // "world" at position 7 + assert_eq!(int_array.value(1), 0); // "xyz" not found + assert_eq!(int_array.value(2), 1); // "hello" at position 1 + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_locate_with_position() { + let function = LocateFunction::default(); + + let substr = Arc::new(StringArray::from(vec!["o", "o", "o"])); + let str_arr = Arc::new(StringArray::from(vec![ + "hello world", + "hello world", + "hello world", + ])); + let pos = Arc::new(datafusion_common::arrow::array::Int64Array::from(vec![ + 1, 5, 8, + ])); + + let args = create_args(vec![substr, str_arr, pos]); + let result = function.invoke_with_args(args).unwrap(); + + if let ColumnarValue::Array(array) = result { + let int_array = array.as_primitive::(); + assert_eq!(int_array.value(0), 5); // first 'o' at position 5 + assert_eq!(int_array.value(1), 5); // 'o' at position 5 (start from 5) + assert_eq!(int_array.value(2), 8); // 'o' in "world" at position 8 + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_locate_unicode() { + let function = LocateFunction::default(); + + let substr = Arc::new(StringArray::from(vec!["世", "界"])); + let str_arr = Arc::new(StringArray::from(vec!["hello世界", "hello世界"])); + + let args = create_args(vec![substr, str_arr]); + let result = function.invoke_with_args(args).unwrap(); + + if let ColumnarValue::Array(array) = result { + let int_array = array.as_primitive::(); + assert_eq!(int_array.value(0), 6); // "世" at position 6 + assert_eq!(int_array.value(1), 7); // "界" at position 7 + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_locate_empty_needle() { + let function = LocateFunction::default(); + + let substr = Arc::new(StringArray::from(vec!["", ""])); + let str_arr = Arc::new(StringArray::from(vec!["hello", "hello"])); + let pos = Arc::new(datafusion_common::arrow::array::Int64Array::from(vec![ + 1, 3, + ])); + + let args = create_args(vec![substr, str_arr, pos]); + let result = function.invoke_with_args(args).unwrap(); + + if let ColumnarValue::Array(array) = result { + let int_array = array.as_primitive::(); + assert_eq!(int_array.value(0), 1); // empty string at pos 1 + assert_eq!(int_array.value(1), 3); // empty string at pos 3 + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_locate_with_nulls() { + let function = LocateFunction::default(); + + let substr = Arc::new(StringArray::from(vec![Some("o"), None])); + let str_arr = Arc::new(StringArray::from(vec![Some("hello"), Some("hello")])); + + let args = create_args(vec![substr, str_arr]); + let result = function.invoke_with_args(args).unwrap(); + + if let ColumnarValue::Array(array) = result { + let int_array = array.as_primitive::(); + assert_eq!(int_array.value(0), 5); + assert!(int_array.is_null(1)); + } else { + panic!("Expected array result"); + } + } +} diff --git a/src/common/function/src/scalars/string/space.rs b/src/common/function/src/scalars/string/space.rs new file mode 100644 index 0000000000..a35779159f --- /dev/null +++ b/src/common/function/src/scalars/string/space.rs @@ -0,0 +1,252 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! MySQL-compatible SPACE function implementation. +//! +//! SPACE(N) - Returns a string consisting of N space characters. + +use std::fmt; +use std::sync::Arc; + +use datafusion_common::DataFusionError; +use datafusion_common::arrow::array::{Array, AsArray, LargeStringBuilder}; +use datafusion_common::arrow::datatypes::DataType; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility}; + +use crate::function::Function; +use crate::function_registry::FunctionRegistry; + +const NAME: &str = "space"; + +// Safety limit for maximum number of spaces +const MAX_SPACE_COUNT: i64 = 1024 * 1024; // 1MB of spaces + +/// MySQL-compatible SPACE function. +/// +/// Syntax: SPACE(N) +/// Returns a string consisting of N space characters. +/// Returns NULL if N is NULL. +/// Returns empty string if N < 0. +#[derive(Debug)] +pub struct SpaceFunction { + signature: Signature, +} + +impl SpaceFunction { + pub fn register(registry: &FunctionRegistry) { + registry.register_scalar(SpaceFunction::default()); + } +} + +impl Default for SpaceFunction { + fn default() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Int64]), + TypeSignature::Exact(vec![DataType::Int32]), + TypeSignature::Exact(vec![DataType::Int16]), + TypeSignature::Exact(vec![DataType::Int8]), + TypeSignature::Exact(vec![DataType::UInt64]), + TypeSignature::Exact(vec![DataType::UInt32]), + TypeSignature::Exact(vec![DataType::UInt16]), + TypeSignature::Exact(vec![DataType::UInt8]), + ], + Volatility::Immutable, + ), + } + } +} + +impl fmt::Display for SpaceFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", NAME.to_ascii_uppercase()) + } +} + +impl Function for SpaceFunction { + fn name(&self) -> &str { + NAME + } + + fn return_type(&self, _: &[DataType]) -> datafusion_common::Result { + Ok(DataType::LargeUtf8) + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + if args.args.len() != 1 { + return Err(DataFusionError::Execution( + "SPACE requires exactly 1 argument: SPACE(N)".to_string(), + )); + } + + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + let len = arrays[0].len(); + let n_array = &arrays[0]; + + let mut builder = LargeStringBuilder::with_capacity(len, len * 10); + + for i in 0..len { + if n_array.is_null(i) { + builder.append_null(); + continue; + } + + let n = get_int_value(n_array, i)?; + + if n < 0 { + // MySQL returns empty string for negative values + builder.append_value(""); + } else if n > MAX_SPACE_COUNT { + return Err(DataFusionError::Execution(format!( + "SPACE: requested {} spaces exceeds maximum allowed ({})", + n, MAX_SPACE_COUNT + ))); + } else { + let spaces = " ".repeat(n as usize); + builder.append_value(&spaces); + } + } + + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + } +} + +/// Extract integer value from various integer types. +fn get_int_value( + array: &datafusion_common::arrow::array::ArrayRef, + index: usize, +) -> datafusion_common::Result { + use datafusion_common::arrow::datatypes as arrow_types; + + match array.data_type() { + DataType::Int64 => Ok(array.as_primitive::().value(index)), + DataType::Int32 => Ok(array.as_primitive::().value(index) as i64), + DataType::Int16 => Ok(array.as_primitive::().value(index) as i64), + DataType::Int8 => Ok(array.as_primitive::().value(index) as i64), + DataType::UInt64 => { + let v = array.as_primitive::().value(index); + if v > i64::MAX as u64 { + Err(DataFusionError::Execution(format!( + "SPACE: value {} exceeds maximum", + v + ))) + } else { + Ok(v as i64) + } + } + DataType::UInt32 => Ok(array.as_primitive::().value(index) as i64), + DataType::UInt16 => Ok(array.as_primitive::().value(index) as i64), + DataType::UInt8 => Ok(array.as_primitive::().value(index) as i64), + _ => Err(DataFusionError::Execution(format!( + "SPACE: unsupported type {:?}", + array.data_type() + ))), + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use datafusion_common::arrow::array::Int64Array; + use datafusion_common::arrow::datatypes::Field; + use datafusion_expr::ScalarFunctionArgs; + + use super::*; + + fn create_args(arrays: Vec) -> ScalarFunctionArgs { + let arg_fields: Vec<_> = arrays + .iter() + .enumerate() + .map(|(i, arr)| { + Arc::new(Field::new( + format!("arg_{}", i), + arr.data_type().clone(), + true, + )) + }) + .collect(); + + ScalarFunctionArgs { + args: arrays.iter().cloned().map(ColumnarValue::Array).collect(), + arg_fields, + return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)), + number_rows: arrays[0].len(), + config_options: Arc::new(datafusion_common::config::ConfigOptions::default()), + } + } + + #[test] + fn test_space_basic() { + let function = SpaceFunction::default(); + + let n = Arc::new(Int64Array::from(vec![0, 1, 5])); + + let args = create_args(vec![n]); + let result = function.invoke_with_args(args).unwrap(); + + if let ColumnarValue::Array(array) = result { + let str_array = array.as_string::(); + assert_eq!(str_array.value(0), ""); + assert_eq!(str_array.value(1), " "); + assert_eq!(str_array.value(2), " "); + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_space_negative() { + let function = SpaceFunction::default(); + + let n = Arc::new(Int64Array::from(vec![-1, -100])); + + let args = create_args(vec![n]); + let result = function.invoke_with_args(args).unwrap(); + + if let ColumnarValue::Array(array) = result { + let str_array = array.as_string::(); + assert_eq!(str_array.value(0), ""); + assert_eq!(str_array.value(1), ""); + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_space_with_nulls() { + let function = SpaceFunction::default(); + + let n = Arc::new(Int64Array::from(vec![Some(3), None])); + + let args = create_args(vec![n]); + let result = function.invoke_with_args(args).unwrap(); + + if let ColumnarValue::Array(array) = result { + let str_array = array.as_string::(); + assert_eq!(str_array.value(0), " "); + assert!(str_array.is_null(1)); + } else { + panic!("Expected array result"); + } + } +} diff --git a/src/query/src/datafusion/planner.rs b/src/query/src/datafusion/planner.rs index 43e7a04db1..d9c74b9d5a 100644 --- a/src/query/src/datafusion/planner.rs +++ b/src/query/src/datafusion/planner.rs @@ -41,8 +41,6 @@ use snafu::{Location, ResultExt}; use crate::error::{CatalogSnafu, Result}; use crate::query_engine::{DefaultPlanDecoder, QueryEngineState}; -mod function_alias; - pub struct DfContextProviderAdapter { engine_state: Arc, session_state: SessionState, @@ -149,17 +147,7 @@ impl ContextProvider for DfContextProviderAdapter { fn get_function_meta(&self, name: &str) -> Option> { self.engine_state.scalar_function(name).map_or_else( - || { - self.session_state - .scalar_functions() - .get(name) - .cloned() - .or_else(|| { - function_alias::resolve_scalar(name).and_then(|name| { - self.session_state.scalar_functions().get(name).cloned() - }) - }) - }, + || self.session_state.scalar_functions().get(name).cloned(), |func| { Some(Arc::new(func.provide(FunctionContext { query_ctx: self.query_ctx.clone(), @@ -171,17 +159,7 @@ impl ContextProvider for DfContextProviderAdapter { fn get_aggregate_meta(&self, name: &str) -> Option> { self.engine_state.aggr_function(name).map_or_else( - || { - self.session_state - .aggregate_functions() - .get(name) - .cloned() - .or_else(|| { - function_alias::resolve_aggregate(name).and_then(|name| { - self.session_state.aggregate_functions().get(name).cloned() - }) - }) - }, + || self.session_state.aggregate_functions().get(name).cloned(), |func| Some(Arc::new(func)), ) } @@ -215,14 +193,12 @@ impl ContextProvider for DfContextProviderAdapter { fn udf_names(&self) -> Vec { let mut names = self.engine_state.scalar_names(); names.extend(self.session_state.scalar_functions().keys().cloned()); - names.extend(function_alias::scalar_alias_names().map(|name| name.to_string())); names } fn udaf_names(&self) -> Vec { let mut names = self.engine_state.aggr_names(); names.extend(self.session_state.aggregate_functions().keys().cloned()); - names.extend(function_alias::aggregate_alias_names().map(|name| name.to_string())); names } @@ -257,14 +233,9 @@ impl ContextProvider for DfContextProviderAdapter { .table_functions() .get(name) .cloned() - .or_else(|| { - function_alias::resolve_scalar(name) - .and_then(|alias| self.session_state.table_functions().get(alias).cloned()) - }); - - let tbl_func = tbl_func.ok_or_else(|| { - DataFusionError::Plan(format!("table function '{name}' not found")) - })?; + .ok_or_else(|| { + DataFusionError::Plan(format!("table function '{name}' not found")) + })?; let provider = tbl_func.create_table_provider(&args)?; Ok(provider_as_source(provider)) diff --git a/src/query/src/datafusion/planner/function_alias.rs b/src/query/src/datafusion/planner/function_alias.rs deleted file mode 100644 index 898ef81e93..0000000000 --- a/src/query/src/datafusion/planner/function_alias.rs +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2023 Greptime Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::collections::HashMap; - -use once_cell::sync::Lazy; - -const SCALAR_ALIASES: &[(&str, &str)] = &[ - // SQL compat aliases. - ("ucase", "upper"), - ("lcase", "lower"), - ("ceiling", "ceil"), - ("mid", "substr"), - // MySQL's RAND([seed]) accepts an optional seed argument, while DataFusion's `random()` - // does not. We alias the name for `rand()` compatibility, and `rand(seed)` will error - // due to mismatched arity. - ("rand", "random"), -]; - -const AGGREGATE_ALIASES: &[(&str, &str)] = &[ - // MySQL compat aliases that don't override existing DataFusion aggregate names. - // - // NOTE: We intentionally do NOT alias `stddev` here, because DataFusion defines `stddev` - // as sample standard deviation while MySQL's `STDDEV` is population standard deviation. - ("std", "stddev_pop"), - ("variance", "var_pop"), -]; - -static SCALAR_FUNCTION_ALIAS: Lazy> = - Lazy::new(|| SCALAR_ALIASES.iter().copied().collect()); - -static AGGREGATE_FUNCTION_ALIAS: Lazy> = - Lazy::new(|| AGGREGATE_ALIASES.iter().copied().collect()); - -pub fn resolve_scalar(name: &str) -> Option<&'static str> { - let name = name.to_ascii_lowercase(); - SCALAR_FUNCTION_ALIAS.get(name.as_str()).copied() -} - -pub fn resolve_aggregate(name: &str) -> Option<&'static str> { - let name = name.to_ascii_lowercase(); - AGGREGATE_FUNCTION_ALIAS.get(name.as_str()).copied() -} - -pub fn scalar_alias_names() -> impl Iterator { - SCALAR_ALIASES.iter().map(|(name, _)| *name) -} - -pub fn aggregate_alias_names() -> impl Iterator { - AGGREGATE_ALIASES.iter().map(|(name, _)| *name) -} - -#[cfg(test)] -mod tests { - use super::{resolve_aggregate, resolve_scalar}; - - #[test] - fn resolves_scalar_aliases_case_insensitive() { - assert_eq!(resolve_scalar("ucase"), Some("upper")); - assert_eq!(resolve_scalar("UCASE"), Some("upper")); - assert_eq!(resolve_scalar("lcase"), Some("lower")); - assert_eq!(resolve_scalar("ceiling"), Some("ceil")); - assert_eq!(resolve_scalar("MID"), Some("substr")); - assert_eq!(resolve_scalar("RAND"), Some("random")); - assert_eq!(resolve_scalar("not_a_real_alias"), None); - } - - #[test] - fn resolves_aggregate_aliases_case_insensitive() { - assert_eq!(resolve_aggregate("std"), Some("stddev_pop")); - assert_eq!(resolve_aggregate("variance"), Some("var_pop")); - assert_eq!(resolve_aggregate("STDDEV"), None); - assert_eq!(resolve_aggregate("not_a_real_alias"), None); - } -} diff --git a/src/query/src/query_engine/state.rs b/src/query/src/query_engine/state.rs index 9328f5f736..d232c0367d 100644 --- a/src/query/src/query_engine/state.rs +++ b/src/query/src/query_engine/state.rs @@ -209,6 +209,7 @@ impl QueryEngineState { .build(); let df_context = SessionContext::new_with_state(session_state); + register_function_aliases(&df_context); Self { df_context, @@ -415,6 +416,41 @@ impl QueryPlanner for DfQueryPlanner { } } +/// MySQL-compatible scalar function aliases: (target_name, alias) +const SCALAR_FUNCTION_ALIASES: &[(&str, &str)] = &[ + ("upper", "ucase"), + ("lower", "lcase"), + ("ceil", "ceiling"), + ("substr", "mid"), + ("random", "rand"), +]; + +/// MySQL-compatible aggregate function aliases: (target_name, alias) +const AGGREGATE_FUNCTION_ALIASES: &[(&str, &str)] = + &[("stddev_pop", "std"), ("var_pop", "variance")]; + +/// Register function aliases. +/// +/// This function adds aliases like `ucase` -> `upper`, `lcase` -> `lower`, etc. +/// to make GreptimeDB more compatible with MySQL syntax. +fn register_function_aliases(ctx: &SessionContext) { + let state = ctx.state(); + + for (target, alias) in SCALAR_FUNCTION_ALIASES { + if let Some(func) = state.scalar_functions().get(*target) { + let aliased = func.as_ref().clone().with_aliases([*alias]); + ctx.register_udf(aliased); + } + } + + for (target, alias) in AGGREGATE_FUNCTION_ALIASES { + if let Some(func) = state.aggregate_functions().get(*target) { + let aliased = func.as_ref().clone().with_aliases([*alias]); + ctx.register_udaf(aliased); + } + } +} + impl DfQueryPlanner { fn new( catalog_manager: CatalogManagerRef, diff --git a/tests/cases/standalone/common/function/string/mysql_compat.result b/tests/cases/standalone/common/function/string/mysql_compat.result new file mode 100644 index 0000000000..fffd36d8f8 --- /dev/null +++ b/tests/cases/standalone/common/function/string/mysql_compat.result @@ -0,0 +1,347 @@ +-- MySQL-compatible string function tests +-- LOCATE function tests +SELECT LOCATE('world', 'hello world'); + ++-------------------------------------------+ +| locate(Utf8("world"),Utf8("hello world")) | ++-------------------------------------------+ +| 7 | ++-------------------------------------------+ + +SELECT LOCATE('xyz', 'hello world'); + ++-----------------------------------------+ +| locate(Utf8("xyz"),Utf8("hello world")) | ++-----------------------------------------+ +| 0 | ++-----------------------------------------+ + +SELECT LOCATE('o', 'hello world'); + ++---------------------------------------+ +| locate(Utf8("o"),Utf8("hello world")) | ++---------------------------------------+ +| 5 | ++---------------------------------------+ + +SELECT LOCATE('o', 'hello world', 5); + ++------------------------------------------------+ +| locate(Utf8("o"),Utf8("hello world"),Int64(5)) | ++------------------------------------------------+ +| 5 | ++------------------------------------------------+ + +SELECT LOCATE('o', 'hello world', 6); + ++------------------------------------------------+ +| locate(Utf8("o"),Utf8("hello world"),Int64(6)) | ++------------------------------------------------+ +| 8 | ++------------------------------------------------+ + +SELECT LOCATE('', 'hello'); + ++--------------------------------+ +| locate(Utf8(""),Utf8("hello")) | ++--------------------------------+ +| 1 | ++--------------------------------+ + +SELECT LOCATE('世', 'hello世界'); + ++--------------------------------------+ +| locate(Utf8("世"),Utf8("hello世界")) | ++--------------------------------------+ +| 6 | ++--------------------------------------+ + +SELECT LOCATE(NULL, 'hello'); + ++----------------------------+ +| locate(NULL,Utf8("hello")) | ++----------------------------+ +| | ++----------------------------+ + +SELECT LOCATE('o', NULL); + ++------------------------+ +| locate(Utf8("o"),NULL) | ++------------------------+ +| | ++------------------------+ + +-- ELT function tests +SELECT ELT(1, 'a', 'b', 'c'); + ++---------------------------------------------+ +| elt(Int64(1),Utf8("a"),Utf8("b"),Utf8("c")) | ++---------------------------------------------+ +| a | ++---------------------------------------------+ + +SELECT ELT(2, 'a', 'b', 'c'); + ++---------------------------------------------+ +| elt(Int64(2),Utf8("a"),Utf8("b"),Utf8("c")) | ++---------------------------------------------+ +| b | ++---------------------------------------------+ + +SELECT ELT(3, 'a', 'b', 'c'); + ++---------------------------------------------+ +| elt(Int64(3),Utf8("a"),Utf8("b"),Utf8("c")) | ++---------------------------------------------+ +| c | ++---------------------------------------------+ + +SELECT ELT(0, 'a', 'b', 'c'); + ++---------------------------------------------+ +| elt(Int64(0),Utf8("a"),Utf8("b"),Utf8("c")) | ++---------------------------------------------+ +| | ++---------------------------------------------+ + +SELECT ELT(4, 'a', 'b', 'c'); + ++---------------------------------------------+ +| elt(Int64(4),Utf8("a"),Utf8("b"),Utf8("c")) | ++---------------------------------------------+ +| | ++---------------------------------------------+ + +SELECT ELT(NULL, 'a', 'b', 'c'); + ++-----------------------------------------+ +| elt(NULL,Utf8("a"),Utf8("b"),Utf8("c")) | ++-----------------------------------------+ +| | ++-----------------------------------------+ + +-- FIELD function tests +SELECT FIELD('b', 'a', 'b', 'c'); + ++------------------------------------------------+ +| field(Utf8("b"),Utf8("a"),Utf8("b"),Utf8("c")) | ++------------------------------------------------+ +| 2 | ++------------------------------------------------+ + +SELECT FIELD('d', 'a', 'b', 'c'); + ++------------------------------------------------+ +| field(Utf8("d"),Utf8("a"),Utf8("b"),Utf8("c")) | ++------------------------------------------------+ +| 0 | ++------------------------------------------------+ + +SELECT FIELD('a', 'a', 'b', 'c'); + ++------------------------------------------------+ +| field(Utf8("a"),Utf8("a"),Utf8("b"),Utf8("c")) | ++------------------------------------------------+ +| 1 | ++------------------------------------------------+ + +SELECT FIELD('A', 'a', 'b', 'c'); + ++------------------------------------------------+ +| field(Utf8("A"),Utf8("a"),Utf8("b"),Utf8("c")) | ++------------------------------------------------+ +| 0 | ++------------------------------------------------+ + +SELECT FIELD(NULL, 'a', 'b', 'c'); + ++-------------------------------------------+ +| field(NULL,Utf8("a"),Utf8("b"),Utf8("c")) | ++-------------------------------------------+ +| 0 | ++-------------------------------------------+ + +-- INSERT function tests +SELECT INSERT('Quadratic', 3, 4, 'What'); + ++----------------------------------------------------------+ +| insert(Utf8("Quadratic"),Int64(3),Int64(4),Utf8("What")) | ++----------------------------------------------------------+ +| QuWhattic | ++----------------------------------------------------------+ + +SELECT INSERT('Quadratic', 3, 100, 'What'); + ++------------------------------------------------------------+ +| insert(Utf8("Quadratic"),Int64(3),Int64(100),Utf8("What")) | ++------------------------------------------------------------+ +| QuWhat | ++------------------------------------------------------------+ + +SELECT INSERT('Quadratic', 0, 4, 'What'); + ++----------------------------------------------------------+ +| insert(Utf8("Quadratic"),Int64(0),Int64(4),Utf8("What")) | ++----------------------------------------------------------+ +| Quadratic | ++----------------------------------------------------------+ + +SELECT INSERT('hello', 1, 0, 'X'); + ++---------------------------------------------------+ +| insert(Utf8("hello"),Int64(1),Int64(0),Utf8("X")) | ++---------------------------------------------------+ +| Xhello | ++---------------------------------------------------+ + +SELECT INSERT('hello世界', 6, 1, 'の'); + ++--------------------------------------------------------+ +| insert(Utf8("hello世界"),Int64(6),Int64(1),Utf8("の")) | ++--------------------------------------------------------+ +| helloの界 | ++--------------------------------------------------------+ + +SELECT INSERT(NULL, 1, 1, 'X'); + ++------------------------------------------+ +| insert(NULL,Int64(1),Int64(1),Utf8("X")) | ++------------------------------------------+ +| | ++------------------------------------------+ + +-- SPACE function tests +SELECT SPACE(5); + ++-----------------+ +| space(Int64(5)) | ++-----------------+ +| | ++-----------------+ + +SELECT SPACE(0); + ++-----------------+ +| space(Int64(0)) | ++-----------------+ +| | ++-----------------+ + +SELECT SPACE(-1); + ++------------------+ +| space(Int64(-1)) | ++------------------+ +| | ++------------------+ + +SELECT CONCAT('a', SPACE(3), 'b'); + ++---------------------------------------------+ +| concat(Utf8("a"),space(Int64(3)),Utf8("b")) | ++---------------------------------------------+ +| a b | ++---------------------------------------------+ + +SELECT SPACE(NULL); + ++-------------+ +| space(NULL) | ++-------------+ +| | ++-------------+ + +-- FORMAT function tests +SELECT FORMAT(1234567.891, 2); + ++---------------------------------------+ +| format(Float64(1234567.891),Int64(2)) | ++---------------------------------------+ +| 1,234,567.89 | ++---------------------------------------+ + +SELECT FORMAT(1234567.891, 0); + ++---------------------------------------+ +| format(Float64(1234567.891),Int64(0)) | ++---------------------------------------+ +| 1,234,568 | ++---------------------------------------+ + +SELECT FORMAT(1234.5, 4); + ++----------------------------------+ +| format(Float64(1234.5),Int64(4)) | ++----------------------------------+ +| 1,234.5000 | ++----------------------------------+ + +SELECT FORMAT(-1234567.891, 2); + ++----------------------------------------+ +| format(Float64(-1234567.891),Int64(2)) | ++----------------------------------------+ +| -1,234,567.89 | ++----------------------------------------+ + +SELECT FORMAT(0.5, 2); + ++-------------------------------+ +| format(Float64(0.5),Int64(2)) | ++-------------------------------+ +| 0.50 | ++-------------------------------+ + +SELECT FORMAT(123, 2); + ++-----------------------------+ +| format(Int64(123),Int64(2)) | ++-----------------------------+ +| 123.00 | ++-----------------------------+ + +SELECT FORMAT(NULL, 2); + ++-----------------------+ +| format(NULL,Int64(2)) | ++-----------------------+ +| | ++-----------------------+ + +-- Combined test with table +CREATE TABLE string_test(idx INT, val VARCHAR, ts TIMESTAMP TIME INDEX); + +Affected Rows: 0 + +INSERT INTO string_test VALUES +(1, 'hello world', 1), +(2, 'foo bar baz', 2), +(3, 'hello世界', 3); + +Affected Rows: 3 + +SELECT idx, val, LOCATE('o', val) as loc FROM string_test ORDER BY idx; + ++-----+-------------+-----+ +| idx | val | loc | ++-----+-------------+-----+ +| 1 | hello world | 5 | +| 2 | foo bar baz | 2 | +| 3 | hello世界 | 5 | ++-----+-------------+-----+ + +SELECT idx, val, INSERT(val, 1, 5, 'hi') as inserted FROM string_test ORDER BY idx; + ++-----+-------------+----------+ +| idx | val | inserted | ++-----+-------------+----------+ +| 1 | hello world | hi world | +| 2 | foo bar baz | hiar baz | +| 3 | hello世界 | hi世界 | ++-----+-------------+----------+ + +DROP TABLE string_test; + +Affected Rows: 0 + diff --git a/tests/cases/standalone/common/function/string/mysql_compat.sql b/tests/cases/standalone/common/function/string/mysql_compat.sql new file mode 100644 index 0000000000..8ae2d1d7d9 --- /dev/null +++ b/tests/cases/standalone/common/function/string/mysql_compat.sql @@ -0,0 +1,97 @@ +-- MySQL-compatible string function tests + +-- LOCATE function tests +SELECT LOCATE('world', 'hello world'); + +SELECT LOCATE('xyz', 'hello world'); + +SELECT LOCATE('o', 'hello world'); + +SELECT LOCATE('o', 'hello world', 5); + +SELECT LOCATE('o', 'hello world', 6); + +SELECT LOCATE('', 'hello'); + +SELECT LOCATE('世', 'hello世界'); + +SELECT LOCATE(NULL, 'hello'); + +SELECT LOCATE('o', NULL); + +-- ELT function tests +SELECT ELT(1, 'a', 'b', 'c'); + +SELECT ELT(2, 'a', 'b', 'c'); + +SELECT ELT(3, 'a', 'b', 'c'); + +SELECT ELT(0, 'a', 'b', 'c'); + +SELECT ELT(4, 'a', 'b', 'c'); + +SELECT ELT(NULL, 'a', 'b', 'c'); + +-- FIELD function tests +SELECT FIELD('b', 'a', 'b', 'c'); + +SELECT FIELD('d', 'a', 'b', 'c'); + +SELECT FIELD('a', 'a', 'b', 'c'); + +SELECT FIELD('A', 'a', 'b', 'c'); + +SELECT FIELD(NULL, 'a', 'b', 'c'); + +-- INSERT function tests +SELECT INSERT('Quadratic', 3, 4, 'What'); + +SELECT INSERT('Quadratic', 3, 100, 'What'); + +SELECT INSERT('Quadratic', 0, 4, 'What'); + +SELECT INSERT('hello', 1, 0, 'X'); + +SELECT INSERT('hello世界', 6, 1, 'の'); + +SELECT INSERT(NULL, 1, 1, 'X'); + +-- SPACE function tests +SELECT SPACE(5); + +SELECT SPACE(0); + +SELECT SPACE(-1); + +SELECT CONCAT('a', SPACE(3), 'b'); + +SELECT SPACE(NULL); + +-- FORMAT function tests +SELECT FORMAT(1234567.891, 2); + +SELECT FORMAT(1234567.891, 0); + +SELECT FORMAT(1234.5, 4); + +SELECT FORMAT(-1234567.891, 2); + +SELECT FORMAT(0.5, 2); + +SELECT FORMAT(123, 2); + +SELECT FORMAT(NULL, 2); + +-- Combined test with table +CREATE TABLE string_test(idx INT, val VARCHAR, ts TIMESTAMP TIME INDEX); + +INSERT INTO string_test VALUES +(1, 'hello world', 1), +(2, 'foo bar baz', 2), +(3, 'hello世界', 3); + +SELECT idx, val, LOCATE('o', val) as loc FROM string_test ORDER BY idx; + +SELECT idx, val, INSERT(val, 1, 5, 'hi') as inserted FROM string_test ORDER BY idx; + +DROP TABLE string_test;