Enable async deserialization of FeMessage

Now it's possible to call Fe{Startup,}Message in both
sync and async contexts, which is good for proxy.

Co-authored-by: bojanserafimov <bojan.serafimov7@gmail.com>
This commit is contained in:
Dmitry Ivanov
2021-10-28 19:53:26 +03:00
parent 33251a9d8f
commit c2927353a5
5 changed files with 334 additions and 106 deletions

1
Cargo.lock generated
View File

@@ -2765,6 +2765,7 @@ dependencies = [
"jsonwebtoken",
"lazy_static",
"nix",
"pin-project-lite",
"postgres 0.19.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)",
"postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)",
"rand",

View File

@@ -7,10 +7,10 @@ edition = "2021"
[dependencies]
anyhow = "1.0"
bincode = "1.3"
byteorder = "1.4.3"
bytes = "1.0.1"
hyper = { version = "0.14.7", features = ["full"] }
lazy_static = "1.4.0"
pin-project-lite = "0.2.7"
postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }
postgres-protocol = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }
routerify = "2"
@@ -33,7 +33,8 @@ zenith_metrics = { path = "../zenith_metrics" }
workspace_hack = { path = "../workspace_hack" }
[dev-dependencies]
byteorder = "1.4.3"
bytes = "1.0.1"
hex-literal = "0.3"
bytes = "1.0"
webpki = "0.21"
tempfile = "3.2"
webpki = "0.21"

View File

@@ -42,6 +42,9 @@ pub mod logging;
pub mod accum;
pub mod shutdown;
// Tools for calling certain async methods in sync contexts
pub mod sync;
// Utility for binding TcpListeners with proper socket options.
pub mod tcp_listener;

View File

