mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-16 01:42:55 +00:00
Enable async deserialization of FeMessage
Now it's possible to call Fe{Startup,}Message in both
sync and async contexts, which is good for proxy.
Co-authored-by: bojanserafimov <bojan.serafimov7@gmail.com>
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -2765,6 +2765,7 @@ dependencies = [
|
||||
"jsonwebtoken",
|
||||
"lazy_static",
|
||||
"nix",
|
||||
"pin-project-lite",
|
||||
"postgres 0.19.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)",
|
||||
"postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)",
|
||||
"rand",
|
||||
|
||||
@@ -7,10 +7,10 @@ edition = "2021"
|
||||
[dependencies]
|
||||
anyhow = "1.0"
|
||||
bincode = "1.3"
|
||||
byteorder = "1.4.3"
|
||||
bytes = "1.0.1"
|
||||
hyper = { version = "0.14.7", features = ["full"] }
|
||||
lazy_static = "1.4.0"
|
||||
pin-project-lite = "0.2.7"
|
||||
postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }
|
||||
postgres-protocol = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }
|
||||
routerify = "2"
|
||||
@@ -33,7 +33,8 @@ zenith_metrics = { path = "../zenith_metrics" }
|
||||
workspace_hack = { path = "../workspace_hack" }
|
||||
|
||||
[dev-dependencies]
|
||||
byteorder = "1.4.3"
|
||||
bytes = "1.0.1"
|
||||
hex-literal = "0.3"
|
||||
bytes = "1.0"
|
||||
webpki = "0.21"
|
||||
tempfile = "3.2"
|
||||
webpki = "0.21"
|
||||
|
||||
@@ -42,6 +42,9 @@ pub mod logging;
|
||||
pub mod accum;
|
||||
pub mod shutdown;
|
||||
|
||||
// Tools for calling certain async methods in sync contexts
|
||||
pub mod sync;
|
||||
|
||||
// Utility for binding TcpListeners with proper socket options.
|
||||
pub mod tcp_listener;
|
||||
|
||||
|
||||
@@ -2,17 +2,17 @@
|
||||
//! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
|
||||
//! on message formats.
|
||||
|
||||
use crate::sync::{AsyncishRead, SyncFuture};
|
||||
use anyhow::{bail, ensure, Context, Result};
|
||||
use byteorder::{BigEndian, ByteOrder};
|
||||
use byteorder::{ReadBytesExt, BE};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use postgres_protocol::PG_EPOCH;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::io::Read;
|
||||
use std::future::Future;
|
||||
use std::io::{self, Cursor};
|
||||
use std::str;
|
||||
use std::time::{Duration, SystemTime};
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tracing::info;
|
||||
|
||||
pub type Oid = u32;
|
||||
@@ -96,14 +96,14 @@ impl FeMessage {
|
||||
/// One way to handle this properly:
|
||||
///
|
||||
/// ```
|
||||
/// # use std::io::Read;
|
||||
/// # use std::io;
|
||||
/// # use zenith_utils::pq_proto::FeMessage;
|
||||
/// #
|
||||
/// # fn process_message(msg: FeMessage) -> anyhow::Result<()> {
|
||||
/// # Ok(())
|
||||
/// # };
|
||||
/// #
|
||||
/// fn do_the_job(stream: &mut impl Read) -> anyhow::Result<()> {
|
||||
/// fn do_the_job(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result<()> {
|
||||
/// while let Some(msg) = FeMessage::read(stream)? {
|
||||
/// process_message(msg)?;
|
||||
/// }
|
||||
@@ -111,124 +111,159 @@ impl FeMessage {
|
||||
/// Ok(())
|
||||
/// }
|
||||
/// ```
|
||||
pub fn read(stream: &mut impl Read) -> anyhow::Result<Option<FeMessage>> {
|
||||
// Each libpq message begins with a message type byte, followed by message length
|
||||
// If the client closes the connection, return None. But if the client closes the
|
||||
// connection in the middle of a message, we will return an error.
|
||||
let tag = match stream.read_u8() {
|
||||
Ok(b) => b,
|
||||
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
let len = stream.read_u32::<BE>()?;
|
||||
#[inline(never)]
|
||||
pub fn read(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result<Option<FeMessage>> {
|
||||
Self::read_fut(&mut AsyncishRead(stream)).wait()
|
||||
}
|
||||
|
||||
// The message length includes itself, so it better be at least 4
|
||||
let bodylen = len
|
||||
.checked_sub(4)
|
||||
.context("invalid message length: parsing u32")?;
|
||||
/// Read one message from the stream.
|
||||
/// See documentation for `Self::read`.
|
||||
pub fn read_fut<Reader>(
|
||||
stream: &mut Reader,
|
||||
) -> SyncFuture<Reader, impl Future<Output = anyhow::Result<Option<FeMessage>>> + '_>
|
||||
where
|
||||
Reader: tokio::io::AsyncRead + Unpin,
|
||||
{
|
||||
// We return a Future that's sync (has a `wait` method) if and only if the provided stream is SyncProof.
|
||||
// SyncFuture contract: we are only allowed to await on sync-proof futures, the AsyncRead and
|
||||
// AsyncReadExt methods of the stream.
|
||||
SyncFuture::new(async move {
|
||||
// Each libpq message begins with a message type byte, followed by message length
|
||||
// If the client closes the connection, return None. But if the client closes the
|
||||
// connection in the middle of a message, we will return an error.
|
||||
let tag = match stream.read_u8().await {
|
||||
Ok(b) => b,
|
||||
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
let len = stream.read_u32().await?;
|
||||
|
||||
// Read message body
|
||||
let mut body_buf: Vec<u8> = vec![0; bodylen as usize];
|
||||
stream.read_exact(&mut body_buf)?;
|
||||
// The message length includes itself, so it better be at least 4
|
||||
let bodylen = len
|
||||
.checked_sub(4)
|
||||
.context("invalid message length: parsing u32")?;
|
||||
|
||||
let body = Bytes::from(body_buf);
|
||||
// Read message body
|
||||
let mut body_buf: Vec<u8> = vec![0; bodylen as usize];
|
||||
stream.read_exact(&mut body_buf).await?;
|
||||
|
||||
// Parse it
|
||||
match tag {
|
||||
b'Q' => Ok(Some(FeMessage::Query(FeQueryMessage { body }))),
|
||||
b'P' => Ok(Some(FeParseMessage::parse(body)?)),
|
||||
b'D' => Ok(Some(FeDescribeMessage::parse(body)?)),
|
||||
b'E' => Ok(Some(FeExecuteMessage::parse(body)?)),
|
||||
b'B' => Ok(Some(FeBindMessage::parse(body)?)),
|
||||
b'C' => Ok(Some(FeCloseMessage::parse(body)?)),
|
||||
b'S' => Ok(Some(FeMessage::Sync)),
|
||||
b'X' => Ok(Some(FeMessage::Terminate)),
|
||||
b'd' => Ok(Some(FeMessage::CopyData(body))),
|
||||
b'c' => Ok(Some(FeMessage::CopyDone)),
|
||||
b'f' => Ok(Some(FeMessage::CopyFail)),
|
||||
b'p' => Ok(Some(FeMessage::PasswordMessage(body))),
|
||||
tag => bail!("unknown message tag: {},'{:?}'", tag, body),
|
||||
}
|
||||
let body = Bytes::from(body_buf);
|
||||
|
||||
// Parse it
|
||||
match tag {
|
||||
b'Q' => Ok(Some(FeMessage::Query(FeQueryMessage { body }))),
|
||||
b'P' => Ok(Some(FeParseMessage::parse(body)?)),
|
||||
b'D' => Ok(Some(FeDescribeMessage::parse(body)?)),
|
||||
b'E' => Ok(Some(FeExecuteMessage::parse(body)?)),
|
||||
b'B' => Ok(Some(FeBindMessage::parse(body)?)),
|
||||
b'C' => Ok(Some(FeCloseMessage::parse(body)?)),
|
||||
b'S' => Ok(Some(FeMessage::Sync)),
|
||||
b'X' => Ok(Some(FeMessage::Terminate)),
|
||||
b'd' => Ok(Some(FeMessage::CopyData(body))),
|
||||
b'c' => Ok(Some(FeMessage::CopyDone)),
|
||||
b'f' => Ok(Some(FeMessage::CopyFail)),
|
||||
b'p' => Ok(Some(FeMessage::PasswordMessage(body))),
|
||||
tag => bail!("unknown message tag: {},'{:?}'", tag, body),
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl FeStartupPacket {
|
||||
/// Read startup message from the stream.
|
||||
pub fn read(stream: &mut impl std::io::Read) -> anyhow::Result<Option<FeMessage>> {
|
||||
// XXX: It's tempting yet undesirable to accept `stream` by value,
|
||||
// since such a change will cause user-supplied &mut references to be consumed
|
||||
pub fn read(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result<Option<FeMessage>> {
|
||||
Self::read_fut(&mut AsyncishRead(stream)).wait()
|
||||
}
|
||||
|
||||
/// Read startup message from the stream.
|
||||
// XXX: It's tempting yet undesirable to accept `stream` by value,
|
||||
// since such a change will cause user-supplied &mut references to be consumed
|
||||
pub fn read_fut<Reader>(
|
||||
stream: &mut Reader,
|
||||
) -> SyncFuture<Reader, impl Future<Output = anyhow::Result<Option<FeMessage>>> + '_>
|
||||
where
|
||||
Reader: tokio::io::AsyncRead + Unpin,
|
||||
{
|
||||
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
|
||||
const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
|
||||
const CANCEL_REQUEST_CODE: u32 = 5678;
|
||||
const NEGOTIATE_SSL_CODE: u32 = 5679;
|
||||
const NEGOTIATE_GSS_CODE: u32 = 5680;
|
||||
|
||||
// Read length. If the connection is closed before reading anything (or before
|
||||
// reading 4 bytes, to be precise), return None to indicate that the connection
|
||||
// was closed. This matches the PostgreSQL server's behavior, which avoids noise
|
||||
// in the log if the client opens connection but closes it immediately.
|
||||
let len = match stream.read_u32::<BE>() {
|
||||
Ok(len) => len as usize,
|
||||
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
SyncFuture::new(async move {
|
||||
// Read length. If the connection is closed before reading anything (or before
|
||||
// reading 4 bytes, to be precise), return None to indicate that the connection
|
||||
// was closed. This matches the PostgreSQL server's behavior, which avoids noise
|
||||
// in the log if the client opens connection but closes it immediately.
|
||||
let len = match stream.read_u32().await {
|
||||
Ok(len) => len as usize,
|
||||
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
|
||||
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
|
||||
bail!("invalid message length");
|
||||
}
|
||||
|
||||
let request_code = stream.read_u32::<BE>()?;
|
||||
|
||||
// the rest of startup packet are params
|
||||
let params_len = len - 8;
|
||||
let mut params_bytes = vec![0u8; params_len];
|
||||
stream.read_exact(params_bytes.as_mut())?;
|
||||
|
||||
// Parse params depending on request code
|
||||
let most_sig_16_bits = request_code >> 16;
|
||||
let least_sig_16_bits = request_code & ((1 << 16) - 1);
|
||||
let message = match (most_sig_16_bits, least_sig_16_bits) {
|
||||
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
|
||||
ensure!(params_len == 8, "expected 8 bytes for CancelRequest params");
|
||||
let mut cursor = Cursor::new(params_bytes);
|
||||
FeStartupPacket::CancelRequest(CancelKeyData {
|
||||
backend_pid: cursor.read_i32::<BigEndian>()?,
|
||||
cancel_key: cursor.read_i32::<BigEndian>()?,
|
||||
})
|
||||
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
|
||||
bail!("invalid message length");
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => FeStartupPacket::SslRequest,
|
||||
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => FeStartupPacket::GssEncRequest,
|
||||
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
|
||||
bail!("Unrecognized request code {}", unrecognized_code)
|
||||
}
|
||||
(major_version, minor_version) => {
|
||||
// TODO bail if protocol major_version is not 3?
|
||||
// Parse null-terminated (String) pairs of param name / param value
|
||||
let params_str = str::from_utf8(¶ms_bytes).unwrap();
|
||||
let mut params_tokens = params_str.split('\0');
|
||||
let mut params: HashMap<String, String> = HashMap::new();
|
||||
while let Some(name) = params_tokens.next() {
|
||||
let value = params_tokens
|
||||
.next()
|
||||
.context("expected even number of params in StartupMessage")?;
|
||||
if name == "options" {
|
||||
// deprecated way of passing params as cmd line args
|
||||
for cmdopt in value.split(' ') {
|
||||
let nameval: Vec<&str> = cmdopt.split('=').collect();
|
||||
if nameval.len() == 2 {
|
||||
params.insert(nameval[0].to_string(), nameval[1].to_string());
|
||||
|
||||
let request_code = stream.read_u32().await?;
|
||||
|
||||
// the rest of startup packet are params
|
||||
let params_len = len - 8;
|
||||
let mut params_bytes = vec![0u8; params_len];
|
||||
stream.read_exact(params_bytes.as_mut()).await?;
|
||||
|
||||
// Parse params depending on request code
|
||||
let most_sig_16_bits = request_code >> 16;
|
||||
let least_sig_16_bits = request_code & ((1 << 16) - 1);
|
||||
let message = match (most_sig_16_bits, least_sig_16_bits) {
|
||||
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
|
||||
ensure!(params_len == 8, "expected 8 bytes for CancelRequest params");
|
||||
let mut cursor = Cursor::new(params_bytes);
|
||||
FeStartupPacket::CancelRequest(CancelKeyData {
|
||||
backend_pid: cursor.read_i32().await?,
|
||||
cancel_key: cursor.read_i32().await?,
|
||||
})
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => FeStartupPacket::SslRequest,
|
||||
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
|
||||
FeStartupPacket::GssEncRequest
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
|
||||
bail!("Unrecognized request code {}", unrecognized_code)
|
||||
}
|
||||
(major_version, minor_version) => {
|
||||
// TODO bail if protocol major_version is not 3?
|
||||
// Parse null-terminated (String) pairs of param name / param value
|
||||
let params_str = str::from_utf8(¶ms_bytes).unwrap();
|
||||
let mut params_tokens = params_str.split('\0');
|
||||
let mut params: HashMap<String, String> = HashMap::new();
|
||||
while let Some(name) = params_tokens.next() {
|
||||
let value = params_tokens
|
||||
.next()
|
||||
.context("expected even number of params in StartupMessage")?;
|
||||
if name == "options" {
|
||||
// deprecated way of passing params as cmd line args
|
||||
for cmdopt in value.split(' ') {
|
||||
let nameval: Vec<&str> = cmdopt.split('=').collect();
|
||||
if nameval.len() == 2 {
|
||||
params.insert(nameval[0].to_string(), nameval[1].to_string());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
params.insert(name.to_string(), value.to_string());
|
||||
}
|
||||
} else {
|
||||
params.insert(name.to_string(), value.to_string());
|
||||
}
|
||||
FeStartupPacket::StartupMessage {
|
||||
major_version,
|
||||
minor_version,
|
||||
params,
|
||||
}
|
||||
}
|
||||
FeStartupPacket::StartupMessage {
|
||||
major_version,
|
||||
minor_version,
|
||||
params,
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(Some(FeMessage::StartupPacket(message)))
|
||||
};
|
||||
Ok(Some(FeMessage::StartupPacket(message)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -502,7 +537,7 @@ where
|
||||
f(buf)?;
|
||||
|
||||
let size = i32::from_usize(buf.len() - base)?;
|
||||
BigEndian::write_i32(&mut buf[base..], size);
|
||||
(&mut buf[base..]).put_slice(&size.to_be_bytes());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -977,4 +1012,13 @@ mod tests {
|
||||
let zf_parsed = ZenithFeedback::parse(data.freeze());
|
||||
assert_eq!(zf, zf_parsed);
|
||||
}
|
||||
|
||||
// Make sure that `read` is sync/async callable
|
||||
async fn _assert(stream: &mut (impl tokio::io::AsyncRead + Unpin)) {
|
||||
let _ = FeMessage::read(&mut [].as_ref());
|
||||
let _ = FeMessage::read_fut(stream).await;
|
||||
|
||||
let _ = FeStartupPacket::read(&mut [].as_ref());
|
||||
let _ = FeStartupPacket::read_fut(stream).await;
|
||||
}
|
||||
}
|
||||
|
||||
179
zenith_utils/src/sync.rs
Normal file
179
zenith_utils/src/sync.rs
Normal file
@@ -0,0 +1,179 @@
|
||||
use pin_project_lite::pin_project;
|
||||
use std::future::Future;
|
||||
use std::marker::PhantomData;
|
||||
use std::pin::Pin;
|
||||
use std::{io, task};
|
||||
|
||||
pin_project! {
|
||||
/// We use this future to mark certain methods
|
||||
/// as callable in both sync and async modes.
|
||||
#[repr(transparent)]
|
||||
pub struct SyncFuture<S, T: Future> {
|
||||
#[pin]
|
||||
inner: T,
|
||||
_marker: PhantomData<S>,
|
||||
}
|
||||
}
|
||||
|
||||
/// This wrapper lets us synchronously wait for inner future's completion
|
||||
/// (see [`SyncFuture::wait`]) **provided that `S` implements [`SyncProof`]**.
|
||||
/// For instance, `S` may be substituted with types implementing
|
||||
/// [`tokio::io::AsyncRead`], but it's not the only viable option.
|
||||
impl<S, T: Future> SyncFuture<S, T> {
|
||||
/// NOTE: caller should carefully pick a type for `S`,
|
||||
/// because we don't want to enable [`SyncFuture::wait`] when
|
||||
/// it's in fact impossible to run the future synchronously.
|
||||
/// Violation of this contract will not cause UB, but
|
||||
/// panics and async event loop freezes won't please you.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```
|
||||
/// # use zenith_utils::sync::SyncFuture;
|
||||
/// # use std::future::Future;
|
||||
/// # use tokio::io::AsyncReadExt;
|
||||
/// #
|
||||
/// // Parse a pair of numbers from a stream
|
||||
/// pub fn parse_pair<Reader>(
|
||||
/// stream: &mut Reader,
|
||||
/// ) -> SyncFuture<Reader, impl Future<Output = anyhow::Result<(u32, u64)>> + '_>
|
||||
/// where
|
||||
/// Reader: tokio::io::AsyncRead + Unpin,
|
||||
/// {
|
||||
/// // If `Reader` is a `SyncProof`, this will give caller
|
||||
/// // an opportunity to use `SyncFuture::wait`, because
|
||||
/// // `.await` will always result in `Poll::Ready`.
|
||||
/// SyncFuture::new(async move {
|
||||
/// let x = stream.read_u32().await?;
|
||||
/// let y = stream.read_u64().await?;
|
||||
/// Ok((x, y))
|
||||
/// })
|
||||
/// }
|
||||
/// ```
|
||||
pub fn new(inner: T) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, T: Future> Future for SyncFuture<S, T> {
|
||||
type Output = T::Output;
|
||||
|
||||
/// In async code, [`SyncFuture`] behaves like a regular wrapper.
|
||||
#[inline(always)]
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
|
||||
self.project().inner.poll(cx)
|
||||
}
|
||||
}
|
||||
|
||||
/// Postulates that we can call [`SyncFuture::wait`].
|
||||
/// If implementer is also a [`Future`], it should always
|
||||
/// return [`task::Poll::Ready`] from [`Future::poll`].
|
||||
///
|
||||
/// Each implementation should document which futures
|
||||
/// specifically are being declared sync-proof.
|
||||
pub trait SyncPostulate {}
|
||||
|
||||
impl<T: SyncPostulate> SyncPostulate for &T {}
|
||||
impl<T: SyncPostulate> SyncPostulate for &mut T {}
|
||||
|
||||
impl<P: SyncPostulate, T: Future> SyncFuture<P, T> {
|
||||
/// Synchronously wait for future completion.
|
||||
pub fn wait(mut self) -> T::Output {
|
||||
const RAW_WAKER: task::RawWaker = task::RawWaker::new(
|
||||
std::ptr::null(),
|
||||
&task::RawWakerVTable::new(
|
||||
|_| RAW_WAKER,
|
||||
|_| panic!("SyncFuture: failed to wake"),
|
||||
|_| panic!("SyncFuture: failed to wake by ref"),
|
||||
|_| { /* drop is no-op */ },
|
||||
),
|
||||
);
|
||||
|
||||
// SAFETY: We never move `self` during this call;
|
||||
// furthermore, it will be dropped in the end regardless of panics
|
||||
let this = unsafe { Pin::new_unchecked(&mut self) };
|
||||
|
||||
// SAFETY: This waker doesn't do anything apart from panicking
|
||||
let waker = unsafe { task::Waker::from_raw(RAW_WAKER) };
|
||||
let context = &mut task::Context::from_waker(&waker);
|
||||
|
||||
match this.poll(context) {
|
||||
task::Poll::Ready(res) => res,
|
||||
_ => panic!("SyncFuture: unexpected pending!"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// This wrapper turns any [`std::io::Read`] into a blocking [`tokio::io::AsyncRead`],
|
||||
/// which lets us abstract over sync & async readers in methods returning [`SyncFuture`].
|
||||
/// NOTE: you **should not** use this in async code.
|
||||
#[repr(transparent)]
|
||||
pub struct AsyncishRead<T: io::Read + Unpin>(pub T);
|
||||
|
||||
/// This lets us call [`SyncFuture<AsyncishRead<_>, _>::wait`],
|
||||
/// and allows the future to await on any of the [`AsyncRead`]
|
||||
/// and [`AsyncReadExt`] methods on `AsyncishRead`.
|
||||
impl<T: io::Read + Unpin> SyncPostulate for AsyncishRead<T> {}
|
||||
|
||||
impl<T: io::Read + Unpin> tokio::io::AsyncRead for AsyncishRead<T> {
|
||||
#[inline(always)]
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
_cx: &mut task::Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
task::Poll::Ready(
|
||||
// `Read::read` will block, meaning we don't need a real event loop!
|
||||
self.0
|
||||
.read(buf.initialize_unfilled())
|
||||
.map(|sz| buf.advance(sz)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
// async helper(stream: &mut impl AsyncRead) -> io::Result<u32>
|
||||
fn bytes_add<Reader>(
|
||||
stream: &mut Reader,
|
||||
) -> SyncFuture<Reader, impl Future<Output = io::Result<u32>> + '_>
|
||||
where
|
||||
Reader: tokio::io::AsyncRead + Unpin,
|
||||
{
|
||||
SyncFuture::new(async move {
|
||||
let a = stream.read_u32().await?;
|
||||
let b = stream.read_u32().await?;
|
||||
Ok(a + b)
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sync() {
|
||||
let bytes = [100u32.to_be_bytes(), 200u32.to_be_bytes()].concat();
|
||||
let res = bytes_add(&mut AsyncishRead(&mut &bytes[..]))
|
||||
.wait()
|
||||
.unwrap();
|
||||
assert_eq!(res, 300);
|
||||
}
|
||||
|
||||
// We need a single-threaded executor for this test
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn test_async() {
|
||||
let (mut tx, mut rx) = tokio::net::UnixStream::pair().unwrap();
|
||||
|
||||
let write = async move {
|
||||
tx.write_u32(100).await?;
|
||||
tx.write_u32(200).await?;
|
||||
Ok(())
|
||||
};
|
||||
|
||||
let (res, ()) = tokio::try_join!(bytes_add(&mut rx), write).unwrap();
|
||||
assert_eq!(res, 300);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user