1use axum::body::Body;
16use axum::http::Request;
17use axum::middleware::Next;
18use axum::response::Response;
19use common_telemetry::debug;
20use session::context::QueryContext;
21use session::hints::is_reserved_extension_key;
22
23use crate::hint_headers;
24
25pub async fn extract_hints(mut request: Request<Body>, next: Next) -> Response {
26 let hints = hint_headers::extract_hints(request.headers());
27 if let Some(query_ctx) = request.extensions_mut().get_mut::<QueryContext>() {
28 apply_hints(query_ctx, hints);
29 }
30 next.run(request).await
31}
32
33fn apply_hints(query_ctx: &mut QueryContext, hints: Vec<(String, String)>) {
34 for (key, value) in hints {
35 if is_reserved_extension_key(&key) {
36 debug!(
37 key = key.as_str(),
38 "Ignoring reserved external query context extension key"
39 );
40 continue;
41 }
42 query_ctx.set_extension(key, value);
43 }
44}
45
46#[cfg(test)]
47mod tests {
48 use session::context::{QueryContextBuilder, generate_remote_query_id};
49 use session::hints::REMOTE_QUERY_ID_EXTENSION_KEY;
50
51 use super::apply_hints;
52
53 #[test]
54 fn test_apply_hints_ignores_reserved_extension_keys() {
55 let original_query_id = generate_remote_query_id();
56 let mut query_ctx = QueryContextBuilder::default()
57 .set_extension(
58 REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
59 original_query_id.clone(),
60 )
61 .build();
62
63 apply_hints(
64 &mut query_ctx,
65 vec![
66 (
67 REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
68 "spoofed".to_string(),
69 ),
70 ("ttl".to_string(), "7d".to_string()),
71 ],
72 );
73
74 assert_eq!(
75 query_ctx.remote_query_id(),
76 Some(original_query_id.as_str())
77 );
78 assert_eq!(query_ctx.extension("ttl"), Some("7d"));
79 }
80}