diff --git a/proxy/src/bin/local_proxy.rs b/proxy/src/bin/local_proxy.rs index 644f670f88..e59d852d70 100644 --- a/proxy/src/bin/local_proxy.rs +++ b/proxy/src/bin/local_proxy.rs @@ -16,6 +16,7 @@ use proxy::cancellation::CancellationHandlerMain; use proxy::config::{ self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig, }; +use proxy::conn::TokioTcpAcceptor; use proxy::control_plane::locks::ApiLocks; use proxy::control_plane::messages::{EndpointJwksResponse, JwksSettings}; use proxy::http::health_server::AppMetrics; @@ -36,7 +37,6 @@ project_build_tag!(BUILD_TAG); use clap::Parser; use thiserror::Error; -use tokio::net::TcpListener; use tokio::sync::Notify; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; @@ -166,8 +166,8 @@ async fn main() -> anyhow::Result<()> { } }; - let metrics_listener = TcpListener::bind(args.metrics).await?.into_std()?; - let http_listener = TcpListener::bind(args.http).await?; + let metrics_listener = TokioTcpAcceptor::bind(args.metrics).await?; + let http_listener = TokioTcpAcceptor::bind(args.http).await?; let shutdown = CancellationToken::new(); // todo: should scale with CU diff --git a/proxy/src/bin/pg_sni_router.rs b/proxy/src/bin/pg_sni_router.rs index 97d870a83a..121ee0d85e 100644 --- a/proxy/src/bin/pg_sni_router.rs +++ b/proxy/src/bin/pg_sni_router.rs @@ -10,6 +10,7 @@ use clap::Arg; use futures::future::Either; use futures::TryFutureExt; use itertools::Itertools; +use proxy::conn::{Acceptor, TokioTcpAcceptor}; use proxy::context::RequestContext; use proxy::metrics::{Metrics, ThreadPoolMetrics}; use proxy::protocol2::ConnectionInfo; @@ -122,7 +123,7 @@ async fn main() -> anyhow::Result<()> { // Start listening for incoming client connections let proxy_address: SocketAddr = args.get_one::("listen").unwrap().parse()?; info!("Starting sni router on {proxy_address}"); - let proxy_listener = TcpListener::bind(proxy_address).await?; + let proxy_listener = TokioTcpAcceptor::bind(proxy_address).await?; let cancellation_token = CancellationToken::new(); @@ -152,17 +153,13 @@ async fn task_main( dest_suffix: Arc, tls_config: Arc, tls_server_end_point: TlsServerEndPoint, - listener: tokio::net::TcpListener, + acceptor: TokioTcpAcceptor, cancellation_token: CancellationToken, ) -> anyhow::Result<()> { - // When set for the server socket, the keepalive setting - // will be inherited by all accepted client sockets. - socket2::SockRef::from(&listener).set_keepalive(true)?; - let connections = tokio_util::task::task_tracker::TaskTracker::new(); while let Some(accept_result) = - run_until_cancelled(listener.accept(), &cancellation_token).await + run_until_cancelled(acceptor.accept(), &cancellation_token).await { let (socket, peer_addr) = accept_result?; @@ -172,10 +169,6 @@ async fn task_main( connections.spawn( async move { - socket - .set_nodelay(true) - .context("failed to set socket option")?; - info!(%peer_addr, "serving"); let ctx = RequestContext::new( session_id, @@ -197,7 +190,7 @@ async fn task_main( } connections.close(); - drop(listener); + drop(acceptor); connections.wait().await; diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 3b122d771c..99dcae5223 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -12,6 +12,7 @@ use proxy::config::{ self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig, ProjectInfoCacheOptions, ProxyConfig, ProxyProtocolV2, }; +use proxy::conn::TokioTcpAcceptor; use proxy::context::parquet::ParquetUploadArgs; use proxy::http::health_server::AppMetrics; use proxy::metrics::Metrics; @@ -27,7 +28,6 @@ use proxy::serverless::GlobalConnPoolOptions; use proxy::tls::client_config::compute_client_config_with_root_certs; use proxy::{auth, control_plane, http, serverless, usage_metrics}; use remote_storage::RemoteStorageConfig; -use tokio::net::TcpListener; use tokio::sync::Mutex; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; @@ -353,17 +353,17 @@ async fn main() -> anyhow::Result<()> { // Check that we can bind to address before further initialization let http_address: SocketAddr = args.http.parse()?; info!("Starting http on {http_address}"); - let http_listener = TcpListener::bind(http_address).await?.into_std()?; + let http_listener = TokioTcpAcceptor::bind(http_address).await?; let mgmt_address: SocketAddr = args.mgmt.parse()?; info!("Starting mgmt on {mgmt_address}"); - let mgmt_listener = TcpListener::bind(mgmt_address).await?; + let mgmt_listener = TokioTcpAcceptor::bind(mgmt_address).await?; let proxy_listener = if !args.is_auth_broker { let proxy_address: SocketAddr = args.proxy.parse()?; info!("Starting proxy on {proxy_address}"); - Some(TcpListener::bind(proxy_address).await?) + Some(TokioTcpAcceptor::bind(proxy_address).await?) } else { None }; @@ -373,7 +373,7 @@ async fn main() -> anyhow::Result<()> { let serverless_listener = if let Some(serverless_address) = args.wss { let serverless_address: SocketAddr = serverless_address.parse()?; info!("Starting wss on {serverless_address}"); - Some(TcpListener::bind(serverless_address).await?) + Some(TokioTcpAcceptor::bind(serverless_address).await?) } else if args.is_auth_broker { bail!("wss arg must be present for auth-broker") } else { diff --git a/proxy/src/conn.rs b/proxy/src/conn.rs new file mode 100644 index 0000000000..21e93c896d --- /dev/null +++ b/proxy/src/conn.rs @@ -0,0 +1,221 @@ +use std::future::{poll_fn, Future}; +use std::io; +use std::net::SocketAddr; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; + +pub trait Acceptor { + type Connection: AsyncRead + AsyncWrite + Send + Unpin + 'static; + type Error: std::error::Error + Send + Sync + 'static; + + #[inline] + fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { + let _ = cx; + Poll::Ready(Ok(())) + } + + fn accept( + &self, + ) -> impl Future> + Send; +} + +pub trait Connector { + type Connection: AsyncRead + AsyncWrite + Send + Unpin + 'static; + type Error: std::error::Error + Send + Sync + 'static; + + #[inline] + fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { + let _ = cx; + Poll::Ready(Ok(())) + } + + fn connect( + &self, + addr: SocketAddr, + ) -> impl Future> + Send; +} + +pub struct TokioTcpAcceptor { + listener: TcpListener, + tcp_nodelay: Option, + tcp_keepalive: Option, +} + +impl TokioTcpAcceptor { + pub async fn bind(addr: A) -> io::Result { + let listener = TcpListener::bind(addr).await?; + // When set for the server socket, the keepalive setting + // will be inherited by all accepted client sockets. + socket2::SockRef::from(&listener).set_keepalive(true)?; + Ok(Self { + listener, + tcp_nodelay: Some(true), + tcp_keepalive: None, + }) + } + + pub fn into_std(self) -> io::Result { + self.listener.into_std() + } +} + +impl Acceptor for TokioTcpAcceptor { + type Connection = TcpStream; + type Error = io::Error; + + fn accept(&self) -> impl Future> { + async move { + let (stream, addr) = self.listener.accept().await?; + + let socket = socket2::SockRef::from(&stream); + if let Some(nodelay) = self.tcp_nodelay { + socket.set_nodelay(nodelay)?; + } + if let Some(keepalive) = self.tcp_keepalive { + socket.set_keepalive(keepalive)?; + } + + Ok((stream, addr)) + } + } +} + +pub struct TokioTcpConnector; + +impl Connector for TokioTcpConnector { + type Connection = TcpStream; + type Error = io::Error; + + fn connect( + &self, + addr: SocketAddr, + ) -> impl Future> { + async move { + let socket = TcpStream::connect(addr).await?; + socket.set_nodelay(true)?; + Ok(socket) + } + } +} + +pub trait Stream: AsyncRead + AsyncWrite + Send + Unpin + 'static {} + +impl Stream for T {} + +pub trait AsyncRead { + fn readable(&self) -> impl Future> + Send + where + Self: Send + Sync, + { + poll_fn(move |cx| self.poll_read_ready(cx)) + } + + fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll>; + + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll>; + + fn poll_read_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [io::IoSliceMut<'_>], + ) -> Poll>; +} + +pub trait AsyncWrite { + fn writable(&self) -> impl Future> + Send + where + Self: Send + Sync, + { + poll_fn(move |cx| self.poll_write_ready(cx)) + } + + fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll>; + + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll>; + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll>; + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; +} + +impl AsyncRead for tokio::net::TcpStream { + fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll> { + tokio::net::TcpStream::poll_read_ready(self, cx) + } + + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + match tokio::net::TcpStream::try_read(Pin::new(&mut *self).get_mut(), buf) { + Ok(n) => Poll::Ready(Ok(n)), + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + cx.waker().wake_by_ref(); + Poll::Pending + } + Err(e) => Poll::Ready(Err(e)), + } + } + + fn poll_read_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [io::IoSliceMut<'_>], + ) -> Poll> { + match tokio::net::TcpStream::try_read_vectored(Pin::new(&mut *self).get_mut(), bufs) { + Ok(n) => Poll::Ready(Ok(n)), + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + cx.waker().wake_by_ref(); + Poll::Pending + } + Err(e) => Poll::Ready(Err(e)), + } + } +} + +impl AsyncWrite for tokio::net::TcpStream { + fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll> { + tokio::net::TcpStream::poll_write_ready(self, cx) + } + + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + ::poll_write(self, cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + ::poll_write_vectored(self, cx, bufs) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ::poll_flush(self, cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ::poll_shutdown(self, cx) + } +} diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 25a549039c..02659404ef 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -8,6 +8,7 @@ use tracing::{debug, error, info, Instrument}; use crate::auth::backend::ConsoleRedirectBackend; use crate::cancellation::{CancellationHandlerMain, CancellationHandlerMainInternal}; use crate::config::{ProxyConfig, ProxyProtocolV2}; +use crate::conn::{Acceptor, TokioTcpAcceptor}; use crate::context::RequestContext; use crate::error::ReportableError; use crate::metrics::{Metrics, NumClientConnectionsGuard}; @@ -22,7 +23,7 @@ use crate::proxy::{ pub async fn task_main( config: &'static ProxyConfig, backend: &'static ConsoleRedirectBackend, - listener: tokio::net::TcpListener, + acceptor: TokioTcpAcceptor, cancellation_token: CancellationToken, cancellation_handler: Arc, ) -> anyhow::Result<()> { @@ -30,15 +31,11 @@ pub async fn task_main( info!("proxy has shut down"); } - // When set for the server socket, the keepalive setting - // will be inherited by all accepted client sockets. - socket2::SockRef::from(&listener).set_keepalive(true)?; - let connections = tokio_util::task::task_tracker::TaskTracker::new(); let cancellations = tokio_util::task::task_tracker::TaskTracker::new(); while let Some(accept_result) = - run_until_cancelled(listener.accept(), &cancellation_token).await + run_until_cancelled(acceptor.accept(), &cancellation_token).await { let (socket, peer_addr) = accept_result?; @@ -131,7 +128,7 @@ pub async fn task_main( connections.close(); cancellations.close(); - drop(listener); + drop(acceptor); // Drain connections connections.wait().await; diff --git a/proxy/src/control_plane/mgmt.rs b/proxy/src/control_plane/mgmt.rs index 2f7359240d..9b2a5e24fe 100644 --- a/proxy/src/control_plane/mgmt.rs +++ b/proxy/src/control_plane/mgmt.rs @@ -4,10 +4,11 @@ use anyhow::Context; use once_cell::sync::Lazy; use postgres_backend::{AuthType, PostgresBackend, PostgresBackendTCP, QueryError}; use pq_proto::{BeMessage, SINGLE_COL_ROWDESC}; -use tokio::net::{TcpListener, TcpStream}; +use tokio::net::TcpStream; use tokio_util::sync::CancellationToken; use tracing::{error, info, info_span, Instrument}; +use crate::conn::{Acceptor, TokioTcpAcceptor}; use crate::control_plane::messages::{DatabaseInfo, KickSession}; use crate::waiters::{self, Waiter, Waiters}; @@ -26,19 +27,15 @@ pub(crate) fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), wai /// Management API listener task. /// It spawns management response handlers needed for the console redirect auth flow. -pub async fn task_main(listener: TcpListener) -> anyhow::Result { +pub async fn task_main(acceptor: TokioTcpAcceptor) -> anyhow::Result { scopeguard::defer! { info!("mgmt has shut down"); } loop { - let (socket, peer_addr) = listener.accept().await?; + let (socket, peer_addr) = acceptor.accept().await?; info!("accepted connection from {peer_addr}"); - socket - .set_nodelay(true) - .context("failed to set client socket option")?; - let span = info_span!("mgmt", peer = %peer_addr); tokio::task::spawn( diff --git a/proxy/src/http/health_server.rs b/proxy/src/http/health_server.rs index 6ca091feb7..a688b68462 100644 --- a/proxy/src/http/health_server.rs +++ b/proxy/src/http/health_server.rs @@ -1,5 +1,4 @@ use std::convert::Infallible; -use std::net::TcpListener; use std::sync::{Arc, Mutex}; use anyhow::{anyhow, bail}; @@ -14,6 +13,7 @@ use utils::http::error::ApiError; use utils::http::json::json_response; use utils::http::{RouterBuilder, RouterService}; +use crate::conn::TokioTcpAcceptor; use crate::ext::{LockExt, TaskExt}; use crate::jemalloc; @@ -36,7 +36,7 @@ fn make_router(metrics: AppMetrics) -> RouterBuilder { } pub async fn task_main( - http_listener: TcpListener, + http_acceptor: TokioTcpAcceptor, metrics: AppMetrics, ) -> anyhow::Result { scopeguard::defer! { @@ -45,7 +45,7 @@ pub async fn task_main( let service = || RouterService::new(make_router(metrics).build()?); - hyper0::Server::from_tcp(http_listener)? + hyper0::Server::from_tcp(http_acceptor.into_std()?)? .serve(service().map_err(|e| anyhow!(e))?) .await?; diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index c56474edd7..342c444587 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -78,6 +78,7 @@ pub mod cancellation; pub mod compute; pub mod compute_ctl; pub mod config; +pub mod conn; pub mod console_redirect_proxy; pub mod context; pub mod control_plane; diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 3926c56fec..3659839788 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -25,6 +25,7 @@ use self::connect_compute::{connect_to_compute, TcpMechanism}; use self::passthrough::ProxyPassthrough; use crate::cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal}; use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig}; +use crate::conn::{Acceptor, TokioTcpAcceptor}; use crate::context::RequestContext; use crate::error::ReportableError; use crate::metrics::{Metrics, NumClientConnectionsGuard}; @@ -55,7 +56,7 @@ pub async fn run_until_cancelled( pub async fn task_main( config: &'static ProxyConfig, auth_backend: &'static auth::Backend<'static, ()>, - listener: tokio::net::TcpListener, + acceptor: TokioTcpAcceptor, cancellation_token: CancellationToken, cancellation_handler: Arc, endpoint_rate_limiter: Arc, @@ -64,15 +65,11 @@ pub async fn task_main( info!("proxy has shut down"); } - // When set for the server socket, the keepalive setting - // will be inherited by all accepted client sockets. - socket2::SockRef::from(&listener).set_keepalive(true)?; - let connections = tokio_util::task::task_tracker::TaskTracker::new(); let cancellations = tokio_util::task::task_tracker::TaskTracker::new(); while let Some(accept_result) = - run_until_cancelled(listener.accept(), &cancellation_token).await + run_until_cancelled(acceptor.accept(), &cancellation_token).await { let (socket, peer_addr) = accept_result?; @@ -168,7 +165,7 @@ pub async fn task_main( connections.close(); cancellations.close(); - drop(listener); + drop(acceptor); // Drain connections connections.wait().await; diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index c2623e0eca..0f1a98bc76 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -35,7 +35,7 @@ use rand::rngs::StdRng; use rand::SeedableRng; use sql_over_http::{uuid_to_header_value, NEON_REQUEST_ID}; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::net::{TcpListener, TcpStream}; +use tokio::net::TcpStream; use tokio::time::timeout; use tokio_rustls::TlsAcceptor; use tokio_util::sync::CancellationToken; @@ -45,6 +45,7 @@ use utils::http::error::ApiError; use crate::cancellation::CancellationHandlerMain; use crate::config::{ProxyConfig, ProxyProtocolV2}; +use crate::conn::{Acceptor, TokioTcpAcceptor}; use crate::context::RequestContext; use crate::ext::TaskExt; use crate::metrics::Metrics; @@ -59,7 +60,7 @@ pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api"; pub async fn task_main( config: &'static ProxyConfig, auth_backend: &'static crate::auth::Backend<'static, ()>, - ws_listener: TcpListener, + ws_acceptor: TokioTcpAcceptor, cancellation_token: CancellationToken, cancellation_handler: Arc, endpoint_rate_limiter: Arc, @@ -134,7 +135,7 @@ pub async fn task_main( connections.close(); // allows `connections.wait to complete` let cancellations = tokio_util::task::task_tracker::TaskTracker::new(); - while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await { + while let Some(res) = run_until_cancelled(ws_acceptor.accept(), &cancellation_token).await { let (conn, peer_addr) = res.context("could not accept TCP stream")?; if let Err(e) = conn.set_nodelay(true) { tracing::error!("could not set nodelay: {e}");