Extract message processing function from PostgresBackend's event loop

This patch has been extracted from #348, where it became unnecessary
after we had decided that we didn't want to measure anything inside
PostgresBackend.

IMO the change is good enough to make its way into the codebase,
even though it brings nothing "new" to the code.
This commit is contained in:
Dmitry Ivanov
2021-08-03 23:16:57 +03:00
parent bcaa59c0b9
commit ed634ec320
4 changed files with 167 additions and 135 deletions

View File

@@ -165,7 +165,7 @@ fn page_service_conn_main(conf: &'static PageServerConf, socket: TcpStream) -> a
}
let mut conn_handler = PageServerHandler::new(conf);
let mut pgbackend = PostgresBackend::new(socket, AuthType::Trust)?;
let pgbackend = PostgresBackend::new(socket, AuthType::Trust)?;
pgbackend.run(&mut conn_handler)
}

View File

@@ -34,7 +34,7 @@ pub fn thread_main(state: &'static ProxyState, listener: TcpListener) -> anyhow:
pub fn mgmt_conn_main(state: &'static ProxyState, socket: TcpStream) -> anyhow::Result<()> {
let mut conn_handler = MgmtHandler { state };
let mut pgbackend = PostgresBackend::new(socket, AuthType::Trust)?;
let pgbackend = PostgresBackend::new(socket, AuthType::Trust)?;
pgbackend.run(&mut conn_handler)
}

View File

@@ -50,7 +50,7 @@ fn handle_socket(mut socket: TcpStream, conf: WalAcceptorConf) -> Result<()> {
ReceiveWalConn::new(socket, conf)?.run()?; // internal protocol between wal_proposer and wal_acceptor
} else {
let mut conn_handler = SendWalHandler::new(conf);
let mut pgbackend = PostgresBackend::new(socket, AuthType::Trust)?;
let pgbackend = PostgresBackend::new(socket, AuthType::Trust)?;
// libpq replication protocol between wal_acceptor and replicas/pagers
pgbackend.run(&mut conn_handler)?;
}

View File

@@ -4,15 +4,12 @@
//! is rather narrow, but we can extend it once required.
use crate::pq_proto::{BeMessage, FeMessage, FeStartupMessage, StartupRequestCode};
use anyhow::bail;
use anyhow::Result;
use anyhow::{bail, Result};
use bytes::{Bytes, BytesMut};
use log::*;
use rand::Rng;
use std::io;
use std::io::{BufReader, Write};
use std::net::Shutdown;
use std::net::TcpStream;
use std::io::{self, BufReader, Write};
use std::net::{Shutdown, TcpStream};
pub trait Handler {
/// Handle single query.
@@ -36,19 +33,27 @@ pub trait Handler {
}
}
#[derive(PartialEq)]
/// PostgresBackend protocol state.
/// XXX: The order of the constructors matters.
#[derive(Clone, Copy, PartialEq, PartialOrd)]
pub enum ProtoState {
Initialization,
Authentication,
Established,
}
#[derive(PartialEq)]
#[derive(Clone, Copy, PartialEq)]
pub enum AuthType {
Trust,
MD5,
}
#[derive(Clone, Copy)]
pub enum ProcessMsgResult {
Continue,
Break,
}
pub struct PostgresBackend {
// replication.rs wants to handle reading on its own in separate thread, so
// wrap in Option to be able to take and transfer the BufReader. Ugly, but I
@@ -75,7 +80,10 @@ pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
}
impl PostgresBackend {
pub fn new(socket: TcpStream, auth_type: AuthType) -> Result<Self, std::io::Error> {
pub fn new(
socket: TcpStream,
auth_type: AuthType,
) -> io::Result<Self> {
let mut pb = PostgresBackend {
stream_in: None,
stream_out: socket,
@@ -84,6 +92,7 @@ impl PostgresBackend {
md5_salt: [0u8; 4],
auth_type,
};
// if socket cloning fails, report the error and bail out
pb.stream_in = match pb.stream_out.try_clone() {
Ok(read_sock) => Some(BufReader::new(read_sock)),
@@ -93,6 +102,7 @@ impl PostgresBackend {
return Err(error);
}
};
Ok(pb)
}
@@ -114,11 +124,12 @@ impl PostgresBackend {
/// Read full message or return None if connection is closed.
pub fn read_message(&mut self) -> Result<Option<FeMessage>> {
match self.state {
ProtoState::Initialization => FeStartupMessage::read(self.get_stream_in()?),
ProtoState::Authentication | ProtoState::Established => {
FeMessage::read(self.get_stream_in()?)
}
let (state, stream) = (self.state, self.get_stream_in()?);
use ProtoState::*;
match state {
ProtoState::Initialization => FeStartupMessage::read(stream),
Authentication | Established => FeMessage::read(stream),
}
}
@@ -141,137 +152,158 @@ impl PostgresBackend {
self.flush()
}
// wrapper for run_internal() that shuts down socket when we are done
pub fn run(&mut self, handler: &mut impl Handler) -> Result<()> {
let ret = self.run_internal(handler);
// Wrapper for run_message_loop() that shuts down socket when we are done
pub fn run(mut self, handler: &mut impl Handler) -> Result<()> {
let ret = self.run_message_loop(handler);
let _res = self.stream_out.shutdown(Shutdown::Both);
ret
}
fn run_internal(&mut self, handler: &mut impl Handler) -> Result<()> {
fn run_message_loop(&mut self, handler: &mut impl Handler) -> Result<()> {
let peer_addr = self.stream_out.peer_addr()?;
info!("postgres backend to {:?} started", peer_addr);
let mut unnamed_query_string = Bytes::new();
loop {
let msg = self.read_message()?;
while let Some(msg) = self.read_message()? {
trace!("got message {:?}", msg);
// Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
// TODO: change that to proper top-level match of protocol state with separate message handling for each state
if self.state == ProtoState::Authentication || self.state == ProtoState::Initialization
{
match msg {
Some(FeMessage::PasswordMessage(ref _m)) => {}
Some(FeMessage::StartupMessage(ref _m)) => {}
Some(_) => {
bail!("protocol violation");
}
None => {}
};
}
match msg {
Some(FeMessage::StartupMessage(m)) => {
trace!("got startup message {:?}", m);
match m.kind {
StartupRequestCode::NegotiateGss | StartupRequestCode::NegotiateSsl => {
info!("SSL requested");
self.write_message(&BeMessage::Negotiate)?;
}
StartupRequestCode::Normal => {
// NB: startup() may change self.auth_type -- we are using that in proxy code
// to bypass auth for new users.
handler.startup(self, &m)?;
match self.auth_type {
AuthType::Trust => {
self.write_message_noflush(&BeMessage::AuthenticationOk)?;
// psycopg2 will not connect if client_encoding is not
// specified by the server
self.write_message_noflush(&BeMessage::ParameterStatus)?;
self.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
AuthType::MD5 => {
rand::thread_rng().fill(&mut self.md5_salt);
let md5_salt = self.md5_salt;
self.write_message(&BeMessage::AuthenticationMD5Password(
&md5_salt,
))?;
self.state = ProtoState::Authentication;
}
}
}
StartupRequestCode::Cancel => break,
}
}
Some(FeMessage::PasswordMessage(m)) => {
trace!("got password message '{:?}'", m);
assert!(self.state == ProtoState::Authentication);
let (_, md5_response) = m
.split_last()
.ok_or_else(|| anyhow::Error::msg("protocol violation"))?;
if let Err(e) = handler.check_auth_md5(self, md5_response) {
self.write_message(&BeMessage::ErrorResponse(format!("{}", e)))?;
bail!("auth failed: {}", e);
} else {
self.write_message_noflush(&BeMessage::AuthenticationOk)?;
// psycopg2 will not connect if client_encoding is not
// specified by the server
self.write_message_noflush(&BeMessage::ParameterStatus)?;
self.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
}
Some(FeMessage::Query(m)) => {
trace!("got query {:?}", m.body);
// xxx distinguish fatal and recoverable errors?
if let Err(e) = handler.process_query(self, m.body) {
let errmsg = format!("{}", e);
self.write_message_noflush(&BeMessage::ErrorResponse(errmsg))?;
}
self.write_message(&BeMessage::ReadyForQuery)?;
}
Some(FeMessage::Parse(m)) => {
unnamed_query_string = m.query_string;
self.write_message(&BeMessage::ParseComplete)?;
}
Some(FeMessage::Describe(_)) => {
self.write_message_noflush(&BeMessage::ParameterDescription)?
.write_message(&BeMessage::NoData)?;
}
Some(FeMessage::Bind(_)) => {
self.write_message(&BeMessage::BindComplete)?;
}
Some(FeMessage::Close(_)) => {
self.write_message(&BeMessage::CloseComplete)?;
}
Some(FeMessage::Execute(_)) => {
handler.process_query(self, unnamed_query_string.clone())?;
}
Some(FeMessage::Sync) => {
self.write_message(&BeMessage::ReadyForQuery)?;
}
Some(FeMessage::Terminate) => {
break;
}
None => {
info!("connection closed");
break;
}
x => {
bail!("unexpected message type : {:?}", x);
}
match self.process_message(handler, msg, &mut unnamed_query_string)? {
ProcessMsgResult::Continue => continue,
ProcessMsgResult::Break => break,
}
}
info!("postgres backend to {:?} exited", peer_addr);
Ok(())
}
fn process_message(
&mut self,
handler: &mut impl Handler,
msg: FeMessage,
unnamed_query_string: &mut Bytes,
) -> Result<ProcessMsgResult> {
// Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
// TODO: change that to proper top-level match of protocol state with separate message handling for each state
if self.state < ProtoState::Established {
match &msg {
FeMessage::PasswordMessage(_m) => {}
FeMessage::StartupMessage(_m) => {}
_ => {
bail!("protocol violation");
}
}
}
match msg {
FeMessage::StartupMessage(m) => {
trace!("got startup message {:?}", m);
match m.kind {
StartupRequestCode::NegotiateGss | StartupRequestCode::NegotiateSsl => {
info!("SSL requested");
self.write_message(&BeMessage::Negotiate)?;
}
StartupRequestCode::Normal => {
// NB: startup() may change self.auth_type -- we are using that in proxy code
// to bypass auth for new users.
handler.startup(self, &m)?;
match self.auth_type {
AuthType::Trust => {
self.write_message_noflush(&BeMessage::AuthenticationOk)?;
// psycopg2 will not connect if client_encoding is not
// specified by the server
self.write_message_noflush(&BeMessage::ParameterStatus)?;
self.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
AuthType::MD5 => {
rand::thread_rng().fill(&mut self.md5_salt);
let md5_salt = self.md5_salt;
self.write_message(&BeMessage::AuthenticationMD5Password(
&md5_salt,
))?;
self.state = ProtoState::Authentication;
}
}
}
StartupRequestCode::Cancel => {
return Ok(ProcessMsgResult::Break);
}
}
}
FeMessage::PasswordMessage(m) => {
trace!("got password message '{:?}'", m);
assert!(self.state == ProtoState::Authentication);
let (_, md5_response) = m
.split_last()
.ok_or_else(|| anyhow::Error::msg("protocol violation"))?;
if let Err(e) = handler.check_auth_md5(self, md5_response) {
self.write_message(&BeMessage::ErrorResponse(format!("{}", e)))?;
bail!("auth failed: {}", e);
} else {
self.write_message_noflush(&BeMessage::AuthenticationOk)?;
// psycopg2 will not connect if client_encoding is not
// specified by the server
self.write_message_noflush(&BeMessage::ParameterStatus)?;
self.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
}
FeMessage::Query(m) => {
trace!("got query {:?}", m.body);
// xxx distinguish fatal and recoverable errors?
if let Err(e) = handler.process_query(self, m.body) {
let errmsg = format!("{}", e);
self.write_message_noflush(&BeMessage::ErrorResponse(errmsg))?;
}
self.write_message(&BeMessage::ReadyForQuery)?;
}
FeMessage::Parse(m) => {
*unnamed_query_string = m.query_string;
self.write_message(&BeMessage::ParseComplete)?;
}
FeMessage::Describe(_) => {
self.write_message_noflush(&BeMessage::ParameterDescription)?
.write_message(&BeMessage::NoData)?;
}
FeMessage::Bind(_) => {
self.write_message(&BeMessage::BindComplete)?;
}
FeMessage::Close(_) => {
self.write_message(&BeMessage::CloseComplete)?;
}
FeMessage::Execute(_) => {
handler.process_query(self, unnamed_query_string.clone())?;
}
FeMessage::Sync => {
self.write_message(&BeMessage::ReadyForQuery)?;
}
FeMessage::Terminate => {
return Ok(ProcessMsgResult::Break);
}
// We prefer explicit pattern matching to wildcards, because
// this helps us spot the places where new variants are missing
FeMessage::CopyData(_) | FeMessage::CopyDone => {
bail!("unexpected message type: {:?}", msg);
}
}
Ok(ProcessMsgResult::Continue)
}
}