diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 48c8604d86..438190261d 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -4,6 +4,10 @@ version = "0.1.0" edition.workspace = true license.workspace = true +[features] +default = [] +testing = [] + [dependencies] anyhow.workspace = true async-trait.workspace = true @@ -69,6 +73,7 @@ webpki-roots.workspace = true x509-parser.workspace = true native-tls.workspace = true postgres-native-tls.workspace = true +postgres-protocol.workspace = true smol_str.workspace = true workspace_hack.workspace = true @@ -78,4 +83,3 @@ tokio-util.workspace = true rcgen.workspace = true rstest.workspace = true tokio-postgres-rustls.workspace = true -postgres-protocol.workspace = true diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index aa872285b1..649b3f40f2 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -3,9 +3,11 @@ mod hacks; mod link; pub use link::LinkAuthError; +use smol_str::SmolStr; use tokio_postgres::config::AuthKeys; use crate::auth::credentials::check_peer_addr_is_in_list; +use crate::auth::validate_password_and_exchange; use crate::console::errors::GetAuthInfoError; use crate::console::provider::AuthInfo; use crate::console::AuthSecret; @@ -24,31 +26,12 @@ use crate::{ }; use futures::TryFutureExt; use std::borrow::Cow; +use std::net::IpAddr; use std::ops::ControlFlow; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{error, info, warn}; -/// A product of successful authentication. -pub struct AuthSuccess { - /// Did we send [`pq_proto::BeMessage::AuthenticationOk`] to client? - pub reported_auth_ok: bool, - /// Something to be considered a positive result. - pub value: T, -} - -impl AuthSuccess { - /// Very similar to [`std::option::Option::map`]. - /// Maps [`AuthSuccess`] to [`AuthSuccess`] by applying - /// a function to a contained value. - pub fn map(self, f: impl FnOnce(T) -> R) -> AuthSuccess { - AuthSuccess { - reported_auth_ok: self.reported_auth_ok, - value: f(self.value), - } - } -} - /// This type serves two purposes: /// /// * When `T` is `()`, it's just a regular auth backend selector @@ -61,9 +44,11 @@ pub enum BackendType<'a, T> { /// Current Cloud API (V2). Console(Cow<'a, console::provider::neon::Api>, T), /// Local mock of Cloud API (V2). + #[cfg(feature = "testing")] Postgres(Cow<'a, console::provider::mock::Api>, T), /// Authentication via a web browser. Link(Cow<'a, url::ApiUrl>), + #[cfg(test)] /// Test backend. Test(&'a dyn TestBackend), } @@ -78,8 +63,10 @@ impl std::fmt::Display for BackendType<'_, ()> { use BackendType::*; match self { Console(endpoint, _) => fmt.debug_tuple("Console").field(&endpoint.url()).finish(), + #[cfg(feature = "testing")] Postgres(endpoint, _) => fmt.debug_tuple("Postgres").field(&endpoint.url()).finish(), Link(url) => fmt.debug_tuple("Link").field(&url.as_str()).finish(), + #[cfg(test)] Test(_) => fmt.debug_tuple("Test").finish(), } } @@ -92,8 +79,10 @@ impl BackendType<'_, T> { use BackendType::*; match self { Console(c, x) => Console(Cow::Borrowed(c), x), + #[cfg(feature = "testing")] Postgres(c, x) => Postgres(Cow::Borrowed(c), x), Link(c) => Link(Cow::Borrowed(c)), + #[cfg(test)] Test(x) => Test(*x), } } @@ -107,8 +96,10 @@ impl<'a, T> BackendType<'a, T> { use BackendType::*; match self { Console(c, x) => Console(c, f(x)), + #[cfg(feature = "testing")] Postgres(c, x) => Postgres(c, f(x)), Link(c) => Link(c), + #[cfg(test)] Test(x) => Test(x), } } @@ -121,51 +112,87 @@ impl<'a, T, E> BackendType<'a, Result> { use BackendType::*; match self { Console(c, x) => x.map(|x| Console(c, x)), + #[cfg(feature = "testing")] Postgres(c, x) => x.map(|x| Postgres(c, x)), Link(c) => Ok(Link(c)), + #[cfg(test)] Test(x) => Ok(Test(x)), } } } -pub enum ComputeCredentials { +pub struct ComputeCredentials { + pub info: ComputeUserInfo, + pub keys: T, +} + +pub struct ComputeUserInfoNoEndpoint { + pub user: SmolStr, + pub peer_addr: IpAddr, + pub cache_key: SmolStr, +} + +pub struct ComputeUserInfo { + pub endpoint: SmolStr, + pub inner: ComputeUserInfoNoEndpoint, +} + +pub enum ComputeCredentialKeys { + #[cfg(feature = "testing")] Password(Vec), AuthKeys(AuthKeys), } +impl TryFrom for ComputeUserInfo { + // user name + type Error = ComputeUserInfoNoEndpoint; + + fn try_from(creds: ClientCredentials) -> Result { + let inner = ComputeUserInfoNoEndpoint { + user: creds.user, + peer_addr: creds.peer_addr, + cache_key: creds.cache_key, + }; + match creds.project { + None => Err(inner), + Some(endpoint) => Ok(ComputeUserInfo { endpoint, inner }), + } + } +} + /// True to its name, this function encapsulates our current auth trade-offs. /// Here, we choose the appropriate auth flow based on circumstances. -async fn auth_quirks_creds( +/// +/// All authentication flows will emit an AuthenticationOk message if successful. +async fn auth_quirks( api: &impl console::Api, extra: &ConsoleReqExtra<'_>, - creds: &mut ClientCredentials<'_>, + creds: ClientCredentials, client: &mut stream::PqStream>, allow_cleartext: bool, config: &'static AuthenticationConfig, latency_timer: &mut LatencyTimer, -) -> auth::Result> { +) -> auth::Result> { // If there's no project so far, that entails that client doesn't // support SNI or other means of passing the endpoint (project) name. // We now expect to see a very specific payload in the place of password. - let maybe_success = if creds.project.is_none() { - // Password will be checked by the compute node later. - Some(hacks::password_hack(creds, client, latency_timer).await?) - } else { - None + let (info, unauthenticated_password) = match creds.try_into() { + Err(info) => { + let res = hacks::password_hack_no_authentication(info, client, latency_timer).await?; + (res.info, Some(res.keys)) + } + Ok(info) => (info, None), }; - // Password hack should set the project name. - // TODO: make `creds.project` more type-safe. - assert!(creds.project.is_some()); info!("fetching user's authentication info"); // TODO(anna): this will slow down both "hacks" below; we probably need a cache. let AuthInfo { secret, allowed_ips, - } = api.get_auth_info(extra, creds).await?; + } = api.get_auth_info(extra, &info).await?; // check allowed list - if !check_peer_addr_is_in_list(&creds.peer_addr.ip(), &allowed_ips) { + if !check_peer_addr_is_in_list(&info.inner.peer_addr, &allowed_ips) { return Err(auth::AuthError::ip_address_not_allowed()); } let secret = secret.unwrap_or_else(|| { @@ -173,36 +200,49 @@ async fn auth_quirks_creds( // prevent malicious probing (possible due to missing protocol steps). // This mocked secret will never lead to successful authentication. info!("authentication info not found, mocking it"); - AuthSecret::Scram(scram::ServerSecret::mock(creds.user, rand::random())) + AuthSecret::Scram(scram::ServerSecret::mock(&info.inner.user, rand::random())) }); - if let Some(success) = maybe_success { - return Ok(success); + if let Some(password) = unauthenticated_password { + let auth_outcome = validate_password_and_exchange(&password, secret)?; + let keys = match auth_outcome { + crate::sasl::Outcome::Success(key) => key, + crate::sasl::Outcome::Failure(reason) => { + info!("auth backend failed with an error: {reason}"); + return Err(auth::AuthError::auth_failed(&*info.inner.user)); + } + }; + + // we have authenticated the password + client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?; + + return Ok(ComputeCredentials { info, keys }); } + // -- the remaining flows are self-authenticating -- + // Perform cleartext auth if we're allowed to do that. // Currently, we use it for websocket connections (latency). if allow_cleartext { - // Password will be checked by the compute node later. - return hacks::cleartext_hack(client, latency_timer).await; + return hacks::authenticate_cleartext(info, client, latency_timer, secret).await; } // Finally, proceed with the main auth flow (SCRAM-based). - classic::authenticate(creds, client, config, latency_timer, secret).await + classic::authenticate(info, client, config, latency_timer, secret).await } -/// True to its name, this function encapsulates our current auth trade-offs. -/// Here, we choose the appropriate auth flow based on circumstances. -async fn auth_quirks( +/// Authenticate the user and then wake a compute (or retrieve an existing compute session from cache) +/// only if authentication was successfuly. +async fn auth_and_wake_compute( api: &impl console::Api, extra: &ConsoleReqExtra<'_>, - creds: &mut ClientCredentials<'_>, + creds: ClientCredentials, client: &mut stream::PqStream>, allow_cleartext: bool, config: &'static AuthenticationConfig, latency_timer: &mut LatencyTimer, -) -> auth::Result> { - let auth_stuff = auth_quirks_creds( +) -> auth::Result<(CachedNodeInfo, ComputeUserInfo)> { + let compute_credentials = auth_quirks( api, extra, creds, @@ -215,7 +255,7 @@ async fn auth_quirks( let mut num_retries = 0; let mut node = loop { - let wake_res = api.wake_compute(extra, creds).await; + let wake_res = api.wake_compute(extra, &compute_credentials.info).await; match handle_try_wake(wake_res, num_retries) { Err(e) => { error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node"); @@ -232,27 +272,27 @@ async fn auth_quirks( tokio::time::sleep(wait_duration).await; }; - match auth_stuff.value { - ComputeCredentials::Password(password) => node.config.password(password), - ComputeCredentials::AuthKeys(auth_keys) => node.config.auth_keys(auth_keys), + match compute_credentials.keys { + #[cfg(feature = "testing")] + ComputeCredentialKeys::Password(password) => node.config.password(password), + ComputeCredentialKeys::AuthKeys(auth_keys) => node.config.auth_keys(auth_keys), }; - Ok(AuthSuccess { - reported_auth_ok: auth_stuff.reported_auth_ok, - value: node, - }) + Ok((node, compute_credentials.info)) } -impl BackendType<'_, ClientCredentials<'_>> { +impl<'a> BackendType<'a, ClientCredentials> { /// Get compute endpoint name from the credentials. - pub fn get_endpoint(&self) -> Option { + pub fn get_endpoint(&self) -> Option { use BackendType::*; match self { Console(_, creds) => creds.project.clone(), + #[cfg(feature = "testing")] Postgres(_, creds) => creds.project.clone(), - Link(_) => Some("link".to_owned()), - Test(_) => Some("test".to_owned()), + Link(_) => Some("link".into()), + #[cfg(test)] + Test(_) => Some("test".into()), } } @@ -261,9 +301,11 @@ impl BackendType<'_, ClientCredentials<'_>> { use BackendType::*; match self { - Console(_, creds) => creds.user, - Postgres(_, creds) => creds.user, + Console(_, creds) => &creds.user, + #[cfg(feature = "testing")] + Postgres(_, creds) => &creds.user, Link(_) => "link", + #[cfg(test)] Test(_) => "test", } } @@ -271,26 +313,25 @@ impl BackendType<'_, ClientCredentials<'_>> { /// Authenticate the client via the requested backend, possibly using credentials. #[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)] pub async fn authenticate( - &mut self, + self, extra: &ConsoleReqExtra<'_>, client: &mut stream::PqStream>, allow_cleartext: bool, config: &'static AuthenticationConfig, latency_timer: &mut LatencyTimer, - ) -> auth::Result> { + ) -> auth::Result<(CachedNodeInfo, BackendType<'a, ComputeUserInfo>)> { use BackendType::*; let res = match self { Console(api, creds) => { info!( - user = creds.user, + user = &*creds.user, project = creds.project(), "performing authentication using the console" ); - let api = api.as_ref(); - auth_quirks( - api, + let (cache_info, user_info) = auth_and_wake_compute( + &*api, extra, creds, client, @@ -298,18 +339,19 @@ impl BackendType<'_, ClientCredentials<'_>> { config, latency_timer, ) - .await? + .await?; + (cache_info, BackendType::Console(api, user_info)) } + #[cfg(feature = "testing")] Postgres(api, creds) => { info!( - user = creds.user, + user = &*creds.user, project = creds.project(), "performing authentication using a local postgres instance" ); - let api = api.as_ref(); - auth_quirks( - api, + let (cache_info, user_info) = auth_and_wake_compute( + &*api, extra, creds, client, @@ -317,16 +359,21 @@ impl BackendType<'_, ClientCredentials<'_>> { config, latency_timer, ) - .await? + .await?; + (cache_info, BackendType::Postgres(api, user_info)) } // NOTE: this auth backend doesn't use client credentials. Link(url) => { info!("performing link authentication"); - link::authenticate(url, client) - .await? - .map(CachedNodeInfo::new_uncached) + let node_info = link::authenticate(&url, client).await?; + + ( + CachedNodeInfo::new_uncached(node_info), + BackendType::Link(url), + ) } + #[cfg(test)] Test(_) => { unreachable!("this function should never be called in the test backend") } @@ -335,7 +382,9 @@ impl BackendType<'_, ClientCredentials<'_>> { info!("user successfully authenticated"); Ok(res) } +} +impl BackendType<'_, ComputeUserInfo> { pub async fn get_allowed_ips( &self, extra: &ConsoleReqExtra<'_>, @@ -343,8 +392,10 @@ impl BackendType<'_, ClientCredentials<'_>> { use BackendType::*; match self { Console(api, creds) => api.get_allowed_ips(extra, creds).await, + #[cfg(feature = "testing")] Postgres(api, creds) => api.get_allowed_ips(extra, creds).await, Link(_) => Ok(Arc::new(vec![])), + #[cfg(test)] Test(x) => x.get_allowed_ips(), } } @@ -359,8 +410,10 @@ impl BackendType<'_, ClientCredentials<'_>> { match self { Console(api, creds) => api.wake_compute(extra, creds).map_ok(Some).await, + #[cfg(feature = "testing")] Postgres(api, creds) => api.wake_compute(extra, creds).map_ok(Some).await, Link(_) => Ok(None), + #[cfg(test)] Test(x) => x.wake_compute().map(Some), } } diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index bb210821cd..ce52daf16c 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -1,6 +1,6 @@ -use super::{AuthSuccess, ComputeCredentials}; +use super::{ComputeCredentials, ComputeUserInfo}; use crate::{ - auth::{self, AuthFlow, ClientCredentials}, + auth::{self, backend::ComputeCredentialKeys, AuthFlow}, compute, config::AuthenticationConfig, console::AuthSecret, @@ -12,14 +12,15 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{info, warn}; pub(super) async fn authenticate( - creds: &ClientCredentials<'_>, + creds: ComputeUserInfo, client: &mut PqStream>, config: &'static AuthenticationConfig, latency_timer: &mut LatencyTimer, secret: AuthSecret, -) -> auth::Result> { +) -> auth::Result> { let flow = AuthFlow::new(client); let scram_keys = match secret { + #[cfg(feature = "testing")] AuthSecret::Md5(_) => { info!("auth endpoint chooses MD5"); return Err(auth::AuthError::bad_auth_method("MD5")); @@ -53,7 +54,7 @@ pub(super) async fn authenticate( sasl::Outcome::Success(key) => key, sasl::Outcome::Failure(reason) => { info!("auth backend failed with an error: {reason}"); - return Err(auth::AuthError::auth_failed(creds.user)); + return Err(auth::AuthError::auth_failed(&*creds.inner.user)); } }; @@ -64,9 +65,9 @@ pub(super) async fn authenticate( } }; - Ok(AuthSuccess { - reported_auth_ok: false, - value: ComputeCredentials::AuthKeys(tokio_postgres::config::AuthKeys::ScramSha256( + Ok(ComputeCredentials { + info: creds, + keys: ComputeCredentialKeys::AuthKeys(tokio_postgres::config::AuthKeys::ScramSha256( scram_keys, )), }) diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index 4448dbc56a..abbd25008b 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -1,7 +1,11 @@ -use super::{AuthSuccess, ComputeCredentials}; +use super::{ + ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint, +}; use crate::{ - auth::{self, AuthFlow, ClientCredentials}, + auth::{self, AuthFlow}, + console::AuthSecret, proxy::LatencyTimer, + sasl, stream::{self, Stream}, }; use tokio::io::{AsyncRead, AsyncWrite}; @@ -11,35 +15,42 @@ use tracing::{info, warn}; /// one round trip and *expensive* computations (>= 4096 HMAC iterations). /// These properties are benefical for serverless JS workers, so we /// use this mechanism for websocket connections. -pub async fn cleartext_hack( +pub async fn authenticate_cleartext( + info: ComputeUserInfo, client: &mut stream::PqStream>, latency_timer: &mut LatencyTimer, -) -> auth::Result> { + secret: AuthSecret, +) -> auth::Result> { warn!("cleartext auth flow override is enabled, proceeding"); // pause the timer while we communicate with the client let _paused = latency_timer.pause(); - let password = AuthFlow::new(client) - .begin(auth::CleartextPassword) + let auth_outcome = AuthFlow::new(client) + .begin(auth::CleartextPassword(secret)) .await? .authenticate() .await?; - // Report tentative success; compute node will check the password anyway. - Ok(AuthSuccess { - reported_auth_ok: false, - value: ComputeCredentials::Password(password), - }) + let keys = match auth_outcome { + sasl::Outcome::Success(key) => key, + sasl::Outcome::Failure(reason) => { + info!("auth backend failed with an error: {reason}"); + return Err(auth::AuthError::auth_failed(&*info.inner.user)); + } + }; + + Ok(ComputeCredentials { info, keys }) } /// Workaround for clients which don't provide an endpoint (project) name. -/// Very similar to [`cleartext_hack`], but there's a specific password format. -pub async fn password_hack( - creds: &mut ClientCredentials<'_>, +/// Similar to [`authenticate_cleartext`], but there's a specific password format, +/// and passwords are not yet validated (we don't know how to validate them!) +pub async fn password_hack_no_authentication( + info: ComputeUserInfoNoEndpoint, client: &mut stream::PqStream>, latency_timer: &mut LatencyTimer, -) -> auth::Result> { +) -> auth::Result>> { warn!("project not specified, resorting to the password hack auth flow"); // pause the timer while we communicate with the client @@ -48,15 +59,17 @@ pub async fn password_hack( let payload = AuthFlow::new(client) .begin(auth::PasswordHack) .await? - .authenticate() + .get_password() .await?; - info!(project = &payload.endpoint, "received missing parameter"); - creds.project = Some(payload.endpoint); + info!(project = &*payload.endpoint, "received missing parameter"); // Report tentative success; compute node will check the password anyway. - Ok(AuthSuccess { - reported_auth_ok: false, - value: ComputeCredentials::Password(payload.password), + Ok(ComputeCredentials { + info: ComputeUserInfo { + inner: info, + endpoint: payload.endpoint, + }, + keys: payload.password, }) } diff --git a/proxy/src/auth/backend/link.rs b/proxy/src/auth/backend/link.rs index 3a77d7e5ca..2cf7e3acc7 100644 --- a/proxy/src/auth/backend/link.rs +++ b/proxy/src/auth/backend/link.rs @@ -1,4 +1,3 @@ -use super::AuthSuccess; use crate::{ auth, compute, console::{self, provider::NodeInfo}, @@ -57,7 +56,7 @@ pub fn new_psql_session_id() -> String { pub(super) async fn authenticate( link_uri: &reqwest::Url, client: &mut PqStream, -) -> auth::Result> { +) -> auth::Result { let psql_session_id = new_psql_session_id(); let span = info_span!("link", psql_session_id = &psql_session_id); let greeting = hello_message(link_uri, &psql_session_id); @@ -102,12 +101,9 @@ pub(super) async fn authenticate( config.password(password.as_ref()); } - Ok(AuthSuccess { - reported_auth_ok: true, - value: NodeInfo { - config, - aux: db_info.aux, - allow_self_signed_compute: false, // caller may override - }, + Ok(NodeInfo { + config, + aux: db_info.aux, + allow_self_signed_compute: false, // caller may override }) } diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index facb8da8cd..dd7c58255f 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -7,10 +7,8 @@ use crate::{ }; use itertools::Itertools; use pq_proto::StartupMessageParams; -use std::{ - collections::HashSet, - net::{IpAddr, SocketAddr}, -}; +use smol_str::SmolStr; +use std::{collections::HashSet, net::IpAddr}; use thiserror::Error; use tracing::{info, warn}; @@ -24,7 +22,7 @@ pub enum ClientCredsParseError { SNI ('{}') and project option ('{}').", .domain, .option, )] - InconsistentProjectNames { domain: String, option: String }, + InconsistentProjectNames { domain: SmolStr, option: SmolStr }, #[error( "Common name inferred from SNI ('{}') is not known", @@ -33,7 +31,7 @@ pub enum ClientCredsParseError { UnknownCommonName { cn: String }, #[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")] - MalformedProjectName(String), + MalformedProjectName(SmolStr), } impl UserFacingError for ClientCredsParseError {} @@ -41,34 +39,34 @@ impl UserFacingError for ClientCredsParseError {} /// Various client credentials which we use for authentication. /// Note that we don't store any kind of client key or password here. #[derive(Debug, Clone, PartialEq, Eq)] -pub struct ClientCredentials<'a> { - pub user: &'a str, +pub struct ClientCredentials { + pub user: SmolStr, // TODO: this is a severe misnomer! We should think of a new name ASAP. - pub project: Option, + pub project: Option, - pub cache_key: String, - pub peer_addr: SocketAddr, + pub cache_key: SmolStr, + pub peer_addr: IpAddr, } -impl ClientCredentials<'_> { +impl ClientCredentials { #[inline] pub fn project(&self) -> Option<&str> { self.project.as_deref() } } -impl<'a> ClientCredentials<'a> { +impl ClientCredentials { pub fn parse( - params: &'a StartupMessageParams, + params: &StartupMessageParams, sni: Option<&str>, common_names: Option>, - peer_addr: SocketAddr, + peer_addr: IpAddr, ) -> Result { use ClientCredsParseError::*; // Some parameters are stored in the startup message. let get_param = |key| params.get(key).ok_or(MissingKey(key)); - let user = get_param("user")?; + let user = get_param("user")?.into(); // Project name might be passed via PG's command-line options. let project_option = params @@ -82,7 +80,7 @@ impl<'a> ClientCredentials<'a> { .at_most_one() .ok()? }) - .map(|name| name.to_string()); + .map(|name| name.into()); let project_from_domain = if let Some(sni_str) = sni { if let Some(cn) = common_names { @@ -121,7 +119,7 @@ impl<'a> ClientCredentials<'a> { } .transpose()?; - info!(user, project = project.as_deref(), "credentials"); + info!(%user, project = project.as_deref(), "credentials"); if sni.is_some() { info!("Connection with sni"); NUM_CONNECTION_ACCEPTED_BY_SNI @@ -143,7 +141,8 @@ impl<'a> ClientCredentials<'a> { "{}{}", project.as_deref().unwrap_or(""), neon_options(params).unwrap_or("".to_string()) - ); + ) + .into(); Ok(Self { user, @@ -206,10 +205,10 @@ fn project_name_valid(name: &str) -> bool { name.chars().all(|c| c.is_alphanumeric() || c == '-') } -fn subdomain_from_sni(sni: &str, common_name: &str) -> Option { +fn subdomain_from_sni(sni: &str, common_name: &str) -> Option { sni.strip_suffix(common_name)? .strip_suffix('.') - .map(str::to_owned) + .map(SmolStr::from) } #[cfg(test)] @@ -221,7 +220,7 @@ mod tests { fn parse_bare_minimum() -> anyhow::Result<()> { // According to postgresql, only `user` should be required. let options = StartupMessageParams::new([("user", "john_doe")]); - let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234)); + let peer_addr = IpAddr::from([127, 0, 0, 1]); let creds = ClientCredentials::parse(&options, None, None, peer_addr)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.project, None); @@ -236,7 +235,7 @@ mod tests { ("database", "world"), // should be ignored ("foo", "bar"), // should be ignored ]); - let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234)); + let peer_addr = IpAddr::from([127, 0, 0, 1]); let creds = ClientCredentials::parse(&options, None, None, peer_addr)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.project, None); @@ -251,7 +250,7 @@ mod tests { let sni = Some("foo.localhost"); let common_names = Some(["localhost".into()].into()); - let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234)); + let peer_addr = IpAddr::from([127, 0, 0, 1]); let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.project.as_deref(), Some("foo")); @@ -267,7 +266,7 @@ mod tests { ("options", "-ckey=1 project=bar -c geqo=off"), ]); - let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234)); + let peer_addr = IpAddr::from([127, 0, 0, 1]); let creds = ClientCredentials::parse(&options, None, None, peer_addr)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.project.as_deref(), Some("bar")); @@ -282,7 +281,7 @@ mod tests { ("options", "-ckey=1 endpoint=bar -c geqo=off"), ]); - let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234)); + let peer_addr = IpAddr::from([127, 0, 0, 1]); let creds = ClientCredentials::parse(&options, None, None, peer_addr)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.project.as_deref(), Some("bar")); @@ -300,7 +299,7 @@ mod tests { ), ]); - let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234)); + let peer_addr = IpAddr::from([127, 0, 0, 1]); let creds = ClientCredentials::parse(&options, None, None, peer_addr)?; assert_eq!(creds.user, "john_doe"); assert!(creds.project.is_none()); @@ -315,7 +314,7 @@ mod tests { ("options", "-ckey=1 endpoint=bar project=foo -c geqo=off"), ]); - let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234)); + let peer_addr = IpAddr::from([127, 0, 0, 1]); let creds = ClientCredentials::parse(&options, None, None, peer_addr)?; assert_eq!(creds.user, "john_doe"); assert!(creds.project.is_none()); @@ -330,7 +329,7 @@ mod tests { let sni = Some("baz.localhost"); let common_names = Some(["localhost".into()].into()); - let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234)); + let peer_addr = IpAddr::from([127, 0, 0, 1]); let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.project.as_deref(), Some("baz")); @@ -344,13 +343,13 @@ mod tests { let common_names = Some(["a.com".into(), "b.com".into()].into()); let sni = Some("p1.a.com"); - let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234)); + let peer_addr = IpAddr::from([127, 0, 0, 1]); let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?; assert_eq!(creds.project.as_deref(), Some("p1")); let common_names = Some(["a.com".into(), "b.com".into()].into()); let sni = Some("p1.b.com"); - let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234)); + let peer_addr = IpAddr::from([127, 0, 0, 1]); let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?; assert_eq!(creds.project.as_deref(), Some("p1")); @@ -365,7 +364,7 @@ mod tests { let sni = Some("second.localhost"); let common_names = Some(["localhost".into()].into()); - let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234)); + let peer_addr = IpAddr::from([127, 0, 0, 1]); let err = ClientCredentials::parse(&options, sni, common_names, peer_addr) .expect_err("should fail"); match err { @@ -384,7 +383,7 @@ mod tests { let sni = Some("project.localhost"); let common_names = Some(["example.com".into()].into()); - let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234)); + let peer_addr = IpAddr::from([127, 0, 0, 1]); let err = ClientCredentials::parse(&options, sni, common_names, peer_addr) .expect_err("should fail"); match err { @@ -404,7 +403,7 @@ mod tests { let sni = Some("project.localhost"); let common_names = Some(["localhost".into()].into()); - let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234)); + let peer_addr = IpAddr::from([127, 0, 0, 1]); let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?; assert_eq!(creds.project.as_deref(), Some("project")); assert_eq!( diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index efb90733d6..3151a77263 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -1,8 +1,9 @@ //! Main authentication flow. -use super::{AuthErrorImpl, PasswordHackPayload}; +use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload}; use crate::{ config::TlsServerEndPoint, + console::AuthSecret, sasl, scram, stream::{PqStream, Stream}, }; @@ -50,7 +51,7 @@ impl AuthMethod for PasswordHack { /// Use clear-text password auth called `password` in docs /// -pub struct CleartextPassword; +pub struct CleartextPassword(pub AuthSecret); impl AuthMethod for CleartextPassword { #[inline(always)] @@ -98,7 +99,7 @@ impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> { impl AuthFlow<'_, S, PasswordHack> { /// Perform user authentication. Raise an error in case authentication failed. - pub async fn authenticate(self) -> super::Result { + pub async fn get_password(self) -> super::Result { let msg = self.stream.read_password_message().await?; let password = msg .strip_suffix(&[0]) @@ -117,13 +118,19 @@ impl AuthFlow<'_, S, PasswordHack> { impl AuthFlow<'_, S, CleartextPassword> { /// Perform user authentication. Raise an error in case authentication failed. - pub async fn authenticate(self) -> super::Result> { + pub async fn authenticate(self) -> super::Result> { let msg = self.stream.read_password_message().await?; let password = msg .strip_suffix(&[0]) .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?; - Ok(password.to_vec()) + let outcome = validate_password_and_exchange(password, self.state.0)?; + + if let sasl::Outcome::Success(_) = &outcome { + self.stream.write_message_noflush(&Be::AuthenticationOk)?; + } + + Ok(outcome) } } @@ -152,6 +159,49 @@ impl AuthFlow<'_, S, Scram<'_>> { )) .await?; + if let sasl::Outcome::Success(_) = &outcome { + self.stream.write_message_noflush(&Be::AuthenticationOk)?; + } + Ok(outcome) } } + +pub(super) fn validate_password_and_exchange( + password: &[u8], + secret: AuthSecret, +) -> super::Result> { + match secret { + #[cfg(feature = "testing")] + AuthSecret::Md5(_) => { + // test only + Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password( + password.to_owned(), + ))) + } + // perform scram authentication as both client and server to validate the keys + AuthSecret::Scram(scram_secret) => { + use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256}; + let sasl_client = ScramSha256::new(password, ChannelBinding::unsupported()); + let outcome = crate::scram::exchange( + &scram_secret, + sasl_client, + crate::config::TlsServerEndPoint::Undefined, + )?; + + let client_key = match outcome { + sasl::Outcome::Success(client_key) => client_key, + sasl::Outcome::Failure(reason) => return Ok(sasl::Outcome::Failure(reason)), + }; + + let keys = crate::compute::ScramKeys { + client_key: client_key.as_bytes(), + server_key: scram_secret.server_key.as_bytes(), + }; + + Ok(sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys( + tokio_postgres::config::AuthKeys::ScramSha256(keys), + ))) + } + } +} diff --git a/proxy/src/auth/password_hack.rs b/proxy/src/auth/password_hack.rs index d1da208fef..372b0764ee 100644 --- a/proxy/src/auth/password_hack.rs +++ b/proxy/src/auth/password_hack.rs @@ -4,9 +4,10 @@ //! UPDATE (Mon Aug 8 13:20:34 UTC 2022): the payload format has been simplified. use bstr::ByteSlice; +use smol_str::SmolStr; pub struct PasswordHackPayload { - pub endpoint: String, + pub endpoint: SmolStr, pub password: Vec, } @@ -18,7 +19,7 @@ impl PasswordHackPayload { if let Some((endpoint, password)) = bytes.split_once_str(sep) { let endpoint = endpoint.to_str().ok()?; return Some(Self { - endpoint: parse_endpoint_param(endpoint)?.to_owned(), + endpoint: parse_endpoint_param(endpoint)?.into(), password: password.to_owned(), }); } diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 6c4189de75..fc1c44809a 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -30,6 +30,7 @@ use clap::{Parser, ValueEnum}; #[derive(Clone, Debug, ValueEnum)] enum AuthBackend { Console, + #[cfg(feature = "testing")] Postgres, Link, } @@ -289,6 +290,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { let api = console::provider::neon::Api::new(endpoint, caches, locks); auth::BackendType::Console(Cow::Owned(api), ()) } + #[cfg(feature = "testing")] AuthBackend::Postgres => { let url = args.auth_endpoint.parse()?; let api = console::provider::mock::Api::new(url); diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index e735b9f66c..ccb5cbdb92 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -1,9 +1,10 @@ +#[cfg(feature = "testing")] pub mod mock; pub mod neon; use super::messages::MetricsAuxInfo; use crate::{ - auth::ClientCredentials, + auth::backend::ComputeUserInfo, cache::{timed_lru, TimedLru}, compute, scram, }; @@ -205,6 +206,7 @@ pub struct ConsoleReqExtra<'a> { /// Auth secret which is managed by the cloud. pub enum AuthSecret { + #[cfg(feature = "testing")] /// Md5 hash of user's password. Md5([u8; 16]), @@ -247,20 +249,20 @@ pub trait Api { async fn get_auth_info( &self, extra: &ConsoleReqExtra<'_>, - creds: &ClientCredentials, + creds: &ComputeUserInfo, ) -> Result; async fn get_allowed_ips( &self, extra: &ConsoleReqExtra<'_>, - creds: &ClientCredentials, + creds: &ComputeUserInfo, ) -> Result>, errors::GetAuthInfoError>; /// Wake up the compute node and return the corresponding connection info. async fn wake_compute( &self, extra: &ConsoleReqExtra<'_>, - creds: &ClientCredentials, + creds: &ComputeUserInfo, ) -> Result; } diff --git a/proxy/src/console/provider/mock.rs b/proxy/src/console/provider/mock.rs index 4cc68f0ac1..8aad8c06bc 100644 --- a/proxy/src/console/provider/mock.rs +++ b/proxy/src/console/provider/mock.rs @@ -6,7 +6,7 @@ use super::{ errors::{ApiError, GetAuthInfoError, WakeComputeError}, AuthInfo, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo, }; -use crate::{auth::ClientCredentials, compute, error::io_error, scram, url::ApiUrl}; +use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl}; use async_trait::async_trait; use futures::TryFutureExt; use thiserror::Error; @@ -47,7 +47,7 @@ impl Api { async fn do_get_auth_info( &self, - creds: &ClientCredentials<'_>, + creds: &ComputeUserInfo, ) -> Result { let (secret, allowed_ips) = async { // Perhaps we could persist this connection, but then we'd have to @@ -60,7 +60,7 @@ impl Api { let secret = match get_execute_postgres_query( &client, "select rolpassword from pg_catalog.pg_authid where rolname = $1", - &[&creds.user], + &[&&*creds.inner.user], "rolpassword", ) .await? @@ -71,14 +71,14 @@ impl Api { secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5)) } None => { - warn!("user '{}' does not exist", creds.user); + warn!("user '{}' does not exist", creds.inner.user); None } }; let allowed_ips = match get_execute_postgres_query( &client, "select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1", - &[&creds.project.clone().unwrap_or_default().as_str()], + &[&creds.endpoint.as_str()], "allowed_ips", ) .await? @@ -145,7 +145,7 @@ impl super::Api for Api { async fn get_auth_info( &self, _extra: &ConsoleReqExtra<'_>, - creds: &ClientCredentials, + creds: &ComputeUserInfo, ) -> Result { self.do_get_auth_info(creds).await } @@ -153,7 +153,7 @@ impl super::Api for Api { async fn get_allowed_ips( &self, _extra: &ConsoleReqExtra<'_>, - creds: &ClientCredentials, + creds: &ComputeUserInfo, ) -> Result>, GetAuthInfoError> { Ok(Arc::new(self.do_get_auth_info(creds).await?.allowed_ips)) } @@ -162,7 +162,7 @@ impl super::Api for Api { async fn wake_compute( &self, _extra: &ConsoleReqExtra<'_>, - _creds: &ClientCredentials, + _creds: &ComputeUserInfo, ) -> Result { self.do_wake_compute() .map_ok(CachedNodeInfo::new_uncached) diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index 7828a7d7e4..f0510e91ea 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -5,12 +5,8 @@ use super::{ errors::{ApiError, GetAuthInfoError, WakeComputeError}, ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo, }; -use crate::{ - auth::ClientCredentials, - compute, http, - proxy::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER}, - scram, -}; +use crate::proxy::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER}; +use crate::{auth::backend::ComputeUserInfo, compute, http, scram}; use async_trait::async_trait; use futures::TryFutureExt; use itertools::Itertools; @@ -53,7 +49,7 @@ impl Api { async fn do_get_auth_info( &self, extra: &ConsoleReqExtra<'_>, - creds: &ClientCredentials<'_>, + creds: &ComputeUserInfo, ) -> Result { let request_id = uuid::Uuid::new_v4().to_string(); async { @@ -65,8 +61,8 @@ impl Api { .query(&[("session_id", extra.session_id)]) .query(&[ ("application_name", extra.application_name), - ("project", Some(creds.project().expect("impossible"))), - ("role", Some(creds.user)), + ("project", Some(&creds.endpoint)), + ("role", Some(&creds.inner.user)), ]) .build()?; @@ -106,9 +102,8 @@ impl Api { async fn do_wake_compute( &self, extra: &ConsoleReqExtra<'_>, - creds: &ClientCredentials<'_>, + creds: &ComputeUserInfo, ) -> Result { - let project = creds.project().expect("impossible"); let request_id = uuid::Uuid::new_v4().to_string(); async { let request = self @@ -119,7 +114,7 @@ impl Api { .query(&[("session_id", extra.session_id)]) .query(&[ ("application_name", extra.application_name), - ("project", Some(project)), + ("project", Some(&creds.endpoint)), ("options", extra.options), ]) .build()?; @@ -162,7 +157,7 @@ impl super::Api for Api { async fn get_auth_info( &self, extra: &ConsoleReqExtra<'_>, - creds: &ClientCredentials, + creds: &ComputeUserInfo, ) -> Result { self.do_get_auth_info(extra, creds).await } @@ -170,9 +165,9 @@ impl super::Api for Api { async fn get_allowed_ips( &self, extra: &ConsoleReqExtra<'_>, - creds: &ClientCredentials, + creds: &ComputeUserInfo, ) -> Result>, GetAuthInfoError> { - let key: &str = creds.project().expect("impossible"); + let key: &str = &creds.endpoint; if let Some(allowed_ips) = self.caches.allowed_ips.get(key) { ALLOWED_IPS_BY_CACHE_OUTCOME .with_label_values(&["hit"]) @@ -193,9 +188,9 @@ impl super::Api for Api { async fn wake_compute( &self, extra: &ConsoleReqExtra<'_>, - creds: &ClientCredentials, + creds: &ComputeUserInfo, ) -> Result { - let key: &str = &creds.cache_key; + let key: &str = &creds.inner.cache_key; // Every time we do a wakeup http request, the compute node will stay up // for some time (highly depends on the console's scale-to-zero policy); diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 3b6d9cb61d..7cf3ed5b8a 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -2,7 +2,7 @@ mod tests; use crate::{ - auth::{self, backend::AuthSuccess}, + auth, cancellation::{self, CancelMap}, compute::{self, PostgresConnection}, config::{AuthenticationConfig, ProxyConfig, TlsConfig}, @@ -24,7 +24,7 @@ use prometheus::{ IntGaugeVec, }; use regex::Regex; -use std::{error::Error, io, net::SocketAddr, ops::ControlFlow, sync::Arc, time::Instant}; +use std::{error::Error, io, net::IpAddr, ops::ControlFlow, sync::Arc, time::Instant}; use tokio::{ io::{AsyncRead, AsyncWrite, AsyncWriteExt}, time, @@ -318,7 +318,7 @@ pub async fn task_main( .set_nodelay(true) .context("failed to set socket option")?; - handle_client(config, &cancel_map, session_id, socket, ClientMode::Tcp, peer_addr).await + handle_client(config, &cancel_map, session_id, socket, ClientMode::Tcp, peer_addr.ip()).await } .instrument(info_span!("handle_client", ?session_id, peer_addr = tracing::field::Empty)) .unwrap_or_else(move |e| { @@ -408,7 +408,7 @@ pub async fn handle_client( session_id: uuid::Uuid, stream: S, mode: ClientMode, - peer_addr: SocketAddr, + peer_addr: IpAddr, ) -> anyhow::Result<()> { info!( protocol = mode.protocol_label(), @@ -666,7 +666,7 @@ pub async fn connect_to_compute( mechanism: &M, mut node_info: console::CachedNodeInfo, extra: &console::ConsoleReqExtra<'_>, - creds: &auth::BackendType<'_, auth::ClientCredentials<'_>>, + creds: &auth::BackendType<'_, auth::backend::ComputeUserInfo>, mut latency_timer: LatencyTimer, ) -> Result where @@ -696,10 +696,12 @@ where let node_info = loop { let wake_res = match creds { auth::BackendType::Console(api, creds) => api.wake_compute(extra, creds).await, + #[cfg(feature = "testing")] auth::BackendType::Postgres(api, creds) => api.wake_compute(extra, creds).await, // nothing to do? auth::BackendType::Link(_) => return Err(err.into()), // test backend + #[cfg(test)] auth::BackendType::Test(x) => x.wake_compute(), }; @@ -838,7 +840,6 @@ pub fn retry_after(num_retries: u32) -> time::Duration { #[tracing::instrument(skip_all)] async fn prepare_client_connection( node: &compute::PostgresConnection, - reported_auth_ok: bool, session: cancellation::Session<'_>, stream: &mut PqStream, ) -> anyhow::Result<()> { @@ -846,13 +847,6 @@ async fn prepare_client_connection( // The new token (cancel_key_data) will be sent to the client. let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone()); - // Report authentication success if we haven't done this already. - // Note that we do this only (for the most part) after we've connected - // to a compute (see above) which performs its own authentication. - if !reported_auth_ok { - stream.write_message_noflush(&Be::AuthenticationOk)?; - } - // Forward all postgres connection params to the client. // Right now the implementation is very hacky and inefficent (ideally, // we don't need an intermediate hashmap), but at least it should be correct. @@ -921,7 +915,7 @@ struct Client<'a, S> { /// The underlying libpq protocol stream. stream: PqStream>, /// Client credentials that we care about. - creds: auth::BackendType<'a, auth::ClientCredentials<'a>>, + creds: auth::BackendType<'a, auth::ClientCredentials>, /// KV-dictionary with PostgreSQL connection params. params: &'a StartupMessageParams, /// Unique connection ID. @@ -934,7 +928,7 @@ impl<'a, S> Client<'a, S> { /// Construct a new connection context. fn new( stream: PqStream>, - creds: auth::BackendType<'a, auth::ClientCredentials<'a>>, + creds: auth::BackendType<'a, auth::ClientCredentials>, params: &'a StartupMessageParams, session_id: uuid::Uuid, allow_self_signed_compute: bool, @@ -953,7 +947,7 @@ impl Client<'_, S> { /// Let the client authenticate and connect to the designated compute node. // Instrumentation logs endpoint name everywhere. Doesn't work for link // auth; strictly speaking we don't know endpoint name in its case. - #[tracing::instrument(name = "", fields(ep = self.creds.get_endpoint().unwrap_or("".to_owned())), skip_all)] + #[tracing::instrument(name = "", fields(ep = %self.creds.get_endpoint().unwrap_or_default()), skip_all)] async fn connect_to_db( self, session: cancellation::Session<'_>, @@ -962,7 +956,7 @@ impl Client<'_, S> { ) -> anyhow::Result<()> { let Self { mut stream, - mut creds, + creds, params, session_id, allow_self_signed_compute, @@ -978,6 +972,7 @@ impl Client<'_, S> { let mut latency_timer = LatencyTimer::new(mode.protocol_label()); + let user = creds.get_user().to_owned(); let auth_result = match creds .authenticate( &extra, @@ -990,7 +985,6 @@ impl Client<'_, S> { { Ok(auth_result) => auth_result, Err(e) => { - let user = creds.get_user(); let db = params.get("database"); let app = params.get("application_name"); let params_span = tracing::info_span!("", ?user, ?db, ?app); @@ -999,10 +993,7 @@ impl Client<'_, S> { } }; - let AuthSuccess { - reported_auth_ok, - value: mut node_info, - } = auth_result; + let (mut node_info, creds) = auth_result; node_info.allow_self_signed_compute = allow_self_signed_compute; @@ -1025,7 +1016,7 @@ impl Client<'_, S> { NUM_DB_CONNECTIONS_CLOSED_COUNTER.with_label_values(&[proto]).inc(); } - prepare_client_connection(&node, reported_auth_ok, session, &mut stream).await?; + prepare_client_connection(&node, session, &mut stream).await?; // Before proxy passing, forward to compute whatever data is left in the // PqStream input buffer. Normally there is none, but our serverless npm // driver in pipeline mode sends startup, password and first query diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index b97c0efce4..222661db4a 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -3,8 +3,7 @@ mod mitm; use super::*; -use crate::auth::backend::TestBackend; -use crate::auth::ClientCredentials; +use crate::auth::backend::{ComputeUserInfo, TestBackend}; use crate::config::CertResolver; use crate::console::{CachedNodeInfo, NodeInfo}; use crate::{auth, http, sasl, scram}; @@ -109,8 +108,9 @@ fn generate_tls_config<'a>( trait TestAuth: Sized { async fn authenticate( self, - _stream: &mut PqStream>, + stream: &mut PqStream>, ) -> anyhow::Result<()> { + stream.write_message_noflush(&Be::AuthenticationOk)?; Ok(()) } } @@ -168,7 +168,6 @@ async fn dummy_proxy( auth.authenticate(&mut stream).await?; stream - .write_message_noflush(&Be::AuthenticationOk)? .write_message_noflush(&Be::CLIENT_ENCODING)? .write_message(&Be::ReadyForQuery) .await?; @@ -486,7 +485,7 @@ fn helper_create_connect_info( ) -> ( CachedNodeInfo, console::ConsoleReqExtra<'static>, - auth::BackendType<'_, ClientCredentials<'static>>, + auth::BackendType<'_, ComputeUserInfo>, ) { let cache = helper_create_cached_node_info(); let extra = console::ConsoleReqExtra { diff --git a/proxy/src/scram.rs b/proxy/src/scram.rs index 63271309e1..49a7a13043 100644 --- a/proxy/src/scram.rs +++ b/proxy/src/scram.rs @@ -15,7 +15,7 @@ mod signature; #[cfg(any(test, doc))] mod password; -pub use exchange::Exchange; +pub use exchange::{exchange, Exchange}; pub use key::ScramKey; pub use secret::ServerSecret; diff --git a/proxy/src/scram/exchange.rs b/proxy/src/scram/exchange.rs index facaba3798..9af7db5201 100644 --- a/proxy/src/scram/exchange.rs +++ b/proxy/src/scram/exchange.rs @@ -1,5 +1,9 @@ //! Implementation of the SCRAM authentication algorithm. +use std::convert::Infallible; + +use postgres_protocol::authentication::sasl::ScramSha256; + use super::messages::{ ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN, }; @@ -29,22 +33,27 @@ impl std::str::FromStr for TlsServerEndPoint { } } +struct SaslSentInner { + cbind_flag: ChannelBinding, + client_first_message_bare: String, + server_first_message: OwnedServerFirstMessage, +} + +struct SaslInitial { + nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN], +} + enum ExchangeState { /// Waiting for [`ClientFirstMessage`]. - Initial, + Initial(SaslInitial), /// Waiting for [`ClientFinalMessage`]. - SaltSent { - cbind_flag: ChannelBinding, - client_first_message_bare: String, - server_first_message: OwnedServerFirstMessage, - }, + SaltSent(SaslSentInner), } /// Server's side of SCRAM auth algorithm. pub struct Exchange<'a> { state: ExchangeState, secret: &'a ServerSecret, - nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN], tls_server_end_point: config::TlsServerEndPoint, } @@ -55,90 +64,160 @@ impl<'a> Exchange<'a> { tls_server_end_point: config::TlsServerEndPoint, ) -> Self { Self { - state: ExchangeState::Initial, + state: ExchangeState::Initial(SaslInitial { nonce }), secret, - nonce, tls_server_end_point, } } } +pub fn exchange( + secret: &ServerSecret, + mut client: ScramSha256, + tls_server_end_point: config::TlsServerEndPoint, +) -> sasl::Result> { + use sasl::Step::*; + + let init = SaslInitial { + nonce: rand::random, + }; + + let client_first = std::str::from_utf8(client.message()) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + let sent = match init.transition(secret, &tls_server_end_point, client_first)? { + Continue(sent, server_first) => { + client.update(server_first.as_bytes())?; + sent + } + Success(x, _) => match x {}, + Failure(msg) => return Ok(sasl::Outcome::Failure(msg)), + }; + + let client_final = std::str::from_utf8(client.message()) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + let keys = match sent.transition(secret, &tls_server_end_point, client_final)? { + Success(keys, server_final) => { + client.finish(server_final.as_bytes())?; + keys + } + Continue(x, _) => match x {}, + Failure(msg) => return Ok(sasl::Outcome::Failure(msg)), + }; + + Ok(sasl::Outcome::Success(keys)) +} + +impl SaslInitial { + fn transition( + &self, + secret: &ServerSecret, + tls_server_end_point: &config::TlsServerEndPoint, + input: &str, + ) -> sasl::Result> { + let client_first_message = ClientFirstMessage::parse(input) + .ok_or(SaslError::BadClientMessage("invalid client-first-message"))?; + + // If the flag is set to "y" and the server supports channel + // binding, the server MUST fail authentication + if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer + && tls_server_end_point.supported() + { + return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used")); + } + + let server_first_message = client_first_message.build_server_first_message( + &(self.nonce)(), + &secret.salt_base64, + secret.iterations, + ); + let msg = server_first_message.as_str().to_owned(); + + let next = SaslSentInner { + cbind_flag: client_first_message.cbind_flag.and_then(str::parse)?, + client_first_message_bare: client_first_message.bare.to_owned(), + server_first_message, + }; + + Ok(sasl::Step::Continue(next, msg)) + } +} + +impl SaslSentInner { + fn transition( + &self, + secret: &ServerSecret, + tls_server_end_point: &config::TlsServerEndPoint, + input: &str, + ) -> sasl::Result> { + let Self { + cbind_flag, + client_first_message_bare, + server_first_message, + } = self; + + let client_final_message = ClientFinalMessage::parse(input) + .ok_or(SaslError::BadClientMessage("invalid client-final-message"))?; + + let channel_binding = cbind_flag.encode(|_| match tls_server_end_point { + config::TlsServerEndPoint::Sha256(x) => Ok(x), + config::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding), + })?; + + // This might've been caused by a MITM attack + if client_final_message.channel_binding != channel_binding { + return Err(SaslError::ChannelBindingFailed( + "insecure connection: secure channel data mismatch", + )); + } + + if client_final_message.nonce != server_first_message.nonce() { + return Err(SaslError::BadClientMessage("combined nonce doesn't match")); + } + + let signature_builder = SignatureBuilder { + client_first_message_bare, + server_first_message: server_first_message.as_str(), + client_final_message_without_proof: client_final_message.without_proof, + }; + + let client_key = signature_builder + .build(&secret.stored_key) + .derive_client_key(&client_final_message.proof); + + // Auth fails either if keys don't match or it's pre-determined to fail. + if client_key.sha256() != secret.stored_key || secret.doomed { + return Ok(sasl::Step::Failure("password doesn't match")); + } + + let msg = + client_final_message.build_server_final_message(signature_builder, &secret.server_key); + + Ok(sasl::Step::Success(client_key, msg)) + } +} + impl sasl::Mechanism for Exchange<'_> { type Output = super::ScramKey; fn exchange(mut self, input: &str) -> sasl::Result> { use {sasl::Step::*, ExchangeState::*}; match &self.state { - Initial => { - let client_first_message = ClientFirstMessage::parse(input) - .ok_or(SaslError::BadClientMessage("invalid client-first-message"))?; - - // If the flag is set to "y" and the server supports channel - // binding, the server MUST fail authentication - if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer - && self.tls_server_end_point.supported() - { - return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used")); + Initial(init) => { + match init.transition(self.secret, &self.tls_server_end_point, input)? { + Continue(sent, msg) => { + self.state = SaltSent(sent); + Ok(Continue(self, msg)) + } + Success(x, _) => match x {}, + Failure(msg) => Ok(Failure(msg)), } - - let server_first_message = client_first_message.build_server_first_message( - &(self.nonce)(), - &self.secret.salt_base64, - self.secret.iterations, - ); - let msg = server_first_message.as_str().to_owned(); - - self.state = SaltSent { - cbind_flag: client_first_message.cbind_flag.and_then(str::parse)?, - client_first_message_bare: client_first_message.bare.to_owned(), - server_first_message, - }; - - Ok(Continue(self, msg)) } - SaltSent { - cbind_flag, - client_first_message_bare, - server_first_message, - } => { - let client_final_message = ClientFinalMessage::parse(input) - .ok_or(SaslError::BadClientMessage("invalid client-final-message"))?; - - let channel_binding = cbind_flag.encode(|_| match &self.tls_server_end_point { - config::TlsServerEndPoint::Sha256(x) => Ok(x), - config::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding), - })?; - - // This might've been caused by a MITM attack - if client_final_message.channel_binding != channel_binding { - return Err(SaslError::ChannelBindingFailed( - "insecure connection: secure channel data mismatch", - )); + SaltSent(sent) => { + match sent.transition(self.secret, &self.tls_server_end_point, input)? { + Success(keys, msg) => Ok(Success(keys, msg)), + Continue(x, _) => match x {}, + Failure(msg) => Ok(Failure(msg)), } - - if client_final_message.nonce != server_first_message.nonce() { - return Err(SaslError::BadClientMessage("combined nonce doesn't match")); - } - - let signature_builder = SignatureBuilder { - client_first_message_bare, - server_first_message: server_first_message.as_str(), - client_final_message_without_proof: client_final_message.without_proof, - }; - - let client_key = signature_builder - .build(&self.secret.stored_key) - .derive_client_key(&client_final_message.proof); - - // Auth fails either if keys don't match or it's pre-determined to fail. - if client_key.sha256() != self.secret.stored_key || self.secret.doomed { - return Ok(Failure("password doesn't match")); - } - - let msg = client_final_message - .build_server_final_message(signature_builder, &self.secret.server_key); - - Ok(Success(client_key, msg)) } } } diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index 45f8132393..5a992d6461 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -23,7 +23,7 @@ use hyper::{ Body, Method, Request, Response, }; -use std::net::SocketAddr; +use std::net::IpAddr; use std::task::Poll; use std::{future::ready, sync::Arc}; use tls_listener::TlsListener; @@ -103,7 +103,13 @@ pub async fn task_main( let session_id = uuid::Uuid::new_v4(); request_handler( - req, config, conn_pool, cancel_map, session_id, sni_name, peer_addr, + req, + config, + conn_pool, + cancel_map, + session_id, + sni_name, + peer_addr.ip(), ) .instrument(info_span!( "serverless", @@ -171,7 +177,7 @@ async fn request_handler( cancel_map: Arc, session_id: uuid::Uuid, sni_hostname: Option, - peer_addr: SocketAddr, + peer_addr: IpAddr, ) -> Result, ApiError> { let host = request .headers() diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index ca7a9ad0a0..b9d1a9692d 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -1,4 +1,4 @@ -use anyhow::Context; +use anyhow::{anyhow, Context}; use async_trait::async_trait; use dashmap::DashMap; use futures::future::poll_fn; @@ -9,7 +9,7 @@ use pbkdf2::{ }; use pq_proto::StartupMessageParams; use smol_str::SmolStr; -use std::{collections::HashMap, net::SocketAddr, sync::Arc}; +use std::{collections::HashMap, net::IpAddr, sync::Arc}; use std::{ fmt, task::{ready, Poll}, @@ -22,7 +22,7 @@ use tokio::time; use tokio_postgres::{AsyncMessage, ReadyForQueryStatus}; use crate::{ - auth::{self, check_peer_addr_is_in_list}, + auth::{self, backend::ComputeUserInfo, check_peer_addr_is_in_list}, console, proxy::{ neon_options, LatencyTimer, NUM_DB_CONNECTIONS_CLOSED_COUNTER, @@ -146,7 +146,7 @@ impl GlobalConnPool { conn_info: &ConnInfo, force_new: bool, session_id: uuid::Uuid, - peer_addr: SocketAddr, + peer_addr: IpAddr, ) -> anyhow::Result { let mut client: Option = None; let mut latency_timer = LatencyTimer::new("http"); @@ -406,7 +406,7 @@ async fn connect_to_compute( conn_id: uuid::Uuid, session_id: uuid::Uuid, latency_timer: LatencyTimer, - peer_addr: SocketAddr, + peer_addr: IpAddr, ) -> anyhow::Result { let tls = config.tls_config.as_ref(); let common_names = tls.and_then(|tls| tls.common_names.clone()); @@ -423,6 +423,9 @@ async fn connect_to_compute( common_names, peer_addr, )?; + + let creds = + ComputeUserInfo::try_from(creds).map_err(|_| anyhow!("missing endpoint identifier"))?; let backend = config.auth_backend.as_ref().map(|_| creds); let console_options = neon_options(¶ms); @@ -435,7 +438,7 @@ async fn connect_to_compute( // TODO(anna): this is a bit hacky way, consider using console notification listener. if !config.disable_ip_check_for_http { let allowed_ips = backend.get_allowed_ips(&extra).await?; - if !check_peer_addr_is_in_list(&peer_addr.ip(), &allowed_ips) { + if !check_peer_addr_is_in_list(&peer_addr, &allowed_ips) { return Err(auth::AuthError::ip_address_not_allowed().into()); } } diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 6c337a837c..6e80260193 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -1,4 +1,4 @@ -use std::net::SocketAddr; +use std::net::IpAddr; use std::sync::Arc; use anyhow::bail; @@ -202,7 +202,7 @@ pub async fn handle( sni_hostname: Option, conn_pool: Arc, session_id: uuid::Uuid, - peer_addr: SocketAddr, + peer_addr: IpAddr, config: &'static HttpConfig, ) -> Result, ApiError> { let result = tokio::time::timeout( @@ -301,7 +301,7 @@ async fn handle_inner( sni_hostname: Option, conn_pool: Arc, session_id: uuid::Uuid, - peer_addr: SocketAddr, + peer_addr: IpAddr, ) -> anyhow::Result> { NUM_CONNECTIONS_ACCEPTED_COUNTER .with_label_values(&["http"]) diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 8fb9a3dee4..199b03550d 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -11,7 +11,7 @@ use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream}; use pin_project_lite::pin_project; use std::{ - net::SocketAddr, + net::IpAddr, pin::Pin, task::{ready, Context, Poll}, }; @@ -133,7 +133,7 @@ pub async fn serve_websocket( cancel_map: &CancelMap, session_id: uuid::Uuid, hostname: Option, - peer_addr: SocketAddr, + peer_addr: IpAddr, ) -> anyhow::Result<()> { let websocket = websocket.await?; handle_client(