From b2f51026aa00b147a1204263f02bfcd8112edeec Mon Sep 17 00:00:00 2001 From: Arseny Sher Date: Mon, 31 May 2021 09:26:26 +0300 Subject: [PATCH] Consolidate PG proto parsing-deparsing and backend code. Now postgres_backend communicates with the client, passing queries to the provided handler; we have two currently, for wal_acceptor and pageserver. Now BytesMut is again used for writing data to avoid manual message length calculation. ref #118 --- Cargo.lock | 3 + pageserver/src/page_service.rs | 1014 ++++++-------------------- walkeeper/src/replication.rs | 110 +-- walkeeper/src/send_wal.rs | 197 ++--- walkeeper/src/wal_service.rs | 8 +- zenith_utils/Cargo.toml | 4 + zenith_utils/src/lib.rs | 3 + zenith_utils/src/postgres_backend.rs | 181 +++++ zenith_utils/src/pq_proto.rs | 639 ++++++++++++++++ 9 files changed, 1188 insertions(+), 971 deletions(-) create mode 100644 zenith_utils/src/postgres_backend.rs create mode 100644 zenith_utils/src/pq_proto.rs diff --git a/Cargo.lock b/Cargo.lock index d3cf801ede..d06573544a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2478,9 +2478,12 @@ dependencies = [ name = "zenith_utils" version = "0.1.0" dependencies = [ + "anyhow", "bincode", + "byteorder", "bytes", "hex-literal", + "log", "serde", "thiserror", "workspace_hack", diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index 4e6ee91399..cc801ef7a8 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -11,15 +11,17 @@ // use anyhow::{anyhow, bail}; -use byteorder::{ReadBytesExt, WriteBytesExt, BE}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use log::*; use regex::Regex; -use std::io; -use std::io::{BufReader, BufWriter, Read, Write}; -use std::net::{TcpListener, TcpStream}; +use std::io::Write; +use std::net::TcpListener; use std::str::FromStr; use std::thread; +use std::{io, net::TcpStream}; +use zenith_utils::postgres_backend; +use zenith_utils::postgres_backend::PostgresBackend; +use zenith_utils::pq_proto::{BeMessage, FeMessage, HELLO_WORLD_ROW, SINGLE_COL_ROWDESC}; use zenith_utils::{bin_ser::BeSer, lsn::Lsn}; use crate::basebackup; @@ -31,40 +33,6 @@ use crate::walreceiver; use crate::PageServerConf; use crate::ZTimelineId; -#[derive(Debug)] -enum FeMessage { - StartupMessage(FeStartupMessage), - Query(FeQueryMessage), // Simple query - Parse(FeParseMessage), // Extended query protocol - Describe(FeDescribeMessage), - Bind(FeBindMessage), - Execute(FeExecuteMessage), - Close(FeCloseMessage), - Sync, - Terminate, - CopyData(Bytes), - CopyDone, -} - -#[derive(Debug)] -enum BeMessage { - AuthenticationOk, - ParameterStatus, - ReadyForQuery, - RowDescription, - ParseComplete, - ParameterDescription, - NoData, - BindComplete, - CloseComplete, - DataRow(Bytes), - CommandComplete, - ControlFile, - CopyData(Bytes), - ErrorResponse(String), - CopyInResponse, -} - // Wrapped in libpq CopyData enum PagestreamFeMessage { Exists(PagestreamRequest), @@ -79,8 +47,6 @@ enum PagestreamBeMessage { Read(PagestreamReadResponse), } -static HELLO_WORLD_ROW: BeMessage = BeMessage::DataRow(Bytes::from_static(b"hello world")); - #[derive(Debug)] struct PagestreamRequest { spcnode: u32, @@ -104,264 +70,6 @@ struct PagestreamReadResponse { page: Bytes, } -#[derive(Debug)] -struct FeStartupMessage { - version: u32, - kind: StartupRequestCode, -} - -#[derive(Debug)] -enum StartupRequestCode { - Cancel, - NegotiateSsl, - NegotiateGss, - Normal, -} - -impl FeStartupMessage { - pub fn read(stream: &mut dyn std::io::Read) -> anyhow::Result> { - const MAX_STARTUP_PACKET_LENGTH: u32 = 10000; - const CANCEL_REQUEST_CODE: u32 = (1234 << 16) | 5678; - const NEGOTIATE_SSL_CODE: u32 = (1234 << 16) | 5679; - const NEGOTIATE_GSS_CODE: u32 = (1234 << 16) | 5680; - - // Read length. If the connection is closed before reading anything (or before - // reading 4 bytes, to be precise), return None to indicate that the connection - // was closed. This matches the PostgreSQL server's behavior, which avoids noise - // in the log if the client opens connection but closes it immediately. - let len = match stream.read_u32::() { - Ok(len) => len, - Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(e.into()), - }; - - if len < 4 || len > MAX_STARTUP_PACKET_LENGTH { - bail!("invalid message length"); - } - let bodylen = len - 4; - - // Read the rest of the startup packet - let mut body_buf: Vec = vec![0; bodylen as usize]; - stream.read_exact(&mut body_buf)?; - let mut body = Bytes::from(body_buf); - - // Parse the first field, which indicates what kind of a packet it is - let version = body.get_u32(); - let kind = match version { - CANCEL_REQUEST_CODE => StartupRequestCode::Cancel, - NEGOTIATE_SSL_CODE => StartupRequestCode::NegotiateSsl, - NEGOTIATE_GSS_CODE => StartupRequestCode::NegotiateGss, - _ => StartupRequestCode::Normal, - }; - - // Ignore the rest of the packet - - Ok(Some(FeMessage::StartupMessage(FeStartupMessage { - version, - kind, - }))) - } -} - -#[derive(Debug)] -struct Buffer { - bytes: Bytes, - idx: usize, -} - -#[derive(Debug)] -struct FeQueryMessage { - body: Bytes, -} - -// We only support the simple case of Parse on unnamed prepared statement and -// no params -#[derive(Debug)] -struct FeParseMessage { - query_string: Bytes, -} - -fn read_null_terminated(buf: &mut Bytes) -> anyhow::Result { - let mut result = BytesMut::new(); - - loop { - if !buf.has_remaining() { - bail!("no null-terminator in string"); - } - - let byte = buf.get_u8(); - - if byte == 0 { - break; - } - result.put_u8(byte); - } - Ok(result.freeze()) -} - -impl FeParseMessage { - pub fn parse(mut buf: Bytes) -> anyhow::Result { - let _pstmt_name = read_null_terminated(&mut buf)?; - let query_string = read_null_terminated(&mut buf)?; - let nparams = buf.get_i16(); - - // FIXME: the rust-postgres driver uses a named prepared statement - // for copy_out(). We're not prepared to handle that correctly. For - // now, just ignore the statement name, assuming that the client never - // uses more than one prepared statement at a time. - /* - if !pstmt_name.is_empty() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "named prepared statements not implemented in Parse", - )); - } - */ - - if nparams != 0 { - bail!("query params not implemented"); - } - - Ok(FeMessage::Parse(FeParseMessage { query_string })) - } -} - -#[derive(Debug)] -struct FeDescribeMessage { - kind: u8, // 'S' to describe a prepared statement; or 'P' to describe a portal. - // we only support unnamed prepared stmt or portal -} - -impl FeDescribeMessage { - pub fn parse(mut buf: Bytes) -> anyhow::Result { - let kind = buf.get_u8(); - let _pstmt_name = read_null_terminated(&mut buf)?; - - // FIXME: see FeParseMessage::parse - /* - if !pstmt_name.is_empty() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "named prepared statements not implemented in Describe", - )); - } - */ - - if kind != b'S' { - bail!("only prepared statmement Describe is implemented"); - } - - Ok(FeMessage::Describe(FeDescribeMessage { kind })) - } -} - -// we only support unnamed prepared stmt or portal -#[derive(Debug)] -struct FeExecuteMessage { - /// max # of rows - maxrows: i32, -} - -impl FeExecuteMessage { - pub fn parse(mut buf: Bytes) -> anyhow::Result { - let portal_name = read_null_terminated(&mut buf)?; - let maxrows = buf.get_i32(); - - if !portal_name.is_empty() { - bail!("named portals not implemented"); - } - - if maxrows != 0 { - bail!("row limit in Execute message not supported"); - } - - Ok(FeMessage::Execute(FeExecuteMessage { maxrows })) - } -} - -// we only support unnamed prepared stmt and portal -#[derive(Debug)] -struct FeBindMessage {} - -impl FeBindMessage { - pub fn parse(mut buf: Bytes) -> anyhow::Result { - let portal_name = read_null_terminated(&mut buf)?; - let _pstmt_name = read_null_terminated(&mut buf)?; - - if !portal_name.is_empty() { - bail!("named portals not implemented"); - } - - // FIXME: see FeParseMessage::parse - /* - if !pstmt_name.is_empty() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "named prepared statements not implemented", - )); - } - */ - - Ok(FeMessage::Bind(FeBindMessage {})) - } -} - -// we only support unnamed prepared stmt and portal -#[derive(Debug)] -struct FeCloseMessage {} - -impl FeCloseMessage { - pub fn parse(mut buf: Bytes) -> anyhow::Result { - let _kind = buf.get_u8(); - let _pstmt_or_portal_name = read_null_terminated(&mut buf)?; - - // FIXME: we do nothing with Close - - Ok(FeMessage::Close(FeCloseMessage {})) - } -} - -impl FeMessage { - pub fn read(stream: &mut dyn Read) -> anyhow::Result> { - // Each libpq message begins with a message type byte, followed by message length - // If the client closes the connection, return None. But if the client closes the - // connection in the middle of a message, we will return an error. - let tag = match stream.read_u8() { - Ok(b) => b, - Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(e.into()), - }; - let len = stream.read_u32::()?; - - // The message length includes itself, so it better be at least 4 - if len < 4 { - bail!("invalid message length: parsing u32"); - } - let bodylen = len - 4; - - // Read message body - let mut body_buf: Vec = vec![0; bodylen as usize]; - stream.read_exact(&mut body_buf)?; - - let body = Bytes::from(body_buf); - - // Parse it - match tag { - b'Q' => Ok(Some(FeMessage::Query(FeQueryMessage { body }))), - b'P' => Ok(Some(FeParseMessage::parse(body)?)), - b'D' => Ok(Some(FeDescribeMessage::parse(body)?)), - b'E' => Ok(Some(FeExecuteMessage::parse(body)?)), - b'B' => Ok(Some(FeBindMessage::parse(body)?)), - b'C' => Ok(Some(FeCloseMessage::parse(body)?)), - b'S' => Ok(Some(FeMessage::Sync)), - b'X' => Ok(Some(FeMessage::Terminate)), - b'd' => Ok(Some(FeMessage::CopyData(body))), - b'c' => Ok(Some(FeMessage::CopyDone)), - tag => Err(anyhow!("unknown message tag: {},'{:?}'", tag, body)), - } - } -} - impl PagestreamFeMessage { fn parse(mut body: Bytes) -> anyhow::Result { // TODO these gets can fail @@ -432,473 +140,44 @@ pub fn thread_main(conf: &'static PageServerConf, listener: TcpListener) -> anyh let (socket, peer_addr) = listener.accept()?; debug!("accepted connection from {}", peer_addr); socket.set_nodelay(true).unwrap(); - let mut conn_handler = Connection::new(conf, socket); thread::spawn(move || { - if let Err(err) = conn_handler.run() { + if let Err(err) = page_service_conn_main(conf, socket) { error!("error: {}", err); } }); } } +fn page_service_conn_main(conf: &'static PageServerConf, socket: TcpStream) -> anyhow::Result<()> { + let mut conn_handler = PageServerHandler::new(conf); + let mut pgbackend = PostgresBackend::new(socket)?; + pgbackend.run(&mut conn_handler) +} + #[derive(Debug)] -struct Connection { - stream_in: BufReader, - stream: BufWriter, - init_done: bool, +struct PageServerHandler { conf: &'static PageServerConf, } -impl Connection { - pub fn new(conf: &'static PageServerConf, socket: TcpStream) -> Connection { - Connection { - stream_in: BufReader::new(socket.try_clone().unwrap()), - stream: BufWriter::new(socket), - init_done: false, - conf, - } +impl PageServerHandler { + pub fn new(conf: &'static PageServerConf) -> Self { + PageServerHandler { conf } } - // - // Read full message or return None if connection is closed - // - fn read_message(&mut self) -> anyhow::Result> { - if !self.init_done { - FeStartupMessage::read(&mut self.stream_in) - } else { - FeMessage::read(&mut self.stream_in) - } - } - - fn write_message_noflush(&mut self, message: &BeMessage) -> io::Result<()> { - match message { - BeMessage::AuthenticationOk => { - self.stream.write_u8(b'R')?; - self.stream.write_i32::(4 + 4)?; - self.stream.write_i32::(0)?; - } - - BeMessage::ParameterStatus => { - self.stream.write_u8(b'S')?; - // parameter names and values are specified by null terminated strings - const PARAM_NAME_VALUE: &[u8] = b"client_encoding\0UTF8\0"; - // length of this i32 + rest of data in message - self.stream - .write_i32::(4 + PARAM_NAME_VALUE.len() as i32)?; - self.stream.write_all(PARAM_NAME_VALUE)?; - } - - BeMessage::ReadyForQuery => { - self.stream.write_u8(b'Z')?; - self.stream.write_i32::(4 + 1)?; - self.stream.write_u8(b'I')?; - } - - BeMessage::ParseComplete => { - self.stream.write_u8(b'1')?; - self.stream.write_i32::(4)?; - } - - BeMessage::BindComplete => { - self.stream.write_u8(b'2')?; - self.stream.write_i32::(4)?; - } - - BeMessage::CloseComplete => { - self.stream.write_u8(b'3')?; - self.stream.write_i32::(4)?; - } - - BeMessage::NoData => { - self.stream.write_u8(b'n')?; - self.stream.write_i32::(4)?; - } - - BeMessage::ParameterDescription => { - self.stream.write_u8(b't')?; - self.stream.write_i32::(6)?; - // we don't support params, so always 0 - self.stream.write_i16::(0)?; - } - - BeMessage::RowDescription => { - // XXX - let b = Bytes::from("data\0"); - - self.stream.write_u8(b'T')?; - self.stream - .write_i32::(4 + 2 + b.len() as i32 + 3 * (4 + 2))?; - - self.stream.write_i16::(1)?; - self.stream.write_all(&b)?; - self.stream.write_i32::(0)?; /* table oid */ - self.stream.write_i16::(0)?; /* attnum */ - self.stream.write_i32::(25)?; /* TEXTOID */ - self.stream.write_i16::(-1)?; /* typlen */ - self.stream.write_i32::(0)?; /* typmod */ - self.stream.write_i16::(0)?; /* format code */ - } - - // XXX: accept some text data - BeMessage::DataRow(b) => { - self.stream.write_u8(b'D')?; - self.stream.write_i32::(4 + 2 + 4 + b.len() as i32)?; - - self.stream.write_i16::(1)?; - self.stream.write_i32::(b.len() as i32)?; - self.stream.write_all(&b)?; - } - - BeMessage::ControlFile => { - // TODO pass checkpoint and xid info in this message - let b = Bytes::from("hello pg_control"); - - self.stream.write_u8(b'D')?; - self.stream.write_i32::(4 + 2 + 4 + b.len() as i32)?; - - self.stream.write_i16::(1)?; - self.stream.write_i32::(b.len() as i32)?; - self.stream.write_all(&b)?; - } - - BeMessage::CommandComplete => { - let b = Bytes::from("SELECT 1\0"); - - self.stream.write_u8(b'C')?; - self.stream.write_i32::(4 + b.len() as i32)?; - self.stream.write_all(&b)?; - } - - BeMessage::CopyData(data) => { - self.stream.write_u8(b'd')?; - self.stream.write_u32::(4 + data.len() as u32)?; - self.stream.write_all(&data)?; - } - - // ErrorResponse is a zero-terminated array of zero-terminated fields. - // First byte of each field represents type of this field. Set just enough fields - // to satisfy rust-postgres client: 'S' -- severity, 'C' -- error, 'M' -- error - // message text. - BeMessage::ErrorResponse(error_msg) => { - // For all the errors set Severity to Error and error code to - // 'internal error'. - let severity = Bytes::from("SERROR\0"); - let code = Bytes::from("CXX000\0"); - - // 'E' signalizes ErrorResponse messages - self.stream.write_u8(b'E')?; - self.stream.write_u32::( - 4 + severity.len() as u32 - + code.len() as u32 - + (1 + error_msg.len() as u32 + 1) - + 1, - )?; - - // Send severity and code fields - self.stream.write_all(&severity)?; - self.stream.write_all(&code)?; - - // Send error message field - self.stream.write_u8(b'M')?; - self.stream.write_all(error_msg.as_bytes())?; - self.stream.write_u8(0)?; - - // Terminate fields - self.stream.write_u8(0)?; - } - - BeMessage::CopyInResponse => { - self.stream.write_u8(b'G')?; - self.stream.write_u32::(4 + 1 + 2)?; - self.stream.write_u8(1)?; // binary - self.stream.write_u16::(0)?; // no columns - } - } + fn handle_controlfile(&self, pgb: &mut PostgresBackend) -> io::Result<()> { + pgb.write_message_noflush(&SINGLE_COL_ROWDESC)? + .write_message_noflush(&BeMessage::ControlFile)? + .write_message(&BeMessage::CommandComplete(b"SELECT 1"))?; Ok(()) } - fn write_message(&mut self, message: &BeMessage) -> io::Result<()> { - self.write_message_noflush(message)?; - self.stream.flush() - } - - fn run(&mut self) -> anyhow::Result<()> { - let mut unnamed_query_string = Bytes::new(); - loop { - let msg = self.read_message()?; - trace!("got message {:?}", msg); - match msg { - Some(FeMessage::StartupMessage(m)) => { - trace!("got message {:?}", m); - - match m.kind { - StartupRequestCode::NegotiateGss | StartupRequestCode::NegotiateSsl => { - let b = Bytes::from("N"); - self.stream.write_all(&b)?; - self.stream.flush()?; - } - StartupRequestCode::Normal => { - self.write_message_noflush(&BeMessage::AuthenticationOk)?; - // psycopg2 will not connect if client_encoding is not - // specified by the server - self.write_message_noflush(&BeMessage::ParameterStatus)?; - self.write_message(&BeMessage::ReadyForQuery)?; - self.init_done = true; - } - StartupRequestCode::Cancel => return Ok(()), - } - } - Some(FeMessage::Query(m)) => { - if let Err(e) = self.process_query(m.body) { - let errmsg = format!("{}", e); - self.write_message_noflush(&BeMessage::ErrorResponse(errmsg))?; - } - self.write_message(&BeMessage::ReadyForQuery)?; - } - Some(FeMessage::Parse(m)) => { - unnamed_query_string = m.query_string; - self.write_message(&BeMessage::ParseComplete)?; - } - Some(FeMessage::Describe(_)) => { - self.write_message_noflush(&BeMessage::ParameterDescription)?; - self.write_message(&BeMessage::NoData)?; - } - Some(FeMessage::Bind(_)) => { - self.write_message(&BeMessage::BindComplete)?; - } - Some(FeMessage::Close(_)) => { - self.write_message(&BeMessage::CloseComplete)?; - } - Some(FeMessage::Execute(_)) => { - self.process_query(unnamed_query_string.clone())?; - self.stream.flush()?; - } - Some(FeMessage::Sync) => { - self.write_message(&BeMessage::ReadyForQuery)?; - } - Some(FeMessage::Terminate) => { - break; - } - None => { - info!("connection closed"); - break; - } - x => { - bail!("unexpected message type : {:?}", x); - } - } - } - - Ok(()) - } - - fn process_query(&mut self, query_string: Bytes) -> anyhow::Result<()> { - debug!("process query {:?}", query_string); - - // remove null terminator, if any - let mut query_string = query_string; - if query_string.last() == Some(&0) { - query_string.truncate(query_string.len() - 1); - } - - if query_string.starts_with(b"controlfile") { - self.handle_controlfile()?; - } else if query_string.starts_with(b"pagestream ") { - let (_l, r) = query_string.split_at("pagestream ".len()); - let timelineid_str = String::from_utf8(r.to_vec())?; - let timelineid = ZTimelineId::from_str(&timelineid_str)?; - - self.handle_pagerequests(timelineid)?; - } else if query_string.starts_with(b"basebackup ") { - let (_l, r) = query_string.split_at("basebackup ".len()); - let r = r.to_vec(); - let basebackup_args = String::from(String::from_utf8(r)?.trim_end()); - let args: Vec<&str> = basebackup_args.rsplit(' ').collect(); - let timelineid_str = args[0]; - info!("got basebackup command: \"{}\"", timelineid_str); - let timelineid = ZTimelineId::from_str(&timelineid_str)?; - let lsn = if args.len() > 1 { - Some(Lsn::from_str(args[1])?) - } else { - None - }; - // Check that the timeline exists - self.handle_basebackup_request(timelineid, lsn)?; - self.write_message_noflush(&BeMessage::CommandComplete)?; - } else if query_string.starts_with(b"callmemaybe ") { - let query_str = String::from_utf8(query_string.to_vec())?; - - // callmemaybe - // TODO lazy static - let re = Regex::new(r"^callmemaybe ([[:xdigit:]]+) (.*)$").unwrap(); - let caps = re - .captures(&query_str) - .ok_or_else(|| anyhow!("invalid callmemaybe: '{}'", query_str))?; - - let timelineid = ZTimelineId::from_str(caps.get(1).unwrap().as_str())?; - let connstr: String = String::from(caps.get(2).unwrap().as_str()); - - // Check that the timeline exists - let repository = page_cache::get_repository(); - if repository.get_timeline(timelineid).is_err() { - bail!("client requested callmemaybe on timeline {} which does not exist in page server", timelineid); - } - - walreceiver::launch_wal_receiver(&self.conf, timelineid, &connstr); - - self.write_message_noflush(&BeMessage::CommandComplete)?; - } else if query_string.starts_with(b"branch_create ") { - let query_str = String::from_utf8(query_string.to_vec())?; - let err = || anyhow!("invalid branch_create: '{}'", query_str); - - // branch_create - // TODO lazy static - // TOOD: escaping, to allow branch names with spaces - let re = Regex::new(r"^branch_create (\S+) ([^\r\n\s;]+)[\r\n\s;]*;?$").unwrap(); - let caps = re.captures(&query_str).ok_or_else(err)?; - - let branchname: String = String::from(caps.get(1).ok_or_else(err)?.as_str()); - let startpoint_str: String = String::from(caps.get(2).ok_or_else(err)?.as_str()); - - let branch = branches::create_branch(&self.conf, &branchname, &startpoint_str)?; - let branch = serde_json::to_vec(&branch)?; - - self.write_message_noflush(&BeMessage::RowDescription)?; - self.write_message_noflush(&BeMessage::DataRow(Bytes::from(branch)))?; - self.write_message_noflush(&BeMessage::CommandComplete)?; - } else if query_string.starts_with(b"push ") { - let query_str = std::str::from_utf8(&query_string)?; - let mut it = query_str.split(' '); - it.next().unwrap(); - let timeline_id: ZTimelineId = it - .next() - .ok_or_else(|| anyhow!("missing timeline id"))? - .parse()?; - - let start_lsn = Lsn(0); // TODO this needs to come from the repo - let timeline = - page_cache::get_repository().create_empty_timeline(timeline_id, start_lsn)?; - - self.write_message(&BeMessage::CopyInResponse)?; - - let mut last_lsn = Lsn(0); - - while let Some(msg) = self.read_message()? { - match msg { - FeMessage::CopyData(bytes) => { - let relation_update = RelationUpdate::des(&bytes)?; - - last_lsn = relation_update.lsn; - - match relation_update.update { - Update::Page { blknum, img } => { - let tag = BufferTag { - rel: relation_update.rel, - blknum, - }; - - timeline.put_page_image(tag, relation_update.lsn, img)?; - } - Update::WALRecord { blknum, rec } => { - let tag = BufferTag { - rel: relation_update.rel, - blknum, - }; - - timeline.put_wal_record(tag, rec)?; - } - Update::Truncate { n_blocks } => { - timeline.put_truncation( - relation_update.rel, - relation_update.lsn, - n_blocks, - )?; - } - Update::Unlink => { - todo!() - } - } - } - FeMessage::CopyDone => { - timeline.advance_last_valid_lsn(last_lsn); - break; - } - FeMessage::Sync => {} - _ => bail!("unexpected message {:?}", msg), - } - } - - self.write_message(&BeMessage::CommandComplete)?; - } else if query_string.starts_with(b"request_push ") { - let query_str = std::str::from_utf8(&query_string)?; - let mut it = query_str.split(' '); - it.next().unwrap(); - - let timeline_id: ZTimelineId = it - .next() - .ok_or_else(|| anyhow!("missing timeline id"))? - .parse()?; - let timeline = page_cache::get_repository().get_timeline(timeline_id)?; - - let postgres_connection_uri = it.next().ok_or(anyhow!("missing postgres uri"))?; - - let mut conn = postgres::Client::connect(postgres_connection_uri, postgres::NoTls)?; - let mut copy_in = conn.copy_in(format!("push {}", timeline_id.to_string()).as_str())?; - - let history = timeline.history()?; - for update_res in history { - let update = update_res?; - let update_bytes = update.ser()?; - copy_in.write_all(&update_bytes)?; - copy_in.flush()?; // ensure that messages are sent inside individual CopyData packets - } - - copy_in.finish()?; - - self.write_message(&BeMessage::CommandComplete)?; - } else if query_string.starts_with(b"branch_list") { - let branches = crate::branches::get_branches(&self.conf)?; - let branches_buf = serde_json::to_vec(&branches)?; - - self.write_message_noflush(&BeMessage::RowDescription)?; - self.write_message_noflush(&BeMessage::DataRow(Bytes::from(branches_buf)))?; - self.write_message_noflush(&BeMessage::CommandComplete)?; - } else if query_string.starts_with(b"status") { - self.write_message_noflush(&BeMessage::RowDescription)?; - self.write_message_noflush(&HELLO_WORLD_ROW)?; - self.write_message_noflush(&BeMessage::CommandComplete)?; - } else if query_string.to_ascii_lowercase().starts_with(b"set ") { - // important because psycopg2 executes "SET datestyle TO 'ISO'" - // on connect - self.write_message_noflush(&BeMessage::CommandComplete)?; - } else if query_string - .to_ascii_lowercase() - .starts_with(b"identify_system") - { - // TODO: match postgres response formarmat for 'identify_system' - let system_id = crate::branches::get_system_id(&self.conf)?.to_string(); - - self.write_message_noflush(&BeMessage::RowDescription)?; - self.write_message_noflush(&BeMessage::DataRow(Bytes::from(system_id)))?; - self.write_message_noflush(&BeMessage::CommandComplete)?; - } else { - bail!("unknown command"); - } - - Ok(()) - } - - fn handle_controlfile(&mut self) -> io::Result<()> { - self.write_message_noflush(&BeMessage::RowDescription)?; - self.write_message_noflush(&BeMessage::ControlFile)?; - self.write_message(&BeMessage::CommandComplete)?; - - Ok(()) - } - - fn handle_pagerequests(&mut self, timelineid: ZTimelineId) -> anyhow::Result<()> { + fn handle_pagerequests( + &self, + pgb: &mut PostgresBackend, + timelineid: ZTimelineId, + ) -> anyhow::Result<()> { // Check that the timeline exists let repository = page_cache::get_repository(); let timeline = repository.get_timeline(timelineid).map_err(|_| { @@ -909,13 +188,9 @@ impl Connection { })?; /* switch client to COPYBOTH */ - self.stream.write_u8(b'W')?; - self.stream.write_i32::(4 + 1 + 2)?; - self.stream.write_u8(0)?; /* copy_is_binary */ - self.stream.write_i16::(0)?; /* numAttributes */ - self.stream.flush()?; + pgb.write_message(&BeMessage::CopyBothResponse)?; - while let Some(message) = self.read_message()? { + while let Some(message) = pgb.read_message()? { trace!("query({:?}): {:?}", timelineid, message); let copy_data_bytes = match message { @@ -985,14 +260,15 @@ impl Connection { } }; - self.write_message(&BeMessage::CopyData(response.serialize()))?; + pgb.write_message(&BeMessage::CopyData(&response.serialize()))?; } Ok(()) } fn handle_basebackup_request( - &mut self, + &self, + pgb: &mut PostgresBackend, timelineid: ZTimelineId, lsn: Option, ) -> anyhow::Result<()> { @@ -1006,12 +282,7 @@ impl Connection { ) })?; /* switch client to COPYOUT */ - let stream = &mut self.stream; - stream.write_u8(b'H')?; - stream.write_i32::(4 + 1 + 2)?; - stream.write_u8(0)?; /* copy_is_binary */ - stream.write_i16::(0)?; /* numAttributes */ - stream.flush()?; + pgb.write_message(&BeMessage::CopyOutResponse)?; info!("sent CopyOut"); /* Send a tarball of the latest snapshot on the timeline */ @@ -1021,28 +292,229 @@ impl Connection { restore_local_repo::find_latest_snapshot(&self.conf, timelineid).unwrap(); let req_lsn = lsn.unwrap_or(snapshot_lsn); basebackup::send_tarball_at_lsn( - &mut CopyDataSink { stream }, + &mut CopyDataSink { pgb }, timelineid, &timeline, req_lsn, snapshot_lsn, )?; - // CopyDone - self.stream.write_u8(b'c')?; - self.stream.write_u32::(4)?; - self.stream.flush()?; + pgb.write_message(&BeMessage::CopyDone)?; debug!("CopyDone sent!"); Ok(()) } } +impl postgres_backend::Handler for PageServerHandler { + fn process_query( + &mut self, + pgb: &mut PostgresBackend, + query_string: Bytes, + ) -> anyhow::Result<()> { + debug!("process query {:?}", query_string); + + // remove null terminator, if any + let mut query_string = query_string; + if query_string.last() == Some(&0) { + query_string.truncate(query_string.len() - 1); + } + + if query_string.starts_with(b"controlfile") { + self.handle_controlfile(pgb)?; + } else if query_string.starts_with(b"pagestream ") { + let (_l, r) = query_string.split_at("pagestream ".len()); + let timelineid_str = String::from_utf8(r.to_vec())?; + let timelineid = ZTimelineId::from_str(&timelineid_str)?; + + self.handle_pagerequests(pgb, timelineid)?; + } else if query_string.starts_with(b"basebackup ") { + let (_l, r) = query_string.split_at("basebackup ".len()); + let r = r.to_vec(); + let basebackup_args = String::from(String::from_utf8(r)?.trim_end()); + let args: Vec<&str> = basebackup_args.rsplit(' ').collect(); + let timelineid_str = args[0]; + info!("got basebackup command: \"{}\"", timelineid_str); + let timelineid = ZTimelineId::from_str(&timelineid_str)?; + let lsn = if args.len() > 1 { + Some(Lsn::from_str(args[1])?) + } else { + None + }; + // Check that the timeline exists + self.handle_basebackup_request(pgb, timelineid, lsn)?; + pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; + } else if query_string.starts_with(b"callmemaybe ") { + let query_str = String::from_utf8(query_string.to_vec())?; + + // callmemaybe + // TODO lazy static + let re = Regex::new(r"^callmemaybe ([[:xdigit:]]+) (.*)$").unwrap(); + let caps = re + .captures(&query_str) + .ok_or_else(|| anyhow!("invalid callmemaybe: '{}'", query_str))?; + + let timelineid = ZTimelineId::from_str(caps.get(1).unwrap().as_str())?; + let connstr: String = String::from(caps.get(2).unwrap().as_str()); + + // Check that the timeline exists + let repository = page_cache::get_repository(); + if repository.get_timeline(timelineid).is_err() { + bail!("client requested callmemaybe on timeline {} which does not exist in page server", timelineid); + } + + walreceiver::launch_wal_receiver(&self.conf, timelineid, &connstr); + + pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; + } else if query_string.starts_with(b"branch_create ") { + let query_str = String::from_utf8(query_string.to_vec())?; + let err = || anyhow!("invalid branch_create: '{}'", query_str); + + // branch_create + // TODO lazy static + // TOOD: escaping, to allow branch names with spaces + let re = Regex::new(r"^branch_create (\S+) ([^\r\n\s;]+)[\r\n\s;]*;?$").unwrap(); + let caps = re.captures(&query_str).ok_or_else(err)?; + + let branchname: String = String::from(caps.get(1).ok_or_else(err)?.as_str()); + let startpoint_str: String = String::from(caps.get(2).ok_or_else(err)?.as_str()); + + let branch = branches::create_branch(&self.conf, &branchname, &startpoint_str)?; + let branch = serde_json::to_vec(&branch)?; + + pgb.write_message_noflush(&SINGLE_COL_ROWDESC)? + .write_message_noflush(&BeMessage::DataRow(&[Some(&branch)]))? + .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; + } else if query_string.starts_with(b"push ") { + let query_str = std::str::from_utf8(&query_string)?; + let mut it = query_str.split(' '); + it.next().unwrap(); + let timeline_id: ZTimelineId = it + .next() + .ok_or_else(|| anyhow!("missing timeline id"))? + .parse()?; + + let start_lsn = Lsn(0); // TODO this needs to come from the repo + let timeline = + page_cache::get_repository().create_empty_timeline(timeline_id, start_lsn)?; + + pgb.write_message(&BeMessage::CopyInResponse)?; + + let mut last_lsn = Lsn(0); + + while let Some(msg) = pgb.read_message()? { + match msg { + FeMessage::CopyData(bytes) => { + let relation_update = RelationUpdate::des(&bytes)?; + + last_lsn = relation_update.lsn; + + match relation_update.update { + Update::Page { blknum, img } => { + let tag = BufferTag { + rel: relation_update.rel, + blknum, + }; + + timeline.put_page_image(tag, relation_update.lsn, img)?; + } + Update::WALRecord { blknum, rec } => { + let tag = BufferTag { + rel: relation_update.rel, + blknum, + }; + + timeline.put_wal_record(tag, rec)?; + } + Update::Truncate { n_blocks } => { + timeline.put_truncation( + relation_update.rel, + relation_update.lsn, + n_blocks, + )?; + } + Update::Unlink => { + todo!() + } + } + } + FeMessage::CopyDone => { + timeline.advance_last_valid_lsn(last_lsn); + break; + } + FeMessage::Sync => {} + _ => bail!("unexpected message {:?}", msg), + } + } + + pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; + } else if query_string.starts_with(b"request_push ") { + let query_str = std::str::from_utf8(&query_string)?; + let mut it = query_str.split(' '); + it.next().unwrap(); + + let timeline_id: ZTimelineId = it + .next() + .ok_or_else(|| anyhow!("missing timeline id"))? + .parse()?; + let timeline = page_cache::get_repository().get_timeline(timeline_id)?; + + let postgres_connection_uri = it.next().ok_or(anyhow!("missing postgres uri"))?; + + let mut conn = postgres::Client::connect(postgres_connection_uri, postgres::NoTls)?; + let mut copy_in = conn.copy_in(format!("push {}", timeline_id.to_string()).as_str())?; + + let history = timeline.history()?; + for update_res in history { + let update = update_res?; + let update_bytes = update.ser()?; + copy_in.write_all(&update_bytes)?; + copy_in.flush()?; // ensure that messages are sent inside individual CopyData packets + } + + copy_in.finish()?; + + pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; + } else if query_string.starts_with(b"branch_list") { + let branches = crate::branches::get_branches(&self.conf)?; + let branches_buf = serde_json::to_vec(&branches)?; + + pgb.write_message_noflush(&SINGLE_COL_ROWDESC)? + .write_message_noflush(&BeMessage::DataRow(&[Some(&branches_buf)]))? + .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; + } else if query_string.starts_with(b"status") { + pgb.write_message_noflush(&SINGLE_COL_ROWDESC)? + .write_message_noflush(&HELLO_WORLD_ROW)? + .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; + } else if query_string.to_ascii_lowercase().starts_with(b"set ") { + // important because psycopg2 executes "SET datestyle TO 'ISO'" + // on connect + pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; + } else if query_string + .to_ascii_lowercase() + .starts_with(b"identify_system") + { + // TODO: match postgres response formarmat for 'identify_system' + let system_id = crate::branches::get_system_id(&self.conf)?.to_string(); + + pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?; + pgb.write_message_noflush(&BeMessage::DataRow(&[Some(system_id.as_bytes())]))?; + pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; + } else { + bail!("unknown command"); + } + + pgb.flush()?; + + Ok(()) + } +} + /// /// A std::io::Write implementation that wraps all data written to it in CopyData /// messages. /// struct CopyDataSink<'a> { - stream: &'a mut BufWriter, + pgb: &'a mut PostgresBackend, } impl<'a> io::Write for CopyDataSink<'a> { @@ -1051,14 +523,10 @@ impl<'a> io::Write for CopyDataSink<'a> { // FIXME: if the input is large, we should split it into multiple messages. // Not sure what the threshold should be, but the ultimate hard limit is that // the length cannot exceed u32. - self.stream.write_u8(b'd')?; - self.stream.write_u32::((4 + data.len()) as u32)?; - self.stream.write_all(&data)?; - trace!("CopyData sent for {} bytes!", data.len()); - // FIXME: flush isn't really required, but makes it easier // to view in wireshark - self.stream.flush()?; + self.pgb.write_message(&BeMessage::CopyData(data))?; + trace!("CopyData sent for {} bytes!", data.len()); Ok(data.len()) } diff --git a/walkeeper/src/replication.rs b/walkeeper/src/replication.rs index a9b1bf1f82..cef15f242c 100644 --- a/walkeeper/src/replication.rs +++ b/walkeeper/src/replication.rs @@ -1,20 +1,18 @@ -//! This module implements the replication protocol, starting with the -//! "START REPLICATION" message. +//! This module implements the streaming side of replication protocol, starting +//! with the "START REPLICATION" message. -use crate::pq_protocol::{BeMessage, FeMessage}; -use crate::send_wal::SendWalConn; +use crate::send_wal::SendWalHandler; use crate::timeline::{Timeline, TimelineTools}; -use crate::WalAcceptorConf; use anyhow::{anyhow, Result}; -use bytes::{BufMut, Bytes, BytesMut}; +use bytes::Bytes; use log::*; use postgres_ffi::xlog_utils::{get_current_timestamp, TimestampTz, XLogFileName, MAX_SEND_SIZE}; use regex::Regex; use serde::{Deserialize, Serialize}; use std::cmp::min; use std::fs::File; -use std::io::{BufReader, Read, Seek, SeekFrom, Write}; -use std::net::{Shutdown, TcpStream}; +use std::io::{BufReader, Read, Seek, SeekFrom}; +use std::net::TcpStream; use std::path::Path; use std::sync::Arc; use std::thread::sleep; @@ -22,10 +20,9 @@ use std::time::Duration; use std::{str, thread}; use zenith_utils::bin_ser::BeSer; use zenith_utils::lsn::Lsn; +use zenith_utils::postgres_backend::PostgresBackend; +use zenith_utils::pq_proto::{BeMessage, FeMessage, XLogDataBody}; -const XLOG_HDR_SIZE: usize = 1 + 8 * 3; /* 'w' + startPos + walEnd + timestamp */ -const LIBPQ_HDR_SIZE: usize = 5; /* 1 byte with message type + 4 bytes length */ -const LIBPQ_MSG_SIZE_OFFS: usize = 1; pub const END_REPLICATION_MARKER: Lsn = Lsn::MAX; type FullTransactionId = u64; @@ -40,37 +37,16 @@ pub struct HotStandbyFeedback { /// A network connection that's speaking the replication protocol. pub struct ReplicationConn { - timeline: Option>, - /// Postgres connection, buffered input - /// /// This is an `Option` because we will spawn a background thread that will /// `take` it from us. stream_in: Option>, - /// Postgres connection, output - stream_out: TcpStream, - /// wal acceptor configuration - conf: WalAcceptorConf, - /// assigned application name - appname: Option, -} - -// Separate thread is reading keepalives from the socket. When main one -// finishes, tell it to get down by shutdowning the socket. -impl Drop for ReplicationConn { - fn drop(&mut self) { - let _res = self.stream_out.shutdown(Shutdown::Both); - } } impl ReplicationConn { - /// Create a new `SendWal`, consuming the `Connection`. - pub fn new(conn: SendWalConn) -> Self { + /// Create a new `ReplicationConn` + pub fn new(pgb: &mut PostgresBackend) -> Self { Self { - timeline: conn.timeline, - stream_in: Some(conn.stream_in), - stream_out: conn.stream_out, - conf: conn.conf, - appname: None, + stream_in: pgb.take_stream_in(), } } @@ -82,9 +58,9 @@ impl ReplicationConn { // Wait for replica's feedback. // We only handle `CopyData` messages. Anything else is ignored. loop { - match FeMessage::read_from(&mut stream_in)? { - FeMessage::CopyData(m) => { - let feedback = HotStandbyFeedback::des(&m.body)?; + match FeMessage::read(&mut stream_in)? { + Some(FeMessage::CopyData(m)) => { + let feedback = HotStandbyFeedback::des(&m)?; timeline.add_hs_feedback(feedback) } msg => { @@ -128,9 +104,14 @@ impl ReplicationConn { /// /// Handle START_REPLICATION replication command /// - pub fn run(&mut self, cmd: &Bytes) -> Result<()> { + pub fn run( + &mut self, + swh: &mut SendWalHandler, + pgb: &mut PostgresBackend, + cmd: &Bytes, + ) -> Result<()> { // spawn the background thread which receives HotStandbyFeedback messages. - let bg_timeline = Arc::clone(self.timeline.get()); + let bg_timeline = Arc::clone(swh.timeline.get()); let bg_stream_in = self.stream_in.take().unwrap(); thread::spawn(move || { @@ -143,7 +124,7 @@ impl ReplicationConn { let mut wal_seg_size: usize; loop { - wal_seg_size = self.timeline.get().get_info().server.wal_seg_size as usize; + wal_seg_size = swh.timeline.get().get_info().server.wal_seg_size as usize; if wal_seg_size == 0 { error!("Can not start replication before connecting to wal_proposer"); sleep(Duration::from_secs(1)); @@ -151,19 +132,17 @@ impl ReplicationConn { break; } } - let (wal_end, timeline) = self.timeline.find_end_of_wal(&self.conf.data_dir, false); + let (wal_end, timeline) = swh.timeline.find_end_of_wal(&swh.conf.data_dir, false); if start_pos == Lsn(0) { start_pos = wal_end; } - if stop_pos == Lsn(0) && self.appname == Some("wal_proposer_recovery".to_string()) { + if stop_pos == Lsn(0) && swh.appname == Some("wal_proposer_recovery".to_string()) { stop_pos = wal_end; } info!("Start replication from {} till {}", start_pos, stop_pos); - let mut outbuf = BytesMut::new(); - BeMessage::write(&mut outbuf, &BeMessage::Copy); - self.send(&outbuf)?; - outbuf.clear(); + // switch to copy + pgb.write_message(&BeMessage::CopyBothResponse)?; let mut end_pos: Lsn; let mut wal_file: Option = None; @@ -179,7 +158,7 @@ impl ReplicationConn { end_pos = stop_pos; } else { /* normal mode */ - let timeline = self.timeline.get(); + let timeline = swh.timeline.get(); end_pos = timeline.wait_for_lsn(start_pos); } if end_pos == END_REPLICATION_MARKER { @@ -193,8 +172,8 @@ impl ReplicationConn { // Open a new file. let segno = start_pos.segment_number(wal_seg_size); let wal_file_name = XLogFileName(timeline, segno, wal_seg_size); - let timeline_id = self.timeline.get().timelineid.to_string(); - let wal_file_path = self.conf.data_dir.join(timeline_id).join(wal_file_name); + let timeline_id = swh.timeline.get().timelineid.to_string(); + let wal_file_path = swh.conf.data_dir.join(timeline_id).join(wal_file_name); Self::open_wal_file(&wal_file_path)? } }; @@ -207,32 +186,19 @@ impl ReplicationConn { let send_size = min(send_size, wal_seg_size - xlogoff); let send_size = min(send_size, MAX_SEND_SIZE); - let msg_size = LIBPQ_HDR_SIZE + XLOG_HDR_SIZE + send_size; - // Read some data from the file. let mut file_buf = vec![0u8; send_size]; file.seek(SeekFrom::Start(xlogoff as u64))?; file.read_exact(&mut file_buf)?; // Write some data to the network socket. - // FIXME: turn these into structs. - // 'd' is CopyData; - // 'w' is "WAL records" - // https://www.postgresql.org/docs/9.1/protocol-message-formats.html - // src/backend/replication/walreceiver.c - outbuf.clear(); - outbuf.put_u8(b'd'); - outbuf.put_u32((msg_size - LIBPQ_MSG_SIZE_OFFS) as u32); - outbuf.put_u8(b'w'); - outbuf.put_u64(start_pos.0); - outbuf.put_u64(end_pos.0); - outbuf.put_u64(get_current_timestamp()); + pgb.write_message(&BeMessage::XLogData(XLogDataBody { + wal_start: start_pos.0, + wal_end: end_pos.0, + timestamp: get_current_timestamp(), + data: &file_buf, + }))?; - assert!(outbuf.len() + file_buf.len() == msg_size); - // This thread has exclusive access to the TcpStream, so it's fine - // to do this as two separate calls. - self.send(&outbuf)?; - self.send(&file_buf)?; start_pos += send_size as u64; debug!("Sent WAL to page server up to {}", end_pos); @@ -245,10 +211,4 @@ impl ReplicationConn { } Ok(()) } - - /// Send messages on the network. - fn send(&mut self, buf: &[u8]) -> Result<()> { - self.stream_out.write_all(buf.as_ref())?; - Ok(()) - } } diff --git a/walkeeper/src/send_wal.rs b/walkeeper/src/send_wal.rs index fc29c3575d..711be4b9ca 100644 --- a/walkeeper/src/send_wal.rs +++ b/walkeeper/src/send_wal.rs @@ -1,111 +1,69 @@ -//! This implements the libpq replication protocol between wal_acceptor -//! and replicas/pagers +//! Part of WAL acceptor pretending to be Postgres, streaming xlog to +//! pageserver/any other consumer. //! -use crate::pq_protocol::{ - BeMessage, FeMessage, FeStartupMessage, RowDescriptor, StartupRequestCode, -}; use crate::replication::ReplicationConn; use crate::timeline::{Timeline, TimelineTools}; use crate::WalAcceptorConf; use anyhow::{bail, Result}; -use bytes::BytesMut; -use log::*; -use std::io::{BufReader, Write}; -use std::net::{SocketAddr, TcpStream}; +use bytes::Bytes; +use pageserver::ZTimelineId; +use std::str::FromStr; use std::sync::Arc; +use zenith_utils::postgres_backend; +use zenith_utils::postgres_backend::PostgresBackend; +use zenith_utils::pq_proto::{BeMessage, FeStartupMessage, RowDescriptor}; -/// A network connection that's speaking the libpq replication protocol. -pub struct SendWalConn { - pub timeline: Option>, - /// Postgres connection, buffered input - pub stream_in: BufReader, - /// Postgres connection, output - pub stream_out: TcpStream, - /// The cached result of socket.peer_addr() - pub peer_addr: SocketAddr, +/// Handler for streaming WAL from acceptor +pub struct SendWalHandler { /// wal acceptor configuration pub conf: WalAcceptorConf, /// assigned application name - appname: Option, + pub appname: Option, + pub timeline: Option>, } -impl SendWalConn { - /// Create a new `SendWal`, consuming the `Connection`. - pub fn new(socket: TcpStream, conf: WalAcceptorConf) -> Result { - let peer_addr = socket.peer_addr()?; - let conn = SendWalConn { - timeline: None, - stream_in: BufReader::new(socket.try_clone()?), - stream_out: socket, - peer_addr, - conf, - appname: None, - }; - Ok(conn) +impl postgres_backend::Handler for SendWalHandler { + fn startup(&mut self, _pgb: &mut PostgresBackend, sm: &FeStartupMessage) -> Result<()> { + match sm.params.get("ztimelineid") { + Some(ref ztimelineid) => { + let ztlid = ZTimelineId::from_str(ztimelineid)?; + self.timeline.set(ztlid)?; + } + _ => bail!("timelineid is required"), + } + if let Some(app_name) = sm.params.get("application_name") { + self.appname = Some(app_name.clone()); + } + Ok(()) } - /// - /// Send WAL to replica or WAL receiver using standard libpq replication protocol - /// - pub fn run(mut self) -> Result<()> { - let peer_addr = self.peer_addr; - info!("WAL sender to {:?} is started", peer_addr); - - // Handle the startup message first. - - let m = FeStartupMessage::read_from(&mut self.stream_in)?; - trace!("got startup message {:?}", m); - match m.kind { - StartupRequestCode::NegotiateGss | StartupRequestCode::NegotiateSsl => { - let mut buf = BytesMut::new(); - BeMessage::write(&mut buf, &BeMessage::Negotiate); - info!("SSL requested"); - self.stream_out.write_all(&buf)?; - } - StartupRequestCode::Normal => { - let mut buf = BytesMut::new(); - BeMessage::write(&mut buf, &BeMessage::AuthenticationOk); - BeMessage::write(&mut buf, &BeMessage::ReadyForQuery); - self.stream_out.write_all(&buf)?; - self.timeline.set(m.timelineid)?; - self.appname = m.appname; - } - StartupRequestCode::Cancel => return Ok(()), + fn process_query(&mut self, pgb: &mut PostgresBackend, query_string: Bytes) -> Result<()> { + if query_string.starts_with(b"IDENTIFY_SYSTEM") { + self.handle_identify_system(pgb)?; + Ok(()) + } else if query_string.starts_with(b"START_REPLICATION") { + ReplicationConn::new(pgb).run(self, pgb, &query_string)?; + Ok(()) + } else { + bail!("Unexpected command {:?}", query_string); } + } +} - loop { - let msg = FeMessage::read_from(&mut self.stream_in)?; - match msg { - FeMessage::Query(q) => { - trace!("got query {:?}", q.body); - - if q.body.starts_with(b"IDENTIFY_SYSTEM") { - self.handle_identify_system()?; - } else if q.body.starts_with(b"START_REPLICATION") { - // Create a new replication object, consuming `self`. - ReplicationConn::new(self).run(&q.body)?; - break; - } else { - bail!("Unexpected command {:?}", q.body); - } - } - FeMessage::Terminate => { - break; - } - _ => { - bail!("unexpected message"); - } - } +impl SendWalHandler { + pub fn new(conf: WalAcceptorConf) -> Self { + SendWalHandler { + conf, + appname: None, + timeline: None, } - info!("WAL sender to {:?} is finished", peer_addr); - Ok(()) } /// /// Handle IDENTIFY_SYSTEM replication command /// - fn handle_identify_system(&mut self) -> Result<()> { + fn handle_identify_system(&mut self, pgb: &mut PostgresBackend) -> Result<()> { let (start_pos, timeline) = self.timeline.find_end_of_wal(&self.conf.data_dir, false); let lsn = start_pos.to_string(); let tli = timeline.to_string(); @@ -114,42 +72,39 @@ impl SendWalConn { let tli_bytes = tli.as_bytes(); let sysid_bytes = sysid.as_bytes(); - let mut outbuf = BytesMut::new(); - BeMessage::write( - &mut outbuf, - &BeMessage::RowDescription(&[ - RowDescriptor { - name: b"systemid\0", - typoid: 25, - typlen: -1, - }, - RowDescriptor { - name: b"timeline\0", - typoid: 23, - typlen: 4, - }, - RowDescriptor { - name: b"xlogpos\0", - typoid: 25, - typlen: -1, - }, - RowDescriptor { - name: b"dbname\0", - typoid: 25, - typlen: -1, - }, - ]), - ); - BeMessage::write( - &mut outbuf, - &BeMessage::DataRow(&[Some(sysid_bytes), Some(tli_bytes), Some(lsn_bytes), None]), - ); - BeMessage::write( - &mut outbuf, - &BeMessage::CommandComplete(b"IDENTIFY_SYSTEM\0"), - ); - BeMessage::write(&mut outbuf, &BeMessage::ReadyForQuery); - self.stream_out.write_all(&outbuf)?; + pgb.write_message_noflush(&BeMessage::RowDescription(&[ + RowDescriptor { + name: b"systemid", + typoid: 25, + typlen: -1, + ..Default::default() + }, + RowDescriptor { + name: b"timeline", + typoid: 23, + typlen: 4, + ..Default::default() + }, + RowDescriptor { + name: b"xlogpos", + typoid: 25, + typlen: -1, + ..Default::default() + }, + RowDescriptor { + name: b"dbname", + typoid: 25, + typlen: -1, + ..Default::default() + }, + ]))? + .write_message_noflush(&BeMessage::DataRow(&[ + Some(sysid_bytes), + Some(tli_bytes), + Some(lsn_bytes), + None, + ]))? + .write_message(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM"))?; Ok(()) } } diff --git a/walkeeper/src/wal_service.rs b/walkeeper/src/wal_service.rs index 0c045bd9e6..fab07a24c4 100644 --- a/walkeeper/src/wal_service.rs +++ b/walkeeper/src/wal_service.rs @@ -9,8 +9,9 @@ use std::net::{TcpListener, TcpStream}; use std::thread; use crate::receive_wal::ReceiveWalConn; -use crate::send_wal::SendWalConn; +use crate::send_wal::SendWalHandler; use crate::WalAcceptorConf; +use zenith_utils::postgres_backend::PostgresBackend; /// Accept incoming TCP connections and spawn them into a background thread. pub fn thread_main(conf: WalAcceptorConf) -> Result<()> { @@ -48,7 +49,10 @@ fn handle_socket(mut socket: TcpStream, conf: WalAcceptorConf) -> Result<()> { socket.read_exact(&mut [0u8; 4])?; ReceiveWalConn::new(socket, conf)?.run()?; // internal protocol between wal_proposer and wal_acceptor } else { - SendWalConn::new(socket, conf)?.run()?; // libpq replication protocol between wal_acceptor and replicas/pagers + let mut conn_handler = SendWalHandler::new(conf); + let mut pgbackend = PostgresBackend::new(socket)?; + // libpq replication protocol between wal_acceptor and replicas/pagers + pgbackend.run(&mut conn_handler)?; } Ok(()) } diff --git a/zenith_utils/Cargo.toml b/zenith_utils/Cargo.toml index e3576d97d2..3c644cb6d9 100644 --- a/zenith_utils/Cargo.toml +++ b/zenith_utils/Cargo.toml @@ -5,6 +5,10 @@ authors = ["Eric Seppanen "] edition = "2018" [dependencies] +anyhow = "1.0" +bytes = "1.0.1" +byteorder = "1.4.3" +log = "0.4.14" serde = { version = "1.0", features = ["derive"] } bincode = "1.3" thiserror = "1.0" diff --git a/zenith_utils/src/lib.rs b/zenith_utils/src/lib.rs index 8388c5f9ed..fc24a0ad1e 100644 --- a/zenith_utils/src/lib.rs +++ b/zenith_utils/src/lib.rs @@ -10,3 +10,6 @@ pub mod seqwait; // pub mod seqwait_async; pub mod bin_ser; + +pub mod postgres_backend; +pub mod pq_proto; diff --git a/zenith_utils/src/postgres_backend.rs b/zenith_utils/src/postgres_backend.rs new file mode 100644 index 0000000000..78993af761 --- /dev/null +++ b/zenith_utils/src/postgres_backend.rs @@ -0,0 +1,181 @@ +//! Server-side synchronous Postgres connection, as limited as we need. +//! To use, create PostgresBackend and run() it, passing the Handler +//! implementation determining how to process the queries. Currently its API +//! is rather narrow, but we can extend it once required. + +use crate::pq_proto::{BeMessage, FeMessage, FeStartupMessage, StartupRequestCode}; +use anyhow::bail; +use anyhow::Result; +use bytes::{Bytes, BytesMut}; +use log::*; +use std::io; +use std::io::{BufReader, Write}; +use std::net::{Shutdown, TcpStream}; + +pub trait Handler { + /// Handle single query. + /// postgres_backend will issue ReadyForQuery after calling this (this + /// might be not what we want after CopyData streaming, but currently we don't + /// care). + fn process_query(&mut self, pgb: &mut PostgresBackend, query_string: Bytes) -> Result<()>; + /// Called on startup packet receival, allows to process params. + fn startup(&mut self, _pgb: &mut PostgresBackend, _sm: &FeStartupMessage) -> Result<()> { + Ok(()) + } +} + +pub struct PostgresBackend { + // replication.rs wants to handle reading on its own in separate thread, so + // wrap in Option to be able to take and transfer the BufReader. Ugly, but I + // have no better ideas. + stream_in: Option>, + stream_out: TcpStream, + // Output buffer. c.f. BeMessage::write why we are using BytesMut here. + buf_out: BytesMut, + init_done: bool, +} + +// In replication.rs a separate thread is reading keepalives from the +// socket. When main one finishes, tell it to get down by shutdowning the +// socket. +impl Drop for PostgresBackend { + fn drop(&mut self) { + let _res = self.stream_out.shutdown(Shutdown::Both); + } +} + +impl PostgresBackend { + pub fn new(socket: TcpStream) -> Result { + let mut pb = PostgresBackend { + stream_in: None, + stream_out: socket, + buf_out: BytesMut::with_capacity(10 * 1024), + init_done: false, + }; + // if socket cloning fails, report the error and bail out + pb.stream_in = match pb.stream_out.try_clone() { + Ok(read_sock) => Some(BufReader::new(read_sock)), + Err(error) => { + let errmsg = format!("{}", error); + let _res = pb.write_message_noflush(&BeMessage::ErrorResponse(errmsg)); + return Err(error); + } + }; + Ok(pb) + } + + /// Get direct reference (into the Option) to the read stream. + fn get_stream_in(&mut self) -> Result<&mut BufReader> { + match self.stream_in { + Some(ref mut stream_in) => Ok(stream_in), + None => bail!("stream_in was taken"), + } + } + + pub fn take_stream_in(&mut self) -> Option> { + self.stream_in.take() + } + + /// Read full message or return None if connection is closed. + pub fn read_message(&mut self) -> Result> { + if !self.init_done { + FeStartupMessage::read(self.get_stream_in()?) + } else { + FeMessage::read(self.get_stream_in()?) + } + } + + /// Write message into internal output buffer. + pub fn write_message_noflush(&mut self, message: &BeMessage) -> io::Result<&mut Self> { + BeMessage::write(&mut self.buf_out, message)?; + Ok(self) + } + + /// Flush output buffer into the socket. + pub fn flush(&mut self) -> io::Result<&mut Self> { + self.stream_out.write_all(&self.buf_out)?; + self.buf_out.clear(); + Ok(self) + } + + /// Write message into internal buffer and flush it. + pub fn write_message(&mut self, message: &BeMessage) -> io::Result<&mut Self> { + self.write_message_noflush(message)?; + self.flush() + } + + pub fn run(&mut self, handler: &mut impl Handler) -> Result<()> { + let peer_addr = self.stream_out.peer_addr()?; + info!("postgres backend to {:?} started", peer_addr); + let mut unnamed_query_string = Bytes::new(); + + loop { + let msg = self.read_message()?; + trace!("got message {:?}", msg); + match msg { + Some(FeMessage::StartupMessage(m)) => { + trace!("got startup message {:?}", m); + + handler.startup(self, &m)?; + + match m.kind { + StartupRequestCode::NegotiateGss | StartupRequestCode::NegotiateSsl => { + info!("SSL requested"); + self.write_message(&BeMessage::Negotiate)?; + } + StartupRequestCode::Normal => { + self.write_message_noflush(&BeMessage::AuthenticationOk)?; + // psycopg2 will not connect if client_encoding is not + // specified by the server + self.write_message_noflush(&BeMessage::ParameterStatus)?; + self.write_message(&BeMessage::ReadyForQuery)?; + self.init_done = true; + } + StartupRequestCode::Cancel => break, + } + } + Some(FeMessage::Query(m)) => { + trace!("got query {:?}", m.body); + // xxx distinguish fatal and recoverable errors? + if let Err(e) = handler.process_query(self, m.body) { + let errmsg = format!("{}", e); + self.write_message_noflush(&BeMessage::ErrorResponse(errmsg))?; + } + self.write_message(&BeMessage::ReadyForQuery)?; + } + Some(FeMessage::Parse(m)) => { + unnamed_query_string = m.query_string; + self.write_message(&BeMessage::ParseComplete)?; + } + Some(FeMessage::Describe(_)) => { + self.write_message_noflush(&BeMessage::ParameterDescription)? + .write_message(&BeMessage::NoData)?; + } + Some(FeMessage::Bind(_)) => { + self.write_message(&BeMessage::BindComplete)?; + } + Some(FeMessage::Close(_)) => { + self.write_message(&BeMessage::CloseComplete)?; + } + Some(FeMessage::Execute(_)) => { + handler.process_query(self, unnamed_query_string.clone())?; + } + Some(FeMessage::Sync) => { + self.write_message(&BeMessage::ReadyForQuery)?; + } + Some(FeMessage::Terminate) => { + break; + } + None => { + info!("connection closed"); + break; + } + x => { + bail!("unexpected message type : {:?}", x); + } + } + } + info!("postgres backend to {:?} exited", peer_addr); + Ok(()) + } +} diff --git a/zenith_utils/src/pq_proto.rs b/zenith_utils/src/pq_proto.rs new file mode 100644 index 0000000000..e1f43067e5 --- /dev/null +++ b/zenith_utils/src/pq_proto.rs @@ -0,0 +1,639 @@ +//! Postgres protocol messages serialization-deserialization. See +//! https://www.postgresql.org/docs/devel/protocol-message-formats.html +//! on message formats. + +use anyhow::{anyhow, bail, Result}; +use byteorder::{BigEndian, ByteOrder}; +use byteorder::{ReadBytesExt, BE}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +// use postgres_ffi::xlog_utils::TimestampTz; +use std::collections::HashMap; +use std::io; +use std::io::Read; +use std::str; + +pub type Oid = u32; +pub type SystemId = u64; + +#[derive(Debug)] +pub enum FeMessage { + StartupMessage(FeStartupMessage), + Query(FeQueryMessage), // Simple query + Parse(FeParseMessage), // Extended query protocol + Describe(FeDescribeMessage), + Bind(FeBindMessage), + Execute(FeExecuteMessage), + Close(FeCloseMessage), + Sync, + Terminate, + CopyData(Bytes), + CopyDone, +} + +#[derive(Debug)] +pub struct FeStartupMessage { + pub version: u32, + pub kind: StartupRequestCode, + // optional params arriving in startup packet + pub params: HashMap, +} + +#[derive(Debug)] +pub enum StartupRequestCode { + Cancel, + NegotiateSsl, + NegotiateGss, + Normal, +} + +#[derive(Debug)] +pub struct FeQueryMessage { + pub body: Bytes, +} + +// We only support the simple case of Parse on unnamed prepared statement and +// no params +#[derive(Debug)] +pub struct FeParseMessage { + pub query_string: Bytes, +} + +#[derive(Debug)] +pub struct FeDescribeMessage { + pub kind: u8, // 'S' to describe a prepared statement; or 'P' to describe a portal. + // we only support unnamed prepared stmt or portal +} + +// we only support unnamed prepared stmt and portal +#[derive(Debug)] +pub struct FeBindMessage {} + +// we only support unnamed prepared stmt or portal +#[derive(Debug)] +pub struct FeExecuteMessage { + /// max # of rows + pub maxrows: i32, +} + +// we only support unnamed prepared stmt and portal +#[derive(Debug)] +pub struct FeCloseMessage {} + +impl FeMessage { + /// Read one message from the stream. + pub fn read(stream: &mut impl Read) -> anyhow::Result> { + // Each libpq message begins with a message type byte, followed by message length + // If the client closes the connection, return None. But if the client closes the + // connection in the middle of a message, we will return an error. + let tag = match stream.read_u8() { + Ok(b) => b, + Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), + Err(e) => return Err(e.into()), + }; + let len = stream.read_u32::()?; + + // The message length includes itself, so it better be at least 4 + let bodylen = len.checked_sub(4).ok_or(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid message length: parsing u32", + ))?; + + // Read message body + let mut body_buf: Vec = vec![0; bodylen as usize]; + stream.read_exact(&mut body_buf)?; + + let body = Bytes::from(body_buf); + + // Parse it + match tag { + b'Q' => Ok(Some(FeMessage::Query(FeQueryMessage { body }))), + b'P' => Ok(Some(FeParseMessage::parse(body)?)), + b'D' => Ok(Some(FeDescribeMessage::parse(body)?)), + b'E' => Ok(Some(FeExecuteMessage::parse(body)?)), + b'B' => Ok(Some(FeBindMessage::parse(body)?)), + b'C' => Ok(Some(FeCloseMessage::parse(body)?)), + b'S' => Ok(Some(FeMessage::Sync)), + b'X' => Ok(Some(FeMessage::Terminate)), + b'd' => Ok(Some(FeMessage::CopyData(body))), + b'c' => Ok(Some(FeMessage::CopyDone)), + tag => Err(anyhow!("unknown message tag: {},'{:?}'", tag, body)), + } + } +} + +impl FeStartupMessage { + /// Read startup message from the stream. + pub fn read(stream: &mut impl std::io::Read) -> anyhow::Result> { + const MAX_STARTUP_PACKET_LENGTH: usize = 10000; + const CANCEL_REQUEST_CODE: u32 = (1234 << 16) | 5678; + const NEGOTIATE_SSL_CODE: u32 = (1234 << 16) | 5679; + const NEGOTIATE_GSS_CODE: u32 = (1234 << 16) | 5680; + + // Read length. If the connection is closed before reading anything (or before + // reading 4 bytes, to be precise), return None to indicate that the connection + // was closed. This matches the PostgreSQL server's behavior, which avoids noise + // in the log if the client opens connection but closes it immediately. + let len = match stream.read_u32::() { + Ok(len) => len as usize, + Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), + Err(e) => return Err(e.into()), + }; + + if len < 4 || len > MAX_STARTUP_PACKET_LENGTH { + bail!("invalid message length"); + } + + let version = stream.read_u32::()?; + let kind = match version { + CANCEL_REQUEST_CODE => StartupRequestCode::Cancel, + NEGOTIATE_SSL_CODE => StartupRequestCode::NegotiateSsl, + NEGOTIATE_GSS_CODE => StartupRequestCode::NegotiateGss, + _ => StartupRequestCode::Normal, + }; + + // the rest of startup packet are params + let params_len = len - 8; + let mut params_bytes = vec![0u8; params_len]; + stream.read_exact(params_bytes.as_mut())?; + + // Then null-terminated (String) pairs of param name / param value go. + let params_str = str::from_utf8(¶ms_bytes).unwrap(); + let params = params_str.split('\0'); + let mut params_hash: HashMap = HashMap::new(); + for pair in params.collect::>().chunks_exact(2) { + let name = pair[0]; + let value = pair[1]; + if name == "options" { + // deprecated way of passing params as cmd line args + for cmdopt in value.split(' ') { + let nameval: Vec<&str> = cmdopt.split('=').collect(); + if nameval.len() == 2 { + params_hash.insert(nameval[0].to_string(), nameval[1].to_string()); + } + } + } else { + params_hash.insert(name.to_string(), value.to_string()); + } + } + + Ok(Some(FeMessage::StartupMessage(FeStartupMessage { + version, + kind, + params: params_hash, + }))) + } +} + +impl FeParseMessage { + pub fn parse(mut buf: Bytes) -> anyhow::Result { + let _pstmt_name = read_null_terminated(&mut buf)?; + let query_string = read_null_terminated(&mut buf)?; + let nparams = buf.get_i16(); + + // FIXME: the rust-postgres driver uses a named prepared statement + // for copy_out(). We're not prepared to handle that correctly. For + // now, just ignore the statement name, assuming that the client never + // uses more than one prepared statement at a time. + /* + if !pstmt_name.is_empty() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "named prepared statements not implemented in Parse", + )); + } + */ + + if nparams != 0 { + bail!("query params not implemented"); + } + + Ok(FeMessage::Parse(FeParseMessage { query_string })) + } +} + +impl FeDescribeMessage { + pub fn parse(mut buf: Bytes) -> anyhow::Result { + let kind = buf.get_u8(); + let _pstmt_name = read_null_terminated(&mut buf)?; + + // FIXME: see FeParseMessage::parse + /* + if !pstmt_name.is_empty() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "named prepared statements not implemented in Describe", + )); + } + */ + + if kind != b'S' { + bail!("only prepared statmement Describe is implemented"); + } + + Ok(FeMessage::Describe(FeDescribeMessage { kind })) + } +} + +impl FeExecuteMessage { + pub fn parse(mut buf: Bytes) -> anyhow::Result { + let portal_name = read_null_terminated(&mut buf)?; + let maxrows = buf.get_i32(); + + if !portal_name.is_empty() { + bail!("named portals not implemented"); + } + + if maxrows != 0 { + bail!("row limit in Execute message not supported"); + } + + Ok(FeMessage::Execute(FeExecuteMessage { maxrows })) + } +} + +impl FeBindMessage { + pub fn parse(mut buf: Bytes) -> anyhow::Result { + let portal_name = read_null_terminated(&mut buf)?; + let _pstmt_name = read_null_terminated(&mut buf)?; + + if !portal_name.is_empty() { + bail!("named portals not implemented"); + } + + // FIXME: see FeParseMessage::parse + /* + if !pstmt_name.is_empty() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "named prepared statements not implemented", + )); + } + */ + + Ok(FeMessage::Bind(FeBindMessage {})) + } +} + +impl FeCloseMessage { + pub fn parse(mut buf: Bytes) -> anyhow::Result { + let _kind = buf.get_u8(); + let _pstmt_or_portal_name = read_null_terminated(&mut buf)?; + + // FIXME: we do nothing with Close + + Ok(FeMessage::Close(FeCloseMessage {})) + } +} + +fn read_null_terminated(buf: &mut Bytes) -> anyhow::Result { + let mut result = BytesMut::new(); + + loop { + if !buf.has_remaining() { + bail!("no null-terminator in string"); + } + + let byte = buf.get_u8(); + + if byte == 0 { + break; + } + result.put_u8(byte); + } + Ok(result.freeze()) +} + +// Backend + +#[derive(Debug)] +pub enum BeMessage<'a> { + AuthenticationOk, + BindComplete, + CommandComplete(&'a [u8]), + ControlFile, + CopyData(&'a [u8]), + CopyDone, + CopyInResponse, + CopyOutResponse, + CopyBothResponse, + CloseComplete, + // None means column is NULL + DataRow(&'a [Option<&'a [u8]>]), + ErrorResponse(String), + // see https://www.postgresql.org/docs/devel/protocol-flow.html#id-1.10.5.7.11 + Negotiate, + NoData, + ParameterDescription, + ParameterStatus, + ParseComplete, + ReadyForQuery, + RowDescription(&'a [RowDescriptor<'a>]), + XLogData(XLogDataBody<'a>), +} + +// One row desciption in RowDescription packet. +#[derive(Debug)] +pub struct RowDescriptor<'a> { + pub name: &'a [u8], + pub tableoid: Oid, + pub attnum: i16, + pub typoid: Oid, + pub typlen: i16, + pub typmod: i32, + pub formatcode: i16, +} + +impl Default for RowDescriptor<'_> { + fn default() -> RowDescriptor<'static> { + RowDescriptor { + name: b"", + tableoid: 0, + attnum: 0, + typoid: 0, + typlen: 0, + typmod: 0, + formatcode: 0, + } + } +} + +#[derive(Debug)] +pub struct XLogDataBody<'a> { + pub wal_start: u64, + pub wal_end: u64, + pub timestamp: u64, + pub data: &'a [u8], +} + +pub static HELLO_WORLD_ROW: BeMessage = BeMessage::DataRow(&[Some(b"hello world")]); +pub const TEXT_OID: Oid = 25; +// single text column +pub static SINGLE_COL_ROWDESC: BeMessage = BeMessage::RowDescription(&[RowDescriptor { + name: b"data", + tableoid: 0, + attnum: 0, + typoid: TEXT_OID, + typlen: -1, + typmod: 0, + formatcode: 0, +}]); + +// Safe usize -> i32|i16 conversion, from rust-postgres +trait FromUsize: Sized { + fn from_usize(x: usize) -> Result; +} + +macro_rules! from_usize { + ($t:ty) => { + impl FromUsize for $t { + #[inline] + fn from_usize(x: usize) -> io::Result<$t> { + if x > <$t>::max_value() as usize { + Err(io::Error::new( + io::ErrorKind::InvalidInput, + "value too large to transmit", + )) + } else { + Ok(x as $t) + } + } + } + }; +} + +from_usize!(i32); + +/// Call f() to write body of the message and prepend it with 4-byte len as +/// prescribed by the protocol. +fn write_body(buf: &mut BytesMut, f: F) -> io::Result<()> +where + F: FnOnce(&mut BytesMut) -> io::Result<()>, +{ + let base = buf.len(); + buf.extend_from_slice(&[0; 4]); + + f(buf)?; + + let size = i32::from_usize(buf.len() - base)?; + BigEndian::write_i32(&mut buf[base..], size); + Ok(()) +} + +/// Safe write of s into buf as cstring (String in the protocol). +fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> { + if s.contains(&0) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "string contains embedded null", + )); + } + buf.put_slice(s); + buf.put_u8(0); + Ok(()) +} + +impl<'a> BeMessage<'a> { + /// Write message to the given buf. + // Unlike the reading side, we use BytesMut + // here as msg len preceeds its body and it is handy to write it down first + // and then fill the length. With Write we would have to either calc it + // manually or have one more buffer. + pub fn write(buf: &mut BytesMut, message: &BeMessage) -> io::Result<()> { + match message { + BeMessage::AuthenticationOk => { + buf.put_u8(b'R'); + write_body(buf, |buf| { + buf.put_i32(0); + Ok::<_, io::Error>(()) + }) + .unwrap(); // write into BytesMut can't fail + } + + BeMessage::BindComplete => { + buf.put_u8(b'2'); + write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); + } + + BeMessage::CloseComplete => { + buf.put_u8(b'3'); + write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); + } + + BeMessage::CommandComplete(cmd) => { + buf.put_u8(b'C'); + write_body(buf, |buf| { + write_cstr(cmd, buf)?; + Ok::<_, io::Error>(()) + })?; + } + + BeMessage::ControlFile => { + // TODO pass checkpoint and xid info in this message + BeMessage::write(buf, &BeMessage::DataRow(&[Some(b"hello pg_control")]))?; + } + + BeMessage::CopyData(data) => { + buf.put_u8(b'd'); + write_body(buf, |buf| { + buf.put_slice(data); + Ok::<_, io::Error>(()) + }) + .unwrap(); + } + + BeMessage::CopyDone => { + buf.put_u8(b'c'); + write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); + } + + BeMessage::CopyInResponse => { + buf.put_u8(b'G'); + write_body(buf, |buf| { + buf.put_u8(1); /* copy_is_binary */ + buf.put_i16(0); /* numAttributes */ + Ok::<_, io::Error>(()) + }) + .unwrap(); + } + + BeMessage::CopyOutResponse => { + buf.put_u8(b'H'); + write_body(buf, |buf| { + buf.put_u8(0); /* copy_is_binary */ + buf.put_i16(0); /* numAttributes */ + Ok::<_, io::Error>(()) + }) + .unwrap(); + } + + BeMessage::CopyBothResponse => { + buf.put_u8(b'W'); + write_body(buf, |buf| { + // doesn't matter, used only for replication + buf.put_u8(0); /* copy_is_binary */ + buf.put_i16(0); /* numAttributes */ + Ok::<_, io::Error>(()) + }) + .unwrap(); + } + + BeMessage::DataRow(vals) => { + buf.put_u8(b'D'); + write_body(buf, |buf| { + buf.put_u16(vals.len() as u16); // num of cols + for val_opt in vals.iter() { + if let Some(val) = val_opt { + buf.put_u32(val.len() as u32); + buf.put_slice(val); + } else { + buf.put_i32(-1); + } + } + Ok::<_, io::Error>(()) + }) + .unwrap(); + } + + // ErrorResponse is a zero-terminated array of zero-terminated fields. + // First byte of each field represents type of this field. Set just enough fields + // to satisfy rust-postgres client: 'S' -- severity, 'C' -- error, 'M' -- error + // message text. + BeMessage::ErrorResponse(error_msg) => { + // For all the errors set Severity to Error and error code to + // 'internal error'. + + // 'E' signalizes ErrorResponse messages + buf.put_u8(b'E'); + write_body(buf, |buf| { + buf.put_u8(b'S'); // severity + write_cstr(&Bytes::from("ERROR"), buf)?; + + buf.put_u8(b'C'); // SQLSTATE error code + write_cstr(&Bytes::from("CXX000"), buf)?; + + buf.put_u8(b'M'); // the message + write_cstr(error_msg.as_bytes(), buf)?; + + buf.put_u8(0); // terminator + Ok::<_, io::Error>(()) + }) + .unwrap(); + } + + BeMessage::NoData => { + buf.put_u8(b'n'); + write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); + } + + BeMessage::Negotiate => { + buf.put_u8(b'N'); + } + + BeMessage::ParameterStatus => { + buf.put_u8(b'S'); + // parameter names and values are specified by null terminated strings + const PARAM_NAME_VALUE: &[u8] = b"client_encoding\0UTF8\0"; + write_body(buf, |buf| { + buf.put_slice(PARAM_NAME_VALUE); + Ok::<_, io::Error>(()) + }) + .unwrap(); + } + + BeMessage::ParameterDescription => { + buf.put_u8(b't'); + write_body(buf, |buf| { + // we don't support params, so always 0 + buf.put_i16(0); + Ok::<_, io::Error>(()) + }) + .unwrap(); + } + + BeMessage::ParseComplete => { + buf.put_u8(b'1'); + write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); + } + + BeMessage::ReadyForQuery => { + buf.put_u8(b'Z'); + write_body(buf, |buf| { + buf.put_u8(b'I'); + Ok::<_, io::Error>(()) + }) + .unwrap(); + } + + BeMessage::RowDescription(rows) => { + buf.put_u8(b'T'); + write_body(buf, |buf| { + buf.put_i16(rows.len() as i16); // # of fields + for row in rows.iter() { + write_cstr(row.name, buf)?; + buf.put_i32(0); /* table oid */ + buf.put_i16(0); /* attnum */ + buf.put_u32(row.typoid); + buf.put_i16(row.typlen); + buf.put_i32(-1); /* typmod */ + buf.put_i16(0); /* format code */ + } + Ok::<_, io::Error>(()) + })?; + } + + BeMessage::XLogData(body) => { + buf.put_u8(b'd'); + write_body(buf, |buf| { + buf.put_u8(b'w'); + buf.put_u64(body.wal_start); + buf.put_u64(body.wal_end); + buf.put_u64(body.timestamp); + buf.put_slice(body.data); + Ok::<_, io::Error>(()) + }) + .unwrap(); + } + } + Ok(()) + } +}