Skip to main content

query/datafusion/
planner.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::collections::hash_map::Entry;
17use std::sync::Arc;
18
19use arrow_schema::DataType;
20use catalog::table_source::DfTableSourceProvider;
21use common_function::function::FunctionContext;
22use datafusion::common::TableReference;
23use datafusion::datasource::cte_worktable::CteWorkTable;
24use datafusion::datasource::file_format::{FileFormatFactory, format_as_file_type};
25use datafusion::datasource::provider_as_source;
26use datafusion::error::Result as DfResult;
27use datafusion::execution::SessionStateDefaults;
28use datafusion::execution::context::SessionState;
29use datafusion::sql::planner::ContextProvider;
30use datafusion::variable::VarType;
31use datafusion_common::DataFusionError;
32use datafusion_common::config::ConfigOptions;
33use datafusion_common::file_options::file_type::FileType;
34use datafusion_expr::planner::{ExprPlanner, TypePlanner};
35use datafusion_expr::var_provider::is_system_variables;
36use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF};
37use datafusion_sql::parser::Statement as DfStatement;
38use session::context::QueryContextRef;
39use snafu::{Location, ResultExt};
40
41use crate::datafusion::json_expr_planner::JsonExprPlanner;
42use crate::error::{CatalogSnafu, Result};
43use crate::query_engine::{DefaultPlanDecoder, QueryEngineState};
44
45pub struct DfContextProviderAdapter {
46    engine_state: Arc<QueryEngineState>,
47    session_state: SessionState,
48    tables: HashMap<String, Arc<dyn TableSource>>,
49    table_provider: DfTableSourceProvider,
50    query_ctx: QueryContextRef,
51
52    // Fields from session state defaults:
53    /// Holds registered external FileFormat implementations
54    /// DataFusion doesn't pub this field, so we need to store it here.
55    file_formats: HashMap<String, Arc<dyn FileFormatFactory>>,
56    /// Provides support for customising the SQL planner, e.g. to add support for custom operators like `->>` or `?`
57    /// DataFusion doesn't pub this field, so we need to store it here.
58    expr_planners: Vec<Arc<dyn ExprPlanner>>,
59}
60
61impl DfContextProviderAdapter {
62    pub(crate) async fn try_new(
63        engine_state: Arc<QueryEngineState>,
64        session_state: SessionState,
65        df_stmt: Option<&DfStatement>,
66        query_ctx: QueryContextRef,
67    ) -> Result<Self> {
68        let table_names = if let Some(df_stmt) = df_stmt {
69            session_state.resolve_table_references(df_stmt)?
70        } else {
71            vec![]
72        };
73
74        let mut table_provider = DfTableSourceProvider::new(
75            engine_state.catalog_manager().clone(),
76            engine_state.disallow_cross_catalog_query(),
77            query_ctx.clone(),
78            Arc::new(DefaultPlanDecoder::new(session_state.clone(), &query_ctx)?),
79            session_state
80                .config_options()
81                .sql_parser
82                .enable_ident_normalization,
83        );
84
85        let tables = resolve_tables(table_names, &mut table_provider).await?;
86        let file_formats = SessionStateDefaults::default_file_formats()
87            .into_iter()
88            .map(|format| (format.get_ext().to_lowercase(), format))
89            .collect();
90
91        let mut expr_planners = SessionStateDefaults::default_expr_planners();
92        expr_planners.insert(0, Arc::new(JsonExprPlanner));
93
94        Ok(Self {
95            engine_state,
96            session_state,
97            tables,
98            table_provider,
99            query_ctx,
100            file_formats,
101            expr_planners,
102        })
103    }
104}
105
106async fn resolve_tables(
107    table_names: Vec<TableReference>,
108    table_provider: &mut DfTableSourceProvider,
109) -> Result<HashMap<String, Arc<dyn TableSource>>> {
110    let mut tables = HashMap::with_capacity(table_names.len());
111
112    for table_name in table_names {
113        let resolved_name = table_provider
114            .resolve_table_ref(table_name.clone())
115            .context(CatalogSnafu)?;
116
117        if let Entry::Vacant(v) = tables.entry(resolved_name.to_string()) {
118            // Try our best to resolve the tables here, but we don't return an error if table is not found,
119            // because the table name may be a temporary name of CTE, they can't be found until plan
120            // execution.
121            match table_provider.resolve_table(table_name).await {
122                Ok(table) => {
123                    let _ = v.insert(table);
124                }
125                Err(e) if e.should_fail() => {
126                    return Err(e).context(CatalogSnafu);
127                }
128                _ => {
129                    // ignore
130                }
131            }
132        }
133    }
134    Ok(tables)
135}
136
137impl ContextProvider for DfContextProviderAdapter {
138    fn get_table_source(&self, name: TableReference) -> DfResult<Arc<dyn TableSource>> {
139        let table_ref = self.table_provider.resolve_table_ref(name)?;
140        self.tables
141            .get(&table_ref.to_string())
142            .cloned()
143            .ok_or_else(|| {
144                crate::error::Error::TableNotFound {
145                    table: table_ref.to_string(),
146                    location: Location::default(),
147                }
148                .into()
149            })
150    }
151
152    fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
153        self.engine_state.scalar_function(name).map_or_else(
154            || self.session_state.scalar_functions().get(name).cloned(),
155            |func| {
156                Some(Arc::new(func.provide(FunctionContext {
157                    query_ctx: self.query_ctx.clone(),
158                    state: self.engine_state.function_state(),
159                })))
160            },
161        )
162    }
163
164    fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
165        self.engine_state.aggr_function(name).map_or_else(
166            || self.session_state.aggregate_functions().get(name).cloned(),
167            |func| Some(Arc::new(func)),
168        )
169    }
170
171    fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
172        self.session_state.window_functions().get(name).cloned()
173    }
174
175    fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType> {
176        if variable_names.is_empty() {
177            return None;
178        }
179
180        let provider_type = if is_system_variables(variable_names) {
181            VarType::System
182        } else {
183            VarType::UserDefined
184        };
185
186        self.session_state
187            .execution_props()
188            .var_providers
189            .as_ref()
190            .and_then(|provider| provider.get(&provider_type)?.get_type(variable_names))
191    }
192
193    fn options(&self) -> &ConfigOptions {
194        self.session_state.config_options()
195    }
196
197    fn udf_names(&self) -> Vec<String> {
198        let mut names = self.engine_state.scalar_names();
199        names.extend(self.session_state.scalar_functions().keys().cloned());
200        names
201    }
202
203    fn udaf_names(&self) -> Vec<String> {
204        let mut names = self.engine_state.aggr_names();
205        names.extend(self.session_state.aggregate_functions().keys().cloned());
206        names
207    }
208
209    fn udwf_names(&self) -> Vec<String> {
210        self.session_state
211            .window_functions()
212            .keys()
213            .cloned()
214            .collect()
215    }
216
217    fn get_file_type(&self, ext: &str) -> DfResult<Arc<dyn FileType>> {
218        self.file_formats
219            .get(&ext.to_lowercase())
220            .ok_or_else(|| {
221                DataFusionError::Plan(format!("There is no registered file format with ext {ext}"))
222            })
223            .map(|file_type| format_as_file_type(Arc::clone(file_type)))
224    }
225
226    fn get_table_function_source(
227        &self,
228        name: &str,
229        args: Vec<datafusion_expr::Expr>,
230    ) -> DfResult<Arc<dyn TableSource>> {
231        if let Some(tbl_func) = self.engine_state.table_function(name) {
232            let provider = tbl_func.create_table_provider(&args)?;
233            Ok(provider_as_source(provider))
234        } else {
235            let tbl_func = self
236                .session_state
237                .table_functions()
238                .get(name)
239                .cloned()
240                .ok_or_else(|| {
241                    DataFusionError::Plan(format!("table function '{name}' not found"))
242                })?;
243            let provider = tbl_func.create_table_provider(&args)?;
244
245            Ok(provider_as_source(provider))
246        }
247    }
248
249    fn create_cte_work_table(
250        &self,
251        name: &str,
252        schema: arrow_schema::SchemaRef,
253    ) -> DfResult<Arc<dyn TableSource>> {
254        let table = Arc::new(CteWorkTable::new(name, schema));
255        Ok(provider_as_source(table))
256    }
257
258    fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
259        &self.expr_planners
260    }
261
262    fn get_type_planner(&self) -> Option<Arc<dyn TypePlanner>> {
263        None
264    }
265}