diff --git a/proxy/src/auth_proxy/backend.rs b/proxy/src/auth_proxy/backend.rs new file mode 100644 index 0000000000..e0b8eb4e8b --- /dev/null +++ b/proxy/src/auth_proxy/backend.rs @@ -0,0 +1,270 @@ +mod classic; +mod hacks; + +use tracing::info; + +use crate::auth::backend::{ + ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint, +}; +use crate::auth::{self, ComputeUserInfoMaybeEndpoint}; +use crate::auth_proxy::validate_password_and_exchange; +use crate::console::errors::GetAuthInfoError; +use crate::console::provider::{CachedRoleSecret, ConsoleBackend}; +use crate::console::AuthSecret; +use crate::context::RequestMonitoring; +use crate::intern::EndpointIdInt; +use crate::proxy::connect_compute::ComputeConnectBackend; +use crate::scram; +use crate::stream::AuthProxyStreamExt; +use crate::{ + config::AuthenticationConfig, + console::{ + self, + provider::{CachedAllowedIps, CachedNodeInfo}, + Api, + }, +}; + +use super::AuthProxyStream; + +/// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality +pub enum MaybeOwned<'a, T> { + Owned(T), + Borrowed(&'a T), +} + +impl std::ops::Deref for MaybeOwned<'_, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + match self { + MaybeOwned::Owned(t) => t, + MaybeOwned::Borrowed(t) => t, + } + } +} + +/// This type serves two purposes: +/// +/// * When `T` is `()`, it's just a regular auth backend selector +/// which we use in [`crate::config::ProxyConfig`]. +/// +/// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`], +/// this helps us provide the credentials only to those auth +/// backends which require them for the authentication process. +pub enum Backend<'a, T> { + /// Cloud API (V2). + Console(MaybeOwned<'a, ConsoleBackend>, T), +} + +#[cfg(test)] +pub(crate) trait TestBackend: Send + Sync + 'static { + fn wake_compute(&self) -> Result; + fn get_allowed_ips_and_secret( + &self, + ) -> Result<(CachedAllowedIps, Option), console::errors::GetAuthInfoError>; +} + +impl std::fmt::Display for Backend<'_, ()> { + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Console(api, ()) => match &**api { + ConsoleBackend::Console(endpoint) => { + fmt.debug_tuple("Console").field(&endpoint.url()).finish() + } + #[cfg(any(test, feature = "testing"))] + ConsoleBackend::Postgres(endpoint) => { + fmt.debug_tuple("Postgres").field(&endpoint.url()).finish() + } + #[cfg(test)] + ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(), + }, + } + } +} + +impl Backend<'_, T> { + /// Very similar to [`std::option::Option::as_ref`]. + /// This helps us pass structured config to async tasks. + pub(crate) fn as_ref(&self) -> Backend<'_, &T> { + match self { + Self::Console(c, x) => Backend::Console(MaybeOwned::Borrowed(c), x), + } + } +} + +impl<'a, T> Backend<'a, T> { + /// Very similar to [`std::option::Option::map`]. + /// Maps [`Backend`] to [`Backend`] by applying + /// a function to a contained value. + pub(crate) fn map(self, f: impl FnOnce(T) -> R) -> Backend<'a, R> { + match self { + Self::Console(c, x) => Backend::Console(c, f(x)), + } + } +} +impl<'a, T, E> Backend<'a, Result> { + /// Very similar to [`std::option::Option::transpose`]. + /// This is most useful for error handling. + pub(crate) fn transpose(self) -> Result, E> { + match self { + Self::Console(c, x) => x.map(|x| Backend::Console(c, x)), + } + } +} + +/// True to its name, this function encapsulates our current auth trade-offs. +/// Here, we choose the appropriate auth flow based on circumstances. +/// +/// All authentication flows will emit an AuthenticationOk message if successful. +async fn auth_quirks( + api: &impl console::Api, + user_info: ComputeUserInfoMaybeEndpoint, + client: &mut AuthProxyStream, + config: &'static AuthenticationConfig, +) -> auth::Result { + // If there's no project so far, that entails that client doesn't + // support SNI or other means of passing the endpoint (project) name. + // We now expect to see a very specific payload in the place of password. + let (info, unauthenticated_password) = match user_info.try_into() { + Err(info) => { + let res = hacks::password_hack_no_authentication(info, client).await?; + + let password = match res.keys { + ComputeCredentialKeys::Password(p) => p, + ComputeCredentialKeys::AuthKeys(_) | ComputeCredentialKeys::None => { + unreachable!("password hack should return a password") + } + }; + (res.info, Some(password)) + } + Ok(info) => (info, None), + }; + + info!("fetching user's authentication info"); + let cached_secret = api + .get_role_secret(&RequestMonitoring::test(), &info) + .await?; + + let (cached_entry, secret) = cached_secret.take_value(); + + let secret = if let Some(secret) = secret { + secret + } else { + // If we don't have an authentication secret, we mock one to + // prevent malicious probing (possible due to missing protocol steps). + // This mocked secret will never lead to successful authentication. + info!("authentication info not found, mocking it"); + AuthSecret::Scram(scram::ServerSecret::mock(rand::random())) + }; + + match authenticate_with_secret(secret, info, client, unauthenticated_password, config).await { + Ok(keys) => Ok(keys), + Err(e) => { + if e.is_auth_failed() { + // The password could have been changed, so we invalidate the cache. + cached_entry.invalidate(); + } + Err(e) + } + } +} + +async fn authenticate_with_secret( + secret: AuthSecret, + info: ComputeUserInfo, + client: &mut AuthProxyStream, + unauthenticated_password: Option>, + config: &'static AuthenticationConfig, +) -> auth::Result { + if let Some(password) = unauthenticated_password { + let ep = EndpointIdInt::from(&info.endpoint); + + let auth_outcome = + validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?; + let keys = match auth_outcome { + crate::sasl::Outcome::Success(key) => key, + crate::sasl::Outcome::Failure(reason) => { + info!("auth backend failed with an error: {reason}"); + return Err(auth::AuthError::auth_failed(&*info.user)); + } + }; + + // we have authenticated the password + client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?; + + return Ok(ComputeCredentials { info, keys }); + } + + // Finally, proceed with the main auth flow (SCRAM-based). + classic::authenticate(info, client, config, secret).await +} + +impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> { + /// Get username from the credentials. + pub(crate) fn get_user(&self) -> &str { + match self { + Self::Console(_, user_info) => &user_info.user, + } + } + + pub(crate) async fn authenticate( + self, + client: &mut AuthProxyStream, + config: &'static AuthenticationConfig, + ) -> auth::Result> { + let res = match self { + Self::Console(api, user_info) => { + info!( + user = &*user_info.user, + project = user_info.endpoint(), + "performing authentication using the console" + ); + + let credentials = auth_quirks(&*api, user_info, client, config).await?; + Backend::Console(api, credentials) + } + }; + + info!("user successfully authenticated"); + Ok(res) + } +} + +impl Backend<'_, ComputeUserInfo> { + pub(crate) async fn get_role_secret( + &self, + ctx: &RequestMonitoring, + ) -> Result { + match self { + Self::Console(api, user_info) => api.get_role_secret(ctx, user_info).await, + } + } + + pub(crate) async fn get_allowed_ips_and_secret( + &self, + ctx: &RequestMonitoring, + ) -> Result<(CachedAllowedIps, Option), GetAuthInfoError> { + match self { + Self::Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await, + } + } +} + +#[async_trait::async_trait] +impl ComputeConnectBackend for Backend<'_, ComputeCredentials> { + async fn wake_compute( + &self, + ctx: &RequestMonitoring, + ) -> Result { + match self { + Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await, + } + } + + fn get_keys(&self) -> &ComputeCredentialKeys { + match self { + Self::Console(_, creds) => &creds.keys, + } + } +} diff --git a/proxy/src/auth_proxy/backend/classic.rs b/proxy/src/auth_proxy/backend/classic.rs new file mode 100644 index 0000000000..e6380e51ac --- /dev/null +++ b/proxy/src/auth_proxy/backend/classic.rs @@ -0,0 +1,69 @@ +use super::{ComputeCredentials, ComputeUserInfo}; +use crate::{ + auth::{self, backend::ComputeCredentialKeys}, + auth_proxy::{self, AuthFlow, AuthProxyStream}, + compute, + config::AuthenticationConfig, + console::AuthSecret, + sasl, +}; +use tracing::{info, warn}; + +pub(super) async fn authenticate( + creds: ComputeUserInfo, + client: &mut AuthProxyStream, + config: &'static AuthenticationConfig, + secret: AuthSecret, +) -> auth::Result { + let flow = AuthFlow::new(client); + let scram_keys = match secret { + #[cfg(any(test, feature = "testing"))] + AuthSecret::Md5(_) => { + info!("auth endpoint chooses MD5"); + return Err(auth::AuthError::bad_auth_method("MD5")); + } + AuthSecret::Scram(secret) => { + info!("auth endpoint chooses SCRAM"); + let scram = auth_proxy::Scram(&secret); + + let auth_outcome = tokio::time::timeout( + config.scram_protocol_timeout, + async { + + flow.begin(scram).await.map_err(|error| { + warn!(?error, "error sending scram acknowledgement"); + error + })?.authenticate().await.map_err(|error| { + warn!(?error, "error processing scram messages"); + error + }) + } + ) + .await + .map_err(|e| { + warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs()); + auth::AuthError::user_timeout(e) + })??; + + let client_key = match auth_outcome { + sasl::Outcome::Success(key) => key, + sasl::Outcome::Failure(reason) => { + info!("auth backend failed with an error: {reason}"); + return Err(auth::AuthError::auth_failed(&*creds.user)); + } + }; + + compute::ScramKeys { + client_key: client_key.as_bytes(), + server_key: secret.server_key.as_bytes(), + } + } + }; + + Ok(ComputeCredentials { + info: creds, + keys: ComputeCredentialKeys::AuthKeys(tokio_postgres::config::AuthKeys::ScramSha256( + scram_keys, + )), + }) +} diff --git a/proxy/src/auth_proxy/backend/hacks.rs b/proxy/src/auth_proxy/backend/hacks.rs new file mode 100644 index 0000000000..37e9bbb77c --- /dev/null +++ b/proxy/src/auth_proxy/backend/hacks.rs @@ -0,0 +1,77 @@ +use super::{ + ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint, +}; +use crate::{ + auth, + auth_proxy::{self, AuthFlow, AuthProxyStream}, + config::AuthenticationConfig, + console::AuthSecret, + intern::EndpointIdInt, + sasl, +}; +use tracing::{info, warn}; + +/// Compared to [SCRAM](crate::scram), cleartext password auth saves +/// one round trip and *expensive* computations (>= 4096 HMAC iterations). +/// These properties are benefical for serverless JS workers, so we +/// use this mechanism for websocket connections. +pub(crate) async fn authenticate_cleartext( + info: ComputeUserInfo, + client: &mut AuthProxyStream, + secret: AuthSecret, + config: &'static AuthenticationConfig, +) -> auth::Result { + warn!("cleartext auth flow override is enabled, proceeding"); + + let ep = EndpointIdInt::from(&info.endpoint); + + let auth_flow = AuthFlow::new(client) + .begin(auth_proxy::CleartextPassword { + secret, + endpoint: ep, + pool: config.thread_pool.clone(), + }) + .await?; + // cleartext auth is only allowed to the ws/http protocol. + // If we're here, we already received the password in the first message. + // Scram protocol will be executed on the proxy side. + let auth_outcome = auth_flow.authenticate().await?; + + let keys = match auth_outcome { + sasl::Outcome::Success(key) => key, + sasl::Outcome::Failure(reason) => { + info!("auth backend failed with an error: {reason}"); + return Err(auth::AuthError::auth_failed(&*info.user)); + } + }; + + Ok(ComputeCredentials { info, keys }) +} + +/// Workaround for clients which don't provide an endpoint (project) name. +/// Similar to [`authenticate_cleartext`], but there's a specific password format, +/// and passwords are not yet validated (we don't know how to validate them!) +pub(crate) async fn password_hack_no_authentication( + info: ComputeUserInfoNoEndpoint, + client: &mut AuthProxyStream, +) -> auth::Result { + warn!("project not specified, resorting to the password hack auth flow"); + + let payload = AuthFlow::new(client) + .begin(auth_proxy::PasswordHack) + .await? + .get_password() + .await?; + + info!(project = &*payload.endpoint, "received missing parameter"); + + // Report tentative success; compute node will check the password anyway. + Ok(ComputeCredentials { + info: ComputeUserInfo { + user: info.user, + options: info.options, + endpoint: payload.endpoint, + }, + keys: ComputeCredentialKeys::Password(payload.password), + }) +} diff --git a/proxy/src/auth_proxy/flow.rs b/proxy/src/auth_proxy/flow.rs new file mode 100644 index 0000000000..fdf1e6bdae --- /dev/null +++ b/proxy/src/auth_proxy/flow.rs @@ -0,0 +1,218 @@ +//! Main authentication flow. + +use super::{AuthProxyStream, PasswordHackPayload}; +use crate::{ + auth::{self, backend::ComputeCredentialKeys, AuthErrorImpl}, + config::TlsServerEndPoint, + console::AuthSecret, + intern::EndpointIdInt, + sasl, + scram::{self, threadpool::ThreadPool}, + stream::AuthProxyStreamExt, +}; +use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be}; +use std::{io, sync::Arc}; +use tracing::info; + +/// Every authentication selector is supposed to implement this trait. +pub(crate) trait AuthMethod { + /// Any authentication selector should provide initial backend message + /// containing auth method name and parameters, e.g. md5 salt. + fn first_message(&self, channel_binding: bool) -> BeMessage<'_>; +} + +/// Initial state of [`AuthFlow`]. +pub(crate) struct Begin; + +/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`]. +pub(crate) struct Scram<'a>(pub(crate) &'a scram::ServerSecret); + +impl AuthMethod for Scram<'_> { + #[inline(always)] + fn first_message(&self, channel_binding: bool) -> BeMessage<'_> { + if channel_binding { + Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS)) + } else { + Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods( + scram::METHODS_WITHOUT_PLUS, + )) + } + } +} + +/// Use an ad hoc auth flow (for clients which don't support SNI) proposed in +/// . +pub(crate) struct PasswordHack; + +impl AuthMethod for PasswordHack { + #[inline(always)] + fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> { + Be::AuthenticationCleartextPassword + } +} + +/// Use clear-text password auth called `password` in docs +/// +pub(crate) struct CleartextPassword { + pub(crate) pool: Arc, + pub(crate) endpoint: EndpointIdInt, + pub(crate) secret: AuthSecret, +} + +impl AuthMethod for CleartextPassword { + #[inline(always)] + fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> { + Be::AuthenticationCleartextPassword + } +} + +/// This wrapper for [`PqStream`] performs client authentication. +#[must_use] +pub(crate) struct AuthFlow<'a, State> { + /// The underlying stream which implements libpq's protocol. + stream: &'a mut AuthProxyStream, + /// State might contain ancillary data (see [`Self::begin`]). + state: State, + tls_server_end_point: TlsServerEndPoint, +} + +/// Initial state of the stream wrapper. +impl<'a> AuthFlow<'a, Begin> { + /// Create a new wrapper for client authentication. + pub(crate) fn new(stream: &'a mut AuthProxyStream) -> Self { + // TODO: + // let tls_server_end_point = stream.get_ref().tls_server_end_point(); + let tls_server_end_point = TlsServerEndPoint::Undefined; + + Self { + stream, + state: Begin, + tls_server_end_point, + } + } + + /// Move to the next step by sending auth method's name & params to client. + pub(crate) async fn begin(self, method: M) -> io::Result> { + self.stream + .write_message(&method.first_message(self.tls_server_end_point.supported())) + .await?; + + Ok(AuthFlow { + stream: self.stream, + state: method, + tls_server_end_point: self.tls_server_end_point, + }) + } +} + +impl AuthFlow<'_, PasswordHack> { + /// Perform user authentication. Raise an error in case authentication failed. + pub(crate) async fn get_password(self) -> auth::Result { + let msg = self.stream.read_password_message().await?; + let password = msg + .strip_suffix(&[0]) + .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?; + + let payload = PasswordHackPayload::parse(password) + // If we ended up here and the payload is malformed, it means that + // the user neither enabled SNI nor resorted to any other method + // for passing the project name we rely on. We should show them + // the most helpful error message and point to the documentation. + .ok_or(AuthErrorImpl::MissingEndpointName)?; + + Ok(payload) + } +} + +impl AuthFlow<'_, CleartextPassword> { + /// Perform user authentication. Raise an error in case authentication failed. + pub(crate) async fn authenticate(self) -> auth::Result> { + let msg = self.stream.read_password_message().await?; + let password = msg + .strip_suffix(&[0]) + .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?; + + let outcome = validate_password_and_exchange( + &self.state.pool, + self.state.endpoint, + password, + self.state.secret, + ) + .await?; + + if let sasl::Outcome::Success(_) = &outcome { + self.stream.write_message_noflush(&Be::AuthenticationOk)?; + } + + Ok(outcome) + } +} + +/// Stream wrapper for handling [SCRAM](crate::scram) auth. +impl AuthFlow<'_, Scram<'_>> { + /// Perform user authentication. Raise an error in case authentication failed. + pub(crate) async fn authenticate(self) -> auth::Result> { + let Scram(secret) = self.state; + + // Initial client message contains the chosen auth method's name. + let msg = self.stream.read_password_message().await?; + let sasl = sasl::FirstMessage::parse(&msg) + .ok_or(AuthErrorImpl::MalformedPassword("bad sasl message"))?; + + // Currently, the only supported SASL method is SCRAM. + if !scram::METHODS.contains(&sasl.method) { + return Err(auth::AuthError::bad_auth_method(sasl.method)); + } + + info!("client chooses {}", sasl.method); + + let outcome = sasl::SaslStream2::new(self.stream, sasl.message) + .authenticate(scram::Exchange::new( + secret, + rand::random, + self.tls_server_end_point, + )) + .await?; + + if let sasl::Outcome::Success(_) = &outcome { + self.stream.write_message_noflush(&Be::AuthenticationOk)?; + } + + Ok(outcome) + } +} + +pub(crate) async fn validate_password_and_exchange( + pool: &ThreadPool, + endpoint: EndpointIdInt, + password: &[u8], + secret: AuthSecret, +) -> auth::Result> { + match secret { + #[cfg(any(test, feature = "testing"))] + AuthSecret::Md5(_) => { + // test only + Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password( + password.to_owned(), + ))) + } + // perform scram authentication as both client and server to validate the keys + AuthSecret::Scram(scram_secret) => { + let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?; + + let client_key = match outcome { + sasl::Outcome::Success(client_key) => client_key, + sasl::Outcome::Failure(reason) => return Ok(sasl::Outcome::Failure(reason)), + }; + + let keys = crate::compute::ScramKeys { + client_key: client_key.as_bytes(), + server_key: scram_secret.server_key.as_bytes(), + }; + + Ok(sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys( + tokio_postgres::config::AuthKeys::ScramSha256(keys), + ))) + } + } +} diff --git a/proxy/src/auth_proxy/mod.rs b/proxy/src/auth_proxy/mod.rs new file mode 100644 index 0000000000..3d7317ed04 --- /dev/null +++ b/proxy/src/auth_proxy/mod.rs @@ -0,0 +1,17 @@ +//! Client authentication mechanisms. + +pub mod backend; +pub use backend::Backend; + +mod password_hack; +use password_hack::PasswordHackPayload; + +mod flow; +pub(crate) use flow::*; +use quinn::{RecvStream, SendStream}; +use tokio::io::Join; +use tokio_util::codec::Framed; + +use crate::PglbCodec; + +pub type AuthProxyStream = Framed, PglbCodec>; diff --git a/proxy/src/auth_proxy/password_hack.rs b/proxy/src/auth_proxy/password_hack.rs new file mode 100644 index 0000000000..8585b8ff48 --- /dev/null +++ b/proxy/src/auth_proxy/password_hack.rs @@ -0,0 +1,121 @@ +//! Payload for ad hoc authentication method for clients that don't support SNI. +//! See the `impl` for [`super::backend::Backend`]. +//! Read more: . +//! UPDATE (Mon Aug 8 13:20:34 UTC 2022): the payload format has been simplified. + +use bstr::ByteSlice; + +use crate::EndpointId; + +pub(crate) struct PasswordHackPayload { + pub(crate) endpoint: EndpointId, + pub(crate) password: Vec, +} + +impl PasswordHackPayload { + pub(crate) fn parse(bytes: &[u8]) -> Option { + // The format is `project=;` or `project=$`. + let separators = [";", "$"]; + for sep in separators { + if let Some((endpoint, password)) = bytes.split_once_str(sep) { + let endpoint = endpoint.to_str().ok()?; + return Some(Self { + endpoint: parse_endpoint_param(endpoint)?.into(), + password: password.to_owned(), + }); + } + } + + None + } +} + +pub(crate) fn parse_endpoint_param(bytes: &str) -> Option<&str> { + bytes + .strip_prefix("project=") + .or_else(|| bytes.strip_prefix("endpoint=")) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_endpoint_param_fn() { + let input = ""; + assert!(parse_endpoint_param(input).is_none()); + + let input = "project="; + assert_eq!(parse_endpoint_param(input), Some("")); + + let input = "project=foobar"; + assert_eq!(parse_endpoint_param(input), Some("foobar")); + + let input = "endpoint="; + assert_eq!(parse_endpoint_param(input), Some("")); + + let input = "endpoint=foobar"; + assert_eq!(parse_endpoint_param(input), Some("foobar")); + + let input = "other_option=foobar"; + assert!(parse_endpoint_param(input).is_none()); + } + + #[test] + fn parse_password_hack_payload_project() { + let bytes = b""; + assert!(PasswordHackPayload::parse(bytes).is_none()); + + let bytes = b"project="; + assert!(PasswordHackPayload::parse(bytes).is_none()); + + let bytes = b"project=;"; + let payload: PasswordHackPayload = + PasswordHackPayload::parse(bytes).expect("parsing failed"); + assert_eq!(payload.endpoint, ""); + assert_eq!(payload.password, b""); + + let bytes = b"project=foobar;pass;word"; + let payload = PasswordHackPayload::parse(bytes).expect("parsing failed"); + assert_eq!(payload.endpoint, "foobar"); + assert_eq!(payload.password, b"pass;word"); + } + + #[test] + fn parse_password_hack_payload_endpoint() { + let bytes = b""; + assert!(PasswordHackPayload::parse(bytes).is_none()); + + let bytes = b"endpoint="; + assert!(PasswordHackPayload::parse(bytes).is_none()); + + let bytes = b"endpoint=;"; + let payload = PasswordHackPayload::parse(bytes).expect("parsing failed"); + assert_eq!(payload.endpoint, ""); + assert_eq!(payload.password, b""); + + let bytes = b"endpoint=foobar;pass;word"; + let payload = PasswordHackPayload::parse(bytes).expect("parsing failed"); + assert_eq!(payload.endpoint, "foobar"); + assert_eq!(payload.password, b"pass;word"); + } + + #[test] + fn parse_password_hack_payload_dollar() { + let bytes = b""; + assert!(PasswordHackPayload::parse(bytes).is_none()); + + let bytes = b"endpoint="; + assert!(PasswordHackPayload::parse(bytes).is_none()); + + let bytes = b"endpoint=$"; + let payload = PasswordHackPayload::parse(bytes).expect("parsing failed"); + assert_eq!(payload.endpoint, ""); + assert_eq!(payload.password, b""); + + let bytes = b"endpoint=foobar$pass$word"; + let payload = PasswordHackPayload::parse(bytes).expect("parsing failed"); + assert_eq!(payload.endpoint, "foobar"); + assert_eq!(payload.password, b"pass$word"); + } +} diff --git a/proxy/src/bin/auth_proxy.rs b/proxy/src/bin/auth_proxy.rs index 2f0686f11e..2e9b5b2997 100644 --- a/proxy/src/bin/auth_proxy.rs +++ b/proxy/src/bin/auth_proxy.rs @@ -1,6 +1,7 @@ use std::{sync::Arc, time::Duration}; -use proxy::PglbCodec; +use futures::TryStreamExt; +use proxy::{PglbCodec, PglbControlMessage, PglbMessage}; use quinn::{ crypto::rustls::QuicClientConfig, rustls::client::danger, Endpoint, RecvStream, SendStream, VarInt, @@ -11,10 +12,7 @@ use tokio::{ signal::unix::{signal, SignalKind}, time::interval, }; -use tokio_util::{ - codec::{Framed, FramedRead, FramedWrite}, - task::TaskTracker, -}; +use tokio_util::{codec::Framed, task::TaskTracker}; #[tokio::main] async fn main() { @@ -107,5 +105,11 @@ impl danger::ServerCertVerifier for NoVerify { } async fn handle_stream(send: SendStream, recv: RecvStream) { - let _stream = Framed::new(join(recv, send), PglbCodec); + let mut stream = Framed::new(join(recv, send), PglbCodec); + + let first_msg = stream.try_next().await.unwrap(); + let Some(PglbMessage::Control(PglbControlMessage::ConnectionInitiated(_first_msg))) = first_msg + else { + panic!("invalid first msg") + }; } diff --git a/proxy/src/context.rs b/proxy/src/context.rs index c013218ad9..950516523f 100644 --- a/proxy/src/context.rs +++ b/proxy/src/context.rs @@ -125,7 +125,7 @@ impl RequestMonitoring { Self(TryLock::new(inner)) } - #[cfg(test)] + // #[cfg(test)] pub(crate) fn test() -> Self { RequestMonitoring::new(Uuid::now_v7(), [127, 0, 0, 1].into(), Protocol::Tcp, "test") } diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index 0c820c8512..82c8a01301 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -98,6 +98,7 @@ use tokio_util::sync::CancellationToken; use tracing::warn; pub mod auth; +pub mod auth_proxy; pub mod cache; pub mod cancellation; pub mod compute; @@ -405,7 +406,7 @@ pub enum PglbControlMessage { #[derive(Serialize, Deserialize)] pub struct ConnectionInitiatedPayload { - tls_server_end_point: TlsServerEndPoint, - server_name: Option, - ip_addr: IpAddr, + pub tls_server_end_point: TlsServerEndPoint, + pub server_name: Option, + pub ip_addr: IpAddr, } diff --git a/proxy/src/sasl.rs b/proxy/src/sasl.rs index 0a36694359..e394499426 100644 --- a/proxy/src/sasl.rs +++ b/proxy/src/sasl.rs @@ -9,6 +9,7 @@ mod channel_binding; mod messages; mod stream; +mod stream2; use crate::error::{ReportableError, UserFacingError}; use std::io; @@ -17,6 +18,7 @@ use thiserror::Error; pub(crate) use channel_binding::ChannelBinding; pub(crate) use messages::FirstMessage; pub(crate) use stream::{Outcome, SaslStream}; +pub(crate) use stream2::SaslStream2; /// Fine-grained auth errors help in writing tests. #[derive(Error, Debug)] diff --git a/proxy/src/sasl/stream2.rs b/proxy/src/sasl/stream2.rs new file mode 100644 index 0000000000..f09633b267 --- /dev/null +++ b/proxy/src/sasl/stream2.rs @@ -0,0 +1,85 @@ +//! Abstraction for the string-oriented SASL protocols. + +use crate::{ + auth_proxy::AuthProxyStream, + sasl::{messages::ServerMessage, Mechanism}, + stream::AuthProxyStreamExt, +}; +use std::io; +use tracing::info; + +use super::Outcome; + +/// Abstracts away all peculiarities of the libpq's protocol. +pub(crate) struct SaslStream2<'a> { + /// The underlying stream. + stream: &'a mut AuthProxyStream, + /// Current password message we received from client. + current: bytes::Bytes, + /// First SASL message produced by client. + first: Option<&'a str>, +} + +impl<'a> SaslStream2<'a> { + pub(crate) fn new(stream: &'a mut AuthProxyStream, first: &'a str) -> Self { + Self { + stream, + current: bytes::Bytes::new(), + first: Some(first), + } + } +} + +impl SaslStream2<'_> { + // Receive a new SASL message from the client. + async fn recv(&mut self) -> io::Result<&str> { + if let Some(first) = self.first.take() { + return Ok(first); + } + + self.current = self.stream.read_password_message().await?; + let s = std::str::from_utf8(&self.current) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?; + + Ok(s) + } +} + +impl SaslStream2<'_> { + // Send a SASL message to the client. + async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> { + self.stream.write_message(&msg.to_reply()).await?; + Ok(()) + } +} + +impl SaslStream2<'_> { + /// Perform SASL message exchange according to the underlying algorithm + /// until user is either authenticated or denied access. + pub(crate) async fn authenticate( + mut self, + mut mechanism: M, + ) -> crate::sasl::Result> { + loop { + let input = self.recv().await?; + let step = mechanism.exchange(input).map_err(|error| { + info!(?error, "error during SASL exchange"); + error + })?; + + use crate::sasl::Step; + return Ok(match step { + Step::Continue(moved_mechanism, reply) => { + self.send(&ServerMessage::Continue(&reply)).await?; + mechanism = moved_mechanism; + continue; + } + Step::Success(result, reply) => { + self.send(&ServerMessage::Final(&reply)).await?; + Outcome::Success(result) + } + Step::Failure(reason) => Outcome::Failure(reason), + }); + } + } +} diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index e2fc73235e..d0f6920271 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -1,8 +1,11 @@ +use crate::auth_proxy::AuthProxyStream; use crate::config::TlsServerEndPoint; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::metrics::Metrics; +use crate::PglbMessage; use bytes::BytesMut; +use futures::{SinkExt, TryStreamExt}; use pq_proto::framed::{ConnectionError, Framed}; use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError}; use rustls::ServerConfig; @@ -294,3 +297,161 @@ impl AsyncWrite for Stream { } } } + +pub(crate) trait AuthProxyStreamExt { + /// Write the message into an internal buffer, but don't flush the underlying stream. + fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self>; + + /// Write the message into an internal buffer and flush it. + async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self>; + + // /// Flush the output buffer into the underlying stream. + // async fn flush(&mut self) -> io::Result<&mut Self>; + + /// Write the error message using [`Self::write_message`], then re-throw it. + /// Allowing string literals is safe under the assumption they might not contain any runtime info. + /// This method exists due to `&str` not implementing `Into`. + async fn throw_error_str( + &mut self, + msg: &'static str, + error_kind: ErrorKind, + ) -> Result; + + /// Write the error message using [`Self::write_message`], then re-throw it. + /// Trait [`UserFacingError`] acts as an allowlist for error types. + async fn throw_error(&mut self, error: E) -> Result + where + E: UserFacingError + Into; + + /// Receive [`FeStartupPacket`], which is a first packet sent by a client. + async fn read_startup_packet(&mut self) -> io::Result; + async fn read_message(&mut self) -> io::Result; + + async fn read_password_message(&mut self) -> io::Result; +} + +impl AuthProxyStreamExt for AuthProxyStream { + /// Write the message into an internal buffer, but don't flush the underlying stream. + fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> { + let mut b = BytesMut::new(); + BeMessage::write(&mut b, message).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + self.start_send_unpin(PglbMessage::Postgres(b.freeze())) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + Ok(self) + } + + /// Write the message into an internal buffer and flush it. + async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> { + self.write_message_noflush(message)?; + self.flush() + .await + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + Ok(self) + } + + /// Write the error message using [`Self::write_message`], then re-throw it. + /// Allowing string literals is safe under the assumption they might not contain any runtime info. + /// This method exists due to `&str` not implementing `Into`. + async fn throw_error_str( + &mut self, + msg: &'static str, + error_kind: ErrorKind, + ) -> Result { + tracing::info!( + kind = error_kind.to_metric_label(), + msg, + "forwarding error to user" + ); + + // already error case, ignore client IO error + self.write_message(&BeMessage::ErrorResponse(msg, None)) + .await + .inspect_err(|e| debug!("write_message failed: {e}")) + .ok(); + + Err(ReportedError { + source: anyhow::anyhow!(msg), + error_kind, + }) + } + + /// Write the error message using [`Self::write_message`], then re-throw it. + /// Trait [`UserFacingError`] acts as an allowlist for error types. + async fn throw_error(&mut self, error: E) -> Result + where + E: UserFacingError + Into, + { + let error_kind = error.get_error_kind(); + let msg = error.to_string_client(); + tracing::info!( + kind=error_kind.to_metric_label(), + error=%error, + msg, + "forwarding error to user" + ); + + // already error case, ignore client IO error + self.write_message(&BeMessage::ErrorResponse(&msg, None)) + .await + .inspect_err(|e| debug!("write_message failed: {e}")) + .ok(); + + Err(ReportedError { + source: anyhow::anyhow!(error), + error_kind, + }) + } + + /// Receive [`FeStartupPacket`], which is a first packet sent by a client. + async fn read_startup_packet(&mut self) -> io::Result { + let msg = self + .try_next() + .await + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))? + .ok_or_else(err_connection)?; + + match msg { + PglbMessage::Control(_) => Err(io::Error::new( + io::ErrorKind::Other, + "unexpected control message", + )), + PglbMessage::Postgres(pg) => { + let mut buf = BytesMut::from(&*pg); + FeStartupPacket::parse(&mut buf) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))? + .ok_or_else(err_connection) + } + } + } + + async fn read_message(&mut self) -> io::Result { + let msg = self + .try_next() + .await + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))? + .ok_or_else(err_connection)?; + + match msg { + PglbMessage::Control(_) => Err(io::Error::new( + io::ErrorKind::Other, + "unexpected control message", + )), + PglbMessage::Postgres(pg) => { + let mut buf = BytesMut::from(&*pg); + FeMessage::parse(&mut buf) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))? + .ok_or_else(err_connection) + } + } + } + + async fn read_password_message(&mut self) -> io::Result { + match self.read_message().await? { + FeMessage::PasswordMessage(msg) => Ok(msg), + bad => Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("unexpected message type: {bad:?}"), + )), + } + } +}