1use std::sync::Arc;
16
17use common_telemetry::{error, warn};
18use common_time::Timestamp;
19use common_time::range::TimestampRange;
20use common_time::timestamp::TimeUnit;
21use datafusion::common::ScalarValue;
22use datafusion::physical_optimizer::pruning::PruningPredicate;
23use datafusion_common::ToDFSchema;
24use datafusion_common::pruning::PruningStatistics;
25use datafusion_expr::expr::{Expr, InList};
26use datafusion_expr::{Between, BinaryExpr, Operator};
27use datafusion_physical_expr::execution_props::ExecutionProps;
28use datafusion_physical_expr::expressions::DynamicFilterPhysicalExpr;
29use datafusion_physical_expr::{PhysicalExpr, create_physical_expr};
30use datatypes::arrow;
31use datatypes::value::scalar_value_to_timestamp;
32use snafu::ResultExt;
33
34use crate::error;
35
36#[cfg(test)]
37mod stats;
38
39macro_rules! return_none_if_utf8 {
42 ($lit: ident) => {
43 if matches!($lit, ScalarValue::Utf8(_)) {
44 warn!(
45 "Unexpected ScalarValue::Utf8 in time range predicate: {:?}. Maybe it's an implicit bug, please report it to https://github.com/GreptimeTeam/greptimedb/issues",
46 $lit
47 );
48
49 return None;
51 }
52 };
53}
54
55#[derive(Debug, Clone)]
57pub struct Predicate {
58 exprs: Arc<Vec<Expr>>,
60 dyn_filters: Arc<Vec<DynamicFilterPhysicalExpr>>,
64}
65
66impl Predicate {
67 pub fn new(exprs: Vec<Expr>) -> Self {
71 Self {
72 exprs: Arc::new(exprs),
73 dyn_filters: Arc::new(vec![]),
74 }
75 }
76
77 pub fn with_dyn_filters(self, dyn_filters: Arc<Vec<DynamicFilterPhysicalExpr>>) -> Self {
79 Self {
80 exprs: self.exprs,
81 dyn_filters,
82 }
83 }
84
85 pub fn exprs(&self) -> &[Expr] {
87 &self.exprs
88 }
89
90 pub fn dyn_filters(&self) -> &Arc<Vec<DynamicFilterPhysicalExpr>> {
93 &self.dyn_filters
94 }
95
96 pub fn dyn_filter_phy_exprs(&self) -> error::Result<Vec<Arc<dyn PhysicalExpr>>> {
99 self.dyn_filters
100 .iter()
101 .map(|e| e.current())
102 .collect::<Result<Vec<_>, _>>()
103 .context(error::DatafusionSnafu)
104 }
105
106 pub fn to_physical_exprs(
108 &self,
109 schema: &arrow::datatypes::SchemaRef,
110 ) -> error::Result<Vec<Arc<dyn PhysicalExpr>>> {
111 let df_schema = schema
112 .clone()
113 .to_dfschema_ref()
114 .context(error::DatafusionSnafu)?;
115
116 let execution_props = &ExecutionProps::new();
120
121 let dyn_filters = self.dyn_filter_phy_exprs()?;
122
123 Ok(self
124 .exprs
125 .iter()
126 .filter_map(|expr| create_physical_expr(expr, df_schema.as_ref(), execution_props).ok())
127 .chain(dyn_filters)
128 .collect::<Vec<_>>())
129 }
130
131 pub fn prune_with_stats<S: PruningStatistics>(
134 &self,
135 stats: &S,
136 schema: &arrow::datatypes::SchemaRef,
137 ) -> Vec<bool> {
138 let mut res = vec![true; stats.num_containers()];
139 let physical_exprs = match self.to_physical_exprs(schema) {
140 Ok(expr) => expr,
141 Err(e) => {
142 warn!(e; "Failed to build physical expr from predicates: {:?}", &self.exprs);
143 return res;
144 }
145 };
146
147 for expr in &physical_exprs {
148 match PruningPredicate::try_new(expr.clone(), schema.clone()) {
149 Ok(p) => match p.prune(stats) {
150 Ok(r) => {
151 for (curr_val, res) in r.into_iter().zip(res.iter_mut()) {
152 *res &= curr_val
153 }
154 }
155 Err(e) => {
156 warn!(e; "Failed to prune row groups");
157 }
158 },
159 Err(e) => {
160 error!(e; "Failed to create predicate for expr");
161 }
162 }
163 }
164 res
165 }
166}
167
168pub fn build_time_range_predicate(
173 ts_col_name: &str,
174 ts_col_unit: TimeUnit,
175 filters: &[Expr],
176) -> TimestampRange {
177 let mut res = TimestampRange::min_to_max();
178 for expr in filters {
179 if let Some(range) = extract_time_range_from_expr(ts_col_name, ts_col_unit, expr) {
180 res = res.and(&range);
181 }
182 }
183 res
184}
185
186fn extract_time_range_from_expr(
189 ts_col_name: &str,
190 ts_col_unit: TimeUnit,
191 expr: &Expr,
192) -> Option<TimestampRange> {
193 match expr {
194 Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
195 extract_from_binary_expr(ts_col_name, ts_col_unit, left, op, right)
196 }
197 Expr::Between(Between {
198 expr,
199 negated,
200 low,
201 high,
202 }) => extract_from_between_expr(ts_col_name, ts_col_unit, expr, negated, low, high),
203 Expr::InList(InList {
204 expr,
205 list,
206 negated,
207 }) => extract_from_in_list_expr(ts_col_name, expr, *negated, list),
208 _ => None,
209 }
210}
211
212fn extract_from_binary_expr(
213 ts_col_name: &str,
214 ts_col_unit: TimeUnit,
215 left: &Expr,
216 op: &Operator,
217 right: &Expr,
218) -> Option<TimestampRange> {
219 match op {
220 Operator::Eq => get_timestamp_filter(ts_col_name, left, right)
221 .and_then(|(ts, _)| ts.convert_to(ts_col_unit))
222 .map(TimestampRange::single),
223 Operator::Lt => {
224 let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
225 if reverse {
226 let ts_val = ts.convert_to(ts_col_unit)?.value();
228 Some(TimestampRange::from_start(Timestamp::new(
229 ts_val + 1,
230 ts_col_unit,
231 )))
232 } else {
233 ts.convert_to_ceil(ts_col_unit)
235 .map(|ts| TimestampRange::until_end(ts, false))
236 }
237 }
238 Operator::LtEq => {
239 let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
240 if reverse {
241 ts.convert_to_ceil(ts_col_unit)
243 .map(TimestampRange::from_start)
244 } else {
245 ts.convert_to(ts_col_unit)
247 .map(|ts| TimestampRange::until_end(ts, true))
248 }
249 }
250 Operator::Gt => {
251 let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
252 if reverse {
253 ts.convert_to_ceil(ts_col_unit)
255 .map(|t| TimestampRange::until_end(t, false))
256 } else {
257 let ts_val = ts.convert_to(ts_col_unit)?.value();
259 Some(TimestampRange::from_start(Timestamp::new(
260 ts_val + 1,
261 ts_col_unit,
262 )))
263 }
264 }
265 Operator::GtEq => {
266 let (ts, reverse) = get_timestamp_filter(ts_col_name, left, right)?;
267 if reverse {
268 ts.convert_to(ts_col_unit)
270 .map(|t| TimestampRange::until_end(t, true))
271 } else {
272 ts.convert_to_ceil(ts_col_unit)
274 .map(TimestampRange::from_start)
275 }
276 }
277 Operator::And => {
278 let left = extract_time_range_from_expr(ts_col_name, ts_col_unit, left)
281 .unwrap_or_else(TimestampRange::min_to_max);
282 let right = extract_time_range_from_expr(ts_col_name, ts_col_unit, right)
283 .unwrap_or_else(TimestampRange::min_to_max);
284 Some(left.and(&right))
285 }
286 Operator::Or => {
287 let left = extract_time_range_from_expr(ts_col_name, ts_col_unit, left)?;
288 let right = extract_time_range_from_expr(ts_col_name, ts_col_unit, right)?;
289 Some(left.or(&right))
290 }
291 _ => None,
292 }
293}
294
295fn get_timestamp_filter(ts_col_name: &str, left: &Expr, right: &Expr) -> Option<(Timestamp, bool)> {
296 let (col, lit, reverse) = match (left, right) {
297 (Expr::Column(column), Expr::Literal(scalar, _)) => (column, scalar, false),
298 (Expr::Literal(scalar, _), Expr::Column(column)) => (column, scalar, true),
299 _ => {
300 return None;
301 }
302 };
303 if col.name != ts_col_name {
304 return None;
305 }
306
307 return_none_if_utf8!(lit);
308 scalar_value_to_timestamp(lit, None).map(|t| (t, reverse))
309}
310
311fn extract_from_between_expr(
312 ts_col_name: &str,
313 ts_col_unit: TimeUnit,
314 expr: &Expr,
315 negated: &bool,
316 low: &Expr,
317 high: &Expr,
318) -> Option<TimestampRange> {
319 let Expr::Column(col) = expr else {
320 return None;
321 };
322 if col.name != ts_col_name {
323 return None;
324 }
325
326 if *negated {
327 return None;
328 }
329
330 match (low, high) {
331 (Expr::Literal(low, _), Expr::Literal(high, _)) => {
332 return_none_if_utf8!(low);
333 return_none_if_utf8!(high);
334
335 let low_opt =
336 scalar_value_to_timestamp(low, None).and_then(|ts| ts.convert_to(ts_col_unit));
337 let high_opt = scalar_value_to_timestamp(high, None)
338 .and_then(|ts| ts.convert_to_ceil(ts_col_unit));
339 Some(TimestampRange::new_inclusive(low_opt, high_opt))
340 }
341 _ => None,
342 }
343}
344
345fn extract_from_in_list_expr(
347 ts_col_name: &str,
348 expr: &Expr,
349 negated: bool,
350 list: &[Expr],
351) -> Option<TimestampRange> {
352 if negated {
353 return None;
354 }
355 let Expr::Column(col) = expr else {
356 return None;
357 };
358 if col.name != ts_col_name {
359 return None;
360 }
361
362 if list.is_empty() {
363 return Some(TimestampRange::empty());
364 }
365 let mut init_range = TimestampRange::empty();
366 for expr in list {
367 if let Expr::Literal(scalar, _) = expr {
368 return_none_if_utf8!(scalar);
369 if let Some(timestamp) = scalar_value_to_timestamp(scalar, None) {
370 init_range = init_range.or(&TimestampRange::single(timestamp))
371 } else {
372 return None;
375 }
376 }
377 }
378 Some(init_range)
379}
380
381#[cfg(test)]
382mod tests {
383 use std::sync::Arc;
384
385 use common_test_util::temp_dir::{TempDir, create_temp_dir};
386 use datafusion::parquet::arrow::ArrowWriter;
387 use datafusion_common::{Column, ScalarValue};
388 use datafusion_expr::{BinaryExpr, Literal, Operator, col, lit};
389 use datatypes::arrow::array::Int32Array;
390 use datatypes::arrow::datatypes::{DataType, Field, Schema};
391 use datatypes::arrow::record_batch::RecordBatch;
392 use datatypes::arrow_array::StringArray;
393 use parquet::arrow::ParquetRecordBatchStreamBuilder;
394 use parquet::file::properties::WriterProperties;
395
396 use super::*;
397 use crate::predicate::stats::RowGroupPruningStatistics;
398
399 fn check_build_predicate(expr: Expr, expect: TimestampRange) {
400 assert_eq!(
401 expect,
402 build_time_range_predicate("ts", TimeUnit::Millisecond, &[expr])
403 );
404 }
405
406 #[test]
407 fn test_gt() {
408 check_build_predicate(
410 col("ts").gt(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
411 TimestampRange::from_start(Timestamp::new_millisecond(2)),
412 );
413
414 check_build_predicate(
416 lit(ScalarValue::TimestampMillisecond(Some(1), None)).gt(col("ts")),
417 TimestampRange::until_end(Timestamp::new_millisecond(1), false),
418 );
419
420 check_build_predicate(
422 lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).gt(col("ts")),
423 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
424 );
425
426 check_build_predicate(
428 col("ts").gt(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
429 TimestampRange::from_start(Timestamp::new_millisecond(2)),
430 );
431
432 check_build_predicate(
434 lit(ScalarValue::TimestampSecond(Some(1), None)).gt(col("ts")),
435 TimestampRange::until_end(Timestamp::new_millisecond(1000), false),
436 );
437
438 check_build_predicate(
440 col("ts").gt(lit(ScalarValue::TimestampSecond(Some(1), None))),
441 TimestampRange::from_start(Timestamp::new_millisecond(1001)),
442 );
443 }
444
445 #[test]
446 fn test_gt_eq() {
447 check_build_predicate(
449 col("ts").gt_eq(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
450 TimestampRange::from_start(Timestamp::new_millisecond(1)),
451 );
452
453 check_build_predicate(
455 lit(ScalarValue::TimestampMillisecond(Some(1), None)).gt_eq(col("ts")),
456 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
457 );
458
459 check_build_predicate(
461 lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).gt_eq(col("ts")),
462 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
463 );
464
465 check_build_predicate(
467 col("ts").gt_eq(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
468 TimestampRange::from_start(Timestamp::new_millisecond(2)),
469 );
470
471 check_build_predicate(
473 lit(ScalarValue::TimestampSecond(Some(1), None)).gt_eq(col("ts")),
474 TimestampRange::until_end(Timestamp::new_millisecond(1000), true),
475 );
476
477 check_build_predicate(
479 col("ts").gt_eq(lit(ScalarValue::TimestampSecond(Some(1), None))),
480 TimestampRange::from_start(Timestamp::new_millisecond(1000)),
481 );
482 }
483
484 #[test]
485 fn test_lt() {
486 check_build_predicate(
488 col("ts").lt(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
489 TimestampRange::until_end(Timestamp::new_millisecond(1), false),
490 );
491
492 check_build_predicate(
494 lit(ScalarValue::TimestampMillisecond(Some(1), None)).lt(col("ts")),
495 TimestampRange::from_start(Timestamp::new_millisecond(2)),
496 );
497
498 check_build_predicate(
500 lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).lt(col("ts")),
501 TimestampRange::from_start(Timestamp::new_millisecond(2)),
502 );
503
504 check_build_predicate(
506 col("ts").lt(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
507 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
508 );
509
510 check_build_predicate(
512 lit(ScalarValue::TimestampSecond(Some(1), None)).lt(col("ts")),
513 TimestampRange::from_start(Timestamp::new_millisecond(1001)),
514 );
515
516 check_build_predicate(
518 col("ts").lt(lit(ScalarValue::TimestampSecond(Some(1), None))),
519 TimestampRange::until_end(Timestamp::new_millisecond(1000), false),
520 );
521 }
522
523 #[test]
524 fn test_lt_eq() {
525 check_build_predicate(
527 col("ts").lt_eq(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
528 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
529 );
530
531 check_build_predicate(
533 lit(ScalarValue::TimestampMillisecond(Some(1), None)).lt_eq(col("ts")),
534 TimestampRange::from_start(Timestamp::new_millisecond(1)),
535 );
536
537 check_build_predicate(
539 lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).lt_eq(col("ts")),
540 TimestampRange::from_start(Timestamp::new_millisecond(2)),
541 );
542
543 check_build_predicate(
545 col("ts").lt_eq(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
546 TimestampRange::until_end(Timestamp::new_millisecond(1), true),
547 );
548
549 check_build_predicate(
551 lit(ScalarValue::TimestampSecond(Some(1), None)).lt_eq(col("ts")),
552 TimestampRange::from_start(Timestamp::new_millisecond(1000)),
553 );
554
555 check_build_predicate(
557 col("ts").lt_eq(lit(ScalarValue::TimestampSecond(Some(1), None))),
558 TimestampRange::until_end(Timestamp::new_millisecond(1000), true),
559 );
560 }
561
562 async fn gen_test_parquet_file(dir: &TempDir, cnt: usize) -> (String, Arc<Schema>) {
563 let path = dir
564 .path()
565 .join("test-prune.parquet")
566 .to_string_lossy()
567 .to_string();
568
569 let name_field = Field::new("name", DataType::Utf8, true);
570 let count_field = Field::new("cnt", DataType::Int32, true);
571 let schema = Arc::new(Schema::new(vec![name_field, count_field]));
572
573 let file = std::fs::OpenOptions::new()
574 .write(true)
575 .create(true)
576 .truncate(true)
577 .open(path.clone())
578 .unwrap();
579
580 let write_props = WriterProperties::builder()
581 .set_max_row_group_size(10)
582 .build();
583 let mut writer = ArrowWriter::try_new(file, schema.clone(), Some(write_props)).unwrap();
584
585 for i in (0..cnt).step_by(10) {
586 let name_array = Arc::new(StringArray::from(
587 (i..(i + 10).min(cnt))
588 .map(|i| i.to_string())
589 .collect::<Vec<_>>(),
590 )) as Arc<_>;
591 let count_array = Arc::new(Int32Array::from(
592 (i..(i + 10).min(cnt)).map(|i| i as i32).collect::<Vec<_>>(),
593 )) as Arc<_>;
594 let rb = RecordBatch::try_new(schema.clone(), vec![name_array, count_array]).unwrap();
595 writer.write(&rb).unwrap();
596 }
597 let _ = writer.close().unwrap();
598 (path, schema)
599 }
600
601 async fn assert_prune(array_cnt: usize, filters: Vec<Expr>, expect: Vec<bool>) {
602 let dir = create_temp_dir("prune_parquet");
603 let (path, arrow_schema) = gen_test_parquet_file(&dir, array_cnt).await;
604 let schema = Arc::new(datatypes::schema::Schema::try_from(arrow_schema.clone()).unwrap());
605 let arrow_predicate = Predicate::new(filters);
606 let builder = ParquetRecordBatchStreamBuilder::new(
607 tokio::fs::OpenOptions::new()
608 .read(true)
609 .open(path)
610 .await
611 .unwrap(),
612 )
613 .await
614 .unwrap();
615 let metadata = builder.metadata().clone();
616 let row_groups = metadata.row_groups();
617
618 let stats = RowGroupPruningStatistics::new(row_groups, &schema);
619 let res = arrow_predicate.prune_with_stats(&stats, &arrow_schema);
620 assert_eq!(expect, res);
621 }
622
623 fn gen_predicate(max_val: i32, op: Operator) -> Vec<Expr> {
624 vec![datafusion_expr::Expr::BinaryExpr(BinaryExpr {
625 left: Box::new(datafusion_expr::Expr::Column(Column::from_name("cnt"))),
626 op,
627 right: Box::new(max_val.lit()),
628 })]
629 }
630
631 #[tokio::test]
632 async fn test_prune_empty() {
633 assert_prune(3, vec![], vec![true]).await;
634 }
635
636 #[tokio::test]
637 async fn test_prune_all_match() {
638 let p = gen_predicate(3, Operator::Gt);
639 assert_prune(2, p, vec![false]).await;
640 }
641
642 #[tokio::test]
643 async fn test_prune_gt() {
644 let p = gen_predicate(29, Operator::Gt);
645 assert_prune(
646 100,
647 p,
648 vec![
649 false, false, false, true, true, true, true, true, true, true,
650 ],
651 )
652 .await;
653 }
654
655 #[tokio::test]
656 async fn test_prune_eq_expr() {
657 let p = gen_predicate(30, Operator::Eq);
658 assert_prune(40, p, vec![false, false, false, true]).await;
659 }
660
661 #[tokio::test]
662 async fn test_prune_neq_expr() {
663 let p = gen_predicate(30, Operator::NotEq);
664 assert_prune(40, p, vec![true, true, true, true]).await;
665 }
666
667 #[tokio::test]
668 async fn test_prune_gteq_expr() {
669 let p = gen_predicate(29, Operator::GtEq);
670 assert_prune(40, p, vec![false, false, true, true]).await;
671 }
672
673 #[tokio::test]
674 async fn test_prune_lt_expr() {
675 let p = gen_predicate(30, Operator::Lt);
676 assert_prune(40, p, vec![true, true, true, false]).await;
677 }
678
679 #[tokio::test]
680 async fn test_prune_lteq_expr() {
681 let p = gen_predicate(30, Operator::LtEq);
682 assert_prune(40, p, vec![true, true, true, true]).await;
683 }
684
685 #[tokio::test]
686 async fn test_prune_between_expr() {
687 let p = gen_predicate(30, Operator::LtEq);
688 assert_prune(40, p, vec![true, true, true, true]).await;
689 }
690
691 #[tokio::test]
692 async fn test_or() {
693 let e = datafusion_expr::Expr::Column(Column::from_name("cnt"))
695 .gt(30.lit())
696 .or(datafusion_expr::Expr::Column(Column::from_name("cnt")).lt(20.lit()));
697 assert_prune(40, vec![e], vec![true, true, false, true]).await;
698 }
699
700 #[tokio::test]
701 async fn test_to_physical_expr() {
702 let predicate = Predicate::new(vec![
703 col("host").eq(lit("host_a")),
704 col("ts").gt(lit(ScalarValue::TimestampMicrosecond(Some(123), None))),
705 ]);
706
707 let schema = Arc::new(arrow::datatypes::Schema::new(vec![Field::new(
708 "host",
709 arrow::datatypes::DataType::Utf8,
710 false,
711 )]));
712
713 let predicates = predicate.to_physical_exprs(&schema).unwrap();
714 assert!(!predicates.is_empty());
715 }
716}