mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-14 12:00:40 +00:00
@@ -35,7 +35,7 @@ use common_time::timezone::parse_timezone;
|
||||
use futures_util::StreamExt;
|
||||
use session::context::{
|
||||
Channel, QueryContextBuilder, QueryContextRef, REMOTE_QUERY_ID_EXTENSION_KEY,
|
||||
generate_remote_query_id,
|
||||
generate_remote_query_id, is_reserved_extension_key,
|
||||
};
|
||||
use session::hints::READ_PREFERENCE_HINT;
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
@@ -238,6 +238,9 @@ pub(crate) fn create_query_context(
|
||||
}
|
||||
|
||||
for (key, value) in extensions {
|
||||
if is_reserved_extension_key(&key) {
|
||||
continue;
|
||||
}
|
||||
ctx_builder = ctx_builder.set_extension(key, value);
|
||||
}
|
||||
Ok(ctx_builder.build().into())
|
||||
@@ -327,4 +330,23 @@ mod tests {
|
||||
Some(extensions[1].1.as_str())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_query_context_ignores_remote_query_id_extension() {
|
||||
let query_context = create_query_context(
|
||||
Channel::Grpc,
|
||||
None,
|
||||
vec![(
|
||||
REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
|
||||
"spoofed-query-id".to_string(),
|
||||
)],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_ne!(query_context.remote_query_id(), Some("spoofed-query-id"));
|
||||
assert_eq!(
|
||||
query_context.extension(REMOTE_QUERY_ID_EXTENSION_KEY),
|
||||
query_context.remote_query_id()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ use axum::body::Body;
|
||||
use axum::http::Request;
|
||||
use axum::middleware::Next;
|
||||
use axum::response::Response;
|
||||
use session::context::QueryContext;
|
||||
use session::context::{QueryContext, REMOTE_QUERY_ID_EXTENSION_KEY, is_reserved_extension_key};
|
||||
|
||||
use crate::hint_headers;
|
||||
|
||||
@@ -24,8 +24,50 @@ pub async fn extract_hints(mut request: Request<Body>, next: Next) -> Response {
|
||||
let hints = hint_headers::extract_hints(request.headers());
|
||||
if let Some(query_ctx) = request.extensions_mut().get_mut::<QueryContext>() {
|
||||
for (key, value) in hints {
|
||||
query_ctx.set_extension(key, value);
|
||||
apply_hint(query_ctx, key, value);
|
||||
}
|
||||
}
|
||||
next.run(request).await
|
||||
}
|
||||
|
||||
fn apply_hint(query_ctx: &mut QueryContext, key: String, value: String) {
|
||||
if is_reserved_extension_key(&key) {
|
||||
return;
|
||||
}
|
||||
query_ctx.set_extension(key, value);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use session::context::{QueryContextBuilder, generate_remote_query_id};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_apply_hint_ignores_remote_query_id() {
|
||||
let expected_remote_query_id = generate_remote_query_id();
|
||||
let mut query_ctx = QueryContextBuilder::default()
|
||||
.set_extension(
|
||||
REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
|
||||
expected_remote_query_id.clone(),
|
||||
)
|
||||
.build();
|
||||
|
||||
apply_hint(
|
||||
&mut query_ctx,
|
||||
REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
|
||||
"spoofed-query-id".to_string(),
|
||||
);
|
||||
apply_hint(
|
||||
&mut query_ctx,
|
||||
"auto_create_table".to_string(),
|
||||
"true".to_string(),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
query_ctx.remote_query_id(),
|
||||
Some(expected_remote_query_id.as_str())
|
||||
);
|
||||
assert_eq!(query_ctx.extension("auto_create_table"), Some("true"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,6 +43,10 @@ pub type ConnInfoRef = Arc<ConnInfo>;
|
||||
const CURSOR_COUNT_WARNING_LIMIT: usize = 10;
|
||||
pub const REMOTE_QUERY_ID_EXTENSION_KEY: &str = "remote_query_id";
|
||||
|
||||
pub fn is_reserved_extension_key(key: &str) -> bool {
|
||||
key == REMOTE_QUERY_ID_EXTENSION_KEY
|
||||
}
|
||||
|
||||
pub fn generate_remote_query_id() -> String {
|
||||
generate_remote_query_id_value().to_string()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user