tests: lossy downcast

Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
discord9
2026-04-27 20:47:02 +08:00
parent 6de646adc7
commit f921fd0a70
5 changed files with 295 additions and 2 deletions

View File

@@ -309,9 +309,12 @@ mod tests {
use std::collections::HashMap;
use std::sync::Arc;
use datafusion_common::arrow::datatypes::Field;
use datafusion::execution::SessionStateBuilder;
use datafusion_common::arrow::datatypes::{Field, TimeUnit as ArrowTimeUnit};
use datafusion_common::{Column, DFSchema};
use datafusion_expr::simplify::SimplifyContext;
use datafusion_expr::{Literal, LogicalPlanBuilder};
use datafusion_optimizer::simplify_expressions::ExprSimplifier;
use datafusion_sql::TableReference;
use session::context::QueryContext;
@@ -384,6 +387,102 @@ mod tests {
);
}
/// TODO(discord9): update this once datafusion update and fixes lossy downcast problem
#[test]
fn test_datafusion_simplifier_unwraps_timestamp_precision_cast_comparisons() {
let schema = Arc::new(
DFSchema::from_unqualified_fields(
vec![Arc::new(Field::new(
"ts",
DataType::Timestamp(ArrowTimeUnit::Nanosecond, None),
false,
))]
.into(),
HashMap::new(),
)
.unwrap(),
);
let simplifier = ExprSimplifier::new(SimplifyContext::default().with_schema(schema));
let ts = Expr::Column(Column::from_name("ts"));
let cast_ts = Expr::Cast(datafusion_expr::Cast {
expr: Box::new(ts.clone()),
data_type: DataType::Timestamp(ArrowTimeUnit::Millisecond, None),
});
let ms_lit = ScalarValue::TimestampMillisecond(Some(1000), None).lit();
let ns_lit = ScalarValue::TimestampNanosecond(Some(1_000_000_000), None).lit();
let simplify = |expr| simplifier.simplify(expr).unwrap();
assert_eq!(
simplify(cast_ts.clone().eq(ms_lit.clone())),
ts.clone().eq(ns_lit.clone()),
);
assert_eq!(
simplify(cast_ts.clone().not_eq(ms_lit.clone())),
ts.clone().not_eq(ns_lit.clone()),
);
assert_eq!(
simplify(cast_ts.clone().lt(ms_lit.clone())),
ts.clone().lt(ns_lit.clone()),
);
assert_eq!(
simplify(cast_ts.clone().lt_eq(ms_lit.clone())),
ts.clone().lt_eq(ns_lit.clone()),
);
assert_eq!(
simplify(cast_ts.clone().gt(ms_lit.clone())),
ts.clone().gt(ns_lit.clone()),
);
assert_eq!(
simplify(cast_ts.clone().gt_eq(ms_lit.clone())),
ts.clone().gt_eq(ns_lit.clone()),
);
assert_eq!(simplify(ms_lit.lt(cast_ts)), ts.gt(ns_lit),);
}
#[test]
fn test_datafusion_optimizer_pushes_filter_through_timestamp_cast_projection() {
let cast_ts = Expr::Cast(datafusion_expr::Cast {
expr: Box::new(Expr::Column(Column::from_name("column1"))),
data_type: DataType::Timestamp(ArrowTimeUnit::Millisecond, None),
});
let plan = LogicalPlanBuilder::values(vec![vec![
ScalarValue::TimestampNanosecond(Some(1_000_000_123), None).lit(),
1_i64.lit(),
]])
.unwrap()
.project(vec![
cast_ts.alias("ts_ms"),
Expr::Column(Column::from_name("column2")).alias("val"),
])
.unwrap()
.filter(
Expr::Column(Column::from_name("ts_ms")).eq(ScalarValue::TimestampMillisecond(
Some(1000),
None,
)
.lit()),
)
.unwrap()
.build()
.unwrap();
let session_state = SessionStateBuilder::new().with_default_features().build();
let optimized_plan = session_state.optimize(&plan).unwrap();
let optimized = optimized_plan.display_indent().to_string();
assert!(optimized.contains("Projection:"), "{optimized}");
assert!(
optimized.contains("Filter: column1 = TimestampNanosecond(1000000000, None)"),
"{optimized}"
);
assert!(
optimized.find("Projection:") < optimized.find("Filter:"),
"{optimized}"
);
}
#[test]
fn test_convert_timestamp_str() {
use datatypes::arrow::datatypes::TimeUnit as ArrowTimeUnit;

View File

@@ -22,7 +22,10 @@ use common_recordbatch::{RecordBatch, SendableRecordBatchStream};
use common_time::Timestamp;
use common_time::range::TimestampRange;
use common_time::timestamp::TimeUnit;
use datafusion_common::ScalarValue;
use datafusion_expr::expr::Expr;
use datafusion_expr::{col, lit};
use datatypes::arrow::datatypes::{DataType, TimeUnit as ArrowTimeUnit};
use datatypes::data_type::ConcreteDataType;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::vectors::{Int64Vector, TimestampMillisecondVector};
@@ -135,6 +138,77 @@ impl TimeRangeTester {
}
}
fn cast_to_ms_col(name: &str) -> Expr {
Expr::Cast(datafusion_expr::Cast {
expr: Box::new(col(name)),
data_type: DataType::Timestamp(ArrowTimeUnit::Millisecond, None),
})
}
fn ms_lit(value: i64) -> Expr {
lit(ScalarValue::TimestampMillisecond(Some(value), None))
}
#[test]
fn test_casted_time_index_precision_boundaries() {
let cast_ts = cast_to_ms_col("ts");
let us_bucket = TimestampRange::with_unit(1_000_000, 1_001_000, TimeUnit::Microsecond).unwrap();
assert_eq!(
us_bucket,
build_time_range_predicate(
"ts",
TimeUnit::Microsecond,
&[cast_ts.clone().eq(ms_lit(1000))],
)
);
assert!(us_bucket.contains(&Timestamp::new(1_000_999, TimeUnit::Microsecond)));
assert!(!us_bucket.contains(&Timestamp::new(1_001_000, TimeUnit::Microsecond)));
assert_eq!(
TimestampRange::until_end(Timestamp::new(1_001_000, TimeUnit::Microsecond), false),
build_time_range_predicate(
"ts",
TimeUnit::Microsecond,
&[cast_ts.clone().lt_eq(ms_lit(1000))],
)
);
let ns_bucket =
TimestampRange::with_unit(1_000_000_000, 1_001_000_000, TimeUnit::Nanosecond).unwrap();
assert_eq!(
ns_bucket,
build_time_range_predicate(
"ts",
TimeUnit::Nanosecond,
&[cast_ts.clone().eq(ms_lit(1000))],
)
);
assert!(ns_bucket.contains(&Timestamp::new(1_000_999_999, TimeUnit::Nanosecond)));
assert!(!ns_bucket.contains(&Timestamp::new(1_001_000_000, TimeUnit::Nanosecond)));
assert_eq!(
TimestampRange::from_start(Timestamp::new(1_000_000_000, TimeUnit::Nanosecond)),
build_time_range_predicate(
"ts",
TimeUnit::Nanosecond,
&[cast_ts.clone().gt_eq(ms_lit(1000))],
)
);
assert_eq!(
TimestampRange::from_start(Timestamp::new(1_001_000_000, TimeUnit::Nanosecond)),
build_time_range_predicate(
"ts",
TimeUnit::Nanosecond,
&[cast_ts.clone().gt(ms_lit(1000))],
)
);
assert_eq!(
TimestampRange::until_end(Timestamp::new(1_000_000_000, TimeUnit::Nanosecond), false),
build_time_range_predicate("ts", TimeUnit::Nanosecond, &[cast_ts.lt(ms_lit(1000))],)
);
}
#[tokio::test]
async fn test_range_filter() {
let tester = create_test_engine();

View File

@@ -323,7 +323,7 @@ fn get_casted_timestamp_filter(
) -> Option<TimestampRange> {
let (lit, op) = match (left, right) {
(expr, Expr::Literal(scalar, _)) if is_casted_time_index(expr, ts_col_name) => {
(scalar, op.clone())
(scalar, *op)
}
(Expr::Literal(scalar, _), expr) if is_casted_time_index(expr, ts_col_name) => {
(scalar, reverse_operator(op)?)

View File

@@ -0,0 +1,75 @@
-- Corresponding to issue #7913.
-- Verify a filter over a projected millisecond cast of a non-ms time index
-- is passed down to scan as a casted time-index predicate for pruning.
CREATE TABLE cast_time_index_filter_pushdown (
ts TIMESTAMP_NS NOT NULL TIME INDEX,
val BIGINT,
) ENGINE = mito
WITH
(append_mode = 'true', sst_format = 'flat');
Affected Rows: 0
INSERT INTO cast_time_index_filter_pushdown VALUES
('2023-06-12 01:04:49.999999999'::TIMESTAMP_NS, 1),
('2023-06-12 01:04:50.000000123'::TIMESTAMP_NS, 2),
('2023-06-12 01:04:50.999999999'::TIMESTAMP_NS, 3),
('2023-06-12 01:04:51.000000000'::TIMESTAMP_NS, 4);
Affected Rows: 4
ADMIN FLUSH_TABLE ('cast_time_index_filter_pushdown');
+------------------------------------------------------+
| ADMIN FLUSH_TABLE('cast_time_index_filter_pushdown') |
+------------------------------------------------------+
| 0 |
+------------------------------------------------------+
-- SQLNESS REPLACE (-+) -
-- SQLNESS REPLACE (\s\s+) _
-- SQLNESS REPLACE (peers.*) REDACTED
-- SQLNESS REPLACE (metrics.*) REDACTED
-- SQLNESS REPLACE region=\d+\(\d+,\s+\d+\) region=REDACTED
-- SQLNESS REPLACE num_ranges=\d+ num_ranges=REDACTED
-- SQLNESS REPLACE (RepartitionExec:.*) RepartitionExec: REDACTED
-- SQLNESS REPLACE "flat_format":\s\w+, "flat_format": REDACTED,
EXPLAIN ANALYZE VERBOSE
SELECT ts_ms, val
FROM (
SELECT ts::TIMESTAMP_MS AS ts_ms, val
FROM cast_time_index_filter_pushdown
) projected
WHERE ts_ms = '2023-06-12 01:04:50'::TIMESTAMP_MS
ORDER BY val;
+-+-+-+
| stage | node | plan_|
+-+-+-+
| 0_| 0_|_SortPreservingMergeExec: [val@1 ASC NULLS LAST] REDACTED
|_|_|_SortExec: expr=[val@1 ASC NULLS LAST], preserve_partitioning=[true] REDACTED
|_|_|_ProjectionExec: expr=[CAST(ts@0 AS Timestamp(ms)) as ts_ms, val@1 as val] REDACTED
|_|_|_FilterExec: ts@0 = 1686531890000000000 REDACTED
|_|_|_MergeScanExec: REDACTED
|_|_|_|
| 1_| 0_|_CooperativeExec REDACTED
|_|_|_UnorderedScan: region=REDACTED, {"partition_count":{"count":1, "mem_ranges":0, "files":1, "file_ranges":1}, "projection": ["ts", "val"], "files": [{"file_id":"4398046511104(1024, 0)/eaed16e2-3420-4acb-aa01-e2d2c5e30fa9","time_range_start":"1686531889999999999::Nanosecond","time_range_end":"1686531891000000000::Nanosecond","rows":4,"size":2356,"index_size":0}], "flat_format": REDACTED, "REDACTED
|_|_|_|
|_|_| Total rows: 0_|
+-+-+-+
SELECT ts_ms, val
FROM (
SELECT ts::TIMESTAMP_MS AS ts_ms, val
FROM cast_time_index_filter_pushdown
) projected
WHERE ts_ms = '2023-06-12 01:04:50'::TIMESTAMP_MS
ORDER BY val;
++
++
DROP TABLE cast_time_index_filter_pushdown;
Affected Rows: 0

View File

@@ -0,0 +1,45 @@
-- Corresponding to issue #7913.
-- Verify a filter over a projected millisecond cast of a non-ms time index
-- is passed down to scan as a casted time-index predicate for pruning.
CREATE TABLE cast_time_index_filter_pushdown (
ts TIMESTAMP_NS NOT NULL TIME INDEX,
val BIGINT,
) ENGINE = mito
WITH
(append_mode = 'true', sst_format = 'flat');
INSERT INTO cast_time_index_filter_pushdown VALUES
('2023-06-12 01:04:49.999999999'::TIMESTAMP_NS, 1),
('2023-06-12 01:04:50.000000123'::TIMESTAMP_NS, 2),
('2023-06-12 01:04:50.999999999'::TIMESTAMP_NS, 3),
('2023-06-12 01:04:51.000000000'::TIMESTAMP_NS, 4);
ADMIN FLUSH_TABLE ('cast_time_index_filter_pushdown');
-- SQLNESS REPLACE (-+) -
-- SQLNESS REPLACE (\s\s+) _
-- SQLNESS REPLACE (peers.*) REDACTED
-- SQLNESS REPLACE (metrics.*) REDACTED
-- SQLNESS REPLACE region=\d+\(\d+,\s+\d+\) region=REDACTED
-- SQLNESS REPLACE num_ranges=\d+ num_ranges=REDACTED
-- SQLNESS REPLACE (RepartitionExec:.*) RepartitionExec: REDACTED
-- SQLNESS REPLACE "flat_format":\s\w+, "flat_format": REDACTED,
EXPLAIN ANALYZE VERBOSE
SELECT ts_ms, val
FROM (
SELECT ts::TIMESTAMP_MS AS ts_ms, val
FROM cast_time_index_filter_pushdown
) projected
WHERE ts_ms = '2023-06-12 01:04:50'::TIMESTAMP_MS
ORDER BY val;
SELECT ts_ms, val
FROM (
SELECT ts::TIMESTAMP_MS AS ts_ms, val
FROM cast_time_index_filter_pushdown
) projected
WHERE ts_ms = '2023-06-12 01:04:50'::TIMESTAMP_MS
ORDER BY val;
DROP TABLE cast_time_index_filter_pushdown;