Compare commits

...

1 Commits

Author SHA1 Message Date
Folke Behrens
4f88c4b8f3 proxy: introduce Acceptor and Connector traits 2025-01-03 15:53:39 +01:00
10 changed files with 254 additions and 47 deletions

View File

@@ -16,6 +16,7 @@ use proxy::cancellation::CancellationHandlerMain;
use proxy::config::{
self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig,
};
use proxy::conn::TokioTcpAcceptor;
use proxy::control_plane::locks::ApiLocks;
use proxy::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use proxy::http::health_server::AppMetrics;
@@ -36,7 +37,6 @@ project_build_tag!(BUILD_TAG);
use clap::Parser;
use thiserror::Error;
use tokio::net::TcpListener;
use tokio::sync::Notify;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
@@ -166,8 +166,8 @@ async fn main() -> anyhow::Result<()> {
}
};
let metrics_listener = TcpListener::bind(args.metrics).await?.into_std()?;
let http_listener = TcpListener::bind(args.http).await?;
let metrics_listener = TokioTcpAcceptor::bind(args.metrics).await?;
let http_listener = TokioTcpAcceptor::bind(args.http).await?;
let shutdown = CancellationToken::new();
// todo: should scale with CU

View File

@@ -10,6 +10,7 @@ use clap::Arg;
use futures::future::Either;
use futures::TryFutureExt;
use itertools::Itertools;
use proxy::conn::{Acceptor, TokioTcpAcceptor};
use proxy::context::RequestContext;
use proxy::metrics::{Metrics, ThreadPoolMetrics};
use proxy::protocol2::ConnectionInfo;
@@ -122,7 +123,7 @@ async fn main() -> anyhow::Result<()> {
// Start listening for incoming client connections
let proxy_address: SocketAddr = args.get_one::<String>("listen").unwrap().parse()?;
info!("Starting sni router on {proxy_address}");
let proxy_listener = TcpListener::bind(proxy_address).await?;
let proxy_listener = TokioTcpAcceptor::bind(proxy_address).await?;
let cancellation_token = CancellationToken::new();
@@ -152,17 +153,13 @@ async fn task_main(
dest_suffix: Arc<String>,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
listener: tokio::net::TcpListener,
acceptor: TokioTcpAcceptor,
cancellation_token: CancellationToken,
) -> anyhow::Result<()> {
// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
run_until_cancelled(acceptor.accept(), &cancellation_token).await
{
let (socket, peer_addr) = accept_result?;
@@ -172,10 +169,6 @@ async fn task_main(
connections.spawn(
async move {
socket
.set_nodelay(true)
.context("failed to set socket option")?;
info!(%peer_addr, "serving");
let ctx = RequestContext::new(
session_id,
@@ -197,7 +190,7 @@ async fn task_main(
}
connections.close();
drop(listener);
drop(acceptor);
connections.wait().await;

View File

@@ -12,6 +12,7 @@ use proxy::config::{
self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig,
ProjectInfoCacheOptions, ProxyConfig, ProxyProtocolV2,
};
use proxy::conn::TokioTcpAcceptor;
use proxy::context::parquet::ParquetUploadArgs;
use proxy::http::health_server::AppMetrics;
use proxy::metrics::Metrics;
@@ -27,7 +28,6 @@ use proxy::serverless::GlobalConnPoolOptions;
use proxy::tls::client_config::compute_client_config_with_root_certs;
use proxy::{auth, control_plane, http, serverless, usage_metrics};
use remote_storage::RemoteStorageConfig;
use tokio::net::TcpListener;
use tokio::sync::Mutex;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
@@ -353,17 +353,17 @@ async fn main() -> anyhow::Result<()> {
// Check that we can bind to address before further initialization
let http_address: SocketAddr = args.http.parse()?;
info!("Starting http on {http_address}");
let http_listener = TcpListener::bind(http_address).await?.into_std()?;
let http_listener = TokioTcpAcceptor::bind(http_address).await?;
let mgmt_address: SocketAddr = args.mgmt.parse()?;
info!("Starting mgmt on {mgmt_address}");
let mgmt_listener = TcpListener::bind(mgmt_address).await?;
let mgmt_listener = TokioTcpAcceptor::bind(mgmt_address).await?;
let proxy_listener = if !args.is_auth_broker {
let proxy_address: SocketAddr = args.proxy.parse()?;
info!("Starting proxy on {proxy_address}");
Some(TcpListener::bind(proxy_address).await?)
Some(TokioTcpAcceptor::bind(proxy_address).await?)
} else {
None
};
@@ -373,7 +373,7 @@ async fn main() -> anyhow::Result<()> {
let serverless_listener = if let Some(serverless_address) = args.wss {
let serverless_address: SocketAddr = serverless_address.parse()?;
info!("Starting wss on {serverless_address}");
Some(TcpListener::bind(serverless_address).await?)
Some(TokioTcpAcceptor::bind(serverless_address).await?)
} else if args.is_auth_broker {
bail!("wss arg must be present for auth-broker")
} else {

221
proxy/src/conn.rs Normal file
View File

@@ -0,0 +1,221 @@
use std::future::{poll_fn, Future};
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
pub trait Acceptor {
type Connection: AsyncRead + AsyncWrite + Send + Unpin + 'static;
type Error: std::error::Error + Send + Sync + 'static;
#[inline]
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let _ = cx;
Poll::Ready(Ok(()))
}
fn accept(
&self,
) -> impl Future<Output = Result<(Self::Connection, SocketAddr), Self::Error>> + Send;
}
pub trait Connector {
type Connection: AsyncRead + AsyncWrite + Send + Unpin + 'static;
type Error: std::error::Error + Send + Sync + 'static;
#[inline]
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let _ = cx;
Poll::Ready(Ok(()))
}
fn connect(
&self,
addr: SocketAddr,
) -> impl Future<Output = Result<Self::Connection, Self::Error>> + Send;
}
pub struct TokioTcpAcceptor {
listener: TcpListener,
tcp_nodelay: Option<bool>,
tcp_keepalive: Option<bool>,
}
impl TokioTcpAcceptor {
pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
let listener = TcpListener::bind(addr).await?;
// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
Ok(Self {
listener,
tcp_nodelay: Some(true),
tcp_keepalive: None,
})
}
pub fn into_std(self) -> io::Result<std::net::TcpListener> {
self.listener.into_std()
}
}
impl Acceptor for TokioTcpAcceptor {
type Connection = TcpStream;
type Error = io::Error;
fn accept(&self) -> impl Future<Output = Result<(Self::Connection, SocketAddr), Self::Error>> {
async move {
let (stream, addr) = self.listener.accept().await?;
let socket = socket2::SockRef::from(&stream);
if let Some(nodelay) = self.tcp_nodelay {
socket.set_nodelay(nodelay)?;
}
if let Some(keepalive) = self.tcp_keepalive {
socket.set_keepalive(keepalive)?;
}
Ok((stream, addr))
}
}
}
pub struct TokioTcpConnector;
impl Connector for TokioTcpConnector {
type Connection = TcpStream;
type Error = io::Error;
fn connect(
&self,
addr: SocketAddr,
) -> impl Future<Output = Result<Self::Connection, Self::Error>> {
async move {
let socket = TcpStream::connect(addr).await?;
socket.set_nodelay(true)?;
Ok(socket)
}
}
}
pub trait Stream: AsyncRead + AsyncWrite + Send + Unpin + 'static {}
impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> Stream for T {}
pub trait AsyncRead {
fn readable(&self) -> impl Future<Output = io::Result<()>> + Send
where
Self: Send + Sync,
{
poll_fn(move |cx| self.poll_read_ready(cx))
}
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>>;
fn poll_read_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [io::IoSliceMut<'_>],
) -> Poll<io::Result<usize>>;
}
pub trait AsyncWrite {
fn writable(&self) -> impl Future<Output = io::Result<()>> + Send
where
Self: Send + Sync,
{
poll_fn(move |cx| self.poll_write_ready(cx))
}
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>>;
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>>;
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
}
impl AsyncRead for tokio::net::TcpStream {
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
tokio::net::TcpStream::poll_read_ready(self, cx)
}
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
match tokio::net::TcpStream::try_read(Pin::new(&mut *self).get_mut(), buf) {
Ok(n) => Poll::Ready(Ok(n)),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
cx.waker().wake_by_ref();
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
fn poll_read_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [io::IoSliceMut<'_>],
) -> Poll<io::Result<usize>> {
match tokio::net::TcpStream::try_read_vectored(Pin::new(&mut *self).get_mut(), bufs) {
Ok(n) => Poll::Ready(Ok(n)),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
cx.waker().wake_by_ref();
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
}
impl AsyncWrite for tokio::net::TcpStream {
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
tokio::net::TcpStream::poll_write_ready(self, cx)
}
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
<Self as tokio::io::AsyncWrite>::poll_write(self, cx, buf)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
<Self as tokio::io::AsyncWrite>::poll_write_vectored(self, cx, bufs)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
<Self as tokio::io::AsyncWrite>::poll_flush(self, cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
<Self as tokio::io::AsyncWrite>::poll_shutdown(self, cx)
}
}

View File

@@ -8,6 +8,7 @@ use tracing::{debug, error, info, Instrument};
use crate::auth::backend::ConsoleRedirectBackend;
use crate::cancellation::{CancellationHandlerMain, CancellationHandlerMainInternal};
use crate::config::{ProxyConfig, ProxyProtocolV2};
use crate::conn::{Acceptor, TokioTcpAcceptor};
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
@@ -22,7 +23,7 @@ use crate::proxy::{
pub async fn task_main(
config: &'static ProxyConfig,
backend: &'static ConsoleRedirectBackend,
listener: tokio::net::TcpListener,
acceptor: TokioTcpAcceptor,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
) -> anyhow::Result<()> {
@@ -30,15 +31,11 @@ pub async fn task_main(
info!("proxy has shut down");
}
// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
run_until_cancelled(acceptor.accept(), &cancellation_token).await
{
let (socket, peer_addr) = accept_result?;
@@ -131,7 +128,7 @@ pub async fn task_main(
connections.close();
cancellations.close();
drop(listener);
drop(acceptor);
// Drain connections
connections.wait().await;

View File

@@ -4,10 +4,11 @@ use anyhow::Context;
use once_cell::sync::Lazy;
use postgres_backend::{AuthType, PostgresBackend, PostgresBackendTCP, QueryError};
use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
use tokio::net::{TcpListener, TcpStream};
use tokio::net::TcpStream;
use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, Instrument};
use crate::conn::{Acceptor, TokioTcpAcceptor};
use crate::control_plane::messages::{DatabaseInfo, KickSession};
use crate::waiters::{self, Waiter, Waiters};
@@ -26,19 +27,15 @@ pub(crate) fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), wai
/// Management API listener task.
/// It spawns management response handlers needed for the console redirect auth flow.
pub async fn task_main(listener: TcpListener) -> anyhow::Result<Infallible> {
pub async fn task_main(acceptor: TokioTcpAcceptor) -> anyhow::Result<Infallible> {
scopeguard::defer! {
info!("mgmt has shut down");
}
loop {
let (socket, peer_addr) = listener.accept().await?;
let (socket, peer_addr) = acceptor.accept().await?;
info!("accepted connection from {peer_addr}");
socket
.set_nodelay(true)
.context("failed to set client socket option")?;
let span = info_span!("mgmt", peer = %peer_addr);
tokio::task::spawn(

View File

@@ -1,5 +1,4 @@
use std::convert::Infallible;
use std::net::TcpListener;
use std::sync::{Arc, Mutex};
use anyhow::{anyhow, bail};
@@ -14,6 +13,7 @@ use utils::http::error::ApiError;
use utils::http::json::json_response;
use utils::http::{RouterBuilder, RouterService};
use crate::conn::TokioTcpAcceptor;
use crate::ext::{LockExt, TaskExt};
use crate::jemalloc;
@@ -36,7 +36,7 @@ fn make_router(metrics: AppMetrics) -> RouterBuilder<hyper0::Body, ApiError> {
}
pub async fn task_main(
http_listener: TcpListener,
http_acceptor: TokioTcpAcceptor,
metrics: AppMetrics,
) -> anyhow::Result<Infallible> {
scopeguard::defer! {
@@ -45,7 +45,7 @@ pub async fn task_main(
let service = || RouterService::new(make_router(metrics).build()?);
hyper0::Server::from_tcp(http_listener)?
hyper0::Server::from_tcp(http_acceptor.into_std()?)?
.serve(service().map_err(|e| anyhow!(e))?)
.await?;

View File

@@ -78,6 +78,7 @@ pub mod cancellation;
pub mod compute;
pub mod compute_ctl;
pub mod config;
pub mod conn;
pub mod console_redirect_proxy;
pub mod context;
pub mod control_plane;

View File

@@ -25,6 +25,7 @@ use self::connect_compute::{connect_to_compute, TcpMechanism};
use self::passthrough::ProxyPassthrough;
use crate::cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal};
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::conn::{Acceptor, TokioTcpAcceptor};
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
@@ -55,7 +56,7 @@ pub async fn run_until_cancelled<F: std::future::Future>(
pub async fn task_main(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
listener: tokio::net::TcpListener,
acceptor: TokioTcpAcceptor,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -64,15 +65,11 @@ pub async fn task_main(
info!("proxy has shut down");
}
// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
run_until_cancelled(acceptor.accept(), &cancellation_token).await
{
let (socket, peer_addr) = accept_result?;
@@ -168,7 +165,7 @@ pub async fn task_main(
connections.close();
cancellations.close();
drop(listener);
drop(acceptor);
// Drain connections
connections.wait().await;

View File

@@ -35,7 +35,7 @@ use rand::rngs::StdRng;
use rand::SeedableRng;
use sql_over_http::{uuid_to_header_value, NEON_REQUEST_ID};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};
use tokio::net::TcpStream;
use tokio::time::timeout;
use tokio_rustls::TlsAcceptor;
use tokio_util::sync::CancellationToken;
@@ -45,6 +45,7 @@ use utils::http::error::ApiError;
use crate::cancellation::CancellationHandlerMain;
use crate::config::{ProxyConfig, ProxyProtocolV2};
use crate::conn::{Acceptor, TokioTcpAcceptor};
use crate::context::RequestContext;
use crate::ext::TaskExt;
use crate::metrics::Metrics;
@@ -59,7 +60,7 @@ pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api";
pub async fn task_main(
config: &'static ProxyConfig,
auth_backend: &'static crate::auth::Backend<'static, ()>,
ws_listener: TcpListener,
ws_acceptor: TokioTcpAcceptor,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -134,7 +135,7 @@ pub async fn task_main(
connections.close(); // allows `connections.wait to complete`
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
while let Some(res) = run_until_cancelled(ws_acceptor.accept(), &cancellation_token).await {
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
if let Err(e) = conn.set_nodelay(true) {
tracing::error!("could not set nodelay: {e}");