[proxy] sasl::Mechanism may return Output during exchange

This is needed to forward the `ClientKey` that's required
to connect the proxy to a compute.

Co-authored-by: bojanserafimov <bojan.serafimov7@gmail.com>
This commit is contained in:
Dmitry Ivanov
2022-04-12 01:12:07 +03:00
committed by Stas Kelvich
parent 4b1bd32e4a
commit 9df8915b03
5 changed files with 30 additions and 11 deletions

View File

@@ -39,9 +39,20 @@ pub enum Error {
/// A convenient result type for SASL exchange.
pub type Result<T> = std::result::Result<T, Error>;
/// A result of one SASL exchange.
pub enum Step<T, R> {
/// We should continue exchanging messages.
Continue(T),
/// The client has been authenticated successfully.
Authenticated(R),
}
/// Every SASL mechanism (e.g. [SCRAM](crate::scram)) is expected to implement this trait.
pub trait Mechanism: Sized {
/// What's produced as a result of successful authentication.
type Output;
/// Produce a server challenge to be sent to the client.
/// This is how this method is called in PostgreSQL (`libpq/sasl.h`).
fn exchange(self, input: &str) -> Result<(Option<Self>, String)>;
fn exchange(self, input: &str) -> Result<(Step<Self, Self::Output>, String)>;
}

View File

@@ -49,6 +49,7 @@ impl<'a> ServerMessage<&'a str> {
})
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -51,18 +51,23 @@ impl<S: AsyncWrite + Unpin> SaslStream<'_, S> {
impl<S: AsyncRead + AsyncWrite + Unpin> SaslStream<'_, S> {
/// Perform SASL message exchange according to the underlying algorithm
/// until user is either authenticated or denied access.
pub async fn authenticate(mut self, mut mechanism: impl Mechanism) -> super::Result<()> {
pub async fn authenticate<M: Mechanism>(
mut self,
mut mechanism: M,
) -> super::Result<M::Output> {
loop {
let input = self.recv().await?;
let (moved, reply) = mechanism.exchange(input)?;
use super::Step::*;
match moved {
Some(moved) => {
Continue(moved) => {
self.send(&ServerMessage::Continue(&reply)).await?;
mechanism = moved;
}
None => {
Authenticated(result) => {
self.send(&ServerMessage::Final(&reply)).await?;
return Ok(());
return Ok(result);
}
}
}

View File

@@ -13,10 +13,10 @@ mod password;
mod secret;
mod signature;
pub use secret::*;
pub use exchange::Exchange;
pub use key::ScramKey;
pub use secret::ServerSecret;
pub use secret::*;
use hmac::{Hmac, Mac};
use sha2::{Digest, Sha256};

View File

@@ -62,8 +62,10 @@ impl<'a> Exchange<'a> {
}
impl sasl::Mechanism for Exchange<'_> {
fn exchange(mut self, input: &str) -> sasl::Result<(Option<Self>, String)> {
use ExchangeState::*;
type Output = super::ScramKey;
fn exchange(mut self, input: &str) -> sasl::Result<(sasl::Step<Self, Self::Output>, String)> {
use {sasl::Step::*, ExchangeState::*};
match &self.state {
Initial => {
let client_first_message =
@@ -82,7 +84,7 @@ impl sasl::Mechanism for Exchange<'_> {
server_first_message,
};
Ok((Some(self), msg))
Ok((Continue(self), msg))
}
SaltSent {
cbind_flag,
@@ -124,7 +126,7 @@ impl sasl::Mechanism for Exchange<'_> {
let msg = client_final_message
.build_server_final_message(signature_builder, &self.secret.server_key);
Ok((None, msg))
Ok((Authenticated(client_key), msg))
}
}
}