add new auth proxy backend with new codec

This commit is contained in:
Conrad Ludgate
2024-09-12 16:36:41 +01:00
parent f47401f2e9
commit 91e8b7d22b
12 changed files with 1035 additions and 10 deletions

View File

@@ -0,0 +1,270 @@
mod classic;
mod hacks;
use tracing::info;
use crate::auth::backend::{
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint,
};
use crate::auth::{self, ComputeUserInfoMaybeEndpoint};
use crate::auth_proxy::validate_password_and_exchange;
use crate::console::errors::GetAuthInfoError;
use crate::console::provider::{CachedRoleSecret, ConsoleBackend};
use crate::console::AuthSecret;
use crate::context::RequestMonitoring;
use crate::intern::EndpointIdInt;
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::scram;
use crate::stream::AuthProxyStreamExt;
use crate::{
config::AuthenticationConfig,
console::{
self,
provider::{CachedAllowedIps, CachedNodeInfo},
Api,
},
};
use super::AuthProxyStream;
/// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
pub enum MaybeOwned<'a, T> {
Owned(T),
Borrowed(&'a T),
}
impl<T> std::ops::Deref for MaybeOwned<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
match self {
MaybeOwned::Owned(t) => t,
MaybeOwned::Borrowed(t) => t,
}
}
}
/// This type serves two purposes:
///
/// * When `T` is `()`, it's just a regular auth backend selector
/// which we use in [`crate::config::ProxyConfig`].
///
/// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
/// this helps us provide the credentials only to those auth
/// backends which require them for the authentication process.
pub enum Backend<'a, T> {
/// Cloud API (V2).
Console(MaybeOwned<'a, ConsoleBackend>, T),
}
#[cfg(test)]
pub(crate) trait TestBackend: Send + Sync + 'static {
fn wake_compute(&self) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
fn get_allowed_ips_and_secret(
&self,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>;
}
impl std::fmt::Display for Backend<'_, ()> {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Console(api, ()) => match &**api {
ConsoleBackend::Console(endpoint) => {
fmt.debug_tuple("Console").field(&endpoint.url()).finish()
}
#[cfg(any(test, feature = "testing"))]
ConsoleBackend::Postgres(endpoint) => {
fmt.debug_tuple("Postgres").field(&endpoint.url()).finish()
}
#[cfg(test)]
ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(),
},
}
}
}
impl<T> Backend<'_, T> {
/// Very similar to [`std::option::Option::as_ref`].
/// This helps us pass structured config to async tasks.
pub(crate) fn as_ref(&self) -> Backend<'_, &T> {
match self {
Self::Console(c, x) => Backend::Console(MaybeOwned::Borrowed(c), x),
}
}
}
impl<'a, T> Backend<'a, T> {
/// Very similar to [`std::option::Option::map`].
/// Maps [`Backend<T>`] to [`Backend<R>`] by applying
/// a function to a contained value.
pub(crate) fn map<R>(self, f: impl FnOnce(T) -> R) -> Backend<'a, R> {
match self {
Self::Console(c, x) => Backend::Console(c, f(x)),
}
}
}
impl<'a, T, E> Backend<'a, Result<T, E>> {
/// Very similar to [`std::option::Option::transpose`].
/// This is most useful for error handling.
pub(crate) fn transpose(self) -> Result<Backend<'a, T>, E> {
match self {
Self::Console(c, x) => x.map(|x| Backend::Console(c, x)),
}
}
}
/// True to its name, this function encapsulates our current auth trade-offs.
/// Here, we choose the appropriate auth flow based on circumstances.
///
/// All authentication flows will emit an AuthenticationOk message if successful.
async fn auth_quirks(
api: &impl console::Api,
user_info: ComputeUserInfoMaybeEndpoint,
client: &mut AuthProxyStream,
config: &'static AuthenticationConfig,
) -> auth::Result<ComputeCredentials> {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name.
// We now expect to see a very specific payload in the place of password.
let (info, unauthenticated_password) = match user_info.try_into() {
Err(info) => {
let res = hacks::password_hack_no_authentication(info, client).await?;
let password = match res.keys {
ComputeCredentialKeys::Password(p) => p,
ComputeCredentialKeys::AuthKeys(_) | ComputeCredentialKeys::None => {
unreachable!("password hack should return a password")
}
};
(res.info, Some(password))
}
Ok(info) => (info, None),
};
info!("fetching user's authentication info");
let cached_secret = api
.get_role_secret(&RequestMonitoring::test(), &info)
.await?;
let (cached_entry, secret) = cached_secret.take_value();
let secret = if let Some(secret) = secret {
secret
} else {
// If we don't have an authentication secret, we mock one to
// prevent malicious probing (possible due to missing protocol steps).
// This mocked secret will never lead to successful authentication.
info!("authentication info not found, mocking it");
AuthSecret::Scram(scram::ServerSecret::mock(rand::random()))
};
match authenticate_with_secret(secret, info, client, unauthenticated_password, config).await {
Ok(keys) => Ok(keys),
Err(e) => {
if e.is_auth_failed() {
// The password could have been changed, so we invalidate the cache.
cached_entry.invalidate();
}
Err(e)
}
}
}
async fn authenticate_with_secret(
secret: AuthSecret,
info: ComputeUserInfo,
client: &mut AuthProxyStream,
unauthenticated_password: Option<Vec<u8>>,
config: &'static AuthenticationConfig,
) -> auth::Result<ComputeCredentials> {
if let Some(password) = unauthenticated_password {
let ep = EndpointIdInt::from(&info.endpoint);
let auth_outcome =
validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
let keys = match auth_outcome {
crate::sasl::Outcome::Success(key) => key,
crate::sasl::Outcome::Failure(reason) => {
info!("auth backend failed with an error: {reason}");
return Err(auth::AuthError::auth_failed(&*info.user));
}
};
// we have authenticated the password
client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?;
return Ok(ComputeCredentials { info, keys });
}
// Finally, proceed with the main auth flow (SCRAM-based).
classic::authenticate(info, client, config, secret).await
}
impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
/// Get username from the credentials.
pub(crate) fn get_user(&self) -> &str {
match self {
Self::Console(_, user_info) => &user_info.user,
}
}
pub(crate) async fn authenticate(
self,
client: &mut AuthProxyStream,
config: &'static AuthenticationConfig,
) -> auth::Result<Backend<'a, ComputeCredentials>> {
let res = match self {
Self::Console(api, user_info) => {
info!(
user = &*user_info.user,
project = user_info.endpoint(),
"performing authentication using the console"
);
let credentials = auth_quirks(&*api, user_info, client, config).await?;
Backend::Console(api, credentials)
}
};
info!("user successfully authenticated");
Ok(res)
}
}
impl Backend<'_, ComputeUserInfo> {
pub(crate) async fn get_role_secret(
&self,
ctx: &RequestMonitoring,
) -> Result<CachedRoleSecret, GetAuthInfoError> {
match self {
Self::Console(api, user_info) => api.get_role_secret(ctx, user_info).await,
}
}
pub(crate) async fn get_allowed_ips_and_secret(
&self,
ctx: &RequestMonitoring,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
match self {
Self::Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await,
}
}
}
#[async_trait::async_trait]
impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
async fn wake_compute(
&self,
ctx: &RequestMonitoring,
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
match self {
Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
}
}
fn get_keys(&self) -> &ComputeCredentialKeys {
match self {
Self::Console(_, creds) => &creds.keys,
}
}
}

