From 3bd6551b36be636c7497ee774c65718320093bc3 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 14 Mar 2024 08:20:56 +0000 Subject: [PATCH] proxy http cancellation safety (#7117) ## Problem hyper auto-cancels the request futures on connection close. `sql_over_http::handle` is not 'drop cancel safe', so we need to do some other work to make sure connections are queries in the right way. ## Summary of changes 1. tokio::spawn the request handler to resolve the initial cancel-safety issue 2. share a cancellation token, and cancel it when the request `Service` is dropped. 3. Add a new log span to be able to track the HTTP connection lifecycle. --- proxy/src/protocol2.rs | 18 ++++++- proxy/src/serverless.rs | 74 +++++++++++++++++++-------- proxy/src/serverless/sql_over_http.rs | 2 +- proxy/src/serverless/tls_listener.rs | 29 ++++------- test_runner/fixtures/neon_fixtures.py | 2 + test_runner/regress/test_proxy.py | 36 +++++++++++++ 6 files changed, 120 insertions(+), 41 deletions(-) diff --git a/proxy/src/protocol2.rs b/proxy/src/protocol2.rs index f476cb9b37..700c8c8681 100644 --- a/proxy/src/protocol2.rs +++ b/proxy/src/protocol2.rs @@ -341,7 +341,14 @@ impl Accept for ProxyProtocolAccept { cx: &mut Context<'_>, ) -> Poll>> { let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?); - tracing::info!(protocol = self.protocol, "accepted new TCP connection"); + + let conn_id = uuid::Uuid::new_v4(); + let span = tracing::info_span!("http_conn", ?conn_id); + { + let _enter = span.enter(); + tracing::info!("accepted new TCP connection"); + } + let Some(conn) = conn else { return Poll::Ready(None); }; @@ -354,6 +361,7 @@ impl Accept for ProxyProtocolAccept { .with_label_values(&[self.protocol]) .guard(), )), + span, }))) } } @@ -364,6 +372,14 @@ pin_project! { pub inner: T, pub connection_id: Uuid, pub gauge: Mutex>, + pub span: tracing::Span, + } + + impl PinnedDrop for WithConnectionGuard { + fn drop(this: Pin<&mut Self>) { + let _enter = this.span.enter(); + tracing::info!("HTTP connection closed") + } } } diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index 68f68eaba1..be9f90acde 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -19,6 +19,7 @@ use rand::SeedableRng; pub use reqwest_middleware::{ClientWithMiddleware, Error}; pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use tokio_util::task::TaskTracker; +use tracing::instrument::Instrumented; use crate::context::RequestMonitoring; use crate::protocol2::{ProxyProtocolAccept, WithClientIp, WithConnectionGuard}; @@ -30,13 +31,12 @@ use hyper::{ Body, Method, Request, Response, }; -use std::convert::Infallible; use std::net::IpAddr; use std::sync::Arc; use std::task::Poll; use tls_listener::TlsListener; use tokio::net::TcpListener; -use tokio_util::sync::CancellationToken; +use tokio_util::sync::{CancellationToken, DropGuard}; use tracing::{error, info, warn, Instrument}; use utils::http::{error::ApiError, json::json_response}; @@ -100,12 +100,7 @@ pub async fn task_main( let ws_connections = tokio_util::task::task_tracker::TaskTracker::new(); ws_connections.close(); // allows `ws_connections.wait to complete` - let tls_listener = TlsListener::new( - tls_acceptor, - addr_incoming, - "http", - config.handshake_timeout, - ); + let tls_listener = TlsListener::new(tls_acceptor, addr_incoming, config.handshake_timeout); let make_svc = hyper::service::make_service_fn( |stream: &tokio_rustls::server::TlsStream< @@ -121,6 +116,11 @@ pub async fn task_main( .take() .expect("gauge should be set on connection start"); + // Cancel all current inflight HTTP requests if the HTTP connection is closed. + let http_cancellation_token = CancellationToken::new(); + let cancel_connection = http_cancellation_token.clone().drop_guard(); + + let span = conn.span.clone(); let client_addr = conn.inner.client_addr(); let remote_addr = conn.inner.inner.remote_addr(); let backend = backend.clone(); @@ -136,27 +136,43 @@ pub async fn task_main( Ok(MetricService::new( hyper::service::service_fn(move |req: Request| { let backend = backend.clone(); - let ws_connections = ws_connections.clone(); + let ws_connections2 = ws_connections.clone(); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); let cancellation_handler = cancellation_handler.clone(); + let http_cancellation_token = http_cancellation_token.child_token(); - async move { - Ok::<_, Infallible>( - request_handler( + // `request_handler` is not cancel safe. It expects to be cancelled only at specific times. + // By spawning the future, we ensure it never gets cancelled until it decides to. + ws_connections.spawn( + async move { + // Cancel the current inflight HTTP request if the requets stream is closed. + // This is slightly different to `_cancel_connection` in that + // h2 can cancel individual requests with a `RST_STREAM`. + let _cancel_session = http_cancellation_token.clone().drop_guard(); + + let res = request_handler( req, config, backend, - ws_connections, + ws_connections2, cancellation_handler, peer_addr.ip(), endpoint_rate_limiter, + http_cancellation_token, ) .await - .map_or_else(|e| e.into_response(), |r| r), - ) - } + .map_or_else(|e| e.into_response(), |r| r); + + _cancel_session.disarm(); + + res + } + .in_current_span(), + ) }), gauge, + cancel_connection, + span, )) } }, @@ -176,11 +192,23 @@ pub async fn task_main( struct MetricService { inner: S, _gauge: IntCounterPairGuard, + _cancel: DropGuard, + span: tracing::Span, } impl MetricService { - fn new(inner: S, _gauge: IntCounterPairGuard) -> MetricService { - MetricService { inner, _gauge } + fn new( + inner: S, + _gauge: IntCounterPairGuard, + _cancel: DropGuard, + span: tracing::Span, + ) -> MetricService { + MetricService { + inner, + _gauge, + _cancel, + span, + } } } @@ -190,14 +218,16 @@ where { type Response = S::Response; type Error = S::Error; - type Future = S::Future; + type Future = Instrumented; fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { - self.inner.call(req) + self.span + .in_scope(|| self.inner.call(req)) + .instrument(self.span.clone()) } } @@ -210,6 +240,8 @@ async fn request_handler( cancellation_handler: Arc, peer_addr: IpAddr, endpoint_rate_limiter: Arc, + // used to cancel in-flight HTTP requests. not used to cancel websockets + http_cancellation_token: CancellationToken, ) -> Result, ApiError> { let session_id = uuid::Uuid::new_v4(); @@ -253,7 +285,7 @@ async fn request_handler( let ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region); let span = ctx.span.clone(); - sql_over_http::handle(config, ctx, request, backend) + sql_over_http::handle(config, ctx, request, backend, http_cancellation_token) .instrument(span) .await } else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS { diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 86c278030f..f675375ff1 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -217,8 +217,8 @@ pub async fn handle( mut ctx: RequestMonitoring, request: Request, backend: Arc, + cancel: CancellationToken, ) -> Result, ApiError> { - let cancel = CancellationToken::new(); let cancel2 = cancel.clone(); let handle = tokio::spawn(async move { time::sleep(config.http_config.request_timeout).await; diff --git a/proxy/src/serverless/tls_listener.rs b/proxy/src/serverless/tls_listener.rs index cce02e3850..33f194dd59 100644 --- a/proxy/src/serverless/tls_listener.rs +++ b/proxy/src/serverless/tls_listener.rs @@ -13,7 +13,7 @@ use tokio::{ time::timeout, }; use tokio_rustls::{server::TlsStream, TlsAcceptor}; -use tracing::{info, warn}; +use tracing::{info, warn, Instrument}; use crate::{ metrics::TLS_HANDSHAKE_FAILURES, @@ -29,24 +29,17 @@ pin_project! { tls: TlsAcceptor, waiting: JoinSet>>, timeout: Duration, - protocol: &'static str, } } impl TlsListener { /// Create a `TlsListener` with default options. - pub(crate) fn new( - tls: TlsAcceptor, - listener: A, - protocol: &'static str, - timeout: Duration, - ) -> Self { + pub(crate) fn new(tls: TlsAcceptor, listener: A, timeout: Duration) -> Self { TlsListener { listener, tls, waiting: JoinSet::new(), timeout, - protocol, } } } @@ -73,7 +66,7 @@ where Poll::Ready(Some(Ok(mut conn))) => { let t = *this.timeout; let tls = this.tls.clone(); - let protocol = *this.protocol; + let span = conn.span.clone(); this.waiting.spawn(async move { let peer_addr = match conn.inner.wait_for_addr().await { Ok(Some(addr)) => addr, @@ -86,21 +79,24 @@ where let accept = tls.accept(conn); match timeout(t, accept).await { - Ok(Ok(conn)) => Some(conn), + Ok(Ok(conn)) => { + info!(%peer_addr, "accepted new TLS connection"); + Some(conn) + }, // The handshake failed, try getting another connection from the queue Ok(Err(e)) => { TLS_HANDSHAKE_FAILURES.inc(); - warn!(%peer_addr, protocol, "failed to accept TLS connection: {e:?}"); + warn!(%peer_addr, "failed to accept TLS connection: {e:?}"); None } // The handshake timed out, try getting another connection from the queue Err(_) => { TLS_HANDSHAKE_FAILURES.inc(); - warn!(%peer_addr, protocol, "failed to accept TLS connection: timeout"); + warn!(%peer_addr, "failed to accept TLS connection: timeout"); None } } - }); + }.instrument(span)); } Poll::Ready(Some(Err(e))) => { tracing::error!("error accepting TCP connection: {e}"); @@ -112,10 +108,7 @@ where loop { return match this.waiting.poll_join_next(cx) { - Poll::Ready(Some(Ok(Some(conn)))) => { - info!(protocol = this.protocol, "accepted new TLS connection"); - Poll::Ready(Some(Ok(conn))) - } + Poll::Ready(Some(Ok(Some(conn)))) => Poll::Ready(Some(Ok(conn))), // The handshake failed to complete, try getting another connection from the queue Poll::Ready(Some(Ok(None))) => continue, // The handshake panicked or was cancelled. ignore and get another connection diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index b3f460c7fe..5b76e808d5 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -2944,6 +2944,7 @@ class NeonProxy(PgProtocol): user = quote(kwargs["user"]) password = quote(kwargs["password"]) expected_code = kwargs.get("expected_code") + timeout = kwargs.get("timeout") log.info(f"Executing http query: {query}") @@ -2957,6 +2958,7 @@ class NeonProxy(PgProtocol): "Neon-Pool-Opt-In": "true", }, verify=str(self.test_output_dir / "proxy.crt"), + timeout=timeout, ) if expected_code is not None: diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index 078589d8eb..3e986a8f7b 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -596,3 +596,39 @@ def test_sql_over_http_timeout_cancel(static_proxy: NeonProxy): assert ( "duplicate key value violates unique constraint" in res["message"] ), "HTTP query should conflict" + + +def test_sql_over_http_connection_cancel(static_proxy: NeonProxy): + static_proxy.safe_psql("create role http with login password 'http' superuser") + + static_proxy.safe_psql("create table test_table ( id int primary key )") + + # insert into a table, with a unique constraint, after sleeping for n seconds + query = "WITH temp AS ( \ + SELECT pg_sleep($1) as sleep, $2::int as id \ + ) INSERT INTO test_table (id) SELECT id FROM temp" + + try: + # The request should complete before the proxy HTTP timeout triggers. + # Timeout and cancel the request on the client side before the query completes. + static_proxy.http_query( + query, + [static_proxy.http_timeout_seconds - 1, 1], + user="http", + password="http", + timeout=2, + ) + except requests.exceptions.ReadTimeout: + pass + + # wait until the query _would_ have been complete + time.sleep(static_proxy.http_timeout_seconds) + + res = static_proxy.http_query(query, [1, 1], user="http", password="http", expected_code=200) + assert res["command"] == "INSERT", "HTTP query should insert" + assert res["rowCount"] == 1, "HTTP query should insert" + + res = static_proxy.http_query(query, [0, 1], user="http", password="http", expected_code=400) + assert ( + "duplicate key value violates unique constraint" in res["message"] + ), "HTTP query should conflict"