mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-10 06:52:55 +00:00
Parse cancel message in pq_proto (#1060)
This commit is contained in:
@@ -113,31 +113,31 @@ impl ProxyConnection {
|
||||
let mut encrypted = false;
|
||||
|
||||
loop {
|
||||
let mut msg = match self.pgb.read_message()? {
|
||||
Some(Fe::StartupMessage(msg)) => msg,
|
||||
let msg = match self.pgb.read_message()? {
|
||||
Some(Fe::StartupPacket(msg)) => msg,
|
||||
None => bail!("connection is lost"),
|
||||
bad => bail!("unexpected message type: {:?}", bad),
|
||||
};
|
||||
println!("got message: {:?}", msg);
|
||||
|
||||
match msg.kind {
|
||||
StartupRequestCode::NegotiateGss => {
|
||||
match msg {
|
||||
FeStartupPacket::GssEncRequest => {
|
||||
self.pgb.write_message(&Be::EncryptionResponse(false))?;
|
||||
}
|
||||
StartupRequestCode::NegotiateSsl => {
|
||||
FeStartupPacket::SslRequest => {
|
||||
self.pgb.write_message(&Be::EncryptionResponse(have_tls))?;
|
||||
if have_tls {
|
||||
self.pgb.start_tls()?;
|
||||
encrypted = true;
|
||||
}
|
||||
}
|
||||
StartupRequestCode::Normal => {
|
||||
FeStartupPacket::StartupMessage { mut params, .. } => {
|
||||
if have_tls && !encrypted {
|
||||
bail!("must connect with TLS");
|
||||
}
|
||||
|
||||
let mut get_param = |key| {
|
||||
msg.params
|
||||
params
|
||||
.remove(key)
|
||||
.ok_or_else(|| anyhow!("{} is missing in startup packet", key))
|
||||
};
|
||||
@@ -145,7 +145,9 @@ impl ProxyConnection {
|
||||
return Ok((get_param("user")?, get_param("database")?));
|
||||
}
|
||||
// TODO: implement proper stmt cancellation
|
||||
StartupRequestCode::Cancel => bail!("query cancellation is not supported"),
|
||||
FeStartupPacket::CancelRequest { .. } => {
|
||||
bail!("query cancellation is not supported")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ use std::sync::Arc;
|
||||
use zenith_utils::lsn::Lsn;
|
||||
use zenith_utils::postgres_backend;
|
||||
use zenith_utils::postgres_backend::PostgresBackend;
|
||||
use zenith_utils::pq_proto::{BeMessage, FeStartupMessage, RowDescriptor, INT4_OID, TEXT_OID};
|
||||
use zenith_utils::pq_proto::{BeMessage, FeStartupPacket, RowDescriptor, INT4_OID, TEXT_OID};
|
||||
use zenith_utils::zid::{ZTenantId, ZTimelineId};
|
||||
|
||||
use crate::callmemaybe::CallmeEvent;
|
||||
@@ -73,22 +73,26 @@ fn parse_cmd(cmd: &str) -> Result<SafekeeperPostgresCommand> {
|
||||
|
||||
impl postgres_backend::Handler for SafekeeperPostgresHandler {
|
||||
// ztenant id and ztimeline id are passed in connection string params
|
||||
fn startup(&mut self, _pgb: &mut PostgresBackend, sm: &FeStartupMessage) -> Result<()> {
|
||||
self.ztenantid = match sm.params.get("ztenantid") {
|
||||
Some(z) => Some(ZTenantId::from_str(z)?), // just curious, can I do that from .map?
|
||||
_ => None,
|
||||
};
|
||||
fn startup(&mut self, _pgb: &mut PostgresBackend, sm: &FeStartupPacket) -> Result<()> {
|
||||
if let FeStartupPacket::StartupMessage { params, .. } = sm {
|
||||
self.ztenantid = match params.get("ztenantid") {
|
||||
Some(z) => Some(ZTenantId::from_str(z)?), // just curious, can I do that from .map?
|
||||
_ => None,
|
||||
};
|
||||
|
||||
self.ztimelineid = match sm.params.get("ztimelineid") {
|
||||
Some(z) => Some(ZTimelineId::from_str(z)?),
|
||||
_ => None,
|
||||
};
|
||||
self.ztimelineid = match params.get("ztimelineid") {
|
||||
Some(z) => Some(ZTimelineId::from_str(z)?),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
if let Some(app_name) = sm.params.get("application_name") {
|
||||
self.appname = Some(app_name.clone());
|
||||
if let Some(app_name) = params.get("application_name") {
|
||||
self.appname = Some(app_name.clone());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
} else {
|
||||
bail!("Walkeeper received unexpected initial message: {:?}", sm);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn process_query(&mut self, pgb: &mut PostgresBackend, query_string: &str) -> Result<()> {
|
||||
|
||||
@@ -3,9 +3,7 @@
|
||||
//! implementation determining how to process the queries. Currently its API
|
||||
//! is rather narrow, but we can extend it once required.
|
||||
|
||||
use crate::pq_proto::{
|
||||
BeMessage, BeParameterStatusMessage, FeMessage, FeStartupMessage, StartupRequestCode,
|
||||
};
|
||||
use crate::pq_proto::{BeMessage, BeParameterStatusMessage, FeMessage, FeStartupPacket};
|
||||
use crate::sock_split::{BidiStream, ReadStream, WriteStream};
|
||||
use anyhow::{anyhow, bail, ensure, Result};
|
||||
use bytes::{Bytes, BytesMut};
|
||||
@@ -34,7 +32,7 @@ pub trait Handler {
|
||||
/// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
|
||||
/// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
|
||||
/// to override whole init logic in implementations.
|
||||
fn startup(&mut self, _pgb: &mut PostgresBackend, _sm: &FeStartupMessage) -> Result<()> {
|
||||
fn startup(&mut self, _pgb: &mut PostgresBackend, _sm: &FeStartupPacket) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -237,7 +235,7 @@ impl PostgresBackend {
|
||||
|
||||
use ProtoState::*;
|
||||
match state {
|
||||
Initialization | Encrypted => FeStartupMessage::read(stream),
|
||||
Initialization | Encrypted => FeStartupPacket::read(stream),
|
||||
Authentication | Established => FeMessage::read(stream),
|
||||
}
|
||||
}
|
||||
@@ -329,7 +327,7 @@ impl PostgresBackend {
|
||||
ensure!(
|
||||
matches!(
|
||||
msg,
|
||||
FeMessage::PasswordMessage(_) | FeMessage::StartupMessage(_)
|
||||
FeMessage::PasswordMessage(_) | FeMessage::StartupPacket(_)
|
||||
),
|
||||
"protocol violation"
|
||||
);
|
||||
@@ -337,11 +335,11 @@ impl PostgresBackend {
|
||||
|
||||
let have_tls = self.tls_config.is_some();
|
||||
match msg {
|
||||
FeMessage::StartupMessage(m) => {
|
||||
FeMessage::StartupPacket(m) => {
|
||||
trace!("got startup message {:?}", m);
|
||||
|
||||
match m.kind {
|
||||
StartupRequestCode::NegotiateSsl => {
|
||||
match m {
|
||||
FeStartupPacket::SslRequest => {
|
||||
info!("SSL requested");
|
||||
|
||||
self.write_message(&BeMessage::EncryptionResponse(have_tls))?;
|
||||
@@ -350,11 +348,11 @@ impl PostgresBackend {
|
||||
self.state = ProtoState::Encrypted;
|
||||
}
|
||||
}
|
||||
StartupRequestCode::NegotiateGss => {
|
||||
FeStartupPacket::GssEncRequest => {
|
||||
info!("GSS requested");
|
||||
self.write_message(&BeMessage::EncryptionResponse(false))?;
|
||||
}
|
||||
StartupRequestCode::Normal => {
|
||||
FeStartupPacket::StartupMessage { .. } => {
|
||||
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
|
||||
self.write_message(&BeMessage::ErrorResponse(
|
||||
"must connect with TLS".to_string(),
|
||||
@@ -387,7 +385,7 @@ impl PostgresBackend {
|
||||
}
|
||||
}
|
||||
}
|
||||
StartupRequestCode::Cancel => {
|
||||
FeStartupPacket::CancelRequest { .. } => {
|
||||
return Ok(ProcessMsgResult::Break);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,14 +2,14 @@
|
||||
//! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
|
||||
//! on message formats.
|
||||
|
||||
use anyhow::{anyhow, bail, Result};
|
||||
use anyhow::{anyhow, bail, ensure, Result};
|
||||
use byteorder::{BigEndian, ByteOrder};
|
||||
use byteorder::{ReadBytesExt, BE};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
// use postgres_ffi::xlog_utils::TimestampTz;
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use std::io::Read;
|
||||
use std::io::{self, Cursor};
|
||||
use std::str;
|
||||
|
||||
pub type Oid = u32;
|
||||
@@ -21,7 +21,7 @@ pub const TEXT_OID: Oid = 25;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum FeMessage {
|
||||
StartupMessage(FeStartupMessage),
|
||||
StartupPacket(FeStartupPacket),
|
||||
Query(FeQueryMessage), // Simple query
|
||||
Parse(FeParseMessage), // Extended query protocol
|
||||
Describe(FeDescribeMessage),
|
||||
@@ -37,19 +37,15 @@ pub enum FeMessage {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FeStartupMessage {
|
||||
pub version: u32,
|
||||
pub kind: StartupRequestCode,
|
||||
// optional params arriving in startup packet
|
||||
pub params: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum StartupRequestCode {
|
||||
Cancel,
|
||||
NegotiateSsl,
|
||||
NegotiateGss,
|
||||
Normal,
|
||||
pub enum FeStartupPacket {
|
||||
CancelRequest(CancelKeyData),
|
||||
SslRequest,
|
||||
GssEncRequest,
|
||||
StartupMessage {
|
||||
major_version: u32,
|
||||
minor_version: u32,
|
||||
params: HashMap<String, String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -153,13 +149,14 @@ impl FeMessage {
|
||||
}
|
||||
}
|
||||
|
||||
impl FeStartupMessage {
|
||||
impl FeStartupPacket {
|
||||
/// Read startup message from the stream.
|
||||
pub fn read(stream: &mut impl std::io::Read) -> anyhow::Result<Option<FeMessage>> {
|
||||
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
|
||||
const CANCEL_REQUEST_CODE: u32 = (1234 << 16) | 5678;
|
||||
const NEGOTIATE_SSL_CODE: u32 = (1234 << 16) | 5679;
|
||||
const NEGOTIATE_GSS_CODE: u32 = (1234 << 16) | 5680;
|
||||
const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
|
||||
const CANCEL_REQUEST_CODE: u32 = 5678;
|
||||
const NEGOTIATE_SSL_CODE: u32 = 5679;
|
||||
const NEGOTIATE_GSS_CODE: u32 = 5680;
|
||||
|
||||
// Read length. If the connection is closed before reading anything (or before
|
||||
// reading 4 bytes, to be precise), return None to indicate that the connection
|
||||
@@ -175,44 +172,60 @@ impl FeStartupMessage {
|
||||
bail!("invalid message length");
|
||||
}
|
||||
|
||||
let version = stream.read_u32::<BE>()?;
|
||||
let kind = match version {
|
||||
CANCEL_REQUEST_CODE => StartupRequestCode::Cancel,
|
||||
NEGOTIATE_SSL_CODE => StartupRequestCode::NegotiateSsl,
|
||||
NEGOTIATE_GSS_CODE => StartupRequestCode::NegotiateGss,
|
||||
_ => StartupRequestCode::Normal,
|
||||
};
|
||||
let request_code = stream.read_u32::<BE>()?;
|
||||
|
||||
// the rest of startup packet are params
|
||||
let params_len = len - 8;
|
||||
let mut params_bytes = vec![0u8; params_len];
|
||||
stream.read_exact(params_bytes.as_mut())?;
|
||||
|
||||
// Then null-terminated (String) pairs of param name / param value go.
|
||||
let params_str = str::from_utf8(¶ms_bytes).unwrap();
|
||||
let params = params_str.split('\0');
|
||||
let mut params_hash: HashMap<String, String> = HashMap::new();
|
||||
for pair in params.collect::<Vec<_>>().chunks_exact(2) {
|
||||
let name = pair[0];
|
||||
let value = pair[1];
|
||||
if name == "options" {
|
||||
// deprecated way of passing params as cmd line args
|
||||
for cmdopt in value.split(' ') {
|
||||
let nameval: Vec<&str> = cmdopt.split('=').collect();
|
||||
if nameval.len() == 2 {
|
||||
params_hash.insert(nameval[0].to_string(), nameval[1].to_string());
|
||||
// Parse params depending on request code
|
||||
let most_sig_16_bits = request_code >> 16;
|
||||
let least_sig_16_bits = request_code & ((1 << 16) - 1);
|
||||
let message = match (most_sig_16_bits, least_sig_16_bits) {
|
||||
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
|
||||
ensure!(params_len == 8, "expected 8 bytes for CancelRequest params");
|
||||
let mut cursor = Cursor::new(params_bytes);
|
||||
FeStartupPacket::CancelRequest(CancelKeyData {
|
||||
backend_pid: cursor.read_i32::<BigEndian>()?,
|
||||
cancel_key: cursor.read_i32::<BigEndian>()?,
|
||||
})
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => FeStartupPacket::SslRequest,
|
||||
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => FeStartupPacket::GssEncRequest,
|
||||
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
|
||||
bail!("Unrecognized request code {}", unrecognized_code)
|
||||
}
|
||||
(major_version, minor_version) => {
|
||||
// TODO bail if protocol major_version is not 3?
|
||||
// Parse null-terminated (String) pairs of param name / param value
|
||||
let params_str = str::from_utf8(¶ms_bytes).unwrap();
|
||||
let mut params_tokens = params_str.split('\0');
|
||||
let mut params: HashMap<String, String> = HashMap::new();
|
||||
while let Some(name) = params_tokens.next() {
|
||||
let value = params_tokens.next().ok_or_else(|| {
|
||||
anyhow!("expected even number of params in StartupMessage")
|
||||
})?;
|
||||
if name == "options" {
|
||||
// deprecated way of passing params as cmd line args
|
||||
for cmdopt in value.split(' ') {
|
||||
let nameval: Vec<&str> = cmdopt.split('=').collect();
|
||||
if nameval.len() == 2 {
|
||||
params.insert(nameval[0].to_string(), nameval[1].to_string());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
params.insert(name.to_string(), value.to_string());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
params_hash.insert(name.to_string(), value.to_string());
|
||||
FeStartupPacket::StartupMessage {
|
||||
major_version,
|
||||
minor_version,
|
||||
params,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Some(FeMessage::StartupMessage(FeStartupMessage {
|
||||
version,
|
||||
kind,
|
||||
params: params_hash,
|
||||
})))
|
||||
};
|
||||
Ok(Some(FeMessage::StartupPacket(message)))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user