View File

@@ -0,0 +1,69 @@
use super::{ComputeCredentials, ComputeUserInfo};
use crate::{
auth::{self, backend::ComputeCredentialKeys},
auth_proxy::{self, AuthFlow, AuthProxyStream},
compute,
config::AuthenticationConfig,
console::AuthSecret,
sasl,
};
use tracing::{info, warn};
pub(super) async fn authenticate(
creds: ComputeUserInfo,
client: &mut AuthProxyStream,
config: &'static AuthenticationConfig,
secret: AuthSecret,
) -> auth::Result<ComputeCredentials> {
let flow = AuthFlow::new(client);
let scram_keys = match secret {
#[cfg(any(test, feature = "testing"))]
AuthSecret::Md5(_) => {
info!("auth endpoint chooses MD5");
return Err(auth::AuthError::bad_auth_method("MD5"));
}
AuthSecret::Scram(secret) => {
info!("auth endpoint chooses SCRAM");
let scram = auth_proxy::Scram(&secret);
let auth_outcome = tokio::time::timeout(
config.scram_protocol_timeout,
async {
flow.begin(scram).await.map_err(|error| {
warn!(?error, "error sending scram acknowledgement");
error
})?.authenticate().await.map_err(|error| {
warn!(?error, "error processing scram messages");
error
})
}
)
.await
.map_err(|e| {
warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs());
auth::AuthError::user_timeout(e)
})??;
let client_key = match auth_outcome {
sasl::Outcome::Success(key) => key,
sasl::Outcome::Failure(reason) => {
info!("auth backend failed with an error: {reason}");
return Err(auth::AuthError::auth_failed(&*creds.user));
}
};
compute::ScramKeys {
client_key: client_key.as_bytes(),
server_key: secret.server_key.as_bytes(),
}
}
};
Ok(ComputeCredentials {
info: creds,
keys: ComputeCredentialKeys::AuthKeys(tokio_postgres::config::AuthKeys::ScramSha256(
scram_keys,
)),
})
}

