Consolidate PG proto parsing-deparsing and backend code.

Now postgres_backend communicates with the client, passing queries to the
provided handler; we have two currently, for wal_acceptor and pageserver.

Now BytesMut is again used for writing data to avoid manual message length
calculation.

ref #118
This commit is contained in:
Arseny Sher
2021-05-31 09:26:26 +03:00
committed by arssher
parent 2b0193e6bf
commit b2f51026aa
9 changed files with 1188 additions and 971 deletions

3
Cargo.lock generated
View File

@@ -2478,9 +2478,12 @@ dependencies = [
name = "zenith_utils" name = "zenith_utils"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow",
"bincode", "bincode",
"byteorder",
"bytes", "bytes",
"hex-literal", "hex-literal",
"log",
"serde", "serde",
"thiserror", "thiserror",
"workspace_hack", "workspace_hack",

File diff suppressed because it is too large Load Diff

View File

@@ -1,20 +1,18 @@
//! This module implements the replication protocol, starting with the //! This module implements the streaming side of replication protocol, starting
//! "START REPLICATION" message. //! with the "START REPLICATION" message.
use crate::pq_protocol::{BeMessage, FeMessage}; use crate::send_wal::SendWalHandler;
use crate::send_wal::SendWalConn;
use crate::timeline::{Timeline, TimelineTools}; use crate::timeline::{Timeline, TimelineTools};
use crate::WalAcceptorConf;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use bytes::{BufMut, Bytes, BytesMut}; use bytes::Bytes;
use log::*; use log::*;
use postgres_ffi::xlog_utils::{get_current_timestamp, TimestampTz, XLogFileName, MAX_SEND_SIZE}; use postgres_ffi::xlog_utils::{get_current_timestamp, TimestampTz, XLogFileName, MAX_SEND_SIZE};
use regex::Regex; use regex::Regex;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::cmp::min; use std::cmp::min;
use std::fs::File; use std::fs::File;
use std::io::{BufReader, Read, Seek, SeekFrom, Write}; use std::io::{BufReader, Read, Seek, SeekFrom};
use std::net::{Shutdown, TcpStream}; use std::net::TcpStream;
use std::path::Path; use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use std::thread::sleep; use std::thread::sleep;
@@ -22,10 +20,9 @@ use std::time::Duration;
use std::{str, thread}; use std::{str, thread};
use zenith_utils::bin_ser::BeSer; use zenith_utils::bin_ser::BeSer;
use zenith_utils::lsn::Lsn; use zenith_utils::lsn::Lsn;
use zenith_utils::postgres_backend::PostgresBackend;
use zenith_utils::pq_proto::{BeMessage, FeMessage, XLogDataBody};
const XLOG_HDR_SIZE: usize = 1 + 8 * 3; /* 'w' + startPos + walEnd + timestamp */
const LIBPQ_HDR_SIZE: usize = 5; /* 1 byte with message type + 4 bytes length */
const LIBPQ_MSG_SIZE_OFFS: usize = 1;
pub const END_REPLICATION_MARKER: Lsn = Lsn::MAX; pub const END_REPLICATION_MARKER: Lsn = Lsn::MAX;
type FullTransactionId = u64; type FullTransactionId = u64;
@@ -40,37 +37,16 @@ pub struct HotStandbyFeedback {
/// A network connection that's speaking the replication protocol. /// A network connection that's speaking the replication protocol.
pub struct ReplicationConn { pub struct ReplicationConn {
timeline: Option<Arc<Timeline>>,
/// Postgres connection, buffered input
///
/// This is an `Option` because we will spawn a background thread that will /// This is an `Option` because we will spawn a background thread that will
/// `take` it from us. /// `take` it from us.
stream_in: Option<BufReader<TcpStream>>, stream_in: Option<BufReader<TcpStream>>,
/// Postgres connection, output
stream_out: TcpStream,
/// wal acceptor configuration
conf: WalAcceptorConf,
/// assigned application name
appname: Option<String>,
}
// Separate thread is reading keepalives from the socket. When main one
// finishes, tell it to get down by shutdowning the socket.
impl Drop for ReplicationConn {
fn drop(&mut self) {
let _res = self.stream_out.shutdown(Shutdown::Both);
}
} }
impl ReplicationConn { impl ReplicationConn {
/// Create a new `SendWal`, consuming the `Connection`. /// Create a new `ReplicationConn`
pub fn new(conn: SendWalConn) -> Self { pub fn new(pgb: &mut PostgresBackend) -> Self {
Self { Self {
timeline: conn.timeline, stream_in: pgb.take_stream_in(),
stream_in: Some(conn.stream_in),
stream_out: conn.stream_out,
conf: conn.conf,
appname: None,
} }
} }
@@ -82,9 +58,9 @@ impl ReplicationConn {
// Wait for replica's feedback. // Wait for replica's feedback.
// We only handle `CopyData` messages. Anything else is ignored. // We only handle `CopyData` messages. Anything else is ignored.
loop { loop {
match FeMessage::read_from(&mut stream_in)? { match FeMessage::read(&mut stream_in)? {
FeMessage::CopyData(m) => { Some(FeMessage::CopyData(m)) => {
let feedback = HotStandbyFeedback::des(&m.body)?; let feedback = HotStandbyFeedback::des(&m)?;
timeline.add_hs_feedback(feedback) timeline.add_hs_feedback(feedback)
} }
msg => { msg => {
@@ -128,9 +104,14 @@ impl ReplicationConn {
/// ///
/// Handle START_REPLICATION replication command /// Handle START_REPLICATION replication command
/// ///
pub fn run(&mut self, cmd: &Bytes) -> Result<()> { pub fn run(
&mut self,
swh: &mut SendWalHandler,
pgb: &mut PostgresBackend,
cmd: &Bytes,
) -> Result<()> {
// spawn the background thread which receives HotStandbyFeedback messages. // spawn the background thread which receives HotStandbyFeedback messages.
let bg_timeline = Arc::clone(self.timeline.get()); let bg_timeline = Arc::clone(swh.timeline.get());
let bg_stream_in = self.stream_in.take().unwrap(); let bg_stream_in = self.stream_in.take().unwrap();
thread::spawn(move || { thread::spawn(move || {
@@ -143,7 +124,7 @@ impl ReplicationConn {
let mut wal_seg_size: usize; let mut wal_seg_size: usize;
loop { loop {
wal_seg_size = self.timeline.get().get_info().server.wal_seg_size as usize; wal_seg_size = swh.timeline.get().get_info().server.wal_seg_size as usize;
if wal_seg_size == 0 { if wal_seg_size == 0 {
error!("Can not start replication before connecting to wal_proposer"); error!("Can not start replication before connecting to wal_proposer");
sleep(Duration::from_secs(1)); sleep(Duration::from_secs(1));
@@ -151,19 +132,17 @@ impl ReplicationConn {
break; break;
} }
} }
let (wal_end, timeline) = self.timeline.find_end_of_wal(&self.conf.data_dir, false); let (wal_end, timeline) = swh.timeline.find_end_of_wal(&swh.conf.data_dir, false);
if start_pos == Lsn(0) { if start_pos == Lsn(0) {
start_pos = wal_end; start_pos = wal_end;
} }
if stop_pos == Lsn(0) && self.appname == Some("wal_proposer_recovery".to_string()) { if stop_pos == Lsn(0) && swh.appname == Some("wal_proposer_recovery".to_string()) {
stop_pos = wal_end; stop_pos = wal_end;
} }
info!("Start replication from {} till {}", start_pos, stop_pos); info!("Start replication from {} till {}", start_pos, stop_pos);
let mut outbuf = BytesMut::new(); // switch to copy
BeMessage::write(&mut outbuf, &BeMessage::Copy); pgb.write_message(&BeMessage::CopyBothResponse)?;
self.send(&outbuf)?;
outbuf.clear();
let mut end_pos: Lsn; let mut end_pos: Lsn;
let mut wal_file: Option<File> = None; let mut wal_file: Option<File> = None;
@@ -179,7 +158,7 @@ impl ReplicationConn {
end_pos = stop_pos; end_pos = stop_pos;
} else { } else {
/* normal mode */ /* normal mode */
let timeline = self.timeline.get(); let timeline = swh.timeline.get();
end_pos = timeline.wait_for_lsn(start_pos); end_pos = timeline.wait_for_lsn(start_pos);
} }
if end_pos == END_REPLICATION_MARKER { if end_pos == END_REPLICATION_MARKER {
@@ -193,8 +172,8 @@ impl ReplicationConn {
// Open a new file. // Open a new file.
let segno = start_pos.segment_number(wal_seg_size); let segno = start_pos.segment_number(wal_seg_size);
let wal_file_name = XLogFileName(timeline, segno, wal_seg_size); let wal_file_name = XLogFileName(timeline, segno, wal_seg_size);
let timeline_id = self.timeline.get().timelineid.to_string(); let timeline_id = swh.timeline.get().timelineid.to_string();
let wal_file_path = self.conf.data_dir.join(timeline_id).join(wal_file_name); let wal_file_path = swh.conf.data_dir.join(timeline_id).join(wal_file_name);
Self::open_wal_file(&wal_file_path)? Self::open_wal_file(&wal_file_path)?
} }
}; };
@@ -207,32 +186,19 @@ impl ReplicationConn {
let send_size = min(send_size, wal_seg_size - xlogoff); let send_size = min(send_size, wal_seg_size - xlogoff);
let send_size = min(send_size, MAX_SEND_SIZE); let send_size = min(send_size, MAX_SEND_SIZE);
let msg_size = LIBPQ_HDR_SIZE + XLOG_HDR_SIZE + send_size;
// Read some data from the file. // Read some data from the file.
let mut file_buf = vec![0u8; send_size]; let mut file_buf = vec![0u8; send_size];
file.seek(SeekFrom::Start(xlogoff as u64))?; file.seek(SeekFrom::Start(xlogoff as u64))?;
file.read_exact(&mut file_buf)?; file.read_exact(&mut file_buf)?;
// Write some data to the network socket. // Write some data to the network socket.
// FIXME: turn these into structs. pgb.write_message(&BeMessage::XLogData(XLogDataBody {
// 'd' is CopyData; wal_start: start_pos.0,
// 'w' is "WAL records" wal_end: end_pos.0,
// https://www.postgresql.org/docs/9.1/protocol-message-formats.html timestamp: get_current_timestamp(),
// src/backend/replication/walreceiver.c data: &file_buf,
outbuf.clear(); }))?;
outbuf.put_u8(b'd');
outbuf.put_u32((msg_size - LIBPQ_MSG_SIZE_OFFS) as u32);
outbuf.put_u8(b'w');
outbuf.put_u64(start_pos.0);
outbuf.put_u64(end_pos.0);
outbuf.put_u64(get_current_timestamp());
assert!(outbuf.len() + file_buf.len() == msg_size);
// This thread has exclusive access to the TcpStream, so it's fine
// to do this as two separate calls.
self.send(&outbuf)?;
self.send(&file_buf)?;
start_pos += send_size as u64; start_pos += send_size as u64;
debug!("Sent WAL to page server up to {}", end_pos); debug!("Sent WAL to page server up to {}", end_pos);
@@ -245,10 +211,4 @@ impl ReplicationConn {
} }
Ok(()) Ok(())
} }
/// Send messages on the network.
fn send(&mut self, buf: &[u8]) -> Result<()> {
self.stream_out.write_all(buf.as_ref())?;
Ok(())
}
} }

View File

@@ -1,111 +1,69 @@
//! This implements the libpq replication protocol between wal_acceptor //! Part of WAL acceptor pretending to be Postgres, streaming xlog to
//! and replicas/pagers //! pageserver/any other consumer.
//! //!
use crate::pq_protocol::{
BeMessage, FeMessage, FeStartupMessage, RowDescriptor, StartupRequestCode,
};
use crate::replication::ReplicationConn; use crate::replication::ReplicationConn;
use crate::timeline::{Timeline, TimelineTools}; use crate::timeline::{Timeline, TimelineTools};
use crate::WalAcceptorConf; use crate::WalAcceptorConf;
use anyhow::{bail, Result}; use anyhow::{bail, Result};
use bytes::BytesMut; use bytes::Bytes;
use log::*; use pageserver::ZTimelineId;
use std::io::{BufReader, Write}; use std::str::FromStr;
use std::net::{SocketAddr, TcpStream};
use std::sync::Arc; use std::sync::Arc;
use zenith_utils::postgres_backend;
use zenith_utils::postgres_backend::PostgresBackend;
use zenith_utils::pq_proto::{BeMessage, FeStartupMessage, RowDescriptor};
/// A network connection that's speaking the libpq replication protocol. /// Handler for streaming WAL from acceptor
pub struct SendWalConn { pub struct SendWalHandler {
pub timeline: Option<Arc<Timeline>>,
/// Postgres connection, buffered input
pub stream_in: BufReader<TcpStream>,
/// Postgres connection, output
pub stream_out: TcpStream,
/// The cached result of socket.peer_addr()
pub peer_addr: SocketAddr,
/// wal acceptor configuration /// wal acceptor configuration
pub conf: WalAcceptorConf, pub conf: WalAcceptorConf,
/// assigned application name /// assigned application name
appname: Option<String>, pub appname: Option<String>,
pub timeline: Option<Arc<Timeline>>,
} }
impl SendWalConn { impl postgres_backend::Handler for SendWalHandler {
/// Create a new `SendWal`, consuming the `Connection`. fn startup(&mut self, _pgb: &mut PostgresBackend, sm: &FeStartupMessage) -> Result<()> {
pub fn new(socket: TcpStream, conf: WalAcceptorConf) -> Result<Self> { match sm.params.get("ztimelineid") {
let peer_addr = socket.peer_addr()?; Some(ref ztimelineid) => {
let conn = SendWalConn { let ztlid = ZTimelineId::from_str(ztimelineid)?;
timeline: None, self.timeline.set(ztlid)?;
stream_in: BufReader::new(socket.try_clone()?), }
stream_out: socket, _ => bail!("timelineid is required"),
peer_addr, }
conf, if let Some(app_name) = sm.params.get("application_name") {
appname: None, self.appname = Some(app_name.clone());
}; }
Ok(conn) Ok(())
} }
/// fn process_query(&mut self, pgb: &mut PostgresBackend, query_string: Bytes) -> Result<()> {
/// Send WAL to replica or WAL receiver using standard libpq replication protocol if query_string.starts_with(b"IDENTIFY_SYSTEM") {
/// self.handle_identify_system(pgb)?;
pub fn run(mut self) -> Result<()> { Ok(())
let peer_addr = self.peer_addr; } else if query_string.starts_with(b"START_REPLICATION") {
info!("WAL sender to {:?} is started", peer_addr); ReplicationConn::new(pgb).run(self, pgb, &query_string)?;
Ok(())
// Handle the startup message first. } else {
bail!("Unexpected command {:?}", query_string);
let m = FeStartupMessage::read_from(&mut self.stream_in)?;
trace!("got startup message {:?}", m);
match m.kind {
StartupRequestCode::NegotiateGss | StartupRequestCode::NegotiateSsl => {
let mut buf = BytesMut::new();
BeMessage::write(&mut buf, &BeMessage::Negotiate);
info!("SSL requested");
self.stream_out.write_all(&buf)?;
}
StartupRequestCode::Normal => {
let mut buf = BytesMut::new();
BeMessage::write(&mut buf, &BeMessage::AuthenticationOk);
BeMessage::write(&mut buf, &BeMessage::ReadyForQuery);
self.stream_out.write_all(&buf)?;
self.timeline.set(m.timelineid)?;
self.appname = m.appname;
}
StartupRequestCode::Cancel => return Ok(()),
} }
}
}
loop { impl SendWalHandler {
let msg = FeMessage::read_from(&mut self.stream_in)?; pub fn new(conf: WalAcceptorConf) -> Self {
match msg { SendWalHandler {
FeMessage::Query(q) => { conf,
trace!("got query {:?}", q.body); appname: None,
timeline: None,
if q.body.starts_with(b"IDENTIFY_SYSTEM") {
self.handle_identify_system()?;
} else if q.body.starts_with(b"START_REPLICATION") {
// Create a new replication object, consuming `self`.
ReplicationConn::new(self).run(&q.body)?;
break;
} else {
bail!("Unexpected command {:?}", q.body);
}
}
FeMessage::Terminate => {
break;
}
_ => {
bail!("unexpected message");
}
}
} }
info!("WAL sender to {:?} is finished", peer_addr);
Ok(())
} }
/// ///
/// Handle IDENTIFY_SYSTEM replication command /// Handle IDENTIFY_SYSTEM replication command
/// ///
fn handle_identify_system(&mut self) -> Result<()> { fn handle_identify_system(&mut self, pgb: &mut PostgresBackend) -> Result<()> {
let (start_pos, timeline) = self.timeline.find_end_of_wal(&self.conf.data_dir, false); let (start_pos, timeline) = self.timeline.find_end_of_wal(&self.conf.data_dir, false);
let lsn = start_pos.to_string(); let lsn = start_pos.to_string();
let tli = timeline.to_string(); let tli = timeline.to_string();
@@ -114,42 +72,39 @@ impl SendWalConn {
let tli_bytes = tli.as_bytes(); let tli_bytes = tli.as_bytes();
let sysid_bytes = sysid.as_bytes(); let sysid_bytes = sysid.as_bytes();
let mut outbuf = BytesMut::new(); pgb.write_message_noflush(&BeMessage::RowDescription(&[
BeMessage::write( RowDescriptor {
&mut outbuf, name: b"systemid",
&BeMessage::RowDescription(&[ typoid: 25,
RowDescriptor { typlen: -1,
name: b"systemid\0", ..Default::default()
typoid: 25, },
typlen: -1, RowDescriptor {
}, name: b"timeline",
RowDescriptor { typoid: 23,
name: b"timeline\0", typlen: 4,
typoid: 23, ..Default::default()
typlen: 4, },
}, RowDescriptor {
RowDescriptor { name: b"xlogpos",
name: b"xlogpos\0", typoid: 25,
typoid: 25, typlen: -1,
typlen: -1, ..Default::default()
}, },
RowDescriptor { RowDescriptor {
name: b"dbname\0", name: b"dbname",
typoid: 25, typoid: 25,
typlen: -1, typlen: -1,
}, ..Default::default()
]), },
); ]))?
BeMessage::write( .write_message_noflush(&BeMessage::DataRow(&[
&mut outbuf, Some(sysid_bytes),
&BeMessage::DataRow(&[Some(sysid_bytes), Some(tli_bytes), Some(lsn_bytes), None]), Some(tli_bytes),
); Some(lsn_bytes),
BeMessage::write( None,
&mut outbuf, ]))?
&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM\0"), .write_message(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM"))?;
);
BeMessage::write(&mut outbuf, &BeMessage::ReadyForQuery);
self.stream_out.write_all(&outbuf)?;
Ok(()) Ok(())
} }
} }

View File

@@ -9,8 +9,9 @@ use std::net::{TcpListener, TcpStream};
use std::thread; use std::thread;
use crate::receive_wal::ReceiveWalConn; use crate::receive_wal::ReceiveWalConn;
use crate::send_wal::SendWalConn; use crate::send_wal::SendWalHandler;
use crate::WalAcceptorConf; use crate::WalAcceptorConf;
use zenith_utils::postgres_backend::PostgresBackend;
/// Accept incoming TCP connections and spawn them into a background thread. /// Accept incoming TCP connections and spawn them into a background thread.
pub fn thread_main(conf: WalAcceptorConf) -> Result<()> { pub fn thread_main(conf: WalAcceptorConf) -> Result<()> {
@@ -48,7 +49,10 @@ fn handle_socket(mut socket: TcpStream, conf: WalAcceptorConf) -> Result<()> {
socket.read_exact(&mut [0u8; 4])?; socket.read_exact(&mut [0u8; 4])?;
ReceiveWalConn::new(socket, conf)?.run()?; // internal protocol between wal_proposer and wal_acceptor ReceiveWalConn::new(socket, conf)?.run()?; // internal protocol between wal_proposer and wal_acceptor
} else { } else {
SendWalConn::new(socket, conf)?.run()?; // libpq replication protocol between wal_acceptor and replicas/pagers let mut conn_handler = SendWalHandler::new(conf);
let mut pgbackend = PostgresBackend::new(socket)?;
// libpq replication protocol between wal_acceptor and replicas/pagers
pgbackend.run(&mut conn_handler)?;
} }
Ok(()) Ok(())
} }

View File

@@ -5,6 +5,10 @@ authors = ["Eric Seppanen <eric@zenith.tech>"]
edition = "2018" edition = "2018"
[dependencies] [dependencies]
anyhow = "1.0"
bytes = "1.0.1"
byteorder = "1.4.3"
log = "0.4.14"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
bincode = "1.3" bincode = "1.3"
thiserror = "1.0" thiserror = "1.0"

View File

@@ -10,3 +10,6 @@ pub mod seqwait;
// pub mod seqwait_async; // pub mod seqwait_async;
pub mod bin_ser; pub mod bin_ser;
pub mod postgres_backend;
pub mod pq_proto;

View File

@@ -0,0 +1,181 @@
//! Server-side synchronous Postgres connection, as limited as we need.
//! To use, create PostgresBackend and run() it, passing the Handler
//! implementation determining how to process the queries. Currently its API
//! is rather narrow, but we can extend it once required.
use crate::pq_proto::{BeMessage, FeMessage, FeStartupMessage, StartupRequestCode};
use anyhow::bail;
use anyhow::Result;
use bytes::{Bytes, BytesMut};
use log::*;
use std::io;
use std::io::{BufReader, Write};
use std::net::{Shutdown, TcpStream};
pub trait Handler {
/// Handle single query.
/// postgres_backend will issue ReadyForQuery after calling this (this
/// might be not what we want after CopyData streaming, but currently we don't
/// care).
fn process_query(&mut self, pgb: &mut PostgresBackend, query_string: Bytes) -> Result<()>;
/// Called on startup packet receival, allows to process params.
fn startup(&mut self, _pgb: &mut PostgresBackend, _sm: &FeStartupMessage) -> Result<()> {
Ok(())
}
}
pub struct PostgresBackend {
// replication.rs wants to handle reading on its own in separate thread, so
// wrap in Option to be able to take and transfer the BufReader. Ugly, but I
// have no better ideas.
stream_in: Option<BufReader<TcpStream>>,
stream_out: TcpStream,
// Output buffer. c.f. BeMessage::write why we are using BytesMut here.
buf_out: BytesMut,
init_done: bool,
}
// In replication.rs a separate thread is reading keepalives from the
// socket. When main one finishes, tell it to get down by shutdowning the
// socket.
impl Drop for PostgresBackend {
fn drop(&mut self) {
let _res = self.stream_out.shutdown(Shutdown::Both);
}
}
impl PostgresBackend {
pub fn new(socket: TcpStream) -> Result<Self, std::io::Error> {
let mut pb = PostgresBackend {
stream_in: None,
stream_out: socket,
buf_out: BytesMut::with_capacity(10 * 1024),
init_done: false,
};
// if socket cloning fails, report the error and bail out
pb.stream_in = match pb.stream_out.try_clone() {
Ok(read_sock) => Some(BufReader::new(read_sock)),
Err(error) => {
let errmsg = format!("{}", error);
let _res = pb.write_message_noflush(&BeMessage::ErrorResponse(errmsg));
return Err(error);
}
};
Ok(pb)
}
/// Get direct reference (into the Option) to the read stream.
fn get_stream_in(&mut self) -> Result<&mut BufReader<TcpStream>> {
match self.stream_in {
Some(ref mut stream_in) => Ok(stream_in),
None => bail!("stream_in was taken"),
}
}
pub fn take_stream_in(&mut self) -> Option<BufReader<TcpStream>> {
self.stream_in.take()
}
/// Read full message or return None if connection is closed.
pub fn read_message(&mut self) -> Result<Option<FeMessage>> {
if !self.init_done {
FeStartupMessage::read(self.get_stream_in()?)
} else {
FeMessage::read(self.get_stream_in()?)
}
}
/// Write message into internal output buffer.
pub fn write_message_noflush(&mut self, message: &BeMessage) -> io::Result<&mut Self> {
BeMessage::write(&mut self.buf_out, message)?;
Ok(self)
}
/// Flush output buffer into the socket.
pub fn flush(&mut self) -> io::Result<&mut Self> {
self.stream_out.write_all(&self.buf_out)?;
self.buf_out.clear();
Ok(self)
}
/// Write message into internal buffer and flush it.
pub fn write_message(&mut self, message: &BeMessage) -> io::Result<&mut Self> {
self.write_message_noflush(message)?;
self.flush()
}
pub fn run(&mut self, handler: &mut impl Handler) -> Result<()> {
let peer_addr = self.stream_out.peer_addr()?;
info!("postgres backend to {:?} started", peer_addr);
let mut unnamed_query_string = Bytes::new();
loop {
let msg = self.read_message()?;
trace!("got message {:?}", msg);
match msg {
Some(FeMessage::StartupMessage(m)) => {
trace!("got startup message {:?}", m);
handler.startup(self, &m)?;
match m.kind {
StartupRequestCode::NegotiateGss | StartupRequestCode::NegotiateSsl => {
info!("SSL requested");
self.write_message(&BeMessage::Negotiate)?;
}
StartupRequestCode::Normal => {
self.write_message_noflush(&BeMessage::AuthenticationOk)?;
// psycopg2 will not connect if client_encoding is not
// specified by the server
self.write_message_noflush(&BeMessage::ParameterStatus)?;
self.write_message(&BeMessage::ReadyForQuery)?;
self.init_done = true;
}
StartupRequestCode::Cancel => break,
}
}
Some(FeMessage::Query(m)) => {
trace!("got query {:?}", m.body);
// xxx distinguish fatal and recoverable errors?
if let Err(e) = handler.process_query(self, m.body) {
let errmsg = format!("{}", e);
self.write_message_noflush(&BeMessage::ErrorResponse(errmsg))?;
}
self.write_message(&BeMessage::ReadyForQuery)?;
}
Some(FeMessage::Parse(m)) => {
unnamed_query_string = m.query_string;
self.write_message(&BeMessage::ParseComplete)?;
}
Some(FeMessage::Describe(_)) => {
self.write_message_noflush(&BeMessage::ParameterDescription)?
.write_message(&BeMessage::NoData)?;
}
Some(FeMessage::Bind(_)) => {
self.write_message(&BeMessage::BindComplete)?;
}
Some(FeMessage::Close(_)) => {
self.write_message(&BeMessage::CloseComplete)?;
}
Some(FeMessage::Execute(_)) => {
handler.process_query(self, unnamed_query_string.clone())?;
}
Some(FeMessage::Sync) => {
self.write_message(&BeMessage::ReadyForQuery)?;
}
Some(FeMessage::Terminate) => {
break;
}
None => {
info!("connection closed");
break;
}
x => {
bail!("unexpected message type : {:?}", x);
}
}
}
info!("postgres backend to {:?} exited", peer_addr);
Ok(())
}
}

View File

@@ -0,0 +1,639 @@
//! Postgres protocol messages serialization-deserialization. See
//! https://www.postgresql.org/docs/devel/protocol-message-formats.html
//! on message formats.
use anyhow::{anyhow, bail, Result};
use byteorder::{BigEndian, ByteOrder};
use byteorder::{ReadBytesExt, BE};
use bytes::{Buf, BufMut, Bytes, BytesMut};
// use postgres_ffi::xlog_utils::TimestampTz;
use std::collections::HashMap;
use std::io;
use std::io::Read;
use std::str;
pub type Oid = u32;
pub type SystemId = u64;
#[derive(Debug)]
pub enum FeMessage {
StartupMessage(FeStartupMessage),
Query(FeQueryMessage), // Simple query
Parse(FeParseMessage), // Extended query protocol
Describe(FeDescribeMessage),
Bind(FeBindMessage),
Execute(FeExecuteMessage),
Close(FeCloseMessage),
Sync,
Terminate,
CopyData(Bytes),
CopyDone,
}
#[derive(Debug)]
pub struct FeStartupMessage {
pub version: u32,
pub kind: StartupRequestCode,
// optional params arriving in startup packet
pub params: HashMap<String, String>,
}
#[derive(Debug)]
pub enum StartupRequestCode {
Cancel,
NegotiateSsl,
NegotiateGss,
Normal,
}
#[derive(Debug)]
pub struct FeQueryMessage {
pub body: Bytes,
}
// We only support the simple case of Parse on unnamed prepared statement and
// no params
#[derive(Debug)]
pub struct FeParseMessage {
pub query_string: Bytes,
}
#[derive(Debug)]
pub struct FeDescribeMessage {
pub kind: u8, // 'S' to describe a prepared statement; or 'P' to describe a portal.
// we only support unnamed prepared stmt or portal
}
// we only support unnamed prepared stmt and portal
#[derive(Debug)]
pub struct FeBindMessage {}
// we only support unnamed prepared stmt or portal
#[derive(Debug)]
pub struct FeExecuteMessage {
/// max # of rows
pub maxrows: i32,
}
// we only support unnamed prepared stmt and portal
#[derive(Debug)]
pub struct FeCloseMessage {}
impl FeMessage {
/// Read one message from the stream.
pub fn read(stream: &mut impl Read) -> anyhow::Result<Option<FeMessage>> {
// Each libpq message begins with a message type byte, followed by message length
// If the client closes the connection, return None. But if the client closes the
// connection in the middle of a message, we will return an error.
let tag = match stream.read_u8() {
Ok(b) => b,
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(e.into()),
};
let len = stream.read_u32::<BE>()?;
// The message length includes itself, so it better be at least 4
let bodylen = len.checked_sub(4).ok_or(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid message length: parsing u32",
))?;
// Read message body
let mut body_buf: Vec<u8> = vec![0; bodylen as usize];
stream.read_exact(&mut body_buf)?;
let body = Bytes::from(body_buf);
// Parse it
match tag {
b'Q' => Ok(Some(FeMessage::Query(FeQueryMessage { body }))),
b'P' => Ok(Some(FeParseMessage::parse(body)?)),
b'D' => Ok(Some(FeDescribeMessage::parse(body)?)),
b'E' => Ok(Some(FeExecuteMessage::parse(body)?)),
b'B' => Ok(Some(FeBindMessage::parse(body)?)),
b'C' => Ok(Some(FeCloseMessage::parse(body)?)),
b'S' => Ok(Some(FeMessage::Sync)),
b'X' => Ok(Some(FeMessage::Terminate)),
b'd' => Ok(Some(FeMessage::CopyData(body))),
b'c' => Ok(Some(FeMessage::CopyDone)),
tag => Err(anyhow!("unknown message tag: {},'{:?}'", tag, body)),
}
}
}
impl FeStartupMessage {
/// Read startup message from the stream.
pub fn read(stream: &mut impl std::io::Read) -> anyhow::Result<Option<FeMessage>> {
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
const CANCEL_REQUEST_CODE: u32 = (1234 << 16) | 5678;
const NEGOTIATE_SSL_CODE: u32 = (1234 << 16) | 5679;
const NEGOTIATE_GSS_CODE: u32 = (1234 << 16) | 5680;
// Read length. If the connection is closed before reading anything (or before
// reading 4 bytes, to be precise), return None to indicate that the connection
// was closed. This matches the PostgreSQL server's behavior, which avoids noise
// in the log if the client opens connection but closes it immediately.
let len = match stream.read_u32::<BE>() {
Ok(len) => len as usize,
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(e.into()),
};
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
bail!("invalid message length");
}
let version = stream.read_u32::<BE>()?;
let kind = match version {
CANCEL_REQUEST_CODE => StartupRequestCode::Cancel,
NEGOTIATE_SSL_CODE => StartupRequestCode::NegotiateSsl,
NEGOTIATE_GSS_CODE => StartupRequestCode::NegotiateGss,
_ => StartupRequestCode::Normal,
};
// the rest of startup packet are params
let params_len = len - 8;
let mut params_bytes = vec![0u8; params_len];
stream.read_exact(params_bytes.as_mut())?;
// Then null-terminated (String) pairs of param name / param value go.
let params_str = str::from_utf8(&params_bytes).unwrap();
let params = params_str.split('\0');
let mut params_hash: HashMap<String, String> = HashMap::new();
for pair in params.collect::<Vec<_>>().chunks_exact(2) {
let name = pair[0];
let value = pair[1];
if name == "options" {
// deprecated way of passing params as cmd line args
for cmdopt in value.split(' ') {
let nameval: Vec<&str> = cmdopt.split('=').collect();
if nameval.len() == 2 {
params_hash.insert(nameval[0].to_string(), nameval[1].to_string());
}
}
} else {
params_hash.insert(name.to_string(), value.to_string());
}
}
Ok(Some(FeMessage::StartupMessage(FeStartupMessage {
version,
kind,
params: params_hash,
})))
}
}
impl FeParseMessage {
pub fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let _pstmt_name = read_null_terminated(&mut buf)?;
let query_string = read_null_terminated(&mut buf)?;
let nparams = buf.get_i16();
// FIXME: the rust-postgres driver uses a named prepared statement
// for copy_out(). We're not prepared to handle that correctly. For
// now, just ignore the statement name, assuming that the client never
// uses more than one prepared statement at a time.
/*
if !pstmt_name.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"named prepared statements not implemented in Parse",
));
}
*/
if nparams != 0 {
bail!("query params not implemented");
}
Ok(FeMessage::Parse(FeParseMessage { query_string }))
}
}
impl FeDescribeMessage {
pub fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let kind = buf.get_u8();
let _pstmt_name = read_null_terminated(&mut buf)?;
// FIXME: see FeParseMessage::parse
/*
if !pstmt_name.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"named prepared statements not implemented in Describe",
));
}
*/
if kind != b'S' {
bail!("only prepared statmement Describe is implemented");
}
Ok(FeMessage::Describe(FeDescribeMessage { kind }))
}
}
impl FeExecuteMessage {
pub fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let portal_name = read_null_terminated(&mut buf)?;
let maxrows = buf.get_i32();
if !portal_name.is_empty() {
bail!("named portals not implemented");
}
if maxrows != 0 {
bail!("row limit in Execute message not supported");
}
Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
}
}
impl FeBindMessage {
pub fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let portal_name = read_null_terminated(&mut buf)?;
let _pstmt_name = read_null_terminated(&mut buf)?;
if !portal_name.is_empty() {
bail!("named portals not implemented");
}
// FIXME: see FeParseMessage::parse
/*
if !pstmt_name.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"named prepared statements not implemented",
));
}
*/
Ok(FeMessage::Bind(FeBindMessage {}))
}
}
impl FeCloseMessage {
pub fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let _kind = buf.get_u8();
let _pstmt_or_portal_name = read_null_terminated(&mut buf)?;
// FIXME: we do nothing with Close
Ok(FeMessage::Close(FeCloseMessage {}))
}
}
fn read_null_terminated(buf: &mut Bytes) -> anyhow::Result<Bytes> {
let mut result = BytesMut::new();
loop {
if !buf.has_remaining() {
bail!("no null-terminator in string");
}
let byte = buf.get_u8();
if byte == 0 {
break;
}
result.put_u8(byte);
}
Ok(result.freeze())
}
// Backend
#[derive(Debug)]
pub enum BeMessage<'a> {
AuthenticationOk,
BindComplete,
CommandComplete(&'a [u8]),
ControlFile,
CopyData(&'a [u8]),
CopyDone,
CopyInResponse,
CopyOutResponse,
CopyBothResponse,
CloseComplete,
// None means column is NULL
DataRow(&'a [Option<&'a [u8]>]),
ErrorResponse(String),
// see https://www.postgresql.org/docs/devel/protocol-flow.html#id-1.10.5.7.11
Negotiate,
NoData,
ParameterDescription,
ParameterStatus,
ParseComplete,
ReadyForQuery,
RowDescription(&'a [RowDescriptor<'a>]),
XLogData(XLogDataBody<'a>),
}
// One row desciption in RowDescription packet.
#[derive(Debug)]
pub struct RowDescriptor<'a> {
pub name: &'a [u8],
pub tableoid: Oid,
pub attnum: i16,
pub typoid: Oid,
pub typlen: i16,
pub typmod: i32,
pub formatcode: i16,
}
impl Default for RowDescriptor<'_> {
fn default() -> RowDescriptor<'static> {
RowDescriptor {
name: b"",
tableoid: 0,
attnum: 0,
typoid: 0,
typlen: 0,
typmod: 0,
formatcode: 0,
}
}
}
#[derive(Debug)]
pub struct XLogDataBody<'a> {
pub wal_start: u64,
pub wal_end: u64,
pub timestamp: u64,
pub data: &'a [u8],
}
pub static HELLO_WORLD_ROW: BeMessage = BeMessage::DataRow(&[Some(b"hello world")]);
pub const TEXT_OID: Oid = 25;
// single text column
pub static SINGLE_COL_ROWDESC: BeMessage = BeMessage::RowDescription(&[RowDescriptor {
name: b"data",
tableoid: 0,
attnum: 0,
typoid: TEXT_OID,
typlen: -1,
typmod: 0,
formatcode: 0,
}]);
// Safe usize -> i32|i16 conversion, from rust-postgres
trait FromUsize: Sized {
fn from_usize(x: usize) -> Result<Self, io::Error>;
}
macro_rules! from_usize {
($t:ty) => {
impl FromUsize for $t {
#[inline]
fn from_usize(x: usize) -> io::Result<$t> {
if x > <$t>::max_value() as usize {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"value too large to transmit",
))
} else {
Ok(x as $t)
}
}
}
};
}
from_usize!(i32);
/// Call f() to write body of the message and prepend it with 4-byte len as
/// prescribed by the protocol.
fn write_body<F>(buf: &mut BytesMut, f: F) -> io::Result<()>
where
F: FnOnce(&mut BytesMut) -> io::Result<()>,
{
let base = buf.len();
buf.extend_from_slice(&[0; 4]);
f(buf)?;
let size = i32::from_usize(buf.len() - base)?;
BigEndian::write_i32(&mut buf[base..], size);
Ok(())
}
/// Safe write of s into buf as cstring (String in the protocol).
fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> {
if s.contains(&0) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"string contains embedded null",
));
}
buf.put_slice(s);
buf.put_u8(0);
Ok(())
}
impl<'a> BeMessage<'a> {
/// Write message to the given buf.
// Unlike the reading side, we use BytesMut
// here as msg len preceeds its body and it is handy to write it down first
// and then fill the length. With Write we would have to either calc it
// manually or have one more buffer.
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> io::Result<()> {
match message {
BeMessage::AuthenticationOk => {
buf.put_u8(b'R');
write_body(buf, |buf| {
buf.put_i32(0);
Ok::<_, io::Error>(())
})
.unwrap(); // write into BytesMut can't fail
}
BeMessage::BindComplete => {
buf.put_u8(b'2');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
}
BeMessage::CloseComplete => {
buf.put_u8(b'3');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
}
BeMessage::CommandComplete(cmd) => {
buf.put_u8(b'C');
write_body(buf, |buf| {
write_cstr(cmd, buf)?;
Ok::<_, io::Error>(())
})?;
}
BeMessage::ControlFile => {
// TODO pass checkpoint and xid info in this message
BeMessage::write(buf, &BeMessage::DataRow(&[Some(b"hello pg_control")]))?;
}
BeMessage::CopyData(data) => {
buf.put_u8(b'd');
write_body(buf, |buf| {
buf.put_slice(data);
Ok::<_, io::Error>(())
})
.unwrap();
}
BeMessage::CopyDone => {
buf.put_u8(b'c');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
}
BeMessage::CopyInResponse => {
buf.put_u8(b'G');
write_body(buf, |buf| {
buf.put_u8(1); /* copy_is_binary */
buf.put_i16(0); /* numAttributes */
Ok::<_, io::Error>(())
})
.unwrap();
}
BeMessage::CopyOutResponse => {
buf.put_u8(b'H');
write_body(buf, |buf| {
buf.put_u8(0); /* copy_is_binary */
buf.put_i16(0); /* numAttributes */
Ok::<_, io::Error>(())
})
.unwrap();
}
BeMessage::CopyBothResponse => {
buf.put_u8(b'W');
write_body(buf, |buf| {
// doesn't matter, used only for replication
buf.put_u8(0); /* copy_is_binary */
buf.put_i16(0); /* numAttributes */
Ok::<_, io::Error>(())
})
.unwrap();
}
BeMessage::DataRow(vals) => {
buf.put_u8(b'D');
write_body(buf, |buf| {
buf.put_u16(vals.len() as u16); // num of cols
for val_opt in vals.iter() {
if let Some(val) = val_opt {
buf.put_u32(val.len() as u32);
buf.put_slice(val);
} else {
buf.put_i32(-1);
}
}
Ok::<_, io::Error>(())
})
.unwrap();
}
// ErrorResponse is a zero-terminated array of zero-terminated fields.
// First byte of each field represents type of this field. Set just enough fields
// to satisfy rust-postgres client: 'S' -- severity, 'C' -- error, 'M' -- error
// message text.
BeMessage::ErrorResponse(error_msg) => {
// For all the errors set Severity to Error and error code to
// 'internal error'.
// 'E' signalizes ErrorResponse messages
buf.put_u8(b'E');
write_body(buf, |buf| {
buf.put_u8(b'S'); // severity
write_cstr(&Bytes::from("ERROR"), buf)?;
buf.put_u8(b'C'); // SQLSTATE error code
write_cstr(&Bytes::from("CXX000"), buf)?;
buf.put_u8(b'M'); // the message
write_cstr(error_msg.as_bytes(), buf)?;
buf.put_u8(0); // terminator
Ok::<_, io::Error>(())
})
.unwrap();
}
BeMessage::NoData => {
buf.put_u8(b'n');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
}
BeMessage::Negotiate => {
buf.put_u8(b'N');
}
BeMessage::ParameterStatus => {
buf.put_u8(b'S');
// parameter names and values are specified by null terminated strings
const PARAM_NAME_VALUE: &[u8] = b"client_encoding\0UTF8\0";
write_body(buf, |buf| {
buf.put_slice(PARAM_NAME_VALUE);
Ok::<_, io::Error>(())
})
.unwrap();
}
BeMessage::ParameterDescription => {
buf.put_u8(b't');
write_body(buf, |buf| {
// we don't support params, so always 0
buf.put_i16(0);
Ok::<_, io::Error>(())
})
.unwrap();
}
BeMessage::ParseComplete => {
buf.put_u8(b'1');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
}
BeMessage::ReadyForQuery => {
buf.put_u8(b'Z');
write_body(buf, |buf| {
buf.put_u8(b'I');
Ok::<_, io::Error>(())
})
.unwrap();
}
BeMessage::RowDescription(rows) => {
buf.put_u8(b'T');
write_body(buf, |buf| {
buf.put_i16(rows.len() as i16); // # of fields
for row in rows.iter() {
write_cstr(row.name, buf)?;
buf.put_i32(0); /* table oid */
buf.put_i16(0); /* attnum */
buf.put_u32(row.typoid);
buf.put_i16(row.typlen);
buf.put_i32(-1); /* typmod */
buf.put_i16(0); /* format code */
}
Ok::<_, io::Error>(())
})?;
}
BeMessage::XLogData(body) => {
buf.put_u8(b'd');
write_body(buf, |buf| {
buf.put_u8(b'w');
buf.put_u64(body.wal_start);
buf.put_u64(body.wal_end);
buf.put_u64(body.timestamp);
buf.put_slice(body.data);
Ok::<_, io::Error>(())
})
.unwrap();
}
}
Ok(())
}
}