diff --git a/proxy/src/http/websocket.rs b/proxy/src/http/websocket.rs index 72ae3dc26f..fa66df0469 100644 --- a/proxy/src/http/websocket.rs +++ b/proxy/src/http/websocket.rs @@ -2,6 +2,7 @@ use crate::{ cancellation::CancelMap, config::ProxyConfig, error::io_error, + protocol2::{ProxyProtocolAccept, WithClientIp}, proxy::{handle_client, ClientMode}, }; use bytes::{Buf, Bytes}; @@ -292,6 +293,9 @@ pub async fn task_main( let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?; let _ = addr_incoming.set_nodelay(true); + let addr_incoming = ProxyProtocolAccept { + incoming: addr_incoming, + }; let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| { if let Err(err) = conn { @@ -302,9 +306,11 @@ pub async fn task_main( } }); - let make_svc = - hyper::service::make_service_fn(|stream: &tokio_rustls::server::TlsStream| { - let sni_name = stream.get_ref().1.server_name().map(|s| s.to_string()); + let make_svc = hyper::service::make_service_fn( + |stream: &tokio_rustls::server::TlsStream>| { + let (io, tls) = stream.get_ref(); + let peer_addr = io.client_addr().unwrap_or(io.inner.remote_addr()); + let sni_name = tls.server_name().map(|s| s.to_string()); let conn_pool = conn_pool.clone(); async move { @@ -319,13 +325,15 @@ pub async fn task_main( ws_handler(req, config, conn_pool, cancel_map, session_id, sni_name) .instrument(info_span!( "ws-client", - session = %session_id + session = %session_id, + %peer_addr, )) .await } })) } - }); + }, + ); hyper::Server::builder(accept::from_stream(tls_listener)) .serve(make_svc) diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index 1e1e216bb7..a3d1cdd3c8 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -16,6 +16,7 @@ pub mod http; pub mod logging; pub mod metrics; pub mod parse; +pub mod protocol2; pub mod proxy; pub mod sasl; pub mod scram; diff --git a/proxy/src/protocol2.rs b/proxy/src/protocol2.rs new file mode 100644 index 0000000000..1d8931be85 --- /dev/null +++ b/proxy/src/protocol2.rs @@ -0,0 +1,479 @@ +//! Proxy Protocol V2 implementation + +use std::{ + future::poll_fn, + future::Future, + io, + net::SocketAddr, + pin::{pin, Pin}, + task::{ready, Context, Poll}, +}; + +use bytes::{Buf, BytesMut}; +use hyper::server::conn::{AddrIncoming, AddrStream}; +use pin_project_lite::pin_project; +use tls_listener::AsyncAccept; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf}; + +pub struct ProxyProtocolAccept { + pub incoming: AddrIncoming, +} + +pin_project! { + pub struct WithClientIp { + #[pin] + pub inner: T, + buf: BytesMut, + tlv_bytes: u16, + state: ProxyParse, + } +} + +#[derive(Clone, PartialEq, Debug)] +enum ProxyParse { + NotStarted, + + Finished(SocketAddr), + None, +} + +impl AsyncWrite for WithClientIp { + #[inline] + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().inner.poll_write(cx, buf) + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_flush(cx) + } + + #[inline] + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_shutdown(cx) + } + + #[inline] + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + self.project().inner.poll_write_vectored(cx, bufs) + } + + #[inline] + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } +} + +impl WithClientIp { + pub fn new(inner: T) -> Self { + WithClientIp { + inner, + buf: BytesMut::with_capacity(128), + tlv_bytes: 0, + state: ProxyParse::NotStarted, + } + } + + pub fn client_addr(&self) -> Option { + match self.state { + ProxyParse::Finished(socket) => Some(socket), + _ => None, + } + } +} + +impl WithClientIp { + pub async fn wait_for_addr(&mut self) -> io::Result> { + match self.state { + ProxyParse::NotStarted => { + let mut pin = Pin::new(&mut *self); + let addr = poll_fn(|cx| pin.as_mut().poll_client_ip(cx)).await?; + match addr { + Some(addr) => self.state = ProxyParse::Finished(addr), + None => self.state = ProxyParse::None, + } + Ok(addr) + } + ProxyParse::Finished(addr) => Ok(Some(addr)), + ProxyParse::None => Ok(None), + } + } +} + +/// Proxy Protocol Version 2 Header +const HEADER: [u8; 12] = [ + 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, +]; + +impl WithClientIp { + /// implementation of + /// Version 2 (Binary Format) + fn poll_client_ip( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + // The binary header format starts with a constant 12 bytes block containing the protocol signature : + // \x0D \x0A \x0D \x0A \x00 \x0D \x0A \x51 \x55 \x49 \x54 \x0A + while self.buf.len() < 16 { + let mut this = self.as_mut().project(); + let bytes_read = pin!(this.inner.read_buf(this.buf)).poll(cx)?; + + // exit for bad header + let len = usize::min(self.buf.len(), HEADER.len()); + if self.buf[..len] != HEADER[..len] { + return Poll::Ready(Ok(None)); + } + + // if no more bytes available then exit + if ready!(bytes_read) == 0 { + return Poll::Ready(Ok(None)); + }; + } + + // The next byte (the 13th one) is the protocol version and command. + // The highest four bits contains the version. As of this specification, it must + // always be sent as \x2 and the receiver must only accept this value. + let vc = self.buf[12]; + let version = vc >> 4; + let command = vc & 0b1111; + if version != 2 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::Other, + "invalid proxy protocol version. expected version 2", + ))); + } + match command { + // the connection was established on purpose by the proxy + // without being relayed. The connection endpoints are the sender and the + // receiver. Such connections exist when the proxy sends health-checks to the + // server. The receiver must accept this connection as valid and must use the + // real connection endpoints and discard the protocol block including the + // family which is ignored. + 0 => {} + // the connection was established on behalf of another node, + // and reflects the original connection endpoints. The receiver must then use + // the information provided in the protocol block to get original the address. + 1 => {} + // other values are unassigned and must not be emitted by senders. Receivers + // must drop connections presenting unexpected values here. + _ => { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::Other, + "invalid proxy protocol command. expected local (0) or proxy (1)", + ))) + } + }; + + // The 14th byte contains the transport protocol and address family. The highest 4 + // bits contain the address family, the lowest 4 bits contain the protocol. + let ft = self.buf[13]; + let address_length = match ft { + // - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET + // protocol family. Address length is 2*4 + 2*2 = 12 bytes. + // - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET + // protocol family. Address length is 2*4 + 2*2 = 12 bytes. + 0x11 | 0x12 => 12, + // - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6 + // protocol family. Address length is 2*16 + 2*2 = 36 bytes. + // - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6 + // protocol family. Address length is 2*16 + 2*2 = 36 bytes. + 0x21 | 0x22 => 36, + // unspecified or unix stream. ignore the addresses + _ => 0, + }; + + // The 15th and 16th bytes is the address length in bytes in network endian order. + // It is used so that the receiver knows how many address bytes to skip even when + // it does not implement the presented protocol. Thus the length of the protocol + // header in bytes is always exactly 16 + this value. When a sender presents a + // LOCAL connection, it should not present any address so it sets this field to + // zero. Receivers MUST always consider this field to skip the appropriate number + // of bytes and must not assume zero is presented for LOCAL connections. When a + // receiver accepts an incoming connection showing an UNSPEC address family or + // protocol, it may or may not decide to log the address information if present. + let remaining_length = u16::from_be_bytes(self.buf[14..16].try_into().unwrap()); + if remaining_length < address_length { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::Other, + "invalid proxy protocol length. not enough to fit requested IP addresses", + ))); + } + + while self.buf.len() < 16 + address_length as usize { + let mut this = self.as_mut().project(); + if ready!(pin!(this.inner.read_buf(this.buf)).poll(cx)?) == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "stream closed while waiting for proxy protocol addresses", + ))); + } + } + + let this = self.as_mut().project(); + + // we are sure this is a proxy protocol v2 entry and we have read all the bytes we need + // discard the header we have parsed + this.buf.advance(16); + + // Starting from the 17th byte, addresses are presented in network byte order. + // The address order is always the same : + // - source layer 3 address in network byte order + // - destination layer 3 address in network byte order + // - source layer 4 address if any, in network byte order (port) + // - destination layer 4 address if any, in network byte order (port) + let addresses = this.buf.split_to(address_length as usize); + let socket = match address_length { + 12 => { + let src_addr: [u8; 4] = addresses[0..4].try_into().unwrap(); + let src_port = u16::from_be_bytes(addresses[8..10].try_into().unwrap()); + Some(SocketAddr::from((src_addr, src_port))) + } + 36 => { + let src_addr: [u8; 16] = addresses[0..16].try_into().unwrap(); + let src_port = u16::from_be_bytes(addresses[32..34].try_into().unwrap()); + Some(SocketAddr::from((src_addr, src_port))) + } + _ => None, + }; + + *this.tlv_bytes = remaining_length - address_length; + self.as_mut().skip_tlv_inner(); + + Poll::Ready(Ok(socket)) + } + + #[cold] + fn read_ip(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let ip = ready!(self.as_mut().poll_client_ip(cx)?); + match ip { + Some(x) => *self.as_mut().project().state = ProxyParse::Finished(x), + None => *self.as_mut().project().state = ProxyParse::None, + } + Poll::Ready(Ok(())) + } + + #[cold] + fn skip_tlv(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.as_mut().project(); + // we know that this.buf is empty + debug_assert_eq!(this.buf.len(), 0); + + this.buf.reserve((*this.tlv_bytes).clamp(0, 1024) as usize); + ready!(pin!(this.inner.read_buf(this.buf)).poll(cx)?); + self.skip_tlv_inner(); + + Poll::Ready(Ok(())) + } + + fn skip_tlv_inner(self: Pin<&mut Self>) { + let tlv_bytes_read = match u16::try_from(self.buf.len()) { + // we read more than u16::MAX therefore we must have read the full tlv_bytes + Err(_) => self.tlv_bytes, + // we might not have read the full tlv bytes yet + Ok(n) => u16::min(n, self.tlv_bytes), + }; + let this = self.project(); + *this.tlv_bytes -= tlv_bytes_read; + this.buf.advance(tlv_bytes_read as usize); + } +} + +impl AsyncRead for WithClientIp { + #[inline] + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + // I'm assuming these 3 comparisons will be easy to branch predict. + // especially with the cold attributes + // which should make this read wrapper almost invisible + + if let ProxyParse::NotStarted = self.state { + ready!(self.as_mut().read_ip(cx)?); + } + + while self.tlv_bytes > 0 { + ready!(self.as_mut().skip_tlv(cx)?) + } + + let this = self.project(); + if this.buf.is_empty() { + this.inner.poll_read(cx, buf) + } else { + // we know that tlv_bytes is 0 + debug_assert_eq!(*this.tlv_bytes, 0); + + let write = usize::min(this.buf.len(), buf.remaining()); + let slice = this.buf.split_to(write).freeze(); + buf.put_slice(&slice); + + // reset the allocation so it can be freed + if this.buf.is_empty() { + *this.buf = BytesMut::new(); + } + + Poll::Ready(Ok(())) + } + } +} + +impl AsyncAccept for ProxyProtocolAccept { + type Connection = WithClientIp; + + type Error = io::Error; + + fn poll_accept( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?); + let Some(conn) = conn else { + return Poll::Ready(None); + }; + + Poll::Ready(Some(Ok(WithClientIp::new(conn)))) + } +} + +#[cfg(test)] +mod tests { + use std::pin::pin; + + use tokio::io::AsyncReadExt; + + use crate::protocol2::{ProxyParse, WithClientIp}; + + #[tokio::test] + async fn test_ipv4() { + let header = super::HEADER + // Proxy command, IPV4 | TCP + .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice()) + // 12 + 3 bytes + .chain([0, 15].as_slice()) + // src ip + .chain([127, 0, 0, 1].as_slice()) + // dst ip + .chain([192, 168, 0, 1].as_slice()) + // src port + .chain([255, 255].as_slice()) + // dst port + .chain([1, 1].as_slice()) + // TLV + .chain([1, 2, 3].as_slice()); + + let extra_data = [0x55; 256]; + + let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice()))); + + let mut bytes = vec![]; + read.read_to_end(&mut bytes).await.unwrap(); + + assert_eq!(bytes, extra_data); + assert_eq!( + read.state, + ProxyParse::Finished(([127, 0, 0, 1], 65535).into()) + ); + } + + #[tokio::test] + async fn test_ipv6() { + let header = super::HEADER + // Proxy command, IPV6 | UDP + .chain([(2 << 4) | 1, (2 << 4) | 2].as_slice()) + // 36 + 3 bytes + .chain([0, 39].as_slice()) + // src ip + .chain([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0].as_slice()) + // dst ip + .chain([0, 15, 1, 14, 2, 13, 3, 12, 4, 11, 5, 10, 6, 9, 7, 8].as_slice()) + // src port + .chain([1, 1].as_slice()) + // dst port + .chain([255, 255].as_slice()) + // TLV + .chain([1, 2, 3].as_slice()); + + let extra_data = [0x55; 256]; + + let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice()))); + + let mut bytes = vec![]; + read.read_to_end(&mut bytes).await.unwrap(); + + assert_eq!(bytes, extra_data); + assert_eq!( + read.state, + ProxyParse::Finished( + ([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into() + ) + ); + } + + #[tokio::test] + async fn test_invalid() { + let data = [0x55; 256]; + + let mut read = pin!(WithClientIp::new(data.as_slice())); + + let mut bytes = vec![]; + read.read_to_end(&mut bytes).await.unwrap(); + assert_eq!(bytes, data); + assert_eq!(read.state, ProxyParse::None); + } + + #[tokio::test] + async fn test_short() { + let data = [0x55; 10]; + + let mut read = pin!(WithClientIp::new(data.as_slice())); + + let mut bytes = vec![]; + read.read_to_end(&mut bytes).await.unwrap(); + assert_eq!(bytes, data); + assert_eq!(read.state, ProxyParse::None); + } + + #[tokio::test] + async fn test_large_tlv() { + let tlv = vec![0x55; 32768]; + let len = (12 + tlv.len() as u16).to_be_bytes(); + + let header = super::HEADER + // Proxy command, Inet << 4 | Stream + .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice()) + // 12 + 3 bytes + .chain(len.as_slice()) + // src ip + .chain([55, 56, 57, 58].as_slice()) + // dst ip + .chain([192, 168, 0, 1].as_slice()) + // src port + .chain([255, 255].as_slice()) + // dst port + .chain([1, 1].as_slice()) + // TLV + .chain(tlv.as_slice()); + + let extra_data = [0xaa; 256]; + + let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice()))); + + let mut bytes = vec![]; + read.read_to_end(&mut bytes).await.unwrap(); + + assert_eq!(bytes, extra_data); + assert_eq!( + read.state, + ProxyParse::Finished(([55, 56, 57, 58], 65535).into()) + ); + } +} diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 0267d767ee..66ce2e5fd0 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -7,6 +7,7 @@ use crate::{ compute::{self, PostgresConnection}, config::{ProxyConfig, TlsConfig}, console::{self, errors::WakeComputeError, messages::MetricsAuxInfo, Api}, + protocol2::WithClientIp, stream::{PqStream, Stream}, }; use anyhow::{bail, Context}; @@ -100,7 +101,7 @@ pub async fn task_main( loop { tokio::select! { accept_result = listener.accept() => { - let (socket, peer_addr) = accept_result?; + let (socket, _) = accept_result?; let session_id = uuid::Uuid::new_v4(); let cancel_map = Arc::clone(&cancel_map); @@ -108,13 +109,19 @@ pub async fn task_main( async move { info!("accepted postgres client connection"); + let mut socket = WithClientIp::new(socket); + if let Some(ip) = socket.wait_for_addr().await? { + tracing::Span::current().record("peer_addr", &tracing::field::display(ip)); + } + socket + .inner .set_nodelay(true) .context("failed to set socket option")?; handle_client(config, &cancel_map, session_id, socket, ClientMode::Tcp).await } - .instrument(info_span!("handle_client", ?session_id, %peer_addr)) + .instrument(info_span!("handle_client", ?session_id, peer_addr = tracing::field::Empty)) .unwrap_or_else(move |e| { // Acknowledge that the task has finished with an error. error!(?session_id, "per-client task finished with an error: {e:#}"); diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 5653ec94dc..99ec8fb090 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -137,6 +137,7 @@ async fn dummy_proxy( auth: impl TestAuth + Send, ) -> anyhow::Result<()> { let cancel_map = CancelMap::default(); + let client = WithClientIp::new(client); let (mut stream, _params) = handshake(client, tls.as_ref(), &cancel_map) .await? .context("handshake failed")?;