View File

@@ -0,0 +1,77 @@
use super::{
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint,
};
use crate::{
auth,
auth_proxy::{self, AuthFlow, AuthProxyStream},
config::AuthenticationConfig,
console::AuthSecret,
intern::EndpointIdInt,
sasl,
};
use tracing::{info, warn};
/// Compared to [SCRAM](crate::scram), cleartext password auth saves
/// one round trip and *expensive* computations (>= 4096 HMAC iterations).
/// These properties are benefical for serverless JS workers, so we
/// use this mechanism for websocket connections.
pub(crate) async fn authenticate_cleartext(
info: ComputeUserInfo,
client: &mut AuthProxyStream,
secret: AuthSecret,
config: &'static AuthenticationConfig,
) -> auth::Result<ComputeCredentials> {
warn!("cleartext auth flow override is enabled, proceeding");
let ep = EndpointIdInt::from(&info.endpoint);
let auth_flow = AuthFlow::new(client)
.begin(auth_proxy::CleartextPassword {
secret,
endpoint: ep,
pool: config.thread_pool.clone(),
})
.await?;
// cleartext auth is only allowed to the ws/http protocol.
// If we're here, we already received the password in the first message.
// Scram protocol will be executed on the proxy side.
let auth_outcome = auth_flow.authenticate().await?;
let keys = match auth_outcome {
sasl::Outcome::Success(key) => key,
sasl::Outcome::Failure(reason) => {
info!("auth backend failed with an error: {reason}");
return Err(auth::AuthError::auth_failed(&*info.user));
}
};
Ok(ComputeCredentials { info, keys })
}
/// Workaround for clients which don't provide an endpoint (project) name.
/// Similar to [`authenticate_cleartext`], but there's a specific password format,
/// and passwords are not yet validated (we don't know how to validate them!)
pub(crate) async fn password_hack_no_authentication(
info: ComputeUserInfoNoEndpoint,
client: &mut AuthProxyStream,
) -> auth::Result<ComputeCredentials> {
warn!("project not specified, resorting to the password hack auth flow");
let payload = AuthFlow::new(client)
.begin(auth_proxy::PasswordHack)
.await?
.get_password()
.await?;
info!(project = &*payload.endpoint, "received missing parameter");
// Report tentative success; compute node will check the password anyway.
Ok(ComputeCredentials {
info: ComputeUserInfo {
user: info.user,
options: info.options,
endpoint: payload.endpoint,
},
keys: ComputeCredentialKeys::Password(payload.password),
})
}

View File