@@ -2,17 +2,17 @@
//! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
//! on message formats.
use crate::sync::{AsyncishRead, SyncFuture};
use anyhow::{bail, ensure, Context, Result};
use byteorder::{BigEndian, ByteOrder};
use byteorder::{ReadBytesExt, BE};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use postgres_protocol::PG_EPOCH;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::io::Read;
use std::future::Future;
use std::io::{self, Cursor};
use std::str;
use std::time::{Duration, SystemTime};
use tokio::io::AsyncReadExt;
use tracing::info;
pub type Oid = u32;
@@ -96,14 +96,14 @@ impl FeMessage {
/// One way to handle this properly:
///
/// ```
/// # use std::io::Read;
/// # use std::io;
/// # use zenith_utils::pq_proto::FeMessage;
/// #
/// # fn process_message(msg: FeMessage) -> anyhow::Result<()> {
/// # Ok(())
/// # };
/// #
/// fn do_the_job(stream: &mut impl Read) -> anyhow::Result<()> {
/// fn do_the_job(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result<()> {
/// while let Some(msg) = FeMessage::read(stream)? {
/// process_message(msg)?;
/// }
@@ -111,124 +111,159 @@ impl FeMessage {
/// Ok(())
/// }
/// ```
pub fn read(stream: &mut impl Read) -> anyhow::Result<Option<FeMessage>> {
// Each libpq message begins with a message type byte, followed by message length
// If the client closes the connection, return None. But if the client closes the
// connection in the middle of a message, we will return an error.
let tag = match stream.read_u8() {
Ok(b) => b,
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(e.into()),
};
let len = stream.read_u32::<BE>()?;
#[inline(never)]
pub fn read(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result<Option<FeMessage>> {
Self::read_fut(&mut AsyncishRead(stream)).wait()
}
// The message length includes itself, so it better be at least 4
let bodylen = len
.checked_sub(4)
.context("invalid message length: parsing u32")?;
/// Read one message from the stream.
/// See documentation for `Self::read`.
pub fn read_fut<Reader>(
stream: &mut Reader,
) -> SyncFuture<Reader, impl Future<Output = anyhow::Result<Option<FeMessage>>> + '_>
where
Reader: tokio::io::AsyncRead + Unpin,
{
// We return a Future that's sync (has a `wait` method) if and only if the provided stream is SyncProof.
// SyncFuture contract: we are only allowed to await on sync-proof futures, the AsyncRead and
// AsyncReadExt methods of the stream.
SyncFuture::new(async move {
// Each libpq message begins with a message type byte, followed by message length
// If the client closes the connection, return None. But if the client closes the
// connection in the middle of a message, we will return an error.
let tag = match stream.read_u8().await {
Ok(b) => b,
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(e.into()),
};
let len = stream.read_u32().await?;
// Read message body
let mut body_buf: Vec<u8> = vec![0; bodylen as usize];
stream.read_exact(&mut body_buf)?;
// The message length includes itself, so it better be at least 4
let bodylen = len
.checked_sub(4)
.context("invalid message length: parsing u32")?;
let body = Bytes::from(body_buf);
// Read message body
let mut body_buf: Vec<u8> = vec![0; bodylen as usize];
stream.read_exact(&mut body_buf).await?;
// Parse it
match tag {
b'Q' => Ok(Some(FeMessage::Query(FeQueryMessage { body }))),
b'P' => Ok(Some(FeParseMessage::parse(body)?)),
b'D' => Ok(Some(FeDescribeMessage::parse(body)?)),
b'E' => Ok(Some(FeExecuteMessage::parse(body)?)),
b'B' => Ok(Some(FeBindMessage::parse(body)?)),
b'C' => Ok(Some(FeCloseMessage::parse(body)?)),
b'S' => Ok(Some(FeMessage::Sync)),
b'X' => Ok(Some(FeMessage::Terminate)),
b'd' => Ok(Some(FeMessage::CopyData(body))),
b'c' => Ok(Some(FeMessage::CopyDone)),
b'f' => Ok(Some(FeMessage::CopyFail)),
b'p' => Ok(Some(FeMessage::PasswordMessage(body))),
tag => bail!("unknown message tag: {},'{:?}'", tag, body),
}
let body = Bytes::from(body_buf);
// Parse it
match tag {
b'Q' => Ok(Some(FeMessage::Query(FeQueryMessage { body }))),
b'P' => Ok(Some(FeParseMessage::parse(body)?)),
b'D' => Ok(Some(FeDescribeMessage::parse(body)?)),
b'E' => Ok(Some(FeExecuteMessage::parse(body)?)),
b'B' => Ok(Some(FeBindMessage::parse(body)?)),
b'C' => Ok(Some(FeCloseMessage::parse(body)?)),
b'S' => Ok(Some(FeMessage::Sync)),
b'X' => Ok(Some(FeMessage::Terminate)),
b'd' => Ok(Some(FeMessage::CopyData(body))),
b'c' => Ok(Some(FeMessage::CopyDone)),
b'f' => Ok(Some(FeMessage::CopyFail)),
b'p' => Ok(Some(FeMessage::PasswordMessage(body))),
tag => bail!("unknown message tag: {},'{:?}'", tag, body),
}
})
}
}
impl FeStartupPacket {
/// Read startup message from the stream.
pub fn read(stream: &mut impl std::io::Read) -> anyhow::Result<Option<FeMessage>> {
// XXX: It's tempting yet undesirable to accept `stream` by value,
// since such a change will cause user-supplied &mut references to be consumed
pub fn read(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result<Option<FeMessage>> {
Self::read_fut(&mut AsyncishRead(stream)).wait()
}
/// Read startup message from the stream.
// XXX: It's tempting yet undesirable to accept `stream` by value,
// since such a change will cause user-supplied &mut references to be consumed
pub fn read_fut<Reader>(
stream: &mut Reader,
) -> SyncFuture<Reader, impl Future<Output = anyhow::Result<Option<FeMessage>>> + '_>
where
Reader: tokio::io::AsyncRead + Unpin,
{
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
const CANCEL_REQUEST_CODE: u32 = 5678;
const NEGOTIATE_SSL_CODE: u32 = 5679;
const NEGOTIATE_GSS_CODE: u32 = 5680;
// Read length. If the connection is closed before reading anything (or before
// reading 4 bytes, to be precise), return None to indicate that the connection
// was closed. This matches the PostgreSQL server's behavior, which avoids noise
// in the log if the client opens connection but closes it immediately.
let len = match stream.read_u32::<BE>() {
Ok(len) => len as usize,
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(e.into()),
};
SyncFuture::new(async move {
// Read length. If the connection is closed before reading anything (or before
// reading 4 bytes, to be precise), return None to indicate that the connection
// was closed. This matches the PostgreSQL server's behavior, which avoids noise
// in the log if the client opens connection but closes it immediately.
let len = match stream.read_u32().await {
Ok(len) => len as usize,
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(e.into()),
};
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
bail!("invalid message length");
}
let request_code = stream.read_u32::<BE>()?;
// the rest of startup packet are params
let params_len = len - 8;
let mut params_bytes = vec![0u8; params_len];
stream.read_exact(params_bytes.as_mut())?;
// Parse params depending on request code
let most_sig_16_bits = request_code >> 16;
let least_sig_16_bits = request_code & ((1 << 16) - 1);
let message = match (most_sig_16_bits, least_sig_16_bits) {
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
ensure!(params_len == 8, "expected 8 bytes for CancelRequest params");
let mut cursor = Cursor::new(params_bytes);
FeStartupPacket::CancelRequest(CancelKeyData {
backend_pid: cursor.read_i32::<BigEndian>()?,
cancel_key: cursor.read_i32::<BigEndian>()?,
})
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
bail!("invalid message length");
}
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => FeStartupPacket::SslRequest,
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => FeStartupPacket::GssEncRequest,
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
bail!("Unrecognized request code {}", unrecognized_code)
}
(major_version, minor_version) => {
// TODO bail if protocol major_version is not 3?
// Parse null-terminated (String) pairs of param name / param value
let params_str = str::from_utf8(&params_bytes).unwrap();
let mut params_tokens = params_str.split('\0');
let mut params: HashMap<String, String> = HashMap::new();
while let Some(name) = params_tokens.next() {
let value = params_tokens
.next()
.context("expected even number of params in StartupMessage")?;
if name == "options" {
// deprecated way of passing params as cmd line args
for cmdopt in value.split(' ') {
let nameval: Vec<&str> = cmdopt.split('=').collect();
if nameval.len() == 2 {
params.insert(nameval[0].to_string(), nameval[1].to_string());
let request_code = stream.read_u32().await?;
// the rest of startup packet are params
let params_len = len - 8;
let mut params_bytes = vec![0u8; params_len];
stream.read_exact(params_bytes.as_mut()).await?;
// Parse params depending on request code
let most_sig_16_bits = request_code >> 16;
let least_sig_16_bits = request_code & ((1 << 16) - 1);
let message = match (most_sig_16_bits, least_sig_16_bits) {
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
ensure!(params_len == 8, "expected 8 bytes for CancelRequest params");
let mut cursor = Cursor::new(params_bytes);
FeStartupPacket::CancelRequest(CancelKeyData {
backend_pid: cursor.read_i32().await?,
cancel_key: cursor.read_i32().await?,
})
}
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => FeStartupPacket::SslRequest,
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
FeStartupPacket::GssEncRequest
}
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
bail!("Unrecognized request code {}", unrecognized_code)
}
(major_version, minor_version) => {
// TODO bail if protocol major_version is not 3?
// Parse null-terminated (String) pairs of param name / param value
let params_str = str::from_utf8(&params_bytes).unwrap();
let mut params_tokens = params_str.split('\0');
let mut params: HashMap<String, String> = HashMap::new();
while let Some(name) = params_tokens.next() {
let value = params_tokens
.next()
.context("expected even number of params in StartupMessage")?;
if name == "options" {
// deprecated way of passing params as cmd line args
for cmdopt in value.split(' ') {
let nameval: Vec<&str> = cmdopt.split('=').collect();
if nameval.len() == 2 {
params.insert(nameval[0].to_string(), nameval[1].to_string());
}
}
} else {
params.insert(name.to_string(), value.to_string());
}
} else {
params.insert(name.to_string(), value.to_string());
}
FeStartupPacket::StartupMessage {
major_version,
minor_version,
params,
}
}
FeStartupPacket::StartupMessage {
major_version,
minor_version,
params,
}
}
};
Ok(Some(FeMessage::StartupPacket(message)))
};
Ok(Some(FeMessage::StartupPacket(message)))
})
}
}
@@ -502,7 +537,7 @@ where
f(buf)?;
let size = i32::from_usize(buf.len() - base)?;
BigEndian::write_i32(&mut buf[base..], size);
(&mut buf[base..]).put_slice(&size.to_be_bytes());
Ok(())
}
@@ -977,4 +1012,13 @@ mod tests {
let zf_parsed = ZenithFeedback::parse(data.freeze());
assert_eq!(zf, zf_parsed);
}
// Make sure that `read` is sync/async callable
async fn _assert(stream: &mut (impl tokio::io::AsyncRead + Unpin)) {
let _ = FeMessage::read(&mut [].as_ref());
let _ = FeMessage::read_fut(stream).await;
let _ = FeStartupPacket::read(&mut [].as_ref());
let _ = FeStartupPacket::read_fut(stream).await;
}
}

179
zenith_utils/src/sync.rs Normal file
View File

@@ -0,0 +1,179 @@
use pin_project_lite::pin_project;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::{io, task};
pin_project! {
/// We use this future to mark certain methods
/// as callable in both sync and async modes.
#[repr(transparent)]
pub struct SyncFuture<S, T: Future> {
#[pin]
inner: T,
_marker: PhantomData<S>,
}
}
/// This wrapper lets us synchronously wait for inner future's completion
/// (see [`SyncFuture::wait`]) **provided that `S` implements [`SyncProof`]**.
/// For instance, `S` may be substituted with types implementing
/// [`tokio::io::AsyncRead`], but it's not the only viable option.
impl<S, T: Future> SyncFuture<S, T> {
/// NOTE: caller should carefully pick a type for `S`,
/// because we don't want to enable [`SyncFuture::wait`] when
/// it's in fact impossible to run the future synchronously.
/// Violation of this contract will not cause UB, but
/// panics and async event loop freezes won't please you.
///
/// Example:
///
/// ```
/// # use zenith_utils::sync::SyncFuture;
/// # use std::future::Future;
/// # use tokio::io::AsyncReadExt;
/// #
/// // Parse a pair of numbers from a stream
/// pub fn parse_pair<Reader>(
/// stream: &mut Reader,
/// ) -> SyncFuture<Reader, impl Future<Output = anyhow::Result<(u32, u64)>> + '_>
/// where
/// Reader: tokio::io::AsyncRead + Unpin,
/// {
/// // If `Reader` is a `SyncProof`, this will give caller
/// // an opportunity to use `SyncFuture::wait`, because
/// // `.await` will always result in `Poll::Ready`.
/// SyncFuture::new(async move {
/// let x = stream.read_u32().await?;
/// let y = stream.read_u64().await?;
/// Ok((x, y))
/// })
/// }
/// ```
pub fn new(inner: T) -> Self {
Self {
inner,
_marker: PhantomData,
}
}
}
impl<S, T: Future> Future for SyncFuture<S, T> {
type Output = T::Output;
/// In async code, [`SyncFuture`] behaves like a regular wrapper.
#[inline(always)]
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
self.project().inner.poll(cx)
}
}
/// Postulates that we can call [`SyncFuture::wait`].
/// If implementer is also a [`Future`], it should always
/// return [`task::Poll::Ready`] from [`Future::poll`].
///
/// Each implementation should document which futures
/// specifically are being declared sync-proof.
pub trait SyncPostulate {}
impl<T: SyncPostulate> SyncPostulate for &T {}
impl<T: SyncPostulate> SyncPostulate for &mut T {}
impl<P: SyncPostulate, T: Future> SyncFuture<P, T> {
/// Synchronously wait for future completion.
pub fn wait(mut self) -> T::Output {
const RAW_WAKER: task::RawWaker = task::RawWaker::new(
std::ptr::null(),
&task::RawWakerVTable::new(
|_| RAW_WAKER,
|_| panic!("SyncFuture: failed to wake"),
|_| panic!("SyncFuture: failed to wake by ref"),
|_| { /* drop is no-op */ },
),
);
// SAFETY: We never move `self` during this call;
// furthermore, it will be dropped in the end regardless of panics
let this = unsafe { Pin::new_unchecked(&mut self) };
// SAFETY: This waker doesn't do anything apart from panicking
let waker = unsafe { task::Waker::from_raw(RAW_WAKER) };
let context = &mut task::Context::from_waker(&waker);
match this.poll(context) {
task::Poll::Ready(res) => res,
_ => panic!("SyncFuture: unexpected pending!"),
}
}
}
/// This wrapper turns any [`std::io::Read`] into a blocking [`tokio::io::AsyncRead`],
/// which lets us abstract over sync & async readers in methods returning [`SyncFuture`].
/// NOTE: you **should not** use this in async code.
#[repr(transparent)]
pub struct AsyncishRead<T: io::Read + Unpin>(pub T);
/// This lets us call [`SyncFuture<AsyncishRead<_>, _>::wait`],
/// and allows the future to await on any of the [`AsyncRead`]
/// and [`AsyncReadExt`] methods on `AsyncishRead`.
impl<T: io::Read + Unpin> SyncPostulate for AsyncishRead<T> {}
impl<T: io::Read + Unpin> tokio::io::AsyncRead for AsyncishRead<T> {
#[inline(always)]
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
task::Poll::Ready(
// `Read::read` will block, meaning we don't need a real event loop!
self.0
.read(buf.initialize_unfilled())
.map(|sz| buf.advance(sz)),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
// async helper(stream: &mut impl AsyncRead) -> io::Result<u32>
fn bytes_add<Reader>(
stream: &mut Reader,
) -> SyncFuture<Reader, impl Future<Output = io::Result<u32>> + '_>
where
Reader: tokio::io::AsyncRead + Unpin,
{
SyncFuture::new(async move {
let a = stream.read_u32().await?;
let b = stream.read_u32().await?;
Ok(a + b)
})
}
#[test]
fn test_sync() {
let bytes = [100u32.to_be_bytes(), 200u32.to_be_bytes()].concat();
let res = bytes_add(&mut AsyncishRead(&mut &bytes[..]))
.wait()
.unwrap();
assert_eq!(res, 300);
}
// We need a single-threaded executor for this test
#[tokio::test(flavor = "current_thread")]
async fn test_async() {
let (mut tx, mut rx) = tokio::net::UnixStream::pair().unwrap();
let write = async move {
tx.write_u32(100).await?;
tx.write_u32(200).await?;
Ok(())
};
let (res, ()) = tokio::try_join!(bytes_add(&mut rx), write).unwrap();
assert_eq!(res, 300);
}
}