diff --git a/Cargo.lock b/Cargo.lock index 4ac9e9f3d9..9ea826f808 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2765,6 +2765,7 @@ dependencies = [ "jsonwebtoken", "lazy_static", "nix", + "pin-project-lite", "postgres 0.19.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)", "postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)", "rand", diff --git a/zenith_utils/Cargo.toml b/zenith_utils/Cargo.toml index 34c4c03d97..055787a2ec 100644 --- a/zenith_utils/Cargo.toml +++ b/zenith_utils/Cargo.toml @@ -7,10 +7,10 @@ edition = "2021" [dependencies] anyhow = "1.0" bincode = "1.3" -byteorder = "1.4.3" bytes = "1.0.1" hyper = { version = "0.14.7", features = ["full"] } lazy_static = "1.4.0" +pin-project-lite = "0.2.7" postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" } postgres-protocol = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" } routerify = "2" @@ -33,7 +33,8 @@ zenith_metrics = { path = "../zenith_metrics" } workspace_hack = { path = "../workspace_hack" } [dev-dependencies] +byteorder = "1.4.3" +bytes = "1.0.1" hex-literal = "0.3" -bytes = "1.0" -webpki = "0.21" tempfile = "3.2" +webpki = "0.21" diff --git a/zenith_utils/src/lib.rs b/zenith_utils/src/lib.rs index b0e5131a11..7d8ef63b1c 100644 --- a/zenith_utils/src/lib.rs +++ b/zenith_utils/src/lib.rs @@ -42,6 +42,9 @@ pub mod logging; pub mod accum; pub mod shutdown; +// Tools for calling certain async methods in sync contexts +pub mod sync; + // Utility for binding TcpListeners with proper socket options. pub mod tcp_listener; diff --git a/zenith_utils/src/pq_proto.rs b/zenith_utils/src/pq_proto.rs index dec4f460a6..694a990448 100644 --- a/zenith_utils/src/pq_proto.rs +++ b/zenith_utils/src/pq_proto.rs @@ -2,17 +2,17 @@ //! //! on message formats. +use crate::sync::{AsyncishRead, SyncFuture}; use anyhow::{bail, ensure, Context, Result}; -use byteorder::{BigEndian, ByteOrder}; -use byteorder::{ReadBytesExt, BE}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use postgres_protocol::PG_EPOCH; use serde::{Deserialize, Serialize}; use std::collections::HashMap; -use std::io::Read; +use std::future::Future; use std::io::{self, Cursor}; use std::str; use std::time::{Duration, SystemTime}; +use tokio::io::AsyncReadExt; use tracing::info; pub type Oid = u32; @@ -96,14 +96,14 @@ impl FeMessage { /// One way to handle this properly: /// /// ``` - /// # use std::io::Read; + /// # use std::io; /// # use zenith_utils::pq_proto::FeMessage; /// # /// # fn process_message(msg: FeMessage) -> anyhow::Result<()> { /// # Ok(()) /// # }; /// # - /// fn do_the_job(stream: &mut impl Read) -> anyhow::Result<()> { + /// fn do_the_job(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result<()> { /// while let Some(msg) = FeMessage::read(stream)? { /// process_message(msg)?; /// } @@ -111,124 +111,159 @@ impl FeMessage { /// Ok(()) /// } /// ``` - pub fn read(stream: &mut impl Read) -> anyhow::Result> { - // Each libpq message begins with a message type byte, followed by message length - // If the client closes the connection, return None. But if the client closes the - // connection in the middle of a message, we will return an error. - let tag = match stream.read_u8() { - Ok(b) => b, - Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(e.into()), - }; - let len = stream.read_u32::()?; + #[inline(never)] + pub fn read(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result> { + Self::read_fut(&mut AsyncishRead(stream)).wait() + } - // The message length includes itself, so it better be at least 4 - let bodylen = len - .checked_sub(4) - .context("invalid message length: parsing u32")?; + /// Read one message from the stream. + /// See documentation for `Self::read`. + pub fn read_fut( + stream: &mut Reader, + ) -> SyncFuture>> + '_> + where + Reader: tokio::io::AsyncRead + Unpin, + { + // We return a Future that's sync (has a `wait` method) if and only if the provided stream is SyncProof. + // SyncFuture contract: we are only allowed to await on sync-proof futures, the AsyncRead and + // AsyncReadExt methods of the stream. + SyncFuture::new(async move { + // Each libpq message begins with a message type byte, followed by message length + // If the client closes the connection, return None. But if the client closes the + // connection in the middle of a message, we will return an error. + let tag = match stream.read_u8().await { + Ok(b) => b, + Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), + Err(e) => return Err(e.into()), + }; + let len = stream.read_u32().await?; - // Read message body - let mut body_buf: Vec = vec![0; bodylen as usize]; - stream.read_exact(&mut body_buf)?; + // The message length includes itself, so it better be at least 4 + let bodylen = len + .checked_sub(4) + .context("invalid message length: parsing u32")?; - let body = Bytes::from(body_buf); + // Read message body + let mut body_buf: Vec = vec![0; bodylen as usize]; + stream.read_exact(&mut body_buf).await?; - // Parse it - match tag { - b'Q' => Ok(Some(FeMessage::Query(FeQueryMessage { body }))), - b'P' => Ok(Some(FeParseMessage::parse(body)?)), - b'D' => Ok(Some(FeDescribeMessage::parse(body)?)), - b'E' => Ok(Some(FeExecuteMessage::parse(body)?)), - b'B' => Ok(Some(FeBindMessage::parse(body)?)), - b'C' => Ok(Some(FeCloseMessage::parse(body)?)), - b'S' => Ok(Some(FeMessage::Sync)), - b'X' => Ok(Some(FeMessage::Terminate)), - b'd' => Ok(Some(FeMessage::CopyData(body))), - b'c' => Ok(Some(FeMessage::CopyDone)), - b'f' => Ok(Some(FeMessage::CopyFail)), - b'p' => Ok(Some(FeMessage::PasswordMessage(body))), - tag => bail!("unknown message tag: {},'{:?}'", tag, body), - } + let body = Bytes::from(body_buf); + + // Parse it + match tag { + b'Q' => Ok(Some(FeMessage::Query(FeQueryMessage { body }))), + b'P' => Ok(Some(FeParseMessage::parse(body)?)), + b'D' => Ok(Some(FeDescribeMessage::parse(body)?)), + b'E' => Ok(Some(FeExecuteMessage::parse(body)?)), + b'B' => Ok(Some(FeBindMessage::parse(body)?)), + b'C' => Ok(Some(FeCloseMessage::parse(body)?)), + b'S' => Ok(Some(FeMessage::Sync)), + b'X' => Ok(Some(FeMessage::Terminate)), + b'd' => Ok(Some(FeMessage::CopyData(body))), + b'c' => Ok(Some(FeMessage::CopyDone)), + b'f' => Ok(Some(FeMessage::CopyFail)), + b'p' => Ok(Some(FeMessage::PasswordMessage(body))), + tag => bail!("unknown message tag: {},'{:?}'", tag, body), + } + }) } } impl FeStartupPacket { /// Read startup message from the stream. - pub fn read(stream: &mut impl std::io::Read) -> anyhow::Result> { + // XXX: It's tempting yet undesirable to accept `stream` by value, + // since such a change will cause user-supplied &mut references to be consumed + pub fn read(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result> { + Self::read_fut(&mut AsyncishRead(stream)).wait() + } + + /// Read startup message from the stream. + // XXX: It's tempting yet undesirable to accept `stream` by value, + // since such a change will cause user-supplied &mut references to be consumed + pub fn read_fut( + stream: &mut Reader, + ) -> SyncFuture>> + '_> + where + Reader: tokio::io::AsyncRead + Unpin, + { const MAX_STARTUP_PACKET_LENGTH: usize = 10000; const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234; const CANCEL_REQUEST_CODE: u32 = 5678; const NEGOTIATE_SSL_CODE: u32 = 5679; const NEGOTIATE_GSS_CODE: u32 = 5680; - // Read length. If the connection is closed before reading anything (or before - // reading 4 bytes, to be precise), return None to indicate that the connection - // was closed. This matches the PostgreSQL server's behavior, which avoids noise - // in the log if the client opens connection but closes it immediately. - let len = match stream.read_u32::() { - Ok(len) => len as usize, - Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(e.into()), - }; + SyncFuture::new(async move { + // Read length. If the connection is closed before reading anything (or before + // reading 4 bytes, to be precise), return None to indicate that the connection + // was closed. This matches the PostgreSQL server's behavior, which avoids noise + // in the log if the client opens connection but closes it immediately. + let len = match stream.read_u32().await { + Ok(len) => len as usize, + Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), + Err(e) => return Err(e.into()), + }; - if len < 4 || len > MAX_STARTUP_PACKET_LENGTH { - bail!("invalid message length"); - } - - let request_code = stream.read_u32::()?; - - // the rest of startup packet are params - let params_len = len - 8; - let mut params_bytes = vec![0u8; params_len]; - stream.read_exact(params_bytes.as_mut())?; - - // Parse params depending on request code - let most_sig_16_bits = request_code >> 16; - let least_sig_16_bits = request_code & ((1 << 16) - 1); - let message = match (most_sig_16_bits, least_sig_16_bits) { - (RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => { - ensure!(params_len == 8, "expected 8 bytes for CancelRequest params"); - let mut cursor = Cursor::new(params_bytes); - FeStartupPacket::CancelRequest(CancelKeyData { - backend_pid: cursor.read_i32::()?, - cancel_key: cursor.read_i32::()?, - }) + if len < 4 || len > MAX_STARTUP_PACKET_LENGTH { + bail!("invalid message length"); } - (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => FeStartupPacket::SslRequest, - (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => FeStartupPacket::GssEncRequest, - (RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => { - bail!("Unrecognized request code {}", unrecognized_code) - } - (major_version, minor_version) => { - // TODO bail if protocol major_version is not 3? - // Parse null-terminated (String) pairs of param name / param value - let params_str = str::from_utf8(¶ms_bytes).unwrap(); - let mut params_tokens = params_str.split('\0'); - let mut params: HashMap = HashMap::new(); - while let Some(name) = params_tokens.next() { - let value = params_tokens - .next() - .context("expected even number of params in StartupMessage")?; - if name == "options" { - // deprecated way of passing params as cmd line args - for cmdopt in value.split(' ') { - let nameval: Vec<&str> = cmdopt.split('=').collect(); - if nameval.len() == 2 { - params.insert(nameval[0].to_string(), nameval[1].to_string()); + + let request_code = stream.read_u32().await?; + + // the rest of startup packet are params + let params_len = len - 8; + let mut params_bytes = vec![0u8; params_len]; + stream.read_exact(params_bytes.as_mut()).await?; + + // Parse params depending on request code + let most_sig_16_bits = request_code >> 16; + let least_sig_16_bits = request_code & ((1 << 16) - 1); + let message = match (most_sig_16_bits, least_sig_16_bits) { + (RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => { + ensure!(params_len == 8, "expected 8 bytes for CancelRequest params"); + let mut cursor = Cursor::new(params_bytes); + FeStartupPacket::CancelRequest(CancelKeyData { + backend_pid: cursor.read_i32().await?, + cancel_key: cursor.read_i32().await?, + }) + } + (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => FeStartupPacket::SslRequest, + (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => { + FeStartupPacket::GssEncRequest + } + (RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => { + bail!("Unrecognized request code {}", unrecognized_code) + } + (major_version, minor_version) => { + // TODO bail if protocol major_version is not 3? + // Parse null-terminated (String) pairs of param name / param value + let params_str = str::from_utf8(¶ms_bytes).unwrap(); + let mut params_tokens = params_str.split('\0'); + let mut params: HashMap = HashMap::new(); + while let Some(name) = params_tokens.next() { + let value = params_tokens + .next() + .context("expected even number of params in StartupMessage")?; + if name == "options" { + // deprecated way of passing params as cmd line args + for cmdopt in value.split(' ') { + let nameval: Vec<&str> = cmdopt.split('=').collect(); + if nameval.len() == 2 { + params.insert(nameval[0].to_string(), nameval[1].to_string()); + } } + } else { + params.insert(name.to_string(), value.to_string()); } - } else { - params.insert(name.to_string(), value.to_string()); + } + FeStartupPacket::StartupMessage { + major_version, + minor_version, + params, } } - FeStartupPacket::StartupMessage { - major_version, - minor_version, - params, - } - } - }; - Ok(Some(FeMessage::StartupPacket(message))) + }; + Ok(Some(FeMessage::StartupPacket(message))) + }) } } @@ -502,7 +537,7 @@ where f(buf)?; let size = i32::from_usize(buf.len() - base)?; - BigEndian::write_i32(&mut buf[base..], size); + (&mut buf[base..]).put_slice(&size.to_be_bytes()); Ok(()) } @@ -977,4 +1012,13 @@ mod tests { let zf_parsed = ZenithFeedback::parse(data.freeze()); assert_eq!(zf, zf_parsed); } + + // Make sure that `read` is sync/async callable + async fn _assert(stream: &mut (impl tokio::io::AsyncRead + Unpin)) { + let _ = FeMessage::read(&mut [].as_ref()); + let _ = FeMessage::read_fut(stream).await; + + let _ = FeStartupPacket::read(&mut [].as_ref()); + let _ = FeStartupPacket::read_fut(stream).await; + } } diff --git a/zenith_utils/src/sync.rs b/zenith_utils/src/sync.rs new file mode 100644 index 0000000000..5e61480bc3 --- /dev/null +++ b/zenith_utils/src/sync.rs @@ -0,0 +1,179 @@ +use pin_project_lite::pin_project; +use std::future::Future; +use std::marker::PhantomData; +use std::pin::Pin; +use std::{io, task}; + +pin_project! { + /// We use this future to mark certain methods + /// as callable in both sync and async modes. + #[repr(transparent)] + pub struct SyncFuture { + #[pin] + inner: T, + _marker: PhantomData, + } +} + +/// This wrapper lets us synchronously wait for inner future's completion +/// (see [`SyncFuture::wait`]) **provided that `S` implements [`SyncProof`]**. +/// For instance, `S` may be substituted with types implementing +/// [`tokio::io::AsyncRead`], but it's not the only viable option. +impl SyncFuture { + /// NOTE: caller should carefully pick a type for `S`, + /// because we don't want to enable [`SyncFuture::wait`] when + /// it's in fact impossible to run the future synchronously. + /// Violation of this contract will not cause UB, but + /// panics and async event loop freezes won't please you. + /// + /// Example: + /// + /// ``` + /// # use zenith_utils::sync::SyncFuture; + /// # use std::future::Future; + /// # use tokio::io::AsyncReadExt; + /// # + /// // Parse a pair of numbers from a stream + /// pub fn parse_pair( + /// stream: &mut Reader, + /// ) -> SyncFuture> + '_> + /// where + /// Reader: tokio::io::AsyncRead + Unpin, + /// { + /// // If `Reader` is a `SyncProof`, this will give caller + /// // an opportunity to use `SyncFuture::wait`, because + /// // `.await` will always result in `Poll::Ready`. + /// SyncFuture::new(async move { + /// let x = stream.read_u32().await?; + /// let y = stream.read_u64().await?; + /// Ok((x, y)) + /// }) + /// } + /// ``` + pub fn new(inner: T) -> Self { + Self { + inner, + _marker: PhantomData, + } + } +} + +impl Future for SyncFuture { + type Output = T::Output; + + /// In async code, [`SyncFuture`] behaves like a regular wrapper. + #[inline(always)] + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll { + self.project().inner.poll(cx) + } +} + +/// Postulates that we can call [`SyncFuture::wait`]. +/// If implementer is also a [`Future`], it should always +/// return [`task::Poll::Ready`] from [`Future::poll`]. +/// +/// Each implementation should document which futures +/// specifically are being declared sync-proof. +pub trait SyncPostulate {} + +impl SyncPostulate for &T {} +impl SyncPostulate for &mut T {} + +impl SyncFuture { + /// Synchronously wait for future completion. + pub fn wait(mut self) -> T::Output { + const RAW_WAKER: task::RawWaker = task::RawWaker::new( + std::ptr::null(), + &task::RawWakerVTable::new( + |_| RAW_WAKER, + |_| panic!("SyncFuture: failed to wake"), + |_| panic!("SyncFuture: failed to wake by ref"), + |_| { /* drop is no-op */ }, + ), + ); + + // SAFETY: We never move `self` during this call; + // furthermore, it will be dropped in the end regardless of panics + let this = unsafe { Pin::new_unchecked(&mut self) }; + + // SAFETY: This waker doesn't do anything apart from panicking + let waker = unsafe { task::Waker::from_raw(RAW_WAKER) }; + let context = &mut task::Context::from_waker(&waker); + + match this.poll(context) { + task::Poll::Ready(res) => res, + _ => panic!("SyncFuture: unexpected pending!"), + } + } +} + +/// This wrapper turns any [`std::io::Read`] into a blocking [`tokio::io::AsyncRead`], +/// which lets us abstract over sync & async readers in methods returning [`SyncFuture`]. +/// NOTE: you **should not** use this in async code. +#[repr(transparent)] +pub struct AsyncishRead(pub T); + +/// This lets us call [`SyncFuture, _>::wait`], +/// and allows the future to await on any of the [`AsyncRead`] +/// and [`AsyncReadExt`] methods on `AsyncishRead`. +impl SyncPostulate for AsyncishRead {} + +impl tokio::io::AsyncRead for AsyncishRead { + #[inline(always)] + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> task::Poll> { + task::Poll::Ready( + // `Read::read` will block, meaning we don't need a real event loop! + self.0 + .read(buf.initialize_unfilled()) + .map(|sz| buf.advance(sz)), + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + // async helper(stream: &mut impl AsyncRead) -> io::Result + fn bytes_add( + stream: &mut Reader, + ) -> SyncFuture> + '_> + where + Reader: tokio::io::AsyncRead + Unpin, + { + SyncFuture::new(async move { + let a = stream.read_u32().await?; + let b = stream.read_u32().await?; + Ok(a + b) + }) + } + + #[test] + fn test_sync() { + let bytes = [100u32.to_be_bytes(), 200u32.to_be_bytes()].concat(); + let res = bytes_add(&mut AsyncishRead(&mut &bytes[..])) + .wait() + .unwrap(); + assert_eq!(res, 300); + } + + // We need a single-threaded executor for this test + #[tokio::test(flavor = "current_thread")] + async fn test_async() { + let (mut tx, mut rx) = tokio::net::UnixStream::pair().unwrap(); + + let write = async move { + tx.write_u32(100).await?; + tx.write_u32(200).await?; + Ok(()) + }; + + let (res, ()) = tokio::try_join!(bytes_add(&mut rx), write).unwrap(); + assert_eq!(res, 300); + } +}