mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-21 07:00:38 +00:00
1082 lines
36 KiB
Rust
1082 lines
36 KiB
Rust
//! Postgres protocol messages serialization-deserialization. See
|
|
//! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
|
|
//! on message formats.
|
|
|
|
// Tools for calling certain async methods in sync contexts.
|
|
pub mod sync;
|
|
|
|
use anyhow::{bail, ensure, Context, Result};
|
|
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
|
use postgres_protocol::PG_EPOCH;
|
|
use serde::{Deserialize, Serialize};
|
|
use std::{
|
|
borrow::Cow,
|
|
collections::HashMap,
|
|
fmt,
|
|
future::Future,
|
|
io::{self, Cursor},
|
|
str,
|
|
time::{Duration, SystemTime},
|
|
};
|
|
use sync::{AsyncishRead, SyncFuture};
|
|
use tokio::io::AsyncReadExt;
|
|
use tracing::{trace, warn};
|
|
|
|
pub type Oid = u32;
|
|
pub type SystemId = u64;
|
|
|
|
pub const INT8_OID: Oid = 20;
|
|
pub const INT4_OID: Oid = 23;
|
|
pub const TEXT_OID: Oid = 25;
|
|
|
|
#[derive(Debug)]
|
|
pub enum FeMessage {
|
|
StartupPacket(FeStartupPacket),
|
|
// Simple query.
|
|
Query(Bytes),
|
|
// Extended query protocol.
|
|
Parse(FeParseMessage),
|
|
Describe(FeDescribeMessage),
|
|
Bind(FeBindMessage),
|
|
Execute(FeExecuteMessage),
|
|
Close(FeCloseMessage),
|
|
Sync,
|
|
Terminate,
|
|
CopyData(Bytes),
|
|
CopyDone,
|
|
CopyFail,
|
|
PasswordMessage(Bytes),
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum FeStartupPacket {
|
|
CancelRequest(CancelKeyData),
|
|
SslRequest,
|
|
GssEncRequest,
|
|
StartupMessage {
|
|
major_version: u32,
|
|
minor_version: u32,
|
|
params: StartupMessageParams,
|
|
},
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct StartupMessageParams {
|
|
params: HashMap<String, String>,
|
|
}
|
|
|
|
impl StartupMessageParams {
|
|
/// Get parameter's value by its name.
|
|
pub fn get(&self, name: &str) -> Option<&str> {
|
|
self.params.get(name).map(|s| s.as_str())
|
|
}
|
|
|
|
/// Split command-line options according to PostgreSQL's logic,
|
|
/// taking into account all escape sequences but leaving them as-is.
|
|
/// [`None`] means that there's no `options` in [`Self`].
|
|
pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> {
|
|
// See `postgres: pg_split_opts`.
|
|
let mut last_was_escape = false;
|
|
let iter = self
|
|
.get("options")?
|
|
.split(move |c: char| {
|
|
// We split by non-escaped whitespace symbols.
|
|
let should_split = c.is_ascii_whitespace() && !last_was_escape;
|
|
last_was_escape = c == '\\' && !last_was_escape;
|
|
should_split
|
|
})
|
|
.filter(|s| !s.is_empty());
|
|
|
|
Some(iter)
|
|
}
|
|
|
|
/// Split command-line options according to PostgreSQL's logic,
|
|
/// applying all escape sequences (using owned strings as needed).
|
|
/// [`None`] means that there's no `options` in [`Self`].
|
|
pub fn options_escaped(&self) -> Option<impl Iterator<Item = Cow<'_, str>>> {
|
|
// See `postgres: pg_split_opts`.
|
|
let iter = self.options_raw()?.map(|s| {
|
|
let mut preserve_next_escape = false;
|
|
let escape = |c| {
|
|
// We should remove '\\' unless it's preceded by '\\'.
|
|
let should_remove = c == '\\' && !preserve_next_escape;
|
|
preserve_next_escape = should_remove;
|
|
should_remove
|
|
};
|
|
|
|
match s.contains('\\') {
|
|
true => Cow::Owned(s.replace(escape, "")),
|
|
false => Cow::Borrowed(s),
|
|
}
|
|
});
|
|
|
|
Some(iter)
|
|
}
|
|
|
|
// This function is mostly useful in tests.
|
|
#[doc(hidden)]
|
|
pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self {
|
|
Self {
|
|
params: pairs.map(|(k, v)| (k.to_owned(), v.to_owned())).into(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
|
|
pub struct CancelKeyData {
|
|
pub backend_pid: i32,
|
|
pub cancel_key: i32,
|
|
}
|
|
|
|
impl fmt::Display for CancelKeyData {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
let hi = (self.backend_pid as u64) << 32;
|
|
let lo = self.cancel_key as u64;
|
|
let id = hi | lo;
|
|
|
|
// This format is more compact and might work better for logs.
|
|
f.debug_tuple("CancelKeyData")
|
|
.field(&format_args!("{:x}", id))
|
|
.finish()
|
|
}
|
|
}
|
|
|
|
use rand::distributions::{Distribution, Standard};
|
|
impl Distribution<CancelKeyData> for Standard {
|
|
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
|
|
CancelKeyData {
|
|
backend_pid: rng.gen(),
|
|
cancel_key: rng.gen(),
|
|
}
|
|
}
|
|
}
|
|
|
|
// We only support the simple case of Parse on unnamed prepared statement and
|
|
// no params
|
|
#[derive(Debug)]
|
|
pub struct FeParseMessage {
|
|
pub query_string: Bytes,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct FeDescribeMessage {
|
|
pub kind: u8, // 'S' to describe a prepared statement; or 'P' to describe a portal.
|
|
// we only support unnamed prepared stmt or portal
|
|
}
|
|
|
|
// we only support unnamed prepared stmt and portal
|
|
#[derive(Debug)]
|
|
pub struct FeBindMessage;
|
|
|
|
// we only support unnamed prepared stmt or portal
|
|
#[derive(Debug)]
|
|
pub struct FeExecuteMessage {
|
|
/// max # of rows
|
|
pub maxrows: i32,
|
|
}
|
|
|
|
// we only support unnamed prepared stmt and portal
|
|
#[derive(Debug)]
|
|
pub struct FeCloseMessage;
|
|
|
|
/// Retry a read on EINTR
|
|
///
|
|
/// This runs the enclosed expression, and if it returns
|
|
/// Err(io::ErrorKind::Interrupted), retries it.
|
|
macro_rules! retry_read {
|
|
( $x:expr ) => {
|
|
loop {
|
|
match $x {
|
|
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
|
|
res => break res,
|
|
}
|
|
}
|
|
};
|
|
}
|
|
|
|
impl FeMessage {
|
|
/// Read one message from the stream.
|
|
/// This function returns `Ok(None)` in case of EOF.
|
|
/// One way to handle this properly:
|
|
///
|
|
/// ```
|
|
/// # use std::io;
|
|
/// # use pq_proto::FeMessage;
|
|
/// #
|
|
/// # fn process_message(msg: FeMessage) -> anyhow::Result<()> {
|
|
/// # Ok(())
|
|
/// # };
|
|
/// #
|
|
/// fn do_the_job(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result<()> {
|
|
/// while let Some(msg) = FeMessage::read(stream)? {
|
|
/// process_message(msg)?;
|
|
/// }
|
|
///
|
|
/// Ok(())
|
|
/// }
|
|
/// ```
|
|
#[inline(never)]
|
|
pub fn read(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result<Option<FeMessage>> {
|
|
Self::read_fut(&mut AsyncishRead(stream)).wait()
|
|
}
|
|
|
|
/// Read one message from the stream.
|
|
/// See documentation for `Self::read`.
|
|
pub fn read_fut<Reader>(
|
|
stream: &mut Reader,
|
|
) -> SyncFuture<Reader, impl Future<Output = anyhow::Result<Option<FeMessage>>> + '_>
|
|
where
|
|
Reader: tokio::io::AsyncRead + Unpin,
|
|
{
|
|
// We return a Future that's sync (has a `wait` method) if and only if the provided stream is SyncProof.
|
|
// SyncFuture contract: we are only allowed to await on sync-proof futures, the AsyncRead and
|
|
// AsyncReadExt methods of the stream.
|
|
SyncFuture::new(async move {
|
|
// Each libpq message begins with a message type byte, followed by message length
|
|
// If the client closes the connection, return None. But if the client closes the
|
|
// connection in the middle of a message, we will return an error.
|
|
let tag = match retry_read!(stream.read_u8().await) {
|
|
Ok(b) => b,
|
|
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
|
|
Err(e) => return Err(e.into()),
|
|
};
|
|
|
|
// The message length includes itself, so it better be at least 4.
|
|
let len = retry_read!(stream.read_u32().await)?
|
|
.checked_sub(4)
|
|
.context("invalid message length")?;
|
|
|
|
let body = {
|
|
let mut buffer = vec![0u8; len as usize];
|
|
stream.read_exact(&mut buffer).await?;
|
|
Bytes::from(buffer)
|
|
};
|
|
|
|
match tag {
|
|
b'Q' => Ok(Some(FeMessage::Query(body))),
|
|
b'P' => Ok(Some(FeParseMessage::parse(body)?)),
|
|
b'D' => Ok(Some(FeDescribeMessage::parse(body)?)),
|
|
b'E' => Ok(Some(FeExecuteMessage::parse(body)?)),
|
|
b'B' => Ok(Some(FeBindMessage::parse(body)?)),
|
|
b'C' => Ok(Some(FeCloseMessage::parse(body)?)),
|
|
b'S' => Ok(Some(FeMessage::Sync)),
|
|
b'X' => Ok(Some(FeMessage::Terminate)),
|
|
b'd' => Ok(Some(FeMessage::CopyData(body))),
|
|
b'c' => Ok(Some(FeMessage::CopyDone)),
|
|
b'f' => Ok(Some(FeMessage::CopyFail)),
|
|
b'p' => Ok(Some(FeMessage::PasswordMessage(body))),
|
|
tag => bail!("unknown message tag: {},'{:?}'", tag, body),
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
impl FeStartupPacket {
|
|
/// Read startup message from the stream.
|
|
// XXX: It's tempting yet undesirable to accept `stream` by value,
|
|
// since such a change will cause user-supplied &mut references to be consumed
|
|
pub fn read(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result<Option<FeMessage>> {
|
|
Self::read_fut(&mut AsyncishRead(stream)).wait()
|
|
}
|
|
|
|
/// Read startup message from the stream.
|
|
// XXX: It's tempting yet undesirable to accept `stream` by value,
|
|
// since such a change will cause user-supplied &mut references to be consumed
|
|
pub fn read_fut<Reader>(
|
|
stream: &mut Reader,
|
|
) -> SyncFuture<Reader, impl Future<Output = anyhow::Result<Option<FeMessage>>> + '_>
|
|
where
|
|
Reader: tokio::io::AsyncRead + Unpin,
|
|
{
|
|
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
|
|
const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
|
|
const CANCEL_REQUEST_CODE: u32 = 5678;
|
|
const NEGOTIATE_SSL_CODE: u32 = 5679;
|
|
const NEGOTIATE_GSS_CODE: u32 = 5680;
|
|
|
|
SyncFuture::new(async move {
|
|
// Read length. If the connection is closed before reading anything (or before
|
|
// reading 4 bytes, to be precise), return None to indicate that the connection
|
|
// was closed. This matches the PostgreSQL server's behavior, which avoids noise
|
|
// in the log if the client opens connection but closes it immediately.
|
|
let len = match retry_read!(stream.read_u32().await) {
|
|
Ok(len) => len as usize,
|
|
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
|
|
Err(e) => return Err(e.into()),
|
|
};
|
|
|
|
#[allow(clippy::manual_range_contains)]
|
|
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
|
|
bail!("invalid message length");
|
|
}
|
|
|
|
let request_code = retry_read!(stream.read_u32().await)?;
|
|
|
|
// the rest of startup packet are params
|
|
let params_len = len - 8;
|
|
let mut params_bytes = vec![0u8; params_len];
|
|
stream.read_exact(params_bytes.as_mut()).await?;
|
|
|
|
// Parse params depending on request code
|
|
let req_hi = request_code >> 16;
|
|
let req_lo = request_code & ((1 << 16) - 1);
|
|
let message = match (req_hi, req_lo) {
|
|
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
|
|
ensure!(params_len == 8, "expected 8 bytes for CancelRequest params");
|
|
let mut cursor = Cursor::new(params_bytes);
|
|
FeStartupPacket::CancelRequest(CancelKeyData {
|
|
backend_pid: cursor.read_i32().await?,
|
|
cancel_key: cursor.read_i32().await?,
|
|
})
|
|
}
|
|
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
|
|
// Requested upgrade to SSL (aka TLS)
|
|
FeStartupPacket::SslRequest
|
|
}
|
|
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
|
|
// Requested upgrade to GSSAPI
|
|
FeStartupPacket::GssEncRequest
|
|
}
|
|
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
|
|
bail!("Unrecognized request code {}", unrecognized_code)
|
|
}
|
|
// TODO bail if protocol major_version is not 3?
|
|
(major_version, minor_version) => {
|
|
// Parse pairs of null-terminated strings (key, value).
|
|
// See `postgres: ProcessStartupPacket, build_startup_packet`.
|
|
let mut tokens = str::from_utf8(¶ms_bytes)
|
|
.context("StartupMessage params: invalid utf-8")?
|
|
.strip_suffix('\0') // drop packet's own null terminator
|
|
.context("StartupMessage params: missing null terminator")?
|
|
.split_terminator('\0');
|
|
|
|
let mut params = HashMap::new();
|
|
while let Some(name) = tokens.next() {
|
|
let value = tokens
|
|
.next()
|
|
.context("StartupMessage params: key without value")?;
|
|
|
|
params.insert(name.to_owned(), value.to_owned());
|
|
}
|
|
|
|
FeStartupPacket::StartupMessage {
|
|
major_version,
|
|
minor_version,
|
|
params: StartupMessageParams { params },
|
|
}
|
|
}
|
|
};
|
|
|
|
Ok(Some(FeMessage::StartupPacket(message)))
|
|
})
|
|
}
|
|
}
|
|
|
|
impl FeParseMessage {
|
|
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
|
|
// FIXME: the rust-postgres driver uses a named prepared statement
|
|
// for copy_out(). We're not prepared to handle that correctly. For
|
|
// now, just ignore the statement name, assuming that the client never
|
|
// uses more than one prepared statement at a time.
|
|
|
|
let _pstmt_name = read_cstr(&mut buf)?;
|
|
let query_string = read_cstr(&mut buf)?;
|
|
let nparams = buf.get_i16();
|
|
|
|
ensure!(nparams == 0, "query params not implemented");
|
|
|
|
Ok(FeMessage::Parse(FeParseMessage { query_string }))
|
|
}
|
|
}
|
|
|
|
impl FeDescribeMessage {
|
|
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
|
|
let kind = buf.get_u8();
|
|
let _pstmt_name = read_cstr(&mut buf)?;
|
|
|
|
// FIXME: see FeParseMessage::parse
|
|
ensure!(
|
|
kind == b'S',
|
|
"only prepared statemement Describe is implemented"
|
|
);
|
|
|
|
Ok(FeMessage::Describe(FeDescribeMessage { kind }))
|
|
}
|
|
}
|
|
|
|
impl FeExecuteMessage {
|
|
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
|
|
let portal_name = read_cstr(&mut buf)?;
|
|
let maxrows = buf.get_i32();
|
|
|
|
ensure!(portal_name.is_empty(), "named portals not implemented");
|
|
ensure!(maxrows == 0, "row limit in Execute message not implemented");
|
|
|
|
Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
|
|
}
|
|
}
|
|
|
|
impl FeBindMessage {
|
|
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
|
|
let portal_name = read_cstr(&mut buf)?;
|
|
let _pstmt_name = read_cstr(&mut buf)?;
|
|
|
|
// FIXME: see FeParseMessage::parse
|
|
ensure!(portal_name.is_empty(), "named portals not implemented");
|
|
|
|
Ok(FeMessage::Bind(FeBindMessage))
|
|
}
|
|
}
|
|
|
|
impl FeCloseMessage {
|
|
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
|
|
let _kind = buf.get_u8();
|
|
let _pstmt_or_portal_name = read_cstr(&mut buf)?;
|
|
|
|
// FIXME: we do nothing with Close
|
|
Ok(FeMessage::Close(FeCloseMessage))
|
|
}
|
|
}
|
|
|
|
// Backend
|
|
|
|
#[derive(Debug)]
|
|
pub enum BeMessage<'a> {
|
|
AuthenticationOk,
|
|
AuthenticationMD5Password([u8; 4]),
|
|
AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
|
|
AuthenticationCleartextPassword,
|
|
BackendKeyData(CancelKeyData),
|
|
BindComplete,
|
|
CommandComplete(&'a [u8]),
|
|
CopyData(&'a [u8]),
|
|
CopyDone,
|
|
CopyFail,
|
|
CopyInResponse,
|
|
CopyOutResponse,
|
|
CopyBothResponse,
|
|
CloseComplete,
|
|
// None means column is NULL
|
|
DataRow(&'a [Option<&'a [u8]>]),
|
|
ErrorResponse(&'a str),
|
|
/// Single byte - used in response to SSLRequest/GSSENCRequest.
|
|
EncryptionResponse(bool),
|
|
NoData,
|
|
ParameterDescription,
|
|
ParameterStatus(BeParameterStatusMessage<'a>),
|
|
ParseComplete,
|
|
ReadyForQuery,
|
|
RowDescription(&'a [RowDescriptor<'a>]),
|
|
XLogData(XLogDataBody<'a>),
|
|
NoticeResponse(&'a str),
|
|
KeepAlive(WalSndKeepAlive),
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum BeAuthenticationSaslMessage<'a> {
|
|
Methods(&'a [&'a str]),
|
|
Continue(&'a [u8]),
|
|
Final(&'a [u8]),
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum BeParameterStatusMessage<'a> {
|
|
Encoding(&'a str),
|
|
ServerVersion(&'a str),
|
|
}
|
|
|
|
impl BeParameterStatusMessage<'static> {
|
|
pub fn encoding() -> BeMessage<'static> {
|
|
BeMessage::ParameterStatus(Self::Encoding("UTF8"))
|
|
}
|
|
}
|
|
|
|
// One row description in RowDescription packet.
|
|
#[derive(Debug)]
|
|
pub struct RowDescriptor<'a> {
|
|
pub name: &'a [u8],
|
|
pub tableoid: Oid,
|
|
pub attnum: i16,
|
|
pub typoid: Oid,
|
|
pub typlen: i16,
|
|
pub typmod: i32,
|
|
pub formatcode: i16,
|
|
}
|
|
|
|
impl Default for RowDescriptor<'_> {
|
|
fn default() -> RowDescriptor<'static> {
|
|
RowDescriptor {
|
|
name: b"",
|
|
tableoid: 0,
|
|
attnum: 0,
|
|
typoid: 0,
|
|
typlen: 0,
|
|
typmod: 0,
|
|
formatcode: 0,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl RowDescriptor<'_> {
|
|
/// Convenience function to create a RowDescriptor message for an int8 column
|
|
pub const fn int8_col(name: &[u8]) -> RowDescriptor {
|
|
RowDescriptor {
|
|
name,
|
|
tableoid: 0,
|
|
attnum: 0,
|
|
typoid: INT8_OID,
|
|
typlen: 8,
|
|
typmod: 0,
|
|
formatcode: 0,
|
|
}
|
|
}
|
|
|
|
pub const fn text_col(name: &[u8]) -> RowDescriptor {
|
|
RowDescriptor {
|
|
name,
|
|
tableoid: 0,
|
|
attnum: 0,
|
|
typoid: TEXT_OID,
|
|
typlen: -1,
|
|
typmod: 0,
|
|
formatcode: 0,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct XLogDataBody<'a> {
|
|
pub wal_start: u64,
|
|
pub wal_end: u64,
|
|
pub timestamp: i64,
|
|
pub data: &'a [u8],
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct WalSndKeepAlive {
|
|
pub sent_ptr: u64,
|
|
pub timestamp: i64,
|
|
pub request_reply: bool,
|
|
}
|
|
|
|
pub static HELLO_WORLD_ROW: BeMessage = BeMessage::DataRow(&[Some(b"hello world")]);
|
|
|
|
// single text column
|
|
pub static SINGLE_COL_ROWDESC: BeMessage = BeMessage::RowDescription(&[RowDescriptor {
|
|
name: b"data",
|
|
tableoid: 0,
|
|
attnum: 0,
|
|
typoid: TEXT_OID,
|
|
typlen: -1,
|
|
typmod: 0,
|
|
formatcode: 0,
|
|
}]);
|
|
|
|
/// Call f() to write body of the message and prepend it with 4-byte len as
|
|
/// prescribed by the protocol.
|
|
fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
|
|
let base = buf.len();
|
|
buf.extend_from_slice(&[0; 4]);
|
|
|
|
let res = f(buf);
|
|
|
|
let size = i32::try_from(buf.len() - base).expect("message too big to transmit");
|
|
(&mut buf[base..]).put_slice(&size.to_be_bytes());
|
|
|
|
res
|
|
}
|
|
|
|
/// Safe write of s into buf as cstring (String in the protocol).
|
|
fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> {
|
|
if s.contains(&0) {
|
|
return Err(io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
"string contains embedded null",
|
|
));
|
|
}
|
|
buf.put_slice(s);
|
|
buf.put_u8(0);
|
|
Ok(())
|
|
}
|
|
|
|
fn read_cstr(buf: &mut Bytes) -> anyhow::Result<Bytes> {
|
|
let pos = buf.iter().position(|x| *x == 0);
|
|
let result = buf.split_to(pos.context("missing terminator")?);
|
|
buf.advance(1); // drop the null terminator
|
|
Ok(result)
|
|
}
|
|
|
|
impl<'a> BeMessage<'a> {
|
|
/// Write message to the given buf.
|
|
// Unlike the reading side, we use BytesMut
|
|
// here as msg len precedes its body and it is handy to write it down first
|
|
// and then fill the length. With Write we would have to either calc it
|
|
// manually or have one more buffer.
|
|
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> io::Result<()> {
|
|
match message {
|
|
BeMessage::AuthenticationOk => {
|
|
buf.put_u8(b'R');
|
|
write_body(buf, |buf| {
|
|
buf.put_i32(0); // Specifies that the authentication was successful.
|
|
});
|
|
}
|
|
|
|
BeMessage::AuthenticationCleartextPassword => {
|
|
buf.put_u8(b'R');
|
|
write_body(buf, |buf| {
|
|
buf.put_i32(3); // Specifies that clear text password is required.
|
|
});
|
|
}
|
|
|
|
BeMessage::AuthenticationMD5Password(salt) => {
|
|
buf.put_u8(b'R');
|
|
write_body(buf, |buf| {
|
|
buf.put_i32(5); // Specifies that an MD5-encrypted password is required.
|
|
buf.put_slice(&salt[..]);
|
|
});
|
|
}
|
|
|
|
BeMessage::AuthenticationSasl(msg) => {
|
|
buf.put_u8(b'R');
|
|
write_body(buf, |buf| {
|
|
use BeAuthenticationSaslMessage::*;
|
|
match msg {
|
|
Methods(methods) => {
|
|
buf.put_i32(10); // Specifies that SASL auth method is used.
|
|
for method in methods.iter() {
|
|
write_cstr(method.as_bytes(), buf)?;
|
|
}
|
|
buf.put_u8(0); // zero terminator for the list
|
|
}
|
|
Continue(extra) => {
|
|
buf.put_i32(11); // Continue SASL auth.
|
|
buf.put_slice(extra);
|
|
}
|
|
Final(extra) => {
|
|
buf.put_i32(12); // Send final SASL message.
|
|
buf.put_slice(extra);
|
|
}
|
|
}
|
|
Ok::<_, io::Error>(())
|
|
})?;
|
|
}
|
|
|
|
BeMessage::BackendKeyData(key_data) => {
|
|
buf.put_u8(b'K');
|
|
write_body(buf, |buf| {
|
|
buf.put_i32(key_data.backend_pid);
|
|
buf.put_i32(key_data.cancel_key);
|
|
});
|
|
}
|
|
|
|
BeMessage::BindComplete => {
|
|
buf.put_u8(b'2');
|
|
write_body(buf, |_| {});
|
|
}
|
|
|
|
BeMessage::CloseComplete => {
|
|
buf.put_u8(b'3');
|
|
write_body(buf, |_| {});
|
|
}
|
|
|
|
BeMessage::CommandComplete(cmd) => {
|
|
buf.put_u8(b'C');
|
|
write_body(buf, |buf| write_cstr(cmd, buf))?;
|
|
}
|
|
|
|
BeMessage::CopyData(data) => {
|
|
buf.put_u8(b'd');
|
|
write_body(buf, |buf| {
|
|
buf.put_slice(data);
|
|
});
|
|
}
|
|
|
|
BeMessage::CopyDone => {
|
|
buf.put_u8(b'c');
|
|
write_body(buf, |_| {});
|
|
}
|
|
|
|
BeMessage::CopyFail => {
|
|
buf.put_u8(b'f');
|
|
write_body(buf, |_| {});
|
|
}
|
|
|
|
BeMessage::CopyInResponse => {
|
|
buf.put_u8(b'G');
|
|
write_body(buf, |buf| {
|
|
buf.put_u8(1); // copy_is_binary
|
|
buf.put_i16(0); // numAttributes
|
|
});
|
|
}
|
|
|
|
BeMessage::CopyOutResponse => {
|
|
buf.put_u8(b'H');
|
|
write_body(buf, |buf| {
|
|
buf.put_u8(0); // copy_is_binary
|
|
buf.put_i16(0); // numAttributes
|
|
});
|
|
}
|
|
|
|
BeMessage::CopyBothResponse => {
|
|
buf.put_u8(b'W');
|
|
write_body(buf, |buf| {
|
|
// doesn't matter, used only for replication
|
|
buf.put_u8(0); // copy_is_binary
|
|
buf.put_i16(0); // numAttributes
|
|
});
|
|
}
|
|
|
|
BeMessage::DataRow(vals) => {
|
|
buf.put_u8(b'D');
|
|
write_body(buf, |buf| {
|
|
buf.put_u16(vals.len() as u16); // num of cols
|
|
for val_opt in vals.iter() {
|
|
if let Some(val) = val_opt {
|
|
buf.put_u32(val.len() as u32);
|
|
buf.put_slice(val);
|
|
} else {
|
|
buf.put_i32(-1);
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
// ErrorResponse is a zero-terminated array of zero-terminated fields.
|
|
// First byte of each field represents type of this field. Set just enough fields
|
|
// to satisfy rust-postgres client: 'S' -- severity, 'C' -- error, 'M' -- error
|
|
// message text.
|
|
BeMessage::ErrorResponse(error_msg) => {
|
|
// For all the errors set Severity to Error and error code to
|
|
// 'internal error'.
|
|
|
|
// 'E' signalizes ErrorResponse messages
|
|
buf.put_u8(b'E');
|
|
write_body(buf, |buf| {
|
|
buf.put_u8(b'S'); // severity
|
|
buf.put_slice(b"ERROR\0");
|
|
|
|
buf.put_u8(b'C'); // SQLSTATE error code
|
|
buf.put_slice(b"CXX000\0");
|
|
|
|
buf.put_u8(b'M'); // the message
|
|
write_cstr(error_msg.as_bytes(), buf)?;
|
|
|
|
buf.put_u8(0); // terminator
|
|
Ok::<_, io::Error>(())
|
|
})?;
|
|
}
|
|
|
|
// NoticeResponse has the same format as ErrorResponse. From doc: "The frontend should display the
|
|
// message but continue listening for ReadyForQuery or ErrorResponse"
|
|
BeMessage::NoticeResponse(error_msg) => {
|
|
// For all the errors set Severity to Error and error code to
|
|
// 'internal error'.
|
|
|
|
// 'N' signalizes NoticeResponse messages
|
|
buf.put_u8(b'N');
|
|
write_body(buf, |buf| {
|
|
buf.put_u8(b'S'); // severity
|
|
buf.put_slice(b"NOTICE\0");
|
|
|
|
buf.put_u8(b'C'); // SQLSTATE error code
|
|
buf.put_slice(b"CXX000\0");
|
|
|
|
buf.put_u8(b'M'); // the message
|
|
write_cstr(error_msg.as_bytes(), buf)?;
|
|
|
|
buf.put_u8(0); // terminator
|
|
Ok::<_, io::Error>(())
|
|
})?;
|
|
}
|
|
|
|
BeMessage::NoData => {
|
|
buf.put_u8(b'n');
|
|
write_body(buf, |_| {});
|
|
}
|
|
|
|
BeMessage::EncryptionResponse(should_negotiate) => {
|
|
let response = if *should_negotiate { b'S' } else { b'N' };
|
|
buf.put_u8(response);
|
|
}
|
|
|
|
BeMessage::ParameterStatus(param) => {
|
|
use std::io::{IoSlice, Write};
|
|
use BeParameterStatusMessage::*;
|
|
|
|
let [name, value] = match param {
|
|
Encoding(name) => [b"client_encoding", name.as_bytes()],
|
|
ServerVersion(version) => [b"server_version", version.as_bytes()],
|
|
};
|
|
|
|
// Parameter names and values are passed as null-terminated strings
|
|
let iov = &mut [name, b"\0", value, b"\0"].map(IoSlice::new);
|
|
let mut buffer = [0u8; 64]; // this should be enough
|
|
let cnt = buffer.as_mut().write_vectored(iov).unwrap();
|
|
|
|
buf.put_u8(b'S');
|
|
write_body(buf, |buf| {
|
|
buf.put_slice(&buffer[..cnt]);
|
|
});
|
|
}
|
|
|
|
BeMessage::ParameterDescription => {
|
|
buf.put_u8(b't');
|
|
write_body(buf, |buf| {
|
|
// we don't support params, so always 0
|
|
buf.put_i16(0);
|
|
});
|
|
}
|
|
|
|
BeMessage::ParseComplete => {
|
|
buf.put_u8(b'1');
|
|
write_body(buf, |_| {});
|
|
}
|
|
|
|
BeMessage::ReadyForQuery => {
|
|
buf.put_u8(b'Z');
|
|
write_body(buf, |buf| {
|
|
buf.put_u8(b'I');
|
|
});
|
|
}
|
|
|
|
BeMessage::RowDescription(rows) => {
|
|
buf.put_u8(b'T');
|
|
write_body(buf, |buf| {
|
|
buf.put_i16(rows.len() as i16); // # of fields
|
|
for row in rows.iter() {
|
|
write_cstr(row.name, buf)?;
|
|
buf.put_i32(0); /* table oid */
|
|
buf.put_i16(0); /* attnum */
|
|
buf.put_u32(row.typoid);
|
|
buf.put_i16(row.typlen);
|
|
buf.put_i32(-1); /* typmod */
|
|
buf.put_i16(0); /* format code */
|
|
}
|
|
Ok::<_, io::Error>(())
|
|
})?;
|
|
}
|
|
|
|
BeMessage::XLogData(body) => {
|
|
buf.put_u8(b'd');
|
|
write_body(buf, |buf| {
|
|
buf.put_u8(b'w');
|
|
buf.put_u64(body.wal_start);
|
|
buf.put_u64(body.wal_end);
|
|
buf.put_i64(body.timestamp);
|
|
buf.put_slice(body.data);
|
|
});
|
|
}
|
|
|
|
BeMessage::KeepAlive(req) => {
|
|
buf.put_u8(b'd');
|
|
write_body(buf, |buf| {
|
|
buf.put_u8(b'k');
|
|
buf.put_u64(req.sent_ptr);
|
|
buf.put_i64(req.timestamp);
|
|
buf.put_u8(if req.request_reply { 1 } else { 0 });
|
|
});
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
// Neon extension of postgres replication protocol
|
|
// See NEON_STATUS_UPDATE_TAG_BYTE
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
|
pub struct ReplicationFeedback {
|
|
// Last known size of the timeline. Used to enforce timeline size limit.
|
|
pub current_timeline_size: u64,
|
|
// Parts of StandbyStatusUpdate we resend to compute via safekeeper
|
|
pub ps_writelsn: u64,
|
|
pub ps_applylsn: u64,
|
|
pub ps_flushlsn: u64,
|
|
pub ps_replytime: SystemTime,
|
|
}
|
|
|
|
// NOTE: Do not forget to increment this number when adding new fields to ReplicationFeedback.
|
|
// Do not remove previously available fields because this might be backwards incompatible.
|
|
pub const REPLICATION_FEEDBACK_FIELDS_NUMBER: u8 = 5;
|
|
|
|
impl ReplicationFeedback {
|
|
pub fn empty() -> ReplicationFeedback {
|
|
ReplicationFeedback {
|
|
current_timeline_size: 0,
|
|
ps_writelsn: 0,
|
|
ps_applylsn: 0,
|
|
ps_flushlsn: 0,
|
|
ps_replytime: SystemTime::now(),
|
|
}
|
|
}
|
|
|
|
// Serialize ReplicationFeedback using custom format
|
|
// to support protocol extensibility.
|
|
//
|
|
// Following layout is used:
|
|
// char - number of key-value pairs that follow.
|
|
//
|
|
// key-value pairs:
|
|
// null-terminated string - key,
|
|
// uint32 - value length in bytes
|
|
// value itself
|
|
pub fn serialize(&self, buf: &mut BytesMut) -> Result<()> {
|
|
buf.put_u8(REPLICATION_FEEDBACK_FIELDS_NUMBER); // # of keys
|
|
buf.put_slice(b"current_timeline_size\0");
|
|
buf.put_i32(8);
|
|
buf.put_u64(self.current_timeline_size);
|
|
|
|
buf.put_slice(b"ps_writelsn\0");
|
|
buf.put_i32(8);
|
|
buf.put_u64(self.ps_writelsn);
|
|
buf.put_slice(b"ps_flushlsn\0");
|
|
buf.put_i32(8);
|
|
buf.put_u64(self.ps_flushlsn);
|
|
buf.put_slice(b"ps_applylsn\0");
|
|
buf.put_i32(8);
|
|
buf.put_u64(self.ps_applylsn);
|
|
|
|
let timestamp = self
|
|
.ps_replytime
|
|
.duration_since(*PG_EPOCH)
|
|
.expect("failed to serialize pg_replytime earlier than PG_EPOCH")
|
|
.as_micros() as i64;
|
|
|
|
buf.put_slice(b"ps_replytime\0");
|
|
buf.put_i32(8);
|
|
buf.put_i64(timestamp);
|
|
Ok(())
|
|
}
|
|
|
|
// Deserialize ReplicationFeedback message
|
|
pub fn parse(mut buf: Bytes) -> ReplicationFeedback {
|
|
let mut rf = ReplicationFeedback::empty();
|
|
let nfields = buf.get_u8();
|
|
for _ in 0..nfields {
|
|
let key = read_cstr(&mut buf).unwrap();
|
|
match key.as_ref() {
|
|
b"current_timeline_size" => {
|
|
let len = buf.get_i32();
|
|
assert_eq!(len, 8);
|
|
rf.current_timeline_size = buf.get_u64();
|
|
}
|
|
b"ps_writelsn" => {
|
|
let len = buf.get_i32();
|
|
assert_eq!(len, 8);
|
|
rf.ps_writelsn = buf.get_u64();
|
|
}
|
|
b"ps_flushlsn" => {
|
|
let len = buf.get_i32();
|
|
assert_eq!(len, 8);
|
|
rf.ps_flushlsn = buf.get_u64();
|
|
}
|
|
b"ps_applylsn" => {
|
|
let len = buf.get_i32();
|
|
assert_eq!(len, 8);
|
|
rf.ps_applylsn = buf.get_u64();
|
|
}
|
|
b"ps_replytime" => {
|
|
let len = buf.get_i32();
|
|
assert_eq!(len, 8);
|
|
let raw_time = buf.get_i64();
|
|
if raw_time > 0 {
|
|
rf.ps_replytime = *PG_EPOCH + Duration::from_micros(raw_time as u64);
|
|
} else {
|
|
rf.ps_replytime = *PG_EPOCH - Duration::from_micros(-raw_time as u64);
|
|
}
|
|
}
|
|
_ => {
|
|
let len = buf.get_i32();
|
|
warn!(
|
|
"ReplicationFeedback parse. unknown key {} of len {len}. Skip it.",
|
|
String::from_utf8_lossy(key.as_ref())
|
|
);
|
|
buf.advance(len as usize);
|
|
}
|
|
}
|
|
}
|
|
trace!("ReplicationFeedback parsed is {:?}", rf);
|
|
rf
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_replication_feedback_serialization() {
|
|
let mut rf = ReplicationFeedback::empty();
|
|
// Fill rf with some values
|
|
rf.current_timeline_size = 12345678;
|
|
// Set rounded time to be able to compare it with deserialized value,
|
|
// because it is rounded up to microseconds during serialization.
|
|
rf.ps_replytime = *PG_EPOCH + Duration::from_secs(100_000_000);
|
|
let mut data = BytesMut::new();
|
|
rf.serialize(&mut data).unwrap();
|
|
|
|
let rf_parsed = ReplicationFeedback::parse(data.freeze());
|
|
assert_eq!(rf, rf_parsed);
|
|
}
|
|
|
|
#[test]
|
|
fn test_replication_feedback_unknown_key() {
|
|
let mut rf = ReplicationFeedback::empty();
|
|
// Fill rf with some values
|
|
rf.current_timeline_size = 12345678;
|
|
// Set rounded time to be able to compare it with deserialized value,
|
|
// because it is rounded up to microseconds during serialization.
|
|
rf.ps_replytime = *PG_EPOCH + Duration::from_secs(100_000_000);
|
|
let mut data = BytesMut::new();
|
|
rf.serialize(&mut data).unwrap();
|
|
|
|
// Add an extra field to the buffer and adjust number of keys
|
|
if let Some(first) = data.first_mut() {
|
|
*first = REPLICATION_FEEDBACK_FIELDS_NUMBER + 1;
|
|
}
|
|
|
|
data.put_slice(b"new_field_one\0");
|
|
data.put_i32(8);
|
|
data.put_u64(42);
|
|
|
|
// Parse serialized data and check that new field is not parsed
|
|
let rf_parsed = ReplicationFeedback::parse(data.freeze());
|
|
assert_eq!(rf, rf_parsed);
|
|
}
|
|
|
|
#[test]
|
|
fn test_startup_message_params_options_escaped() {
|
|
fn split_options(params: &StartupMessageParams) -> Vec<Cow<'_, str>> {
|
|
params
|
|
.options_escaped()
|
|
.expect("options are None")
|
|
.collect()
|
|
}
|
|
|
|
let make_params = |options| StartupMessageParams::new([("options", options)]);
|
|
|
|
let params = StartupMessageParams::new([]);
|
|
assert!(matches!(params.options_escaped(), None));
|
|
|
|
let params = make_params("");
|
|
assert!(split_options(¶ms).is_empty());
|
|
|
|
let params = make_params("foo");
|
|
assert_eq!(split_options(¶ms), ["foo"]);
|
|
|
|
let params = make_params(" foo bar ");
|
|
assert_eq!(split_options(¶ms), ["foo", "bar"]);
|
|
|
|
let params = make_params("foo\\ bar \\ \\\\ baz\\ lol");
|
|
assert_eq!(split_options(¶ms), ["foo bar", " \\", "baz ", "lol"]);
|
|
}
|
|
|
|
// Make sure that `read` is sync/async callable
|
|
async fn _assert(stream: &mut (impl tokio::io::AsyncRead + Unpin)) {
|
|
let _ = FeMessage::read(&mut [].as_ref());
|
|
let _ = FeMessage::read_fut(stream).await;
|
|
|
|
let _ = FeStartupPacket::read(&mut [].as_ref());
|
|
let _ = FeStartupPacket::read_fut(stream).await;
|
|
}
|
|
}
|