Files
neon/libs/pq_proto/src/lib.rs
2022-11-03 22:56:04 +03:00

1082 lines
36 KiB
Rust

//! Postgres protocol messages serialization-deserialization. See
//! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
//! on message formats.
// Tools for calling certain async methods in sync contexts.
pub mod sync;
use anyhow::{bail, ensure, Context, Result};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use postgres_protocol::PG_EPOCH;
use serde::{Deserialize, Serialize};
use std::{
borrow::Cow,
collections::HashMap,
fmt,
future::Future,
io::{self, Cursor},
str,
time::{Duration, SystemTime},
};
use sync::{AsyncishRead, SyncFuture};
use tokio::io::AsyncReadExt;
use tracing::{trace, warn};
pub type Oid = u32;
pub type SystemId = u64;
pub const INT8_OID: Oid = 20;
pub const INT4_OID: Oid = 23;
pub const TEXT_OID: Oid = 25;
#[derive(Debug)]
pub enum FeMessage {
StartupPacket(FeStartupPacket),
// Simple query.
Query(Bytes),
// Extended query protocol.
Parse(FeParseMessage),
Describe(FeDescribeMessage),
Bind(FeBindMessage),
Execute(FeExecuteMessage),
Close(FeCloseMessage),
Sync,
Terminate,
CopyData(Bytes),
CopyDone,
CopyFail,
PasswordMessage(Bytes),
}
#[derive(Debug)]
pub enum FeStartupPacket {
CancelRequest(CancelKeyData),
SslRequest,
GssEncRequest,
StartupMessage {
major_version: u32,
minor_version: u32,
params: StartupMessageParams,
},
}
#[derive(Debug)]
pub struct StartupMessageParams {
params: HashMap<String, String>,
}
impl StartupMessageParams {
/// Get parameter's value by its name.
pub fn get(&self, name: &str) -> Option<&str> {
self.params.get(name).map(|s| s.as_str())
}
/// Split command-line options according to PostgreSQL's logic,
/// taking into account all escape sequences but leaving them as-is.
/// [`None`] means that there's no `options` in [`Self`].
pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> {
// See `postgres: pg_split_opts`.
let mut last_was_escape = false;
let iter = self
.get("options")?
.split(move |c: char| {
// We split by non-escaped whitespace symbols.
let should_split = c.is_ascii_whitespace() && !last_was_escape;
last_was_escape = c == '\\' && !last_was_escape;
should_split
})
.filter(|s| !s.is_empty());
Some(iter)
}
/// Split command-line options according to PostgreSQL's logic,
/// applying all escape sequences (using owned strings as needed).
/// [`None`] means that there's no `options` in [`Self`].
pub fn options_escaped(&self) -> Option<impl Iterator<Item = Cow<'_, str>>> {
// See `postgres: pg_split_opts`.
let iter = self.options_raw()?.map(|s| {
let mut preserve_next_escape = false;
let escape = |c| {
// We should remove '\\' unless it's preceded by '\\'.
let should_remove = c == '\\' && !preserve_next_escape;
preserve_next_escape = should_remove;
should_remove
};
match s.contains('\\') {
true => Cow::Owned(s.replace(escape, "")),
false => Cow::Borrowed(s),
}
});
Some(iter)
}
// This function is mostly useful in tests.
#[doc(hidden)]
pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self {
Self {
params: pairs.map(|(k, v)| (k.to_owned(), v.to_owned())).into(),
}
}
}
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
pub struct CancelKeyData {
pub backend_pid: i32,
pub cancel_key: i32,
}
impl fmt::Display for CancelKeyData {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let hi = (self.backend_pid as u64) << 32;
let lo = self.cancel_key as u64;
let id = hi | lo;
// This format is more compact and might work better for logs.
f.debug_tuple("CancelKeyData")
.field(&format_args!("{:x}", id))
.finish()
}
}
use rand::distributions::{Distribution, Standard};
impl Distribution<CancelKeyData> for Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
CancelKeyData {
backend_pid: rng.gen(),
cancel_key: rng.gen(),
}
}
}
// We only support the simple case of Parse on unnamed prepared statement and
// no params
#[derive(Debug)]
pub struct FeParseMessage {
pub query_string: Bytes,
}
#[derive(Debug)]
pub struct FeDescribeMessage {
pub kind: u8, // 'S' to describe a prepared statement; or 'P' to describe a portal.
// we only support unnamed prepared stmt or portal
}
// we only support unnamed prepared stmt and portal
#[derive(Debug)]
pub struct FeBindMessage;
// we only support unnamed prepared stmt or portal
#[derive(Debug)]
pub struct FeExecuteMessage {
/// max # of rows
pub maxrows: i32,
}
// we only support unnamed prepared stmt and portal
#[derive(Debug)]
pub struct FeCloseMessage;
/// Retry a read on EINTR
///
/// This runs the enclosed expression, and if it returns
/// Err(io::ErrorKind::Interrupted), retries it.
macro_rules! retry_read {
( $x:expr ) => {
loop {
match $x {
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
res => break res,
}
}
};
}
impl FeMessage {
/// Read one message from the stream.
/// This function returns `Ok(None)` in case of EOF.
/// One way to handle this properly:
///
/// ```
/// # use std::io;
/// # use pq_proto::FeMessage;
/// #
/// # fn process_message(msg: FeMessage) -> anyhow::Result<()> {
/// # Ok(())
/// # };
/// #
/// fn do_the_job(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result<()> {
/// while let Some(msg) = FeMessage::read(stream)? {
/// process_message(msg)?;
/// }
///
/// Ok(())
/// }
/// ```
#[inline(never)]
pub fn read(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result<Option<FeMessage>> {
Self::read_fut(&mut AsyncishRead(stream)).wait()
}
/// Read one message from the stream.
/// See documentation for `Self::read`.
pub fn read_fut<Reader>(
stream: &mut Reader,
) -> SyncFuture<Reader, impl Future<Output = anyhow::Result<Option<FeMessage>>> + '_>
where
Reader: tokio::io::AsyncRead + Unpin,
{
// We return a Future that's sync (has a `wait` method) if and only if the provided stream is SyncProof.
// SyncFuture contract: we are only allowed to await on sync-proof futures, the AsyncRead and
// AsyncReadExt methods of the stream.
SyncFuture::new(async move {
// Each libpq message begins with a message type byte, followed by message length
// If the client closes the connection, return None. But if the client closes the
// connection in the middle of a message, we will return an error.
let tag = match retry_read!(stream.read_u8().await) {
Ok(b) => b,
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(e.into()),
};
// The message length includes itself, so it better be at least 4.
let len = retry_read!(stream.read_u32().await)?
.checked_sub(4)
.context("invalid message length")?;
let body = {
let mut buffer = vec![0u8; len as usize];
stream.read_exact(&mut buffer).await?;
Bytes::from(buffer)
};
match tag {
b'Q' => Ok(Some(FeMessage::Query(body))),
b'P' => Ok(Some(FeParseMessage::parse(body)?)),
b'D' => Ok(Some(FeDescribeMessage::parse(body)?)),
b'E' => Ok(Some(FeExecuteMessage::parse(body)?)),
b'B' => Ok(Some(FeBindMessage::parse(body)?)),
b'C' => Ok(Some(FeCloseMessage::parse(body)?)),
b'S' => Ok(Some(FeMessage::Sync)),
b'X' => Ok(Some(FeMessage::Terminate)),
b'd' => Ok(Some(FeMessage::CopyData(body))),
b'c' => Ok(Some(FeMessage::CopyDone)),
b'f' => Ok(Some(FeMessage::CopyFail)),
b'p' => Ok(Some(FeMessage::PasswordMessage(body))),
tag => bail!("unknown message tag: {},'{:?}'", tag, body),
}
})
}
}
impl FeStartupPacket {
/// Read startup message from the stream.
// XXX: It's tempting yet undesirable to accept `stream` by value,
// since such a change will cause user-supplied &mut references to be consumed
pub fn read(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result<Option<FeMessage>> {
Self::read_fut(&mut AsyncishRead(stream)).wait()
}
/// Read startup message from the stream.
// XXX: It's tempting yet undesirable to accept `stream` by value,
// since such a change will cause user-supplied &mut references to be consumed
pub fn read_fut<Reader>(
stream: &mut Reader,
) -> SyncFuture<Reader, impl Future<Output = anyhow::Result<Option<FeMessage>>> + '_>
where
Reader: tokio::io::AsyncRead + Unpin,
{
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
const CANCEL_REQUEST_CODE: u32 = 5678;
const NEGOTIATE_SSL_CODE: u32 = 5679;
const NEGOTIATE_GSS_CODE: u32 = 5680;
SyncFuture::new(async move {
// Read length. If the connection is closed before reading anything (or before
// reading 4 bytes, to be precise), return None to indicate that the connection
// was closed. This matches the PostgreSQL server's behavior, which avoids noise
// in the log if the client opens connection but closes it immediately.
let len = match retry_read!(stream.read_u32().await) {
Ok(len) => len as usize,
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(e.into()),
};
#[allow(clippy::manual_range_contains)]
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
bail!("invalid message length");
}
let request_code = retry_read!(stream.read_u32().await)?;
// the rest of startup packet are params
let params_len = len - 8;
let mut params_bytes = vec![0u8; params_len];
stream.read_exact(params_bytes.as_mut()).await?;
// Parse params depending on request code
let req_hi = request_code >> 16;
let req_lo = request_code & ((1 << 16) - 1);
let message = match (req_hi, req_lo) {
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
ensure!(params_len == 8, "expected 8 bytes for CancelRequest params");
let mut cursor = Cursor::new(params_bytes);
FeStartupPacket::CancelRequest(CancelKeyData {
backend_pid: cursor.read_i32().await?,
cancel_key: cursor.read_i32().await?,
})
}
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
// Requested upgrade to SSL (aka TLS)
FeStartupPacket::SslRequest
}
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
// Requested upgrade to GSSAPI
FeStartupPacket::GssEncRequest
}
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
bail!("Unrecognized request code {}", unrecognized_code)
}
// TODO bail if protocol major_version is not 3?
(major_version, minor_version) => {
// Parse pairs of null-terminated strings (key, value).
// See `postgres: ProcessStartupPacket, build_startup_packet`.
let mut tokens = str::from_utf8(&params_bytes)
.context("StartupMessage params: invalid utf-8")?
.strip_suffix('\0') // drop packet's own null terminator
.context("StartupMessage params: missing null terminator")?
.split_terminator('\0');
let mut params = HashMap::new();
while let Some(name) = tokens.next() {
let value = tokens
.next()
.context("StartupMessage params: key without value")?;
params.insert(name.to_owned(), value.to_owned());
}
FeStartupPacket::StartupMessage {
major_version,
minor_version,
params: StartupMessageParams { params },
}
}
};
Ok(Some(FeMessage::StartupPacket(message)))
})
}
}
impl FeParseMessage {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
// FIXME: the rust-postgres driver uses a named prepared statement
// for copy_out(). We're not prepared to handle that correctly. For
// now, just ignore the statement name, assuming that the client never
// uses more than one prepared statement at a time.
let _pstmt_name = read_cstr(&mut buf)?;
let query_string = read_cstr(&mut buf)?;
let nparams = buf.get_i16();
ensure!(nparams == 0, "query params not implemented");
Ok(FeMessage::Parse(FeParseMessage { query_string }))
}
}
impl FeDescribeMessage {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let kind = buf.get_u8();
let _pstmt_name = read_cstr(&mut buf)?;
// FIXME: see FeParseMessage::parse
ensure!(
kind == b'S',
"only prepared statemement Describe is implemented"
);
Ok(FeMessage::Describe(FeDescribeMessage { kind }))
}
}
impl FeExecuteMessage {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let portal_name = read_cstr(&mut buf)?;
let maxrows = buf.get_i32();
ensure!(portal_name.is_empty(), "named portals not implemented");
ensure!(maxrows == 0, "row limit in Execute message not implemented");
Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
}
}
impl FeBindMessage {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let portal_name = read_cstr(&mut buf)?;
let _pstmt_name = read_cstr(&mut buf)?;
// FIXME: see FeParseMessage::parse
ensure!(portal_name.is_empty(), "named portals not implemented");
Ok(FeMessage::Bind(FeBindMessage))
}
}
impl FeCloseMessage {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let _kind = buf.get_u8();
let _pstmt_or_portal_name = read_cstr(&mut buf)?;
// FIXME: we do nothing with Close
Ok(FeMessage::Close(FeCloseMessage))
}
}
// Backend
#[derive(Debug)]
pub enum BeMessage<'a> {
AuthenticationOk,
AuthenticationMD5Password([u8; 4]),
AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
AuthenticationCleartextPassword,
BackendKeyData(CancelKeyData),
BindComplete,
CommandComplete(&'a [u8]),
CopyData(&'a [u8]),
CopyDone,
CopyFail,
CopyInResponse,
CopyOutResponse,
CopyBothResponse,
CloseComplete,
// None means column is NULL
DataRow(&'a [Option<&'a [u8]>]),
ErrorResponse(&'a str),
/// Single byte - used in response to SSLRequest/GSSENCRequest.
EncryptionResponse(bool),
NoData,
ParameterDescription,
ParameterStatus(BeParameterStatusMessage<'a>),
ParseComplete,
ReadyForQuery,
RowDescription(&'a [RowDescriptor<'a>]),
XLogData(XLogDataBody<'a>),
NoticeResponse(&'a str),
KeepAlive(WalSndKeepAlive),
}
#[derive(Debug)]
pub enum BeAuthenticationSaslMessage<'a> {
Methods(&'a [&'a str]),
Continue(&'a [u8]),
Final(&'a [u8]),
}
#[derive(Debug)]
pub enum BeParameterStatusMessage<'a> {
Encoding(&'a str),
ServerVersion(&'a str),
}
impl BeParameterStatusMessage<'static> {
pub fn encoding() -> BeMessage<'static> {
BeMessage::ParameterStatus(Self::Encoding("UTF8"))
}
}
// One row description in RowDescription packet.
#[derive(Debug)]
pub struct RowDescriptor<'a> {
pub name: &'a [u8],
pub tableoid: Oid,
pub attnum: i16,
pub typoid: Oid,
pub typlen: i16,
pub typmod: i32,
pub formatcode: i16,
}
impl Default for RowDescriptor<'_> {
fn default() -> RowDescriptor<'static> {
RowDescriptor {
name: b"",
tableoid: 0,
attnum: 0,
typoid: 0,
typlen: 0,
typmod: 0,
formatcode: 0,
}
}
}
impl RowDescriptor<'_> {
/// Convenience function to create a RowDescriptor message for an int8 column
pub const fn int8_col(name: &[u8]) -> RowDescriptor {
RowDescriptor {
name,
tableoid: 0,
attnum: 0,
typoid: INT8_OID,
typlen: 8,
typmod: 0,
formatcode: 0,
}
}
pub const fn text_col(name: &[u8]) -> RowDescriptor {
RowDescriptor {
name,
tableoid: 0,
attnum: 0,
typoid: TEXT_OID,
typlen: -1,
typmod: 0,
formatcode: 0,
}
}
}
#[derive(Debug)]
pub struct XLogDataBody<'a> {
pub wal_start: u64,
pub wal_end: u64,
pub timestamp: i64,
pub data: &'a [u8],
}
#[derive(Debug)]
pub struct WalSndKeepAlive {
pub sent_ptr: u64,
pub timestamp: i64,
pub request_reply: bool,
}
pub static HELLO_WORLD_ROW: BeMessage = BeMessage::DataRow(&[Some(b"hello world")]);
// single text column
pub static SINGLE_COL_ROWDESC: BeMessage = BeMessage::RowDescription(&[RowDescriptor {
name: b"data",
tableoid: 0,
attnum: 0,
typoid: TEXT_OID,
typlen: -1,
typmod: 0,
formatcode: 0,
}]);
/// Call f() to write body of the message and prepend it with 4-byte len as
/// prescribed by the protocol.
fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
let base = buf.len();
buf.extend_from_slice(&[0; 4]);
let res = f(buf);
let size = i32::try_from(buf.len() - base).expect("message too big to transmit");
(&mut buf[base..]).put_slice(&size.to_be_bytes());
res
}
/// Safe write of s into buf as cstring (String in the protocol).
fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> {
if s.contains(&0) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"string contains embedded null",
));
}
buf.put_slice(s);
buf.put_u8(0);
Ok(())
}
fn read_cstr(buf: &mut Bytes) -> anyhow::Result<Bytes> {
let pos = buf.iter().position(|x| *x == 0);
let result = buf.split_to(pos.context("missing terminator")?);
buf.advance(1); // drop the null terminator
Ok(result)
}
impl<'a> BeMessage<'a> {
/// Write message to the given buf.
// Unlike the reading side, we use BytesMut
// here as msg len precedes its body and it is handy to write it down first
// and then fill the length. With Write we would have to either calc it
// manually or have one more buffer.
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> io::Result<()> {
match message {
BeMessage::AuthenticationOk => {
buf.put_u8(b'R');
write_body(buf, |buf| {
buf.put_i32(0); // Specifies that the authentication was successful.
});
}
BeMessage::AuthenticationCleartextPassword => {
buf.put_u8(b'R');
write_body(buf, |buf| {
buf.put_i32(3); // Specifies that clear text password is required.
});
}
BeMessage::AuthenticationMD5Password(salt) => {
buf.put_u8(b'R');
write_body(buf, |buf| {
buf.put_i32(5); // Specifies that an MD5-encrypted password is required.
buf.put_slice(&salt[..]);
});
}
BeMessage::AuthenticationSasl(msg) => {
buf.put_u8(b'R');
write_body(buf, |buf| {
use BeAuthenticationSaslMessage::*;
match msg {
Methods(methods) => {
buf.put_i32(10); // Specifies that SASL auth method is used.
for method in methods.iter() {
write_cstr(method.as_bytes(), buf)?;
}
buf.put_u8(0); // zero terminator for the list
}
Continue(extra) => {
buf.put_i32(11); // Continue SASL auth.
buf.put_slice(extra);
}
Final(extra) => {
buf.put_i32(12); // Send final SASL message.
buf.put_slice(extra);
}
}
Ok::<_, io::Error>(())
})?;
}
BeMessage::BackendKeyData(key_data) => {
buf.put_u8(b'K');
write_body(buf, |buf| {
buf.put_i32(key_data.backend_pid);
buf.put_i32(key_data.cancel_key);
});
}
BeMessage::BindComplete => {
buf.put_u8(b'2');
write_body(buf, |_| {});
}
BeMessage::CloseComplete => {
buf.put_u8(b'3');
write_body(buf, |_| {});
}
BeMessage::CommandComplete(cmd) => {
buf.put_u8(b'C');
write_body(buf, |buf| write_cstr(cmd, buf))?;
}
BeMessage::CopyData(data) => {
buf.put_u8(b'd');
write_body(buf, |buf| {
buf.put_slice(data);
});
}
BeMessage::CopyDone => {
buf.put_u8(b'c');
write_body(buf, |_| {});
}
BeMessage::CopyFail => {
buf.put_u8(b'f');
write_body(buf, |_| {});
}
BeMessage::CopyInResponse => {
buf.put_u8(b'G');
write_body(buf, |buf| {
buf.put_u8(1); // copy_is_binary
buf.put_i16(0); // numAttributes
});
}
BeMessage::CopyOutResponse => {
buf.put_u8(b'H');
write_body(buf, |buf| {
buf.put_u8(0); // copy_is_binary
buf.put_i16(0); // numAttributes
});
}
BeMessage::CopyBothResponse => {
buf.put_u8(b'W');
write_body(buf, |buf| {
// doesn't matter, used only for replication
buf.put_u8(0); // copy_is_binary
buf.put_i16(0); // numAttributes
});
}
BeMessage::DataRow(vals) => {
buf.put_u8(b'D');
write_body(buf, |buf| {
buf.put_u16(vals.len() as u16); // num of cols
for val_opt in vals.iter() {
if let Some(val) = val_opt {
buf.put_u32(val.len() as u32);
buf.put_slice(val);
} else {
buf.put_i32(-1);
}
}
});
}
// ErrorResponse is a zero-terminated array of zero-terminated fields.
// First byte of each field represents type of this field. Set just enough fields
// to satisfy rust-postgres client: 'S' -- severity, 'C' -- error, 'M' -- error
// message text.
BeMessage::ErrorResponse(error_msg) => {
// For all the errors set Severity to Error and error code to
// 'internal error'.
// 'E' signalizes ErrorResponse messages
buf.put_u8(b'E');
write_body(buf, |buf| {
buf.put_u8(b'S'); // severity
buf.put_slice(b"ERROR\0");
buf.put_u8(b'C'); // SQLSTATE error code
buf.put_slice(b"CXX000\0");
buf.put_u8(b'M'); // the message
write_cstr(error_msg.as_bytes(), buf)?;
buf.put_u8(0); // terminator
Ok::<_, io::Error>(())
})?;
}
// NoticeResponse has the same format as ErrorResponse. From doc: "The frontend should display the
// message but continue listening for ReadyForQuery or ErrorResponse"
BeMessage::NoticeResponse(error_msg) => {
// For all the errors set Severity to Error and error code to
// 'internal error'.
// 'N' signalizes NoticeResponse messages
buf.put_u8(b'N');
write_body(buf, |buf| {
buf.put_u8(b'S'); // severity
buf.put_slice(b"NOTICE\0");
buf.put_u8(b'C'); // SQLSTATE error code
buf.put_slice(b"CXX000\0");
buf.put_u8(b'M'); // the message
write_cstr(error_msg.as_bytes(), buf)?;
buf.put_u8(0); // terminator
Ok::<_, io::Error>(())
})?;
}
BeMessage::NoData => {
buf.put_u8(b'n');
write_body(buf, |_| {});
}
BeMessage::EncryptionResponse(should_negotiate) => {
let response = if *should_negotiate { b'S' } else { b'N' };
buf.put_u8(response);
}
BeMessage::ParameterStatus(param) => {
use std::io::{IoSlice, Write};
use BeParameterStatusMessage::*;
let [name, value] = match param {
Encoding(name) => [b"client_encoding", name.as_bytes()],
ServerVersion(version) => [b"server_version", version.as_bytes()],
};
// Parameter names and values are passed as null-terminated strings
let iov = &mut [name, b"\0", value, b"\0"].map(IoSlice::new);
let mut buffer = [0u8; 64]; // this should be enough
let cnt = buffer.as_mut().write_vectored(iov).unwrap();
buf.put_u8(b'S');
write_body(buf, |buf| {
buf.put_slice(&buffer[..cnt]);
});
}
BeMessage::ParameterDescription => {
buf.put_u8(b't');
write_body(buf, |buf| {
// we don't support params, so always 0
buf.put_i16(0);
});
}
BeMessage::ParseComplete => {
buf.put_u8(b'1');
write_body(buf, |_| {});
}
BeMessage::ReadyForQuery => {
buf.put_u8(b'Z');
write_body(buf, |buf| {
buf.put_u8(b'I');
});
}
BeMessage::RowDescription(rows) => {
buf.put_u8(b'T');
write_body(buf, |buf| {
buf.put_i16(rows.len() as i16); // # of fields
for row in rows.iter() {
write_cstr(row.name, buf)?;
buf.put_i32(0); /* table oid */
buf.put_i16(0); /* attnum */
buf.put_u32(row.typoid);
buf.put_i16(row.typlen);
buf.put_i32(-1); /* typmod */
buf.put_i16(0); /* format code */
}
Ok::<_, io::Error>(())
})?;
}
BeMessage::XLogData(body) => {
buf.put_u8(b'd');
write_body(buf, |buf| {
buf.put_u8(b'w');
buf.put_u64(body.wal_start);
buf.put_u64(body.wal_end);
buf.put_i64(body.timestamp);
buf.put_slice(body.data);
});
}
BeMessage::KeepAlive(req) => {
buf.put_u8(b'd');
write_body(buf, |buf| {
buf.put_u8(b'k');
buf.put_u64(req.sent_ptr);
buf.put_i64(req.timestamp);
buf.put_u8(if req.request_reply { 1 } else { 0 });
});
}
}
Ok(())
}
}
// Neon extension of postgres replication protocol
// See NEON_STATUS_UPDATE_TAG_BYTE
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct ReplicationFeedback {
// Last known size of the timeline. Used to enforce timeline size limit.
pub current_timeline_size: u64,
// Parts of StandbyStatusUpdate we resend to compute via safekeeper
pub ps_writelsn: u64,
pub ps_applylsn: u64,
pub ps_flushlsn: u64,
pub ps_replytime: SystemTime,
}
// NOTE: Do not forget to increment this number when adding new fields to ReplicationFeedback.
// Do not remove previously available fields because this might be backwards incompatible.
pub const REPLICATION_FEEDBACK_FIELDS_NUMBER: u8 = 5;
impl ReplicationFeedback {
pub fn empty() -> ReplicationFeedback {
ReplicationFeedback {
current_timeline_size: 0,
ps_writelsn: 0,
ps_applylsn: 0,
ps_flushlsn: 0,
ps_replytime: SystemTime::now(),
}
}
// Serialize ReplicationFeedback using custom format
// to support protocol extensibility.
//
// Following layout is used:
// char - number of key-value pairs that follow.
//
// key-value pairs:
// null-terminated string - key,
// uint32 - value length in bytes
// value itself
pub fn serialize(&self, buf: &mut BytesMut) -> Result<()> {
buf.put_u8(REPLICATION_FEEDBACK_FIELDS_NUMBER); // # of keys
buf.put_slice(b"current_timeline_size\0");
buf.put_i32(8);
buf.put_u64(self.current_timeline_size);
buf.put_slice(b"ps_writelsn\0");
buf.put_i32(8);
buf.put_u64(self.ps_writelsn);
buf.put_slice(b"ps_flushlsn\0");
buf.put_i32(8);
buf.put_u64(self.ps_flushlsn);
buf.put_slice(b"ps_applylsn\0");
buf.put_i32(8);
buf.put_u64(self.ps_applylsn);
let timestamp = self
.ps_replytime
.duration_since(*PG_EPOCH)
.expect("failed to serialize pg_replytime earlier than PG_EPOCH")
.as_micros() as i64;
buf.put_slice(b"ps_replytime\0");
buf.put_i32(8);
buf.put_i64(timestamp);
Ok(())
}
// Deserialize ReplicationFeedback message
pub fn parse(mut buf: Bytes) -> ReplicationFeedback {
let mut rf = ReplicationFeedback::empty();
let nfields = buf.get_u8();
for _ in 0..nfields {
let key = read_cstr(&mut buf).unwrap();
match key.as_ref() {
b"current_timeline_size" => {
let len = buf.get_i32();
assert_eq!(len, 8);
rf.current_timeline_size = buf.get_u64();
}
b"ps_writelsn" => {
let len = buf.get_i32();
assert_eq!(len, 8);
rf.ps_writelsn = buf.get_u64();
}
b"ps_flushlsn" => {
let len = buf.get_i32();
assert_eq!(len, 8);
rf.ps_flushlsn = buf.get_u64();
}
b"ps_applylsn" => {
let len = buf.get_i32();
assert_eq!(len, 8);
rf.ps_applylsn = buf.get_u64();
}
b"ps_replytime" => {
let len = buf.get_i32();
assert_eq!(len, 8);
let raw_time = buf.get_i64();
if raw_time > 0 {
rf.ps_replytime = *PG_EPOCH + Duration::from_micros(raw_time as u64);
} else {
rf.ps_replytime = *PG_EPOCH - Duration::from_micros(-raw_time as u64);
}
}
_ => {
let len = buf.get_i32();
warn!(
"ReplicationFeedback parse. unknown key {} of len {len}. Skip it.",
String::from_utf8_lossy(key.as_ref())
);
buf.advance(len as usize);
}
}
}
trace!("ReplicationFeedback parsed is {:?}", rf);
rf
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_replication_feedback_serialization() {
let mut rf = ReplicationFeedback::empty();
// Fill rf with some values
rf.current_timeline_size = 12345678;
// Set rounded time to be able to compare it with deserialized value,
// because it is rounded up to microseconds during serialization.
rf.ps_replytime = *PG_EPOCH + Duration::from_secs(100_000_000);
let mut data = BytesMut::new();
rf.serialize(&mut data).unwrap();
let rf_parsed = ReplicationFeedback::parse(data.freeze());
assert_eq!(rf, rf_parsed);
}
#[test]
fn test_replication_feedback_unknown_key() {
let mut rf = ReplicationFeedback::empty();
// Fill rf with some values
rf.current_timeline_size = 12345678;
// Set rounded time to be able to compare it with deserialized value,
// because it is rounded up to microseconds during serialization.
rf.ps_replytime = *PG_EPOCH + Duration::from_secs(100_000_000);
let mut data = BytesMut::new();
rf.serialize(&mut data).unwrap();
// Add an extra field to the buffer and adjust number of keys
if let Some(first) = data.first_mut() {
*first = REPLICATION_FEEDBACK_FIELDS_NUMBER + 1;
}
data.put_slice(b"new_field_one\0");
data.put_i32(8);
data.put_u64(42);
// Parse serialized data and check that new field is not parsed
let rf_parsed = ReplicationFeedback::parse(data.freeze());
assert_eq!(rf, rf_parsed);
}
#[test]
fn test_startup_message_params_options_escaped() {
fn split_options(params: &StartupMessageParams) -> Vec<Cow<'_, str>> {
params
.options_escaped()
.expect("options are None")
.collect()
}
let make_params = |options| StartupMessageParams::new([("options", options)]);
let params = StartupMessageParams::new([]);
assert!(matches!(params.options_escaped(), None));
let params = make_params("");
assert!(split_options(&params).is_empty());
let params = make_params("foo");
assert_eq!(split_options(&params), ["foo"]);
let params = make_params(" foo bar ");
assert_eq!(split_options(&params), ["foo", "bar"]);
let params = make_params("foo\\ bar \\ \\\\ baz\\ lol");
assert_eq!(split_options(&params), ["foo bar", " \\", "baz ", "lol"]);
}
// Make sure that `read` is sync/async callable
async fn _assert(stream: &mut (impl tokio::io::AsyncRead + Unpin)) {
let _ = FeMessage::read(&mut [].as_ref());
let _ = FeMessage::read_fut(stream).await;
let _ = FeStartupPacket::read(&mut [].as_ref());
let _ = FeStartupPacket::read_fut(stream).await;
}
}