//! Postgres protocol codec //! //! use std::fmt; use std::io::{self, Cursor}; use bytes::{Buf, BufMut}; use itertools::Itertools; use rand::distr::{Distribution, StandardUniform}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian}; pub type ErrorCode = [u8; 5]; pub const FE_PASSWORD_MESSAGE: u8 = b'p'; pub const SQLSTATE_INTERNAL_ERROR: [u8; 5] = *b"XX000"; /// The protocol version number. /// /// The most significant 16 bits are the major version number (3 for the protocol described here). /// The least significant 16 bits are the minor version number (0 for the protocol described here). /// #[derive(Clone, Copy, PartialEq, PartialOrd, FromBytes, IntoBytes, Immutable)] #[repr(C)] pub struct ProtocolVersion { major: big_endian::U16, minor: big_endian::U16, } impl ProtocolVersion { pub const fn new(major: u16, minor: u16) -> Self { Self { major: big_endian::U16::new(major), minor: big_endian::U16::new(minor), } } pub const fn minor(self) -> u16 { self.minor.get() } pub const fn major(self) -> u16 { self.major.get() } } impl fmt::Debug for ProtocolVersion { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_list() .entry(&self.major()) .entry(&self.minor()) .finish() } } /// const MAX_STARTUP_PACKET_LENGTH: usize = 10000; const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234; /// const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678); /// const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679); /// const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680); /// This first reads the startup message header, is 8 bytes. /// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number. /// /// The length value is inclusive of the header. For example, /// an empty message will always have length 8. #[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)] #[repr(C)] struct StartupHeader { len: big_endian::U32, version: ProtocolVersion, } /// read the type from the stream using zerocopy. /// /// not cancel safe. macro_rules! read { ($s:expr => $t:ty) => {{ // cannot be implemented as a function due to lack of const-generic-expr let mut buf = [0; size_of::<$t>()]; $s.read_exact(&mut buf).await?; let res: $t = zerocopy::transmute!(buf); res }}; } /// Returns true if TLS is supported. /// /// This is not cancel safe. pub async fn request_tls(stream: &mut S) -> io::Result where S: AsyncRead + AsyncWrite + Unpin, { let payload = StartupHeader { len: 8.into(), version: NEGOTIATE_SSL_CODE, }; stream.write_all(payload.as_bytes()).await?; stream.flush().await?; // we expect back either `S` or `N` as a single byte. let mut res = *b"0"; stream.read_exact(&mut res).await?; debug_assert!( res == *b"S" || res == *b"N", "unexpected SSL negotiation response: {}", char::from(res[0]), ); // S for SSL. Ok(res == *b"S") } pub async fn read_startup(stream: &mut S) -> io::Result where S: AsyncRead + Unpin, { let header = read!(stream => StartupHeader); // // First byte indicates standard SSL handshake message // (It can't be a Postgres startup length because in network byte order // that would be a startup packet hundreds of megabytes long) if header.as_bytes()[0] == 0x16 { return Ok(FeStartupPacket::SslRequest { // The bytes we read for the header are actually part of a TLS ClientHello. // In theory, if the ClientHello was < 8 bytes we would fail with EOF before we get here. // In practice though, I see no world where a ClientHello is less than 8 bytes // since it includes ephemeral keys etc. direct: Some(zerocopy::transmute!(header)), }); } let Some(len) = (header.len.get() as usize).checked_sub(8) else { return Err(io::Error::other(format!( "invalid startup message length {}, must be at least 8.", header.len, ))); }; // TODO: add a histogram for startup packet lengths if len > MAX_STARTUP_PACKET_LENGTH { tracing::warn!("large startup message detected: {len} bytes"); return Err(io::Error::other(format!( "invalid startup message length {len}" ))); } match header.version { // CANCEL_REQUEST_CODE => { if len != 8 { return Err(io::Error::other( "CancelRequest message is malformed, backend PID / secret key missing", )); } Ok(FeStartupPacket::CancelRequest( read!(stream => CancelKeyData), )) } // NEGOTIATE_SSL_CODE => { // Requested upgrade to SSL (aka TLS) Ok(FeStartupPacket::SslRequest { direct: None }) } NEGOTIATE_GSS_CODE => { // Requested upgrade to GSSAPI Ok(FeStartupPacket::GssEncRequest) } version if version.major() == RESERVED_INVALID_MAJOR_VERSION => Err(io::Error::other( format!("Unrecognized request code {version:?}"), )), // StartupMessage version => { // The protocol version number is followed by one or more pairs of parameter name and value strings. // A zero byte is required as a terminator after the last name/value pair. // Parameters can appear in any order. user is required, others are optional. let mut buf = vec![0; len]; stream.read_exact(&mut buf).await?; if buf.pop() != Some(b'\0') { return Err(io::Error::other( "StartupMessage params: missing null terminator", )); } // TODO: Don't do this. // There's no guarantee that these messages are utf8, // but they usually happen to be simple ascii. let params = String::from_utf8(buf) .map_err(|_| io::Error::other("StartupMessage params: invalid utf-8"))?; Ok(FeStartupPacket::StartupMessage { version, params: StartupMessageParams { params }, }) } } } /// Read a raw postgres packet, which will respect the max length requested. /// /// This returns the message tag, as well as the message body. The message /// body is written into `buf`, and it is otherwise completely overwritten. /// /// This is not cancel safe. pub async fn read_message<'a, S>( stream: &mut S, buf: &'a mut Vec, max: u32, ) -> io::Result<(u8, &'a mut [u8])> where S: AsyncRead + Unpin, { /// This first reads the header, which for regular messages in the 3.0 protocol is 5 bytes. /// The first byte is a message tag, and the next 4 bytes is a big-endian length. /// /// Awkwardly, the length value is inclusive of itself, but not of the tag. For example, /// an empty message will always have length 4. #[derive(Clone, Copy, FromBytes)] #[repr(C)] struct Header { tag: u8, len: big_endian::U32, } let header = read!(stream => Header); // as described above, the length must be at least 4. let Some(len) = header.len.get().checked_sub(4) else { return Err(io::Error::other(format!( "invalid startup message length {}, must be at least 4.", header.len, ))); }; // TODO: add a histogram for message lengths // check if the message exceeds our desired max. if len > max { tracing::warn!("large postgres message detected: {len} bytes"); return Err(io::Error::other(format!("invalid message length {len}"))); } // read in our entire message. buf.resize(len as usize, 0); stream.read_exact(buf).await?; Ok((header.tag, buf)) } pub struct WriteBuf(Cursor>); impl Buf for WriteBuf { #[inline] fn remaining(&self) -> usize { self.0.remaining() } #[inline] fn chunk(&self) -> &[u8] { self.0.chunk() } #[inline] fn advance(&mut self, cnt: usize) { self.0.advance(cnt); } } impl WriteBuf { pub const fn new() -> Self { Self(Cursor::new(Vec::new())) } /// Use a heuristic to determine if we should shrink the write buffer. #[inline] fn should_shrink(&self) -> bool { let n = self.0.position() as usize; let len = self.0.get_ref().len(); // the unused space at the front of our buffer is 2x the size of our filled portion. n + n > len } /// Shrink the write buffer so that subsequent writes have more spare capacity. #[cold] fn shrink(&mut self) { let n = self.0.position() as usize; let buf = self.0.get_mut(); // buf repr: // [----unused------|-----filled-----|-----uninit-----] // ^ n ^ buf.len() ^ buf.capacity() let filled = n..buf.len(); let filled_len = filled.len(); buf.copy_within(filled, 0); buf.truncate(filled_len); self.0.set_position(0); } /// clear the write buffer. pub fn reset(&mut self) { let buf = self.0.get_mut(); buf.clear(); self.0.set_position(0); } /// Shrinks the buffer if efficient to do so, and returns the remaining size. pub fn occupied_len(&mut self) -> usize { if self.should_shrink() { self.shrink(); } self.0.get_mut().len() } /// Write a raw message to the internal buffer. /// /// The size_hint value is only a hint for reserving space. It's ok if it's incorrect, since /// we calculate the length after the fact. pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec)) { if self.should_shrink() { self.shrink(); } let buf = self.0.get_mut(); buf.reserve(5 + size_hint); buf.push(tag); let start = buf.len(); buf.extend_from_slice(&[0, 0, 0, 0]); f(buf); let end = buf.len(); let len = (end - start) as u32; buf[start..start + 4].copy_from_slice(&len.to_be_bytes()); } /// Write an encryption response message. pub fn encryption(&mut self, m: u8) { self.0.get_mut().push(m); } pub fn write_error(&mut self, msg: &str, error_code: ErrorCode) { self.shrink(); // // // "SERROR\0CXXXXX\0M\0\0".len() == 17 self.write_raw(17 + msg.len(), b'E', |buf| { // Severity: ERROR buf.put_slice(b"SERROR\0"); // Code: error_code buf.put_u8(b'C'); buf.put_slice(&error_code); buf.put_u8(0); // Message: msg buf.put_u8(b'M'); buf.put_slice(msg.as_bytes()); buf.put_u8(0); // End. buf.put_u8(0); }); } } #[derive(Debug)] pub enum FeStartupPacket { CancelRequest(CancelKeyData), SslRequest { direct: Option<[u8; 8]>, }, GssEncRequest, StartupMessage { version: ProtocolVersion, params: StartupMessageParams, }, } #[derive(Debug, Clone, Default)] pub struct StartupMessageParams { pub params: String, } impl StartupMessageParams { /// Get parameter's value by its name. pub fn get(&self, name: &str) -> Option<&str> { self.iter().find_map(|(k, v)| (k == name).then_some(v)) } /// Split command-line options according to PostgreSQL's logic, /// taking into account all escape sequences but leaving them as-is. /// [`None`] means that there's no `options` in [`Self`]. pub fn options_raw(&self) -> Option> { self.get("options").map(Self::parse_options_raw) } /// Split command-line options according to PostgreSQL's logic, /// taking into account all escape sequences but leaving them as-is. pub fn parse_options_raw(input: &str) -> impl Iterator { // See `postgres: pg_split_opts`. let mut last_was_escape = false; input .split(move |c: char| { // We split by non-escaped whitespace symbols. let should_split = c.is_ascii_whitespace() && !last_was_escape; last_was_escape = c == '\\' && !last_was_escape; should_split }) .filter(|s| !s.is_empty()) } /// Iterate through key-value pairs in an arbitrary order. pub fn iter(&self) -> impl Iterator { self.params.split_terminator('\0').tuples() } // This function is mostly useful in tests. #[cfg(test)] pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self { let mut b = Self { params: String::new(), }; for (k, v) in pairs { b.insert(k, v); } b } /// Set parameter's value by its name. /// name and value must not contain a \0 byte pub fn insert(&mut self, name: &str, value: &str) { self.params.reserve(name.len() + value.len() + 2); self.params.push_str(name); self.params.push('\0'); self.params.push_str(value); self.params.push('\0'); } } /// Cancel keys usually are represented as PID+SecretKey, but to proxy they're just /// opaque bytes. #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, FromBytes, IntoBytes, Immutable)] pub struct CancelKeyData(pub big_endian::U64); pub fn id_to_cancel_key(id: u64) -> CancelKeyData { CancelKeyData(big_endian::U64::new(id)) } impl fmt::Display for CancelKeyData { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let id = self.0; f.debug_tuple("CancelKeyData") .field(&format_args!("{id:x}")) .finish() } } impl Distribution for StandardUniform { fn sample(&self, rng: &mut R) -> CancelKeyData { id_to_cancel_key(rng.random()) } } pub enum BeMessage<'a> { AuthenticationOk, AuthenticationSasl(BeAuthenticationSaslMessage<'a>), AuthenticationCleartextPassword, BackendKeyData(CancelKeyData), ParameterStatus { name: &'a [u8], value: &'a [u8], }, ReadyForQuery, NoticeResponse(&'a str), NegotiateProtocolVersion { version: ProtocolVersion, options: &'a [&'a str], }, } #[derive(Debug)] pub enum BeAuthenticationSaslMessage<'a> { Methods(&'a [&'a str]), Continue(&'a [u8]), Final(&'a [u8]), } impl BeMessage<'_> { /// Write the message into an internal buffer pub fn write_message(self, buf: &mut WriteBuf) { match self { // BeMessage::AuthenticationOk => { buf.write_raw(1, b'R', |buf| buf.put_i32(0)); } // BeMessage::AuthenticationCleartextPassword => { buf.write_raw(1, b'R', |buf| buf.put_i32(3)); } // BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(methods)) => { let len: usize = methods.iter().map(|m| m.len() + 1).sum(); buf.write_raw(len + 2, b'R', |buf| { buf.put_i32(10); // Specifies that SASL auth method is used. for method in methods { buf.put_slice(method.as_bytes()); buf.put_u8(0); } buf.put_u8(0); // zero terminator for the list }); } // BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Continue(extra)) => { buf.write_raw(extra.len() + 1, b'R', |buf| { buf.put_i32(11); // Continue SASL auth. buf.put_slice(extra); }); } // BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Final(extra)) => { buf.write_raw(extra.len() + 1, b'R', |buf| { buf.put_i32(12); // Send final SASL message. buf.put_slice(extra); }); } // BeMessage::BackendKeyData(key_data) => { buf.write_raw(8, b'K', |buf| buf.put_slice(key_data.as_bytes())); } // // BeMessage::NoticeResponse(msg) => { // 'N' signalizes NoticeResponse messages buf.write_raw(18 + msg.len(), b'N', |buf| { // Severity: NOTICE buf.put_slice(b"SNOTICE\0"); // Code: XX000 (ignored for notice, but still required) buf.put_slice(b"CXX000\0"); // Message: msg buf.put_u8(b'M'); buf.put_slice(msg.as_bytes()); buf.put_u8(0); // End notice. buf.put_u8(0); }); } // BeMessage::ParameterStatus { name, value } => { buf.write_raw(name.len() + value.len() + 2, b'S', |buf| { buf.put_slice(name.as_bytes()); buf.put_u8(0); buf.put_slice(value.as_bytes()); buf.put_u8(0); }); } // BeMessage::ReadyForQuery => { buf.write_raw(1, b'Z', |buf| buf.put_u8(b'I')); } // BeMessage::NegotiateProtocolVersion { version, options } => { let len: usize = options.iter().map(|o| o.len() + 1).sum(); buf.write_raw(8 + len, b'v', |buf| { buf.put_slice(version.as_bytes()); buf.put_u32(options.len() as u32); for option in options { buf.put_slice(option.as_bytes()); buf.put_u8(0); } }); } } } } #[cfg(test)] mod tests { use std::io::Cursor; use tokio::io::{AsyncWriteExt, duplex}; use zerocopy::IntoBytes; use super::ProtocolVersion; use crate::pqproto::{FeStartupPacket, read_message, read_startup}; #[tokio::test] async fn reject_large_startup() { // we're going to define a v3.0 startup message with far too many parameters. let mut payload = vec![]; // 10001 + 8 bytes. payload.extend_from_slice(&10009_u32.to_be_bytes()); payload.extend_from_slice(ProtocolVersion::new(3, 0).as_bytes()); payload.resize(10009, b'a'); let (mut server, mut client) = duplex(128); #[rustfmt::skip] let (server, client) = tokio::join!( async move { read_startup(&mut server).await.unwrap_err() }, async move { client.write_all(&payload).await.unwrap_err() }, ); assert_eq!(server.to_string(), "invalid startup message length 10001"); assert_eq!(client.to_string(), "broken pipe"); } #[tokio::test] async fn reject_large_password() { // we're going to define a password message that is far too long. let mut payload = vec![]; payload.push(b'p'); payload.extend_from_slice(&517_u32.to_be_bytes()); payload.resize(518, b'a'); let (mut server, mut client) = duplex(128); #[rustfmt::skip] let (server, client) = tokio::join!( async move { read_message(&mut server, &mut vec![], 512).await.unwrap_err() }, async move { client.write_all(&payload).await.unwrap_err() }, ); assert_eq!(server.to_string(), "invalid message length 513"); assert_eq!(client.to_string(), "broken pipe"); } #[tokio::test] async fn read_startup_message() { let mut payload = vec![]; payload.extend_from_slice(&17_u32.to_be_bytes()); payload.extend_from_slice(ProtocolVersion::new(3, 0).as_bytes()); payload.extend_from_slice(b"abc\0def\0\0"); let startup = read_startup(&mut Cursor::new(&payload)).await.unwrap(); let FeStartupPacket::StartupMessage { version, params } = startup else { panic!("unexpected startup message: {startup:?}"); }; assert_eq!(version.major(), 3); assert_eq!(version.minor(), 0); assert_eq!(params.params, "abc\0def\0"); } #[tokio::test] async fn read_ssl_message() { let mut payload = vec![]; payload.extend_from_slice(&8_u32.to_be_bytes()); payload.extend_from_slice(ProtocolVersion::new(1234, 5679).as_bytes()); let startup = read_startup(&mut Cursor::new(&payload)).await.unwrap(); let FeStartupPacket::SslRequest { direct: None } = startup else { panic!("unexpected startup message: {startup:?}"); }; } #[tokio::test] async fn read_tls_message() { // sample client hello taken from let client_hello = [ 0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00, 0xf4, 0x03, 0x03, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef, 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, 0x00, 0x08, 0x13, 0x02, 0x13, 0x03, 0x13, 0x01, 0x00, 0xff, 0x01, 0x00, 0x00, 0xa3, 0x00, 0x00, 0x00, 0x18, 0x00, 0x16, 0x00, 0x00, 0x13, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x75, 0x6c, 0x66, 0x68, 0x65, 0x69, 0x6d, 0x2e, 0x6e, 0x65, 0x74, 0x00, 0x0b, 0x00, 0x04, 0x03, 0x00, 0x01, 0x02, 0x00, 0x0a, 0x00, 0x16, 0x00, 0x14, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x1e, 0x00, 0x19, 0x00, 0x18, 0x01, 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x23, 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x1e, 0x00, 0x1c, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x08, 0x07, 0x08, 0x08, 0x08, 0x09, 0x08, 0x0a, 0x08, 0x0b, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06, 0x04, 0x01, 0x05, 0x01, 0x06, 0x01, 0x00, 0x2b, 0x00, 0x03, 0x02, 0x03, 0x04, 0x00, 0x2d, 0x00, 0x02, 0x01, 0x01, 0x00, 0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0x35, 0x80, 0x72, 0xd6, 0x36, 0x58, 0x80, 0xd1, 0xae, 0xea, 0x32, 0x9a, 0xdf, 0x91, 0x21, 0x38, 0x38, 0x51, 0xed, 0x21, 0xa2, 0x8e, 0x3b, 0x75, 0xe9, 0x65, 0xd0, 0xd2, 0xcd, 0x16, 0x62, 0x54, ]; let mut cursor = Cursor::new(&client_hello); let startup = read_startup(&mut cursor).await.unwrap(); let FeStartupPacket::SslRequest { direct: Some(prefix), } = startup else { panic!("unexpected startup message: {startup:?}"); }; // check that no data is lost. assert_eq!(prefix, [0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00]); assert_eq!(cursor.position(), 8); } #[tokio::test] async fn read_message_success() { let query = b"Q\0\0\0\x0cSELECT 1Q\0\0\0\x0cSELECT 2"; let mut cursor = Cursor::new(&query); let mut buf = vec![]; let (tag, message) = read_message(&mut cursor, &mut buf, 100).await.unwrap(); assert_eq!(tag, b'Q'); assert_eq!(message, b"SELECT 1"); let (tag, message) = read_message(&mut cursor, &mut buf, 100).await.unwrap(); assert_eq!(tag, b'Q'); assert_eq!(message, b"SELECT 2"); } }