fix: prevent spoof

Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
discord9
2026-04-21 14:42:11 +08:00
parent 7450e70505
commit 6d6d2802b7
3 changed files with 71 additions and 3 deletions

View File

@@ -35,7 +35,7 @@ use common_time::timezone::parse_timezone;
use futures_util::StreamExt;
use session::context::{
Channel, QueryContextBuilder, QueryContextRef, REMOTE_QUERY_ID_EXTENSION_KEY,
generate_remote_query_id,
generate_remote_query_id, is_reserved_extension_key,
};
use session::hints::READ_PREFERENCE_HINT;
use snafu::{OptionExt, ResultExt};
@@ -238,6 +238,9 @@ pub(crate) fn create_query_context(
}
for (key, value) in extensions {
if is_reserved_extension_key(&key) {
continue;
}
ctx_builder = ctx_builder.set_extension(key, value);
}
Ok(ctx_builder.build().into())
@@ -327,4 +330,23 @@ mod tests {
Some(extensions[1].1.as_str())
);
}
#[test]
fn test_create_query_context_ignores_remote_query_id_extension() {
let query_context = create_query_context(
Channel::Grpc,
None,
vec![(
REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
"spoofed-query-id".to_string(),
)],
)
.unwrap();
assert_ne!(query_context.remote_query_id(), Some("spoofed-query-id"));
assert_eq!(
query_context.extension(REMOTE_QUERY_ID_EXTENSION_KEY),
query_context.remote_query_id()
);
}
}

View File

@@ -16,7 +16,7 @@ use axum::body::Body;
use axum::http::Request;
use axum::middleware::Next;
use axum::response::Response;
use session::context::QueryContext;
use session::context::{QueryContext, REMOTE_QUERY_ID_EXTENSION_KEY, is_reserved_extension_key};
use crate::hint_headers;
@@ -24,8 +24,50 @@ pub async fn extract_hints(mut request: Request<Body>, next: Next) -> Response {
let hints = hint_headers::extract_hints(request.headers());
if let Some(query_ctx) = request.extensions_mut().get_mut::<QueryContext>() {
for (key, value) in hints {
query_ctx.set_extension(key, value);
apply_hint(query_ctx, key, value);
}
}
next.run(request).await
}
fn apply_hint(query_ctx: &mut QueryContext, key: String, value: String) {
if is_reserved_extension_key(&key) {
return;
}
query_ctx.set_extension(key, value);
}
#[cfg(test)]
mod tests {
use session::context::{QueryContextBuilder, generate_remote_query_id};
use super::*;
#[test]
fn test_apply_hint_ignores_remote_query_id() {
let expected_remote_query_id = generate_remote_query_id();
let mut query_ctx = QueryContextBuilder::default()
.set_extension(
REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
expected_remote_query_id.clone(),
)
.build();
apply_hint(
&mut query_ctx,
REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
"spoofed-query-id".to_string(),
);
apply_hint(
&mut query_ctx,
"auto_create_table".to_string(),
"true".to_string(),
);
assert_eq!(
query_ctx.remote_query_id(),
Some(expected_remote_query_id.as_str())
);
assert_eq!(query_ctx.extension("auto_create_table"), Some("true"));
}
}

View File

@@ -43,6 +43,10 @@ pub type ConnInfoRef = Arc<ConnInfo>;
const CURSOR_COUNT_WARNING_LIMIT: usize = 10;
pub const REMOTE_QUERY_ID_EXTENSION_KEY: &str = "remote_query_id";
pub fn is_reserved_extension_key(key: &str) -> bool {
key == REMOTE_QUERY_ID_EXTENSION_KEY
}
pub fn generate_remote_query_id() -> String {
generate_remote_query_id_value().to_string()
}