From 32fd850c20e56dc8ccdb3ad4655a5da8d5aebda7 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Thu, 13 Mar 2025 18:08:29 -0700 Subject: [PATCH] perf: support in list in simple filter (#5709) * feat: support in list in simple filter Signed-off-by: Ruihang Xia * fix clippy Signed-off-by: Ruihang Xia --------- Signed-off-by: Ruihang Xia --- src/common/recordbatch/src/filter.rs | 96 ++++++++++++++++++- .../optimizer/filter_push_down.result | 23 +++++ .../standalone/optimizer/filter_push_down.sql | 8 ++ 3 files changed, 123 insertions(+), 4 deletions(-) diff --git a/src/common/recordbatch/src/filter.rs b/src/common/recordbatch/src/filter.rs index 8c1ebe7d53..32aae0190e 100644 --- a/src/common/recordbatch/src/filter.rs +++ b/src/common/recordbatch/src/filter.rs @@ -26,6 +26,7 @@ use datafusion_common::cast::{as_boolean_array, as_null_array}; use datafusion_common::{internal_err, DataFusionError, ScalarValue}; use datatypes::arrow::array::{Array, BooleanArray, RecordBatch}; use datatypes::arrow::compute::filter_record_batch; +use datatypes::compute::or_kleene; use datatypes::vectors::VectorRef; use snafu::ResultExt; @@ -47,6 +48,8 @@ pub struct SimpleFilterEvaluator { literal: Scalar, /// The operator. op: Operator, + /// Only used when the operator is `Or`-chain. + literal_list: Vec>, } impl SimpleFilterEvaluator { @@ -69,6 +72,7 @@ impl SimpleFilterEvaluator { column_name, literal: val.to_scalar().ok()?, op, + literal_list: vec![], }) } @@ -83,6 +87,35 @@ impl SimpleFilterEvaluator { | Operator::LtEq | Operator::Gt | Operator::GtEq => {} + Operator::Or => { + let lhs = Self::try_new(&binary.left)?; + let rhs = Self::try_new(&binary.right)?; + if lhs.column_name != rhs.column_name + || !matches!(lhs.op, Operator::Eq | Operator::Or) + || !matches!(rhs.op, Operator::Eq | Operator::Or) + { + return None; + } + let mut list = vec![]; + let placeholder_literal = lhs.literal.clone(); + // above check guarantees the op is either `Eq` or `Or` + if matches!(lhs.op, Operator::Or) { + list.extend(lhs.literal_list); + } else { + list.push(lhs.literal); + } + if matches!(rhs.op, Operator::Or) { + list.extend(rhs.literal_list); + } else { + list.push(rhs.literal); + } + return Some(Self { + column_name: lhs.column_name, + literal: placeholder_literal, + op: Operator::Or, + literal_list: list, + }); + } _ => return None, } @@ -103,6 +136,7 @@ impl SimpleFilterEvaluator { column_name: lhs.name.clone(), literal, op, + literal_list: vec![], }) } _ => None, @@ -118,19 +152,19 @@ impl SimpleFilterEvaluator { let input = input .to_scalar() .with_context(|_| ToArrowScalarSnafu { v: input.clone() })?; - let result = self.evaluate_datum(&input)?; + let result = self.evaluate_datum(&input, 1)?; Ok(result.value(0)) } pub fn evaluate_array(&self, input: &ArrayRef) -> Result { - self.evaluate_datum(input) + self.evaluate_datum(input, input.len()) } pub fn evaluate_vector(&self, input: &VectorRef) -> Result { - self.evaluate_datum(&input.to_arrow_array()) + self.evaluate_datum(&input.to_arrow_array(), input.len()) } - fn evaluate_datum(&self, input: &impl Datum) -> Result { + fn evaluate_datum(&self, input: &impl Datum, input_len: usize) -> Result { let result = match self.op { Operator::Eq => cmp::eq(input, &self.literal), Operator::NotEq => cmp::neq(input, &self.literal), @@ -138,6 +172,15 @@ impl SimpleFilterEvaluator { Operator::LtEq => cmp::lt_eq(input, &self.literal), Operator::Gt => cmp::gt(input, &self.literal), Operator::GtEq => cmp::gt_eq(input, &self.literal), + Operator::Or => { + // OR operator stands for OR-chained EQs (or INLIST in other words) + let mut result: BooleanArray = vec![false; input_len].into(); + for literal in &self.literal_list { + let rhs = cmp::eq(input, literal).context(ArrowComputeSnafu)?; + result = or_kleene(&result, &rhs).context(ArrowComputeSnafu)?; + } + Ok(result) + } _ => { return UnsupportedOperationSnafu { reason: format!("{:?}", self.op), @@ -349,4 +392,49 @@ mod test { let expected = datatypes::arrow::array::Int32Array::from(vec![5, 6]); assert_eq!(first_column_values, &expected); } + + #[test] + fn test_complex_filter_expression() { + // Create an expression tree for: col = 'B' OR col = 'C' OR col = 'D' + let col_eq_b = col("col").eq(lit("B")); + let col_eq_c = col("col").eq(lit("C")); + let col_eq_d = col("col").eq(lit("D")); + + // Build the OR chain + let col_or_expr = col_eq_b.or(col_eq_c).or(col_eq_d); + + // Check that SimpleFilterEvaluator can handle OR chain + let or_evaluator = SimpleFilterEvaluator::try_new(&col_or_expr).unwrap(); + assert_eq!(or_evaluator.column_name, "col"); + assert_eq!(or_evaluator.op, Operator::Or); + assert_eq!(or_evaluator.literal_list.len(), 3); + assert_eq!(format!("{:?}", or_evaluator.literal_list), "[Scalar(StringArray\n[\n \"B\",\n]), Scalar(StringArray\n[\n \"C\",\n]), Scalar(StringArray\n[\n \"D\",\n])]"); + + // Create a schema and batch for testing + let schema = Schema::new(vec![Field::new("col", DataType::Utf8, false)]); + let df_schema = DFSchema::try_from(schema.clone()).unwrap(); + let props = ExecutionProps::new(); + let physical_expr = create_physical_expr(&col_or_expr, &df_schema, &props).unwrap(); + + // Create test data + let col_data = Arc::new(datatypes::arrow::array::StringArray::from(vec![ + "B", "C", "E", "B", "C", "D", "F", + ])); + let batch = RecordBatch::try_new(Arc::new(schema), vec![col_data]).unwrap(); + let expected = datatypes::arrow::array::StringArray::from(vec!["B", "C", "B", "C", "D"]); + + // Filter the batch + let filtered_batch = batch_filter(&batch, &physical_expr).unwrap(); + + // Expected: rows with col in ("B", "C", "D") + // That would be rows 0, 1, 3, 4, 5 + assert_eq!(filtered_batch.num_rows(), 5); + + let col_filtered = filtered_batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col_filtered, &expected); + } } diff --git a/tests/cases/standalone/optimizer/filter_push_down.result b/tests/cases/standalone/optimizer/filter_push_down.result index 33ce01865b..30e4789fdb 100644 --- a/tests/cases/standalone/optimizer/filter_push_down.result +++ b/tests/cases/standalone/optimizer/filter_push_down.result @@ -204,3 +204,26 @@ DROP TABLE integers; Affected Rows: 0 +CREATE TABLE characters(c STRING, t TIMESTAMP TIME INDEX); + +Affected Rows: 0 + +INSERT INTO characters VALUES ('a', 1), ('b', 2), ('c', 3), (NULL, 4), ('a', 5), ('b', 6), ('c', 7), (NULL, 8); + +Affected Rows: 8 + +SELECT * FROM characters WHERE c IN ('a', 'c') ORDER BY t; + ++---+-------------------------+ +| c | t | ++---+-------------------------+ +| a | 1970-01-01T00:00:00.001 | +| c | 1970-01-01T00:00:00.003 | +| a | 1970-01-01T00:00:00.005 | +| c | 1970-01-01T00:00:00.007 | ++---+-------------------------+ + +DROP TABLE characters; + +Affected Rows: 0 + diff --git a/tests/cases/standalone/optimizer/filter_push_down.sql b/tests/cases/standalone/optimizer/filter_push_down.sql index 0d47ed3713..36688412d1 100644 --- a/tests/cases/standalone/optimizer/filter_push_down.sql +++ b/tests/cases/standalone/optimizer/filter_push_down.sql @@ -57,3 +57,11 @@ SELECT * FROM (SELECT i1.i AS a, i2.i AS b, row_number() OVER (ORDER BY i1.i, i2 SELECT * FROM (SELECT 0=1 AS cond FROM integers i1, integers i2 GROUP BY 1) a1 WHERE cond ORDER BY 1; DROP TABLE integers; + +CREATE TABLE characters(c STRING, t TIMESTAMP TIME INDEX); + +INSERT INTO characters VALUES ('a', 1), ('b', 2), ('c', 3), (NULL, 4), ('a', 5), ('b', 6), ('c', 7), (NULL, 8); + +SELECT * FROM characters WHERE c IN ('a', 'c') ORDER BY t; + +DROP TABLE characters;