fix: gRPC auth (#6827)

* fix: internal service

Signed-off-by: shuiyisong <xixing.sys@gmail.com>

* refactor: gRPC auth

Signed-off-by: shuiyisong <xixing.sys@gmail.com>

* chore: add permission check for bulk ingest

Signed-off-by: shuiyisong <xixing.sys@gmail.com>

* chore: remove unused grpc auth middleware

Signed-off-by: shuiyisong <xixing.sys@gmail.com>

* chore: extract header function

Signed-off-by: shuiyisong <xixing.sys@gmail.com>

* refactor: extract common code and add auth to otel arrow api

Signed-off-by: shuiyisong <xixing.sys@gmail.com>

* chore: rename utils to context_auth

Signed-off-by: shuiyisong <xixing.sys@gmail.com>

* test: otel arrow auth

Signed-off-by: shuiyisong <xixing.sys@gmail.com>

* chore: add support for old auth value

Signed-off-by: shuiyisong <xixing.sys@gmail.com>

---------

Signed-off-by: shuiyisong <xixing.sys@gmail.com>
This commit is contained in:
shuiyisong
2025-08-28 12:00:45 +08:00
committed by GitHub
parent 32e73dad12
commit ec817f6877
17 changed files with 357 additions and 364 deletions

View File

@@ -32,6 +32,7 @@ pub enum PermissionReq<'a> {
PromStoreRead,
Otlp,
LogWrite,
BulkInsert,
}
#[derive(Debug)]

View File

@@ -473,8 +473,8 @@ impl Database {
}) = &self.ctx.auth_header
{
let encoded = BASE64_STANDARD.encode(format!("{username}:{password}"));
let value =
MetadataValue::from_str(&encoded).context(InvalidTonicMetadataValueSnafu)?;
let value = MetadataValue::from_str(&format!("Basic {encoded}"))
.context(InvalidTonicMetadataValueSnafu)?;
request.metadata_mut().insert("x-greptime-auth", value);
}

View File

