fix: support dictionary in regex match (#7055)

* fix: support dictionary in regex match

Signed-off-by: evenyag <realevenyag@gmail.com>

* chore: get key from keys buffer directly

Signed-off-by: evenyag <realevenyag@gmail.com>

---------

Signed-off-by: evenyag <realevenyag@gmail.com>
This commit is contained in:
Yingwen
2025-10-10 11:03:34 +08:00
committed by GitHub
parent 591b9f3e81
commit 47c1ef672a

View File

@@ -25,11 +25,11 @@ use datafusion_common::arrow::compute::kernels::cmp;
use datafusion_common::cast::{as_boolean_array, as_null_array, as_string_array};
use datafusion_common::{DataFusionError, ScalarValue, internal_err};
use datatypes::arrow::array::{
Array, ArrayAccessor, ArrayData, BooleanArray, BooleanBufferBuilder, RecordBatch,
StringArrayType,
Array, ArrayAccessor, ArrayData, BooleanArray, BooleanBufferBuilder, DictionaryArray,
RecordBatch, StringArrayType,
};
use datatypes::arrow::compute::filter_record_batch;
use datatypes::arrow::datatypes::DataType;
use datatypes::arrow::datatypes::{DataType, UInt32Type};
use datatypes::arrow::error::ArrowError;
use datatypes::compute::or_kleene;
use datatypes::vectors::VectorRef;
@@ -264,14 +264,29 @@ impl SimpleFilterEvaluator {
fn regex_match(&self, input: &impl Datum) -> std::result::Result<BooleanArray, ArrowError> {
let array = input.get().0;
let string_array = as_string_array(array).map_err(|_| {
ArrowError::CastError(format!("Cannot cast {:?} to StringArray", array))
})?;
let mut result = regexp_is_match_scalar(string_array, self.regex.as_ref())?;
if self.regex_negative {
result = datatypes::compute::not(&result)?;
// Try to cast to StringArray first
if let Ok(string_array) = as_string_array(array) {
let mut result = regexp_is_match_scalar(string_array, self.regex.as_ref())?;
if self.regex_negative {
result = datatypes::compute::not(&result)?;
}
return Ok(result);
}
Ok(result)
// Try to cast to StringDictionaryArray
if let Some(dict_array) = array.as_any().downcast_ref::<DictionaryArray<UInt32Type>>() {
let mut result = regexp_is_match_dictionary(dict_array, self.regex.as_ref())?;
if self.regex_negative {
result = datatypes::compute::not(&result)?;
}
return Ok(result);
}
Err(ArrowError::CastError(format!(
"Cannot cast {:?} to StringArray or StringDictionaryArray",
array.data_type()
)))
}
}
@@ -340,6 +355,55 @@ where
Ok(BooleanArray::from(data))
}
/// Similar to [regexp_is_match_scalar] but for StringDictionaryArray.
/// Iterates through dictionary keys to get string values and applies regex matching.
pub fn regexp_is_match_dictionary(
dict_array: &DictionaryArray<UInt32Type>,
regex: Option<&Regex>,
) -> Result<BooleanArray, ArrowError> {
// Get the string values from the dictionary
let string_values = dict_array
.values()
.as_any()
.downcast_ref::<datatypes::arrow::array::StringArray>()
.ok_or_else(|| {
ArrowError::CastError("Dictionary values must be StringArray".to_string())
})?;
let null_bit_buffer = dict_array.nulls().map(|x| x.inner().sliced());
let mut result = BooleanBufferBuilder::new(dict_array.len());
if let Some(re) = regex {
let keys = dict_array.keys().values();
for i in 0..dict_array.len() {
if dict_array.is_null(i) {
result.append(false);
} else {
let key = keys[i] as usize;
let string_value = string_values.value(key);
result.append(re.is_match(string_value));
}
}
} else {
result.append_n(dict_array.len(), true);
}
let buffer = result.into();
let data = unsafe {
ArrayData::new_unchecked(
DataType::Boolean,
dict_array.len(),
None,
null_bit_buffer,
0,
vec![buffer],
vec![],
)
};
Ok(BooleanArray::from(data))
}
#[cfg(test)]
mod test {
@@ -615,4 +679,44 @@ mod test {
);
assert!(result.is_err());
}
#[test]
fn test_regex_match_dictionary_array() {
use datatypes::arrow::array::StringDictionaryBuilder;
// Create a StringDictionaryArray
let mut builder = StringDictionaryBuilder::<UInt32Type>::new();
builder.append("apple").unwrap();
builder.append("banana").unwrap();
builder.append("apple").unwrap();
builder.append("cherry").unwrap();
let dict_array = builder.finish();
// Test regex that matches "apple"
let regex = regex::Regex::new(r"app.*").unwrap();
let result = regexp_is_match_dictionary(&dict_array, Some(&regex)).unwrap();
// Should match indices 0 and 2 (both "apple")
assert_eq!(result.len(), 4);
assert!(result.value(0)); // "apple"
assert!(!result.value(1)); // "banana"
assert!(result.value(2)); // "apple"
assert!(!result.value(3)); // "cherry"
// Test regex that matches "banana"
let regex2 = regex::Regex::new(r"ban.*").unwrap();
let result2 = regexp_is_match_dictionary(&dict_array, Some(&regex2)).unwrap();
assert!(!result2.value(0)); // "apple"
assert!(result2.value(1)); // "banana"
assert!(!result2.value(2)); // "apple"
assert!(!result2.value(3)); // "cherry"
// Test with no regex (should match all)
let result3 = regexp_is_match_dictionary(&dict_array, None).unwrap();
assert!(result3.value(0));
assert!(result3.value(1));
assert!(result3.value(2));
assert!(result3.value(3));
}
}