diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index 54e8ea5e2d..89f10d08d6 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -58,7 +58,9 @@ use query::query_engine::DescribeResult; use query::{QueryEngineFactory, QueryEngineRef}; use servers::error as server_error; use servers::error::{ExecuteQuerySnafu, ParsePromQLSnafu}; -use servers::interceptor::{SqlQueryInterceptor, SqlQueryInterceptorRef}; +use servers::interceptor::{ + PromQueryInterceptor, PromQueryInterceptorRef, SqlQueryInterceptor, SqlQueryInterceptorRef, +}; use servers::prom::PromHandler; use servers::query_handler::grpc::{GrpcQueryHandler, GrpcQueryHandlerRef}; use servers::query_handler::sql::SqlQueryHandler; @@ -571,17 +573,25 @@ impl PromHandler for Instance { query: &PromQuery, query_ctx: QueryContextRef, ) -> server_error::Result { + let interceptor = self + .plugins + .get::>(); + interceptor.pre_execute(query, query_ctx.clone())?; + let stmt = QueryLanguageParser::parse_promql(query).with_context(|_| ParsePromQLSnafu { query: query.clone(), })?; - self.statement_executor - .execute_stmt(stmt, query_ctx) + let output = self + .statement_executor + .execute_stmt(stmt, query_ctx.clone()) .await .map_err(BoxedError::new) .with_context(|_| ExecuteQuerySnafu { query: format!("{query:?}"), - }) + })?; + + Ok(interceptor.post_execute(output, query_ctx)?) } } diff --git a/src/servers/src/interceptor.rs b/src/servers/src/interceptor.rs index 7471c45547..197819e340 100644 --- a/src/servers/src/interceptor.rs +++ b/src/servers/src/interceptor.rs @@ -18,6 +18,7 @@ use std::sync::Arc; use api::v1::greptime_request::Request; use common_error::prelude::ErrorExt; use common_query::Output; +use query::parser::PromQuery; use query::plan::LogicalPlan; use session::context::QueryContextRef; use sql::statements::statement::Statement; @@ -186,3 +187,62 @@ where } } } + +/// PromQueryInterceptor can track life cycle of a prometheus request and customize or +/// abort its execution at given point. +pub trait PromQueryInterceptor { + type Error: ErrorExt; + + /// Called before request is actually executed. + fn pre_execute( + &self, + _query: &PromQuery, + _query_ctx: QueryContextRef, + ) -> Result<(), Self::Error> { + Ok(()) + } + + /// Called after execution finished. The implementation can modify the + /// output if needed. + fn post_execute( + &self, + output: Output, + _query_ctx: QueryContextRef, + ) -> Result { + Ok(output) + } +} + +pub type PromQueryInterceptorRef = + Arc + Send + Sync + 'static>; + +impl PromQueryInterceptor for Option> +where + E: ErrorExt, +{ + type Error = E; + + fn pre_execute( + &self, + query: &PromQuery, + query_ctx: QueryContextRef, + ) -> Result<(), Self::Error> { + if let Some(this) = self { + this.pre_execute(query, query_ctx) + } else { + Ok(()) + } + } + + fn post_execute( + &self, + output: Output, + query_ctx: QueryContextRef, + ) -> Result { + if let Some(this) = self { + this.post_execute(output, query_ctx) + } else { + Ok(output) + } + } +} diff --git a/src/servers/tests/interceptor.rs b/src/servers/tests/interceptor.rs index 9c9f93e7aa..ab721e03eb 100644 --- a/src/servers/tests/interceptor.rs +++ b/src/servers/tests/interceptor.rs @@ -17,8 +17,10 @@ use std::sync::Arc; use api::v1::greptime_request::Request; use api::v1::{InsertRequest, InsertRequests}; -use servers::error::{self, NotSupportedSnafu, Result}; -use servers::interceptor::{GrpcQueryInterceptor, SqlQueryInterceptor}; +use common_query::Output; +use query::parser::PromQuery; +use servers::error::{self, InternalSnafu, NotSupportedSnafu, Result}; +use servers::interceptor::{GrpcQueryInterceptor, PromQueryInterceptor, SqlQueryInterceptor}; use session::context::{QueryContext, QueryContextRef}; use snafu::ensure; @@ -85,3 +87,48 @@ fn test_grpc_interceptor() { let req = Request::Inserts(InsertRequests::default()); GrpcQueryInterceptor::pre_execute(&di, &req, ctx).unwrap(); } + +impl PromQueryInterceptor for NoopInterceptor { + type Error = error::Error; + + fn pre_execute( + &self, + query: &PromQuery, + _query_ctx: QueryContextRef, + ) -> std::result::Result<(), Self::Error> { + match query.query.as_str() { + "up" => InternalSnafu { err_msg: "test" }.fail(), + _ => Ok(()), + } + } + + fn post_execute( + &self, + output: Output, + _query_ctx: QueryContextRef, + ) -> std::result::Result { + match output { + Output::AffectedRows(1) => Ok(Output::AffectedRows(2)), + _ => Ok(output), + } + } +} + +#[test] +fn test_prom_interceptor() { + let di = NoopInterceptor; + let ctx = Arc::new(QueryContext::new()); + + let query = PromQuery { + query: "up".to_string(), + ..Default::default() + }; + + let fail = PromQueryInterceptor::pre_execute(&di, &query, ctx.clone()); + assert!(fail.is_err()); + + let output = Output::AffectedRows(1); + let two = PromQueryInterceptor::post_execute(&di, output, ctx); + assert!(two.is_ok()); + matches!(two.unwrap(), Output::AffectedRows(2)); +}