mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-07 13:32:57 +00:00
proxy http cancellation safety (#7117)
## Problem hyper auto-cancels the request futures on connection close. `sql_over_http::handle` is not 'drop cancel safe', so we need to do some other work to make sure connections are queries in the right way. ## Summary of changes 1. tokio::spawn the request handler to resolve the initial cancel-safety issue 2. share a cancellation token, and cancel it when the request `Service` is dropped. 3. Add a new log span to be able to track the HTTP connection lifecycle.
This commit is contained in:
@@ -341,7 +341,14 @@ impl Accept for ProxyProtocolAccept {
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
|
||||
let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?);
|
||||
tracing::info!(protocol = self.protocol, "accepted new TCP connection");
|
||||
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
let span = tracing::info_span!("http_conn", ?conn_id);
|
||||
{
|
||||
let _enter = span.enter();
|
||||
tracing::info!("accepted new TCP connection");
|
||||
}
|
||||
|
||||
let Some(conn) = conn else {
|
||||
return Poll::Ready(None);
|
||||
};
|
||||
@@ -354,6 +361,7 @@ impl Accept for ProxyProtocolAccept {
|
||||
.with_label_values(&[self.protocol])
|
||||
.guard(),
|
||||
)),
|
||||
span,
|
||||
})))
|
||||
}
|
||||
}
|
||||
@@ -364,6 +372,14 @@ pin_project! {
|
||||
pub inner: T,
|
||||
pub connection_id: Uuid,
|
||||
pub gauge: Mutex<Option<IntCounterPairGuard>>,
|
||||
pub span: tracing::Span,
|
||||
}
|
||||
|
||||
impl<S> PinnedDrop for WithConnectionGuard<S> {
|
||||
fn drop(this: Pin<&mut Self>) {
|
||||
let _enter = this.span.enter();
|
||||
tracing::info!("HTTP connection closed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ use rand::SeedableRng;
|
||||
pub use reqwest_middleware::{ClientWithMiddleware, Error};
|
||||
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
use tokio_util::task::TaskTracker;
|
||||
use tracing::instrument::Instrumented;
|
||||
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::protocol2::{ProxyProtocolAccept, WithClientIp, WithConnectionGuard};
|
||||
@@ -30,13 +31,12 @@ use hyper::{
|
||||
Body, Method, Request, Response,
|
||||
};
|
||||
|
||||
use std::convert::Infallible;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::task::Poll;
|
||||
use tls_listener::TlsListener;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::sync::{CancellationToken, DropGuard};
|
||||
use tracing::{error, info, warn, Instrument};
|
||||
use utils::http::{error::ApiError, json::json_response};
|
||||
|
||||
@@ -100,12 +100,7 @@ pub async fn task_main(
|
||||
let ws_connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
ws_connections.close(); // allows `ws_connections.wait to complete`
|
||||
|
||||
let tls_listener = TlsListener::new(
|
||||
tls_acceptor,
|
||||
addr_incoming,
|
||||
"http",
|
||||
config.handshake_timeout,
|
||||
);
|
||||
let tls_listener = TlsListener::new(tls_acceptor, addr_incoming, config.handshake_timeout);
|
||||
|
||||
let make_svc = hyper::service::make_service_fn(
|
||||
|stream: &tokio_rustls::server::TlsStream<
|
||||
@@ -121,6 +116,11 @@ pub async fn task_main(
|
||||
.take()
|
||||
.expect("gauge should be set on connection start");
|
||||
|
||||
// Cancel all current inflight HTTP requests if the HTTP connection is closed.
|
||||
let http_cancellation_token = CancellationToken::new();
|
||||
let cancel_connection = http_cancellation_token.clone().drop_guard();
|
||||
|
||||
let span = conn.span.clone();
|
||||
let client_addr = conn.inner.client_addr();
|
||||
let remote_addr = conn.inner.inner.remote_addr();
|
||||
let backend = backend.clone();
|
||||
@@ -136,27 +136,43 @@ pub async fn task_main(
|
||||
Ok(MetricService::new(
|
||||
hyper::service::service_fn(move |req: Request<Body>| {
|
||||
let backend = backend.clone();
|
||||
let ws_connections = ws_connections.clone();
|
||||
let ws_connections2 = ws_connections.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
let cancellation_handler = cancellation_handler.clone();
|
||||
let http_cancellation_token = http_cancellation_token.child_token();
|
||||
|
||||
async move {
|
||||
Ok::<_, Infallible>(
|
||||
request_handler(
|
||||
// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
|
||||
// By spawning the future, we ensure it never gets cancelled until it decides to.
|
||||
ws_connections.spawn(
|
||||
async move {
|
||||
// Cancel the current inflight HTTP request if the requets stream is closed.
|
||||
// This is slightly different to `_cancel_connection` in that
|
||||
// h2 can cancel individual requests with a `RST_STREAM`.
|
||||
let _cancel_session = http_cancellation_token.clone().drop_guard();
|
||||
|
||||
let res = request_handler(
|
||||
req,
|
||||
config,
|
||||
backend,
|
||||
ws_connections,
|
||||
ws_connections2,
|
||||
cancellation_handler,
|
||||
peer_addr.ip(),
|
||||
endpoint_rate_limiter,
|
||||
http_cancellation_token,
|
||||
)
|
||||
.await
|
||||
.map_or_else(|e| e.into_response(), |r| r),
|
||||
)
|
||||
}
|
||||
.map_or_else(|e| e.into_response(), |r| r);
|
||||
|
||||
_cancel_session.disarm();
|
||||
|
||||
res
|
||||
}
|
||||
.in_current_span(),
|
||||
)
|
||||
}),
|
||||
gauge,
|
||||
cancel_connection,
|
||||
span,
|
||||
))
|
||||
}
|
||||
},
|
||||
@@ -176,11 +192,23 @@ pub async fn task_main(
|
||||
struct MetricService<S> {
|
||||
inner: S,
|
||||
_gauge: IntCounterPairGuard,
|
||||
_cancel: DropGuard,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl<S> MetricService<S> {
|
||||
fn new(inner: S, _gauge: IntCounterPairGuard) -> MetricService<S> {
|
||||
MetricService { inner, _gauge }
|
||||
fn new(
|
||||
inner: S,
|
||||
_gauge: IntCounterPairGuard,
|
||||
_cancel: DropGuard,
|
||||
span: tracing::Span,
|
||||
) -> MetricService<S> {
|
||||
MetricService {
|
||||
inner,
|
||||
_gauge,
|
||||
_cancel,
|
||||
span,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -190,14 +218,16 @@ where
|
||||
{
|
||||
type Response = S::Response;
|
||||
type Error = S::Error;
|
||||
type Future = S::Future;
|
||||
type Future = Instrumented<S::Future>;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
|
||||
self.inner.call(req)
|
||||
self.span
|
||||
.in_scope(|| self.inner.call(req))
|
||||
.instrument(self.span.clone())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -210,6 +240,8 @@ async fn request_handler(
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
peer_addr: IpAddr,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
// used to cancel in-flight HTTP requests. not used to cancel websockets
|
||||
http_cancellation_token: CancellationToken,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
|
||||
@@ -253,7 +285,7 @@ async fn request_handler(
|
||||
let ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
|
||||
let span = ctx.span.clone();
|
||||
|
||||
sql_over_http::handle(config, ctx, request, backend)
|
||||
sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)
|
||||
.instrument(span)
|
||||
.await
|
||||
} else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
|
||||
|
||||
@@ -217,8 +217,8 @@ pub async fn handle(
|
||||
mut ctx: RequestMonitoring,
|
||||
request: Request<Body>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
cancel: CancellationToken,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
let cancel = CancellationToken::new();
|
||||
let cancel2 = cancel.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
time::sleep(config.http_config.request_timeout).await;
|
||||
|
||||
@@ -13,7 +13,7 @@ use tokio::{
|
||||
time::timeout,
|
||||
};
|
||||
use tokio_rustls::{server::TlsStream, TlsAcceptor};
|
||||
use tracing::{info, warn};
|
||||
use tracing::{info, warn, Instrument};
|
||||
|
||||
use crate::{
|
||||
metrics::TLS_HANDSHAKE_FAILURES,
|
||||
@@ -29,24 +29,17 @@ pin_project! {
|
||||
tls: TlsAcceptor,
|
||||
waiting: JoinSet<Option<TlsStream<A::Conn>>>,
|
||||
timeout: Duration,
|
||||
protocol: &'static str,
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: Accept> TlsListener<A> {
|
||||
/// Create a `TlsListener` with default options.
|
||||
pub(crate) fn new(
|
||||
tls: TlsAcceptor,
|
||||
listener: A,
|
||||
protocol: &'static str,
|
||||
timeout: Duration,
|
||||
) -> Self {
|
||||
pub(crate) fn new(tls: TlsAcceptor, listener: A, timeout: Duration) -> Self {
|
||||
TlsListener {
|
||||
listener,
|
||||
tls,
|
||||
waiting: JoinSet::new(),
|
||||
timeout,
|
||||
protocol,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -73,7 +66,7 @@ where
|
||||
Poll::Ready(Some(Ok(mut conn))) => {
|
||||
let t = *this.timeout;
|
||||
let tls = this.tls.clone();
|
||||
let protocol = *this.protocol;
|
||||
let span = conn.span.clone();
|
||||
this.waiting.spawn(async move {
|
||||
let peer_addr = match conn.inner.wait_for_addr().await {
|
||||
Ok(Some(addr)) => addr,
|
||||
@@ -86,21 +79,24 @@ where
|
||||
|
||||
let accept = tls.accept(conn);
|
||||
match timeout(t, accept).await {
|
||||
Ok(Ok(conn)) => Some(conn),
|
||||
Ok(Ok(conn)) => {
|
||||
info!(%peer_addr, "accepted new TLS connection");
|
||||
Some(conn)
|
||||
},
|
||||
// The handshake failed, try getting another connection from the queue
|
||||
Ok(Err(e)) => {
|
||||
TLS_HANDSHAKE_FAILURES.inc();
|
||||
warn!(%peer_addr, protocol, "failed to accept TLS connection: {e:?}");
|
||||
warn!(%peer_addr, "failed to accept TLS connection: {e:?}");
|
||||
None
|
||||
}
|
||||
// The handshake timed out, try getting another connection from the queue
|
||||
Err(_) => {
|
||||
TLS_HANDSHAKE_FAILURES.inc();
|
||||
warn!(%peer_addr, protocol, "failed to accept TLS connection: timeout");
|
||||
warn!(%peer_addr, "failed to accept TLS connection: timeout");
|
||||
None
|
||||
}
|
||||
}
|
||||
});
|
||||
}.instrument(span));
|
||||
}
|
||||
Poll::Ready(Some(Err(e))) => {
|
||||
tracing::error!("error accepting TCP connection: {e}");
|
||||
@@ -112,10 +108,7 @@ where
|
||||
|
||||
loop {
|
||||
return match this.waiting.poll_join_next(cx) {
|
||||
Poll::Ready(Some(Ok(Some(conn)))) => {
|
||||
info!(protocol = this.protocol, "accepted new TLS connection");
|
||||
Poll::Ready(Some(Ok(conn)))
|
||||
}
|
||||
Poll::Ready(Some(Ok(Some(conn)))) => Poll::Ready(Some(Ok(conn))),
|
||||
// The handshake failed to complete, try getting another connection from the queue
|
||||
Poll::Ready(Some(Ok(None))) => continue,
|
||||
// The handshake panicked or was cancelled. ignore and get another connection
|
||||
|
||||
@@ -2944,6 +2944,7 @@ class NeonProxy(PgProtocol):
|
||||
user = quote(kwargs["user"])
|
||||
password = quote(kwargs["password"])
|
||||
expected_code = kwargs.get("expected_code")
|
||||
timeout = kwargs.get("timeout")
|
||||
|
||||
log.info(f"Executing http query: {query}")
|
||||
|
||||
@@ -2957,6 +2958,7 @@ class NeonProxy(PgProtocol):
|
||||
"Neon-Pool-Opt-In": "true",
|
||||
},
|
||||
verify=str(self.test_output_dir / "proxy.crt"),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if expected_code is not None:
|
||||
|
||||
@@ -596,3 +596,39 @@ def test_sql_over_http_timeout_cancel(static_proxy: NeonProxy):
|
||||
assert (
|
||||
"duplicate key value violates unique constraint" in res["message"]
|
||||
), "HTTP query should conflict"
|
||||
|
||||
|
||||
def test_sql_over_http_connection_cancel(static_proxy: NeonProxy):
|
||||
static_proxy.safe_psql("create role http with login password 'http' superuser")
|
||||
|
||||
static_proxy.safe_psql("create table test_table ( id int primary key )")
|
||||
|
||||
# insert into a table, with a unique constraint, after sleeping for n seconds
|
||||
query = "WITH temp AS ( \
|
||||
SELECT pg_sleep($1) as sleep, $2::int as id \
|
||||
) INSERT INTO test_table (id) SELECT id FROM temp"
|
||||
|
||||
try:
|
||||
# The request should complete before the proxy HTTP timeout triggers.
|
||||
# Timeout and cancel the request on the client side before the query completes.
|
||||
static_proxy.http_query(
|
||||
query,
|
||||
[static_proxy.http_timeout_seconds - 1, 1],
|
||||
user="http",
|
||||
password="http",
|
||||
timeout=2,
|
||||
)
|
||||
except requests.exceptions.ReadTimeout:
|
||||
pass
|
||||
|
||||
# wait until the query _would_ have been complete
|
||||
time.sleep(static_proxy.http_timeout_seconds)
|
||||
|
||||
res = static_proxy.http_query(query, [1, 1], user="http", password="http", expected_code=200)
|
||||
assert res["command"] == "INSERT", "HTTP query should insert"
|
||||
assert res["rowCount"] == 1, "HTTP query should insert"
|
||||
|
||||
res = static_proxy.http_query(query, [0, 1], user="http", password="http", expected_code=400)
|
||||
assert (
|
||||
"duplicate key value violates unique constraint" in res["message"]
|
||||
), "HTTP query should conflict"
|
||||
|
||||
Reference in New Issue
Block a user