mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-08 05:52:55 +00:00
proxy: merge AuthError and AuthErrorImpl (#9418)
Since GetAuthInfoError now boxes the ControlPlaneError message the variant is not big anymore and AuthError is 32 bytes.
This commit is contained in:
@@ -9,7 +9,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
|
||||
use super::backend::ComputeCredentialKeys;
|
||||
use super::{AuthErrorImpl, PasswordHackPayload};
|
||||
use super::{AuthError, PasswordHackPayload};
|
||||
use crate::config::TlsServerEndPoint;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::AuthSecret;
|
||||
@@ -117,14 +117,14 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
let password = msg
|
||||
.strip_suffix(&[0])
|
||||
.ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
|
||||
.ok_or(AuthError::MalformedPassword("missing terminator"))?;
|
||||
|
||||
let payload = PasswordHackPayload::parse(password)
|
||||
// If we ended up here and the payload is malformed, it means that
|
||||
// the user neither enabled SNI nor resorted to any other method
|
||||
// for passing the project name we rely on. We should show them
|
||||
// the most helpful error message and point to the documentation.
|
||||
.ok_or(AuthErrorImpl::MissingEndpointName)?;
|
||||
.ok_or(AuthError::MissingEndpointName)?;
|
||||
|
||||
Ok(payload)
|
||||
}
|
||||
@@ -136,7 +136,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
let password = msg
|
||||
.strip_suffix(&[0])
|
||||
.ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
|
||||
.ok_or(AuthError::MalformedPassword("missing terminator"))?;
|
||||
|
||||
let outcome = validate_password_and_exchange(
|
||||
&self.state.pool,
|
||||
@@ -166,7 +166,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
// Initial client message contains the chosen auth method's name.
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
let sasl = sasl::FirstMessage::parse(&msg)
|
||||
.ok_or(AuthErrorImpl::MalformedPassword("bad sasl message"))?;
|
||||
.ok_or(AuthError::MalformedPassword("bad sasl message"))?;
|
||||
|
||||
// Currently, the only supported SASL method is SCRAM.
|
||||
if !scram::METHODS.contains(&sasl.method) {
|
||||
|
||||
@@ -29,7 +29,7 @@ pub(crate) type Result<T> = std::result::Result<T, AuthError>;
|
||||
|
||||
/// Common authentication error.
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum AuthErrorImpl {
|
||||
pub(crate) enum AuthError {
|
||||
#[error(transparent)]
|
||||
Web(#[from] backend::WebAuthError),
|
||||
|
||||
@@ -78,80 +78,70 @@ pub(crate) enum AuthErrorImpl {
|
||||
ConfirmationTimeout(humantime::Duration),
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error(transparent)]
|
||||
pub(crate) struct AuthError(Box<AuthErrorImpl>);
|
||||
|
||||
impl AuthError {
|
||||
pub(crate) fn bad_auth_method(name: impl Into<Box<str>>) -> Self {
|
||||
AuthErrorImpl::BadAuthMethod(name.into()).into()
|
||||
AuthError::BadAuthMethod(name.into())
|
||||
}
|
||||
|
||||
pub(crate) fn auth_failed(user: impl Into<Box<str>>) -> Self {
|
||||
AuthErrorImpl::AuthFailed(user.into()).into()
|
||||
AuthError::AuthFailed(user.into())
|
||||
}
|
||||
|
||||
pub(crate) fn ip_address_not_allowed(ip: IpAddr) -> Self {
|
||||
AuthErrorImpl::IpAddressNotAllowed(ip).into()
|
||||
AuthError::IpAddressNotAllowed(ip)
|
||||
}
|
||||
|
||||
pub(crate) fn too_many_connections() -> Self {
|
||||
AuthErrorImpl::TooManyConnections.into()
|
||||
AuthError::TooManyConnections
|
||||
}
|
||||
|
||||
pub(crate) fn is_auth_failed(&self) -> bool {
|
||||
matches!(self.0.as_ref(), AuthErrorImpl::AuthFailed(_))
|
||||
matches!(self, AuthError::AuthFailed(_))
|
||||
}
|
||||
|
||||
pub(crate) fn user_timeout(elapsed: Elapsed) -> Self {
|
||||
AuthErrorImpl::UserTimeout(elapsed).into()
|
||||
AuthError::UserTimeout(elapsed)
|
||||
}
|
||||
|
||||
pub(crate) fn confirmation_timeout(timeout: humantime::Duration) -> Self {
|
||||
AuthErrorImpl::ConfirmationTimeout(timeout).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Into<AuthErrorImpl>> From<E> for AuthError {
|
||||
fn from(e: E) -> Self {
|
||||
Self(Box::new(e.into()))
|
||||
AuthError::ConfirmationTimeout(timeout)
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for AuthError {
|
||||
fn to_string_client(&self) -> String {
|
||||
match self.0.as_ref() {
|
||||
AuthErrorImpl::Web(e) => e.to_string_client(),
|
||||
AuthErrorImpl::GetAuthInfo(e) => e.to_string_client(),
|
||||
AuthErrorImpl::Sasl(e) => e.to_string_client(),
|
||||
AuthErrorImpl::AuthFailed(_) => self.to_string(),
|
||||
AuthErrorImpl::BadAuthMethod(_) => self.to_string(),
|
||||
AuthErrorImpl::MalformedPassword(_) => self.to_string(),
|
||||
AuthErrorImpl::MissingEndpointName => self.to_string(),
|
||||
AuthErrorImpl::Io(_) => "Internal error".to_string(),
|
||||
AuthErrorImpl::IpAddressNotAllowed(_) => self.to_string(),
|
||||
AuthErrorImpl::TooManyConnections => self.to_string(),
|
||||
AuthErrorImpl::UserTimeout(_) => self.to_string(),
|
||||
AuthErrorImpl::ConfirmationTimeout(_) => self.to_string(),
|
||||
match self {
|
||||
Self::Web(e) => e.to_string_client(),
|
||||
Self::GetAuthInfo(e) => e.to_string_client(),
|
||||
Self::Sasl(e) => e.to_string_client(),
|
||||
Self::AuthFailed(_) => self.to_string(),
|
||||
Self::BadAuthMethod(_) => self.to_string(),
|
||||
Self::MalformedPassword(_) => self.to_string(),
|
||||
Self::MissingEndpointName => self.to_string(),
|
||||
Self::Io(_) => "Internal error".to_string(),
|
||||
Self::IpAddressNotAllowed(_) => self.to_string(),
|
||||
Self::TooManyConnections => self.to_string(),
|
||||
Self::UserTimeout(_) => self.to_string(),
|
||||
Self::ConfirmationTimeout(_) => self.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for AuthError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self.0.as_ref() {
|
||||
AuthErrorImpl::Web(e) => e.get_error_kind(),
|
||||
AuthErrorImpl::GetAuthInfo(e) => e.get_error_kind(),
|
||||
AuthErrorImpl::Sasl(e) => e.get_error_kind(),
|
||||
AuthErrorImpl::AuthFailed(_) => crate::error::ErrorKind::User,
|
||||
AuthErrorImpl::BadAuthMethod(_) => crate::error::ErrorKind::User,
|
||||
AuthErrorImpl::MalformedPassword(_) => crate::error::ErrorKind::User,
|
||||
AuthErrorImpl::MissingEndpointName => crate::error::ErrorKind::User,
|
||||
AuthErrorImpl::Io(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
AuthErrorImpl::IpAddressNotAllowed(_) => crate::error::ErrorKind::User,
|
||||
AuthErrorImpl::TooManyConnections => crate::error::ErrorKind::RateLimit,
|
||||
AuthErrorImpl::UserTimeout(_) => crate::error::ErrorKind::User,
|
||||
AuthErrorImpl::ConfirmationTimeout(_) => crate::error::ErrorKind::User,
|
||||
match self {
|
||||
Self::Web(e) => e.get_error_kind(),
|
||||
Self::GetAuthInfo(e) => e.get_error_kind(),
|
||||
Self::Sasl(e) => e.get_error_kind(),
|
||||
Self::AuthFailed(_) => crate::error::ErrorKind::User,
|
||||
Self::BadAuthMethod(_) => crate::error::ErrorKind::User,
|
||||
Self::MalformedPassword(_) => crate::error::ErrorKind::User,
|
||||
Self::MissingEndpointName => crate::error::ErrorKind::User,
|
||||
Self::Io(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
Self::IpAddressNotAllowed(_) => crate::error::ErrorKind::User,
|
||||
Self::TooManyConnections => crate::error::ErrorKind::RateLimit,
|
||||
Self::UserTimeout(_) => crate::error::ErrorKind::User,
|
||||
Self::ConfirmationTimeout(_) => crate::error::ErrorKind::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user