Consolidate PG proto parsing-deparsing and backend code.

Now postgres_backend communicates with the client, passing queries to the
provided handler; we have two currently, for wal_acceptor and pageserver.

Now BytesMut is again used for writing data to avoid manual message length
calculation.

ref #118
This commit is contained in:
Arseny Sher
2021-05-31 09:26:26 +03:00
committed by arssher
parent 2b0193e6bf
commit b2f51026aa
9 changed files with 1188 additions and 971 deletions

3
Cargo.lock generated
View File

@@ -2478,9 +2478,12 @@ dependencies = [
name = "zenith_utils"
version = "0.1.0"
dependencies = [
"anyhow",
"bincode",
"byteorder",
"bytes",
"hex-literal",
"log",
"serde",
"thiserror",
"workspace_hack",

File diff suppressed because it is too large Load Diff

View File

@@ -1,20 +1,18 @@
//! This module implements the replication protocol, starting with the
//! "START REPLICATION" message.
//! This module implements the streaming side of replication protocol, starting
//! with the "START REPLICATION" message.
use crate::pq_protocol::{BeMessage, FeMessage};
use crate::send_wal::SendWalConn;
use crate::send_wal::SendWalHandler;
use crate::timeline::{Timeline, TimelineTools};
use crate::WalAcceptorConf;
use anyhow::{anyhow, Result};
use bytes::{BufMut, Bytes, BytesMut};
use bytes::Bytes;
use log::*;
use postgres_ffi::xlog_utils::{get_current_timestamp, TimestampTz, XLogFileName, MAX_SEND_SIZE};
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::cmp::min;
use std::fs::File;
use std::io::{BufReader, Read, Seek, SeekFrom, Write};
use std::net::{Shutdown, TcpStream};
use std::io::{BufReader, Read, Seek, SeekFrom};
use std::net::TcpStream;
use std::path::Path;
use std::sync::Arc;
use std::thread::sleep;
@@ -22,10 +20,9 @@ use std::time::Duration;
use std::{str, thread};
use zenith_utils::bin_ser::BeSer;
use zenith_utils::lsn::Lsn;
use zenith_utils::postgres_backend::PostgresBackend;
use zenith_utils::pq_proto::{BeMessage, FeMessage, XLogDataBody};
const XLOG_HDR_SIZE: usize = 1 + 8 * 3; /* 'w' + startPos + walEnd + timestamp */
const LIBPQ_HDR_SIZE: usize = 5; /* 1 byte with message type + 4 bytes length */
const LIBPQ_MSG_SIZE_OFFS: usize = 1;
pub const END_REPLICATION_MARKER: Lsn = Lsn::MAX;
type FullTransactionId = u64;
@@ -40,37 +37,16 @@ pub struct HotStandbyFeedback {
/// A network connection that's speaking the replication protocol.
pub struct ReplicationConn {
timeline: Option<Arc<Timeline>>,
/// Postgres connection, buffered input
///
/// This is an `Option` because we will spawn a background thread that will
/// `take` it from us.
stream_in: Option<BufReader<TcpStream>>,
/// Postgres connection, output
stream_out: TcpStream,
/// wal acceptor configuration
conf: WalAcceptorConf,
/// assigned application name
appname: Option<String>,
}
// Separate thread is reading keepalives from the socket. When main one
// finishes, tell it to get down by shutdowning the socket.
impl Drop for ReplicationConn {
fn drop(&mut self) {
let _res = self.stream_out.shutdown(Shutdown::Both);
}
}
impl ReplicationConn {
/// Create a new `SendWal`, consuming the `Connection`.
pub fn new(conn: SendWalConn) -> Self {
/// Create a new `ReplicationConn`
pub fn new(pgb: &mut PostgresBackend) -> Self {
Self {
timeline: conn.timeline,
stream_in: Some(conn.stream_in),
stream_out: conn.stream_out,
conf: conn.conf,
appname: None,
stream_in: pgb.take_stream_in(),
}
}
@@ -82,9 +58,9 @@ impl ReplicationConn {
// Wait for replica's feedback.
// We only handle `CopyData` messages. Anything else is ignored.
loop {
match FeMessage::read_from(&mut stream_in)? {
FeMessage::CopyData(m) => {
let feedback = HotStandbyFeedback::des(&m.body)?;
match FeMessage::read(&mut stream_in)? {
Some(FeMessage::CopyData(m)) => {
let feedback = HotStandbyFeedback::des(&m)?;
timeline.add_hs_feedback(feedback)
}
msg => {
@@ -128,9 +104,14 @@ impl ReplicationConn {
///
/// Handle START_REPLICATION replication command
///
pub fn run(&mut self, cmd: &Bytes) -> Result<()> {
pub fn run(
&mut self,
swh: &mut SendWalHandler,
pgb: &mut PostgresBackend,
cmd: &Bytes,
) -> Result<()> {
// spawn the background thread which receives HotStandbyFeedback messages.
let bg_timeline = Arc::clone(self.timeline.get());
let bg_timeline = Arc::clone(swh.timeline.get());
let bg_stream_in = self.stream_in.take().unwrap();
thread::spawn(move || {
@@ -143,7 +124,7 @@ impl ReplicationConn {
let mut wal_seg_size: usize;
loop {
wal_seg_size = self.timeline.get().get_info().server.wal_seg_size as usize;
wal_seg_size = swh.timeline.get().get_info().server.wal_seg_size as usize;
if wal_seg_size == 0 {
error!("Can not start replication before connecting to wal_proposer");
sleep(Duration::from_secs(1));
@@ -151,19 +132,17 @@ impl ReplicationConn {
break;
}
}
let (wal_end, timeline) = self.timeline.find_end_of_wal(&self.conf.data_dir, false);
let (wal_end, timeline) = swh.timeline.find_end_of_wal(&swh.conf.data_dir, false);
if start_pos == Lsn(0) {
start_pos = wal_end;
}
if stop_pos == Lsn(0) && self.appname == Some("wal_proposer_recovery".to_string()) {
if stop_pos == Lsn(0) && swh.appname == Some("wal_proposer_recovery".to_string()) {
stop_pos = wal_end;
}
info!("Start replication from {} till {}", start_pos, stop_pos);
let mut outbuf = BytesMut::new();
BeMessage::write(&mut outbuf, &BeMessage::Copy);
self.send(&outbuf)?;
outbuf.clear();
// switch to copy
pgb.write_message(&BeMessage::CopyBothResponse)?;
let mut end_pos: Lsn;
let mut wal_file: Option<File> = None;
@@ -179,7 +158,7 @@ impl ReplicationConn {
end_pos = stop_pos;
} else {
/* normal mode */
let timeline = self.timeline.get();
let timeline = swh.timeline.get();
end_pos = timeline.wait_for_lsn(start_pos);
}
if end_pos == END_REPLICATION_MARKER {
@@ -193,8 +172,8 @@ impl ReplicationConn {
// Open a new file.
let segno = start_pos.segment_number(wal_seg_size);
let wal_file_name = XLogFileName(timeline, segno, wal_seg_size);
let timeline_id = self.timeline.get().timelineid.to_string();
let wal_file_path = self.conf.data_dir.join(timeline_id).join(wal_file_name);
let timeline_id = swh.timeline.get().timelineid.to_string();
let wal_file_path = swh.conf.data_dir.join(timeline_id).join(wal_file_name);
Self::open_wal_file(&wal_file_path)?
}
};
@@ -207,32 +186,19 @@ impl ReplicationConn {
let send_size = min(send_size, wal_seg_size - xlogoff);
let send_size = min(send_size, MAX_SEND_SIZE);
let msg_size = LIBPQ_HDR_SIZE + XLOG_HDR_SIZE + send_size;
// Read some data from the file.
let mut file_buf = vec![0u8; send_size];
file.seek(SeekFrom::Start(xlogoff as u64))?;
file.read_exact(&mut file_buf)?;
// Write some data to the network socket.
// FIXME: turn these into structs.
// 'd' is CopyData;
// 'w' is "WAL records"
// https://www.postgresql.org/docs/9.1/protocol-message-formats.html
// src/backend/replication/walreceiver.c
outbuf.clear();
outbuf.put_u8(b'd');
outbuf.put_u32((msg_size - LIBPQ_MSG_SIZE_OFFS) as u32);
outbuf.put_u8(b'w');
outbuf.put_u64(start_pos.0);
outbuf.put_u64(end_pos.0);
outbuf.put_u64(get_current_timestamp());
pgb.write_message(&BeMessage::XLogData(XLogDataBody {
wal_start: start_pos.0,
wal_end: end_pos.0,
timestamp: get_current_timestamp(),
data: &file_buf,
}))?;
assert!(outbuf.len() + file_buf.len() == msg_size);
// This thread has exclusive access to the TcpStream, so it's fine
// to do this as two separate calls.
self.send(&outbuf)?;
self.send(&file_buf)?;
start_pos += send_size as u64;
debug!("Sent WAL to page server up to {}", end_pos);
@@ -245,10 +211,4 @@ impl ReplicationConn {
}
Ok(())
}
/// Send messages on the network.
fn send(&mut self, buf: &[u8]) -> Result<()> {
self.stream_out.write_all(buf.as_ref())?;
Ok(())
}
}

View File

@@ -1,111 +1,69 @@
//! This implements the libpq replication protocol between wal_acceptor
//! and replicas/pagers
//! Part of WAL acceptor pretending to be Postgres, streaming xlog to
//! pageserver/any other consumer.
//!
use crate::pq_protocol::{
BeMessage, FeMessage, FeStartupMessage, RowDescriptor, StartupRequestCode,
};
use crate::replication::ReplicationConn;
use crate::timeline::{Timeline, TimelineTools};
use crate::WalAcceptorConf;
use anyhow::{bail, Result};
use bytes::BytesMut;
use log::*;
use std::io::{BufReader, Write};
use std::net::{SocketAddr, TcpStream};
use bytes::Bytes;
use pageserver::ZTimelineId;
use std::str::FromStr;
use std::sync::Arc;
use zenith_utils::postgres_backend;
use zenith_utils::postgres_backend::PostgresBackend;
use zenith_utils::pq_proto::{BeMessage, FeStartupMessage, RowDescriptor};
/// A network connection that's speaking the libpq replication protocol.
pub struct SendWalConn {
pub timeline: Option<Arc<Timeline>>,
/// Postgres connection, buffered input
pub stream_in: BufReader<TcpStream>,
/// Postgres connection, output
pub stream_out: TcpStream,
/// The cached result of socket.peer_addr()
pub peer_addr: SocketAddr,
/// Handler for streaming WAL from acceptor
pub struct SendWalHandler {
/// wal acceptor configuration
pub conf: WalAcceptorConf,
/// assigned application name
appname: Option<String>,
pub appname: Option<String>,
pub timeline: Option<Arc<Timeline>>,
}
impl SendWalConn {
/// Create a new `SendWal`, consuming the `Connection`.
pub fn new(socket: TcpStream, conf: WalAcceptorConf) -> Result<Self> {
let peer_addr = socket.peer_addr()?;
let conn = SendWalConn {
timeline: None,
stream_in: BufReader::new(socket.try_clone()?),
stream_out: socket,
peer_addr,
conf,
appname: None,
};
Ok(conn)
impl postgres_backend::Handler for SendWalHandler {
fn startup(&mut self, _pgb: &mut PostgresBackend, sm: &FeStartupMessage) -> Result<()> {
match sm.params.get("ztimelineid") {
Some(ref ztimelineid) => {
let ztlid = ZTimelineId::from_str(ztimelineid)?;
self.timeline.set(ztlid)?;
}
_ => bail!("timelineid is required"),
}
if let Some(app_name) = sm.params.get("application_name") {
self.appname = Some(app_name.clone());
}
Ok(())
}
///
/// Send WAL to replica or WAL receiver using standard libpq replication protocol
///
pub fn run(mut self) -> Result<()> {
let peer_addr = self.peer_addr;
info!("WAL sender to {:?} is started", peer_addr);
// Handle the startup message first.
let m = FeStartupMessage::read_from(&mut self.stream_in)?;
trace!("got startup message {:?}", m);
match m.kind {
StartupRequestCode::NegotiateGss | StartupRequestCode::NegotiateSsl => {
let mut buf = BytesMut::new();
BeMessage::write(&mut buf, &BeMessage::Negotiate);
info!("SSL requested");
self.stream_out.write_all(&buf)?;
}
StartupRequestCode::Normal => {
let mut buf = BytesMut::new();
BeMessage::write(&mut buf, &BeMessage::AuthenticationOk);
BeMessage::write(&mut buf, &BeMessage::ReadyForQuery);
self.stream_out.write_all(&buf)?;
self.timeline.set(m.timelineid)?;
self.appname = m.appname;
}
StartupRequestCode::Cancel => return Ok(()),
fn process_query(&mut self, pgb: &mut PostgresBackend, query_string: Bytes) -> Result<()> {
if query_string.starts_with(b"IDENTIFY_SYSTEM") {
self.handle_identify_system(pgb)?;
Ok(())
} else if query_string.starts_with(b"START_REPLICATION") {
ReplicationConn::new(pgb).run(self, pgb, &query_string)?;
Ok(())
} else {
bail!("Unexpected command {:?}", query_string);
}
}
}
loop {
let msg = FeMessage::read_from(&mut self.stream_in)?;
match msg {
FeMessage::Query(q) => {
trace!("got query {:?}", q.body);
if q.body.starts_with(b"IDENTIFY_SYSTEM") {
self.handle_identify_system()?;
} else if q.body.starts_with(b"START_REPLICATION") {
// Create a new replication object, consuming `self`.
ReplicationConn::new(self).run(&q.body)?;
break;
} else {
bail!("Unexpected command {:?}", q.body);
}
}
FeMessage::Terminate => {
break;
}
_ => {
bail!("unexpected message");
}
}
impl SendWalHandler {
pub fn new(conf: WalAcceptorConf) -> Self {
SendWalHandler {
conf,
appname: None,
timeline: None,
}
info!("WAL sender to {:?} is finished", peer_addr);
Ok(())
}
///
/// Handle IDENTIFY_SYSTEM replication command
///
fn handle_identify_system(&mut self) -> Result<()> {
fn handle_identify_system(&mut self, pgb: &mut PostgresBackend) -> Result<()> {
let (start_pos, timeline) = self.timeline.find_end_of_wal(&self.conf.data_dir, false);
let lsn = start_pos.to_string();
let tli = timeline.to_string();
@@ -114,42 +72,39 @@ impl SendWalConn {
let tli_bytes = tli.as_bytes();
let sysid_bytes = sysid.as_bytes();
let mut outbuf = BytesMut::new();
BeMessage::write(
&mut outbuf,
&BeMessage::RowDescription(&[
RowDescriptor {
name: b"systemid\0",
typoid: 25,
typlen: -1,
},
RowDescriptor {
name: b"timeline\0",
typoid: 23,
typlen: 4,
},
RowDescriptor {
name: b"xlogpos\0",
typoid: 25,
typlen: -1,
},
RowDescriptor {
name: b"dbname\0",
typoid: 25,
typlen: -1,
},
]),
);
BeMessage::write(
&mut outbuf,
&BeMessage::DataRow(&[Some(sysid_bytes), Some(tli_bytes), Some(lsn_bytes), None]),
);
BeMessage::write(
&mut outbuf,
&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM\0"),
);
BeMessage::write(&mut outbuf, &BeMessage::ReadyForQuery);
self.stream_out.write_all(&outbuf)?;
pgb.write_message_noflush(&BeMessage::RowDescription(&[
RowDescriptor {
name: b"systemid",
typoid: 25,
typlen: -1,
..Default::default()
},
RowDescriptor {
name: b"timeline",
typoid: 23,
typlen: 4,
..Default::default()
},
RowDescriptor {
name: b"xlogpos",
typoid: 25,
typlen: -1,
..Default::default()
},
RowDescriptor {
name: b"dbname",
typoid: 25,
typlen: -1,
..Default::default()
},
]))?
.write_message_noflush(&BeMessage::DataRow(&[
Some(sysid_bytes),
Some(tli_bytes),
Some(lsn_bytes),
None,
]))?
.write_message(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM"))?;
Ok(())
}
}

View File

@@ -9,8 +9,9 @@ use std::net::{TcpListener, TcpStream};
use std::thread;
use crate::receive_wal::ReceiveWalConn;
use crate::send_wal::SendWalConn;
use crate::send_wal::SendWalHandler;
use crate::WalAcceptorConf;
use zenith_utils::postgres_backend::PostgresBackend;
/// Accept incoming TCP connections and spawn them into a background thread.
pub fn thread_main(conf: WalAcceptorConf) -> Result<()> {
@@ -48,7 +49,10 @@ fn handle_socket(mut socket: TcpStream, conf: WalAcceptorConf) -> Result<()> {
socket.read_exact(&mut [0u8; 4])?;
ReceiveWalConn::new(socket, conf)?.run()?; // internal protocol between wal_proposer and wal_acceptor
} else {
SendWalConn::new(socket, conf)?.run()?; // libpq replication protocol between wal_acceptor and replicas/pagers
let mut conn_handler = SendWalHandler::new(conf);
let mut pgbackend = PostgresBackend::new(socket)?;
// libpq replication protocol between wal_acceptor and replicas/pagers
pgbackend.run(&mut conn_handler)?;
}
Ok(())
}

View File

@@ -5,6 +5,10 @@ authors = ["Eric Seppanen <eric@zenith.tech>"]
edition = "2018"
[dependencies]
anyhow = "1.0"
bytes = "1.0.1"
byteorder = "1.4.3"
log = "0.4.14"
serde = { version = "1.0", features = ["derive"] }
bincode = "1.3"
thiserror = "1.0"

View File

@@ -10,3 +10,6 @@ pub mod seqwait;
// pub mod seqwait_async;
pub mod bin_ser;
pub mod postgres_backend;
pub mod pq_proto;

View File

@@ -0,0 +1,181 @@
//! Server-side synchronous Postgres connection, as limited as we need.
//! To use, create PostgresBackend and run() it, passing the Handler
//! implementation determining how to process the queries. Currently its API
//! is rather narrow, but we can extend it once required.
use crate::pq_proto::{BeMessage, FeMessage, FeStartupMessage, StartupRequestCode};
use anyhow::bail;
use anyhow::Result;
use bytes::{Bytes, BytesMut};
use log::*;
use std::io;
use std::io::{BufReader, Write};
use std::net::{Shutdown, TcpStream};
pub trait Handler {
/// Handle single query.
/// postgres_backend will issue ReadyForQuery after calling this (this
/// might be not what we want after CopyData streaming, but currently we don't
/// care).
fn process_query(&mut self, pgb: &mut PostgresBackend, query_string: Bytes) -> Result<()>;
/// Called on startup packet receival, allows to process params.
fn startup(&mut self, _pgb: &mut PostgresBackend, _sm: &FeStartupMessage) -> Result<()> {
Ok(())
}
}
pub struct PostgresBackend {
// replication.rs wants to handle reading on its own in separate thread, so
// wrap in Option to be able to take and transfer the BufReader. Ugly, but I
// have no better ideas.
stream_in: Option<BufReader<TcpStream>>,
stream_out: TcpStream,
// Output buffer. c.f. BeMessage::write why we are using BytesMut here.
buf_out: BytesMut,
init_done: bool,
}
// In replication.rs a separate thread is reading keepalives from the
// socket. When main one finishes, tell it to get down by shutdowning the
// socket.
impl Drop for PostgresBackend {
fn drop(&mut self) {
let _res = self.stream_out.shutdown(Shutdown::Both);
}
}
impl PostgresBackend {
pub fn new(socket: TcpStream) -> Result<Self, std::io::Error> {
let mut pb = PostgresBackend {
stream_in: None,
stream_out: socket,
buf_out: BytesMut::with_capacity(10 * 1024),
init_done: false,
};
// if socket cloning fails, report the error and bail out
pb.stream_in = match pb.stream_out.try_clone() {
Ok(read_sock) => Some(BufReader::new(read_sock)),
Err(error) => {
let errmsg = format!("{}", error);
let _res = pb.write_message_noflush(&BeMessage::ErrorResponse(errmsg));
return Err(error);
}
};
Ok(pb)
}
/// Get direct reference (into the Option) to the read stream.
fn get_stream_in(&mut self) -> Result<&mut BufReader<TcpStream>> {
match self.stream_in {
Some(ref mut stream_in) => Ok(stream_in),
None => bail!("stream_in was taken"),
}
}
pub fn take_stream_in(&mut self) -> Option<BufReader<TcpStream>> {
self.stream_in.take()
}
/// Read full message or return None if connection is closed.
pub fn read_message(&mut self) -> Result<Option<FeMessage>> {
if !self.init_done {
FeStartupMessage::read(self.get_stream_in()?)
} else {
FeMessage::read(self.get_stream_in()?)
}
}
/// Write message into internal output buffer.
pub fn write_message_noflush(&mut self, message: &BeMessage) -> io::Result<&mut Self> {
BeMessage::write(&mut self.buf_out, message)?;
Ok(self)
}
/// Flush output buffer into the socket.
pub fn flush(&mut self) -> io::Result<&mut Self> {
self.stream_out.write_all(&self.buf_out)?;
self.buf_out.clear();
Ok(self)
}
/// Write message into internal buffer and flush it.
pub fn write_message(&mut self, message: &BeMessage) -> io::Result<&mut Self> {
self.write_message_noflush(message)?;
self.flush()
}
pub fn run(&mut self, handler: &mut impl Handler) -> Result<()> {
let peer_addr = self.stream_out.peer_addr()?;
info!("postgres backend to {:?} started", peer_addr);
let mut unnamed_query_string = Bytes::new();
loop {
let msg = self.read_message()?;
trace!("got message {:?}", msg);
match msg {
Some(FeMessage::StartupMessage(m)) => {
trace!("got startup message {:?}", m);
handler.startup(self, &m)?;
match m.kind {
StartupRequestCode::NegotiateGss | StartupRequestCode::NegotiateSsl => {
info!("SSL requested");
self.write_message(&BeMessage::Negotiate)?;
}
StartupRequestCode::Normal => {
self.write_message_noflush(&BeMessage::AuthenticationOk)?;
// psycopg2 will not connect if client_encoding is not
// specified by the server
self.write_message_noflush(&BeMessage::ParameterStatus)?;
self.write_message(&BeMessage::ReadyForQuery)?;
self.init_done = true;
}
StartupRequestCode::Cancel => break,
}
}
Some(FeMessage::Query(m)) => {
trace!("got query {:?}", m.body);
// xxx distinguish fatal and recoverable errors?
if let Err(e) = handler.process_query(self, m.body) {
let errmsg = format!("{}", e);
self.write_message_noflush(&BeMessage::ErrorResponse(errmsg))?;
}
self.write_message(&BeMessage::ReadyForQuery)?;
}
Some(FeMessage::Parse(m)) => {
unnamed_query_string = m.query_string;
self.write_message(&BeMessage::ParseComplete)?;
}
Some(FeMessage::Describe(_)) => {
self.write_message_noflush(&BeMessage::ParameterDescription)?
.write_message(&BeMessage::NoData)?;
}
Some(FeMessage::Bind(_)) => {
self.write_message(&BeMessage::BindComplete)?;
}
Some(FeMessage::Close(_)) => {
self.write_message(&BeMessage::CloseComplete)?;
}
Some(FeMessage::Execute(_)) => {
handler.process_query(self, unnamed_query_string.clone())?;
}
Some(FeMessage::Sync) => {
self.write_message(&BeMessage::ReadyForQuery)?;
}
Some(FeMessage::Terminate) => {
break;
}
None => {
info!("connection closed");
break;
}
x => {
bail!("unexpected message type : {:?}", x);
}
}
}
info!("postgres backend to {:?} exited", peer_addr);
Ok(())
}
}

View File

@@ -0,0 +1,639 @@
//! Postgres protocol messages serialization-deserialization. See
//! https://www.postgresql.org/docs/devel/protocol-message-formats.html
//! on message formats.
use anyhow::{anyhow, bail, Result};
use byteorder::{BigEndian, ByteOrder};
use byteorder::{ReadBytesExt, BE};
use bytes::{Buf, BufMut, Bytes, BytesMut};
// use postgres_ffi::xlog_utils::TimestampTz;
use std::collections::HashMap;
use std::io;
use std::io::Read;
use std::str;
pub type Oid = u32;
pub type SystemId = u64;
#[derive(Debug)]
pub enum FeMessage {
StartupMessage(FeStartupMessage),
Query(FeQueryMessage), // Simple query
Parse(FeParseMessage), // Extended query protocol
Describe(FeDescribeMessage),
Bind(FeBindMessage),
Execute(FeExecuteMessage),
Close(FeCloseMessage),
Sync,
Terminate,
CopyData(Bytes),
CopyDone,
}
#[derive(Debug)]
pub struct FeStartupMessage {
pub version: u32,
pub kind: StartupRequestCode,
// optional params arriving in startup packet
pub params: HashMap<String, String>,
}
#[derive(Debug)]
pub enum StartupRequestCode {
Cancel,
NegotiateSsl,
NegotiateGss,
Normal,
}
#[derive(Debug)]
pub struct FeQueryMessage {
pub body: Bytes,
}
// We only support the simple case of Parse on unnamed prepared statement and
// no params
#[derive(Debug)]
pub struct FeParseMessage {
pub query_string: Bytes,
}
#[derive(Debug)]
pub struct FeDescribeMessage {
pub kind: u8, // 'S' to describe a prepared statement; or 'P' to describe a portal.
// we only support unnamed prepared stmt or portal
}
// we only support unnamed prepared stmt and portal
#[derive(Debug)]
pub struct FeBindMessage {}
// we only support unnamed prepared stmt or portal
#[derive(Debug)]
pub struct FeExecuteMessage {
/// max # of rows
pub maxrows: i32,
}
// we only support unnamed prepared stmt and portal
#[derive(Debug)]
pub struct FeCloseMessage {}
impl FeMessage {
/// Read one message from the stream.
pub fn read(stream: &mut impl Read) -> anyhow::Result<Option<FeMessage>> {
// Each libpq message begins with a message type byte, followed by message length
// If the client closes the connection, return None. But if the client closes the
// connection in the middle of a message, we will return an error.
let tag = match stream.read_u8() {
Ok(b) => b,
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(e.into()),
};
let len = stream.read_u32::<BE>()?;
// The message length includes itself, so it better be at least 4
let bodylen = len.checked_sub(4).ok_or(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid message length: parsing u32",
))?;
// Read message body
let mut body_buf: Vec<u8> = vec![0; bodylen as usize];
stream.read_exact(&mut body_buf)?;
let body = Bytes::from(body_buf);
// Parse it
match tag {
b'Q' => Ok(Some(FeMessage::Query(FeQueryMessage { body }))),
b'P' => Ok(Some(FeParseMessage::parse(body)?)),
b'D' => Ok(Some(FeDescribeMessage::parse(body)?)),
b'E' => Ok(Some(FeExecuteMessage::parse(body)?)),
b'B' => Ok(Some(FeBindMessage::parse(body)?)),
b'C' => Ok(Some(FeCloseMessage::parse(body)?)),
b'S' => Ok(Some(FeMessage::Sync)),
b'X' => Ok(Some(FeMessage::Terminate)),
b'd' => Ok(Some(FeMessage::CopyData(body))),
b'c' => Ok(Some(FeMessage::CopyDone)),
tag => Err(anyhow!("unknown message tag: {},'{:?}'", tag, body)),
}
}
}
impl FeStartupMessage {
/// Read startup message from the stream.
pub fn read(stream: &mut impl std::io::Read) -> anyhow::Result<Option<FeMessage>> {
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
const CANCEL_REQUEST_CODE: u32 = (1234 << 16) | 5678;
const NEGOTIATE_SSL_CODE: u32 = (1234 << 16) | 5679;
const NEGOTIATE_GSS_CODE: u32 = (1234 << 16) | 5680;
// Read length. If the connection is closed before reading anything (or before
// reading 4 bytes, to be precise), return None to indicate that the connection
// was closed. This matches the PostgreSQL server's behavior, which avoids noise
// in the log if the client opens connection but closes it immediately.
let len = match stream.read_u32::<BE>() {
Ok(len) => len as usize,
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(e.into()),
};
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
bail!("invalid message length");
}
let version = stream.read_u32::<BE>()?;
let kind = match version {
CANCEL_REQUEST_CODE => StartupRequestCode::Cancel,
NEGOTIATE_SSL_CODE => StartupRequestCode::NegotiateSsl,
NEGOTIATE_GSS_CODE => StartupRequestCode::NegotiateGss,
_ => StartupRequestCode::Normal,
};
// the rest of startup packet are params
let params_len = len - 8;
let mut params_bytes = vec![0u8; params_len];
stream.read_exact(params_bytes.as_mut())?;
// Then null-terminated (String) pairs of param name / param value go.
let params_str = str::from_utf8(&params_bytes).unwrap();
let params = params_str.split('\0');
let mut params_hash: HashMap<String, String> = HashMap::new();
for pair in params.collect::<Vec<_>>().chunks_exact(2) {
let name = pair[0];
let value = pair[1];
if name == "options" {
// deprecated way of passing params as cmd line args
for cmdopt in value.split(' ') {
let nameval: Vec<&str> = cmdopt.split('=').collect();
if nameval.len() == 2 {
params_hash.insert(nameval[0].to_string(), nameval[1].to_string());
}
}
} else {
params_hash.insert(name.to_string(), value.to_string());
}
}
Ok(Some(FeMessage::StartupMessage(FeStartupMessage {
version,
kind,
params: params_hash,
})))
}
}
impl FeParseMessage {
pub fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let _pstmt_name = read_null_terminated(&mut buf)?;
let query_string = read_null_terminated(&mut buf)?;
let nparams = buf.get_i16();
// FIXME: the rust-postgres driver uses a named prepared statement
// for copy_out(). We're not prepared to handle that correctly. For
// now, just ignore the statement name, assuming that the client never
// uses more than one prepared statement at a time.
/*
if !pstmt_name.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"named prepared statements not implemented in Parse",
));
}
*/
if nparams != 0 {
bail!("query params not implemented");
}
Ok(FeMessage::Parse(FeParseMessage { query_string }))
}
}
impl FeDescribeMessage {
pub fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let kind = buf.get_u8();
let _pstmt_name = read_null_terminated(&mut buf)?;
// FIXME: see FeParseMessage::parse
/*
if !pstmt_name.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"named prepared statements not implemented in Describe",
));
}
*/
if kind != b'S' {
bail!("only prepared statmement Describe is implemented");
}
Ok(FeMessage::Describe(FeDescribeMessage { kind }))
}
}
impl FeExecuteMessage {
pub fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let portal_name = read_null_terminated(&mut buf)?;
let maxrows = buf.get_i32();
if !portal_name.is_empty() {
bail!("named portals not implemented");
}
if maxrows != 0 {
bail!("row limit in Execute message not supported");
}
Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
}
}
impl FeBindMessage {
pub fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let portal_name = read_null_terminated(&mut buf)?;
let _pstmt_name = read_null_terminated(&mut buf)?;
if !portal_name.is_empty() {
bail!("named portals not implemented");
}
// FIXME: see FeParseMessage::parse
/*
if !pstmt_name.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"named prepared statements not implemented",
));
}
*/
Ok(FeMessage::Bind(FeBindMessage {}))
}
}
impl FeCloseMessage {
pub fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let _kind = buf.get_u8();
let _pstmt_or_portal_name = read_null_terminated(&mut buf)?;
// FIXME: we do nothing with Close
Ok(FeMessage::Close(FeCloseMessage {}))
}
}
fn read_null_terminated(buf: &mut Bytes) -> anyhow::Result<Bytes> {
let mut result = BytesMut::new();
loop {
if !buf.has_remaining() {
bail!("no null-terminator in string");
}
let byte = buf.get_u8();
if byte == 0 {
break;
}
result.put_u8(byte);
}
Ok(result.freeze())
}
// Backend
#[derive(Debug)]
pub enum BeMessage<'a> {
AuthenticationOk,
BindComplete,
CommandComplete(&'a [u8]),
ControlFile,
CopyData(&'a [u8]),
CopyDone,
CopyInResponse,
CopyOutResponse,
CopyBothResponse,
CloseComplete,
// None means column is NULL
DataRow(&'a [Option<&'a [u8]>]),
ErrorResponse(String),
// see https://www.postgresql.org/docs/devel/protocol-flow.html#id-1.10.5.7.11
Negotiate,
NoData,
ParameterDescription,
ParameterStatus,
ParseComplete,
ReadyForQuery,
RowDescription(&'a [RowDescriptor<'a>]),
XLogData(XLogDataBody<'a>),
}
// One row desciption in RowDescription packet.
#[derive(Debug)]
pub struct RowDescriptor<'a> {
pub name: &'a [u8],
pub tableoid: Oid,
pub attnum: i16,
pub typoid: Oid,
pub typlen: i16,
pub typmod: i32,
pub formatcode: i16,
}
impl Default for RowDescriptor<'_> {
fn default() -> RowDescriptor<'static> {
RowDescriptor {
name: b"",
tableoid: 0,
attnum: 0,
typoid: 0,
typlen: 0,
typmod: 0,
formatcode: 0,
}
}
}
#[derive(Debug)]
pub struct XLogDataBody<'a> {
pub wal_start: u64,
pub wal_end: u64,
pub timestamp: u64,
pub data: &'a [u8],
}
pub static HELLO_WORLD_ROW: BeMessage = BeMessage::DataRow(&[Some(b"hello world")]);
pub const TEXT_OID: Oid = 25;
// single text column
pub static SINGLE_COL_ROWDESC: BeMessage = BeMessage::RowDescription(&[RowDescriptor {
name: b"data",
tableoid: 0,
attnum: 0,
typoid: TEXT_OID,
typlen: -1,
typmod: 0,
formatcode: 0,
}]);
// Safe usize -> i32|i16 conversion, from rust-postgres
trait FromUsize: Sized {
fn from_usize(x: usize) -> Result<Self, io::Error>;
}
macro_rules! from_usize {
($t:ty) => {
impl FromUsize for $t {
#[inline]
fn from_usize(x: usize) -> io::Result<$t> {
if x > <$t>::max_value() as usize {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"value too large to transmit",
))
} else {
Ok(x as $t)
}
}
}
};
}
from_usize!(i32);
/// Call f() to write body of the message and prepend it with 4-byte len as
/// prescribed by the protocol.
fn write_body<F>(buf: &mut BytesMut, f: F) -> io::Result<()>
where
F: FnOnce(&mut BytesMut) -> io::Result<()>,
{
let base = buf.len();
buf.extend_from_slice(&[0; 4]);
f(buf)?;
let size = i32::from_usize(buf.len() - base)?;
BigEndian::write_i32(&mut buf[base..], size);
Ok(())
}
/// Safe write of s into buf as cstring (String in the protocol).
fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> {
if s.contains(&0) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"string contains embedded null",
));
}
buf.put_slice(s);
buf.put_u8(0);
Ok(())
}
impl<'a> BeMessage<'a> {
/// Write message to the given buf.
// Unlike the reading side, we use BytesMut
// here as msg len preceeds its body and it is handy to write it down first
// and then fill the length. With Write we would have to either calc it
// manually or have one more buffer.
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> io::Result<()> {
match message {
BeMessage::AuthenticationOk => {
buf.put_u8(b'R');
write_body(buf, |buf| {
buf.put_i32(0);
Ok::<_, io::Error>(())
})
.unwrap(); // write into BytesMut can't fail
}
BeMessage::BindComplete => {
buf.put_u8(b'2');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
}
BeMessage::CloseComplete => {
buf.put_u8(b'3');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
}
BeMessage::CommandComplete(cmd) => {
buf.put_u8(b'C');
write_body(buf, |buf| {
write_cstr(cmd, buf)?;
Ok::<_, io::Error>(())
})?;
}
BeMessage::ControlFile => {
// TODO pass checkpoint and xid info in this message
BeMessage::write(buf, &BeMessage::DataRow(&[Some(b"hello pg_control")]))?;
}
BeMessage::CopyData(data) => {
buf.put_u8(b'd');
write_body(buf, |buf| {
buf.put_slice(data);
Ok::<_, io::Error>(())
})
.unwrap();
}
BeMessage::CopyDone => {
buf.put_u8(b'c');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
}
BeMessage::CopyInResponse => {
buf.put_u8(b'G');
write_body(buf, |buf| {
buf.put_u8(1); /* copy_is_binary */
buf.put_i16(0); /* numAttributes */
Ok::<_, io::Error>(())
})
.unwrap();
}
BeMessage::CopyOutResponse => {
buf.put_u8(b'H');
write_body(buf, |buf| {
buf.put_u8(0); /* copy_is_binary */
buf.put_i16(0); /* numAttributes */
Ok::<_, io::Error>(())
})
.unwrap();
}
BeMessage::CopyBothResponse => {
buf.put_u8(b'W');
write_body(buf, |buf| {
// doesn't matter, used only for replication
buf.put_u8(0); /* copy_is_binary */
buf.put_i16(0); /* numAttributes */
Ok::<_, io::Error>(())
})
.unwrap();
}
BeMessage::DataRow(vals) => {
buf.put_u8(b'D');
write_body(buf, |buf| {
buf.put_u16(vals.len() as u16); // num of cols
for val_opt in vals.iter() {
if let Some(val) = val_opt {
buf.put_u32(val.len() as u32);
buf.put_slice(val);
} else {
buf.put_i32(-1);
}
}
Ok::<_, io::Error>(())
})
.unwrap();
}
// ErrorResponse is a zero-terminated array of zero-terminated fields.
// First byte of each field represents type of this field. Set just enough fields
// to satisfy rust-postgres client: 'S' -- severity, 'C' -- error, 'M' -- error
// message text.
BeMessage::ErrorResponse(error_msg) => {
// For all the errors set Severity to Error and error code to
// 'internal error'.
// 'E' signalizes ErrorResponse messages
buf.put_u8(b'E');
write_body(buf, |buf| {
buf.put_u8(b'S'); // severity
write_cstr(&Bytes::from("ERROR"), buf)?;
buf.put_u8(b'C'); // SQLSTATE error code
write_cstr(&Bytes::from("CXX000"), buf)?;
buf.put_u8(b'M'); // the message
write_cstr(error_msg.as_bytes(), buf)?;
buf.put_u8(0); // terminator
Ok::<_, io::Error>(())
})
.unwrap();
}
BeMessage::NoData => {
buf.put_u8(b'n');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
}
BeMessage::Negotiate => {
buf.put_u8(b'N');
}
BeMessage::ParameterStatus => {
buf.put_u8(b'S');
// parameter names and values are specified by null terminated strings
const PARAM_NAME_VALUE: &[u8] = b"client_encoding\0UTF8\0";
write_body(buf, |buf| {
buf.put_slice(PARAM_NAME_VALUE);
Ok::<_, io::Error>(())
})
.unwrap();
}
BeMessage::ParameterDescription => {
buf.put_u8(b't');
write_body(buf, |buf| {
// we don't support params, so always 0
buf.put_i16(0);
Ok::<_, io::Error>(())
})
.unwrap();
}
BeMessage::ParseComplete => {
buf.put_u8(b'1');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
}
BeMessage::ReadyForQuery => {
buf.put_u8(b'Z');
write_body(buf, |buf| {
buf.put_u8(b'I');
Ok::<_, io::Error>(())
})
.unwrap();
}
BeMessage::RowDescription(rows) => {
buf.put_u8(b'T');
write_body(buf, |buf| {
buf.put_i16(rows.len() as i16); // # of fields
for row in rows.iter() {
write_cstr(row.name, buf)?;
buf.put_i32(0); /* table oid */
buf.put_i16(0); /* attnum */
buf.put_u32(row.typoid);
buf.put_i16(row.typlen);
buf.put_i32(-1); /* typmod */
buf.put_i16(0); /* format code */
}
Ok::<_, io::Error>(())
})?;
}
BeMessage::XLogData(body) => {
buf.put_u8(b'd');
write_body(buf, |buf| {
buf.put_u8(b'w');
buf.put_u64(body.wal_start);
buf.put_u64(body.wal_end);
buf.put_u64(body.timestamp);
buf.put_slice(body.data);
Ok::<_, io::Error>(())
})
.unwrap();
}
}
Ok(())
}
}