Skip to main content

servers/http/
hints.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use axum::body::Body;
16use axum::http::Request;
17use axum::middleware::Next;
18use axum::response::Response;
19use common_telemetry::debug;
20use session::context::QueryContext;
21use session::hints::is_reserved_extension_key;
22
23use crate::hint_headers;
24
25pub async fn extract_hints(mut request: Request<Body>, next: Next) -> Response {
26    let hints = hint_headers::extract_hints(request.headers());
27    if let Some(query_ctx) = request.extensions_mut().get_mut::<QueryContext>() {
28        apply_hints(query_ctx, hints);
29    }
30    next.run(request).await
31}
32
33fn apply_hints(query_ctx: &mut QueryContext, hints: Vec<(String, String)>) {
34    for (key, value) in hints {
35        if is_reserved_extension_key(&key) {
36            debug!(
37                key = key.as_str(),
38                "Ignoring reserved external query context extension key"
39            );
40            continue;
41        }
42        query_ctx.set_extension(key, value);
43    }
44}
45
46#[cfg(test)]
47mod tests {
48    use session::context::{QueryContextBuilder, generate_remote_query_id};
49    use session::hints::REMOTE_QUERY_ID_EXTENSION_KEY;
50
51    use super::apply_hints;
52
53    #[test]
54    fn test_apply_hints_ignores_reserved_extension_keys() {
55        let original_query_id = generate_remote_query_id();
56        let mut query_ctx = QueryContextBuilder::default()
57            .set_extension(
58                REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
59                original_query_id.clone(),
60            )
61            .build();
62
63        apply_hints(
64            &mut query_ctx,
65            vec![
66                (
67                    REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
68                    "spoofed".to_string(),
69                ),
70                ("ttl".to_string(), "7d".to_string()),
71            ],
72        );
73
74        assert_eq!(
75            query_ctx.remote_query_id(),
76            Some(original_query_id.as_str())
77        );
78        assert_eq!(query_ctx.extension("ttl"), Some("7d"));
79    }
80}