add some more typesafety to pqproto and move stream.rs to a folder

This commit is contained in:
Conrad Ludgate
2025-06-10 14:34:26 -07:00
parent b509982bbf
commit 72b1c573b1
2 changed files with 104 additions and 13 deletions

View File

@@ -11,11 +11,87 @@ use rand::distributions::{Distribution, Standard};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian};
pub type ErrorCode = [u8; 5];
#[derive(Copy, Clone, PartialEq)]
pub struct ErrorCode(pub [u8; 5]);
pub const FE_PASSWORD_MESSAGE: u8 = b'p';
#[derive(Copy, Clone, PartialEq)]
pub struct FeTag(pub u8);
pub const SQLSTATE_INTERNAL_ERROR: [u8; 5] = *b"XX000";
#[derive(Copy, Clone, PartialEq)]
pub struct BeTag(pub u8);
#[derive(Copy, Clone, PartialEq)]
pub struct AuthTag(pub i32);
pub const FE_PASSWORD_MESSAGE: FeTag = FeTag(b'p');
pub const BE_AUTH_MESSAGE: BeTag = BeTag(b'R');
pub const BE_ERR_MESSAGE: BeTag = BeTag(b'E');
pub const BE_KEY_MESSAGE: BeTag = BeTag(b'K');
pub const BE_READY_MESSAGE: BeTag = BeTag(b'Z');
pub const BE_NEGOTIATE_MESSAGE: BeTag = BeTag(b'v');
pub const AUTH_OK: AuthTag = AuthTag(0);
pub const AUTH_CLEAR: AuthTag = AuthTag(3);
pub const AUTH_SASL: AuthTag = AuthTag(10);
pub const AUTH_SASL_CONT: AuthTag = AuthTag(11);
pub const AUTH_SASL_FINAL: AuthTag = AuthTag(12);
pub const SQLSTATE_INTERNAL_ERROR: ErrorCode = ErrorCode(*b"XX000");
pub const CONNECTION_EXCEPTION: ErrorCode = ErrorCode(*b"08000");
pub const CONNECTION_DOES_NOT_EXIST: ErrorCode = ErrorCode(*b"08003");
pub const CONNECTION_FAILURE: ErrorCode = ErrorCode(*b"08006");
pub const SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION: ErrorCode = ErrorCode(*b"08001");
pub const PROTOCOL_VIOLATION: ErrorCode = ErrorCode(*b"08P01");
pub const INVALID_PARAMETER_VALUE: ErrorCode = ErrorCode(*b"22023");
pub const INVALID_CATALOG_NAME: ErrorCode = ErrorCode(*b"3D000");
pub const INVALID_SCHEMA_NAME: ErrorCode = ErrorCode(*b"3F000");
pub const T_R_SERIALIZATION_FAILURE: ErrorCode = ErrorCode(*b"40001");
pub const SYNTAX_ERROR: ErrorCode = ErrorCode(*b"42601");
pub const OUT_OF_MEMORY: ErrorCode = ErrorCode(*b"53200");
pub const TOO_MANY_CONNECTIONS: ErrorCode = ErrorCode(*b"53300");
impl fmt::Display for AuthTag {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.0 {
0 => f.write_str("Ok"),
2 => f.write_str("KerberosV5"),
3 => f.write_str("CleartextPassword"),
5 => f.write_str("MD5Password"),
7 => f.write_str("GSS"),
8 => f.write_str("GSSContinue"),
9 => f.write_str("SSPI"),
10 => f.write_str("SASL"),
11 => f.write_str("SASLContinue"),
12 => f.write_str("SASLFinal"),
x => write!(f, "{x}"),
}
}
}
impl fmt::Display for BeTag {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
BE_AUTH_MESSAGE => f.write_str("Authentication"),
BE_KEY_MESSAGE => f.write_str("BackendKeyData"),
BE_ERR_MESSAGE => f.write_str("ErrorResponse"),
BE_READY_MESSAGE => f.write_str("ReadyForQuery"),
BE_NEGOTIATE_MESSAGE => f.write_str("NegotiateProtocolVersion"),
BeTag(b'S') => f.write_str("ParameterStatus"),
BeTag(b'N') => f.write_str("NoticeMessage"),
BeTag(x) => write!(f, "{:?}", char::from(x)),
}
}
}
impl fmt::Display for FeTag {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
FE_PASSWORD_MESSAGE => f.write_str("Password"),
FeTag(x) => write!(f, "{:?}", char::from(x)),
}
}
}
/// The protocol version number.
///
@@ -313,6 +389,19 @@ impl WriteBuf {
self.0.set_position(0);
}
/// Write a startup message.
pub fn startup(&mut self, params: &StartupMessageParams) {
self.0.get_mut().extend_from_slice(
StartupHeader {
len: big_endian::U32::new(params.params.len() as u32 + 9),
version: ProtocolVersion::new(3, 0),
}
.as_bytes(),
);
self.0.get_mut().extend_from_slice(params.params.as_bytes());
self.0.get_mut().push(0);
}
/// Write a raw message to the internal buffer.
///
/// The size_hint value is only a hint for reserving space. It's ok if it's incorrect, since
@@ -353,7 +442,7 @@ impl WriteBuf {
// Code: error_code
buf.put_u8(b'C');
buf.put_slice(&error_code);
buf.put_slice(&error_code.0);
buf.put_u8(0);
// Message: msg
@@ -494,17 +583,17 @@ impl BeMessage<'_> {
match self {
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONCLEARTEXTPASSWORD>
BeMessage::AuthenticationOk => {
buf.write_raw(1, b'R', |buf| buf.put_i32(0));
buf.write_raw(1, BE_AUTH_MESSAGE.0, |buf| buf.put_i32(0));
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONCLEARTEXTPASSWORD>
BeMessage::AuthenticationCleartextPassword => {
buf.write_raw(1, b'R', |buf| buf.put_i32(3));
buf.write_raw(1, BE_AUTH_MESSAGE.0, |buf| buf.put_i32(3));
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(methods)) => {
let len: usize = methods.iter().map(|m| m.len() + 1).sum();
buf.write_raw(len + 2, b'R', |buf| {
buf.write_raw(len + 2, BE_AUTH_MESSAGE.0, |buf| {
buf.put_i32(10); // Specifies that SASL auth method is used.
for method in methods {
buf.put_slice(method.as_bytes());
@@ -515,14 +604,14 @@ impl BeMessage<'_> {
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Continue(extra)) => {
buf.write_raw(extra.len() + 1, b'R', |buf| {
buf.write_raw(extra.len() + 1, BE_AUTH_MESSAGE.0, |buf| {
buf.put_i32(11); // Continue SASL auth.
buf.put_slice(extra);
});
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Final(extra)) => {
buf.write_raw(extra.len() + 1, b'R', |buf| {
buf.write_raw(extra.len() + 1, BE_AUTH_MESSAGE.0, |buf| {
buf.put_i32(12); // Send final SASL message.
buf.put_slice(extra);
});
@@ -530,7 +619,9 @@ impl BeMessage<'_> {
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-BACKENDKEYDATA>
BeMessage::BackendKeyData(key_data) => {
buf.write_raw(8, b'K', |buf| buf.put_slice(key_data.as_bytes()));
buf.write_raw(8, BE_KEY_MESSAGE.0, |buf| {
buf.put_slice(key_data.as_bytes())
});
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NOTICERESPONSE>
@@ -566,13 +657,13 @@ impl BeMessage<'_> {
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NEGOTIATEPROTOCOLVERSION>
BeMessage::ReadyForQuery => {
buf.write_raw(1, b'Z', |buf| buf.put_u8(b'I'));
buf.write_raw(1, BE_READY_MESSAGE.0, |buf| buf.put_u8(b'I'));
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NEGOTIATEPROTOCOLVERSION>
BeMessage::NegotiateProtocolVersion { version, options } => {
let len: usize = options.iter().map(|o| o.len() + 1).sum();
buf.write_raw(8 + len, b'v', |buf| {
buf.write_raw(8 + len, BE_NEGOTIATE_MESSAGE.0, |buf| {
buf.put_slice(version.as_bytes());
buf.put_u32(options.len() as u32);
for option in options {

View File

@@ -90,7 +90,7 @@ impl<S: AsyncRead + Unpin> PqStream<S> {
// and SASL SCRAM messages are no longer than 256 bytes in my testing
// (a few hashes and random bytes, encoded into base64).
const MAX_PASSWORD_LENGTH: u32 = 512;
self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH)
self.read_raw_expect(FE_PASSWORD_MESSAGE.0, MAX_PASSWORD_LENGTH)
.await
}
}