query/datafusion/
planner.rs1use std::collections::HashMap;
16use std::collections::hash_map::Entry;
17use std::sync::Arc;
18
19use arrow_schema::DataType;
20use catalog::table_source::DfTableSourceProvider;
21use common_function::function::FunctionContext;
22use datafusion::common::TableReference;
23use datafusion::datasource::cte_worktable::CteWorkTable;
24use datafusion::datasource::file_format::{FileFormatFactory, format_as_file_type};
25use datafusion::datasource::provider_as_source;
26use datafusion::error::Result as DfResult;
27use datafusion::execution::SessionStateDefaults;
28use datafusion::execution::context::SessionState;
29use datafusion::sql::planner::ContextProvider;
30use datafusion::variable::VarType;
31use datafusion_common::DataFusionError;
32use datafusion_common::config::ConfigOptions;
33use datafusion_common::file_options::file_type::FileType;
34use datafusion_expr::planner::{ExprPlanner, TypePlanner};
35use datafusion_expr::var_provider::is_system_variables;
36use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF};
37use datafusion_sql::parser::Statement as DfStatement;
38use session::context::QueryContextRef;
39use snafu::{Location, ResultExt};
40
41use crate::datafusion::json_expr_planner::JsonExprPlanner;
42use crate::error::{CatalogSnafu, Result};
43use crate::query_engine::{DefaultPlanDecoder, QueryEngineState};
44
45pub struct DfContextProviderAdapter {
46 engine_state: Arc<QueryEngineState>,
47 session_state: SessionState,
48 tables: HashMap<String, Arc<dyn TableSource>>,
49 table_provider: DfTableSourceProvider,
50 query_ctx: QueryContextRef,
51
52 file_formats: HashMap<String, Arc<dyn FileFormatFactory>>,
56 expr_planners: Vec<Arc<dyn ExprPlanner>>,
59}
60
61impl DfContextProviderAdapter {
62 pub(crate) async fn try_new(
63 engine_state: Arc<QueryEngineState>,
64 session_state: SessionState,
65 df_stmt: Option<&DfStatement>,
66 query_ctx: QueryContextRef,
67 ) -> Result<Self> {
68 let table_names = if let Some(df_stmt) = df_stmt {
69 session_state.resolve_table_references(df_stmt)?
70 } else {
71 vec![]
72 };
73
74 let mut table_provider = DfTableSourceProvider::new(
75 engine_state.catalog_manager().clone(),
76 engine_state.disallow_cross_catalog_query(),
77 query_ctx.clone(),
78 Arc::new(DefaultPlanDecoder::new(session_state.clone(), &query_ctx)?),
79 session_state
80 .config_options()
81 .sql_parser
82 .enable_ident_normalization,
83 );
84
85 let tables = resolve_tables(table_names, &mut table_provider).await?;
86 let file_formats = SessionStateDefaults::default_file_formats()
87 .into_iter()
88 .map(|format| (format.get_ext().to_lowercase(), format))
89 .collect();
90
91 let mut expr_planners = SessionStateDefaults::default_expr_planners();
92 expr_planners.insert(0, Arc::new(JsonExprPlanner));
93
94 Ok(Self {
95 engine_state,
96 session_state,
97 tables,
98 table_provider,
99 query_ctx,
100 file_formats,
101 expr_planners,
102 })
103 }
104}
105
106async fn resolve_tables(
107 table_names: Vec<TableReference>,
108 table_provider: &mut DfTableSourceProvider,
109) -> Result<HashMap<String, Arc<dyn TableSource>>> {
110 let mut tables = HashMap::with_capacity(table_names.len());
111
112 for table_name in table_names {
113 let resolved_name = table_provider
114 .resolve_table_ref(table_name.clone())
115 .context(CatalogSnafu)?;
116
117 if let Entry::Vacant(v) = tables.entry(resolved_name.to_string()) {
118 match table_provider.resolve_table(table_name).await {
122 Ok(table) => {
123 let _ = v.insert(table);
124 }
125 Err(e) if e.should_fail() => {
126 return Err(e).context(CatalogSnafu);
127 }
128 _ => {
129 }
131 }
132 }
133 }
134 Ok(tables)
135}
136
137impl ContextProvider for DfContextProviderAdapter {
138 fn get_table_source(&self, name: TableReference) -> DfResult<Arc<dyn TableSource>> {
139 let table_ref = self.table_provider.resolve_table_ref(name)?;
140 self.tables
141 .get(&table_ref.to_string())
142 .cloned()
143 .ok_or_else(|| {
144 crate::error::Error::TableNotFound {
145 table: table_ref.to_string(),
146 location: Location::default(),
147 }
148 .into()
149 })
150 }
151
152 fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
153 self.engine_state.scalar_function(name).map_or_else(
154 || self.session_state.scalar_functions().get(name).cloned(),
155 |func| {
156 Some(Arc::new(func.provide(FunctionContext {
157 query_ctx: self.query_ctx.clone(),
158 state: self.engine_state.function_state(),
159 })))
160 },
161 )
162 }
163
164 fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
165 self.engine_state.aggr_function(name).map_or_else(
166 || self.session_state.aggregate_functions().get(name).cloned(),
167 |func| Some(Arc::new(func)),
168 )
169 }
170
171 fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
172 self.session_state.window_functions().get(name).cloned()
173 }
174
175 fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType> {
176 if variable_names.is_empty() {
177 return None;
178 }
179
180 let provider_type = if is_system_variables(variable_names) {
181 VarType::System
182 } else {
183 VarType::UserDefined
184 };
185
186 self.session_state
187 .execution_props()
188 .var_providers
189 .as_ref()
190 .and_then(|provider| provider.get(&provider_type)?.get_type(variable_names))
191 }
192
193 fn options(&self) -> &ConfigOptions {
194 self.session_state.config_options()
195 }
196
197 fn udf_names(&self) -> Vec<String> {
198 let mut names = self.engine_state.scalar_names();
199 names.extend(self.session_state.scalar_functions().keys().cloned());
200 names
201 }
202
203 fn udaf_names(&self) -> Vec<String> {
204 let mut names = self.engine_state.aggr_names();
205 names.extend(self.session_state.aggregate_functions().keys().cloned());
206 names
207 }
208
209 fn udwf_names(&self) -> Vec<String> {
210 self.session_state
211 .window_functions()
212 .keys()
213 .cloned()
214 .collect()
215 }
216
217 fn get_file_type(&self, ext: &str) -> DfResult<Arc<dyn FileType>> {
218 self.file_formats
219 .get(&ext.to_lowercase())
220 .ok_or_else(|| {
221 DataFusionError::Plan(format!("There is no registered file format with ext {ext}"))
222 })
223 .map(|file_type| format_as_file_type(Arc::clone(file_type)))
224 }
225
226 fn get_table_function_source(
227 &self,
228 name: &str,
229 args: Vec<datafusion_expr::Expr>,
230 ) -> DfResult<Arc<dyn TableSource>> {
231 if let Some(tbl_func) = self.engine_state.table_function(name) {
232 let provider = tbl_func.create_table_provider(&args)?;
233 Ok(provider_as_source(provider))
234 } else {
235 let tbl_func = self
236 .session_state
237 .table_functions()
238 .get(name)
239 .cloned()
240 .ok_or_else(|| {
241 DataFusionError::Plan(format!("table function '{name}' not found"))
242 })?;
243 let provider = tbl_func.create_table_provider(&args)?;
244
245 Ok(provider_as_source(provider))
246 }
247 }
248
249 fn create_cte_work_table(
250 &self,
251 name: &str,
252 schema: arrow_schema::SchemaRef,
253 ) -> DfResult<Arc<dyn TableSource>> {
254 let table = Arc::new(CteWorkTable::new(name, schema));
255 Ok(provider_as_source(table))
256 }
257
258 fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
259 &self.expr_planners
260 }
261
262 fn get_type_planner(&self) -> Option<Arc<dyn TypePlanner>> {
263 None
264 }
265}