@@ -247,6 +247,7 @@ impl GrpcQueryHandler for Instance {
table_ref: &mut Option<TableRef>,
decoder: &mut FlightDecoder,
data: FlightData,
ctx: QueryContextRef,
) -> Result<AffectedRows> {
let table = if let Some(table) = table_ref {
table.clone()
@@ -268,6 +269,18 @@ impl GrpcQueryHandler for Instance {
table
};
let interceptor_ref = self.plugins.get::<GrpcQueryInterceptorRef<Error>>();
let interceptor = interceptor_ref.as_ref();
interceptor.pre_bulk_insert(table.clone(), ctx.clone())?;
self.plugins
.get::<PermissionCheckerRef>()
.as_ref()
.check_permission(ctx.current_user(), PermissionReq::BulkInsert)
.context(PermissionSnafu)?;
// do we check limit for bulk insert?
self.inserter
.handle_bulk_insert(table, decoder, data)
.await

View File

@@ -170,10 +170,13 @@ where
.name(name)
.database_handler(greptime_request_handler.clone())
.prometheus_handler(self.instance.clone(), user_provider.clone())
.otel_arrow_handler(OtelArrowServiceHandler::new(self.instance.clone()))
.otel_arrow_handler(OtelArrowServiceHandler::new(
self.instance.clone(),
user_provider.clone(),
))
.flight_handler(Arc::new(greptime_request_handler));
let grpc_server = if external {
let grpc_server = if !external {
let frontend_grpc_handler =
FrontendGrpcHandler::new(self.instance.process_manager().clone());
grpc_server.frontend_grpc_handler(frontend_grpc_handler)

View File

@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
mod authorize;
pub mod builder;
mod cancellation;
pub mod context_auth;
mod database;
pub mod flight;
pub mod frontend_grpc_handler;

View File

@@ -1,199 +0,0 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::pin::Pin;
use std::result::Result as StdResult;
use std::task::{Context, Poll};
use auth::UserProviderRef;
use session::context::{Channel, QueryContext};
use tonic::body::Body;
use tonic::server::NamedService;
use tower::{Layer, Service};
use crate::http::authorize::{extract_catalog_and_schema, extract_username_and_password};
#[derive(Clone)]
pub struct AuthMiddlewareLayer {
user_provider: Option<UserProviderRef>,
}
impl<S> Layer<S> for AuthMiddlewareLayer {
type Service = AuthMiddleware<S>;
fn layer(&self, service: S) -> Self::Service {
AuthMiddleware {
inner: service,
user_provider: self.user_provider.clone(),
}
}
}
/// This middleware is responsible for authenticating the user and setting the user
/// info in the request extension.
///
/// Detail: Authorization information is passed in through the Authorization request
/// header.
#[derive(Clone)]
pub struct AuthMiddleware<S> {
inner: S,
user_provider: Option<UserProviderRef>,
}
impl<S> NamedService for AuthMiddleware<S>
where
S: NamedService,
{
const NAME: &'static str = S::NAME;
}
type BoxFuture<'a, T> = Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
impl<S> Service<http::Request<Body>> for AuthMiddleware<S>
where
S: Service<http::Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, StdResult<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<StdResult<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: http::Request<Body>) -> Self::Future {
// This is necessary because tonic internally uses `tower::buffer::Buffer`.
// See https://github.com/tower-rs/tower/issues/547#issuecomment-767629149
// for details on why this is necessary.
let clone = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, clone);
let user_provider = self.user_provider.clone();
Box::pin(async move {
if let Err(status) = do_auth(&mut req, user_provider).await {
return Ok(status.into_http());
}
inner.call(req).await
})
}
}
async fn do_auth<T>(
req: &mut http::Request<T>,
user_provider: Option<UserProviderRef>,
) -> Result<(), tonic::Status> {
let (catalog, schema) = extract_catalog_and_schema(req);
let query_ctx = QueryContext::with_channel(&catalog, &schema, Channel::Grpc);
let Some(user_provider) = user_provider else {
query_ctx.set_current_user(auth::userinfo_by_name(None));
let _ = req.extensions_mut().insert(query_ctx);
return Ok(());
};
let (username, password) = extract_username_and_password(req)
.map_err(|e| tonic::Status::invalid_argument(e.to_string()))?;
let id = auth::Identity::UserId(&username, None);
let pwd = auth::Password::PlainText(password);
let user_info = user_provider
.auth(id, pwd, &catalog, &schema)
.await
.map_err(|e| tonic::Status::unauthenticated(e.to_string()))?;
query_ctx.set_current_user(user_info);
let _ = req.extensions_mut().insert(query_ctx);
Ok(())
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use auth::tests::MockUserProvider;
use base64::engine::general_purpose::STANDARD;
use base64::Engine;
use headers::Header;
use hyper::Request;
use session::context::QueryContext;
use crate::grpc::authorize::do_auth;
use crate::http::header::GreptimeDbName;
#[tokio::test]
async fn test_do_auth_with_user_provider() {
let user_provider = Arc::new(MockUserProvider::default());
// auth success
let authorization_val = format!("Basic {}", STANDARD.encode("greptime:greptime"));
let mut req = Request::new(());
req.headers_mut()
.insert("authorization", authorization_val.parse().unwrap());
let auth_result = do_auth(&mut req, Some(user_provider.clone())).await;
assert!(auth_result.is_ok());
check_req(&req, "greptime", "public", "greptime");
// auth failed, err: user not exist.
let authorization_val = format!("Basic {}", STANDARD.encode("greptime2:greptime2"));
let mut req = Request::new(());
req.headers_mut()
.insert("authorization", authorization_val.parse().unwrap());
let auth_result = do_auth(&mut req, Some(user_provider)).await;
assert!(auth_result.is_err());
}
#[tokio::test]
async fn test_do_auth_without_user_provider() {
let mut req = Request::new(());
req.headers_mut()
.insert("authentication", "pwd".parse().unwrap());
let auth_result = do_auth(&mut req, None).await;
assert!(auth_result.is_ok());
check_req(&req, "greptime", "public", "greptime");
let mut req = Request::new(());
let auth_result = do_auth(&mut req, None).await;
assert!(auth_result.is_ok());
check_req(&req, "greptime", "public", "greptime");
let mut req = Request::new(());
req.headers_mut()
.insert(GreptimeDbName::name(), "catalog-schema".parse().unwrap());
let auth_result = do_auth(&mut req, None).await;
assert!(auth_result.is_ok());
check_req(&req, "catalog", "schema", "greptime");
}
fn check_req<T>(
req: &Request<T>,
expected_catalog: &str,
expected_schema: &str,
expected_user_name: &str,
) {
let ctx = req.extensions().get::<QueryContext>().unwrap();
assert_eq!(expected_catalog, ctx.current_catalog());
assert_eq!(expected_schema, ctx.current_schema());
let user_info = ctx.current_user();
assert_eq!(expected_user_name, user_info.username());
}
}

View File

@@ -0,0 +1,163 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use api::v1::auth_header::AuthScheme;
use api::v1::{AuthHeader, RequestHeader};
use auth::{Identity, Password, UserInfoRef, UserProviderRef};
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_catalog::parse_catalog_and_schema_from_db_string;
use common_error::ext::ErrorExt;
use session::context::{Channel, QueryContextBuilder, QueryContextRef};
use snafu::{OptionExt, ResultExt};
use tonic::metadata::MetadataMap;
use tonic::Status;
use crate::error::Error::UnsupportedAuthScheme;
use crate::error::{AuthSnafu, InvalidParameterSnafu, NotFoundAuthHeaderSnafu, Result};
use crate::grpc::TonicResult;
use crate::http::header::constants::GREPTIME_DB_HEADER_NAME;
use crate::http::AUTHORIZATION_HEADER;
use crate::metrics::METRIC_AUTH_FAILURE;
/// Create a query context from the grpc metadata.
pub fn create_query_context_from_grpc_metadata(
headers: &MetadataMap,
) -> TonicResult<QueryContextRef> {
let (catalog, schema) = if let Some(db) = extract_header(headers, &[GREPTIME_DB_HEADER_NAME])? {
parse_catalog_and_schema_from_db_string(db)
} else {
(
DEFAULT_CATALOG_NAME.to_string(),
DEFAULT_SCHEMA_NAME.to_string(),
)
};
Ok(Arc::new(
QueryContextBuilder::default()
.current_catalog(catalog)
.current_schema(schema)
.channel(Channel::Grpc)
.build(),
))
}
/// Helper function to extract a header from the metadata map.
/// Can be multiple keys, and the first one found will be returned.
pub fn extract_header<'a>(headers: &'a MetadataMap, keys: &[&str]) -> TonicResult<Option<&'a str>> {
let mut value = None;
for key in keys {
if let Some(v) = headers.get(*key) {
value = Some(v);
break;
}
}
let Some(v) = value else {
return Ok(None);
};
let Ok(v) = std::str::from_utf8(v.as_bytes()) else {
return Err(InvalidParameterSnafu {
reason: "expect valid UTF-8 value",
}
.build()
.into());
};
Ok(Some(v))
}
/// Helper function to extract the header from the metadata and authenticate the user.
pub async fn check_auth(
user_provider: Option<UserProviderRef>,
headers: &MetadataMap,
query_ctx: QueryContextRef,
) -> TonicResult<bool> {
if user_provider.is_none() {
return Ok(true);
}
let auth_schema = extract_header(
headers,
&[AUTHORIZATION_HEADER, http::header::AUTHORIZATION.as_str()],
)?
.map(|x| {
if x.len() > 5 && x[0..5].eq_ignore_ascii_case("Basic") {
x.try_into()
} else {
// compatible with old version
format!("Basic {}", x).as_str().try_into()
}
})
.transpose()?
.map(|x: crate::http::authorize::AuthScheme| x.into());
let auth_schema = auth_schema.context(NotFoundAuthHeaderSnafu)?;
let header = RequestHeader {
authorization: Some(AuthHeader {
auth_scheme: Some(auth_schema),
}),
catalog: query_ctx.current_catalog().to_string(),
schema: query_ctx.current_schema(),
..Default::default()
};
match auth(user_provider, Some(&header), &query_ctx).await {
Ok(user_info) => {
query_ctx.set_current_user(user_info);
Ok(true)
}
Err(_) => Err(Status::unauthenticated("auth failed")),
}
}
/// Authenticate the user based on the header and query context.
pub async fn auth(
user_provider: Option<UserProviderRef>,
header: Option<&RequestHeader>,
query_ctx: &QueryContextRef,
) -> Result<UserInfoRef> {
let Some(user_provider) = user_provider else {
return Ok(auth::userinfo_by_name(None));
};
let auth_scheme = header
.and_then(|header| {
header
.authorization
.as_ref()
.and_then(|x| x.auth_scheme.clone())
})
.context(NotFoundAuthHeaderSnafu)?;
match auth_scheme {
AuthScheme::Basic(api::v1::Basic { username, password }) => user_provider
.auth(
Identity::UserId(&username, None),
Password::PlainText(password.into()),
query_ctx.current_catalog(),
&query_ctx.current_schema(),
)
.await
.context(AuthSnafu),
AuthScheme::Token(_) => Err(UnsupportedAuthScheme {
name: "Token AuthScheme".to_string(),
}),
}
.inspect_err(|e| {
METRIC_AUTH_FAILURE
.with_label_values(&[e.status_code().as_ref()])
.inc();
})
}

