diff --git a/walkeeper/src/send_wal.rs b/walkeeper/src/send_wal.rs index 0ee6a7829c..1884c1aab4 100644 --- a/walkeeper/src/send_wal.rs +++ b/walkeeper/src/send_wal.rs @@ -7,7 +7,6 @@ use crate::pq_protocol::{ }; use crate::replication::ReplicationConn; use crate::timeline::{Timeline, TimelineTools}; -use crate::wal_service::Connection; use crate::WalAcceptorConf; use anyhow::{bail, Result}; use bytes::BytesMut; @@ -33,15 +32,17 @@ pub struct SendWalConn { impl SendWalConn { /// Create a new `SendWal`, consuming the `Connection`. - pub fn new(conn: Connection) -> Self { - Self { - timeline: conn.timeline, - stream_in: conn.stream_in, - stream_out: conn.stream_out, - peer_addr: conn.peer_addr, - conf: conn.conf, + pub fn new(socket: TcpStream, conf: WalAcceptorConf) -> Result { + let peer_addr = socket.peer_addr()?; + let conn = SendWalConn { + timeline: None, + stream_in: BufReader::new(socket.try_clone()?), + stream_out: socket, + peer_addr, + conf, appname: None, - } + }; + Ok(conn) } /// diff --git a/walkeeper/src/wal_service.rs b/walkeeper/src/wal_service.rs index 02f7702b2c..eb005059f7 100644 --- a/walkeeper/src/wal_service.rs +++ b/walkeeper/src/wal_service.rs @@ -3,14 +3,13 @@ //! receive WAL from wal_proposer and send it to WAL receivers //! use anyhow::{bail, Result}; -use byteorder::{BigEndian, ByteOrder, ReadBytesExt}; use fs2::FileExt; use log::*; use postgres::{Client, NoTls}; use serde::{Deserialize, Serialize}; use std::cmp::{max, min}; use std::fs::{self, File, OpenOptions}; -use std::io::{BufRead, BufReader, Read, Seek, SeekFrom, Write}; +use std::io::{BufReader, Read, Seek, SeekFrom, Write}; use std::net::{SocketAddr, TcpListener, TcpStream}; use std::str; use std::sync::Arc; @@ -230,7 +229,7 @@ impl SharedState { } #[derive(Debug)] -pub struct Connection { +pub struct ReceiveWalConn { pub timeline: Option>, /// Postgres connection, buffered input pub stream_in: BufReader, @@ -294,19 +293,38 @@ pub fn thread_main(conf: WalAcceptorConf) -> Result<()> { /// This is run by main_loop, inside a background thread. /// -/// This is only a separate function to make a convenient place to collect -/// all errors for logging. Our caller can log errors in a single place. -fn handle_socket(socket: TcpStream, conf: WalAcceptorConf) -> Result<()> { +fn handle_socket(mut socket: TcpStream, conf: WalAcceptorConf) -> Result<()> { socket.set_nodelay(true)?; - let conn = Connection::new(socket, conf)?; - conn.run()?; + + // Peek at the incoming data to see what protocol is being sent. + let peeked = peek_u32(&mut socket)?; + if peeked == 0 { + // Consume the 4 bytes we peeked at. This protocol begins after them. + socket.read_exact(&mut [0u8; 4])?; + ReceiveWalConn::new(socket, conf)?.run()?; // internal protocol between wal_proposer and wal_acceptor + } else { + SendWalConn::new(socket, conf)?.run()?; // libpq replication protocol between wal_acceptor and replicas/pagers + } Ok(()) } -impl Connection { - pub fn new(socket: TcpStream, conf: WalAcceptorConf) -> Result { +/// Fetch the first 4 bytes from the network (big endian), without consuming them. +/// +/// This is used to help determine what protocol the peer is using. +fn peek_u32(stream: &mut TcpStream) -> Result { + let mut buf = [0u8; 4]; + loop { + let num_bytes = stream.peek(&mut buf)?; + if num_bytes == 4 { + return Ok(u32::from_be_bytes(buf)); + } + } +} + +impl ReceiveWalConn { + pub fn new(socket: TcpStream, conf: WalAcceptorConf) -> Result { let peer_addr = socket.peer_addr()?; - let conn = Connection { + let conn = ReceiveWalConn { timeline: None, stream_in: BufReader::new(socket.try_clone()?), stream_out: socket, @@ -316,29 +334,6 @@ impl Connection { Ok(conn) } - fn run(mut self) -> Result<()> { - // Peek at the first 4 bytes of the incoming data, to determine which - // protocol is being spoken. fill_buf` does not consume any of the - // bytes we peek at; they are left in the BufReader's internal buffer - // for the next reader. - let peek_buf = self.stream_in.fill_buf()?; - if peek_buf.len() < 4 { - // Empty peek_buf means the socket was closed. - // Less than 4 bytes doesn't seem likely unless the sender is malicious. - // read_u32 would panic on any of these, so just return an error. - bail!("fill_buf EOF or underrun"); - } - let startup_pkg_len = BigEndian::read_u32(peek_buf); - if startup_pkg_len == 0 { - // Consume the 4 bytes we peeked at. This protocol begins after them. - self.stream_in.read_u32::()?; - self.receive_wal()?; // internal protocol between wal_proposer and wal_acceptor - } else { - SendWalConn::new(self).run()?; // libpq replication protocol between wal_acceptor and replicas/pagers - } - Ok(()) - } - fn read_req(&mut self) -> Result { // As the trait bound implies, this always encodes little-endian. Ok(T::des_from(&mut self.stream_in)?) @@ -371,7 +366,7 @@ impl Connection { } /// Receive WAL from wal_proposer - fn receive_wal(&mut self) -> Result<()> { + fn run(&mut self) -> Result<()> { // Receive information about server let server_info = self.read_req::()?; info!(