feat: support function aliases and add MySQL-compatible aliases (#7410)

* feat: support function aliases and add MySQL-compatible aliases

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* fix: get_table_function_source

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* refactor: add function_alias mod

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* fix: license

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

---------

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>
This commit is contained in:
dennis zhuang
2025-12-16 14:56:23 +08:00
committed by GitHub
parent f7d5c87ac0
commit 2dfcf35fee
4 changed files with 228 additions and 5 deletions

View File

@@ -41,6 +41,8 @@ use snafu::{Location, ResultExt};
use crate::error::{CatalogSnafu, Result};
use crate::query_engine::{DefaultPlanDecoder, QueryEngineState};
mod function_alias;
pub struct DfContextProviderAdapter {
engine_state: Arc<QueryEngineState>,
session_state: SessionState,
@@ -147,7 +149,17 @@ impl ContextProvider for DfContextProviderAdapter {
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
self.engine_state.scalar_function(name).map_or_else(
|| self.session_state.scalar_functions().get(name).cloned(),
|| {
self.session_state
.scalar_functions()
.get(name)
.cloned()
.or_else(|| {
function_alias::resolve_scalar(name).and_then(|name| {
self.session_state.scalar_functions().get(name).cloned()
})
})
},
|func| {
Some(Arc::new(func.provide(FunctionContext {
query_ctx: self.query_ctx.clone(),
@@ -159,7 +171,17 @@ impl ContextProvider for DfContextProviderAdapter {
fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
self.engine_state.aggr_function(name).map_or_else(
|| self.session_state.aggregate_functions().get(name).cloned(),
|| {
self.session_state
.aggregate_functions()
.get(name)
.cloned()
.or_else(|| {
function_alias::resolve_aggregate(name).and_then(|name| {
self.session_state.aggregate_functions().get(name).cloned()
})
})
},
|func| Some(Arc::new(func)),
)
}
@@ -193,12 +215,14 @@ impl ContextProvider for DfContextProviderAdapter {
fn udf_names(&self) -> Vec<String> {
let mut names = self.engine_state.scalar_names();
names.extend(self.session_state.scalar_functions().keys().cloned());
names.extend(function_alias::scalar_alias_names().map(|name| name.to_string()));
names
}
fn udaf_names(&self) -> Vec<String> {
let mut names = self.engine_state.aggr_names();
names.extend(self.session_state.aggregate_functions().keys().cloned());
names.extend(function_alias::aggregate_alias_names().map(|name| name.to_string()));
names
}
@@ -233,9 +257,14 @@ impl ContextProvider for DfContextProviderAdapter {
.table_functions()
.get(name)
.cloned()
.ok_or_else(|| {
DataFusionError::Plan(format!("table function '{name}' not found"))
})?;
.or_else(|| {
function_alias::resolve_scalar(name)
.and_then(|alias| self.session_state.table_functions().get(alias).cloned())
});
let tbl_func = tbl_func.ok_or_else(|| {
DataFusionError::Plan(format!("table function '{name}' not found"))
})?;
let provider = tbl_func.create_table_provider(&args)?;
Ok(provider_as_source(provider))

View File

@@ -0,0 +1,86 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use once_cell::sync::Lazy;
const SCALAR_ALIASES: &[(&str, &str)] = &[
// SQL compat aliases.
("ucase", "upper"),
("lcase", "lower"),
("ceiling", "ceil"),
("mid", "substr"),
// MySQL's RAND([seed]) accepts an optional seed argument, while DataFusion's `random()`
// does not. We alias the name for `rand()` compatibility, and `rand(seed)` will error
// due to mismatched arity.
("rand", "random"),
];
const AGGREGATE_ALIASES: &[(&str, &str)] = &[
// MySQL compat aliases that don't override existing DataFusion aggregate names.
//
// NOTE: We intentionally do NOT alias `stddev` here, because DataFusion defines `stddev`
// as sample standard deviation while MySQL's `STDDEV` is population standard deviation.
("std", "stddev_pop"),
("variance", "var_pop"),
];
static SCALAR_FUNCTION_ALIAS: Lazy<HashMap<&'static str, &'static str>> =
Lazy::new(|| SCALAR_ALIASES.iter().copied().collect());
static AGGREGATE_FUNCTION_ALIAS: Lazy<HashMap<&'static str, &'static str>> =
Lazy::new(|| AGGREGATE_ALIASES.iter().copied().collect());
pub fn resolve_scalar(name: &str) -> Option<&'static str> {
let name = name.to_ascii_lowercase();
SCALAR_FUNCTION_ALIAS.get(name.as_str()).copied()
}
pub fn resolve_aggregate(name: &str) -> Option<&'static str> {
let name = name.to_ascii_lowercase();
AGGREGATE_FUNCTION_ALIAS.get(name.as_str()).copied()
}
pub fn scalar_alias_names() -> impl Iterator<Item = &'static str> {
SCALAR_ALIASES.iter().map(|(name, _)| *name)
}
pub fn aggregate_alias_names() -> impl Iterator<Item = &'static str> {
AGGREGATE_ALIASES.iter().map(|(name, _)| *name)
}
#[cfg(test)]
mod tests {
use super::{resolve_aggregate, resolve_scalar};
#[test]
fn resolves_scalar_aliases_case_insensitive() {
assert_eq!(resolve_scalar("ucase"), Some("upper"));
assert_eq!(resolve_scalar("UCASE"), Some("upper"));
assert_eq!(resolve_scalar("lcase"), Some("lower"));
assert_eq!(resolve_scalar("ceiling"), Some("ceil"));
assert_eq!(resolve_scalar("MID"), Some("substr"));
assert_eq!(resolve_scalar("RAND"), Some("random"));
assert_eq!(resolve_scalar("not_a_real_alias"), None);
}
#[test]
fn resolves_aggregate_aliases_case_insensitive() {
assert_eq!(resolve_aggregate("std"), Some("stddev_pop"));
assert_eq!(resolve_aggregate("variance"), Some("var_pop"));
assert_eq!(resolve_aggregate("STDDEV"), None);
assert_eq!(resolve_aggregate("not_a_real_alias"), None);
}
}