mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-14 08:52:56 +00:00
add some more typesafety to pqproto and move stream.rs to a folder
This commit is contained in:
@@ -11,11 +11,87 @@ use rand::distributions::{Distribution, Standard};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian};
|
||||
|
||||
pub type ErrorCode = [u8; 5];
|
||||
#[derive(Copy, Clone, PartialEq)]
|
||||
pub struct ErrorCode(pub [u8; 5]);
|
||||
|
||||
pub const FE_PASSWORD_MESSAGE: u8 = b'p';
|
||||
#[derive(Copy, Clone, PartialEq)]
|
||||
pub struct FeTag(pub u8);
|
||||
|
||||
pub const SQLSTATE_INTERNAL_ERROR: [u8; 5] = *b"XX000";
|
||||
#[derive(Copy, Clone, PartialEq)]
|
||||
pub struct BeTag(pub u8);
|
||||
|
||||
#[derive(Copy, Clone, PartialEq)]
|
||||
pub struct AuthTag(pub i32);
|
||||
|
||||
pub const FE_PASSWORD_MESSAGE: FeTag = FeTag(b'p');
|
||||
|
||||
pub const BE_AUTH_MESSAGE: BeTag = BeTag(b'R');
|
||||
pub const BE_ERR_MESSAGE: BeTag = BeTag(b'E');
|
||||
pub const BE_KEY_MESSAGE: BeTag = BeTag(b'K');
|
||||
pub const BE_READY_MESSAGE: BeTag = BeTag(b'Z');
|
||||
pub const BE_NEGOTIATE_MESSAGE: BeTag = BeTag(b'v');
|
||||
|
||||
pub const AUTH_OK: AuthTag = AuthTag(0);
|
||||
pub const AUTH_CLEAR: AuthTag = AuthTag(3);
|
||||
pub const AUTH_SASL: AuthTag = AuthTag(10);
|
||||
pub const AUTH_SASL_CONT: AuthTag = AuthTag(11);
|
||||
pub const AUTH_SASL_FINAL: AuthTag = AuthTag(12);
|
||||
|
||||
pub const SQLSTATE_INTERNAL_ERROR: ErrorCode = ErrorCode(*b"XX000");
|
||||
pub const CONNECTION_EXCEPTION: ErrorCode = ErrorCode(*b"08000");
|
||||
pub const CONNECTION_DOES_NOT_EXIST: ErrorCode = ErrorCode(*b"08003");
|
||||
pub const CONNECTION_FAILURE: ErrorCode = ErrorCode(*b"08006");
|
||||
pub const SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION: ErrorCode = ErrorCode(*b"08001");
|
||||
pub const PROTOCOL_VIOLATION: ErrorCode = ErrorCode(*b"08P01");
|
||||
pub const INVALID_PARAMETER_VALUE: ErrorCode = ErrorCode(*b"22023");
|
||||
pub const INVALID_CATALOG_NAME: ErrorCode = ErrorCode(*b"3D000");
|
||||
pub const INVALID_SCHEMA_NAME: ErrorCode = ErrorCode(*b"3F000");
|
||||
pub const T_R_SERIALIZATION_FAILURE: ErrorCode = ErrorCode(*b"40001");
|
||||
pub const SYNTAX_ERROR: ErrorCode = ErrorCode(*b"42601");
|
||||
pub const OUT_OF_MEMORY: ErrorCode = ErrorCode(*b"53200");
|
||||
pub const TOO_MANY_CONNECTIONS: ErrorCode = ErrorCode(*b"53300");
|
||||
|
||||
impl fmt::Display for AuthTag {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self.0 {
|
||||
0 => f.write_str("Ok"),
|
||||
2 => f.write_str("KerberosV5"),
|
||||
3 => f.write_str("CleartextPassword"),
|
||||
5 => f.write_str("MD5Password"),
|
||||
7 => f.write_str("GSS"),
|
||||
8 => f.write_str("GSSContinue"),
|
||||
9 => f.write_str("SSPI"),
|
||||
10 => f.write_str("SASL"),
|
||||
11 => f.write_str("SASLContinue"),
|
||||
12 => f.write_str("SASLFinal"),
|
||||
x => write!(f, "{x}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for BeTag {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match *self {
|
||||
BE_AUTH_MESSAGE => f.write_str("Authentication"),
|
||||
BE_KEY_MESSAGE => f.write_str("BackendKeyData"),
|
||||
BE_ERR_MESSAGE => f.write_str("ErrorResponse"),
|
||||
BE_READY_MESSAGE => f.write_str("ReadyForQuery"),
|
||||
BE_NEGOTIATE_MESSAGE => f.write_str("NegotiateProtocolVersion"),
|
||||
BeTag(b'S') => f.write_str("ParameterStatus"),
|
||||
BeTag(b'N') => f.write_str("NoticeMessage"),
|
||||
BeTag(x) => write!(f, "{:?}", char::from(x)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for FeTag {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match *self {
|
||||
FE_PASSWORD_MESSAGE => f.write_str("Password"),
|
||||
FeTag(x) => write!(f, "{:?}", char::from(x)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The protocol version number.
|
||||
///
|
||||
@@ -313,6 +389,19 @@ impl WriteBuf {
|
||||
self.0.set_position(0);
|
||||
}
|
||||
|
||||
/// Write a startup message.
|
||||
pub fn startup(&mut self, params: &StartupMessageParams) {
|
||||
self.0.get_mut().extend_from_slice(
|
||||
StartupHeader {
|
||||
len: big_endian::U32::new(params.params.len() as u32 + 9),
|
||||
version: ProtocolVersion::new(3, 0),
|
||||
}
|
||||
.as_bytes(),
|
||||
);
|
||||
self.0.get_mut().extend_from_slice(params.params.as_bytes());
|
||||
self.0.get_mut().push(0);
|
||||
}
|
||||
|
||||
/// Write a raw message to the internal buffer.
|
||||
///
|
||||
/// The size_hint value is only a hint for reserving space. It's ok if it's incorrect, since
|
||||
@@ -353,7 +442,7 @@ impl WriteBuf {
|
||||
|
||||
// Code: error_code
|
||||
buf.put_u8(b'C');
|
||||
buf.put_slice(&error_code);
|
||||
buf.put_slice(&error_code.0);
|
||||
buf.put_u8(0);
|
||||
|
||||
// Message: msg
|
||||
@@ -494,17 +583,17 @@ impl BeMessage<'_> {
|
||||
match self {
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONCLEARTEXTPASSWORD>
|
||||
BeMessage::AuthenticationOk => {
|
||||
buf.write_raw(1, b'R', |buf| buf.put_i32(0));
|
||||
buf.write_raw(1, BE_AUTH_MESSAGE.0, |buf| buf.put_i32(0));
|
||||
}
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONCLEARTEXTPASSWORD>
|
||||
BeMessage::AuthenticationCleartextPassword => {
|
||||
buf.write_raw(1, b'R', |buf| buf.put_i32(3));
|
||||
buf.write_raw(1, BE_AUTH_MESSAGE.0, |buf| buf.put_i32(3));
|
||||
}
|
||||
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
|
||||
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(methods)) => {
|
||||
let len: usize = methods.iter().map(|m| m.len() + 1).sum();
|
||||
buf.write_raw(len + 2, b'R', |buf| {
|
||||
buf.write_raw(len + 2, BE_AUTH_MESSAGE.0, |buf| {
|
||||
buf.put_i32(10); // Specifies that SASL auth method is used.
|
||||
for method in methods {
|
||||
buf.put_slice(method.as_bytes());
|
||||
@@ -515,14 +604,14 @@ impl BeMessage<'_> {
|
||||
}
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
|
||||
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Continue(extra)) => {
|
||||
buf.write_raw(extra.len() + 1, b'R', |buf| {
|
||||
buf.write_raw(extra.len() + 1, BE_AUTH_MESSAGE.0, |buf| {
|
||||
buf.put_i32(11); // Continue SASL auth.
|
||||
buf.put_slice(extra);
|
||||
});
|
||||
}
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
|
||||
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Final(extra)) => {
|
||||
buf.write_raw(extra.len() + 1, b'R', |buf| {
|
||||
buf.write_raw(extra.len() + 1, BE_AUTH_MESSAGE.0, |buf| {
|
||||
buf.put_i32(12); // Send final SASL message.
|
||||
buf.put_slice(extra);
|
||||
});
|
||||
@@ -530,7 +619,9 @@ impl BeMessage<'_> {
|
||||
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-BACKENDKEYDATA>
|
||||
BeMessage::BackendKeyData(key_data) => {
|
||||
buf.write_raw(8, b'K', |buf| buf.put_slice(key_data.as_bytes()));
|
||||
buf.write_raw(8, BE_KEY_MESSAGE.0, |buf| {
|
||||
buf.put_slice(key_data.as_bytes())
|
||||
});
|
||||
}
|
||||
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NOTICERESPONSE>
|
||||
@@ -566,13 +657,13 @@ impl BeMessage<'_> {
|
||||
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NEGOTIATEPROTOCOLVERSION>
|
||||
BeMessage::ReadyForQuery => {
|
||||
buf.write_raw(1, b'Z', |buf| buf.put_u8(b'I'));
|
||||
buf.write_raw(1, BE_READY_MESSAGE.0, |buf| buf.put_u8(b'I'));
|
||||
}
|
||||
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NEGOTIATEPROTOCOLVERSION>
|
||||
BeMessage::NegotiateProtocolVersion { version, options } => {
|
||||
let len: usize = options.iter().map(|o| o.len() + 1).sum();
|
||||
buf.write_raw(8 + len, b'v', |buf| {
|
||||
buf.write_raw(8 + len, BE_NEGOTIATE_MESSAGE.0, |buf| {
|
||||
buf.put_slice(version.as_bytes());
|
||||
buf.put_u32(options.len() as u32);
|
||||
for option in options {
|
||||
|
||||
@@ -90,7 +90,7 @@ impl<S: AsyncRead + Unpin> PqStream<S> {
|
||||
// and SASL SCRAM messages are no longer than 256 bytes in my testing
|
||||
// (a few hashes and random bytes, encoded into base64).
|
||||
const MAX_PASSWORD_LENGTH: u32 = 512;
|
||||
self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH)
|
||||
self.read_raw_expect(FE_PASSWORD_MESSAGE.0, MAX_PASSWORD_LENGTH)
|
||||
.await
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user