From bea8269a175ffcb7b96511f55a46258d20768ebb Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 4 Nov 2024 13:07:27 +0000 Subject: [PATCH] zerocopy parsing --- Cargo.lock | 33 +++++++-- proxy/Cargo.toml | 1 + proxy/src/protocol2.rs | 155 +++++++++++++++++++++++++++++++---------- 3 files changed, 145 insertions(+), 44 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 51cb80d985..04183df6d6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -34,7 +34,7 @@ dependencies = [ "getrandom 0.2.11", "once_cell", "version_check", - "zerocopy", + "zerocopy 0.7.31", ] [[package]] @@ -951,7 +951,7 @@ dependencies = [ "bitflags 2.4.1", "cexpr", "clang-sys", - "itertools 0.10.5", + "itertools 0.12.1", "lazy_static", "lazycell", "log", @@ -974,7 +974,7 @@ dependencies = [ "bitflags 2.4.1", "cexpr", "clang-sys", - "itertools 0.10.5", + "itertools 0.12.1", "log", "prettyplease", "proc-macro2", @@ -4314,7 +4314,7 @@ checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15" dependencies = [ "bytes", "heck 0.5.0", - "itertools 0.10.5", + "itertools 0.12.1", "log", "multimap", "once_cell", @@ -4334,7 +4334,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e9552f850d5f0964a4e4d0bf306459ac29323ddfbae05e35a7c0d35cb0803cc5" dependencies = [ "anyhow", - "itertools 0.10.5", + "itertools 0.12.1", "proc-macro2", "quote", "syn 2.0.52", @@ -4457,6 +4457,7 @@ dependencies = [ "walkdir", "workspace_hack", "x509-parser", + "zerocopy 0.8.8", ] [[package]] @@ -7552,7 +7553,16 @@ version = "0.7.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1c4061bedbb353041c12f413700357bec76df2c7e2ca8e4df8bac24c6bf68e3d" dependencies = [ - "zerocopy-derive", + "zerocopy-derive 0.7.31", +] + +[[package]] +name = "zerocopy" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a4e33e6dce36f2adba29746927f8e848ba70989fdb61c772773bbdda8b5d6a7" +dependencies = [ + "zerocopy-derive 0.8.8", ] [[package]] @@ -7566,6 +7576,17 @@ dependencies = [ "syn 2.0.52", ] +[[package]] +name = "zerocopy-derive" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3cd137b4cc21bde6ecce3bbbb3350130872cda0be2c6888874279ea76e17d4c1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.52", +] + [[package]] name = "zeroize" version = "1.7.0" diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 2580b1cf8a..917056099b 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -98,6 +98,7 @@ rustls-native-certs.workspace = true x509-parser.workspace = true postgres-protocol.workspace = true redis.workspace = true +zerocopy = { version = "0.8", features = ["derive"] } # jwt stuff jose-jwa = "0.1.2" diff --git a/proxy/src/protocol2.rs b/proxy/src/protocol2.rs index d1084ca2ff..6269c1970c 100644 --- a/proxy/src/protocol2.rs +++ b/proxy/src/protocol2.rs @@ -11,6 +11,7 @@ use bytes::{Buf, Bytes, BytesMut}; use pin_project_lite::pin_project; use strum_macros::FromRepr; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf}; +use zerocopy::{FromBytes, Immutable, KnownLayout, Unaligned}; pin_project! { /// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough @@ -87,6 +88,55 @@ pub enum ConnectionInfoExtra { Azure { link_id: u32 }, } +#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)] +#[repr(C)] +struct ProxyProtocolV2Header { + identifier: [u8; 12], + version_and_command: u8, + protocol_and_family: u8, + len: zerocopy::byteorder::network_endian::U16, +} + +#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)] +#[repr(C)] +struct ProxyProtocolV2HeaderV4 { + src_addr: NetworkEndianIpv4, + dst_addr: NetworkEndianIpv4, + src_port: zerocopy::byteorder::network_endian::U16, + dst_port: zerocopy::byteorder::network_endian::U16, +} + +#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)] +#[repr(C)] +struct ProxyProtocolV2HeaderV6 { + src_addr: NetworkEndianIpv6, + dst_addr: NetworkEndianIpv6, + src_port: zerocopy::byteorder::network_endian::U16, + dst_port: zerocopy::byteorder::network_endian::U16, +} + +#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)] +#[repr(transparent)] +struct NetworkEndianIpv4(zerocopy::byteorder::network_endian::U32); + +impl NetworkEndianIpv4 { + #[inline] + fn get(self) -> Ipv4Addr { + Ipv4Addr::from_bits(self.0.get()) + } +} + +#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)] +#[repr(transparent)] +struct NetworkEndianIpv6(zerocopy::byteorder::network_endian::U128); + +impl NetworkEndianIpv6 { + #[inline] + fn get(self) -> Ipv6Addr { + Ipv6Addr::from_bits(self.0.get()) + } +} + pub(crate) async fn read_proxy_protocol( mut read: T, ) -> std::io::Result<(ChainRW, Option)> { @@ -106,14 +156,15 @@ pub(crate) async fn read_proxy_protocol( }; } - let header = buf.split_to(16); + let header = buf + .try_get::() + .expect("we have checked the length already, so this should not panic"); // The next byte (the 13th one) is the protocol version and command. // The highest four bits contains the version. As of this specification, it must // always be sent as \x2 and the receiver must only accept this value. - let vc = header[12]; - let version = vc >> 4; - let command = vc & 0b1111; + let version = header.version_and_command >> 4; + let command = header.version_and_command & 0b1111; if version != 2 { return Err(io::Error::new( io::ErrorKind::Other, @@ -144,18 +195,13 @@ pub(crate) async fn read_proxy_protocol( // The 14th byte contains the transport protocol and address family. The highest 4 // bits contain the address family, the lowest 4 bits contain the protocol. - let ft = header[13]; - let address_length = match ft { - // - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET - // protocol family. Address length is 2*4 + 2*2 = 12 bytes. - // - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET - // protocol family. Address length is 2*4 + 2*2 = 12 bytes. - 0x11 | 0x12 => 12, - // - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6 - // protocol family. Address length is 2*16 + 2*2 = 36 bytes. - // - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6 - // protocol family. Address length is 2*16 + 2*2 = 36 bytes. - 0x21 | 0x22 => 36, + let address_length = match header.protocol_and_family { + // - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET protocol family. + // - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET protocol family. + 0x11 | 0x12 => size_of::(), + // - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6 protocol family. + // - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6 protocol family. + 0x21 | 0x22 => size_of::(), // unspecified or unix stream. ignore the addresses _ => 0, }; @@ -169,16 +215,15 @@ pub(crate) async fn read_proxy_protocol( // of bytes and must not assume zero is presented for LOCAL connections. When a // receiver accepts an incoming connection showing an UNSPEC address family or // protocol, it may or may not decide to log the address information if present. - let remaining_length = u16::from_be_bytes(header[14..16].try_into().unwrap()); + let remaining_length = usize::from(header.len.get()); if remaining_length < address_length { return Err(io::Error::new( io::ErrorKind::Other, "invalid proxy protocol length. not enough to fit requested IP addresses", )); } - drop(header); - while buf.len() < remaining_length as usize { + while buf.len() < remaining_length { if read.read_buf(&mut buf).await? == 0 { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, @@ -193,22 +238,21 @@ pub(crate) async fn read_proxy_protocol( // - destination layer 3 address in network byte order // - source layer 4 address if any, in network byte order (port) // - destination layer 4 address if any, in network byte order (port) - let mut header = buf.split_to(usize::from(remaining_length)); - let mut addr = header.split_to(usize::from(address_length)); - let socket = match addr.len() { + let mut header = buf.split_to(remaining_length); + let socket = match address_length { 12 => { - let src_addr = Ipv4Addr::from_bits(addr.get_u32()); - let _dst_addr = Ipv4Addr::from_bits(addr.get_u32()); - let src_port = addr.get_u16(); - let _dst_port = addr.get_u16(); - Some(SocketAddr::from((src_addr, src_port))) + let addr = header + .try_get::() + .expect("we have verified that 12 bytes are in the buf"); + + Some(SocketAddr::from((addr.src_addr.get(), addr.src_port.get()))) } 36 => { - let src_addr = Ipv6Addr::from_bits(addr.get_u128()); - let _dst_addr = Ipv6Addr::from_bits(addr.get_u128()); - let src_port = addr.get_u16(); - let _dst_port = addr.get_u16(); - Some(SocketAddr::from((src_addr, src_port))) + let addr = header + .try_get::() + .expect("we have verified that 36 bytes are in the buf"); + + Some(SocketAddr::from((addr.src_addr.get(), addr.src_port.get()))) } _ => None, }; @@ -330,6 +374,13 @@ impl ChainRW { } } +#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)] +#[repr(C)] +struct TlvHeader { + kind: u8, + len: zerocopy::byteorder::network_endian::U16, +} + #[derive(Debug)] struct Tlv { kind: u8, @@ -337,16 +388,44 @@ struct Tlv { } fn read_tlv(b: &mut BytesMut) -> Option { - if b.len() < 3 { - return None; - } - let kind = b.get_u8(); - let len = usize::from(b.get_u16()); + let tlv_header = b.try_get::().ok()?; + let len = usize::from(tlv_header.len.get()); if b.len() < len { return None; } let value = b.split_to(len).freeze(); - Some(Tlv { kind, value }) + Some(Tlv { + kind: tlv_header.kind, + value, + }) +} + +trait BufExt: Sized { + fn try_get(&mut self) + -> Result>; + + // fn peek( + // &self, + // ) -> Option<&T>; +} + +impl BufExt for BytesMut { + fn try_get( + &mut self, + ) -> Result> { + let len = size_of::(); + // this will error in the read_from_bytes if the buf is too small + let len = usize::min(len, self.len()); + let buf = self.split_to(len); + + T::read_from_bytes(&buf).map_err(|e| e.map_src(|_| buf.clone())) + } + + // fn peek( + // &self, + // ) -> Option<&T> { + // T::ref_from_prefix(self).ok().map(|(t, _)| t) + // } } #[cfg(test)]