mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-16 04:30:38 +00:00
This is a full switch, fs io operations are also tokio ones, working through thread pool. Similar to pageserver, we have multiple runtimes for easier `top` usage and isolation. Notable points: - Now that guts of safekeeper.rs are full of .await's, we need to be very careful not to drop task at random point, leaving timeline in unclear state. Currently the only writer is walreceiver and we don't have top level cancellation there, so we are good. But to be safe probably we should add a fuse panicking if task is being dropped while operation on a timeline is in progress. - Timeline lock is Tokio one now, as we do disk IO under it. - Collecting metrics got a crutch: since prometheus Collector is synchronous, it spawns a thread with current thread runtime collecting data. - Anything involving closures becomes significantly more complicated, as async fns are already kinda closures + 'async closures are unstable'. - Main thread now tracks other main tasks, which got much easier. - The only sync place left is initial data loading, as otherwise clippy complains on timeline map lock being held across await points -- which is not bad here as it happens only in single threaded runtime of main thread. But having it sync doesn't hurt either. I'm concerned about performance of thread pool io offloading, async traits and many await points; but we can try and see how it goes. fixes https://github.com/neondatabase/neon/issues/3036 fixes https://github.com/neondatabase/neon/issues/3966
395 lines
14 KiB
Rust
395 lines
14 KiB
Rust
use crate::auth::{Claims, JwtAuth};
|
|
use crate::http::error::{api_error_handler, route_error_handler, ApiError};
|
|
use anyhow::Context;
|
|
use hyper::header::{HeaderName, AUTHORIZATION};
|
|
use hyper::http::HeaderValue;
|
|
use hyper::Method;
|
|
use hyper::{header::CONTENT_TYPE, Body, Request, Response};
|
|
use metrics::{register_int_counter, Encoder, IntCounter, TextEncoder};
|
|
use once_cell::sync::Lazy;
|
|
use routerify::ext::RequestExt;
|
|
use routerify::{Middleware, RequestInfo, Router, RouterBuilder};
|
|
use tokio::task::JoinError;
|
|
use tracing::{self, debug, info, info_span, warn, Instrument};
|
|
|
|
use std::future::Future;
|
|
use std::str::FromStr;
|
|
|
|
static SERVE_METRICS_COUNT: Lazy<IntCounter> = Lazy::new(|| {
|
|
register_int_counter!(
|
|
"libmetrics_metric_handler_requests_total",
|
|
"Number of metric requests made"
|
|
)
|
|
.expect("failed to define a metric")
|
|
});
|
|
|
|
static X_REQUEST_ID_HEADER_STR: &str = "x-request-id";
|
|
|
|
static X_REQUEST_ID_HEADER: HeaderName = HeaderName::from_static(X_REQUEST_ID_HEADER_STR);
|
|
#[derive(Debug, Default, Clone)]
|
|
struct RequestId(String);
|
|
|
|
/// Adds a tracing info_span! instrumentation around the handler events,
|
|
/// logs the request start and end events for non-GET requests and non-200 responses.
|
|
///
|
|
/// Usage: Replace `my_handler` with `|r| request_span(r, my_handler)`
|
|
///
|
|
/// Use this to distinguish between logs of different HTTP requests: every request handler wrapped
|
|
/// with this will get request info logged in the wrapping span, including the unique request ID.
|
|
///
|
|
/// This also handles errors, logging them and converting them to an HTTP error response.
|
|
///
|
|
/// NB: If the client disconnects, Hyper will drop the Future, without polling it to
|
|
/// completion. In other words, the handler must be async cancellation safe! request_span
|
|
/// prints a warning to the log when that happens, so that you have some trace of it in
|
|
/// the log.
|
|
///
|
|
///
|
|
/// There could be other ways to implement similar functionality:
|
|
///
|
|
/// * procmacros placed on top of all handler methods
|
|
/// With all the drawbacks of procmacros, brings no difference implementation-wise,
|
|
/// and little code reduction compared to the existing approach.
|
|
///
|
|
/// * Another `TraitExt` with e.g. the `get_with_span`, `post_with_span` methods to do similar logic,
|
|
/// implemented for [`RouterBuilder`].
|
|
/// Could be simpler, but we don't want to depend on [`routerify`] more, targeting to use other library later.
|
|
///
|
|
/// * In theory, a span guard could've been created in a pre-request middleware and placed into a global collection, to be dropped
|
|
/// later, in a post-response middleware.
|
|
/// Due to suspendable nature of the futures, would give contradictive results which is exactly the opposite of what `tracing-futures`
|
|
/// tries to achive with its `.instrument` used in the current approach.
|
|
///
|
|
/// If needed, a declarative macro to substitute the |r| ... closure boilerplate could be introduced.
|
|
pub async fn request_span<R, H>(request: Request<Body>, handler: H) -> R::Output
|
|
where
|
|
R: Future<Output = Result<Response<Body>, ApiError>> + Send + 'static,
|
|
H: FnOnce(Request<Body>) -> R + Send + Sync + 'static,
|
|
{
|
|
let request_id = request.context::<RequestId>().unwrap_or_default().0;
|
|
let method = request.method();
|
|
let path = request.uri().path();
|
|
let request_span = info_span!("request", %method, %path, %request_id);
|
|
|
|
let log_quietly = method == Method::GET;
|
|
async move {
|
|
let cancellation_guard = RequestCancelled::warn_when_dropped_without_responding();
|
|
if log_quietly {
|
|
debug!("Handling request");
|
|
} else {
|
|
info!("Handling request");
|
|
}
|
|
|
|
// No special handling for panics here. There's a `tracing_panic_hook` from another
|
|
// module to do that globally.
|
|
let res = handler(request).await;
|
|
|
|
cancellation_guard.disarm();
|
|
|
|
// Log the result if needed.
|
|
//
|
|
// We also convert any errors into an Ok response with HTTP error code here.
|
|
// `make_router` sets a last-resort error handler that would do the same, but
|
|
// we prefer to do it here, before we exit the request span, so that the error
|
|
// is still logged with the span.
|
|
//
|
|
// (Because we convert errors to Ok response, we never actually return an error,
|
|
// and we could declare the function to return the never type (`!`). However,
|
|
// using `routerify::RouterBuilder` requires a proper error type.)
|
|
match res {
|
|
Ok(response) => {
|
|
let response_status = response.status();
|
|
if log_quietly && response_status.is_success() {
|
|
debug!("Request handled, status: {response_status}");
|
|
} else {
|
|
info!("Request handled, status: {response_status}");
|
|
}
|
|
Ok(response)
|
|
}
|
|
Err(err) => Ok(api_error_handler(err)),
|
|
}
|
|
}
|
|
.instrument(request_span)
|
|
.await
|
|
}
|
|
|
|
/// Drop guard to WARN in case the request was dropped before completion.
|
|
struct RequestCancelled {
|
|
warn: Option<tracing::Span>,
|
|
}
|
|
|
|
impl RequestCancelled {
|
|
/// Create the drop guard using the [`tracing::Span::current`] as the span.
|
|
fn warn_when_dropped_without_responding() -> Self {
|
|
RequestCancelled {
|
|
warn: Some(tracing::Span::current()),
|
|
}
|
|
}
|
|
|
|
/// Consume the drop guard without logging anything.
|
|
fn disarm(mut self) {
|
|
self.warn = None;
|
|
}
|
|
}
|
|
|
|
impl Drop for RequestCancelled {
|
|
fn drop(&mut self) {
|
|
if std::thread::panicking() {
|
|
// we are unwinding due to panicking, assume we are not dropped for cancellation
|
|
} else if let Some(span) = self.warn.take() {
|
|
// the span has all of the info already, but the outer `.instrument(span)` has already
|
|
// been dropped, so we need to manually re-enter it for this message.
|
|
//
|
|
// this is what the instrument would do before polling so it is fine.
|
|
let _g = span.entered();
|
|
warn!("request was dropped before completing");
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn prometheus_metrics_handler(_req: Request<Body>) -> Result<Response<Body>, ApiError> {
|
|
SERVE_METRICS_COUNT.inc();
|
|
|
|
let mut buffer = vec![];
|
|
let encoder = TextEncoder::new();
|
|
|
|
let metrics = tokio::task::spawn_blocking(move || {
|
|
// Currently we take a lot of mutexes while collecting metrics, so it's
|
|
// better to spawn a blocking task to avoid blocking the event loop.
|
|
metrics::gather()
|
|
})
|
|
.await
|
|
.map_err(|e: JoinError| ApiError::InternalServerError(e.into()))?;
|
|
encoder.encode(&metrics, &mut buffer).unwrap();
|
|
|
|
let response = Response::builder()
|
|
.status(200)
|
|
.header(CONTENT_TYPE, encoder.format_type())
|
|
.body(Body::from(buffer))
|
|
.unwrap();
|
|
|
|
Ok(response)
|
|
}
|
|
|
|
pub fn add_request_id_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
|
|
) -> Middleware<B, ApiError> {
|
|
Middleware::pre(move |req| async move {
|
|
let request_id = match req.headers().get(&X_REQUEST_ID_HEADER) {
|
|
Some(request_id) => request_id
|
|
.to_str()
|
|
.expect("extract request id value")
|
|
.to_owned(),
|
|
None => {
|
|
let request_id = uuid::Uuid::new_v4();
|
|
request_id.to_string()
|
|
}
|
|
};
|
|
req.set_context(RequestId(request_id));
|
|
|
|
Ok(req)
|
|
})
|
|
}
|
|
|
|
async fn add_request_id_header_to_response(
|
|
mut res: Response<Body>,
|
|
req_info: RequestInfo,
|
|
) -> Result<Response<Body>, ApiError> {
|
|
if let Some(request_id) = req_info.context::<RequestId>() {
|
|
if let Ok(request_header_value) = HeaderValue::from_str(&request_id.0) {
|
|
res.headers_mut()
|
|
.insert(&X_REQUEST_ID_HEADER, request_header_value);
|
|
};
|
|
};
|
|
|
|
Ok(res)
|
|
}
|
|
|
|
pub fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
|
|
Router::builder()
|
|
.middleware(add_request_id_middleware())
|
|
.middleware(Middleware::post_with_info(
|
|
add_request_id_header_to_response,
|
|
))
|
|
.get("/metrics", |r| request_span(r, prometheus_metrics_handler))
|
|
.err_handler(route_error_handler)
|
|
}
|
|
|
|
pub fn attach_openapi_ui(
|
|
router_builder: RouterBuilder<hyper::Body, ApiError>,
|
|
spec: &'static [u8],
|
|
spec_mount_path: &'static str,
|
|
ui_mount_path: &'static str,
|
|
) -> RouterBuilder<hyper::Body, ApiError> {
|
|
router_builder
|
|
.get(spec_mount_path,
|
|
move |r| request_span(r, move |_| async move {
|
|
Ok(Response::builder().body(Body::from(spec)).unwrap())
|
|
})
|
|
)
|
|
.get(ui_mount_path,
|
|
move |r| request_span(r, move |_| async move {
|
|
Ok(Response::builder().body(Body::from(format!(r#"
|
|
<!DOCTYPE html>
|
|
<html lang="en">
|
|
<head>
|
|
<title>rweb</title>
|
|
<link href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css" rel="stylesheet">
|
|
</head>
|
|
<body>
|
|
<div id="swagger-ui"></div>
|
|
<script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js" charset="UTF-8"> </script>
|
|
<script>
|
|
window.onload = function() {{
|
|
const ui = SwaggerUIBundle({{
|
|
"dom_id": "\#swagger-ui",
|
|
presets: [
|
|
SwaggerUIBundle.presets.apis,
|
|
SwaggerUIBundle.SwaggerUIStandalonePreset
|
|
],
|
|
layout: "BaseLayout",
|
|
deepLinking: true,
|
|
showExtensions: true,
|
|
showCommonExtensions: true,
|
|
url: "{}",
|
|
}})
|
|
window.ui = ui;
|
|
}};
|
|
</script>
|
|
</body>
|
|
</html>
|
|
"#, spec_mount_path))).unwrap())
|
|
})
|
|
)
|
|
}
|
|
|
|
fn parse_token(header_value: &str) -> Result<&str, ApiError> {
|
|
// header must be in form Bearer <token>
|
|
let (prefix, token) = header_value
|
|
.split_once(' ')
|
|
.ok_or_else(|| ApiError::Unauthorized("malformed authorization header".to_string()))?;
|
|
if prefix != "Bearer" {
|
|
return Err(ApiError::Unauthorized(
|
|
"malformed authorization header".to_string(),
|
|
));
|
|
}
|
|
Ok(token)
|
|
}
|
|
|
|
pub fn auth_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
|
|
provide_auth: fn(&Request<Body>) -> Option<&JwtAuth>,
|
|
) -> Middleware<B, ApiError> {
|
|
Middleware::pre(move |req| async move {
|
|
if let Some(auth) = provide_auth(&req) {
|
|
match req.headers().get(AUTHORIZATION) {
|
|
Some(value) => {
|
|
let header_value = value.to_str().map_err(|_| {
|
|
ApiError::Unauthorized("malformed authorization header".to_string())
|
|
})?;
|
|
let token = parse_token(header_value)?;
|
|
|
|
let data = auth
|
|
.decode(token)
|
|
.map_err(|_| ApiError::Unauthorized("malformed jwt token".to_string()))?;
|
|
req.set_context(data.claims);
|
|
}
|
|
None => {
|
|
return Err(ApiError::Unauthorized(
|
|
"missing authorization header".to_string(),
|
|
))
|
|
}
|
|
}
|
|
}
|
|
Ok(req)
|
|
})
|
|
}
|
|
|
|
pub fn add_response_header_middleware<B>(
|
|
header: &str,
|
|
value: &str,
|
|
) -> anyhow::Result<Middleware<B, ApiError>>
|
|
where
|
|
B: hyper::body::HttpBody + Send + Sync + 'static,
|
|
{
|
|
let name =
|
|
HeaderName::from_str(header).with_context(|| format!("invalid header name: {header}"))?;
|
|
let value =
|
|
HeaderValue::from_str(value).with_context(|| format!("invalid header value: {value}"))?;
|
|
Ok(Middleware::post_with_info(
|
|
move |mut response, request_info| {
|
|
let name = name.clone();
|
|
let value = value.clone();
|
|
async move {
|
|
let headers = response.headers_mut();
|
|
if headers.contains_key(&name) {
|
|
warn!(
|
|
"{} response already contains header {:?}",
|
|
request_info.uri(),
|
|
&name,
|
|
);
|
|
} else {
|
|
headers.insert(name, value);
|
|
}
|
|
Ok(response)
|
|
}
|
|
},
|
|
))
|
|
}
|
|
|
|
pub fn check_permission_with(
|
|
req: &Request<Body>,
|
|
check_permission: impl Fn(&Claims) -> Result<(), anyhow::Error>,
|
|
) -> Result<(), ApiError> {
|
|
match req.context::<Claims>() {
|
|
Some(claims) => {
|
|
Ok(check_permission(&claims).map_err(|err| ApiError::Forbidden(err.to_string()))?)
|
|
}
|
|
None => Ok(()), // claims is None because auth is disabled
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use futures::future::poll_fn;
|
|
use hyper::service::Service;
|
|
use routerify::RequestServiceBuilder;
|
|
use std::net::{IpAddr, SocketAddr};
|
|
|
|
#[tokio::test]
|
|
async fn test_request_id_returned() {
|
|
let builder = RequestServiceBuilder::new(make_router().build().unwrap()).unwrap();
|
|
let remote_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
|
|
let mut service = builder.build(remote_addr);
|
|
if let Err(e) = poll_fn(|ctx| service.poll_ready(ctx)).await {
|
|
panic!("request service is not ready: {:?}", e);
|
|
}
|
|
|
|
let mut req: Request<Body> = Request::default();
|
|
req.headers_mut()
|
|
.append(&X_REQUEST_ID_HEADER, HeaderValue::from_str("42").unwrap());
|
|
|
|
let resp: Response<hyper::body::Body> = service.call(req).await.unwrap();
|
|
|
|
let header_val = resp.headers().get(&X_REQUEST_ID_HEADER).unwrap();
|
|
|
|
assert!(header_val == "42", "response header mismatch");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_request_id_empty() {
|
|
let builder = RequestServiceBuilder::new(make_router().build().unwrap()).unwrap();
|
|
let remote_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
|
|
let mut service = builder.build(remote_addr);
|
|
if let Err(e) = poll_fn(|ctx| service.poll_ready(ctx)).await {
|
|
panic!("request service is not ready: {:?}", e);
|
|
}
|
|
|
|
let req: Request<Body> = Request::default();
|
|
let resp: Response<hyper::body::Body> = service.call(req).await.unwrap();
|
|
|
|
let header_val = resp.headers().get(&X_REQUEST_ID_HEADER);
|
|
|
|
assert_ne!(header_val, None, "response header should NOT be empty");
|
|
}
|
|
}
|