mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-08 14:02:55 +00:00
Consolidate PG proto parsing-deparsing and backend code.
Now postgres_backend communicates with the client, passing queries to the provided handler; we have two currently, for wal_acceptor and pageserver. Now BytesMut is again used for writing data to avoid manual message length calculation. ref #118
This commit is contained in:
3
Cargo.lock
generated
3
Cargo.lock
generated
@@ -2478,9 +2478,12 @@ dependencies = [
|
||||
name = "zenith_utils"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bincode",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"hex-literal",
|
||||
"log",
|
||||
"serde",
|
||||
"thiserror",
|
||||
"workspace_hack",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,20 +1,18 @@
|
||||
//! This module implements the replication protocol, starting with the
|
||||
//! "START REPLICATION" message.
|
||||
//! This module implements the streaming side of replication protocol, starting
|
||||
//! with the "START REPLICATION" message.
|
||||
|
||||
use crate::pq_protocol::{BeMessage, FeMessage};
|
||||
use crate::send_wal::SendWalConn;
|
||||
use crate::send_wal::SendWalHandler;
|
||||
use crate::timeline::{Timeline, TimelineTools};
|
||||
use crate::WalAcceptorConf;
|
||||
use anyhow::{anyhow, Result};
|
||||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
use bytes::Bytes;
|
||||
use log::*;
|
||||
use postgres_ffi::xlog_utils::{get_current_timestamp, TimestampTz, XLogFileName, MAX_SEND_SIZE};
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::cmp::min;
|
||||
use std::fs::File;
|
||||
use std::io::{BufReader, Read, Seek, SeekFrom, Write};
|
||||
use std::net::{Shutdown, TcpStream};
|
||||
use std::io::{BufReader, Read, Seek, SeekFrom};
|
||||
use std::net::TcpStream;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::thread::sleep;
|
||||
@@ -22,10 +20,9 @@ use std::time::Duration;
|
||||
use std::{str, thread};
|
||||
use zenith_utils::bin_ser::BeSer;
|
||||
use zenith_utils::lsn::Lsn;
|
||||
use zenith_utils::postgres_backend::PostgresBackend;
|
||||
use zenith_utils::pq_proto::{BeMessage, FeMessage, XLogDataBody};
|
||||
|
||||
const XLOG_HDR_SIZE: usize = 1 + 8 * 3; /* 'w' + startPos + walEnd + timestamp */
|
||||
const LIBPQ_HDR_SIZE: usize = 5; /* 1 byte with message type + 4 bytes length */
|
||||
const LIBPQ_MSG_SIZE_OFFS: usize = 1;
|
||||
pub const END_REPLICATION_MARKER: Lsn = Lsn::MAX;
|
||||
|
||||
type FullTransactionId = u64;
|
||||
@@ -40,37 +37,16 @@ pub struct HotStandbyFeedback {
|
||||
|
||||
/// A network connection that's speaking the replication protocol.
|
||||
pub struct ReplicationConn {
|
||||
timeline: Option<Arc<Timeline>>,
|
||||
/// Postgres connection, buffered input
|
||||
///
|
||||
/// This is an `Option` because we will spawn a background thread that will
|
||||
/// `take` it from us.
|
||||
stream_in: Option<BufReader<TcpStream>>,
|
||||
/// Postgres connection, output
|
||||
stream_out: TcpStream,
|
||||
/// wal acceptor configuration
|
||||
conf: WalAcceptorConf,
|
||||
/// assigned application name
|
||||
appname: Option<String>,
|
||||
}
|
||||
|
||||
// Separate thread is reading keepalives from the socket. When main one
|
||||
// finishes, tell it to get down by shutdowning the socket.
|
||||
impl Drop for ReplicationConn {
|
||||
fn drop(&mut self) {
|
||||
let _res = self.stream_out.shutdown(Shutdown::Both);
|
||||
}
|
||||
}
|
||||
|
||||
impl ReplicationConn {
|
||||
/// Create a new `SendWal`, consuming the `Connection`.
|
||||
pub fn new(conn: SendWalConn) -> Self {
|
||||
/// Create a new `ReplicationConn`
|
||||
pub fn new(pgb: &mut PostgresBackend) -> Self {
|
||||
Self {
|
||||
timeline: conn.timeline,
|
||||
stream_in: Some(conn.stream_in),
|
||||
stream_out: conn.stream_out,
|
||||
conf: conn.conf,
|
||||
appname: None,
|
||||
stream_in: pgb.take_stream_in(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,9 +58,9 @@ impl ReplicationConn {
|
||||
// Wait for replica's feedback.
|
||||
// We only handle `CopyData` messages. Anything else is ignored.
|
||||
loop {
|
||||
match FeMessage::read_from(&mut stream_in)? {
|
||||
FeMessage::CopyData(m) => {
|
||||
let feedback = HotStandbyFeedback::des(&m.body)?;
|
||||
match FeMessage::read(&mut stream_in)? {
|
||||
Some(FeMessage::CopyData(m)) => {
|
||||
let feedback = HotStandbyFeedback::des(&m)?;
|
||||
timeline.add_hs_feedback(feedback)
|
||||
}
|
||||
msg => {
|
||||
@@ -128,9 +104,14 @@ impl ReplicationConn {
|
||||
///
|
||||
/// Handle START_REPLICATION replication command
|
||||
///
|
||||
pub fn run(&mut self, cmd: &Bytes) -> Result<()> {
|
||||
pub fn run(
|
||||
&mut self,
|
||||
swh: &mut SendWalHandler,
|
||||
pgb: &mut PostgresBackend,
|
||||
cmd: &Bytes,
|
||||
) -> Result<()> {
|
||||
// spawn the background thread which receives HotStandbyFeedback messages.
|
||||
let bg_timeline = Arc::clone(self.timeline.get());
|
||||
let bg_timeline = Arc::clone(swh.timeline.get());
|
||||
let bg_stream_in = self.stream_in.take().unwrap();
|
||||
|
||||
thread::spawn(move || {
|
||||
@@ -143,7 +124,7 @@ impl ReplicationConn {
|
||||
|
||||
let mut wal_seg_size: usize;
|
||||
loop {
|
||||
wal_seg_size = self.timeline.get().get_info().server.wal_seg_size as usize;
|
||||
wal_seg_size = swh.timeline.get().get_info().server.wal_seg_size as usize;
|
||||
if wal_seg_size == 0 {
|
||||
error!("Can not start replication before connecting to wal_proposer");
|
||||
sleep(Duration::from_secs(1));
|
||||
@@ -151,19 +132,17 @@ impl ReplicationConn {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let (wal_end, timeline) = self.timeline.find_end_of_wal(&self.conf.data_dir, false);
|
||||
let (wal_end, timeline) = swh.timeline.find_end_of_wal(&swh.conf.data_dir, false);
|
||||
if start_pos == Lsn(0) {
|
||||
start_pos = wal_end;
|
||||
}
|
||||
if stop_pos == Lsn(0) && self.appname == Some("wal_proposer_recovery".to_string()) {
|
||||
if stop_pos == Lsn(0) && swh.appname == Some("wal_proposer_recovery".to_string()) {
|
||||
stop_pos = wal_end;
|
||||
}
|
||||
info!("Start replication from {} till {}", start_pos, stop_pos);
|
||||
|
||||
let mut outbuf = BytesMut::new();
|
||||
BeMessage::write(&mut outbuf, &BeMessage::Copy);
|
||||
self.send(&outbuf)?;
|
||||
outbuf.clear();
|
||||
// switch to copy
|
||||
pgb.write_message(&BeMessage::CopyBothResponse)?;
|
||||
|
||||
let mut end_pos: Lsn;
|
||||
let mut wal_file: Option<File> = None;
|
||||
@@ -179,7 +158,7 @@ impl ReplicationConn {
|
||||
end_pos = stop_pos;
|
||||
} else {
|
||||
/* normal mode */
|
||||
let timeline = self.timeline.get();
|
||||
let timeline = swh.timeline.get();
|
||||
end_pos = timeline.wait_for_lsn(start_pos);
|
||||
}
|
||||
if end_pos == END_REPLICATION_MARKER {
|
||||
@@ -193,8 +172,8 @@ impl ReplicationConn {
|
||||
// Open a new file.
|
||||
let segno = start_pos.segment_number(wal_seg_size);
|
||||
let wal_file_name = XLogFileName(timeline, segno, wal_seg_size);
|
||||
let timeline_id = self.timeline.get().timelineid.to_string();
|
||||
let wal_file_path = self.conf.data_dir.join(timeline_id).join(wal_file_name);
|
||||
let timeline_id = swh.timeline.get().timelineid.to_string();
|
||||
let wal_file_path = swh.conf.data_dir.join(timeline_id).join(wal_file_name);
|
||||
Self::open_wal_file(&wal_file_path)?
|
||||
}
|
||||
};
|
||||
@@ -207,32 +186,19 @@ impl ReplicationConn {
|
||||
let send_size = min(send_size, wal_seg_size - xlogoff);
|
||||
let send_size = min(send_size, MAX_SEND_SIZE);
|
||||
|
||||
let msg_size = LIBPQ_HDR_SIZE + XLOG_HDR_SIZE + send_size;
|
||||
|
||||
// Read some data from the file.
|
||||
let mut file_buf = vec![0u8; send_size];
|
||||
file.seek(SeekFrom::Start(xlogoff as u64))?;
|
||||
file.read_exact(&mut file_buf)?;
|
||||
|
||||
// Write some data to the network socket.
|
||||
// FIXME: turn these into structs.
|
||||
// 'd' is CopyData;
|
||||
// 'w' is "WAL records"
|
||||
// https://www.postgresql.org/docs/9.1/protocol-message-formats.html
|
||||
// src/backend/replication/walreceiver.c
|
||||
outbuf.clear();
|
||||
outbuf.put_u8(b'd');
|
||||
outbuf.put_u32((msg_size - LIBPQ_MSG_SIZE_OFFS) as u32);
|
||||
outbuf.put_u8(b'w');
|
||||
outbuf.put_u64(start_pos.0);
|
||||
outbuf.put_u64(end_pos.0);
|
||||
outbuf.put_u64(get_current_timestamp());
|
||||
pgb.write_message(&BeMessage::XLogData(XLogDataBody {
|
||||
wal_start: start_pos.0,
|
||||
wal_end: end_pos.0,
|
||||
timestamp: get_current_timestamp(),
|
||||
data: &file_buf,
|
||||
}))?;
|
||||
|
||||
assert!(outbuf.len() + file_buf.len() == msg_size);
|
||||
// This thread has exclusive access to the TcpStream, so it's fine
|
||||
// to do this as two separate calls.
|
||||
self.send(&outbuf)?;
|
||||
self.send(&file_buf)?;
|
||||
start_pos += send_size as u64;
|
||||
|
||||
debug!("Sent WAL to page server up to {}", end_pos);
|
||||
@@ -245,10 +211,4 @@ impl ReplicationConn {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send messages on the network.
|
||||
fn send(&mut self, buf: &[u8]) -> Result<()> {
|
||||
self.stream_out.write_all(buf.as_ref())?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,111 +1,69 @@
|
||||
//! This implements the libpq replication protocol between wal_acceptor
|
||||
//! and replicas/pagers
|
||||
//! Part of WAL acceptor pretending to be Postgres, streaming xlog to
|
||||
//! pageserver/any other consumer.
|
||||
//!
|
||||
|
||||
use crate::pq_protocol::{
|
||||
BeMessage, FeMessage, FeStartupMessage, RowDescriptor, StartupRequestCode,
|
||||
};
|
||||
use crate::replication::ReplicationConn;
|
||||
use crate::timeline::{Timeline, TimelineTools};
|
||||
use crate::WalAcceptorConf;
|
||||
use anyhow::{bail, Result};
|
||||
use bytes::BytesMut;
|
||||
use log::*;
|
||||
use std::io::{BufReader, Write};
|
||||
use std::net::{SocketAddr, TcpStream};
|
||||
use bytes::Bytes;
|
||||
use pageserver::ZTimelineId;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use zenith_utils::postgres_backend;
|
||||
use zenith_utils::postgres_backend::PostgresBackend;
|
||||
use zenith_utils::pq_proto::{BeMessage, FeStartupMessage, RowDescriptor};
|
||||
|
||||
/// A network connection that's speaking the libpq replication protocol.
|
||||
pub struct SendWalConn {
|
||||
pub timeline: Option<Arc<Timeline>>,
|
||||
/// Postgres connection, buffered input
|
||||
pub stream_in: BufReader<TcpStream>,
|
||||
/// Postgres connection, output
|
||||
pub stream_out: TcpStream,
|
||||
/// The cached result of socket.peer_addr()
|
||||
pub peer_addr: SocketAddr,
|
||||
/// Handler for streaming WAL from acceptor
|
||||
pub struct SendWalHandler {
|
||||
/// wal acceptor configuration
|
||||
pub conf: WalAcceptorConf,
|
||||
/// assigned application name
|
||||
appname: Option<String>,
|
||||
pub appname: Option<String>,
|
||||
pub timeline: Option<Arc<Timeline>>,
|
||||
}
|
||||
|
||||
impl SendWalConn {
|
||||
/// Create a new `SendWal`, consuming the `Connection`.
|
||||
pub fn new(socket: TcpStream, conf: WalAcceptorConf) -> Result<Self> {
|
||||
let peer_addr = socket.peer_addr()?;
|
||||
let conn = SendWalConn {
|
||||
timeline: None,
|
||||
stream_in: BufReader::new(socket.try_clone()?),
|
||||
stream_out: socket,
|
||||
peer_addr,
|
||||
conf,
|
||||
appname: None,
|
||||
};
|
||||
Ok(conn)
|
||||
impl postgres_backend::Handler for SendWalHandler {
|
||||
fn startup(&mut self, _pgb: &mut PostgresBackend, sm: &FeStartupMessage) -> Result<()> {
|
||||
match sm.params.get("ztimelineid") {
|
||||
Some(ref ztimelineid) => {
|
||||
let ztlid = ZTimelineId::from_str(ztimelineid)?;
|
||||
self.timeline.set(ztlid)?;
|
||||
}
|
||||
_ => bail!("timelineid is required"),
|
||||
}
|
||||
if let Some(app_name) = sm.params.get("application_name") {
|
||||
self.appname = Some(app_name.clone());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
///
|
||||
/// Send WAL to replica or WAL receiver using standard libpq replication protocol
|
||||
///
|
||||
pub fn run(mut self) -> Result<()> {
|
||||
let peer_addr = self.peer_addr;
|
||||
info!("WAL sender to {:?} is started", peer_addr);
|
||||
|
||||
// Handle the startup message first.
|
||||
|
||||
let m = FeStartupMessage::read_from(&mut self.stream_in)?;
|
||||
trace!("got startup message {:?}", m);
|
||||
match m.kind {
|
||||
StartupRequestCode::NegotiateGss | StartupRequestCode::NegotiateSsl => {
|
||||
let mut buf = BytesMut::new();
|
||||
BeMessage::write(&mut buf, &BeMessage::Negotiate);
|
||||
info!("SSL requested");
|
||||
self.stream_out.write_all(&buf)?;
|
||||
}
|
||||
StartupRequestCode::Normal => {
|
||||
let mut buf = BytesMut::new();
|
||||
BeMessage::write(&mut buf, &BeMessage::AuthenticationOk);
|
||||
BeMessage::write(&mut buf, &BeMessage::ReadyForQuery);
|
||||
self.stream_out.write_all(&buf)?;
|
||||
self.timeline.set(m.timelineid)?;
|
||||
self.appname = m.appname;
|
||||
}
|
||||
StartupRequestCode::Cancel => return Ok(()),
|
||||
fn process_query(&mut self, pgb: &mut PostgresBackend, query_string: Bytes) -> Result<()> {
|
||||
if query_string.starts_with(b"IDENTIFY_SYSTEM") {
|
||||
self.handle_identify_system(pgb)?;
|
||||
Ok(())
|
||||
} else if query_string.starts_with(b"START_REPLICATION") {
|
||||
ReplicationConn::new(pgb).run(self, pgb, &query_string)?;
|
||||
Ok(())
|
||||
} else {
|
||||
bail!("Unexpected command {:?}", query_string);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
let msg = FeMessage::read_from(&mut self.stream_in)?;
|
||||
match msg {
|
||||
FeMessage::Query(q) => {
|
||||
trace!("got query {:?}", q.body);
|
||||
|
||||
if q.body.starts_with(b"IDENTIFY_SYSTEM") {
|
||||
self.handle_identify_system()?;
|
||||
} else if q.body.starts_with(b"START_REPLICATION") {
|
||||
// Create a new replication object, consuming `self`.
|
||||
ReplicationConn::new(self).run(&q.body)?;
|
||||
break;
|
||||
} else {
|
||||
bail!("Unexpected command {:?}", q.body);
|
||||
}
|
||||
}
|
||||
FeMessage::Terminate => {
|
||||
break;
|
||||
}
|
||||
_ => {
|
||||
bail!("unexpected message");
|
||||
}
|
||||
}
|
||||
impl SendWalHandler {
|
||||
pub fn new(conf: WalAcceptorConf) -> Self {
|
||||
SendWalHandler {
|
||||
conf,
|
||||
appname: None,
|
||||
timeline: None,
|
||||
}
|
||||
info!("WAL sender to {:?} is finished", peer_addr);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
///
|
||||
/// Handle IDENTIFY_SYSTEM replication command
|
||||
///
|
||||
fn handle_identify_system(&mut self) -> Result<()> {
|
||||
fn handle_identify_system(&mut self, pgb: &mut PostgresBackend) -> Result<()> {
|
||||
let (start_pos, timeline) = self.timeline.find_end_of_wal(&self.conf.data_dir, false);
|
||||
let lsn = start_pos.to_string();
|
||||
let tli = timeline.to_string();
|
||||
@@ -114,42 +72,39 @@ impl SendWalConn {
|
||||
let tli_bytes = tli.as_bytes();
|
||||
let sysid_bytes = sysid.as_bytes();
|
||||
|
||||
let mut outbuf = BytesMut::new();
|
||||
BeMessage::write(
|
||||
&mut outbuf,
|
||||
&BeMessage::RowDescription(&[
|
||||
RowDescriptor {
|
||||
name: b"systemid\0",
|
||||
typoid: 25,
|
||||
typlen: -1,
|
||||
},
|
||||
RowDescriptor {
|
||||
name: b"timeline\0",
|
||||
typoid: 23,
|
||||
typlen: 4,
|
||||
},
|
||||
RowDescriptor {
|
||||
name: b"xlogpos\0",
|
||||
typoid: 25,
|
||||
typlen: -1,
|
||||
},
|
||||
RowDescriptor {
|
||||
name: b"dbname\0",
|
||||
typoid: 25,
|
||||
typlen: -1,
|
||||
},
|
||||
]),
|
||||
);
|
||||
BeMessage::write(
|
||||
&mut outbuf,
|
||||
&BeMessage::DataRow(&[Some(sysid_bytes), Some(tli_bytes), Some(lsn_bytes), None]),
|
||||
);
|
||||
BeMessage::write(
|
||||
&mut outbuf,
|
||||
&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM\0"),
|
||||
);
|
||||
BeMessage::write(&mut outbuf, &BeMessage::ReadyForQuery);
|
||||
self.stream_out.write_all(&outbuf)?;
|
||||
pgb.write_message_noflush(&BeMessage::RowDescription(&[
|
||||
RowDescriptor {
|
||||
name: b"systemid",
|
||||
typoid: 25,
|
||||
typlen: -1,
|
||||
..Default::default()
|
||||
},
|
||||
RowDescriptor {
|
||||
name: b"timeline",
|
||||
typoid: 23,
|
||||
typlen: 4,
|
||||
..Default::default()
|
||||
},
|
||||
RowDescriptor {
|
||||
name: b"xlogpos",
|
||||
typoid: 25,
|
||||
typlen: -1,
|
||||
..Default::default()
|
||||
},
|
||||
RowDescriptor {
|
||||
name: b"dbname",
|
||||
typoid: 25,
|
||||
typlen: -1,
|
||||
..Default::default()
|
||||
},
|
||||
]))?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[
|
||||
Some(sysid_bytes),
|
||||
Some(tli_bytes),
|
||||
Some(lsn_bytes),
|
||||
None,
|
||||
]))?
|
||||
.write_message(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM"))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,8 +9,9 @@ use std::net::{TcpListener, TcpStream};
|
||||
use std::thread;
|
||||
|
||||
use crate::receive_wal::ReceiveWalConn;
|
||||
use crate::send_wal::SendWalConn;
|
||||
use crate::send_wal::SendWalHandler;
|
||||
use crate::WalAcceptorConf;
|
||||
use zenith_utils::postgres_backend::PostgresBackend;
|
||||
|
||||
/// Accept incoming TCP connections and spawn them into a background thread.
|
||||
pub fn thread_main(conf: WalAcceptorConf) -> Result<()> {
|
||||
@@ -48,7 +49,10 @@ fn handle_socket(mut socket: TcpStream, conf: WalAcceptorConf) -> Result<()> {
|
||||
socket.read_exact(&mut [0u8; 4])?;
|
||||
ReceiveWalConn::new(socket, conf)?.run()?; // internal protocol between wal_proposer and wal_acceptor
|
||||
} else {
|
||||
SendWalConn::new(socket, conf)?.run()?; // libpq replication protocol between wal_acceptor and replicas/pagers
|
||||
let mut conn_handler = SendWalHandler::new(conf);
|
||||
let mut pgbackend = PostgresBackend::new(socket)?;
|
||||
// libpq replication protocol between wal_acceptor and replicas/pagers
|
||||
pgbackend.run(&mut conn_handler)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -5,6 +5,10 @@ authors = ["Eric Seppanen <eric@zenith.tech>"]
|
||||
edition = "2018"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0"
|
||||
bytes = "1.0.1"
|
||||
byteorder = "1.4.3"
|
||||
log = "0.4.14"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
bincode = "1.3"
|
||||
thiserror = "1.0"
|
||||
|
||||
@@ -10,3 +10,6 @@ pub mod seqwait;
|
||||
// pub mod seqwait_async;
|
||||
|
||||
pub mod bin_ser;
|
||||
|
||||
pub mod postgres_backend;
|
||||
pub mod pq_proto;
|
||||
|
||||
181
zenith_utils/src/postgres_backend.rs
Normal file
181
zenith_utils/src/postgres_backend.rs
Normal file
@@ -0,0 +1,181 @@
|
||||
//! Server-side synchronous Postgres connection, as limited as we need.
|
||||
//! To use, create PostgresBackend and run() it, passing the Handler
|
||||
//! implementation determining how to process the queries. Currently its API
|
||||
//! is rather narrow, but we can extend it once required.
|
||||
|
||||
use crate::pq_proto::{BeMessage, FeMessage, FeStartupMessage, StartupRequestCode};
|
||||
use anyhow::bail;
|
||||
use anyhow::Result;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use log::*;
|
||||
use std::io;
|
||||
use std::io::{BufReader, Write};
|
||||
use std::net::{Shutdown, TcpStream};
|
||||
|
||||
pub trait Handler {
|
||||
/// Handle single query.
|
||||
/// postgres_backend will issue ReadyForQuery after calling this (this
|
||||
/// might be not what we want after CopyData streaming, but currently we don't
|
||||
/// care).
|
||||
fn process_query(&mut self, pgb: &mut PostgresBackend, query_string: Bytes) -> Result<()>;
|
||||
/// Called on startup packet receival, allows to process params.
|
||||
fn startup(&mut self, _pgb: &mut PostgresBackend, _sm: &FeStartupMessage) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresBackend {
|
||||
// replication.rs wants to handle reading on its own in separate thread, so
|
||||
// wrap in Option to be able to take and transfer the BufReader. Ugly, but I
|
||||
// have no better ideas.
|
||||
stream_in: Option<BufReader<TcpStream>>,
|
||||
stream_out: TcpStream,
|
||||
// Output buffer. c.f. BeMessage::write why we are using BytesMut here.
|
||||
buf_out: BytesMut,
|
||||
init_done: bool,
|
||||
}
|
||||
|
||||
// In replication.rs a separate thread is reading keepalives from the
|
||||
// socket. When main one finishes, tell it to get down by shutdowning the
|
||||
// socket.
|
||||
impl Drop for PostgresBackend {
|
||||
fn drop(&mut self) {
|
||||
let _res = self.stream_out.shutdown(Shutdown::Both);
|
||||
}
|
||||
}
|
||||
|
||||
impl PostgresBackend {
|
||||
pub fn new(socket: TcpStream) -> Result<Self, std::io::Error> {
|
||||
let mut pb = PostgresBackend {
|
||||
stream_in: None,
|
||||
stream_out: socket,
|
||||
buf_out: BytesMut::with_capacity(10 * 1024),
|
||||
init_done: false,
|
||||
};
|
||||
// if socket cloning fails, report the error and bail out
|
||||
pb.stream_in = match pb.stream_out.try_clone() {
|
||||
Ok(read_sock) => Some(BufReader::new(read_sock)),
|
||||
Err(error) => {
|
||||
let errmsg = format!("{}", error);
|
||||
let _res = pb.write_message_noflush(&BeMessage::ErrorResponse(errmsg));
|
||||
return Err(error);
|
||||
}
|
||||
};
|
||||
Ok(pb)
|
||||
}
|
||||
|
||||
/// Get direct reference (into the Option) to the read stream.
|
||||
fn get_stream_in(&mut self) -> Result<&mut BufReader<TcpStream>> {
|
||||
match self.stream_in {
|
||||
Some(ref mut stream_in) => Ok(stream_in),
|
||||
None => bail!("stream_in was taken"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn take_stream_in(&mut self) -> Option<BufReader<TcpStream>> {
|
||||
self.stream_in.take()
|
||||
}
|
||||
|
||||
/// Read full message or return None if connection is closed.
|
||||
pub fn read_message(&mut self) -> Result<Option<FeMessage>> {
|
||||
if !self.init_done {
|
||||
FeStartupMessage::read(self.get_stream_in()?)
|
||||
} else {
|
||||
FeMessage::read(self.get_stream_in()?)
|
||||
}
|
||||
}
|
||||
|
||||
/// Write message into internal output buffer.
|
||||
pub fn write_message_noflush(&mut self, message: &BeMessage) -> io::Result<&mut Self> {
|
||||
BeMessage::write(&mut self.buf_out, message)?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Flush output buffer into the socket.
|
||||
pub fn flush(&mut self) -> io::Result<&mut Self> {
|
||||
self.stream_out.write_all(&self.buf_out)?;
|
||||
self.buf_out.clear();
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Write message into internal buffer and flush it.
|
||||
pub fn write_message(&mut self, message: &BeMessage) -> io::Result<&mut Self> {
|
||||
self.write_message_noflush(message)?;
|
||||
self.flush()
|
||||
}
|
||||
|
||||
pub fn run(&mut self, handler: &mut impl Handler) -> Result<()> {
|
||||
let peer_addr = self.stream_out.peer_addr()?;
|
||||
info!("postgres backend to {:?} started", peer_addr);
|
||||
let mut unnamed_query_string = Bytes::new();
|
||||
|
||||
loop {
|
||||
let msg = self.read_message()?;
|
||||
trace!("got message {:?}", msg);
|
||||
match msg {
|
||||
Some(FeMessage::StartupMessage(m)) => {
|
||||
trace!("got startup message {:?}", m);
|
||||
|
||||
handler.startup(self, &m)?;
|
||||
|
||||
match m.kind {
|
||||
StartupRequestCode::NegotiateGss | StartupRequestCode::NegotiateSsl => {
|
||||
info!("SSL requested");
|
||||
self.write_message(&BeMessage::Negotiate)?;
|
||||
}
|
||||
StartupRequestCode::Normal => {
|
||||
self.write_message_noflush(&BeMessage::AuthenticationOk)?;
|
||||
// psycopg2 will not connect if client_encoding is not
|
||||
// specified by the server
|
||||
self.write_message_noflush(&BeMessage::ParameterStatus)?;
|
||||
self.write_message(&BeMessage::ReadyForQuery)?;
|
||||
self.init_done = true;
|
||||
}
|
||||
StartupRequestCode::Cancel => break,
|
||||
}
|
||||
}
|
||||
Some(FeMessage::Query(m)) => {
|
||||
trace!("got query {:?}", m.body);
|
||||
// xxx distinguish fatal and recoverable errors?
|
||||
if let Err(e) = handler.process_query(self, m.body) {
|
||||
let errmsg = format!("{}", e);
|
||||
self.write_message_noflush(&BeMessage::ErrorResponse(errmsg))?;
|
||||
}
|
||||
self.write_message(&BeMessage::ReadyForQuery)?;
|
||||
}
|
||||
Some(FeMessage::Parse(m)) => {
|
||||
unnamed_query_string = m.query_string;
|
||||
self.write_message(&BeMessage::ParseComplete)?;
|
||||
}
|
||||
Some(FeMessage::Describe(_)) => {
|
||||
self.write_message_noflush(&BeMessage::ParameterDescription)?
|
||||
.write_message(&BeMessage::NoData)?;
|
||||
}
|
||||
Some(FeMessage::Bind(_)) => {
|
||||
self.write_message(&BeMessage::BindComplete)?;
|
||||
}
|
||||
Some(FeMessage::Close(_)) => {
|
||||
self.write_message(&BeMessage::CloseComplete)?;
|
||||
}
|
||||
Some(FeMessage::Execute(_)) => {
|
||||
handler.process_query(self, unnamed_query_string.clone())?;
|
||||
}
|
||||
Some(FeMessage::Sync) => {
|
||||
self.write_message(&BeMessage::ReadyForQuery)?;
|
||||
}
|
||||
Some(FeMessage::Terminate) => {
|
||||
break;
|
||||
}
|
||||
None => {
|
||||
info!("connection closed");
|
||||
break;
|
||||
}
|
||||
x => {
|
||||
bail!("unexpected message type : {:?}", x);
|
||||
}
|
||||
}
|
||||
}
|
||||
info!("postgres backend to {:?} exited", peer_addr);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
639
zenith_utils/src/pq_proto.rs
Normal file
639
zenith_utils/src/pq_proto.rs
Normal file
@@ -0,0 +1,639 @@
|
||||
//! Postgres protocol messages serialization-deserialization. See
|
||||
//! https://www.postgresql.org/docs/devel/protocol-message-formats.html
|
||||
//! on message formats.
|
||||
|
||||
use anyhow::{anyhow, bail, Result};
|
||||
use byteorder::{BigEndian, ByteOrder};
|
||||
use byteorder::{ReadBytesExt, BE};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
// use postgres_ffi::xlog_utils::TimestampTz;
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use std::io::Read;
|
||||
use std::str;
|
||||
|
||||
pub type Oid = u32;
|
||||
pub type SystemId = u64;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum FeMessage {
|
||||
StartupMessage(FeStartupMessage),
|
||||
Query(FeQueryMessage), // Simple query
|
||||
Parse(FeParseMessage), // Extended query protocol
|
||||
Describe(FeDescribeMessage),
|
||||
Bind(FeBindMessage),
|
||||
Execute(FeExecuteMessage),
|
||||
Close(FeCloseMessage),
|
||||
Sync,
|
||||
Terminate,
|
||||
CopyData(Bytes),
|
||||
CopyDone,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FeStartupMessage {
|
||||
pub version: u32,
|
||||
pub kind: StartupRequestCode,
|
||||
// optional params arriving in startup packet
|
||||
pub params: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum StartupRequestCode {
|
||||
Cancel,
|
||||
NegotiateSsl,
|
||||
NegotiateGss,
|
||||
Normal,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FeQueryMessage {
|
||||
pub body: Bytes,
|
||||
}
|
||||
|
||||
// We only support the simple case of Parse on unnamed prepared statement and
|
||||
// no params
|
||||
#[derive(Debug)]
|
||||
pub struct FeParseMessage {
|
||||
pub query_string: Bytes,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FeDescribeMessage {
|
||||
pub kind: u8, // 'S' to describe a prepared statement; or 'P' to describe a portal.
|
||||
// we only support unnamed prepared stmt or portal
|
||||
}
|
||||
|
||||
// we only support unnamed prepared stmt and portal
|
||||
#[derive(Debug)]
|
||||
pub struct FeBindMessage {}
|
||||
|
||||
// we only support unnamed prepared stmt or portal
|
||||
#[derive(Debug)]
|
||||
pub struct FeExecuteMessage {
|
||||
/// max # of rows
|
||||
pub maxrows: i32,
|
||||
}
|
||||
|
||||
// we only support unnamed prepared stmt and portal
|
||||
#[derive(Debug)]
|
||||
pub struct FeCloseMessage {}
|
||||
|
||||
impl FeMessage {
|
||||
/// Read one message from the stream.
|
||||
pub fn read(stream: &mut impl Read) -> anyhow::Result<Option<FeMessage>> {
|
||||
// Each libpq message begins with a message type byte, followed by message length
|
||||
// If the client closes the connection, return None. But if the client closes the
|
||||
// connection in the middle of a message, we will return an error.
|
||||
let tag = match stream.read_u8() {
|
||||
Ok(b) => b,
|
||||
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
let len = stream.read_u32::<BE>()?;
|
||||
|
||||
// The message length includes itself, so it better be at least 4
|
||||
let bodylen = len.checked_sub(4).ok_or(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"invalid message length: parsing u32",
|
||||
))?;
|
||||
|
||||
// Read message body
|
||||
let mut body_buf: Vec<u8> = vec![0; bodylen as usize];
|
||||
stream.read_exact(&mut body_buf)?;
|
||||
|
||||
let body = Bytes::from(body_buf);
|
||||
|
||||
// Parse it
|
||||
match tag {
|
||||
b'Q' => Ok(Some(FeMessage::Query(FeQueryMessage { body }))),
|
||||
b'P' => Ok(Some(FeParseMessage::parse(body)?)),
|
||||
b'D' => Ok(Some(FeDescribeMessage::parse(body)?)),
|
||||
b'E' => Ok(Some(FeExecuteMessage::parse(body)?)),
|
||||
b'B' => Ok(Some(FeBindMessage::parse(body)?)),
|
||||
b'C' => Ok(Some(FeCloseMessage::parse(body)?)),
|
||||
b'S' => Ok(Some(FeMessage::Sync)),
|
||||
b'X' => Ok(Some(FeMessage::Terminate)),
|
||||
b'd' => Ok(Some(FeMessage::CopyData(body))),
|
||||
b'c' => Ok(Some(FeMessage::CopyDone)),
|
||||
tag => Err(anyhow!("unknown message tag: {},'{:?}'", tag, body)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FeStartupMessage {
|
||||
/// Read startup message from the stream.
|
||||
pub fn read(stream: &mut impl std::io::Read) -> anyhow::Result<Option<FeMessage>> {
|
||||
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
|
||||
const CANCEL_REQUEST_CODE: u32 = (1234 << 16) | 5678;
|
||||
const NEGOTIATE_SSL_CODE: u32 = (1234 << 16) | 5679;
|
||||
const NEGOTIATE_GSS_CODE: u32 = (1234 << 16) | 5680;
|
||||
|
||||
// Read length. If the connection is closed before reading anything (or before
|
||||
// reading 4 bytes, to be precise), return None to indicate that the connection
|
||||
// was closed. This matches the PostgreSQL server's behavior, which avoids noise
|
||||
// in the log if the client opens connection but closes it immediately.
|
||||
let len = match stream.read_u32::<BE>() {
|
||||
Ok(len) => len as usize,
|
||||
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
|
||||
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
|
||||
bail!("invalid message length");
|
||||
}
|
||||
|
||||
let version = stream.read_u32::<BE>()?;
|
||||
let kind = match version {
|
||||
CANCEL_REQUEST_CODE => StartupRequestCode::Cancel,
|
||||
NEGOTIATE_SSL_CODE => StartupRequestCode::NegotiateSsl,
|
||||
NEGOTIATE_GSS_CODE => StartupRequestCode::NegotiateGss,
|
||||
_ => StartupRequestCode::Normal,
|
||||
};
|
||||
|
||||
// the rest of startup packet are params
|
||||
let params_len = len - 8;
|
||||
let mut params_bytes = vec![0u8; params_len];
|
||||
stream.read_exact(params_bytes.as_mut())?;
|
||||
|
||||
// Then null-terminated (String) pairs of param name / param value go.
|
||||
let params_str = str::from_utf8(¶ms_bytes).unwrap();
|
||||
let params = params_str.split('\0');
|
||||
let mut params_hash: HashMap<String, String> = HashMap::new();
|
||||
for pair in params.collect::<Vec<_>>().chunks_exact(2) {
|
||||
let name = pair[0];
|
||||
let value = pair[1];
|
||||
if name == "options" {
|
||||
// deprecated way of passing params as cmd line args
|
||||
for cmdopt in value.split(' ') {
|
||||
let nameval: Vec<&str> = cmdopt.split('=').collect();
|
||||
if nameval.len() == 2 {
|
||||
params_hash.insert(nameval[0].to_string(), nameval[1].to_string());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
params_hash.insert(name.to_string(), value.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Some(FeMessage::StartupMessage(FeStartupMessage {
|
||||
version,
|
||||
kind,
|
||||
params: params_hash,
|
||||
})))
|
||||
}
|
||||
}
|
||||
|
||||
impl FeParseMessage {
|
||||
pub fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
|
||||
let _pstmt_name = read_null_terminated(&mut buf)?;
|
||||
let query_string = read_null_terminated(&mut buf)?;
|
||||
let nparams = buf.get_i16();
|
||||
|
||||
// FIXME: the rust-postgres driver uses a named prepared statement
|
||||
// for copy_out(). We're not prepared to handle that correctly. For
|
||||
// now, just ignore the statement name, assuming that the client never
|
||||
// uses more than one prepared statement at a time.
|
||||
/*
|
||||
if !pstmt_name.is_empty() {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"named prepared statements not implemented in Parse",
|
||||
));
|
||||
}
|
||||
*/
|
||||
|
||||
if nparams != 0 {
|
||||
bail!("query params not implemented");
|
||||
}
|
||||
|
||||
Ok(FeMessage::Parse(FeParseMessage { query_string }))
|
||||
}
|
||||
}
|
||||
|
||||
impl FeDescribeMessage {
|
||||
pub fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
|
||||
let kind = buf.get_u8();
|
||||
let _pstmt_name = read_null_terminated(&mut buf)?;
|
||||
|
||||
// FIXME: see FeParseMessage::parse
|
||||
/*
|
||||
if !pstmt_name.is_empty() {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"named prepared statements not implemented in Describe",
|
||||
));
|
||||
}
|
||||
*/
|
||||
|
||||
if kind != b'S' {
|
||||
bail!("only prepared statmement Describe is implemented");
|
||||
}
|
||||
|
||||
Ok(FeMessage::Describe(FeDescribeMessage { kind }))
|
||||
}
|
||||
}
|
||||
|
||||
impl FeExecuteMessage {
|
||||
pub fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
|
||||
let portal_name = read_null_terminated(&mut buf)?;
|
||||
let maxrows = buf.get_i32();
|
||||
|
||||
if !portal_name.is_empty() {
|
||||
bail!("named portals not implemented");
|
||||
}
|
||||
|
||||
if maxrows != 0 {
|
||||
bail!("row limit in Execute message not supported");
|
||||
}
|
||||
|
||||
Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
|
||||
}
|
||||
}
|
||||
|
||||
impl FeBindMessage {
|
||||
pub fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
|
||||
let portal_name = read_null_terminated(&mut buf)?;
|
||||
let _pstmt_name = read_null_terminated(&mut buf)?;
|
||||
|
||||
if !portal_name.is_empty() {
|
||||
bail!("named portals not implemented");
|
||||
}
|
||||
|
||||
// FIXME: see FeParseMessage::parse
|
||||
/*
|
||||
if !pstmt_name.is_empty() {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"named prepared statements not implemented",
|
||||
));
|
||||
}
|
||||
*/
|
||||
|
||||
Ok(FeMessage::Bind(FeBindMessage {}))
|
||||
}
|
||||
}
|
||||
|
||||
impl FeCloseMessage {
|
||||
pub fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
|
||||
let _kind = buf.get_u8();
|
||||
let _pstmt_or_portal_name = read_null_terminated(&mut buf)?;
|
||||
|
||||
// FIXME: we do nothing with Close
|
||||
|
||||
Ok(FeMessage::Close(FeCloseMessage {}))
|
||||
}
|
||||
}
|
||||
|
||||
fn read_null_terminated(buf: &mut Bytes) -> anyhow::Result<Bytes> {
|
||||
let mut result = BytesMut::new();
|
||||
|
||||
loop {
|
||||
if !buf.has_remaining() {
|
||||
bail!("no null-terminator in string");
|
||||
}
|
||||
|
||||
let byte = buf.get_u8();
|
||||
|
||||
if byte == 0 {
|
||||
break;
|
||||
}
|
||||
result.put_u8(byte);
|
||||
}
|
||||
Ok(result.freeze())
|
||||
}
|
||||
|
||||
// Backend
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum BeMessage<'a> {
|
||||
AuthenticationOk,
|
||||
BindComplete,
|
||||
CommandComplete(&'a [u8]),
|
||||
ControlFile,
|
||||
CopyData(&'a [u8]),
|
||||
CopyDone,
|
||||
CopyInResponse,
|
||||
CopyOutResponse,
|
||||
CopyBothResponse,
|
||||
CloseComplete,
|
||||
// None means column is NULL
|
||||
DataRow(&'a [Option<&'a [u8]>]),
|
||||
ErrorResponse(String),
|
||||
// see https://www.postgresql.org/docs/devel/protocol-flow.html#id-1.10.5.7.11
|
||||
Negotiate,
|
||||
NoData,
|
||||
ParameterDescription,
|
||||
ParameterStatus,
|
||||
ParseComplete,
|
||||
ReadyForQuery,
|
||||
RowDescription(&'a [RowDescriptor<'a>]),
|
||||
XLogData(XLogDataBody<'a>),
|
||||
}
|
||||
|
||||
// One row desciption in RowDescription packet.
|
||||
#[derive(Debug)]
|
||||
pub struct RowDescriptor<'a> {
|
||||
pub name: &'a [u8],
|
||||
pub tableoid: Oid,
|
||||
pub attnum: i16,
|
||||
pub typoid: Oid,
|
||||
pub typlen: i16,
|
||||
pub typmod: i32,
|
||||
pub formatcode: i16,
|
||||
}
|
||||
|
||||
impl Default for RowDescriptor<'_> {
|
||||
fn default() -> RowDescriptor<'static> {
|
||||
RowDescriptor {
|
||||
name: b"",
|
||||
tableoid: 0,
|
||||
attnum: 0,
|
||||
typoid: 0,
|
||||
typlen: 0,
|
||||
typmod: 0,
|
||||
formatcode: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct XLogDataBody<'a> {
|
||||
pub wal_start: u64,
|
||||
pub wal_end: u64,
|
||||
pub timestamp: u64,
|
||||
pub data: &'a [u8],
|
||||
}
|
||||
|
||||
pub static HELLO_WORLD_ROW: BeMessage = BeMessage::DataRow(&[Some(b"hello world")]);
|
||||
pub const TEXT_OID: Oid = 25;
|
||||
// single text column
|
||||
pub static SINGLE_COL_ROWDESC: BeMessage = BeMessage::RowDescription(&[RowDescriptor {
|
||||
name: b"data",
|
||||
tableoid: 0,
|
||||
attnum: 0,
|
||||
typoid: TEXT_OID,
|
||||
typlen: -1,
|
||||
typmod: 0,
|
||||
formatcode: 0,
|
||||
}]);
|
||||
|
||||
// Safe usize -> i32|i16 conversion, from rust-postgres
|
||||
trait FromUsize: Sized {
|
||||
fn from_usize(x: usize) -> Result<Self, io::Error>;
|
||||
}
|
||||
|
||||
macro_rules! from_usize {
|
||||
($t:ty) => {
|
||||
impl FromUsize for $t {
|
||||
#[inline]
|
||||
fn from_usize(x: usize) -> io::Result<$t> {
|
||||
if x > <$t>::max_value() as usize {
|
||||
Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"value too large to transmit",
|
||||
))
|
||||
} else {
|
||||
Ok(x as $t)
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
from_usize!(i32);
|
||||
|
||||
/// Call f() to write body of the message and prepend it with 4-byte len as
|
||||
/// prescribed by the protocol.
|
||||
fn write_body<F>(buf: &mut BytesMut, f: F) -> io::Result<()>
|
||||
where
|
||||
F: FnOnce(&mut BytesMut) -> io::Result<()>,
|
||||
{
|
||||
let base = buf.len();
|
||||
buf.extend_from_slice(&[0; 4]);
|
||||
|
||||
f(buf)?;
|
||||
|
||||
let size = i32::from_usize(buf.len() - base)?;
|
||||
BigEndian::write_i32(&mut buf[base..], size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Safe write of s into buf as cstring (String in the protocol).
|
||||
fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> {
|
||||
if s.contains(&0) {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"string contains embedded null",
|
||||
));
|
||||
}
|
||||
buf.put_slice(s);
|
||||
buf.put_u8(0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl<'a> BeMessage<'a> {
|
||||
/// Write message to the given buf.
|
||||
// Unlike the reading side, we use BytesMut
|
||||
// here as msg len preceeds its body and it is handy to write it down first
|
||||
// and then fill the length. With Write we would have to either calc it
|
||||
// manually or have one more buffer.
|
||||
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> io::Result<()> {
|
||||
match message {
|
||||
BeMessage::AuthenticationOk => {
|
||||
buf.put_u8(b'R');
|
||||
write_body(buf, |buf| {
|
||||
buf.put_i32(0);
|
||||
Ok::<_, io::Error>(())
|
||||
})
|
||||
.unwrap(); // write into BytesMut can't fail
|
||||
}
|
||||
|
||||
BeMessage::BindComplete => {
|
||||
buf.put_u8(b'2');
|
||||
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
|
||||
}
|
||||
|
||||
BeMessage::CloseComplete => {
|
||||
buf.put_u8(b'3');
|
||||
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
|
||||
}
|
||||
|
||||
BeMessage::CommandComplete(cmd) => {
|
||||
buf.put_u8(b'C');
|
||||
write_body(buf, |buf| {
|
||||
write_cstr(cmd, buf)?;
|
||||
Ok::<_, io::Error>(())
|
||||
})?;
|
||||
}
|
||||
|
||||
BeMessage::ControlFile => {
|
||||
// TODO pass checkpoint and xid info in this message
|
||||
BeMessage::write(buf, &BeMessage::DataRow(&[Some(b"hello pg_control")]))?;
|
||||
}
|
||||
|
||||
BeMessage::CopyData(data) => {
|
||||
buf.put_u8(b'd');
|
||||
write_body(buf, |buf| {
|
||||
buf.put_slice(data);
|
||||
Ok::<_, io::Error>(())
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
BeMessage::CopyDone => {
|
||||
buf.put_u8(b'c');
|
||||
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
|
||||
}
|
||||
|
||||
BeMessage::CopyInResponse => {
|
||||
buf.put_u8(b'G');
|
||||
write_body(buf, |buf| {
|
||||
buf.put_u8(1); /* copy_is_binary */
|
||||
buf.put_i16(0); /* numAttributes */
|
||||
Ok::<_, io::Error>(())
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
BeMessage::CopyOutResponse => {
|
||||
buf.put_u8(b'H');
|
||||
write_body(buf, |buf| {
|
||||
buf.put_u8(0); /* copy_is_binary */
|
||||
buf.put_i16(0); /* numAttributes */
|
||||
Ok::<_, io::Error>(())
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
BeMessage::CopyBothResponse => {
|
||||
buf.put_u8(b'W');
|
||||
write_body(buf, |buf| {
|
||||
// doesn't matter, used only for replication
|
||||
buf.put_u8(0); /* copy_is_binary */
|
||||
buf.put_i16(0); /* numAttributes */
|
||||
Ok::<_, io::Error>(())
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
BeMessage::DataRow(vals) => {
|
||||
buf.put_u8(b'D');
|
||||
write_body(buf, |buf| {
|
||||
buf.put_u16(vals.len() as u16); // num of cols
|
||||
for val_opt in vals.iter() {
|
||||
if let Some(val) = val_opt {
|
||||
buf.put_u32(val.len() as u32);
|
||||
buf.put_slice(val);
|
||||
} else {
|
||||
buf.put_i32(-1);
|
||||
}
|
||||
}
|
||||
Ok::<_, io::Error>(())
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// ErrorResponse is a zero-terminated array of zero-terminated fields.
|
||||
// First byte of each field represents type of this field. Set just enough fields
|
||||
// to satisfy rust-postgres client: 'S' -- severity, 'C' -- error, 'M' -- error
|
||||
// message text.
|
||||
BeMessage::ErrorResponse(error_msg) => {
|
||||
// For all the errors set Severity to Error and error code to
|
||||
// 'internal error'.
|
||||
|
||||
// 'E' signalizes ErrorResponse messages
|
||||
buf.put_u8(b'E');
|
||||
write_body(buf, |buf| {
|
||||
buf.put_u8(b'S'); // severity
|
||||
write_cstr(&Bytes::from("ERROR"), buf)?;
|
||||
|
||||
buf.put_u8(b'C'); // SQLSTATE error code
|
||||
write_cstr(&Bytes::from("CXX000"), buf)?;
|
||||
|
||||
buf.put_u8(b'M'); // the message
|
||||
write_cstr(error_msg.as_bytes(), buf)?;
|
||||
|
||||
buf.put_u8(0); // terminator
|
||||
Ok::<_, io::Error>(())
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
BeMessage::NoData => {
|
||||
buf.put_u8(b'n');
|
||||
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
|
||||
}
|
||||
|
||||
BeMessage::Negotiate => {
|
||||
buf.put_u8(b'N');
|
||||
}
|
||||
|
||||
BeMessage::ParameterStatus => {
|
||||
buf.put_u8(b'S');
|
||||
// parameter names and values are specified by null terminated strings
|
||||
const PARAM_NAME_VALUE: &[u8] = b"client_encoding\0UTF8\0";
|
||||
write_body(buf, |buf| {
|
||||
buf.put_slice(PARAM_NAME_VALUE);
|
||||
Ok::<_, io::Error>(())
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
BeMessage::ParameterDescription => {
|
||||
buf.put_u8(b't');
|
||||
write_body(buf, |buf| {
|
||||
// we don't support params, so always 0
|
||||
buf.put_i16(0);
|
||||
Ok::<_, io::Error>(())
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
BeMessage::ParseComplete => {
|
||||
buf.put_u8(b'1');
|
||||
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
|
||||
}
|
||||
|
||||
BeMessage::ReadyForQuery => {
|
||||
buf.put_u8(b'Z');
|
||||
write_body(buf, |buf| {
|
||||
buf.put_u8(b'I');
|
||||
Ok::<_, io::Error>(())
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
BeMessage::RowDescription(rows) => {
|
||||
buf.put_u8(b'T');
|
||||
write_body(buf, |buf| {
|
||||
buf.put_i16(rows.len() as i16); // # of fields
|
||||
for row in rows.iter() {
|
||||
write_cstr(row.name, buf)?;
|
||||
buf.put_i32(0); /* table oid */
|
||||
buf.put_i16(0); /* attnum */
|
||||
buf.put_u32(row.typoid);
|
||||
buf.put_i16(row.typlen);
|
||||
buf.put_i32(-1); /* typmod */
|
||||
buf.put_i16(0); /* format code */
|
||||
}
|
||||
Ok::<_, io::Error>(())
|
||||
})?;
|
||||
}
|
||||
|
||||
BeMessage::XLogData(body) => {
|
||||
buf.put_u8(b'd');
|
||||
write_body(buf, |buf| {
|
||||
buf.put_u8(b'w');
|
||||
buf.put_u64(body.wal_start);
|
||||
buf.put_u64(body.wal_end);
|
||||
buf.put_u64(body.timestamp);
|
||||
buf.put_slice(body.data);
|
||||
Ok::<_, io::Error>(())
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user