Compare commits

..

3 Commits

Author SHA1 Message Date
discord9
b6e7fb5e08 feat: async decode 2025-03-14 13:48:19 +08:00
yihong
a5df3954f3 chore: update flate2 version (#5706)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2025-03-14 02:15:27 +00:00
Ruihang Xia
32fd850c20 perf: support in list in simple filter (#5709)
* feat: support in list in simple filter

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* fix clippy

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

---------

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2025-03-14 01:08:29 +00:00
9 changed files with 343 additions and 15 deletions

24
Cargo.lock generated
View File

@@ -4119,11 +4119,12 @@ dependencies = [
[[package]]
name = "flate2"
version = "1.0.34"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0"
checksum = "11faaf5a5236997af9848be0bef4db95824b1d534ebc64d0f0c6cf3e67bd38dc"
dependencies = [
"crc32fast",
"libz-rs-sys",
"libz-sys",
"miniz_oxide",
]
@@ -6278,6 +6279,15 @@ dependencies = [
"vcpkg",
]
[[package]]
name = "libz-rs-sys"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "902bc563b5d65ad9bba616b490842ef0651066a1a1dc3ce1087113ffcb873c8d"
dependencies = [
"zlib-rs",
]
[[package]]
name = "libz-sys"
version = "1.1.20"
@@ -6822,9 +6832,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
[[package]]
name = "miniz_oxide"
version = "0.8.0"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1"
checksum = "8e3e04debbb59698c15bacbb6d93584a8c0ca9cc3213cb423d31f760d8843ce5"
dependencies = [
"adler2",
]
@@ -13954,6 +13964,12 @@ dependencies = [
"syn 2.0.96",
]
[[package]]
name = "zlib-rs"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b20717f0917c908dc63de2e44e97f1e6b126ca58d0e391cee86d504eb8fbd05"
[[package]]
name = "zstd"
version = "0.11.2+zstd.1.5.2"

View File

@@ -126,6 +126,7 @@ deadpool-postgres = "0.12"
derive_builder = "0.12"
dotenv = "0.15"
etcd-client = "0.14"
flate2 = { version = "1.1.0", default-features = false, features = ["zlib-rs"] }
fst = "0.4.7"
futures = "0.3"
futures-util = "0.3"

View File

@@ -26,6 +26,7 @@ use datafusion_common::cast::{as_boolean_array, as_null_array};
use datafusion_common::{internal_err, DataFusionError, ScalarValue};
use datatypes::arrow::array::{Array, BooleanArray, RecordBatch};
use datatypes::arrow::compute::filter_record_batch;
use datatypes::compute::or_kleene;
use datatypes::vectors::VectorRef;
use snafu::ResultExt;
@@ -47,6 +48,8 @@ pub struct SimpleFilterEvaluator {
literal: Scalar<ArrayRef>,
/// The operator.
op: Operator,
/// Only used when the operator is `Or`-chain.
literal_list: Vec<Scalar<ArrayRef>>,
}
impl SimpleFilterEvaluator {
@@ -69,6 +72,7 @@ impl SimpleFilterEvaluator {
column_name,
literal: val.to_scalar().ok()?,
op,
literal_list: vec![],
})
}
@@ -83,6 +87,35 @@ impl SimpleFilterEvaluator {
| Operator::LtEq
| Operator::Gt
| Operator::GtEq => {}
Operator::Or => {
let lhs = Self::try_new(&binary.left)?;
let rhs = Self::try_new(&binary.right)?;
if lhs.column_name != rhs.column_name
|| !matches!(lhs.op, Operator::Eq | Operator::Or)
|| !matches!(rhs.op, Operator::Eq | Operator::Or)
{
return None;
}
let mut list = vec![];
let placeholder_literal = lhs.literal.clone();
// above check guarantees the op is either `Eq` or `Or`
if matches!(lhs.op, Operator::Or) {
list.extend(lhs.literal_list);
} else {
list.push(lhs.literal);
}
if matches!(rhs.op, Operator::Or) {
list.extend(rhs.literal_list);
} else {
list.push(rhs.literal);
}
return Some(Self {
column_name: lhs.column_name,
literal: placeholder_literal,
op: Operator::Or,
literal_list: list,
});
}
_ => return None,
}
@@ -103,6 +136,7 @@ impl SimpleFilterEvaluator {
column_name: lhs.name.clone(),
literal,
op,
literal_list: vec![],
})
}
_ => None,
@@ -118,19 +152,19 @@ impl SimpleFilterEvaluator {
let input = input
.to_scalar()
.with_context(|_| ToArrowScalarSnafu { v: input.clone() })?;
let result = self.evaluate_datum(&input)?;
let result = self.evaluate_datum(&input, 1)?;
Ok(result.value(0))
}
pub fn evaluate_array(&self, input: &ArrayRef) -> Result<BooleanBuffer> {
self.evaluate_datum(input)
self.evaluate_datum(input, input.len())
}
pub fn evaluate_vector(&self, input: &VectorRef) -> Result<BooleanBuffer> {
self.evaluate_datum(&input.to_arrow_array())
self.evaluate_datum(&input.to_arrow_array(), input.len())
}
fn evaluate_datum(&self, input: &impl Datum) -> Result<BooleanBuffer> {
fn evaluate_datum(&self, input: &impl Datum, input_len: usize) -> Result<BooleanBuffer> {
let result = match self.op {
Operator::Eq => cmp::eq(input, &self.literal),
Operator::NotEq => cmp::neq(input, &self.literal),
@@ -138,6 +172,15 @@ impl SimpleFilterEvaluator {
Operator::LtEq => cmp::lt_eq(input, &self.literal),
Operator::Gt => cmp::gt(input, &self.literal),
Operator::GtEq => cmp::gt_eq(input, &self.literal),
Operator::Or => {
// OR operator stands for OR-chained EQs (or INLIST in other words)
let mut result: BooleanArray = vec![false; input_len].into();
for literal in &self.literal_list {
let rhs = cmp::eq(input, literal).context(ArrowComputeSnafu)?;
result = or_kleene(&result, &rhs).context(ArrowComputeSnafu)?;
}
Ok(result)
}
_ => {
return UnsupportedOperationSnafu {
reason: format!("{:?}", self.op),
@@ -349,4 +392,49 @@ mod test {
let expected = datatypes::arrow::array::Int32Array::from(vec![5, 6]);
assert_eq!(first_column_values, &expected);
}
#[test]
fn test_complex_filter_expression() {
// Create an expression tree for: col = 'B' OR col = 'C' OR col = 'D'
let col_eq_b = col("col").eq(lit("B"));
let col_eq_c = col("col").eq(lit("C"));
let col_eq_d = col("col").eq(lit("D"));
// Build the OR chain
let col_or_expr = col_eq_b.or(col_eq_c).or(col_eq_d);
// Check that SimpleFilterEvaluator can handle OR chain
let or_evaluator = SimpleFilterEvaluator::try_new(&col_or_expr).unwrap();
assert_eq!(or_evaluator.column_name, "col");
assert_eq!(or_evaluator.op, Operator::Or);
assert_eq!(or_evaluator.literal_list.len(), 3);
assert_eq!(format!("{:?}", or_evaluator.literal_list), "[Scalar(StringArray\n[\n \"B\",\n]), Scalar(StringArray\n[\n \"C\",\n]), Scalar(StringArray\n[\n \"D\",\n])]");
// Create a schema and batch for testing
let schema = Schema::new(vec![Field::new("col", DataType::Utf8, false)]);
let df_schema = DFSchema::try_from(schema.clone()).unwrap();
let props = ExecutionProps::new();
let physical_expr = create_physical_expr(&col_or_expr, &df_schema, &props).unwrap();
// Create test data
let col_data = Arc::new(datatypes::arrow::array::StringArray::from(vec![
"B", "C", "E", "B", "C", "D", "F",
]));
let batch = RecordBatch::try_new(Arc::new(schema), vec![col_data]).unwrap();
let expected = datatypes::arrow::array::StringArray::from(vec!["B", "C", "B", "C", "D"]);
// Filter the batch
let filtered_batch = batch_filter(&batch, &physical_expr).unwrap();
// Expected: rows with col in ("B", "C", "D")
// That would be rows 0, 1, 3, 4, 5
assert_eq!(filtered_batch.num_rows(), 5);
let col_filtered = filtered_batch
.column(0)
.as_any()
.downcast_ref::<datatypes::arrow::array::StringArray>()
.unwrap();
assert_eq!(col_filtered, &expected);
}
}

191
src/query/src/expand.rs Normal file
View File

@@ -0,0 +1,191 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use common_error::ext::BoxedError;
use common_query::logical_plan::SubstraitPlanDecoder;
use datafusion::catalog::CatalogProviderList;
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError};
use datafusion_expr::{LogicalPlan, UserDefinedLogicalNodeCore};
use snafu::ResultExt;
use crate::error::{DataFusionSnafu, Error, QueryPlanSnafu};
use crate::query_engine::DefaultPlanDecoder;
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct UnexpandedNode {
pub inner: Vec<u8>,
pub schema: DFSchemaRef,
}
impl UnexpandedNode {
pub fn new_no_schema(inner: Vec<u8>) -> Self {
Self {
inner,
schema: Arc::new(DFSchema::empty()),
}
}
}
impl PartialOrd for UnexpandedNode {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.inner.partial_cmp(&other.inner)
}
}
impl UnexpandedNode {
const NAME: &'static str = "Unexpanded";
}
impl UserDefinedLogicalNodeCore for UnexpandedNode {
fn name(&self) -> &'static str {
Self::NAME
}
fn inputs(&self) -> Vec<&LogicalPlan> {
vec![]
}
fn schema(&self) -> &DFSchemaRef {
&self.schema
}
fn with_exprs_and_inputs(
&self,
_: Vec<datafusion_expr::Expr>,
_: Vec<LogicalPlan>,
) -> datafusion_common::Result<Self> {
Ok(self.clone())
}
fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", Self::NAME)
}
fn expressions(&self) -> Vec<datafusion_expr::Expr> {
vec![]
}
}
/// Rewrite decoded `LogicalPlan` so all `UnexpandedNode` are expanded
///
/// This is a hack to support decoded substrait plan using async functions
///
/// Corresponding encode method should put custom logical node's input plan into `UnexpandedNode` after encoding into bytes
pub struct UnexpandDecoder {
pub default_decoder: DefaultPlanDecoder,
}
impl UnexpandDecoder {
pub fn new(default_decoder: DefaultPlanDecoder) -> Self {
Self { default_decoder }
}
}
impl UnexpandDecoder {
/// Decode substrait plan into `LogicalPlan` and recursively expand all unexpanded nodes
///
/// supporting async functions so our custom logical plan's input can be decoded as well
pub async fn decode(
&self,
message: bytes::Bytes,
catalog_list: Arc<dyn CatalogProviderList>,
optimize: bool,
) -> Result<LogicalPlan, Error> {
let plan = self
.default_decoder
.decode(message, catalog_list.clone(), optimize)
.await
.map_err(BoxedError::new)
.context(QueryPlanSnafu)?;
self.expand(plan, catalog_list, optimize)
.await
.map_err(BoxedError::new)
.context(QueryPlanSnafu)
}
/// Recursively expand all unexpanded nodes in the plan
pub async fn expand(
&self,
plan: LogicalPlan,
catalog_list: Arc<dyn CatalogProviderList>,
optimize: bool,
) -> Result<LogicalPlan, Error> {
let mut cur_unexpanded_node = None;
let mut root_expanded_plan = plan.clone();
loop {
root_expanded_plan
.apply(|p| {
if let LogicalPlan::Extension(node) = p {
if node.node.name() == UnexpandedNode::NAME {
let node = node.node.as_any().downcast_ref::<UnexpandedNode>().ok_or(
DataFusionError::Plan(
"Failed to downcast to UnexpandedNode".to_string(),
),
)?;
cur_unexpanded_node = Some(node.clone());
return Ok(TreeNodeRecursion::Stop);
}
}
Ok(TreeNodeRecursion::Continue)
})
.context(DataFusionSnafu)?;
if let Some(unexpanded) = cur_unexpanded_node.take() {
let decoded = self
.default_decoder
.decode(
unexpanded.inner.clone().into(),
catalog_list.clone(),
optimize,
)
.await
.map_err(BoxedError::new)
.context(QueryPlanSnafu)?;
let mut decoded = Some(decoded);
// replace it with decoded plan
// since if unexpanded the first node we encountered is the same node
root_expanded_plan = root_expanded_plan
.transform(|p| {
let Some(decoded) = decoded.take() else {
return Ok(Transformed::no(p));
};
if let LogicalPlan::Extension(node) = &p
&& node.node.name() == UnexpandedNode::NAME
{
let _ = node.node.as_any().downcast_ref::<UnexpandedNode>().ok_or(
DataFusionError::Plan(
"Failed to downcast to UnexpandedNode".to_string(),
),
)?;
Ok(Transformed::yes(decoded))
} else {
Ok(Transformed::no(p))
}
})
.context(DataFusionSnafu)?
.data;
} else {
// all node are expanded
break;
}
}
Ok(root_expanded_plan)
}
}

View File

@@ -26,6 +26,7 @@ pub mod dist_plan;
pub mod dummy_catalog;
pub mod error;
pub mod executor;
pub mod expand;
pub mod log_query;
pub mod metrics;
mod optimizer;

View File

@@ -37,16 +37,16 @@ common-telemetry.workspace = true
common-test-util.workspace = true
common-time.workspace = true
common-wal.workspace = true
datanode = { workspace = true }
datanode.workspace = true
datatypes.workspace = true
dotenv.workspace = true
flate2 = "1.0"
flate2.workspace = true
flow.workspace = true
frontend = { workspace = true, features = ["testing"] }
futures.workspace = true
futures-util.workspace = true
hyper-util = { workspace = true, features = ["tokio"] }
log-query = { workspace = true }
log-query.workspace = true
loki-proto.workspace = true
meta-client.workspace = true
meta-srv = { workspace = true, features = ["mock"] }
@@ -96,5 +96,5 @@ prost.workspace = true
rand.workspace = true
session = { workspace = true, features = ["testing"] }
store-api.workspace = true
tokio-postgres = { workspace = true }
tokio-postgres.workspace = true
url = "2.3"

View File

@@ -204,3 +204,26 @@ DROP TABLE integers;
Affected Rows: 0
CREATE TABLE characters(c STRING, t TIMESTAMP TIME INDEX);
Affected Rows: 0
INSERT INTO characters VALUES ('a', 1), ('b', 2), ('c', 3), (NULL, 4), ('a', 5), ('b', 6), ('c', 7), (NULL, 8);
Affected Rows: 8
SELECT * FROM characters WHERE c IN ('a', 'c') ORDER BY t;
+---+-------------------------+
| c | t |
+---+-------------------------+
| a | 1970-01-01T00:00:00.001 |
| c | 1970-01-01T00:00:00.003 |
| a | 1970-01-01T00:00:00.005 |
| c | 1970-01-01T00:00:00.007 |
+---+-------------------------+
DROP TABLE characters;
Affected Rows: 0

View File

@@ -57,3 +57,11 @@ SELECT * FROM (SELECT i1.i AS a, i2.i AS b, row_number() OVER (ORDER BY i1.i, i2
SELECT * FROM (SELECT 0=1 AS cond FROM integers i1, integers i2 GROUP BY 1) a1 WHERE cond ORDER BY 1;
DROP TABLE integers;
CREATE TABLE characters(c STRING, t TIMESTAMP TIME INDEX);
INSERT INTO characters VALUES ('a', 1), ('b', 2), ('c', 3), (NULL, 4), ('a', 5), ('b', 6), ('c', 7), (NULL, 8);
SELECT * FROM characters WHERE c IN ('a', 'c') ORDER BY t;
DROP TABLE characters;

View File

@@ -15,8 +15,8 @@ common-error.workspace = true
common-query.workspace = true
common-recordbatch.workspace = true
common-time.workspace = true
datatypes = { workspace = true }
flate2 = "1.0"
datatypes.workspace = true
flate2.workspace = true
hex = "0.4"
local-ip-address = "0.6"
mysql = { version = "25.0.1", default-features = false, features = ["minimal", "rustls-tls"] }
@@ -31,5 +31,5 @@ tar = "0.4"
tempfile.workspace = true
tinytemplate = "1.2"
tokio.workspace = true
tokio-postgres = { workspace = true }
tokio-postgres.workspace = true
tokio-stream.workspace = true