diff --git a/Cargo.lock b/Cargo.lock index 3f184ebe0b..02b02a09c1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1312,6 +1312,7 @@ dependencies = [ "tracing-utils", "url", "utils", + "uuid", "vm_monitor", "workspace_hack", "zstd", diff --git a/compute_tools/Cargo.toml b/compute_tools/Cargo.toml index 33892813c4..b04f364cbb 100644 --- a/compute_tools/Cargo.toml +++ b/compute_tools/Cargo.toml @@ -51,6 +51,7 @@ tracing-subscriber.workspace = true tracing-utils.workspace = true thiserror.workspace = true url.workspace = true +uuid.workspace = true prometheus.workspace = true postgres_initdb.workspace = true diff --git a/compute_tools/src/http/server.rs b/compute_tools/src/http/server.rs index 33d4b489a0..40fb1f4b4d 100644 --- a/compute_tools/src/http/server.rs +++ b/compute_tools/src/http/server.rs @@ -1,15 +1,14 @@ use std::{ net::{IpAddr, Ipv6Addr, SocketAddr}, - sync::{ - atomic::{AtomicU64, Ordering}, - Arc, - }, + sync::Arc, thread, time::Duration, }; use anyhow::Result; use axum::{ + extract::Request, + middleware::{self, Next}, response::{IntoResponse, Response}, routing::{get, post}, Router, @@ -17,11 +16,9 @@ use axum::{ use http::StatusCode; use tokio::net::TcpListener; use tower::ServiceBuilder; -use tower_http::{ - request_id::{MakeRequestId, PropagateRequestIdLayer, RequestId, SetRequestIdLayer}, - trace::TraceLayer, -}; +use tower_http::{request_id::PropagateRequestIdLayer, trace::TraceLayer}; use tracing::{debug, error, info, Span}; +use uuid::Uuid; use super::routes::{ check_writability, configure, database_schema, dbs_and_roles, extension_server, extensions, @@ -34,30 +31,24 @@ async fn handle_404() -> Response { StatusCode::NOT_FOUND.into_response() } -#[derive(Clone, Default)] -struct ComputeMakeRequestId(Arc); +const X_REQUEST_ID: &str = "x-request-id"; -impl MakeRequestId for ComputeMakeRequestId { - fn make_request_id( - &mut self, - _request: &http::Request, - ) -> Option { - let request_id = self - .0 - .fetch_add(1, Ordering::SeqCst) - .to_string() - .parse() - .unwrap(); +/// This middleware function allows compute_ctl to generate its own request ID +/// if one isn't supplied. The control plane will always send one as a UUID. The +/// neon Postgres extension on the other hand does not send one. +async fn maybe_add_request_id_header(mut request: Request, next: Next) -> Response { + let headers = request.headers_mut(); - Some(RequestId::new(request_id)) + if headers.get(X_REQUEST_ID).is_none() { + headers.append(X_REQUEST_ID, Uuid::new_v4().to_string().parse().unwrap()); } + + next.run(request).await } /// Run the HTTP server and wait on it forever. #[tokio::main] async fn serve(port: u16, compute: Arc) { - const X_REQUEST_ID: &str = "x-request-id"; - let mut app = Router::new() .route("/check_writability", post(check_writability::is_writable)) .route("/configure", post(configure::configure)) @@ -82,9 +73,8 @@ async fn serve(port: u16, compute: Arc) { .fallback(handle_404) .layer( ServiceBuilder::new() - .layer(SetRequestIdLayer::x_request_id( - ComputeMakeRequestId::default(), - )) + // Add this middleware since we assume the request ID exists + .layer(middleware::from_fn(maybe_add_request_id_header)) .layer( TraceLayer::new_for_http() .on_request(|request: &http::Request<_>, _span: &Span| {