mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-16 09:52:54 +00:00
Refactor postgres protocol parsing.
1) Remove allocation and data copy during each message read. Instead, parsing functions now accept BytesMut from which data they form messages, with pointers (e.g. in CopyData) pointing directly into BytesMut buffer. Accordingly, move ConnectionError containing IO error subtype into framed.rs providing this and leave in pq_proto only ProtocolError. 2) Remove anyhow from pq_proto. 3) Move FeStartupPacket out of FeMessage. Now FeStartupPacket::parse returns it directly, eliminating dead code where user wants startup packet but has to match for others. proxy stream.rs is adapted to framed.rs with minimal changes. It also benefits from framed.rs improvements described above.
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -2747,7 +2747,7 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
|
|||||||
name = "pq_proto"
|
name = "pq_proto"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"byteorder",
|
||||||
"bytes",
|
"bytes",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"postgres-protocol",
|
"postgres-protocol",
|
||||||
|
|||||||
@@ -17,9 +17,9 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
|||||||
use tokio_rustls::TlsAcceptor;
|
use tokio_rustls::TlsAcceptor;
|
||||||
use tracing::{debug, error, info, trace};
|
use tracing::{debug, error, info, trace};
|
||||||
|
|
||||||
use pq_proto::framed::{Framed, FramedReader, FramedWriter};
|
use pq_proto::framed::{ConnectionError, Framed, FramedReader, FramedWriter};
|
||||||
use pq_proto::{
|
use pq_proto::{
|
||||||
BeMessage, ConnectionError, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR,
|
BeMessage, FeMessage, FeStartupPacket, ProtocolError, SQLSTATE_INTERNAL_ERROR,
|
||||||
SQLSTATE_SUCCESSFUL_COMPLETION,
|
SQLSTATE_SUCCESSFUL_COMPLETION,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -37,7 +37,7 @@ pub enum QueryError {
|
|||||||
|
|
||||||
impl From<io::Error> for QueryError {
|
impl From<io::Error> for QueryError {
|
||||||
fn from(e: io::Error) -> Self {
|
fn from(e: io::Error) -> Self {
|
||||||
Self::Disconnected(ConnectionError::Socket(e))
|
Self::Disconnected(ConnectionError::Io(e))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -219,7 +219,7 @@ impl MaybeWriteOnly {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ConnectionError> {
|
fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
|
||||||
match self {
|
match self {
|
||||||
MaybeWriteOnly::Full(framed) => framed.write_message(msg),
|
MaybeWriteOnly::Full(framed) => framed.write_message(msg),
|
||||||
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.write_message_noflush(msg),
|
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.write_message_noflush(msg),
|
||||||
@@ -701,8 +701,7 @@ impl PostgresBackend {
|
|||||||
FeMessage::CopyData(_)
|
FeMessage::CopyData(_)
|
||||||
| FeMessage::CopyDone
|
| FeMessage::CopyDone
|
||||||
| FeMessage::CopyFail
|
| FeMessage::CopyFail
|
||||||
| FeMessage::PasswordMessage(_)
|
| FeMessage::PasswordMessage(_) => {
|
||||||
| FeMessage::StartupPacket(_) => {
|
|
||||||
return Err(QueryError::Other(anyhow::anyhow!(
|
return Err(QueryError::Other(anyhow::anyhow!(
|
||||||
"unexpected message type: {msg:?}",
|
"unexpected message type: {msg:?}",
|
||||||
)));
|
)));
|
||||||
@@ -721,7 +720,7 @@ impl PostgresBackend {
|
|||||||
|
|
||||||
let expected_end = match &end {
|
let expected_end = match &end {
|
||||||
ServerInitiated(_) | CopyDone | CopyFail | Terminate | EOF => true,
|
ServerInitiated(_) | CopyDone | CopyFail | Terminate | EOF => true,
|
||||||
CopyStreamHandlerEnd::Disconnected(ConnectionError::Socket(io_error))
|
CopyStreamHandlerEnd::Disconnected(ConnectionError::Io(io_error))
|
||||||
if is_expected_io_error(io_error) =>
|
if is_expected_io_error(io_error) =>
|
||||||
{
|
{
|
||||||
true
|
true
|
||||||
@@ -800,7 +799,7 @@ impl PostgresBackendReader {
|
|||||||
FeMessage::CopyFail => Err(CopyStreamHandlerEnd::CopyFail),
|
FeMessage::CopyFail => Err(CopyStreamHandlerEnd::CopyFail),
|
||||||
FeMessage::Terminate => Err(CopyStreamHandlerEnd::Terminate),
|
FeMessage::Terminate => Err(CopyStreamHandlerEnd::Terminate),
|
||||||
_ => Err(CopyStreamHandlerEnd::from(ConnectionError::Protocol(
|
_ => Err(CopyStreamHandlerEnd::from(ConnectionError::Protocol(
|
||||||
format!("unexpected message in COPY stream {:?}", msg),
|
ProtocolError::Protocol(format!("unexpected message in COPY stream {:?}", msg)),
|
||||||
))),
|
))),
|
||||||
},
|
},
|
||||||
None => Err(CopyStreamHandlerEnd::EOF),
|
None => Err(CopyStreamHandlerEnd::EOF),
|
||||||
@@ -871,7 +870,7 @@ pub fn short_error(e: &QueryError) -> String {
|
|||||||
|
|
||||||
fn log_query_error(query: &str, e: &QueryError) {
|
fn log_query_error(query: &str, e: &QueryError) {
|
||||||
match e {
|
match e {
|
||||||
QueryError::Disconnected(ConnectionError::Socket(io_error)) => {
|
QueryError::Disconnected(ConnectionError::Io(io_error)) => {
|
||||||
if is_expected_io_error(io_error) {
|
if is_expected_io_error(io_error) {
|
||||||
info!("query handler for '{query}' failed with expected io error: {io_error}");
|
info!("query handler for '{query}' failed with expected io error: {io_error}");
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ edition.workspace = true
|
|||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow.workspace = true
|
|
||||||
bytes.workspace = true
|
bytes.workspace = true
|
||||||
|
byteorder.workspace = true
|
||||||
pin-project-lite.workspace = true
|
pin-project-lite.workspace = true
|
||||||
postgres-protocol.workspace = true
|
postgres-protocol.workspace = true
|
||||||
rand.workspace = true
|
rand.workspace = true
|
||||||
|
|||||||
@@ -1,51 +1,84 @@
|
|||||||
//! Provides `Framed` -- writing/flushing and reading Postgres messages to/from
|
//! Provides `Framed` -- writing/flushing and reading Postgres messages to/from
|
||||||
//! the async stream.
|
//! the async stream based on (and buffered with) BytesMut. All functions are
|
||||||
|
//! cancellation safe.
|
||||||
|
//!
|
||||||
|
//! It is similar to what tokio_util::codec::Framed with appropriate codec
|
||||||
|
//! provides, but `FramedReader` and `FramedWriter` read/write parts can be used
|
||||||
|
//! separately without using split from futures::stream::StreamExt (which
|
||||||
|
//! allocates box[1] in polling internally). tokio::io::split is used for splitting
|
||||||
|
//! instead. Plus we customize error messages more than a single type for all io
|
||||||
|
//! calls.
|
||||||
|
//!
|
||||||
|
//! [1] https://docs.rs/futures-util/0.3.26/src/futures_util/lock/bilock.rs.html#107
|
||||||
use bytes::{Buf, BytesMut};
|
use bytes::{Buf, BytesMut};
|
||||||
use std::{
|
use std::{
|
||||||
future::Future,
|
future::Future,
|
||||||
io::{self, ErrorKind},
|
io::{self, ErrorKind},
|
||||||
};
|
};
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
|
||||||
|
|
||||||
use crate::{BeMessage, ConnectionError, FeMessage, FeStartupPacket};
|
use crate::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
|
||||||
|
|
||||||
const INITIAL_CAPACITY: usize = 8 * 1024;
|
const INITIAL_CAPACITY: usize = 8 * 1024;
|
||||||
|
|
||||||
|
/// Error on postgres connection: either IO (physical transport error) or
|
||||||
|
/// protocol violation.
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
pub enum ConnectionError {
|
||||||
|
#[error(transparent)]
|
||||||
|
Io(#[from] io::Error),
|
||||||
|
#[error(transparent)]
|
||||||
|
Protocol(#[from] ProtocolError),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConnectionError {
|
||||||
|
/// Proxy stream.rs uses only io::Error; provide it.
|
||||||
|
pub fn into_io_error(self) -> io::Error {
|
||||||
|
match self {
|
||||||
|
ConnectionError::Io(io) => io,
|
||||||
|
ConnectionError::Protocol(pe) => io::Error::new(io::ErrorKind::Other, pe.to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Wraps async io `stream`, providing messages to write/flush + read Postgres
|
/// Wraps async io `stream`, providing messages to write/flush + read Postgres
|
||||||
/// messages.
|
/// messages.
|
||||||
pub struct Framed<S> {
|
pub struct Framed<S> {
|
||||||
stream: BufReader<S>,
|
stream: S,
|
||||||
|
read_buf: BytesMut,
|
||||||
write_buf: BytesMut,
|
write_buf: BytesMut,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: AsyncRead + Unpin> Framed<S> {
|
impl<S> Framed<S> {
|
||||||
pub fn new(stream: S) -> Self {
|
pub fn new(stream: S) -> Self {
|
||||||
Self {
|
Self {
|
||||||
stream: BufReader::new(stream),
|
stream,
|
||||||
|
read_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
|
||||||
write_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
|
write_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get a shared reference to the underlying stream.
|
/// Get a shared reference to the underlying stream.
|
||||||
pub fn get_ref(&self) -> &S {
|
pub fn get_ref(&self) -> &S {
|
||||||
self.stream.get_ref()
|
&self.stream
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extract the underlying stream.
|
/// Extract the underlying stream.
|
||||||
pub fn into_inner(self) -> S {
|
pub fn into_inner(self) -> S {
|
||||||
self.stream.into_inner()
|
self.stream
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return new Framed with stream type transformed by async f, for TLS
|
/// Return new Framed with stream type transformed by async f, for TLS
|
||||||
/// upgrade.
|
/// upgrade.
|
||||||
pub async fn map_stream<S2: AsyncRead, E, F, Fut>(self, f: F) -> Result<Framed<S2>, E>
|
pub async fn map_stream<S2, E, F, Fut>(self, f: F) -> Result<Framed<S2>, E>
|
||||||
where
|
where
|
||||||
F: FnOnce(S) -> Fut,
|
F: FnOnce(S) -> Fut,
|
||||||
Fut: Future<Output = Result<S2, E>>,
|
Fut: Future<Output = Result<S2, E>>,
|
||||||
{
|
{
|
||||||
let stream = f(self.stream.into_inner()).await?;
|
let stream = f(self.stream).await?;
|
||||||
Ok(Framed {
|
Ok(Framed {
|
||||||
stream: BufReader::new(stream),
|
stream,
|
||||||
|
read_buf: self.read_buf,
|
||||||
write_buf: self.write_buf,
|
write_buf: self.write_buf,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -55,24 +88,18 @@ impl<S: AsyncRead + Unpin> Framed<S> {
|
|||||||
pub async fn read_startup_message(
|
pub async fn read_startup_message(
|
||||||
&mut self,
|
&mut self,
|
||||||
) -> Result<Option<FeStartupPacket>, ConnectionError> {
|
) -> Result<Option<FeStartupPacket>, ConnectionError> {
|
||||||
let msg = FeStartupPacket::read(&mut self.stream).await?;
|
read_message(&mut self.stream, &mut self.read_buf, FeStartupPacket::parse).await
|
||||||
|
|
||||||
match msg {
|
|
||||||
Some(FeMessage::StartupPacket(packet)) => Ok(Some(packet)),
|
|
||||||
None => Ok(None),
|
|
||||||
_ => panic!("unreachable state"),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
|
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
|
||||||
FeMessage::read(&mut self.stream).await
|
read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: AsyncWrite + AsyncRead + Unpin> Framed<S> {
|
impl<S: AsyncWrite + Unpin> Framed<S> {
|
||||||
/// Write next message to the output buffer; doesn't flush.
|
/// Write next message to the output buffer; doesn't flush.
|
||||||
pub fn write_message(&mut self, msg: &BeMessage<'_>) -> Result<(), ConnectionError> {
|
pub fn write_message(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
|
||||||
BeMessage::write(&mut self.write_buf, msg).map_err(|e| e.into())
|
BeMessage::write(&mut self.write_buf, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Flush out the buffer. This function is cancellation safe: it can be
|
/// Flush out the buffer. This function is cancellation safe: it can be
|
||||||
@@ -93,7 +120,10 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Framed<S> {
|
|||||||
/// https://github.com/tokio-rs/tls/issues/40
|
/// https://github.com/tokio-rs/tls/issues/40
|
||||||
pub fn split(self) -> (FramedReader<S>, FramedWriter<S>) {
|
pub fn split(self) -> (FramedReader<S>, FramedWriter<S>) {
|
||||||
let (read_half, write_half) = tokio::io::split(self.stream);
|
let (read_half, write_half) = tokio::io::split(self.stream);
|
||||||
let reader = FramedReader { stream: read_half };
|
let reader = FramedReader {
|
||||||
|
stream: read_half,
|
||||||
|
read_buf: self.read_buf,
|
||||||
|
};
|
||||||
let writer = FramedWriter {
|
let writer = FramedWriter {
|
||||||
stream: write_half,
|
stream: write_half,
|
||||||
write_buf: self.write_buf,
|
write_buf: self.write_buf,
|
||||||
@@ -105,6 +135,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Framed<S> {
|
|||||||
pub fn unsplit(reader: FramedReader<S>, writer: FramedWriter<S>) -> Self {
|
pub fn unsplit(reader: FramedReader<S>, writer: FramedWriter<S>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
stream: reader.stream.unsplit(writer.stream),
|
stream: reader.stream.unsplit(writer.stream),
|
||||||
|
read_buf: reader.read_buf,
|
||||||
write_buf: writer.write_buf,
|
write_buf: writer.write_buf,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -112,25 +143,26 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Framed<S> {
|
|||||||
|
|
||||||
/// Read-only version of `Framed`.
|
/// Read-only version of `Framed`.
|
||||||
pub struct FramedReader<S> {
|
pub struct FramedReader<S> {
|
||||||
stream: ReadHalf<BufReader<S>>,
|
stream: ReadHalf<S>,
|
||||||
|
read_buf: BytesMut,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: AsyncRead + Unpin> FramedReader<S> {
|
impl<S: AsyncRead + Unpin> FramedReader<S> {
|
||||||
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
|
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
|
||||||
FeMessage::read(&mut self.stream).await
|
read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Write-only version of `Framed`.
|
/// Write-only version of `Framed`.
|
||||||
pub struct FramedWriter<S> {
|
pub struct FramedWriter<S> {
|
||||||
stream: WriteHalf<BufReader<S>>,
|
stream: WriteHalf<S>,
|
||||||
write_buf: BytesMut,
|
write_buf: BytesMut,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: AsyncWrite + AsyncRead + Unpin> FramedWriter<S> {
|
impl<S: AsyncWrite + Unpin> FramedWriter<S> {
|
||||||
/// Write next message to the output buffer; doesn't flush.
|
/// Write next message to the output buffer; doesn't flush.
|
||||||
pub fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ConnectionError> {
|
pub fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
|
||||||
BeMessage::write(&mut self.write_buf, msg).map_err(|e| e.into())
|
BeMessage::write(&mut self.write_buf, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Flush out the buffer. This function is cancellation safe: it can be
|
/// Flush out the buffer. This function is cancellation safe: it can be
|
||||||
@@ -145,6 +177,43 @@ impl<S: AsyncWrite + AsyncRead + Unpin> FramedWriter<S> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Read next message from the stream. Returns Ok(None), if EOF happened and we
|
||||||
|
/// don't have remaining data in the buffer. This function is cancellation safe:
|
||||||
|
/// you can drop future which is not yet complete and finalize reading message
|
||||||
|
/// with the next call.
|
||||||
|
///
|
||||||
|
/// Parametrized to allow reading startup or usual message, having different
|
||||||
|
/// format.
|
||||||
|
async fn read_message<S: AsyncRead + Unpin, M, P>(
|
||||||
|
stream: &mut S,
|
||||||
|
read_buf: &mut BytesMut,
|
||||||
|
parse: P,
|
||||||
|
) -> Result<Option<M>, ConnectionError>
|
||||||
|
where
|
||||||
|
P: Fn(&mut BytesMut) -> Result<Option<M>, ProtocolError>,
|
||||||
|
{
|
||||||
|
loop {
|
||||||
|
if let Some(msg) = parse(read_buf)? {
|
||||||
|
return Ok(Some(msg));
|
||||||
|
}
|
||||||
|
// If we can't build a frame yet, try to read more data and try again.
|
||||||
|
// Make sure we've got room for at least one byte to read to ensure
|
||||||
|
// that we don't get a spurious 0 that looks like EOF.
|
||||||
|
read_buf.reserve(1);
|
||||||
|
if stream.read_buf(read_buf).await? == 0 {
|
||||||
|
if read_buf.has_remaining() {
|
||||||
|
return Err(io::Error::new(
|
||||||
|
ErrorKind::UnexpectedEof,
|
||||||
|
"EOF with unprocessed data in the buffer",
|
||||||
|
)
|
||||||
|
.into());
|
||||||
|
} else {
|
||||||
|
return Ok(None); // clean EOF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn flush<S: AsyncWrite + Unpin>(
|
async fn flush<S: AsyncWrite + Unpin>(
|
||||||
stream: &mut S,
|
stream: &mut S,
|
||||||
write_buf: &mut BytesMut,
|
write_buf: &mut BytesMut,
|
||||||
|
|||||||
@@ -4,19 +4,16 @@
|
|||||||
|
|
||||||
pub mod framed;
|
pub mod framed;
|
||||||
|
|
||||||
use anyhow::{ensure, Context, Result};
|
use byteorder::{BigEndian, ReadBytesExt};
|
||||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||||
use postgres_protocol::PG_EPOCH;
|
use postgres_protocol::PG_EPOCH;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::{
|
use std::{
|
||||||
borrow::Cow,
|
borrow::Cow,
|
||||||
collections::HashMap,
|
collections::HashMap,
|
||||||
fmt,
|
fmt, io, str,
|
||||||
io::{self, Cursor},
|
|
||||||
str,
|
|
||||||
time::{Duration, SystemTime},
|
time::{Duration, SystemTime},
|
||||||
};
|
};
|
||||||
use tokio::io::AsyncReadExt;
|
|
||||||
use tracing::{trace, warn};
|
use tracing::{trace, warn};
|
||||||
|
|
||||||
pub type Oid = u32;
|
pub type Oid = u32;
|
||||||
@@ -28,7 +25,6 @@ pub const TEXT_OID: Oid = 25;
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum FeMessage {
|
pub enum FeMessage {
|
||||||
StartupPacket(FeStartupPacket),
|
|
||||||
// Simple query.
|
// Simple query.
|
||||||
Query(Bytes),
|
Query(Bytes),
|
||||||
// Extended query protocol.
|
// Extended query protocol.
|
||||||
@@ -188,100 +184,90 @@ pub struct FeExecuteMessage {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct FeCloseMessage;
|
pub struct FeCloseMessage;
|
||||||
|
|
||||||
/// Retry a read on EINTR
|
/// An error occured while parsing or serializing raw stream into Postgres
|
||||||
///
|
/// messages.
|
||||||
/// This runs the enclosed expression, and if it returns
|
|
||||||
/// Err(io::ErrorKind::Interrupted), retries it.
|
|
||||||
macro_rules! retry_read {
|
|
||||||
( $x:expr ) => {
|
|
||||||
loop {
|
|
||||||
match $x {
|
|
||||||
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
|
|
||||||
res => break res,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/// An error occured during connection being open.
|
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
pub enum ConnectionError {
|
pub enum ProtocolError {
|
||||||
/// IO error during writing to or reading from the connection socket.
|
/// Invalid packet was received from the client (e.g. unexpected message
|
||||||
#[error("Socket IO error: {0}")]
|
/// type or broken len).
|
||||||
Socket(#[from] std::io::Error),
|
|
||||||
/// Invalid packet was received from client
|
|
||||||
#[error("Protocol error: {0}")]
|
#[error("Protocol error: {0}")]
|
||||||
Protocol(String),
|
Protocol(String),
|
||||||
/// Failed to parse a protocol mesage
|
/// Failed to parse or, (unlikely), serialize a protocol message.
|
||||||
#[error("Message parse error: {0}")]
|
#[error("Message parse error: {0}")]
|
||||||
MessageParse(anyhow::Error),
|
BadMessage(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<anyhow::Error> for ConnectionError {
|
impl ProtocolError {
|
||||||
fn from(e: anyhow::Error) -> Self {
|
/// Proxy stream.rs uses only io::Error; provide it.
|
||||||
Self::MessageParse(e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ConnectionError {
|
|
||||||
pub fn into_io_error(self) -> io::Error {
|
pub fn into_io_error(self) -> io::Error {
|
||||||
match self {
|
io::Error::new(io::ErrorKind::Other, self.to_string())
|
||||||
ConnectionError::Socket(io) => io,
|
|
||||||
other => io::Error::new(io::ErrorKind::Other, other.to_string()),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FeMessage {
|
impl FeMessage {
|
||||||
/// Read one message from the stream.
|
/// Read and parse one message from the `buf` input buffer. If there is at
|
||||||
/// This function returns `Ok(None)` in case of EOF.
|
/// least one valid message, returns it, advancing `buf`; redundant copies
|
||||||
pub async fn read<Reader>(stream: &mut Reader) -> Result<Option<FeMessage>, ConnectionError>
|
/// are avoided, as thanks to `bytes` crate ptrs in parsed message point
|
||||||
where
|
/// directly into the `buf` (processed data is garbage collected after
|
||||||
Reader: tokio::io::AsyncRead + Unpin,
|
/// parsed message is dropped).
|
||||||
{
|
///
|
||||||
// We return a Future that's sync (has a `wait` method) if and only if the provided stream is SyncProof.
|
/// Returns None if `buf` doesn't contain enough data for a single message.
|
||||||
// SyncFuture contract: we are only allowed to await on sync-proof futures, the AsyncRead and
|
/// For efficiency, tries to reserve large enough space in `buf` for the
|
||||||
// AsyncReadExt methods of the stream.
|
/// next message in this case to save the repeated calls.
|
||||||
// Each libpq message begins with a message type byte, followed by message length
|
///
|
||||||
// If the client closes the connection, return None. But if the client closes the
|
/// Returns Error if message is malformed, the only possible ErrorKind is
|
||||||
// connection in the middle of a message, we will return an error.
|
/// InvalidInput.
|
||||||
let tag = match retry_read!(stream.read_u8().await) {
|
//
|
||||||
Ok(b) => b,
|
// Inspired by rust-postgres Message::parse.
|
||||||
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
|
pub fn parse(buf: &mut BytesMut) -> Result<Option<FeMessage>, ProtocolError> {
|
||||||
Err(e) => return Err(ConnectionError::Socket(e)),
|
// Every message contains message type byte and 4 bytes len; can't do
|
||||||
};
|
// much without them.
|
||||||
|
if buf.len() < 5 {
|
||||||
|
let to_read = 5 - buf.len();
|
||||||
|
buf.reserve(to_read);
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
// The message length includes itself, so it better be at least 4.
|
// We shouldn't advance `buf` as probably full message is not there yet,
|
||||||
let len = retry_read!(stream.read_u32().await)
|
// so can't directly use Bytes::get_u32 etc.
|
||||||
.map_err(ConnectionError::Socket)?
|
let tag = buf[0];
|
||||||
.checked_sub(4)
|
let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
|
||||||
.ok_or_else(|| ConnectionError::Protocol("invalid message length".to_string()))?;
|
if len < 4 {
|
||||||
|
return Err(ProtocolError::Protocol(format!(
|
||||||
|
"invalid message length {}",
|
||||||
|
len
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
let body = {
|
// length field includes itself, but not message type.
|
||||||
let mut buffer = vec![0u8; len as usize];
|
let total_len = len as usize + 1;
|
||||||
stream
|
if buf.len() < total_len {
|
||||||
.read_exact(&mut buffer)
|
// Don't have full message yet.
|
||||||
.await
|
let to_read = total_len - buf.len();
|
||||||
.map_err(ConnectionError::Socket)?;
|
buf.reserve(to_read);
|
||||||
Bytes::from(buffer)
|
return Ok(None);
|
||||||
};
|
}
|
||||||
|
|
||||||
|
// got the message, advance buffer
|
||||||
|
let mut msg = buf.split_to(total_len).freeze();
|
||||||
|
msg.advance(5); // consume message type and len
|
||||||
|
|
||||||
match tag {
|
match tag {
|
||||||
b'Q' => Ok(Some(FeMessage::Query(body))),
|
b'Q' => Ok(Some(FeMessage::Query(msg))),
|
||||||
b'P' => Ok(Some(FeParseMessage::parse(body)?)),
|
b'P' => Ok(Some(FeParseMessage::parse(msg)?)),
|
||||||
b'D' => Ok(Some(FeDescribeMessage::parse(body)?)),
|
b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)),
|
||||||
b'E' => Ok(Some(FeExecuteMessage::parse(body)?)),
|
b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)),
|
||||||
b'B' => Ok(Some(FeBindMessage::parse(body)?)),
|
b'B' => Ok(Some(FeBindMessage::parse(msg)?)),
|
||||||
b'C' => Ok(Some(FeCloseMessage::parse(body)?)),
|
b'C' => Ok(Some(FeCloseMessage::parse(msg)?)),
|
||||||
b'S' => Ok(Some(FeMessage::Sync)),
|
b'S' => Ok(Some(FeMessage::Sync)),
|
||||||
b'X' => Ok(Some(FeMessage::Terminate)),
|
b'X' => Ok(Some(FeMessage::Terminate)),
|
||||||
b'd' => Ok(Some(FeMessage::CopyData(body))),
|
b'd' => Ok(Some(FeMessage::CopyData(msg))),
|
||||||
b'c' => Ok(Some(FeMessage::CopyDone)),
|
b'c' => Ok(Some(FeMessage::CopyDone)),
|
||||||
b'f' => Ok(Some(FeMessage::CopyFail)),
|
b'f' => Ok(Some(FeMessage::CopyFail)),
|
||||||
b'p' => Ok(Some(FeMessage::PasswordMessage(body))),
|
b'p' => Ok(Some(FeMessage::PasswordMessage(msg))),
|
||||||
tag => {
|
tag => {
|
||||||
return Err(ConnectionError::Protocol(format!(
|
return Err(ProtocolError::Protocol(format!(
|
||||||
"unknown message tag: {tag},'{body:?}'"
|
"unknown message tag: {tag},'{msg:?}'"
|
||||||
)))
|
)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -289,60 +275,59 @@ impl FeMessage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl FeStartupPacket {
|
impl FeStartupPacket {
|
||||||
/// Read startup message from the stream.
|
/// Read and parse startup message from the `buf` input buffer. It is
|
||||||
// XXX: It's tempting yet undesirable to accept `stream` by value,
|
/// different from [`FeMessage::parse`] because startup messages don't have
|
||||||
// since such a change will cause user-supplied &mut references to be consumed
|
/// message type byte; otherwise, its comments apply.
|
||||||
pub async fn read<Reader>(stream: &mut Reader) -> Result<Option<FeMessage>, ConnectionError>
|
pub fn parse(buf: &mut BytesMut) -> Result<Option<FeStartupPacket>, ProtocolError> {
|
||||||
where
|
|
||||||
Reader: tokio::io::AsyncRead + Unpin,
|
|
||||||
{
|
|
||||||
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
|
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
|
||||||
const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
|
const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
|
||||||
const CANCEL_REQUEST_CODE: u32 = 5678;
|
const CANCEL_REQUEST_CODE: u32 = 5678;
|
||||||
const NEGOTIATE_SSL_CODE: u32 = 5679;
|
const NEGOTIATE_SSL_CODE: u32 = 5679;
|
||||||
const NEGOTIATE_GSS_CODE: u32 = 5680;
|
const NEGOTIATE_GSS_CODE: u32 = 5680;
|
||||||
|
|
||||||
// Read length. If the connection is closed before reading anything (or before
|
// need at least 4 bytes with packet len
|
||||||
// reading 4 bytes, to be precise), return None to indicate that the connection
|
if buf.len() < 4 {
|
||||||
// was closed. This matches the PostgreSQL server's behavior, which avoids noise
|
let to_read = 4 - buf.len();
|
||||||
// in the log if the client opens connection but closes it immediately.
|
buf.reserve(to_read);
|
||||||
let len = match retry_read!(stream.read_u32().await) {
|
return Ok(None);
|
||||||
Ok(len) => len as usize,
|
}
|
||||||
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
|
|
||||||
Err(e) => return Err(ConnectionError::Socket(e)),
|
|
||||||
};
|
|
||||||
|
|
||||||
#[allow(clippy::manual_range_contains)]
|
// We shouldn't advance `buf` as probably full message is not there yet,
|
||||||
|
// so can't directly use Bytes::get_u32 etc.
|
||||||
|
let len = (&buf[0..4]).read_u32::<BigEndian>().unwrap() as usize;
|
||||||
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
|
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
|
||||||
return Err(ConnectionError::Protocol(format!(
|
return Err(ProtocolError::Protocol(format!(
|
||||||
"invalid message length {len}"
|
"invalid startup packet message length {}",
|
||||||
|
len
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let request_code = retry_read!(stream.read_u32().await).map_err(ConnectionError::Socket)?;
|
if buf.len() < len {
|
||||||
|
// Don't have full message yet.
|
||||||
|
let to_read = len - buf.len();
|
||||||
|
buf.reserve(to_read);
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
// the rest of startup packet are params
|
// got the message, advance buffer
|
||||||
let params_len = len - 8;
|
let mut msg = buf.split_to(len).freeze();
|
||||||
let mut params_bytes = vec![0u8; params_len];
|
msg.advance(4); // consume len
|
||||||
stream
|
|
||||||
.read_exact(params_bytes.as_mut())
|
|
||||||
.await
|
|
||||||
.map_err(ConnectionError::Socket)?;
|
|
||||||
|
|
||||||
// Parse params depending on request code
|
let request_code = msg.get_u32();
|
||||||
let req_hi = request_code >> 16;
|
let req_hi = request_code >> 16;
|
||||||
let req_lo = request_code & ((1 << 16) - 1);
|
let req_lo = request_code & ((1 << 16) - 1);
|
||||||
|
// StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code.
|
||||||
let message = match (req_hi, req_lo) {
|
let message = match (req_hi, req_lo) {
|
||||||
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
|
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
|
||||||
if params_len != 8 {
|
if msg.remaining() != 8 {
|
||||||
return Err(ConnectionError::Protocol(
|
return Err(ProtocolError::BadMessage(
|
||||||
"expected 8 bytes for CancelRequest params".to_string(),
|
"CancelRequest message is malformed, backend PID / secret key missing"
|
||||||
|
.to_owned(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
let mut cursor = Cursor::new(params_bytes);
|
|
||||||
FeStartupPacket::CancelRequest(CancelKeyData {
|
FeStartupPacket::CancelRequest(CancelKeyData {
|
||||||
backend_pid: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
|
backend_pid: msg.get_i32(),
|
||||||
cancel_key: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
|
cancel_key: msg.get_i32(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
|
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
|
||||||
@@ -354,19 +339,23 @@ impl FeStartupPacket {
|
|||||||
FeStartupPacket::GssEncRequest
|
FeStartupPacket::GssEncRequest
|
||||||
}
|
}
|
||||||
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
|
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
|
||||||
return Err(ConnectionError::Protocol(format!(
|
return Err(ProtocolError::Protocol(format!(
|
||||||
"Unrecognized request code {unrecognized_code}"
|
"Unrecognized request code {unrecognized_code}"
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
// TODO bail if protocol major_version is not 3?
|
// TODO bail if protocol major_version is not 3?
|
||||||
(major_version, minor_version) => {
|
(major_version, minor_version) => {
|
||||||
|
// StartupMessage
|
||||||
|
|
||||||
// Parse pairs of null-terminated strings (key, value).
|
// Parse pairs of null-terminated strings (key, value).
|
||||||
// See `postgres: ProcessStartupPacket, build_startup_packet`.
|
// See `postgres: ProcessStartupPacket, build_startup_packet`.
|
||||||
let mut tokens = str::from_utf8(¶ms_bytes)
|
let mut tokens = str::from_utf8(&msg)
|
||||||
.context("StartupMessage params: invalid utf-8")?
|
.map_err(|_e| {
|
||||||
|
ProtocolError::BadMessage("StartupMessage params: invalid utf-8".to_owned())
|
||||||
|
})?
|
||||||
.strip_suffix('\0') // drop packet's own null
|
.strip_suffix('\0') // drop packet's own null
|
||||||
.ok_or_else(|| {
|
.ok_or_else(|| {
|
||||||
ConnectionError::Protocol(
|
ProtocolError::Protocol(
|
||||||
"StartupMessage params: missing null terminator".to_string(),
|
"StartupMessage params: missing null terminator".to_string(),
|
||||||
)
|
)
|
||||||
})?
|
})?
|
||||||
@@ -375,7 +364,7 @@ impl FeStartupPacket {
|
|||||||
let mut params = HashMap::new();
|
let mut params = HashMap::new();
|
||||||
while let Some(name) = tokens.next() {
|
while let Some(name) = tokens.next() {
|
||||||
let value = tokens.next().ok_or_else(|| {
|
let value = tokens.next().ok_or_else(|| {
|
||||||
ConnectionError::Protocol(
|
ProtocolError::Protocol(
|
||||||
"StartupMessage params: key without value".to_string(),
|
"StartupMessage params: key without value".to_string(),
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
@@ -390,13 +379,12 @@ impl FeStartupPacket {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
Ok(Some(message))
|
||||||
Ok(Some(FeMessage::StartupPacket(message)))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FeParseMessage {
|
impl FeParseMessage {
|
||||||
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
|
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
|
||||||
// FIXME: the rust-postgres driver uses a named prepared statement
|
// FIXME: the rust-postgres driver uses a named prepared statement
|
||||||
// for copy_out(). We're not prepared to handle that correctly. For
|
// for copy_out(). We're not prepared to handle that correctly. For
|
||||||
// now, just ignore the statement name, assuming that the client never
|
// now, just ignore the statement name, assuming that the client never
|
||||||
@@ -404,55 +392,82 @@ impl FeParseMessage {
|
|||||||
|
|
||||||
let _pstmt_name = read_cstr(&mut buf)?;
|
let _pstmt_name = read_cstr(&mut buf)?;
|
||||||
let query_string = read_cstr(&mut buf)?;
|
let query_string = read_cstr(&mut buf)?;
|
||||||
|
if buf.remaining() < 2 {
|
||||||
|
return Err(ProtocolError::BadMessage(
|
||||||
|
"Parse message is malformed, nparams missing".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
let nparams = buf.get_i16();
|
let nparams = buf.get_i16();
|
||||||
|
|
||||||
ensure!(nparams == 0, "query params not implemented");
|
if nparams != 0 {
|
||||||
|
return Err(ProtocolError::BadMessage(
|
||||||
|
"query params not implemented".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
Ok(FeMessage::Parse(FeParseMessage { query_string }))
|
Ok(FeMessage::Parse(FeParseMessage { query_string }))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FeDescribeMessage {
|
impl FeDescribeMessage {
|
||||||
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
|
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
|
||||||
let kind = buf.get_u8();
|
let kind = buf.get_u8();
|
||||||
let _pstmt_name = read_cstr(&mut buf)?;
|
let _pstmt_name = read_cstr(&mut buf)?;
|
||||||
|
|
||||||
// FIXME: see FeParseMessage::parse
|
// FIXME: see FeParseMessage::parse
|
||||||
ensure!(
|
if kind != b'S' {
|
||||||
kind == b'S',
|
return Err(ProtocolError::BadMessage(
|
||||||
"only prepared statemement Describe is implemented"
|
"only prepared statemement Describe is implemented".to_string(),
|
||||||
);
|
));
|
||||||
|
}
|
||||||
|
|
||||||
Ok(FeMessage::Describe(FeDescribeMessage { kind }))
|
Ok(FeMessage::Describe(FeDescribeMessage { kind }))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FeExecuteMessage {
|
impl FeExecuteMessage {
|
||||||
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
|
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
|
||||||
let portal_name = read_cstr(&mut buf)?;
|
let portal_name = read_cstr(&mut buf)?;
|
||||||
|
if buf.remaining() < 4 {
|
||||||
|
return Err(ProtocolError::BadMessage(
|
||||||
|
"FeExecuteMessage message is malformed, maxrows missing".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
let maxrows = buf.get_i32();
|
let maxrows = buf.get_i32();
|
||||||
|
|
||||||
ensure!(portal_name.is_empty(), "named portals not implemented");
|
if !portal_name.is_empty() {
|
||||||
ensure!(maxrows == 0, "row limit in Execute message not implemented");
|
return Err(ProtocolError::BadMessage(
|
||||||
|
"named portals not implemented".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
if maxrows != 0 {
|
||||||
|
return Err(ProtocolError::BadMessage(
|
||||||
|
"row limit in Execute message not implemented".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
|
Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FeBindMessage {
|
impl FeBindMessage {
|
||||||
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
|
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
|
||||||
let portal_name = read_cstr(&mut buf)?;
|
let portal_name = read_cstr(&mut buf)?;
|
||||||
let _pstmt_name = read_cstr(&mut buf)?;
|
let _pstmt_name = read_cstr(&mut buf)?;
|
||||||
|
|
||||||
// FIXME: see FeParseMessage::parse
|
// FIXME: see FeParseMessage::parse
|
||||||
ensure!(portal_name.is_empty(), "named portals not implemented");
|
if !portal_name.is_empty() {
|
||||||
|
return Err(ProtocolError::BadMessage(
|
||||||
|
"named portals not implemented".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
Ok(FeMessage::Bind(FeBindMessage))
|
Ok(FeMessage::Bind(FeBindMessage))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FeCloseMessage {
|
impl FeCloseMessage {
|
||||||
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
|
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
|
||||||
let _kind = buf.get_u8();
|
let _kind = buf.get_u8();
|
||||||
let _pstmt_or_portal_name = read_cstr(&mut buf)?;
|
let _pstmt_or_portal_name = read_cstr(&mut buf)?;
|
||||||
|
|
||||||
@@ -481,6 +496,7 @@ pub enum BeMessage<'a> {
|
|||||||
CloseComplete,
|
CloseComplete,
|
||||||
// None means column is NULL
|
// None means column is NULL
|
||||||
DataRow(&'a [Option<&'a [u8]>]),
|
DataRow(&'a [Option<&'a [u8]>]),
|
||||||
|
// None errcode means internal_error will be sent.
|
||||||
ErrorResponse(&'a str, Option<&'a [u8; 5]>),
|
ErrorResponse(&'a str, Option<&'a [u8; 5]>),
|
||||||
/// Single byte - used in response to SSLRequest/GSSENCRequest.
|
/// Single byte - used in response to SSLRequest/GSSENCRequest.
|
||||||
EncryptionResponse(bool),
|
EncryptionResponse(bool),
|
||||||
@@ -594,7 +610,7 @@ impl RowDescriptor<'_> {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct XLogDataBody<'a> {
|
pub struct XLogDataBody<'a> {
|
||||||
pub wal_start: u64,
|
pub wal_start: u64,
|
||||||
pub wal_end: u64,
|
pub wal_end: u64, // current end of WAL on the server
|
||||||
pub timestamp: i64,
|
pub timestamp: i64,
|
||||||
pub data: &'a [u8],
|
pub data: &'a [u8],
|
||||||
}
|
}
|
||||||
@@ -634,12 +650,11 @@ fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Safe write of s into buf as cstring (String in the protocol).
|
/// Safe write of s into buf as cstring (String in the protocol).
|
||||||
fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> io::Result<()> {
|
fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> {
|
||||||
let bytes = s.as_ref();
|
let bytes = s.as_ref();
|
||||||
if bytes.contains(&0) {
|
if bytes.contains(&0) {
|
||||||
return Err(io::Error::new(
|
return Err(ProtocolError::BadMessage(
|
||||||
io::ErrorKind::InvalidInput,
|
"string contains embedded null".to_owned(),
|
||||||
"string contains embedded null",
|
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
buf.put_slice(bytes);
|
buf.put_slice(bytes);
|
||||||
@@ -647,9 +662,13 @@ fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> io::Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_cstr(buf: &mut Bytes) -> anyhow::Result<Bytes> {
|
/// Read cstring from buf, advancing it.
|
||||||
let pos = buf.iter().position(|x| *x == 0);
|
fn read_cstr(buf: &mut Bytes) -> Result<Bytes, ProtocolError> {
|
||||||
let result = buf.split_to(pos.context("missing terminator")?);
|
let pos = buf
|
||||||
|
.iter()
|
||||||
|
.position(|x| *x == 0)
|
||||||
|
.ok_or_else(|| ProtocolError::BadMessage("missing cstring terminator".to_owned()))?;
|
||||||
|
let result = buf.split_to(pos);
|
||||||
buf.advance(1); // drop the null terminator
|
buf.advance(1); // drop the null terminator
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
@@ -658,12 +677,12 @@ pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
|
|||||||
pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000";
|
pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000";
|
||||||
|
|
||||||
impl<'a> BeMessage<'a> {
|
impl<'a> BeMessage<'a> {
|
||||||
/// Write message to the given buf.
|
/// Serialize `message` to the given `buf`.
|
||||||
// Unlike the reading side, we use BytesMut
|
/// Apart from smart memory managemet, BytesMut is good here as msg len
|
||||||
// here as msg len precedes its body and it is handy to write it down first
|
/// precedes its body and it is handy to write it down first and then fill
|
||||||
// and then fill the length. With Write we would have to either calc it
|
/// the length. With Write we would have to either calc it manually or have
|
||||||
// manually or have one more buffer.
|
/// one more buffer.
|
||||||
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> io::Result<()> {
|
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> {
|
||||||
match message {
|
match message {
|
||||||
BeMessage::AuthenticationOk => {
|
BeMessage::AuthenticationOk => {
|
||||||
buf.put_u8(b'R');
|
buf.put_u8(b'R');
|
||||||
@@ -708,7 +727,7 @@ impl<'a> BeMessage<'a> {
|
|||||||
buf.put_slice(extra);
|
buf.put_slice(extra);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok::<_, io::Error>(())
|
Ok(())
|
||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -812,7 +831,7 @@ impl<'a> BeMessage<'a> {
|
|||||||
write_cstr(error_msg, buf)?;
|
write_cstr(error_msg, buf)?;
|
||||||
|
|
||||||
buf.put_u8(0); // terminator
|
buf.put_u8(0); // terminator
|
||||||
Ok::<_, io::Error>(())
|
Ok(())
|
||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -835,7 +854,7 @@ impl<'a> BeMessage<'a> {
|
|||||||
write_cstr(error_msg.as_bytes(), buf)?;
|
write_cstr(error_msg.as_bytes(), buf)?;
|
||||||
|
|
||||||
buf.put_u8(0); // terminator
|
buf.put_u8(0); // terminator
|
||||||
Ok::<_, io::Error>(())
|
Ok(())
|
||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -890,7 +909,7 @@ impl<'a> BeMessage<'a> {
|
|||||||
buf.put_i32(-1); /* typmod */
|
buf.put_i32(-1); /* typmod */
|
||||||
buf.put_i16(0); /* format code */
|
buf.put_i16(0); /* format code */
|
||||||
}
|
}
|
||||||
Ok::<_, io::Error>(())
|
Ok(())
|
||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -957,7 +976,7 @@ impl ReplicationFeedback {
|
|||||||
// null-terminated string - key,
|
// null-terminated string - key,
|
||||||
// uint32 - value length in bytes
|
// uint32 - value length in bytes
|
||||||
// value itself
|
// value itself
|
||||||
pub fn serialize(&self, buf: &mut BytesMut) -> Result<()> {
|
pub fn serialize(&self, buf: &mut BytesMut) {
|
||||||
buf.put_u8(REPLICATION_FEEDBACK_FIELDS_NUMBER); // # of keys
|
buf.put_u8(REPLICATION_FEEDBACK_FIELDS_NUMBER); // # of keys
|
||||||
buf.put_slice(b"current_timeline_size\0");
|
buf.put_slice(b"current_timeline_size\0");
|
||||||
buf.put_i32(8);
|
buf.put_i32(8);
|
||||||
@@ -982,7 +1001,6 @@ impl ReplicationFeedback {
|
|||||||
buf.put_slice(b"ps_replytime\0");
|
buf.put_slice(b"ps_replytime\0");
|
||||||
buf.put_i32(8);
|
buf.put_i32(8);
|
||||||
buf.put_i64(timestamp);
|
buf.put_i64(timestamp);
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deserialize ReplicationFeedback message
|
// Deserialize ReplicationFeedback message
|
||||||
@@ -1050,7 +1068,7 @@ mod tests {
|
|||||||
// because it is rounded up to microseconds during serialization.
|
// because it is rounded up to microseconds during serialization.
|
||||||
rf.ps_replytime = *PG_EPOCH + Duration::from_secs(100_000_000);
|
rf.ps_replytime = *PG_EPOCH + Duration::from_secs(100_000_000);
|
||||||
let mut data = BytesMut::new();
|
let mut data = BytesMut::new();
|
||||||
rf.serialize(&mut data).unwrap();
|
rf.serialize(&mut data);
|
||||||
|
|
||||||
let rf_parsed = ReplicationFeedback::parse(data.freeze());
|
let rf_parsed = ReplicationFeedback::parse(data.freeze());
|
||||||
assert_eq!(rf, rf_parsed);
|
assert_eq!(rf, rf_parsed);
|
||||||
@@ -1065,7 +1083,7 @@ mod tests {
|
|||||||
// because it is rounded up to microseconds during serialization.
|
// because it is rounded up to microseconds during serialization.
|
||||||
rf.ps_replytime = *PG_EPOCH + Duration::from_secs(100_000_000);
|
rf.ps_replytime = *PG_EPOCH + Duration::from_secs(100_000_000);
|
||||||
let mut data = BytesMut::new();
|
let mut data = BytesMut::new();
|
||||||
rf.serialize(&mut data).unwrap();
|
rf.serialize(&mut data);
|
||||||
|
|
||||||
// Add an extra field to the buffer and adjust number of keys
|
// Add an extra field to the buffer and adjust number of keys
|
||||||
if let Some(first) = data.first_mut() {
|
if let Some(first) = data.first_mut() {
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ use pageserver_api::models::{
|
|||||||
PagestreamNblocksRequest, PagestreamNblocksResponse,
|
PagestreamNblocksRequest, PagestreamNblocksResponse,
|
||||||
};
|
};
|
||||||
use postgres_backend::{self, is_expected_io_error, AuthType, PostgresBackend, QueryError};
|
use postgres_backend::{self, is_expected_io_error, AuthType, PostgresBackend, QueryError};
|
||||||
use pq_proto::ConnectionError;
|
use pq_proto::framed::ConnectionError;
|
||||||
use pq_proto::FeStartupPacket;
|
use pq_proto::FeStartupPacket;
|
||||||
use pq_proto::{BeMessage, FeMessage, RowDescriptor};
|
use pq_proto::{BeMessage, FeMessage, RowDescriptor};
|
||||||
use std::io;
|
use std::io;
|
||||||
@@ -78,7 +78,7 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Byt
|
|||||||
FeMessage::Sync => continue,
|
FeMessage::Sync => continue,
|
||||||
FeMessage::Terminate => {
|
FeMessage::Terminate => {
|
||||||
let msg = "client terminated connection with Terminate message during COPY";
|
let msg = "client terminated connection with Terminate message during COPY";
|
||||||
let query_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
|
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
|
||||||
// error can't happen here, ErrorResponse serialization should be always ok
|
// error can't happen here, ErrorResponse serialization should be always ok
|
||||||
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?;
|
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?;
|
||||||
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
|
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
|
||||||
@@ -97,13 +97,13 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Byt
|
|||||||
}
|
}
|
||||||
Ok(None) => {
|
Ok(None) => {
|
||||||
let msg = "client closed connection during COPY";
|
let msg = "client closed connection during COPY";
|
||||||
let query_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
|
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
|
||||||
// error can't happen here, ErrorResponse serialization should be always ok
|
// error can't happen here, ErrorResponse serialization should be always ok
|
||||||
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?;
|
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?;
|
||||||
pgb.flush().await?;
|
pgb.flush().await?;
|
||||||
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
|
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
|
||||||
}
|
}
|
||||||
Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => {
|
Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => {
|
||||||
Err(io_error)?;
|
Err(io_error)?;
|
||||||
}
|
}
|
||||||
Err(other) => {
|
Err(other) => {
|
||||||
@@ -214,7 +214,7 @@ async fn page_service_conn_main(
|
|||||||
// we've been requested to shut down
|
// we've been requested to shut down
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => {
|
Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => {
|
||||||
if is_expected_io_error(&io_error) {
|
if is_expected_io_error(&io_error) {
|
||||||
info!("Postgres client disconnected ({io_error})");
|
info!("Postgres client disconnected ({io_error})");
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -1057,7 +1057,7 @@ impl From<GetActiveTenantError> for QueryError {
|
|||||||
fn from(e: GetActiveTenantError) -> Self {
|
fn from(e: GetActiveTenantError) -> Self {
|
||||||
match e {
|
match e {
|
||||||
GetActiveTenantError::WaitForActiveTimeout { .. } => QueryError::Disconnected(
|
GetActiveTenantError::WaitForActiveTimeout { .. } => QueryError::Disconnected(
|
||||||
ConnectionError::Socket(io::Error::new(io::ErrorKind::TimedOut, e.to_string())),
|
ConnectionError::Io(io::Error::new(io::ErrorKind::TimedOut, e.to_string())),
|
||||||
),
|
),
|
||||||
GetActiveTenantError::Other(e) => QueryError::Other(e),
|
GetActiveTenantError::Other(e) => QueryError::Other(e),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -354,7 +354,7 @@ pub async fn handle_walreceiver_connection(
|
|||||||
debug!("neon_status_update {status_update:?}");
|
debug!("neon_status_update {status_update:?}");
|
||||||
|
|
||||||
let mut data = BytesMut::new();
|
let mut data = BytesMut::new();
|
||||||
status_update.serialize(&mut data)?;
|
status_update.serialize(&mut data);
|
||||||
physical_stream
|
physical_stream
|
||||||
.as_mut()
|
.as_mut()
|
||||||
.zenith_status_update(data.len() as u64, &data)
|
.zenith_status_update(data.len() as u64, &data)
|
||||||
|
|||||||
@@ -1,45 +1,40 @@
|
|||||||
use crate::error::UserFacingError;
|
use crate::error::UserFacingError;
|
||||||
use anyhow::bail;
|
use anyhow::bail;
|
||||||
use bytes::BytesMut;
|
|
||||||
use pin_project_lite::pin_project;
|
use pin_project_lite::pin_project;
|
||||||
use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket};
|
use pq_proto::framed::{ConnectionError, Framed};
|
||||||
|
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
|
||||||
use rustls::ServerConfig;
|
use rustls::ServerConfig;
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::{io, task};
|
use std::{io, task};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
|
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||||
use tokio_rustls::server::TlsStream;
|
use tokio_rustls::server::TlsStream;
|
||||||
|
|
||||||
pin_project! {
|
/// Stream wrapper which implements libpq's protocol.
|
||||||
/// Stream wrapper which implements libpq's protocol.
|
/// NOTE: This object deliberately doesn't implement [`AsyncRead`]
|
||||||
/// NOTE: This object deliberately doesn't implement [`AsyncRead`]
|
/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
|
||||||
/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
|
/// to pass random malformed bytes through the connection).
|
||||||
/// to pass random malformed bytes through the connection).
|
pub struct PqStream<S> {
|
||||||
pub struct PqStream<S> {
|
framed: Framed<S>,
|
||||||
#[pin]
|
|
||||||
stream: S,
|
|
||||||
buffer: BytesMut,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S> PqStream<S> {
|
impl<S> PqStream<S> {
|
||||||
/// Construct a new libpq protocol wrapper.
|
/// Construct a new libpq protocol wrapper.
|
||||||
pub fn new(stream: S) -> Self {
|
pub fn new(stream: S) -> Self {
|
||||||
Self {
|
Self {
|
||||||
stream,
|
framed: Framed::new(stream),
|
||||||
buffer: Default::default(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extract the underlying stream.
|
/// Extract the underlying stream.
|
||||||
pub fn into_inner(self) -> S {
|
pub fn into_inner(self) -> S {
|
||||||
self.stream
|
self.framed.into_inner()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get a shared reference to the underlying stream.
|
/// Get a shared reference to the underlying stream.
|
||||||
pub fn get_ref(&self) -> &S {
|
pub fn get_ref(&self) -> &S {
|
||||||
&self.stream
|
self.framed.get_ref()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,16 +45,19 @@ fn err_connection() -> io::Error {
|
|||||||
impl<S: AsyncRead + Unpin> PqStream<S> {
|
impl<S: AsyncRead + Unpin> PqStream<S> {
|
||||||
/// Receive [`FeStartupPacket`], which is a first packet sent by a client.
|
/// Receive [`FeStartupPacket`], which is a first packet sent by a client.
|
||||||
pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
|
pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
|
||||||
// TODO: `FeStartupPacket::read_fut` should return `FeStartupPacket`
|
self.framed
|
||||||
let msg = FeStartupPacket::read(&mut self.stream)
|
.read_startup_message()
|
||||||
.await
|
.await
|
||||||
.map_err(ConnectionError::into_io_error)?
|
.map_err(ConnectionError::into_io_error)?
|
||||||
.ok_or_else(err_connection)?;
|
.ok_or_else(err_connection)
|
||||||
|
}
|
||||||
|
|
||||||
match msg {
|
async fn read_message(&mut self) -> io::Result<FeMessage> {
|
||||||
FeMessage::StartupPacket(packet) => Ok(packet),
|
self.framed
|
||||||
_ => panic!("unreachable state"),
|
.read_message()
|
||||||
}
|
.await
|
||||||
|
.map_err(ConnectionError::into_io_error)?
|
||||||
|
.ok_or_else(err_connection)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
|
pub async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
|
||||||
@@ -71,19 +69,14 @@ impl<S: AsyncRead + Unpin> PqStream<S> {
|
|||||||
)),
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn read_message(&mut self) -> io::Result<FeMessage> {
|
|
||||||
FeMessage::read(&mut self.stream)
|
|
||||||
.await
|
|
||||||
.map_err(ConnectionError::into_io_error)?
|
|
||||||
.ok_or_else(err_connection)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: AsyncWrite + Unpin> PqStream<S> {
|
impl<S: AsyncWrite + Unpin> PqStream<S> {
|
||||||
/// Write the message into an internal buffer, but don't flush the underlying stream.
|
/// Write the message into an internal buffer, but don't flush the underlying stream.
|
||||||
pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
|
pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
|
||||||
BeMessage::write(&mut self.buffer, message)?;
|
self.framed
|
||||||
|
.write_message(message)
|
||||||
|
.map_err(ProtocolError::into_io_error)?;
|
||||||
Ok(self)
|
Ok(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -96,9 +89,7 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
|
|||||||
|
|
||||||
/// Flush the output buffer into the underlying stream.
|
/// Flush the output buffer into the underlying stream.
|
||||||
pub async fn flush(&mut self) -> io::Result<&mut Self> {
|
pub async fn flush(&mut self) -> io::Result<&mut Self> {
|
||||||
self.stream.write_all(&self.buffer).await?;
|
self.framed.flush().await?;
|
||||||
self.buffer.clear();
|
|
||||||
self.stream.flush().await?;
|
|
||||||
Ok(self)
|
Ok(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -488,7 +488,7 @@ impl AcceptorProposerMessage {
|
|||||||
buf.put_u64_le(msg.hs_feedback.xmin);
|
buf.put_u64_le(msg.hs_feedback.xmin);
|
||||||
buf.put_u64_le(msg.hs_feedback.catalog_xmin);
|
buf.put_u64_le(msg.hs_feedback.catalog_xmin);
|
||||||
|
|
||||||
msg.pageserver_feedback.serialize(buf)?;
|
msg.pageserver_feedback.serialize(buf);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user