Skip to main content

catalog/
table_source.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::sync::Arc;
17
18use bytes::Bytes;
19use common_catalog::format_full_table_name;
20use common_query::logical_plan::{SubstraitPlanDecoderRef, rename_logical_plan_columns};
21use datafusion::common::{ResolvedTableReference, TableReference};
22use datafusion::datasource::view::ViewTable;
23use datafusion::datasource::{TableProvider, provider_as_source};
24use datafusion::logical_expr::TableSource;
25use itertools::Itertools;
26use session::context::QueryContextRef;
27use snafu::{OptionExt, ResultExt, ensure};
28use table::metadata::TableType;
29use table::table::adapter::DfTableProviderAdapter;
30pub mod dummy_catalog;
31use dummy_catalog::DummyCatalogList;
32use table::TableRef;
33
34use crate::CatalogManagerRef;
35use crate::error::{
36    CastManagerSnafu, DecodePlanSnafu, GetViewCacheSnafu, ProjectViewColumnsSnafu,
37    QueryAccessDeniedSnafu, Result, TableNotExistSnafu, ViewInfoNotFoundSnafu,
38    ViewPlanColumnsChangedSnafu,
39};
40use crate::kvbackend::KvBackendCatalogManager;
41
42pub struct DfTableSourceProvider {
43    catalog_manager: CatalogManagerRef,
44    resolved_tables: HashMap<String, Arc<dyn TableSource>>,
45    disallow_cross_catalog_query: bool,
46    default_catalog: String,
47    default_schema: String,
48    query_ctx: QueryContextRef,
49    plan_decoder: SubstraitPlanDecoderRef,
50    enable_ident_normalization: bool,
51}
52
53impl DfTableSourceProvider {
54    pub fn new(
55        catalog_manager: CatalogManagerRef,
56        disallow_cross_catalog_query: bool,
57        query_ctx: QueryContextRef,
58        plan_decoder: SubstraitPlanDecoderRef,
59        enable_ident_normalization: bool,
60    ) -> Self {
61        Self {
62            catalog_manager,
63            disallow_cross_catalog_query,
64            resolved_tables: HashMap::new(),
65            default_catalog: query_ctx.current_catalog().to_owned(),
66            default_schema: query_ctx.current_schema(),
67            query_ctx,
68            plan_decoder,
69            enable_ident_normalization,
70        }
71    }
72
73    /// Returns the query context.
74    pub fn query_ctx(&self) -> &QueryContextRef {
75        &self.query_ctx
76    }
77
78    pub fn resolve_table_ref(&self, table_ref: TableReference) -> Result<ResolvedTableReference> {
79        if self.disallow_cross_catalog_query {
80            match &table_ref {
81                TableReference::Bare { .. } | TableReference::Partial { .. } => {}
82                TableReference::Full {
83                    catalog, schema, ..
84                } => {
85                    ensure!(
86                        catalog.as_ref() == self.default_catalog,
87                        QueryAccessDeniedSnafu {
88                            catalog: catalog.as_ref(),
89                            schema: schema.as_ref(),
90                        }
91                    );
92                }
93            };
94        }
95
96        Ok(table_ref.resolve(&self.default_catalog, &self.default_schema))
97    }
98
99    pub async fn resolve_table(
100        &mut self,
101        table_ref: TableReference,
102    ) -> Result<Arc<dyn TableSource>> {
103        let table_ref = self.resolve_table_ref(table_ref)?;
104
105        let resolved_name = table_ref.to_string();
106        if let Some(table) = self.resolved_tables.get(&resolved_name) {
107            return Ok(table.clone());
108        }
109
110        let catalog_name = table_ref.catalog.as_ref();
111        let schema_name = table_ref.schema.as_ref();
112        let table_name = table_ref.table.as_ref();
113
114        let table = self
115            .catalog_manager
116            .table(catalog_name, schema_name, table_name, Some(&self.query_ctx))
117            .await?
118            .with_context(|| TableNotExistSnafu {
119                table: format_full_table_name(catalog_name, schema_name, table_name),
120            })?;
121
122        let provider: Arc<dyn TableProvider> = if table.table_info().table_type == TableType::View {
123            self.create_view_provider(&table).await?
124        } else {
125            Arc::new(DfTableProviderAdapter::new(table))
126        };
127
128        let source = provider_as_source(provider);
129
130        let _ = self.resolved_tables.insert(resolved_name, source.clone());
131        Ok(source)
132    }
133
134    async fn create_view_provider(&self, table: &TableRef) -> Result<Arc<dyn TableProvider>> {
135        let catalog_manager = self
136            .catalog_manager
137            .as_any()
138            .downcast_ref::<KvBackendCatalogManager>()
139            .context(CastManagerSnafu)?;
140
141        let view_info = catalog_manager
142            .view_info_cache()?
143            .get(table.table_info().ident.table_id)
144            .await
145            .context(GetViewCacheSnafu)?
146            .context(ViewInfoNotFoundSnafu {
147                name: &table.table_info().name,
148            })?;
149
150        // Build the catalog list provider for deserialization.
151        let catalog_list = Arc::new(DummyCatalogList::new(self.catalog_manager.clone()));
152        let logical_plan = self
153            .plan_decoder
154            .decode(
155                Bytes::from(view_info.view_info.clone()),
156                catalog_list,
157                false,
158            )
159            .await
160            .context(DecodePlanSnafu {
161                name: &table.table_info().name,
162            })?;
163
164        let columns: Vec<_> = view_info.columns.iter().map(|c| c.as_str()).collect();
165
166        let original_plan_columns: Vec<_> =
167            view_info.plan_columns.iter().map(|c| c.as_str()).collect();
168
169        let plan_columns: Vec<_> = logical_plan
170            .schema()
171            .columns()
172            .into_iter()
173            .map(|c| c.name)
174            .collect();
175
176        // Only check columns number, because substrait doesn't include aliases currently.
177        // See https://github.com/apache/datafusion/issues/10815#issuecomment-2158666881
178        // and https://github.com/apache/datafusion/issues/6489
179        // TODO(dennis): check column names
180        ensure!(
181            original_plan_columns.len() == plan_columns.len(),
182            ViewPlanColumnsChangedSnafu {
183                origin_names: original_plan_columns.iter().join(","),
184                actual_names: plan_columns.iter().join(","),
185            }
186        );
187
188        // We have to do `columns` projection here, because
189        // substrait doesn't include aliases neither for tables nor for columns:
190        // https://github.com/apache/datafusion/issues/10815#issuecomment-2158666881
191        let logical_plan = if !columns.is_empty() {
192            rename_logical_plan_columns(
193                self.enable_ident_normalization,
194                logical_plan,
195                plan_columns
196                    .iter()
197                    .map(|c| c.as_str())
198                    .zip(columns)
199                    .collect(),
200            )
201            .context(ProjectViewColumnsSnafu)?
202        } else {
203            logical_plan
204        };
205
206        Ok(Arc::new(ViewTable::new(
207            logical_plan,
208            Some(view_info.definition.clone()),
209        )))
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use common_query::test_util::DummyDecoder;
216    use session::context::QueryContext;
217
218    use super::*;
219    use crate::kvbackend::KvBackendCatalogManagerBuilder;
220    use crate::memory::MemoryCatalogManager;
221
222    #[test]
223    fn test_validate_table_ref() {
224        let query_ctx = Arc::new(QueryContext::with("greptime", "public"));
225
226        let table_provider = DfTableSourceProvider::new(
227            MemoryCatalogManager::with_default_setup(),
228            true,
229            query_ctx.clone(),
230            DummyDecoder::arc(),
231            true,
232        );
233
234        let table_ref = TableReference::bare("table_name");
235        let result = table_provider.resolve_table_ref(table_ref);
236        assert!(result.is_ok());
237
238        let table_ref = TableReference::partial("public", "table_name");
239        let result = table_provider.resolve_table_ref(table_ref);
240        assert!(result.is_ok());
241
242        let table_ref = TableReference::partial("wrong_schema", "table_name");
243        let result = table_provider.resolve_table_ref(table_ref);
244        assert!(result.is_ok());
245
246        let table_ref = TableReference::full("greptime", "public", "table_name");
247        let result = table_provider.resolve_table_ref(table_ref);
248        assert!(result.is_ok());
249
250        let table_ref = TableReference::full("wrong_catalog", "public", "table_name");
251        let result = table_provider.resolve_table_ref(table_ref);
252        assert!(result.is_err());
253
254        let table_ref = TableReference::partial("information_schema", "columns");
255        let result = table_provider.resolve_table_ref(table_ref);
256        assert!(result.is_ok());
257
258        let table_ref = TableReference::full("greptime", "information_schema", "columns");
259        assert!(table_provider.resolve_table_ref(table_ref).is_ok());
260
261        let table_ref = TableReference::full("dummy", "information_schema", "columns");
262        assert!(table_provider.resolve_table_ref(table_ref).is_err());
263
264        let table_ref = TableReference::full("greptime", "greptime_private", "columns");
265        assert!(table_provider.resolve_table_ref(table_ref).is_ok());
266    }
267
268    use std::collections::HashSet;
269
270    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
271    use cache::{build_fundamental_cache_registry, with_default_composite_cache_registry};
272    use common_meta::cache::{CacheRegistryBuilder, LayeredCacheRegistryBuilder};
273    use common_meta::key::TableMetadataManager;
274    use common_meta::kv_backend::memory::MemoryKvBackend;
275    use common_query::error::Result as QueryResult;
276    use common_query::logical_plan::SubstraitPlanDecoder;
277    use datafusion::catalog::CatalogProviderList;
278    use datafusion::logical_expr::builder::LogicalTableSource;
279    use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder, col, lit};
280
281    use crate::information_schema::NoopInformationExtension;
282
283    struct MockDecoder;
284    impl MockDecoder {
285        pub fn arc() -> Arc<Self> {
286            Arc::new(MockDecoder)
287        }
288    }
289
290    #[async_trait::async_trait]
291    impl SubstraitPlanDecoder for MockDecoder {
292        async fn decode(
293            &self,
294            _message: bytes::Bytes,
295            _catalog_list: Arc<dyn CatalogProviderList>,
296            _optimize: bool,
297        ) -> QueryResult<LogicalPlan> {
298            Ok(mock_plan())
299        }
300    }
301
302    fn mock_plan() -> LogicalPlan {
303        let schema = Schema::new(vec![
304            Field::new("id", DataType::Int32, true),
305            Field::new("name", DataType::Utf8, true),
306        ]);
307        let table_source = LogicalTableSource::new(SchemaRef::new(schema));
308
309        let projection = None;
310
311        let builder =
312            LogicalPlanBuilder::scan("person", Arc::new(table_source), projection).unwrap();
313
314        builder
315            .filter(col("id").gt(lit(500)))
316            .unwrap()
317            .build()
318            .unwrap()
319    }
320
321    #[tokio::test]
322    async fn test_resolve_view() {
323        let query_ctx = Arc::new(QueryContext::with("greptime", "public"));
324        let backend = Arc::new(MemoryKvBackend::default());
325        let layered_cache_builder = LayeredCacheRegistryBuilder::default()
326            .add_cache_registry(CacheRegistryBuilder::default().build());
327        let fundamental_cache_registry = build_fundamental_cache_registry(backend.clone());
328        let layered_cache_registry = Arc::new(
329            with_default_composite_cache_registry(
330                layered_cache_builder.add_cache_registry(fundamental_cache_registry),
331            )
332            .unwrap()
333            .build(),
334        );
335
336        let catalog_manager = KvBackendCatalogManagerBuilder::new(
337            Arc::new(NoopInformationExtension),
338            backend.clone(),
339            layered_cache_registry,
340        )
341        .build();
342
343        let table_metadata_manager = TableMetadataManager::new(backend);
344        let mut view_info = common_meta::key::test_utils::new_test_table_info(1024);
345        view_info.table_type = TableType::View;
346        let logical_plan = vec![1, 2, 3];
347        // Create view metadata
348        table_metadata_manager
349            .create_view_metadata(
350                view_info.clone(),
351                logical_plan,
352                HashSet::new(),
353                vec!["a".to_string(), "b".to_string()],
354                vec!["id".to_string(), "name".to_string()],
355                "definition".to_string(),
356            )
357            .await
358            .unwrap();
359
360        let mut table_provider = DfTableSourceProvider::new(
361            catalog_manager,
362            true,
363            query_ctx.clone(),
364            MockDecoder::arc(),
365            true,
366        );
367
368        // View not found
369        let table_ref = TableReference::bare("not_exists_view");
370        assert!(table_provider.resolve_table(table_ref).await.is_err());
371
372        let table_ref = TableReference::bare(view_info.name);
373        let source = table_provider.resolve_table(table_ref).await.unwrap();
374        assert_eq!(
375            r#"
376Projection: person.id AS a, person.name AS b
377  Filter: person.id > Int32(500)
378    TableScan: person"#,
379            format!("\n{}", source.get_logical_plan().unwrap())
380        );
381    }
382}