diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index a12d188d3c..f804b843ba 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -165,7 +165,7 @@ fn page_service_conn_main(conf: &'static PageServerConf, socket: TcpStream) -> a } let mut conn_handler = PageServerHandler::new(conf); - let mut pgbackend = PostgresBackend::new(socket, AuthType::Trust)?; + let pgbackend = PostgresBackend::new(socket, AuthType::Trust)?; pgbackend.run(&mut conn_handler) } diff --git a/proxy/src/mgmt.rs b/proxy/src/mgmt.rs index fab8d2121c..b8c59bbe03 100644 --- a/proxy/src/mgmt.rs +++ b/proxy/src/mgmt.rs @@ -34,7 +34,7 @@ pub fn thread_main(state: &'static ProxyState, listener: TcpListener) -> anyhow: pub fn mgmt_conn_main(state: &'static ProxyState, socket: TcpStream) -> anyhow::Result<()> { let mut conn_handler = MgmtHandler { state }; - let mut pgbackend = PostgresBackend::new(socket, AuthType::Trust)?; + let pgbackend = PostgresBackend::new(socket, AuthType::Trust)?; pgbackend.run(&mut conn_handler) } diff --git a/walkeeper/src/wal_service.rs b/walkeeper/src/wal_service.rs index 124a273159..f4e05a0bc9 100644 --- a/walkeeper/src/wal_service.rs +++ b/walkeeper/src/wal_service.rs @@ -50,7 +50,7 @@ fn handle_socket(mut socket: TcpStream, conf: WalAcceptorConf) -> Result<()> { ReceiveWalConn::new(socket, conf)?.run()?; // internal protocol between wal_proposer and wal_acceptor } else { let mut conn_handler = SendWalHandler::new(conf); - let mut pgbackend = PostgresBackend::new(socket, AuthType::Trust)?; + let pgbackend = PostgresBackend::new(socket, AuthType::Trust)?; // libpq replication protocol between wal_acceptor and replicas/pagers pgbackend.run(&mut conn_handler)?; } diff --git a/zenith_utils/src/postgres_backend.rs b/zenith_utils/src/postgres_backend.rs index 8c54b31e1e..62f3ecee62 100644 --- a/zenith_utils/src/postgres_backend.rs +++ b/zenith_utils/src/postgres_backend.rs @@ -4,15 +4,12 @@ //! is rather narrow, but we can extend it once required. use crate::pq_proto::{BeMessage, FeMessage, FeStartupMessage, StartupRequestCode}; -use anyhow::bail; -use anyhow::Result; +use anyhow::{bail, Result}; use bytes::{Bytes, BytesMut}; use log::*; use rand::Rng; -use std::io; -use std::io::{BufReader, Write}; -use std::net::Shutdown; -use std::net::TcpStream; +use std::io::{self, BufReader, Write}; +use std::net::{Shutdown, TcpStream}; pub trait Handler { /// Handle single query. @@ -36,19 +33,27 @@ pub trait Handler { } } -#[derive(PartialEq)] +/// PostgresBackend protocol state. +/// XXX: The order of the constructors matters. +#[derive(Clone, Copy, PartialEq, PartialOrd)] pub enum ProtoState { Initialization, Authentication, Established, } -#[derive(PartialEq)] +#[derive(Clone, Copy, PartialEq)] pub enum AuthType { Trust, MD5, } +#[derive(Clone, Copy)] +pub enum ProcessMsgResult { + Continue, + Break, +} + pub struct PostgresBackend { // replication.rs wants to handle reading on its own in separate thread, so // wrap in Option to be able to take and transfer the BufReader. Ugly, but I @@ -75,7 +80,10 @@ pub fn query_from_cstring(query_string: Bytes) -> Vec { } impl PostgresBackend { - pub fn new(socket: TcpStream, auth_type: AuthType) -> Result { + pub fn new( + socket: TcpStream, + auth_type: AuthType, + ) -> io::Result { let mut pb = PostgresBackend { stream_in: None, stream_out: socket, @@ -84,6 +92,7 @@ impl PostgresBackend { md5_salt: [0u8; 4], auth_type, }; + // if socket cloning fails, report the error and bail out pb.stream_in = match pb.stream_out.try_clone() { Ok(read_sock) => Some(BufReader::new(read_sock)), @@ -93,6 +102,7 @@ impl PostgresBackend { return Err(error); } }; + Ok(pb) } @@ -114,11 +124,12 @@ impl PostgresBackend { /// Read full message or return None if connection is closed. pub fn read_message(&mut self) -> Result> { - match self.state { - ProtoState::Initialization => FeStartupMessage::read(self.get_stream_in()?), - ProtoState::Authentication | ProtoState::Established => { - FeMessage::read(self.get_stream_in()?) - } + let (state, stream) = (self.state, self.get_stream_in()?); + + use ProtoState::*; + match state { + ProtoState::Initialization => FeStartupMessage::read(stream), + Authentication | Established => FeMessage::read(stream), } } @@ -141,137 +152,158 @@ impl PostgresBackend { self.flush() } - // wrapper for run_internal() that shuts down socket when we are done - pub fn run(&mut self, handler: &mut impl Handler) -> Result<()> { - let ret = self.run_internal(handler); + // Wrapper for run_message_loop() that shuts down socket when we are done + pub fn run(mut self, handler: &mut impl Handler) -> Result<()> { + let ret = self.run_message_loop(handler); let _res = self.stream_out.shutdown(Shutdown::Both); ret } - fn run_internal(&mut self, handler: &mut impl Handler) -> Result<()> { + fn run_message_loop(&mut self, handler: &mut impl Handler) -> Result<()> { let peer_addr = self.stream_out.peer_addr()?; info!("postgres backend to {:?} started", peer_addr); + let mut unnamed_query_string = Bytes::new(); - loop { - let msg = self.read_message()?; + while let Some(msg) = self.read_message()? { trace!("got message {:?}", msg); - // Allow only startup and password messages during auth. Otherwise client would be able to bypass auth - // TODO: change that to proper top-level match of protocol state with separate message handling for each state - if self.state == ProtoState::Authentication || self.state == ProtoState::Initialization - { - match msg { - Some(FeMessage::PasswordMessage(ref _m)) => {} - Some(FeMessage::StartupMessage(ref _m)) => {} - Some(_) => { - bail!("protocol violation"); - } - None => {} - }; - } - - match msg { - Some(FeMessage::StartupMessage(m)) => { - trace!("got startup message {:?}", m); - - match m.kind { - StartupRequestCode::NegotiateGss | StartupRequestCode::NegotiateSsl => { - info!("SSL requested"); - self.write_message(&BeMessage::Negotiate)?; - } - StartupRequestCode::Normal => { - // NB: startup() may change self.auth_type -- we are using that in proxy code - // to bypass auth for new users. - handler.startup(self, &m)?; - - match self.auth_type { - AuthType::Trust => { - self.write_message_noflush(&BeMessage::AuthenticationOk)?; - // psycopg2 will not connect if client_encoding is not - // specified by the server - self.write_message_noflush(&BeMessage::ParameterStatus)?; - self.write_message(&BeMessage::ReadyForQuery)?; - self.state = ProtoState::Established; - } - AuthType::MD5 => { - rand::thread_rng().fill(&mut self.md5_salt); - let md5_salt = self.md5_salt; - self.write_message(&BeMessage::AuthenticationMD5Password( - &md5_salt, - ))?; - self.state = ProtoState::Authentication; - } - } - } - StartupRequestCode::Cancel => break, - } - } - - Some(FeMessage::PasswordMessage(m)) => { - trace!("got password message '{:?}'", m); - - assert!(self.state == ProtoState::Authentication); - - let (_, md5_response) = m - .split_last() - .ok_or_else(|| anyhow::Error::msg("protocol violation"))?; - - if let Err(e) = handler.check_auth_md5(self, md5_response) { - self.write_message(&BeMessage::ErrorResponse(format!("{}", e)))?; - bail!("auth failed: {}", e); - } else { - self.write_message_noflush(&BeMessage::AuthenticationOk)?; - // psycopg2 will not connect if client_encoding is not - // specified by the server - self.write_message_noflush(&BeMessage::ParameterStatus)?; - self.write_message(&BeMessage::ReadyForQuery)?; - self.state = ProtoState::Established; - } - } - - Some(FeMessage::Query(m)) => { - trace!("got query {:?}", m.body); - // xxx distinguish fatal and recoverable errors? - if let Err(e) = handler.process_query(self, m.body) { - let errmsg = format!("{}", e); - self.write_message_noflush(&BeMessage::ErrorResponse(errmsg))?; - } - self.write_message(&BeMessage::ReadyForQuery)?; - } - Some(FeMessage::Parse(m)) => { - unnamed_query_string = m.query_string; - self.write_message(&BeMessage::ParseComplete)?; - } - Some(FeMessage::Describe(_)) => { - self.write_message_noflush(&BeMessage::ParameterDescription)? - .write_message(&BeMessage::NoData)?; - } - Some(FeMessage::Bind(_)) => { - self.write_message(&BeMessage::BindComplete)?; - } - Some(FeMessage::Close(_)) => { - self.write_message(&BeMessage::CloseComplete)?; - } - Some(FeMessage::Execute(_)) => { - handler.process_query(self, unnamed_query_string.clone())?; - } - Some(FeMessage::Sync) => { - self.write_message(&BeMessage::ReadyForQuery)?; - } - Some(FeMessage::Terminate) => { - break; - } - None => { - info!("connection closed"); - break; - } - x => { - bail!("unexpected message type : {:?}", x); - } + match self.process_message(handler, msg, &mut unnamed_query_string)? { + ProcessMsgResult::Continue => continue, + ProcessMsgResult::Break => break, } } + info!("postgres backend to {:?} exited", peer_addr); Ok(()) } + + fn process_message( + &mut self, + handler: &mut impl Handler, + msg: FeMessage, + unnamed_query_string: &mut Bytes, + ) -> Result { + // Allow only startup and password messages during auth. Otherwise client would be able to bypass auth + // TODO: change that to proper top-level match of protocol state with separate message handling for each state + if self.state < ProtoState::Established { + match &msg { + FeMessage::PasswordMessage(_m) => {} + FeMessage::StartupMessage(_m) => {} + _ => { + bail!("protocol violation"); + } + } + } + + match msg { + FeMessage::StartupMessage(m) => { + trace!("got startup message {:?}", m); + + match m.kind { + StartupRequestCode::NegotiateGss | StartupRequestCode::NegotiateSsl => { + info!("SSL requested"); + self.write_message(&BeMessage::Negotiate)?; + } + StartupRequestCode::Normal => { + // NB: startup() may change self.auth_type -- we are using that in proxy code + // to bypass auth for new users. + handler.startup(self, &m)?; + + match self.auth_type { + AuthType::Trust => { + self.write_message_noflush(&BeMessage::AuthenticationOk)?; + // psycopg2 will not connect if client_encoding is not + // specified by the server + self.write_message_noflush(&BeMessage::ParameterStatus)?; + self.write_message(&BeMessage::ReadyForQuery)?; + self.state = ProtoState::Established; + } + AuthType::MD5 => { + rand::thread_rng().fill(&mut self.md5_salt); + let md5_salt = self.md5_salt; + self.write_message(&BeMessage::AuthenticationMD5Password( + &md5_salt, + ))?; + self.state = ProtoState::Authentication; + } + } + } + StartupRequestCode::Cancel => { + return Ok(ProcessMsgResult::Break); + } + } + } + + FeMessage::PasswordMessage(m) => { + trace!("got password message '{:?}'", m); + + assert!(self.state == ProtoState::Authentication); + + let (_, md5_response) = m + .split_last() + .ok_or_else(|| anyhow::Error::msg("protocol violation"))?; + + if let Err(e) = handler.check_auth_md5(self, md5_response) { + self.write_message(&BeMessage::ErrorResponse(format!("{}", e)))?; + bail!("auth failed: {}", e); + } else { + self.write_message_noflush(&BeMessage::AuthenticationOk)?; + // psycopg2 will not connect if client_encoding is not + // specified by the server + self.write_message_noflush(&BeMessage::ParameterStatus)?; + self.write_message(&BeMessage::ReadyForQuery)?; + self.state = ProtoState::Established; + } + } + + FeMessage::Query(m) => { + trace!("got query {:?}", m.body); + // xxx distinguish fatal and recoverable errors? + if let Err(e) = handler.process_query(self, m.body) { + let errmsg = format!("{}", e); + self.write_message_noflush(&BeMessage::ErrorResponse(errmsg))?; + } + self.write_message(&BeMessage::ReadyForQuery)?; + } + + FeMessage::Parse(m) => { + *unnamed_query_string = m.query_string; + self.write_message(&BeMessage::ParseComplete)?; + } + + FeMessage::Describe(_) => { + self.write_message_noflush(&BeMessage::ParameterDescription)? + .write_message(&BeMessage::NoData)?; + } + + FeMessage::Bind(_) => { + self.write_message(&BeMessage::BindComplete)?; + } + + FeMessage::Close(_) => { + self.write_message(&BeMessage::CloseComplete)?; + } + + FeMessage::Execute(_) => { + handler.process_query(self, unnamed_query_string.clone())?; + } + + FeMessage::Sync => { + self.write_message(&BeMessage::ReadyForQuery)?; + } + + FeMessage::Terminate => { + return Ok(ProcessMsgResult::Break); + } + + // We prefer explicit pattern matching to wildcards, because + // this helps us spot the places where new variants are missing + FeMessage::CopyData(_) | FeMessage::CopyDone => { + bail!("unexpected message type: {:?}", msg); + } + } + + Ok(ProcessMsgResult::Continue) + } }