@@ -0,0 +1,218 @@
//! Main authentication flow.
use super::{AuthProxyStream, PasswordHackPayload};
use crate::{
auth::{self, backend::ComputeCredentialKeys, AuthErrorImpl},
config::TlsServerEndPoint,
console::AuthSecret,
intern::EndpointIdInt,
sasl,
scram::{self, threadpool::ThreadPool},
stream::AuthProxyStreamExt,
};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
use std::{io, sync::Arc};
use tracing::info;
/// Every authentication selector is supposed to implement this trait.
pub(crate) trait AuthMethod {
/// Any authentication selector should provide initial backend message
/// containing auth method name and parameters, e.g. md5 salt.
fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
}
/// Initial state of [`AuthFlow`].
pub(crate) struct Begin;
/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
pub(crate) struct Scram<'a>(pub(crate) &'a scram::ServerSecret);
impl AuthMethod for Scram<'_> {
#[inline(always)]
fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
if channel_binding {
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
} else {
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
scram::METHODS_WITHOUT_PLUS,
))
}
}
}
/// Use an ad hoc auth flow (for clients which don't support SNI) proposed in
/// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
pub(crate) struct PasswordHack;
impl AuthMethod for PasswordHack {
#[inline(always)]
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
Be::AuthenticationCleartextPassword
}
}
/// Use clear-text password auth called `password` in docs
/// <https://www.postgresql.org/docs/current/auth-password.html>
pub(crate) struct CleartextPassword {
pub(crate) pool: Arc<ThreadPool>,
pub(crate) endpoint: EndpointIdInt,
pub(crate) secret: AuthSecret,
}
impl AuthMethod for CleartextPassword {
#[inline(always)]
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
Be::AuthenticationCleartextPassword
}
}
/// This wrapper for [`PqStream`] performs client authentication.
#[must_use]
pub(crate) struct AuthFlow<'a, State> {
/// The underlying stream which implements libpq's protocol.
stream: &'a mut AuthProxyStream,
/// State might contain ancillary data (see [`Self::begin`]).
state: State,
tls_server_end_point: TlsServerEndPoint,
}
/// Initial state of the stream wrapper.
impl<'a> AuthFlow<'a, Begin> {
/// Create a new wrapper for client authentication.
pub(crate) fn new(stream: &'a mut AuthProxyStream) -> Self {
// TODO:
// let tls_server_end_point = stream.get_ref().tls_server_end_point();
let tls_server_end_point = TlsServerEndPoint::Undefined;
Self {
stream,
state: Begin,
tls_server_end_point,
}
}
/// Move to the next step by sending auth method's name & params to client.
pub(crate) async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, M>> {
self.stream
.write_message(&method.first_message(self.tls_server_end_point.supported()))
.await?;
Ok(AuthFlow {
stream: self.stream,
state: method,
tls_server_end_point: self.tls_server_end_point,
})
}
}
impl AuthFlow<'_, PasswordHack> {
/// Perform user authentication. Raise an error in case authentication failed.
pub(crate) async fn get_password(self) -> auth::Result<PasswordHackPayload> {
let msg = self.stream.read_password_message().await?;
let password = msg
.strip_suffix(&[0])
.ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
let payload = PasswordHackPayload::parse(password)
// If we ended up here and the payload is malformed, it means that
// the user neither enabled SNI nor resorted to any other method
// for passing the project name we rely on. We should show them
// the most helpful error message and point to the documentation.
.ok_or(AuthErrorImpl::MissingEndpointName)?;
Ok(payload)
}
}
impl AuthFlow<'_, CleartextPassword> {
/// Perform user authentication. Raise an error in case authentication failed.
pub(crate) async fn authenticate(self) -> auth::Result<sasl::Outcome<ComputeCredentialKeys>> {
let msg = self.stream.read_password_message().await?;
let password = msg
.strip_suffix(&[0])
.ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
let outcome = validate_password_and_exchange(
&self.state.pool,
self.state.endpoint,
password,
self.state.secret,
)
.await?;
if let sasl::Outcome::Success(_) = &outcome {
self.stream.write_message_noflush(&Be::AuthenticationOk)?;
}
Ok(outcome)
}
}
/// Stream wrapper for handling [SCRAM](crate::scram) auth.
impl AuthFlow<'_, Scram<'_>> {
/// Perform user authentication. Raise an error in case authentication failed.
pub(crate) async fn authenticate(self) -> auth::Result<sasl::Outcome<scram::ScramKey>> {
let Scram(secret) = self.state;
// Initial client message contains the chosen auth method's name.
let msg = self.stream.read_password_message().await?;
let sasl = sasl::FirstMessage::parse(&msg)
.ok_or(AuthErrorImpl::MalformedPassword("bad sasl message"))?;
// Currently, the only supported SASL method is SCRAM.
if !scram::METHODS.contains(&sasl.method) {
return Err(auth::AuthError::bad_auth_method(sasl.method));
}
info!("client chooses {}", sasl.method);
let outcome = sasl::SaslStream2::new(self.stream, sasl.message)
.authenticate(scram::Exchange::new(
secret,
rand::random,
self.tls_server_end_point,
))
.await?;
if let sasl::Outcome::Success(_) = &outcome {
self.stream.write_message_noflush(&Be::AuthenticationOk)?;
}
Ok(outcome)
}
}
pub(crate) async fn validate_password_and_exchange(
pool: &ThreadPool,
endpoint: EndpointIdInt,
password: &[u8],
secret: AuthSecret,
) -> auth::Result<sasl::Outcome<ComputeCredentialKeys>> {
match secret {
#[cfg(any(test, feature = "testing"))]
AuthSecret::Md5(_) => {
// test only
Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password(
password.to_owned(),
)))
}
// perform scram authentication as both client and server to validate the keys
AuthSecret::Scram(scram_secret) => {
let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
let client_key = match outcome {
sasl::Outcome::Success(client_key) => client_key,
sasl::Outcome::Failure(reason) => return Ok(sasl::Outcome::Failure(reason)),
};
let keys = crate::compute::ScramKeys {
client_key: client_key.as_bytes(),
server_key: scram_secret.server_key.as_bytes(),
};
Ok(sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys(
tokio_postgres::config::AuthKeys::ScramSha256(keys),
)))
}
}
}

