proxy refactor tls listener (#7056)

## Problem

Now that we have tls-listener vendored, we can refactor and remove a lot
of bloated code and make the whole flow a bit simpler

## Summary of changes

1. Remove dead code
2. Move the error handling to inside the `TlsListener` accept() function
3. Extract the peer_addr from the PROXY protocol header and log it with
errors
This commit is contained in:
Conrad Ludgate
2024-03-12 13:05:40 +00:00
committed by GitHub
parent 580e136b2e
commit 1f7d54f987
3 changed files with 97 additions and 262 deletions

View File

@@ -17,7 +17,7 @@ use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
use uuid::Uuid;
use crate::{metrics::NUM_CLIENT_CONNECTION_GAUGE, serverless::tls_listener::AsyncAccept};
use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE;
pub struct ProxyProtocolAccept {
pub incoming: AddrIncoming,
@@ -331,15 +331,15 @@ impl<T: AsyncRead> AsyncRead for WithClientIp<T> {
}
}
impl AsyncAccept for ProxyProtocolAccept {
type Connection = WithConnectionGuard<WithClientIp<AddrStream>>;
impl Accept for ProxyProtocolAccept {
type Conn = WithConnectionGuard<WithClientIp<AddrStream>>;
type Error = io::Error;
fn poll_accept(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?);
tracing::info!(protocol = self.protocol, "accepted new TCP connection");
let Some(conn) = conn else {

View File

@@ -21,24 +21,19 @@ pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use tokio_util::task::TaskTracker;
use crate::context::RequestMonitoring;
use crate::metrics::TLS_HANDSHAKE_FAILURES;
use crate::protocol2::{ProxyProtocolAccept, WithClientIp, WithConnectionGuard};
use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
use crate::{cancellation::CancellationHandler, config::ProxyConfig};
use futures::StreamExt;
use hyper::{
server::{
accept,
conn::{AddrIncoming, AddrStream},
},
server::conn::{AddrIncoming, AddrStream},
Body, Method, Request, Response,
};
use std::convert::Infallible;
use std::net::IpAddr;
use std::sync::Arc;
use std::task::Poll;
use std::{future::ready, sync::Arc};
use tls_listener::TlsListener;
use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken;
@@ -105,19 +100,12 @@ pub async fn task_main(
let ws_connections = tokio_util::task::task_tracker::TaskTracker::new();
ws_connections.close(); // allows `ws_connections.wait to complete`
let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| {
if let Err(err) = conn {
error!(
protocol = "http",
"failed to accept TLS connection: {err:?}"
);
TLS_HANDSHAKE_FAILURES.inc();
ready(false)
} else {
info!(protocol = "http", "accepted new TLS connection");
ready(true)
}
});
let tls_listener = TlsListener::new(
tls_acceptor,
addr_incoming,
"http",
config.handshake_timeout,
);
let make_svc = hyper::service::make_service_fn(
|stream: &tokio_rustls::server::TlsStream<
@@ -174,7 +162,7 @@ pub async fn task_main(
},
);
hyper::Server::builder(accept::from_stream(tls_listener))
hyper::Server::builder(tls_listener)
.serve(make_svc)
.with_graceful_shutdown(cancellation_token.cancelled())
.await?;

View File

@@ -1,186 +1,110 @@
use std::{
convert::Infallible,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use futures::{Future, Stream, StreamExt};
use hyper::server::{accept::Accept, conn::AddrStream};
use pin_project_lite::pin_project;
use thiserror::Error;
use tokio::{
io::{AsyncRead, AsyncWrite},
task::JoinSet,
time::timeout,
};
use tokio_rustls::{server::TlsStream, TlsAcceptor};
use tracing::{info, warn};
/// Default timeout for the TLS handshake.
pub const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
use crate::{
metrics::TLS_HANDSHAKE_FAILURES,
protocol2::{WithClientIp, WithConnectionGuard},
};
/// Trait for TLS implementation.
///
/// Implementations are provided by the rustls and native-tls features.
pub trait AsyncTls<C: AsyncRead + AsyncWrite>: Clone {
/// The type of the TLS stream created from the underlying stream.
type Stream: Send + 'static;
/// Error type for completing the TLS handshake
type Error: std::error::Error + Send + 'static;
/// Type of the Future for the TLS stream that is accepted.
type AcceptFuture: Future<Output = Result<Self::Stream, Self::Error>> + Send + 'static;
/// Accept a TLS connection on an underlying stream
fn accept(&self, stream: C) -> Self::AcceptFuture;
pin_project! {
/// Wraps a `Stream` of connections (such as a TCP listener) so that each connection is itself
/// encrypted using TLS.
pub(crate) struct TlsListener<A: Accept> {
#[pin]
listener: A,
tls: TlsAcceptor,
waiting: JoinSet<Option<TlsStream<A::Conn>>>,
timeout: Duration,
protocol: &'static str,
}
}
/// Asynchronously accept connections.
pub trait AsyncAccept {
/// The type of the connection that is accepted.
type Connection: AsyncRead + AsyncWrite;
/// The type of error that may be returned.
type Error;
/// Poll to accept the next connection.
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Connection, Self::Error>>>;
/// Return a new `AsyncAccept` that stops accepting connections after
/// `ender` completes.
///
/// Useful for graceful shutdown.
///
/// See [examples/echo.rs](https://github.com/tmccombs/tls-listener/blob/main/examples/echo.rs)
/// for example of how to use.
fn until<F: Future>(self, ender: F) -> Until<Self, F>
where
Self: Sized,
{
Until {
acceptor: self,
ender,
impl<A: Accept> TlsListener<A> {
/// Create a `TlsListener` with default options.
pub(crate) fn new(
tls: TlsAcceptor,
listener: A,
protocol: &'static str,
timeout: Duration,
) -> Self {
TlsListener {
listener,
tls,
waiting: JoinSet::new(),
timeout,
protocol,
}
}
}
pin_project! {
///
/// Wraps a `Stream` of connections (such as a TCP listener) so that each connection is itself
/// encrypted using TLS.
///
/// It is similar to:
///
/// ```ignore
/// tcpListener.and_then(|s| tlsAcceptor.accept(s))
/// ```
///
/// except that it has the ability to accept multiple transport-level connections
/// simultaneously while the TLS handshake is pending for other connections.
///
/// By default, if a client fails the TLS handshake, that is treated as an error, and the
/// `TlsListener` will return an `Err`. If the `TlsListener` is passed directly to a hyper
/// [`Server`][1], then an invalid handshake can cause the server to stop accepting connections.
/// See [`http-stream.rs`][2] or [`http-low-level`][3] examples, for examples of how to avoid this.
///
/// Note that if the maximum number of pending connections is greater than 1, the resulting
/// [`T::Stream`][4] connections may come in a different order than the connections produced by the
/// underlying listener.
///
/// [1]: https://docs.rs/hyper/latest/hyper/server/struct.Server.html
/// [2]: https://github.com/tmccombs/tls-listener/blob/main/examples/http-stream.rs
/// [3]: https://github.com/tmccombs/tls-listener/blob/main/examples/http-low-level.rs
/// [4]: AsyncTls::Stream
///
#[allow(clippy::type_complexity)]
pub struct TlsListener<A: AsyncAccept, T: AsyncTls<A::Connection>> {
#[pin]
listener: A,
tls: T,
waiting: JoinSet<Result<Result<T::Stream, T::Error>, tokio::time::error::Elapsed>>,
timeout: Duration,
}
}
/// Builder for `TlsListener`.
#[derive(Clone)]
pub struct Builder<T> {
tls: T,
handshake_timeout: Duration,
}
/// Wraps errors from either the listener or the TLS Acceptor
#[derive(Debug, Error)]
pub enum Error<LE: std::error::Error, TE: std::error::Error> {
/// An error that arose from the listener ([AsyncAccept::Error])
#[error("{0}")]
ListenerError(#[source] LE),
/// An error that occurred during the TLS accept handshake
#[error("{0}")]
TlsAcceptError(#[source] TE),
}
impl<A: AsyncAccept, T> TlsListener<A, T>
impl<A> Accept for TlsListener<A>
where
T: AsyncTls<A::Connection>,
{
/// Create a `TlsListener` with default options.
pub fn new(tls: T, listener: A) -> Self {
builder(tls).listen(listener)
}
}
impl<A, T> TlsListener<A, T>
where
A: AsyncAccept,
A: Accept<Conn = WithConnectionGuard<WithClientIp<AddrStream>>>,
A::Error: std::error::Error,
T: AsyncTls<A::Connection>,
A::Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
/// Accept the next connection
///
/// This is essentially an alias to `self.next()` with a more domain-appropriate name.
pub async fn accept(&mut self) -> Option<<Self as Stream>::Item>
where
Self: Unpin,
{
self.next().await
}
type Conn = TlsStream<A::Conn>;
/// Replaces the Tls Acceptor configuration, which will be used for new connections.
///
/// This can be used to change the certificate used at runtime.
pub fn replace_acceptor(&mut self, acceptor: T) {
self.tls = acceptor;
}
type Error = Infallible;
/// Replaces the Tls Acceptor configuration from a pinned reference to `Self`.
///
/// This is useful if your listener is `!Unpin`.
///
/// This can be used to change the certificate used at runtime.
pub fn replace_acceptor_pin(self: Pin<&mut Self>, acceptor: T) {
*self.project().tls = acceptor;
}
}
impl<A, T> Stream for TlsListener<A, T>
where
A: AsyncAccept,
A::Error: std::error::Error,
T: AsyncTls<A::Connection>,
{
type Item = Result<T::Stream, Error<A::Error, T::Error>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let mut this = self.project();
loop {
match this.listener.as_mut().poll_accept(cx) {
Poll::Pending => break,
Poll::Ready(Some(Ok(conn))) => {
this.waiting
.spawn(timeout(*this.timeout, this.tls.accept(conn)));
Poll::Ready(Some(Ok(mut conn))) => {
let t = *this.timeout;
let tls = this.tls.clone();
let protocol = *this.protocol;
this.waiting.spawn(async move {
let peer_addr = match conn.inner.wait_for_addr().await {
Ok(Some(addr)) => addr,
Err(e) => {
tracing::error!("failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}");
return None;
}
Ok(None) => conn.inner.inner.remote_addr()
};
let accept = tls.accept(conn);
match timeout(t, accept).await {
Ok(Ok(conn)) => Some(conn),
// The handshake failed, try getting another connection from the queue
Ok(Err(e)) => {
TLS_HANDSHAKE_FAILURES.inc();
warn!(%peer_addr, protocol, "failed to accept TLS connection: {e:?}");
None
}
// The handshake timed out, try getting another connection from the queue
Err(_) => {
TLS_HANDSHAKE_FAILURES.inc();
warn!(%peer_addr, protocol, "failed to accept TLS connection: timeout");
None
}
}
});
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(Error::ListenerError(e))));
tracing::error!("error accepting TCP connection: {e}");
continue;
}
Poll::Ready(None) => return Poll::Ready(None),
}
@@ -188,96 +112,19 @@ where
loop {
return match this.waiting.poll_join_next(cx) {
Poll::Ready(Some(Ok(Ok(conn)))) => {
Poll::Ready(Some(conn.map_err(Error::TlsAcceptError)))
Poll::Ready(Some(Ok(Some(conn)))) => {
info!(protocol = this.protocol, "accepted new TLS connection");
Poll::Ready(Some(Ok(conn)))
}
// The handshake timed out, try getting another connection from the queue
Poll::Ready(Some(Ok(Err(_)))) => continue,
// The handshake panicked
Poll::Ready(Some(Err(e))) if e.is_panic() => {
std::panic::resume_unwind(e.into_panic())
// The handshake failed to complete, try getting another connection from the queue
Poll::Ready(Some(Ok(None))) => continue,
// The handshake panicked or was cancelled. ignore and get another connection
Poll::Ready(Some(Err(e))) => {
tracing::warn!("handshake aborted: {e}");
continue;
}
// The handshake was externally aborted
Poll::Ready(Some(Err(_))) => unreachable!("handshake tasks are never aborted"),
_ => Poll::Pending,
};
}
}
}
impl<C: AsyncRead + AsyncWrite + Unpin + Send + 'static> AsyncTls<C> for tokio_rustls::TlsAcceptor {
type Stream = tokio_rustls::server::TlsStream<C>;
type Error = std::io::Error;
type AcceptFuture = tokio_rustls::Accept<C>;
fn accept(&self, conn: C) -> Self::AcceptFuture {
tokio_rustls::TlsAcceptor::accept(self, conn)
}
}
impl<T> Builder<T> {
/// Set the timeout for handshakes.
///
/// If a timeout takes longer than `timeout`, then the handshake will be
/// aborted and the underlying connection will be dropped.
///
/// Defaults to `DEFAULT_HANDSHAKE_TIMEOUT`.
pub fn handshake_timeout(&mut self, timeout: Duration) -> &mut Self {
self.handshake_timeout = timeout;
self
}
/// Create a `TlsListener` from the builder
///
/// Actually build the `TlsListener`. The `listener` argument should be
/// an implementation of the `AsyncAccept` trait that accepts new connections
/// that the `TlsListener` will encrypt using TLS.
pub fn listen<A: AsyncAccept>(&self, listener: A) -> TlsListener<A, T>
where
T: AsyncTls<A::Connection>,
{
TlsListener {
listener,
tls: self.tls.clone(),
waiting: JoinSet::new(),
timeout: self.handshake_timeout,
}
}
}
/// Create a new Builder for a TlsListener
///
/// `server_config` will be used to configure the TLS sessions.
pub fn builder<T>(tls: T) -> Builder<T> {
Builder {
tls,
handshake_timeout: DEFAULT_HANDSHAKE_TIMEOUT,
}
}
pin_project! {
/// See [`AsyncAccept::until`]
pub struct Until<A, E> {
#[pin]
acceptor: A,
#[pin]
ender: E,
}
}
impl<A: AsyncAccept, E: Future> AsyncAccept for Until<A, E> {
type Connection = A::Connection;
type Error = A::Error;
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
let this = self.project();
match this.ender.poll(cx) {
Poll::Pending => this.acceptor.poll_accept(cx),
Poll::Ready(_) => Poll::Ready(None),
}
}
}