frontend/
instance.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub mod builder;
16mod grpc;
17mod influxdb;
18mod jaeger;
19mod log_handler;
20mod logs;
21mod opentsdb;
22mod otlp;
23pub mod prom_store;
24mod promql;
25mod region_query;
26pub mod standalone;
27
28use std::pin::Pin;
29use std::sync::atomic::AtomicBool;
30use std::sync::{Arc, atomic};
31use std::time::{Duration, SystemTime};
32
33use async_stream::stream;
34use async_trait::async_trait;
35use auth::{PermissionChecker, PermissionCheckerRef, PermissionReq};
36use catalog::CatalogManagerRef;
37use catalog::process_manager::{
38    ProcessManagerRef, QueryStatement as CatalogQueryStatement, SlowQueryTimer,
39};
40use client::OutputData;
41use common_base::Plugins;
42use common_base::cancellation::CancellableFuture;
43use common_error::ext::{BoxedError, ErrorExt};
44use common_event_recorder::EventRecorderRef;
45use common_meta::cache_invalidator::CacheInvalidatorRef;
46use common_meta::key::TableMetadataManagerRef;
47use common_meta::key::table_name::TableNameKey;
48use common_meta::node_manager::NodeManagerRef;
49use common_meta::procedure_executor::ProcedureExecutorRef;
50use common_query::Output;
51use common_recordbatch::RecordBatchStreamWrapper;
52use common_recordbatch::error::StreamTimeoutSnafu;
53use common_telemetry::logging::SlowQueryOptions;
54use common_telemetry::{debug, error, tracing};
55use dashmap::DashMap;
56use datafusion_expr::LogicalPlan;
57use futures::{Stream, StreamExt};
58use lazy_static::lazy_static;
59use operator::delete::DeleterRef;
60use operator::insert::InserterRef;
61use operator::statement::{StatementExecutor, StatementExecutorRef};
62use partition::manager::PartitionRuleManagerRef;
63use pipeline::pipeline_operator::PipelineOperator;
64use prometheus::HistogramTimer;
65use promql_parser::label::Matcher;
66use query::QueryEngineRef;
67use query::metrics::OnDone;
68use query::parser::{PromQuery, QueryLanguageParser, QueryStatement};
69use query::query_engine::DescribeResult;
70use query::query_engine::options::{QueryOptions, validate_catalog_and_schema};
71use servers::error::{
72    self as server_error, AuthSnafu, CommonMetaSnafu, ExecuteQuerySnafu,
73    OtlpMetricModeIncompatibleSnafu, ParsePromQLSnafu, UnexpectedResultSnafu,
74};
75use servers::interceptor::{
76    PromQueryInterceptor, PromQueryInterceptorRef, SqlQueryInterceptor, SqlQueryInterceptorRef,
77};
78use servers::otlp::metrics::legacy_normalize_otlp_name;
79use servers::prometheus_handler::PrometheusHandler;
80use servers::query_handler::sql::SqlQueryHandler;
81use session::context::{Channel, QueryContextRef};
82use session::table_name::table_idents_to_full_name;
83use snafu::prelude::*;
84use sql::ast::ObjectNamePartExt;
85use sql::dialect::Dialect;
86use sql::parser::{ParseOptions, ParserContext};
87use sql::statements::comment::CommentObject;
88use sql::statements::copy::{CopyDatabase, CopyTable};
89use sql::statements::statement::Statement;
90use sql::statements::tql::Tql;
91use sqlparser::ast::ObjectName;
92pub use standalone::StandaloneDatanodeManager;
93use table::requests::{OTLP_METRIC_COMPAT_KEY, OTLP_METRIC_COMPAT_PROM};
94
95use crate::error::{
96    self, Error, ExecLogicalPlanSnafu, ExecutePromqlSnafu, ExternalSnafu, InvalidSqlSnafu,
97    ParseSqlSnafu, PermissionSnafu, PlanStatementSnafu, Result, SqlExecInterceptedSnafu,
98    StatementTimeoutSnafu, TableOperationSnafu,
99};
100use crate::stream_wrapper::CancellableStreamWrapper;
101
102lazy_static! {
103    static ref OTLP_LEGACY_DEFAULT_VALUE: String = "legacy".to_string();
104}
105
106/// The frontend instance contains necessary components, and implements many
107/// traits, like [`servers::query_handler::grpc::GrpcQueryHandler`],
108/// [`servers::query_handler::sql::SqlQueryHandler`], etc.
109#[derive(Clone)]
110pub struct Instance {
111    catalog_manager: CatalogManagerRef,
112    pipeline_operator: Arc<PipelineOperator>,
113    statement_executor: Arc<StatementExecutor>,
114    query_engine: QueryEngineRef,
115    plugins: Plugins,
116    inserter: InserterRef,
117    deleter: DeleterRef,
118    table_metadata_manager: TableMetadataManagerRef,
119    event_recorder: Option<EventRecorderRef>,
120    process_manager: ProcessManagerRef,
121    slow_query_options: SlowQueryOptions,
122    suspend: Arc<AtomicBool>,
123
124    // cache for otlp metrics
125    // first layer key: db-string
126    // key: direct input metric name
127    // value: if runs in legacy mode
128    otlp_metrics_table_legacy_cache: DashMap<String, DashMap<String, bool>>,
129}
130
131impl Instance {
132    pub fn catalog_manager(&self) -> &CatalogManagerRef {
133        &self.catalog_manager
134    }
135
136    pub fn query_engine(&self) -> &QueryEngineRef {
137        &self.query_engine
138    }
139
140    pub fn plugins(&self) -> &Plugins {
141        &self.plugins
142    }
143
144    pub fn statement_executor(&self) -> &StatementExecutorRef {
145        &self.statement_executor
146    }
147
148    pub fn table_metadata_manager(&self) -> &TableMetadataManagerRef {
149        &self.table_metadata_manager
150    }
151
152    pub fn inserter(&self) -> &InserterRef {
153        &self.inserter
154    }
155
156    pub fn process_manager(&self) -> &ProcessManagerRef {
157        &self.process_manager
158    }
159
160    pub fn node_manager(&self) -> &NodeManagerRef {
161        self.inserter.node_manager()
162    }
163
164    pub fn partition_manager(&self) -> &PartitionRuleManagerRef {
165        self.inserter.partition_manager()
166    }
167
168    pub fn cache_invalidator(&self) -> &CacheInvalidatorRef {
169        self.statement_executor.cache_invalidator()
170    }
171
172    pub fn procedure_executor(&self) -> &ProcedureExecutorRef {
173        self.statement_executor.procedure_executor()
174    }
175
176    pub fn suspend_state(&self) -> Arc<AtomicBool> {
177        self.suspend.clone()
178    }
179
180    pub(crate) fn is_suspended(&self) -> bool {
181        self.suspend.load(atomic::Ordering::Relaxed)
182    }
183}
184
185fn parse_stmt(sql: &str, dialect: &(dyn Dialect + Send + Sync)) -> Result<Vec<Statement>> {
186    ParserContext::create_with_dialect(sql, dialect, ParseOptions::default()).context(ParseSqlSnafu)
187}
188
189impl Instance {
190    async fn query_statement(&self, stmt: Statement, query_ctx: QueryContextRef) -> Result<Output> {
191        check_permission(self.plugins.clone(), &stmt, &query_ctx)?;
192
193        let query_interceptor = self.plugins.get::<SqlQueryInterceptorRef<Error>>();
194        let query_interceptor = query_interceptor.as_ref();
195
196        if should_capture_statement(Some(&stmt)) {
197            let slow_query_timer = self
198                .slow_query_options
199                .enable
200                .then(|| self.event_recorder.clone())
201                .flatten()
202                .map(|event_recorder| {
203                    SlowQueryTimer::new(
204                        CatalogQueryStatement::Sql(stmt.clone()),
205                        self.slow_query_options.threshold,
206                        self.slow_query_options.sample_ratio,
207                        self.slow_query_options.record_type,
208                        event_recorder,
209                    )
210                });
211
212            let ticket = self.process_manager.register_query(
213                query_ctx.current_catalog().to_string(),
214                vec![query_ctx.current_schema()],
215                stmt.to_string(),
216                query_ctx.conn_info().to_string(),
217                Some(query_ctx.process_id()),
218                slow_query_timer,
219            );
220
221            let query_fut = self.exec_statement_with_timeout(stmt, query_ctx, query_interceptor);
222
223            CancellableFuture::new(query_fut, ticket.cancellation_handle.clone())
224                .await
225                .map_err(|_| error::CancelledSnafu.build())?
226                .map(|output| {
227                    let Output { meta, data } = output;
228
229                    let data = match data {
230                        OutputData::Stream(stream) => OutputData::Stream(Box::pin(
231                            CancellableStreamWrapper::new(stream, ticket),
232                        )),
233                        other => other,
234                    };
235                    Output { data, meta }
236                })
237        } else {
238            self.exec_statement_with_timeout(stmt, query_ctx, query_interceptor)
239                .await
240        }
241    }
242
243    async fn exec_statement_with_timeout(
244        &self,
245        stmt: Statement,
246        query_ctx: QueryContextRef,
247        query_interceptor: Option<&SqlQueryInterceptorRef<Error>>,
248    ) -> Result<Output> {
249        let timeout = derive_timeout(&stmt, &query_ctx);
250        match timeout {
251            Some(timeout) => {
252                let start = tokio::time::Instant::now();
253                let output = tokio::time::timeout(
254                    timeout,
255                    self.exec_statement(stmt, query_ctx, query_interceptor),
256                )
257                .await
258                .map_err(|_| StatementTimeoutSnafu.build())??;
259                // compute remaining timeout
260                let remaining_timeout = timeout.checked_sub(start.elapsed()).unwrap_or_default();
261                attach_timeout(output, remaining_timeout)
262            }
263            None => {
264                self.exec_statement(stmt, query_ctx, query_interceptor)
265                    .await
266            }
267        }
268    }
269
270    async fn exec_statement(
271        &self,
272        stmt: Statement,
273        query_ctx: QueryContextRef,
274        query_interceptor: Option<&SqlQueryInterceptorRef<Error>>,
275    ) -> Result<Output> {
276        match stmt {
277            Statement::Query(_) | Statement::Explain(_) | Statement::Delete(_) => {
278                // TODO: remove this when format is supported in datafusion
279                if let Statement::Explain(explain) = &stmt
280                    && let Some(format) = explain.format()
281                {
282                    query_ctx.set_explain_format(format.to_string());
283                }
284
285                self.plan_and_exec_sql(stmt, &query_ctx, query_interceptor)
286                    .await
287            }
288            Statement::Tql(tql) => {
289                self.plan_and_exec_tql(&query_ctx, query_interceptor, tql)
290                    .await
291            }
292            _ => {
293                query_interceptor.pre_execute(&stmt, None, query_ctx.clone())?;
294                self.statement_executor
295                    .execute_sql(stmt, query_ctx)
296                    .await
297                    .context(TableOperationSnafu)
298            }
299        }
300    }
301
302    async fn plan_and_exec_sql(
303        &self,
304        stmt: Statement,
305        query_ctx: &QueryContextRef,
306        query_interceptor: Option<&SqlQueryInterceptorRef<Error>>,
307    ) -> Result<Output> {
308        let stmt = QueryStatement::Sql(stmt);
309        let plan = self
310            .statement_executor
311            .plan(&stmt, query_ctx.clone())
312            .await?;
313        let QueryStatement::Sql(stmt) = stmt else {
314            unreachable!()
315        };
316        query_interceptor.pre_execute(&stmt, Some(&plan), query_ctx.clone())?;
317        self.statement_executor
318            .exec_plan(plan, query_ctx.clone())
319            .await
320            .context(TableOperationSnafu)
321    }
322
323    async fn plan_and_exec_tql(
324        &self,
325        query_ctx: &QueryContextRef,
326        query_interceptor: Option<&SqlQueryInterceptorRef<Error>>,
327        tql: Tql,
328    ) -> Result<Output> {
329        let plan = self
330            .statement_executor
331            .plan_tql(tql.clone(), query_ctx)
332            .await?;
333        query_interceptor.pre_execute(&Statement::Tql(tql), Some(&plan), query_ctx.clone())?;
334        self.statement_executor
335            .exec_plan(plan, query_ctx.clone())
336            .await
337            .context(TableOperationSnafu)
338    }
339
340    async fn check_otlp_legacy(
341        &self,
342        names: &[&String],
343        ctx: QueryContextRef,
344    ) -> server_error::Result<bool> {
345        let db_string = ctx.get_db_string();
346        // fast cache check
347        let cache = self
348            .otlp_metrics_table_legacy_cache
349            .entry(db_string.clone())
350            .or_default();
351        if let Some(flag) = fast_legacy_check(&cache, names)? {
352            return Ok(flag);
353        }
354        // release cache reference to avoid lock contention
355        drop(cache);
356
357        let catalog = ctx.current_catalog();
358        let schema = ctx.current_schema();
359
360        // query legacy table names
361        let normalized_names = names
362            .iter()
363            .map(|n| legacy_normalize_otlp_name(n))
364            .collect::<Vec<_>>();
365        let table_names = normalized_names
366            .iter()
367            .map(|n| TableNameKey::new(catalog, &schema, n))
368            .collect::<Vec<_>>();
369        let table_values = self
370            .table_metadata_manager()
371            .table_name_manager()
372            .batch_get(table_names)
373            .await
374            .context(CommonMetaSnafu)?;
375        let table_ids = table_values
376            .into_iter()
377            .filter_map(|v| v.map(|vi| vi.table_id()))
378            .collect::<Vec<_>>();
379
380        // means no existing table is found, use new mode
381        if table_ids.is_empty() {
382            let cache = self
383                .otlp_metrics_table_legacy_cache
384                .entry(db_string)
385                .or_default();
386            names.iter().for_each(|name| {
387                cache.insert((*name).clone(), false);
388            });
389            return Ok(false);
390        }
391
392        // has existing table, check table options
393        let table_infos = self
394            .table_metadata_manager()
395            .table_info_manager()
396            .batch_get(&table_ids)
397            .await
398            .context(CommonMetaSnafu)?;
399        let options = table_infos
400            .values()
401            .map(|info| {
402                info.table_info
403                    .meta
404                    .options
405                    .extra_options
406                    .get(OTLP_METRIC_COMPAT_KEY)
407                    .unwrap_or(&OTLP_LEGACY_DEFAULT_VALUE)
408            })
409            .collect::<Vec<_>>();
410        let cache = self
411            .otlp_metrics_table_legacy_cache
412            .entry(db_string)
413            .or_default();
414        if !options.is_empty() {
415            // check value consistency
416            let has_prom = options.iter().any(|opt| *opt == OTLP_METRIC_COMPAT_PROM);
417            let has_legacy = options
418                .iter()
419                .any(|opt| *opt == OTLP_LEGACY_DEFAULT_VALUE.as_str());
420            ensure!(!(has_prom && has_legacy), OtlpMetricModeIncompatibleSnafu);
421            let flag = has_legacy;
422            names.iter().for_each(|name| {
423                cache.insert((*name).clone(), flag);
424            });
425            Ok(flag)
426        } else {
427            // no table info, use new mode
428            names.iter().for_each(|name| {
429                cache.insert((*name).clone(), false);
430            });
431            Ok(false)
432        }
433    }
434}
435
436fn fast_legacy_check(
437    cache: &DashMap<String, bool>,
438    names: &[&String],
439) -> server_error::Result<Option<bool>> {
440    let hit_cache = names
441        .iter()
442        .filter_map(|name| cache.get(*name))
443        .collect::<Vec<_>>();
444    if !hit_cache.is_empty() {
445        let hit_legacy = hit_cache.iter().any(|en| *en.value());
446        let hit_prom = hit_cache.iter().any(|en| !*en.value());
447
448        // hit but have true and false, means both legacy and new mode are used
449        // we cannot handle this case, so return error
450        // add doc links in err msg later
451        ensure!(!(hit_legacy && hit_prom), OtlpMetricModeIncompatibleSnafu);
452
453        let flag = hit_legacy;
454        // drop hit_cache to release references before inserting to avoid deadlock
455        drop(hit_cache);
456
457        // set cache for all names
458        names.iter().for_each(|name| {
459            if !cache.contains_key(*name) {
460                cache.insert((*name).clone(), flag);
461            }
462        });
463        Ok(Some(flag))
464    } else {
465        Ok(None)
466    }
467}
468
469/// If the relevant variables are set, the timeout is enforced for all PostgreSQL statements.
470/// For MySQL, it applies only to read-only statements.
471fn derive_timeout(stmt: &Statement, query_ctx: &QueryContextRef) -> Option<Duration> {
472    let query_timeout = query_ctx.query_timeout()?;
473    if query_timeout.is_zero() {
474        return None;
475    }
476    match query_ctx.channel() {
477        Channel::Mysql if stmt.is_readonly() => Some(query_timeout),
478        Channel::Postgres => Some(query_timeout),
479        _ => None,
480    }
481}
482
483fn attach_timeout(output: Output, mut timeout: Duration) -> Result<Output> {
484    if timeout.is_zero() {
485        return StatementTimeoutSnafu.fail();
486    }
487
488    let output = match output.data {
489        OutputData::AffectedRows(_) | OutputData::RecordBatches(_) => output,
490        OutputData::Stream(mut stream) => {
491            let schema = stream.schema();
492            let s = Box::pin(stream! {
493                let mut start = tokio::time::Instant::now();
494                while let Some(item) = tokio::time::timeout(timeout, stream.next()).await.map_err(|_| StreamTimeoutSnafu.build())? {
495                    yield item;
496
497                    let now = tokio::time::Instant::now();
498                    timeout = timeout.checked_sub(now - start).unwrap_or(Duration::ZERO);
499                    start = now;
500                    // tokio::time::timeout may not return an error immediately when timeout is 0.
501                    if timeout.is_zero() {
502                        StreamTimeoutSnafu.fail()?;
503                    }
504                }
505            }) as Pin<Box<dyn Stream<Item = _> + Send>>;
506            let stream = RecordBatchStreamWrapper {
507                schema,
508                stream: s,
509                output_ordering: None,
510                metrics: Default::default(),
511            };
512            Output::new(OutputData::Stream(Box::pin(stream)), output.meta)
513        }
514    };
515
516    Ok(output)
517}
518
519#[async_trait]
520impl SqlQueryHandler for Instance {
521    type Error = Error;
522
523    #[tracing::instrument(skip_all)]
524    async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec<Result<Output>> {
525        if self.is_suspended() {
526            return vec![error::SuspendedSnafu {}.fail()];
527        }
528
529        let query_interceptor_opt = self.plugins.get::<SqlQueryInterceptorRef<Error>>();
530        let query_interceptor = query_interceptor_opt.as_ref();
531        let query = match query_interceptor.pre_parsing(query, query_ctx.clone()) {
532            Ok(q) => q,
533            Err(e) => return vec![Err(e)],
534        };
535
536        let checker_ref = self.plugins.get::<PermissionCheckerRef>();
537        let checker = checker_ref.as_ref();
538
539        match parse_stmt(query.as_ref(), query_ctx.sql_dialect())
540            .and_then(|stmts| query_interceptor.post_parsing(stmts, query_ctx.clone()))
541        {
542            Ok(stmts) => {
543                if stmts.is_empty() {
544                    return vec![
545                        InvalidSqlSnafu {
546                            err_msg: "empty statements",
547                        }
548                        .fail(),
549                    ];
550                }
551
552                let mut results = Vec::with_capacity(stmts.len());
553                for stmt in stmts {
554                    if let Err(e) = checker
555                        .check_permission(
556                            query_ctx.current_user(),
557                            PermissionReq::SqlStatement(&stmt),
558                        )
559                        .context(PermissionSnafu)
560                    {
561                        results.push(Err(e));
562                        break;
563                    }
564
565                    match self.query_statement(stmt.clone(), query_ctx.clone()).await {
566                        Ok(output) => {
567                            let output_result =
568                                query_interceptor.post_execute(output, query_ctx.clone());
569                            results.push(output_result);
570                        }
571                        Err(e) => {
572                            if e.status_code().should_log_error() {
573                                error!(e; "Failed to execute query: {stmt}");
574                            } else {
575                                debug!("Failed to execute query: {stmt}, {e}");
576                            }
577                            results.push(Err(e));
578                            break;
579                        }
580                    }
581                }
582                results
583            }
584            Err(e) => {
585                vec![Err(e)]
586            }
587        }
588    }
589
590    async fn do_exec_plan(
591        &self,
592        stmt: Option<Statement>,
593        plan: LogicalPlan,
594        query_ctx: QueryContextRef,
595    ) -> Result<Output> {
596        ensure!(!self.is_suspended(), error::SuspendedSnafu);
597
598        if should_capture_statement(stmt.as_ref()) {
599            // It's safe to unwrap here because we've already checked the type.
600            let stmt = stmt.unwrap();
601            let query = stmt.to_string();
602            let slow_query_timer = self
603                .slow_query_options
604                .enable
605                .then(|| self.event_recorder.clone())
606                .flatten()
607                .map(|event_recorder| {
608                    SlowQueryTimer::new(
609                        CatalogQueryStatement::Sql(stmt.clone()),
610                        self.slow_query_options.threshold,
611                        self.slow_query_options.sample_ratio,
612                        self.slow_query_options.record_type,
613                        event_recorder,
614                    )
615                });
616
617            let ticket = self.process_manager.register_query(
618                query_ctx.current_catalog().to_string(),
619                vec![query_ctx.current_schema()],
620                query,
621                query_ctx.conn_info().to_string(),
622                Some(query_ctx.process_id()),
623                slow_query_timer,
624            );
625
626            let query_fut = self.query_engine.execute(plan.clone(), query_ctx);
627
628            CancellableFuture::new(query_fut, ticket.cancellation_handle.clone())
629                .await
630                .map_err(|_| error::CancelledSnafu.build())?
631                .map(|output| {
632                    let Output { meta, data } = output;
633
634                    let data = match data {
635                        OutputData::Stream(stream) => OutputData::Stream(Box::pin(
636                            CancellableStreamWrapper::new(stream, ticket),
637                        )),
638                        other => other,
639                    };
640                    Output { data, meta }
641                })
642                .context(ExecLogicalPlanSnafu)
643        } else {
644            // plan should be prepared before exec
645            // we'll do check there
646            self.query_engine
647                .execute(plan.clone(), query_ctx)
648                .await
649                .context(ExecLogicalPlanSnafu)
650        }
651    }
652
653    #[tracing::instrument(skip_all)]
654    async fn do_promql_query(
655        &self,
656        query: &PromQuery,
657        query_ctx: QueryContextRef,
658    ) -> Vec<Result<Output>> {
659        if self.is_suspended() {
660            return vec![error::SuspendedSnafu {}.fail()];
661        }
662
663        // check will be done in prometheus handler's do_query
664        let result = PrometheusHandler::do_query(self, query, query_ctx)
665            .await
666            .with_context(|_| ExecutePromqlSnafu {
667                query: format!("{query:?}"),
668            });
669        vec![result]
670    }
671
672    async fn do_describe(
673        &self,
674        stmt: Statement,
675        query_ctx: QueryContextRef,
676    ) -> Result<Option<DescribeResult>> {
677        ensure!(!self.is_suspended(), error::SuspendedSnafu);
678
679        if matches!(
680            stmt,
681            Statement::Insert(_) | Statement::Query(_) | Statement::Delete(_)
682        ) {
683            self.plugins
684                .get::<PermissionCheckerRef>()
685                .as_ref()
686                .check_permission(query_ctx.current_user(), PermissionReq::SqlStatement(&stmt))
687                .context(PermissionSnafu)?;
688
689            let plan = self
690                .query_engine
691                .planner()
692                .plan(&QueryStatement::Sql(stmt), query_ctx.clone())
693                .await
694                .context(PlanStatementSnafu)?;
695            self.query_engine
696                .describe(plan, query_ctx)
697                .await
698                .map(Some)
699                .context(error::DescribeStatementSnafu)
700        } else {
701            Ok(None)
702        }
703    }
704
705    async fn is_valid_schema(&self, catalog: &str, schema: &str) -> Result<bool> {
706        self.catalog_manager
707            .schema_exists(catalog, schema, None)
708            .await
709            .context(error::CatalogSnafu)
710    }
711}
712
713/// Attaches a timer to the output and observes it once the output is exhausted.
714pub fn attach_timer(output: Output, timer: HistogramTimer) -> Output {
715    match output.data {
716        OutputData::AffectedRows(_) | OutputData::RecordBatches(_) => output,
717        OutputData::Stream(stream) => {
718            let stream = OnDone::new(stream, move || {
719                timer.observe_duration();
720            });
721            Output::new(OutputData::Stream(Box::pin(stream)), output.meta)
722        }
723    }
724}
725
726#[async_trait]
727impl PrometheusHandler for Instance {
728    #[tracing::instrument(skip_all)]
729    async fn do_query(
730        &self,
731        query: &PromQuery,
732        query_ctx: QueryContextRef,
733    ) -> server_error::Result<Output> {
734        let interceptor = self
735            .plugins
736            .get::<PromQueryInterceptorRef<server_error::Error>>();
737
738        self.plugins
739            .get::<PermissionCheckerRef>()
740            .as_ref()
741            .check_permission(query_ctx.current_user(), PermissionReq::PromQuery)
742            .context(AuthSnafu)?;
743
744        let stmt = QueryLanguageParser::parse_promql(query, &query_ctx).with_context(|_| {
745            ParsePromQLSnafu {
746                query: query.clone(),
747            }
748        })?;
749
750        let plan = self
751            .statement_executor
752            .plan(&stmt, query_ctx.clone())
753            .await
754            .map_err(BoxedError::new)
755            .context(ExecuteQuerySnafu)?;
756
757        interceptor.pre_execute(query, Some(&plan), query_ctx.clone())?;
758
759        // Take the EvalStmt from the original QueryStatement and use it to create the CatalogQueryStatement.
760        let query_statement = if let QueryStatement::Promql(eval_stmt, alias) = stmt {
761            CatalogQueryStatement::Promql(eval_stmt, alias)
762        } else {
763            // It should not happen since the query is already parsed successfully.
764            return UnexpectedResultSnafu {
765                reason: "The query should always be promql.".to_string(),
766            }
767            .fail();
768        };
769        let query = query_statement.to_string();
770
771        let slow_query_timer = self
772            .slow_query_options
773            .enable
774            .then(|| self.event_recorder.clone())
775            .flatten()
776            .map(|event_recorder| {
777                SlowQueryTimer::new(
778                    query_statement,
779                    self.slow_query_options.threshold,
780                    self.slow_query_options.sample_ratio,
781                    self.slow_query_options.record_type,
782                    event_recorder,
783                )
784            });
785
786        let ticket = self.process_manager.register_query(
787            query_ctx.current_catalog().to_string(),
788            vec![query_ctx.current_schema()],
789            query,
790            query_ctx.conn_info().to_string(),
791            Some(query_ctx.process_id()),
792            slow_query_timer,
793        );
794
795        let query_fut = self.statement_executor.exec_plan(plan, query_ctx.clone());
796
797        let output = CancellableFuture::new(query_fut, ticket.cancellation_handle.clone())
798            .await
799            .map_err(|_| servers::error::CancelledSnafu.build())?
800            .map(|output| {
801                let Output { meta, data } = output;
802                let data = match data {
803                    OutputData::Stream(stream) => {
804                        OutputData::Stream(Box::pin(CancellableStreamWrapper::new(stream, ticket)))
805                    }
806                    other => other,
807                };
808                Output { data, meta }
809            })
810            .map_err(BoxedError::new)
811            .context(ExecuteQuerySnafu)?;
812
813        Ok(interceptor.post_execute(output, query_ctx)?)
814    }
815
816    async fn query_metric_names(
817        &self,
818        matchers: Vec<Matcher>,
819        ctx: &QueryContextRef,
820    ) -> server_error::Result<Vec<String>> {
821        self.handle_query_metric_names(matchers, ctx)
822            .await
823            .map_err(BoxedError::new)
824            .context(ExecuteQuerySnafu)
825    }
826
827    async fn query_label_values(
828        &self,
829        metric: String,
830        label_name: String,
831        matchers: Vec<Matcher>,
832        start: SystemTime,
833        end: SystemTime,
834        ctx: &QueryContextRef,
835    ) -> server_error::Result<Vec<String>> {
836        self.handle_query_label_values(metric, label_name, matchers, start, end, ctx)
837            .await
838            .map_err(BoxedError::new)
839            .context(ExecuteQuerySnafu)
840    }
841
842    fn catalog_manager(&self) -> CatalogManagerRef {
843        self.catalog_manager.clone()
844    }
845}
846
847/// Validate `stmt.database` permission if it's presented.
848macro_rules! validate_db_permission {
849    ($stmt: expr, $query_ctx: expr) => {
850        if let Some(database) = &$stmt.database {
851            validate_catalog_and_schema($query_ctx.current_catalog(), database, $query_ctx)
852                .map_err(BoxedError::new)
853                .context(SqlExecInterceptedSnafu)?;
854        }
855    };
856}
857
858pub fn check_permission(
859    plugins: Plugins,
860    stmt: &Statement,
861    query_ctx: &QueryContextRef,
862) -> Result<()> {
863    let need_validate = plugins
864        .get::<QueryOptions>()
865        .map(|opts| opts.disallow_cross_catalog_query)
866        .unwrap_or_default();
867
868    if !need_validate {
869        return Ok(());
870    }
871
872    match stmt {
873        // Will be checked in execution.
874        // TODO(dennis): add a hook for admin commands.
875        Statement::Admin(_) => {}
876        // These are executed by query engine, and will be checked there.
877        Statement::Query(_)
878        | Statement::Explain(_)
879        | Statement::Tql(_)
880        | Statement::Delete(_)
881        | Statement::DeclareCursor(_)
882        | Statement::Copy(sql::statements::copy::Copy::CopyQueryTo(_)) => {}
883        // database ops won't be checked
884        Statement::CreateDatabase(_)
885        | Statement::ShowDatabases(_)
886        | Statement::DropDatabase(_)
887        | Statement::AlterDatabase(_)
888        | Statement::DropFlow(_)
889        | Statement::Use(_) => {}
890        #[cfg(feature = "enterprise")]
891        Statement::DropTrigger(_) => {}
892        Statement::ShowCreateDatabase(stmt) => {
893            validate_database(&stmt.database_name, query_ctx)?;
894        }
895        Statement::ShowCreateTable(stmt) => {
896            validate_param(&stmt.table_name, query_ctx)?;
897        }
898        Statement::ShowCreateFlow(stmt) => {
899            validate_flow(&stmt.flow_name, query_ctx)?;
900        }
901        #[cfg(feature = "enterprise")]
902        Statement::ShowCreateTrigger(stmt) => {
903            validate_param(&stmt.trigger_name, query_ctx)?;
904        }
905        Statement::ShowCreateView(stmt) => {
906            validate_param(&stmt.view_name, query_ctx)?;
907        }
908        Statement::CreateExternalTable(stmt) => {
909            validate_param(&stmt.name, query_ctx)?;
910        }
911        Statement::CreateFlow(stmt) => {
912            // TODO: should also validate source table name here?
913            validate_param(&stmt.sink_table_name, query_ctx)?;
914        }
915        #[cfg(feature = "enterprise")]
916        Statement::CreateTrigger(stmt) => {
917            validate_param(&stmt.trigger_name, query_ctx)?;
918        }
919        Statement::CreateView(stmt) => {
920            validate_param(&stmt.name, query_ctx)?;
921        }
922        Statement::AlterTable(stmt) => {
923            validate_param(stmt.table_name(), query_ctx)?;
924        }
925        #[cfg(feature = "enterprise")]
926        Statement::AlterTrigger(_) => {}
927        // set/show variable now only alter/show variable in session
928        Statement::SetVariables(_) | Statement::ShowVariables(_) => {}
929        // show charset and show collation won't be checked
930        Statement::ShowCharset(_) | Statement::ShowCollation(_) => {}
931
932        Statement::Comment(comment) => match &comment.object {
933            CommentObject::Table(table) => validate_param(table, query_ctx)?,
934            CommentObject::Column { table, .. } => validate_param(table, query_ctx)?,
935            CommentObject::Flow(flow) => validate_flow(flow, query_ctx)?,
936        },
937
938        Statement::Insert(insert) => {
939            let name = insert.table_name().context(ParseSqlSnafu)?;
940            validate_param(name, query_ctx)?;
941        }
942        Statement::CreateTable(stmt) => {
943            validate_param(&stmt.name, query_ctx)?;
944        }
945        Statement::CreateTableLike(stmt) => {
946            validate_param(&stmt.table_name, query_ctx)?;
947            validate_param(&stmt.source_name, query_ctx)?;
948        }
949        Statement::DropTable(drop_stmt) => {
950            for table_name in drop_stmt.table_names() {
951                validate_param(table_name, query_ctx)?;
952            }
953        }
954        Statement::DropView(stmt) => {
955            validate_param(&stmt.view_name, query_ctx)?;
956        }
957        Statement::ShowTables(stmt) => {
958            validate_db_permission!(stmt, query_ctx);
959        }
960        Statement::ShowTableStatus(stmt) => {
961            validate_db_permission!(stmt, query_ctx);
962        }
963        Statement::ShowColumns(stmt) => {
964            validate_db_permission!(stmt, query_ctx);
965        }
966        Statement::ShowIndex(stmt) => {
967            validate_db_permission!(stmt, query_ctx);
968        }
969        Statement::ShowRegion(stmt) => {
970            validate_db_permission!(stmt, query_ctx);
971        }
972        Statement::ShowViews(stmt) => {
973            validate_db_permission!(stmt, query_ctx);
974        }
975        Statement::ShowFlows(stmt) => {
976            validate_db_permission!(stmt, query_ctx);
977        }
978        #[cfg(feature = "enterprise")]
979        Statement::ShowTriggers(_stmt) => {
980            // The trigger is organized based on the catalog dimension, so there
981            // is no need to check the permission of the database(schema).
982        }
983        Statement::ShowStatus(_stmt) => {}
984        Statement::ShowSearchPath(_stmt) => {}
985        Statement::DescribeTable(stmt) => {
986            validate_param(stmt.name(), query_ctx)?;
987        }
988        Statement::Copy(sql::statements::copy::Copy::CopyTable(stmt)) => match stmt {
989            CopyTable::To(copy_table_to) => validate_param(&copy_table_to.table_name, query_ctx)?,
990            CopyTable::From(copy_table_from) => {
991                validate_param(&copy_table_from.table_name, query_ctx)?
992            }
993        },
994        Statement::Copy(sql::statements::copy::Copy::CopyDatabase(copy_database)) => {
995            match copy_database {
996                CopyDatabase::To(stmt) => validate_database(&stmt.database_name, query_ctx)?,
997                CopyDatabase::From(stmt) => validate_database(&stmt.database_name, query_ctx)?,
998            }
999        }
1000        Statement::TruncateTable(stmt) => {
1001            validate_param(stmt.table_name(), query_ctx)?;
1002        }
1003        // cursor operations are always allowed once it's created
1004        Statement::FetchCursor(_) | Statement::CloseCursor(_) => {}
1005        // User can only kill process in their own catalog.
1006        Statement::Kill(_) => {}
1007        // SHOW PROCESSLIST
1008        Statement::ShowProcesslist(_) => {}
1009    }
1010    Ok(())
1011}
1012
1013fn validate_param(name: &ObjectName, query_ctx: &QueryContextRef) -> Result<()> {
1014    let (catalog, schema, _) = table_idents_to_full_name(name, query_ctx)
1015        .map_err(BoxedError::new)
1016        .context(ExternalSnafu)?;
1017
1018    validate_catalog_and_schema(&catalog, &schema, query_ctx)
1019        .map_err(BoxedError::new)
1020        .context(SqlExecInterceptedSnafu)
1021}
1022
1023fn validate_flow(name: &ObjectName, query_ctx: &QueryContextRef) -> Result<()> {
1024    let catalog = match &name.0[..] {
1025        [_flow] => query_ctx.current_catalog().to_string(),
1026        [catalog, _flow] => catalog.to_string_unquoted(),
1027        _ => {
1028            return InvalidSqlSnafu {
1029                err_msg: format!(
1030                    "expect flow name to be <catalog>.<flow_name> or <flow_name>, actual: {name}",
1031                ),
1032            }
1033            .fail();
1034        }
1035    };
1036
1037    let schema = query_ctx.current_schema();
1038
1039    validate_catalog_and_schema(&catalog, &schema, query_ctx)
1040        .map_err(BoxedError::new)
1041        .context(SqlExecInterceptedSnafu)
1042}
1043
1044fn validate_database(name: &ObjectName, query_ctx: &QueryContextRef) -> Result<()> {
1045    let (catalog, schema) = match &name.0[..] {
1046        [schema] => (
1047            query_ctx.current_catalog().to_string(),
1048            schema.to_string_unquoted(),
1049        ),
1050        [catalog, schema] => (catalog.to_string_unquoted(), schema.to_string_unquoted()),
1051        _ => InvalidSqlSnafu {
1052            err_msg: format!(
1053                "expect database name to be <catalog>.<schema> or <schema>, actual: {name}",
1054            ),
1055        }
1056        .fail()?,
1057    };
1058
1059    validate_catalog_and_schema(&catalog, &schema, query_ctx)
1060        .map_err(BoxedError::new)
1061        .context(SqlExecInterceptedSnafu)
1062}
1063
1064// Create a query ticket and slow query timer if the statement is a query or readonly statement.
1065fn should_capture_statement(stmt: Option<&Statement>) -> bool {
1066    if let Some(stmt) = stmt {
1067        matches!(stmt, Statement::Query(_)) || stmt.is_readonly()
1068    } else {
1069        false
1070    }
1071}
1072
1073#[cfg(test)]
1074mod tests {
1075    use std::collections::HashMap;
1076    use std::sync::atomic::{AtomicBool, Ordering};
1077    use std::sync::{Arc, Barrier};
1078    use std::thread;
1079    use std::time::{Duration, Instant};
1080
1081    use common_base::Plugins;
1082    use query::query_engine::options::QueryOptions;
1083    use session::context::QueryContext;
1084    use sql::dialect::GreptimeDbDialect;
1085    use strfmt::Format;
1086
1087    use super::*;
1088
1089    #[test]
1090    fn test_fast_legacy_check_deadlock_prevention() {
1091        // Create a DashMap to simulate the cache
1092        let cache = DashMap::new();
1093
1094        // Pre-populate cache with some entries
1095        cache.insert("metric1".to_string(), true); // legacy mode
1096        cache.insert("metric2".to_string(), false); // prom mode
1097        cache.insert("metric3".to_string(), true); // legacy mode
1098
1099        // Test case 1: Normal operation with cache hits
1100        let metric1 = "metric1".to_string();
1101        let metric4 = "metric4".to_string();
1102        let names1 = vec![&metric1, &metric4];
1103        let result = fast_legacy_check(&cache, &names1);
1104        assert!(result.is_ok());
1105        assert_eq!(result.unwrap(), Some(true)); // should return legacy mode
1106
1107        // Verify that metric4 was added to cache
1108        assert!(cache.contains_key("metric4"));
1109        assert!(*cache.get("metric4").unwrap().value());
1110
1111        // Test case 2: No cache hits
1112        let metric5 = "metric5".to_string();
1113        let metric6 = "metric6".to_string();
1114        let names2 = vec![&metric5, &metric6];
1115        let result = fast_legacy_check(&cache, &names2);
1116        assert!(result.is_ok());
1117        assert_eq!(result.unwrap(), None); // should return None as no cache hits
1118
1119        // Test case 3: Incompatible modes should return error
1120        let cache_incompatible = DashMap::new();
1121        cache_incompatible.insert("metric1".to_string(), true); // legacy
1122        cache_incompatible.insert("metric2".to_string(), false); // prom
1123        let metric1_test = "metric1".to_string();
1124        let metric2_test = "metric2".to_string();
1125        let names3 = vec![&metric1_test, &metric2_test];
1126        let result = fast_legacy_check(&cache_incompatible, &names3);
1127        assert!(result.is_err()); // should error due to incompatible modes
1128
1129        // Test case 4: Intensive concurrent access to test deadlock prevention
1130        // This test specifically targets the scenario where multiple threads
1131        // access the same cache entries simultaneously
1132        let cache_concurrent = Arc::new(DashMap::new());
1133        cache_concurrent.insert("shared_metric".to_string(), true);
1134
1135        let num_threads = 8;
1136        let operations_per_thread = 100;
1137        let barrier = Arc::new(Barrier::new(num_threads));
1138        let success_flag = Arc::new(AtomicBool::new(true));
1139
1140        let handles: Vec<_> = (0..num_threads)
1141            .map(|thread_id| {
1142                let cache_clone = Arc::clone(&cache_concurrent);
1143                let barrier_clone = Arc::clone(&barrier);
1144                let success_flag_clone = Arc::clone(&success_flag);
1145
1146                thread::spawn(move || {
1147                    // Wait for all threads to be ready
1148                    barrier_clone.wait();
1149
1150                    let start_time = Instant::now();
1151                    for i in 0..operations_per_thread {
1152                        // Each operation references existing cache entry and adds new ones
1153                        let shared_metric = "shared_metric".to_string();
1154                        let new_metric = format!("thread_{}_metric_{}", thread_id, i);
1155                        let names = vec![&shared_metric, &new_metric];
1156
1157                        match fast_legacy_check(&cache_clone, &names) {
1158                            Ok(_) => {}
1159                            Err(_) => {
1160                                success_flag_clone.store(false, Ordering::Relaxed);
1161                                return;
1162                            }
1163                        }
1164
1165                        // If the test takes too long, it likely means deadlock
1166                        if start_time.elapsed() > Duration::from_secs(10) {
1167                            success_flag_clone.store(false, Ordering::Relaxed);
1168                            return;
1169                        }
1170                    }
1171                })
1172            })
1173            .collect();
1174
1175        // Join all threads with timeout
1176        let start_time = Instant::now();
1177        for (i, handle) in handles.into_iter().enumerate() {
1178            let join_result = handle.join();
1179
1180            // Check if we're taking too long (potential deadlock)
1181            if start_time.elapsed() > Duration::from_secs(30) {
1182                panic!("Test timed out - possible deadlock detected!");
1183            }
1184
1185            if join_result.is_err() {
1186                panic!("Thread {} panicked during execution", i);
1187            }
1188        }
1189
1190        // Verify all operations completed successfully
1191        assert!(
1192            success_flag.load(Ordering::Relaxed),
1193            "Some operations failed"
1194        );
1195
1196        // Verify that many new entries were added (proving operations completed)
1197        let final_count = cache_concurrent.len();
1198        assert!(
1199            final_count > 1 + num_threads * operations_per_thread / 2,
1200            "Expected more cache entries, got {}",
1201            final_count
1202        );
1203    }
1204
1205    #[test]
1206    fn test_exec_validation() {
1207        let query_ctx = QueryContext::arc();
1208        let plugins: Plugins = Plugins::new();
1209        plugins.insert(QueryOptions {
1210            disallow_cross_catalog_query: true,
1211        });
1212
1213        let sql = r#"
1214        SELECT * FROM demo;
1215        EXPLAIN SELECT * FROM demo;
1216        CREATE DATABASE test_database;
1217        SHOW DATABASES;
1218        "#;
1219        let stmts = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
1220        assert_eq!(stmts.len(), 4);
1221        for stmt in stmts {
1222            let re = check_permission(plugins.clone(), &stmt, &query_ctx);
1223            re.unwrap();
1224        }
1225
1226        let sql = r#"
1227        SHOW CREATE TABLE demo;
1228        ALTER TABLE demo ADD COLUMN new_col INT;
1229        "#;
1230        let stmts = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
1231        assert_eq!(stmts.len(), 2);
1232        for stmt in stmts {
1233            let re = check_permission(plugins.clone(), &stmt, &query_ctx);
1234            re.unwrap();
1235        }
1236
1237        fn replace_test(template_sql: &str, plugins: Plugins, query_ctx: &QueryContextRef) {
1238            // test right
1239            let right = vec![("", ""), ("", "public."), ("greptime.", "public.")];
1240            for (catalog, schema) in right {
1241                let sql = do_fmt(template_sql, catalog, schema);
1242                do_test(&sql, plugins.clone(), query_ctx, true);
1243            }
1244
1245            let wrong = vec![
1246                ("wrongcatalog.", "public."),
1247                ("wrongcatalog.", "wrongschema."),
1248            ];
1249            for (catalog, schema) in wrong {
1250                let sql = do_fmt(template_sql, catalog, schema);
1251                do_test(&sql, plugins.clone(), query_ctx, false);
1252            }
1253        }
1254
1255        fn do_fmt(template: &str, catalog: &str, schema: &str) -> String {
1256            let vars = HashMap::from([
1257                ("catalog".to_string(), catalog),
1258                ("schema".to_string(), schema),
1259            ]);
1260            template.format(&vars).unwrap()
1261        }
1262
1263        fn do_test(sql: &str, plugins: Plugins, query_ctx: &QueryContextRef, is_ok: bool) {
1264            let stmt = &parse_stmt(sql, &GreptimeDbDialect {}).unwrap()[0];
1265            let re = check_permission(plugins, stmt, query_ctx);
1266            if is_ok {
1267                re.unwrap();
1268            } else {
1269                assert!(re.is_err());
1270            }
1271        }
1272
1273        // test insert
1274        let sql = "INSERT INTO {catalog}{schema}monitor(host) VALUES ('host1');";
1275        replace_test(sql, plugins.clone(), &query_ctx);
1276
1277        // test create table
1278        let sql = r#"CREATE TABLE {catalog}{schema}demo(
1279                            host STRING,
1280                            ts TIMESTAMP,
1281                            TIME INDEX (ts),
1282                            PRIMARY KEY(host)
1283                        ) engine=mito;"#;
1284        replace_test(sql, plugins.clone(), &query_ctx);
1285
1286        // test drop table
1287        let sql = "DROP TABLE {catalog}{schema}demo;";
1288        replace_test(sql, plugins.clone(), &query_ctx);
1289
1290        // test show tables
1291        let sql = "SHOW TABLES FROM public";
1292        let stmt = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
1293        check_permission(plugins.clone(), &stmt[0], &query_ctx).unwrap();
1294
1295        let sql = "SHOW TABLES FROM private";
1296        let stmt = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
1297        let re = check_permission(plugins.clone(), &stmt[0], &query_ctx);
1298        assert!(re.is_ok());
1299
1300        // test describe table
1301        let sql = "DESC TABLE {catalog}{schema}demo;";
1302        replace_test(sql, plugins.clone(), &query_ctx);
1303
1304        let comment_flow_cases = [
1305            ("COMMENT ON FLOW my_flow IS 'comment';", true),
1306            ("COMMENT ON FLOW greptime.my_flow IS 'comment';", true),
1307            ("COMMENT ON FLOW wrongcatalog.my_flow IS 'comment';", false),
1308        ];
1309        for (sql, is_ok) in comment_flow_cases {
1310            let stmt = &parse_stmt(sql, &GreptimeDbDialect {}).unwrap()[0];
1311            let result = check_permission(plugins.clone(), stmt, &query_ctx);
1312            assert_eq!(result.is_ok(), is_ok);
1313        }
1314
1315        let show_flow_cases = [
1316            ("SHOW CREATE FLOW my_flow;", true),
1317            ("SHOW CREATE FLOW greptime.my_flow;", true),
1318            ("SHOW CREATE FLOW wrongcatalog.my_flow;", false),
1319        ];
1320        for (sql, is_ok) in show_flow_cases {
1321            let stmt = &parse_stmt(sql, &GreptimeDbDialect {}).unwrap()[0];
1322            let result = check_permission(plugins.clone(), stmt, &query_ctx);
1323            assert_eq!(result.is_ok(), is_ok);
1324        }
1325    }
1326}