View File

@@ -0,0 +1,17 @@
//! Client authentication mechanisms.
pub mod backend;
pub use backend::Backend;
mod password_hack;
use password_hack::PasswordHackPayload;
mod flow;
pub(crate) use flow::*;
use quinn::{RecvStream, SendStream};
use tokio::io::Join;
use tokio_util::codec::Framed;
use crate::PglbCodec;
pub type AuthProxyStream = Framed<Join<RecvStream, SendStream>, PglbCodec>;

View File

@@ -0,0 +1,121 @@
//! Payload for ad hoc authentication method for clients that don't support SNI.
//! See the `impl` for [`super::backend::Backend<ClientCredentials>`].
//! Read more: <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
//! UPDATE (Mon Aug 8 13:20:34 UTC 2022): the payload format has been simplified.
use bstr::ByteSlice;
use crate::EndpointId;
pub(crate) struct PasswordHackPayload {
pub(crate) endpoint: EndpointId,
pub(crate) password: Vec<u8>,
}
impl PasswordHackPayload {
pub(crate) fn parse(bytes: &[u8]) -> Option<Self> {
// The format is `project=<utf-8>;<password-bytes>` or `project=<utf-8>$<password-bytes>`.
let separators = [";", "$"];
for sep in separators {
if let Some((endpoint, password)) = bytes.split_once_str(sep) {
let endpoint = endpoint.to_str().ok()?;
return Some(Self {
endpoint: parse_endpoint_param(endpoint)?.into(),
password: password.to_owned(),
});
}
}
None
}
}
pub(crate) fn parse_endpoint_param(bytes: &str) -> Option<&str> {
bytes
.strip_prefix("project=")
.or_else(|| bytes.strip_prefix("endpoint="))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_endpoint_param_fn() {
let input = "";
assert!(parse_endpoint_param(input).is_none());
let input = "project=";
assert_eq!(parse_endpoint_param(input), Some(""));
let input = "project=foobar";
assert_eq!(parse_endpoint_param(input), Some("foobar"));
let input = "endpoint=";
assert_eq!(parse_endpoint_param(input), Some(""));
let input = "endpoint=foobar";
assert_eq!(parse_endpoint_param(input), Some("foobar"));
let input = "other_option=foobar";
assert!(parse_endpoint_param(input).is_none());
}
#[test]
fn parse_password_hack_payload_project() {
let bytes = b"";
assert!(PasswordHackPayload::parse(bytes).is_none());
let bytes = b"project=";
assert!(PasswordHackPayload::parse(bytes).is_none());
let bytes = b"project=;";
let payload: PasswordHackPayload =
PasswordHackPayload::parse(bytes).expect("parsing failed");
assert_eq!(payload.endpoint, "");
assert_eq!(payload.password, b"");
let bytes = b"project=foobar;pass;word";
let payload = PasswordHackPayload::parse(bytes).expect("parsing failed");
assert_eq!(payload.endpoint, "foobar");
assert_eq!(payload.password, b"pass;word");
}
#[test]
fn parse_password_hack_payload_endpoint() {
let bytes = b"";
assert!(PasswordHackPayload::parse(bytes).is_none());
let bytes = b"endpoint=";
assert!(PasswordHackPayload::parse(bytes).is_none());
let bytes = b"endpoint=;";
let payload = PasswordHackPayload::parse(bytes).expect("parsing failed");
assert_eq!(payload.endpoint, "");
assert_eq!(payload.password, b"");
let bytes = b"endpoint=foobar;pass;word";
let payload = PasswordHackPayload::parse(bytes).expect("parsing failed");
assert_eq!(payload.endpoint, "foobar");
assert_eq!(payload.password, b"pass;word");
}
#[test]
fn parse_password_hack_payload_dollar() {
let bytes = b"";
assert!(PasswordHackPayload::parse(bytes).is_none());
let bytes = b"endpoint=";
assert!(PasswordHackPayload::parse(bytes).is_none());
let bytes = b"endpoint=$";
let payload = PasswordHackPayload::parse(bytes).expect("parsing failed");
assert_eq!(payload.endpoint, "");
assert_eq!(payload.password, b"");
let bytes = b"endpoint=foobar$pass$word";
let payload = PasswordHackPayload::parse(bytes).expect("parsing failed");
assert_eq!(payload.endpoint, "foobar");
assert_eq!(payload.password, b"pass$word");
}
}

