From 1f7d54f9872482b4b181f93dee2e6d91173d0ef8 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Tue, 12 Mar 2024 13:05:40 +0000 Subject: [PATCH] proxy refactor tls listener (#7056) ## Problem Now that we have tls-listener vendored, we can refactor and remove a lot of bloated code and make the whole flow a bit simpler ## Summary of changes 1. Remove dead code 2. Move the error handling to inside the `TlsListener` accept() function 3. Extract the peer_addr from the PROXY protocol header and log it with errors --- proxy/src/protocol2.rs | 8 +- proxy/src/serverless.rs | 30 +-- proxy/src/serverless/tls_listener.rs | 321 +++++++-------------------- 3 files changed, 97 insertions(+), 262 deletions(-) diff --git a/proxy/src/protocol2.rs b/proxy/src/protocol2.rs index 3a7aabca32..f476cb9b37 100644 --- a/proxy/src/protocol2.rs +++ b/proxy/src/protocol2.rs @@ -17,7 +17,7 @@ use pin_project_lite::pin_project; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf}; use uuid::Uuid; -use crate::{metrics::NUM_CLIENT_CONNECTION_GAUGE, serverless::tls_listener::AsyncAccept}; +use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE; pub struct ProxyProtocolAccept { pub incoming: AddrIncoming, @@ -331,15 +331,15 @@ impl AsyncRead for WithClientIp { } } -impl AsyncAccept for ProxyProtocolAccept { - type Connection = WithConnectionGuard>; +impl Accept for ProxyProtocolAccept { + type Conn = WithConnectionGuard>; type Error = io::Error; fn poll_accept( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?); tracing::info!(protocol = self.protocol, "accepted new TCP connection"); let Some(conn) = conn else { diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index c81ae03b23..68f68eaba1 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -21,24 +21,19 @@ pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use tokio_util::task::TaskTracker; use crate::context::RequestMonitoring; -use crate::metrics::TLS_HANDSHAKE_FAILURES; use crate::protocol2::{ProxyProtocolAccept, WithClientIp, WithConnectionGuard}; use crate::rate_limiter::EndpointRateLimiter; use crate::serverless::backend::PoolingBackend; use crate::{cancellation::CancellationHandler, config::ProxyConfig}; -use futures::StreamExt; use hyper::{ - server::{ - accept, - conn::{AddrIncoming, AddrStream}, - }, + server::conn::{AddrIncoming, AddrStream}, Body, Method, Request, Response, }; use std::convert::Infallible; use std::net::IpAddr; +use std::sync::Arc; use std::task::Poll; -use std::{future::ready, sync::Arc}; use tls_listener::TlsListener; use tokio::net::TcpListener; use tokio_util::sync::CancellationToken; @@ -105,19 +100,12 @@ pub async fn task_main( let ws_connections = tokio_util::task::task_tracker::TaskTracker::new(); ws_connections.close(); // allows `ws_connections.wait to complete` - let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| { - if let Err(err) = conn { - error!( - protocol = "http", - "failed to accept TLS connection: {err:?}" - ); - TLS_HANDSHAKE_FAILURES.inc(); - ready(false) - } else { - info!(protocol = "http", "accepted new TLS connection"); - ready(true) - } - }); + let tls_listener = TlsListener::new( + tls_acceptor, + addr_incoming, + "http", + config.handshake_timeout, + ); let make_svc = hyper::service::make_service_fn( |stream: &tokio_rustls::server::TlsStream< @@ -174,7 +162,7 @@ pub async fn task_main( }, ); - hyper::Server::builder(accept::from_stream(tls_listener)) + hyper::Server::builder(tls_listener) .serve(make_svc) .with_graceful_shutdown(cancellation_token.cancelled()) .await?; diff --git a/proxy/src/serverless/tls_listener.rs b/proxy/src/serverless/tls_listener.rs index 6196ff393c..cce02e3850 100644 --- a/proxy/src/serverless/tls_listener.rs +++ b/proxy/src/serverless/tls_listener.rs @@ -1,186 +1,110 @@ use std::{ + convert::Infallible, pin::Pin, task::{Context, Poll}, time::Duration, }; -use futures::{Future, Stream, StreamExt}; +use hyper::server::{accept::Accept, conn::AddrStream}; use pin_project_lite::pin_project; -use thiserror::Error; use tokio::{ io::{AsyncRead, AsyncWrite}, task::JoinSet, time::timeout, }; +use tokio_rustls::{server::TlsStream, TlsAcceptor}; +use tracing::{info, warn}; -/// Default timeout for the TLS handshake. -pub const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10); +use crate::{ + metrics::TLS_HANDSHAKE_FAILURES, + protocol2::{WithClientIp, WithConnectionGuard}, +}; -/// Trait for TLS implementation. -/// -/// Implementations are provided by the rustls and native-tls features. -pub trait AsyncTls: Clone { - /// The type of the TLS stream created from the underlying stream. - type Stream: Send + 'static; - /// Error type for completing the TLS handshake - type Error: std::error::Error + Send + 'static; - /// Type of the Future for the TLS stream that is accepted. - type AcceptFuture: Future> + Send + 'static; - - /// Accept a TLS connection on an underlying stream - fn accept(&self, stream: C) -> Self::AcceptFuture; +pin_project! { + /// Wraps a `Stream` of connections (such as a TCP listener) so that each connection is itself + /// encrypted using TLS. + pub(crate) struct TlsListener { + #[pin] + listener: A, + tls: TlsAcceptor, + waiting: JoinSet>>, + timeout: Duration, + protocol: &'static str, + } } -/// Asynchronously accept connections. -pub trait AsyncAccept { - /// The type of the connection that is accepted. - type Connection: AsyncRead + AsyncWrite; - /// The type of error that may be returned. - type Error; - - /// Poll to accept the next connection. - fn poll_accept( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>>; - - /// Return a new `AsyncAccept` that stops accepting connections after - /// `ender` completes. - /// - /// Useful for graceful shutdown. - /// - /// See [examples/echo.rs](https://github.com/tmccombs/tls-listener/blob/main/examples/echo.rs) - /// for example of how to use. - fn until(self, ender: F) -> Until - where - Self: Sized, - { - Until { - acceptor: self, - ender, +impl TlsListener { + /// Create a `TlsListener` with default options. + pub(crate) fn new( + tls: TlsAcceptor, + listener: A, + protocol: &'static str, + timeout: Duration, + ) -> Self { + TlsListener { + listener, + tls, + waiting: JoinSet::new(), + timeout, + protocol, } } } -pin_project! { - /// - /// Wraps a `Stream` of connections (such as a TCP listener) so that each connection is itself - /// encrypted using TLS. - /// - /// It is similar to: - /// - /// ```ignore - /// tcpListener.and_then(|s| tlsAcceptor.accept(s)) - /// ``` - /// - /// except that it has the ability to accept multiple transport-level connections - /// simultaneously while the TLS handshake is pending for other connections. - /// - /// By default, if a client fails the TLS handshake, that is treated as an error, and the - /// `TlsListener` will return an `Err`. If the `TlsListener` is passed directly to a hyper - /// [`Server`][1], then an invalid handshake can cause the server to stop accepting connections. - /// See [`http-stream.rs`][2] or [`http-low-level`][3] examples, for examples of how to avoid this. - /// - /// Note that if the maximum number of pending connections is greater than 1, the resulting - /// [`T::Stream`][4] connections may come in a different order than the connections produced by the - /// underlying listener. - /// - /// [1]: https://docs.rs/hyper/latest/hyper/server/struct.Server.html - /// [2]: https://github.com/tmccombs/tls-listener/blob/main/examples/http-stream.rs - /// [3]: https://github.com/tmccombs/tls-listener/blob/main/examples/http-low-level.rs - /// [4]: AsyncTls::Stream - /// - #[allow(clippy::type_complexity)] - pub struct TlsListener> { - #[pin] - listener: A, - tls: T, - waiting: JoinSet, tokio::time::error::Elapsed>>, - timeout: Duration, - } -} - -/// Builder for `TlsListener`. -#[derive(Clone)] -pub struct Builder { - tls: T, - handshake_timeout: Duration, -} - -/// Wraps errors from either the listener or the TLS Acceptor -#[derive(Debug, Error)] -pub enum Error { - /// An error that arose from the listener ([AsyncAccept::Error]) - #[error("{0}")] - ListenerError(#[source] LE), - /// An error that occurred during the TLS accept handshake - #[error("{0}")] - TlsAcceptError(#[source] TE), -} - -impl TlsListener +impl Accept for TlsListener where - T: AsyncTls, -{ - /// Create a `TlsListener` with default options. - pub fn new(tls: T, listener: A) -> Self { - builder(tls).listen(listener) - } -} - -impl TlsListener -where - A: AsyncAccept, + A: Accept>>, A::Error: std::error::Error, - T: AsyncTls, + A::Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - /// Accept the next connection - /// - /// This is essentially an alias to `self.next()` with a more domain-appropriate name. - pub async fn accept(&mut self) -> Option<::Item> - where - Self: Unpin, - { - self.next().await - } + type Conn = TlsStream; - /// Replaces the Tls Acceptor configuration, which will be used for new connections. - /// - /// This can be used to change the certificate used at runtime. - pub fn replace_acceptor(&mut self, acceptor: T) { - self.tls = acceptor; - } + type Error = Infallible; - /// Replaces the Tls Acceptor configuration from a pinned reference to `Self`. - /// - /// This is useful if your listener is `!Unpin`. - /// - /// This can be used to change the certificate used at runtime. - pub fn replace_acceptor_pin(self: Pin<&mut Self>, acceptor: T) { - *self.project().tls = acceptor; - } -} - -impl Stream for TlsListener -where - A: AsyncAccept, - A::Error: std::error::Error, - T: AsyncTls, -{ - type Item = Result>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_accept( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { let mut this = self.project(); loop { match this.listener.as_mut().poll_accept(cx) { Poll::Pending => break, - Poll::Ready(Some(Ok(conn))) => { - this.waiting - .spawn(timeout(*this.timeout, this.tls.accept(conn))); + Poll::Ready(Some(Ok(mut conn))) => { + let t = *this.timeout; + let tls = this.tls.clone(); + let protocol = *this.protocol; + this.waiting.spawn(async move { + let peer_addr = match conn.inner.wait_for_addr().await { + Ok(Some(addr)) => addr, + Err(e) => { + tracing::error!("failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}"); + return None; + } + Ok(None) => conn.inner.inner.remote_addr() + }; + + let accept = tls.accept(conn); + match timeout(t, accept).await { + Ok(Ok(conn)) => Some(conn), + // The handshake failed, try getting another connection from the queue + Ok(Err(e)) => { + TLS_HANDSHAKE_FAILURES.inc(); + warn!(%peer_addr, protocol, "failed to accept TLS connection: {e:?}"); + None + } + // The handshake timed out, try getting another connection from the queue + Err(_) => { + TLS_HANDSHAKE_FAILURES.inc(); + warn!(%peer_addr, protocol, "failed to accept TLS connection: timeout"); + None + } + } + }); } Poll::Ready(Some(Err(e))) => { - return Poll::Ready(Some(Err(Error::ListenerError(e)))); + tracing::error!("error accepting TCP connection: {e}"); + continue; } Poll::Ready(None) => return Poll::Ready(None), } @@ -188,96 +112,19 @@ where loop { return match this.waiting.poll_join_next(cx) { - Poll::Ready(Some(Ok(Ok(conn)))) => { - Poll::Ready(Some(conn.map_err(Error::TlsAcceptError))) + Poll::Ready(Some(Ok(Some(conn)))) => { + info!(protocol = this.protocol, "accepted new TLS connection"); + Poll::Ready(Some(Ok(conn))) } - // The handshake timed out, try getting another connection from the queue - Poll::Ready(Some(Ok(Err(_)))) => continue, - // The handshake panicked - Poll::Ready(Some(Err(e))) if e.is_panic() => { - std::panic::resume_unwind(e.into_panic()) + // The handshake failed to complete, try getting another connection from the queue + Poll::Ready(Some(Ok(None))) => continue, + // The handshake panicked or was cancelled. ignore and get another connection + Poll::Ready(Some(Err(e))) => { + tracing::warn!("handshake aborted: {e}"); + continue; } - // The handshake was externally aborted - Poll::Ready(Some(Err(_))) => unreachable!("handshake tasks are never aborted"), _ => Poll::Pending, }; } } } - -impl AsyncTls for tokio_rustls::TlsAcceptor { - type Stream = tokio_rustls::server::TlsStream; - type Error = std::io::Error; - type AcceptFuture = tokio_rustls::Accept; - - fn accept(&self, conn: C) -> Self::AcceptFuture { - tokio_rustls::TlsAcceptor::accept(self, conn) - } -} - -impl Builder { - /// Set the timeout for handshakes. - /// - /// If a timeout takes longer than `timeout`, then the handshake will be - /// aborted and the underlying connection will be dropped. - /// - /// Defaults to `DEFAULT_HANDSHAKE_TIMEOUT`. - pub fn handshake_timeout(&mut self, timeout: Duration) -> &mut Self { - self.handshake_timeout = timeout; - self - } - - /// Create a `TlsListener` from the builder - /// - /// Actually build the `TlsListener`. The `listener` argument should be - /// an implementation of the `AsyncAccept` trait that accepts new connections - /// that the `TlsListener` will encrypt using TLS. - pub fn listen(&self, listener: A) -> TlsListener - where - T: AsyncTls, - { - TlsListener { - listener, - tls: self.tls.clone(), - waiting: JoinSet::new(), - timeout: self.handshake_timeout, - } - } -} - -/// Create a new Builder for a TlsListener -/// -/// `server_config` will be used to configure the TLS sessions. -pub fn builder(tls: T) -> Builder { - Builder { - tls, - handshake_timeout: DEFAULT_HANDSHAKE_TIMEOUT, - } -} - -pin_project! { - /// See [`AsyncAccept::until`] - pub struct Until { - #[pin] - acceptor: A, - #[pin] - ender: E, - } -} - -impl AsyncAccept for Until { - type Connection = A::Connection; - type Error = A::Error; - - fn poll_accept( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - let this = self.project(); - - match this.ender.poll(cx) { - Poll::Pending => this.acceptor.poll_accept(cx), - Poll::Ready(_) => Poll::Ready(None), - } - } -}