From 2e047b64fdfb4ef911e3550a17859e5249aef73e Mon Sep 17 00:00:00 2001 From: Dmitry Ivanov Date: Thu, 15 Dec 2022 15:38:22 +0300 Subject: [PATCH] [WIP] Implement proper severity levels in pq_proto's ErrorResponse --- libs/pq_proto/src/lib.rs | 14 +++++++------- libs/utils/src/postgres_backend.rs | 8 ++++---- libs/utils/src/postgres_backend_async.rs | 8 ++++---- proxy/src/auth/backend/link.rs | 4 ++-- proxy/src/auth/flow.rs | 4 ++-- proxy/src/proxy.rs | 4 ++-- proxy/src/proxy/tests.rs | 2 +- proxy/src/sasl/messages.rs | 7 +++---- 8 files changed, 25 insertions(+), 26 deletions(-) diff --git a/libs/pq_proto/src/lib.rs b/libs/pq_proto/src/lib.rs index 2e311dd6e3..dbdecb89c7 100644 --- a/libs/pq_proto/src/lib.rs +++ b/libs/pq_proto/src/lib.rs @@ -444,7 +444,7 @@ impl FeCloseMessage { pub enum BeMessage<'a> { AuthenticationOk, AuthenticationMD5Password([u8; 4]), - AuthenticationSasl(BeAuthenticationSaslMessage<'a>), + AuthenticationSasl(SaslMessage<'a>), AuthenticationCleartextPassword, BackendKeyData(CancelKeyData), BindComplete, @@ -463,7 +463,7 @@ pub enum BeMessage<'a> { EncryptionResponse(bool), NoData, ParameterDescription, - ParameterStatus(BeParameterStatusMessage<'a>), + ParameterStatus(ParameterStatusMessage<'a>), ParseComplete, ReadyForQuery, RowDescription(&'a [RowDescriptor<'a>]), @@ -473,19 +473,19 @@ pub enum BeMessage<'a> { } #[derive(Debug)] -pub enum BeAuthenticationSaslMessage<'a> { +pub enum SaslMessage<'a> { Methods(&'a [&'a str]), Continue(&'a [u8]), Final(&'a [u8]), } #[derive(Debug)] -pub enum BeParameterStatusMessage<'a> { +pub enum ParameterStatusMessage<'a> { Encoding(&'a str), ServerVersion(&'a str), } -impl BeParameterStatusMessage<'static> { +impl ParameterStatusMessage<'static> { pub fn encoding() -> BeMessage<'static> { BeMessage::ParameterStatus(Self::Encoding("UTF8")) } @@ -639,7 +639,7 @@ impl<'a> BeMessage<'a> { BeMessage::AuthenticationSasl(msg) => { buf.put_u8(b'R'); write_body(buf, |buf| { - use BeAuthenticationSaslMessage::*; + use SaslMessage::*; match msg { Methods(methods) => { buf.put_i32(10); // Specifies that SASL auth method is used. @@ -801,7 +801,7 @@ impl<'a> BeMessage<'a> { BeMessage::ParameterStatus(param) => { use std::io::{IoSlice, Write}; - use BeParameterStatusMessage::*; + use ParameterStatusMessage::*; let [name, value] = match param { Encoding(name) => [b"client_encoding", name.as_bytes()], diff --git a/libs/utils/src/postgres_backend.rs b/libs/utils/src/postgres_backend.rs index 89f7197718..831ba812a0 100644 --- a/libs/utils/src/postgres_backend.rs +++ b/libs/utils/src/postgres_backend.rs @@ -6,7 +6,7 @@ use crate::sock_split::{BidiStream, ReadStream, WriteStream}; use anyhow::{bail, ensure, Context, Result}; use bytes::{Bytes, BytesMut}; -use pq_proto::{BeMessage, BeParameterStatusMessage, FeMessage, FeStartupPacket}; +use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ParameterStatusMessage}; use rand::Rng; use serde::{Deserialize, Serialize}; use std::fmt; @@ -361,10 +361,10 @@ impl PostgresBackend { match self.auth_type { AuthType::Trust => { self.write_message_noflush(&BeMessage::AuthenticationOk)? - .write_message_noflush(&BeParameterStatusMessage::encoding())? + .write_message_noflush(&ParameterStatusMessage::encoding())? // The async python driver requires a valid server_version .write_message_noflush(&BeMessage::ParameterStatus( - BeParameterStatusMessage::ServerVersion("14.1"), + ParameterStatusMessage::ServerVersion("14.1"), ))? .write_message(&BeMessage::ReadyForQuery)?; self.state = ProtoState::Established; @@ -413,7 +413,7 @@ impl PostgresBackend { } } self.write_message_noflush(&BeMessage::AuthenticationOk)? - .write_message_noflush(&BeParameterStatusMessage::encoding())? + .write_message_noflush(&ParameterStatusMessage::encoding())? .write_message(&BeMessage::ReadyForQuery)?; self.state = ProtoState::Established; } diff --git a/libs/utils/src/postgres_backend_async.rs b/libs/utils/src/postgres_backend_async.rs index 376819027b..85a2c225a0 100644 --- a/libs/utils/src/postgres_backend_async.rs +++ b/libs/utils/src/postgres_backend_async.rs @@ -6,7 +6,7 @@ use crate::postgres_backend::AuthType; use anyhow::{bail, Context, Result}; use bytes::{Bytes, BytesMut}; -use pq_proto::{BeMessage, BeParameterStatusMessage, FeMessage, FeStartupPacket}; +use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ParameterStatusMessage}; use rand::Rng; use std::future::Future; use std::net::SocketAddr; @@ -331,10 +331,10 @@ impl PostgresBackend { match self.auth_type { AuthType::Trust => { self.write_message(&BeMessage::AuthenticationOk)? - .write_message(&BeParameterStatusMessage::encoding())? + .write_message(&ParameterStatusMessage::encoding())? // The async python driver requires a valid server_version .write_message(&BeMessage::ParameterStatus( - BeParameterStatusMessage::ServerVersion("14.1"), + ParameterStatusMessage::ServerVersion("14.1"), ))? .write_message(&BeMessage::ReadyForQuery)?; self.state = ProtoState::Established; @@ -384,7 +384,7 @@ impl PostgresBackend { } } self.write_message(&BeMessage::AuthenticationOk)? - .write_message(&BeParameterStatusMessage::encoding())? + .write_message(&ParameterStatusMessage::encoding())? .write_message(&BeMessage::ReadyForQuery)?; self.state = ProtoState::Established; } diff --git a/proxy/src/auth/backend/link.rs b/proxy/src/auth/backend/link.rs index 440a55f194..1fdce876af 100644 --- a/proxy/src/auth/backend/link.rs +++ b/proxy/src/auth/backend/link.rs @@ -1,6 +1,6 @@ use super::{AuthSuccess, NodeInfo}; use crate::{auth, compute, error::UserFacingError, stream::PqStream, waiters}; -use pq_proto::{BeMessage as Be, BeParameterStatusMessage}; +use pq_proto::{BeMessage as Be, ParameterStatusMessage}; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{info, info_span}; @@ -60,7 +60,7 @@ pub async fn handle_user( info!(parent: &span, "sending the auth URL to the user"); client .write_message_noflush(&Be::AuthenticationOk)? - .write_message_noflush(&BeParameterStatusMessage::encoding())? + .write_message_noflush(&ParameterStatusMessage::encoding())? .write_message(&Be::NoticeResponse(&greeting)) .await?; diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index d9ee50894d..4a5edfb255 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -2,7 +2,7 @@ use super::{AuthErrorImpl, PasswordHackPayload}; use crate::{sasl, scram, stream::PqStream}; -use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be}; +use pq_proto::{BeMessage, BeMessage as Be, SaslMessage}; use std::io; use tokio::io::{AsyncRead, AsyncWrite}; @@ -22,7 +22,7 @@ pub struct Scram<'a>(pub &'a scram::ServerSecret); impl AuthMethod for Scram<'_> { #[inline(always)] fn first_message(&self) -> BeMessage<'_> { - Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS)) + Be::AuthenticationSasl(SaslMessage::Methods(scram::METHODS)) } } diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index da3cb144e3..4679d5a0f5 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -257,12 +257,12 @@ impl Client<'_, S> { if !auth_result.reported_auth_ok { stream .write_message_noflush(&Be::AuthenticationOk)? - .write_message_noflush(&BeParameterStatusMessage::encoding())?; + .write_message_noflush(&ParameterStatusMessage::encoding())?; } stream .write_message_noflush(&BeMessage::ParameterStatus( - BeParameterStatusMessage::ServerVersion(&db.version), + ParameterStatusMessage::ServerVersion(&db.version), ))? .write_message_noflush(&Be::BackendKeyData(cancel_key_data))? .write_message(&BeMessage::ReadyForQuery) diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 24fbc57b99..617f5c3cc9 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -139,7 +139,7 @@ async fn dummy_proxy( stream .write_message_noflush(&Be::AuthenticationOk)? - .write_message_noflush(&BeParameterStatusMessage::encoding())? + .write_message_noflush(&ParameterStatusMessage::encoding())? .write_message(&BeMessage::ReadyForQuery) .await?; diff --git a/proxy/src/sasl/messages.rs b/proxy/src/sasl/messages.rs index fb3833c8b6..d42ed9ebfe 100644 --- a/proxy/src/sasl/messages.rs +++ b/proxy/src/sasl/messages.rs @@ -1,7 +1,7 @@ //! Definitions for SASL messages. use crate::parse::{split_at_const, split_cstr}; -use pq_proto::{BeAuthenticationSaslMessage, BeMessage}; +use pq_proto::{BeMessage, SaslMessage}; /// SASL-specific payload of [`PasswordMessage`](pq_proto::FeMessage::PasswordMessage). #[derive(Debug)] @@ -42,10 +42,9 @@ pub(super) enum ServerMessage { impl<'a> ServerMessage<&'a str> { pub(super) fn to_reply(&self) -> BeMessage<'a> { - use BeAuthenticationSaslMessage::*; BeMessage::AuthenticationSasl(match self { - ServerMessage::Continue(s) => Continue(s.as_bytes()), - ServerMessage::Final(s) => Final(s.as_bytes()), + Self::Continue(s) => SaslMessage::Continue(s.as_bytes()), + Self::Final(s) => SaslMessage::Final(s.as_bytes()), }) } }