View File

@@ -1,6 +1,7 @@
use std::{sync::Arc, time::Duration};
use proxy::PglbCodec;
use futures::TryStreamExt;
use proxy::{PglbCodec, PglbControlMessage, PglbMessage};
use quinn::{
crypto::rustls::QuicClientConfig, rustls::client::danger, Endpoint, RecvStream, SendStream,
VarInt,
@@ -11,10 +12,7 @@ use tokio::{
signal::unix::{signal, SignalKind},
time::interval,
};
use tokio_util::{
codec::{Framed, FramedRead, FramedWrite},
task::TaskTracker,
};
use tokio_util::{codec::Framed, task::TaskTracker};
#[tokio::main]
async fn main() {
@@ -107,5 +105,11 @@ impl danger::ServerCertVerifier for NoVerify {
}
async fn handle_stream(send: SendStream, recv: RecvStream) {
let _stream = Framed::new(join(recv, send), PglbCodec);
let mut stream = Framed::new(join(recv, send), PglbCodec);
let first_msg = stream.try_next().await.unwrap();
let Some(PglbMessage::Control(PglbControlMessage::ConnectionInitiated(_first_msg))) = first_msg
else {
panic!("invalid first msg")
};
}

View File

@@ -125,7 +125,7 @@ impl RequestMonitoring {
Self(TryLock::new(inner))
}
#[cfg(test)]
// #[cfg(test)]
pub(crate) fn test() -> Self {
RequestMonitoring::new(Uuid::now_v7(), [127, 0, 0, 1].into(), Protocol::Tcp, "test")
}

View File

@@ -98,6 +98,7 @@ use tokio_util::sync::CancellationToken;
use tracing::warn;
pub mod auth;
pub mod auth_proxy;
pub mod cache;
pub mod cancellation;
pub mod compute;
@@ -405,7 +406,7 @@ pub enum PglbControlMessage {
#[derive(Serialize, Deserialize)]
pub struct ConnectionInitiatedPayload {
tls_server_end_point: TlsServerEndPoint,
server_name: Option<String>,
ip_addr: IpAddr,
pub tls_server_end_point: TlsServerEndPoint,
pub server_name: Option<String>,
pub ip_addr: IpAddr,
}

View File

@@ -9,6 +9,7 @@
mod channel_binding;
mod messages;
mod stream;
mod stream2;
use crate::error::{ReportableError, UserFacingError};
use std::io;
@@ -17,6 +18,7 @@ use thiserror::Error;
pub(crate) use channel_binding::ChannelBinding;
pub(crate) use messages::FirstMessage;
pub(crate) use stream::{Outcome, SaslStream};
pub(crate) use stream2::SaslStream2;
/// Fine-grained auth errors help in writing tests.
#[derive(Error, Debug)]

85
proxy/src/sasl/stream2.rs Normal file
View File

@@ -0,0 +1,85 @@
//! Abstraction for the string-oriented SASL protocols.
use crate::{
auth_proxy::AuthProxyStream,
sasl::{messages::ServerMessage, Mechanism},
stream::AuthProxyStreamExt,
};
use std::io;
use tracing::info;
use super::Outcome;
/// Abstracts away all peculiarities of the libpq's protocol.
pub(crate) struct SaslStream2<'a> {
/// The underlying stream.
stream: &'a mut AuthProxyStream,
/// Current password message we received from client.
current: bytes::Bytes,
/// First SASL message produced by client.
first: Option<&'a str>,
}
impl<'a> SaslStream2<'a> {
pub(crate) fn new(stream: &'a mut AuthProxyStream, first: &'a str) -> Self {
Self {
stream,
current: bytes::Bytes::new(),
first: Some(first),
}
}
}
impl SaslStream2<'_> {
// Receive a new SASL message from the client.
async fn recv(&mut self) -> io::Result<&str> {
if let Some(first) = self.first.take() {
return Ok(first);
}
self.current = self.stream.read_password_message().await?;
let s = std::str::from_utf8(&self.current)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
Ok(s)
}
}
impl SaslStream2<'_> {
// Send a SASL message to the client.
async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
self.stream.write_message(&msg.to_reply()).await?;
Ok(())
}
}
impl SaslStream2<'_> {
/// Perform SASL message exchange according to the underlying algorithm
/// until user is either authenticated or denied access.
pub(crate) async fn authenticate<M: Mechanism>(
mut self,
mut mechanism: M,
) -> crate::sasl::Result<Outcome<M::Output>> {
loop {
let input = self.recv().await?;
let step = mechanism.exchange(input).map_err(|error| {
info!(?error, "error during SASL exchange");
error
})?;
use crate::sasl::Step;
return Ok(match step {
Step::Continue(moved_mechanism, reply) => {
self.send(&ServerMessage::Continue(&reply)).await?;
mechanism = moved_mechanism;
continue;
}
Step::Success(result, reply) => {
self.send(&ServerMessage::Final(&reply)).await?;
Outcome::Success(result)
}
Step::Failure(reason) => Outcome::Failure(reason),
});
}
}
}

