mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-15 01:12:56 +00:00
zerocopy parsing
This commit is contained in:
33
Cargo.lock
generated
33
Cargo.lock
generated
@@ -34,7 +34,7 @@ dependencies = [
|
||||
"getrandom 0.2.11",
|
||||
"once_cell",
|
||||
"version_check",
|
||||
"zerocopy",
|
||||
"zerocopy 0.7.31",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -951,7 +951,7 @@ dependencies = [
|
||||
"bitflags 2.4.1",
|
||||
"cexpr",
|
||||
"clang-sys",
|
||||
"itertools 0.10.5",
|
||||
"itertools 0.12.1",
|
||||
"lazy_static",
|
||||
"lazycell",
|
||||
"log",
|
||||
@@ -974,7 +974,7 @@ dependencies = [
|
||||
"bitflags 2.4.1",
|
||||
"cexpr",
|
||||
"clang-sys",
|
||||
"itertools 0.10.5",
|
||||
"itertools 0.12.1",
|
||||
"log",
|
||||
"prettyplease",
|
||||
"proc-macro2",
|
||||
@@ -4314,7 +4314,7 @@ checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"heck 0.5.0",
|
||||
"itertools 0.10.5",
|
||||
"itertools 0.12.1",
|
||||
"log",
|
||||
"multimap",
|
||||
"once_cell",
|
||||
@@ -4334,7 +4334,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e9552f850d5f0964a4e4d0bf306459ac29323ddfbae05e35a7c0d35cb0803cc5"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"itertools 0.10.5",
|
||||
"itertools 0.12.1",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.52",
|
||||
@@ -4457,6 +4457,7 @@ dependencies = [
|
||||
"walkdir",
|
||||
"workspace_hack",
|
||||
"x509-parser",
|
||||
"zerocopy 0.8.8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7552,7 +7553,16 @@ version = "0.7.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1c4061bedbb353041c12f413700357bec76df2c7e2ca8e4df8bac24c6bf68e3d"
|
||||
dependencies = [
|
||||
"zerocopy-derive",
|
||||
"zerocopy-derive 0.7.31",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy"
|
||||
version = "0.8.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a4e33e6dce36f2adba29746927f8e848ba70989fdb61c772773bbdda8b5d6a7"
|
||||
dependencies = [
|
||||
"zerocopy-derive 0.8.8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7566,6 +7576,17 @@ dependencies = [
|
||||
"syn 2.0.52",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy-derive"
|
||||
version = "0.8.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3cd137b4cc21bde6ecce3bbbb3350130872cda0be2c6888874279ea76e17d4c1"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.52",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zeroize"
|
||||
version = "1.7.0"
|
||||
|
||||
@@ -98,6 +98,7 @@ rustls-native-certs.workspace = true
|
||||
x509-parser.workspace = true
|
||||
postgres-protocol.workspace = true
|
||||
redis.workspace = true
|
||||
zerocopy = { version = "0.8", features = ["derive"] }
|
||||
|
||||
# jwt stuff
|
||||
jose-jwa = "0.1.2"
|
||||
|
||||
@@ -11,6 +11,7 @@ use bytes::{Buf, Bytes, BytesMut};
|
||||
use pin_project_lite::pin_project;
|
||||
use strum_macros::FromRepr;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
|
||||
use zerocopy::{FromBytes, Immutable, KnownLayout, Unaligned};
|
||||
|
||||
pin_project! {
|
||||
/// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough
|
||||
@@ -87,6 +88,55 @@ pub enum ConnectionInfoExtra {
|
||||
Azure { link_id: u32 },
|
||||
}
|
||||
|
||||
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
|
||||
#[repr(C)]
|
||||
struct ProxyProtocolV2Header {
|
||||
identifier: [u8; 12],
|
||||
version_and_command: u8,
|
||||
protocol_and_family: u8,
|
||||
len: zerocopy::byteorder::network_endian::U16,
|
||||
}
|
||||
|
||||
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
|
||||
#[repr(C)]
|
||||
struct ProxyProtocolV2HeaderV4 {
|
||||
src_addr: NetworkEndianIpv4,
|
||||
dst_addr: NetworkEndianIpv4,
|
||||
src_port: zerocopy::byteorder::network_endian::U16,
|
||||
dst_port: zerocopy::byteorder::network_endian::U16,
|
||||
}
|
||||
|
||||
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
|
||||
#[repr(C)]
|
||||
struct ProxyProtocolV2HeaderV6 {
|
||||
src_addr: NetworkEndianIpv6,
|
||||
dst_addr: NetworkEndianIpv6,
|
||||
src_port: zerocopy::byteorder::network_endian::U16,
|
||||
dst_port: zerocopy::byteorder::network_endian::U16,
|
||||
}
|
||||
|
||||
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
|
||||
#[repr(transparent)]
|
||||
struct NetworkEndianIpv4(zerocopy::byteorder::network_endian::U32);
|
||||
|
||||
impl NetworkEndianIpv4 {
|
||||
#[inline]
|
||||
fn get(self) -> Ipv4Addr {
|
||||
Ipv4Addr::from_bits(self.0.get())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
|
||||
#[repr(transparent)]
|
||||
struct NetworkEndianIpv6(zerocopy::byteorder::network_endian::U128);
|
||||
|
||||
impl NetworkEndianIpv6 {
|
||||
#[inline]
|
||||
fn get(self) -> Ipv6Addr {
|
||||
Ipv6Addr::from_bits(self.0.get())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
|
||||
mut read: T,
|
||||
) -> std::io::Result<(ChainRW<T>, Option<ConnectionInfo>)> {
|
||||
@@ -106,14 +156,15 @@ pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
|
||||
};
|
||||
}
|
||||
|
||||
let header = buf.split_to(16);
|
||||
let header = buf
|
||||
.try_get::<ProxyProtocolV2Header>()
|
||||
.expect("we have checked the length already, so this should not panic");
|
||||
|
||||
// The next byte (the 13th one) is the protocol version and command.
|
||||
// The highest four bits contains the version. As of this specification, it must
|
||||
// always be sent as \x2 and the receiver must only accept this value.
|
||||
let vc = header[12];
|
||||
let version = vc >> 4;
|
||||
let command = vc & 0b1111;
|
||||
let version = header.version_and_command >> 4;
|
||||
let command = header.version_and_command & 0b1111;
|
||||
if version != 2 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
@@ -144,18 +195,13 @@ pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
|
||||
|
||||
// The 14th byte contains the transport protocol and address family. The highest 4
|
||||
// bits contain the address family, the lowest 4 bits contain the protocol.
|
||||
let ft = header[13];
|
||||
let address_length = match ft {
|
||||
// - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET
|
||||
// protocol family. Address length is 2*4 + 2*2 = 12 bytes.
|
||||
// - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET
|
||||
// protocol family. Address length is 2*4 + 2*2 = 12 bytes.
|
||||
0x11 | 0x12 => 12,
|
||||
// - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6
|
||||
// protocol family. Address length is 2*16 + 2*2 = 36 bytes.
|
||||
// - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6
|
||||
// protocol family. Address length is 2*16 + 2*2 = 36 bytes.
|
||||
0x21 | 0x22 => 36,
|
||||
let address_length = match header.protocol_and_family {
|
||||
// - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET protocol family.
|
||||
// - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET protocol family.
|
||||
0x11 | 0x12 => size_of::<ProxyProtocolV2HeaderV4>(),
|
||||
// - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6 protocol family.
|
||||
// - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6 protocol family.
|
||||
0x21 | 0x22 => size_of::<ProxyProtocolV2HeaderV6>(),
|
||||
// unspecified or unix stream. ignore the addresses
|
||||
_ => 0,
|
||||
};
|
||||
@@ -169,16 +215,15 @@ pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
|
||||
// of bytes and must not assume zero is presented for LOCAL connections. When a
|
||||
// receiver accepts an incoming connection showing an UNSPEC address family or
|
||||
// protocol, it may or may not decide to log the address information if present.
|
||||
let remaining_length = u16::from_be_bytes(header[14..16].try_into().unwrap());
|
||||
let remaining_length = usize::from(header.len.get());
|
||||
if remaining_length < address_length {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"invalid proxy protocol length. not enough to fit requested IP addresses",
|
||||
));
|
||||
}
|
||||
drop(header);
|
||||
|
||||
while buf.len() < remaining_length as usize {
|
||||
while buf.len() < remaining_length {
|
||||
if read.read_buf(&mut buf).await? == 0 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
@@ -193,22 +238,21 @@ pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
|
||||
// - destination layer 3 address in network byte order
|
||||
// - source layer 4 address if any, in network byte order (port)
|
||||
// - destination layer 4 address if any, in network byte order (port)
|
||||
let mut header = buf.split_to(usize::from(remaining_length));
|
||||
let mut addr = header.split_to(usize::from(address_length));
|
||||
let socket = match addr.len() {
|
||||
let mut header = buf.split_to(remaining_length);
|
||||
let socket = match address_length {
|
||||
12 => {
|
||||
let src_addr = Ipv4Addr::from_bits(addr.get_u32());
|
||||
let _dst_addr = Ipv4Addr::from_bits(addr.get_u32());
|
||||
let src_port = addr.get_u16();
|
||||
let _dst_port = addr.get_u16();
|
||||
Some(SocketAddr::from((src_addr, src_port)))
|
||||
let addr = header
|
||||
.try_get::<ProxyProtocolV2HeaderV4>()
|
||||
.expect("we have verified that 12 bytes are in the buf");
|
||||
|
||||
Some(SocketAddr::from((addr.src_addr.get(), addr.src_port.get())))
|
||||
}
|
||||
36 => {
|
||||
let src_addr = Ipv6Addr::from_bits(addr.get_u128());
|
||||
let _dst_addr = Ipv6Addr::from_bits(addr.get_u128());
|
||||
let src_port = addr.get_u16();
|
||||
let _dst_port = addr.get_u16();
|
||||
Some(SocketAddr::from((src_addr, src_port)))
|
||||
let addr = header
|
||||
.try_get::<ProxyProtocolV2HeaderV6>()
|
||||
.expect("we have verified that 36 bytes are in the buf");
|
||||
|
||||
Some(SocketAddr::from((addr.src_addr.get(), addr.src_port.get())))
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
@@ -330,6 +374,13 @@ impl<T: AsyncRead> ChainRW<T> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
|
||||
#[repr(C)]
|
||||
struct TlvHeader {
|
||||
kind: u8,
|
||||
len: zerocopy::byteorder::network_endian::U16,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Tlv {
|
||||
kind: u8,
|
||||
@@ -337,16 +388,44 @@ struct Tlv {
|
||||
}
|
||||
|
||||
fn read_tlv(b: &mut BytesMut) -> Option<Tlv> {
|
||||
if b.len() < 3 {
|
||||
return None;
|
||||
}
|
||||
let kind = b.get_u8();
|
||||
let len = usize::from(b.get_u16());
|
||||
let tlv_header = b.try_get::<TlvHeader>().ok()?;
|
||||
let len = usize::from(tlv_header.len.get());
|
||||
if b.len() < len {
|
||||
return None;
|
||||
}
|
||||
let value = b.split_to(len).freeze();
|
||||
Some(Tlv { kind, value })
|
||||
Some(Tlv {
|
||||
kind: tlv_header.kind,
|
||||
value,
|
||||
})
|
||||
}
|
||||
|
||||
trait BufExt: Sized {
|
||||
fn try_get<T: zerocopy::FromBytes>(&mut self)
|
||||
-> Result<T, zerocopy::error::SizeError<Self, T>>;
|
||||
|
||||
// fn peek<T: zerocopy::FromBytes + zerocopy::KnownLayout + zerocopy::Immutable>(
|
||||
// &self,
|
||||
// ) -> Option<&T>;
|
||||
}
|
||||
|
||||
impl BufExt for BytesMut {
|
||||
fn try_get<T: zerocopy::FromBytes>(
|
||||
&mut self,
|
||||
) -> Result<T, zerocopy::error::SizeError<Self, T>> {
|
||||
let len = size_of::<T>();
|
||||
// this will error in the read_from_bytes if the buf is too small
|
||||
let len = usize::min(len, self.len());
|
||||
let buf = self.split_to(len);
|
||||
|
||||
T::read_from_bytes(&buf).map_err(|e| e.map_src(|_| buf.clone()))
|
||||
}
|
||||
|
||||
// fn peek<T: zerocopy::FromBytes + zerocopy::KnownLayout + zerocopy::Immutable>(
|
||||
// &self,
|
||||
// ) -> Option<&T> {
|
||||
// T::ref_from_prefix(self).ok().map(|(t, _)| t)
|
||||
// }
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
Reference in New Issue
Block a user