1use std::sync::Arc;
16
17use arc_swap::ArcSwap;
18use common_telemetry::{debug, warn};
19use common_time::Timestamp;
20use common_time::range::TimestampRange;
21use common_time::timestamp::TimeUnit;
22use datafusion::common::ScalarValue;
23use datafusion::physical_optimizer::pruning::PruningPredicate;
24use datafusion_common::ToDFSchema;
25use datafusion_common::pruning::PruningStatistics;
26use datafusion_expr::expr::{Expr, InList};
27use datafusion_expr::{Between, BinaryExpr, Operator};
28use datafusion_physical_expr::execution_props::ExecutionProps;
29use datafusion_physical_expr::expressions::DynamicFilterPhysicalExpr;
30use datafusion_physical_expr::{PhysicalExpr, create_physical_expr};
31use datatypes::arrow;
32use datatypes::value::scalar_value_to_timestamp;
33use snafu::ResultExt;
34
35use crate::error;
36
37#[cfg(test)]
38mod stats;
39
40macro_rules! return_none_if_utf8 {
43 ($lit: ident) => {
44 if is_string_timestamp_literal($lit) {
45 warn!(
46 "Unexpected ScalarValue::Utf8 in time range predicate: {:?}. Maybe it's an implicit bug, please report it to https://github.com/GreptimeTeam/greptimedb/issues",
47 $lit
48 );
49
50 return None;
52 }
53 };
54}
55
56pub fn is_string_timestamp_literal(scalar: &ScalarValue) -> bool {
57 matches!(
58 scalar,
59 ScalarValue::Utf8(_) | ScalarValue::LargeUtf8(_) | ScalarValue::Utf8View(_)
60 )
61}
62
63#[derive(Debug, Clone, Default)]
65pub struct Predicate {
66 exprs: Arc<Vec<Expr>>,
68 dyn_filters: Arc<ArcSwap<Vec<Arc<DynamicFilterPhysicalExpr>>>>,
72}
73
74impl Predicate {
75 pub fn new(exprs: Vec<Expr>) -> Self {
79 Self {
80 exprs: Arc::new(exprs),
81 dyn_filters: Arc::new(ArcSwap::new(Arc::new(vec![]))),
82 }
83 }
84
85 pub fn with_dyn_filters(
86 exprs: Vec<Expr>,
87 dyn_filters: Vec<Arc<DynamicFilterPhysicalExpr>>,
88 ) -> Self {
89 Self {
90 exprs: Arc::new(exprs),
91 dyn_filters: Arc::new(ArcSwap::new(Arc::new(dyn_filters))),
92 }
93 }
94
95 pub fn is_empty(&self) -> bool {
96 self.exprs.is_empty() && self.dyn_filters.load().is_empty()
97 }
98
99 pub fn add_dyn_filters(&self, dyn_filters: Vec<Arc<DynamicFilterPhysicalExpr>>) {
101 self.dyn_filters.rcu(|existing| {
102 let mut new_filters = existing.as_ref().clone();
103 new_filters.extend(dyn_filters.clone());
104 Arc::new(new_filters)
105 });
106 }
107
108 pub fn exprs(&self) -> &[Expr] {
110 &self.exprs
111 }
112
113 pub fn dyn_filters(&self) -> Arc<Vec<Arc<DynamicFilterPhysicalExpr>>> {
116 self.dyn_filters.load_full()
117 }
118
119 pub fn dyn_filter_phy_exprs(&self) -> error::Result<Vec<Arc<dyn PhysicalExpr>>> {
122 self.dyn_filters
123 .load()
124 .iter()
125 .map(|e| e.current())
126 .collect::<Result<Vec<_>, _>>()
127 .context(error::DatafusionSnafu)
128 }
129
130 pub fn to_physical_expr(
132 expr: &Expr,
133 schema: &arrow::datatypes::SchemaRef,
134 ) -> error::Result<Arc<dyn PhysicalExpr>> {
135 let df_schema = schema
136 .clone()
137 .to_dfschema_ref()
138 .context(error::DatafusionSnafu)?;
139
140 let execution_props = &ExecutionProps::new();
144
145 create_physical_expr(expr, df_schema.as_ref(), execution_props)
146 .context(error::DatafusionSnafu)
147 }
148
149 pub fn to_physical_exprs(
151 &self,
152 schema: &arrow::datatypes::SchemaRef,
153 ) -> error::Result<Vec<Arc<dyn PhysicalExpr>>> {
154 let dyn_filters = self.dyn_filter_phy_exprs()?;
155
156 Ok(self
157 .exprs
158 .iter()
159 .filter_map(|expr| Self::to_physical_expr(expr, schema).ok())
160 .chain(dyn_filters)
161 .collect::<Vec<_>>())
162 }
163
164 pub fn prune_with_stats<S: PruningStatistics>(
167 &self,
168 stats: &S,
169 schema: &arrow::datatypes::SchemaRef,
170 ) -> Vec<bool> {
171 let mut res = vec![true; stats.num_containers()];
172 let physical_exprs = match self.to_physical_exprs(schema) {
173 Ok(expr) => expr,
174 Err(e) => {
175 warn!(e; "Failed to build physical expr from predicates: {:?}", &self.exprs);
176 return res;
177 }
178 };
179
180 for expr in &physical_exprs {
181 match PruningPredicate::try_new(expr.clone(), schema.clone()) {
182 Ok(p) => match p.prune(stats) {
183 Ok(r) => {
184 for (curr_val, res) in r.into_iter().zip(res.iter_mut()) {
185 *res &= curr_val
186 }
187 }
188 Err(e) => {
189 warn!(e; "Failed to prune row groups");
190 }
191 },
192 Err(e) => {
193 debug!("Failed to create pruning predicate for expr: {e:?}");
195 }
196 }
197 }
198 res
199 }
200}
201
202pub fn build_time_range_predicate(
207 ts_col_name: &str,
208 ts_col_unit: TimeUnit,
209 filters: &[Expr],
210) -> TimestampRange {
211 let mut res = TimestampRange::min_to_max();
212 for expr in filters {
213 if let Some(range) = extract_time_range_from_expr(ts_col_name, ts_col_unit, expr) {
214 res = res.and(&range);
215 }
216 }
217 res
218}
219
220pub fn extract_time_range_from_expr(
223 ts_col_name: &str,
224 ts_col_unit: TimeUnit,
225 expr: &Expr,
226) -> Option<TimestampRange> {
227 match expr {
228 Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
229 extract_from_binary_expr(ts_col_name, ts_col_unit, left, op, right)
230 }
231 Expr::Between(Between {
232 expr,
233 negated,
234 low,
235 high,
236 }) => extract_from_between_expr(ts_col_name, ts_col_unit, expr, negated, low, high),
237 Expr::InList(InList {
238 expr,
239 list,
240 negated,
241 }) => extract_from_in_list_expr(ts_col_name, expr, *negated, list),
242 _ => None,
243 }
244}
245
246fn extract_from_binary_expr(
247 ts_col_name: &str,
248 ts_col_unit: TimeUnit,
249 left: &Expr,
250 op: &Operator,
251 right: &Expr,
252) -> Option<TimestampRange> {
253 match op {
254 Operator::Eq => get_timestamp_filter(ts_col_name, left, right)
255 .and_then(|(ts, _)| ts.convert_to(ts_col_unit))
256 .map(TimestampRange::single),
257 Operator::Lt => {
258 let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
259 if reverse {
260 let ts_val = ts.convert_to(ts_col_unit)?.value();
262 Some(TimestampRange::from_start(Timestamp::new(
263 ts_val + 1,
264 ts_col_unit,
265 )))
266 } else {
267 ts.convert_to_ceil(ts_col_unit)
269 .map(|ts| TimestampRange::until_end(ts, false))
270 }
271 }
272 Operator::LtEq => {
273 let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
274 if reverse {
275 ts.convert_to_ceil(ts_col_unit)
277 .map(TimestampRange::from_start)
278 } else {
279 ts.convert_to(ts_col_unit)
281 .map(|ts| TimestampRange::until_end(ts, true))
282 }
283 }
284 Operator::Gt => {
285 let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
286 if reverse {
287 ts.convert_to_ceil(ts_col_unit)
289 .map(|t| TimestampRange::until_end(t, false))
290 } else {
291 let ts_val = ts.convert_to(ts_col_unit)?.value();
293 Some(TimestampRange::from_start(Timestamp::new(
294 ts_val + 1,
295 ts_col_unit,
296 )))
297 }
298 }
299 Operator::GtEq => {
300 let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
301 if reverse {
302 ts.convert_to(ts_col_unit)
304 .map(|t| TimestampRange::until_end(t, true))
305 } else {
306 ts.convert_to_ceil(ts_col_unit)
308 .map(TimestampRange::from_start)
309 }
310 }
311 Operator::And => {
312 let left = extract_time_range_from_expr(ts_col_name, ts_col_unit, left)
315 .unwrap_or_else(TimestampRange::min_to_max);
316 let right = extract_time_range_from_expr(ts_col_name, ts_col_unit, right)
317 .unwrap_or_else(TimestampRange::min_to_max);
318 Some(left.and(&right))
319 }
320 Operator::Or => {
321 let left = extract_time_range_from_expr(ts_col_name, ts_col_unit, left)?;
322 let right = extract_time_range_from_expr(ts_col_name, ts_col_unit, right)?;
323 Some(left.or(&right))
324 }
325 _ => None,
326 }
327}
328
329fn get_timestamp_filter(ts_col_name: &str, left: &Expr, right: &Expr) -> Option<(Timestamp, bool)> {
330 let (col, lit, reverse) = match (left, right) {
331 (Expr::Column(column), Expr::Literal(scalar, _)) => (column, scalar, false),
332 (Expr::Literal(scalar, _), Expr::Column(column)) => (column, scalar, true),
333 _ => {
334 return None;
335 }
336 };
337 if col.name != ts_col_name {
338 return None;
339 }
340
341 return_none_if_utf8!(lit);
342 scalar_value_to_timestamp(lit, None).map(|t| (t, reverse))
343}
344
345fn extract_from_between_expr(
346 ts_col_name: &str,
347 ts_col_unit: TimeUnit,
348 expr: &Expr,
349 negated: &bool,
350 low: &Expr,
351 high: &Expr,
352) -> Option<TimestampRange> {
353 let Expr::Column(col) = expr else {
354 return None;
355 };
356 if col.name != ts_col_name {
357 return None;
358 }
359
360 if *negated {
361 return None;
362 }
363
364 match (low, high) {
365 (Expr::Literal(low, _), Expr::Literal(high, _)) => {
366 return_none_if_utf8!(low);
367 return_none_if_utf8!(high);
368
369 let low_opt =
370 scalar_value_to_timestamp(low, None).and_then(|ts| ts.convert_to(ts_col_unit));
371 let high_opt = scalar_value_to_timestamp(high, None)
372 .and_then(|ts| ts.convert_to_ceil(ts_col_unit));
373 Some(TimestampRange::new_inclusive(low_opt, high_opt))
374 }
375 _ => None,
376 }
377}
378
379fn extract_from_in_list_expr(
381 ts_col_name: &str,
382 expr: &Expr,
383 negated: bool,
384 list: &[Expr],
385) -> Option<TimestampRange> {
386 if negated {
387 return None;
388 }
389 let Expr::Column(col) = expr else {
390 return None;
391 };
392 if col.name != ts_col_name {
393 return None;
394 }
395
396 if list.is_empty() {
397 return Some(TimestampRange::empty());
398 }
399 let mut init_range = TimestampRange::empty();
400 for expr in list {
401 if let Expr::Literal(scalar, _) = expr {
402 return_none_if_utf8!(scalar);
403 if let Some(timestamp) = scalar_value_to_timestamp(scalar, None) {
404 init_range = init_range.or(&TimestampRange::single(timestamp))
405 } else {
406 return None;
409 }
410 }
411 }
412 Some(init_range)
413}
414
415#[cfg(test)]
416mod tests {
417 use std::sync::Arc;
418
419 use common_test_util::temp_dir::{TempDir, create_temp_dir};
420 use datafusion::parquet::arrow::ArrowWriter;
421 use datafusion_common::{Column, ScalarValue};
422 use datafusion_expr::{BinaryExpr, Literal, Operator, col, lit};
423 use datatypes::arrow::array::Int32Array;
424 use datatypes::arrow::datatypes::{DataType, Field, Schema};
425 use datatypes::arrow::record_batch::RecordBatch;
426 use datatypes::arrow_array::StringArray;
427 use parquet::arrow::ParquetRecordBatchStreamBuilder;
428 use parquet::file::properties::WriterProperties;
429
430 use super::*;
431 use crate::predicate::stats::RowGroupPruningStatistics;
432
433 fn check_build_predicate(expr: Expr, expect: TimestampRange) {
434 assert_eq!(
435 expect,
436 build_time_range_predicate("ts", TimeUnit::Millisecond, &[expr])
437 );
438 }
439
440 #[test]
441 fn test_gt() {
442 check_build_predicate(
444 col("ts").gt(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
445 TimestampRange::from_start(Timestamp::new_millisecond(2)),
446 );
447
448 check_build_predicate(
450 lit(ScalarValue::TimestampMillisecond(Some(1), None)).gt(col("ts")),
451 TimestampRange::until_end(Timestamp::new_millisecond(1), false),
452 );
453
454 check_build_predicate(
456 lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).gt(col("ts")),
457 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
458 );
459
460 check_build_predicate(
462 col("ts").gt(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
463 TimestampRange::from_start(Timestamp::new_millisecond(2)),
464 );
465
466 check_build_predicate(
468 lit(ScalarValue::TimestampSecond(Some(1), None)).gt(col("ts")),
469 TimestampRange::until_end(Timestamp::new_millisecond(1000), false),
470 );
471
472 check_build_predicate(
474 col("ts").gt(lit(ScalarValue::TimestampSecond(Some(1), None))),
475 TimestampRange::from_start(Timestamp::new_millisecond(1001)),
476 );
477 }
478
479 #[test]
480 fn test_gt_eq() {
481 check_build_predicate(
483 col("ts").gt_eq(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
484 TimestampRange::from_start(Timestamp::new_millisecond(1)),
485 );
486
487 check_build_predicate(
489 lit(ScalarValue::TimestampMillisecond(Some(1), None)).gt_eq(col("ts")),
490 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
491 );
492
493 check_build_predicate(
495 lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).gt_eq(col("ts")),
496 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
497 );
498
499 check_build_predicate(
501 col("ts").gt_eq(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
502 TimestampRange::from_start(Timestamp::new_millisecond(2)),
503 );
504
505 check_build_predicate(
507 lit(ScalarValue::TimestampSecond(Some(1), None)).gt_eq(col("ts")),
508 TimestampRange::until_end(Timestamp::new_millisecond(1000), true),
509 );
510
511 check_build_predicate(
513 col("ts").gt_eq(lit(ScalarValue::TimestampSecond(Some(1), None))),
514 TimestampRange::from_start(Timestamp::new_millisecond(1000)),
515 );
516 }
517
518 #[test]
519 fn test_lt() {
520 check_build_predicate(
522 col("ts").lt(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
523 TimestampRange::until_end(Timestamp::new_millisecond(1), false),
524 );
525
526 check_build_predicate(
528 lit(ScalarValue::TimestampMillisecond(Some(1), None)).lt(col("ts")),
529 TimestampRange::from_start(Timestamp::new_millisecond(2)),
530 );
531
532 check_build_predicate(
534 lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).lt(col("ts")),
535 TimestampRange::from_start(Timestamp::new_millisecond(2)),
536 );
537
538 check_build_predicate(
540 col("ts").lt(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
541 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
542 );
543
544 check_build_predicate(
546 lit(ScalarValue::TimestampSecond(Some(1), None)).lt(col("ts")),
547 TimestampRange::from_start(Timestamp::new_millisecond(1001)),
548 );
549
550 check_build_predicate(
552 col("ts").lt(lit(ScalarValue::TimestampSecond(Some(1), None))),
553 TimestampRange::until_end(Timestamp::new_millisecond(1000), false),
554 );
555 }
556
557 #[test]
558 fn test_lt_eq() {
559 check_build_predicate(
561 col("ts").lt_eq(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
562 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
563 );
564
565 check_build_predicate(
567 lit(ScalarValue::TimestampMillisecond(Some(1), None)).lt_eq(col("ts")),
568 TimestampRange::from_start(Timestamp::new_millisecond(1)),
569 );
570
571 check_build_predicate(
573 lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).lt_eq(col("ts")),
574 TimestampRange::from_start(Timestamp::new_millisecond(2)),
575 );
576
577 check_build_predicate(
579 col("ts").lt_eq(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
580 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
581 );
582
583 check_build_predicate(
585 lit(ScalarValue::TimestampSecond(Some(1), None)).lt_eq(col("ts")),
586 TimestampRange::from_start(Timestamp::new_millisecond(1000)),
587 );
588
589 check_build_predicate(
591 col("ts").lt_eq(lit(ScalarValue::TimestampSecond(Some(1), None))),
592 TimestampRange::until_end(Timestamp::new_millisecond(1000), true),
593 );
594 }
595
596 async fn gen_test_parquet_file(dir: &TempDir, cnt: usize) -> (String, Arc<Schema>) {
597 let path = dir
598 .path()
599 .join("test-prune.parquet")
600 .to_string_lossy()
601 .to_string();
602
603 let name_field = Field::new("name", DataType::Utf8, true);
604 let count_field = Field::new("cnt", DataType::Int32, true);
605 let schema = Arc::new(Schema::new(vec![name_field, count_field]));
606
607 let file = std::fs::OpenOptions::new()
608 .write(true)
609 .create(true)
610 .truncate(true)
611 .open(path.clone())
612 .unwrap();
613
614 let write_props = WriterProperties::builder()
615 .set_max_row_group_row_count(Some(10))
616 .build();
617 let mut writer = ArrowWriter::try_new(file, schema.clone(), Some(write_props)).unwrap();
618
619 for i in (0..cnt).step_by(10) {
620 let name_array = Arc::new(StringArray::from(
621 (i..(i + 10).min(cnt))
622 .map(|i| i.to_string())
623 .collect::<Vec<_>>(),
624 )) as Arc<_>;
625 let count_array = Arc::new(Int32Array::from(
626 (i..(i + 10).min(cnt)).map(|i| i as i32).collect::<Vec<_>>(),
627 )) as Arc<_>;
628 let rb = RecordBatch::try_new(schema.clone(), vec![name_array, count_array]).unwrap();
629 writer.write(&rb).unwrap();
630 }
631 let _ = writer.close().unwrap();
632 (path, schema)
633 }
634
635 async fn assert_prune(array_cnt: usize, filters: Vec<Expr>, expect: Vec<bool>) {
636 let dir = create_temp_dir("prune_parquet");
637 let (path, arrow_schema) = gen_test_parquet_file(&dir, array_cnt).await;
638 let schema = Arc::new(datatypes::schema::Schema::try_from(arrow_schema.clone()).unwrap());
639 let arrow_predicate = Predicate::new(filters);
640 let builder = ParquetRecordBatchStreamBuilder::new(
641 tokio::fs::OpenOptions::new()
642 .read(true)
643 .open(path)
644 .await
645 .unwrap(),
646 )
647 .await
648 .unwrap();
649 let metadata = builder.metadata().clone();
650 let row_groups = metadata.row_groups();
651
652 let stats = RowGroupPruningStatistics::new(row_groups, &schema);
653 let res = arrow_predicate.prune_with_stats(&stats, &arrow_schema);
654 assert_eq!(expect, res);
655 }
656
657 fn gen_predicate(max_val: i32, op: Operator) -> Vec<Expr> {
658 vec![datafusion_expr::Expr::BinaryExpr(BinaryExpr {
659 left: Box::new(datafusion_expr::Expr::Column(Column::from_name("cnt"))),
660 op,
661 right: Box::new(max_val.lit()),
662 })]
663 }
664
665 #[tokio::test]
666 async fn test_prune_empty() {
667 assert_prune(3, vec![], vec![true]).await;
668 }
669
670 #[tokio::test]
671 async fn test_prune_all_match() {
672 let p = gen_predicate(3, Operator::Gt);
673 assert_prune(2, p, vec![false]).await;
674 }
675
676 #[tokio::test]
677 async fn test_prune_gt() {
678 let p = gen_predicate(29, Operator::Gt);
679 assert_prune(
680 100,
681 p,
682 vec![
683 false, false, false, true, true, true, true, true, true, true,
684 ],
685 )
686 .await;
687 }
688
689 #[tokio::test]
690 async fn test_prune_eq_expr() {
691 let p = gen_predicate(30, Operator::Eq);
692 assert_prune(40, p, vec![false, false, false, true]).await;
693 }
694
695 #[tokio::test]
696 async fn test_prune_neq_expr() {
697 let p = gen_predicate(30, Operator::NotEq);
698 assert_prune(40, p, vec![true, true, true, true]).await;
699 }
700
701 #[tokio::test]
702 async fn test_prune_gteq_expr() {
703 let p = gen_predicate(29, Operator::GtEq);
704 assert_prune(40, p, vec![false, false, true, true]).await;
705 }
706
707 #[tokio::test]
708 async fn test_prune_lt_expr() {
709 let p = gen_predicate(30, Operator::Lt);
710 assert_prune(40, p, vec![true, true, true, false]).await;
711 }
712
713 #[tokio::test]
714 async fn test_prune_lteq_expr() {
715 let p = gen_predicate(30, Operator::LtEq);
716 assert_prune(40, p, vec![true, true, true, true]).await;
717 }
718
719 #[tokio::test]
720 async fn test_prune_between_expr() {
721 let p = gen_predicate(30, Operator::LtEq);
722 assert_prune(40, p, vec![true, true, true, true]).await;
723 }
724
725 #[tokio::test]
726 async fn test_or() {
727 let e = datafusion_expr::Expr::Column(Column::from_name("cnt"))
729 .gt(30.lit())
730 .or(datafusion_expr::Expr::Column(Column::from_name("cnt")).lt(20.lit()));
731 assert_prune(40, vec![e], vec![true, true, false, true]).await;
732 }
733
734 #[tokio::test]
735 async fn test_to_physical_expr() {
736 let predicate = Predicate::new(vec![
737 col("host").eq(lit("host_a")),
738 col("ts").gt(lit(ScalarValue::TimestampMicrosecond(Some(123), None))),
739 ]);
740
741 let schema = Arc::new(arrow::datatypes::Schema::new(vec![Field::new(
742 "host",
743 arrow::datatypes::DataType::Utf8,
744 false,
745 )]));
746
747 let predicates = predicate.to_physical_exprs(&schema).unwrap();
748 assert!(!predicates.is_empty());
749
750 let physical_expr = Predicate::to_physical_expr(&col("host").eq(lit("host_a")), &schema);
751 assert!(physical_expr.is_ok());
752 }
753}