From ec817f6877eb6a0b72b9f36fc54a79b612792424 Mon Sep 17 00:00:00 2001 From: shuiyisong <113876041+shuiyisong@users.noreply.github.com> Date: Thu, 28 Aug 2025 12:00:45 +0800 Subject: [PATCH] fix: gRPC auth (#6827) * fix: internal service Signed-off-by: shuiyisong * refactor: gRPC auth Signed-off-by: shuiyisong * chore: add permission check for bulk ingest Signed-off-by: shuiyisong * chore: remove unused grpc auth middleware Signed-off-by: shuiyisong * chore: extract header function Signed-off-by: shuiyisong * refactor: extract common code and add auth to otel arrow api Signed-off-by: shuiyisong * chore: rename utils to context_auth Signed-off-by: shuiyisong * test: otel arrow auth Signed-off-by: shuiyisong * chore: add support for old auth value Signed-off-by: shuiyisong --------- Signed-off-by: shuiyisong --- src/auth/src/permission.rs | 1 + src/client/src/database.rs | 4 +- src/frontend/src/instance/grpc.rs | 13 ++ src/frontend/src/server.rs | 7 +- src/servers/src/grpc.rs | 2 +- src/servers/src/grpc/authorize.rs | 199 --------------------- src/servers/src/grpc/context_auth.rs | 163 +++++++++++++++++ src/servers/src/grpc/flight.rs | 54 ++---- src/servers/src/grpc/greptime_handler.rs | 117 ++---------- src/servers/src/grpc/prom_query_gateway.rs | 3 +- src/servers/src/http/authorize.rs | 16 +- src/servers/src/interceptor.rs | 22 +++ src/servers/src/otel_arrow.rs | 28 ++- src/servers/src/query_handler/grpc.rs | 4 +- src/servers/tests/mod.rs | 1 + tests-integration/src/test_util.rs | 4 +- tests-integration/tests/grpc.rs | 83 +++++++++ 17 files changed, 357 insertions(+), 364 deletions(-) delete mode 100644 src/servers/src/grpc/authorize.rs create mode 100644 src/servers/src/grpc/context_auth.rs diff --git a/src/auth/src/permission.rs b/src/auth/src/permission.rs index 6c33a766a6..0f0650d282 100644 --- a/src/auth/src/permission.rs +++ b/src/auth/src/permission.rs @@ -32,6 +32,7 @@ pub enum PermissionReq<'a> { PromStoreRead, Otlp, LogWrite, + BulkInsert, } #[derive(Debug)] diff --git a/src/client/src/database.rs b/src/client/src/database.rs index 11736c1996..2b608d0010 100644 --- a/src/client/src/database.rs +++ b/src/client/src/database.rs @@ -473,8 +473,8 @@ impl Database { }) = &self.ctx.auth_header { let encoded = BASE64_STANDARD.encode(format!("{username}:{password}")); - let value = - MetadataValue::from_str(&encoded).context(InvalidTonicMetadataValueSnafu)?; + let value = MetadataValue::from_str(&format!("Basic {encoded}")) + .context(InvalidTonicMetadataValueSnafu)?; request.metadata_mut().insert("x-greptime-auth", value); } diff --git a/src/frontend/src/instance/grpc.rs b/src/frontend/src/instance/grpc.rs index c2ac853eb6..5a269267d0 100644 --- a/src/frontend/src/instance/grpc.rs +++ b/src/frontend/src/instance/grpc.rs @@ -247,6 +247,7 @@ impl GrpcQueryHandler for Instance { table_ref: &mut Option, decoder: &mut FlightDecoder, data: FlightData, + ctx: QueryContextRef, ) -> Result { let table = if let Some(table) = table_ref { table.clone() @@ -268,6 +269,18 @@ impl GrpcQueryHandler for Instance { table }; + let interceptor_ref = self.plugins.get::>(); + let interceptor = interceptor_ref.as_ref(); + interceptor.pre_bulk_insert(table.clone(), ctx.clone())?; + + self.plugins + .get::() + .as_ref() + .check_permission(ctx.current_user(), PermissionReq::BulkInsert) + .context(PermissionSnafu)?; + + // do we check limit for bulk insert? + self.inserter .handle_bulk_insert(table, decoder, data) .await diff --git a/src/frontend/src/server.rs b/src/frontend/src/server.rs index ab33a5b500..73c86deabc 100644 --- a/src/frontend/src/server.rs +++ b/src/frontend/src/server.rs @@ -170,10 +170,13 @@ where .name(name) .database_handler(greptime_request_handler.clone()) .prometheus_handler(self.instance.clone(), user_provider.clone()) - .otel_arrow_handler(OtelArrowServiceHandler::new(self.instance.clone())) + .otel_arrow_handler(OtelArrowServiceHandler::new( + self.instance.clone(), + user_provider.clone(), + )) .flight_handler(Arc::new(greptime_request_handler)); - let grpc_server = if external { + let grpc_server = if !external { let frontend_grpc_handler = FrontendGrpcHandler::new(self.instance.process_manager().clone()); grpc_server.frontend_grpc_handler(frontend_grpc_handler) diff --git a/src/servers/src/grpc.rs b/src/servers/src/grpc.rs index 4205b7debc..0310abc0e8 100644 --- a/src/servers/src/grpc.rs +++ b/src/servers/src/grpc.rs @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -mod authorize; pub mod builder; mod cancellation; +pub mod context_auth; mod database; pub mod flight; pub mod frontend_grpc_handler; diff --git a/src/servers/src/grpc/authorize.rs b/src/servers/src/grpc/authorize.rs deleted file mode 100644 index 9fd236f54c..0000000000 --- a/src/servers/src/grpc/authorize.rs +++ /dev/null @@ -1,199 +0,0 @@ -// Copyright 2023 Greptime Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::pin::Pin; -use std::result::Result as StdResult; -use std::task::{Context, Poll}; - -use auth::UserProviderRef; -use session::context::{Channel, QueryContext}; -use tonic::body::Body; -use tonic::server::NamedService; -use tower::{Layer, Service}; - -use crate::http::authorize::{extract_catalog_and_schema, extract_username_and_password}; - -#[derive(Clone)] -pub struct AuthMiddlewareLayer { - user_provider: Option, -} - -impl Layer for AuthMiddlewareLayer { - type Service = AuthMiddleware; - - fn layer(&self, service: S) -> Self::Service { - AuthMiddleware { - inner: service, - user_provider: self.user_provider.clone(), - } - } -} - -/// This middleware is responsible for authenticating the user and setting the user -/// info in the request extension. -/// -/// Detail: Authorization information is passed in through the Authorization request -/// header. -#[derive(Clone)] -pub struct AuthMiddleware { - inner: S, - user_provider: Option, -} - -impl NamedService for AuthMiddleware -where - S: NamedService, -{ - const NAME: &'static str = S::NAME; -} - -type BoxFuture<'a, T> = Pin + Send + 'a>>; - -impl Service> for AuthMiddleware -where - S: Service, Response = http::Response> + Clone + Send + 'static, - S::Future: Send + 'static, -{ - type Response = S::Response; - type Error = S::Error; - type Future = BoxFuture<'static, StdResult>; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready(cx) - } - - fn call(&mut self, mut req: http::Request) -> Self::Future { - // This is necessary because tonic internally uses `tower::buffer::Buffer`. - // See https://github.com/tower-rs/tower/issues/547#issuecomment-767629149 - // for details on why this is necessary. - let clone = self.inner.clone(); - let mut inner = std::mem::replace(&mut self.inner, clone); - - let user_provider = self.user_provider.clone(); - - Box::pin(async move { - if let Err(status) = do_auth(&mut req, user_provider).await { - return Ok(status.into_http()); - } - inner.call(req).await - }) - } -} - -async fn do_auth( - req: &mut http::Request, - user_provider: Option, -) -> Result<(), tonic::Status> { - let (catalog, schema) = extract_catalog_and_schema(req); - - let query_ctx = QueryContext::with_channel(&catalog, &schema, Channel::Grpc); - - let Some(user_provider) = user_provider else { - query_ctx.set_current_user(auth::userinfo_by_name(None)); - let _ = req.extensions_mut().insert(query_ctx); - return Ok(()); - }; - - let (username, password) = extract_username_and_password(req) - .map_err(|e| tonic::Status::invalid_argument(e.to_string()))?; - - let id = auth::Identity::UserId(&username, None); - let pwd = auth::Password::PlainText(password); - - let user_info = user_provider - .auth(id, pwd, &catalog, &schema) - .await - .map_err(|e| tonic::Status::unauthenticated(e.to_string()))?; - - query_ctx.set_current_user(user_info); - let _ = req.extensions_mut().insert(query_ctx); - - Ok(()) -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use auth::tests::MockUserProvider; - use base64::engine::general_purpose::STANDARD; - use base64::Engine; - use headers::Header; - use hyper::Request; - use session::context::QueryContext; - - use crate::grpc::authorize::do_auth; - use crate::http::header::GreptimeDbName; - - #[tokio::test] - async fn test_do_auth_with_user_provider() { - let user_provider = Arc::new(MockUserProvider::default()); - - // auth success - let authorization_val = format!("Basic {}", STANDARD.encode("greptime:greptime")); - let mut req = Request::new(()); - req.headers_mut() - .insert("authorization", authorization_val.parse().unwrap()); - - let auth_result = do_auth(&mut req, Some(user_provider.clone())).await; - - assert!(auth_result.is_ok()); - check_req(&req, "greptime", "public", "greptime"); - - // auth failed, err: user not exist. - let authorization_val = format!("Basic {}", STANDARD.encode("greptime2:greptime2")); - let mut req = Request::new(()); - req.headers_mut() - .insert("authorization", authorization_val.parse().unwrap()); - - let auth_result = do_auth(&mut req, Some(user_provider)).await; - assert!(auth_result.is_err()); - } - - #[tokio::test] - async fn test_do_auth_without_user_provider() { - let mut req = Request::new(()); - req.headers_mut() - .insert("authentication", "pwd".parse().unwrap()); - let auth_result = do_auth(&mut req, None).await; - assert!(auth_result.is_ok()); - check_req(&req, "greptime", "public", "greptime"); - - let mut req = Request::new(()); - let auth_result = do_auth(&mut req, None).await; - assert!(auth_result.is_ok()); - check_req(&req, "greptime", "public", "greptime"); - - let mut req = Request::new(()); - req.headers_mut() - .insert(GreptimeDbName::name(), "catalog-schema".parse().unwrap()); - let auth_result = do_auth(&mut req, None).await; - assert!(auth_result.is_ok()); - check_req(&req, "catalog", "schema", "greptime"); - } - - fn check_req( - req: &Request, - expected_catalog: &str, - expected_schema: &str, - expected_user_name: &str, - ) { - let ctx = req.extensions().get::().unwrap(); - assert_eq!(expected_catalog, ctx.current_catalog()); - assert_eq!(expected_schema, ctx.current_schema()); - - let user_info = ctx.current_user(); - assert_eq!(expected_user_name, user_info.username()); - } -} diff --git a/src/servers/src/grpc/context_auth.rs b/src/servers/src/grpc/context_auth.rs new file mode 100644 index 0000000000..f1e1bdd7e3 --- /dev/null +++ b/src/servers/src/grpc/context_auth.rs @@ -0,0 +1,163 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use api::v1::auth_header::AuthScheme; +use api::v1::{AuthHeader, RequestHeader}; +use auth::{Identity, Password, UserInfoRef, UserProviderRef}; +use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; +use common_catalog::parse_catalog_and_schema_from_db_string; +use common_error::ext::ErrorExt; +use session::context::{Channel, QueryContextBuilder, QueryContextRef}; +use snafu::{OptionExt, ResultExt}; +use tonic::metadata::MetadataMap; +use tonic::Status; + +use crate::error::Error::UnsupportedAuthScheme; +use crate::error::{AuthSnafu, InvalidParameterSnafu, NotFoundAuthHeaderSnafu, Result}; +use crate::grpc::TonicResult; +use crate::http::header::constants::GREPTIME_DB_HEADER_NAME; +use crate::http::AUTHORIZATION_HEADER; +use crate::metrics::METRIC_AUTH_FAILURE; + +/// Create a query context from the grpc metadata. +pub fn create_query_context_from_grpc_metadata( + headers: &MetadataMap, +) -> TonicResult { + let (catalog, schema) = if let Some(db) = extract_header(headers, &[GREPTIME_DB_HEADER_NAME])? { + parse_catalog_and_schema_from_db_string(db) + } else { + ( + DEFAULT_CATALOG_NAME.to_string(), + DEFAULT_SCHEMA_NAME.to_string(), + ) + }; + + Ok(Arc::new( + QueryContextBuilder::default() + .current_catalog(catalog) + .current_schema(schema) + .channel(Channel::Grpc) + .build(), + )) +} + +/// Helper function to extract a header from the metadata map. +/// Can be multiple keys, and the first one found will be returned. +pub fn extract_header<'a>(headers: &'a MetadataMap, keys: &[&str]) -> TonicResult> { + let mut value = None; + for key in keys { + if let Some(v) = headers.get(*key) { + value = Some(v); + break; + } + } + + let Some(v) = value else { + return Ok(None); + }; + let Ok(v) = std::str::from_utf8(v.as_bytes()) else { + return Err(InvalidParameterSnafu { + reason: "expect valid UTF-8 value", + } + .build() + .into()); + }; + Ok(Some(v)) +} + +/// Helper function to extract the header from the metadata and authenticate the user. +pub async fn check_auth( + user_provider: Option, + headers: &MetadataMap, + query_ctx: QueryContextRef, +) -> TonicResult { + if user_provider.is_none() { + return Ok(true); + } + + let auth_schema = extract_header( + headers, + &[AUTHORIZATION_HEADER, http::header::AUTHORIZATION.as_str()], + )? + .map(|x| { + if x.len() > 5 && x[0..5].eq_ignore_ascii_case("Basic") { + x.try_into() + } else { + // compatible with old version + format!("Basic {}", x).as_str().try_into() + } + }) + .transpose()? + .map(|x: crate::http::authorize::AuthScheme| x.into()); + + let auth_schema = auth_schema.context(NotFoundAuthHeaderSnafu)?; + let header = RequestHeader { + authorization: Some(AuthHeader { + auth_scheme: Some(auth_schema), + }), + catalog: query_ctx.current_catalog().to_string(), + schema: query_ctx.current_schema(), + ..Default::default() + }; + + match auth(user_provider, Some(&header), &query_ctx).await { + Ok(user_info) => { + query_ctx.set_current_user(user_info); + Ok(true) + } + Err(_) => Err(Status::unauthenticated("auth failed")), + } +} + +/// Authenticate the user based on the header and query context. +pub async fn auth( + user_provider: Option, + header: Option<&RequestHeader>, + query_ctx: &QueryContextRef, +) -> Result { + let Some(user_provider) = user_provider else { + return Ok(auth::userinfo_by_name(None)); + }; + + let auth_scheme = header + .and_then(|header| { + header + .authorization + .as_ref() + .and_then(|x| x.auth_scheme.clone()) + }) + .context(NotFoundAuthHeaderSnafu)?; + + match auth_scheme { + AuthScheme::Basic(api::v1::Basic { username, password }) => user_provider + .auth( + Identity::UserId(&username, None), + Password::PlainText(password.into()), + query_ctx.current_catalog(), + &query_ctx.current_schema(), + ) + .await + .context(AuthSnafu), + AuthScheme::Token(_) => Err(UnsupportedAuthScheme { + name: "Token AuthScheme".to_string(), + }), + } + .inspect_err(|e| { + METRIC_AUTH_FAILURE + .with_label_values(&[e.status_code().as_ref()]) + .inc(); + }) +} diff --git a/src/servers/src/grpc/flight.rs b/src/servers/src/grpc/flight.rs index 0264a3b46b..d799ef3a69 100644 --- a/src/servers/src/grpc/flight.rs +++ b/src/servers/src/grpc/flight.rs @@ -26,8 +26,6 @@ use arrow_flight::{ }; use async_trait::async_trait; use bytes::Bytes; -use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; -use common_catalog::parse_catalog_and_schema_from_db_string; use common_grpc::flight::do_put::{DoPutMetadata, DoPutResponse}; use common_grpc::flight::{FlightEncoder, FlightMessage}; use common_query::{Output, OutputData}; @@ -46,9 +44,7 @@ use tonic::{Request, Response, Status, Streaming}; use crate::error::{InvalidParameterSnafu, ParseJsonSnafu, Result, ToJsonSnafu}; pub use crate::grpc::flight::stream::FlightRecordBatchStream; use crate::grpc::greptime_handler::{get_request_type, GreptimeRequestHandler}; -use crate::grpc::{FlightCompression, TonicResult}; -use crate::http::header::constants::GREPTIME_DB_HEADER_NAME; -use crate::http::AUTHORIZATION_HEADER; +use crate::grpc::{context_auth, FlightCompression, TonicResult}; use crate::{error, hint_headers}; pub type TonicStream = Pin> + Send + 'static>>; @@ -189,7 +185,6 @@ impl FlightCraft for GreptimeRequestHandler { let ticket = request.into_inner().ticket; let request = GreptimeRequest::decode(ticket.as_ref()).context(error::InvalidFlightTicketSnafu)?; - let query_ctx = QueryContext::arc(); // The Grpc protocol pass query by Flight. It needs to be wrapped under a span, in order to record stream let span = info_span!( @@ -204,7 +199,7 @@ impl FlightCraft for GreptimeRequestHandler { output, TracingContext::from_current_span(), flight_compression, - query_ctx, + QueryContext::arc(), ); Ok(Response::new(stream)) } @@ -218,34 +213,20 @@ impl FlightCraft for GreptimeRequestHandler { ) -> TonicResult>> { let (headers, _, stream) = request.into_parts(); - let header = |key: &str| -> TonicResult> { - let Some(v) = headers.get(key) else { - return Ok(None); - }; - let Ok(v) = std::str::from_utf8(v.as_bytes()) else { - return Err(InvalidParameterSnafu { - reason: "expect valid UTF-8 value", - } - .build() - .into()); - }; - Ok(Some(v)) - }; - - let username_and_password = header(AUTHORIZATION_HEADER)?; - let db = header(GREPTIME_DB_HEADER_NAME)?; - if !self.validate_auth(username_and_password, db).await? { - return Err(Status::unauthenticated("auth failed")); - } + let query_ctx = context_auth::create_query_context_from_grpc_metadata(&headers)?; + context_auth::check_auth(self.user_provider.clone(), &headers, query_ctx.clone()).await?; const MAX_PENDING_RESPONSES: usize = 32; let (tx, rx) = mpsc::channel::>(MAX_PENDING_RESPONSES); let stream = PutRecordBatchRequestStream { flight_data_stream: stream, - state: PutRecordBatchRequestStreamState::Init(db.map(ToString::to_string)), + state: PutRecordBatchRequestStreamState::Init( + query_ctx.current_catalog().to_string(), + query_ctx.current_schema(), + ), }; - self.put_record_batches(stream, tx).await; + self.put_record_batches(stream, tx, query_ctx).await; let response = ReceiverStream::new(rx) .and_then(|response| { @@ -292,7 +273,7 @@ pub(crate) struct PutRecordBatchRequestStream { } enum PutRecordBatchRequestStreamState { - Init(Option), + Init(String, String), Started(TableName), } @@ -319,21 +300,12 @@ impl Stream for PutRecordBatchRequestStream { let poll = ready!(self.flight_data_stream.poll_next_unpin(cx)); let result = match &mut self.state { - PutRecordBatchRequestStreamState::Init(db) => match poll { + PutRecordBatchRequestStreamState::Init(catalog, schema) => match poll { Some(Ok(mut flight_data)) => { let flight_descriptor = flight_data.flight_descriptor.take(); let result = if let Some(descriptor) = flight_descriptor { - let table_name = extract_table_name(descriptor).map(|x| { - let (catalog, schema) = if let Some(db) = db { - parse_catalog_and_schema_from_db_string(db) - } else { - ( - DEFAULT_CATALOG_NAME.to_string(), - DEFAULT_SCHEMA_NAME.to_string(), - ) - }; - TableName::new(catalog, schema, x) - }); + let table_name = extract_table_name(descriptor) + .map(|x| TableName::new(catalog.clone(), schema.clone(), x)); let table_name = match table_name { Ok(table_name) => table_name, Err(e) => return Poll::Ready(Some(Err(e.into()))), diff --git a/src/servers/src/grpc/greptime_handler.rs b/src/servers/src/grpc/greptime_handler.rs index e2eebbb80c..2ec2e4e617 100644 --- a/src/servers/src/grpc/greptime_handler.rs +++ b/src/servers/src/grpc/greptime_handler.rs @@ -18,11 +18,8 @@ use std::str::FromStr; use std::time::Instant; use api::helper::request_type; -use api::v1::auth_header::AuthScheme; -use api::v1::{AuthHeader, Basic, GreptimeRequest, RequestHeader}; -use auth::{Identity, Password, UserInfoRef, UserProviderRef}; -use base64::prelude::BASE64_STANDARD; -use base64::Engine; +use api::v1::{GreptimeRequest, RequestHeader}; +use auth::UserProviderRef; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_catalog::parse_catalog_and_schema_from_db_string; use common_error::ext::ErrorExt; @@ -37,28 +34,24 @@ use common_telemetry::tracing_context::{FutureExt, TracingContext}; use common_telemetry::{debug, error, tracing, warn}; use common_time::timezone::parse_timezone; use futures_util::StreamExt; -use session::context::{Channel, QueryContext, QueryContextBuilder, QueryContextRef}; +use session::context::{Channel, QueryContextBuilder, QueryContextRef}; use session::hints::READ_PREFERENCE_HINT; use snafu::{OptionExt, ResultExt}; use table::TableRef; use tokio::sync::mpsc; use tokio::sync::mpsc::error::TrySendError; -use crate::error::Error::UnsupportedAuthScheme; -use crate::error::{ - AuthSnafu, InvalidAuthHeaderInvalidUtf8ValueSnafu, InvalidBase64ValueSnafu, InvalidQuerySnafu, - JoinTaskSnafu, NotFoundAuthHeaderSnafu, Result, UnknownHintSnafu, -}; +use crate::error::{InvalidQuerySnafu, JoinTaskSnafu, Result, UnknownHintSnafu}; use crate::grpc::flight::{PutRecordBatchRequest, PutRecordBatchRequestStream}; -use crate::grpc::{FlightCompression, TonicResult}; +use crate::grpc::{context_auth, FlightCompression, TonicResult}; use crate::metrics; -use crate::metrics::{METRIC_AUTH_FAILURE, METRIC_SERVER_GRPC_DB_REQUEST_TIMER}; +use crate::metrics::METRIC_SERVER_GRPC_DB_REQUEST_TIMER; use crate::query_handler::grpc::ServerGrpcQueryHandlerRef; #[derive(Clone)] pub struct GreptimeRequestHandler { handler: ServerGrpcQueryHandlerRef, - user_provider: Option, + pub(crate) user_provider: Option, runtime: Option, pub(crate) flight_compression: FlightCompression, } @@ -90,7 +83,7 @@ impl GreptimeRequestHandler { let header = request.header.as_ref(); let query_ctx = create_query_context(Channel::Grpc, header, hints)?; - let user_info = auth(self.user_provider.clone(), header, &query_ctx).await?; + let user_info = context_auth::auth(self.user_provider.clone(), header, &query_ctx).await?; query_ctx.set_current_user(user_info); let handler = self.handler.clone(); @@ -143,6 +136,7 @@ impl GreptimeRequestHandler { &self, mut stream: PutRecordBatchRequestStream, result_sender: mpsc::Sender>, + query_ctx: QueryContextRef, ) { let handler = self.handler.clone(); let runtime = self @@ -170,7 +164,7 @@ impl GreptimeRequestHandler { let timer = metrics::GRPC_BULK_INSERT_ELAPSED.start_timer(); let result = handler - .put_record_batch(&table_name, &mut table_ref, &mut decoder, data) + .put_record_batch(&table_name, &mut table_ref, &mut decoder, data, query_ctx.clone()) .await .inspect_err(|e| error!(e; "Failed to handle flight record batches")); timer.observe_duration(); @@ -185,58 +179,6 @@ impl GreptimeRequestHandler { } }); } - - pub(crate) async fn validate_auth( - &self, - username_and_password: Option<&str>, - db: Option<&str>, - ) -> Result { - if self.user_provider.is_none() { - return Ok(true); - } - - let username_and_password = username_and_password.context(NotFoundAuthHeaderSnafu)?; - let username_and_password = BASE64_STANDARD - .decode(username_and_password) - .context(InvalidBase64ValueSnafu) - .and_then(|x| String::from_utf8(x).context(InvalidAuthHeaderInvalidUtf8ValueSnafu))?; - - let mut split = username_and_password.splitn(2, ':'); - let (username, password) = match (split.next(), split.next()) { - (Some(username), Some(password)) => (username, password), - (Some(username), None) => (username, ""), - (None, None) => return Ok(false), - _ => unreachable!(), // because this iterator won't yield Some after None - }; - - let (catalog, schema) = if let Some(db) = db { - parse_catalog_and_schema_from_db_string(db) - } else { - ( - DEFAULT_CATALOG_NAME.to_string(), - DEFAULT_SCHEMA_NAME.to_string(), - ) - }; - let header = RequestHeader { - authorization: Some(AuthHeader { - auth_scheme: Some(AuthScheme::Basic(Basic { - username: username.to_string(), - password: password.to_string(), - })), - }), - catalog, - schema, - ..Default::default() - }; - - Ok(auth( - self.user_provider.clone(), - Some(&header), - &QueryContext::arc(), - ) - .await - .is_ok()) - } } pub fn get_request_type(request: &GreptimeRequest) -> &'static str { @@ -247,45 +189,6 @@ pub fn get_request_type(request: &GreptimeRequest) -> &'static str { .unwrap_or_default() } -pub(crate) async fn auth( - user_provider: Option, - header: Option<&RequestHeader>, - query_ctx: &QueryContextRef, -) -> Result { - let Some(user_provider) = user_provider else { - return Ok(auth::userinfo_by_name(None)); - }; - - let auth_scheme = header - .and_then(|header| { - header - .authorization - .as_ref() - .and_then(|x| x.auth_scheme.clone()) - }) - .context(NotFoundAuthHeaderSnafu)?; - - match auth_scheme { - AuthScheme::Basic(Basic { username, password }) => user_provider - .auth( - Identity::UserId(&username, None), - Password::PlainText(password.into()), - query_ctx.current_catalog(), - &query_ctx.current_schema(), - ) - .await - .context(AuthSnafu), - AuthScheme::Token(_) => Err(UnsupportedAuthScheme { - name: "Token AuthScheme".to_string(), - }), - } - .inspect_err(|e| { - METRIC_AUTH_FAILURE - .with_label_values(&[e.status_code().as_ref()]) - .inc(); - }) -} - /// Creates a new `QueryContext` from the provided request header and extensions. /// Strongly recommend setting an appropriate channel, as this is very helpful for statistics. pub(crate) fn create_query_context( diff --git a/src/servers/src/grpc/prom_query_gateway.rs b/src/servers/src/grpc/prom_query_gateway.rs index 3ee7902f2c..a8565737f9 100644 --- a/src/servers/src/grpc/prom_query_gateway.rs +++ b/src/servers/src/grpc/prom_query_gateway.rs @@ -32,7 +32,8 @@ use snafu::OptionExt; use tonic::{Request, Response}; use crate::error::InvalidQuerySnafu; -use crate::grpc::greptime_handler::{auth, create_query_context}; +use crate::grpc::context_auth::auth; +use crate::grpc::greptime_handler::create_query_context; use crate::grpc::TonicResult; use crate::http::prometheus::{retrieve_metric_name_and_result_type, PrometheusJsonResponse}; use crate::prometheus_handler::PrometheusHandlerRef; diff --git a/src/servers/src/http/authorize.rs b/src/servers/src/http/authorize.rs index f04f1676f0..e8b42603d3 100644 --- a/src/servers/src/http/authorize.rs +++ b/src/servers/src/http/authorize.rs @@ -13,13 +13,14 @@ // limitations under the License. use ::auth::UserProviderRef; +use api::v1::Basic; use axum::extract::{Request, State}; use axum::http::{self, StatusCode}; use axum::middleware::Next; use axum::response::{IntoResponse, Response}; use base64::prelude::BASE64_STANDARD; use base64::Engine; -use common_base::secrets::SecretString; +use common_base::secrets::{ExposeSecret, SecretString}; use common_catalog::consts::DEFAULT_SCHEMA_NAME; use common_catalog::parse_catalog_and_schema_from_db_string; use common_error::ext::ErrorExt; @@ -240,6 +241,19 @@ impl TryFrom<&str> for AuthScheme { } } +impl From for api::v1::auth_header::AuthScheme { + fn from(value: AuthScheme) -> Self { + match value { + AuthScheme::Basic(username, password) => { + api::v1::auth_header::AuthScheme::Basic(Basic { + username, + password: password.expose_secret().to_string(), + }) + } + } + } +} + type Credential<'a> = &'a str; fn auth_header(req: &Request) -> Result { diff --git a/src/servers/src/interceptor.rs b/src/servers/src/interceptor.rs index c366fd2efd..48f44409fd 100644 --- a/src/servers/src/interceptor.rs +++ b/src/servers/src/interceptor.rs @@ -26,6 +26,7 @@ use log_query::LogQuery; use query::parser::PromQuery; use session::context::QueryContextRef; use sql::statements::statement::Statement; +use table::TableRef; use vrl::value::Value; /// SqlQueryInterceptor can track life cycle of a sql query and customize or @@ -148,6 +149,15 @@ pub trait GrpcQueryInterceptor { Ok(()) } + /// Called before bulk insert is executed. + fn pre_bulk_insert( + &self, + _table: TableRef, + _query_ctx: QueryContextRef, + ) -> Result<(), Self::Error> { + Ok(()) + } + /// Called after execution finished. The implementation can modify the /// output if needed. fn post_execute( @@ -180,6 +190,18 @@ where } } + fn pre_bulk_insert( + &self, + _table: TableRef, + _query_ctx: QueryContextRef, + ) -> Result<(), Self::Error> { + if let Some(this) = self { + this.pre_bulk_insert(_table, _query_ctx) + } else { + Ok(()) + } + } + fn post_execute( &self, output: Output, diff --git a/src/servers/src/otel_arrow.rs b/src/servers/src/otel_arrow.rs index 808bef5c02..edf2e8e89b 100644 --- a/src/servers/src/otel_arrow.rs +++ b/src/servers/src/otel_arrow.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use auth::UserProviderRef; use common_error::ext::ErrorExt; use common_error::status_code::status_to_tonic_code; use common_telemetry::error; @@ -19,19 +20,25 @@ use futures::SinkExt; use otel_arrow_rust::proto::opentelemetry::arrow::v1::arrow_metrics_service_server::ArrowMetricsService; use otel_arrow_rust::proto::opentelemetry::arrow::v1::{BatchArrowRecords, BatchStatus}; use otel_arrow_rust::Consumer; -use session::context::QueryContext; use tonic::metadata::{Entry, MetadataValue}; use tonic::service::Interceptor; use tonic::{Request, Response, Status, Streaming}; use crate::error; +use crate::grpc::context_auth; use crate::query_handler::OpenTelemetryProtocolHandlerRef; -pub struct OtelArrowServiceHandler(pub T); +pub struct OtelArrowServiceHandler { + handler: T, + user_provider: Option, +} impl OtelArrowServiceHandler { - pub fn new(handler: T) -> Self { - Self(handler) + pub fn new(handler: T, user_provider: Option) -> Self { + Self { + handler, + user_provider, + } } } @@ -43,9 +50,14 @@ impl ArrowMetricsService for OtelArrowServiceHandler>, ) -> Result, Status> { let (mut sender, receiver) = futures::channel::mpsc::channel(100); - let mut incoming_requests = request.into_inner(); - let handler = self.0.clone(); - let query_context = QueryContext::arc(); + + let (headers, _, mut incoming_requests) = request.into_parts(); + + let query_ctx = context_auth::create_query_context_from_grpc_metadata(&headers)?; + context_auth::check_auth(self.user_provider.clone(), &headers, query_ctx.clone()).await?; + + let handler = self.handler.clone(); + // handles incoming requests common_runtime::spawn_global(async move { let mut consumer = Consumer::default(); @@ -87,7 +99,7 @@ impl ArrowMetricsService for OtelArrowServiceHandler, decoder: &mut FlightDecoder, flight_data: FlightData, + ctx: QueryContextRef, ) -> std::result::Result; } @@ -81,9 +82,10 @@ where table_ref: &mut Option, decoder: &mut FlightDecoder, data: FlightData, + ctx: QueryContextRef, ) -> Result { self.0 - .put_record_batch(table_name, table_ref, decoder, data) + .put_record_batch(table_name, table_ref, decoder, data, ctx) .await .map_err(BoxedError::new) .context(error::ExecuteGrpcRequestSnafu) diff --git a/src/servers/tests/mod.rs b/src/servers/tests/mod.rs index 461c448a84..e38ac5abc9 100644 --- a/src/servers/tests/mod.rs +++ b/src/servers/tests/mod.rs @@ -168,6 +168,7 @@ impl GrpcQueryHandler for DummyInstance { _table_ref: &mut Option, _decoder: &mut FlightDecoder, _data: FlightData, + _ctx: QueryContextRef, ) -> std::result::Result { unimplemented!() } diff --git a/tests-integration/src/test_util.rs b/tests-integration/src/test_util.rs index 9dbf0e92f8..9396205737 100644 --- a/tests-integration/src/test_util.rs +++ b/tests-integration/src/test_util.rs @@ -44,6 +44,7 @@ use servers::grpc::{FlightCompression, GrpcOptions, GrpcServer, GrpcServerConfig use servers::http::{HttpOptions, HttpServerBuilder, PromValidationMode}; use servers::metrics_handler::MetricsHandler; use servers::mysql::server::{MysqlServer, MysqlSpawnConfig, MysqlSpawnRef}; +use servers::otel_arrow::OtelArrowServiceHandler; use servers::postgres::PostgresServer; use servers::query_handler::grpc::ServerGrpcQueryHandlerAdapter; use servers::query_handler::sql::{ServerSqlQueryHandlerAdapter, SqlQueryHandler}; @@ -592,7 +593,8 @@ pub async fn setup_grpc_server_with( let grpc_builder = GrpcServerBuilder::new(grpc_config.clone(), runtime) .database_handler(greptime_request_handler) .flight_handler(flight_handler) - .prometheus_handler(fe_instance_ref.clone(), user_provider) + .prometheus_handler(fe_instance_ref.clone(), user_provider.clone()) + .otel_arrow_handler(OtelArrowServiceHandler::new(fe_instance_ref, user_provider)) .with_tls_config(grpc_config.tls) .unwrap(); diff --git a/tests-integration/tests/grpc.rs b/tests-integration/tests/grpc.rs index 95cbeb70dd..7f03e84cf5 100644 --- a/tests-integration/tests/grpc.rs +++ b/tests-integration/tests/grpc.rs @@ -28,6 +28,8 @@ use common_recordbatch::RecordBatches; use common_runtime::runtime::{BuilderBuild, RuntimeTrait}; use common_runtime::Runtime; use common_test_util::find_workspace_path; +use otel_arrow_rust::proto::opentelemetry::arrow::v1::arrow_metrics_service_client::ArrowMetricsServiceClient; +use otel_arrow_rust::proto::opentelemetry::arrow::v1::BatchArrowRecords; use servers::grpc::builder::GrpcServerBuilder; use servers::grpc::GrpcServerConfig; use servers::http::prometheus::{ @@ -39,6 +41,8 @@ use servers::tls::{TlsMode, TlsOption}; use tests_integration::test_util::{ setup_grpc_server, setup_grpc_server_with, setup_grpc_server_with_user_provider, StorageType, }; +use tonic::metadata::MetadataValue; +use tonic::Request; #[macro_export] macro_rules! grpc_test { @@ -73,6 +77,7 @@ macro_rules! grpc_tests { test_invalid_dbname, test_auto_create_table, test_auto_create_table_with_hints, + test_otel_arrow_auth, test_insert_and_select, test_dbname, test_grpc_message_size_ok, @@ -276,6 +281,84 @@ pub async fn test_grpc_auth(store_type: StorageType) { let _ = fe_grpc_server.shutdown().await; } +pub async fn test_otel_arrow_auth(store_type: StorageType) { + let user_provider = user_provider_from_option( + &"static_user_provider:cmd:greptime_user=greptime_pwd".to_string(), + ) + .unwrap(); + let (_db, fe_grpc_server) = setup_grpc_server_with_user_provider( + store_type, + "test_otel_arrow_auth", + Some(user_provider), + ) + .await; + let addr = fe_grpc_server.bind_addr().unwrap().to_string(); + + let mut client = ArrowMetricsServiceClient::connect(format!("http://{}", addr)) + .await + .unwrap(); + + let batch_arrow_records = BatchArrowRecords { + batch_id: 1, + arrow_payloads: vec![], + headers: vec![], + }; + + // test without auth + { + let records = batch_arrow_records.clone(); + let stream = futures::stream::once(async { records }); + let request = Request::new(stream); + let response = client.arrow_metrics(request).await; + assert!(response.is_err()); + let error = response.unwrap_err(); + assert_eq!(error.code(), tonic::Code::Unauthenticated); + } + // test auth + { + let records = batch_arrow_records.clone(); + let stream = futures::stream::once(async { records }); + let mut request = Request::new(stream); + request.metadata_mut().insert( + "authorization", + MetadataValue::from_static("Basic Z3JlcHRpbWVfdXNlcjpncmVwdGltZV9wd2Q="), // greptime_user:greptime_pwd base64 encoded + ); + let response = client.arrow_metrics(request).await; + assert!(response.is_ok()); + + let mut response_stream = response.unwrap().into_inner(); + let resp = response_stream.message().await; + assert!(resp.is_err()); + let error = resp.unwrap_err(); + assert_eq!( + error.message(), + "Failed to handle otel-arrow request, error message: Batch is empty" + ); + } + // test old auth + { + let stream = futures::stream::once(async { batch_arrow_records }); + let mut request = Request::new(stream); + request.metadata_mut().insert( + "authorization", + MetadataValue::from_static("Z3JlcHRpbWVfdXNlcjpncmVwdGltZV9wd2Q="), // greptime_user:greptime_pwd base64 encoded + ); + let response = client.arrow_metrics(request).await; + assert!(response.is_ok()); + + let mut response_stream = response.unwrap().into_inner(); + let resp = response_stream.message().await; + assert!(resp.is_err()); + let error = resp.unwrap_err(); + assert_eq!( + error.message(), + "Failed to handle otel-arrow request, error message: Batch is empty" + ); + } + + let _ = fe_grpc_server.shutdown().await; +} + pub async fn test_auto_create_table(store_type: StorageType) { let (_db, fe_grpc_server) = setup_grpc_server(store_type, "test_auto_create_table").await; let addr = fe_grpc_server.bind_addr().unwrap().to_string();