diff --git a/src/operator/src/statement/ddl.rs b/src/operator/src/statement/ddl.rs index d1c380306b..5debbb0548 100644 --- a/src/operator/src/statement/ddl.rs +++ b/src/operator/src/statement/ddl.rs @@ -2169,6 +2169,7 @@ mod test { use sql::dialect::GreptimeDbDialect; use sql::parser::{ParseOptions, ParserContext}; use sql::statements::statement::Statement; + use sqlparser::parser::Parser; use super::*; use crate::expr_helper; @@ -2186,6 +2187,39 @@ mod test { assert!(!NAME_PATTERN_REG.is_match("#")); } + #[test] + fn test_partition_expr_equivalence_with_swapped_operands() { + let column_name = "device_id".to_string(); + let column_name_and_type = + HashMap::from([(&column_name, ConcreteDataType::int32_datatype())]); + let timezone = Timezone::from_tz_string("UTC").unwrap(); + let dialect = GreptimeDbDialect {}; + + let mut parser = Parser::new(&dialect) + .try_with_sql("device_id < 100") + .unwrap(); + let expr_left = parser.parse_expr().unwrap(); + + let mut parser = Parser::new(&dialect) + .try_with_sql("100 > device_id") + .unwrap(); + let expr_right = parser.parse_expr().unwrap(); + + let partition_left = + convert_one_expr(&expr_left, &column_name_and_type, &timezone).unwrap(); + let partition_right = + convert_one_expr(&expr_right, &column_name_and_type, &timezone).unwrap(); + + assert_eq!(partition_left, partition_right); + assert!([partition_left.clone()].contains(&partition_right)); + + let mut physical_partition_exprs = vec![partition_left]; + let mut logical_partition_exprs = vec![partition_right]; + physical_partition_exprs.sort_unstable(); + logical_partition_exprs.sort_unstable(); + assert_eq!(physical_partition_exprs, logical_partition_exprs); + } + #[tokio::test] #[ignore = "TODO(ruihang): WIP new partition rule"] async fn test_parse_partitions() { diff --git a/src/partition/src/expr.rs b/src/partition/src/expr.rs index 31eb910fdb..317a9b2711 100644 --- a/src/partition/src/expr.rs +++ b/src/partition/src/expr.rs @@ -185,6 +185,19 @@ impl RestrictedOp { Self::Or => ParserBinaryOperator::Or, } } + + fn invert_for_swap(&self) -> Self { + match self { + Self::Eq => Self::Eq, + Self::NotEq => Self::NotEq, + Self::Lt => Self::Gt, + Self::LtEq => Self::GtEq, + Self::Gt => Self::Lt, + Self::GtEq => Self::LtEq, + Self::And => Self::And, + Self::Or => Self::Or, + } + } } impl Display for RestrictedOp { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { @@ -208,6 +221,32 @@ impl PartitionExpr { op, rhs: Box::new(rhs), } + .canonicalize() + } + + /// Canonicalize to `Column op Value` form when possible for consistent equality checks. + pub fn canonicalize(self) -> Self { + let lhs = Self::canonicalize_operand(*self.lhs); + let rhs = Self::canonicalize_operand(*self.rhs); + let mut expr = Self { + lhs: Box::new(lhs), + op: self.op, + rhs: Box::new(rhs), + }; + + if matches!(&*expr.lhs, Operand::Value(_)) && matches!(&*expr.rhs, Operand::Column(_)) { + std::mem::swap(&mut expr.lhs, &mut expr.rhs); + expr.op = expr.op.invert_for_swap(); + } + + expr + } + + fn canonicalize_operand(operand: Operand) -> Operand { + match operand { + Operand::Expr(expr) => Operand::Expr(expr.canonicalize()), + other => other, + } } /// Convert [Self] back to sqlparser's [Expr] @@ -354,7 +393,7 @@ impl PartitionExpr { let bound: PartitionBound = serde_json::from_str(s).context(error::DeserializeJsonSnafu)?; match bound { - PartitionBound::Expr(expr) => Ok(Some(expr)), + PartitionBound::Expr(expr) => Ok(Some(expr.canonicalize())), _ => Ok(None), } }