fix: prom cast to f64 (#7840)

* fix: cast to f64

Signed-off-by: discord9 <discord9@163.com>

* test: div case

Signed-off-by: discord9 <discord9@163.com>

* test: int test

Signed-off-by: discord9 <discord9@163.com>

* chore: sqlness update

Signed-off-by: discord9 <discord9@163.com>

* chore: test

Signed-off-by: discord9 <discord9@163.com>

* chore: update test

Signed-off-by: discord9 <discord9@163.com>

---------

Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
discord9
2026-03-24 14:24:52 +08:00
committed by GitHub
parent 9bd983ea40
commit 30e895abbe
6 changed files with 494 additions and 62 deletions

View File

@@ -3323,28 +3323,55 @@ impl PromPlanner {
fn prom_token_to_binary_expr_builder(
token: TokenType,
) -> Result<Box<dyn Fn(DfExpr, DfExpr) -> Result<DfExpr>>> {
let cast_float = |expr| {
if matches!(
&expr,
DfExpr::Cast(Cast {
data_type: ArrowDataType::Float64,
..
})
) || matches!(&expr, DfExpr::Literal(ScalarValue::Float64(_), _))
{
expr
} else {
DfExpr::Cast(Cast {
expr: Box::new(expr),
data_type: ArrowDataType::Float64,
})
}
};
match token.id() {
token::T_ADD => Ok(Box::new(|lhs, rhs| Ok(lhs + rhs))),
token::T_SUB => Ok(Box::new(|lhs, rhs| Ok(lhs - rhs))),
token::T_MUL => Ok(Box::new(|lhs, rhs| Ok(lhs * rhs))),
token::T_DIV => Ok(Box::new(|lhs, rhs| Ok(lhs / rhs))),
token::T_MOD => Ok(Box::new(|lhs: DfExpr, rhs| Ok(lhs % rhs))),
token::T_ADD => Ok(Box::new(move |lhs, rhs| {
Ok(cast_float(lhs) + cast_float(rhs))
})),
token::T_SUB => Ok(Box::new(move |lhs, rhs| {
Ok(cast_float(lhs) - cast_float(rhs))
})),
token::T_MUL => Ok(Box::new(move |lhs, rhs| {
Ok(cast_float(lhs) * cast_float(rhs))
})),
token::T_DIV => Ok(Box::new(move |lhs, rhs| {
Ok(cast_float(lhs) / cast_float(rhs))
})),
token::T_MOD => Ok(Box::new(move |lhs: DfExpr, rhs| {
Ok(cast_float(lhs) % cast_float(rhs))
})),
token::T_EQLC => Ok(Box::new(|lhs, rhs| Ok(lhs.eq(rhs)))),
token::T_NEQ => Ok(Box::new(|lhs, rhs| Ok(lhs.not_eq(rhs)))),
token::T_GTR => Ok(Box::new(|lhs, rhs| Ok(lhs.gt(rhs)))),
token::T_LSS => Ok(Box::new(|lhs, rhs| Ok(lhs.lt(rhs)))),
token::T_GTE => Ok(Box::new(|lhs, rhs| Ok(lhs.gt_eq(rhs)))),
token::T_LTE => Ok(Box::new(|lhs, rhs| Ok(lhs.lt_eq(rhs)))),
token::T_POW => Ok(Box::new(|lhs, rhs| {
token::T_POW => Ok(Box::new(move |lhs, rhs| {
Ok(DfExpr::ScalarFunction(ScalarFunction {
func: datafusion_functions::math::power(),
args: vec![lhs, rhs],
args: vec![cast_float(lhs), cast_float(rhs)],
}))
})),
token::T_ATAN2 => Ok(Box::new(|lhs, rhs| {
token::T_ATAN2 => Ok(Box::new(move |lhs, rhs| {
Ok(DfExpr::ScalarFunction(ScalarFunction {
func: datafusion_functions::math::atan2(),
args: vec![lhs, rhs],
args: vec![cast_float(lhs), cast_float(rhs)],
}))
})),
_ => UnexpectedTokenSnafu { token }.fail(),
@@ -5169,7 +5196,7 @@ mod test {
.unwrap();
let expected = String::from(
"Projection: rhs.tag_0, rhs.timestamp, lhs.field_0 + rhs.field_0 AS lhs.field_0 + rhs.field_0 [tag_0:Utf8, timestamp:Timestamp(ms), lhs.field_0 + rhs.field_0:Float64;N]\
"Projection: rhs.tag_0, rhs.timestamp, CAST(lhs.field_0 AS Float64) + CAST(rhs.field_0 AS Float64) AS lhs.field_0 + rhs.field_0 [tag_0:Utf8, timestamp:Timestamp(ms), lhs.field_0 + rhs.field_0:Float64;N]\
\n Inner Join: lhs.tag_0 = rhs.tag_0, lhs.timestamp = rhs.timestamp [tag_0:Utf8, timestamp:Timestamp(ms), field_0:Float64;N, tag_0:Utf8, timestamp:Timestamp(ms), field_0:Float64;N]\
\n SubqueryAlias: lhs [tag_0:Utf8, timestamp:Timestamp(ms), field_0:Float64;N]\
\n PromInstantManipulate: range=[0..100000000], lookback=[1000], interval=[5000], time index=[timestamp] [tag_0:Utf8, timestamp:Timestamp(ms), field_0:Float64;N]\
@@ -5224,7 +5251,7 @@ mod test {
async fn binary_op_literal_column() {
let query = r#"1 + some_metric{tag_0="bar"}"#;
let expected = String::from(
"Projection: some_metric.tag_0, some_metric.timestamp, Float64(1) + some_metric.field_0 AS Float64(1) + field_0 [tag_0:Utf8, timestamp:Timestamp(ms), Float64(1) + field_0:Float64;N]\
"Projection: some_metric.tag_0, some_metric.timestamp, Float64(1) + CAST(some_metric.field_0 AS Float64) AS Float64(1) + field_0 [tag_0:Utf8, timestamp:Timestamp(ms), Float64(1) + field_0:Float64;N]\
\n PromInstantManipulate: range=[0..100000000], lookback=[1000], interval=[5000], time index=[timestamp] [tag_0:Utf8, timestamp:Timestamp(ms), field_0:Float64;N]\
\n PromSeriesDivide: tags=[\"tag_0\"] [tag_0:Utf8, timestamp:Timestamp(ms), field_0:Float64;N]\
\n Sort: some_metric.tag_0 ASC NULLS FIRST, some_metric.timestamp ASC NULLS FIRST [tag_0:Utf8, timestamp:Timestamp(ms), field_0:Float64;N]\
@@ -5262,7 +5289,7 @@ mod test {
async fn bool_with_additional_arithmetic() {
let query = "some_metric + (1 == bool 2)";
let expected = String::from(
"Projection: some_metric.tag_0, some_metric.timestamp, some_metric.field_0 + CAST(Float64(1) = Float64(2) AS Float64) AS field_0 + Float64(1) = Float64(2) [tag_0:Utf8, timestamp:Timestamp(ms), field_0 + Float64(1) = Float64(2):Float64;N]\
"Projection: some_metric.tag_0, some_metric.timestamp, CAST(some_metric.field_0 AS Float64) + CAST(Float64(1) = Float64(2) AS Float64) AS field_0 + Float64(1) = Float64(2) [tag_0:Utf8, timestamp:Timestamp(ms), field_0 + Float64(1) = Float64(2):Float64;N]\
\n PromInstantManipulate: range=[0..100000000], lookback=[1000], interval=[5000], time index=[timestamp] [tag_0:Utf8, timestamp:Timestamp(ms), field_0:Float64;N]\
\n PromSeriesDivide: tags=[\"tag_0\"] [tag_0:Utf8, timestamp:Timestamp(ms), field_0:Float64;N]\
\n Sort: some_metric.tag_0 ASC NULLS FIRST, some_metric.timestamp ASC NULLS FIRST [tag_0:Utf8, timestamp:Timestamp(ms), field_0:Float64;N]\
@@ -5372,7 +5399,7 @@ mod test {
PromPlanner::stmt_to_plan(table_provider, &eval_stmt, &build_query_engine_state())
.await
.unwrap();
let expected = "Projection: http_server_requests_seconds_count.uri, http_server_requests_seconds_count.kubernetes_namespace, http_server_requests_seconds_count.kubernetes_pod_name, http_server_requests_seconds_count.greptime_timestamp, http_server_requests_seconds_sum.greptime_value / http_server_requests_seconds_count.greptime_value AS http_server_requests_seconds_sum.greptime_value / http_server_requests_seconds_count.greptime_value\
let expected = "Projection: http_server_requests_seconds_count.uri, http_server_requests_seconds_count.kubernetes_namespace, http_server_requests_seconds_count.kubernetes_pod_name, http_server_requests_seconds_count.greptime_timestamp, CAST(http_server_requests_seconds_sum.greptime_value AS Float64) / CAST(http_server_requests_seconds_count.greptime_value AS Float64) AS http_server_requests_seconds_sum.greptime_value / http_server_requests_seconds_count.greptime_value\
\n Inner Join: http_server_requests_seconds_sum.greptime_timestamp = http_server_requests_seconds_count.greptime_timestamp, http_server_requests_seconds_sum.uri = http_server_requests_seconds_count.uri\
\n SubqueryAlias: http_server_requests_seconds_sum\
\n PromInstantManipulate: range=[0..100000000], lookback=[1000], interval=[5000], time index=[greptime_timestamp]\
@@ -5763,7 +5790,7 @@ mod test {
let query = "some_alt_metric{__schema__=\"greptime_private\"} / some_metric";
let expected = String::from(
"Projection: some_metric.tag_0, some_metric.timestamp, greptime_private.some_alt_metric.field_0 / some_metric.field_0 AS greptime_private.some_alt_metric.field_0 / some_metric.field_0 [tag_0:Utf8, timestamp:Timestamp(ms), greptime_private.some_alt_metric.field_0 / some_metric.field_0:Float64;N]\
"Projection: some_metric.tag_0, some_metric.timestamp, CAST(greptime_private.some_alt_metric.field_0 AS Float64) / CAST(some_metric.field_0 AS Float64) AS greptime_private.some_alt_metric.field_0 / some_metric.field_0 [tag_0:Utf8, timestamp:Timestamp(ms), greptime_private.some_alt_metric.field_0 / some_metric.field_0:Float64;N]\
\n Inner Join: greptime_private.some_alt_metric.tag_0 = some_metric.tag_0, greptime_private.some_alt_metric.timestamp = some_metric.timestamp [tag_0:Utf8, timestamp:Timestamp(ms), field_0:Float64;N, tag_0:Utf8, timestamp:Timestamp(ms), field_0:Float64;N]\
\n SubqueryAlias: greptime_private.some_alt_metric [tag_0:Utf8, timestamp:Timestamp(ms), field_0:Float64;N]\
\n PromInstantManipulate: range=[0..100000000], lookback=[1000], interval=[5000], time index=[timestamp] [tag_0:Utf8, timestamp:Timestamp(ms), field_0:Float64;N]\