diff --git a/proxy/src/protocol2.rs b/proxy/src/protocol2.rs index 70f9b4bfab..1dd4563514 100644 --- a/proxy/src/protocol2.rs +++ b/proxy/src/protocol2.rs @@ -1,42 +1,26 @@ //! Proxy Protocol V2 implementation use std::{ - future::{poll_fn, Future}, io, net::SocketAddr, - pin::{pin, Pin}, - task::{ready, Context, Poll}, + pin::Pin, + task::{Context, Poll}, }; -use bytes::{Buf, BytesMut}; -use hyper::server::conn::AddrIncoming; +use bytes::BytesMut; use pin_project_lite::pin_project; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf}; -pub struct ProxyProtocolAccept { - pub incoming: AddrIncoming, - pub protocol: &'static str, -} - pin_project! { - pub struct WithClientIp { + /// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough + pub struct ChainRW { #[pin] pub inner: T, buf: BytesMut, - tlv_bytes: u16, - state: ProxyParse, } } -#[derive(Clone, PartialEq, Debug)] -enum ProxyParse { - NotStarted, - - Finished(SocketAddr), - None, -} - -impl AsyncWrite for WithClientIp { +impl AsyncWrite for ChainRW { #[inline] fn poll_write( self: Pin<&mut Self>, @@ -71,267 +55,174 @@ impl AsyncWrite for WithClientIp { } } -impl WithClientIp { - pub fn new(inner: T) -> Self { - WithClientIp { - inner, - buf: BytesMut::with_capacity(128), - tlv_bytes: 0, - state: ProxyParse::NotStarted, - } - } - - pub fn client_addr(&self) -> Option { - match self.state { - ProxyParse::Finished(socket) => Some(socket), - _ => None, - } - } -} - -impl WithClientIp { - pub async fn wait_for_addr(&mut self) -> io::Result> { - match self.state { - ProxyParse::NotStarted => { - let mut pin = Pin::new(&mut *self); - let addr = poll_fn(|cx| pin.as_mut().poll_client_ip(cx)).await?; - match addr { - Some(addr) => self.state = ProxyParse::Finished(addr), - None => self.state = ProxyParse::None, - } - Ok(addr) - } - ProxyParse::Finished(addr) => Ok(Some(addr)), - ProxyParse::None => Ok(None), - } - } -} - /// Proxy Protocol Version 2 Header const HEADER: [u8; 12] = [ 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, ]; -impl WithClientIp { - /// implementation of - /// Version 2 (Binary Format) - fn poll_client_ip( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - // The binary header format starts with a constant 12 bytes block containing the protocol signature : - // \x0D \x0A \x0D \x0A \x00 \x0D \x0A \x51 \x55 \x49 \x54 \x0A - while self.buf.len() < 16 { - let mut this = self.as_mut().project(); - let bytes_read = pin!(this.inner.read_buf(this.buf)).poll(cx)?; +pub async fn read_proxy_protocol( + mut read: T, +) -> std::io::Result<(ChainRW, Option)> { + let mut buf = BytesMut::with_capacity(128); + while buf.len() < 16 { + let bytes_read = read.read_buf(&mut buf).await?; - // exit for bad header - let len = usize::min(self.buf.len(), HEADER.len()); - if self.buf[..len] != HEADER[..len] { - return Poll::Ready(Ok(None)); - } - - // if no more bytes available then exit - if ready!(bytes_read) == 0 { - return Poll::Ready(Ok(None)); - }; + // exit for bad header + let len = usize::min(buf.len(), HEADER.len()); + if buf[..len] != HEADER[..len] { + return Ok((ChainRW { inner: read, buf }, None)); } - // The next byte (the 13th one) is the protocol version and command. - // The highest four bits contains the version. As of this specification, it must - // always be sent as \x2 and the receiver must only accept this value. - let vc = self.buf[12]; - let version = vc >> 4; - let command = vc & 0b1111; - if version != 2 { - return Poll::Ready(Err(io::Error::new( + // if no more bytes available then exit + if bytes_read == 0 { + return Ok((ChainRW { inner: read, buf }, None)); + }; + } + + let header = buf.split_to(16); + + // The next byte (the 13th one) is the protocol version and command. + // The highest four bits contains the version. As of this specification, it must + // always be sent as \x2 and the receiver must only accept this value. + let vc = header[12]; + let version = vc >> 4; + let command = vc & 0b1111; + if version != 2 { + return Err(io::Error::new( + io::ErrorKind::Other, + "invalid proxy protocol version. expected version 2", + )); + } + match command { + // the connection was established on purpose by the proxy + // without being relayed. The connection endpoints are the sender and the + // receiver. Such connections exist when the proxy sends health-checks to the + // server. The receiver must accept this connection as valid and must use the + // real connection endpoints and discard the protocol block including the + // family which is ignored. + 0 => {} + // the connection was established on behalf of another node, + // and reflects the original connection endpoints. The receiver must then use + // the information provided in the protocol block to get original the address. + 1 => {} + // other values are unassigned and must not be emitted by senders. Receivers + // must drop connections presenting unexpected values here. + _ => { + return Err(io::Error::new( io::ErrorKind::Other, - "invalid proxy protocol version. expected version 2", - ))); + "invalid proxy protocol command. expected local (0) or proxy (1)", + )) } - match command { - // the connection was established on purpose by the proxy - // without being relayed. The connection endpoints are the sender and the - // receiver. Such connections exist when the proxy sends health-checks to the - // server. The receiver must accept this connection as valid and must use the - // real connection endpoints and discard the protocol block including the - // family which is ignored. - 0 => {} - // the connection was established on behalf of another node, - // and reflects the original connection endpoints. The receiver must then use - // the information provided in the protocol block to get original the address. - 1 => {} - // other values are unassigned and must not be emitted by senders. Receivers - // must drop connections presenting unexpected values here. - _ => { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::Other, - "invalid proxy protocol command. expected local (0) or proxy (1)", - ))) - } - }; + }; - // The 14th byte contains the transport protocol and address family. The highest 4 - // bits contain the address family, the lowest 4 bits contain the protocol. - let ft = self.buf[13]; - let address_length = match ft { - // - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET - // protocol family. Address length is 2*4 + 2*2 = 12 bytes. - // - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET - // protocol family. Address length is 2*4 + 2*2 = 12 bytes. - 0x11 | 0x12 => 12, - // - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6 - // protocol family. Address length is 2*16 + 2*2 = 36 bytes. - // - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6 - // protocol family. Address length is 2*16 + 2*2 = 36 bytes. - 0x21 | 0x22 => 36, - // unspecified or unix stream. ignore the addresses - _ => 0, - }; + // The 14th byte contains the transport protocol and address family. The highest 4 + // bits contain the address family, the lowest 4 bits contain the protocol. + let ft = header[13]; + let address_length = match ft { + // - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET + // protocol family. Address length is 2*4 + 2*2 = 12 bytes. + // - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET + // protocol family. Address length is 2*4 + 2*2 = 12 bytes. + 0x11 | 0x12 => 12, + // - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6 + // protocol family. Address length is 2*16 + 2*2 = 36 bytes. + // - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6 + // protocol family. Address length is 2*16 + 2*2 = 36 bytes. + 0x21 | 0x22 => 36, + // unspecified or unix stream. ignore the addresses + _ => 0, + }; - // The 15th and 16th bytes is the address length in bytes in network endian order. - // It is used so that the receiver knows how many address bytes to skip even when - // it does not implement the presented protocol. Thus the length of the protocol - // header in bytes is always exactly 16 + this value. When a sender presents a - // LOCAL connection, it should not present any address so it sets this field to - // zero. Receivers MUST always consider this field to skip the appropriate number - // of bytes and must not assume zero is presented for LOCAL connections. When a - // receiver accepts an incoming connection showing an UNSPEC address family or - // protocol, it may or may not decide to log the address information if present. - let remaining_length = u16::from_be_bytes(self.buf[14..16].try_into().unwrap()); - if remaining_length < address_length { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::Other, - "invalid proxy protocol length. not enough to fit requested IP addresses", - ))); + // The 15th and 16th bytes is the address length in bytes in network endian order. + // It is used so that the receiver knows how many address bytes to skip even when + // it does not implement the presented protocol. Thus the length of the protocol + // header in bytes is always exactly 16 + this value. When a sender presents a + // LOCAL connection, it should not present any address so it sets this field to + // zero. Receivers MUST always consider this field to skip the appropriate number + // of bytes and must not assume zero is presented for LOCAL connections. When a + // receiver accepts an incoming connection showing an UNSPEC address family or + // protocol, it may or may not decide to log the address information if present. + let remaining_length = u16::from_be_bytes(header[14..16].try_into().unwrap()); + if remaining_length < address_length { + return Err(io::Error::new( + io::ErrorKind::Other, + "invalid proxy protocol length. not enough to fit requested IP addresses", + )); + } + drop(header); + + while buf.len() < remaining_length as usize { + if read.read_buf(&mut buf).await? == 0 { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "stream closed while waiting for proxy protocol addresses", + )); } - - while self.buf.len() < 16 + address_length as usize { - let mut this = self.as_mut().project(); - if ready!(pin!(this.inner.read_buf(this.buf)).poll(cx)?) == 0 { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - "stream closed while waiting for proxy protocol addresses", - ))); - } - } - - let this = self.as_mut().project(); - - // we are sure this is a proxy protocol v2 entry and we have read all the bytes we need - // discard the header we have parsed - this.buf.advance(16); - - // Starting from the 17th byte, addresses are presented in network byte order. - // The address order is always the same : - // - source layer 3 address in network byte order - // - destination layer 3 address in network byte order - // - source layer 4 address if any, in network byte order (port) - // - destination layer 4 address if any, in network byte order (port) - let addresses = this.buf.split_to(address_length as usize); - let socket = match address_length { - 12 => { - let src_addr: [u8; 4] = addresses[0..4].try_into().unwrap(); - let src_port = u16::from_be_bytes(addresses[8..10].try_into().unwrap()); - Some(SocketAddr::from((src_addr, src_port))) - } - 36 => { - let src_addr: [u8; 16] = addresses[0..16].try_into().unwrap(); - let src_port = u16::from_be_bytes(addresses[32..34].try_into().unwrap()); - Some(SocketAddr::from((src_addr, src_port))) - } - _ => None, - }; - - *this.tlv_bytes = remaining_length - address_length; - self.as_mut().skip_tlv_inner(); - - Poll::Ready(Ok(socket)) } - #[cold] - fn read_ip(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let ip = ready!(self.as_mut().poll_client_ip(cx)?); - match ip { - Some(x) => *self.as_mut().project().state = ProxyParse::Finished(x), - None => *self.as_mut().project().state = ProxyParse::None, + // Starting from the 17th byte, addresses are presented in network byte order. + // The address order is always the same : + // - source layer 3 address in network byte order + // - destination layer 3 address in network byte order + // - source layer 4 address if any, in network byte order (port) + // - destination layer 4 address if any, in network byte order (port) + let addresses = buf.split_to(remaining_length as usize); + let socket = match address_length { + 12 => { + let src_addr: [u8; 4] = addresses[0..4].try_into().unwrap(); + let src_port = u16::from_be_bytes(addresses[8..10].try_into().unwrap()); + Some(SocketAddr::from((src_addr, src_port))) } - Poll::Ready(Ok(())) - } + 36 => { + let src_addr: [u8; 16] = addresses[0..16].try_into().unwrap(); + let src_port = u16::from_be_bytes(addresses[32..34].try_into().unwrap()); + Some(SocketAddr::from((src_addr, src_port))) + } + _ => None, + }; - #[cold] - fn skip_tlv(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.as_mut().project(); - // we know that this.buf is empty - debug_assert_eq!(this.buf.len(), 0); - - this.buf.reserve((*this.tlv_bytes).clamp(0, 1024) as usize); - ready!(pin!(this.inner.read_buf(this.buf)).poll(cx)?); - self.skip_tlv_inner(); - - Poll::Ready(Ok(())) - } - - fn skip_tlv_inner(self: Pin<&mut Self>) { - let tlv_bytes_read = match u16::try_from(self.buf.len()) { - // we read more than u16::MAX therefore we must have read the full tlv_bytes - Err(_) => self.tlv_bytes, - // we might not have read the full tlv bytes yet - Ok(n) => u16::min(n, self.tlv_bytes), - }; - let this = self.project(); - *this.tlv_bytes -= tlv_bytes_read; - this.buf.advance(tlv_bytes_read as usize); - } + Ok((ChainRW { inner: read, buf }, socket)) } -impl AsyncRead for WithClientIp { +impl AsyncRead for ChainRW { #[inline] fn poll_read( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - // I'm assuming these 3 comparisons will be easy to branch predict. - // especially with the cold attributes - // which should make this read wrapper almost invisible - - if let ProxyParse::NotStarted = self.state { - ready!(self.as_mut().read_ip(cx)?); - } - - while self.tlv_bytes > 0 { - ready!(self.as_mut().skip_tlv(cx)?) - } - - let this = self.project(); - if this.buf.is_empty() { - this.inner.poll_read(cx, buf) + if self.buf.is_empty() { + self.project().inner.poll_read(cx, buf) } else { - // we know that tlv_bytes is 0 - debug_assert_eq!(*this.tlv_bytes, 0); - - let write = usize::min(this.buf.len(), buf.remaining()); - let slice = this.buf.split_to(write).freeze(); - buf.put_slice(&slice); - - // reset the allocation so it can be freed - if this.buf.is_empty() { - *this.buf = BytesMut::new(); - } - - Poll::Ready(Ok(())) + self.read_from_buf(buf) } } } +impl ChainRW { + #[cold] + fn read_from_buf(self: Pin<&mut Self>, buf: &mut ReadBuf<'_>) -> Poll> { + debug_assert!(!self.buf.is_empty()); + let this = self.project(); + + let write = usize::min(this.buf.len(), buf.remaining()); + let slice = this.buf.split_to(write).freeze(); + buf.put_slice(&slice); + + // reset the allocation so it can be freed + if this.buf.is_empty() { + *this.buf = BytesMut::new(); + } + + Poll::Ready(Ok(())) + } +} + #[cfg(test)] mod tests { - use std::pin::pin; - use tokio::io::AsyncReadExt; - use crate::protocol2::{ProxyParse, WithClientIp}; + use crate::protocol2::read_proxy_protocol; #[tokio::test] async fn test_ipv4() { @@ -353,16 +244,15 @@ mod tests { let extra_data = [0x55; 256]; - let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice()))); + let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice())) + .await + .unwrap(); let mut bytes = vec![]; read.read_to_end(&mut bytes).await.unwrap(); assert_eq!(bytes, extra_data); - assert_eq!( - read.state, - ProxyParse::Finished(([127, 0, 0, 1], 65535).into()) - ); + assert_eq!(addr, Some(([127, 0, 0, 1], 65535).into())); } #[tokio::test] @@ -385,17 +275,17 @@ mod tests { let extra_data = [0x55; 256]; - let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice()))); + let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice())) + .await + .unwrap(); let mut bytes = vec![]; read.read_to_end(&mut bytes).await.unwrap(); assert_eq!(bytes, extra_data); assert_eq!( - read.state, - ProxyParse::Finished( - ([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into() - ) + addr, + Some(([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into()) ); } @@ -403,24 +293,24 @@ mod tests { async fn test_invalid() { let data = [0x55; 256]; - let mut read = pin!(WithClientIp::new(data.as_slice())); + let (mut read, addr) = read_proxy_protocol(data.as_slice()).await.unwrap(); let mut bytes = vec![]; read.read_to_end(&mut bytes).await.unwrap(); assert_eq!(bytes, data); - assert_eq!(read.state, ProxyParse::None); + assert_eq!(addr, None); } #[tokio::test] async fn test_short() { let data = [0x55; 10]; - let mut read = pin!(WithClientIp::new(data.as_slice())); + let (mut read, addr) = read_proxy_protocol(data.as_slice()).await.unwrap(); let mut bytes = vec![]; read.read_to_end(&mut bytes).await.unwrap(); assert_eq!(bytes, data); - assert_eq!(read.state, ProxyParse::None); + assert_eq!(addr, None); } #[tokio::test] @@ -446,15 +336,14 @@ mod tests { let extra_data = [0xaa; 256]; - let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice()))); + let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice())) + .await + .unwrap(); let mut bytes = vec![]; read.read_to_end(&mut bytes).await.unwrap(); assert_eq!(bytes, extra_data); - assert_eq!( - read.state, - ProxyParse::Finished(([55, 56, 57, 58], 65535).into()) - ); + assert_eq!(addr, Some(([55, 56, 57, 58], 65535).into())); } } diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index a4554eef38..ddae6536fb 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -17,7 +17,7 @@ use crate::{ context::RequestMonitoring, error::ReportableError, metrics::{Metrics, NumClientConnectionsGuard}, - protocol2::WithClientIp, + protocol2::read_proxy_protocol, proxy::handshake::{handshake, HandshakeData}, stream::{PqStream, Stream}, EndpointCacheKey, @@ -88,20 +88,18 @@ pub async fn task_main( tracing::info!(protocol = "tcp", %session_id, "accepted new TCP connection"); connections.spawn(async move { - let mut socket = WithClientIp::new(socket); - let mut peer_addr = peer_addr.ip(); - match socket.wait_for_addr().await { - Ok(Some(addr)) => peer_addr = addr.ip(), + let (socket, peer_addr) = match read_proxy_protocol(socket).await{ + Ok((socket, Some(addr))) => (socket, addr.ip()), Err(e) => { error!("per-client task finished with an error: {e:#}"); return; } - Ok(None) if config.require_client_ip => { + Ok((_socket, None)) if config.require_client_ip => { error!("missing required client IP"); return; } - Ok(None) => {} - } + Ok((socket, None)) => (socket, peer_addr.ip()) + }; match socket.inner.set_nodelay(true) { Ok(()) => {}, diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index e0ec90cb44..ad48af0093 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -174,7 +174,7 @@ async fn dummy_proxy( tls: Option, auth: impl TestAuth + Send, ) -> anyhow::Result<()> { - let client = WithClientIp::new(client); + let (client, _) = read_proxy_protocol(client).await?; let mut stream = match handshake(client, tls.as_ref(), false).await? { HandshakeData::Startup(stream, _) => stream, HandshakeData::Cancel(_) => bail!("cancellation not supported"), diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index b0f4026c76..1a0d1f7b0e 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -33,7 +33,7 @@ use crate::cancellation::CancellationHandlerMain; use crate::config::ProxyConfig; use crate::context::RequestMonitoring; use crate::metrics::Metrics; -use crate::protocol2::WithClientIp; +use crate::protocol2::read_proxy_protocol; use crate::proxy::run_until_cancelled; use crate::serverless::backend::PoolingBackend; use crate::serverless::http_util::{api_error_into_response, json_response}; @@ -158,9 +158,8 @@ async fn connection_handler( .guard(crate::metrics::Protocol::Http); // handle PROXY protocol - let mut conn = WithClientIp::new(conn); - let peer = match conn.wait_for_addr().await { - Ok(peer) => peer, + let (conn, peer) = match read_proxy_protocol(conn).await { + Ok(c) => c, Err(e) => { tracing::error!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}"); return;