View File

@@ -26,8 +26,6 @@ use arrow_flight::{
};
use async_trait::async_trait;
use bytes::Bytes;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_catalog::parse_catalog_and_schema_from_db_string;
use common_grpc::flight::do_put::{DoPutMetadata, DoPutResponse};
use common_grpc::flight::{FlightEncoder, FlightMessage};
use common_query::{Output, OutputData};
@@ -46,9 +44,7 @@ use tonic::{Request, Response, Status, Streaming};
use crate::error::{InvalidParameterSnafu, ParseJsonSnafu, Result, ToJsonSnafu};
pub use crate::grpc::flight::stream::FlightRecordBatchStream;
use crate::grpc::greptime_handler::{get_request_type, GreptimeRequestHandler};
use crate::grpc::{FlightCompression, TonicResult};
use crate::http::header::constants::GREPTIME_DB_HEADER_NAME;
use crate::http::AUTHORIZATION_HEADER;
use crate::grpc::{context_auth, FlightCompression, TonicResult};
use crate::{error, hint_headers};
pub type TonicStream<T> = Pin<Box<dyn Stream<Item = TonicResult<T>> + Send + 'static>>;
@@ -189,7 +185,6 @@ impl FlightCraft for GreptimeRequestHandler {
let ticket = request.into_inner().ticket;
let request =
GreptimeRequest::decode(ticket.as_ref()).context(error::InvalidFlightTicketSnafu)?;
let query_ctx = QueryContext::arc();
// The Grpc protocol pass query by Flight. It needs to be wrapped under a span, in order to record stream
let span = info_span!(
@@ -204,7 +199,7 @@ impl FlightCraft for GreptimeRequestHandler {
output,
TracingContext::from_current_span(),
flight_compression,
query_ctx,
QueryContext::arc(),
);
Ok(Response::new(stream))
}
@@ -218,34 +213,20 @@ impl FlightCraft for GreptimeRequestHandler {
) -> TonicResult<Response<TonicStream<PutResult>>> {
let (headers, _, stream) = request.into_parts();
let header = |key: &str| -> TonicResult<Option<&str>> {
let Some(v) = headers.get(key) else {
return Ok(None);
};
let Ok(v) = std::str::from_utf8(v.as_bytes()) else {
return Err(InvalidParameterSnafu {
reason: "expect valid UTF-8 value",
}
.build()
.into());
};
Ok(Some(v))
};
let username_and_password = header(AUTHORIZATION_HEADER)?;
let db = header(GREPTIME_DB_HEADER_NAME)?;
if !self.validate_auth(username_and_password, db).await? {
return Err(Status::unauthenticated("auth failed"));
}
let query_ctx = context_auth::create_query_context_from_grpc_metadata(&headers)?;
context_auth::check_auth(self.user_provider.clone(), &headers, query_ctx.clone()).await?;
const MAX_PENDING_RESPONSES: usize = 32;
let (tx, rx) = mpsc::channel::<TonicResult<DoPutResponse>>(MAX_PENDING_RESPONSES);
let stream = PutRecordBatchRequestStream {
flight_data_stream: stream,
state: PutRecordBatchRequestStreamState::Init(db.map(ToString::to_string)),
state: PutRecordBatchRequestStreamState::Init(
query_ctx.current_catalog().to_string(),
query_ctx.current_schema(),
),
};
self.put_record_batches(stream, tx).await;
self.put_record_batches(stream, tx, query_ctx).await;
let response = ReceiverStream::new(rx)
.and_then(|response| {
@@ -292,7 +273,7 @@ pub(crate) struct PutRecordBatchRequestStream {
}
enum PutRecordBatchRequestStreamState {
Init(Option<String>),
Init(String, String),
Started(TableName),
}
@@ -319,21 +300,12 @@ impl Stream for PutRecordBatchRequestStream {
let poll = ready!(self.flight_data_stream.poll_next_unpin(cx));
let result = match &mut self.state {
PutRecordBatchRequestStreamState::Init(db) => match poll {
PutRecordBatchRequestStreamState::Init(catalog, schema) => match poll {
Some(Ok(mut flight_data)) => {
let flight_descriptor = flight_data.flight_descriptor.take();
let result = if let Some(descriptor) = flight_descriptor {
let table_name = extract_table_name(descriptor).map(|x| {
let (catalog, schema) = if let Some(db) = db {
parse_catalog_and_schema_from_db_string(db)
} else {
(
DEFAULT_CATALOG_NAME.to_string(),
DEFAULT_SCHEMA_NAME.to_string(),
)
};
TableName::new(catalog, schema, x)
});
let table_name = extract_table_name(descriptor)
.map(|x| TableName::new(catalog.clone(), schema.clone(), x));
let table_name = match table_name {
Ok(table_name) => table_name,
Err(e) => return Poll::Ready(Some(Err(e.into()))),

View File

@@ -18,11 +18,8 @@ use std::str::FromStr;
use std::time::Instant;
use api::helper::request_type;
use api::v1::auth_header::AuthScheme;
use api::v1::{AuthHeader, Basic, GreptimeRequest, RequestHeader};
use auth::{Identity, Password, UserInfoRef, UserProviderRef};
use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use api::v1::{GreptimeRequest, RequestHeader};
use auth::UserProviderRef;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_catalog::parse_catalog_and_schema_from_db_string;
use common_error::ext::ErrorExt;
@@ -37,28 +34,24 @@ use common_telemetry::tracing_context::{FutureExt, TracingContext};
use common_telemetry::{debug, error, tracing, warn};
use common_time::timezone::parse_timezone;
use futures_util::StreamExt;
use session::context::{Channel, QueryContext, QueryContextBuilder, QueryContextRef};
use session::context::{Channel, QueryContextBuilder, QueryContextRef};
use session::hints::READ_PREFERENCE_HINT;
use snafu::{OptionExt, ResultExt};
use table::TableRef;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TrySendError;
use crate::error::Error::UnsupportedAuthScheme;
use crate::error::{
AuthSnafu, InvalidAuthHeaderInvalidUtf8ValueSnafu, InvalidBase64ValueSnafu, InvalidQuerySnafu,
JoinTaskSnafu, NotFoundAuthHeaderSnafu, Result, UnknownHintSnafu,
};
use crate::error::{InvalidQuerySnafu, JoinTaskSnafu, Result, UnknownHintSnafu};
use crate::grpc::flight::{PutRecordBatchRequest, PutRecordBatchRequestStream};
use crate::grpc::{FlightCompression, TonicResult};
use crate::grpc::{context_auth, FlightCompression, TonicResult};
use crate::metrics;
use crate::metrics::{METRIC_AUTH_FAILURE, METRIC_SERVER_GRPC_DB_REQUEST_TIMER};
use crate::metrics::METRIC_SERVER_GRPC_DB_REQUEST_TIMER;
use crate::query_handler::grpc::ServerGrpcQueryHandlerRef;
#[derive(Clone)]
pub struct GreptimeRequestHandler {
handler: ServerGrpcQueryHandlerRef,
user_provider: Option<UserProviderRef>,
pub(crate) user_provider: Option<UserProviderRef>,
runtime: Option<Runtime>,
pub(crate) flight_compression: FlightCompression,
}
@@ -90,7 +83,7 @@ impl GreptimeRequestHandler {
let header = request.header.as_ref();
let query_ctx = create_query_context(Channel::Grpc, header, hints)?;
let user_info = auth(self.user_provider.clone(), header, &query_ctx).await?;
let user_info = context_auth::auth(self.user_provider.clone(), header, &query_ctx).await?;
query_ctx.set_current_user(user_info);
let handler = self.handler.clone();
@@ -143,6 +136,7 @@ impl GreptimeRequestHandler {
&self,
mut stream: PutRecordBatchRequestStream,
result_sender: mpsc::Sender<TonicResult<DoPutResponse>>,
query_ctx: QueryContextRef,
) {
let handler = self.handler.clone();
let runtime = self
@@ -170,7 +164,7 @@ impl GreptimeRequestHandler {
let timer = metrics::GRPC_BULK_INSERT_ELAPSED.start_timer();
let result = handler
.put_record_batch(&table_name, &mut table_ref, &mut decoder, data)
.put_record_batch(&table_name, &mut table_ref, &mut decoder, data, query_ctx.clone())
.await
.inspect_err(|e| error!(e; "Failed to handle flight record batches"));
timer.observe_duration();
@@ -185,58 +179,6 @@ impl GreptimeRequestHandler {
}
});
}
pub(crate) async fn validate_auth(
&self,
username_and_password: Option<&str>,
db: Option<&str>,
) -> Result<bool> {
if self.user_provider.is_none() {
return Ok(true);
}
let username_and_password = username_and_password.context(NotFoundAuthHeaderSnafu)?;
let username_and_password = BASE64_STANDARD
.decode(username_and_password)
.context(InvalidBase64ValueSnafu)
.and_then(|x| String::from_utf8(x).context(InvalidAuthHeaderInvalidUtf8ValueSnafu))?;
let mut split = username_and_password.splitn(2, ':');
let (username, password) = match (split.next(), split.next()) {
(Some(username), Some(password)) => (username, password),
(Some(username), None) => (username, ""),
(None, None) => return Ok(false),
_ => unreachable!(), // because this iterator won't yield Some after None
};
let (catalog, schema) = if let Some(db) = db {
parse_catalog_and_schema_from_db_string(db)
} else {
(
DEFAULT_CATALOG_NAME.to_string(),
DEFAULT_SCHEMA_NAME.to_string(),
)
};
let header = RequestHeader {
authorization: Some(AuthHeader {
auth_scheme: Some(AuthScheme::Basic(Basic {
username: username.to_string(),
password: password.to_string(),
})),
}),
catalog,
schema,
..Default::default()
};
Ok(auth(
self.user_provider.clone(),
Some(&header),
&QueryContext::arc(),
)
.await
.is_ok())
}
}
pub fn get_request_type(request: &GreptimeRequest) -> &'static str {
@@ -247,45 +189,6 @@ pub fn get_request_type(request: &GreptimeRequest) -> &'static str {
.unwrap_or_default()
}
pub(crate) async fn auth(
user_provider: Option<UserProviderRef>,
header: Option<&RequestHeader>,
query_ctx: &QueryContextRef,
) -> Result<UserInfoRef> {
let Some(user_provider) = user_provider else {
return Ok(auth::userinfo_by_name(None));
};
let auth_scheme = header
.and_then(|header| {
header
.authorization
.as_ref()
.and_then(|x| x.auth_scheme.clone())
})
.context(NotFoundAuthHeaderSnafu)?;
match auth_scheme {
AuthScheme::Basic(Basic { username, password }) => user_provider
.auth(
Identity::UserId(&username, None),
Password::PlainText(password.into()),
query_ctx.current_catalog(),
&query_ctx.current_schema(),
)
.await
.context(AuthSnafu),
AuthScheme::Token(_) => Err(UnsupportedAuthScheme {
name: "Token AuthScheme".to_string(),
}),
}
.inspect_err(|e| {
METRIC_AUTH_FAILURE
.with_label_values(&[e.status_code().as_ref()])
.inc();
})
}
/// Creates a new `QueryContext` from the provided request header and extensions.
/// Strongly recommend setting an appropriate channel, as this is very helpful for statistics.
pub(crate) fn create_query_context(

View File

@@ -32,7 +32,8 @@ use snafu::OptionExt;
use tonic::{Request, Response};
use crate::error::InvalidQuerySnafu;
use crate::grpc::greptime_handler::{auth, create_query_context};
use crate::grpc::context_auth::auth;
use crate::grpc::greptime_handler::create_query_context;
use crate::grpc::TonicResult;
use crate::http::prometheus::{retrieve_metric_name_and_result_type, PrometheusJsonResponse};
use crate::prometheus_handler::PrometheusHandlerRef;

View File

@@ -13,13 +13,14 @@
// limitations under the License.
use ::auth::UserProviderRef;
use api::v1::Basic;
use axum::extract::{Request, State};
use axum::http::{self, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use common_base::secrets::SecretString;
use common_base::secrets::{ExposeSecret, SecretString};
use common_catalog::consts::DEFAULT_SCHEMA_NAME;
use common_catalog::parse_catalog_and_schema_from_db_string;
use common_error::ext::ErrorExt;
@@ -240,6 +241,19 @@ impl TryFrom<&str> for AuthScheme {
}
}
impl From<AuthScheme> for api::v1::auth_header::AuthScheme {
fn from(value: AuthScheme) -> Self {
match value {
AuthScheme::Basic(username, password) => {
api::v1::auth_header::AuthScheme::Basic(Basic {
username,
password: password.expose_secret().to_string(),
})
}
}
}
}
type Credential<'a> = &'a str;
fn auth_header<B>(req: &Request<B>) -> Result<AuthScheme> {

View File

@@ -26,6 +26,7 @@ use log_query::LogQuery;
use query::parser::PromQuery;
use session::context::QueryContextRef;
use sql::statements::statement::Statement;
use table::TableRef;
use vrl::value::Value;
/// SqlQueryInterceptor can track life cycle of a sql query and customize or
@@ -148,6 +149,15 @@ pub trait GrpcQueryInterceptor {
Ok(())
}
/// Called before bulk insert is executed.
fn pre_bulk_insert(
&self,
_table: TableRef,
_query_ctx: QueryContextRef,
) -> Result<(), Self::Error> {
Ok(())
}
/// Called after execution finished. The implementation can modify the
/// output if needed.
fn post_execute(
@@ -180,6 +190,18 @@ where
}
}
fn pre_bulk_insert(
&self,
_table: TableRef,
_query_ctx: QueryContextRef,
) -> Result<(), Self::Error> {
if let Some(this) = self {
this.pre_bulk_insert(_table, _query_ctx)
} else {
Ok(())
}
}
fn post_execute(
&self,
output: Output,

View File

@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use auth::UserProviderRef;
use common_error::ext::ErrorExt;
use common_error::status_code::status_to_tonic_code;
use common_telemetry::error;
@@ -19,19 +20,25 @@ use futures::SinkExt;
use otel_arrow_rust::proto::opentelemetry::arrow::v1::arrow_metrics_service_server::ArrowMetricsService;
use otel_arrow_rust::proto::opentelemetry::arrow::v1::{BatchArrowRecords, BatchStatus};
use otel_arrow_rust::Consumer;
use session::context::QueryContext;
use tonic::metadata::{Entry, MetadataValue};
use tonic::service::Interceptor;
use tonic::{Request, Response, Status, Streaming};
use crate::error;
use crate::grpc::context_auth;
use crate::query_handler::OpenTelemetryProtocolHandlerRef;
pub struct OtelArrowServiceHandler<T>(pub T);
pub struct OtelArrowServiceHandler<T> {
handler: T,
user_provider: Option<UserProviderRef>,
}
impl<T> OtelArrowServiceHandler<T> {
pub fn new(handler: T) -> Self {
Self(handler)
pub fn new(handler: T, user_provider: Option<UserProviderRef>) -> Self {
Self {
handler,
user_provider,
}
}
}
@@ -43,9 +50,14 @@ impl ArrowMetricsService for OtelArrowServiceHandler<OpenTelemetryProtocolHandle
request: Request<Streaming<BatchArrowRecords>>,
) -> Result<Response<Self::ArrowMetricsStream>, Status> {
let (mut sender, receiver) = futures::channel::mpsc::channel(100);
let mut incoming_requests = request.into_inner();
let handler = self.0.clone();
let query_context = QueryContext::arc();
let (headers, _, mut incoming_requests) = request.into_parts();
let query_ctx = context_auth::create_query_context_from_grpc_metadata(&headers)?;
context_auth::check_auth(self.user_provider.clone(), &headers, query_ctx.clone()).await?;
let handler = self.handler.clone();
// handles incoming requests
common_runtime::spawn_global(async move {
let mut consumer = Consumer::default();
@@ -87,7 +99,7 @@ impl ArrowMetricsService for OtelArrowServiceHandler<OpenTelemetryProtocolHandle
}
};
// use metric engine by default
if let Err(e) = handler.metrics(request, query_context.clone()).await {
if let Err(e) = handler.metrics(request, query_ctx.clone()).await {
let _ = sender
.send(Err(Status::new(
status_to_tonic_code(e.status_code()),

View File

@@ -49,6 +49,7 @@ pub trait GrpcQueryHandler {
table_ref: &mut Option<TableRef>,
decoder: &mut FlightDecoder,
flight_data: FlightData,
ctx: QueryContextRef,
) -> std::result::Result<AffectedRows, Self::Error>;
}
@@ -81,9 +82,10 @@ where
table_ref: &mut Option<TableRef>,
decoder: &mut FlightDecoder,
data: FlightData,
ctx: QueryContextRef,
) -> Result<AffectedRows> {
self.0
.put_record_batch(table_name, table_ref, decoder, data)
.put_record_batch(table_name, table_ref, decoder, data, ctx)
.await
.map_err(BoxedError::new)
.context(error::ExecuteGrpcRequestSnafu)

View File

@@ -168,6 +168,7 @@ impl GrpcQueryHandler for DummyInstance {
_table_ref: &mut Option<TableRef>,
_decoder: &mut FlightDecoder,
_data: FlightData,
_ctx: QueryContextRef,
) -> std::result::Result<AffectedRows, Self::Error> {
unimplemented!()
}

View File

@@ -44,6 +44,7 @@ use servers::grpc::{FlightCompression, GrpcOptions, GrpcServer, GrpcServerConfig
use servers::http::{HttpOptions, HttpServerBuilder, PromValidationMode};
use servers::metrics_handler::MetricsHandler;
use servers::mysql::server::{MysqlServer, MysqlSpawnConfig, MysqlSpawnRef};
use servers::otel_arrow::OtelArrowServiceHandler;
use servers::postgres::PostgresServer;
use servers::query_handler::grpc::ServerGrpcQueryHandlerAdapter;
use servers::query_handler::sql::{ServerSqlQueryHandlerAdapter, SqlQueryHandler};
@@ -592,7 +593,8 @@ pub async fn setup_grpc_server_with(
let grpc_builder = GrpcServerBuilder::new(grpc_config.clone(), runtime)
.database_handler(greptime_request_handler)
.flight_handler(flight_handler)
.prometheus_handler(fe_instance_ref.clone(), user_provider)
.prometheus_handler(fe_instance_ref.clone(), user_provider.clone())
.otel_arrow_handler(OtelArrowServiceHandler::new(fe_instance_ref, user_provider))
.with_tls_config(grpc_config.tls)
.unwrap();

View File

@@ -28,6 +28,8 @@ use common_recordbatch::RecordBatches;
use common_runtime::runtime::{BuilderBuild, RuntimeTrait};
use common_runtime::Runtime;
use common_test_util::find_workspace_path;
use otel_arrow_rust::proto::opentelemetry::arrow::v1::arrow_metrics_service_client::ArrowMetricsServiceClient;
use otel_arrow_rust::proto::opentelemetry::arrow::v1::BatchArrowRecords;
use servers::grpc::builder::GrpcServerBuilder;
use servers::grpc::GrpcServerConfig;
use servers::http::prometheus::{
@@ -39,6 +41,8 @@ use servers::tls::{TlsMode, TlsOption};
use tests_integration::test_util::{
setup_grpc_server, setup_grpc_server_with, setup_grpc_server_with_user_provider, StorageType,
};
use tonic::metadata::MetadataValue;
use tonic::Request;
#[macro_export]
macro_rules! grpc_test {
@@ -73,6 +77,7 @@ macro_rules! grpc_tests {
test_invalid_dbname,
test_auto_create_table,
test_auto_create_table_with_hints,
test_otel_arrow_auth,
test_insert_and_select,
test_dbname,
test_grpc_message_size_ok,
@@ -276,6 +281,84 @@ pub async fn test_grpc_auth(store_type: StorageType) {
let _ = fe_grpc_server.shutdown().await;
}
pub async fn test_otel_arrow_auth(store_type: StorageType) {
let user_provider = user_provider_from_option(
&"static_user_provider:cmd:greptime_user=greptime_pwd".to_string(),
)
.unwrap();
let (_db, fe_grpc_server) = setup_grpc_server_with_user_provider(
store_type,
"test_otel_arrow_auth",
Some(user_provider),
)
.await;
let addr = fe_grpc_server.bind_addr().unwrap().to_string();
let mut client = ArrowMetricsServiceClient::connect(format!("http://{}", addr))
.await
.unwrap();
let batch_arrow_records = BatchArrowRecords {
batch_id: 1,
arrow_payloads: vec![],
headers: vec![],
};
// test without auth
{
let records = batch_arrow_records.clone();
let stream = futures::stream::once(async { records });
let request = Request::new(stream);
let response = client.arrow_metrics(request).await;
assert!(response.is_err());
let error = response.unwrap_err();
assert_eq!(error.code(), tonic::Code::Unauthenticated);
}
// test auth
{
let records = batch_arrow_records.clone();
let stream = futures::stream::once(async { records });
let mut request = Request::new(stream);
request.metadata_mut().insert(
"authorization",
MetadataValue::from_static("Basic Z3JlcHRpbWVfdXNlcjpncmVwdGltZV9wd2Q="), // greptime_user:greptime_pwd base64 encoded
);
let response = client.arrow_metrics(request).await;
assert!(response.is_ok());
let mut response_stream = response.unwrap().into_inner();
let resp = response_stream.message().await;
assert!(resp.is_err());
let error = resp.unwrap_err();
assert_eq!(
error.message(),
"Failed to handle otel-arrow request, error message: Batch is empty"
);
}
// test old auth
{
let stream = futures::stream::once(async { batch_arrow_records });
let mut request = Request::new(stream);
request.metadata_mut().insert(
"authorization",
MetadataValue::from_static("Z3JlcHRpbWVfdXNlcjpncmVwdGltZV9wd2Q="), // greptime_user:greptime_pwd base64 encoded
);
let response = client.arrow_metrics(request).await;
assert!(response.is_ok());
let mut response_stream = response.unwrap().into_inner();
let resp = response_stream.message().await;
assert!(resp.is_err());
let error = resp.unwrap_err();
assert_eq!(
error.message(),
"Failed to handle otel-arrow request, error message: Batch is empty"
);
}
let _ = fe_grpc_server.shutdown().await;
}
pub async fn test_auto_create_table(store_type: StorageType) {
let (_db, fe_grpc_server) = setup_grpc_server(store_type, "test_auto_create_table").await;
let addr = fe_grpc_server.bind_addr().unwrap().to_string();