Compare commits

...

3 Commits

Author SHA1 Message Date
Conrad Ludgate
bea8269a17 zerocopy parsing 2024-11-04 13:07:27 +00:00
Conrad Ludgate
3500a758af parse ip addr for consistency 2024-11-04 11:19:36 +00:00
Conrad Ludgate
3b3c2da57f [proxy]: parse proxy protocol TLVs with aws/azure support 2024-11-01 18:11:08 +00:00
9 changed files with 363 additions and 86 deletions

35
Cargo.lock generated
View File

@@ -34,7 +34,7 @@ dependencies = [
"getrandom 0.2.11",
"once_cell",
"version_check",
"zerocopy",
"zerocopy 0.7.31",
]
[[package]]
@@ -951,7 +951,7 @@ dependencies = [
"bitflags 2.4.1",
"cexpr",
"clang-sys",
"itertools 0.10.5",
"itertools 0.12.1",
"lazy_static",
"lazycell",
"log",
@@ -974,7 +974,7 @@ dependencies = [
"bitflags 2.4.1",
"cexpr",
"clang-sys",
"itertools 0.10.5",
"itertools 0.12.1",
"log",
"prettyplease",
"proc-macro2",
@@ -4314,7 +4314,7 @@ checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15"
dependencies = [
"bytes",
"heck 0.5.0",
"itertools 0.10.5",
"itertools 0.12.1",
"log",
"multimap",
"once_cell",
@@ -4334,7 +4334,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9552f850d5f0964a4e4d0bf306459ac29323ddfbae05e35a7c0d35cb0803cc5"
dependencies = [
"anyhow",
"itertools 0.10.5",
"itertools 0.12.1",
"proc-macro2",
"quote",
"syn 2.0.52",
@@ -4433,6 +4433,8 @@ dependencies = [
"smallvec",
"smol_str",
"socket2",
"strum",
"strum_macros",
"subtle",
"thiserror",
"tikv-jemalloc-ctl",
@@ -4455,6 +4457,7 @@ dependencies = [
"walkdir",
"workspace_hack",
"x509-parser",
"zerocopy 0.8.8",
]
[[package]]
@@ -7550,7 +7553,16 @@ version = "0.7.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1c4061bedbb353041c12f413700357bec76df2c7e2ca8e4df8bac24c6bf68e3d"
dependencies = [
"zerocopy-derive",
"zerocopy-derive 0.7.31",
]
[[package]]
name = "zerocopy"
version = "0.8.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a4e33e6dce36f2adba29746927f8e848ba70989fdb61c772773bbdda8b5d6a7"
dependencies = [
"zerocopy-derive 0.8.8",
]
[[package]]
@@ -7564,6 +7576,17 @@ dependencies = [
"syn 2.0.52",
]
[[package]]
name = "zerocopy-derive"
version = "0.8.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3cd137b4cc21bde6ecce3bbbb3350130872cda0be2c6888874279ea76e17d4c1"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.52",
]
[[package]]
name = "zeroize"
version = "1.7.0"

View File

@@ -74,6 +74,8 @@ sha2 = { workspace = true, features = ["asm", "oid"] }
smol_str.workspace = true
smallvec.workspace = true
socket2.workspace = true
strum.workspace = true
strum_macros.workspace = true
subtle.workspace = true
thiserror.workspace = true
tikv-jemallocator.workspace = true
@@ -96,6 +98,7 @@ rustls-native-certs.workspace = true
x509-parser.workspace = true
postgres-protocol.workspace = true
redis.workspace = true
zerocopy = { version = "0.8", features = ["derive"] }
# jwt stuff
jose-jwa = "0.1.2"

View File

@@ -13,6 +13,7 @@ use itertools::Itertools;
use proxy::config::TlsServerEndPoint;
use proxy::context::RequestMonitoring;
use proxy::metrics::{Metrics, ThreadPoolMetrics};
use proxy::protocol2::ConnectionInfo;
use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled, ErrorSource};
use proxy::stream::{PqStream, Stream};
use rustls::crypto::aws_lc_rs;
@@ -179,7 +180,10 @@ async fn task_main(
info!(%peer_addr, "serving");
let ctx = RequestMonitoring::new(
session_id,
peer_addr.ip(),
ConnectionInfo {
addr: peer_addr,
extra: None,
},
proxy::metrics::Protocol::SniRouter,
"sni",
);

View File

@@ -11,7 +11,7 @@ use crate::config::{ProxyConfig, ProxyProtocolV2};
use crate::context::RequestMonitoring;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::protocol2::read_proxy_protocol;
use crate::protocol2::{read_proxy_protocol, ConnectionInfo};
use crate::proxy::connect_compute::{connect_to_compute, TcpMechanism};
use crate::proxy::handshake::{handshake, HandshakeData};
use crate::proxy::passthrough::ProxyPassthrough;
@@ -65,8 +65,8 @@ pub async fn task_main(
error!("proxy protocol header not supported");
return;
}
Ok((socket, Some(addr))) => (socket, addr.ip()),
Ok((socket, None)) => (socket, peer_addr.ip()),
Ok((socket, Some(info))) => (socket, info),
Ok((socket, None)) => (socket, ConnectionInfo{ addr: peer_addr, extra: None }),
};
match socket.inner.set_nodelay(true) {

View File

@@ -19,6 +19,7 @@ use crate::intern::{BranchIdInt, ProjectIdInt};
use crate::metrics::{
ConnectOutcome, InvalidEndpointsGroup, LatencyTimer, Metrics, Protocol, Waiting,
};
use crate::protocol2::ConnectionInfo;
use crate::types::{DbName, EndpointId, RoleName};
pub mod parquet;
@@ -40,7 +41,7 @@ pub struct RequestMonitoring(
);
struct RequestMonitoringInner {
pub(crate) peer_addr: IpAddr,
pub(crate) conn_info: ConnectionInfo,
pub(crate) session_id: Uuid,
pub(crate) protocol: Protocol,
first_packet: chrono::DateTime<Utc>,
@@ -84,7 +85,7 @@ impl Clone for RequestMonitoring {
fn clone(&self) -> Self {
let inner = self.0.try_lock().expect("should not deadlock");
let new = RequestMonitoringInner {
peer_addr: inner.peer_addr,
conn_info: inner.conn_info.clone(),
session_id: inner.session_id,
protocol: inner.protocol,
first_packet: inner.first_packet,
@@ -117,7 +118,7 @@ impl Clone for RequestMonitoring {
impl RequestMonitoring {
pub fn new(
session_id: Uuid,
peer_addr: IpAddr,
conn_info: ConnectionInfo,
protocol: Protocol,
region: &'static str,
) -> Self {
@@ -125,13 +126,13 @@ impl RequestMonitoring {
"connect_request",
%protocol,
?session_id,
%peer_addr,
%conn_info,
ep = tracing::field::Empty,
role = tracing::field::Empty,
);
let inner = RequestMonitoringInner {
peer_addr,
conn_info,
session_id,
protocol,
first_packet: Utc::now(),
@@ -162,7 +163,11 @@ impl RequestMonitoring {
#[cfg(test)]
pub(crate) fn test() -> Self {
RequestMonitoring::new(Uuid::now_v7(), [127, 0, 0, 1].into(), Protocol::Tcp, "test")
use std::net::SocketAddr;
let ip = IpAddr::from([127, 0, 0, 1]);
let addr = SocketAddr::new(ip, 5432);
let conn_info = ConnectionInfo { addr, extra: None };
RequestMonitoring::new(Uuid::now_v7(), conn_info, Protocol::Tcp, "test")
}
pub(crate) fn console_application_name(&self) -> String {
@@ -286,7 +291,12 @@ impl RequestMonitoring {
}
pub(crate) fn peer_addr(&self) -> IpAddr {
self.0.try_lock().expect("should not deadlock").peer_addr
self.0
.try_lock()
.expect("should not deadlock")
.conn_info
.addr
.ip()
}
pub(crate) fn cold_start_info(&self) -> ColdStartInfo {
@@ -362,7 +372,7 @@ impl RequestMonitoringInner {
}
fn has_private_peer_addr(&self) -> bool {
match self.peer_addr {
match self.conn_info.addr.ip() {
IpAddr::V4(ip) => ip.is_private(),
IpAddr::V6(_) => false,
}

View File

@@ -121,7 +121,7 @@ impl From<&RequestMonitoringInner> for RequestData {
fn from(value: &RequestMonitoringInner) -> Self {
Self {
session_id: value.session_id,
peer_addr: value.peer_addr.to_string(),
peer_addr: value.conn_info.addr.ip().to_string(),
timestamp: value.first_packet.naive_utc(),
username: value.user.as_deref().map(String::from),
application_name: value.application.as_deref().map(String::from),

View File

@@ -1,13 +1,17 @@
//! Proxy Protocol V2 implementation
//! Compatible with <https://www.haproxy.org/download/3.1/doc/proxy-protocol.txt>
use core::fmt;
use std::io;
use std::net::SocketAddr;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::BytesMut;
use bytes::{Buf, Bytes, BytesMut};
use pin_project_lite::pin_project;
use strum_macros::FromRepr;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
use zerocopy::{FromBytes, Immutable, KnownLayout, Unaligned};
pin_project! {
/// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough
@@ -58,9 +62,84 @@ const HEADER: [u8; 12] = [
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
];
#[derive(PartialEq, Eq, Clone, Debug)]
pub struct ConnectionInfo {
pub addr: SocketAddr,
pub extra: Option<ConnectionInfoExtra>,
}
impl fmt::Display for ConnectionInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.extra {
None => self.addr.ip().fmt(f),
Some(ConnectionInfoExtra::Aws { vpce_id }) => {
write!(f, "vpce_id[{vpce_id:?}]:addr[{}]", self.addr.ip())
}
Some(ConnectionInfoExtra::Azure { link_id }) => {
write!(f, "link_id[{link_id}]:addr[{}]", self.addr.ip())
}
}
}
}
#[derive(PartialEq, Eq, Clone, Debug)]
pub enum ConnectionInfoExtra {
Aws { vpce_id: Bytes },
Azure { link_id: u32 },
}
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
#[repr(C)]
struct ProxyProtocolV2Header {
identifier: [u8; 12],
version_and_command: u8,
protocol_and_family: u8,
len: zerocopy::byteorder::network_endian::U16,
}
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
#[repr(C)]
struct ProxyProtocolV2HeaderV4 {
src_addr: NetworkEndianIpv4,
dst_addr: NetworkEndianIpv4,
src_port: zerocopy::byteorder::network_endian::U16,
dst_port: zerocopy::byteorder::network_endian::U16,
}
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
#[repr(C)]
struct ProxyProtocolV2HeaderV6 {
src_addr: NetworkEndianIpv6,
dst_addr: NetworkEndianIpv6,
src_port: zerocopy::byteorder::network_endian::U16,
dst_port: zerocopy::byteorder::network_endian::U16,
}
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
#[repr(transparent)]
struct NetworkEndianIpv4(zerocopy::byteorder::network_endian::U32);
impl NetworkEndianIpv4 {
#[inline]
fn get(self) -> Ipv4Addr {
Ipv4Addr::from_bits(self.0.get())
}
}
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
#[repr(transparent)]
struct NetworkEndianIpv6(zerocopy::byteorder::network_endian::U128);
impl NetworkEndianIpv6 {
#[inline]
fn get(self) -> Ipv6Addr {
Ipv6Addr::from_bits(self.0.get())
}
}
pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
mut read: T,
) -> std::io::Result<(ChainRW<T>, Option<SocketAddr>)> {
) -> std::io::Result<(ChainRW<T>, Option<ConnectionInfo>)> {
let mut buf = BytesMut::with_capacity(128);
while buf.len() < 16 {
let bytes_read = read.read_buf(&mut buf).await?;
@@ -77,14 +156,15 @@ pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
};
}
let header = buf.split_to(16);
let header = buf
.try_get::<ProxyProtocolV2Header>()
.expect("we have checked the length already, so this should not panic");
// The next byte (the 13th one) is the protocol version and command.
// The highest four bits contains the version. As of this specification, it must
// always be sent as \x2 and the receiver must only accept this value.
let vc = header[12];
let version = vc >> 4;
let command = vc & 0b1111;
let version = header.version_and_command >> 4;
let command = header.version_and_command & 0b1111;
if version != 2 {
return Err(io::Error::new(
io::ErrorKind::Other,
@@ -115,18 +195,13 @@ pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
// The 14th byte contains the transport protocol and address family. The highest 4
// bits contain the address family, the lowest 4 bits contain the protocol.
let ft = header[13];
let address_length = match ft {
// - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET
// protocol family. Address length is 2*4 + 2*2 = 12 bytes.
// - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET
// protocol family. Address length is 2*4 + 2*2 = 12 bytes.
0x11 | 0x12 => 12,
// - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6
// protocol family. Address length is 2*16 + 2*2 = 36 bytes.
// - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6
// protocol family. Address length is 2*16 + 2*2 = 36 bytes.
0x21 | 0x22 => 36,
let address_length = match header.protocol_and_family {
// - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET protocol family.
// - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET protocol family.
0x11 | 0x12 => size_of::<ProxyProtocolV2HeaderV4>(),
// - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6 protocol family.
// - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6 protocol family.
0x21 | 0x22 => size_of::<ProxyProtocolV2HeaderV6>(),
// unspecified or unix stream. ignore the addresses
_ => 0,
};
@@ -140,16 +215,15 @@ pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
// of bytes and must not assume zero is presented for LOCAL connections. When a
// receiver accepts an incoming connection showing an UNSPEC address family or
// protocol, it may or may not decide to log the address information if present.
let remaining_length = u16::from_be_bytes(header[14..16].try_into().unwrap());
let remaining_length = usize::from(header.len.get());
if remaining_length < address_length {
return Err(io::Error::new(
io::ErrorKind::Other,
"invalid proxy protocol length. not enough to fit requested IP addresses",
));
}
drop(header);
while buf.len() < remaining_length as usize {
while buf.len() < remaining_length {
if read.read_buf(&mut buf).await? == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
@@ -164,22 +238,106 @@ pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
// - destination layer 3 address in network byte order
// - source layer 4 address if any, in network byte order (port)
// - destination layer 4 address if any, in network byte order (port)
let addresses = buf.split_to(remaining_length as usize);
let mut header = buf.split_to(remaining_length);
let socket = match address_length {
12 => {
let src_addr: [u8; 4] = addresses[0..4].try_into().unwrap();
let src_port = u16::from_be_bytes(addresses[8..10].try_into().unwrap());
Some(SocketAddr::from((src_addr, src_port)))
let addr = header
.try_get::<ProxyProtocolV2HeaderV4>()
.expect("we have verified that 12 bytes are in the buf");
Some(SocketAddr::from((addr.src_addr.get(), addr.src_port.get())))
}
36 => {
let src_addr: [u8; 16] = addresses[0..16].try_into().unwrap();
let src_port = u16::from_be_bytes(addresses[32..34].try_into().unwrap());
Some(SocketAddr::from((src_addr, src_port)))
let addr = header
.try_get::<ProxyProtocolV2HeaderV6>()
.expect("we have verified that 36 bytes are in the buf");
Some(SocketAddr::from((addr.src_addr.get(), addr.src_port.get())))
}
_ => None,
};
Ok((ChainRW { inner: read, buf }, socket))
let mut extra = None;
while let Some(mut tlv) = read_tlv(&mut header) {
match Pp2Kind::from_repr(tlv.kind) {
Some(Pp2Kind::Aws) => {
if tlv.value.is_empty() {
tracing::warn!("invalid aws tlv: no subtype");
}
let subtype = tlv.value.get_u8();
match Pp2AwsType::from_repr(subtype) {
Some(Pp2AwsType::VpceId) => {
extra = Some(ConnectionInfoExtra::Aws { vpce_id: tlv.value });
}
None => {
tracing::warn!("unknown aws tlv: subtype={subtype}");
}
}
}
Some(Pp2Kind::Azure) => {
if tlv.value.is_empty() {
tracing::warn!("invalid azure tlv: no subtype");
}
let subtype = tlv.value.get_u8();
match Pp2AzureType::from_repr(subtype) {
Some(Pp2AzureType::PrivateEndpointLinkId) => {
if tlv.value.len() != 4 {
tracing::warn!("invalid azure link_id: {:?}", tlv.value);
}
extra = Some(ConnectionInfoExtra::Azure {
link_id: tlv.value.get_u32_le(),
});
}
None => {
tracing::warn!("unknown azure tlv: subtype={subtype}");
}
}
}
Some(kind) => {
tracing::debug!("unused tlv[{kind:?}]: {:?}", tlv.value);
}
None => {
tracing::debug!("unknown tlv: {tlv:?}");
}
}
}
let conn_info = socket.map(|addr| ConnectionInfo { addr, extra });
Ok((ChainRW { inner: read, buf }, conn_info))
}
#[derive(FromRepr, Debug, Copy, Clone)]
#[repr(u8)]
enum Pp2Kind {
// The following are defined by https://www.haproxy.org/download/3.1/doc/proxy-protocol.txt
// we don't use these but it would be interesting to know what's available
Alpn = 0x01,
Authority = 0x02,
Crc32C = 0x03,
Noop = 0x04,
UniqueId = 0x05,
Ssl = 0x20,
NetNs = 0x30,
/// <https://docs.aws.amazon.com/elasticloadbalancing/latest/network/edit-target-group-attributes.html#proxy-protocol>
Aws = 0xEA,
/// <https://learn.microsoft.com/en-us/azure/private-link/private-link-service-overview#getting-connection-information-using-tcp-proxy-v2>
Azure = 0xEE,
}
#[derive(FromRepr, Debug, Copy, Clone)]
#[repr(u8)]
enum Pp2AwsType {
VpceId = 0x01,
}
#[derive(FromRepr, Debug, Copy, Clone)]
#[repr(u8)]
enum Pp2AzureType {
PrivateEndpointLinkId = 0x01,
}
impl<T: AsyncRead> AsyncRead for ChainRW<T> {
@@ -216,6 +374,60 @@ impl<T: AsyncRead> ChainRW<T> {
}
}
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
#[repr(C)]
struct TlvHeader {
kind: u8,
len: zerocopy::byteorder::network_endian::U16,
}
#[derive(Debug)]
struct Tlv {
kind: u8,
value: Bytes,
}
fn read_tlv(b: &mut BytesMut) -> Option<Tlv> {
let tlv_header = b.try_get::<TlvHeader>().ok()?;
let len = usize::from(tlv_header.len.get());
if b.len() < len {
return None;
}
let value = b.split_to(len).freeze();
Some(Tlv {
kind: tlv_header.kind,
value,
})
}
trait BufExt: Sized {
fn try_get<T: zerocopy::FromBytes>(&mut self)
-> Result<T, zerocopy::error::SizeError<Self, T>>;
// fn peek<T: zerocopy::FromBytes + zerocopy::KnownLayout + zerocopy::Immutable>(
// &self,
// ) -> Option<&T>;
}
impl BufExt for BytesMut {
fn try_get<T: zerocopy::FromBytes>(
&mut self,
) -> Result<T, zerocopy::error::SizeError<Self, T>> {
let len = size_of::<T>();
// this will error in the read_from_bytes if the buf is too small
let len = usize::min(len, self.len());
let buf = self.split_to(len);
T::read_from_bytes(&buf).map_err(|e| e.map_src(|_| buf.clone()))
}
// fn peek<T: zerocopy::FromBytes + zerocopy::KnownLayout + zerocopy::Immutable>(
// &self,
// ) -> Option<&T> {
// T::ref_from_prefix(self).ok().map(|(t, _)| t)
// }
}
#[cfg(test)]
mod tests {
use tokio::io::AsyncReadExt;
@@ -242,7 +454,7 @@ mod tests {
let extra_data = [0x55; 256];
let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
.await
.unwrap();
@@ -250,7 +462,9 @@ mod tests {
read.read_to_end(&mut bytes).await.unwrap();
assert_eq!(bytes, extra_data);
assert_eq!(addr, Some(([127, 0, 0, 1], 65535).into()));
let info = info.unwrap();
assert_eq!(info.addr, ([127, 0, 0, 1], 65535).into());
}
#[tokio::test]
@@ -273,7 +487,7 @@ mod tests {
let extra_data = [0x55; 256];
let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
.await
.unwrap();
@@ -281,9 +495,11 @@ mod tests {
read.read_to_end(&mut bytes).await.unwrap();
assert_eq!(bytes, extra_data);
let info = info.unwrap();
assert_eq!(
addr,
Some(([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into())
info.addr,
([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into()
);
}
@@ -291,30 +507,31 @@ mod tests {
async fn test_invalid() {
let data = [0x55; 256];
let (mut read, addr) = read_proxy_protocol(data.as_slice()).await.unwrap();
let (mut read, info) = read_proxy_protocol(data.as_slice()).await.unwrap();
let mut bytes = vec![];
read.read_to_end(&mut bytes).await.unwrap();
assert_eq!(bytes, data);
assert_eq!(addr, None);
assert_eq!(info, None);
}
#[tokio::test]
async fn test_short() {
let data = [0x55; 10];
let (mut read, addr) = read_proxy_protocol(data.as_slice()).await.unwrap();
let (mut read, info) = read_proxy_protocol(data.as_slice()).await.unwrap();
let mut bytes = vec![];
read.read_to_end(&mut bytes).await.unwrap();
assert_eq!(bytes, data);
assert_eq!(addr, None);
assert_eq!(info, None);
}
#[tokio::test]
async fn test_large_tlv() {
let tlv = vec![0x55; 32768];
let len = (12 + tlv.len() as u16).to_be_bytes();
let tlv_len = (tlv.len() as u16).to_be_bytes();
let len = (12 + 3 + tlv.len() as u16).to_be_bytes();
let header = super::HEADER
// Proxy command, Inet << 4 | Stream
@@ -330,11 +547,13 @@ mod tests {
// dst port
.chain([1, 1].as_slice())
// TLV
.chain([255].as_slice())
.chain(tlv_len.as_slice())
.chain(tlv.as_slice());
let extra_data = [0xaa; 256];
let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
.await
.unwrap();
@@ -342,6 +561,8 @@ mod tests {
read.read_to_end(&mut bytes).await.unwrap();
assert_eq!(bytes, extra_data);
assert_eq!(addr, Some(([55, 56, 57, 58], 65535).into()));
let info = info.unwrap();
assert_eq!(info.addr, ([55, 56, 57, 58], 65535).into());
}
}

View File

@@ -28,7 +28,7 @@ use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::context::RequestMonitoring;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::protocol2::read_proxy_protocol;
use crate::protocol2::{read_proxy_protocol, ConnectionInfo};
use crate::proxy::handshake::{handshake, HandshakeData};
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::{PqStream, Stream};
@@ -87,7 +87,7 @@ pub async fn task_main(
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
connections.spawn(async move {
let (socket, peer_addr) = match read_proxy_protocol(socket).await {
let (socket, conn_info) = match read_proxy_protocol(socket).await {
Err(e) => {
warn!("per-client task finished with an error: {e:#}");
return;
@@ -100,8 +100,8 @@ pub async fn task_main(
warn!("proxy protocol header not supported");
return;
}
Ok((socket, Some(addr))) => (socket, addr.ip()),
Ok((socket, None)) => (socket, peer_addr.ip()),
Ok((socket, Some(info))) => (socket, info),
Ok((socket, None)) => (socket, ConnectionInfo { addr: peer_addr, extra: None }),
};
match socket.inner.set_nodelay(true) {
@@ -114,7 +114,7 @@ pub async fn task_main(
let ctx = RequestMonitoring::new(
session_id,
peer_addr,
conn_info,
crate::metrics::Protocol::Tcp,
&config.region,
);

View File

@@ -44,10 +44,10 @@ use tracing::{info, warn, Instrument};
use utils::http::error::ApiError;
use crate::cancellation::CancellationHandlerMain;
use crate::config::ProxyConfig;
use crate::config::{ProxyConfig, ProxyProtocolV2};
use crate::context::RequestMonitoring;
use crate::metrics::Metrics;
use crate::protocol2::{read_proxy_protocol, ChainRW};
use crate::protocol2::{read_proxy_protocol, ChainRW, ConnectionInfo};
use crate::proxy::run_until_cancelled;
use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
@@ -180,7 +180,7 @@ pub async fn task_main(
peer_addr,
))
.await;
let Some((conn, peer_addr)) = startup_result else {
let Some((conn, conn_info)) = startup_result else {
return;
};
@@ -192,7 +192,7 @@ pub async fn task_main(
endpoint_rate_limiter,
conn_token,
conn,
peer_addr,
conn_info,
session_id,
))
.await;
@@ -240,7 +240,7 @@ async fn connection_startup(
session_id: uuid::Uuid,
conn: TcpStream,
peer_addr: SocketAddr,
) -> Option<(AsyncRW, IpAddr)> {
) -> Option<(AsyncRW, ConnectionInfo)> {
// handle PROXY protocol
let (conn, peer) = match read_proxy_protocol(conn).await {
Ok(c) => c,
@@ -250,17 +250,32 @@ async fn connection_startup(
}
};
let peer_addr = peer.unwrap_or(peer_addr).ip();
let has_private_peer_addr = match peer_addr {
let conn_info = match peer {
None if config.proxy_protocol_v2 == ProxyProtocolV2::Required => {
tracing::warn!("missing required proxy protocol header");
return None;
}
Some(_) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => {
tracing::warn!("proxy protocol header not supported");
return None;
}
Some(info) => info,
None => ConnectionInfo {
addr: peer_addr,
extra: None,
},
};
let has_private_peer_addr = match conn_info.addr.ip() {
IpAddr::V4(ip) => ip.is_private(),
IpAddr::V6(_) => false,
};
info!(?session_id, %peer_addr, "accepted new TCP connection");
info!(?session_id, %conn_info, "accepted new TCP connection");
// try upgrade to TLS, but with a timeout.
let conn = match timeout(config.handshake_timeout, tls_acceptor.accept(conn)).await {
Ok(Ok(conn)) => {
info!(?session_id, %peer_addr, "accepted new TLS connection");
info!(?session_id, %conn_info, "accepted new TLS connection");
conn
}
// The handshake failed
@@ -268,7 +283,7 @@ async fn connection_startup(
if !has_private_peer_addr {
Metrics::get().proxy.tls_handshake_failures.inc();
}
warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}");
warn!(?session_id, %conn_info, "failed to accept TLS connection: {e:?}");
return None;
}
// The handshake timed out
@@ -276,12 +291,12 @@ async fn connection_startup(
if !has_private_peer_addr {
Metrics::get().proxy.tls_handshake_failures.inc();
}
warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}");
warn!(?session_id, %conn_info, "failed to accept TLS connection: {e:?}");
return None;
}
};
Some((conn, peer_addr))
Some((conn, conn_info))
}
/// Handles HTTP connection
@@ -297,7 +312,7 @@ async fn connection_handler(
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_token: CancellationToken,
conn: AsyncRW,
peer_addr: IpAddr,
conn_info: ConnectionInfo,
session_id: uuid::Uuid,
) {
let session_id = AtomicTake::new(session_id);
@@ -306,6 +321,7 @@ async fn connection_handler(
let http_cancellation_token = CancellationToken::new();
let _cancel_connection = http_cancellation_token.clone().drop_guard();
let conn_info2 = conn_info.clone();
let server = Builder::new(TokioExecutor::new());
let conn = server.serve_connection_with_upgrades(
hyper_util::rt::TokioIo::new(conn),
@@ -340,7 +356,7 @@ async fn connection_handler(
connections.clone(),
cancellation_handler.clone(),
session_id,
peer_addr,
conn_info2.clone(),
http_request_token,
endpoint_rate_limiter.clone(),
)
@@ -365,7 +381,7 @@ async fn connection_handler(
// On cancellation, trigger the HTTP connection handler to shut down.
let res = match select(pin!(cancellation_token.cancelled()), pin!(conn)).await {
Either::Left((_cancelled, mut conn)) => {
tracing::debug!(%peer_addr, "cancelling connection");
tracing::debug!(%conn_info, "cancelling connection");
conn.as_mut().graceful_shutdown();
conn.await
}
@@ -373,8 +389,8 @@ async fn connection_handler(
};
match res {
Ok(()) => tracing::info!(%peer_addr, "HTTP connection closed"),
Err(e) => tracing::warn!(%peer_addr, "HTTP connection error {e}"),
Ok(()) => tracing::info!(%conn_info, "HTTP connection closed"),
Err(e) => tracing::warn!(%conn_info, "HTTP connection error {e}"),
}
}
@@ -386,7 +402,7 @@ async fn request_handler(
ws_connections: TaskTracker,
cancellation_handler: Arc<CancellationHandlerMain>,
session_id: uuid::Uuid,
peer_addr: IpAddr,
conn_info: ConnectionInfo,
// used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -404,7 +420,7 @@ async fn request_handler(
{
let ctx = RequestMonitoring::new(
session_id,
peer_addr,
conn_info,
crate::metrics::Protocol::Ws,
&config.region,
);
@@ -439,7 +455,7 @@ async fn request_handler(
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
let ctx = RequestMonitoring::new(
session_id,
peer_addr,
conn_info,
crate::metrics::Protocol::Http,
&config.region,
);