mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-07 13:52:59 +00:00
feat: push all possible filters down to parquet exec (#1839)
* feat: push all possible filters down to parquet exec * fix: project * test: add ut for DatafusionArrowPredicate * fix: according to CR comments
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -9150,8 +9150,10 @@ dependencies = [
|
||||
"common-test-util",
|
||||
"common-time",
|
||||
"criterion 0.3.6",
|
||||
"datafusion",
|
||||
"datafusion-common",
|
||||
"datafusion-expr",
|
||||
"datafusion-physical-expr",
|
||||
"datatypes",
|
||||
"futures",
|
||||
"futures-util",
|
||||
|
||||
@@ -294,7 +294,7 @@ fn new_item_field(data_type: ArrowDataType) -> Field {
|
||||
Field::new("item", data_type, false)
|
||||
}
|
||||
|
||||
fn timestamp_to_scalar_value(unit: TimeUnit, val: Option<i64>) -> ScalarValue {
|
||||
pub fn timestamp_to_scalar_value(unit: TimeUnit, val: Option<i64>) -> ScalarValue {
|
||||
match unit {
|
||||
TimeUnit::Second => ScalarValue::TimestampSecond(val, None),
|
||||
TimeUnit::Millisecond => ScalarValue::TimestampMillisecond(val, None),
|
||||
|
||||
@@ -140,7 +140,7 @@ impl TimeRangeTester {
|
||||
let _ = exec_selection(self.engine.clone(), sql).await;
|
||||
let filters = self.table.get_filters().await;
|
||||
|
||||
let range = TimeRangePredicateBuilder::new("ts", &filters).build();
|
||||
let range = TimeRangePredicateBuilder::new("ts", TimeUnit::Millisecond, &filters).build();
|
||||
assert_eq!(expect, range);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,6 +23,8 @@ common-time = { path = "../common/time" }
|
||||
datatypes = { path = "../datatypes" }
|
||||
datafusion-common.workspace = true
|
||||
datafusion-expr.workspace = true
|
||||
datafusion-physical-expr.workspace = true
|
||||
datafusion.workspace = true
|
||||
futures.workspace = true
|
||||
futures-util.workspace = true
|
||||
itertools.workspace = true
|
||||
|
||||
@@ -226,10 +226,16 @@ impl ChunkReaderBuilder {
|
||||
reader_builder = reader_builder.push_batch_iter(iter);
|
||||
}
|
||||
|
||||
let predicate = Predicate::try_new(
|
||||
self.filters.clone(),
|
||||
self.schema.store_schema().schema().clone(),
|
||||
)
|
||||
.context(error::BuildPredicateSnafu)?;
|
||||
|
||||
let read_opts = ReadOptions {
|
||||
batch_size: self.iter_ctx.batch_size,
|
||||
projected_schema: schema.clone(),
|
||||
predicate: Predicate::new(self.filters.clone()),
|
||||
predicate,
|
||||
time_range: *time_range,
|
||||
};
|
||||
for file in &self.files_to_read {
|
||||
@@ -270,7 +276,12 @@ impl ChunkReaderBuilder {
|
||||
/// Build time range predicate from schema and filters.
|
||||
pub fn build_time_range_predicate(&self) -> TimestampRange {
|
||||
let Some(ts_col) = self.schema.user_schema().timestamp_column() else { return TimestampRange::min_to_max() };
|
||||
TimeRangePredicateBuilder::new(&ts_col.name, &self.filters).build()
|
||||
let unit = ts_col
|
||||
.data_type
|
||||
.as_timestamp()
|
||||
.expect("Timestamp column must have timestamp-compatible type")
|
||||
.unit();
|
||||
TimeRangePredicateBuilder::new(&ts_col.name, unit, &self.filters).build()
|
||||
}
|
||||
|
||||
/// Check if SST file's time range matches predicate.
|
||||
|
||||
@@ -13,8 +13,9 @@
|
||||
// limitations under the License.
|
||||
|
||||
use common_query::logical_plan::{DfExpr, Expr};
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_expr::{BinaryExpr, Operator};
|
||||
use common_time::timestamp::TimeUnit;
|
||||
use datafusion_expr::Operator;
|
||||
use datatypes::value::timestamp_to_scalar_value;
|
||||
|
||||
use crate::chunk::{ChunkReaderBuilder, ChunkReaderImpl};
|
||||
use crate::error;
|
||||
@@ -31,53 +32,84 @@ pub(crate) async fn build_sst_reader(
|
||||
) -> error::Result<ChunkReaderImpl> {
|
||||
// TODO(hl): Schemas in different SSTs may differ, thus we should infer
|
||||
// timestamp column name from Parquet metadata.
|
||||
let ts_col_name = schema
|
||||
.user_schema()
|
||||
.timestamp_column()
|
||||
.unwrap()
|
||||
.name
|
||||
.clone();
|
||||
|
||||
// safety: Region schema's timestamp column must present
|
||||
let ts_col = schema.user_schema().timestamp_column().unwrap();
|
||||
let ts_col_unit = ts_col.data_type.as_timestamp().unwrap().unit();
|
||||
let ts_col_name = ts_col.name.clone();
|
||||
|
||||
ChunkReaderBuilder::new(schema, sst_layer)
|
||||
.pick_ssts(files)
|
||||
.filters(vec![build_time_range_filter(
|
||||
lower_sec_inclusive,
|
||||
upper_sec_exclusive,
|
||||
&ts_col_name,
|
||||
)])
|
||||
.filters(
|
||||
build_time_range_filter(
|
||||
lower_sec_inclusive,
|
||||
upper_sec_exclusive,
|
||||
&ts_col_name,
|
||||
ts_col_unit,
|
||||
)
|
||||
.into_iter()
|
||||
.collect(),
|
||||
)
|
||||
.build()
|
||||
.await
|
||||
}
|
||||
|
||||
fn build_time_range_filter(low_sec: i64, high_sec: i64, ts_col_name: &str) -> Expr {
|
||||
let ts_col = Box::new(DfExpr::Column(datafusion_common::Column::from_name(
|
||||
ts_col_name,
|
||||
)));
|
||||
let lower_bound_expr = Box::new(DfExpr::Literal(ScalarValue::TimestampSecond(
|
||||
Some(low_sec),
|
||||
None,
|
||||
)));
|
||||
/// Build time range filter expr from lower (inclusive) and upper bound(exclusive).
|
||||
/// Returns `None` if time range overflows.
|
||||
fn build_time_range_filter(
|
||||
low_sec: i64,
|
||||
high_sec: i64,
|
||||
ts_col_name: &str,
|
||||
ts_col_unit: TimeUnit,
|
||||
) -> Option<Expr> {
|
||||
debug_assert!(low_sec <= high_sec);
|
||||
let ts_col = DfExpr::Column(datafusion_common::Column::from_name(ts_col_name));
|
||||
|
||||
let upper_bound_expr = Box::new(DfExpr::Literal(ScalarValue::TimestampSecond(
|
||||
Some(high_sec),
|
||||
None,
|
||||
)));
|
||||
// Converting seconds to whatever unit won't lose precision.
|
||||
// Here only handles overflow.
|
||||
let low_ts = common_time::Timestamp::new_second(low_sec)
|
||||
.convert_to(ts_col_unit)
|
||||
.map(|ts| ts.value());
|
||||
let high_ts = common_time::Timestamp::new_second(high_sec)
|
||||
.convert_to(ts_col_unit)
|
||||
.map(|ts| ts.value());
|
||||
|
||||
let expr = DfExpr::BinaryExpr(BinaryExpr {
|
||||
left: Box::new(DfExpr::BinaryExpr(BinaryExpr {
|
||||
left: ts_col.clone(),
|
||||
op: Operator::GtEq,
|
||||
right: lower_bound_expr,
|
||||
})),
|
||||
op: Operator::And,
|
||||
right: Box::new(DfExpr::BinaryExpr(BinaryExpr {
|
||||
left: ts_col,
|
||||
op: Operator::Lt,
|
||||
right: upper_bound_expr,
|
||||
})),
|
||||
});
|
||||
let expr = match (low_ts, high_ts) {
|
||||
(Some(low), Some(high)) => {
|
||||
let lower_bound_expr =
|
||||
DfExpr::Literal(timestamp_to_scalar_value(ts_col_unit, Some(low)));
|
||||
let upper_bound_expr =
|
||||
DfExpr::Literal(timestamp_to_scalar_value(ts_col_unit, Some(high)));
|
||||
Some(datafusion_expr::and(
|
||||
datafusion_expr::binary_expr(ts_col.clone(), Operator::GtEq, lower_bound_expr),
|
||||
datafusion_expr::binary_expr(ts_col, Operator::Lt, upper_bound_expr),
|
||||
))
|
||||
}
|
||||
|
||||
Expr::from(expr)
|
||||
(Some(low), None) => {
|
||||
let lower_bound_expr =
|
||||
datafusion_expr::lit(timestamp_to_scalar_value(ts_col_unit, Some(low)));
|
||||
Some(datafusion_expr::binary_expr(
|
||||
ts_col,
|
||||
Operator::GtEq,
|
||||
lower_bound_expr,
|
||||
))
|
||||
}
|
||||
|
||||
(None, Some(high)) => {
|
||||
let upper_bound_expr =
|
||||
datafusion_expr::lit(timestamp_to_scalar_value(ts_col_unit, Some(high)));
|
||||
Some(datafusion_expr::binary_expr(
|
||||
ts_col,
|
||||
Operator::Lt,
|
||||
upper_bound_expr,
|
||||
))
|
||||
}
|
||||
|
||||
(None, None) => None,
|
||||
};
|
||||
|
||||
expr.map(Expr::from)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -490,4 +522,35 @@ mod tests {
|
||||
|
||||
assert_eq!(timestamps_in_outputs, timestamps_in_inputs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_time_range_filter() {
|
||||
assert!(build_time_range_filter(i64::MIN, i64::MAX, "ts", TimeUnit::Nanosecond).is_none());
|
||||
|
||||
assert_eq!(
|
||||
Expr::from(datafusion_expr::binary_expr(
|
||||
datafusion_expr::col("ts"),
|
||||
Operator::Lt,
|
||||
datafusion_expr::lit(timestamp_to_scalar_value(
|
||||
TimeUnit::Nanosecond,
|
||||
Some(TimeUnit::Second.factor() as i64 / TimeUnit::Nanosecond.factor() as i64)
|
||||
))
|
||||
)),
|
||||
build_time_range_filter(i64::MIN, 1, "ts", TimeUnit::Nanosecond).unwrap()
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
Expr::from(datafusion_expr::binary_expr(
|
||||
datafusion_expr::col("ts"),
|
||||
Operator::GtEq,
|
||||
datafusion_expr::lit(timestamp_to_scalar_value(
|
||||
TimeUnit::Nanosecond,
|
||||
Some(
|
||||
2 * TimeUnit::Second.factor() as i64 / TimeUnit::Nanosecond.factor() as i64
|
||||
)
|
||||
))
|
||||
)),
|
||||
build_time_range_filter(2, i64::MAX, "ts", TimeUnit::Nanosecond).unwrap()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -522,6 +522,12 @@ pub enum Error {
|
||||
source: ArrowError,
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to build scan predicate, source: {}", source))]
|
||||
BuildPredicate {
|
||||
source: table::error::Error,
|
||||
location: Location,
|
||||
},
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
@@ -621,6 +627,7 @@ impl ErrorExt for Error {
|
||||
|
||||
TtlCalculation { source, .. } => source.status_code(),
|
||||
ConvertColumnsToRows { .. } | SortArrays { .. } => StatusCode::Unexpected,
|
||||
BuildPredicate { source, .. } => source.status_code(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -21,7 +21,9 @@ use arrow::compute::SortOptions;
|
||||
use common_query::prelude::Expr;
|
||||
use common_recordbatch::OrderOption;
|
||||
use common_test_util::temp_dir::create_temp_dir;
|
||||
use common_time::timestamp::TimeUnit;
|
||||
use datafusion_common::Column;
|
||||
use datatypes::value::timestamp_to_scalar_value;
|
||||
use log_store::raft_engine::log_store::RaftEngineLogStore;
|
||||
use store_api::storage::{FlushContext, FlushReason, OpenOptions, Region, ScanRequest};
|
||||
|
||||
@@ -404,7 +406,10 @@ async fn test_flush_and_query_empty() {
|
||||
filters: vec![Expr::from(datafusion_expr::binary_expr(
|
||||
DfExpr::Column(Column::from("timestamp")),
|
||||
datafusion_expr::Operator::GtEq,
|
||||
datafusion_expr::lit(20000),
|
||||
datafusion_expr::lit(timestamp_to_scalar_value(
|
||||
TimeUnit::Millisecond,
|
||||
Some(20000),
|
||||
)),
|
||||
))],
|
||||
output_ordering: Some(vec![OrderOption {
|
||||
name: "timestamp".to_string(),
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
pub(crate) mod parquet;
|
||||
mod pruning;
|
||||
mod stream_writer;
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
@@ -18,12 +18,6 @@ use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::datatypes::DataType;
|
||||
use arrow_array::types::Int64Type;
|
||||
use arrow_array::{
|
||||
Array, PrimitiveArray, TimestampMicrosecondArray, TimestampMillisecondArray,
|
||||
TimestampNanosecondArray, TimestampSecondArray,
|
||||
};
|
||||
use async_compat::CompatExt;
|
||||
use async_stream::try_stream;
|
||||
use async_trait::async_trait;
|
||||
@@ -31,19 +25,16 @@ use common_telemetry::{debug, error};
|
||||
use common_time::range::TimestampRange;
|
||||
use common_time::timestamp::TimeUnit;
|
||||
use common_time::Timestamp;
|
||||
use datatypes::arrow::array::BooleanArray;
|
||||
use datatypes::arrow::error::ArrowError;
|
||||
use datatypes::arrow::record_batch::RecordBatch;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use futures_util::{Stream, StreamExt, TryStreamExt};
|
||||
use object_store::ObjectStore;
|
||||
use parquet::arrow::arrow_reader::{ArrowPredicate, RowFilter};
|
||||
use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask};
|
||||
use parquet::basic::{Compression, Encoding, ZstdLevel};
|
||||
use parquet::file::metadata::KeyValue;
|
||||
use parquet::file::properties::WriterProperties;
|
||||
use parquet::format::FileMetaData;
|
||||
use parquet::schema::types::{ColumnPath, SchemaDescriptor};
|
||||
use parquet::schema::types::ColumnPath;
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
use store_api::storage::consts::SEQUENCE_COLUMN_NAME;
|
||||
use table::predicate::Predicate;
|
||||
@@ -54,6 +45,7 @@ use crate::read::{Batch, BatchReader};
|
||||
use crate::schema::compat::ReadAdapter;
|
||||
use crate::schema::{ProjectedSchemaRef, StoreSchema};
|
||||
use crate::sst;
|
||||
use crate::sst::pruning::build_row_filter;
|
||||
use crate::sst::stream_writer::BufferedWriter;
|
||||
use crate::sst::{FileHandle, Source, SstInfo};
|
||||
|
||||
@@ -277,10 +269,7 @@ impl ParquetReader {
|
||||
|
||||
let pruned_row_groups = self
|
||||
.predicate
|
||||
.prune_row_groups(
|
||||
store_schema.schema().clone(),
|
||||
builder.metadata().row_groups(),
|
||||
)
|
||||
.prune_row_groups(builder.metadata().row_groups())
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter_map(|(idx, valid)| if valid { Some(idx) } else { None })
|
||||
@@ -288,15 +277,18 @@ impl ParquetReader {
|
||||
|
||||
let parquet_schema_desc = builder.metadata().file_metadata().schema_descr_ptr();
|
||||
|
||||
let projection = ProjectionMask::roots(&parquet_schema_desc, adapter.fields_to_read());
|
||||
let projection_mask = ProjectionMask::roots(&parquet_schema_desc, adapter.fields_to_read());
|
||||
let mut builder = builder
|
||||
.with_projection(projection)
|
||||
.with_projection(projection_mask.clone())
|
||||
.with_row_groups(pruned_row_groups);
|
||||
|
||||
// if time range row filter is present, we can push down the filter to reduce rows to scan.
|
||||
if let Some(row_filter) =
|
||||
build_time_range_row_filter(self.time_range, &store_schema, &parquet_schema_desc)
|
||||
{
|
||||
if let Some(row_filter) = build_row_filter(
|
||||
self.time_range,
|
||||
&self.predicate,
|
||||
&store_schema,
|
||||
&parquet_schema_desc,
|
||||
projection_mask,
|
||||
) {
|
||||
builder = builder.with_row_filter(row_filter);
|
||||
}
|
||||
|
||||
@@ -314,198 +306,6 @@ impl ParquetReader {
|
||||
}
|
||||
}
|
||||
|
||||
/// Builds time range row filter.
|
||||
fn build_time_range_row_filter(
|
||||
time_range: TimestampRange,
|
||||
store_schema: &Arc<StoreSchema>,
|
||||
schema_desc: &SchemaDescriptor,
|
||||
) -> Option<RowFilter> {
|
||||
let ts_col_idx = store_schema.timestamp_index();
|
||||
|
||||
let ts_col = store_schema.columns().get(ts_col_idx)?;
|
||||
|
||||
let ts_col_unit = match &ts_col.desc.data_type {
|
||||
ConcreteDataType::Int64(_) => TimeUnit::Millisecond,
|
||||
ConcreteDataType::Timestamp(ts_type) => ts_type.unit(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let projection = ProjectionMask::roots(schema_desc, vec![ts_col_idx]);
|
||||
|
||||
// checks if converting time range unit into ts col unit will result into rounding error.
|
||||
if time_unit_lossy(&time_range, ts_col_unit) {
|
||||
let filter = RowFilter::new(vec![Box::new(PlainTimestampRowFilter::new(
|
||||
time_range, projection,
|
||||
))]);
|
||||
return Some(filter);
|
||||
}
|
||||
|
||||
// If any of the conversion overflows, we cannot use arrow's computation method, instead
|
||||
// we resort to plain filter that compares timestamp with given range, less efficient,
|
||||
// but simpler.
|
||||
// TODO(hl): If the range is gt_eq/lt, we also use PlainTimestampRowFilter, but these cases
|
||||
// can also use arrow's gt_eq_scalar/lt_scalar methods.
|
||||
let row_filter = if let (Some(lower), Some(upper)) = (
|
||||
time_range
|
||||
.start()
|
||||
.and_then(|s| s.convert_to(ts_col_unit))
|
||||
.map(|t| t.value()),
|
||||
time_range
|
||||
.end()
|
||||
.and_then(|s| s.convert_to(ts_col_unit))
|
||||
.map(|t| t.value()),
|
||||
) {
|
||||
Box::new(FastTimestampRowFilter::new(projection, lower, upper)) as _
|
||||
} else {
|
||||
Box::new(PlainTimestampRowFilter::new(time_range, projection)) as _
|
||||
};
|
||||
let filter = RowFilter::new(vec![row_filter]);
|
||||
Some(filter)
|
||||
}
|
||||
|
||||
fn time_unit_lossy(range: &TimestampRange, ts_col_unit: TimeUnit) -> bool {
|
||||
range
|
||||
.start()
|
||||
.map(|start| start.unit().factor() < ts_col_unit.factor())
|
||||
.unwrap_or(false)
|
||||
|| range
|
||||
.end()
|
||||
.map(|end| end.unit().factor() < ts_col_unit.factor())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// `FastTimestampRowFilter` is used to filter rows within given timestamp range when reading
|
||||
/// row groups from parquet files, while avoids fetching all columns from SSTs file.
|
||||
struct FastTimestampRowFilter {
|
||||
lower_bound: i64,
|
||||
upper_bound: i64,
|
||||
projection: ProjectionMask,
|
||||
}
|
||||
|
||||
impl FastTimestampRowFilter {
|
||||
fn new(projection: ProjectionMask, lower_bound: i64, upper_bound: i64) -> Self {
|
||||
Self {
|
||||
lower_bound,
|
||||
upper_bound,
|
||||
projection,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ArrowPredicate for FastTimestampRowFilter {
|
||||
fn projection(&self) -> &ProjectionMask {
|
||||
&self.projection
|
||||
}
|
||||
|
||||
/// Selects the rows matching given time range.
|
||||
fn evaluate(&mut self, batch: RecordBatch) -> std::result::Result<BooleanArray, ArrowError> {
|
||||
// the projection has only timestamp column, so we can safely take the first column in batch.
|
||||
let ts_col = batch.column(0);
|
||||
|
||||
macro_rules! downcast_and_compute {
|
||||
($typ: ty) => {
|
||||
{
|
||||
let ts_col = ts_col
|
||||
.as_any()
|
||||
.downcast_ref::<$typ>()
|
||||
.unwrap(); // safety: we've checked the data type of timestamp column.
|
||||
let left = arrow::compute::gt_eq_scalar(ts_col, self.lower_bound)?;
|
||||
let right = arrow::compute::lt_scalar(ts_col, self.upper_bound)?;
|
||||
arrow::compute::and(&left, &right)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
match ts_col.data_type() {
|
||||
DataType::Timestamp(unit, _) => match unit {
|
||||
arrow::datatypes::TimeUnit::Second => {
|
||||
downcast_and_compute!(TimestampSecondArray)
|
||||
}
|
||||
arrow::datatypes::TimeUnit::Millisecond => {
|
||||
downcast_and_compute!(TimestampMillisecondArray)
|
||||
}
|
||||
arrow::datatypes::TimeUnit::Microsecond => {
|
||||
downcast_and_compute!(TimestampMicrosecondArray)
|
||||
}
|
||||
arrow::datatypes::TimeUnit::Nanosecond => {
|
||||
downcast_and_compute!(TimestampNanosecondArray)
|
||||
}
|
||||
},
|
||||
DataType::Int64 => downcast_and_compute!(PrimitiveArray<Int64Type>),
|
||||
_ => {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// [PlainTimestampRowFilter] iterates each element in timestamp column, build a [Timestamp] struct
|
||||
/// and checks if given time range contains the timestamp.
|
||||
struct PlainTimestampRowFilter {
|
||||
time_range: TimestampRange,
|
||||
projection: ProjectionMask,
|
||||
}
|
||||
|
||||
impl PlainTimestampRowFilter {
|
||||
fn new(time_range: TimestampRange, projection: ProjectionMask) -> Self {
|
||||
Self {
|
||||
time_range,
|
||||
projection,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ArrowPredicate for PlainTimestampRowFilter {
|
||||
fn projection(&self) -> &ProjectionMask {
|
||||
&self.projection
|
||||
}
|
||||
|
||||
fn evaluate(&mut self, batch: RecordBatch) -> std::result::Result<BooleanArray, ArrowError> {
|
||||
// the projection has only timestamp column, so we can safely take the first column in batch.
|
||||
let ts_col = batch.column(0);
|
||||
|
||||
macro_rules! downcast_and_compute {
|
||||
($array_ty: ty, $unit: ident) => {{
|
||||
let ts_col = ts_col
|
||||
.as_any()
|
||||
.downcast_ref::<$array_ty>()
|
||||
.unwrap(); // safety: we've checked the data type of timestamp column.
|
||||
Ok(BooleanArray::from_iter(ts_col.iter().map(|ts| {
|
||||
ts.map(|val| {
|
||||
Timestamp::new(val, TimeUnit::$unit)
|
||||
}).map(|ts| {
|
||||
self.time_range.contains(&ts)
|
||||
})
|
||||
})))
|
||||
|
||||
}};
|
||||
}
|
||||
|
||||
match ts_col.data_type() {
|
||||
DataType::Timestamp(unit, _) => match unit {
|
||||
arrow::datatypes::TimeUnit::Second => {
|
||||
downcast_and_compute!(TimestampSecondArray, Second)
|
||||
}
|
||||
arrow::datatypes::TimeUnit::Millisecond => {
|
||||
downcast_and_compute!(TimestampMillisecondArray, Millisecond)
|
||||
}
|
||||
arrow::datatypes::TimeUnit::Microsecond => {
|
||||
downcast_and_compute!(TimestampMicrosecondArray, Microsecond)
|
||||
}
|
||||
arrow::datatypes::TimeUnit::Nanosecond => {
|
||||
downcast_and_compute!(TimestampNanosecondArray, Nanosecond)
|
||||
}
|
||||
},
|
||||
DataType::Int64 => {
|
||||
downcast_and_compute!(PrimitiveArray<Int64Type>, Millisecond)
|
||||
}
|
||||
_ => {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type SendableChunkStream = Pin<Box<dyn Stream<Item = Result<RecordBatch>> + Send>>;
|
||||
|
||||
pub struct ChunkStream {
|
||||
@@ -740,11 +540,12 @@ mod tests {
|
||||
let operator = create_object_store(dir.path().to_str().unwrap());
|
||||
|
||||
let projected_schema = Arc::new(ProjectedSchema::new(schema, Some(vec![1])).unwrap());
|
||||
let user_schema = projected_schema.projected_user_schema().clone();
|
||||
let reader = ParquetReader::new(
|
||||
sst_file_handle,
|
||||
operator,
|
||||
projected_schema,
|
||||
Predicate::empty(),
|
||||
Predicate::empty(user_schema),
|
||||
TimestampRange::min_to_max(),
|
||||
);
|
||||
|
||||
@@ -826,11 +627,12 @@ mod tests {
|
||||
let operator = create_object_store(dir.path().to_str().unwrap());
|
||||
|
||||
let projected_schema = Arc::new(ProjectedSchema::new(schema, Some(vec![1])).unwrap());
|
||||
let user_schema = projected_schema.projected_user_schema().clone();
|
||||
let reader = ParquetReader::new(
|
||||
file_handle,
|
||||
operator,
|
||||
projected_schema,
|
||||
Predicate::empty(),
|
||||
Predicate::empty(user_schema),
|
||||
TimestampRange::min_to_max(),
|
||||
);
|
||||
|
||||
@@ -854,8 +656,14 @@ mod tests {
|
||||
range: TimestampRange,
|
||||
expect: Vec<i64>,
|
||||
) {
|
||||
let reader =
|
||||
ParquetReader::new(file_handle, object_store, schema, Predicate::empty(), range);
|
||||
let store_schema = schema.schema_to_read().clone();
|
||||
let reader = ParquetReader::new(
|
||||
file_handle,
|
||||
object_store,
|
||||
schema,
|
||||
Predicate::empty(store_schema.schema().clone()),
|
||||
range,
|
||||
);
|
||||
let mut stream = reader.chunk_stream().await.unwrap();
|
||||
let result = stream.next_batch().await;
|
||||
|
||||
@@ -981,16 +789,6 @@ mod tests {
|
||||
.await;
|
||||
}
|
||||
|
||||
fn check_unit_lossy(range_unit: TimeUnit, col_unit: TimeUnit, expect: bool) {
|
||||
assert_eq!(
|
||||
expect,
|
||||
time_unit_lossy(
|
||||
&TimestampRange::with_unit(0, 1, range_unit).unwrap(),
|
||||
col_unit
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_write_empty_file() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
@@ -1014,28 +812,4 @@ mod tests {
|
||||
// The file should not exist when no row has been written.
|
||||
assert!(!object_store.is_exist(sst_file_name).await.unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_time_unit_lossy() {
|
||||
// converting a range with unit second to millisecond will not cause rounding error
|
||||
check_unit_lossy(TimeUnit::Second, TimeUnit::Second, false);
|
||||
check_unit_lossy(TimeUnit::Second, TimeUnit::Millisecond, false);
|
||||
check_unit_lossy(TimeUnit::Second, TimeUnit::Microsecond, false);
|
||||
check_unit_lossy(TimeUnit::Second, TimeUnit::Nanosecond, false);
|
||||
|
||||
check_unit_lossy(TimeUnit::Millisecond, TimeUnit::Second, true);
|
||||
check_unit_lossy(TimeUnit::Millisecond, TimeUnit::Millisecond, false);
|
||||
check_unit_lossy(TimeUnit::Millisecond, TimeUnit::Microsecond, false);
|
||||
check_unit_lossy(TimeUnit::Millisecond, TimeUnit::Nanosecond, false);
|
||||
|
||||
check_unit_lossy(TimeUnit::Microsecond, TimeUnit::Second, true);
|
||||
check_unit_lossy(TimeUnit::Microsecond, TimeUnit::Millisecond, true);
|
||||
check_unit_lossy(TimeUnit::Microsecond, TimeUnit::Microsecond, false);
|
||||
check_unit_lossy(TimeUnit::Microsecond, TimeUnit::Nanosecond, false);
|
||||
|
||||
check_unit_lossy(TimeUnit::Nanosecond, TimeUnit::Second, true);
|
||||
check_unit_lossy(TimeUnit::Nanosecond, TimeUnit::Millisecond, true);
|
||||
check_unit_lossy(TimeUnit::Nanosecond, TimeUnit::Microsecond, true);
|
||||
check_unit_lossy(TimeUnit::Nanosecond, TimeUnit::Nanosecond, false);
|
||||
}
|
||||
}
|
||||
|
||||
408
src/storage/src/sst/pruning.rs
Normal file
408
src/storage/src/sst/pruning.rs
Normal file
@@ -0,0 +1,408 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::array::{
|
||||
PrimitiveArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
|
||||
TimestampSecondArray,
|
||||
};
|
||||
use arrow::datatypes::{DataType, Int64Type};
|
||||
use arrow::error::ArrowError;
|
||||
use arrow_array::{Array, BooleanArray, RecordBatch};
|
||||
use common_time::range::TimestampRange;
|
||||
use common_time::timestamp::TimeUnit;
|
||||
use common_time::Timestamp;
|
||||
use datafusion::physical_plan::PhysicalExpr;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use parquet::arrow::arrow_reader::{ArrowPredicate, RowFilter};
|
||||
use parquet::arrow::ProjectionMask;
|
||||
use parquet::schema::types::SchemaDescriptor;
|
||||
use table::predicate::Predicate;
|
||||
|
||||
use crate::error;
|
||||
use crate::schema::StoreSchema;
|
||||
|
||||
/// Builds row filters according to predicates.
|
||||
pub(crate) fn build_row_filter(
|
||||
time_range: TimestampRange,
|
||||
predicate: &Predicate,
|
||||
store_schema: &Arc<StoreSchema>,
|
||||
schema_desc: &SchemaDescriptor,
|
||||
projection_mask: ProjectionMask,
|
||||
) -> Option<RowFilter> {
|
||||
let ts_col_idx = store_schema.timestamp_index();
|
||||
let ts_col = store_schema.columns().get(ts_col_idx)?;
|
||||
let ts_col_unit = match &ts_col.desc.data_type {
|
||||
ConcreteDataType::Int64(_) => TimeUnit::Millisecond,
|
||||
ConcreteDataType::Timestamp(ts_type) => ts_type.unit(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let ts_col_projection = ProjectionMask::roots(schema_desc, vec![ts_col_idx]);
|
||||
|
||||
// checks if converting time range unit into ts col unit will result into rounding error.
|
||||
if time_unit_lossy(&time_range, ts_col_unit) {
|
||||
let filter = RowFilter::new(vec![Box::new(PlainTimestampRowFilter::new(
|
||||
time_range,
|
||||
ts_col_projection,
|
||||
))]);
|
||||
return Some(filter);
|
||||
}
|
||||
|
||||
// If any of the conversion overflows, we cannot use arrow's computation method, instead
|
||||
// we resort to plain filter that compares timestamp with given range, less efficient,
|
||||
// but simpler.
|
||||
// TODO(hl): If the range is gt_eq/lt, we also use PlainTimestampRowFilter, but these cases
|
||||
// can also use arrow's gt_eq_scalar/lt_scalar methods.
|
||||
let time_range_row_filter = if let (Some(lower), Some(upper)) = (
|
||||
time_range
|
||||
.start()
|
||||
.and_then(|s| s.convert_to(ts_col_unit))
|
||||
.map(|t| t.value()),
|
||||
time_range
|
||||
.end()
|
||||
.and_then(|s| s.convert_to(ts_col_unit))
|
||||
.map(|t| t.value()),
|
||||
) {
|
||||
Box::new(FastTimestampRowFilter::new(ts_col_projection, lower, upper)) as _
|
||||
} else {
|
||||
Box::new(PlainTimestampRowFilter::new(time_range, ts_col_projection)) as _
|
||||
};
|
||||
let mut predicates = vec![time_range_row_filter];
|
||||
if let Ok(datafusion_filters) = predicate_to_row_filter(predicate, projection_mask) {
|
||||
predicates.extend(datafusion_filters);
|
||||
}
|
||||
let filter = RowFilter::new(predicates);
|
||||
Some(filter)
|
||||
}
|
||||
|
||||
fn predicate_to_row_filter(
|
||||
predicate: &Predicate,
|
||||
projection_mask: ProjectionMask,
|
||||
) -> error::Result<Vec<Box<dyn ArrowPredicate>>> {
|
||||
let mut datafusion_predicates = Vec::with_capacity(predicate.exprs().len());
|
||||
for expr in predicate.exprs() {
|
||||
datafusion_predicates.push(Box::new(DatafusionArrowPredicate {
|
||||
projection_mask: projection_mask.clone(),
|
||||
physical_expr: expr.clone(),
|
||||
}) as _);
|
||||
}
|
||||
|
||||
Ok(datafusion_predicates)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct DatafusionArrowPredicate {
|
||||
projection_mask: ProjectionMask,
|
||||
physical_expr: Arc<dyn PhysicalExpr>,
|
||||
}
|
||||
|
||||
impl ArrowPredicate for DatafusionArrowPredicate {
|
||||
fn projection(&self) -> &ProjectionMask {
|
||||
&self.projection_mask
|
||||
}
|
||||
|
||||
fn evaluate(&mut self, batch: RecordBatch) -> Result<BooleanArray, ArrowError> {
|
||||
match self
|
||||
.physical_expr
|
||||
.evaluate(&batch)
|
||||
.map(|v| v.into_array(batch.num_rows()))
|
||||
{
|
||||
Ok(array) => {
|
||||
let bool_arr = array
|
||||
.as_any()
|
||||
.downcast_ref::<BooleanArray>()
|
||||
.ok_or(ArrowError::CastError(
|
||||
"Physical expr evaluated res is not a boolean array".to_string(),
|
||||
))?
|
||||
.clone();
|
||||
Ok(bool_arr)
|
||||
}
|
||||
Err(e) => Err(ArrowError::ComputeError(format!(
|
||||
"Error evaluating filter predicate: {e:?}"
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn time_unit_lossy(range: &TimestampRange, ts_col_unit: TimeUnit) -> bool {
|
||||
range
|
||||
.start()
|
||||
.map(|start| start.unit().factor() < ts_col_unit.factor())
|
||||
.unwrap_or(false)
|
||||
|| range
|
||||
.end()
|
||||
.map(|end| end.unit().factor() < ts_col_unit.factor())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// `FastTimestampRowFilter` is used to filter rows within given timestamp range when reading
|
||||
/// row groups from parquet files, while avoids fetching all columns from SSTs file.
|
||||
struct FastTimestampRowFilter {
|
||||
lower_bound: i64,
|
||||
upper_bound: i64,
|
||||
projection: ProjectionMask,
|
||||
}
|
||||
|
||||
impl FastTimestampRowFilter {
|
||||
fn new(projection: ProjectionMask, lower_bound: i64, upper_bound: i64) -> Self {
|
||||
Self {
|
||||
lower_bound,
|
||||
upper_bound,
|
||||
projection,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ArrowPredicate for FastTimestampRowFilter {
|
||||
fn projection(&self) -> &ProjectionMask {
|
||||
&self.projection
|
||||
}
|
||||
|
||||
/// Selects the rows matching given time range.
|
||||
fn evaluate(&mut self, batch: RecordBatch) -> Result<BooleanArray, ArrowError> {
|
||||
// the projection has only timestamp column, so we can safely take the first column in batch.
|
||||
let ts_col = batch.column(0);
|
||||
|
||||
macro_rules! downcast_and_compute {
|
||||
($typ: ty) => {
|
||||
{
|
||||
let ts_col = ts_col
|
||||
.as_any()
|
||||
.downcast_ref::<$typ>()
|
||||
.unwrap(); // safety: we've checked the data type of timestamp column.
|
||||
let left = arrow::compute::gt_eq_scalar(ts_col, self.lower_bound)?;
|
||||
let right = arrow::compute::lt_scalar(ts_col, self.upper_bound)?;
|
||||
arrow::compute::and(&left, &right)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
match ts_col.data_type() {
|
||||
DataType::Timestamp(unit, _) => match unit {
|
||||
arrow::datatypes::TimeUnit::Second => {
|
||||
downcast_and_compute!(TimestampSecondArray)
|
||||
}
|
||||
arrow::datatypes::TimeUnit::Millisecond => {
|
||||
downcast_and_compute!(TimestampMillisecondArray)
|
||||
}
|
||||
arrow::datatypes::TimeUnit::Microsecond => {
|
||||
downcast_and_compute!(TimestampMicrosecondArray)
|
||||
}
|
||||
arrow::datatypes::TimeUnit::Nanosecond => {
|
||||
downcast_and_compute!(TimestampNanosecondArray)
|
||||
}
|
||||
},
|
||||
DataType::Int64 => downcast_and_compute!(PrimitiveArray<Int64Type>),
|
||||
_ => {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// [PlainTimestampRowFilter] iterates each element in timestamp column, build a [Timestamp] struct
|
||||
/// and checks if given time range contains the timestamp.
|
||||
struct PlainTimestampRowFilter {
|
||||
time_range: TimestampRange,
|
||||
projection: ProjectionMask,
|
||||
}
|
||||
|
||||
impl PlainTimestampRowFilter {
|
||||
fn new(time_range: TimestampRange, projection: ProjectionMask) -> Self {
|
||||
Self {
|
||||
time_range,
|
||||
projection,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ArrowPredicate for PlainTimestampRowFilter {
|
||||
fn projection(&self) -> &ProjectionMask {
|
||||
&self.projection
|
||||
}
|
||||
|
||||
fn evaluate(&mut self, batch: RecordBatch) -> Result<BooleanArray, ArrowError> {
|
||||
// the projection has only timestamp column, so we can safely take the first column in batch.
|
||||
let ts_col = batch.column(0);
|
||||
|
||||
macro_rules! downcast_and_compute {
|
||||
($array_ty: ty, $unit: ident) => {{
|
||||
let ts_col = ts_col
|
||||
.as_any()
|
||||
.downcast_ref::<$array_ty>()
|
||||
.unwrap(); // safety: we've checked the data type of timestamp column.
|
||||
Ok(BooleanArray::from_iter(ts_col.iter().map(|ts| {
|
||||
ts.map(|val| {
|
||||
Timestamp::new(val, TimeUnit::$unit)
|
||||
}).map(|ts| {
|
||||
self.time_range.contains(&ts)
|
||||
})
|
||||
})))
|
||||
|
||||
}};
|
||||
}
|
||||
|
||||
match ts_col.data_type() {
|
||||
DataType::Timestamp(unit, _) => match unit {
|
||||
arrow::datatypes::TimeUnit::Second => {
|
||||
downcast_and_compute!(TimestampSecondArray, Second)
|
||||
}
|
||||
arrow::datatypes::TimeUnit::Millisecond => {
|
||||
downcast_and_compute!(TimestampMillisecondArray, Millisecond)
|
||||
}
|
||||
arrow::datatypes::TimeUnit::Microsecond => {
|
||||
downcast_and_compute!(TimestampMicrosecondArray, Microsecond)
|
||||
}
|
||||
arrow::datatypes::TimeUnit::Nanosecond => {
|
||||
downcast_and_compute!(TimestampNanosecondArray, Nanosecond)
|
||||
}
|
||||
},
|
||||
DataType::Int64 => {
|
||||
downcast_and_compute!(PrimitiveArray<Int64Type>, Millisecond)
|
||||
}
|
||||
_ => {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use arrow_array::ArrayRef;
|
||||
use datafusion_common::ToDFSchema;
|
||||
use datafusion_expr::Operator;
|
||||
use datafusion_physical_expr::create_physical_expr;
|
||||
use datafusion_physical_expr::execution_props::ExecutionProps;
|
||||
use datatypes::arrow_array::StringArray;
|
||||
use datatypes::schema::{ColumnSchema, Schema};
|
||||
use datatypes::value::timestamp_to_scalar_value;
|
||||
use parquet::arrow::arrow_to_parquet_schema;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn check_unit_lossy(range_unit: TimeUnit, col_unit: TimeUnit, expect: bool) {
|
||||
assert_eq!(
|
||||
expect,
|
||||
time_unit_lossy(
|
||||
&TimestampRange::with_unit(0, 1, range_unit).unwrap(),
|
||||
col_unit
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_time_unit_lossy() {
|
||||
// converting a range with unit second to millisecond will not cause rounding error
|
||||
check_unit_lossy(TimeUnit::Second, TimeUnit::Second, false);
|
||||
check_unit_lossy(TimeUnit::Second, TimeUnit::Millisecond, false);
|
||||
check_unit_lossy(TimeUnit::Second, TimeUnit::Microsecond, false);
|
||||
check_unit_lossy(TimeUnit::Second, TimeUnit::Nanosecond, false);
|
||||
|
||||
check_unit_lossy(TimeUnit::Millisecond, TimeUnit::Second, true);
|
||||
check_unit_lossy(TimeUnit::Millisecond, TimeUnit::Millisecond, false);
|
||||
check_unit_lossy(TimeUnit::Millisecond, TimeUnit::Microsecond, false);
|
||||
check_unit_lossy(TimeUnit::Millisecond, TimeUnit::Nanosecond, false);
|
||||
|
||||
check_unit_lossy(TimeUnit::Microsecond, TimeUnit::Second, true);
|
||||
check_unit_lossy(TimeUnit::Microsecond, TimeUnit::Millisecond, true);
|
||||
check_unit_lossy(TimeUnit::Microsecond, TimeUnit::Microsecond, false);
|
||||
check_unit_lossy(TimeUnit::Microsecond, TimeUnit::Nanosecond, false);
|
||||
|
||||
check_unit_lossy(TimeUnit::Nanosecond, TimeUnit::Second, true);
|
||||
check_unit_lossy(TimeUnit::Nanosecond, TimeUnit::Millisecond, true);
|
||||
check_unit_lossy(TimeUnit::Nanosecond, TimeUnit::Microsecond, true);
|
||||
check_unit_lossy(TimeUnit::Nanosecond, TimeUnit::Nanosecond, false);
|
||||
}
|
||||
|
||||
fn check_arrow_predicate(
|
||||
schema: Schema,
|
||||
expr: datafusion_expr::Expr,
|
||||
columns: Vec<ArrayRef>,
|
||||
expected: Vec<Option<bool>>,
|
||||
) {
|
||||
let arrow_schema = schema.arrow_schema();
|
||||
let df_schema = arrow_schema.clone().to_dfschema().unwrap();
|
||||
let physical_expr = create_physical_expr(
|
||||
&expr,
|
||||
&df_schema,
|
||||
arrow_schema.as_ref(),
|
||||
&ExecutionProps::default(),
|
||||
)
|
||||
.unwrap();
|
||||
let parquet_schema = arrow_to_parquet_schema(arrow_schema).unwrap();
|
||||
let mut predicate = DatafusionArrowPredicate {
|
||||
physical_expr,
|
||||
projection_mask: ProjectionMask::roots(&parquet_schema, vec![0, 1]),
|
||||
};
|
||||
|
||||
let batch = arrow_array::RecordBatch::try_new(arrow_schema.clone(), columns).unwrap();
|
||||
|
||||
let res = predicate.evaluate(batch).unwrap();
|
||||
assert_eq!(expected, res.iter().collect::<Vec<_>>());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_datafusion_predicate() {
|
||||
let schema = Schema::new(vec![
|
||||
ColumnSchema::new(
|
||||
"ts",
|
||||
ConcreteDataType::timestamp_datatype(TimeUnit::Nanosecond),
|
||||
false,
|
||||
),
|
||||
ColumnSchema::new("name", ConcreteDataType::string_datatype(), true),
|
||||
]);
|
||||
|
||||
let expr = datafusion_expr::and(
|
||||
datafusion_expr::binary_expr(
|
||||
datafusion_expr::col("ts"),
|
||||
Operator::GtEq,
|
||||
datafusion_expr::lit(timestamp_to_scalar_value(TimeUnit::Nanosecond, Some(10))),
|
||||
),
|
||||
datafusion_expr::binary_expr(
|
||||
datafusion_expr::col("name"),
|
||||
Operator::Lt,
|
||||
datafusion_expr::lit("Bob"),
|
||||
),
|
||||
);
|
||||
|
||||
let ts_arr = Arc::new(TimestampNanosecondArray::from(vec![9, 11])) as Arc<_>;
|
||||
let name_arr = Arc::new(StringArray::from(vec![Some("Alice"), Some("Charlie")])) as Arc<_>;
|
||||
|
||||
let columns = vec![ts_arr, name_arr];
|
||||
check_arrow_predicate(
|
||||
schema.clone(),
|
||||
expr,
|
||||
columns.clone(),
|
||||
vec![Some(false), Some(false)],
|
||||
);
|
||||
|
||||
let expr = datafusion_expr::and(
|
||||
datafusion_expr::binary_expr(
|
||||
datafusion_expr::col("ts"),
|
||||
Operator::Lt,
|
||||
datafusion_expr::lit(timestamp_to_scalar_value(TimeUnit::Nanosecond, Some(10))),
|
||||
),
|
||||
datafusion_expr::binary_expr(
|
||||
datafusion_expr::col("name"),
|
||||
Operator::Lt,
|
||||
datafusion_expr::lit("Bob"),
|
||||
),
|
||||
);
|
||||
|
||||
check_arrow_predicate(schema, expr, columns, vec![Some(true), Some(false)]);
|
||||
}
|
||||
}
|
||||
@@ -12,66 +12,94 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::logical_plan::{DfExpr, Expr};
|
||||
use common_telemetry::{error, warn};
|
||||
use common_time::range::TimestampRange;
|
||||
use common_time::timestamp::TimeUnit;
|
||||
use common_time::Timestamp;
|
||||
use datafusion::parquet::file::metadata::RowGroupMetaData;
|
||||
use datafusion::physical_optimizer::pruning::PruningPredicate;
|
||||
use datafusion_common::ToDFSchema;
|
||||
use datafusion_expr::expr::InList;
|
||||
use datafusion_expr::{Between, BinaryExpr, Operator};
|
||||
use datafusion_physical_expr::create_physical_expr;
|
||||
use datafusion_physical_expr::execution_props::ExecutionProps;
|
||||
use datafusion_physical_expr::{create_physical_expr, PhysicalExpr};
|
||||
use datatypes::schema::SchemaRef;
|
||||
use datatypes::value::scalar_value_to_timestamp;
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::error;
|
||||
use crate::predicate::stats::RowGroupPruningStatistics;
|
||||
|
||||
mod stats;
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
#[derive(Clone)]
|
||||
pub struct Predicate {
|
||||
exprs: Vec<Expr>,
|
||||
/// The schema of underlying storage.
|
||||
schema: SchemaRef,
|
||||
/// Physical expressions of this predicate.
|
||||
exprs: Vec<Arc<dyn PhysicalExpr>>,
|
||||
}
|
||||
|
||||
impl Predicate {
|
||||
pub fn new(exprs: Vec<Expr>) -> Self {
|
||||
Self { exprs }
|
||||
}
|
||||
|
||||
pub fn empty() -> Self {
|
||||
Self { exprs: vec![] }
|
||||
}
|
||||
|
||||
pub fn prune_row_groups(
|
||||
&self,
|
||||
schema: SchemaRef,
|
||||
row_groups: &[RowGroupMetaData],
|
||||
) -> Vec<bool> {
|
||||
let mut res = vec![true; row_groups.len()];
|
||||
let arrow_schema = (*schema.arrow_schema()).clone();
|
||||
let df_schema = arrow_schema.clone().to_dfschema_ref();
|
||||
let df_schema = match df_schema {
|
||||
Ok(x) => x,
|
||||
Err(e) => {
|
||||
warn!("Failed to create Datafusion schema when trying to prune row groups, error: {e}");
|
||||
return res;
|
||||
}
|
||||
};
|
||||
/// Creates a new `Predicate` by converting logical exprs to physical exprs that can be
|
||||
/// evaluated against record batches.
|
||||
/// Returns error when failed to convert exprs.
|
||||
pub fn try_new(exprs: Vec<Expr>, schema: SchemaRef) -> error::Result<Self> {
|
||||
let arrow_schema = schema.arrow_schema();
|
||||
let df_schema = arrow_schema
|
||||
.clone()
|
||||
.to_dfschema_ref()
|
||||
.context(error::DatafusionSnafu)?;
|
||||
|
||||
// TODO(hl): `execution_props` provides variables required by evaluation.
|
||||
// we may reuse the `execution_props` from `SessionState` once we support
|
||||
// registering variables.
|
||||
let execution_props = &ExecutionProps::new();
|
||||
|
||||
let physical_exprs = exprs
|
||||
.iter()
|
||||
.map(|expr| {
|
||||
create_physical_expr(
|
||||
expr.df_expr(),
|
||||
df_schema.as_ref(),
|
||||
arrow_schema.as_ref(),
|
||||
execution_props,
|
||||
)
|
||||
})
|
||||
.collect::<Result<_, _>>()
|
||||
.context(error::DatafusionSnafu)?;
|
||||
|
||||
Ok(Self {
|
||||
schema,
|
||||
exprs: physical_exprs,
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn exprs(&self) -> &[Arc<dyn PhysicalExpr>] {
|
||||
&self.exprs
|
||||
}
|
||||
|
||||
/// Builds an empty predicate from given schema.
|
||||
pub fn empty(schema: SchemaRef) -> Self {
|
||||
Self {
|
||||
schema,
|
||||
exprs: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluates the predicate against row group metadata.
|
||||
/// Returns a vector of boolean values, among which `false` means the row group can be skipped.
|
||||
pub fn prune_row_groups(&self, row_groups: &[RowGroupMetaData]) -> Vec<bool> {
|
||||
let mut res = vec![true; row_groups.len()];
|
||||
let arrow_schema = self.schema.arrow_schema();
|
||||
for expr in &self.exprs {
|
||||
match create_physical_expr(
|
||||
expr.df_expr(),
|
||||
df_schema.as_ref(),
|
||||
arrow_schema.as_ref(),
|
||||
execution_props,
|
||||
)
|
||||
.and_then(|expr| PruningPredicate::try_new(expr, arrow_schema.clone()))
|
||||
{
|
||||
match PruningPredicate::try_new(expr.clone(), arrow_schema.clone()) {
|
||||
Ok(p) => {
|
||||
let stat = RowGroupPruningStatistics::new(row_groups, &schema);
|
||||
let stat = RowGroupPruningStatistics::new(row_groups, &self.schema);
|
||||
match p.prune(&stat) {
|
||||
Ok(r) => {
|
||||
for (curr_val, res) in r.into_iter().zip(res.iter_mut()) {
|
||||
@@ -94,15 +122,19 @@ impl Predicate {
|
||||
|
||||
// tests for `TimeRangePredicateBuilder` locates in src/query/tests/time_range_filter_test.rs
|
||||
// since it requires query engine to convert sql to filters.
|
||||
/// `TimeRangePredicateBuilder` extracts time range from logical exprs to facilitate fast
|
||||
/// time range pruning.
|
||||
pub struct TimeRangePredicateBuilder<'a> {
|
||||
ts_col_name: &'a str,
|
||||
ts_col_unit: TimeUnit,
|
||||
filters: &'a [Expr],
|
||||
}
|
||||
|
||||
impl<'a> TimeRangePredicateBuilder<'a> {
|
||||
pub fn new(ts_col_name: &'a str, filters: &'a [Expr]) -> Self {
|
||||
pub fn new(ts_col_name: &'a str, ts_col_unit: TimeUnit, filters: &'a [Expr]) -> Self {
|
||||
Self {
|
||||
ts_col_name,
|
||||
ts_col_unit,
|
||||
filters,
|
||||
}
|
||||
}
|
||||
@@ -149,18 +181,23 @@ impl<'a> TimeRangePredicateBuilder<'a> {
|
||||
match op {
|
||||
Operator::Eq => self
|
||||
.get_timestamp_filter(left, right)
|
||||
.and_then(|ts| ts.convert_to(self.ts_col_unit))
|
||||
.map(TimestampRange::single),
|
||||
Operator::Lt => self
|
||||
.get_timestamp_filter(left, right)
|
||||
.and_then(|ts| ts.convert_to_ceil(self.ts_col_unit))
|
||||
.map(|ts| TimestampRange::until_end(ts, false)),
|
||||
Operator::LtEq => self
|
||||
.get_timestamp_filter(left, right)
|
||||
.and_then(|ts| ts.convert_to_ceil(self.ts_col_unit))
|
||||
.map(|ts| TimestampRange::until_end(ts, true)),
|
||||
Operator::Gt => self
|
||||
.get_timestamp_filter(left, right)
|
||||
.and_then(|ts| ts.convert_to(self.ts_col_unit))
|
||||
.map(TimestampRange::from_start),
|
||||
Operator::GtEq => self
|
||||
.get_timestamp_filter(left, right)
|
||||
.and_then(|ts| ts.convert_to(self.ts_col_unit))
|
||||
.map(TimestampRange::from_start),
|
||||
Operator::And => {
|
||||
// instead of return none when failed to extract time range from left/right, we unwrap the none into
|
||||
@@ -231,8 +268,10 @@ impl<'a> TimeRangePredicateBuilder<'a> {
|
||||
|
||||
match (low, high) {
|
||||
(DfExpr::Literal(low), DfExpr::Literal(high)) => {
|
||||
let low_opt = scalar_value_to_timestamp(low);
|
||||
let high_opt = scalar_value_to_timestamp(high);
|
||||
let low_opt =
|
||||
scalar_value_to_timestamp(low).and_then(|ts| ts.convert_to(self.ts_col_unit));
|
||||
let high_opt = scalar_value_to_timestamp(high)
|
||||
.and_then(|ts| ts.convert_to_ceil(self.ts_col_unit));
|
||||
Some(TimestampRange::new_inclusive(low_opt, high_opt))
|
||||
}
|
||||
_ => None,
|
||||
@@ -329,10 +368,15 @@ mod tests {
|
||||
(path, schema)
|
||||
}
|
||||
|
||||
async fn assert_prune(array_cnt: usize, predicate: Predicate, expect: Vec<bool>) {
|
||||
async fn assert_prune(
|
||||
array_cnt: usize,
|
||||
filters: Vec<common_query::logical_plan::Expr>,
|
||||
expect: Vec<bool>,
|
||||
) {
|
||||
let dir = create_temp_dir("prune_parquet");
|
||||
let (path, schema) = gen_test_parquet_file(&dir, array_cnt).await;
|
||||
let schema = Arc::new(datatypes::schema::Schema::try_from(schema).unwrap());
|
||||
let arrow_predicate = Predicate::try_new(filters, schema.clone()).unwrap();
|
||||
let builder = ParquetRecordBatchStreamBuilder::new(
|
||||
tokio::fs::OpenOptions::new()
|
||||
.read(true)
|
||||
@@ -344,23 +388,23 @@ mod tests {
|
||||
.unwrap();
|
||||
let metadata = builder.metadata().clone();
|
||||
let row_groups = metadata.row_groups();
|
||||
let res = predicate.prune_row_groups(schema, row_groups);
|
||||
let res = arrow_predicate.prune_row_groups(row_groups);
|
||||
assert_eq!(expect, res);
|
||||
}
|
||||
|
||||
fn gen_predicate(max_val: i32, op: Operator) -> Predicate {
|
||||
Predicate::new(vec![common_query::logical_plan::Expr::from(
|
||||
Expr::BinaryExpr(BinaryExpr {
|
||||
fn gen_predicate(max_val: i32, op: Operator) -> Vec<common_query::logical_plan::Expr> {
|
||||
vec![common_query::logical_plan::Expr::from(Expr::BinaryExpr(
|
||||
BinaryExpr {
|
||||
left: Box::new(Expr::Column(Column::from_name("cnt"))),
|
||||
op,
|
||||
right: Box::new(Expr::Literal(ScalarValue::Int32(Some(max_val)))),
|
||||
}),
|
||||
)])
|
||||
},
|
||||
))]
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_prune_empty() {
|
||||
assert_prune(3, Predicate::empty(), vec![true]).await;
|
||||
assert_prune(3, vec![], vec![true]).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -424,7 +468,6 @@ mod tests {
|
||||
let e = Expr::Column(Column::from_name("cnt"))
|
||||
.gt(30.lit())
|
||||
.or(Expr::Column(Column::from_name("cnt")).lt(20.lit()));
|
||||
let p = Predicate::new(vec![e.into()]);
|
||||
assert_prune(40, p, vec![true, true, false, true]).await;
|
||||
assert_prune(40, vec![e.into()], vec![true, true, false, true]).await;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user