feat: push all possible filters down to parquet exec (#1839)

* feat: push all possible filters down to parquet exec

* fix: project

* test: add ut for DatafusionArrowPredicate

* fix: according to CR comments
This commit is contained in:
Lei, HUANG
2023-06-28 20:14:37 +08:00
committed by GitHub
parent bc33fdc8ef
commit 559d1f73a2
12 changed files with 658 additions and 342 deletions

2
Cargo.lock generated
View File

@@ -9150,8 +9150,10 @@ dependencies = [
"common-test-util",
"common-time",
"criterion 0.3.6",
"datafusion",
"datafusion-common",
"datafusion-expr",
"datafusion-physical-expr",
"datatypes",
"futures",
"futures-util",

View File

@@ -294,7 +294,7 @@ fn new_item_field(data_type: ArrowDataType) -> Field {
Field::new("item", data_type, false)
}
fn timestamp_to_scalar_value(unit: TimeUnit, val: Option<i64>) -> ScalarValue {
pub fn timestamp_to_scalar_value(unit: TimeUnit, val: Option<i64>) -> ScalarValue {
match unit {
TimeUnit::Second => ScalarValue::TimestampSecond(val, None),
TimeUnit::Millisecond => ScalarValue::TimestampMillisecond(val, None),

View File

@@ -140,7 +140,7 @@ impl TimeRangeTester {
let _ = exec_selection(self.engine.clone(), sql).await;
let filters = self.table.get_filters().await;
let range = TimeRangePredicateBuilder::new("ts", &filters).build();
let range = TimeRangePredicateBuilder::new("ts", TimeUnit::Millisecond, &filters).build();
assert_eq!(expect, range);
}
}

View File

@@ -23,6 +23,8 @@ common-time = { path = "../common/time" }
datatypes = { path = "../datatypes" }
datafusion-common.workspace = true
datafusion-expr.workspace = true
datafusion-physical-expr.workspace = true
datafusion.workspace = true
futures.workspace = true
futures-util.workspace = true
itertools.workspace = true

View File

@@ -226,10 +226,16 @@ impl ChunkReaderBuilder {
reader_builder = reader_builder.push_batch_iter(iter);
}
let predicate = Predicate::try_new(
self.filters.clone(),
self.schema.store_schema().schema().clone(),
)
.context(error::BuildPredicateSnafu)?;
let read_opts = ReadOptions {
batch_size: self.iter_ctx.batch_size,
projected_schema: schema.clone(),
predicate: Predicate::new(self.filters.clone()),
predicate,
time_range: *time_range,
};
for file in &self.files_to_read {
@@ -270,7 +276,12 @@ impl ChunkReaderBuilder {
/// Build time range predicate from schema and filters.
pub fn build_time_range_predicate(&self) -> TimestampRange {
let Some(ts_col) = self.schema.user_schema().timestamp_column() else { return TimestampRange::min_to_max() };
TimeRangePredicateBuilder::new(&ts_col.name, &self.filters).build()
let unit = ts_col
.data_type
.as_timestamp()
.expect("Timestamp column must have timestamp-compatible type")
.unit();
TimeRangePredicateBuilder::new(&ts_col.name, unit, &self.filters).build()
}
/// Check if SST file's time range matches predicate.

View File

@@ -13,8 +13,9 @@
// limitations under the License.
use common_query::logical_plan::{DfExpr, Expr};
use datafusion_common::ScalarValue;
use datafusion_expr::{BinaryExpr, Operator};
use common_time::timestamp::TimeUnit;
use datafusion_expr::Operator;
use datatypes::value::timestamp_to_scalar_value;
use crate::chunk::{ChunkReaderBuilder, ChunkReaderImpl};
use crate::error;
@@ -31,53 +32,84 @@ pub(crate) async fn build_sst_reader(
) -> error::Result<ChunkReaderImpl> {
// TODO(hl): Schemas in different SSTs may differ, thus we should infer
// timestamp column name from Parquet metadata.
let ts_col_name = schema
.user_schema()
.timestamp_column()
.unwrap()
.name
.clone();
// safety: Region schema's timestamp column must present
let ts_col = schema.user_schema().timestamp_column().unwrap();
let ts_col_unit = ts_col.data_type.as_timestamp().unwrap().unit();
let ts_col_name = ts_col.name.clone();
ChunkReaderBuilder::new(schema, sst_layer)
.pick_ssts(files)
.filters(vec![build_time_range_filter(
lower_sec_inclusive,
upper_sec_exclusive,
&ts_col_name,
)])
.filters(
build_time_range_filter(
lower_sec_inclusive,
upper_sec_exclusive,
&ts_col_name,
ts_col_unit,
)
.into_iter()
.collect(),
)
.build()
.await
}
fn build_time_range_filter(low_sec: i64, high_sec: i64, ts_col_name: &str) -> Expr {
let ts_col = Box::new(DfExpr::Column(datafusion_common::Column::from_name(
ts_col_name,
)));
let lower_bound_expr = Box::new(DfExpr::Literal(ScalarValue::TimestampSecond(
Some(low_sec),
None,
)));
/// Build time range filter expr from lower (inclusive) and upper bound(exclusive).
/// Returns `None` if time range overflows.
fn build_time_range_filter(
low_sec: i64,
high_sec: i64,
ts_col_name: &str,
ts_col_unit: TimeUnit,
) -> Option<Expr> {
debug_assert!(low_sec <= high_sec);
let ts_col = DfExpr::Column(datafusion_common::Column::from_name(ts_col_name));
let upper_bound_expr = Box::new(DfExpr::Literal(ScalarValue::TimestampSecond(
Some(high_sec),
None,
)));
// Converting seconds to whatever unit won't lose precision.
// Here only handles overflow.
let low_ts = common_time::Timestamp::new_second(low_sec)
.convert_to(ts_col_unit)
.map(|ts| ts.value());
let high_ts = common_time::Timestamp::new_second(high_sec)
.convert_to(ts_col_unit)
.map(|ts| ts.value());
let expr = DfExpr::BinaryExpr(BinaryExpr {
left: Box::new(DfExpr::BinaryExpr(BinaryExpr {
left: ts_col.clone(),
op: Operator::GtEq,
right: lower_bound_expr,
})),
op: Operator::And,
right: Box::new(DfExpr::BinaryExpr(BinaryExpr {
left: ts_col,
op: Operator::Lt,
right: upper_bound_expr,
})),
});
let expr = match (low_ts, high_ts) {
(Some(low), Some(high)) => {
let lower_bound_expr =
DfExpr::Literal(timestamp_to_scalar_value(ts_col_unit, Some(low)));
let upper_bound_expr =
DfExpr::Literal(timestamp_to_scalar_value(ts_col_unit, Some(high)));
Some(datafusion_expr::and(
datafusion_expr::binary_expr(ts_col.clone(), Operator::GtEq, lower_bound_expr),
datafusion_expr::binary_expr(ts_col, Operator::Lt, upper_bound_expr),
))
}
Expr::from(expr)
(Some(low), None) => {
let lower_bound_expr =
datafusion_expr::lit(timestamp_to_scalar_value(ts_col_unit, Some(low)));
Some(datafusion_expr::binary_expr(
ts_col,
Operator::GtEq,
lower_bound_expr,
))
}
(None, Some(high)) => {
let upper_bound_expr =
datafusion_expr::lit(timestamp_to_scalar_value(ts_col_unit, Some(high)));
Some(datafusion_expr::binary_expr(
ts_col,
Operator::Lt,
upper_bound_expr,
))
}
(None, None) => None,
};
expr.map(Expr::from)
}
#[cfg(test)]
@@ -490,4 +522,35 @@ mod tests {
assert_eq!(timestamps_in_outputs, timestamps_in_inputs);
}
#[test]
fn test_build_time_range_filter() {
assert!(build_time_range_filter(i64::MIN, i64::MAX, "ts", TimeUnit::Nanosecond).is_none());
assert_eq!(
Expr::from(datafusion_expr::binary_expr(
datafusion_expr::col("ts"),
Operator::Lt,
datafusion_expr::lit(timestamp_to_scalar_value(
TimeUnit::Nanosecond,
Some(TimeUnit::Second.factor() as i64 / TimeUnit::Nanosecond.factor() as i64)
))
)),
build_time_range_filter(i64::MIN, 1, "ts", TimeUnit::Nanosecond).unwrap()
);
assert_eq!(
Expr::from(datafusion_expr::binary_expr(
datafusion_expr::col("ts"),
Operator::GtEq,
datafusion_expr::lit(timestamp_to_scalar_value(
TimeUnit::Nanosecond,
Some(
2 * TimeUnit::Second.factor() as i64 / TimeUnit::Nanosecond.factor() as i64
)
))
)),
build_time_range_filter(2, i64::MAX, "ts", TimeUnit::Nanosecond).unwrap()
);
}
}

View File

@@ -522,6 +522,12 @@ pub enum Error {
source: ArrowError,
location: Location,
},
#[snafu(display("Failed to build scan predicate, source: {}", source))]
BuildPredicate {
source: table::error::Error,
location: Location,
},
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -621,6 +627,7 @@ impl ErrorExt for Error {
TtlCalculation { source, .. } => source.status_code(),
ConvertColumnsToRows { .. } | SortArrays { .. } => StatusCode::Unexpected,
BuildPredicate { source, .. } => source.status_code(),
}
}

View File

@@ -21,7 +21,9 @@ use arrow::compute::SortOptions;
use common_query::prelude::Expr;
use common_recordbatch::OrderOption;
use common_test_util::temp_dir::create_temp_dir;
use common_time::timestamp::TimeUnit;
use datafusion_common::Column;
use datatypes::value::timestamp_to_scalar_value;
use log_store::raft_engine::log_store::RaftEngineLogStore;
use store_api::storage::{FlushContext, FlushReason, OpenOptions, Region, ScanRequest};
@@ -404,7 +406,10 @@ async fn test_flush_and_query_empty() {
filters: vec![Expr::from(datafusion_expr::binary_expr(
DfExpr::Column(Column::from("timestamp")),
datafusion_expr::Operator::GtEq,
datafusion_expr::lit(20000),
datafusion_expr::lit(timestamp_to_scalar_value(
TimeUnit::Millisecond,
Some(20000),
)),
))],
output_ordering: Some(vec![OrderOption {
name: "timestamp".to_string(),

View File

@@ -13,6 +13,7 @@
// limitations under the License.
pub(crate) mod parquet;
mod pruning;
mod stream_writer;
use std::collections::HashMap;

View File

@@ -18,12 +18,6 @@ use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use arrow::datatypes::DataType;
use arrow_array::types::Int64Type;
use arrow_array::{
Array, PrimitiveArray, TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampNanosecondArray, TimestampSecondArray,
};
use async_compat::CompatExt;
use async_stream::try_stream;
use async_trait::async_trait;
@@ -31,19 +25,16 @@ use common_telemetry::{debug, error};
use common_time::range::TimestampRange;
use common_time::timestamp::TimeUnit;
use common_time::Timestamp;
use datatypes::arrow::array::BooleanArray;
use datatypes::arrow::error::ArrowError;
use datatypes::arrow::record_batch::RecordBatch;
use datatypes::prelude::ConcreteDataType;
use futures_util::{Stream, StreamExt, TryStreamExt};
use object_store::ObjectStore;
use parquet::arrow::arrow_reader::{ArrowPredicate, RowFilter};
use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask};
use parquet::basic::{Compression, Encoding, ZstdLevel};
use parquet::file::metadata::KeyValue;
use parquet::file::properties::WriterProperties;
use parquet::format::FileMetaData;
use parquet::schema::types::{ColumnPath, SchemaDescriptor};
use parquet::schema::types::ColumnPath;
use snafu::{OptionExt, ResultExt};
use store_api::storage::consts::SEQUENCE_COLUMN_NAME;
use table::predicate::Predicate;
@@ -54,6 +45,7 @@ use crate::read::{Batch, BatchReader};
use crate::schema::compat::ReadAdapter;
use crate::schema::{ProjectedSchemaRef, StoreSchema};
use crate::sst;
use crate::sst::pruning::build_row_filter;
use crate::sst::stream_writer::BufferedWriter;
use crate::sst::{FileHandle, Source, SstInfo};
@@ -277,10 +269,7 @@ impl ParquetReader {
let pruned_row_groups = self
.predicate
.prune_row_groups(
store_schema.schema().clone(),
builder.metadata().row_groups(),
)
.prune_row_groups(builder.metadata().row_groups())
.into_iter()
.enumerate()
.filter_map(|(idx, valid)| if valid { Some(idx) } else { None })
@@ -288,15 +277,18 @@ impl ParquetReader {
let parquet_schema_desc = builder.metadata().file_metadata().schema_descr_ptr();
let projection = ProjectionMask::roots(&parquet_schema_desc, adapter.fields_to_read());
let projection_mask = ProjectionMask::roots(&parquet_schema_desc, adapter.fields_to_read());
let mut builder = builder
.with_projection(projection)
.with_projection(projection_mask.clone())
.with_row_groups(pruned_row_groups);
// if time range row filter is present, we can push down the filter to reduce rows to scan.
if let Some(row_filter) =
build_time_range_row_filter(self.time_range, &store_schema, &parquet_schema_desc)
{
if let Some(row_filter) = build_row_filter(
self.time_range,
&self.predicate,
&store_schema,
&parquet_schema_desc,
projection_mask,
) {
builder = builder.with_row_filter(row_filter);
}
@@ -314,198 +306,6 @@ impl ParquetReader {
}
}
/// Builds time range row filter.
fn build_time_range_row_filter(
time_range: TimestampRange,
store_schema: &Arc<StoreSchema>,
schema_desc: &SchemaDescriptor,
) -> Option<RowFilter> {
let ts_col_idx = store_schema.timestamp_index();
let ts_col = store_schema.columns().get(ts_col_idx)?;
let ts_col_unit = match &ts_col.desc.data_type {
ConcreteDataType::Int64(_) => TimeUnit::Millisecond,
ConcreteDataType::Timestamp(ts_type) => ts_type.unit(),
_ => unreachable!(),
};
let projection = ProjectionMask::roots(schema_desc, vec![ts_col_idx]);
// checks if converting time range unit into ts col unit will result into rounding error.
if time_unit_lossy(&time_range, ts_col_unit) {
let filter = RowFilter::new(vec![Box::new(PlainTimestampRowFilter::new(
time_range, projection,
))]);
return Some(filter);
}
// If any of the conversion overflows, we cannot use arrow's computation method, instead
// we resort to plain filter that compares timestamp with given range, less efficient,
// but simpler.
// TODO(hl): If the range is gt_eq/lt, we also use PlainTimestampRowFilter, but these cases
// can also use arrow's gt_eq_scalar/lt_scalar methods.
let row_filter = if let (Some(lower), Some(upper)) = (
time_range
.start()
.and_then(|s| s.convert_to(ts_col_unit))
.map(|t| t.value()),
time_range
.end()
.and_then(|s| s.convert_to(ts_col_unit))
.map(|t| t.value()),
) {
Box::new(FastTimestampRowFilter::new(projection, lower, upper)) as _
} else {
Box::new(PlainTimestampRowFilter::new(time_range, projection)) as _
};
let filter = RowFilter::new(vec![row_filter]);
Some(filter)
}
fn time_unit_lossy(range: &TimestampRange, ts_col_unit: TimeUnit) -> bool {
range
.start()
.map(|start| start.unit().factor() < ts_col_unit.factor())
.unwrap_or(false)
|| range
.end()
.map(|end| end.unit().factor() < ts_col_unit.factor())
.unwrap_or(false)
}
/// `FastTimestampRowFilter` is used to filter rows within given timestamp range when reading
/// row groups from parquet files, while avoids fetching all columns from SSTs file.
struct FastTimestampRowFilter {
lower_bound: i64,
upper_bound: i64,
projection: ProjectionMask,
}
impl FastTimestampRowFilter {
fn new(projection: ProjectionMask, lower_bound: i64, upper_bound: i64) -> Self {
Self {
lower_bound,
upper_bound,
projection,
}
}
}
impl ArrowPredicate for FastTimestampRowFilter {
fn projection(&self) -> &ProjectionMask {
&self.projection
}
/// Selects the rows matching given time range.
fn evaluate(&mut self, batch: RecordBatch) -> std::result::Result<BooleanArray, ArrowError> {
// the projection has only timestamp column, so we can safely take the first column in batch.
let ts_col = batch.column(0);
macro_rules! downcast_and_compute {
($typ: ty) => {
{
let ts_col = ts_col
.as_any()
.downcast_ref::<$typ>()
.unwrap(); // safety: we've checked the data type of timestamp column.
let left = arrow::compute::gt_eq_scalar(ts_col, self.lower_bound)?;
let right = arrow::compute::lt_scalar(ts_col, self.upper_bound)?;
arrow::compute::and(&left, &right)
}
};
}
match ts_col.data_type() {
DataType::Timestamp(unit, _) => match unit {
arrow::datatypes::TimeUnit::Second => {
downcast_and_compute!(TimestampSecondArray)
}
arrow::datatypes::TimeUnit::Millisecond => {
downcast_and_compute!(TimestampMillisecondArray)
}
arrow::datatypes::TimeUnit::Microsecond => {
downcast_and_compute!(TimestampMicrosecondArray)
}
arrow::datatypes::TimeUnit::Nanosecond => {
downcast_and_compute!(TimestampNanosecondArray)
}
},
DataType::Int64 => downcast_and_compute!(PrimitiveArray<Int64Type>),
_ => {
unreachable!()
}
}
}
}
/// [PlainTimestampRowFilter] iterates each element in timestamp column, build a [Timestamp] struct
/// and checks if given time range contains the timestamp.
struct PlainTimestampRowFilter {
time_range: TimestampRange,
projection: ProjectionMask,
}
impl PlainTimestampRowFilter {
fn new(time_range: TimestampRange, projection: ProjectionMask) -> Self {
Self {
time_range,
projection,
}
}
}
impl ArrowPredicate for PlainTimestampRowFilter {
fn projection(&self) -> &ProjectionMask {
&self.projection
}
fn evaluate(&mut self, batch: RecordBatch) -> std::result::Result<BooleanArray, ArrowError> {
// the projection has only timestamp column, so we can safely take the first column in batch.
let ts_col = batch.column(0);
macro_rules! downcast_and_compute {
($array_ty: ty, $unit: ident) => {{
let ts_col = ts_col
.as_any()
.downcast_ref::<$array_ty>()
.unwrap(); // safety: we've checked the data type of timestamp column.
Ok(BooleanArray::from_iter(ts_col.iter().map(|ts| {
ts.map(|val| {
Timestamp::new(val, TimeUnit::$unit)
}).map(|ts| {
self.time_range.contains(&ts)
})
})))
}};
}
match ts_col.data_type() {
DataType::Timestamp(unit, _) => match unit {
arrow::datatypes::TimeUnit::Second => {
downcast_and_compute!(TimestampSecondArray, Second)
}
arrow::datatypes::TimeUnit::Millisecond => {
downcast_and_compute!(TimestampMillisecondArray, Millisecond)
}
arrow::datatypes::TimeUnit::Microsecond => {
downcast_and_compute!(TimestampMicrosecondArray, Microsecond)
}
arrow::datatypes::TimeUnit::Nanosecond => {
downcast_and_compute!(TimestampNanosecondArray, Nanosecond)
}
},
DataType::Int64 => {
downcast_and_compute!(PrimitiveArray<Int64Type>, Millisecond)
}
_ => {
unreachable!()
}
}
}
}
pub type SendableChunkStream = Pin<Box<dyn Stream<Item = Result<RecordBatch>> + Send>>;
pub struct ChunkStream {
@@ -740,11 +540,12 @@ mod tests {
let operator = create_object_store(dir.path().to_str().unwrap());
let projected_schema = Arc::new(ProjectedSchema::new(schema, Some(vec![1])).unwrap());
let user_schema = projected_schema.projected_user_schema().clone();
let reader = ParquetReader::new(
sst_file_handle,
operator,
projected_schema,
Predicate::empty(),
Predicate::empty(user_schema),
TimestampRange::min_to_max(),
);
@@ -826,11 +627,12 @@ mod tests {
let operator = create_object_store(dir.path().to_str().unwrap());
let projected_schema = Arc::new(ProjectedSchema::new(schema, Some(vec![1])).unwrap());
let user_schema = projected_schema.projected_user_schema().clone();
let reader = ParquetReader::new(
file_handle,
operator,
projected_schema,
Predicate::empty(),
Predicate::empty(user_schema),
TimestampRange::min_to_max(),
);
@@ -854,8 +656,14 @@ mod tests {
range: TimestampRange,
expect: Vec<i64>,
) {
let reader =
ParquetReader::new(file_handle, object_store, schema, Predicate::empty(), range);
let store_schema = schema.schema_to_read().clone();
let reader = ParquetReader::new(
file_handle,
object_store,
schema,
Predicate::empty(store_schema.schema().clone()),
range,
);
let mut stream = reader.chunk_stream().await.unwrap();
let result = stream.next_batch().await;
@@ -981,16 +789,6 @@ mod tests {
.await;
}
fn check_unit_lossy(range_unit: TimeUnit, col_unit: TimeUnit, expect: bool) {
assert_eq!(
expect,
time_unit_lossy(
&TimestampRange::with_unit(0, 1, range_unit).unwrap(),
col_unit
)
)
}
#[tokio::test]
async fn test_write_empty_file() {
common_telemetry::init_default_ut_logging();
@@ -1014,28 +812,4 @@ mod tests {
// The file should not exist when no row has been written.
assert!(!object_store.is_exist(sst_file_name).await.unwrap());
}
#[test]
fn test_time_unit_lossy() {
// converting a range with unit second to millisecond will not cause rounding error
check_unit_lossy(TimeUnit::Second, TimeUnit::Second, false);
check_unit_lossy(TimeUnit::Second, TimeUnit::Millisecond, false);
check_unit_lossy(TimeUnit::Second, TimeUnit::Microsecond, false);
check_unit_lossy(TimeUnit::Second, TimeUnit::Nanosecond, false);
check_unit_lossy(TimeUnit::Millisecond, TimeUnit::Second, true);
check_unit_lossy(TimeUnit::Millisecond, TimeUnit::Millisecond, false);
check_unit_lossy(TimeUnit::Millisecond, TimeUnit::Microsecond, false);
check_unit_lossy(TimeUnit::Millisecond, TimeUnit::Nanosecond, false);
check_unit_lossy(TimeUnit::Microsecond, TimeUnit::Second, true);
check_unit_lossy(TimeUnit::Microsecond, TimeUnit::Millisecond, true);
check_unit_lossy(TimeUnit::Microsecond, TimeUnit::Microsecond, false);
check_unit_lossy(TimeUnit::Microsecond, TimeUnit::Nanosecond, false);
check_unit_lossy(TimeUnit::Nanosecond, TimeUnit::Second, true);
check_unit_lossy(TimeUnit::Nanosecond, TimeUnit::Millisecond, true);
check_unit_lossy(TimeUnit::Nanosecond, TimeUnit::Microsecond, true);
check_unit_lossy(TimeUnit::Nanosecond, TimeUnit::Nanosecond, false);
}
}

View File

@@ -0,0 +1,408 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use arrow::array::{
PrimitiveArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
TimestampSecondArray,
};
use arrow::datatypes::{DataType, Int64Type};
use arrow::error::ArrowError;
use arrow_array::{Array, BooleanArray, RecordBatch};
use common_time::range::TimestampRange;
use common_time::timestamp::TimeUnit;
use common_time::Timestamp;
use datafusion::physical_plan::PhysicalExpr;
use datatypes::prelude::ConcreteDataType;
use parquet::arrow::arrow_reader::{ArrowPredicate, RowFilter};
use parquet::arrow::ProjectionMask;
use parquet::schema::types::SchemaDescriptor;
use table::predicate::Predicate;
use crate::error;
use crate::schema::StoreSchema;
/// Builds row filters according to predicates.
pub(crate) fn build_row_filter(
time_range: TimestampRange,
predicate: &Predicate,
store_schema: &Arc<StoreSchema>,
schema_desc: &SchemaDescriptor,
projection_mask: ProjectionMask,
) -> Option<RowFilter> {
let ts_col_idx = store_schema.timestamp_index();
let ts_col = store_schema.columns().get(ts_col_idx)?;
let ts_col_unit = match &ts_col.desc.data_type {
ConcreteDataType::Int64(_) => TimeUnit::Millisecond,
ConcreteDataType::Timestamp(ts_type) => ts_type.unit(),
_ => unreachable!(),
};
let ts_col_projection = ProjectionMask::roots(schema_desc, vec![ts_col_idx]);
// checks if converting time range unit into ts col unit will result into rounding error.
if time_unit_lossy(&time_range, ts_col_unit) {
let filter = RowFilter::new(vec![Box::new(PlainTimestampRowFilter::new(
time_range,
ts_col_projection,
))]);
return Some(filter);
}
// If any of the conversion overflows, we cannot use arrow's computation method, instead
// we resort to plain filter that compares timestamp with given range, less efficient,
// but simpler.
// TODO(hl): If the range is gt_eq/lt, we also use PlainTimestampRowFilter, but these cases
// can also use arrow's gt_eq_scalar/lt_scalar methods.
let time_range_row_filter = if let (Some(lower), Some(upper)) = (
time_range
.start()
.and_then(|s| s.convert_to(ts_col_unit))
.map(|t| t.value()),
time_range
.end()
.and_then(|s| s.convert_to(ts_col_unit))
.map(|t| t.value()),
) {
Box::new(FastTimestampRowFilter::new(ts_col_projection, lower, upper)) as _
} else {
Box::new(PlainTimestampRowFilter::new(time_range, ts_col_projection)) as _
};
let mut predicates = vec![time_range_row_filter];
if let Ok(datafusion_filters) = predicate_to_row_filter(predicate, projection_mask) {
predicates.extend(datafusion_filters);
}
let filter = RowFilter::new(predicates);
Some(filter)
}
fn predicate_to_row_filter(
predicate: &Predicate,
projection_mask: ProjectionMask,
) -> error::Result<Vec<Box<dyn ArrowPredicate>>> {
let mut datafusion_predicates = Vec::with_capacity(predicate.exprs().len());
for expr in predicate.exprs() {
datafusion_predicates.push(Box::new(DatafusionArrowPredicate {
projection_mask: projection_mask.clone(),
physical_expr: expr.clone(),
}) as _);
}
Ok(datafusion_predicates)
}
#[derive(Debug)]
struct DatafusionArrowPredicate {
projection_mask: ProjectionMask,
physical_expr: Arc<dyn PhysicalExpr>,
}
impl ArrowPredicate for DatafusionArrowPredicate {
fn projection(&self) -> &ProjectionMask {
&self.projection_mask
}
fn evaluate(&mut self, batch: RecordBatch) -> Result<BooleanArray, ArrowError> {
match self
.physical_expr
.evaluate(&batch)
.map(|v| v.into_array(batch.num_rows()))
{
Ok(array) => {
let bool_arr = array
.as_any()
.downcast_ref::<BooleanArray>()
.ok_or(ArrowError::CastError(
"Physical expr evaluated res is not a boolean array".to_string(),
))?
.clone();
Ok(bool_arr)
}
Err(e) => Err(ArrowError::ComputeError(format!(
"Error evaluating filter predicate: {e:?}"
))),
}
}
}
fn time_unit_lossy(range: &TimestampRange, ts_col_unit: TimeUnit) -> bool {
range
.start()
.map(|start| start.unit().factor() < ts_col_unit.factor())
.unwrap_or(false)
|| range
.end()
.map(|end| end.unit().factor() < ts_col_unit.factor())
.unwrap_or(false)
}
/// `FastTimestampRowFilter` is used to filter rows within given timestamp range when reading
/// row groups from parquet files, while avoids fetching all columns from SSTs file.
struct FastTimestampRowFilter {
lower_bound: i64,
upper_bound: i64,
projection: ProjectionMask,
}
impl FastTimestampRowFilter {
fn new(projection: ProjectionMask, lower_bound: i64, upper_bound: i64) -> Self {
Self {
lower_bound,
upper_bound,
projection,
}
}
}
impl ArrowPredicate for FastTimestampRowFilter {
fn projection(&self) -> &ProjectionMask {
&self.projection
}
/// Selects the rows matching given time range.
fn evaluate(&mut self, batch: RecordBatch) -> Result<BooleanArray, ArrowError> {
// the projection has only timestamp column, so we can safely take the first column in batch.
let ts_col = batch.column(0);
macro_rules! downcast_and_compute {
($typ: ty) => {
{
let ts_col = ts_col
.as_any()
.downcast_ref::<$typ>()
.unwrap(); // safety: we've checked the data type of timestamp column.
let left = arrow::compute::gt_eq_scalar(ts_col, self.lower_bound)?;
let right = arrow::compute::lt_scalar(ts_col, self.upper_bound)?;
arrow::compute::and(&left, &right)
}
};
}
match ts_col.data_type() {
DataType::Timestamp(unit, _) => match unit {
arrow::datatypes::TimeUnit::Second => {
downcast_and_compute!(TimestampSecondArray)
}
arrow::datatypes::TimeUnit::Millisecond => {
downcast_and_compute!(TimestampMillisecondArray)
}
arrow::datatypes::TimeUnit::Microsecond => {
downcast_and_compute!(TimestampMicrosecondArray)
}
arrow::datatypes::TimeUnit::Nanosecond => {
downcast_and_compute!(TimestampNanosecondArray)
}
},
DataType::Int64 => downcast_and_compute!(PrimitiveArray<Int64Type>),
_ => {
unreachable!()
}
}
}
}
/// [PlainTimestampRowFilter] iterates each element in timestamp column, build a [Timestamp] struct
/// and checks if given time range contains the timestamp.
struct PlainTimestampRowFilter {
time_range: TimestampRange,
projection: ProjectionMask,
}
impl PlainTimestampRowFilter {
fn new(time_range: TimestampRange, projection: ProjectionMask) -> Self {
Self {
time_range,
projection,
}
}
}
impl ArrowPredicate for PlainTimestampRowFilter {
fn projection(&self) -> &ProjectionMask {
&self.projection
}
fn evaluate(&mut self, batch: RecordBatch) -> Result<BooleanArray, ArrowError> {
// the projection has only timestamp column, so we can safely take the first column in batch.
let ts_col = batch.column(0);
macro_rules! downcast_and_compute {
($array_ty: ty, $unit: ident) => {{
let ts_col = ts_col
.as_any()
.downcast_ref::<$array_ty>()
.unwrap(); // safety: we've checked the data type of timestamp column.
Ok(BooleanArray::from_iter(ts_col.iter().map(|ts| {
ts.map(|val| {
Timestamp::new(val, TimeUnit::$unit)
}).map(|ts| {
self.time_range.contains(&ts)
})
})))
}};
}
match ts_col.data_type() {
DataType::Timestamp(unit, _) => match unit {
arrow::datatypes::TimeUnit::Second => {
downcast_and_compute!(TimestampSecondArray, Second)
}
arrow::datatypes::TimeUnit::Millisecond => {
downcast_and_compute!(TimestampMillisecondArray, Millisecond)
}
arrow::datatypes::TimeUnit::Microsecond => {
downcast_and_compute!(TimestampMicrosecondArray, Microsecond)
}
arrow::datatypes::TimeUnit::Nanosecond => {
downcast_and_compute!(TimestampNanosecondArray, Nanosecond)
}
},
DataType::Int64 => {
downcast_and_compute!(PrimitiveArray<Int64Type>, Millisecond)
}
_ => {
unreachable!()
}
}
}
}
#[cfg(test)]
mod tests {
use arrow_array::ArrayRef;
use datafusion_common::ToDFSchema;
use datafusion_expr::Operator;
use datafusion_physical_expr::create_physical_expr;
use datafusion_physical_expr::execution_props::ExecutionProps;
use datatypes::arrow_array::StringArray;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::value::timestamp_to_scalar_value;
use parquet::arrow::arrow_to_parquet_schema;
use super::*;
fn check_unit_lossy(range_unit: TimeUnit, col_unit: TimeUnit, expect: bool) {
assert_eq!(
expect,
time_unit_lossy(
&TimestampRange::with_unit(0, 1, range_unit).unwrap(),
col_unit
)
)
}
#[test]
fn test_time_unit_lossy() {
// converting a range with unit second to millisecond will not cause rounding error
check_unit_lossy(TimeUnit::Second, TimeUnit::Second, false);
check_unit_lossy(TimeUnit::Second, TimeUnit::Millisecond, false);
check_unit_lossy(TimeUnit::Second, TimeUnit::Microsecond, false);
check_unit_lossy(TimeUnit::Second, TimeUnit::Nanosecond, false);
check_unit_lossy(TimeUnit::Millisecond, TimeUnit::Second, true);
check_unit_lossy(TimeUnit::Millisecond, TimeUnit::Millisecond, false);
check_unit_lossy(TimeUnit::Millisecond, TimeUnit::Microsecond, false);
check_unit_lossy(TimeUnit::Millisecond, TimeUnit::Nanosecond, false);
check_unit_lossy(TimeUnit::Microsecond, TimeUnit::Second, true);
check_unit_lossy(TimeUnit::Microsecond, TimeUnit::Millisecond, true);
check_unit_lossy(TimeUnit::Microsecond, TimeUnit::Microsecond, false);
check_unit_lossy(TimeUnit::Microsecond, TimeUnit::Nanosecond, false);
check_unit_lossy(TimeUnit::Nanosecond, TimeUnit::Second, true);
check_unit_lossy(TimeUnit::Nanosecond, TimeUnit::Millisecond, true);
check_unit_lossy(TimeUnit::Nanosecond, TimeUnit::Microsecond, true);
check_unit_lossy(TimeUnit::Nanosecond, TimeUnit::Nanosecond, false);
}
fn check_arrow_predicate(
schema: Schema,
expr: datafusion_expr::Expr,
columns: Vec<ArrayRef>,
expected: Vec<Option<bool>>,
) {
let arrow_schema = schema.arrow_schema();
let df_schema = arrow_schema.clone().to_dfschema().unwrap();
let physical_expr = create_physical_expr(
&expr,
&df_schema,
arrow_schema.as_ref(),
&ExecutionProps::default(),
)
.unwrap();
let parquet_schema = arrow_to_parquet_schema(arrow_schema).unwrap();
let mut predicate = DatafusionArrowPredicate {
physical_expr,
projection_mask: ProjectionMask::roots(&parquet_schema, vec![0, 1]),
};
let batch = arrow_array::RecordBatch::try_new(arrow_schema.clone(), columns).unwrap();
let res = predicate.evaluate(batch).unwrap();
assert_eq!(expected, res.iter().collect::<Vec<_>>());
}
#[test]
fn test_datafusion_predicate() {
let schema = Schema::new(vec![
ColumnSchema::new(
"ts",
ConcreteDataType::timestamp_datatype(TimeUnit::Nanosecond),
false,
),
ColumnSchema::new("name", ConcreteDataType::string_datatype(), true),
]);
let expr = datafusion_expr::and(
datafusion_expr::binary_expr(
datafusion_expr::col("ts"),
Operator::GtEq,
datafusion_expr::lit(timestamp_to_scalar_value(TimeUnit::Nanosecond, Some(10))),
),
datafusion_expr::binary_expr(
datafusion_expr::col("name"),
Operator::Lt,
datafusion_expr::lit("Bob"),
),
);
let ts_arr = Arc::new(TimestampNanosecondArray::from(vec![9, 11])) as Arc<_>;
let name_arr = Arc::new(StringArray::from(vec![Some("Alice"), Some("Charlie")])) as Arc<_>;
let columns = vec![ts_arr, name_arr];
check_arrow_predicate(
schema.clone(),
expr,
columns.clone(),
vec![Some(false), Some(false)],
);
let expr = datafusion_expr::and(
datafusion_expr::binary_expr(
datafusion_expr::col("ts"),
Operator::Lt,
datafusion_expr::lit(timestamp_to_scalar_value(TimeUnit::Nanosecond, Some(10))),
),
datafusion_expr::binary_expr(
datafusion_expr::col("name"),
Operator::Lt,
datafusion_expr::lit("Bob"),
),
);
check_arrow_predicate(schema, expr, columns, vec![Some(true), Some(false)]);
}
}

View File

@@ -12,66 +12,94 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use common_query::logical_plan::{DfExpr, Expr};
use common_telemetry::{error, warn};
use common_time::range::TimestampRange;
use common_time::timestamp::TimeUnit;
use common_time::Timestamp;
use datafusion::parquet::file::metadata::RowGroupMetaData;
use datafusion::physical_optimizer::pruning::PruningPredicate;
use datafusion_common::ToDFSchema;
use datafusion_expr::expr::InList;
use datafusion_expr::{Between, BinaryExpr, Operator};
use datafusion_physical_expr::create_physical_expr;
use datafusion_physical_expr::execution_props::ExecutionProps;
use datafusion_physical_expr::{create_physical_expr, PhysicalExpr};
use datatypes::schema::SchemaRef;
use datatypes::value::scalar_value_to_timestamp;
use snafu::ResultExt;
use crate::error;
use crate::predicate::stats::RowGroupPruningStatistics;
mod stats;
#[derive(Default, Clone)]
#[derive(Clone)]
pub struct Predicate {
exprs: Vec<Expr>,
/// The schema of underlying storage.
schema: SchemaRef,
/// Physical expressions of this predicate.
exprs: Vec<Arc<dyn PhysicalExpr>>,
}
impl Predicate {
pub fn new(exprs: Vec<Expr>) -> Self {
Self { exprs }
}
pub fn empty() -> Self {
Self { exprs: vec![] }
}
pub fn prune_row_groups(
&self,
schema: SchemaRef,
row_groups: &[RowGroupMetaData],
) -> Vec<bool> {
let mut res = vec![true; row_groups.len()];
let arrow_schema = (*schema.arrow_schema()).clone();
let df_schema = arrow_schema.clone().to_dfschema_ref();
let df_schema = match df_schema {
Ok(x) => x,
Err(e) => {
warn!("Failed to create Datafusion schema when trying to prune row groups, error: {e}");
return res;
}
};
/// Creates a new `Predicate` by converting logical exprs to physical exprs that can be
/// evaluated against record batches.
/// Returns error when failed to convert exprs.
pub fn try_new(exprs: Vec<Expr>, schema: SchemaRef) -> error::Result<Self> {
let arrow_schema = schema.arrow_schema();
let df_schema = arrow_schema
.clone()
.to_dfschema_ref()
.context(error::DatafusionSnafu)?;
// TODO(hl): `execution_props` provides variables required by evaluation.
// we may reuse the `execution_props` from `SessionState` once we support
// registering variables.
let execution_props = &ExecutionProps::new();
let physical_exprs = exprs
.iter()
.map(|expr| {
create_physical_expr(
expr.df_expr(),
df_schema.as_ref(),
arrow_schema.as_ref(),
execution_props,
)
})
.collect::<Result<_, _>>()
.context(error::DatafusionSnafu)?;
Ok(Self {
schema,
exprs: physical_exprs,
})
}
#[inline]
pub fn exprs(&self) -> &[Arc<dyn PhysicalExpr>] {
&self.exprs
}
/// Builds an empty predicate from given schema.
pub fn empty(schema: SchemaRef) -> Self {
Self {
schema,
exprs: vec![],
}
}
/// Evaluates the predicate against row group metadata.
/// Returns a vector of boolean values, among which `false` means the row group can be skipped.
pub fn prune_row_groups(&self, row_groups: &[RowGroupMetaData]) -> Vec<bool> {
let mut res = vec![true; row_groups.len()];
let arrow_schema = self.schema.arrow_schema();
for expr in &self.exprs {
match create_physical_expr(
expr.df_expr(),
df_schema.as_ref(),
arrow_schema.as_ref(),
execution_props,
)
.and_then(|expr| PruningPredicate::try_new(expr, arrow_schema.clone()))
{
match PruningPredicate::try_new(expr.clone(), arrow_schema.clone()) {
Ok(p) => {
let stat = RowGroupPruningStatistics::new(row_groups, &schema);
let stat = RowGroupPruningStatistics::new(row_groups, &self.schema);
match p.prune(&stat) {
Ok(r) => {
for (curr_val, res) in r.into_iter().zip(res.iter_mut()) {
@@ -94,15 +122,19 @@ impl Predicate {
// tests for `TimeRangePredicateBuilder` locates in src/query/tests/time_range_filter_test.rs
// since it requires query engine to convert sql to filters.
/// `TimeRangePredicateBuilder` extracts time range from logical exprs to facilitate fast
/// time range pruning.
pub struct TimeRangePredicateBuilder<'a> {
ts_col_name: &'a str,
ts_col_unit: TimeUnit,
filters: &'a [Expr],
}
impl<'a> TimeRangePredicateBuilder<'a> {
pub fn new(ts_col_name: &'a str, filters: &'a [Expr]) -> Self {
pub fn new(ts_col_name: &'a str, ts_col_unit: TimeUnit, filters: &'a [Expr]) -> Self {
Self {
ts_col_name,
ts_col_unit,
filters,
}
}
@@ -149,18 +181,23 @@ impl<'a> TimeRangePredicateBuilder<'a> {
match op {
Operator::Eq => self
.get_timestamp_filter(left, right)
.and_then(|ts| ts.convert_to(self.ts_col_unit))
.map(TimestampRange::single),
Operator::Lt => self
.get_timestamp_filter(left, right)
.and_then(|ts| ts.convert_to_ceil(self.ts_col_unit))
.map(|ts| TimestampRange::until_end(ts, false)),
Operator::LtEq => self
.get_timestamp_filter(left, right)
.and_then(|ts| ts.convert_to_ceil(self.ts_col_unit))
.map(|ts| TimestampRange::until_end(ts, true)),
Operator::Gt => self
.get_timestamp_filter(left, right)
.and_then(|ts| ts.convert_to(self.ts_col_unit))
.map(TimestampRange::from_start),
Operator::GtEq => self
.get_timestamp_filter(left, right)
.and_then(|ts| ts.convert_to(self.ts_col_unit))
.map(TimestampRange::from_start),
Operator::And => {
// instead of return none when failed to extract time range from left/right, we unwrap the none into
@@ -231,8 +268,10 @@ impl<'a> TimeRangePredicateBuilder<'a> {
match (low, high) {
(DfExpr::Literal(low), DfExpr::Literal(high)) => {
let low_opt = scalar_value_to_timestamp(low);
let high_opt = scalar_value_to_timestamp(high);
let low_opt =
scalar_value_to_timestamp(low).and_then(|ts| ts.convert_to(self.ts_col_unit));
let high_opt = scalar_value_to_timestamp(high)
.and_then(|ts| ts.convert_to_ceil(self.ts_col_unit));
Some(TimestampRange::new_inclusive(low_opt, high_opt))
}
_ => None,
@@ -329,10 +368,15 @@ mod tests {
(path, schema)
}
async fn assert_prune(array_cnt: usize, predicate: Predicate, expect: Vec<bool>) {
async fn assert_prune(
array_cnt: usize,
filters: Vec<common_query::logical_plan::Expr>,
expect: Vec<bool>,
) {
let dir = create_temp_dir("prune_parquet");
let (path, schema) = gen_test_parquet_file(&dir, array_cnt).await;
let schema = Arc::new(datatypes::schema::Schema::try_from(schema).unwrap());
let arrow_predicate = Predicate::try_new(filters, schema.clone()).unwrap();
let builder = ParquetRecordBatchStreamBuilder::new(
tokio::fs::OpenOptions::new()
.read(true)
@@ -344,23 +388,23 @@ mod tests {
.unwrap();
let metadata = builder.metadata().clone();
let row_groups = metadata.row_groups();
let res = predicate.prune_row_groups(schema, row_groups);
let res = arrow_predicate.prune_row_groups(row_groups);
assert_eq!(expect, res);
}
fn gen_predicate(max_val: i32, op: Operator) -> Predicate {
Predicate::new(vec![common_query::logical_plan::Expr::from(
Expr::BinaryExpr(BinaryExpr {
fn gen_predicate(max_val: i32, op: Operator) -> Vec<common_query::logical_plan::Expr> {
vec![common_query::logical_plan::Expr::from(Expr::BinaryExpr(
BinaryExpr {
left: Box::new(Expr::Column(Column::from_name("cnt"))),
op,
right: Box::new(Expr::Literal(ScalarValue::Int32(Some(max_val)))),
}),
)])
},
))]
}
#[tokio::test]
async fn test_prune_empty() {
assert_prune(3, Predicate::empty(), vec![true]).await;
assert_prune(3, vec![], vec![true]).await;
}
#[tokio::test]
@@ -424,7 +468,6 @@ mod tests {
let e = Expr::Column(Column::from_name("cnt"))
.gt(30.lit())
.or(Expr::Column(Column::from_name("cnt")).lt(20.lit()));
let p = Predicate::new(vec![e.into()]);
assert_prune(40, p, vec![true, true, false, true]).await;
assert_prune(40, vec![e.into()], vec![true, true, false, true]).await;
}
}