diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index dcae263647..757c1e988b 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -221,8 +221,7 @@ struct ProxyCliArgs { is_private_access_proxy: bool, /// Configure whether all incoming requests have a Proxy Protocol V2 packet. - // TODO(conradludgate): switch default to rejected or required once we've updated all deployments - #[clap(value_enum, long, default_value_t = ProxyProtocolV2::Supported)] + #[clap(value_enum, long, default_value_t = ProxyProtocolV2::Rejected)] proxy_protocol_v2: ProxyProtocolV2, /// Time the proxy waits for the webauth session to be confirmed by the control plane. diff --git a/proxy/src/config.rs b/proxy/src/config.rs index a97339df9a..248584a19a 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -39,8 +39,6 @@ pub struct ComputeConfig { pub enum ProxyProtocolV2 { /// Connection will error if PROXY protocol v2 header is missing Required, - /// Connection will parse PROXY protocol v2 header, but accept the connection if it's missing. - Supported, /// Connection will error if PROXY protocol v2 header is provided Rejected, } diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 7fb84b5ee5..6755499b45 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -54,30 +54,24 @@ pub async fn task_main( debug!(protocol = "tcp", %session_id, "accepted new TCP connection"); connections.spawn(async move { - let (socket, peer_addr) = match read_proxy_protocol(socket).await { - Err(e) => { - error!("per-client task finished with an error: {e:#}"); - return; + let (socket, conn_info) = match config.proxy_protocol_v2 { + ProxyProtocolV2::Required => { + match read_proxy_protocol(socket).await { + Err(e) => { + error!("per-client task finished with an error: {e:#}"); + return; + } + // our load balancers will not send any more data. let's just exit immediately + Ok((_socket, ConnectHeader::Local)) => { + debug!("healthcheck received"); + return; + } + Ok((socket, ConnectHeader::Proxy(info))) => (socket, info), + } } - // our load balancers will not send any more data. let's just exit immediately - Ok((_socket, ConnectHeader::Local)) => { - debug!("healthcheck received"); - return; - } - Ok((_socket, ConnectHeader::Missing)) - if config.proxy_protocol_v2 == ProxyProtocolV2::Required => - { - error!("missing required proxy protocol header"); - return; - } - Ok((_socket, ConnectHeader::Proxy(_))) - if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => - { - error!("proxy protocol header not supported"); - return; - } - Ok((socket, ConnectHeader::Proxy(info))) => (socket, info), - Ok((socket, ConnectHeader::Missing)) => ( + // ignore the header - it cannot be confused for a postgres or http connection so will + // error later. + ProxyProtocolV2::Rejected => ( socket, ConnectionInfo { addr: peer_addr, @@ -86,7 +80,7 @@ pub async fn task_main( ), }; - match socket.inner.set_nodelay(true) { + match socket.set_nodelay(true) { Ok(()) => {} Err(e) => { error!( @@ -98,7 +92,7 @@ pub async fn task_main( let ctx = RequestContext::new( session_id, - peer_addr, + conn_info, crate::metrics::Protocol::Tcp, &config.region, ); diff --git a/proxy/src/protocol2.rs b/proxy/src/protocol2.rs index 0793998639..5bec6d6ca3 100644 --- a/proxy/src/protocol2.rs +++ b/proxy/src/protocol2.rs @@ -4,60 +4,13 @@ use core::fmt; use std::io; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; -use std::pin::Pin; -use std::task::{Context, Poll}; -use bytes::{Buf, Bytes, BytesMut}; -use pin_project_lite::pin_project; +use bytes::Buf; use smol_str::SmolStr; use strum_macros::FromRepr; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf}; +use tokio::io::{AsyncRead, AsyncReadExt}; use zerocopy::{FromBytes, Immutable, KnownLayout, Unaligned, network_endian}; -pin_project! { - /// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough - pub(crate) struct ChainRW { - #[pin] - pub(crate) inner: T, - buf: BytesMut, - } -} - -impl AsyncWrite for ChainRW { - #[inline] - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self.project().inner.poll_write(cx, buf) - } - - #[inline] - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_flush(cx) - } - - #[inline] - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_shutdown(cx) - } - - #[inline] - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[io::IoSlice<'_>], - ) -> Poll> { - self.project().inner.poll_write_vectored(cx, bufs) - } - - #[inline] - fn is_write_vectored(&self) -> bool { - self.inner.is_write_vectored() - } -} - /// Proxy Protocol Version 2 Header const SIGNATURE: [u8; 12] = [ 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, @@ -79,7 +32,6 @@ pub struct ConnectionInfo { #[derive(PartialEq, Eq, Clone, Debug)] pub enum ConnectHeader { - Missing, Local, Proxy(ConnectionInfo), } @@ -106,47 +58,24 @@ pub enum ConnectionInfoExtra { pub(crate) async fn read_proxy_protocol( mut read: T, -) -> std::io::Result<(ChainRW, ConnectHeader)> { - let mut buf = BytesMut::with_capacity(128); - let header = loop { - let bytes_read = read.read_buf(&mut buf).await?; - - // exit for bad header signature - let len = usize::min(buf.len(), SIGNATURE.len()); - if buf[..len] != SIGNATURE[..len] { - return Ok((ChainRW { inner: read, buf }, ConnectHeader::Missing)); - } - - // if no more bytes available then exit - if bytes_read == 0 { - return Ok((ChainRW { inner: read, buf }, ConnectHeader::Missing)); - } - - // check if we have enough bytes to continue - if let Some(header) = buf.try_get::() { - break header; - } - }; - - let remaining_length = usize::from(header.len.get()); - - while buf.len() < remaining_length { - if read.read_buf(&mut buf).await? == 0 { - return Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - "stream closed while waiting for proxy protocol addresses", - )); - } +) -> std::io::Result<(T, ConnectHeader)> { + let mut header = [0; size_of::()]; + read.read_exact(&mut header).await?; + let header: ProxyProtocolV2Header = zerocopy::transmute!(header); + if header.signature != SIGNATURE { + return Err(std::io::Error::other("invalid proxy protocol header")); } - let payload = buf.split_to(remaining_length); - let res = process_proxy_payload(header, payload)?; - Ok((ChainRW { inner: read, buf }, res)) + let mut payload = vec![0; usize::from(header.len.get())]; + read.read_exact(&mut payload).await?; + + let res = process_proxy_payload(header, &payload)?; + Ok((read, res)) } fn process_proxy_payload( header: ProxyProtocolV2Header, - mut payload: BytesMut, + mut payload: &[u8], ) -> std::io::Result { match header.version_and_command { // the connection was established on purpose by the proxy @@ -162,13 +91,12 @@ fn process_proxy_payload( PROXY_V2 => {} // other values are unassigned and must not be emitted by senders. Receivers // must drop connections presenting unexpected values here. - #[rustfmt::skip] // https://github.com/rust-lang/rustfmt/issues/6384 - _ => return Err(io::Error::other( - format!( + _ => { + return Err(io::Error::other(format!( "invalid proxy protocol command 0x{:02X}. expected local (0x20) or proxy (0x21)", header.version_and_command - ), - )), + ))); + } } let size_err = @@ -206,7 +134,7 @@ fn process_proxy_payload( } let subtype = tlv.value.get_u8(); match Pp2AwsType::from_repr(subtype) { - Some(Pp2AwsType::VpceId) => match std::str::from_utf8(&tlv.value) { + Some(Pp2AwsType::VpceId) => match std::str::from_utf8(tlv.value) { Ok(s) => { extra = Some(ConnectionInfoExtra::Aws { vpce_id: s.into() }); } @@ -282,65 +210,28 @@ enum Pp2AzureType { PrivateEndpointLinkId = 0x01, } -impl AsyncRead for ChainRW { - #[inline] - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - if self.buf.is_empty() { - self.project().inner.poll_read(cx, buf) - } else { - self.read_from_buf(buf) - } - } -} - -impl ChainRW { - #[cold] - fn read_from_buf(self: Pin<&mut Self>, buf: &mut ReadBuf<'_>) -> Poll> { - debug_assert!(!self.buf.is_empty()); - let this = self.project(); - - let write = usize::min(this.buf.len(), buf.remaining()); - let slice = this.buf.split_to(write).freeze(); - buf.put_slice(&slice); - - // reset the allocation so it can be freed - if this.buf.is_empty() { - *this.buf = BytesMut::new(); - } - - Poll::Ready(Ok(())) - } -} - #[derive(Debug)] -struct Tlv { +struct Tlv<'a> { kind: u8, - value: Bytes, + value: &'a [u8], } -fn read_tlv(b: &mut BytesMut) -> Option { +fn read_tlv<'a>(b: &mut &'a [u8]) -> Option> { let tlv_header = b.try_get::()?; let len = usize::from(tlv_header.len.get()); - if b.len() < len { - return None; - } Some(Tlv { kind: tlv_header.kind, - value: b.split_to(len).freeze(), + value: b.split_off(..len)?, }) } trait BufExt: Sized { fn try_get(&mut self) -> Option; } -impl BufExt for BytesMut { +impl BufExt for &[u8] { fn try_get(&mut self) -> Option { - let (res, _) = T::read_from_prefix(self).ok()?; - self.advance(size_of::()); + let (res, rest) = T::read_from_prefix(self).ok()?; + *self = rest; Some(res) } } @@ -481,27 +372,19 @@ mod tests { } #[tokio::test] + #[should_panic = "invalid proxy protocol header"] async fn test_invalid() { let data = [0x55; 256]; - let (mut read, info) = read_proxy_protocol(data.as_slice()).await.unwrap(); - - let mut bytes = vec![]; - read.read_to_end(&mut bytes).await.unwrap(); - assert_eq!(bytes, data); - assert_eq!(info, ConnectHeader::Missing); + read_proxy_protocol(data.as_slice()).await.unwrap(); } #[tokio::test] + #[should_panic = "early eof"] async fn test_short() { let data = [0x55; 10]; - let (mut read, info) = read_proxy_protocol(data.as_slice()).await.unwrap(); - - let mut bytes = vec![]; - read.read_to_end(&mut bytes).await.unwrap(); - assert_eq!(bytes, data); - assert_eq!(info, ConnectHeader::Missing); + read_proxy_protocol(data.as_slice()).await.unwrap(); } #[tokio::test] diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 0ffc54aa88..477baff1c9 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -102,30 +102,24 @@ pub async fn task_main( let endpoint_rate_limiter2 = endpoint_rate_limiter.clone(); connections.spawn(async move { - let (socket, conn_info) = match read_proxy_protocol(socket).await { - Err(e) => { - warn!("per-client task finished with an error: {e:#}"); - return; + let (socket, conn_info) = match config.proxy_protocol_v2 { + ProxyProtocolV2::Required => { + match read_proxy_protocol(socket).await { + Err(e) => { + warn!("per-client task finished with an error: {e:#}"); + return; + } + // our load balancers will not send any more data. let's just exit immediately + Ok((_socket, ConnectHeader::Local)) => { + debug!("healthcheck received"); + return; + } + Ok((socket, ConnectHeader::Proxy(info))) => (socket, info), + } } - // our load balancers will not send any more data. let's just exit immediately - Ok((_socket, ConnectHeader::Local)) => { - debug!("healthcheck received"); - return; - } - Ok((_socket, ConnectHeader::Missing)) - if config.proxy_protocol_v2 == ProxyProtocolV2::Required => - { - warn!("missing required proxy protocol header"); - return; - } - Ok((_socket, ConnectHeader::Proxy(_))) - if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => - { - warn!("proxy protocol header not supported"); - return; - } - Ok((socket, ConnectHeader::Proxy(info))) => (socket, info), - Ok((socket, ConnectHeader::Missing)) => ( + // ignore the header - it cannot be confused for a postgres or http connection so will + // error later. + ProxyProtocolV2::Rejected => ( socket, ConnectionInfo { addr: peer_addr, @@ -134,7 +128,7 @@ pub async fn task_main( ), }; - match socket.inner.set_nodelay(true) { + match socket.set_nodelay(true) { Ok(()) => {} Err(e) => { error!( diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 61e8ee4a10..117c42e19c 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -173,7 +173,6 @@ async fn dummy_proxy( tls: Option, auth: impl TestAuth + Send, ) -> anyhow::Result<()> { - let (client, _) = read_proxy_protocol(client).await?; let mut stream = match handshake(&RequestContext::test(), client, tls.as_ref(), false).await? { HandshakeData::Startup(stream, _) => stream, HandshakeData::Cancel(_) => bail!("cancellation not supported"), diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index 2a7069b1c2..f6f681ac45 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -49,7 +49,7 @@ use crate::config::{ProxyConfig, ProxyProtocolV2}; use crate::context::RequestContext; use crate::ext::TaskExt; use crate::metrics::Metrics; -use crate::protocol2::{ChainRW, ConnectHeader, ConnectionInfo, read_proxy_protocol}; +use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol}; use crate::proxy::run_until_cancelled; use crate::rate_limiter::EndpointRateLimiter; use crate::serverless::backend::PoolingBackend; @@ -207,12 +207,12 @@ pub(crate) type AsyncRW = Pin>; #[async_trait] trait MaybeTlsAcceptor: Send + Sync + 'static { - async fn accept(&self, conn: ChainRW) -> std::io::Result; + async fn accept(&self, conn: TcpStream) -> std::io::Result; } #[async_trait] impl MaybeTlsAcceptor for &'static ArcSwapOption { - async fn accept(&self, conn: ChainRW) -> std::io::Result { + async fn accept(&self, conn: TcpStream) -> std::io::Result { match &*self.load() { Some(config) => Ok(Box::pin( TlsAcceptor::from(config.http_config.clone()) @@ -235,33 +235,30 @@ async fn connection_startup( peer_addr: SocketAddr, ) -> Option<(AsyncRW, ConnectionInfo)> { // handle PROXY protocol - let (conn, peer) = match read_proxy_protocol(conn).await { - Ok(c) => c, - Err(e) => { - tracing::warn!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}"); - return None; + let (conn, conn_info) = match config.proxy_protocol_v2 { + ProxyProtocolV2::Required => { + match read_proxy_protocol(conn).await { + Err(e) => { + warn!("per-client task finished with an error: {e:#}"); + return None; + } + // our load balancers will not send any more data. let's just exit immediately + Ok((_conn, ConnectHeader::Local)) => { + tracing::debug!("healthcheck received"); + return None; + } + Ok((conn, ConnectHeader::Proxy(info))) => (conn, info), + } } - }; - - let conn_info = match peer { - // our load balancers will not send any more data. let's just exit immediately - ConnectHeader::Local => { - tracing::debug!("healthcheck received"); - return None; - } - ConnectHeader::Missing if config.proxy_protocol_v2 == ProxyProtocolV2::Required => { - tracing::warn!("missing required proxy protocol header"); - return None; - } - ConnectHeader::Proxy(_) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => { - tracing::warn!("proxy protocol header not supported"); - return None; - } - ConnectHeader::Proxy(info) => info, - ConnectHeader::Missing => ConnectionInfo { - addr: peer_addr, - extra: None, - }, + // ignore the header - it cannot be confused for a postgres or http connection so will + // error later. + ProxyProtocolV2::Rejected => ( + conn, + ConnectionInfo { + addr: peer_addr, + extra: None, + }, + ), }; let has_private_peer_addr = match conn_info.addr.ip() {