diff --git a/src/common/recordbatch/src/filter.rs b/src/common/recordbatch/src/filter.rs index d8955bf57b..a62b4aa394 100644 --- a/src/common/recordbatch/src/filter.rs +++ b/src/common/recordbatch/src/filter.rs @@ -25,11 +25,11 @@ use datafusion_common::arrow::compute::kernels::cmp; use datafusion_common::cast::{as_boolean_array, as_null_array, as_string_array}; use datafusion_common::{DataFusionError, ScalarValue, internal_err}; use datatypes::arrow::array::{ - Array, ArrayAccessor, ArrayData, BooleanArray, BooleanBufferBuilder, RecordBatch, - StringArrayType, + Array, ArrayAccessor, ArrayData, BooleanArray, BooleanBufferBuilder, DictionaryArray, + RecordBatch, StringArrayType, }; use datatypes::arrow::compute::filter_record_batch; -use datatypes::arrow::datatypes::DataType; +use datatypes::arrow::datatypes::{DataType, UInt32Type}; use datatypes::arrow::error::ArrowError; use datatypes::compute::or_kleene; use datatypes::vectors::VectorRef; @@ -264,14 +264,29 @@ impl SimpleFilterEvaluator { fn regex_match(&self, input: &impl Datum) -> std::result::Result { let array = input.get().0; - let string_array = as_string_array(array).map_err(|_| { - ArrowError::CastError(format!("Cannot cast {:?} to StringArray", array)) - })?; - let mut result = regexp_is_match_scalar(string_array, self.regex.as_ref())?; - if self.regex_negative { - result = datatypes::compute::not(&result)?; + + // Try to cast to StringArray first + if let Ok(string_array) = as_string_array(array) { + let mut result = regexp_is_match_scalar(string_array, self.regex.as_ref())?; + if self.regex_negative { + result = datatypes::compute::not(&result)?; + } + return Ok(result); } - Ok(result) + + // Try to cast to StringDictionaryArray + if let Some(dict_array) = array.as_any().downcast_ref::>() { + let mut result = regexp_is_match_dictionary(dict_array, self.regex.as_ref())?; + if self.regex_negative { + result = datatypes::compute::not(&result)?; + } + return Ok(result); + } + + Err(ArrowError::CastError(format!( + "Cannot cast {:?} to StringArray or StringDictionaryArray", + array.data_type() + ))) } } @@ -340,6 +355,55 @@ where Ok(BooleanArray::from(data)) } +/// Similar to [regexp_is_match_scalar] but for StringDictionaryArray. +/// Iterates through dictionary keys to get string values and applies regex matching. +pub fn regexp_is_match_dictionary( + dict_array: &DictionaryArray, + regex: Option<&Regex>, +) -> Result { + // Get the string values from the dictionary + let string_values = dict_array + .values() + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError("Dictionary values must be StringArray".to_string()) + })?; + + let null_bit_buffer = dict_array.nulls().map(|x| x.inner().sliced()); + let mut result = BooleanBufferBuilder::new(dict_array.len()); + + if let Some(re) = regex { + let keys = dict_array.keys().values(); + for i in 0..dict_array.len() { + if dict_array.is_null(i) { + result.append(false); + } else { + let key = keys[i] as usize; + let string_value = string_values.value(key); + result.append(re.is_match(string_value)); + } + } + } else { + result.append_n(dict_array.len(), true); + } + + let buffer = result.into(); + let data = unsafe { + ArrayData::new_unchecked( + DataType::Boolean, + dict_array.len(), + None, + null_bit_buffer, + 0, + vec![buffer], + vec![], + ) + }; + + Ok(BooleanArray::from(data)) +} + #[cfg(test)] mod test { @@ -615,4 +679,44 @@ mod test { ); assert!(result.is_err()); } + + #[test] + fn test_regex_match_dictionary_array() { + use datatypes::arrow::array::StringDictionaryBuilder; + + // Create a StringDictionaryArray + let mut builder = StringDictionaryBuilder::::new(); + builder.append("apple").unwrap(); + builder.append("banana").unwrap(); + builder.append("apple").unwrap(); + builder.append("cherry").unwrap(); + let dict_array = builder.finish(); + + // Test regex that matches "apple" + let regex = regex::Regex::new(r"app.*").unwrap(); + let result = regexp_is_match_dictionary(&dict_array, Some(®ex)).unwrap(); + + // Should match indices 0 and 2 (both "apple") + assert_eq!(result.len(), 4); + assert!(result.value(0)); // "apple" + assert!(!result.value(1)); // "banana" + assert!(result.value(2)); // "apple" + assert!(!result.value(3)); // "cherry" + + // Test regex that matches "banana" + let regex2 = regex::Regex::new(r"ban.*").unwrap(); + let result2 = regexp_is_match_dictionary(&dict_array, Some(®ex2)).unwrap(); + + assert!(!result2.value(0)); // "apple" + assert!(result2.value(1)); // "banana" + assert!(!result2.value(2)); // "apple" + assert!(!result2.value(3)); // "cherry" + + // Test with no regex (should match all) + let result3 = regexp_is_match_dictionary(&dict_array, None).unwrap(); + assert!(result3.value(0)); + assert!(result3.value(1)); + assert!(result3.value(2)); + assert!(result3.value(3)); + } }