diff --git a/src/servers/src/grpc.rs b/src/servers/src/grpc.rs index ef10c7df39..337add7b76 100644 --- a/src/servers/src/grpc.rs +++ b/src/servers/src/grpc.rs @@ -55,6 +55,7 @@ type TonicResult = std::result::Result; pub struct GrpcServer { shutdown_tx: Mutex>>, request_handler: Arc, + user_provider: Option, /// Handler for Prometheus-compatible PromQL queries. Only present for frontend server. prometheus_handler: Option, @@ -72,12 +73,13 @@ impl GrpcServer { ) -> Self { let request_handler = Arc::new(GreptimeRequestHandler::new( query_handler, - user_provider, + user_provider.clone(), runtime, )); Self { shutdown_tx: Mutex::new(None), request_handler, + user_provider, prometheus_handler, serve_state: Mutex::new(None), } @@ -99,7 +101,10 @@ impl GrpcServer { &self, handler: PrometheusHandlerRef, ) -> PrometheusGatewayServer { - PrometheusGatewayServer::new(PrometheusGatewayService::new(handler)) + PrometheusGatewayServer::new(PrometheusGatewayService::new( + handler, + self.user_provider.clone(), + )) } pub async fn wait_for_serve(&self) -> Result<()> { diff --git a/src/servers/src/grpc/handler.rs b/src/servers/src/grpc/handler.rs index d79a25f943..ed6e36ad2e 100644 --- a/src/servers/src/grpc/handler.rs +++ b/src/servers/src/grpc/handler.rs @@ -64,7 +64,7 @@ impl GreptimeRequestHandler { let header = request.header.as_ref(); let query_ctx = create_query_context(header); - let user_info = self.auth(header, &query_ctx).await?; + let user_info = auth(self.user_provider.clone(), header, &query_ctx).await?; query_ctx.set_current_user(user_info); let handler = self.handler.clone(); @@ -96,48 +96,48 @@ impl GreptimeRequestHandler { e })? } +} - async fn auth( - &self, - header: Option<&RequestHeader>, - query_ctx: &QueryContextRef, - ) -> Result> { - let Some(user_provider) = self.user_provider.as_ref() else { - return Ok(None); - }; +pub(crate) async fn auth( + user_provider: Option, + header: Option<&RequestHeader>, + query_ctx: &QueryContextRef, +) -> Result> { + let Some(user_provider) = user_provider else { + return Ok(None); + }; - let auth_scheme = header - .and_then(|header| { - header - .authorization - .as_ref() - .and_then(|x| x.auth_scheme.clone()) - }) - .context(NotFoundAuthHeaderSnafu)?; - - match auth_scheme { - AuthScheme::Basic(Basic { username, password }) => user_provider - .auth( - Identity::UserId(&username, None), - Password::PlainText(password.into()), - query_ctx.current_catalog(), - query_ctx.current_schema(), - ) - .await - .context(AuthSnafu), - AuthScheme::Token(_) => Err(UnsupportedAuthScheme { - name: "Token AuthScheme".to_string(), - }), - } - .map(Some) - .map_err(|e| { - increment_counter!( - METRIC_AUTH_FAILURE, - &[(METRIC_CODE_LABEL, format!("{}", e.status_code()))] - ); - e + let auth_scheme = header + .and_then(|header| { + header + .authorization + .as_ref() + .and_then(|x| x.auth_scheme.clone()) }) + .context(NotFoundAuthHeaderSnafu)?; + + match auth_scheme { + AuthScheme::Basic(Basic { username, password }) => user_provider + .auth( + Identity::UserId(&username, None), + Password::PlainText(password.into()), + query_ctx.current_catalog(), + query_ctx.current_schema(), + ) + .await + .context(AuthSnafu), + AuthScheme::Token(_) => Err(UnsupportedAuthScheme { + name: "Token AuthScheme".to_string(), + }), } + .map(Some) + .map_err(|e| { + increment_counter!( + METRIC_AUTH_FAILURE, + &[(METRIC_CODE_LABEL, format!("{}", e.status_code()))] + ); + e + }) } pub(crate) fn create_query_context(header: Option<&RequestHeader>) -> QueryContextRef { diff --git a/src/servers/src/grpc/prom_query_gateway.rs b/src/servers/src/grpc/prom_query_gateway.rs index 6a93acab0a..1cae3b3a45 100644 --- a/src/servers/src/grpc/prom_query_gateway.rs +++ b/src/servers/src/grpc/prom_query_gateway.rs @@ -21,6 +21,7 @@ use api::v1::prometheus_gateway_server::PrometheusGateway; use api::v1::promql_request::Promql; use api::v1::{PromqlRequest, PromqlResponse, ResponseHeader}; use async_trait::async_trait; +use auth::UserProviderRef; use common_error::ext::ErrorExt; use common_error::status_code::StatusCode; use common_telemetry::timer; @@ -32,7 +33,7 @@ use snafu::OptionExt; use tonic::{Request, Response}; use crate::error::InvalidQuerySnafu; -use crate::grpc::handler::create_query_context; +use crate::grpc::handler::{auth, create_query_context}; use crate::grpc::TonicResult; use crate::prometheus::{ retrieve_metric_name_and_result_type, PrometheusHandlerRef, PrometheusJsonResponse, @@ -40,6 +41,7 @@ use crate::prometheus::{ pub struct PrometheusGatewayService { handler: PrometheusHandlerRef, + user_provider: Option, } #[async_trait] @@ -74,9 +76,13 @@ impl PrometheusGateway for PrometheusGatewayService { } }; - let query_context = create_query_context(inner.header.as_ref()); + let header = inner.header.as_ref(); + let query_ctx = create_query_context(header); + let user_info = auth(self.user_provider.clone(), header, &query_ctx).await?; + query_ctx.set_current_user(user_info); + let json_response = self - .handle_inner(prom_query, query_context, is_range_query) + .handle_inner(prom_query, query_ctx, is_range_query) .await; let json_bytes = serde_json::to_string(&json_response).unwrap().into_bytes(); @@ -94,8 +100,11 @@ impl PrometheusGateway for PrometheusGatewayService { } impl PrometheusGatewayService { - pub fn new(handler: PrometheusHandlerRef) -> Self { - Self { handler } + pub fn new(handler: PrometheusHandlerRef, user_provider: Option) -> Self { + Self { + handler, + user_provider, + } } async fn handle_inner(