mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-15 17:32:56 +00:00
[proxy]: parse proxy protocol TLVs with aws/azure support
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -4433,6 +4433,8 @@ dependencies = [
|
||||
"smallvec",
|
||||
"smol_str",
|
||||
"socket2",
|
||||
"strum",
|
||||
"strum_macros",
|
||||
"subtle",
|
||||
"thiserror",
|
||||
"tikv-jemalloc-ctl",
|
||||
|
||||
@@ -74,6 +74,8 @@ sha2 = { workspace = true, features = ["asm", "oid"] }
|
||||
smol_str.workspace = true
|
||||
smallvec.workspace = true
|
||||
socket2.workspace = true
|
||||
strum.workspace = true
|
||||
strum_macros.workspace = true
|
||||
subtle.workspace = true
|
||||
thiserror.workspace = true
|
||||
tikv-jemallocator.workspace = true
|
||||
|
||||
@@ -13,6 +13,7 @@ use itertools::Itertools;
|
||||
use proxy::config::TlsServerEndPoint;
|
||||
use proxy::context::RequestMonitoring;
|
||||
use proxy::metrics::{Metrics, ThreadPoolMetrics};
|
||||
use proxy::protocol2::ConnectionInfo;
|
||||
use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled, ErrorSource};
|
||||
use proxy::stream::{PqStream, Stream};
|
||||
use rustls::crypto::aws_lc_rs;
|
||||
@@ -179,7 +180,10 @@ async fn task_main(
|
||||
info!(%peer_addr, "serving");
|
||||
let ctx = RequestMonitoring::new(
|
||||
session_id,
|
||||
peer_addr.ip(),
|
||||
ConnectionInfo {
|
||||
addr: peer_addr,
|
||||
extra: None,
|
||||
},
|
||||
proxy::metrics::Protocol::SniRouter,
|
||||
"sni",
|
||||
);
|
||||
|
||||
@@ -11,7 +11,7 @@ use crate::config::{ProxyConfig, ProxyProtocolV2};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::protocol2::read_proxy_protocol;
|
||||
use crate::protocol2::{read_proxy_protocol, ConnectionInfo};
|
||||
use crate::proxy::connect_compute::{connect_to_compute, TcpMechanism};
|
||||
use crate::proxy::handshake::{handshake, HandshakeData};
|
||||
use crate::proxy::passthrough::ProxyPassthrough;
|
||||
@@ -65,8 +65,8 @@ pub async fn task_main(
|
||||
error!("proxy protocol header not supported");
|
||||
return;
|
||||
}
|
||||
Ok((socket, Some(addr))) => (socket, addr.ip()),
|
||||
Ok((socket, None)) => (socket, peer_addr.ip()),
|
||||
Ok((socket, Some(info))) => (socket, info),
|
||||
Ok((socket, None)) => (socket, ConnectionInfo{ addr: peer_addr, extra: None }),
|
||||
};
|
||||
|
||||
match socket.inner.set_nodelay(true) {
|
||||
|
||||
@@ -19,6 +19,7 @@ use crate::intern::{BranchIdInt, ProjectIdInt};
|
||||
use crate::metrics::{
|
||||
ConnectOutcome, InvalidEndpointsGroup, LatencyTimer, Metrics, Protocol, Waiting,
|
||||
};
|
||||
use crate::protocol2::ConnectionInfo;
|
||||
use crate::types::{DbName, EndpointId, RoleName};
|
||||
|
||||
pub mod parquet;
|
||||
@@ -40,7 +41,7 @@ pub struct RequestMonitoring(
|
||||
);
|
||||
|
||||
struct RequestMonitoringInner {
|
||||
pub(crate) peer_addr: IpAddr,
|
||||
pub(crate) conn_info: ConnectionInfo,
|
||||
pub(crate) session_id: Uuid,
|
||||
pub(crate) protocol: Protocol,
|
||||
first_packet: chrono::DateTime<Utc>,
|
||||
@@ -84,7 +85,7 @@ impl Clone for RequestMonitoring {
|
||||
fn clone(&self) -> Self {
|
||||
let inner = self.0.try_lock().expect("should not deadlock");
|
||||
let new = RequestMonitoringInner {
|
||||
peer_addr: inner.peer_addr,
|
||||
conn_info: inner.conn_info.clone(),
|
||||
session_id: inner.session_id,
|
||||
protocol: inner.protocol,
|
||||
first_packet: inner.first_packet,
|
||||
@@ -117,7 +118,7 @@ impl Clone for RequestMonitoring {
|
||||
impl RequestMonitoring {
|
||||
pub fn new(
|
||||
session_id: Uuid,
|
||||
peer_addr: IpAddr,
|
||||
conn_info: ConnectionInfo,
|
||||
protocol: Protocol,
|
||||
region: &'static str,
|
||||
) -> Self {
|
||||
@@ -125,13 +126,13 @@ impl RequestMonitoring {
|
||||
"connect_request",
|
||||
%protocol,
|
||||
?session_id,
|
||||
%peer_addr,
|
||||
%conn_info,
|
||||
ep = tracing::field::Empty,
|
||||
role = tracing::field::Empty,
|
||||
);
|
||||
|
||||
let inner = RequestMonitoringInner {
|
||||
peer_addr,
|
||||
conn_info,
|
||||
session_id,
|
||||
protocol,
|
||||
first_packet: Utc::now(),
|
||||
@@ -162,7 +163,11 @@ impl RequestMonitoring {
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn test() -> Self {
|
||||
RequestMonitoring::new(Uuid::now_v7(), [127, 0, 0, 1].into(), Protocol::Tcp, "test")
|
||||
use std::net::SocketAddr;
|
||||
let ip = IpAddr::from([127, 0, 0, 1]);
|
||||
let addr = SocketAddr::new(ip, 5432);
|
||||
let conn_info = ConnectionInfo { addr, extra: None };
|
||||
RequestMonitoring::new(Uuid::now_v7(), conn_info, Protocol::Tcp, "test")
|
||||
}
|
||||
|
||||
pub(crate) fn console_application_name(&self) -> String {
|
||||
@@ -286,7 +291,12 @@ impl RequestMonitoring {
|
||||
}
|
||||
|
||||
pub(crate) fn peer_addr(&self) -> IpAddr {
|
||||
self.0.try_lock().expect("should not deadlock").peer_addr
|
||||
self.0
|
||||
.try_lock()
|
||||
.expect("should not deadlock")
|
||||
.conn_info
|
||||
.addr
|
||||
.ip()
|
||||
}
|
||||
|
||||
pub(crate) fn cold_start_info(&self) -> ColdStartInfo {
|
||||
@@ -362,7 +372,7 @@ impl RequestMonitoringInner {
|
||||
}
|
||||
|
||||
fn has_private_peer_addr(&self) -> bool {
|
||||
match self.peer_addr {
|
||||
match self.conn_info.addr.ip() {
|
||||
IpAddr::V4(ip) => ip.is_private(),
|
||||
IpAddr::V6(_) => false,
|
||||
}
|
||||
|
||||
@@ -121,7 +121,7 @@ impl From<&RequestMonitoringInner> for RequestData {
|
||||
fn from(value: &RequestMonitoringInner) -> Self {
|
||||
Self {
|
||||
session_id: value.session_id,
|
||||
peer_addr: value.peer_addr.to_string(),
|
||||
peer_addr: value.conn_info.addr.ip().to_string(),
|
||||
timestamp: value.first_packet.naive_utc(),
|
||||
username: value.user.as_deref().map(String::from),
|
||||
application_name: value.application.as_deref().map(String::from),
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
//! Proxy Protocol V2 implementation
|
||||
//! Compatible with <https://www.haproxy.org/download/3.1/doc/proxy-protocol.txt>
|
||||
|
||||
use core::fmt;
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use bytes::BytesMut;
|
||||
use bytes::{Buf, Bytes, BytesMut};
|
||||
use pin_project_lite::pin_project;
|
||||
use strum_macros::FromRepr;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
|
||||
|
||||
pin_project! {
|
||||
@@ -58,9 +61,35 @@ const HEADER: [u8; 12] = [
|
||||
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
|
||||
];
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Debug)]
|
||||
pub struct ConnectionInfo {
|
||||
pub addr: SocketAddr,
|
||||
pub extra: Option<ConnectionInfoExtra>,
|
||||
}
|
||||
|
||||
impl fmt::Display for ConnectionInfo {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match &self.extra {
|
||||
None => self.addr.ip().fmt(f),
|
||||
Some(ConnectionInfoExtra::Aws { vpce_id }) => {
|
||||
write!(f, "vpce_id[{vpce_id:?}]:addr[{}]", self.addr.ip())
|
||||
}
|
||||
Some(ConnectionInfoExtra::Azure { link_id }) => {
|
||||
write!(f, "link_id[{link_id}]:addr[{}]", self.addr.ip())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Debug)]
|
||||
pub enum ConnectionInfoExtra {
|
||||
Aws { vpce_id: Bytes },
|
||||
Azure { link_id: u32 },
|
||||
}
|
||||
|
||||
pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
|
||||
mut read: T,
|
||||
) -> std::io::Result<(ChainRW<T>, Option<SocketAddr>)> {
|
||||
) -> std::io::Result<(ChainRW<T>, Option<ConnectionInfo>)> {
|
||||
let mut buf = BytesMut::with_capacity(128);
|
||||
while buf.len() < 16 {
|
||||
let bytes_read = read.read_buf(&mut buf).await?;
|
||||
@@ -164,22 +193,107 @@ pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
|
||||
// - destination layer 3 address in network byte order
|
||||
// - source layer 4 address if any, in network byte order (port)
|
||||
// - destination layer 4 address if any, in network byte order (port)
|
||||
let addresses = buf.split_to(remaining_length as usize);
|
||||
let socket = match address_length {
|
||||
let mut header = buf.split_to(usize::from(remaining_length));
|
||||
let mut addr = header.split_to(usize::from(address_length));
|
||||
let socket = match addr.len() {
|
||||
12 => {
|
||||
let src_addr: [u8; 4] = addresses[0..4].try_into().unwrap();
|
||||
let src_port = u16::from_be_bytes(addresses[8..10].try_into().unwrap());
|
||||
let src_addr = Ipv4Addr::from_bits(addr.get_u32());
|
||||
let _dst_addr = addr.get_u32();
|
||||
let src_port = addr.get_u16();
|
||||
let _dst_port = addr.get_u16();
|
||||
Some(SocketAddr::from((src_addr, src_port)))
|
||||
}
|
||||
36 => {
|
||||
let src_addr: [u8; 16] = addresses[0..16].try_into().unwrap();
|
||||
let src_port = u16::from_be_bytes(addresses[32..34].try_into().unwrap());
|
||||
let src_addr = Ipv6Addr::from_bits(addr.get_u128());
|
||||
let _dst_addr = addr.get_u128();
|
||||
let src_port = addr.get_u16();
|
||||
let _dst_port = addr.get_u16();
|
||||
Some(SocketAddr::from((src_addr, src_port)))
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
Ok((ChainRW { inner: read, buf }, socket))
|
||||
let mut extra = None;
|
||||
|
||||
while let Some(mut tlv) = read_tlv(&mut header) {
|
||||
match Pp2Kind::from_repr(tlv.kind) {
|
||||
Some(Pp2Kind::Aws) => {
|
||||
if tlv.value.is_empty() {
|
||||
tracing::warn!("invalid aws tlv: no subtype");
|
||||
}
|
||||
let subtype = tlv.value.get_u8();
|
||||
match Pp2AwsType::from_repr(subtype) {
|
||||
Some(Pp2AwsType::VpceId) => {
|
||||
extra = Some(ConnectionInfoExtra::Aws { vpce_id: tlv.value });
|
||||
}
|
||||
None => {
|
||||
tracing::warn!("unknown aws tlv: subtype={subtype}");
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Pp2Kind::Azure) => {
|
||||
if tlv.value.is_empty() {
|
||||
tracing::warn!("invalid azure tlv: no subtype");
|
||||
}
|
||||
let subtype = tlv.value.get_u8();
|
||||
match Pp2AzureType::from_repr(subtype) {
|
||||
Some(Pp2AzureType::PrivateEndpointLinkId) => {
|
||||
if tlv.value.len() != 4 {
|
||||
tracing::warn!("invalid azure link_id: {:?}", tlv.value);
|
||||
}
|
||||
extra = Some(ConnectionInfoExtra::Azure {
|
||||
link_id: tlv.value.get_u32_le(),
|
||||
});
|
||||
}
|
||||
None => {
|
||||
tracing::warn!("unknown azure tlv: subtype={subtype}");
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(kind) => {
|
||||
tracing::debug!("unused tlv[{kind:?}]: {:?}", tlv.value);
|
||||
}
|
||||
None => {
|
||||
tracing::debug!("unknown tlv: {tlv:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let conn_info = socket.map(|addr| ConnectionInfo { addr, extra });
|
||||
|
||||
Ok((ChainRW { inner: read, buf }, conn_info))
|
||||
}
|
||||
|
||||
#[derive(FromRepr, Debug, Copy, Clone)]
|
||||
#[repr(u8)]
|
||||
enum Pp2Kind {
|
||||
// The following are defined by https://www.haproxy.org/download/3.1/doc/proxy-protocol.txt
|
||||
// we don't use these but it would be interesting to know what's available
|
||||
Alpn = 0x01,
|
||||
Authority = 0x02,
|
||||
Crc32C = 0x03,
|
||||
Noop = 0x04,
|
||||
UniqueId = 0x05,
|
||||
Ssl = 0x20,
|
||||
NetNs = 0x30,
|
||||
|
||||
/// <https://docs.aws.amazon.com/elasticloadbalancing/latest/network/edit-target-group-attributes.html#proxy-protocol>
|
||||
Aws = 0xEA,
|
||||
|
||||
/// <https://learn.microsoft.com/en-us/azure/private-link/private-link-service-overview#getting-connection-information-using-tcp-proxy-v2>
|
||||
Azure = 0xEE,
|
||||
}
|
||||
|
||||
#[derive(FromRepr, Debug, Copy, Clone)]
|
||||
#[repr(u8)]
|
||||
enum Pp2AwsType {
|
||||
VpceId = 0x01,
|
||||
}
|
||||
|
||||
#[derive(FromRepr, Debug, Copy, Clone)]
|
||||
#[repr(u8)]
|
||||
enum Pp2AzureType {
|
||||
PrivateEndpointLinkId = 0x01,
|
||||
}
|
||||
|
||||
impl<T: AsyncRead> AsyncRead for ChainRW<T> {
|
||||
@@ -216,6 +330,25 @@ impl<T: AsyncRead> ChainRW<T> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Tlv {
|
||||
kind: u8,
|
||||
value: Bytes,
|
||||
}
|
||||
|
||||
fn read_tlv(b: &mut BytesMut) -> Option<Tlv> {
|
||||
if b.len() < 3 {
|
||||
return None;
|
||||
}
|
||||
let kind = b.get_u8();
|
||||
let len = usize::from(b.get_u16());
|
||||
if b.len() < len {
|
||||
return None;
|
||||
}
|
||||
let value = b.split_to(len).freeze();
|
||||
Some(Tlv { kind, value })
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use tokio::io::AsyncReadExt;
|
||||
@@ -242,7 +375,7 @@ mod tests {
|
||||
|
||||
let extra_data = [0x55; 256];
|
||||
|
||||
let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
|
||||
let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -250,7 +383,9 @@ mod tests {
|
||||
read.read_to_end(&mut bytes).await.unwrap();
|
||||
|
||||
assert_eq!(bytes, extra_data);
|
||||
assert_eq!(addr, Some(([127, 0, 0, 1], 65535).into()));
|
||||
|
||||
let info = info.unwrap();
|
||||
assert_eq!(info.addr, ([127, 0, 0, 1], 65535).into());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -273,7 +408,7 @@ mod tests {
|
||||
|
||||
let extra_data = [0x55; 256];
|
||||
|
||||
let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
|
||||
let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -281,9 +416,11 @@ mod tests {
|
||||
read.read_to_end(&mut bytes).await.unwrap();
|
||||
|
||||
assert_eq!(bytes, extra_data);
|
||||
|
||||
let info = info.unwrap();
|
||||
assert_eq!(
|
||||
addr,
|
||||
Some(([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into())
|
||||
info.addr,
|
||||
([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into()
|
||||
);
|
||||
}
|
||||
|
||||
@@ -291,30 +428,31 @@ mod tests {
|
||||
async fn test_invalid() {
|
||||
let data = [0x55; 256];
|
||||
|
||||
let (mut read, addr) = read_proxy_protocol(data.as_slice()).await.unwrap();
|
||||
let (mut read, info) = read_proxy_protocol(data.as_slice()).await.unwrap();
|
||||
|
||||
let mut bytes = vec![];
|
||||
read.read_to_end(&mut bytes).await.unwrap();
|
||||
assert_eq!(bytes, data);
|
||||
assert_eq!(addr, None);
|
||||
assert_eq!(info, None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_short() {
|
||||
let data = [0x55; 10];
|
||||
|
||||
let (mut read, addr) = read_proxy_protocol(data.as_slice()).await.unwrap();
|
||||
let (mut read, info) = read_proxy_protocol(data.as_slice()).await.unwrap();
|
||||
|
||||
let mut bytes = vec![];
|
||||
read.read_to_end(&mut bytes).await.unwrap();
|
||||
assert_eq!(bytes, data);
|
||||
assert_eq!(addr, None);
|
||||
assert_eq!(info, None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_large_tlv() {
|
||||
let tlv = vec![0x55; 32768];
|
||||
let len = (12 + tlv.len() as u16).to_be_bytes();
|
||||
let tlv_len = (tlv.len() as u16).to_be_bytes();
|
||||
let len = (12 + 3 + tlv.len() as u16).to_be_bytes();
|
||||
|
||||
let header = super::HEADER
|
||||
// Proxy command, Inet << 4 | Stream
|
||||
@@ -330,11 +468,13 @@ mod tests {
|
||||
// dst port
|
||||
.chain([1, 1].as_slice())
|
||||
// TLV
|
||||
.chain([255].as_slice())
|
||||
.chain(tlv_len.as_slice())
|
||||
.chain(tlv.as_slice());
|
||||
|
||||
let extra_data = [0xaa; 256];
|
||||
|
||||
let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
|
||||
let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -342,6 +482,8 @@ mod tests {
|
||||
read.read_to_end(&mut bytes).await.unwrap();
|
||||
|
||||
assert_eq!(bytes, extra_data);
|
||||
assert_eq!(addr, Some(([55, 56, 57, 58], 65535).into()));
|
||||
|
||||
let info = info.unwrap();
|
||||
assert_eq!(info.addr, ([55, 56, 57, 58], 65535).into());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::protocol2::read_proxy_protocol;
|
||||
use crate::protocol2::{read_proxy_protocol, ConnectionInfo};
|
||||
use crate::proxy::handshake::{handshake, HandshakeData};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::stream::{PqStream, Stream};
|
||||
@@ -87,7 +87,7 @@ pub async fn task_main(
|
||||
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
|
||||
|
||||
connections.spawn(async move {
|
||||
let (socket, peer_addr) = match read_proxy_protocol(socket).await {
|
||||
let (socket, conn_info) = match read_proxy_protocol(socket).await {
|
||||
Err(e) => {
|
||||
warn!("per-client task finished with an error: {e:#}");
|
||||
return;
|
||||
@@ -100,8 +100,8 @@ pub async fn task_main(
|
||||
warn!("proxy protocol header not supported");
|
||||
return;
|
||||
}
|
||||
Ok((socket, Some(addr))) => (socket, addr.ip()),
|
||||
Ok((socket, None)) => (socket, peer_addr.ip()),
|
||||
Ok((socket, Some(info))) => (socket, info),
|
||||
Ok((socket, None)) => (socket, ConnectionInfo { addr: peer_addr, extra: None }),
|
||||
};
|
||||
|
||||
match socket.inner.set_nodelay(true) {
|
||||
@@ -114,7 +114,7 @@ pub async fn task_main(
|
||||
|
||||
let ctx = RequestMonitoring::new(
|
||||
session_id,
|
||||
peer_addr,
|
||||
conn_info,
|
||||
crate::metrics::Protocol::Tcp,
|
||||
&config.region,
|
||||
);
|
||||
|
||||
@@ -44,10 +44,10 @@ use tracing::{info, warn, Instrument};
|
||||
use utils::http::error::ApiError;
|
||||
|
||||
use crate::cancellation::CancellationHandlerMain;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::protocol2::{read_proxy_protocol, ChainRW};
|
||||
use crate::protocol2::{read_proxy_protocol, ChainRW, ConnectionInfo};
|
||||
use crate::proxy::run_until_cancelled;
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::serverless::backend::PoolingBackend;
|
||||
@@ -180,7 +180,7 @@ pub async fn task_main(
|
||||
peer_addr,
|
||||
))
|
||||
.await;
|
||||
let Some((conn, peer_addr)) = startup_result else {
|
||||
let Some((conn, conn_info)) = startup_result else {
|
||||
return;
|
||||
};
|
||||
|
||||
@@ -192,7 +192,7 @@ pub async fn task_main(
|
||||
endpoint_rate_limiter,
|
||||
conn_token,
|
||||
conn,
|
||||
peer_addr,
|
||||
conn_info,
|
||||
session_id,
|
||||
))
|
||||
.await;
|
||||
@@ -240,7 +240,7 @@ async fn connection_startup(
|
||||
session_id: uuid::Uuid,
|
||||
conn: TcpStream,
|
||||
peer_addr: SocketAddr,
|
||||
) -> Option<(AsyncRW, IpAddr)> {
|
||||
) -> Option<(AsyncRW, ConnectionInfo)> {
|
||||
// handle PROXY protocol
|
||||
let (conn, peer) = match read_proxy_protocol(conn).await {
|
||||
Ok(c) => c,
|
||||
@@ -250,17 +250,32 @@ async fn connection_startup(
|
||||
}
|
||||
};
|
||||
|
||||
let peer_addr = peer.unwrap_or(peer_addr).ip();
|
||||
let has_private_peer_addr = match peer_addr {
|
||||
let conn_info = match peer {
|
||||
None if config.proxy_protocol_v2 == ProxyProtocolV2::Required => {
|
||||
tracing::warn!("missing required proxy protocol header");
|
||||
return None;
|
||||
}
|
||||
Some(_) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => {
|
||||
tracing::warn!("proxy protocol header not supported");
|
||||
return None;
|
||||
}
|
||||
Some(info) => info,
|
||||
None => ConnectionInfo {
|
||||
addr: peer_addr,
|
||||
extra: None,
|
||||
},
|
||||
};
|
||||
|
||||
let has_private_peer_addr = match conn_info.addr.ip() {
|
||||
IpAddr::V4(ip) => ip.is_private(),
|
||||
IpAddr::V6(_) => false,
|
||||
};
|
||||
info!(?session_id, %peer_addr, "accepted new TCP connection");
|
||||
info!(?session_id, %conn_info, "accepted new TCP connection");
|
||||
|
||||
// try upgrade to TLS, but with a timeout.
|
||||
let conn = match timeout(config.handshake_timeout, tls_acceptor.accept(conn)).await {
|
||||
Ok(Ok(conn)) => {
|
||||
info!(?session_id, %peer_addr, "accepted new TLS connection");
|
||||
info!(?session_id, %conn_info, "accepted new TLS connection");
|
||||
conn
|
||||
}
|
||||
// The handshake failed
|
||||
@@ -268,7 +283,7 @@ async fn connection_startup(
|
||||
if !has_private_peer_addr {
|
||||
Metrics::get().proxy.tls_handshake_failures.inc();
|
||||
}
|
||||
warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}");
|
||||
warn!(?session_id, %conn_info, "failed to accept TLS connection: {e:?}");
|
||||
return None;
|
||||
}
|
||||
// The handshake timed out
|
||||
@@ -276,12 +291,12 @@ async fn connection_startup(
|
||||
if !has_private_peer_addr {
|
||||
Metrics::get().proxy.tls_handshake_failures.inc();
|
||||
}
|
||||
warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}");
|
||||
warn!(?session_id, %conn_info, "failed to accept TLS connection: {e:?}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
Some((conn, peer_addr))
|
||||
Some((conn, conn_info))
|
||||
}
|
||||
|
||||
/// Handles HTTP connection
|
||||
@@ -297,7 +312,7 @@ async fn connection_handler(
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
cancellation_token: CancellationToken,
|
||||
conn: AsyncRW,
|
||||
peer_addr: IpAddr,
|
||||
conn_info: ConnectionInfo,
|
||||
session_id: uuid::Uuid,
|
||||
) {
|
||||
let session_id = AtomicTake::new(session_id);
|
||||
@@ -306,6 +321,7 @@ async fn connection_handler(
|
||||
let http_cancellation_token = CancellationToken::new();
|
||||
let _cancel_connection = http_cancellation_token.clone().drop_guard();
|
||||
|
||||
let conn_info2 = conn_info.clone();
|
||||
let server = Builder::new(TokioExecutor::new());
|
||||
let conn = server.serve_connection_with_upgrades(
|
||||
hyper_util::rt::TokioIo::new(conn),
|
||||
@@ -340,7 +356,7 @@ async fn connection_handler(
|
||||
connections.clone(),
|
||||
cancellation_handler.clone(),
|
||||
session_id,
|
||||
peer_addr,
|
||||
conn_info2.clone(),
|
||||
http_request_token,
|
||||
endpoint_rate_limiter.clone(),
|
||||
)
|
||||
@@ -365,7 +381,7 @@ async fn connection_handler(
|
||||
// On cancellation, trigger the HTTP connection handler to shut down.
|
||||
let res = match select(pin!(cancellation_token.cancelled()), pin!(conn)).await {
|
||||
Either::Left((_cancelled, mut conn)) => {
|
||||
tracing::debug!(%peer_addr, "cancelling connection");
|
||||
tracing::debug!(%conn_info, "cancelling connection");
|
||||
conn.as_mut().graceful_shutdown();
|
||||
conn.await
|
||||
}
|
||||
@@ -373,8 +389,8 @@ async fn connection_handler(
|
||||
};
|
||||
|
||||
match res {
|
||||
Ok(()) => tracing::info!(%peer_addr, "HTTP connection closed"),
|
||||
Err(e) => tracing::warn!(%peer_addr, "HTTP connection error {e}"),
|
||||
Ok(()) => tracing::info!(%conn_info, "HTTP connection closed"),
|
||||
Err(e) => tracing::warn!(%conn_info, "HTTP connection error {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -386,7 +402,7 @@ async fn request_handler(
|
||||
ws_connections: TaskTracker,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
session_id: uuid::Uuid,
|
||||
peer_addr: IpAddr,
|
||||
conn_info: ConnectionInfo,
|
||||
// used to cancel in-flight HTTP requests. not used to cancel websockets
|
||||
http_cancellation_token: CancellationToken,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
@@ -404,7 +420,7 @@ async fn request_handler(
|
||||
{
|
||||
let ctx = RequestMonitoring::new(
|
||||
session_id,
|
||||
peer_addr,
|
||||
conn_info,
|
||||
crate::metrics::Protocol::Ws,
|
||||
&config.region,
|
||||
);
|
||||
@@ -439,7 +455,7 @@ async fn request_handler(
|
||||
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
|
||||
let ctx = RequestMonitoring::new(
|
||||
session_id,
|
||||
peer_addr,
|
||||
conn_info,
|
||||
crate::metrics::Protocol::Http,
|
||||
&config.region,
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user