mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-05 20:42:54 +00:00
proxy refactor tls listener (#7056)
## Problem Now that we have tls-listener vendored, we can refactor and remove a lot of bloated code and make the whole flow a bit simpler ## Summary of changes 1. Remove dead code 2. Move the error handling to inside the `TlsListener` accept() function 3. Extract the peer_addr from the PROXY protocol header and log it with errors
This commit is contained in:
@@ -17,7 +17,7 @@ use pin_project_lite::pin_project;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{metrics::NUM_CLIENT_CONNECTION_GAUGE, serverless::tls_listener::AsyncAccept};
|
||||
use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE;
|
||||
|
||||
pub struct ProxyProtocolAccept {
|
||||
pub incoming: AddrIncoming,
|
||||
@@ -331,15 +331,15 @@ impl<T: AsyncRead> AsyncRead for WithClientIp<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncAccept for ProxyProtocolAccept {
|
||||
type Connection = WithConnectionGuard<WithClientIp<AddrStream>>;
|
||||
impl Accept for ProxyProtocolAccept {
|
||||
type Conn = WithConnectionGuard<WithClientIp<AddrStream>>;
|
||||
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_accept(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
|
||||
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
|
||||
let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?);
|
||||
tracing::info!(protocol = self.protocol, "accepted new TCP connection");
|
||||
let Some(conn) = conn else {
|
||||
|
||||
@@ -21,24 +21,19 @@ pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
use tokio_util::task::TaskTracker;
|
||||
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::metrics::TLS_HANDSHAKE_FAILURES;
|
||||
use crate::protocol2::{ProxyProtocolAccept, WithClientIp, WithConnectionGuard};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::serverless::backend::PoolingBackend;
|
||||
use crate::{cancellation::CancellationHandler, config::ProxyConfig};
|
||||
use futures::StreamExt;
|
||||
use hyper::{
|
||||
server::{
|
||||
accept,
|
||||
conn::{AddrIncoming, AddrStream},
|
||||
},
|
||||
server::conn::{AddrIncoming, AddrStream},
|
||||
Body, Method, Request, Response,
|
||||
};
|
||||
|
||||
use std::convert::Infallible;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::task::Poll;
|
||||
use std::{future::ready, sync::Arc};
|
||||
use tls_listener::TlsListener;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
@@ -105,19 +100,12 @@ pub async fn task_main(
|
||||
let ws_connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
ws_connections.close(); // allows `ws_connections.wait to complete`
|
||||
|
||||
let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| {
|
||||
if let Err(err) = conn {
|
||||
error!(
|
||||
protocol = "http",
|
||||
"failed to accept TLS connection: {err:?}"
|
||||
);
|
||||
TLS_HANDSHAKE_FAILURES.inc();
|
||||
ready(false)
|
||||
} else {
|
||||
info!(protocol = "http", "accepted new TLS connection");
|
||||
ready(true)
|
||||
}
|
||||
});
|
||||
let tls_listener = TlsListener::new(
|
||||
tls_acceptor,
|
||||
addr_incoming,
|
||||
"http",
|
||||
config.handshake_timeout,
|
||||
);
|
||||
|
||||
let make_svc = hyper::service::make_service_fn(
|
||||
|stream: &tokio_rustls::server::TlsStream<
|
||||
@@ -174,7 +162,7 @@ pub async fn task_main(
|
||||
},
|
||||
);
|
||||
|
||||
hyper::Server::builder(accept::from_stream(tls_listener))
|
||||
hyper::Server::builder(tls_listener)
|
||||
.serve(make_svc)
|
||||
.with_graceful_shutdown(cancellation_token.cancelled())
|
||||
.await?;
|
||||
|
||||
@@ -1,186 +1,110 @@
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use futures::{Future, Stream, StreamExt};
|
||||
use hyper::server::{accept::Accept, conn::AddrStream};
|
||||
use pin_project_lite::pin_project;
|
||||
use thiserror::Error;
|
||||
use tokio::{
|
||||
io::{AsyncRead, AsyncWrite},
|
||||
task::JoinSet,
|
||||
time::timeout,
|
||||
};
|
||||
use tokio_rustls::{server::TlsStream, TlsAcceptor};
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Default timeout for the TLS handshake.
|
||||
pub const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
use crate::{
|
||||
metrics::TLS_HANDSHAKE_FAILURES,
|
||||
protocol2::{WithClientIp, WithConnectionGuard},
|
||||
};
|
||||
|
||||
/// Trait for TLS implementation.
|
||||
///
|
||||
/// Implementations are provided by the rustls and native-tls features.
|
||||
pub trait AsyncTls<C: AsyncRead + AsyncWrite>: Clone {
|
||||
/// The type of the TLS stream created from the underlying stream.
|
||||
type Stream: Send + 'static;
|
||||
/// Error type for completing the TLS handshake
|
||||
type Error: std::error::Error + Send + 'static;
|
||||
/// Type of the Future for the TLS stream that is accepted.
|
||||
type AcceptFuture: Future<Output = Result<Self::Stream, Self::Error>> + Send + 'static;
|
||||
|
||||
/// Accept a TLS connection on an underlying stream
|
||||
fn accept(&self, stream: C) -> Self::AcceptFuture;
|
||||
pin_project! {
|
||||
/// Wraps a `Stream` of connections (such as a TCP listener) so that each connection is itself
|
||||
/// encrypted using TLS.
|
||||
pub(crate) struct TlsListener<A: Accept> {
|
||||
#[pin]
|
||||
listener: A,
|
||||
tls: TlsAcceptor,
|
||||
waiting: JoinSet<Option<TlsStream<A::Conn>>>,
|
||||
timeout: Duration,
|
||||
protocol: &'static str,
|
||||
}
|
||||
}
|
||||
|
||||
/// Asynchronously accept connections.
|
||||
pub trait AsyncAccept {
|
||||
/// The type of the connection that is accepted.
|
||||
type Connection: AsyncRead + AsyncWrite;
|
||||
/// The type of error that may be returned.
|
||||
type Error;
|
||||
|
||||
/// Poll to accept the next connection.
|
||||
fn poll_accept(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Self::Connection, Self::Error>>>;
|
||||
|
||||
/// Return a new `AsyncAccept` that stops accepting connections after
|
||||
/// `ender` completes.
|
||||
///
|
||||
/// Useful for graceful shutdown.
|
||||
///
|
||||
/// See [examples/echo.rs](https://github.com/tmccombs/tls-listener/blob/main/examples/echo.rs)
|
||||
/// for example of how to use.
|
||||
fn until<F: Future>(self, ender: F) -> Until<Self, F>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
Until {
|
||||
acceptor: self,
|
||||
ender,
|
||||
impl<A: Accept> TlsListener<A> {
|
||||
/// Create a `TlsListener` with default options.
|
||||
pub(crate) fn new(
|
||||
tls: TlsAcceptor,
|
||||
listener: A,
|
||||
protocol: &'static str,
|
||||
timeout: Duration,
|
||||
) -> Self {
|
||||
TlsListener {
|
||||
listener,
|
||||
tls,
|
||||
waiting: JoinSet::new(),
|
||||
timeout,
|
||||
protocol,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
///
|
||||
/// Wraps a `Stream` of connections (such as a TCP listener) so that each connection is itself
|
||||
/// encrypted using TLS.
|
||||
///
|
||||
/// It is similar to:
|
||||
///
|
||||
/// ```ignore
|
||||
/// tcpListener.and_then(|s| tlsAcceptor.accept(s))
|
||||
/// ```
|
||||
///
|
||||
/// except that it has the ability to accept multiple transport-level connections
|
||||
/// simultaneously while the TLS handshake is pending for other connections.
|
||||
///
|
||||
/// By default, if a client fails the TLS handshake, that is treated as an error, and the
|
||||
/// `TlsListener` will return an `Err`. If the `TlsListener` is passed directly to a hyper
|
||||
/// [`Server`][1], then an invalid handshake can cause the server to stop accepting connections.
|
||||
/// See [`http-stream.rs`][2] or [`http-low-level`][3] examples, for examples of how to avoid this.
|
||||
///
|
||||
/// Note that if the maximum number of pending connections is greater than 1, the resulting
|
||||
/// [`T::Stream`][4] connections may come in a different order than the connections produced by the
|
||||
/// underlying listener.
|
||||
///
|
||||
/// [1]: https://docs.rs/hyper/latest/hyper/server/struct.Server.html
|
||||
/// [2]: https://github.com/tmccombs/tls-listener/blob/main/examples/http-stream.rs
|
||||
/// [3]: https://github.com/tmccombs/tls-listener/blob/main/examples/http-low-level.rs
|
||||
/// [4]: AsyncTls::Stream
|
||||
///
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub struct TlsListener<A: AsyncAccept, T: AsyncTls<A::Connection>> {
|
||||
#[pin]
|
||||
listener: A,
|
||||
tls: T,
|
||||
waiting: JoinSet<Result<Result<T::Stream, T::Error>, tokio::time::error::Elapsed>>,
|
||||
timeout: Duration,
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for `TlsListener`.
|
||||
#[derive(Clone)]
|
||||
pub struct Builder<T> {
|
||||
tls: T,
|
||||
handshake_timeout: Duration,
|
||||
}
|
||||
|
||||
/// Wraps errors from either the listener or the TLS Acceptor
|
||||
#[derive(Debug, Error)]
|
||||
pub enum Error<LE: std::error::Error, TE: std::error::Error> {
|
||||
/// An error that arose from the listener ([AsyncAccept::Error])
|
||||
#[error("{0}")]
|
||||
ListenerError(#[source] LE),
|
||||
/// An error that occurred during the TLS accept handshake
|
||||
#[error("{0}")]
|
||||
TlsAcceptError(#[source] TE),
|
||||
}
|
||||
|
||||
impl<A: AsyncAccept, T> TlsListener<A, T>
|
||||
impl<A> Accept for TlsListener<A>
|
||||
where
|
||||
T: AsyncTls<A::Connection>,
|
||||
{
|
||||
/// Create a `TlsListener` with default options.
|
||||
pub fn new(tls: T, listener: A) -> Self {
|
||||
builder(tls).listen(listener)
|
||||
}
|
||||
}
|
||||
|
||||
impl<A, T> TlsListener<A, T>
|
||||
where
|
||||
A: AsyncAccept,
|
||||
A: Accept<Conn = WithConnectionGuard<WithClientIp<AddrStream>>>,
|
||||
A::Error: std::error::Error,
|
||||
T: AsyncTls<A::Connection>,
|
||||
A::Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
/// Accept the next connection
|
||||
///
|
||||
/// This is essentially an alias to `self.next()` with a more domain-appropriate name.
|
||||
pub async fn accept(&mut self) -> Option<<Self as Stream>::Item>
|
||||
where
|
||||
Self: Unpin,
|
||||
{
|
||||
self.next().await
|
||||
}
|
||||
type Conn = TlsStream<A::Conn>;
|
||||
|
||||
/// Replaces the Tls Acceptor configuration, which will be used for new connections.
|
||||
///
|
||||
/// This can be used to change the certificate used at runtime.
|
||||
pub fn replace_acceptor(&mut self, acceptor: T) {
|
||||
self.tls = acceptor;
|
||||
}
|
||||
type Error = Infallible;
|
||||
|
||||
/// Replaces the Tls Acceptor configuration from a pinned reference to `Self`.
|
||||
///
|
||||
/// This is useful if your listener is `!Unpin`.
|
||||
///
|
||||
/// This can be used to change the certificate used at runtime.
|
||||
pub fn replace_acceptor_pin(self: Pin<&mut Self>, acceptor: T) {
|
||||
*self.project().tls = acceptor;
|
||||
}
|
||||
}
|
||||
|
||||
impl<A, T> Stream for TlsListener<A, T>
|
||||
where
|
||||
A: AsyncAccept,
|
||||
A::Error: std::error::Error,
|
||||
T: AsyncTls<A::Connection>,
|
||||
{
|
||||
type Item = Result<T::Stream, Error<A::Error, T::Error>>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
fn poll_accept(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
|
||||
let mut this = self.project();
|
||||
|
||||
loop {
|
||||
match this.listener.as_mut().poll_accept(cx) {
|
||||
Poll::Pending => break,
|
||||
Poll::Ready(Some(Ok(conn))) => {
|
||||
this.waiting
|
||||
.spawn(timeout(*this.timeout, this.tls.accept(conn)));
|
||||
Poll::Ready(Some(Ok(mut conn))) => {
|
||||
let t = *this.timeout;
|
||||
let tls = this.tls.clone();
|
||||
let protocol = *this.protocol;
|
||||
this.waiting.spawn(async move {
|
||||
let peer_addr = match conn.inner.wait_for_addr().await {
|
||||
Ok(Some(addr)) => addr,
|
||||
Err(e) => {
|
||||
tracing::error!("failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}");
|
||||
return None;
|
||||
}
|
||||
Ok(None) => conn.inner.inner.remote_addr()
|
||||
};
|
||||
|
||||
let accept = tls.accept(conn);
|
||||
match timeout(t, accept).await {
|
||||
Ok(Ok(conn)) => Some(conn),
|
||||
// The handshake failed, try getting another connection from the queue
|
||||
Ok(Err(e)) => {
|
||||
TLS_HANDSHAKE_FAILURES.inc();
|
||||
warn!(%peer_addr, protocol, "failed to accept TLS connection: {e:?}");
|
||||
None
|
||||
}
|
||||
// The handshake timed out, try getting another connection from the queue
|
||||
Err(_) => {
|
||||
TLS_HANDSHAKE_FAILURES.inc();
|
||||
warn!(%peer_addr, protocol, "failed to accept TLS connection: timeout");
|
||||
None
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
Poll::Ready(Some(Err(e))) => {
|
||||
return Poll::Ready(Some(Err(Error::ListenerError(e))));
|
||||
tracing::error!("error accepting TCP connection: {e}");
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(None) => return Poll::Ready(None),
|
||||
}
|
||||
@@ -188,96 +112,19 @@ where
|
||||
|
||||
loop {
|
||||
return match this.waiting.poll_join_next(cx) {
|
||||
Poll::Ready(Some(Ok(Ok(conn)))) => {
|
||||
Poll::Ready(Some(conn.map_err(Error::TlsAcceptError)))
|
||||
Poll::Ready(Some(Ok(Some(conn)))) => {
|
||||
info!(protocol = this.protocol, "accepted new TLS connection");
|
||||
Poll::Ready(Some(Ok(conn)))
|
||||
}
|
||||
// The handshake timed out, try getting another connection from the queue
|
||||
Poll::Ready(Some(Ok(Err(_)))) => continue,
|
||||
// The handshake panicked
|
||||
Poll::Ready(Some(Err(e))) if e.is_panic() => {
|
||||
std::panic::resume_unwind(e.into_panic())
|
||||
// The handshake failed to complete, try getting another connection from the queue
|
||||
Poll::Ready(Some(Ok(None))) => continue,
|
||||
// The handshake panicked or was cancelled. ignore and get another connection
|
||||
Poll::Ready(Some(Err(e))) => {
|
||||
tracing::warn!("handshake aborted: {e}");
|
||||
continue;
|
||||
}
|
||||
// The handshake was externally aborted
|
||||
Poll::Ready(Some(Err(_))) => unreachable!("handshake tasks are never aborted"),
|
||||
_ => Poll::Pending,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: AsyncRead + AsyncWrite + Unpin + Send + 'static> AsyncTls<C> for tokio_rustls::TlsAcceptor {
|
||||
type Stream = tokio_rustls::server::TlsStream<C>;
|
||||
type Error = std::io::Error;
|
||||
type AcceptFuture = tokio_rustls::Accept<C>;
|
||||
|
||||
fn accept(&self, conn: C) -> Self::AcceptFuture {
|
||||
tokio_rustls::TlsAcceptor::accept(self, conn)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Builder<T> {
|
||||
/// Set the timeout for handshakes.
|
||||
///
|
||||
/// If a timeout takes longer than `timeout`, then the handshake will be
|
||||
/// aborted and the underlying connection will be dropped.
|
||||
///
|
||||
/// Defaults to `DEFAULT_HANDSHAKE_TIMEOUT`.
|
||||
pub fn handshake_timeout(&mut self, timeout: Duration) -> &mut Self {
|
||||
self.handshake_timeout = timeout;
|
||||
self
|
||||
}
|
||||
|
||||
/// Create a `TlsListener` from the builder
|
||||
///
|
||||
/// Actually build the `TlsListener`. The `listener` argument should be
|
||||
/// an implementation of the `AsyncAccept` trait that accepts new connections
|
||||
/// that the `TlsListener` will encrypt using TLS.
|
||||
pub fn listen<A: AsyncAccept>(&self, listener: A) -> TlsListener<A, T>
|
||||
where
|
||||
T: AsyncTls<A::Connection>,
|
||||
{
|
||||
TlsListener {
|
||||
listener,
|
||||
tls: self.tls.clone(),
|
||||
waiting: JoinSet::new(),
|
||||
timeout: self.handshake_timeout,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new Builder for a TlsListener
|
||||
///
|
||||
/// `server_config` will be used to configure the TLS sessions.
|
||||
pub fn builder<T>(tls: T) -> Builder<T> {
|
||||
Builder {
|
||||
tls,
|
||||
handshake_timeout: DEFAULT_HANDSHAKE_TIMEOUT,
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// See [`AsyncAccept::until`]
|
||||
pub struct Until<A, E> {
|
||||
#[pin]
|
||||
acceptor: A,
|
||||
#[pin]
|
||||
ender: E,
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: AsyncAccept, E: Future> AsyncAccept for Until<A, E> {
|
||||
type Connection = A::Connection;
|
||||
type Error = A::Error;
|
||||
|
||||
fn poll_accept(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
|
||||
let this = self.project();
|
||||
|
||||
match this.ender.poll(cx) {
|
||||
Poll::Pending => this.acceptor.poll_accept(cx),
|
||||
Poll::Ready(_) => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user