diff --git a/src/servers/src/grpc/greptime_handler.rs b/src/servers/src/grpc/greptime_handler.rs index 40b825c821..69a1d92bf1 100644 --- a/src/servers/src/grpc/greptime_handler.rs +++ b/src/servers/src/grpc/greptime_handler.rs @@ -35,7 +35,7 @@ use common_time::timezone::parse_timezone; use futures_util::StreamExt; use session::context::{ Channel, QueryContextBuilder, QueryContextRef, REMOTE_QUERY_ID_EXTENSION_KEY, - generate_remote_query_id, + generate_remote_query_id, is_reserved_extension_key, }; use session::hints::READ_PREFERENCE_HINT; use snafu::{OptionExt, ResultExt}; @@ -238,6 +238,9 @@ pub(crate) fn create_query_context( } for (key, value) in extensions { + if is_reserved_extension_key(&key) { + continue; + } ctx_builder = ctx_builder.set_extension(key, value); } Ok(ctx_builder.build().into()) @@ -327,4 +330,23 @@ mod tests { Some(extensions[1].1.as_str()) ); } + + #[test] + fn test_create_query_context_ignores_remote_query_id_extension() { + let query_context = create_query_context( + Channel::Grpc, + None, + vec![( + REMOTE_QUERY_ID_EXTENSION_KEY.to_string(), + "spoofed-query-id".to_string(), + )], + ) + .unwrap(); + + assert_ne!(query_context.remote_query_id(), Some("spoofed-query-id")); + assert_eq!( + query_context.extension(REMOTE_QUERY_ID_EXTENSION_KEY), + query_context.remote_query_id() + ); + } } diff --git a/src/servers/src/http/hints.rs b/src/servers/src/http/hints.rs index 7f98461cf6..19cd8dc514 100644 --- a/src/servers/src/http/hints.rs +++ b/src/servers/src/http/hints.rs @@ -16,7 +16,7 @@ use axum::body::Body; use axum::http::Request; use axum::middleware::Next; use axum::response::Response; -use session::context::QueryContext; +use session::context::{QueryContext, REMOTE_QUERY_ID_EXTENSION_KEY, is_reserved_extension_key}; use crate::hint_headers; @@ -24,8 +24,50 @@ pub async fn extract_hints(mut request: Request
, next: Next) -> Response { let hints = hint_headers::extract_hints(request.headers()); if let Some(query_ctx) = request.extensions_mut().get_mut::