From 1bde1ba3991dc642b4d4b22d08cd9a12acb2ebd8 Mon Sep 17 00:00:00 2001 From: "Lei, HUANG" <6406592+v0y4g3r@users.noreply.github.com> Date: Thu, 8 Dec 2022 17:44:04 +0800 Subject: [PATCH] fix: row group pruning (#725) * fix: row group pruning * chore: use macro to simplify stats implemetation * fxi: CR comments * fix: row group metadata length mismatch * fix: simplify code --- Cargo.lock | 1 + src/table/Cargo.toml | 2 + src/table/src/error.rs | 6 +- src/table/src/predicate.rs | 116 +++++++++------------- src/table/src/predicate/stats.rs | 148 +++++++++++----------------- src/table/src/table/adapter.rs | 9 +- src/table/src/table/numbers.rs | 3 +- src/table/src/table/scan.rs | 12 ++- src/table/src/test_util/memtable.rs | 20 ++-- 9 files changed, 134 insertions(+), 183 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c9201010d5..468633826d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6707,6 +6707,7 @@ dependencies = [ "datatypes", "derive_builder", "futures", + "parquet", "parquet-format-async-temp", "paste", "serde", diff --git a/src/table/Cargo.toml b/src/table/Cargo.toml index d80d5df293..59d9a901bd 100644 --- a/src/table/Cargo.toml +++ b/src/table/Cargo.toml @@ -14,6 +14,7 @@ common-recordbatch = { path = "../common/recordbatch" } common-telemetry = { path = "../common/telemetry" } datafusion = "14.0.0" datafusion-common = "14.0.0" +datafusion-expr = "14.0.0" datatypes = { path = "../datatypes" } derive_builder = "0.11" futures = "0.3" @@ -28,3 +29,4 @@ tokio = { version = "1.18", features = ["full"] } datafusion-expr = "14.0.0" tempdir = "0.3" tokio-util = { version = "0.7", features = ["compat"] } +parquet = { version = "26", features = ["async"] } diff --git a/src/table/src/error.rs b/src/table/src/error.rs index ed18c471ce..3605ab0a1a 100644 --- a/src/table/src/error.rs +++ b/src/table/src/error.rs @@ -152,7 +152,9 @@ impl From for DataFusionError { impl From for RecordBatchError { fn from(e: InnerError) -> RecordBatchError { - RecordBatchError::new(e) + RecordBatchError::External { + source: BoxedError::new(e), + } } } @@ -173,7 +175,7 @@ mod tests { } fn throw_arrow() -> Result<()> { - Err(ArrowError::Overflow).context(PollStreamSnafu)? + Err(ArrowError::ComputeError("Overflow".to_string())).context(PollStreamSnafu)? } #[test] diff --git a/src/table/src/predicate.rs b/src/table/src/predicate.rs index 64d32d57f4..847ae495e6 100644 --- a/src/table/src/predicate.rs +++ b/src/table/src/predicate.rs @@ -16,8 +16,8 @@ mod stats; use common_query::logical_plan::Expr; use common_telemetry::{error, warn}; +use datafusion::parquet::file::metadata::RowGroupMetaData; use datafusion::physical_optimizer::pruning::PruningPredicate; -use datatypes::arrow::io::parquet::read::RowGroupMetaData; use datatypes::schema::SchemaRef; use crate::predicate::stats::RowGroupPruningStatistics; @@ -70,19 +70,17 @@ impl Predicate { mod tests { use std::sync::Arc; - pub use datafusion::parquet::schema::types::{BasicTypeInfo, PhysicalType}; - use datafusion_common::Column; - use datafusion_expr::{Expr, Literal, Operator}; - use datatypes::arrow::array::{Int32Array, Utf8Array}; - use datatypes::arrow::chunk::Chunk; + use datafusion::parquet::arrow::ArrowWriter; + pub use datafusion::parquet::schema::types::BasicTypeInfo; + use datafusion_common::{Column, ScalarValue}; + use datafusion_expr::{BinaryExpr, Expr, Literal, Operator}; + use datatypes::arrow::array::Int32Array; use datatypes::arrow::datatypes::{DataType, Field, Schema}; - use datatypes::arrow::io::parquet::read::FileReader; - use datatypes::arrow::io::parquet::write::{ - Compression, Encoding, FileSink, Version, WriteOptions, - }; - use futures::{AsyncWriteExt, SinkExt}; + use datatypes::arrow::record_batch::RecordBatch; + use datatypes::arrow_array::StringArray; + use parquet::arrow::ParquetRecordBatchStreamBuilder; + use parquet::file::properties::WriterProperties; use tempdir::TempDir; - use tokio_util::compat::TokioAsyncWriteCompatExt; use super::*; @@ -95,80 +93,62 @@ mod tests { let name_field = Field::new("name", DataType::Utf8, true); let count_field = Field::new("cnt", DataType::Int32, true); + let schema = Arc::new(Schema::new(vec![name_field, count_field])); - let schema = Schema::from(vec![name_field, count_field]); - - // now all physical types use plain encoding, maybe let caller to choose encoding for each type. - let encodings = vec![Encoding::Plain].repeat(schema.fields.len()); - - let mut writer = tokio::fs::OpenOptions::new() + let file = std::fs::OpenOptions::new() .write(true) .create(true) - .open(&path) - .await - .unwrap() - .compat_write(); + .open(path.clone()) + .unwrap(); - let mut sink = FileSink::try_new( - &mut writer, - schema.clone(), - encodings, - WriteOptions { - write_statistics: true, - compression: Compression::Gzip, - version: Version::V2, - }, - ) - .unwrap(); + let write_props = WriterProperties::builder() + .set_max_row_group_size(10) + .build(); + let mut writer = ArrowWriter::try_new(file, schema.clone(), Some(write_props)).unwrap(); for i in (0..cnt).step_by(10) { - let name_array = Utf8Array::::from( - &(i..(i + 10).min(cnt)) - .map(|i| Some(i.to_string())) + let name_array = Arc::new(StringArray::from( + (i..(i + 10).min(cnt)) + .map(|i| i.to_string()) .collect::>(), - ); - let count_array = Int32Array::from( - &(i..(i + 10).min(cnt)) - .map(|i| Some(i as i32)) - .collect::>(), - ); - - sink.send(Chunk::new(vec![ - Arc::new(name_array), - Arc::new(count_array), - ])) - .await - .unwrap(); + )) as Arc<_>; + let count_array = Arc::new(Int32Array::from( + (i..(i + 10).min(cnt)).map(|i| i as i32).collect::>(), + )) as Arc<_>; + let rb = RecordBatch::try_new(schema.clone(), vec![name_array, count_array]).unwrap(); + writer.write(&rb).unwrap(); } - sink.close().await.unwrap(); - - drop(sink); - writer.flush().await.unwrap(); - - (path, Arc::new(schema)) + writer.close().unwrap(); + (path, schema) } async fn assert_prune(array_cnt: usize, predicate: Predicate, expect: Vec) { let dir = TempDir::new("prune_parquet").unwrap(); let (path, schema) = gen_test_parquet_file(&dir, array_cnt).await; - let file_reader = - FileReader::try_new(std::fs::File::open(path).unwrap(), None, None, None, None) - .unwrap(); - let schema = Arc::new(datatypes::schema::Schema::try_from(schema).unwrap()); - - let vec = file_reader.metadata().row_groups.clone(); - let res = predicate.prune_row_groups(schema, &vec); + let builder = ParquetRecordBatchStreamBuilder::new( + tokio::fs::OpenOptions::new() + .read(true) + .open(path) + .await + .unwrap(), + ) + .await + .unwrap(); + let metadata = builder.metadata().clone(); + let row_groups = metadata.row_groups().clone(); + let res = predicate.prune_row_groups(schema, &row_groups); assert_eq!(expect, res); } fn gen_predicate(max_val: i32, op: Operator) -> Predicate { - Predicate::new(vec![Expr::BinaryExpr { - left: Box::new(Expr::Column(Column::from_name("cnt".to_string()))), - op, - right: Box::new(max_val.lit()), - } - .into()]) + Predicate::new(vec![common_query::logical_plan::Expr::from( + Expr::BinaryExpr(BinaryExpr { + left: Box::new(Expr::Column(Column::from_name("cnt"))), + op, + right: Box::new(Expr::Literal(ScalarValue::Int32(Some(max_val)))), + }), + )]) } #[tokio::test] diff --git a/src/table/src/predicate/stats.rs b/src/table/src/predicate/stats.rs index b474eddeb1..f092cd5418 100644 --- a/src/table/src/predicate/stats.rs +++ b/src/table/src/predicate/stats.rs @@ -12,17 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -use datafusion::parquet::metadata::RowGroupMetaData; -use datafusion::parquet::statistics::{ - BinaryStatistics, BooleanStatistics, FixedLenStatistics, PrimitiveStatistics, -}; +use std::sync::Arc; + +use datafusion::parquet::file::metadata::RowGroupMetaData; +use datafusion::parquet::file::statistics::Statistics as ParquetStats; use datafusion::physical_optimizer::pruning::PruningStatistics; use datafusion_common::{Column, ScalarValue}; -use datatypes::arrow::array::ArrayRef; +use datatypes::arrow::array::{ArrayRef, UInt64Array}; use datatypes::arrow::datatypes::DataType; -use datatypes::arrow::io::parquet::read::PhysicalType; -use datatypes::prelude::Vector; -use datatypes::vectors::Int64Vector; use paste::paste; pub struct RowGroupPruningStatistics<'a> { @@ -40,92 +37,58 @@ impl<'a> RowGroupPruningStatistics<'a> { fn field_by_name(&self, name: &str) -> Option<(usize, &DataType)> { let idx = self.schema.column_index_by_name(name)?; - let data_type = &self.schema.arrow_schema().fields.get(idx)?.data_type; + let data_type = &self.schema.arrow_schema().fields.get(idx)?.data_type(); Some((idx, data_type)) } } macro_rules! impl_min_max_values { - ($self:ident, $col:ident, $min_max: ident) => { - paste! { - { - let (column_index, data_type) = $self.field_by_name(&$col.name)?; - let null_scalar: ScalarValue = data_type.try_into().ok()?; - let scalar_values: Vec = $self - .meta_data - .iter() - .flat_map(|meta| meta.column(column_index).statistics()) - .map(|stats| { - let stats = stats.ok()?; - let res = match stats.physical_type() { - PhysicalType::Boolean => { - let $min_max = stats.as_any().downcast_ref::().unwrap().[<$min_max _value>]; - Some(ScalarValue::Boolean($min_max)) - } - PhysicalType::Int32 => { - let $min_max = stats - .as_any() - .downcast_ref::>() - .unwrap() - .[<$min_max _value>]; - Some(ScalarValue::Int32($min_max)) - } - PhysicalType::Int64 => { - let $min_max = stats - .as_any() - .downcast_ref::>() - .unwrap() - .[<$min_max _value>]; - Some(ScalarValue::Int64($min_max)) - } - PhysicalType::Int96 => { - // INT96 currently not supported - None - } - PhysicalType::Float => { - let $min_max = stats - .as_any() - .downcast_ref::>() - .unwrap() - .[<$min_max _value>]; - Some(ScalarValue::Float32($min_max)) - } - PhysicalType::Double => { - let $min_max = stats - .as_any() - .downcast_ref::>() - .unwrap() - .[<$min_max _value>]; - Some(ScalarValue::Float64($min_max)) - } - PhysicalType::ByteArray => { - let $min_max = stats - .as_any() - .downcast_ref::() - .unwrap() - .[<$min_max _value>] - .clone(); - Some(ScalarValue::Binary($min_max)) - } - PhysicalType::FixedLenByteArray(_) => { - let $min_max = stats - .as_any() - .downcast_ref::() - .unwrap() - .[<$min_max _value>] - .clone(); - Some(ScalarValue::Binary($min_max)) - } - }; + ($self:ident, $col:ident, $min_max: ident) => {{ + let arrow_schema = $self.schema.arrow_schema().clone(); + let (column_index, field) = if let Some((v, f)) = arrow_schema.column_with_name(&$col.name) + { + (v, f) + } else { + return None; + }; + let data_type = field.data_type(); + let null_scalar: ScalarValue = if let Ok(v) = data_type.try_into() { + v + } else { + return None; + }; - res - }) - .map(|maybe_scalar| maybe_scalar.unwrap_or_else(|| null_scalar.clone())) - .collect::>(); - ScalarValue::iter_to_array(scalar_values).ok() - } - } - }; + let scalar_values = $self + .meta_data + .iter() + .map(|meta| { + let stats = meta.column(column_index).statistics()?; + if !stats.has_min_max_set() { + return None; + } + match stats { + ParquetStats::Boolean(s) => Some(ScalarValue::Boolean(Some(*s.$min_max()))), + ParquetStats::Int32(s) => Some(ScalarValue::Int32(Some(*s.$min_max()))), + ParquetStats::Int64(s) => Some(ScalarValue::Int64(Some(*s.$min_max()))), + + ParquetStats::Int96(_) => None, + ParquetStats::Float(s) => Some(ScalarValue::Float32(Some(*s.$min_max()))), + ParquetStats::Double(s) => Some(ScalarValue::Float64(Some(*s.$min_max()))), + ParquetStats::ByteArray(s) => { + paste! { + let s = String::from_utf8(s.[<$min_max _bytes>]().to_owned()).ok(); + } + Some(ScalarValue::Utf8(s)) + } + + ParquetStats::FixedLenByteArray(_) => None, + } + }) + .map(|maybe_scalar| maybe_scalar.unwrap_or_else(|| null_scalar.clone())) + .collect::>(); + debug_assert_eq!(scalar_values.len(), $self.meta_data.len()); + ScalarValue::iter_to_array(scalar_values).ok() + }}; } impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { @@ -143,14 +106,13 @@ impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { fn null_counts(&self, column: &Column) -> Option { let (idx, _) = self.field_by_name(&column.name)?; - let mut values: Vec> = Vec::with_capacity(self.meta_data.len()); + let mut values: Vec> = Vec::with_capacity(self.meta_data.len()); for m in self.meta_data { let col = m.column(idx); - let stat = col.statistics()?.ok()?; + let stat = col.statistics()?; let bs = stat.null_count(); - values.push(bs); + values.push(Some(bs)); } - - Some(Int64Vector::from(values).to_arrow_array()) + Some(Arc::new(UInt64Array::from(values))) } } diff --git a/src/table/src/table/adapter.rs b/src/table/src/table/adapter.rs index 32824e7a49..98ff82d08a 100644 --- a/src/table/src/table/adapter.rs +++ b/src/table/src/table/adapter.rs @@ -23,7 +23,9 @@ use datafusion::arrow::datatypes::SchemaRef as DfSchemaRef; use datafusion::datasource::datasource::TableProviderFilterPushDown as DfTableProviderFilterPushDown; use datafusion::datasource::{TableProvider, TableType as DfTableType}; use datafusion::error::Result as DfResult; -use datafusion::logical_plan::Expr as DfExpr; +use datafusion::execution::context::SessionState; +use datafusion::prelude::SessionContext; +use datafusion_expr::expr::Expr as DfExpr; use datatypes::schema::{SchemaRef as TableSchemaRef, SchemaRef}; use snafu::prelude::*; @@ -66,6 +68,7 @@ impl TableProvider for DfTableProviderAdapter { async fn scan( &self, + _ctx: &SessionState, projection: &Option>, filters: &[DfExpr], limit: Option, @@ -135,11 +138,12 @@ impl Table for TableAdapter { filters: &[Expr], limit: Option, ) -> Result { + let ctx = SessionContext::new(); let filters: Vec = filters.iter().map(|e| e.df_expr().clone()).collect(); debug!("TableScan filter size: {}", filters.len()); let execution_plan = self .table_provider - .scan(projection, &filters, limit) + .scan(&ctx.state(), projection, &filters, limit) .await .context(error::DatafusionSnafu)?; let schema: SchemaRef = Arc::new( @@ -168,7 +172,6 @@ impl Table for TableAdapter { mod tests { use datafusion::arrow; use datafusion::datasource::empty::EmptyTable; - use datafusion_common::field_util::SchemaExt; use super::*; use crate::metadata::TableType::Base; diff --git a/src/table/src/table/numbers.rs b/src/table/src/table/numbers.rs index db33769c31..46b12d0e45 100644 --- a/src/table/src/table/numbers.rs +++ b/src/table/src/table/numbers.rs @@ -19,7 +19,8 @@ use std::sync::Arc; use common_query::physical_plan::PhysicalPlanRef; use common_recordbatch::error::Result as RecordBatchResult; use common_recordbatch::{RecordBatch, RecordBatchStream}; -use datafusion_common::record_batch::RecordBatch as DfRecordBatch; +use datafusion::arrow::record_batch::RecordBatch as DfRecordBatch; +use datafusion_common::from_slice::FromSlice; use datatypes::arrow::array::UInt32Array; use datatypes::data_type::ConcreteDataType; use datatypes::schema::{ColumnSchema, SchemaBuilder, SchemaRef}; diff --git a/src/table/src/table/scan.rs b/src/table/src/table/scan.rs index 4e1ef884e7..2ee0c3109e 100644 --- a/src/table/src/table/scan.rs +++ b/src/table/src/table/scan.rs @@ -18,8 +18,9 @@ use std::sync::{Arc, Mutex}; use common_query::error as query_error; use common_query::error::Result as QueryResult; -use common_query::physical_plan::{Partitioning, PhysicalPlan, PhysicalPlanRef, RuntimeEnv}; +use common_query::physical_plan::{Partitioning, PhysicalPlan, PhysicalPlanRef}; use common_recordbatch::SendableRecordBatchStream; +use datafusion::execution::context::TaskContext; use datatypes::schema::SchemaRef; use snafu::OptionExt; @@ -71,7 +72,7 @@ impl PhysicalPlan for SimpleTableScan { fn execute( &self, _partition: usize, - _runtime: Arc, + _context: Arc, ) -> QueryResult { let mut stream = self.stream.lock().unwrap(); Ok(stream.take().context(query_error::ExecuteRepeatedlySnafu)?) @@ -81,6 +82,7 @@ impl PhysicalPlan for SimpleTableScan { #[cfg(test)] mod test { use common_recordbatch::{util, RecordBatch, RecordBatches}; + use datafusion::prelude::SessionContext; use datatypes::data_type::ConcreteDataType; use datatypes::schema::{ColumnSchema, Schema}; use datatypes::vectors::Int32Vector; @@ -89,6 +91,7 @@ mod test { #[tokio::test] async fn test_simple_table_scan() { + let ctx = SessionContext::new(); let schema = Arc::new(Schema::new(vec![ColumnSchema::new( "a", ConcreteDataType::int32_datatype(), @@ -114,13 +117,12 @@ mod test { assert_eq!(scan.schema(), schema); - let runtime = Arc::new(RuntimeEnv::default()); - let stream = scan.execute(0, runtime.clone()).unwrap(); + let stream = scan.execute(0, ctx.task_ctx()).unwrap(); let recordbatches = util::collect(stream).await.unwrap(); assert_eq!(recordbatches[0], batch1); assert_eq!(recordbatches[1], batch2); - let result = scan.execute(0, runtime); + let result = scan.execute(0, ctx.task_ctx()); assert!(result.is_err()); match result { Err(e) => assert!(e diff --git a/src/table/src/test_util/memtable.rs b/src/table/src/test_util/memtable.rs index 5f35e73c82..2fdd1228f6 100644 --- a/src/table/src/test_util/memtable.rs +++ b/src/table/src/test_util/memtable.rs @@ -197,28 +197,27 @@ impl Stream for MemtableStream { #[cfg(test)] mod test { - use common_query::physical_plan::RuntimeEnv; use common_recordbatch::util; + use datafusion::prelude::SessionContext; use datatypes::prelude::*; use datatypes::schema::ColumnSchema; - use datatypes::vectors::{Int32Vector, StringVector}; + use datatypes::vectors::{Helper, Int32Vector, StringVector}; use super::*; #[tokio::test] async fn test_scan_with_projection() { + let ctx = SessionContext::new(); let table = build_testing_table(); let scan_stream = table.scan(&Some(vec![1]), &[], None).await.unwrap(); - let scan_stream = scan_stream - .execute(0, Arc::new(RuntimeEnv::default())) - .unwrap(); + let scan_stream = scan_stream.execute(0, ctx.task_ctx()).unwrap(); let recordbatch = util::collect(scan_stream).await.unwrap(); assert_eq!(1, recordbatch.len()); let columns = recordbatch[0].df_recordbatch.columns(); assert_eq!(1, columns.len()); - let string_column = VectorHelper::try_into_vector(&columns[0]).unwrap(); + let string_column = Helper::try_into_vector(&columns[0]).unwrap(); let string_column = string_column .as_any() .downcast_ref::() @@ -229,23 +228,22 @@ mod test { #[tokio::test] async fn test_scan_with_limit() { + let ctx = SessionContext::new(); let table = build_testing_table(); let scan_stream = table.scan(&None, &[], Some(2)).await.unwrap(); - let scan_stream = scan_stream - .execute(0, Arc::new(RuntimeEnv::default())) - .unwrap(); + let scan_stream = scan_stream.execute(0, ctx.task_ctx()).unwrap(); let recordbatch = util::collect(scan_stream).await.unwrap(); assert_eq!(1, recordbatch.len()); let columns = recordbatch[0].df_recordbatch.columns(); assert_eq!(2, columns.len()); - let i32_column = VectorHelper::try_into_vector(&columns[0]).unwrap(); + let i32_column = Helper::try_into_vector(&columns[0]).unwrap(); let i32_column = i32_column.as_any().downcast_ref::().unwrap(); let i32_column = i32_column.iter_data().flatten().collect::>(); assert_eq!(vec![-100], i32_column); - let string_column = VectorHelper::try_into_vector(&columns[1]).unwrap(); + let string_column = Helper::try_into_vector(&columns[1]).unwrap(); let string_column = string_column .as_any() .downcast_ref::()