diff --git a/src/common/recordbatch/src/filter.rs b/src/common/recordbatch/src/filter.rs index 32aae0190e..0d8a7632d9 100644 --- a/src/common/recordbatch/src/filter.rs +++ b/src/common/recordbatch/src/filter.rs @@ -22,10 +22,12 @@ use datafusion::physical_plan::PhysicalExpr; use datafusion_common::arrow::array::{ArrayRef, Datum, Scalar}; use datafusion_common::arrow::buffer::BooleanBuffer; use datafusion_common::arrow::compute::kernels::cmp; -use datafusion_common::cast::{as_boolean_array, as_null_array}; +use datafusion_common::cast::{as_boolean_array, as_null_array, as_string_array}; use datafusion_common::{internal_err, DataFusionError, ScalarValue}; use datatypes::arrow::array::{Array, BooleanArray, RecordBatch}; use datatypes::arrow::compute::filter_record_batch; +use datatypes::arrow::error::ArrowError; +use datatypes::compute::kernels::regexp; use datatypes::compute::or_kleene; use datatypes::vectors::VectorRef; use snafu::ResultExt; @@ -36,7 +38,8 @@ use crate::error::{ArrowComputeSnafu, Result, ToArrowScalarSnafu, UnsupportedOpe /// - `col` `op` `literal` /// - `literal` `op` `col` /// -/// And the `op` is one of `=`, `!=`, `>`, `>=`, `<`, `<=`. +/// And the `op` is one of `=`, `!=`, `>`, `>=`, `<`, `<=`, +/// or regex operators: `~`, `~*`, `!~`, `!~*`. /// /// This struct contains normalized predicate expr. In the form of /// `col` `op` `literal` where the `col` is provided from input. @@ -86,7 +89,11 @@ impl SimpleFilterEvaluator { | Operator::Lt | Operator::LtEq | Operator::Gt - | Operator::GtEq => {} + | Operator::GtEq + | Operator::RegexMatch + | Operator::RegexIMatch + | Operator::RegexNotMatch + | Operator::RegexNotIMatch => {} Operator::Or => { let lhs = Self::try_new(&binary.left)?; let rhs = Self::try_new(&binary.right)?; @@ -172,6 +179,10 @@ impl SimpleFilterEvaluator { Operator::LtEq => cmp::lt_eq(input, &self.literal), Operator::Gt => cmp::gt(input, &self.literal), Operator::GtEq => cmp::gt_eq(input, &self.literal), + Operator::RegexMatch => self.regex_match(input, false, false), + Operator::RegexIMatch => self.regex_match(input, true, false), + Operator::RegexNotMatch => self.regex_match(input, false, true), + Operator::RegexNotIMatch => self.regex_match(input, true, true), Operator::Or => { // OR operator stands for OR-chained EQs (or INLIST in other words) let mut result: BooleanArray = vec![false; input_len].into(); @@ -192,6 +203,28 @@ impl SimpleFilterEvaluator { .context(ArrowComputeSnafu) .map(|array| array.values().clone()) } + + fn regex_match( + &self, + input: &impl Datum, + ignore_case: bool, + negative: bool, + ) -> std::result::Result { + let flag = if ignore_case { Some("i") } else { None }; + let array = input.get().0; + let string_array = as_string_array(array).map_err(|_| { + ArrowError::CastError(format!("Cannot cast {:?} to StringArray", array)) + })?; + let literal_array = self.literal.clone().into_inner(); + let regex_array = as_string_array(&literal_array).map_err(|_| { + ArrowError::CastError(format!("Cannot cast {:?} to StringArray", literal_array)) + })?; + let mut result = regexp::regexp_is_match_scalar(string_array, regex_array.value(0), flag)?; + if negative { + result = datatypes::compute::not(&result)?; + } + Ok(result) + } } /// Evaluate the predicate on the input [RecordBatch], and return a new [RecordBatch]. diff --git a/tests/cases/standalone/common/select/tql_filter.result b/tests/cases/standalone/common/select/tql_filter.result index bad003eb05..ad598a21a0 100644 --- a/tests/cases/standalone/common/select/tql_filter.result +++ b/tests/cases/standalone/common/select/tql_filter.result @@ -72,11 +72,7 @@ tql analyze (1, 3, '1s') t1{ a =~ "a.*" }; |_|_|_SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] REDACTED |_|_|_MergeScanExec: REDACTED |_|_|_| -| 1_| 0_|_SortPreservingMergeExec: [a@0 DESC NULLS LAST, b@1 DESC NULLS LAST] REDACTED -|_|_|_SortExec: expr=[a@0 DESC NULLS LAST, b@1 DESC NULLS LAST], preserve_partitioning=[true] REDACTED -|_|_|_CoalesceBatchesExec: target_batch_size=8192 REDACTED -|_|_|_FilterExec: a@0 ~ a.* REDACTED -|_|_|_RepartitionExec: partitioning=REDACTED +| 1_| 0_|_SortExec: expr=[a@0 DESC NULLS LAST, b@1 DESC NULLS LAST], preserve_partitioning=[false] REDACTED |_|_|_SeqScan: region=REDACTED, partition_count=1 (1 memtable ranges, 0 file 0 ranges) REDACTED |_|_|_| |_|_| Total rows: 3_|