diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index ad0af57eea..1dd455a306 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -113,31 +113,31 @@ impl ProxyConnection { let mut encrypted = false; loop { - let mut msg = match self.pgb.read_message()? { - Some(Fe::StartupMessage(msg)) => msg, + let msg = match self.pgb.read_message()? { + Some(Fe::StartupPacket(msg)) => msg, None => bail!("connection is lost"), bad => bail!("unexpected message type: {:?}", bad), }; println!("got message: {:?}", msg); - match msg.kind { - StartupRequestCode::NegotiateGss => { + match msg { + FeStartupPacket::GssEncRequest => { self.pgb.write_message(&Be::EncryptionResponse(false))?; } - StartupRequestCode::NegotiateSsl => { + FeStartupPacket::SslRequest => { self.pgb.write_message(&Be::EncryptionResponse(have_tls))?; if have_tls { self.pgb.start_tls()?; encrypted = true; } } - StartupRequestCode::Normal => { + FeStartupPacket::StartupMessage { mut params, .. } => { if have_tls && !encrypted { bail!("must connect with TLS"); } let mut get_param = |key| { - msg.params + params .remove(key) .ok_or_else(|| anyhow!("{} is missing in startup packet", key)) }; @@ -145,7 +145,9 @@ impl ProxyConnection { return Ok((get_param("user")?, get_param("database")?)); } // TODO: implement proper stmt cancellation - StartupRequestCode::Cancel => bail!("query cancellation is not supported"), + FeStartupPacket::CancelRequest { .. } => { + bail!("query cancellation is not supported") + } } } } diff --git a/walkeeper/src/handler.rs b/walkeeper/src/handler.rs index 5ed599ab07..31ae848bae 100644 --- a/walkeeper/src/handler.rs +++ b/walkeeper/src/handler.rs @@ -15,7 +15,7 @@ use std::sync::Arc; use zenith_utils::lsn::Lsn; use zenith_utils::postgres_backend; use zenith_utils::postgres_backend::PostgresBackend; -use zenith_utils::pq_proto::{BeMessage, FeStartupMessage, RowDescriptor, INT4_OID, TEXT_OID}; +use zenith_utils::pq_proto::{BeMessage, FeStartupPacket, RowDescriptor, INT4_OID, TEXT_OID}; use zenith_utils::zid::{ZTenantId, ZTimelineId}; use crate::callmemaybe::CallmeEvent; @@ -73,22 +73,26 @@ fn parse_cmd(cmd: &str) -> Result { impl postgres_backend::Handler for SafekeeperPostgresHandler { // ztenant id and ztimeline id are passed in connection string params - fn startup(&mut self, _pgb: &mut PostgresBackend, sm: &FeStartupMessage) -> Result<()> { - self.ztenantid = match sm.params.get("ztenantid") { - Some(z) => Some(ZTenantId::from_str(z)?), // just curious, can I do that from .map? - _ => None, - }; + fn startup(&mut self, _pgb: &mut PostgresBackend, sm: &FeStartupPacket) -> Result<()> { + if let FeStartupPacket::StartupMessage { params, .. } = sm { + self.ztenantid = match params.get("ztenantid") { + Some(z) => Some(ZTenantId::from_str(z)?), // just curious, can I do that from .map? + _ => None, + }; - self.ztimelineid = match sm.params.get("ztimelineid") { - Some(z) => Some(ZTimelineId::from_str(z)?), - _ => None, - }; + self.ztimelineid = match params.get("ztimelineid") { + Some(z) => Some(ZTimelineId::from_str(z)?), + _ => None, + }; - if let Some(app_name) = sm.params.get("application_name") { - self.appname = Some(app_name.clone()); + if let Some(app_name) = params.get("application_name") { + self.appname = Some(app_name.clone()); + } + + Ok(()) + } else { + bail!("Walkeeper received unexpected initial message: {:?}", sm); } - - Ok(()) } fn process_query(&mut self, pgb: &mut PostgresBackend, query_string: &str) -> Result<()> { diff --git a/zenith_utils/src/postgres_backend.rs b/zenith_utils/src/postgres_backend.rs index 8938c2803b..d55ead93bc 100644 --- a/zenith_utils/src/postgres_backend.rs +++ b/zenith_utils/src/postgres_backend.rs @@ -3,9 +3,7 @@ //! implementation determining how to process the queries. Currently its API //! is rather narrow, but we can extend it once required. -use crate::pq_proto::{ - BeMessage, BeParameterStatusMessage, FeMessage, FeStartupMessage, StartupRequestCode, -}; +use crate::pq_proto::{BeMessage, BeParameterStatusMessage, FeMessage, FeStartupPacket}; use crate::sock_split::{BidiStream, ReadStream, WriteStream}; use anyhow::{anyhow, bail, ensure, Result}; use bytes::{Bytes, BytesMut}; @@ -34,7 +32,7 @@ pub trait Handler { /// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users /// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow /// to override whole init logic in implementations. - fn startup(&mut self, _pgb: &mut PostgresBackend, _sm: &FeStartupMessage) -> Result<()> { + fn startup(&mut self, _pgb: &mut PostgresBackend, _sm: &FeStartupPacket) -> Result<()> { Ok(()) } @@ -237,7 +235,7 @@ impl PostgresBackend { use ProtoState::*; match state { - Initialization | Encrypted => FeStartupMessage::read(stream), + Initialization | Encrypted => FeStartupPacket::read(stream), Authentication | Established => FeMessage::read(stream), } } @@ -329,7 +327,7 @@ impl PostgresBackend { ensure!( matches!( msg, - FeMessage::PasswordMessage(_) | FeMessage::StartupMessage(_) + FeMessage::PasswordMessage(_) | FeMessage::StartupPacket(_) ), "protocol violation" ); @@ -337,11 +335,11 @@ impl PostgresBackend { let have_tls = self.tls_config.is_some(); match msg { - FeMessage::StartupMessage(m) => { + FeMessage::StartupPacket(m) => { trace!("got startup message {:?}", m); - match m.kind { - StartupRequestCode::NegotiateSsl => { + match m { + FeStartupPacket::SslRequest => { info!("SSL requested"); self.write_message(&BeMessage::EncryptionResponse(have_tls))?; @@ -350,11 +348,11 @@ impl PostgresBackend { self.state = ProtoState::Encrypted; } } - StartupRequestCode::NegotiateGss => { + FeStartupPacket::GssEncRequest => { info!("GSS requested"); self.write_message(&BeMessage::EncryptionResponse(false))?; } - StartupRequestCode::Normal => { + FeStartupPacket::StartupMessage { .. } => { if have_tls && !matches!(self.state, ProtoState::Encrypted) { self.write_message(&BeMessage::ErrorResponse( "must connect with TLS".to_string(), @@ -387,7 +385,7 @@ impl PostgresBackend { } } } - StartupRequestCode::Cancel => { + FeStartupPacket::CancelRequest { .. } => { return Ok(ProcessMsgResult::Break); } } diff --git a/zenith_utils/src/pq_proto.rs b/zenith_utils/src/pq_proto.rs index b5e40c66bc..d0fde2486e 100644 --- a/zenith_utils/src/pq_proto.rs +++ b/zenith_utils/src/pq_proto.rs @@ -2,14 +2,14 @@ //! //! on message formats. -use anyhow::{anyhow, bail, Result}; +use anyhow::{anyhow, bail, ensure, Result}; use byteorder::{BigEndian, ByteOrder}; use byteorder::{ReadBytesExt, BE}; use bytes::{Buf, BufMut, Bytes, BytesMut}; // use postgres_ffi::xlog_utils::TimestampTz; use std::collections::HashMap; -use std::io; use std::io::Read; +use std::io::{self, Cursor}; use std::str; pub type Oid = u32; @@ -21,7 +21,7 @@ pub const TEXT_OID: Oid = 25; #[derive(Debug)] pub enum FeMessage { - StartupMessage(FeStartupMessage), + StartupPacket(FeStartupPacket), Query(FeQueryMessage), // Simple query Parse(FeParseMessage), // Extended query protocol Describe(FeDescribeMessage), @@ -37,19 +37,15 @@ pub enum FeMessage { } #[derive(Debug)] -pub struct FeStartupMessage { - pub version: u32, - pub kind: StartupRequestCode, - // optional params arriving in startup packet - pub params: HashMap, -} - -#[derive(Debug)] -pub enum StartupRequestCode { - Cancel, - NegotiateSsl, - NegotiateGss, - Normal, +pub enum FeStartupPacket { + CancelRequest(CancelKeyData), + SslRequest, + GssEncRequest, + StartupMessage { + major_version: u32, + minor_version: u32, + params: HashMap, + }, } #[derive(Debug)] @@ -153,13 +149,14 @@ impl FeMessage { } } -impl FeStartupMessage { +impl FeStartupPacket { /// Read startup message from the stream. pub fn read(stream: &mut impl std::io::Read) -> anyhow::Result> { const MAX_STARTUP_PACKET_LENGTH: usize = 10000; - const CANCEL_REQUEST_CODE: u32 = (1234 << 16) | 5678; - const NEGOTIATE_SSL_CODE: u32 = (1234 << 16) | 5679; - const NEGOTIATE_GSS_CODE: u32 = (1234 << 16) | 5680; + const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234; + const CANCEL_REQUEST_CODE: u32 = 5678; + const NEGOTIATE_SSL_CODE: u32 = 5679; + const NEGOTIATE_GSS_CODE: u32 = 5680; // Read length. If the connection is closed before reading anything (or before // reading 4 bytes, to be precise), return None to indicate that the connection @@ -175,44 +172,60 @@ impl FeStartupMessage { bail!("invalid message length"); } - let version = stream.read_u32::()?; - let kind = match version { - CANCEL_REQUEST_CODE => StartupRequestCode::Cancel, - NEGOTIATE_SSL_CODE => StartupRequestCode::NegotiateSsl, - NEGOTIATE_GSS_CODE => StartupRequestCode::NegotiateGss, - _ => StartupRequestCode::Normal, - }; + let request_code = stream.read_u32::()?; // the rest of startup packet are params let params_len = len - 8; let mut params_bytes = vec![0u8; params_len]; stream.read_exact(params_bytes.as_mut())?; - // Then null-terminated (String) pairs of param name / param value go. - let params_str = str::from_utf8(¶ms_bytes).unwrap(); - let params = params_str.split('\0'); - let mut params_hash: HashMap = HashMap::new(); - for pair in params.collect::>().chunks_exact(2) { - let name = pair[0]; - let value = pair[1]; - if name == "options" { - // deprecated way of passing params as cmd line args - for cmdopt in value.split(' ') { - let nameval: Vec<&str> = cmdopt.split('=').collect(); - if nameval.len() == 2 { - params_hash.insert(nameval[0].to_string(), nameval[1].to_string()); + // Parse params depending on request code + let most_sig_16_bits = request_code >> 16; + let least_sig_16_bits = request_code & ((1 << 16) - 1); + let message = match (most_sig_16_bits, least_sig_16_bits) { + (RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => { + ensure!(params_len == 8, "expected 8 bytes for CancelRequest params"); + let mut cursor = Cursor::new(params_bytes); + FeStartupPacket::CancelRequest(CancelKeyData { + backend_pid: cursor.read_i32::()?, + cancel_key: cursor.read_i32::()?, + }) + } + (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => FeStartupPacket::SslRequest, + (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => FeStartupPacket::GssEncRequest, + (RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => { + bail!("Unrecognized request code {}", unrecognized_code) + } + (major_version, minor_version) => { + // TODO bail if protocol major_version is not 3? + // Parse null-terminated (String) pairs of param name / param value + let params_str = str::from_utf8(¶ms_bytes).unwrap(); + let mut params_tokens = params_str.split('\0'); + let mut params: HashMap = HashMap::new(); + while let Some(name) = params_tokens.next() { + let value = params_tokens.next().ok_or_else(|| { + anyhow!("expected even number of params in StartupMessage") + })?; + if name == "options" { + // deprecated way of passing params as cmd line args + for cmdopt in value.split(' ') { + let nameval: Vec<&str> = cmdopt.split('=').collect(); + if nameval.len() == 2 { + params.insert(nameval[0].to_string(), nameval[1].to_string()); + } + } + } else { + params.insert(name.to_string(), value.to_string()); } } - } else { - params_hash.insert(name.to_string(), value.to_string()); + FeStartupPacket::StartupMessage { + major_version, + minor_version, + params, + } } - } - - Ok(Some(FeMessage::StartupMessage(FeStartupMessage { - version, - kind, - params: params_hash, - }))) + }; + Ok(Some(FeMessage::StartupPacket(message))) } }