diff --git a/proxy/src/pqproto.rs b/proxy/src/pqproto.rs index ad99eecda5..3af7560b7c 100644 --- a/proxy/src/pqproto.rs +++ b/proxy/src/pqproto.rs @@ -11,11 +11,87 @@ use rand::distributions::{Distribution, Standard}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian}; -pub type ErrorCode = [u8; 5]; +#[derive(Copy, Clone, PartialEq)] +pub struct ErrorCode(pub [u8; 5]); -pub const FE_PASSWORD_MESSAGE: u8 = b'p'; +#[derive(Copy, Clone, PartialEq)] +pub struct FeTag(pub u8); -pub const SQLSTATE_INTERNAL_ERROR: [u8; 5] = *b"XX000"; +#[derive(Copy, Clone, PartialEq)] +pub struct BeTag(pub u8); + +#[derive(Copy, Clone, PartialEq)] +pub struct AuthTag(pub i32); + +pub const FE_PASSWORD_MESSAGE: FeTag = FeTag(b'p'); + +pub const BE_AUTH_MESSAGE: BeTag = BeTag(b'R'); +pub const BE_ERR_MESSAGE: BeTag = BeTag(b'E'); +pub const BE_KEY_MESSAGE: BeTag = BeTag(b'K'); +pub const BE_READY_MESSAGE: BeTag = BeTag(b'Z'); +pub const BE_NEGOTIATE_MESSAGE: BeTag = BeTag(b'v'); + +pub const AUTH_OK: AuthTag = AuthTag(0); +pub const AUTH_CLEAR: AuthTag = AuthTag(3); +pub const AUTH_SASL: AuthTag = AuthTag(10); +pub const AUTH_SASL_CONT: AuthTag = AuthTag(11); +pub const AUTH_SASL_FINAL: AuthTag = AuthTag(12); + +pub const SQLSTATE_INTERNAL_ERROR: ErrorCode = ErrorCode(*b"XX000"); +pub const CONNECTION_EXCEPTION: ErrorCode = ErrorCode(*b"08000"); +pub const CONNECTION_DOES_NOT_EXIST: ErrorCode = ErrorCode(*b"08003"); +pub const CONNECTION_FAILURE: ErrorCode = ErrorCode(*b"08006"); +pub const SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION: ErrorCode = ErrorCode(*b"08001"); +pub const PROTOCOL_VIOLATION: ErrorCode = ErrorCode(*b"08P01"); +pub const INVALID_PARAMETER_VALUE: ErrorCode = ErrorCode(*b"22023"); +pub const INVALID_CATALOG_NAME: ErrorCode = ErrorCode(*b"3D000"); +pub const INVALID_SCHEMA_NAME: ErrorCode = ErrorCode(*b"3F000"); +pub const T_R_SERIALIZATION_FAILURE: ErrorCode = ErrorCode(*b"40001"); +pub const SYNTAX_ERROR: ErrorCode = ErrorCode(*b"42601"); +pub const OUT_OF_MEMORY: ErrorCode = ErrorCode(*b"53200"); +pub const TOO_MANY_CONNECTIONS: ErrorCode = ErrorCode(*b"53300"); + +impl fmt::Display for AuthTag { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.0 { + 0 => f.write_str("Ok"), + 2 => f.write_str("KerberosV5"), + 3 => f.write_str("CleartextPassword"), + 5 => f.write_str("MD5Password"), + 7 => f.write_str("GSS"), + 8 => f.write_str("GSSContinue"), + 9 => f.write_str("SSPI"), + 10 => f.write_str("SASL"), + 11 => f.write_str("SASLContinue"), + 12 => f.write_str("SASLFinal"), + x => write!(f, "{x}"), + } + } +} + +impl fmt::Display for BeTag { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + BE_AUTH_MESSAGE => f.write_str("Authentication"), + BE_KEY_MESSAGE => f.write_str("BackendKeyData"), + BE_ERR_MESSAGE => f.write_str("ErrorResponse"), + BE_READY_MESSAGE => f.write_str("ReadyForQuery"), + BE_NEGOTIATE_MESSAGE => f.write_str("NegotiateProtocolVersion"), + BeTag(b'S') => f.write_str("ParameterStatus"), + BeTag(b'N') => f.write_str("NoticeMessage"), + BeTag(x) => write!(f, "{:?}", char::from(x)), + } + } +} + +impl fmt::Display for FeTag { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + FE_PASSWORD_MESSAGE => f.write_str("Password"), + FeTag(x) => write!(f, "{:?}", char::from(x)), + } + } +} /// The protocol version number. /// @@ -313,6 +389,19 @@ impl WriteBuf { self.0.set_position(0); } + /// Write a startup message. + pub fn startup(&mut self, params: &StartupMessageParams) { + self.0.get_mut().extend_from_slice( + StartupHeader { + len: big_endian::U32::new(params.params.len() as u32 + 9), + version: ProtocolVersion::new(3, 0), + } + .as_bytes(), + ); + self.0.get_mut().extend_from_slice(params.params.as_bytes()); + self.0.get_mut().push(0); + } + /// Write a raw message to the internal buffer. /// /// The size_hint value is only a hint for reserving space. It's ok if it's incorrect, since @@ -353,7 +442,7 @@ impl WriteBuf { // Code: error_code buf.put_u8(b'C'); - buf.put_slice(&error_code); + buf.put_slice(&error_code.0); buf.put_u8(0); // Message: msg @@ -494,17 +583,17 @@ impl BeMessage<'_> { match self { // BeMessage::AuthenticationOk => { - buf.write_raw(1, b'R', |buf| buf.put_i32(0)); + buf.write_raw(1, BE_AUTH_MESSAGE.0, |buf| buf.put_i32(0)); } // BeMessage::AuthenticationCleartextPassword => { - buf.write_raw(1, b'R', |buf| buf.put_i32(3)); + buf.write_raw(1, BE_AUTH_MESSAGE.0, |buf| buf.put_i32(3)); } // BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(methods)) => { let len: usize = methods.iter().map(|m| m.len() + 1).sum(); - buf.write_raw(len + 2, b'R', |buf| { + buf.write_raw(len + 2, BE_AUTH_MESSAGE.0, |buf| { buf.put_i32(10); // Specifies that SASL auth method is used. for method in methods { buf.put_slice(method.as_bytes()); @@ -515,14 +604,14 @@ impl BeMessage<'_> { } // BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Continue(extra)) => { - buf.write_raw(extra.len() + 1, b'R', |buf| { + buf.write_raw(extra.len() + 1, BE_AUTH_MESSAGE.0, |buf| { buf.put_i32(11); // Continue SASL auth. buf.put_slice(extra); }); } // BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Final(extra)) => { - buf.write_raw(extra.len() + 1, b'R', |buf| { + buf.write_raw(extra.len() + 1, BE_AUTH_MESSAGE.0, |buf| { buf.put_i32(12); // Send final SASL message. buf.put_slice(extra); }); @@ -530,7 +619,9 @@ impl BeMessage<'_> { // BeMessage::BackendKeyData(key_data) => { - buf.write_raw(8, b'K', |buf| buf.put_slice(key_data.as_bytes())); + buf.write_raw(8, BE_KEY_MESSAGE.0, |buf| { + buf.put_slice(key_data.as_bytes()) + }); } // @@ -566,13 +657,13 @@ impl BeMessage<'_> { // BeMessage::ReadyForQuery => { - buf.write_raw(1, b'Z', |buf| buf.put_u8(b'I')); + buf.write_raw(1, BE_READY_MESSAGE.0, |buf| buf.put_u8(b'I')); } // BeMessage::NegotiateProtocolVersion { version, options } => { let len: usize = options.iter().map(|o| o.len() + 1).sum(); - buf.write_raw(8 + len, b'v', |buf| { + buf.write_raw(8 + len, BE_NEGOTIATE_MESSAGE.0, |buf| { buf.put_slice(version.as_bytes()); buf.put_u32(options.len() as u32); for option in options { diff --git a/proxy/src/stream.rs b/proxy/src/stream/mod.rs similarity index 99% rename from proxy/src/stream.rs rename to proxy/src/stream/mod.rs index c49a431c95..676bd26116 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream/mod.rs @@ -90,7 +90,7 @@ impl PqStream { // and SASL SCRAM messages are no longer than 256 bytes in my testing // (a few hashes and random bytes, encoded into base64). const MAX_PASSWORD_LENGTH: u32 = 512; - self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH) + self.read_raw_expect(FE_PASSWORD_MESSAGE.0, MAX_PASSWORD_LENGTH) .await } }