feat: include order by to commutativity rule set (#4753)

* feat: include order by to commutativity rule set

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* tune sqlness replace interceptor

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

---------

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
This commit is contained in:
Ruihang Xia
2024-09-23 16:35:06 +08:00
committed by GitHub
parent 0f99218386
commit 2feddca1cb
4 changed files with 224 additions and 19 deletions

View File

@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashSet;
use std::sync::Arc;
use datafusion::datasource::DefaultTableSource;
@@ -19,7 +20,8 @@ use datafusion::error::Result as DfResult;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
use datafusion_expr::expr::{Exists, InSubquery};
use datafusion_expr::{col, Expr, LogicalPlan, LogicalPlanBuilder, Subquery};
use datafusion_expr::utils::expr_to_columns;
use datafusion_expr::{col as col_fn, Expr, LogicalPlan, LogicalPlanBuilder, Subquery};
use datafusion_optimizer::analyzer::AnalyzerRule;
use datafusion_optimizer::simplify_expressions::SimplifyExpressions;
use datafusion_optimizer::{OptimizerContext, OptimizerRule};
@@ -104,7 +106,7 @@ impl DistPlannerAnalyzer {
let project_exprs = output_schema
.fields()
.iter()
.map(|f| col(f.name()))
.map(|f| col_fn(f.name()))
.collect::<Vec<_>>();
rewrote_subquery = LogicalPlanBuilder::from(rewrote_subquery)
.project(project_exprs)?
@@ -137,6 +139,7 @@ struct PlanRewriter {
status: RewriterStatus,
/// Partition columns of the table in current pass
partition_cols: Option<Vec<String>>,
column_requirements: HashSet<String>,
}
impl PlanRewriter {
@@ -162,6 +165,7 @@ impl PlanRewriter {
Commutativity::Commutative => {}
Commutativity::PartialCommutative => {
if let Some(plan) = partial_commutative_transformer(plan) {
self.update_column_requirements(&plan);
self.stage.push(plan)
}
}
@@ -169,6 +173,7 @@ impl PlanRewriter {
if let Some(transformer) = transformer
&& let Some(plan) = transformer(plan)
{
self.update_column_requirements(&plan);
self.stage.push(plan)
}
}
@@ -176,6 +181,7 @@ impl PlanRewriter {
if let Some(transformer) = transformer
&& let Some(plan) = transformer(plan)
{
self.update_column_requirements(&plan);
self.stage.push(plan)
}
}
@@ -189,6 +195,18 @@ impl PlanRewriter {
false
}
fn update_column_requirements(&mut self, plan: &LogicalPlan) {
let mut container = HashSet::new();
for expr in plan.expressions() {
// this method won't fail
let _ = expr_to_columns(&expr, &mut container);
}
for col in container {
self.column_requirements.insert(col.flat_name());
}
}
fn is_expanded(&self) -> bool {
self.status == RewriterStatus::Expanded
}
@@ -238,6 +256,67 @@ impl PlanRewriter {
self.level -= 1;
self.stack.pop();
}
fn expand(&mut self, mut on_node: LogicalPlan) -> DfResult<LogicalPlan> {
let mut rewriter = EnforceDistRequirementRewriter {
column_requirements: std::mem::take(&mut self.column_requirements),
};
on_node = on_node.rewrite(&mut rewriter)?.data;
// add merge scan as the new root
let mut node = MergeScanLogicalPlan::new(on_node, false).into_logical_plan();
// expand stages
for new_stage in self.stage.drain(..) {
node = new_stage.with_new_exprs(new_stage.expressions(), vec![node.clone()])?
}
self.set_expanded();
Ok(node)
}
}
/// Implementation of the [`TreeNodeRewriter`] trait which is responsible for rewriting
/// logical plans to enforce various requirement for distributed query.
///
/// Requirements enforced by this rewriter:
/// - Enforce column requirements for `LogicalPlan::Projection` nodes. Makes sure the
/// required columns are available in the sub plan.
struct EnforceDistRequirementRewriter {
column_requirements: HashSet<String>,
}
impl TreeNodeRewriter for EnforceDistRequirementRewriter {
type Node = LogicalPlan;
fn f_down(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
if let LogicalPlan::Projection(ref projection) = node {
let mut column_requirements = std::mem::take(&mut self.column_requirements);
if column_requirements.is_empty() {
return Ok(Transformed::no(node));
}
for expr in &projection.expr {
column_requirements.remove(&expr.name_for_alias()?);
}
if column_requirements.is_empty() {
return Ok(Transformed::no(node));
}
let mut new_exprs = projection.expr.clone();
for col in &column_requirements {
new_exprs.push(col_fn(col));
}
let new_node =
node.with_new_exprs(new_exprs, node.inputs().into_iter().cloned().collect())?;
return Ok(Transformed::yes(new_node));
}
Ok(Transformed::no(node))
}
fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
Ok(Transformed::no(node))
}
}
impl TreeNodeRewriter for PlanRewriter {
@@ -274,14 +353,7 @@ impl TreeNodeRewriter for PlanRewriter {
self.maybe_set_partitions(&node);
let Some(parent) = self.get_parent() else {
// add merge scan as the new root
let mut node = MergeScanLogicalPlan::new(node, false).into_logical_plan();
// expand stages
for new_stage in self.stage.drain(..) {
node = new_stage.with_new_exprs(node.expressions(), vec![node.clone()])?
}
self.set_expanded();
let node = self.expand(node)?;
self.pop_stack();
return Ok(Transformed::yes(node));
};
@@ -289,14 +361,7 @@ impl TreeNodeRewriter for PlanRewriter {
// TODO(ruihang): avoid this clone
if self.should_expand(&parent.clone()) {
// TODO(ruihang): does this work for nodes with multiple children?;
// replace the current node with expanded one
let mut node = MergeScanLogicalPlan::new(node, false).into_logical_plan();
// expand stages
for new_stage in self.stage.drain(..) {
node = new_stage.with_new_exprs(node.expressions(), vec![node.clone()])?
}
self.set_expanded();
let node = self.expand(node)?;
self.pop_stack();
return Ok(Transformed::yes(node));
}

View File

@@ -69,7 +69,7 @@ impl Categorizer {
// sort plan needs to consider column priority
// We can implement a merge-sort on partial ordered data
Commutativity::Unimplemented
Commutativity::PartialCommutative
}
LogicalPlan::Join(_) => Commutativity::NonCommutative,
LogicalPlan::CrossJoin(_) => Commutativity::NonCommutative,