Parse cancel message in pq_proto (#1060)

This commit is contained in:
bojanserafimov
2021-12-28 16:43:44 -05:00
committed by GitHub
parent 1e3ddd43bc
commit 24eca8d58b
4 changed files with 100 additions and 83 deletions

View File

@@ -113,31 +113,31 @@ impl ProxyConnection {
let mut encrypted = false;
loop {
let mut msg = match self.pgb.read_message()? {
Some(Fe::StartupMessage(msg)) => msg,
let msg = match self.pgb.read_message()? {
Some(Fe::StartupPacket(msg)) => msg,
None => bail!("connection is lost"),
bad => bail!("unexpected message type: {:?}", bad),
};
println!("got message: {:?}", msg);
match msg.kind {
StartupRequestCode::NegotiateGss => {
match msg {
FeStartupPacket::GssEncRequest => {
self.pgb.write_message(&Be::EncryptionResponse(false))?;
}
StartupRequestCode::NegotiateSsl => {
FeStartupPacket::SslRequest => {
self.pgb.write_message(&Be::EncryptionResponse(have_tls))?;
if have_tls {
self.pgb.start_tls()?;
encrypted = true;
}
}
StartupRequestCode::Normal => {
FeStartupPacket::StartupMessage { mut params, .. } => {
if have_tls && !encrypted {
bail!("must connect with TLS");
}
let mut get_param = |key| {
msg.params
params
.remove(key)
.ok_or_else(|| anyhow!("{} is missing in startup packet", key))
};
@@ -145,7 +145,9 @@ impl ProxyConnection {
return Ok((get_param("user")?, get_param("database")?));
}
// TODO: implement proper stmt cancellation
StartupRequestCode::Cancel => bail!("query cancellation is not supported"),
FeStartupPacket::CancelRequest { .. } => {
bail!("query cancellation is not supported")
}
}
}
}

View File

@@ -15,7 +15,7 @@ use std::sync::Arc;
use zenith_utils::lsn::Lsn;
use zenith_utils::postgres_backend;
use zenith_utils::postgres_backend::PostgresBackend;
use zenith_utils::pq_proto::{BeMessage, FeStartupMessage, RowDescriptor, INT4_OID, TEXT_OID};
use zenith_utils::pq_proto::{BeMessage, FeStartupPacket, RowDescriptor, INT4_OID, TEXT_OID};
use zenith_utils::zid::{ZTenantId, ZTimelineId};
use crate::callmemaybe::CallmeEvent;
@@ -73,22 +73,26 @@ fn parse_cmd(cmd: &str) -> Result<SafekeeperPostgresCommand> {
impl postgres_backend::Handler for SafekeeperPostgresHandler {
// ztenant id and ztimeline id are passed in connection string params
fn startup(&mut self, _pgb: &mut PostgresBackend, sm: &FeStartupMessage) -> Result<()> {
self.ztenantid = match sm.params.get("ztenantid") {
Some(z) => Some(ZTenantId::from_str(z)?), // just curious, can I do that from .map?
_ => None,
};
fn startup(&mut self, _pgb: &mut PostgresBackend, sm: &FeStartupPacket) -> Result<()> {
if let FeStartupPacket::StartupMessage { params, .. } = sm {
self.ztenantid = match params.get("ztenantid") {
Some(z) => Some(ZTenantId::from_str(z)?), // just curious, can I do that from .map?
_ => None,
};
self.ztimelineid = match sm.params.get("ztimelineid") {
Some(z) => Some(ZTimelineId::from_str(z)?),
_ => None,
};
self.ztimelineid = match params.get("ztimelineid") {
Some(z) => Some(ZTimelineId::from_str(z)?),
_ => None,
};
if let Some(app_name) = sm.params.get("application_name") {
self.appname = Some(app_name.clone());
if let Some(app_name) = params.get("application_name") {
self.appname = Some(app_name.clone());
}
Ok(())
} else {
bail!("Walkeeper received unexpected initial message: {:?}", sm);
}
Ok(())
}
fn process_query(&mut self, pgb: &mut PostgresBackend, query_string: &str) -> Result<()> {

View File

@@ -3,9 +3,7 @@
//! implementation determining how to process the queries. Currently its API
//! is rather narrow, but we can extend it once required.
use crate::pq_proto::{
BeMessage, BeParameterStatusMessage, FeMessage, FeStartupMessage, StartupRequestCode,
};
use crate::pq_proto::{BeMessage, BeParameterStatusMessage, FeMessage, FeStartupPacket};
use crate::sock_split::{BidiStream, ReadStream, WriteStream};
use anyhow::{anyhow, bail, ensure, Result};
use bytes::{Bytes, BytesMut};
@@ -34,7 +32,7 @@ pub trait Handler {
/// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
/// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
/// to override whole init logic in implementations.
fn startup(&mut self, _pgb: &mut PostgresBackend, _sm: &FeStartupMessage) -> Result<()> {
fn startup(&mut self, _pgb: &mut PostgresBackend, _sm: &FeStartupPacket) -> Result<()> {
Ok(())
}
@@ -237,7 +235,7 @@ impl PostgresBackend {
use ProtoState::*;
match state {
Initialization | Encrypted => FeStartupMessage::read(stream),
Initialization | Encrypted => FeStartupPacket::read(stream),
Authentication | Established => FeMessage::read(stream),
}
}
@@ -329,7 +327,7 @@ impl PostgresBackend {
ensure!(
matches!(
msg,
FeMessage::PasswordMessage(_) | FeMessage::StartupMessage(_)
FeMessage::PasswordMessage(_) | FeMessage::StartupPacket(_)
),
"protocol violation"
);
@@ -337,11 +335,11 @@ impl PostgresBackend {
let have_tls = self.tls_config.is_some();
match msg {
FeMessage::StartupMessage(m) => {
FeMessage::StartupPacket(m) => {
trace!("got startup message {:?}", m);
match m.kind {
StartupRequestCode::NegotiateSsl => {
match m {
FeStartupPacket::SslRequest => {
info!("SSL requested");
self.write_message(&BeMessage::EncryptionResponse(have_tls))?;
@@ -350,11 +348,11 @@ impl PostgresBackend {
self.state = ProtoState::Encrypted;
}
}
StartupRequestCode::NegotiateGss => {
FeStartupPacket::GssEncRequest => {
info!("GSS requested");
self.write_message(&BeMessage::EncryptionResponse(false))?;
}
StartupRequestCode::Normal => {
FeStartupPacket::StartupMessage { .. } => {
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
self.write_message(&BeMessage::ErrorResponse(
"must connect with TLS".to_string(),
@@ -387,7 +385,7 @@ impl PostgresBackend {
}
}
}
StartupRequestCode::Cancel => {
FeStartupPacket::CancelRequest { .. } => {
return Ok(ProcessMsgResult::Break);
}
}

View File

@@ -2,14 +2,14 @@
//! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
//! on message formats.
use anyhow::{anyhow, bail, Result};
use anyhow::{anyhow, bail, ensure, Result};
use byteorder::{BigEndian, ByteOrder};
use byteorder::{ReadBytesExt, BE};
use bytes::{Buf, BufMut, Bytes, BytesMut};
// use postgres_ffi::xlog_utils::TimestampTz;
use std::collections::HashMap;
use std::io;
use std::io::Read;
use std::io::{self, Cursor};
use std::str;
pub type Oid = u32;
@@ -21,7 +21,7 @@ pub const TEXT_OID: Oid = 25;
#[derive(Debug)]
pub enum FeMessage {
StartupMessage(FeStartupMessage),
StartupPacket(FeStartupPacket),
Query(FeQueryMessage), // Simple query
Parse(FeParseMessage), // Extended query protocol
Describe(FeDescribeMessage),
@@ -37,19 +37,15 @@ pub enum FeMessage {
}
#[derive(Debug)]
pub struct FeStartupMessage {
pub version: u32,
pub kind: StartupRequestCode,
// optional params arriving in startup packet
pub params: HashMap<String, String>,
}
#[derive(Debug)]
pub enum StartupRequestCode {
Cancel,
NegotiateSsl,
NegotiateGss,
Normal,
pub enum FeStartupPacket {
CancelRequest(CancelKeyData),
SslRequest,
GssEncRequest,
StartupMessage {
major_version: u32,
minor_version: u32,
params: HashMap<String, String>,
},
}
#[derive(Debug)]
@@ -153,13 +149,14 @@ impl FeMessage {
}
}
impl FeStartupMessage {
impl FeStartupPacket {
/// Read startup message from the stream.
pub fn read(stream: &mut impl std::io::Read) -> anyhow::Result<Option<FeMessage>> {
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
const CANCEL_REQUEST_CODE: u32 = (1234 << 16) | 5678;
const NEGOTIATE_SSL_CODE: u32 = (1234 << 16) | 5679;
const NEGOTIATE_GSS_CODE: u32 = (1234 << 16) | 5680;
const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
const CANCEL_REQUEST_CODE: u32 = 5678;
const NEGOTIATE_SSL_CODE: u32 = 5679;
const NEGOTIATE_GSS_CODE: u32 = 5680;
// Read length. If the connection is closed before reading anything (or before
// reading 4 bytes, to be precise), return None to indicate that the connection
@@ -175,44 +172,60 @@ impl FeStartupMessage {
bail!("invalid message length");
}
let version = stream.read_u32::<BE>()?;
let kind = match version {
CANCEL_REQUEST_CODE => StartupRequestCode::Cancel,
NEGOTIATE_SSL_CODE => StartupRequestCode::NegotiateSsl,
NEGOTIATE_GSS_CODE => StartupRequestCode::NegotiateGss,
_ => StartupRequestCode::Normal,
};
let request_code = stream.read_u32::<BE>()?;
// the rest of startup packet are params
let params_len = len - 8;
let mut params_bytes = vec![0u8; params_len];
stream.read_exact(params_bytes.as_mut())?;
// Then null-terminated (String) pairs of param name / param value go.
let params_str = str::from_utf8(&params_bytes).unwrap();
let params = params_str.split('\0');
let mut params_hash: HashMap<String, String> = HashMap::new();
for pair in params.collect::<Vec<_>>().chunks_exact(2) {
let name = pair[0];
let value = pair[1];
if name == "options" {
// deprecated way of passing params as cmd line args
for cmdopt in value.split(' ') {
let nameval: Vec<&str> = cmdopt.split('=').collect();
if nameval.len() == 2 {
params_hash.insert(nameval[0].to_string(), nameval[1].to_string());
// Parse params depending on request code
let most_sig_16_bits = request_code >> 16;
let least_sig_16_bits = request_code & ((1 << 16) - 1);
let message = match (most_sig_16_bits, least_sig_16_bits) {
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
ensure!(params_len == 8, "expected 8 bytes for CancelRequest params");
let mut cursor = Cursor::new(params_bytes);
FeStartupPacket::CancelRequest(CancelKeyData {
backend_pid: cursor.read_i32::<BigEndian>()?,
cancel_key: cursor.read_i32::<BigEndian>()?,
})
}
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => FeStartupPacket::SslRequest,
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => FeStartupPacket::GssEncRequest,
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
bail!("Unrecognized request code {}", unrecognized_code)
}
(major_version, minor_version) => {
// TODO bail if protocol major_version is not 3?
// Parse null-terminated (String) pairs of param name / param value
let params_str = str::from_utf8(&params_bytes).unwrap();
let mut params_tokens = params_str.split('\0');
let mut params: HashMap<String, String> = HashMap::new();
while let Some(name) = params_tokens.next() {
let value = params_tokens.next().ok_or_else(|| {
anyhow!("expected even number of params in StartupMessage")
})?;
if name == "options" {
// deprecated way of passing params as cmd line args
for cmdopt in value.split(' ') {
let nameval: Vec<&str> = cmdopt.split('=').collect();
if nameval.len() == 2 {
params.insert(nameval[0].to_string(), nameval[1].to_string());
}
}
} else {
params.insert(name.to_string(), value.to_string());
}
}
} else {
params_hash.insert(name.to_string(), value.to_string());
FeStartupPacket::StartupMessage {
major_version,
minor_version,
params,
}
}
}
Ok(Some(FeMessage::StartupMessage(FeStartupMessage {
version,
kind,
params: params_hash,
})))
};
Ok(Some(FeMessage::StartupPacket(message)))
}
}