diff --git a/Cargo.lock b/Cargo.lock index e380e72dc0..b96f7dbc99 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2747,7 +2747,7 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" name = "pq_proto" version = "0.1.0" dependencies = [ - "anyhow", + "byteorder", "bytes", "pin-project-lite", "postgres-protocol", diff --git a/libs/postgres_backend/src/lib.rs b/libs/postgres_backend/src/lib.rs index ba28add9f9..ce46899779 100644 --- a/libs/postgres_backend/src/lib.rs +++ b/libs/postgres_backend/src/lib.rs @@ -17,9 +17,9 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio_rustls::TlsAcceptor; use tracing::{debug, error, info, trace}; -use pq_proto::framed::{Framed, FramedReader, FramedWriter}; +use pq_proto::framed::{ConnectionError, Framed, FramedReader, FramedWriter}; use pq_proto::{ - BeMessage, ConnectionError, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR, + BeMessage, FeMessage, FeStartupPacket, ProtocolError, SQLSTATE_INTERNAL_ERROR, SQLSTATE_SUCCESSFUL_COMPLETION, }; @@ -37,7 +37,7 @@ pub enum QueryError { impl From for QueryError { fn from(e: io::Error) -> Self { - Self::Disconnected(ConnectionError::Socket(e)) + Self::Disconnected(ConnectionError::Io(e)) } } @@ -219,7 +219,7 @@ impl MaybeWriteOnly { } } - fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ConnectionError> { + fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> { match self { MaybeWriteOnly::Full(framed) => framed.write_message(msg), MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.write_message_noflush(msg), @@ -701,8 +701,7 @@ impl PostgresBackend { FeMessage::CopyData(_) | FeMessage::CopyDone | FeMessage::CopyFail - | FeMessage::PasswordMessage(_) - | FeMessage::StartupPacket(_) => { + | FeMessage::PasswordMessage(_) => { return Err(QueryError::Other(anyhow::anyhow!( "unexpected message type: {msg:?}", ))); @@ -721,7 +720,7 @@ impl PostgresBackend { let expected_end = match &end { ServerInitiated(_) | CopyDone | CopyFail | Terminate | EOF => true, - CopyStreamHandlerEnd::Disconnected(ConnectionError::Socket(io_error)) + CopyStreamHandlerEnd::Disconnected(ConnectionError::Io(io_error)) if is_expected_io_error(io_error) => { true @@ -800,7 +799,7 @@ impl PostgresBackendReader { FeMessage::CopyFail => Err(CopyStreamHandlerEnd::CopyFail), FeMessage::Terminate => Err(CopyStreamHandlerEnd::Terminate), _ => Err(CopyStreamHandlerEnd::from(ConnectionError::Protocol( - format!("unexpected message in COPY stream {:?}", msg), + ProtocolError::Protocol(format!("unexpected message in COPY stream {:?}", msg)), ))), }, None => Err(CopyStreamHandlerEnd::EOF), @@ -871,7 +870,7 @@ pub fn short_error(e: &QueryError) -> String { fn log_query_error(query: &str, e: &QueryError) { match e { - QueryError::Disconnected(ConnectionError::Socket(io_error)) => { + QueryError::Disconnected(ConnectionError::Io(io_error)) => { if is_expected_io_error(io_error) { info!("query handler for '{query}' failed with expected io error: {io_error}"); } else { diff --git a/libs/pq_proto/Cargo.toml b/libs/pq_proto/Cargo.toml index bc90a7a2c1..76b71729ed 100644 --- a/libs/pq_proto/Cargo.toml +++ b/libs/pq_proto/Cargo.toml @@ -5,8 +5,8 @@ edition.workspace = true license.workspace = true [dependencies] -anyhow.workspace = true bytes.workspace = true +byteorder.workspace = true pin-project-lite.workspace = true postgres-protocol.workspace = true rand.workspace = true diff --git a/libs/pq_proto/src/framed.rs b/libs/pq_proto/src/framed.rs index 7c33222e6e..972730cbab 100644 --- a/libs/pq_proto/src/framed.rs +++ b/libs/pq_proto/src/framed.rs @@ -1,51 +1,84 @@ //! Provides `Framed` -- writing/flushing and reading Postgres messages to/from -//! the async stream. +//! the async stream based on (and buffered with) BytesMut. All functions are +//! cancellation safe. +//! +//! It is similar to what tokio_util::codec::Framed with appropriate codec +//! provides, but `FramedReader` and `FramedWriter` read/write parts can be used +//! separately without using split from futures::stream::StreamExt (which +//! allocates box[1] in polling internally). tokio::io::split is used for splitting +//! instead. Plus we customize error messages more than a single type for all io +//! calls. +//! +//! [1] https://docs.rs/futures-util/0.3.26/src/futures_util/lock/bilock.rs.html#107 use bytes::{Buf, BytesMut}; use std::{ future::Future, io::{self, ErrorKind}, }; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; -use crate::{BeMessage, ConnectionError, FeMessage, FeStartupPacket}; +use crate::{BeMessage, FeMessage, FeStartupPacket, ProtocolError}; const INITIAL_CAPACITY: usize = 8 * 1024; +/// Error on postgres connection: either IO (physical transport error) or +/// protocol violation. +#[derive(thiserror::Error, Debug)] +pub enum ConnectionError { + #[error(transparent)] + Io(#[from] io::Error), + #[error(transparent)] + Protocol(#[from] ProtocolError), +} + +impl ConnectionError { + /// Proxy stream.rs uses only io::Error; provide it. + pub fn into_io_error(self) -> io::Error { + match self { + ConnectionError::Io(io) => io, + ConnectionError::Protocol(pe) => io::Error::new(io::ErrorKind::Other, pe.to_string()), + } + } +} + /// Wraps async io `stream`, providing messages to write/flush + read Postgres /// messages. pub struct Framed { - stream: BufReader, + stream: S, + read_buf: BytesMut, write_buf: BytesMut, } -impl Framed { +impl Framed { pub fn new(stream: S) -> Self { Self { - stream: BufReader::new(stream), + stream, + read_buf: BytesMut::with_capacity(INITIAL_CAPACITY), write_buf: BytesMut::with_capacity(INITIAL_CAPACITY), } } /// Get a shared reference to the underlying stream. pub fn get_ref(&self) -> &S { - self.stream.get_ref() + &self.stream } /// Extract the underlying stream. pub fn into_inner(self) -> S { - self.stream.into_inner() + self.stream } /// Return new Framed with stream type transformed by async f, for TLS /// upgrade. - pub async fn map_stream(self, f: F) -> Result, E> + pub async fn map_stream(self, f: F) -> Result, E> where F: FnOnce(S) -> Fut, Fut: Future>, { - let stream = f(self.stream.into_inner()).await?; + let stream = f(self.stream).await?; Ok(Framed { - stream: BufReader::new(stream), + stream, + read_buf: self.read_buf, write_buf: self.write_buf, }) } @@ -55,24 +88,18 @@ impl Framed { pub async fn read_startup_message( &mut self, ) -> Result, ConnectionError> { - let msg = FeStartupPacket::read(&mut self.stream).await?; - - match msg { - Some(FeMessage::StartupPacket(packet)) => Ok(Some(packet)), - None => Ok(None), - _ => panic!("unreachable state"), - } + read_message(&mut self.stream, &mut self.read_buf, FeStartupPacket::parse).await } pub async fn read_message(&mut self) -> Result, ConnectionError> { - FeMessage::read(&mut self.stream).await + read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await } } -impl Framed { +impl Framed { /// Write next message to the output buffer; doesn't flush. - pub fn write_message(&mut self, msg: &BeMessage<'_>) -> Result<(), ConnectionError> { - BeMessage::write(&mut self.write_buf, msg).map_err(|e| e.into()) + pub fn write_message(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> { + BeMessage::write(&mut self.write_buf, msg) } /// Flush out the buffer. This function is cancellation safe: it can be @@ -93,7 +120,10 @@ impl Framed { /// https://github.com/tokio-rs/tls/issues/40 pub fn split(self) -> (FramedReader, FramedWriter) { let (read_half, write_half) = tokio::io::split(self.stream); - let reader = FramedReader { stream: read_half }; + let reader = FramedReader { + stream: read_half, + read_buf: self.read_buf, + }; let writer = FramedWriter { stream: write_half, write_buf: self.write_buf, @@ -105,6 +135,7 @@ impl Framed { pub fn unsplit(reader: FramedReader, writer: FramedWriter) -> Self { Self { stream: reader.stream.unsplit(writer.stream), + read_buf: reader.read_buf, write_buf: writer.write_buf, } } @@ -112,25 +143,26 @@ impl Framed { /// Read-only version of `Framed`. pub struct FramedReader { - stream: ReadHalf>, + stream: ReadHalf, + read_buf: BytesMut, } impl FramedReader { pub async fn read_message(&mut self) -> Result, ConnectionError> { - FeMessage::read(&mut self.stream).await + read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await } } /// Write-only version of `Framed`. pub struct FramedWriter { - stream: WriteHalf>, + stream: WriteHalf, write_buf: BytesMut, } -impl FramedWriter { +impl FramedWriter { /// Write next message to the output buffer; doesn't flush. - pub fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ConnectionError> { - BeMessage::write(&mut self.write_buf, msg).map_err(|e| e.into()) + pub fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> { + BeMessage::write(&mut self.write_buf, msg) } /// Flush out the buffer. This function is cancellation safe: it can be @@ -145,6 +177,43 @@ impl FramedWriter { } } +/// Read next message from the stream. Returns Ok(None), if EOF happened and we +/// don't have remaining data in the buffer. This function is cancellation safe: +/// you can drop future which is not yet complete and finalize reading message +/// with the next call. +/// +/// Parametrized to allow reading startup or usual message, having different +/// format. +async fn read_message( + stream: &mut S, + read_buf: &mut BytesMut, + parse: P, +) -> Result, ConnectionError> +where + P: Fn(&mut BytesMut) -> Result, ProtocolError>, +{ + loop { + if let Some(msg) = parse(read_buf)? { + return Ok(Some(msg)); + } + // If we can't build a frame yet, try to read more data and try again. + // Make sure we've got room for at least one byte to read to ensure + // that we don't get a spurious 0 that looks like EOF. + read_buf.reserve(1); + if stream.read_buf(read_buf).await? == 0 { + if read_buf.has_remaining() { + return Err(io::Error::new( + ErrorKind::UnexpectedEof, + "EOF with unprocessed data in the buffer", + ) + .into()); + } else { + return Ok(None); // clean EOF + } + } + } +} + async fn flush( stream: &mut S, write_buf: &mut BytesMut, diff --git a/libs/pq_proto/src/lib.rs b/libs/pq_proto/src/lib.rs index 6980c4afae..46d531239a 100644 --- a/libs/pq_proto/src/lib.rs +++ b/libs/pq_proto/src/lib.rs @@ -4,19 +4,16 @@ pub mod framed; -use anyhow::{ensure, Context, Result}; +use byteorder::{BigEndian, ReadBytesExt}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use postgres_protocol::PG_EPOCH; use serde::{Deserialize, Serialize}; use std::{ borrow::Cow, collections::HashMap, - fmt, - io::{self, Cursor}, - str, + fmt, io, str, time::{Duration, SystemTime}, }; -use tokio::io::AsyncReadExt; use tracing::{trace, warn}; pub type Oid = u32; @@ -28,7 +25,6 @@ pub const TEXT_OID: Oid = 25; #[derive(Debug)] pub enum FeMessage { - StartupPacket(FeStartupPacket), // Simple query. Query(Bytes), // Extended query protocol. @@ -188,100 +184,90 @@ pub struct FeExecuteMessage { #[derive(Debug)] pub struct FeCloseMessage; -/// Retry a read on EINTR -/// -/// This runs the enclosed expression, and if it returns -/// Err(io::ErrorKind::Interrupted), retries it. -macro_rules! retry_read { - ( $x:expr ) => { - loop { - match $x { - Err(e) if e.kind() == io::ErrorKind::Interrupted => continue, - res => break res, - } - } - }; -} - -/// An error occured during connection being open. +/// An error occured while parsing or serializing raw stream into Postgres +/// messages. #[derive(thiserror::Error, Debug)] -pub enum ConnectionError { - /// IO error during writing to or reading from the connection socket. - #[error("Socket IO error: {0}")] - Socket(#[from] std::io::Error), - /// Invalid packet was received from client +pub enum ProtocolError { + /// Invalid packet was received from the client (e.g. unexpected message + /// type or broken len). #[error("Protocol error: {0}")] Protocol(String), - /// Failed to parse a protocol mesage + /// Failed to parse or, (unlikely), serialize a protocol message. #[error("Message parse error: {0}")] - MessageParse(anyhow::Error), + BadMessage(String), } -impl From for ConnectionError { - fn from(e: anyhow::Error) -> Self { - Self::MessageParse(e) - } -} - -impl ConnectionError { +impl ProtocolError { + /// Proxy stream.rs uses only io::Error; provide it. pub fn into_io_error(self) -> io::Error { - match self { - ConnectionError::Socket(io) => io, - other => io::Error::new(io::ErrorKind::Other, other.to_string()), - } + io::Error::new(io::ErrorKind::Other, self.to_string()) } } impl FeMessage { - /// Read one message from the stream. - /// This function returns `Ok(None)` in case of EOF. - pub async fn read(stream: &mut Reader) -> Result, ConnectionError> - where - Reader: tokio::io::AsyncRead + Unpin, - { - // We return a Future that's sync (has a `wait` method) if and only if the provided stream is SyncProof. - // SyncFuture contract: we are only allowed to await on sync-proof futures, the AsyncRead and - // AsyncReadExt methods of the stream. - // Each libpq message begins with a message type byte, followed by message length - // If the client closes the connection, return None. But if the client closes the - // connection in the middle of a message, we will return an error. - let tag = match retry_read!(stream.read_u8().await) { - Ok(b) => b, - Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(ConnectionError::Socket(e)), - }; + /// Read and parse one message from the `buf` input buffer. If there is at + /// least one valid message, returns it, advancing `buf`; redundant copies + /// are avoided, as thanks to `bytes` crate ptrs in parsed message point + /// directly into the `buf` (processed data is garbage collected after + /// parsed message is dropped). + /// + /// Returns None if `buf` doesn't contain enough data for a single message. + /// For efficiency, tries to reserve large enough space in `buf` for the + /// next message in this case to save the repeated calls. + /// + /// Returns Error if message is malformed, the only possible ErrorKind is + /// InvalidInput. + // + // Inspired by rust-postgres Message::parse. + pub fn parse(buf: &mut BytesMut) -> Result, ProtocolError> { + // Every message contains message type byte and 4 bytes len; can't do + // much without them. + if buf.len() < 5 { + let to_read = 5 - buf.len(); + buf.reserve(to_read); + return Ok(None); + } - // The message length includes itself, so it better be at least 4. - let len = retry_read!(stream.read_u32().await) - .map_err(ConnectionError::Socket)? - .checked_sub(4) - .ok_or_else(|| ConnectionError::Protocol("invalid message length".to_string()))?; + // We shouldn't advance `buf` as probably full message is not there yet, + // so can't directly use Bytes::get_u32 etc. + let tag = buf[0]; + let len = (&buf[1..5]).read_u32::().unwrap(); + if len < 4 { + return Err(ProtocolError::Protocol(format!( + "invalid message length {}", + len + ))); + } - let body = { - let mut buffer = vec![0u8; len as usize]; - stream - .read_exact(&mut buffer) - .await - .map_err(ConnectionError::Socket)?; - Bytes::from(buffer) - }; + // length field includes itself, but not message type. + let total_len = len as usize + 1; + if buf.len() < total_len { + // Don't have full message yet. + let to_read = total_len - buf.len(); + buf.reserve(to_read); + return Ok(None); + } + + // got the message, advance buffer + let mut msg = buf.split_to(total_len).freeze(); + msg.advance(5); // consume message type and len match tag { - b'Q' => Ok(Some(FeMessage::Query(body))), - b'P' => Ok(Some(FeParseMessage::parse(body)?)), - b'D' => Ok(Some(FeDescribeMessage::parse(body)?)), - b'E' => Ok(Some(FeExecuteMessage::parse(body)?)), - b'B' => Ok(Some(FeBindMessage::parse(body)?)), - b'C' => Ok(Some(FeCloseMessage::parse(body)?)), + b'Q' => Ok(Some(FeMessage::Query(msg))), + b'P' => Ok(Some(FeParseMessage::parse(msg)?)), + b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)), + b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)), + b'B' => Ok(Some(FeBindMessage::parse(msg)?)), + b'C' => Ok(Some(FeCloseMessage::parse(msg)?)), b'S' => Ok(Some(FeMessage::Sync)), b'X' => Ok(Some(FeMessage::Terminate)), - b'd' => Ok(Some(FeMessage::CopyData(body))), + b'd' => Ok(Some(FeMessage::CopyData(msg))), b'c' => Ok(Some(FeMessage::CopyDone)), b'f' => Ok(Some(FeMessage::CopyFail)), - b'p' => Ok(Some(FeMessage::PasswordMessage(body))), + b'p' => Ok(Some(FeMessage::PasswordMessage(msg))), tag => { - return Err(ConnectionError::Protocol(format!( - "unknown message tag: {tag},'{body:?}'" + return Err(ProtocolError::Protocol(format!( + "unknown message tag: {tag},'{msg:?}'" ))) } } @@ -289,60 +275,59 @@ impl FeMessage { } impl FeStartupPacket { - /// Read startup message from the stream. - // XXX: It's tempting yet undesirable to accept `stream` by value, - // since such a change will cause user-supplied &mut references to be consumed - pub async fn read(stream: &mut Reader) -> Result, ConnectionError> - where - Reader: tokio::io::AsyncRead + Unpin, - { + /// Read and parse startup message from the `buf` input buffer. It is + /// different from [`FeMessage::parse`] because startup messages don't have + /// message type byte; otherwise, its comments apply. + pub fn parse(buf: &mut BytesMut) -> Result, ProtocolError> { const MAX_STARTUP_PACKET_LENGTH: usize = 10000; const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234; const CANCEL_REQUEST_CODE: u32 = 5678; const NEGOTIATE_SSL_CODE: u32 = 5679; const NEGOTIATE_GSS_CODE: u32 = 5680; - // Read length. If the connection is closed before reading anything (or before - // reading 4 bytes, to be precise), return None to indicate that the connection - // was closed. This matches the PostgreSQL server's behavior, which avoids noise - // in the log if the client opens connection but closes it immediately. - let len = match retry_read!(stream.read_u32().await) { - Ok(len) => len as usize, - Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(ConnectionError::Socket(e)), - }; + // need at least 4 bytes with packet len + if buf.len() < 4 { + let to_read = 4 - buf.len(); + buf.reserve(to_read); + return Ok(None); + } - #[allow(clippy::manual_range_contains)] + // We shouldn't advance `buf` as probably full message is not there yet, + // so can't directly use Bytes::get_u32 etc. + let len = (&buf[0..4]).read_u32::().unwrap() as usize; if len < 4 || len > MAX_STARTUP_PACKET_LENGTH { - return Err(ConnectionError::Protocol(format!( - "invalid message length {len}" + return Err(ProtocolError::Protocol(format!( + "invalid startup packet message length {}", + len ))); } - let request_code = retry_read!(stream.read_u32().await).map_err(ConnectionError::Socket)?; + if buf.len() < len { + // Don't have full message yet. + let to_read = len - buf.len(); + buf.reserve(to_read); + return Ok(None); + } - // the rest of startup packet are params - let params_len = len - 8; - let mut params_bytes = vec![0u8; params_len]; - stream - .read_exact(params_bytes.as_mut()) - .await - .map_err(ConnectionError::Socket)?; + // got the message, advance buffer + let mut msg = buf.split_to(len).freeze(); + msg.advance(4); // consume len - // Parse params depending on request code + let request_code = msg.get_u32(); let req_hi = request_code >> 16; let req_lo = request_code & ((1 << 16) - 1); + // StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code. let message = match (req_hi, req_lo) { (RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => { - if params_len != 8 { - return Err(ConnectionError::Protocol( - "expected 8 bytes for CancelRequest params".to_string(), + if msg.remaining() != 8 { + return Err(ProtocolError::BadMessage( + "CancelRequest message is malformed, backend PID / secret key missing" + .to_owned(), )); } - let mut cursor = Cursor::new(params_bytes); FeStartupPacket::CancelRequest(CancelKeyData { - backend_pid: cursor.read_i32().await.map_err(ConnectionError::Socket)?, - cancel_key: cursor.read_i32().await.map_err(ConnectionError::Socket)?, + backend_pid: msg.get_i32(), + cancel_key: msg.get_i32(), }) } (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => { @@ -354,19 +339,23 @@ impl FeStartupPacket { FeStartupPacket::GssEncRequest } (RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => { - return Err(ConnectionError::Protocol(format!( + return Err(ProtocolError::Protocol(format!( "Unrecognized request code {unrecognized_code}" ))); } // TODO bail if protocol major_version is not 3? (major_version, minor_version) => { + // StartupMessage + // Parse pairs of null-terminated strings (key, value). // See `postgres: ProcessStartupPacket, build_startup_packet`. - let mut tokens = str::from_utf8(¶ms_bytes) - .context("StartupMessage params: invalid utf-8")? + let mut tokens = str::from_utf8(&msg) + .map_err(|_e| { + ProtocolError::BadMessage("StartupMessage params: invalid utf-8".to_owned()) + })? .strip_suffix('\0') // drop packet's own null .ok_or_else(|| { - ConnectionError::Protocol( + ProtocolError::Protocol( "StartupMessage params: missing null terminator".to_string(), ) })? @@ -375,7 +364,7 @@ impl FeStartupPacket { let mut params = HashMap::new(); while let Some(name) = tokens.next() { let value = tokens.next().ok_or_else(|| { - ConnectionError::Protocol( + ProtocolError::Protocol( "StartupMessage params: key without value".to_string(), ) })?; @@ -390,13 +379,12 @@ impl FeStartupPacket { } } }; - - Ok(Some(FeMessage::StartupPacket(message))) + Ok(Some(message)) } } impl FeParseMessage { - fn parse(mut buf: Bytes) -> anyhow::Result { + fn parse(mut buf: Bytes) -> Result { // FIXME: the rust-postgres driver uses a named prepared statement // for copy_out(). We're not prepared to handle that correctly. For // now, just ignore the statement name, assuming that the client never @@ -404,55 +392,82 @@ impl FeParseMessage { let _pstmt_name = read_cstr(&mut buf)?; let query_string = read_cstr(&mut buf)?; + if buf.remaining() < 2 { + return Err(ProtocolError::BadMessage( + "Parse message is malformed, nparams missing".to_string(), + )); + } let nparams = buf.get_i16(); - ensure!(nparams == 0, "query params not implemented"); + if nparams != 0 { + return Err(ProtocolError::BadMessage( + "query params not implemented".to_string(), + )); + } Ok(FeMessage::Parse(FeParseMessage { query_string })) } } impl FeDescribeMessage { - fn parse(mut buf: Bytes) -> anyhow::Result { + fn parse(mut buf: Bytes) -> Result { let kind = buf.get_u8(); let _pstmt_name = read_cstr(&mut buf)?; // FIXME: see FeParseMessage::parse - ensure!( - kind == b'S', - "only prepared statemement Describe is implemented" - ); + if kind != b'S' { + return Err(ProtocolError::BadMessage( + "only prepared statemement Describe is implemented".to_string(), + )); + } Ok(FeMessage::Describe(FeDescribeMessage { kind })) } } impl FeExecuteMessage { - fn parse(mut buf: Bytes) -> anyhow::Result { + fn parse(mut buf: Bytes) -> Result { let portal_name = read_cstr(&mut buf)?; + if buf.remaining() < 4 { + return Err(ProtocolError::BadMessage( + "FeExecuteMessage message is malformed, maxrows missing".to_string(), + )); + } let maxrows = buf.get_i32(); - ensure!(portal_name.is_empty(), "named portals not implemented"); - ensure!(maxrows == 0, "row limit in Execute message not implemented"); + if !portal_name.is_empty() { + return Err(ProtocolError::BadMessage( + "named portals not implemented".to_string(), + )); + } + if maxrows != 0 { + return Err(ProtocolError::BadMessage( + "row limit in Execute message not implemented".to_string(), + )); + } Ok(FeMessage::Execute(FeExecuteMessage { maxrows })) } } impl FeBindMessage { - fn parse(mut buf: Bytes) -> anyhow::Result { + fn parse(mut buf: Bytes) -> Result { let portal_name = read_cstr(&mut buf)?; let _pstmt_name = read_cstr(&mut buf)?; // FIXME: see FeParseMessage::parse - ensure!(portal_name.is_empty(), "named portals not implemented"); + if !portal_name.is_empty() { + return Err(ProtocolError::BadMessage( + "named portals not implemented".to_string(), + )); + } Ok(FeMessage::Bind(FeBindMessage)) } } impl FeCloseMessage { - fn parse(mut buf: Bytes) -> anyhow::Result { + fn parse(mut buf: Bytes) -> Result { let _kind = buf.get_u8(); let _pstmt_or_portal_name = read_cstr(&mut buf)?; @@ -481,6 +496,7 @@ pub enum BeMessage<'a> { CloseComplete, // None means column is NULL DataRow(&'a [Option<&'a [u8]>]), + // None errcode means internal_error will be sent. ErrorResponse(&'a str, Option<&'a [u8; 5]>), /// Single byte - used in response to SSLRequest/GSSENCRequest. EncryptionResponse(bool), @@ -594,7 +610,7 @@ impl RowDescriptor<'_> { #[derive(Debug)] pub struct XLogDataBody<'a> { pub wal_start: u64, - pub wal_end: u64, + pub wal_end: u64, // current end of WAL on the server pub timestamp: i64, pub data: &'a [u8], } @@ -634,12 +650,11 @@ fn write_body(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R { } /// Safe write of s into buf as cstring (String in the protocol). -fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> io::Result<()> { +fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> { let bytes = s.as_ref(); if bytes.contains(&0) { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "string contains embedded null", + return Err(ProtocolError::BadMessage( + "string contains embedded null".to_owned(), )); } buf.put_slice(bytes); @@ -647,9 +662,13 @@ fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> io::Result<()> { Ok(()) } -fn read_cstr(buf: &mut Bytes) -> anyhow::Result { - let pos = buf.iter().position(|x| *x == 0); - let result = buf.split_to(pos.context("missing terminator")?); +/// Read cstring from buf, advancing it. +fn read_cstr(buf: &mut Bytes) -> Result { + let pos = buf + .iter() + .position(|x| *x == 0) + .ok_or_else(|| ProtocolError::BadMessage("missing cstring terminator".to_owned()))?; + let result = buf.split_to(pos); buf.advance(1); // drop the null terminator Ok(result) } @@ -658,12 +677,12 @@ pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000"; pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000"; impl<'a> BeMessage<'a> { - /// Write message to the given buf. - // Unlike the reading side, we use BytesMut - // here as msg len precedes its body and it is handy to write it down first - // and then fill the length. With Write we would have to either calc it - // manually or have one more buffer. - pub fn write(buf: &mut BytesMut, message: &BeMessage) -> io::Result<()> { + /// Serialize `message` to the given `buf`. + /// Apart from smart memory managemet, BytesMut is good here as msg len + /// precedes its body and it is handy to write it down first and then fill + /// the length. With Write we would have to either calc it manually or have + /// one more buffer. + pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> { match message { BeMessage::AuthenticationOk => { buf.put_u8(b'R'); @@ -708,7 +727,7 @@ impl<'a> BeMessage<'a> { buf.put_slice(extra); } } - Ok::<_, io::Error>(()) + Ok(()) })?; } @@ -812,7 +831,7 @@ impl<'a> BeMessage<'a> { write_cstr(error_msg, buf)?; buf.put_u8(0); // terminator - Ok::<_, io::Error>(()) + Ok(()) })?; } @@ -835,7 +854,7 @@ impl<'a> BeMessage<'a> { write_cstr(error_msg.as_bytes(), buf)?; buf.put_u8(0); // terminator - Ok::<_, io::Error>(()) + Ok(()) })?; } @@ -890,7 +909,7 @@ impl<'a> BeMessage<'a> { buf.put_i32(-1); /* typmod */ buf.put_i16(0); /* format code */ } - Ok::<_, io::Error>(()) + Ok(()) })?; } @@ -957,7 +976,7 @@ impl ReplicationFeedback { // null-terminated string - key, // uint32 - value length in bytes // value itself - pub fn serialize(&self, buf: &mut BytesMut) -> Result<()> { + pub fn serialize(&self, buf: &mut BytesMut) { buf.put_u8(REPLICATION_FEEDBACK_FIELDS_NUMBER); // # of keys buf.put_slice(b"current_timeline_size\0"); buf.put_i32(8); @@ -982,7 +1001,6 @@ impl ReplicationFeedback { buf.put_slice(b"ps_replytime\0"); buf.put_i32(8); buf.put_i64(timestamp); - Ok(()) } // Deserialize ReplicationFeedback message @@ -1050,7 +1068,7 @@ mod tests { // because it is rounded up to microseconds during serialization. rf.ps_replytime = *PG_EPOCH + Duration::from_secs(100_000_000); let mut data = BytesMut::new(); - rf.serialize(&mut data).unwrap(); + rf.serialize(&mut data); let rf_parsed = ReplicationFeedback::parse(data.freeze()); assert_eq!(rf, rf_parsed); @@ -1065,7 +1083,7 @@ mod tests { // because it is rounded up to microseconds during serialization. rf.ps_replytime = *PG_EPOCH + Duration::from_secs(100_000_000); let mut data = BytesMut::new(); - rf.serialize(&mut data).unwrap(); + rf.serialize(&mut data); // Add an extra field to the buffer and adjust number of keys if let Some(first) = data.first_mut() { diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index bdcd71a20f..40e11a70b7 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -21,7 +21,7 @@ use pageserver_api::models::{ PagestreamNblocksRequest, PagestreamNblocksResponse, }; use postgres_backend::{self, is_expected_io_error, AuthType, PostgresBackend, QueryError}; -use pq_proto::ConnectionError; +use pq_proto::framed::ConnectionError; use pq_proto::FeStartupPacket; use pq_proto::{BeMessage, FeMessage, RowDescriptor}; use std::io; @@ -78,7 +78,7 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream continue, FeMessage::Terminate => { let msg = "client terminated connection with Terminate message during COPY"; - let query_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg))); + let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg))); // error can't happen here, ErrorResponse serialization should be always ok pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?; Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?; @@ -97,13 +97,13 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream { let msg = "client closed connection during COPY"; - let query_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg))); + let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg))); // error can't happen here, ErrorResponse serialization should be always ok pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?; pgb.flush().await?; Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?; } - Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => { + Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => { Err(io_error)?; } Err(other) => { @@ -214,7 +214,7 @@ async fn page_service_conn_main( // we've been requested to shut down Ok(()) } - Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => { + Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => { if is_expected_io_error(&io_error) { info!("Postgres client disconnected ({io_error})"); Ok(()) @@ -1057,7 +1057,7 @@ impl From for QueryError { fn from(e: GetActiveTenantError) -> Self { match e { GetActiveTenantError::WaitForActiveTimeout { .. } => QueryError::Disconnected( - ConnectionError::Socket(io::Error::new(io::ErrorKind::TimedOut, e.to_string())), + ConnectionError::Io(io::Error::new(io::ErrorKind::TimedOut, e.to_string())), ), GetActiveTenantError::Other(e) => QueryError::Other(e), } diff --git a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs index 41ac61b7b6..7194a4f3ed 100644 --- a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs +++ b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs @@ -354,7 +354,7 @@ pub async fn handle_walreceiver_connection( debug!("neon_status_update {status_update:?}"); let mut data = BytesMut::new(); - status_update.serialize(&mut data)?; + status_update.serialize(&mut data); physical_stream .as_mut() .zenith_status_update(data.len() as u64, &data) diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index e0cf1326b9..5a802dafb2 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -1,45 +1,40 @@ use crate::error::UserFacingError; use anyhow::bail; -use bytes::BytesMut; use pin_project_lite::pin_project; -use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket}; +use pq_proto::framed::{ConnectionError, Framed}; +use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError}; use rustls::ServerConfig; use std::pin::Pin; use std::sync::Arc; use std::{io, task}; use thiserror::Error; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio_rustls::server::TlsStream; -pin_project! { - /// Stream wrapper which implements libpq's protocol. - /// NOTE: This object deliberately doesn't implement [`AsyncRead`] - /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying - /// to pass random malformed bytes through the connection). - pub struct PqStream { - #[pin] - stream: S, - buffer: BytesMut, - } +/// Stream wrapper which implements libpq's protocol. +/// NOTE: This object deliberately doesn't implement [`AsyncRead`] +/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying +/// to pass random malformed bytes through the connection). +pub struct PqStream { + framed: Framed, } impl PqStream { /// Construct a new libpq protocol wrapper. pub fn new(stream: S) -> Self { Self { - stream, - buffer: Default::default(), + framed: Framed::new(stream), } } /// Extract the underlying stream. pub fn into_inner(self) -> S { - self.stream + self.framed.into_inner() } /// Get a shared reference to the underlying stream. pub fn get_ref(&self) -> &S { - &self.stream + self.framed.get_ref() } } @@ -50,16 +45,19 @@ fn err_connection() -> io::Error { impl PqStream { /// Receive [`FeStartupPacket`], which is a first packet sent by a client. pub async fn read_startup_packet(&mut self) -> io::Result { - // TODO: `FeStartupPacket::read_fut` should return `FeStartupPacket` - let msg = FeStartupPacket::read(&mut self.stream) + self.framed + .read_startup_message() .await .map_err(ConnectionError::into_io_error)? - .ok_or_else(err_connection)?; + .ok_or_else(err_connection) + } - match msg { - FeMessage::StartupPacket(packet) => Ok(packet), - _ => panic!("unreachable state"), - } + async fn read_message(&mut self) -> io::Result { + self.framed + .read_message() + .await + .map_err(ConnectionError::into_io_error)? + .ok_or_else(err_connection) } pub async fn read_password_message(&mut self) -> io::Result { @@ -71,19 +69,14 @@ impl PqStream { )), } } - - async fn read_message(&mut self) -> io::Result { - FeMessage::read(&mut self.stream) - .await - .map_err(ConnectionError::into_io_error)? - .ok_or_else(err_connection) - } } impl PqStream { /// Write the message into an internal buffer, but don't flush the underlying stream. pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> { - BeMessage::write(&mut self.buffer, message)?; + self.framed + .write_message(message) + .map_err(ProtocolError::into_io_error)?; Ok(self) } @@ -96,9 +89,7 @@ impl PqStream { /// Flush the output buffer into the underlying stream. pub async fn flush(&mut self) -> io::Result<&mut Self> { - self.stream.write_all(&self.buffer).await?; - self.buffer.clear(); - self.stream.flush().await?; + self.framed.flush().await?; Ok(self) } diff --git a/safekeeper/src/safekeeper.rs b/safekeeper/src/safekeeper.rs index 4a046cb048..d8fe36d7f8 100644 --- a/safekeeper/src/safekeeper.rs +++ b/safekeeper/src/safekeeper.rs @@ -488,7 +488,7 @@ impl AcceptorProposerMessage { buf.put_u64_le(msg.hs_feedback.xmin); buf.put_u64_le(msg.hs_feedback.catalog_xmin); - msg.pageserver_feedback.serialize(buf)?; + msg.pageserver_feedback.serialize(buf); } }