feat: adds regex_extract function and more type tests (#7107)

* feat: adds format, regex_extract function and more type tests

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* fix: forgot functions

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* chore: forgot null type

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* test: forgot date type

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* feat: remove format function

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* test: update results after upgrading datafusion

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

---------

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>
This commit is contained in:
dennis zhuang
2025-10-25 16:41:49 +08:00
committed by GitHub
parent 7da2f5ed12
commit d8563ba56d
58 changed files with 6502 additions and 15 deletions

View File

@@ -51,6 +51,7 @@ nalgebra.workspace = true
num = "0.4"
num-traits = "0.2"
paste.workspace = true
regex.workspace = true
s2 = { version = "0.0.12", optional = true }
serde.workspace = true
serde_json.workspace = true

View File

@@ -34,6 +34,7 @@ use crate::scalars::json::JsonFunction;
use crate::scalars::matches::MatchesFunction;
use crate::scalars::matches_term::MatchesTermFunction;
use crate::scalars::math::MathFunction;
use crate::scalars::string::register_string_functions;
use crate::scalars::timestamp::TimestampFunction;
use crate::scalars::uddsketch_calc::UddSketchCalcFunction;
use crate::scalars::vector::VectorFunction as VectorScalarFunction;
@@ -154,6 +155,9 @@ pub static FUNCTION_REGISTRY: LazyLock<Arc<FunctionRegistry>> = LazyLock::new(||
// Json related functions
JsonFunction::register(&function_registry);
// String related functions
register_string_functions(&function_registry);
// Vector related functions
VectorScalarFunction::register(&function_registry);
VectorAggrFunction::register(&function_registry);

View File

@@ -20,6 +20,7 @@ pub mod json;
pub mod matches;
pub mod matches_term;
pub mod math;
pub(crate) mod string;
pub mod vector;
pub(crate) mod hll_count;

View File

@@ -20,7 +20,9 @@ use common_query::error;
use common_time::{Date, Timestamp};
use datafusion_common::DataFusionError;
use datafusion_common::arrow::array::{Array, AsArray, StringViewBuilder};
use datafusion_common::arrow::datatypes::{ArrowTimestampType, DataType, Date32Type, TimeUnit};
use datafusion_common::arrow::datatypes::{
ArrowTimestampType, DataType, Date32Type, Date64Type, TimeUnit,
};
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature};
use snafu::ResultExt;
@@ -40,6 +42,7 @@ impl Default for DateFormatFunction {
signature: helper::one_of_sigs2(
vec![
DataType::Date32,
DataType::Date64,
DataType::Timestamp(TimeUnit::Second, None),
DataType::Timestamp(TimeUnit::Millisecond, None),
DataType::Timestamp(TimeUnit::Microsecond, None),
@@ -115,6 +118,29 @@ impl Function for DateFormatFunction {
builder.append_option(result.as_deref());
}
}
DataType::Date64 => {
let left = left.as_primitive::<Date64Type>();
for i in 0..size {
let date = left.is_valid(i).then(|| {
let ms = left.value(i);
Timestamp::new_millisecond(ms)
});
let format = formats.is_valid(i).then(|| formats.value(i));
let result = match (date, format) {
(Some(ts), Some(fmt)) => {
Some(ts.as_formatted_string(fmt, Some(timezone)).map_err(|e| {
DataFusionError::Execution(format!(
"cannot format {ts:?} as '{fmt}': {e}"
))
})?)
}
_ => None,
};
builder.append_option(result.as_deref());
}
}
x => {
return Err(DataFusionError::Execution(format!(
"unsupported input data type {x}"
@@ -137,7 +163,9 @@ mod tests {
use std::sync::Arc;
use arrow_schema::Field;
use datafusion_common::arrow::array::{Date32Array, StringArray, TimestampSecondArray};
use datafusion_common::arrow::array::{
Date32Array, Date64Array, StringArray, TimestampSecondArray,
};
use datafusion_common::config::ConfigOptions;
use datafusion_expr::{TypeSignature, Volatility};
@@ -166,7 +194,7 @@ mod tests {
Signature {
type_signature: TypeSignature::OneOf(sigs),
volatility: Volatility::Immutable
} if sigs.len() == 5));
} if sigs.len() == 6));
}
#[test]
@@ -213,6 +241,50 @@ mod tests {
}
}
#[test]
fn test_date64_date_format() {
let f = DateFormatFunction::default();
let dates = vec![Some(123000), None, Some(42000), None];
let formats = vec![
"%Y-%m-%d %T.%3f",
"%Y-%m-%d %T.%3f",
"%Y-%m-%d %T.%3f",
"%Y-%m-%d %T.%3f",
];
let results = [
Some("1970-01-01 00:02:03.000"),
None,
Some("1970-01-01 00:00:42.000"),
None,
];
let mut config_options = ConfigOptions::default();
config_options.extensions.insert(FunctionContext::default());
let config_options = Arc::new(config_options);
let args = ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(Date64Array::from(dates))),
ColumnarValue::Array(Arc::new(StringArray::from_iter_values(formats))),
],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
config_options,
};
let result = f
.invoke_with_args(args)
.and_then(|x| x.to_array(4))
.unwrap();
let vector = result.as_string_view();
assert_eq!(4, vector.len());
for (actual, expect) in vector.iter().zip(results) {
assert_eq!(actual, expect);
}
}
#[test]
fn test_date_date_format() {
let f = DateFormatFunction::default();

View File

@@ -0,0 +1,26 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! String scalar functions
mod regexp_extract;
pub(crate) use regexp_extract::RegexpExtractFunction;
use crate::function_registry::FunctionRegistry;
/// Register all string functions
pub fn register_string_functions(registry: &FunctionRegistry) {
RegexpExtractFunction::register(registry);
}

View File

@@ -0,0 +1,339 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Implementation of REGEXP_EXTRACT function
use std::fmt;
use std::sync::Arc;
use datafusion_common::DataFusionError;
use datafusion_common::arrow::array::{Array, AsArray, LargeStringBuilder};
use datafusion_common::arrow::compute::cast;
use datafusion_common::arrow::datatypes::DataType;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility};
use regex::{Regex, RegexBuilder};
use crate::function::Function;
use crate::function_registry::FunctionRegistry;
const NAME: &str = "regexp_extract";
// Safety limits
const MAX_REGEX_SIZE: usize = 1024 * 1024; // compiled regex heap cap
const MAX_DFA_SIZE: usize = 2 * 1024 * 1024; // lazy DFA cap
const MAX_TOTAL_RESULT_SIZE: usize = 64 * 1024 * 1024; // total batch cap
const MAX_SINGLE_MATCH: usize = 1024 * 1024; // per-row cap
const MAX_PATTERN_LEN: usize = 10_000; // pattern text length cap
/// REGEXP_EXTRACT function implementation
/// Extracts the first substring matching the given regular expression pattern.
/// If no match is found, returns NULL.
///
#[derive(Debug)]
pub struct RegexpExtractFunction {
signature: Signature,
}
impl RegexpExtractFunction {
pub fn register(registry: &FunctionRegistry) {
registry.register_scalar(RegexpExtractFunction::default());
}
}
impl Default for RegexpExtractFunction {
fn default() -> Self {
Self {
signature: Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8View]),
TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8View]),
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8View]),
TypeSignature::Exact(vec![DataType::Utf8View, DataType::LargeUtf8]),
TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]),
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
],
Volatility::Immutable,
),
}
}
}
impl fmt::Display for RegexpExtractFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", NAME.to_ascii_uppercase())
}
}
impl Function for RegexpExtractFunction {
fn name(&self) -> &str {
NAME
}
// Always return LargeUtf8 for simplicity and safety
fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
Ok(DataType::LargeUtf8)
}
fn signature(&self) -> &Signature {
&self.signature
}
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
if args.args.len() != 2 {
return Err(DataFusionError::Execution(
"REGEXP_EXTRACT requires exactly two arguments (text, pattern)".to_string(),
));
}
// Keep original ColumnarValue variants for scalar-pattern fast path
let pattern_is_scalar = matches!(args.args[1], ColumnarValue::Scalar(_));
let arrays = ColumnarValue::values_to_arrays(&args.args)?;
let text_array = &arrays[0];
let pattern_array = &arrays[1];
// Cast both to LargeUtf8 for uniform access (supports Utf8/Utf8View/Dictionary<String>)
let text_large = cast(text_array.as_ref(), &DataType::LargeUtf8).map_err(|e| {
DataFusionError::Execution(format!("REGEXP_EXTRACT: text cast failed: {e}"))
})?;
let pattern_large = cast(pattern_array.as_ref(), &DataType::LargeUtf8).map_err(|e| {
DataFusionError::Execution(format!("REGEXP_EXTRACT: pattern cast failed: {e}"))
})?;
let text = text_large.as_string::<i64>();
let pattern = pattern_large.as_string::<i64>();
let len = text.len();
// Pre-size result builder with conservative estimate
let mut estimated_total = 0usize;
for i in 0..len {
if !text.is_null(i) {
estimated_total = estimated_total.saturating_add(text.value_length(i) as usize);
if estimated_total > MAX_TOTAL_RESULT_SIZE {
return Err(DataFusionError::ResourcesExhausted(format!(
"REGEXP_EXTRACT total output exceeds {} bytes",
MAX_TOTAL_RESULT_SIZE
)));
}
}
}
let mut builder = LargeStringBuilder::with_capacity(len, estimated_total);
// Fast path: if pattern is scalar, compile once
let compiled_scalar: Option<Regex> = if pattern_is_scalar && len > 0 && !pattern.is_null(0)
{
Some(compile_regex_checked(pattern.value(0))?)
} else {
None
};
for i in 0..len {
if text.is_null(i) || pattern.is_null(i) {
builder.append_null();
continue;
}
let s = text.value(i);
let pat = pattern.value(i);
// Compile or reuse regex
let re = if let Some(ref compiled) = compiled_scalar {
compiled
} else {
// TODO: For performance-critical applications with repeating patterns,
// consider adding a small LRU cache here
&compile_regex_checked(pat)?
};
// First match only
if let Some(m) = re.find(s) {
let m_str = m.as_str();
if m_str.len() > MAX_SINGLE_MATCH {
return Err(DataFusionError::Execution(
"REGEXP_EXTRACT match exceeds per-row limit (1MB)".to_string(),
));
}
builder.append_value(m_str);
} else {
builder.append_null();
}
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
}
// Compile a regex with safety checks
fn compile_regex_checked(pattern: &str) -> datafusion_common::Result<Regex> {
if pattern.len() > MAX_PATTERN_LEN {
return Err(DataFusionError::Execution(format!(
"REGEXP_EXTRACT pattern too long (> {} chars)",
MAX_PATTERN_LEN
)));
}
RegexBuilder::new(pattern)
.size_limit(MAX_REGEX_SIZE)
.dfa_size_limit(MAX_DFA_SIZE)
.build()
.map_err(|e| {
DataFusionError::Execution(format!("REGEXP_EXTRACT invalid pattern '{}': {e}", pattern))
})
}
#[cfg(test)]
mod tests {
use datafusion_common::arrow::array::StringArray;
use datafusion_common::arrow::datatypes::Field;
use datafusion_expr::ScalarFunctionArgs;
use super::*;
#[test]
fn test_regexp_extract_function_basic() {
let text_array = Arc::new(StringArray::from(vec!["version 1.2.3", "no match here"]));
let pattern_array = Arc::new(StringArray::from(vec!["\\d+\\.\\d+\\.\\d+", "\\d+"]));
let args = ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(text_array),
ColumnarValue::Array(pattern_array),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::Utf8, false)),
Arc::new(Field::new("arg_1", DataType::Utf8, false)),
],
return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
number_rows: 2,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let function = RegexpExtractFunction::default();
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let string_array = array.as_string::<i64>();
assert_eq!(string_array.value(0), "1.2.3");
assert!(string_array.is_null(1)); // no match should return NULL
} else {
panic!("Expected array result");
}
}
#[test]
fn test_regexp_extract_phone_number() {
let text_array = Arc::new(StringArray::from(vec!["Phone: 123-456-7890", "No phone"]));
let pattern_array = Arc::new(StringArray::from(vec![
"\\d{3}-\\d{3}-\\d{4}",
"\\d{3}-\\d{3}-\\d{4}",
]));
let args = ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(text_array),
ColumnarValue::Array(pattern_array),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::Utf8, false)),
Arc::new(Field::new("arg_1", DataType::Utf8, false)),
],
return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
number_rows: 2,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let function = RegexpExtractFunction::default();
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let string_array = array.as_string::<i64>();
assert_eq!(string_array.value(0), "123-456-7890");
assert!(string_array.is_null(1)); // no match should return NULL
} else {
panic!("Expected array result");
}
}
#[test]
fn test_regexp_extract_email() {
let text_array = Arc::new(StringArray::from(vec![
"Email: user@domain.com",
"Invalid email",
]));
let pattern_array = Arc::new(StringArray::from(vec![
"[a-zA-Z0-9]+@[a-zA-Z0-9]+\\.[a-zA-Z]+",
"[a-zA-Z0-9]+@[a-zA-Z0-9]+\\.[a-zA-Z]+",
]));
let args = ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(text_array),
ColumnarValue::Array(pattern_array),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::Utf8, false)),
Arc::new(Field::new("arg_1", DataType::Utf8, false)),
],
return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
number_rows: 2,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let function = RegexpExtractFunction::default();
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let string_array = array.as_string::<i64>();
assert_eq!(string_array.value(0), "user@domain.com");
assert!(string_array.is_null(1)); // no match should return NULL
} else {
panic!("Expected array result");
}
}
#[test]
fn test_regexp_extract_with_nulls() {
let text_array = Arc::new(StringArray::from(vec![Some("test 123"), None]));
let pattern_array = Arc::new(StringArray::from(vec![Some("\\d+"), Some("\\d+")]));
let args = ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(text_array),
ColumnarValue::Array(pattern_array),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::Utf8, true)),
Arc::new(Field::new("arg_1", DataType::Utf8, false)),
],
return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
number_rows: 2,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let function = RegexpExtractFunction::default();
let result = function.invoke_with_args(args).unwrap();
if let ColumnarValue::Array(array) = result {
let string_array = array.as_string::<i64>();
assert_eq!(string_array.value(0), "123");
assert!(string_array.is_null(1)); // NULL input should return NULL
} else {
panic!("Expected array result");
}
}
}