use std::future::Future; use std::io::Write as _; use std::str::FromStr; use std::time::Duration; use anyhow::{Context, anyhow}; use bytes::{Bytes, BytesMut}; use hyper::header::{AUTHORIZATION, CONTENT_DISPOSITION, CONTENT_TYPE, HeaderName}; use hyper::http::HeaderValue; use hyper::{Body, Method, Request, Response}; use jsonwebtoken::TokenData; use metrics::{Encoder, IntCounter, TextEncoder, register_int_counter}; use once_cell::sync::Lazy; use pprof::ProfilerGuardBuilder; use pprof::protos::Message as _; use routerify::ext::RequestExt; use routerify::{Middleware, RequestInfo, Router, RouterBuilder}; use tokio::sync::{Mutex, Notify, mpsc}; use tokio_stream::wrappers::ReceiverStream; use tokio_util::io::ReaderStream; use tracing::{Instrument, debug, info, info_span, warn}; use utils::auth::{AuthError, Claims, SwappableJwtAuth}; use crate::error::{ApiError, api_error_handler, route_error_handler}; use crate::request::{get_query_param, parse_query_param}; static SERVE_METRICS_COUNT: Lazy = Lazy::new(|| { register_int_counter!( "libmetrics_metric_handler_requests_total", "Number of metric requests made" ) .expect("failed to define a metric") }); static X_REQUEST_ID_HEADER_STR: &str = "x-request-id"; static X_REQUEST_ID_HEADER: HeaderName = HeaderName::from_static(X_REQUEST_ID_HEADER_STR); #[derive(Debug, Default, Clone)] struct RequestId(String); /// Adds a tracing info_span! instrumentation around the handler events, /// logs the request start and end events for non-GET requests and non-200 responses. /// /// Usage: Replace `my_handler` with `|r| request_span(r, my_handler)` /// /// Use this to distinguish between logs of different HTTP requests: every request handler wrapped /// with this will get request info logged in the wrapping span, including the unique request ID. /// /// This also handles errors, logging them and converting them to an HTTP error response. /// /// NB: If the client disconnects, Hyper will drop the Future, without polling it to /// completion. In other words, the handler must be async cancellation safe! request_span /// prints a warning to the log when that happens, so that you have some trace of it in /// the log. /// /// /// There could be other ways to implement similar functionality: /// /// * procmacros placed on top of all handler methods /// With all the drawbacks of procmacros, brings no difference implementation-wise, /// and little code reduction compared to the existing approach. /// /// * Another `TraitExt` with e.g. the `get_with_span`, `post_with_span` methods to do similar logic, /// implemented for [`RouterBuilder`]. /// Could be simpler, but we don't want to depend on [`routerify`] more, targeting to use other library later. /// /// * In theory, a span guard could've been created in a pre-request middleware and placed into a global collection, to be dropped /// later, in a post-response middleware. /// Due to suspendable nature of the futures, would give contradictive results which is exactly the opposite of what `tracing-futures` /// tries to achive with its `.instrument` used in the current approach. /// /// If needed, a declarative macro to substitute the |r| ... closure boilerplate could be introduced. pub async fn request_span(request: Request, handler: H) -> R::Output where R: Future, ApiError>> + Send + 'static, H: FnOnce(Request) -> R + Send + Sync + 'static, { let request_id = request.context::().unwrap_or_default().0; let method = request.method(); let path = request.uri().path(); let request_span = info_span!("request", %method, %path, %request_id); let log_quietly = method == Method::GET; async move { let cancellation_guard = RequestCancelled::warn_when_dropped_without_responding(); if log_quietly { debug!("Handling request"); } else { info!("Handling request"); } // No special handling for panics here. There's a `tracing_panic_hook` from another // module to do that globally. let res = handler(request).await; cancellation_guard.disarm(); // Log the result if needed. // // We also convert any errors into an Ok response with HTTP error code here. // `make_router` sets a last-resort error handler that would do the same, but // we prefer to do it here, before we exit the request span, so that the error // is still logged with the span. // // (Because we convert errors to Ok response, we never actually return an error, // and we could declare the function to return the never type (`!`). However, // using `routerify::RouterBuilder` requires a proper error type.) match res { Ok(response) => { let response_status = response.status(); if log_quietly && response_status.is_success() { debug!("Request handled, status: {response_status}"); } else { info!("Request handled, status: {response_status}"); } Ok(response) } Err(err) => Ok(api_error_handler(err)), } } .instrument(request_span) .await } /// Drop guard to WARN in case the request was dropped before completion. struct RequestCancelled { warn: Option, } impl RequestCancelled { /// Create the drop guard using the [`tracing::Span::current`] as the span. fn warn_when_dropped_without_responding() -> Self { RequestCancelled { warn: Some(tracing::Span::current()), } } /// Consume the drop guard without logging anything. fn disarm(mut self) { self.warn = None; } } impl Drop for RequestCancelled { fn drop(&mut self) { if std::thread::panicking() { // we are unwinding due to panicking, assume we are not dropped for cancellation } else if let Some(span) = self.warn.take() { // the span has all of the info already, but the outer `.instrument(span)` has already // been dropped, so we need to manually re-enter it for this message. // // this is what the instrument would do before polling so it is fine. let _g = span.entered(); warn!("request was dropped before completing"); } } } /// An [`std::io::Write`] implementation on top of a channel sending [`bytes::Bytes`] chunks. pub struct ChannelWriter { buffer: BytesMut, pub tx: mpsc::Sender>, written: usize, /// Time spent waiting for the channel to make progress. It is not the same as time to upload a /// buffer because we cannot know anything about that, but this should allow us to understand /// the actual time taken without the time spent `std::thread::park`ed. wait_time: std::time::Duration, } impl ChannelWriter { pub fn new(buf_len: usize, tx: mpsc::Sender>) -> Self { assert_ne!(buf_len, 0); ChannelWriter { // split about half off the buffer from the start, because we flush depending on // capacity. first flush will come sooner than without this, but now resizes will // have better chance of picking up the "other" half. not guaranteed of course. buffer: BytesMut::with_capacity(buf_len).split_off(buf_len / 2), tx, written: 0, wait_time: std::time::Duration::ZERO, } } pub fn flush0(&mut self) -> std::io::Result { let n = self.buffer.len(); if n == 0 { return Ok(0); } tracing::trace!(n, "flushing"); let ready = self.buffer.split().freeze(); let wait_started_at = std::time::Instant::now(); // not ideal to call from blocking code to block_on, but we are sure that this // operation does not spawn_blocking other tasks let res: Result<(), ()> = tokio::runtime::Handle::current().block_on(async { self.tx.send(Ok(ready)).await.map_err(|_| ())?; // throttle sending to allow reuse of our buffer in `write`. self.tx.reserve().await.map_err(|_| ())?; // now the response task has picked up the buffer and hopefully started // sending it to the client. Ok(()) }); self.wait_time += wait_started_at.elapsed(); if res.is_err() { return Err(std::io::ErrorKind::BrokenPipe.into()); } self.written += n; Ok(n) } pub fn flushed_bytes(&self) -> usize { self.written } pub fn wait_time(&self) -> std::time::Duration { self.wait_time } } impl std::io::Write for ChannelWriter { fn write(&mut self, mut buf: &[u8]) -> std::io::Result { let remaining = self.buffer.capacity() - self.buffer.len(); let out_of_space = remaining < buf.len(); let original_len = buf.len(); if out_of_space { let can_still_fit = buf.len() - remaining; self.buffer.extend_from_slice(&buf[..can_still_fit]); buf = &buf[can_still_fit..]; self.flush0()?; } // assume that this will often under normal operation just move the pointer back to the // beginning of allocation, because previous split off parts are already sent and // dropped. self.buffer.extend_from_slice(buf); Ok(original_len) } fn flush(&mut self) -> std::io::Result<()> { self.flush0().map(|_| ()) } } pub async fn prometheus_metrics_handler(_req: Request) -> Result, ApiError> { SERVE_METRICS_COUNT.inc(); let started_at = std::time::Instant::now(); let (tx, rx) = mpsc::channel(1); let body = Body::wrap_stream(ReceiverStream::new(rx)); let mut writer = ChannelWriter::new(128 * 1024, tx); let encoder = TextEncoder::new(); let response = Response::builder() .status(200) .header(CONTENT_TYPE, encoder.format_type()) .body(body) .unwrap(); let span = info_span!("blocking"); tokio::task::spawn_blocking(move || { // there are situations where we lose scraped metrics under load, try to gather some clues // since all nodes are queried this, keep the message count low. let spawned_at = std::time::Instant::now(); let _span = span.entered(); let metrics = metrics::gather(); let gathered_at = std::time::Instant::now(); let res = encoder .encode(&metrics, &mut writer) .and_then(|_| writer.flush().map_err(|e| e.into())); // this instant is not when we finally got the full response sent, sending is done by hyper // in another task. let encoded_at = std::time::Instant::now(); let spawned_in = spawned_at - started_at; let collected_in = gathered_at - spawned_at; // remove the wait time here in case the tcp connection was clogged let encoded_in = encoded_at - gathered_at - writer.wait_time(); let total = encoded_at - started_at; match res { Ok(()) => { tracing::info!( bytes = writer.flushed_bytes(), total_ms = total.as_millis(), spawning_ms = spawned_in.as_millis(), collection_ms = collected_in.as_millis(), encoding_ms = encoded_in.as_millis(), "responded /metrics" ); } Err(e) => { // there is a chance that this error is not the BrokenPipe we generate in the writer // for "closed connection", but it is highly unlikely. tracing::warn!( after_bytes = writer.flushed_bytes(), total_ms = total.as_millis(), spawning_ms = spawned_in.as_millis(), collection_ms = collected_in.as_millis(), encoding_ms = encoded_in.as_millis(), "failed to write out /metrics response: {e:?}" ); // semantics of this error are quite... unclear. we want to error the stream out to // abort the response to somehow notify the client that we failed. // // though, most likely the reason for failure is that the receiver is already gone. drop( writer .tx .blocking_send(Err(std::io::ErrorKind::BrokenPipe.into())), ); } } }); Ok(response) } /// Generates CPU profiles. pub async fn profile_cpu_handler(req: Request) -> Result, ApiError> { enum Format { Pprof, Svg, } // Parameters. let format = match get_query_param(&req, "format")?.as_deref() { None => Format::Pprof, Some("pprof") => Format::Pprof, Some("svg") => Format::Svg, Some(format) => return Err(ApiError::BadRequest(anyhow!("invalid format {format}"))), }; let seconds = match parse_query_param(&req, "seconds")? { None => 5, Some(seconds @ 1..=60) => seconds, Some(_) => return Err(ApiError::BadRequest(anyhow!("duration must be 1-60 secs"))), }; let frequency_hz = match parse_query_param(&req, "frequency")? { None => 99, Some(1001..) => return Err(ApiError::BadRequest(anyhow!("frequency must be <=1000 Hz"))), Some(frequency) => frequency, }; let force: bool = parse_query_param(&req, "force")?.unwrap_or_default(); // Take the profile. static PROFILE_LOCK: Lazy> = Lazy::new(|| Mutex::new(())); static PROFILE_CANCEL: Lazy = Lazy::new(Notify::new); let report = { // Only allow one profiler at a time. If force is true, cancel a running profile (e.g. a // Grafana continuous profile). We use a try_lock() loop when cancelling instead of waiting // for a lock(), to avoid races where the notify isn't currently awaited. let _lock = loop { match PROFILE_LOCK.try_lock() { Ok(lock) => break lock, Err(_) if force => PROFILE_CANCEL.notify_waiters(), Err(_) => { return Err(ApiError::Conflict( "profiler already running (use ?force=true to cancel it)".into(), )); } } tokio::time::sleep(Duration::from_millis(1)).await; // don't busy-wait }; let guard = ProfilerGuardBuilder::default() .frequency(frequency_hz) .blocklist(&["libc", "libgcc", "pthread", "vdso"]) .build() .map_err(|err| ApiError::InternalServerError(err.into()))?; tokio::select! { _ = tokio::time::sleep(Duration::from_secs(seconds)) => {}, _ = PROFILE_CANCEL.notified() => {}, }; guard .report() .build() .map_err(|err| ApiError::InternalServerError(err.into()))? }; // Return the report in the requested format. match format { Format::Pprof => { let body = report .pprof() .map_err(|err| ApiError::InternalServerError(err.into()))? .encode_to_vec(); Response::builder() .status(200) .header(CONTENT_TYPE, "application/octet-stream") .header(CONTENT_DISPOSITION, "attachment; filename=\"profile.pb\"") .body(Body::from(body)) .map_err(|err| ApiError::InternalServerError(err.into())) } Format::Svg => { let mut body = Vec::new(); report .flamegraph(&mut body) .map_err(|err| ApiError::InternalServerError(err.into()))?; Response::builder() .status(200) .header(CONTENT_TYPE, "image/svg+xml") .body(Body::from(body)) .map_err(|err| ApiError::InternalServerError(err.into())) } } } /// Generates heap profiles. /// /// This only works with jemalloc on Linux. pub async fn profile_heap_handler(req: Request) -> Result, ApiError> { enum Format { Jemalloc, Pprof, Svg, } // Parameters. let format = match get_query_param(&req, "format")?.as_deref() { None => Format::Pprof, Some("jemalloc") => Format::Jemalloc, Some("pprof") => Format::Pprof, Some("svg") => Format::Svg, Some(format) => return Err(ApiError::BadRequest(anyhow!("invalid format {format}"))), }; // Obtain profiler handle. let mut prof_ctl = jemalloc_pprof::PROF_CTL .as_ref() .ok_or(ApiError::InternalServerError(anyhow!( "heap profiling not enabled" )))? .lock() .await; if !prof_ctl.activated() { return Err(ApiError::InternalServerError(anyhow!( "heap profiling not enabled" ))); } // Take and return the profile. match format { Format::Jemalloc => { // NB: file is an open handle to a tempfile that's already deleted. let file = tokio::task::spawn_blocking(move || prof_ctl.dump()) .await .map_err(|join_err| ApiError::InternalServerError(join_err.into()))? .map_err(ApiError::InternalServerError)?; let stream = ReaderStream::new(tokio::fs::File::from_std(file)); Response::builder() .status(200) .header(CONTENT_TYPE, "application/octet-stream") .header(CONTENT_DISPOSITION, "attachment; filename=\"heap.dump\"") .body(Body::wrap_stream(stream)) .map_err(|err| ApiError::InternalServerError(err.into())) } Format::Pprof => { let data = tokio::task::spawn_blocking(move || prof_ctl.dump_pprof()) .await .map_err(|join_err| ApiError::InternalServerError(join_err.into()))? .map_err(ApiError::InternalServerError)?; Response::builder() .status(200) .header(CONTENT_TYPE, "application/octet-stream") .header(CONTENT_DISPOSITION, "attachment; filename=\"heap.pb.gz\"") .body(Body::from(data)) .map_err(|err| ApiError::InternalServerError(err.into())) } Format::Svg => { let svg = tokio::task::spawn_blocking(move || prof_ctl.dump_flamegraph()) .await .map_err(|join_err| ApiError::InternalServerError(join_err.into()))? .map_err(ApiError::InternalServerError)?; Response::builder() .status(200) .header(CONTENT_TYPE, "image/svg+xml") .body(Body::from(svg)) .map_err(|err| ApiError::InternalServerError(err.into())) } } } pub fn add_request_id_middleware() -> Middleware { Middleware::pre(move |req| async move { let request_id = match req.headers().get(&X_REQUEST_ID_HEADER) { Some(request_id) => request_id .to_str() .expect("extract request id value") .to_owned(), None => { let request_id = uuid::Uuid::new_v4(); request_id.to_string() } }; req.set_context(RequestId(request_id)); Ok(req) }) } async fn add_request_id_header_to_response( mut res: Response, req_info: RequestInfo, ) -> Result, ApiError> { if let Some(request_id) = req_info.context::() { if let Ok(request_header_value) = HeaderValue::from_str(&request_id.0) { res.headers_mut() .insert(&X_REQUEST_ID_HEADER, request_header_value); }; }; Ok(res) } pub fn make_router() -> RouterBuilder { Router::builder() .middleware(add_request_id_middleware()) .middleware(Middleware::post_with_info( add_request_id_header_to_response, )) .err_handler(route_error_handler) } pub fn attach_openapi_ui( router_builder: RouterBuilder, spec: &'static [u8], spec_mount_path: &'static str, ui_mount_path: &'static str, ) -> RouterBuilder { router_builder .get(spec_mount_path, move |r| request_span(r, move |_| async move { Ok(Response::builder().body(Body::from(spec)).unwrap()) }) ) .get(ui_mount_path, move |r| request_span(r, move |_| async move { Ok(Response::builder().body(Body::from(format!(r#" rweb
"#, spec_mount_path))).unwrap()) }) ) } fn parse_token(header_value: &str) -> Result<&str, ApiError> { // header must be in form Bearer let (prefix, token) = header_value .split_once(' ') .ok_or_else(|| ApiError::Unauthorized("malformed authorization header".to_string()))?; if prefix != "Bearer" { return Err(ApiError::Unauthorized( "malformed authorization header".to_string(), )); } Ok(token) } pub fn auth_middleware( provide_auth: fn(&Request) -> Option<&SwappableJwtAuth>, ) -> Middleware { Middleware::pre(move |req| async move { if let Some(auth) = provide_auth(&req) { match req.headers().get(AUTHORIZATION) { Some(value) => { let header_value = value.to_str().map_err(|_| { ApiError::Unauthorized("malformed authorization header".to_string()) })?; let token = parse_token(header_value)?; let data: TokenData = auth.decode(token).map_err(|err| { warn!("Authentication error: {err}"); // Rely on From for ApiError impl err })?; req.set_context(data.claims); } None => { return Err(ApiError::Unauthorized( "missing authorization header".to_string(), )); } } } Ok(req) }) } pub fn add_response_header_middleware( header: &str, value: &str, ) -> anyhow::Result> where B: hyper::body::HttpBody + Send + Sync + 'static, { let name = HeaderName::from_str(header).with_context(|| format!("invalid header name: {header}"))?; let value = HeaderValue::from_str(value).with_context(|| format!("invalid header value: {value}"))?; Ok(Middleware::post_with_info( move |mut response, request_info| { let name = name.clone(); let value = value.clone(); async move { let headers = response.headers_mut(); if headers.contains_key(&name) { warn!( "{} response already contains header {:?}", request_info.uri(), &name, ); } else { headers.insert(name, value); } Ok(response) } }, )) } pub fn check_permission_with( req: &Request, check_permission: impl Fn(&Claims) -> Result<(), AuthError>, ) -> Result<(), ApiError> { match req.context::() { Some(claims) => Ok(check_permission(&claims) .map_err(|_err| ApiError::Forbidden("JWT authentication error".to_string()))?), None => Ok(()), // claims is None because auth is disabled } } #[cfg(test)] mod tests { use std::future::poll_fn; use std::net::{IpAddr, SocketAddr}; use hyper::service::Service; use routerify::RequestServiceBuilder; use super::*; #[tokio::test] async fn test_request_id_returned() { let builder = RequestServiceBuilder::new(make_router().build().unwrap()).unwrap(); let remote_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80); let mut service = builder.build(remote_addr); if let Err(e) = poll_fn(|ctx| service.poll_ready(ctx)).await { panic!("request service is not ready: {:?}", e); } let mut req: Request = Request::default(); req.headers_mut() .append(&X_REQUEST_ID_HEADER, HeaderValue::from_str("42").unwrap()); let resp: Response = service.call(req).await.unwrap(); let header_val = resp.headers().get(&X_REQUEST_ID_HEADER).unwrap(); assert!(header_val == "42", "response header mismatch"); } #[tokio::test] async fn test_request_id_empty() { let builder = RequestServiceBuilder::new(make_router().build().unwrap()).unwrap(); let remote_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80); let mut service = builder.build(remote_addr); if let Err(e) = poll_fn(|ctx| service.poll_ready(ctx)).await { panic!("request service is not ready: {:?}", e); } let req: Request = Request::default(); let resp: Response = service.call(req).await.unwrap(); let header_val = resp.headers().get(&X_REQUEST_ID_HEADER); assert_ne!(header_val, None, "response header should NOT be empty"); } }