From 513696a485790b5eea39df780840eac99dd04a30 Mon Sep 17 00:00:00 2001 From: Eric Seppanen Date: Wed, 12 May 2021 19:59:34 -0700 Subject: [PATCH] break wal_service into multiple pieces The pieces are: base Connection SendWal ReplicationHandler There are lots of other changes here: - Put the replication reader in a background thread; this gets rid of some hacks with nonblocking mode. - Stop manually buffering input data; use BufReader instead. - Use BytesMut a lot less; use Read/Write traits where possible. --- walkeeper/src/bin/wal_acceptor.rs | 5 +- walkeeper/src/pq_protocol.rs | 70 +-- walkeeper/src/wal_service.rs | 887 ++++++++++++++++-------------- zenith_utils/src/bin_ser.rs | 71 ++- 4 files changed, 565 insertions(+), 468 deletions(-) diff --git a/walkeeper/src/bin/wal_acceptor.rs b/walkeeper/src/bin/wal_acceptor.rs index bbf25a13e3..fed8f7746f 100644 --- a/walkeeper/src/bin/wal_acceptor.rs +++ b/walkeeper/src/bin/wal_acceptor.rs @@ -167,7 +167,10 @@ fn start_wal_acceptor(conf: WalAcceptorConf) -> Result<()> { .name("WAL acceptor thread".into()) .spawn(|| { // thread code - wal_service::thread_main(conf); + let thread_result = wal_service::thread_main(conf); + if let Err(e) = thread_result { + info!("wal_service thread terminated: {}", e); + } }) .unwrap(); threads.push(wal_acceptor_thread); diff --git a/walkeeper/src/pq_protocol.rs b/walkeeper/src/pq_protocol.rs index 25d801f714..0a9ca667e8 100644 --- a/walkeeper/src/pq_protocol.rs +++ b/walkeeper/src/pq_protocol.rs @@ -1,7 +1,7 @@ -use byteorder::{BigEndian, ByteOrder}; -use bytes::{Buf, BufMut, Bytes, BytesMut}; +use byteorder::{BigEndian, ReadBytesExt}; +use bytes::{BufMut, Bytes, BytesMut}; use pageserver::ZTimelineId; -use std::io; +use std::io::{self, Read}; use std::str; use std::str::FromStr; @@ -10,7 +10,6 @@ pub type SystemId = u64; #[derive(Debug)] pub enum FeMessage { - StartupMessage(FeStartupMessage), Query(FeQueryMessage), Terminate, CopyData(FeCopyData), @@ -51,28 +50,22 @@ pub enum StartupRequestCode { } impl FeStartupMessage { - pub fn parse(buf: &mut BytesMut) -> io::Result> { + pub fn read_from(reader: &mut impl Read) -> io::Result { const MAX_STARTUP_PACKET_LENGTH: usize = 10000; const CANCEL_REQUEST_CODE: u32 = (1234 << 16) | 5678; const NEGOTIATE_SSL_CODE: u32 = (1234 << 16) | 5679; const NEGOTIATE_GSS_CODE: u32 = (1234 << 16) | 5680; - if buf.len() < 4 { - return Ok(None); - } - let len = BigEndian::read_u32(&buf[0..4]) as usize; + let len = reader.read_u32::()? as usize; if len < 4 || len > MAX_STARTUP_PACKET_LENGTH { return Err(io::Error::new( io::ErrorKind::InvalidData, - "invalid message length", + "FeStartupMessage: invalid message length", )); } - if buf.len() < len { - return Ok(None); - } - let version = BigEndian::read_u32(&buf[4..8]); + let version = reader.read_u32::()?; let kind = match version { CANCEL_REQUEST_CODE => StartupRequestCode::Cancel, @@ -81,7 +74,11 @@ impl FeStartupMessage { _ => StartupRequestCode::Normal, }; - let params_bytes = &buf[8..len]; + // FIXME: A buffer pool would be nice, to avoid zeroing the buffer. + let params_len = len - 8; + let mut params_bytes = vec![0u8; params_len]; + reader.read_exact(params_bytes.as_mut())?; + let params_str = str::from_utf8(¶ms_bytes).unwrap(); let params = params_str.split('\0'); let mut options = false; @@ -109,13 +106,12 @@ impl FeStartupMessage { )); } - buf.advance(len as usize); - Ok(Some(FeMessage::StartupMessage(FeStartupMessage { + Ok(FeStartupMessage { version, kind, appname, timelineid: timelineid.unwrap(), - }))) + }) } } @@ -201,44 +197,28 @@ impl<'a> BeMessage<'a> { } impl FeMessage { - pub fn parse(buf: &mut BytesMut) -> io::Result> { - if buf.len() < 5 { - let to_read = 5 - buf.len(); - buf.reserve(to_read); - return Ok(None); - } - - let tag = buf[0]; - let len = BigEndian::read_u32(&buf[1..5]); + pub fn read_from(reader: &mut impl Read) -> io::Result { + let tag = reader.read_u8()?; + let len = reader.read_u32::()?; if len < 4 { return Err(io::Error::new( io::ErrorKind::InvalidInput, - "invalid message length: parsing u32", + "FeMessage: invalid message length", )); } - let total_len = len as usize + 1; - if buf.len() < total_len { - let to_read = total_len - buf.len(); - buf.reserve(to_read); - return Ok(None); - } - - let mut body = buf.split_to(total_len); - body.advance(5); + let body_len = (len - 4) as usize; + let mut body = vec![0u8; body_len]; + reader.read_exact(&mut body)?; match tag { - b'Q' => Ok(Some(FeMessage::Query(FeQueryMessage { - body: body.freeze(), - }))), - b'd' => Ok(Some(FeMessage::CopyData(FeCopyData { - body: body.freeze(), - }))), - b'X' => Ok(Some(FeMessage::Terminate)), + b'Q' => Ok(FeMessage::Query(FeQueryMessage { body: body.into() })), + b'd' => Ok(FeMessage::CopyData(FeCopyData { body: body.into() })), + b'X' => Ok(FeMessage::Terminate), tag => Err(io::Error::new( io::ErrorKind::InvalidInput, - format!("unknown message tag: {},'{:?}'", tag, buf), + format!("unknown message tag: {},'{:?}'", tag, body), )), } } diff --git a/walkeeper/src/wal_service.rs b/walkeeper/src/wal_service.rs index f4c64d9ef9..dd3fd77657 100644 --- a/walkeeper/src/wal_service.rs +++ b/walkeeper/src/wal_service.rs @@ -2,21 +2,21 @@ //! WAL service listens for client connections and //! receive WAL from wal_proposer and send it to WAL receivers //! -use anyhow::{anyhow, bail, Result}; -use byteorder::{BigEndian, ByteOrder}; +use anyhow::{bail, Result}; +use byteorder::{BigEndian, ByteOrder, ReadBytesExt}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use fs2::FileExt; use lazy_static::lazy_static; use log::*; use postgres::{Client, NoTls}; -use regex::Regex; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use std::cmp::{max, min}; use std::collections::HashMap; use std::fs::{self, File, OpenOptions}; -use std::io::{self, BufReader, Read, Seek, SeekFrom, Write}; +use std::io::{BufRead, BufReader, Read, Seek, SeekFrom, Write}; use std::mem; use std::net::{SocketAddr, TcpListener, TcpStream}; +use std::path::Path; use std::str; use std::sync::{Arc, Condvar, Mutex}; use std::thread; @@ -27,7 +27,7 @@ use crate::pq_protocol::*; use crate::WalAcceptorConf; use pageserver::ZTimelineId; use postgres_ffi::xlog_utils::{ - find_end_of_wal, get_current_timestamp, TimeLineID, TimestampTz, XLogFileName, XLOG_BLCKSZ, + find_end_of_wal, TimeLineID, TimestampTz, XLogFileName, XLOG_BLCKSZ, }; type FullTransactionId = u64; @@ -44,39 +44,6 @@ const LIBPQ_MSG_SIZE_OFFS: usize = 1; const CONTROL_FILE_NAME: &str = "safekeeper.control"; const END_OF_STREAM: Lsn = Lsn(0); -/// Read some bytes from a type that implements [`Read`] into a [`BytesMut`] -/// -/// Will return the number of bytes read, just like `Read::read()` would. -/// -fn read_into(r: &mut impl Read, buf: &mut BytesMut) -> io::Result { - // This is a workaround, because BytesMut and std::io don't play - // well together. - // - // I think this code needs to go away, and I'm confident that - // that's possible, if we are willing to refactor this code to - // use std::io::BufReader instead of doing buffer management - // ourselves. - // - // SAFETY: we already have exclusive access to self.inbuf, so - // there are no concurrency problems; the only risk would be - // accidentally exposing uninitialized parts of the buffer. - // - // We write into the buffer just past the known-initialized part, - // then manually increment its length by the exact number of - // bytes we read. So no uninitialized memory should be exposed. - - let start = buf.len(); - let end = buf.capacity(); - - let num_bytes = unsafe { - let fill_here = buf.get_unchecked_mut(start..end); - let num_bytes_read = r.read(fill_here)?; - buf.set_len(start + num_bytes_read); - num_bytes_read - }; - Ok(num_bytes) -} - /// Unique node identifier used by Paxos #[derive(Debug, Clone, Copy, Ord, PartialOrd, PartialEq, Eq, Serialize, Deserialize)] struct NodeId { @@ -181,9 +148,37 @@ pub struct Timeline { cond: Condvar, } +// Useful utilities needed by various Connection-like objects +trait TimelineTools { + fn set(&mut self, timeline_id: ZTimelineId) -> Result<()>; + fn get(&self) -> &Arc; + fn find_end_of_wal(&self, data_dir: &Path, precise: bool) -> (Lsn, TimeLineID); +} + +impl TimelineTools for Option> { + fn set(&mut self, timeline_id: ZTimelineId) -> Result<()> { + // We will only set the timeline once. If it were to ever change, + // anyone who cloned the Arc would be out of date. + assert!(self.is_none()); + *self = Some(GlobalTimelines::store(timeline_id)?); + Ok(()) + } + + fn get(&self) -> &Arc { + self.as_ref().unwrap() + } + + /// Find last WAL record. If "precise" is false then just locate last partial segment + fn find_end_of_wal(&self, data_dir: &Path, precise: bool) -> (Lsn, TimeLineID) { + let seg_size = self.get().get_info().server.wal_seg_size as usize; + let (lsn, timeline) = find_end_of_wal(data_dir, seg_size, precise); + (Lsn(lsn), timeline) + } +} + /// Private data #[derive(Debug)] -struct Connection { +pub struct Connection { timeline: Option>, /// Postgres connection, buffered input stream_in: BufReader, @@ -192,13 +187,9 @@ struct Connection { /// The cached result of socket.peer_addr() peer_addr: SocketAddr, /// input buffer - inbuf: BytesMut, + //inbuf: BytesMut, /// output buffer outbuf: BytesMut, - /// startup packet proceeded - init_done: bool, - /// assigned application name - appname: Option, /// wal acceptor configuration conf: WalAcceptorConf, } @@ -214,8 +205,8 @@ trait NewSerializer: Serialize + DeserializeOwned { } fn unpack(buf: &mut BytesMut) -> Self { - let buf_r = buf.reader(); - Self::des_from(buf_r).unwrap() + let mut buf_r = buf.reader(); + Self::des_from(&mut buf_r).unwrap() } } @@ -262,25 +253,35 @@ lazy_static! { Mutex::new(HashMap::new()); } -pub fn thread_main(conf: WalAcceptorConf) { - info!("Starting wal acceptor on {}", conf.listen_addr); - main_loop(&conf).unwrap(); -} +/// A zero-sized struct used to manage access to the global timelines map. +struct GlobalTimelines; -/// This is run by main_loop, inside a background thread. -/// -/// This is only a separate function to make a convenient place to collect -/// all errors for logging. Our caller can log errors in a single place. -fn handle_socket(socket: TcpStream, conf: WalAcceptorConf) -> Result<()> { - socket.set_nodelay(true)?; - let mut conn = Connection::new(socket, conf)?; - conn.run()?; - Ok(()) +impl GlobalTimelines { + /// Store a new timeline into the global TIMELINES map. + fn store(timeline_id: ZTimelineId) -> Result> { + let mut timelines = TIMELINES.lock().unwrap(); + + match timelines.get(&timeline_id) { + Some(result) => Ok(Arc::clone(result)), + None => { + info!("creating timeline dir {}", timeline_id); + fs::create_dir_all(timeline_id.to_string())?; + let new_tid = Arc::new(Timeline::new(timeline_id)); + timelines.insert(timeline_id, Arc::clone(&new_tid)); + Ok(new_tid) + } + } + } } /// Accept incoming TCP connections and spawn them into a background thread. -fn main_loop(conf: &WalAcceptorConf) -> Result<()> { - let listener = TcpListener::bind(conf.listen_addr)?; +pub fn thread_main(conf: WalAcceptorConf) -> Result<()> { + info!("Starting wal acceptor on {}", conf.listen_addr); + let listener = TcpListener::bind(conf.listen_addr).map_err(|e| { + error!("failed to bind to address {}: {}", conf.listen_addr, e); + e + })?; + loop { match listener.accept() { Ok((socket, peer_addr)) => { @@ -297,6 +298,17 @@ fn main_loop(conf: &WalAcceptorConf) -> Result<()> { } } +/// This is run by main_loop, inside a background thread. +/// +/// This is only a separate function to make a convenient place to collect +/// all errors for logging. Our caller can log errors in a single place. +fn handle_socket(socket: TcpStream, conf: WalAcceptorConf) -> Result<()> { + socket.set_nodelay(true)?; + let conn = Connection::new(socket, conf)?; + conn.run()?; + Ok(()) +} + impl Timeline { pub fn new(timelineid: ZTimelineId) -> Timeline { let shared_state = SharedState { @@ -447,36 +459,38 @@ impl Connection { stream_in: BufReader::new(socket.try_clone()?), stream_out: socket, peer_addr, - inbuf: BytesMut::with_capacity(10 * 1024), outbuf: BytesMut::with_capacity(10 * 1024), - init_done: false, - appname: None, conf, }; Ok(conn) } - fn timeline(&self) -> Arc { - self.timeline.as_ref().unwrap().clone() - } - - fn run(&mut self) -> Result<()> { - self.inbuf.resize(4, 0u8); - self.stream_in.read_exact(&mut self.inbuf[0..4])?; - let startup_pkg_len = BigEndian::read_u32(&self.inbuf[0..4]); + fn run(mut self) -> Result<()> { + // Peek at the first 4 bytes of the incoming data, to determine which protocol + // is being spoken. + // `fill_buf` does not consume any of the bytes we peek at; they are left + // in the BufReader's internal buffer for the next reader. + let peek_buf = self.stream_in.fill_buf()?; + if peek_buf.len() < 4 { + // Empty peek_buf means the socket was closed. + // Less than 4 bytes doesn't seem likely unless the sender is malicious. + // read_u32 would panic on any of these, so just return an error. + bail!("fill_buf EOF or underrun"); + } + let startup_pkg_len = BigEndian::read_u32(peek_buf); if startup_pkg_len == 0 { + // Consume the 4 bytes we peeked at. This protocol begins after them. + self.stream_in.read_u32::()?; self.receive_wal()?; // internal protocol between wal_proposer and wal_acceptor } else { - self.send_wal()?; // libpq replication protocol between wal_acceptor and replicas/pagers + send_wal::SendWal::new(self).run()?; // libpq replication protocol between wal_acceptor and replicas/pagers } Ok(()) } fn read_req(&mut self) -> Result { - let size = mem::size_of::(); - self.inbuf.resize(size, 0u8); - self.stream_in.read_exact(&mut self.inbuf[0..size])?; - Ok(T::unpack(&mut self.inbuf)) + // NewSerializer is always little-endian. + Ok(T::des_from(&mut self.stream_in)?) } fn request_callback(&self) -> std::result::Result<(), postgres::error::Error> { @@ -490,10 +504,10 @@ impl Connection { ); let callme = format!( "callmemaybe {} host={} port={} options='-c ztimelineid={}'", - self.timeline().timelineid, + self.timeline.get().timelineid, self.conf.listen_addr.ip(), self.conf.listen_addr.port(), - self.timeline().timelineid + self.timeline.get().timelineid ); info!( "requesting page server to connect to us: start {} {}", @@ -505,17 +519,6 @@ impl Connection { Ok(()) } - fn set_timeline(&mut self, timelineid: ZTimelineId) -> Result<()> { - let mut timelines = TIMELINES.lock().unwrap(); - if !timelines.contains_key(&timelineid) { - info!("creating timeline dir {}", timelineid); - fs::create_dir_all(timelineid.to_string())?; - timelines.insert(timelineid, Arc::new(Timeline::new(timelineid))); - } - self.timeline = Some(timelines.get(&timelineid).unwrap().clone()); - Ok(()) - } - /// Receive WAL from wal_proposer fn receive_wal(&mut self) -> Result<()> { // Receive information about server @@ -525,15 +528,15 @@ impl Connection { self.peer_addr, server_info.system_id, server_info.timeline_id, ); // FIXME: also check that the system identifier matches - self.set_timeline(server_info.timeline_id)?; - self.timeline().load_control_file(&self.conf)?; + self.timeline.set(server_info.timeline_id)?; + self.timeline.get().load_control_file(&self.conf)?; - let mut my_info = self.timeline().get_info(); + let mut my_info = self.timeline.get().get_info(); /* Check protocol compatibility */ if server_info.protocol_version != SK_PROTOCOL_VERSION { bail!( - "Incompatible protocol version {} vs. {}", + "Incompatible protocol version {}, expected {}", server_info.protocol_version, SK_PROTOCOL_VERSION ); @@ -543,17 +546,18 @@ impl Connection { && my_info.server.pg_version != UNKNOWN_SERVER_VERSION { info!( - "Server version doesn't match {} vs. {}", + "Incompatible server version {}, expected {}", server_info.pg_version, my_info.server.pg_version ); } + /* Update information about server, but preserve locally stored node_id */ let node_id = my_info.server.node_id; my_info.server = server_info; my_info.server.node_id = node_id; /* Calculate WAL end based on local data */ - let (flush_lsn, timeline) = self.find_end_of_wal(true); + let (flush_lsn, timeline) = self.timeline.find_end_of_wal(&self.conf.data_dir, true); my_info.flush_lsn = flush_lsn; my_info.server.timeline = timeline; @@ -577,9 +581,9 @@ impl Connection { ); } my_info.server.node_id = prop.node_id; - self.timeline().set_info(&my_info); + self.timeline.get().set_info(&my_info); /* Need to persist our vote first */ - self.timeline().save_control_file(true)?; + self.timeline.get().save_control_file(true)?; let mut flushed_restart_lsn = Lsn(0); let wal_seg_size = server_info.wal_seg_size as usize; @@ -626,11 +630,11 @@ impl Connection { ); /* Receive message body */ - self.inbuf.resize(rec_size, 0u8); - self.stream_in.read_exact(&mut self.inbuf[0..rec_size])?; + let mut inbuf = vec![0u8; rec_size]; + self.stream_in.read_exact(&mut inbuf)?; /* Save message in file */ - self.write_wal_file(start_pos, timeline, wal_seg_size, &self.inbuf[0..rec_size])?; + self.write_wal_file(start_pos, timeline, wal_seg_size, &inbuf)?; my_info.restart_lsn = req.restart_lsn; my_info.commit_lsn = req.commit_lsn; @@ -655,7 +659,7 @@ impl Connection { * when restart_lsn delta exceeds WAL segment size. */ sync_control_file |= flushed_restart_lsn + (wal_seg_size as u64) < my_info.restart_lsn; - self.timeline().save_control_file(sync_control_file)?; + self.timeline.get().save_control_file(sync_control_file)?; if sync_control_file { flushed_restart_lsn = my_info.restart_lsn; @@ -666,7 +670,7 @@ impl Connection { let resp = SafeKeeperResponse { epoch: my_info.epoch, flush_lsn: end_pos, - hs_feedback: self.timeline().get_hs_feedback(), + hs_feedback: self.timeline.get().get_hs_feedback(), }; self.start_sending(); resp.pack(&mut self.outbuf); @@ -676,43 +680,13 @@ impl Connection { * Ping wal sender that new data is available. * FlushLSN (end_pos) can be smaller than commitLSN in case we are at catching-up safekeeper. */ - self.timeline() + self.timeline + .get() .notify_wal_senders(min(req.commit_lsn, end_pos)); } Ok(()) } - /// - /// Read full message or return None if connection is closed - /// - fn read_message(&mut self) -> Result> { - loop { - if let Some(message) = self.parse_message()? { - return Ok(Some(message)); - } - - if read_into(&mut self.stream_in, &mut self.inbuf)? == 0 { - if self.inbuf.is_empty() { - return Ok(None); - } else { - bail!("connection reset by peer"); - } - } - } - } - - /// - /// Parse libpq message - /// - fn parse_message(&mut self) -> Result> { - let msg = if !self.init_done { - FeStartupMessage::parse(&mut self.inbuf)? - } else { - FeMessage::parse(&mut self.inbuf)? - }; - Ok(msg) - } - /// /// Reset output buffer to start accumulating data of new message /// @@ -727,270 +701,6 @@ impl Connection { Ok(self.stream_out.write_all(&self.outbuf)?) } - /// - /// Send WAL to replica or WAL receiver using standard libpq replication protocol - /// - fn send_wal(&mut self) -> Result<()> { - info!("WAL sender to {:?} is started", self.peer_addr); - loop { - self.start_sending(); - match self.read_message()? { - Some(FeMessage::StartupMessage(m)) => { - trace!("got message {:?}", m); - - match m.kind { - StartupRequestCode::NegotiateGss | StartupRequestCode::NegotiateSsl => { - BeMessage::write(&mut self.outbuf, &BeMessage::Negotiate); - info!("SSL requested"); - self.send()?; - } - StartupRequestCode::Normal => { - BeMessage::write(&mut self.outbuf, &BeMessage::AuthenticationOk); - BeMessage::write(&mut self.outbuf, &BeMessage::ReadyForQuery); - self.send()?; - self.init_done = true; - self.set_timeline(m.timelineid)?; - self.appname = m.appname; - } - StartupRequestCode::Cancel => return Ok(()), - } - } - Some(FeMessage::Query(m)) => { - if !self.process_query(&m)? { - break; - } - } - Some(FeMessage::Terminate) => { - break; - } - None => { - info!("connection closed"); - break; - } - _ => { - bail!("unexpected message"); - } - } - } - info!("WAL sender to {:?} is finished", self.peer_addr); - Ok(()) - } - - /// - /// Handle IDENTIFY_SYSTEM replication command - /// - fn handle_identify_system(&mut self) -> Result { - let (start_pos, timeline) = self.find_end_of_wal(false); - let lsn = start_pos.to_string(); - let tli = timeline.to_string(); - let sysid = self.timeline().get_info().server.system_id.to_string(); - let lsn_bytes = lsn.as_bytes(); - let tli_bytes = tli.as_bytes(); - let sysid_bytes = sysid.as_bytes(); - - BeMessage::write( - &mut self.outbuf, - &BeMessage::RowDescription(&[ - RowDescriptor { - name: b"systemid\0", - typoid: 25, - typlen: -1, - }, - RowDescriptor { - name: b"timeline\0", - typoid: 23, - typlen: 4, - }, - RowDescriptor { - name: b"xlogpos\0", - typoid: 25, - typlen: -1, - }, - RowDescriptor { - name: b"dbname\0", - typoid: 25, - typlen: -1, - }, - ]), - ); - BeMessage::write( - &mut self.outbuf, - &BeMessage::DataRow(&[Some(sysid_bytes), Some(tli_bytes), Some(lsn_bytes), None]), - ); - BeMessage::write( - &mut self.outbuf, - &BeMessage::CommandComplete(b"IDENTIFY_SYSTEM\0"), - ); - BeMessage::write(&mut self.outbuf, &BeMessage::ReadyForQuery); - self.send()?; - Ok(true) - } - - /// - /// Handle START_REPLICATION replication command - /// - fn handle_start_replication(&mut self, cmd: &Bytes) -> Result { - // helper function to encapsulate the regex -> Lsn magic - fn get_start_stop(cmd: &[u8]) -> Result<(Lsn, Lsn)> { - let re = Regex::new(r"([[:xdigit:]]+/[[:xdigit:]]+)").unwrap(); - let caps = re.captures_iter(str::from_utf8(&cmd[..])?); - let mut lsns = caps.map(|cap| cap[1].parse::()); - let start_pos = lsns - .next() - .ok_or_else(|| anyhow!("failed to find start LSN"))??; - let stop_pos = lsns.next().transpose()?.unwrap_or(Lsn(0)); - Ok((start_pos, stop_pos)) - } - - let (mut start_pos, mut stop_pos) = get_start_stop(&cmd)?; - - let wal_seg_size = self.timeline().get_info().server.wal_seg_size as usize; - if wal_seg_size == 0 { - bail!("Can not start replication before connecting to wal_proposer"); - } - let (wal_end, timeline) = self.find_end_of_wal(false); - if start_pos == Lsn(0) { - start_pos = wal_end; - } - if stop_pos == Lsn(0) && self.appname == Some("wal_proposer_recovery".to_string()) { - stop_pos = wal_end; - } - info!("Start replication from {} till {}", start_pos, stop_pos); - BeMessage::write(&mut self.outbuf, &BeMessage::Copy); - self.send()?; - - let mut end_pos: Lsn; - let mut commit_lsn: Lsn; - let mut wal_file: Option = None; - self.outbuf - .resize(LIBPQ_HDR_SIZE + XLOG_HDR_SIZE + MAX_SEND_SIZE, 0u8); - loop { - /* Wait until we have some data to stream */ - if stop_pos != Lsn(0) { - /* recovery mode: stream up to the specified LSN (VCL) */ - if start_pos >= stop_pos { - /* recovery finished */ - break; - } - end_pos = stop_pos; - } else { - /* normal mode */ - let timeline = self.timeline(); - let mut shared_state = timeline.mutex.lock().unwrap(); - loop { - commit_lsn = shared_state.commit_lsn; - if start_pos < commit_lsn { - end_pos = commit_lsn; - break; - } - shared_state = timeline.cond.wait(shared_state).unwrap(); - } - } - if end_pos == END_REPLICATION_MARKER { - break; - } - // Try to fetch replica's feedback - - // Temporarily set this stream into nonblocking mode. - // FIXME: This seems like a dirty hack. - // Should this task be done on a background thread? - // FIXME: set_nonblocking plus BufReader seems questionable. - self.stream_in.get_ref().set_nonblocking(true).unwrap(); - let read_result = self.stream_in.read(&mut self.inbuf); - self.stream_in.get_ref().set_nonblocking(false).unwrap(); - - match read_result { - Ok(0) => break, - Ok(_) => { - if let Some(FeMessage::CopyData(m)) = self.parse_message()? { - self.timeline() - .add_hs_feedback(HotStandbyFeedback::parse(&m.body)) - } - } - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} - Err(e) => { - return Err(e.into()); - } - } - - /* Open file if not opened yet */ - let curr_file = wal_file.take(); - let mut file: File; - if let Some(opened_file) = curr_file { - file = opened_file; - } else { - let segno = start_pos.segment_number(wal_seg_size as u64); - let wal_file_name = XLogFileName(timeline, segno, wal_seg_size); - let wal_file_path = self - .conf - .data_dir - .join(self.timeline().timelineid.to_string()) - .join(wal_file_name.clone() + ".partial"); - if let Ok(opened_file) = File::open(&wal_file_path) { - file = opened_file; - } else { - let wal_file_path = self - .conf - .data_dir - .join(self.timeline().timelineid.to_string()) - .join(wal_file_name); - match File::open(&wal_file_path) { - Ok(opened_file) => file = opened_file, - Err(e) => { - error!("Failed to open log file {:?}: {}", &wal_file_path, e); - return Err(e.into()); - } - } - } - } - let xlogoff = start_pos.segment_offset(wal_seg_size as u64) as usize; - - // How much to read and send in message? We cannot cross the WAL file - // boundary, and we don't want send more than MAX_SEND_SIZE. - let send_size = end_pos.checked_sub(start_pos).unwrap().0 as usize; - let send_size = min(send_size, wal_seg_size - xlogoff); - let send_size = min(send_size, MAX_SEND_SIZE); - - let msg_size = LIBPQ_HDR_SIZE + XLOG_HDR_SIZE + send_size; - let data_start = LIBPQ_HDR_SIZE + XLOG_HDR_SIZE; - let data_end = data_start + send_size; - - file.seek(SeekFrom::Start(xlogoff as u64))?; - file.read_exact(&mut self.outbuf[data_start..data_end])?; - self.outbuf[0] = b'd'; - BigEndian::write_u32( - &mut self.outbuf[1..5], - (msg_size - LIBPQ_MSG_SIZE_OFFS) as u32, - ); - self.outbuf[5] = b'w'; - BigEndian::write_u64(&mut self.outbuf[6..14], start_pos.0); - BigEndian::write_u64(&mut self.outbuf[14..22], end_pos.0); - BigEndian::write_u64(&mut self.outbuf[22..30], get_current_timestamp()); - - self.stream_out.write_all(&self.outbuf[0..msg_size])?; - start_pos += send_size as u64; - - debug!("Sent WAL to page server up to {}", end_pos); - - if start_pos.segment_offset(wal_seg_size as u64) != 0 { - wal_file = Some(file); - } - } - Ok(false) - } - - fn process_query(&mut self, q: &FeQueryMessage) -> Result { - trace!("got query {:?}", q.body); - - if q.body.starts_with(b"IDENTIFY_SYSTEM") { - self.handle_identify_system() - } else if q.body.starts_with(b"START_REPLICATION") { - self.handle_start_replication(&q.body) - } else { - bail!("Unexpected command {:?}", q.body); - } - } - fn write_wal_file( &self, startpos: Lsn, @@ -1026,12 +736,12 @@ impl Connection { let wal_file_path = self .conf .data_dir - .join(self.timeline().timelineid.to_string()) + .join(self.timeline.get().timelineid.to_string()) .join(wal_file_name.clone()); let wal_file_partial_path = self .conf .data_dir - .join(self.timeline().timelineid.to_string()) + .join(self.timeline.get().timelineid.to_string()) .join(wal_file_name.clone() + ".partial"); { @@ -1089,14 +799,373 @@ impl Connection { } Ok(()) } +} - /// Find last WAL record. If "precise" is false then just locate last partial segment - fn find_end_of_wal(&self, precise: bool) -> (Lsn, TimeLineID) { - let (lsn, timeline) = find_end_of_wal( - &self.conf.data_dir, - self.timeline().get_info().server.wal_seg_size as usize, - precise, - ); - (Lsn(lsn), timeline) +mod send_wal { + use super::{ + Connection, HotStandbyFeedback, Timeline, TimelineTools, END_REPLICATION_MARKER, + LIBPQ_HDR_SIZE, LIBPQ_MSG_SIZE_OFFS, MAX_SEND_SIZE, XLOG_HDR_SIZE, + }; + use crate::pq_protocol::{ + BeMessage, FeMessage, FeStartupMessage, RowDescriptor, StartupRequestCode, + }; + use crate::WalAcceptorConf; + use anyhow::{anyhow, bail, Result}; + use bytes::{BufMut, Bytes, BytesMut}; + use log::*; + use postgres_ffi::xlog_utils::{get_current_timestamp, XLogFileName}; + use regex::Regex; + use std::cmp::min; + use std::fs::File; + use std::io::{BufReader, Read, Seek, SeekFrom, Write}; + use std::net::{SocketAddr, TcpStream}; + use std::path::Path; + use std::sync::{Arc, Mutex}; + use std::{str, thread}; + use zenith_utils::lsn::Lsn; + + pub struct SendWal { + timeline: Option>, + /// Postgres connection, buffered input + stream_in: BufReader, + /// Postgres connection, output FIXME: To buffer, or not to buffer? flush() is a pain. + stream_out: TcpStream, + /// The cached result of socket.peer_addr() + peer_addr: SocketAddr, + /// wal acceptor configuration + conf: WalAcceptorConf, + /// assigned application name + appname: Option, + } + + impl SendWal { + /// Create a new `SendWal`, consuming the `Connection`. + pub fn new(conn: Connection) -> Self { + Self { + timeline: conn.timeline, + stream_in: conn.stream_in, + stream_out: conn.stream_out, + peer_addr: conn.peer_addr, + conf: conn.conf, + appname: None, + } + } + + /// + /// Send WAL to replica or WAL receiver using standard libpq replication protocol + /// + pub fn run(mut self) -> Result<()> { + let peer_addr = self.peer_addr.clone(); + info!("WAL sender to {:?} is started", peer_addr); + + // Handle the startup message first. + + let m = FeStartupMessage::read_from(&mut self.stream_in)?; + trace!("got startup message {:?}", m); + match m.kind { + StartupRequestCode::NegotiateGss | StartupRequestCode::NegotiateSsl => { + let mut buf = BytesMut::new(); + BeMessage::write(&mut buf, &BeMessage::Negotiate); + info!("SSL requested"); + self.stream_out.write_all(&buf)?; + } + StartupRequestCode::Normal => { + let mut buf = BytesMut::new(); + BeMessage::write(&mut buf, &BeMessage::AuthenticationOk); + BeMessage::write(&mut buf, &BeMessage::ReadyForQuery); + self.stream_out.write_all(&buf)?; + self.timeline.set(m.timelineid)?; + self.appname = m.appname; + } + StartupRequestCode::Cancel => return Ok(()), + } + + loop { + let msg = FeMessage::read_from(&mut self.stream_in)?; + match msg { + FeMessage::Query(q) => { + trace!("got query {:?}", q.body); + + if q.body.starts_with(b"IDENTIFY_SYSTEM") { + self.handle_identify_system()?; + } else if q.body.starts_with(b"START_REPLICATION") { + // Create a new replication object, consuming `self`. + let mut replication = ReplicationHandler::new(self); + replication.run(&q.body)?; + break; + } else { + bail!("Unexpected command {:?}", q.body); + } + } + FeMessage::Terminate => { + break; + } + _ => { + bail!("unexpected message"); + } + } + } + info!("WAL sender to {:?} is finished", peer_addr); + Ok(()) + } + + /// + /// Handle IDENTIFY_SYSTEM replication command + /// + fn handle_identify_system(&mut self) -> Result<()> { + let (start_pos, timeline) = self.timeline.find_end_of_wal(&self.conf.data_dir, false); + let lsn = start_pos.to_string(); + let tli = timeline.to_string(); + let sysid = self.timeline.get().get_info().server.system_id.to_string(); + let lsn_bytes = lsn.as_bytes(); + let tli_bytes = tli.as_bytes(); + let sysid_bytes = sysid.as_bytes(); + + let mut outbuf = BytesMut::new(); + BeMessage::write( + &mut outbuf, + &BeMessage::RowDescription(&[ + RowDescriptor { + name: b"systemid\0", + typoid: 25, + typlen: -1, + }, + RowDescriptor { + name: b"timeline\0", + typoid: 23, + typlen: 4, + }, + RowDescriptor { + name: b"xlogpos\0", + typoid: 25, + typlen: -1, + }, + RowDescriptor { + name: b"dbname\0", + typoid: 25, + typlen: -1, + }, + ]), + ); + BeMessage::write( + &mut outbuf, + &BeMessage::DataRow(&[Some(sysid_bytes), Some(tli_bytes), Some(lsn_bytes), None]), + ); + BeMessage::write( + &mut outbuf, + &BeMessage::CommandComplete(b"IDENTIFY_SYSTEM\0"), + ); + BeMessage::write(&mut outbuf, &BeMessage::ReadyForQuery); + self.stream_out.write_all(&outbuf)?; + Ok(()) + } + } + + pub struct ReplicationHandler { + timeline: Option>, + /// Postgres connection, buffered input + stream_in: Option>, + /// Postgres connection, output FIXME: To buffer, or not to buffer? flush() is a pain. + stream_out: Mutex, + /// wal acceptor configuration + conf: WalAcceptorConf, + /// assigned application name + appname: Option, + } + + impl ReplicationHandler { + /// Create a new `SendWal`, consuming the `Connection`. + pub fn new(conn: SendWal) -> Self { + Self { + timeline: conn.timeline, + stream_in: Some(conn.stream_in), + stream_out: Mutex::new(conn.stream_out), + conf: conn.conf, + appname: None, + } + } + + /// Handle incoming messages from the network. + /// + /// This is spawned into the background by `handle_start_replication`. + /// + fn background_thread(mut stream_in: impl Read, timeline: Arc) -> Result<()> { + // Wait for replica's feedback. + // We only handle `CopyData` messages. Anything else is ignored. + + loop { + match FeMessage::read_from(&mut stream_in)? { + FeMessage::CopyData(m) => { + timeline.add_hs_feedback(HotStandbyFeedback::parse(&m.body)) + } + msg => { + info!("unexpected message {:?}", msg); + } + } + } + } + + /// Helper function that parses a pair of LSNs. + fn parse_start_stop(cmd: &[u8]) -> Result<(Lsn, Lsn)> { + let re = Regex::new(r"([[:xdigit:]]+/[[:xdigit:]]+)").unwrap(); + let caps = re.captures_iter(str::from_utf8(&cmd[..])?); + let mut lsns = caps.map(|cap| cap[1].parse::()); + let start_pos = lsns + .next() + .ok_or_else(|| anyhow!("failed to find start LSN"))??; + let stop_pos = lsns.next().transpose()?.unwrap_or(Lsn(0)); + Ok((start_pos, stop_pos)) + } + + /// Helper function for opening a wal file. + fn open_wal_file(wal_file_path: &Path) -> Result { + // First try to open the .partial file. + let mut partial_path = wal_file_path.to_owned(); + partial_path.set_extension("partial"); + if let Ok(opened_file) = File::open(&partial_path) { + return Ok(opened_file); + } + + // If that failed, try it without the .partial extension. + match File::open(&wal_file_path) { + Ok(opened_file) => return Ok(opened_file), + Err(e) => { + error!("Failed to open log file {:?}: {}", &wal_file_path, e); + return Err(e.into()); + } + } + } + + /// + /// Handle START_REPLICATION replication command + /// + fn run(&mut self, cmd: &Bytes) -> Result<()> { + // spawn the background thread which receives HotStandbyFeedback messages. + let bg_timeline = Arc::clone(self.timeline.get()); + let bg_stream_in = self.stream_in.take().unwrap(); + + thread::spawn(move || { + if let Err(err) = Self::background_thread(bg_stream_in, bg_timeline) { + error!("socket error: {}", err); + } + }); + + let (mut start_pos, mut stop_pos) = Self::parse_start_stop(&cmd)?; + + let wal_seg_size = self.timeline.get().get_info().server.wal_seg_size as usize; + if wal_seg_size == 0 { + bail!("Can not start replication before connecting to wal_proposer"); + } + let (wal_end, timeline) = self.timeline.find_end_of_wal(&self.conf.data_dir, false); + if start_pos == Lsn(0) { + start_pos = wal_end; + } + if stop_pos == Lsn(0) && self.appname == Some("wal_proposer_recovery".to_string()) { + stop_pos = wal_end; + } + info!("Start replication from {} till {}", start_pos, stop_pos); + + let mut outbuf = BytesMut::new(); + BeMessage::write(&mut outbuf, &BeMessage::Copy); + self.send(&outbuf)?; + outbuf.clear(); + + let mut end_pos: Lsn; + let mut commit_lsn: Lsn; + let mut wal_file: Option = None; + + loop { + /* Wait until we have some data to stream */ + if stop_pos != Lsn(0) { + /* recovery mode: stream up to the specified LSN (VCL) */ + if start_pos >= stop_pos { + /* recovery finished */ + break; + } + end_pos = stop_pos; + } else { + /* normal mode */ + let timeline = self.timeline.get(); + let mut shared_state = timeline.mutex.lock().unwrap(); + loop { + commit_lsn = shared_state.commit_lsn; + if start_pos < commit_lsn { + end_pos = commit_lsn; + break; + } + shared_state = timeline.cond.wait(shared_state).unwrap(); + } + } + if end_pos == END_REPLICATION_MARKER { + break; + } + + // Take the `File` from `wal_file`, or open a new file. + let mut file = match wal_file.take() { + Some(file) => file, + None => { + // Open a new file. + let segno = start_pos.segment_number(wal_seg_size as u64); + let wal_file_name = XLogFileName(timeline, segno, wal_seg_size); + let timeline_id = self.timeline.get().timelineid.to_string(); + let wal_file_path = + self.conf.data_dir.join(timeline_id).join(wal_file_name); + Self::open_wal_file(&wal_file_path)? + } + }; + + let xlogoff = start_pos.segment_offset(wal_seg_size as u64) as usize; + + // How much to read and send in message? We cannot cross the WAL file + // boundary, and we don't want send more than MAX_SEND_SIZE. + let send_size = end_pos.checked_sub(start_pos).unwrap().0 as usize; + let send_size = min(send_size, wal_seg_size - xlogoff); + let send_size = min(send_size, MAX_SEND_SIZE); + + let msg_size = LIBPQ_HDR_SIZE + XLOG_HDR_SIZE + send_size; + + // Read some data from the file. + let mut file_buf = vec![0u8; send_size]; + file.seek(SeekFrom::Start(xlogoff as u64))?; + file.read_exact(&mut file_buf)?; + + // Write some data to the network socket. + // FIXME: turn these into structs. + // 'd' is CopyData; + // 'w' is "WAL records" + // https://www.postgresql.org/docs/9.1/protocol-message-formats.html + // src/backend/replication/walreceiver.c + outbuf.clear(); + outbuf.put_u8(b'd'); + outbuf.put_u32((msg_size - LIBPQ_MSG_SIZE_OFFS) as u32); + outbuf.put_u8(b'w'); + outbuf.put_u64(start_pos.0); + outbuf.put_u64(end_pos.0); + outbuf.put_u64(get_current_timestamp()); + + assert!(outbuf.len() + file_buf.len() == msg_size); + // FIXME: combine these two into a single send, + // so that no other traffic can be sent in between them. + self.send(&outbuf)?; + self.send(&file_buf)?; + start_pos += send_size as u64; + + debug!("Sent WAL to page server up to {}", end_pos); + + // Decide whether to reuse this file. If we don't set wal_file here + // a new file will be opened next time. + if start_pos.segment_offset(wal_seg_size as u64) != 0 { + wal_file = Some(file); + } + } + Ok(()) + } + + /// Unlock the mutex and send bytes on the network. + fn send(&self, buf: &[u8]) -> Result<()> { + let mut writer = self.stream_out.lock().unwrap(); + writer.write_all(buf.as_ref())?; + Ok(()) + } } } diff --git a/zenith_utils/src/bin_ser.rs b/zenith_utils/src/bin_ser.rs index fb1c56b44c..ad0ae38cd9 100644 --- a/zenith_utils/src/bin_ser.rs +++ b/zenith_utils/src/bin_ser.rs @@ -93,9 +93,7 @@ pub trait BeSer: Serialize + DeserializeOwned { } /// Deserialize from a reader - /// - /// tip: `&[u8]` implements `Read` - fn des_from(r: R) -> Result { + fn des_from(r: &mut R) -> Result { be_coder().deserialize_from(r).or(Err(DeserializeError)) } } @@ -128,11 +126,56 @@ pub trait LeSer: Serialize + DeserializeOwned { le_coder().deserialize(buf).or(Err(DeserializeError)) } + /// Deserialize from a reader + fn des_from(r: &mut R) -> Result { + le_coder().deserialize_from(r).or(Err(DeserializeError)) + } +} + +/// Binary serialize/deserialize helper functions (Big Endian) +/// +/// This version panics on every serialization/deserialization error. +/// That can be useful if you want to see a backtrace to find where the +/// error occurred. +pub trait LeSerPanic: Serialize + DeserializeOwned { + /// Serialize into a byte slice + fn ser_into_slice(&self, b: &mut [u8]) -> Result<(), SerializeError> { + // This is slightly awkward; we need a mutable reference to a mutable reference. + let mut w = b; + self.ser_into(&mut w) + } + + /// Serialize into a borrowed writer + /// + /// This is useful for most `Write` types except `&mut [u8]`, which + /// can more easily use [`ser_into_slice`](Self::ser_into_slice). + fn ser_into(&self, w: &mut W) -> Result<(), SerializeError> { + le_coder() + .serialize_into(w, &self) + .or_else(|e| panic!("ser_into failed: {}", e)) + } + + /// Serialize into a new heap-allocated buffer + fn ser(&self) -> Result, SerializeError> { + le_coder() + .serialize(&self) + .or_else(|e| panic!("ser failed: {}", e)) + } + + /// Deserialize from a byte slice + fn des(buf: &[u8]) -> Result { + le_coder() + .deserialize(buf) + .or_else(|e| panic!("des failed: {}", e)) + } + /// Deserialize from a reader /// /// tip: `&[u8]` implements `Read` - fn des_from(r: R) -> Result { - le_coder().deserialize_from(r).or(Err(DeserializeError)) + fn des_from(r: &mut R) -> Result { + le_coder() + .deserialize_from(r) + .or_else(|e| panic!("des_from failed: {}", e)) } } @@ -140,6 +183,8 @@ impl BeSer for T where T: Serialize + DeserializeOwned {} impl LeSer for T where T: Serialize + DeserializeOwned {} +impl LeSerPanic for T where T: Serialize + DeserializeOwned {} + #[cfg(test)] mod tests { use serde::{Deserialize, Serialize}; @@ -205,13 +250,13 @@ mod tests { assert_eq!(buf.into_inner(), SHORT1_ENC_BE_TRAILING); // deserialize from a `Write` sink. - let buf = Cursor::new(SHORT2_ENC_BE); - let decoded = ShortStruct::des_from(buf).unwrap(); + let mut buf = Cursor::new(SHORT2_ENC_BE); + let decoded = ShortStruct::des_from(&mut buf).unwrap(); assert_eq!(decoded, SHORT2); // deserialize from a `Write` sink that terminates early. - let buf = Cursor::new([0u8; 4]); - ShortStruct::des_from(buf).unwrap_err(); + let mut buf = Cursor::new([0u8; 4]); + ShortStruct::des_from(&mut buf).unwrap_err(); } #[test] @@ -234,13 +279,13 @@ mod tests { assert_eq!(buf.into_inner(), SHORT1_ENC_LE_TRAILING); // deserialize from a `Write` sink. - let buf = Cursor::new(SHORT2_ENC_LE); - let decoded = ShortStruct::des_from(buf).unwrap(); + let mut buf = Cursor::new(SHORT2_ENC_LE); + let decoded = ShortStruct::des_from(&mut buf).unwrap(); assert_eq!(decoded, SHORT2); // deserialize from a `Write` sink that terminates early. - let buf = Cursor::new([0u8; 4]); - ShortStruct::des_from(buf).unwrap_err(); + let mut buf = Cursor::new([0u8; 4]); + ShortStruct::des_from(&mut buf).unwrap_err(); } #[test]