mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-08 05:52:55 +00:00
update proxy protocol parsing to not a rw wrapper (#12035)
## Problem I believe in all environments we now specify either required/rejected for proxy-protocol V2 as required. We no longer rely on the supported flow. This means we no longer need to keep around read bytes incase they're not in a header. While I designed ChainRW to be fast (the hot path with an empty buffer is very easy to branch predict), it's still unnecessary. ## Summary of changes * Remove the ChainRW wrapper * Refactor how we read the proxy-protocol header using read_exact. Slightly worse perf but it's hardly significant. * Don't try and parse the header if it's rejected.
This commit is contained in:
@@ -221,8 +221,7 @@ struct ProxyCliArgs {
|
||||
is_private_access_proxy: bool,
|
||||
|
||||
/// Configure whether all incoming requests have a Proxy Protocol V2 packet.
|
||||
// TODO(conradludgate): switch default to rejected or required once we've updated all deployments
|
||||
#[clap(value_enum, long, default_value_t = ProxyProtocolV2::Supported)]
|
||||
#[clap(value_enum, long, default_value_t = ProxyProtocolV2::Rejected)]
|
||||
proxy_protocol_v2: ProxyProtocolV2,
|
||||
|
||||
/// Time the proxy waits for the webauth session to be confirmed by the control plane.
|
||||
|
||||
@@ -39,8 +39,6 @@ pub struct ComputeConfig {
|
||||
pub enum ProxyProtocolV2 {
|
||||
/// Connection will error if PROXY protocol v2 header is missing
|
||||
Required,
|
||||
/// Connection will parse PROXY protocol v2 header, but accept the connection if it's missing.
|
||||
Supported,
|
||||
/// Connection will error if PROXY protocol v2 header is provided
|
||||
Rejected,
|
||||
}
|
||||
|
||||
@@ -54,30 +54,24 @@ pub async fn task_main(
|
||||
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||
|
||||
connections.spawn(async move {
|
||||
let (socket, peer_addr) = match read_proxy_protocol(socket).await {
|
||||
Err(e) => {
|
||||
error!("per-client task finished with an error: {e:#}");
|
||||
return;
|
||||
let (socket, conn_info) = match config.proxy_protocol_v2 {
|
||||
ProxyProtocolV2::Required => {
|
||||
match read_proxy_protocol(socket).await {
|
||||
Err(e) => {
|
||||
error!("per-client task finished with an error: {e:#}");
|
||||
return;
|
||||
}
|
||||
// our load balancers will not send any more data. let's just exit immediately
|
||||
Ok((_socket, ConnectHeader::Local)) => {
|
||||
debug!("healthcheck received");
|
||||
return;
|
||||
}
|
||||
Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
|
||||
}
|
||||
}
|
||||
// our load balancers will not send any more data. let's just exit immediately
|
||||
Ok((_socket, ConnectHeader::Local)) => {
|
||||
debug!("healthcheck received");
|
||||
return;
|
||||
}
|
||||
Ok((_socket, ConnectHeader::Missing))
|
||||
if config.proxy_protocol_v2 == ProxyProtocolV2::Required =>
|
||||
{
|
||||
error!("missing required proxy protocol header");
|
||||
return;
|
||||
}
|
||||
Ok((_socket, ConnectHeader::Proxy(_)))
|
||||
if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected =>
|
||||
{
|
||||
error!("proxy protocol header not supported");
|
||||
return;
|
||||
}
|
||||
Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
|
||||
Ok((socket, ConnectHeader::Missing)) => (
|
||||
// ignore the header - it cannot be confused for a postgres or http connection so will
|
||||
// error later.
|
||||
ProxyProtocolV2::Rejected => (
|
||||
socket,
|
||||
ConnectionInfo {
|
||||
addr: peer_addr,
|
||||
@@ -86,7 +80,7 @@ pub async fn task_main(
|
||||
),
|
||||
};
|
||||
|
||||
match socket.inner.set_nodelay(true) {
|
||||
match socket.set_nodelay(true) {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
error!(
|
||||
@@ -98,7 +92,7 @@ pub async fn task_main(
|
||||
|
||||
let ctx = RequestContext::new(
|
||||
session_id,
|
||||
peer_addr,
|
||||
conn_info,
|
||||
crate::metrics::Protocol::Tcp,
|
||||
&config.region,
|
||||
);
|
||||
|
||||
@@ -4,60 +4,13 @@
|
||||
use core::fmt;
|
||||
use std::io;
|
||||
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use bytes::{Buf, Bytes, BytesMut};
|
||||
use pin_project_lite::pin_project;
|
||||
use bytes::Buf;
|
||||
use smol_str::SmolStr;
|
||||
use strum_macros::FromRepr;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt};
|
||||
use zerocopy::{FromBytes, Immutable, KnownLayout, Unaligned, network_endian};
|
||||
|
||||
pin_project! {
|
||||
/// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough
|
||||
pub(crate) struct ChainRW<T> {
|
||||
#[pin]
|
||||
pub(crate) inner: T,
|
||||
buf: BytesMut,
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncWrite> AsyncWrite for ChainRW<T> {
|
||||
#[inline]
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
self.project().inner.poll_write(cx, buf)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
self.project().inner.poll_flush(cx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
self.project().inner.poll_shutdown(cx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[io::IoSlice<'_>],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
self.project().inner.poll_write_vectored(cx, bufs)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
self.inner.is_write_vectored()
|
||||
}
|
||||
}
|
||||
|
||||
/// Proxy Protocol Version 2 Header
|
||||
const SIGNATURE: [u8; 12] = [
|
||||
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
|
||||
@@ -79,7 +32,6 @@ pub struct ConnectionInfo {
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Debug)]
|
||||
pub enum ConnectHeader {
|
||||
Missing,
|
||||
Local,
|
||||
Proxy(ConnectionInfo),
|
||||
}
|
||||
@@ -106,47 +58,24 @@ pub enum ConnectionInfoExtra {
|
||||
|
||||
pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
|
||||
mut read: T,
|
||||
) -> std::io::Result<(ChainRW<T>, ConnectHeader)> {
|
||||
let mut buf = BytesMut::with_capacity(128);
|
||||
let header = loop {
|
||||
let bytes_read = read.read_buf(&mut buf).await?;
|
||||
|
||||
// exit for bad header signature
|
||||
let len = usize::min(buf.len(), SIGNATURE.len());
|
||||
if buf[..len] != SIGNATURE[..len] {
|
||||
return Ok((ChainRW { inner: read, buf }, ConnectHeader::Missing));
|
||||
}
|
||||
|
||||
// if no more bytes available then exit
|
||||
if bytes_read == 0 {
|
||||
return Ok((ChainRW { inner: read, buf }, ConnectHeader::Missing));
|
||||
}
|
||||
|
||||
// check if we have enough bytes to continue
|
||||
if let Some(header) = buf.try_get::<ProxyProtocolV2Header>() {
|
||||
break header;
|
||||
}
|
||||
};
|
||||
|
||||
let remaining_length = usize::from(header.len.get());
|
||||
|
||||
while buf.len() < remaining_length {
|
||||
if read.read_buf(&mut buf).await? == 0 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
"stream closed while waiting for proxy protocol addresses",
|
||||
));
|
||||
}
|
||||
) -> std::io::Result<(T, ConnectHeader)> {
|
||||
let mut header = [0; size_of::<ProxyProtocolV2Header>()];
|
||||
read.read_exact(&mut header).await?;
|
||||
let header: ProxyProtocolV2Header = zerocopy::transmute!(header);
|
||||
if header.signature != SIGNATURE {
|
||||
return Err(std::io::Error::other("invalid proxy protocol header"));
|
||||
}
|
||||
let payload = buf.split_to(remaining_length);
|
||||
|
||||
let res = process_proxy_payload(header, payload)?;
|
||||
Ok((ChainRW { inner: read, buf }, res))
|
||||
let mut payload = vec![0; usize::from(header.len.get())];
|
||||
read.read_exact(&mut payload).await?;
|
||||
|
||||
let res = process_proxy_payload(header, &payload)?;
|
||||
Ok((read, res))
|
||||
}
|
||||
|
||||
fn process_proxy_payload(
|
||||
header: ProxyProtocolV2Header,
|
||||
mut payload: BytesMut,
|
||||
mut payload: &[u8],
|
||||
) -> std::io::Result<ConnectHeader> {
|
||||
match header.version_and_command {
|
||||
// the connection was established on purpose by the proxy
|
||||
@@ -162,13 +91,12 @@ fn process_proxy_payload(
|
||||
PROXY_V2 => {}
|
||||
// other values are unassigned and must not be emitted by senders. Receivers
|
||||
// must drop connections presenting unexpected values here.
|
||||
#[rustfmt::skip] // https://github.com/rust-lang/rustfmt/issues/6384
|
||||
_ => return Err(io::Error::other(
|
||||
format!(
|
||||
_ => {
|
||||
return Err(io::Error::other(format!(
|
||||
"invalid proxy protocol command 0x{:02X}. expected local (0x20) or proxy (0x21)",
|
||||
header.version_and_command
|
||||
),
|
||||
)),
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
let size_err =
|
||||
@@ -206,7 +134,7 @@ fn process_proxy_payload(
|
||||
}
|
||||
let subtype = tlv.value.get_u8();
|
||||
match Pp2AwsType::from_repr(subtype) {
|
||||
Some(Pp2AwsType::VpceId) => match std::str::from_utf8(&tlv.value) {
|
||||
Some(Pp2AwsType::VpceId) => match std::str::from_utf8(tlv.value) {
|
||||
Ok(s) => {
|
||||
extra = Some(ConnectionInfoExtra::Aws { vpce_id: s.into() });
|
||||
}
|
||||
@@ -282,65 +210,28 @@ enum Pp2AzureType {
|
||||
PrivateEndpointLinkId = 0x01,
|
||||
}
|
||||
|
||||
impl<T: AsyncRead> AsyncRead for ChainRW<T> {
|
||||
#[inline]
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
if self.buf.is_empty() {
|
||||
self.project().inner.poll_read(cx, buf)
|
||||
} else {
|
||||
self.read_from_buf(buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncRead> ChainRW<T> {
|
||||
#[cold]
|
||||
fn read_from_buf(self: Pin<&mut Self>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
|
||||
debug_assert!(!self.buf.is_empty());
|
||||
let this = self.project();
|
||||
|
||||
let write = usize::min(this.buf.len(), buf.remaining());
|
||||
let slice = this.buf.split_to(write).freeze();
|
||||
buf.put_slice(&slice);
|
||||
|
||||
// reset the allocation so it can be freed
|
||||
if this.buf.is_empty() {
|
||||
*this.buf = BytesMut::new();
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Tlv {
|
||||
struct Tlv<'a> {
|
||||
kind: u8,
|
||||
value: Bytes,
|
||||
value: &'a [u8],
|
||||
}
|
||||
|
||||
fn read_tlv(b: &mut BytesMut) -> Option<Tlv> {
|
||||
fn read_tlv<'a>(b: &mut &'a [u8]) -> Option<Tlv<'a>> {
|
||||
let tlv_header = b.try_get::<TlvHeader>()?;
|
||||
let len = usize::from(tlv_header.len.get());
|
||||
if b.len() < len {
|
||||
return None;
|
||||
}
|
||||
Some(Tlv {
|
||||
kind: tlv_header.kind,
|
||||
value: b.split_to(len).freeze(),
|
||||
value: b.split_off(..len)?,
|
||||
})
|
||||
}
|
||||
|
||||
trait BufExt: Sized {
|
||||
fn try_get<T: FromBytes>(&mut self) -> Option<T>;
|
||||
}
|
||||
impl BufExt for BytesMut {
|
||||
impl BufExt for &[u8] {
|
||||
fn try_get<T: FromBytes>(&mut self) -> Option<T> {
|
||||
let (res, _) = T::read_from_prefix(self).ok()?;
|
||||
self.advance(size_of::<T>());
|
||||
let (res, rest) = T::read_from_prefix(self).ok()?;
|
||||
*self = rest;
|
||||
Some(res)
|
||||
}
|
||||
}
|
||||
@@ -481,27 +372,19 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[should_panic = "invalid proxy protocol header"]
|
||||
async fn test_invalid() {
|
||||
let data = [0x55; 256];
|
||||
|
||||
let (mut read, info) = read_proxy_protocol(data.as_slice()).await.unwrap();
|
||||
|
||||
let mut bytes = vec![];
|
||||
read.read_to_end(&mut bytes).await.unwrap();
|
||||
assert_eq!(bytes, data);
|
||||
assert_eq!(info, ConnectHeader::Missing);
|
||||
read_proxy_protocol(data.as_slice()).await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[should_panic = "early eof"]
|
||||
async fn test_short() {
|
||||
let data = [0x55; 10];
|
||||
|
||||
let (mut read, info) = read_proxy_protocol(data.as_slice()).await.unwrap();
|
||||
|
||||
let mut bytes = vec![];
|
||||
read.read_to_end(&mut bytes).await.unwrap();
|
||||
assert_eq!(bytes, data);
|
||||
assert_eq!(info, ConnectHeader::Missing);
|
||||
read_proxy_protocol(data.as_slice()).await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -102,30 +102,24 @@ pub async fn task_main(
|
||||
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
|
||||
|
||||
connections.spawn(async move {
|
||||
let (socket, conn_info) = match read_proxy_protocol(socket).await {
|
||||
Err(e) => {
|
||||
warn!("per-client task finished with an error: {e:#}");
|
||||
return;
|
||||
let (socket, conn_info) = match config.proxy_protocol_v2 {
|
||||
ProxyProtocolV2::Required => {
|
||||
match read_proxy_protocol(socket).await {
|
||||
Err(e) => {
|
||||
warn!("per-client task finished with an error: {e:#}");
|
||||
return;
|
||||
}
|
||||
// our load balancers will not send any more data. let's just exit immediately
|
||||
Ok((_socket, ConnectHeader::Local)) => {
|
||||
debug!("healthcheck received");
|
||||
return;
|
||||
}
|
||||
Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
|
||||
}
|
||||
}
|
||||
// our load balancers will not send any more data. let's just exit immediately
|
||||
Ok((_socket, ConnectHeader::Local)) => {
|
||||
debug!("healthcheck received");
|
||||
return;
|
||||
}
|
||||
Ok((_socket, ConnectHeader::Missing))
|
||||
if config.proxy_protocol_v2 == ProxyProtocolV2::Required =>
|
||||
{
|
||||
warn!("missing required proxy protocol header");
|
||||
return;
|
||||
}
|
||||
Ok((_socket, ConnectHeader::Proxy(_)))
|
||||
if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected =>
|
||||
{
|
||||
warn!("proxy protocol header not supported");
|
||||
return;
|
||||
}
|
||||
Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
|
||||
Ok((socket, ConnectHeader::Missing)) => (
|
||||
// ignore the header - it cannot be confused for a postgres or http connection so will
|
||||
// error later.
|
||||
ProxyProtocolV2::Rejected => (
|
||||
socket,
|
||||
ConnectionInfo {
|
||||
addr: peer_addr,
|
||||
@@ -134,7 +128,7 @@ pub async fn task_main(
|
||||
),
|
||||
};
|
||||
|
||||
match socket.inner.set_nodelay(true) {
|
||||
match socket.set_nodelay(true) {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
error!(
|
||||
|
||||
@@ -173,7 +173,6 @@ async fn dummy_proxy(
|
||||
tls: Option<TlsConfig>,
|
||||
auth: impl TestAuth + Send,
|
||||
) -> anyhow::Result<()> {
|
||||
let (client, _) = read_proxy_protocol(client).await?;
|
||||
let mut stream = match handshake(&RequestContext::test(), client, tls.as_ref(), false).await? {
|
||||
HandshakeData::Startup(stream, _) => stream,
|
||||
HandshakeData::Cancel(_) => bail!("cancellation not supported"),
|
||||
|
||||
@@ -49,7 +49,7 @@ use crate::config::{ProxyConfig, ProxyProtocolV2};
|
||||
use crate::context::RequestContext;
|
||||
use crate::ext::TaskExt;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::protocol2::{ChainRW, ConnectHeader, ConnectionInfo, read_proxy_protocol};
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
|
||||
use crate::proxy::run_until_cancelled;
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::serverless::backend::PoolingBackend;
|
||||
@@ -207,12 +207,12 @@ pub(crate) type AsyncRW = Pin<Box<dyn AsyncReadWrite>>;
|
||||
|
||||
#[async_trait]
|
||||
trait MaybeTlsAcceptor: Send + Sync + 'static {
|
||||
async fn accept(&self, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW>;
|
||||
async fn accept(&self, conn: TcpStream) -> std::io::Result<AsyncRW>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl MaybeTlsAcceptor for &'static ArcSwapOption<crate::config::TlsConfig> {
|
||||
async fn accept(&self, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW> {
|
||||
async fn accept(&self, conn: TcpStream) -> std::io::Result<AsyncRW> {
|
||||
match &*self.load() {
|
||||
Some(config) => Ok(Box::pin(
|
||||
TlsAcceptor::from(config.http_config.clone())
|
||||
@@ -235,33 +235,30 @@ async fn connection_startup(
|
||||
peer_addr: SocketAddr,
|
||||
) -> Option<(AsyncRW, ConnectionInfo)> {
|
||||
// handle PROXY protocol
|
||||
let (conn, peer) = match read_proxy_protocol(conn).await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::warn!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}");
|
||||
return None;
|
||||
let (conn, conn_info) = match config.proxy_protocol_v2 {
|
||||
ProxyProtocolV2::Required => {
|
||||
match read_proxy_protocol(conn).await {
|
||||
Err(e) => {
|
||||
warn!("per-client task finished with an error: {e:#}");
|
||||
return None;
|
||||
}
|
||||
// our load balancers will not send any more data. let's just exit immediately
|
||||
Ok((_conn, ConnectHeader::Local)) => {
|
||||
tracing::debug!("healthcheck received");
|
||||
return None;
|
||||
}
|
||||
Ok((conn, ConnectHeader::Proxy(info))) => (conn, info),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let conn_info = match peer {
|
||||
// our load balancers will not send any more data. let's just exit immediately
|
||||
ConnectHeader::Local => {
|
||||
tracing::debug!("healthcheck received");
|
||||
return None;
|
||||
}
|
||||
ConnectHeader::Missing if config.proxy_protocol_v2 == ProxyProtocolV2::Required => {
|
||||
tracing::warn!("missing required proxy protocol header");
|
||||
return None;
|
||||
}
|
||||
ConnectHeader::Proxy(_) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => {
|
||||
tracing::warn!("proxy protocol header not supported");
|
||||
return None;
|
||||
}
|
||||
ConnectHeader::Proxy(info) => info,
|
||||
ConnectHeader::Missing => ConnectionInfo {
|
||||
addr: peer_addr,
|
||||
extra: None,
|
||||
},
|
||||
// ignore the header - it cannot be confused for a postgres or http connection so will
|
||||
// error later.
|
||||
ProxyProtocolV2::Rejected => (
|
||||
conn,
|
||||
ConnectionInfo {
|
||||
addr: peer_addr,
|
||||
extra: None,
|
||||
},
|
||||
),
|
||||
};
|
||||
|
||||
let has_private_peer_addr = match conn_info.addr.ip() {
|
||||
|
||||
Reference in New Issue
Block a user