From c2876ec55d985d2820467bd0e248500a29be649c Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 7 Mar 2024 12:36:47 +0000 Subject: [PATCH] proxy http tls investigations (#7045) ## Problem Some HTTP-specific TLS errors ## Summary of changes Add more logging, vendor `tls-listener` with minor modifications. --- Cargo.lock | 15 -- Cargo.toml | 1 - proxy/Cargo.toml | 1 - proxy/src/metrics.rs | 10 +- proxy/src/protocol2.rs | 78 +++++++- proxy/src/proxy.rs | 14 +- proxy/src/serverless.rs | 50 +++-- proxy/src/serverless/tls_listener.rs | 283 +++++++++++++++++++++++++++ proxy/src/serverless/websocket.rs | 6 + proxy/src/stream.rs | 6 +- 10 files changed, 418 insertions(+), 46 deletions(-) create mode 100644 proxy/src/serverless/tls_listener.rs diff --git a/Cargo.lock b/Cargo.lock index 864e5c9046..167a2b2179 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4216,7 +4216,6 @@ dependencies = [ "thiserror", "tikv-jemalloc-ctl", "tikv-jemallocator", - "tls-listener", "tokio", "tokio-postgres", "tokio-postgres-rustls", @@ -5794,20 +5793,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" -[[package]] -name = "tls-listener" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81294c017957a1a69794f506723519255879e15a870507faf45dfed288b763dd" -dependencies = [ - "futures-util", - "hyper", - "pin-project-lite", - "thiserror", - "tokio", - "tokio-rustls", -] - [[package]] name = "tokio" version = "1.36.0" diff --git a/Cargo.toml b/Cargo.toml index 90b02b30ec..42deaac19b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -156,7 +156,6 @@ test-context = "0.1" thiserror = "1.0" tikv-jemallocator = "0.5" tikv-jemalloc-ctl = "0.5" -tls-listener = { version = "0.7", features = ["rustls", "hyper-h1"] } tokio = { version = "1.17", features = ["macros"] } tokio-epoll-uring = { git = "https://github.com/neondatabase/tokio-epoll-uring.git" , branch = "main" } tokio-io-timeout = "1.2.0" diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 0777d361d2..d8112c8bf0 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -68,7 +68,6 @@ task-local-extensions.workspace = true thiserror.workspace = true tikv-jemallocator.workspace = true tikv-jemalloc-ctl = { workspace = true, features = ["use_std"] } -tls-listener.workspace = true tokio-postgres.workspace = true tokio-rustls.workspace = true tokio-util.workspace = true diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index 2464b1e611..0477176c45 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -4,7 +4,7 @@ use ::metrics::{ register_int_gauge_vec, Histogram, HistogramVec, HyperLogLogVec, IntCounterPairVec, IntCounterVec, IntGauge, IntGaugeVec, }; -use metrics::{register_int_counter_pair, IntCounterPair}; +use metrics::{register_int_counter, register_int_counter_pair, IntCounter, IntCounterPair}; use once_cell::sync::Lazy; use tokio::time; @@ -312,3 +312,11 @@ pub static REDIS_BROKEN_MESSAGES: Lazy = Lazy::new(|| { ) .unwrap() }); + +pub static TLS_HANDSHAKE_FAILURES: Lazy = Lazy::new(|| { + register_int_counter!( + "proxy_tls_handshake_failures", + "Number of TLS handshake failures", + ) + .unwrap() +}); diff --git a/proxy/src/protocol2.rs b/proxy/src/protocol2.rs index 1d8931be85..3a7aabca32 100644 --- a/proxy/src/protocol2.rs +++ b/proxy/src/protocol2.rs @@ -1,22 +1,27 @@ //! Proxy Protocol V2 implementation use std::{ - future::poll_fn, - future::Future, + future::{poll_fn, Future}, io, net::SocketAddr, pin::{pin, Pin}, + sync::Mutex, task::{ready, Context, Poll}, }; use bytes::{Buf, BytesMut}; +use hyper::server::accept::Accept; use hyper::server::conn::{AddrIncoming, AddrStream}; +use metrics::IntCounterPairGuard; use pin_project_lite::pin_project; -use tls_listener::AsyncAccept; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf}; +use uuid::Uuid; + +use crate::{metrics::NUM_CLIENT_CONNECTION_GAUGE, serverless::tls_listener::AsyncAccept}; pub struct ProxyProtocolAccept { pub incoming: AddrIncoming, + pub protocol: &'static str, } pin_project! { @@ -327,7 +332,7 @@ impl AsyncRead for WithClientIp { } impl AsyncAccept for ProxyProtocolAccept { - type Connection = WithClientIp; + type Connection = WithConnectionGuard>; type Error = io::Error; @@ -336,11 +341,74 @@ impl AsyncAccept for ProxyProtocolAccept { cx: &mut Context<'_>, ) -> Poll>> { let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?); + tracing::info!(protocol = self.protocol, "accepted new TCP connection"); let Some(conn) = conn else { return Poll::Ready(None); }; - Poll::Ready(Some(Ok(WithClientIp::new(conn)))) + Poll::Ready(Some(Ok(WithConnectionGuard { + inner: WithClientIp::new(conn), + connection_id: Uuid::new_v4(), + gauge: Mutex::new(Some( + NUM_CLIENT_CONNECTION_GAUGE + .with_label_values(&[self.protocol]) + .guard(), + )), + }))) + } +} + +pin_project! { + pub struct WithConnectionGuard { + #[pin] + pub inner: T, + pub connection_id: Uuid, + pub gauge: Mutex>, + } +} + +impl AsyncWrite for WithConnectionGuard { + #[inline] + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().inner.poll_write(cx, buf) + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_flush(cx) + } + + #[inline] + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_shutdown(cx) + } + + #[inline] + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + self.project().inner.poll_write_vectored(cx, bufs) + } + + #[inline] + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } +} + +impl AsyncRead for WithConnectionGuard { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.project().inner.poll_read(cx, buf) } } diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index d94fc67491..aeba08bc4f 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -24,6 +24,7 @@ use crate::{ }; use futures::TryFutureExt; use itertools::Itertools; +use metrics::IntCounterPairGuard; use once_cell::sync::OnceCell; use pq_proto::{BeMessage as Be, StartupMessageParams}; use regex::Regex; @@ -78,10 +79,16 @@ pub async fn task_main( { let (socket, peer_addr) = accept_result?; + let conn_gauge = NUM_CLIENT_CONNECTION_GAUGE + .with_label_values(&["tcp"]) + .guard(); + let session_id = uuid::Uuid::new_v4(); let cancellation_handler = Arc::clone(&cancellation_handler); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); + tracing::info!(protocol = "tcp", %session_id, "accepted new TCP connection"); + connections.spawn(async move { let mut socket = WithClientIp::new(socket); let mut peer_addr = peer_addr.ip(); @@ -116,6 +123,7 @@ pub async fn task_main( socket, ClientMode::Tcp, endpoint_rate_limiter, + conn_gauge, ) .instrument(span.clone()) .await; @@ -229,13 +237,11 @@ pub async fn handle_client( stream: S, mode: ClientMode, endpoint_rate_limiter: Arc, + conn_gauge: IntCounterPairGuard, ) -> Result>, ClientRequestError> { info!("handling interactive connection from client"); let proto = ctx.protocol; - let _client_gauge = NUM_CLIENT_CONNECTION_GAUGE - .with_label_values(&[proto]) - .guard(); let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE .with_label_values(&[proto]) .guard(); @@ -325,7 +331,7 @@ pub async fn handle_client( aux: node.aux.clone(), compute: node, req: _request_gauge, - conn: _client_gauge, + conn: conn_gauge, cancel: session, })) } diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index b5806aec53..c81ae03b23 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -6,6 +6,7 @@ mod backend; mod conn_pool; mod json; mod sql_over_http; +pub mod tls_listener; mod websocket; pub use conn_pool::GlobalConnPoolOptions; @@ -20,8 +21,8 @@ pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use tokio_util::task::TaskTracker; use crate::context::RequestMonitoring; -use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE; -use crate::protocol2::{ProxyProtocolAccept, WithClientIp}; +use crate::metrics::TLS_HANDSHAKE_FAILURES; +use crate::protocol2::{ProxyProtocolAccept, WithClientIp, WithConnectionGuard}; use crate::rate_limiter::EndpointRateLimiter; use crate::serverless::backend::PoolingBackend; use crate::{cancellation::CancellationHandler, config::ProxyConfig}; @@ -98,6 +99,7 @@ pub async fn task_main( let _ = addr_incoming.set_nodelay(true); let addr_incoming = ProxyProtocolAccept { incoming: addr_incoming, + protocol: "http", }; let ws_connections = tokio_util::task::task_tracker::TaskTracker::new(); @@ -105,18 +107,34 @@ pub async fn task_main( let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| { if let Err(err) = conn { - error!("failed to accept TLS connection for websockets: {err:?}"); + error!( + protocol = "http", + "failed to accept TLS connection: {err:?}" + ); + TLS_HANDSHAKE_FAILURES.inc(); ready(false) } else { + info!(protocol = "http", "accepted new TLS connection"); ready(true) } }); let make_svc = hyper::service::make_service_fn( - |stream: &tokio_rustls::server::TlsStream>| { - let (io, _) = stream.get_ref(); - let client_addr = io.client_addr(); - let remote_addr = io.inner.remote_addr(); + |stream: &tokio_rustls::server::TlsStream< + WithConnectionGuard>, + >| { + let (conn, _) = stream.get_ref(); + + // this is jank. should dissapear with hyper 1.0 migration. + let gauge = conn + .gauge + .lock() + .expect("lock should not be poisoned") + .take() + .expect("gauge should be set on connection start"); + + let client_addr = conn.inner.client_addr(); + let remote_addr = conn.inner.inner.remote_addr(); let backend = backend.clone(); let ws_connections = ws_connections.clone(); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); @@ -127,8 +145,8 @@ pub async fn task_main( None if config.require_client_ip => bail!("missing required client ip"), None => remote_addr, }; - Ok(MetricService::new(hyper::service::service_fn( - move |req: Request| { + Ok(MetricService::new( + hyper::service::service_fn(move |req: Request| { let backend = backend.clone(); let ws_connections = ws_connections.clone(); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); @@ -149,8 +167,9 @@ pub async fn task_main( .map_or_else(|e| e.into_response(), |r| r), ) } - }, - ))) + }), + gauge, + )) } }, ); @@ -172,13 +191,8 @@ struct MetricService { } impl MetricService { - fn new(inner: S) -> MetricService { - MetricService { - inner, - _gauge: NUM_CLIENT_CONNECTION_GAUGE - .with_label_values(&["http"]) - .guard(), - } + fn new(inner: S, _gauge: IntCounterPairGuard) -> MetricService { + MetricService { inner, _gauge } } } diff --git a/proxy/src/serverless/tls_listener.rs b/proxy/src/serverless/tls_listener.rs new file mode 100644 index 0000000000..6196ff393c --- /dev/null +++ b/proxy/src/serverless/tls_listener.rs @@ -0,0 +1,283 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; + +use futures::{Future, Stream, StreamExt}; +use pin_project_lite::pin_project; +use thiserror::Error; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + task::JoinSet, + time::timeout, +}; + +/// Default timeout for the TLS handshake. +pub const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10); + +/// Trait for TLS implementation. +/// +/// Implementations are provided by the rustls and native-tls features. +pub trait AsyncTls: Clone { + /// The type of the TLS stream created from the underlying stream. + type Stream: Send + 'static; + /// Error type for completing the TLS handshake + type Error: std::error::Error + Send + 'static; + /// Type of the Future for the TLS stream that is accepted. + type AcceptFuture: Future> + Send + 'static; + + /// Accept a TLS connection on an underlying stream + fn accept(&self, stream: C) -> Self::AcceptFuture; +} + +/// Asynchronously accept connections. +pub trait AsyncAccept { + /// The type of the connection that is accepted. + type Connection: AsyncRead + AsyncWrite; + /// The type of error that may be returned. + type Error; + + /// Poll to accept the next connection. + fn poll_accept( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>>; + + /// Return a new `AsyncAccept` that stops accepting connections after + /// `ender` completes. + /// + /// Useful for graceful shutdown. + /// + /// See [examples/echo.rs](https://github.com/tmccombs/tls-listener/blob/main/examples/echo.rs) + /// for example of how to use. + fn until(self, ender: F) -> Until + where + Self: Sized, + { + Until { + acceptor: self, + ender, + } + } +} + +pin_project! { + /// + /// Wraps a `Stream` of connections (such as a TCP listener) so that each connection is itself + /// encrypted using TLS. + /// + /// It is similar to: + /// + /// ```ignore + /// tcpListener.and_then(|s| tlsAcceptor.accept(s)) + /// ``` + /// + /// except that it has the ability to accept multiple transport-level connections + /// simultaneously while the TLS handshake is pending for other connections. + /// + /// By default, if a client fails the TLS handshake, that is treated as an error, and the + /// `TlsListener` will return an `Err`. If the `TlsListener` is passed directly to a hyper + /// [`Server`][1], then an invalid handshake can cause the server to stop accepting connections. + /// See [`http-stream.rs`][2] or [`http-low-level`][3] examples, for examples of how to avoid this. + /// + /// Note that if the maximum number of pending connections is greater than 1, the resulting + /// [`T::Stream`][4] connections may come in a different order than the connections produced by the + /// underlying listener. + /// + /// [1]: https://docs.rs/hyper/latest/hyper/server/struct.Server.html + /// [2]: https://github.com/tmccombs/tls-listener/blob/main/examples/http-stream.rs + /// [3]: https://github.com/tmccombs/tls-listener/blob/main/examples/http-low-level.rs + /// [4]: AsyncTls::Stream + /// + #[allow(clippy::type_complexity)] + pub struct TlsListener> { + #[pin] + listener: A, + tls: T, + waiting: JoinSet, tokio::time::error::Elapsed>>, + timeout: Duration, + } +} + +/// Builder for `TlsListener`. +#[derive(Clone)] +pub struct Builder { + tls: T, + handshake_timeout: Duration, +} + +/// Wraps errors from either the listener or the TLS Acceptor +#[derive(Debug, Error)] +pub enum Error { + /// An error that arose from the listener ([AsyncAccept::Error]) + #[error("{0}")] + ListenerError(#[source] LE), + /// An error that occurred during the TLS accept handshake + #[error("{0}")] + TlsAcceptError(#[source] TE), +} + +impl TlsListener +where + T: AsyncTls, +{ + /// Create a `TlsListener` with default options. + pub fn new(tls: T, listener: A) -> Self { + builder(tls).listen(listener) + } +} + +impl TlsListener +where + A: AsyncAccept, + A::Error: std::error::Error, + T: AsyncTls, +{ + /// Accept the next connection + /// + /// This is essentially an alias to `self.next()` with a more domain-appropriate name. + pub async fn accept(&mut self) -> Option<::Item> + where + Self: Unpin, + { + self.next().await + } + + /// Replaces the Tls Acceptor configuration, which will be used for new connections. + /// + /// This can be used to change the certificate used at runtime. + pub fn replace_acceptor(&mut self, acceptor: T) { + self.tls = acceptor; + } + + /// Replaces the Tls Acceptor configuration from a pinned reference to `Self`. + /// + /// This is useful if your listener is `!Unpin`. + /// + /// This can be used to change the certificate used at runtime. + pub fn replace_acceptor_pin(self: Pin<&mut Self>, acceptor: T) { + *self.project().tls = acceptor; + } +} + +impl Stream for TlsListener +where + A: AsyncAccept, + A::Error: std::error::Error, + T: AsyncTls, +{ + type Item = Result>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + loop { + match this.listener.as_mut().poll_accept(cx) { + Poll::Pending => break, + Poll::Ready(Some(Ok(conn))) => { + this.waiting + .spawn(timeout(*this.timeout, this.tls.accept(conn))); + } + Poll::Ready(Some(Err(e))) => { + return Poll::Ready(Some(Err(Error::ListenerError(e)))); + } + Poll::Ready(None) => return Poll::Ready(None), + } + } + + loop { + return match this.waiting.poll_join_next(cx) { + Poll::Ready(Some(Ok(Ok(conn)))) => { + Poll::Ready(Some(conn.map_err(Error::TlsAcceptError))) + } + // The handshake timed out, try getting another connection from the queue + Poll::Ready(Some(Ok(Err(_)))) => continue, + // The handshake panicked + Poll::Ready(Some(Err(e))) if e.is_panic() => { + std::panic::resume_unwind(e.into_panic()) + } + // The handshake was externally aborted + Poll::Ready(Some(Err(_))) => unreachable!("handshake tasks are never aborted"), + _ => Poll::Pending, + }; + } + } +} + +impl AsyncTls for tokio_rustls::TlsAcceptor { + type Stream = tokio_rustls::server::TlsStream; + type Error = std::io::Error; + type AcceptFuture = tokio_rustls::Accept; + + fn accept(&self, conn: C) -> Self::AcceptFuture { + tokio_rustls::TlsAcceptor::accept(self, conn) + } +} + +impl Builder { + /// Set the timeout for handshakes. + /// + /// If a timeout takes longer than `timeout`, then the handshake will be + /// aborted and the underlying connection will be dropped. + /// + /// Defaults to `DEFAULT_HANDSHAKE_TIMEOUT`. + pub fn handshake_timeout(&mut self, timeout: Duration) -> &mut Self { + self.handshake_timeout = timeout; + self + } + + /// Create a `TlsListener` from the builder + /// + /// Actually build the `TlsListener`. The `listener` argument should be + /// an implementation of the `AsyncAccept` trait that accepts new connections + /// that the `TlsListener` will encrypt using TLS. + pub fn listen(&self, listener: A) -> TlsListener + where + T: AsyncTls, + { + TlsListener { + listener, + tls: self.tls.clone(), + waiting: JoinSet::new(), + timeout: self.handshake_timeout, + } + } +} + +/// Create a new Builder for a TlsListener +/// +/// `server_config` will be used to configure the TLS sessions. +pub fn builder(tls: T) -> Builder { + Builder { + tls, + handshake_timeout: DEFAULT_HANDSHAKE_TIMEOUT, + } +} + +pin_project! { + /// See [`AsyncAccept::until`] + pub struct Until { + #[pin] + acceptor: A, + #[pin] + ender: E, + } +} + +impl AsyncAccept for Until { + type Connection = A::Connection; + type Error = A::Error; + + fn poll_accept( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let this = self.project(); + + match this.ender.poll(cx) { + Poll::Pending => this.acceptor.poll_accept(cx), + Poll::Ready(_) => Poll::Ready(None), + } + } +} diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 24f2bb7e8c..a72ede6d0a 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -3,6 +3,7 @@ use crate::{ config::ProxyConfig, context::RequestMonitoring, error::{io_error, ReportableError}, + metrics::NUM_CLIENT_CONNECTION_GAUGE, proxy::{handle_client, ClientMode}, rate_limiter::EndpointRateLimiter, }; @@ -138,6 +139,10 @@ pub async fn serve_websocket( endpoint_rate_limiter: Arc, ) -> anyhow::Result<()> { let websocket = websocket.await?; + let conn_gauge = NUM_CLIENT_CONNECTION_GAUGE + .with_label_values(&["ws"]) + .guard(); + let res = handle_client( config, &mut ctx, @@ -145,6 +150,7 @@ pub async fn serve_websocket( WebSocketRw::new(websocket), ClientMode::Websockets { hostname }, endpoint_rate_limiter, + conn_gauge, ) .await; diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 0d639d2c07..b6b7a85659 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -1,5 +1,6 @@ use crate::config::TlsServerEndPoint; use crate::error::{ErrorKind, ReportableError, UserFacingError}; +use crate::metrics::TLS_HANDSHAKE_FAILURES; use bytes::BytesMut; use pq_proto::framed::{ConnectionError, Framed}; @@ -224,7 +225,10 @@ impl Stream { /// If possible, upgrade raw stream into a secure TLS-based stream. pub async fn upgrade(self, cfg: Arc) -> Result, StreamUpgradeError> { match self { - Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?), + Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg) + .accept(raw) + .await + .inspect_err(|_| TLS_HANDSHAKE_FAILURES.inc())?), Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls), } }