mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-11 23:42:55 +00:00
Refactor postgres protocol parsing.
1) Remove allocation and data copy during each message read. Instead, parsing functions now accept BytesMut from which data they form messages, with pointers (e.g. in CopyData) pointing directly into BytesMut buffer. Accordingly, move ConnectionError containing IO error subtype into framed.rs providing this and leave in pq_proto only ProtocolError. 2) Remove anyhow from pq_proto. 3) Move FeStartupPacket out of FeMessage. Now FeStartupPacket::parse returns it directly, eliminating dead code where user wants startup packet but has to match for others. proxy stream.rs is adapted to framed.rs with minimal changes. It also benefits from framed.rs improvements described above.
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -2747,7 +2747,7 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
|
||||
name = "pq_proto"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"pin-project-lite",
|
||||
"postgres-protocol",
|
||||
|
||||
@@ -17,9 +17,9 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tracing::{debug, error, info, trace};
|
||||
|
||||
use pq_proto::framed::{Framed, FramedReader, FramedWriter};
|
||||
use pq_proto::framed::{ConnectionError, Framed, FramedReader, FramedWriter};
|
||||
use pq_proto::{
|
||||
BeMessage, ConnectionError, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR,
|
||||
BeMessage, FeMessage, FeStartupPacket, ProtocolError, SQLSTATE_INTERNAL_ERROR,
|
||||
SQLSTATE_SUCCESSFUL_COMPLETION,
|
||||
};
|
||||
|
||||
@@ -37,7 +37,7 @@ pub enum QueryError {
|
||||
|
||||
impl From<io::Error> for QueryError {
|
||||
fn from(e: io::Error) -> Self {
|
||||
Self::Disconnected(ConnectionError::Socket(e))
|
||||
Self::Disconnected(ConnectionError::Io(e))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -219,7 +219,7 @@ impl MaybeWriteOnly {
|
||||
}
|
||||
}
|
||||
|
||||
fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ConnectionError> {
|
||||
fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
|
||||
match self {
|
||||
MaybeWriteOnly::Full(framed) => framed.write_message(msg),
|
||||
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.write_message_noflush(msg),
|
||||
@@ -701,8 +701,7 @@ impl PostgresBackend {
|
||||
FeMessage::CopyData(_)
|
||||
| FeMessage::CopyDone
|
||||
| FeMessage::CopyFail
|
||||
| FeMessage::PasswordMessage(_)
|
||||
| FeMessage::StartupPacket(_) => {
|
||||
| FeMessage::PasswordMessage(_) => {
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"unexpected message type: {msg:?}",
|
||||
)));
|
||||
@@ -721,7 +720,7 @@ impl PostgresBackend {
|
||||
|
||||
let expected_end = match &end {
|
||||
ServerInitiated(_) | CopyDone | CopyFail | Terminate | EOF => true,
|
||||
CopyStreamHandlerEnd::Disconnected(ConnectionError::Socket(io_error))
|
||||
CopyStreamHandlerEnd::Disconnected(ConnectionError::Io(io_error))
|
||||
if is_expected_io_error(io_error) =>
|
||||
{
|
||||
true
|
||||
@@ -800,7 +799,7 @@ impl PostgresBackendReader {
|
||||
FeMessage::CopyFail => Err(CopyStreamHandlerEnd::CopyFail),
|
||||
FeMessage::Terminate => Err(CopyStreamHandlerEnd::Terminate),
|
||||
_ => Err(CopyStreamHandlerEnd::from(ConnectionError::Protocol(
|
||||
format!("unexpected message in COPY stream {:?}", msg),
|
||||
ProtocolError::Protocol(format!("unexpected message in COPY stream {:?}", msg)),
|
||||
))),
|
||||
},
|
||||
None => Err(CopyStreamHandlerEnd::EOF),
|
||||
@@ -871,7 +870,7 @@ pub fn short_error(e: &QueryError) -> String {
|
||||
|
||||
fn log_query_error(query: &str, e: &QueryError) {
|
||||
match e {
|
||||
QueryError::Disconnected(ConnectionError::Socket(io_error)) => {
|
||||
QueryError::Disconnected(ConnectionError::Io(io_error)) => {
|
||||
if is_expected_io_error(io_error) {
|
||||
info!("query handler for '{query}' failed with expected io error: {io_error}");
|
||||
} else {
|
||||
|
||||
@@ -5,8 +5,8 @@ edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
bytes.workspace = true
|
||||
byteorder.workspace = true
|
||||
pin-project-lite.workspace = true
|
||||
postgres-protocol.workspace = true
|
||||
rand.workspace = true
|
||||
|
||||
@@ -1,51 +1,84 @@
|
||||
//! Provides `Framed` -- writing/flushing and reading Postgres messages to/from
|
||||
//! the async stream.
|
||||
//! the async stream based on (and buffered with) BytesMut. All functions are
|
||||
//! cancellation safe.
|
||||
//!
|
||||
//! It is similar to what tokio_util::codec::Framed with appropriate codec
|
||||
//! provides, but `FramedReader` and `FramedWriter` read/write parts can be used
|
||||
//! separately without using split from futures::stream::StreamExt (which
|
||||
//! allocates box[1] in polling internally). tokio::io::split is used for splitting
|
||||
//! instead. Plus we customize error messages more than a single type for all io
|
||||
//! calls.
|
||||
//!
|
||||
//! [1] https://docs.rs/futures-util/0.3.26/src/futures_util/lock/bilock.rs.html#107
|
||||
use bytes::{Buf, BytesMut};
|
||||
use std::{
|
||||
future::Future,
|
||||
io::{self, ErrorKind},
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
|
||||
|
||||
use crate::{BeMessage, ConnectionError, FeMessage, FeStartupPacket};
|
||||
use crate::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
|
||||
|
||||
const INITIAL_CAPACITY: usize = 8 * 1024;
|
||||
|
||||
/// Error on postgres connection: either IO (physical transport error) or
|
||||
/// protocol violation.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ConnectionError {
|
||||
#[error(transparent)]
|
||||
Io(#[from] io::Error),
|
||||
#[error(transparent)]
|
||||
Protocol(#[from] ProtocolError),
|
||||
}
|
||||
|
||||
impl ConnectionError {
|
||||
/// Proxy stream.rs uses only io::Error; provide it.
|
||||
pub fn into_io_error(self) -> io::Error {
|
||||
match self {
|
||||
ConnectionError::Io(io) => io,
|
||||
ConnectionError::Protocol(pe) => io::Error::new(io::ErrorKind::Other, pe.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wraps async io `stream`, providing messages to write/flush + read Postgres
|
||||
/// messages.
|
||||
pub struct Framed<S> {
|
||||
stream: BufReader<S>,
|
||||
stream: S,
|
||||
read_buf: BytesMut,
|
||||
write_buf: BytesMut,
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> Framed<S> {
|
||||
impl<S> Framed<S> {
|
||||
pub fn new(stream: S) -> Self {
|
||||
Self {
|
||||
stream: BufReader::new(stream),
|
||||
stream,
|
||||
read_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
|
||||
write_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a shared reference to the underlying stream.
|
||||
pub fn get_ref(&self) -> &S {
|
||||
self.stream.get_ref()
|
||||
&self.stream
|
||||
}
|
||||
|
||||
/// Extract the underlying stream.
|
||||
pub fn into_inner(self) -> S {
|
||||
self.stream.into_inner()
|
||||
self.stream
|
||||
}
|
||||
|
||||
/// Return new Framed with stream type transformed by async f, for TLS
|
||||
/// upgrade.
|
||||
pub async fn map_stream<S2: AsyncRead, E, F, Fut>(self, f: F) -> Result<Framed<S2>, E>
|
||||
pub async fn map_stream<S2, E, F, Fut>(self, f: F) -> Result<Framed<S2>, E>
|
||||
where
|
||||
F: FnOnce(S) -> Fut,
|
||||
Fut: Future<Output = Result<S2, E>>,
|
||||
{
|
||||
let stream = f(self.stream.into_inner()).await?;
|
||||
let stream = f(self.stream).await?;
|
||||
Ok(Framed {
|
||||
stream: BufReader::new(stream),
|
||||
stream,
|
||||
read_buf: self.read_buf,
|
||||
write_buf: self.write_buf,
|
||||
})
|
||||
}
|
||||
@@ -55,24 +88,18 @@ impl<S: AsyncRead + Unpin> Framed<S> {
|
||||
pub async fn read_startup_message(
|
||||
&mut self,
|
||||
) -> Result<Option<FeStartupPacket>, ConnectionError> {
|
||||
let msg = FeStartupPacket::read(&mut self.stream).await?;
|
||||
|
||||
match msg {
|
||||
Some(FeMessage::StartupPacket(packet)) => Ok(Some(packet)),
|
||||
None => Ok(None),
|
||||
_ => panic!("unreachable state"),
|
||||
}
|
||||
read_message(&mut self.stream, &mut self.read_buf, FeStartupPacket::parse).await
|
||||
}
|
||||
|
||||
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
|
||||
FeMessage::read(&mut self.stream).await
|
||||
read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + AsyncRead + Unpin> Framed<S> {
|
||||
impl<S: AsyncWrite + Unpin> Framed<S> {
|
||||
/// Write next message to the output buffer; doesn't flush.
|
||||
pub fn write_message(&mut self, msg: &BeMessage<'_>) -> Result<(), ConnectionError> {
|
||||
BeMessage::write(&mut self.write_buf, msg).map_err(|e| e.into())
|
||||
pub fn write_message(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
|
||||
BeMessage::write(&mut self.write_buf, msg)
|
||||
}
|
||||
|
||||
/// Flush out the buffer. This function is cancellation safe: it can be
|
||||
@@ -93,7 +120,10 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Framed<S> {
|
||||
/// https://github.com/tokio-rs/tls/issues/40
|
||||
pub fn split(self) -> (FramedReader<S>, FramedWriter<S>) {
|
||||
let (read_half, write_half) = tokio::io::split(self.stream);
|
||||
let reader = FramedReader { stream: read_half };
|
||||
let reader = FramedReader {
|
||||
stream: read_half,
|
||||
read_buf: self.read_buf,
|
||||
};
|
||||
let writer = FramedWriter {
|
||||
stream: write_half,
|
||||
write_buf: self.write_buf,
|
||||
@@ -105,6 +135,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Framed<S> {
|
||||
pub fn unsplit(reader: FramedReader<S>, writer: FramedWriter<S>) -> Self {
|
||||
Self {
|
||||
stream: reader.stream.unsplit(writer.stream),
|
||||
read_buf: reader.read_buf,
|
||||
write_buf: writer.write_buf,
|
||||
}
|
||||
}
|
||||
@@ -112,25 +143,26 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Framed<S> {
|
||||
|
||||
/// Read-only version of `Framed`.
|
||||
pub struct FramedReader<S> {
|
||||
stream: ReadHalf<BufReader<S>>,
|
||||
stream: ReadHalf<S>,
|
||||
read_buf: BytesMut,
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> FramedReader<S> {
|
||||
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
|
||||
FeMessage::read(&mut self.stream).await
|
||||
read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Write-only version of `Framed`.
|
||||
pub struct FramedWriter<S> {
|
||||
stream: WriteHalf<BufReader<S>>,
|
||||
stream: WriteHalf<S>,
|
||||
write_buf: BytesMut,
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + AsyncRead + Unpin> FramedWriter<S> {
|
||||
impl<S: AsyncWrite + Unpin> FramedWriter<S> {
|
||||
/// Write next message to the output buffer; doesn't flush.
|
||||
pub fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ConnectionError> {
|
||||
BeMessage::write(&mut self.write_buf, msg).map_err(|e| e.into())
|
||||
pub fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
|
||||
BeMessage::write(&mut self.write_buf, msg)
|
||||
}
|
||||
|
||||
/// Flush out the buffer. This function is cancellation safe: it can be
|
||||
@@ -145,6 +177,43 @@ impl<S: AsyncWrite + AsyncRead + Unpin> FramedWriter<S> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Read next message from the stream. Returns Ok(None), if EOF happened and we
|
||||
/// don't have remaining data in the buffer. This function is cancellation safe:
|
||||
/// you can drop future which is not yet complete and finalize reading message
|
||||
/// with the next call.
|
||||
///
|
||||
/// Parametrized to allow reading startup or usual message, having different
|
||||
/// format.
|
||||
async fn read_message<S: AsyncRead + Unpin, M, P>(
|
||||
stream: &mut S,
|
||||
read_buf: &mut BytesMut,
|
||||
parse: P,
|
||||
) -> Result<Option<M>, ConnectionError>
|
||||
where
|
||||
P: Fn(&mut BytesMut) -> Result<Option<M>, ProtocolError>,
|
||||
{
|
||||
loop {
|
||||
if let Some(msg) = parse(read_buf)? {
|
||||
return Ok(Some(msg));
|
||||
}
|
||||
// If we can't build a frame yet, try to read more data and try again.
|
||||
// Make sure we've got room for at least one byte to read to ensure
|
||||
// that we don't get a spurious 0 that looks like EOF.
|
||||
read_buf.reserve(1);
|
||||
if stream.read_buf(read_buf).await? == 0 {
|
||||
if read_buf.has_remaining() {
|
||||
return Err(io::Error::new(
|
||||
ErrorKind::UnexpectedEof,
|
||||
"EOF with unprocessed data in the buffer",
|
||||
)
|
||||
.into());
|
||||
} else {
|
||||
return Ok(None); // clean EOF
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn flush<S: AsyncWrite + Unpin>(
|
||||
stream: &mut S,
|
||||
write_buf: &mut BytesMut,
|
||||
|
||||
@@ -4,19 +4,16 @@
|
||||
|
||||
pub mod framed;
|
||||
|
||||
use anyhow::{ensure, Context, Result};
|
||||
use byteorder::{BigEndian, ReadBytesExt};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use postgres_protocol::PG_EPOCH;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
collections::HashMap,
|
||||
fmt,
|
||||
io::{self, Cursor},
|
||||
str,
|
||||
fmt, io, str,
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tracing::{trace, warn};
|
||||
|
||||
pub type Oid = u32;
|
||||
@@ -28,7 +25,6 @@ pub const TEXT_OID: Oid = 25;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum FeMessage {
|
||||
StartupPacket(FeStartupPacket),
|
||||
// Simple query.
|
||||
Query(Bytes),
|
||||
// Extended query protocol.
|
||||
@@ -188,100 +184,90 @@ pub struct FeExecuteMessage {
|
||||
#[derive(Debug)]
|
||||
pub struct FeCloseMessage;
|
||||
|
||||
/// Retry a read on EINTR
|
||||
///
|
||||
/// This runs the enclosed expression, and if it returns
|
||||
/// Err(io::ErrorKind::Interrupted), retries it.
|
||||
macro_rules! retry_read {
|
||||
( $x:expr ) => {
|
||||
loop {
|
||||
match $x {
|
||||
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
|
||||
res => break res,
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// An error occured during connection being open.
|
||||
/// An error occured while parsing or serializing raw stream into Postgres
|
||||
/// messages.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ConnectionError {
|
||||
/// IO error during writing to or reading from the connection socket.
|
||||
#[error("Socket IO error: {0}")]
|
||||
Socket(#[from] std::io::Error),
|
||||
/// Invalid packet was received from client
|
||||
pub enum ProtocolError {
|
||||
/// Invalid packet was received from the client (e.g. unexpected message
|
||||
/// type or broken len).
|
||||
#[error("Protocol error: {0}")]
|
||||
Protocol(String),
|
||||
/// Failed to parse a protocol mesage
|
||||
/// Failed to parse or, (unlikely), serialize a protocol message.
|
||||
#[error("Message parse error: {0}")]
|
||||
MessageParse(anyhow::Error),
|
||||
BadMessage(String),
|
||||
}
|
||||
|
||||
impl From<anyhow::Error> for ConnectionError {
|
||||
fn from(e: anyhow::Error) -> Self {
|
||||
Self::MessageParse(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl ConnectionError {
|
||||
impl ProtocolError {
|
||||
/// Proxy stream.rs uses only io::Error; provide it.
|
||||
pub fn into_io_error(self) -> io::Error {
|
||||
match self {
|
||||
ConnectionError::Socket(io) => io,
|
||||
other => io::Error::new(io::ErrorKind::Other, other.to_string()),
|
||||
}
|
||||
io::Error::new(io::ErrorKind::Other, self.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl FeMessage {
|
||||
/// Read one message from the stream.
|
||||
/// This function returns `Ok(None)` in case of EOF.
|
||||
pub async fn read<Reader>(stream: &mut Reader) -> Result<Option<FeMessage>, ConnectionError>
|
||||
where
|
||||
Reader: tokio::io::AsyncRead + Unpin,
|
||||
{
|
||||
// We return a Future that's sync (has a `wait` method) if and only if the provided stream is SyncProof.
|
||||
// SyncFuture contract: we are only allowed to await on sync-proof futures, the AsyncRead and
|
||||
// AsyncReadExt methods of the stream.
|
||||
// Each libpq message begins with a message type byte, followed by message length
|
||||
// If the client closes the connection, return None. But if the client closes the
|
||||
// connection in the middle of a message, we will return an error.
|
||||
let tag = match retry_read!(stream.read_u8().await) {
|
||||
Ok(b) => b,
|
||||
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
|
||||
Err(e) => return Err(ConnectionError::Socket(e)),
|
||||
};
|
||||
/// Read and parse one message from the `buf` input buffer. If there is at
|
||||
/// least one valid message, returns it, advancing `buf`; redundant copies
|
||||
/// are avoided, as thanks to `bytes` crate ptrs in parsed message point
|
||||
/// directly into the `buf` (processed data is garbage collected after
|
||||
/// parsed message is dropped).
|
||||
///
|
||||
/// Returns None if `buf` doesn't contain enough data for a single message.
|
||||
/// For efficiency, tries to reserve large enough space in `buf` for the
|
||||
/// next message in this case to save the repeated calls.
|
||||
///
|
||||
/// Returns Error if message is malformed, the only possible ErrorKind is
|
||||
/// InvalidInput.
|
||||
//
|
||||
// Inspired by rust-postgres Message::parse.
|
||||
pub fn parse(buf: &mut BytesMut) -> Result<Option<FeMessage>, ProtocolError> {
|
||||
// Every message contains message type byte and 4 bytes len; can't do
|
||||
// much without them.
|
||||
if buf.len() < 5 {
|
||||
let to_read = 5 - buf.len();
|
||||
buf.reserve(to_read);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// The message length includes itself, so it better be at least 4.
|
||||
let len = retry_read!(stream.read_u32().await)
|
||||
.map_err(ConnectionError::Socket)?
|
||||
.checked_sub(4)
|
||||
.ok_or_else(|| ConnectionError::Protocol("invalid message length".to_string()))?;
|
||||
// We shouldn't advance `buf` as probably full message is not there yet,
|
||||
// so can't directly use Bytes::get_u32 etc.
|
||||
let tag = buf[0];
|
||||
let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
|
||||
if len < 4 {
|
||||
return Err(ProtocolError::Protocol(format!(
|
||||
"invalid message length {}",
|
||||
len
|
||||
)));
|
||||
}
|
||||
|
||||
let body = {
|
||||
let mut buffer = vec![0u8; len as usize];
|
||||
stream
|
||||
.read_exact(&mut buffer)
|
||||
.await
|
||||
.map_err(ConnectionError::Socket)?;
|
||||
Bytes::from(buffer)
|
||||
};
|
||||
// length field includes itself, but not message type.
|
||||
let total_len = len as usize + 1;
|
||||
if buf.len() < total_len {
|
||||
// Don't have full message yet.
|
||||
let to_read = total_len - buf.len();
|
||||
buf.reserve(to_read);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// got the message, advance buffer
|
||||
let mut msg = buf.split_to(total_len).freeze();
|
||||
msg.advance(5); // consume message type and len
|
||||
|
||||
match tag {
|
||||
b'Q' => Ok(Some(FeMessage::Query(body))),
|
||||
b'P' => Ok(Some(FeParseMessage::parse(body)?)),
|
||||
b'D' => Ok(Some(FeDescribeMessage::parse(body)?)),
|
||||
b'E' => Ok(Some(FeExecuteMessage::parse(body)?)),
|
||||
b'B' => Ok(Some(FeBindMessage::parse(body)?)),
|
||||
b'C' => Ok(Some(FeCloseMessage::parse(body)?)),
|
||||
b'Q' => Ok(Some(FeMessage::Query(msg))),
|
||||
b'P' => Ok(Some(FeParseMessage::parse(msg)?)),
|
||||
b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)),
|
||||
b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)),
|
||||
b'B' => Ok(Some(FeBindMessage::parse(msg)?)),
|
||||
b'C' => Ok(Some(FeCloseMessage::parse(msg)?)),
|
||||
b'S' => Ok(Some(FeMessage::Sync)),
|
||||
b'X' => Ok(Some(FeMessage::Terminate)),
|
||||
b'd' => Ok(Some(FeMessage::CopyData(body))),
|
||||
b'd' => Ok(Some(FeMessage::CopyData(msg))),
|
||||
b'c' => Ok(Some(FeMessage::CopyDone)),
|
||||
b'f' => Ok(Some(FeMessage::CopyFail)),
|
||||
b'p' => Ok(Some(FeMessage::PasswordMessage(body))),
|
||||
b'p' => Ok(Some(FeMessage::PasswordMessage(msg))),
|
||||
tag => {
|
||||
return Err(ConnectionError::Protocol(format!(
|
||||
"unknown message tag: {tag},'{body:?}'"
|
||||
return Err(ProtocolError::Protocol(format!(
|
||||
"unknown message tag: {tag},'{msg:?}'"
|
||||
)))
|
||||
}
|
||||
}
|
||||
@@ -289,60 +275,59 @@ impl FeMessage {
|
||||
}
|
||||
|
||||
impl FeStartupPacket {
|
||||
/// Read startup message from the stream.
|
||||
// XXX: It's tempting yet undesirable to accept `stream` by value,
|
||||
// since such a change will cause user-supplied &mut references to be consumed
|
||||
pub async fn read<Reader>(stream: &mut Reader) -> Result<Option<FeMessage>, ConnectionError>
|
||||
where
|
||||
Reader: tokio::io::AsyncRead + Unpin,
|
||||
{
|
||||
/// Read and parse startup message from the `buf` input buffer. It is
|
||||
/// different from [`FeMessage::parse`] because startup messages don't have
|
||||
/// message type byte; otherwise, its comments apply.
|
||||
pub fn parse(buf: &mut BytesMut) -> Result<Option<FeStartupPacket>, ProtocolError> {
|
||||
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
|
||||
const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
|
||||
const CANCEL_REQUEST_CODE: u32 = 5678;
|
||||
const NEGOTIATE_SSL_CODE: u32 = 5679;
|
||||
const NEGOTIATE_GSS_CODE: u32 = 5680;
|
||||
|
||||
// Read length. If the connection is closed before reading anything (or before
|
||||
// reading 4 bytes, to be precise), return None to indicate that the connection
|
||||
// was closed. This matches the PostgreSQL server's behavior, which avoids noise
|
||||
// in the log if the client opens connection but closes it immediately.
|
||||
let len = match retry_read!(stream.read_u32().await) {
|
||||
Ok(len) => len as usize,
|
||||
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
|
||||
Err(e) => return Err(ConnectionError::Socket(e)),
|
||||
};
|
||||
// need at least 4 bytes with packet len
|
||||
if buf.len() < 4 {
|
||||
let to_read = 4 - buf.len();
|
||||
buf.reserve(to_read);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
#[allow(clippy::manual_range_contains)]
|
||||
// We shouldn't advance `buf` as probably full message is not there yet,
|
||||
// so can't directly use Bytes::get_u32 etc.
|
||||
let len = (&buf[0..4]).read_u32::<BigEndian>().unwrap() as usize;
|
||||
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
|
||||
return Err(ConnectionError::Protocol(format!(
|
||||
"invalid message length {len}"
|
||||
return Err(ProtocolError::Protocol(format!(
|
||||
"invalid startup packet message length {}",
|
||||
len
|
||||
)));
|
||||
}
|
||||
|
||||
let request_code = retry_read!(stream.read_u32().await).map_err(ConnectionError::Socket)?;
|
||||
if buf.len() < len {
|
||||
// Don't have full message yet.
|
||||
let to_read = len - buf.len();
|
||||
buf.reserve(to_read);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// the rest of startup packet are params
|
||||
let params_len = len - 8;
|
||||
let mut params_bytes = vec![0u8; params_len];
|
||||
stream
|
||||
.read_exact(params_bytes.as_mut())
|
||||
.await
|
||||
.map_err(ConnectionError::Socket)?;
|
||||
// got the message, advance buffer
|
||||
let mut msg = buf.split_to(len).freeze();
|
||||
msg.advance(4); // consume len
|
||||
|
||||
// Parse params depending on request code
|
||||
let request_code = msg.get_u32();
|
||||
let req_hi = request_code >> 16;
|
||||
let req_lo = request_code & ((1 << 16) - 1);
|
||||
// StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code.
|
||||
let message = match (req_hi, req_lo) {
|
||||
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
|
||||
if params_len != 8 {
|
||||
return Err(ConnectionError::Protocol(
|
||||
"expected 8 bytes for CancelRequest params".to_string(),
|
||||
if msg.remaining() != 8 {
|
||||
return Err(ProtocolError::BadMessage(
|
||||
"CancelRequest message is malformed, backend PID / secret key missing"
|
||||
.to_owned(),
|
||||
));
|
||||
}
|
||||
let mut cursor = Cursor::new(params_bytes);
|
||||
FeStartupPacket::CancelRequest(CancelKeyData {
|
||||
backend_pid: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
|
||||
cancel_key: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
|
||||
backend_pid: msg.get_i32(),
|
||||
cancel_key: msg.get_i32(),
|
||||
})
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
|
||||
@@ -354,19 +339,23 @@ impl FeStartupPacket {
|
||||
FeStartupPacket::GssEncRequest
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
|
||||
return Err(ConnectionError::Protocol(format!(
|
||||
return Err(ProtocolError::Protocol(format!(
|
||||
"Unrecognized request code {unrecognized_code}"
|
||||
)));
|
||||
}
|
||||
// TODO bail if protocol major_version is not 3?
|
||||
(major_version, minor_version) => {
|
||||
// StartupMessage
|
||||
|
||||
// Parse pairs of null-terminated strings (key, value).
|
||||
// See `postgres: ProcessStartupPacket, build_startup_packet`.
|
||||
let mut tokens = str::from_utf8(¶ms_bytes)
|
||||
.context("StartupMessage params: invalid utf-8")?
|
||||
let mut tokens = str::from_utf8(&msg)
|
||||
.map_err(|_e| {
|
||||
ProtocolError::BadMessage("StartupMessage params: invalid utf-8".to_owned())
|
||||
})?
|
||||
.strip_suffix('\0') // drop packet's own null
|
||||
.ok_or_else(|| {
|
||||
ConnectionError::Protocol(
|
||||
ProtocolError::Protocol(
|
||||
"StartupMessage params: missing null terminator".to_string(),
|
||||
)
|
||||
})?
|
||||
@@ -375,7 +364,7 @@ impl FeStartupPacket {
|
||||
let mut params = HashMap::new();
|
||||
while let Some(name) = tokens.next() {
|
||||
let value = tokens.next().ok_or_else(|| {
|
||||
ConnectionError::Protocol(
|
||||
ProtocolError::Protocol(
|
||||
"StartupMessage params: key without value".to_string(),
|
||||
)
|
||||
})?;
|
||||
@@ -390,13 +379,12 @@ impl FeStartupPacket {
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Some(FeMessage::StartupPacket(message)))
|
||||
Ok(Some(message))
|
||||
}
|
||||
}
|
||||
|
||||
impl FeParseMessage {
|
||||
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
|
||||
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
|
||||
// FIXME: the rust-postgres driver uses a named prepared statement
|
||||
// for copy_out(). We're not prepared to handle that correctly. For
|
||||
// now, just ignore the statement name, assuming that the client never
|
||||
@@ -404,55 +392,82 @@ impl FeParseMessage {
|
||||
|
||||
let _pstmt_name = read_cstr(&mut buf)?;
|
||||
let query_string = read_cstr(&mut buf)?;
|
||||
if buf.remaining() < 2 {
|
||||
return Err(ProtocolError::BadMessage(
|
||||
"Parse message is malformed, nparams missing".to_string(),
|
||||
));
|
||||
}
|
||||
let nparams = buf.get_i16();
|
||||
|
||||
ensure!(nparams == 0, "query params not implemented");
|
||||
if nparams != 0 {
|
||||
return Err(ProtocolError::BadMessage(
|
||||
"query params not implemented".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(FeMessage::Parse(FeParseMessage { query_string }))
|
||||
}
|
||||
}
|
||||
|
||||
impl FeDescribeMessage {
|
||||
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
|
||||
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
|
||||
let kind = buf.get_u8();
|
||||
let _pstmt_name = read_cstr(&mut buf)?;
|
||||
|
||||
// FIXME: see FeParseMessage::parse
|
||||
ensure!(
|
||||
kind == b'S',
|
||||
"only prepared statemement Describe is implemented"
|
||||
);
|
||||
if kind != b'S' {
|
||||
return Err(ProtocolError::BadMessage(
|
||||
"only prepared statemement Describe is implemented".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(FeMessage::Describe(FeDescribeMessage { kind }))
|
||||
}
|
||||
}
|
||||
|
||||
impl FeExecuteMessage {
|
||||
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
|
||||
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
|
||||
let portal_name = read_cstr(&mut buf)?;
|
||||
if buf.remaining() < 4 {
|
||||
return Err(ProtocolError::BadMessage(
|
||||
"FeExecuteMessage message is malformed, maxrows missing".to_string(),
|
||||
));
|
||||
}
|
||||
let maxrows = buf.get_i32();
|
||||
|
||||
ensure!(portal_name.is_empty(), "named portals not implemented");
|
||||
ensure!(maxrows == 0, "row limit in Execute message not implemented");
|
||||
if !portal_name.is_empty() {
|
||||
return Err(ProtocolError::BadMessage(
|
||||
"named portals not implemented".to_string(),
|
||||
));
|
||||
}
|
||||
if maxrows != 0 {
|
||||
return Err(ProtocolError::BadMessage(
|
||||
"row limit in Execute message not implemented".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
|
||||
}
|
||||
}
|
||||
|
||||
impl FeBindMessage {
|
||||
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
|
||||
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
|
||||
let portal_name = read_cstr(&mut buf)?;
|
||||
let _pstmt_name = read_cstr(&mut buf)?;
|
||||
|
||||
// FIXME: see FeParseMessage::parse
|
||||
ensure!(portal_name.is_empty(), "named portals not implemented");
|
||||
if !portal_name.is_empty() {
|
||||
return Err(ProtocolError::BadMessage(
|
||||
"named portals not implemented".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(FeMessage::Bind(FeBindMessage))
|
||||
}
|
||||
}
|
||||
|
||||
impl FeCloseMessage {
|
||||
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
|
||||
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
|
||||
let _kind = buf.get_u8();
|
||||
let _pstmt_or_portal_name = read_cstr(&mut buf)?;
|
||||
|
||||
@@ -481,6 +496,7 @@ pub enum BeMessage<'a> {
|
||||
CloseComplete,
|
||||
// None means column is NULL
|
||||
DataRow(&'a [Option<&'a [u8]>]),
|
||||
// None errcode means internal_error will be sent.
|
||||
ErrorResponse(&'a str, Option<&'a [u8; 5]>),
|
||||
/// Single byte - used in response to SSLRequest/GSSENCRequest.
|
||||
EncryptionResponse(bool),
|
||||
@@ -594,7 +610,7 @@ impl RowDescriptor<'_> {
|
||||
#[derive(Debug)]
|
||||
pub struct XLogDataBody<'a> {
|
||||
pub wal_start: u64,
|
||||
pub wal_end: u64,
|
||||
pub wal_end: u64, // current end of WAL on the server
|
||||
pub timestamp: i64,
|
||||
pub data: &'a [u8],
|
||||
}
|
||||
@@ -634,12 +650,11 @@ fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
|
||||
}
|
||||
|
||||
/// Safe write of s into buf as cstring (String in the protocol).
|
||||
fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> io::Result<()> {
|
||||
fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> {
|
||||
let bytes = s.as_ref();
|
||||
if bytes.contains(&0) {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"string contains embedded null",
|
||||
return Err(ProtocolError::BadMessage(
|
||||
"string contains embedded null".to_owned(),
|
||||
));
|
||||
}
|
||||
buf.put_slice(bytes);
|
||||
@@ -647,9 +662,13 @@ fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn read_cstr(buf: &mut Bytes) -> anyhow::Result<Bytes> {
|
||||
let pos = buf.iter().position(|x| *x == 0);
|
||||
let result = buf.split_to(pos.context("missing terminator")?);
|
||||
/// Read cstring from buf, advancing it.
|
||||
fn read_cstr(buf: &mut Bytes) -> Result<Bytes, ProtocolError> {
|
||||
let pos = buf
|
||||
.iter()
|
||||
.position(|x| *x == 0)
|
||||
.ok_or_else(|| ProtocolError::BadMessage("missing cstring terminator".to_owned()))?;
|
||||
let result = buf.split_to(pos);
|
||||
buf.advance(1); // drop the null terminator
|
||||
Ok(result)
|
||||
}
|
||||
@@ -658,12 +677,12 @@ pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
|
||||
pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000";
|
||||
|
||||
impl<'a> BeMessage<'a> {
|
||||
/// Write message to the given buf.
|
||||
// Unlike the reading side, we use BytesMut
|
||||
// here as msg len precedes its body and it is handy to write it down first
|
||||
// and then fill the length. With Write we would have to either calc it
|
||||
// manually or have one more buffer.
|
||||
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> io::Result<()> {
|
||||
/// Serialize `message` to the given `buf`.
|
||||
/// Apart from smart memory managemet, BytesMut is good here as msg len
|
||||
/// precedes its body and it is handy to write it down first and then fill
|
||||
/// the length. With Write we would have to either calc it manually or have
|
||||
/// one more buffer.
|
||||
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> {
|
||||
match message {
|
||||
BeMessage::AuthenticationOk => {
|
||||
buf.put_u8(b'R');
|
||||
@@ -708,7 +727,7 @@ impl<'a> BeMessage<'a> {
|
||||
buf.put_slice(extra);
|
||||
}
|
||||
}
|
||||
Ok::<_, io::Error>(())
|
||||
Ok(())
|
||||
})?;
|
||||
}
|
||||
|
||||
@@ -812,7 +831,7 @@ impl<'a> BeMessage<'a> {
|
||||
write_cstr(error_msg, buf)?;
|
||||
|
||||
buf.put_u8(0); // terminator
|
||||
Ok::<_, io::Error>(())
|
||||
Ok(())
|
||||
})?;
|
||||
}
|
||||
|
||||
@@ -835,7 +854,7 @@ impl<'a> BeMessage<'a> {
|
||||
write_cstr(error_msg.as_bytes(), buf)?;
|
||||
|
||||
buf.put_u8(0); // terminator
|
||||
Ok::<_, io::Error>(())
|
||||
Ok(())
|
||||
})?;
|
||||
}
|
||||
|
||||
@@ -890,7 +909,7 @@ impl<'a> BeMessage<'a> {
|
||||
buf.put_i32(-1); /* typmod */
|
||||
buf.put_i16(0); /* format code */
|
||||
}
|
||||
Ok::<_, io::Error>(())
|
||||
Ok(())
|
||||
})?;
|
||||
}
|
||||
|
||||
@@ -957,7 +976,7 @@ impl ReplicationFeedback {
|
||||
// null-terminated string - key,
|
||||
// uint32 - value length in bytes
|
||||
// value itself
|
||||
pub fn serialize(&self, buf: &mut BytesMut) -> Result<()> {
|
||||
pub fn serialize(&self, buf: &mut BytesMut) {
|
||||
buf.put_u8(REPLICATION_FEEDBACK_FIELDS_NUMBER); // # of keys
|
||||
buf.put_slice(b"current_timeline_size\0");
|
||||
buf.put_i32(8);
|
||||
@@ -982,7 +1001,6 @@ impl ReplicationFeedback {
|
||||
buf.put_slice(b"ps_replytime\0");
|
||||
buf.put_i32(8);
|
||||
buf.put_i64(timestamp);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Deserialize ReplicationFeedback message
|
||||
@@ -1050,7 +1068,7 @@ mod tests {
|
||||
// because it is rounded up to microseconds during serialization.
|
||||
rf.ps_replytime = *PG_EPOCH + Duration::from_secs(100_000_000);
|
||||
let mut data = BytesMut::new();
|
||||
rf.serialize(&mut data).unwrap();
|
||||
rf.serialize(&mut data);
|
||||
|
||||
let rf_parsed = ReplicationFeedback::parse(data.freeze());
|
||||
assert_eq!(rf, rf_parsed);
|
||||
@@ -1065,7 +1083,7 @@ mod tests {
|
||||
// because it is rounded up to microseconds during serialization.
|
||||
rf.ps_replytime = *PG_EPOCH + Duration::from_secs(100_000_000);
|
||||
let mut data = BytesMut::new();
|
||||
rf.serialize(&mut data).unwrap();
|
||||
rf.serialize(&mut data);
|
||||
|
||||
// Add an extra field to the buffer and adjust number of keys
|
||||
if let Some(first) = data.first_mut() {
|
||||
|
||||
@@ -21,7 +21,7 @@ use pageserver_api::models::{
|
||||
PagestreamNblocksRequest, PagestreamNblocksResponse,
|
||||
};
|
||||
use postgres_backend::{self, is_expected_io_error, AuthType, PostgresBackend, QueryError};
|
||||
use pq_proto::ConnectionError;
|
||||
use pq_proto::framed::ConnectionError;
|
||||
use pq_proto::FeStartupPacket;
|
||||
use pq_proto::{BeMessage, FeMessage, RowDescriptor};
|
||||
use std::io;
|
||||
@@ -78,7 +78,7 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Byt
|
||||
FeMessage::Sync => continue,
|
||||
FeMessage::Terminate => {
|
||||
let msg = "client terminated connection with Terminate message during COPY";
|
||||
let query_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
|
||||
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
|
||||
// error can't happen here, ErrorResponse serialization should be always ok
|
||||
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?;
|
||||
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
|
||||
@@ -97,13 +97,13 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Byt
|
||||
}
|
||||
Ok(None) => {
|
||||
let msg = "client closed connection during COPY";
|
||||
let query_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
|
||||
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
|
||||
// error can't happen here, ErrorResponse serialization should be always ok
|
||||
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?;
|
||||
pgb.flush().await?;
|
||||
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
|
||||
}
|
||||
Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => {
|
||||
Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => {
|
||||
Err(io_error)?;
|
||||
}
|
||||
Err(other) => {
|
||||
@@ -214,7 +214,7 @@ async fn page_service_conn_main(
|
||||
// we've been requested to shut down
|
||||
Ok(())
|
||||
}
|
||||
Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => {
|
||||
Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => {
|
||||
if is_expected_io_error(&io_error) {
|
||||
info!("Postgres client disconnected ({io_error})");
|
||||
Ok(())
|
||||
@@ -1057,7 +1057,7 @@ impl From<GetActiveTenantError> for QueryError {
|
||||
fn from(e: GetActiveTenantError) -> Self {
|
||||
match e {
|
||||
GetActiveTenantError::WaitForActiveTimeout { .. } => QueryError::Disconnected(
|
||||
ConnectionError::Socket(io::Error::new(io::ErrorKind::TimedOut, e.to_string())),
|
||||
ConnectionError::Io(io::Error::new(io::ErrorKind::TimedOut, e.to_string())),
|
||||
),
|
||||
GetActiveTenantError::Other(e) => QueryError::Other(e),
|
||||
}
|
||||
|
||||
@@ -354,7 +354,7 @@ pub async fn handle_walreceiver_connection(
|
||||
debug!("neon_status_update {status_update:?}");
|
||||
|
||||
let mut data = BytesMut::new();
|
||||
status_update.serialize(&mut data)?;
|
||||
status_update.serialize(&mut data);
|
||||
physical_stream
|
||||
.as_mut()
|
||||
.zenith_status_update(data.len() as u64, &data)
|
||||
|
||||
@@ -1,45 +1,40 @@
|
||||
use crate::error::UserFacingError;
|
||||
use anyhow::bail;
|
||||
use bytes::BytesMut;
|
||||
use pin_project_lite::pin_project;
|
||||
use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket};
|
||||
use pq_proto::framed::{ConnectionError, Framed};
|
||||
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
|
||||
use rustls::ServerConfig;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::{io, task};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio_rustls::server::TlsStream;
|
||||
|
||||
pin_project! {
|
||||
/// Stream wrapper which implements libpq's protocol.
|
||||
/// NOTE: This object deliberately doesn't implement [`AsyncRead`]
|
||||
/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
|
||||
/// to pass random malformed bytes through the connection).
|
||||
pub struct PqStream<S> {
|
||||
#[pin]
|
||||
stream: S,
|
||||
buffer: BytesMut,
|
||||
}
|
||||
/// Stream wrapper which implements libpq's protocol.
|
||||
/// NOTE: This object deliberately doesn't implement [`AsyncRead`]
|
||||
/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
|
||||
/// to pass random malformed bytes through the connection).
|
||||
pub struct PqStream<S> {
|
||||
framed: Framed<S>,
|
||||
}
|
||||
|
||||
impl<S> PqStream<S> {
|
||||
/// Construct a new libpq protocol wrapper.
|
||||
pub fn new(stream: S) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
buffer: Default::default(),
|
||||
framed: Framed::new(stream),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the underlying stream.
|
||||
pub fn into_inner(self) -> S {
|
||||
self.stream
|
||||
self.framed.into_inner()
|
||||
}
|
||||
|
||||
/// Get a shared reference to the underlying stream.
|
||||
pub fn get_ref(&self) -> &S {
|
||||
&self.stream
|
||||
self.framed.get_ref()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,16 +45,19 @@ fn err_connection() -> io::Error {
|
||||
impl<S: AsyncRead + Unpin> PqStream<S> {
|
||||
/// Receive [`FeStartupPacket`], which is a first packet sent by a client.
|
||||
pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
|
||||
// TODO: `FeStartupPacket::read_fut` should return `FeStartupPacket`
|
||||
let msg = FeStartupPacket::read(&mut self.stream)
|
||||
self.framed
|
||||
.read_startup_message()
|
||||
.await
|
||||
.map_err(ConnectionError::into_io_error)?
|
||||
.ok_or_else(err_connection)?;
|
||||
.ok_or_else(err_connection)
|
||||
}
|
||||
|
||||
match msg {
|
||||
FeMessage::StartupPacket(packet) => Ok(packet),
|
||||
_ => panic!("unreachable state"),
|
||||
}
|
||||
async fn read_message(&mut self) -> io::Result<FeMessage> {
|
||||
self.framed
|
||||
.read_message()
|
||||
.await
|
||||
.map_err(ConnectionError::into_io_error)?
|
||||
.ok_or_else(err_connection)
|
||||
}
|
||||
|
||||
pub async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
|
||||
@@ -71,19 +69,14 @@ impl<S: AsyncRead + Unpin> PqStream<S> {
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_message(&mut self) -> io::Result<FeMessage> {
|
||||
FeMessage::read(&mut self.stream)
|
||||
.await
|
||||
.map_err(ConnectionError::into_io_error)?
|
||||
.ok_or_else(err_connection)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin> PqStream<S> {
|
||||
/// Write the message into an internal buffer, but don't flush the underlying stream.
|
||||
pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
|
||||
BeMessage::write(&mut self.buffer, message)?;
|
||||
self.framed
|
||||
.write_message(message)
|
||||
.map_err(ProtocolError::into_io_error)?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
@@ -96,9 +89,7 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
|
||||
|
||||
/// Flush the output buffer into the underlying stream.
|
||||
pub async fn flush(&mut self) -> io::Result<&mut Self> {
|
||||
self.stream.write_all(&self.buffer).await?;
|
||||
self.buffer.clear();
|
||||
self.stream.flush().await?;
|
||||
self.framed.flush().await?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
|
||||
@@ -488,7 +488,7 @@ impl AcceptorProposerMessage {
|
||||
buf.put_u64_le(msg.hs_feedback.xmin);
|
||||
buf.put_u64_le(msg.hs_feedback.catalog_xmin);
|
||||
|
||||
msg.pageserver_feedback.serialize(buf)?;
|
||||
msg.pageserver_feedback.serialize(buf);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user