feat: apply rewriter to subquery exprs (#2245)

* apply rewriter to subquery exprs

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* workaround for datafusion's check

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* clean up

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* add sqlness test

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* fix typo

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* change time index type

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

---------

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
This commit is contained in:
Ruihang Xia
2023-08-24 06:48:04 -05:00
committed by GitHub
parent 0a6ab2a287
commit b633a16667
5 changed files with 284 additions and 19 deletions

18
Cargo.lock generated
View File

@@ -2436,7 +2436,7 @@ dependencies = [
[[package]]
name = "datafusion"
version = "27.0.0"
source = "git+https://github.com/waynexia/arrow-datafusion.git?rev=2ceb7f927c40787773fdc466d6a4b79f3a6c0001#2ceb7f927c40787773fdc466d6a4b79f3a6c0001"
source = "git+https://github.com/waynexia/arrow-datafusion.git?rev=c0b0fca548e99d020c76e1a1cd7132aab26000e1#c0b0fca548e99d020c76e1a1cd7132aab26000e1"
dependencies = [
"ahash 0.8.3",
"arrow",
@@ -2484,7 +2484,7 @@ dependencies = [
[[package]]
name = "datafusion-common"
version = "27.0.0"
source = "git+https://github.com/waynexia/arrow-datafusion.git?rev=2ceb7f927c40787773fdc466d6a4b79f3a6c0001#2ceb7f927c40787773fdc466d6a4b79f3a6c0001"
source = "git+https://github.com/waynexia/arrow-datafusion.git?rev=c0b0fca548e99d020c76e1a1cd7132aab26000e1#c0b0fca548e99d020c76e1a1cd7132aab26000e1"
dependencies = [
"arrow",
"arrow-array",
@@ -2498,7 +2498,7 @@ dependencies = [
[[package]]
name = "datafusion-execution"
version = "27.0.0"
source = "git+https://github.com/waynexia/arrow-datafusion.git?rev=2ceb7f927c40787773fdc466d6a4b79f3a6c0001#2ceb7f927c40787773fdc466d6a4b79f3a6c0001"
source = "git+https://github.com/waynexia/arrow-datafusion.git?rev=c0b0fca548e99d020c76e1a1cd7132aab26000e1#c0b0fca548e99d020c76e1a1cd7132aab26000e1"
dependencies = [
"dashmap",
"datafusion-common",
@@ -2515,7 +2515,7 @@ dependencies = [
[[package]]
name = "datafusion-expr"
version = "27.0.0"
source = "git+https://github.com/waynexia/arrow-datafusion.git?rev=2ceb7f927c40787773fdc466d6a4b79f3a6c0001#2ceb7f927c40787773fdc466d6a4b79f3a6c0001"
source = "git+https://github.com/waynexia/arrow-datafusion.git?rev=c0b0fca548e99d020c76e1a1cd7132aab26000e1#c0b0fca548e99d020c76e1a1cd7132aab26000e1"
dependencies = [
"ahash 0.8.3",
"arrow",
@@ -2529,7 +2529,7 @@ dependencies = [
[[package]]
name = "datafusion-optimizer"
version = "27.0.0"
source = "git+https://github.com/waynexia/arrow-datafusion.git?rev=2ceb7f927c40787773fdc466d6a4b79f3a6c0001#2ceb7f927c40787773fdc466d6a4b79f3a6c0001"
source = "git+https://github.com/waynexia/arrow-datafusion.git?rev=c0b0fca548e99d020c76e1a1cd7132aab26000e1#c0b0fca548e99d020c76e1a1cd7132aab26000e1"
dependencies = [
"arrow",
"async-trait",
@@ -2546,7 +2546,7 @@ dependencies = [
[[package]]
name = "datafusion-physical-expr"
version = "27.0.0"
source = "git+https://github.com/waynexia/arrow-datafusion.git?rev=2ceb7f927c40787773fdc466d6a4b79f3a6c0001#2ceb7f927c40787773fdc466d6a4b79f3a6c0001"
source = "git+https://github.com/waynexia/arrow-datafusion.git?rev=c0b0fca548e99d020c76e1a1cd7132aab26000e1#c0b0fca548e99d020c76e1a1cd7132aab26000e1"
dependencies = [
"ahash 0.8.3",
"arrow",
@@ -2581,7 +2581,7 @@ dependencies = [
[[package]]
name = "datafusion-row"
version = "27.0.0"
source = "git+https://github.com/waynexia/arrow-datafusion.git?rev=2ceb7f927c40787773fdc466d6a4b79f3a6c0001#2ceb7f927c40787773fdc466d6a4b79f3a6c0001"
source = "git+https://github.com/waynexia/arrow-datafusion.git?rev=c0b0fca548e99d020c76e1a1cd7132aab26000e1#c0b0fca548e99d020c76e1a1cd7132aab26000e1"
dependencies = [
"arrow",
"datafusion-common",
@@ -2592,7 +2592,7 @@ dependencies = [
[[package]]
name = "datafusion-sql"
version = "27.0.0"
source = "git+https://github.com/waynexia/arrow-datafusion.git?rev=2ceb7f927c40787773fdc466d6a4b79f3a6c0001#2ceb7f927c40787773fdc466d6a4b79f3a6c0001"
source = "git+https://github.com/waynexia/arrow-datafusion.git?rev=c0b0fca548e99d020c76e1a1cd7132aab26000e1#c0b0fca548e99d020c76e1a1cd7132aab26000e1"
dependencies = [
"arrow",
"arrow-schema",
@@ -2605,7 +2605,7 @@ dependencies = [
[[package]]
name = "datafusion-substrait"
version = "27.0.0"
source = "git+https://github.com/waynexia/arrow-datafusion.git?rev=2ceb7f927c40787773fdc466d6a4b79f3a6c0001#2ceb7f927c40787773fdc466d6a4b79f3a6c0001"
source = "git+https://github.com/waynexia/arrow-datafusion.git?rev=c0b0fca548e99d020c76e1a1cd7132aab26000e1#c0b0fca548e99d020c76e1a1cd7132aab26000e1"
dependencies = [
"async-recursion",
"chrono",

View File

@@ -67,13 +67,13 @@ arrow-schema = { version = "43.0", features = ["serde"] }
async-stream = "0.3"
async-trait = "0.1"
chrono = { version = "0.4", features = ["serde"] }
datafusion = { git = "https://github.com/waynexia/arrow-datafusion.git", rev = "2ceb7f927c40787773fdc466d6a4b79f3a6c0001" }
datafusion-common = { git = "https://github.com/waynexia/arrow-datafusion.git", rev = "2ceb7f927c40787773fdc466d6a4b79f3a6c0001" }
datafusion-expr = { git = "https://github.com/waynexia/arrow-datafusion.git", rev = "2ceb7f927c40787773fdc466d6a4b79f3a6c0001" }
datafusion-optimizer = { git = "https://github.com/waynexia/arrow-datafusion.git", rev = "2ceb7f927c40787773fdc466d6a4b79f3a6c0001" }
datafusion-physical-expr = { git = "https://github.com/waynexia/arrow-datafusion.git", rev = "2ceb7f927c40787773fdc466d6a4b79f3a6c0001" }
datafusion-sql = { git = "https://github.com/waynexia/arrow-datafusion.git", rev = "2ceb7f927c40787773fdc466d6a4b79f3a6c0001" }
datafusion-substrait = { git = "https://github.com/waynexia/arrow-datafusion.git", rev = "2ceb7f927c40787773fdc466d6a4b79f3a6c0001" }
datafusion = { git = "https://github.com/waynexia/arrow-datafusion.git", rev = "c0b0fca548e99d020c76e1a1cd7132aab26000e1" }
datafusion-common = { git = "https://github.com/waynexia/arrow-datafusion.git", rev = "c0b0fca548e99d020c76e1a1cd7132aab26000e1" }
datafusion-expr = { git = "https://github.com/waynexia/arrow-datafusion.git", rev = "c0b0fca548e99d020c76e1a1cd7132aab26000e1" }
datafusion-optimizer = { git = "https://github.com/waynexia/arrow-datafusion.git", rev = "c0b0fca548e99d020c76e1a1cd7132aab26000e1" }
datafusion-physical-expr = { git = "https://github.com/waynexia/arrow-datafusion.git", rev = "c0b0fca548e99d020c76e1a1cd7132aab26000e1" }
datafusion-sql = { git = "https://github.com/waynexia/arrow-datafusion.git", rev = "c0b0fca548e99d020c76e1a1cd7132aab26000e1" }
datafusion-substrait = { git = "https://github.com/waynexia/arrow-datafusion.git", rev = "c0b0fca548e99d020c76e1a1cd7132aab26000e1" }
derive_builder = "0.12"
futures = "0.3"
futures-util = "0.3"

View File

@@ -12,11 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use common_telemetry::info;
use datafusion::datasource::DefaultTableSource;
use datafusion::error::Result as DfResult;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter};
use datafusion_expr::LogicalPlan;
use datafusion_common::tree_node::{RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter};
use datafusion_expr::expr::{Exists, InSubquery};
use datafusion_expr::utils::from_plan;
use datafusion_expr::{col, Expr, LogicalPlan, LogicalPlanBuilder, Subquery};
use datafusion_optimizer::analyzer::AnalyzerRule;
use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
use table::metadata::TableType;
@@ -39,11 +44,66 @@ impl AnalyzerRule for DistPlannerAnalyzer {
plan: LogicalPlan,
_config: &ConfigOptions,
) -> datafusion_common::Result<LogicalPlan> {
let plan = plan.transform(&Self::inspect_plan_with_subquery)?;
let mut rewriter = PlanRewriter::default();
plan.rewrite(&mut rewriter)
}
}
impl DistPlannerAnalyzer {
fn inspect_plan_with_subquery(plan: LogicalPlan) -> DfResult<Transformed<LogicalPlan>> {
let exprs = plan
.expressions()
.into_iter()
.map(|e| e.transform(&Self::transform_subquery))
.collect::<DfResult<Vec<_>>>()?;
let inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
Ok(Transformed::Yes(from_plan(&plan, &exprs, &inputs)?))
}
fn transform_subquery(expr: Expr) -> DfResult<Transformed<Expr>> {
match expr {
Expr::Exists(exists) => Ok(Transformed::Yes(Expr::Exists(Exists {
subquery: Self::handle_subquery(exists.subquery)?,
negated: exists.negated,
}))),
Expr::InSubquery(in_subquery) => Ok(Transformed::Yes(Expr::InSubquery(InSubquery {
expr: in_subquery.expr,
subquery: Self::handle_subquery(in_subquery.subquery)?,
negated: in_subquery.negated,
}))),
Expr::ScalarSubquery(scalar_subquery) => Ok(Transformed::Yes(Expr::ScalarSubquery(
Self::handle_subquery(scalar_subquery)?,
))),
_ => Ok(Transformed::No(expr)),
}
}
fn handle_subquery(subquery: Subquery) -> DfResult<Subquery> {
let mut rewriter = PlanRewriter::default();
let mut rewrote_subquery = subquery.subquery.as_ref().clone().rewrite(&mut rewriter)?;
// Workaround. DF doesn't support the first plan in subquery to be an Extension
if matches!(rewrote_subquery, LogicalPlan::Extension(_)) {
let output_schema = rewrote_subquery.schema().clone();
let project_exprs = output_schema
.fields()
.iter()
.map(|f| col(f.name()))
.collect::<Vec<_>>();
rewrote_subquery = LogicalPlanBuilder::from(rewrote_subquery)
.project(project_exprs)?
.build()?;
}
Ok(Subquery {
subquery: Arc::new(rewrote_subquery),
outer_ref_columns: subquery.outer_ref_columns,
})
}
}
/// Status of the rewriter to mark if the current pass is expanded
#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
enum RewriterStatus {
@@ -78,7 +138,7 @@ impl PlanRewriter {
/// Return true if should stop and expand. The input plan is the parent node of current node
fn should_expand(&mut self, plan: &LogicalPlan) -> bool {
if DFLogicalSubstraitConvertor.encode(plan).is_err() {
common_telemetry::info!(
info!(
"substrait error: {:?}",
DFLogicalSubstraitConvertor.encode(plan)
);
@@ -177,6 +237,7 @@ impl TreeNodeRewriter for PlanRewriter {
self.stage.clear();
self.set_unexpanded();
self.partition_cols = None;
Ok(RewriteRecursion::Continue)
}
@@ -230,6 +291,7 @@ mod test {
use std::sync::Arc;
use datafusion::datasource::DefaultTableSource;
use datafusion_common::JoinType;
use datafusion_expr::{avg, col, lit, Expr, LogicalPlanBuilder};
use table::table::adapter::DfTableProviderAdapter;
use table::table::numbers::NumbersTable;
@@ -343,4 +405,41 @@ mod test {
.join("\n");
assert_eq!(expected, format!("{:?}", result));
}
#[test]
fn transform_unalighed_join_with_alias() {
let left = NumbersTable::table(0);
let right = NumbersTable::table(1);
let left_source = Arc::new(DefaultTableSource::new(Arc::new(
DfTableProviderAdapter::new(left),
)));
let right_source = Arc::new(DefaultTableSource::new(Arc::new(
DfTableProviderAdapter::new(right),
)));
let right_plan = LogicalPlanBuilder::scan_with_filters("t", right_source, None, vec![])
.unwrap()
.alias("right")
.unwrap()
.build()
.unwrap();
let plan = LogicalPlanBuilder::scan_with_filters("t", left_source, None, vec![])
.unwrap()
.join_using(right_plan, JoinType::LeftSemi, vec!["number"])
.unwrap()
.limit(0, Some(1))
.unwrap()
.build()
.unwrap();
let config = ConfigOptions::default();
let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
let expected = "Limit: skip=0, fetch=1\
\n LeftSemi Join: Using t.number = right.number\
\n MergeScan [is_placeholder=false]\
\n SubqueryAlias: right\
\n MergeScan [is_placeholder=false]";
assert_eq!(expected, format!("{:?}", result));
}
}

View File

@@ -0,0 +1,131 @@
CREATE TABLE integers(i INTEGER, j TIMESTAMP TIME INDEX);
Affected Rows: 0
-- SQLNESS REPLACE (-+) -
-- SQLNESS REPLACE (\s\s+) _
-- SQLNESS REPLACE (RoundRobinBatch.*) REDACTED
-- SQLNESS REPLACE (Hash.*) REDACTED
-- SQLNESS REPLACE (peer-.*) REDACTED
EXPLAIN SELECT * FROM integers WHERE i IN ((SELECT i FROM integers)) ORDER BY i;
+-+-+
| plan_type_| plan_|
+-+-+
| logical_plan_| Sort: integers.i ASC NULLS LAST_|
|_|_LeftSemi Join: integers.i = __correlated_sq_1.i_|
|_|_MergeScan [is_placeholder=false]_|
|_|_SubqueryAlias: __correlated_sq_1_|
|_|_MergeScan [is_placeholder=false]_|
| physical_plan | SortPreservingMergeExec: [i@0 ASC NULLS LAST]_|
|_|_SortExec: expr=[i@0 ASC NULLS LAST]_|
|_|_CoalesceBatchesExec: target_batch_size=8192_|
|_|_REDACTED
|_|_CoalesceBatchesExec: target_batch_size=8192_|
|_|_RepartitionExec: partitioning=REDACTED
|_|_RepartitionExec: partitioning=REDACTED
|_|_MergeScanExec: peers=[REDACTED
|_|_CoalesceBatchesExec: target_batch_size=8192_|
|_|_RepartitionExec: partitioning=REDACTED
|_|_RepartitionExec: partitioning=REDACTED
|_|_MergeScanExec: peers=[REDACTED
|_|_|
+-+-+
-- SQLNESS REPLACE (-+) -
-- SQLNESS REPLACE (\s\s+) _
-- SQLNESS REPLACE (RoundRobinBatch.*) REDACTED
-- SQLNESS REPLACE (Hash.*) REDACTED
-- SQLNESS REPLACE (peer-.*) REDACTED
EXPLAIN SELECT * FROM integers i1 WHERE EXISTS(SELECT i FROM integers WHERE i=i1.i) ORDER BY i1.i;
+-+-+
| plan_type_| plan_|
+-+-+
| logical_plan_| Sort: i1.i ASC NULLS LAST_|
|_|_LeftSemi Join: i1.i = __correlated_sq_1.i_|
|_|_SubqueryAlias: i1_|
|_|_MergeScan [is_placeholder=false]_|
|_|_SubqueryAlias: __correlated_sq_1_|
|_|_Projection: integers.i_|
|_|_MergeScan [is_placeholder=false]_|
| physical_plan | SortPreservingMergeExec: [i@0 ASC NULLS LAST]_|
|_|_SortExec: expr=[i@0 ASC NULLS LAST]_|
|_|_CoalesceBatchesExec: target_batch_size=8192_|
|_|_REDACTED
|_|_CoalesceBatchesExec: target_batch_size=8192_|
|_|_RepartitionExec: partitioning=REDACTED
|_|_RepartitionExec: partitioning=REDACTED
|_|_MergeScanExec: peers=[REDACTED
|_|_CoalesceBatchesExec: target_batch_size=8192_|
|_|_RepartitionExec: partitioning=REDACTED
|_|_RepartitionExec: partitioning=REDACTED
|_|_ProjectionExec: expr=[i@0 as i]_|
|_|_MergeScanExec: peers=[REDACTED
|_|_|
+-+-+
create table other (i INTEGER, j TIMESTAMP TIME INDEX);
Affected Rows: 0
-- SQLNESS REPLACE (-+) -
-- SQLNESS REPLACE (\s\s+) _
-- SQLNESS REPLACE (RoundRobinBatch.*) REDACTED
-- SQLNESS REPLACE (Hash.*) REDACTED
-- SQLNESS REPLACE (peer-.*) REDACTED
explain select t.i
from (
select * from integers join other on 1=1
) t
where t.i is not null
order by t.i desc;
+-+-+
| plan_type_| plan_|
+-+-+
| logical_plan_| Sort: t.i DESC NULLS FIRST_|
|_|_SubqueryAlias: t_|
|_|_Inner Join:_|
|_|_Projection:_|
|_|_MergeScan [is_placeholder=false]_|
|_|_Filter: other.i IS NOT NULL_|
|_|_Projection: other.i_|
|_|_MergeScan [is_placeholder=false]_|
| physical_plan | SortExec: expr=[i@0 DESC]_|
|_|_NestedLoopJoinExec: join_type=Inner_|
|_|_ProjectionExec: expr=[]_|
|_|_MergeScanExec: peers=[REDACTED
|_|_CoalescePartitionsExec_|
|_|_CoalesceBatchesExec: target_batch_size=8192_|
|_|_FilterExec: i@0 IS NOT NULL_|
|_|_RepartitionExec: partitioning=REDACTED
|_|_ProjectionExec: expr=[i@0 as i]_|
|_|_MergeScanExec: peers=[REDACTED
|_|_|
+-+-+
EXPLAIN INSERT INTO other SELECT i, 2 FROM integers WHERE i=(SELECT MAX(i) FROM integers);
+--------------+-------------------------------------------------------------------+
| plan_type | plan |
+--------------+-------------------------------------------------------------------+
| logical_plan | Dml: op=[Insert] table=[other] |
| | Projection: integers.i AS i, TimestampMillisecond(2, None) AS j |
| | Inner Join: integers.i = __scalar_sq_1.MAX(integers.i) |
| | Projection: integers.i |
| | MergeScan [is_placeholder=false] |
| | SubqueryAlias: __scalar_sq_1 |
| | Aggregate: groupBy=[[]], aggr=[[MAX(integers.i)]] |
| | Projection: integers.i |
| | MergeScan [is_placeholder=false] |
+--------------+-------------------------------------------------------------------+
drop table other;
Affected Rows: 1
drop table integers;
Affected Rows: 1

View File

@@ -0,0 +1,35 @@
CREATE TABLE integers(i INTEGER, j TIMESTAMP TIME INDEX);
-- SQLNESS REPLACE (-+) -
-- SQLNESS REPLACE (\s\s+) _
-- SQLNESS REPLACE (RoundRobinBatch.*) REDACTED
-- SQLNESS REPLACE (Hash.*) REDACTED
-- SQLNESS REPLACE (peer-.*) REDACTED
EXPLAIN SELECT * FROM integers WHERE i IN ((SELECT i FROM integers)) ORDER BY i;
-- SQLNESS REPLACE (-+) -
-- SQLNESS REPLACE (\s\s+) _
-- SQLNESS REPLACE (RoundRobinBatch.*) REDACTED
-- SQLNESS REPLACE (Hash.*) REDACTED
-- SQLNESS REPLACE (peer-.*) REDACTED
EXPLAIN SELECT * FROM integers i1 WHERE EXISTS(SELECT i FROM integers WHERE i=i1.i) ORDER BY i1.i;
create table other (i INTEGER, j TIMESTAMP TIME INDEX);
-- SQLNESS REPLACE (-+) -
-- SQLNESS REPLACE (\s\s+) _
-- SQLNESS REPLACE (RoundRobinBatch.*) REDACTED
-- SQLNESS REPLACE (Hash.*) REDACTED
-- SQLNESS REPLACE (peer-.*) REDACTED
explain select t.i
from (
select * from integers join other on 1=1
) t
where t.i is not null
order by t.i desc;
EXPLAIN INSERT INTO other SELECT i, 2 FROM integers WHERE i=(SELECT MAX(i) FROM integers);
drop table other;
drop table integers;