View File

@@ -1,8 +1,11 @@
use crate::auth_proxy::AuthProxyStream;
use crate::config::TlsServerEndPoint;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::Metrics;
use crate::PglbMessage;
use bytes::BytesMut;
use futures::{SinkExt, TryStreamExt};
use pq_proto::framed::{ConnectionError, Framed};
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
use rustls::ServerConfig;
@@ -294,3 +297,161 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
}
}
}
pub(crate) trait AuthProxyStreamExt {
/// Write the message into an internal buffer, but don't flush the underlying stream.
fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self>;
/// Write the message into an internal buffer and flush it.
async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self>;
// /// Flush the output buffer into the underlying stream.
// async fn flush(&mut self) -> io::Result<&mut Self>;
/// Write the error message using [`Self::write_message`], then re-throw it.
/// Allowing string literals is safe under the assumption they might not contain any runtime info.
/// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
async fn throw_error_str<T>(
&mut self,
msg: &'static str,
error_kind: ErrorKind,
) -> Result<T, ReportedError>;
/// Write the error message using [`Self::write_message`], then re-throw it.
/// Trait [`UserFacingError`] acts as an allowlist for error types.
async fn throw_error<T, E>(&mut self, error: E) -> Result<T, ReportedError>
where
E: UserFacingError + Into<anyhow::Error>;
/// Receive [`FeStartupPacket`], which is a first packet sent by a client.
async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket>;
async fn read_message(&mut self) -> io::Result<FeMessage>;
async fn read_password_message(&mut self) -> io::Result<bytes::Bytes>;
}
impl AuthProxyStreamExt for AuthProxyStream {
/// Write the message into an internal buffer, but don't flush the underlying stream.
fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
let mut b = BytesMut::new();
BeMessage::write(&mut b, message).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
self.start_send_unpin(PglbMessage::Postgres(b.freeze()))
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
Ok(self)
}
/// Write the message into an internal buffer and flush it.
async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
self.write_message_noflush(message)?;
self.flush()
.await
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
Ok(self)
}
/// Write the error message using [`Self::write_message`], then re-throw it.
/// Allowing string literals is safe under the assumption they might not contain any runtime info.
/// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
async fn throw_error_str<T>(
&mut self,
msg: &'static str,
error_kind: ErrorKind,
) -> Result<T, ReportedError> {
tracing::info!(
kind = error_kind.to_metric_label(),
msg,
"forwarding error to user"
);
// already error case, ignore client IO error
self.write_message(&BeMessage::ErrorResponse(msg, None))
.await
.inspect_err(|e| debug!("write_message failed: {e}"))
.ok();
Err(ReportedError {
source: anyhow::anyhow!(msg),
error_kind,
})
}
/// Write the error message using [`Self::write_message`], then re-throw it.
/// Trait [`UserFacingError`] acts as an allowlist for error types.
async fn throw_error<T, E>(&mut self, error: E) -> Result<T, ReportedError>
where
E: UserFacingError + Into<anyhow::Error>,
{
let error_kind = error.get_error_kind();
let msg = error.to_string_client();
tracing::info!(
kind=error_kind.to_metric_label(),
error=%error,
msg,
"forwarding error to user"
);
// already error case, ignore client IO error
self.write_message(&BeMessage::ErrorResponse(&msg, None))
.await
.inspect_err(|e| debug!("write_message failed: {e}"))
.ok();
Err(ReportedError {
source: anyhow::anyhow!(error),
error_kind,
})
}
/// Receive [`FeStartupPacket`], which is a first packet sent by a client.
async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
let msg = self
.try_next()
.await
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
.ok_or_else(err_connection)?;
match msg {
PglbMessage::Control(_) => Err(io::Error::new(
io::ErrorKind::Other,
"unexpected control message",
)),
PglbMessage::Postgres(pg) => {
let mut buf = BytesMut::from(&*pg);
FeStartupPacket::parse(&mut buf)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
.ok_or_else(err_connection)
}
}
}
async fn read_message(&mut self) -> io::Result<FeMessage> {
let msg = self
.try_next()
.await
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
.ok_or_else(err_connection)?;
match msg {
PglbMessage::Control(_) => Err(io::Error::new(
io::ErrorKind::Other,
"unexpected control message",
)),
PglbMessage::Postgres(pg) => {
let mut buf = BytesMut::from(&*pg);
FeMessage::parse(&mut buf)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
.ok_or_else(err_connection)
}
}
}
async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
match self.read_message().await? {
FeMessage::PasswordMessage(msg) => Ok(msg),
bad => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("unexpected message type: {bad:?}"),
)),
}
}
}