diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index 64ef108e11..0707c1331f 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -4,7 +4,7 @@ pub mod backend; pub use backend::BackendType; mod credentials; -pub use credentials::{check_peer_addr_is_in_list, ClientCredentials}; +pub use credentials::{check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint}; mod password_hack; pub use password_hack::parse_endpoint_param; diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index a4c5512521..120ed46992 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -14,14 +14,15 @@ use crate::console::AuthSecret; use crate::context::RequestMonitoring; use crate::proxy::connect_compute::handle_try_wake; use crate::proxy::retry::retry_after; +use crate::proxy::NeonOptions; use crate::scram; use crate::stream::Stream; use crate::{ - auth::{self, ClientCredentials}, + auth::{self, ComputeUserInfoMaybeEndpoint}, config::AuthenticationConfig, console::{ self, - provider::{CachedAllowedIps, CachedNodeInfo, ConsoleReqExtra}, + provider::{CachedAllowedIps, CachedNodeInfo}, Api, }, stream, url, @@ -38,7 +39,7 @@ use tracing::{error, info, warn}; /// * When `T` is `()`, it's just a regular auth backend selector /// which we use in [`crate::config::ProxyConfig`]. /// -/// * However, when we substitute `T` with [`ClientCredentials`], +/// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`], /// this helps us provide the credentials only to those auth /// backends which require them for the authentication process. pub enum BackendType<'a, T> { @@ -127,14 +128,23 @@ pub struct ComputeCredentials { pub keys: T, } +#[derive(Debug, Clone)] pub struct ComputeUserInfoNoEndpoint { pub user: SmolStr, - pub cache_key: SmolStr, + pub options: NeonOptions, } +#[derive(Debug, Clone)] pub struct ComputeUserInfo { pub endpoint: SmolStr, - pub inner: ComputeUserInfoNoEndpoint, + pub user: SmolStr, + pub options: NeonOptions, +} + +impl ComputeUserInfo { + pub fn endpoint_cache_key(&self) -> SmolStr { + self.options.get_cache_key(&self.endpoint) + } } pub enum ComputeCredentialKeys { @@ -143,18 +153,21 @@ pub enum ComputeCredentialKeys { AuthKeys(AuthKeys), } -impl TryFrom for ComputeUserInfo { +impl TryFrom for ComputeUserInfo { // user name type Error = ComputeUserInfoNoEndpoint; - fn try_from(creds: ClientCredentials) -> Result { - let inner = ComputeUserInfoNoEndpoint { - user: creds.user, - cache_key: creds.cache_key, - }; - match creds.project { - None => Err(inner), - Some(endpoint) => Ok(ComputeUserInfo { endpoint, inner }), + fn try_from(user_info: ComputeUserInfoMaybeEndpoint) -> Result { + match user_info.project { + None => Err(ComputeUserInfoNoEndpoint { + user: user_info.user, + options: user_info.options, + }), + Some(endpoint) => Ok(ComputeUserInfo { + endpoint, + user: user_info.user, + options: user_info.options, + }), } } } @@ -166,7 +179,7 @@ impl TryFrom for ComputeUserInfo { async fn auth_quirks( ctx: &mut RequestMonitoring, api: &impl console::Api, - creds: ClientCredentials, + user_info: ComputeUserInfoMaybeEndpoint, client: &mut stream::PqStream>, allow_cleartext: bool, config: &'static AuthenticationConfig, @@ -174,7 +187,7 @@ async fn auth_quirks( // If there's no project so far, that entails that client doesn't // support SNI or other means of passing the endpoint (project) name. // We now expect to see a very specific payload in the place of password. - let (info, unauthenticated_password) = match creds.try_into() { + let (info, unauthenticated_password) = match user_info.try_into() { Err(info) => { let res = hacks::password_hack_no_authentication(info, client, &mut ctx.latency_timer) .await?; @@ -199,7 +212,7 @@ async fn auth_quirks( // This mocked secret will never lead to successful authentication. info!("authentication info not found, mocking it"); Cached::new_uncached(AuthSecret::Scram(scram::ServerSecret::mock( - &info.inner.user, + &info.user, rand::random(), ))) }); @@ -240,7 +253,7 @@ async fn authenticate_with_secret( crate::sasl::Outcome::Success(key) => key, crate::sasl::Outcome::Failure(reason) => { info!("auth backend failed with an error: {reason}"); - return Err(auth::AuthError::auth_failed(&*info.inner.user)); + return Err(auth::AuthError::auth_failed(&*info.user)); } }; @@ -267,19 +280,17 @@ async fn authenticate_with_secret( async fn auth_and_wake_compute( ctx: &mut RequestMonitoring, api: &impl console::Api, - extra: &ConsoleReqExtra, - creds: ClientCredentials, + user_info: ComputeUserInfoMaybeEndpoint, client: &mut stream::PqStream>, allow_cleartext: bool, config: &'static AuthenticationConfig, ) -> auth::Result<(CachedNodeInfo, ComputeUserInfo)> { - let compute_credentials = auth_quirks(ctx, api, creds, client, allow_cleartext, config).await?; + let compute_credentials = + auth_quirks(ctx, api, user_info, client, allow_cleartext, config).await?; let mut num_retries = 0; let mut node = loop { - let wake_res = api - .wake_compute(ctx, extra, &compute_credentials.info) - .await; + let wake_res = api.wake_compute(ctx, &compute_credentials.info).await; match handle_try_wake(wake_res, num_retries) { Err(e) => { error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node"); @@ -307,15 +318,15 @@ async fn auth_and_wake_compute( Ok((node, compute_credentials.info)) } -impl<'a> BackendType<'a, ClientCredentials> { +impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { /// Get compute endpoint name from the credentials. pub fn get_endpoint(&self) -> Option { use BackendType::*; match self { - Console(_, creds) => creds.project.clone(), + Console(_, user_info) => user_info.project.clone(), #[cfg(feature = "testing")] - Postgres(_, creds) => creds.project.clone(), + Postgres(_, user_info) => user_info.project.clone(), Link(_) => Some("link".into()), #[cfg(test)] Test(_) => Some("test".into()), @@ -327,9 +338,9 @@ impl<'a> BackendType<'a, ClientCredentials> { use BackendType::*; match self { - Console(_, creds) => &creds.user, + Console(_, user_info) => &user_info.user, #[cfg(feature = "testing")] - Postgres(_, creds) => &creds.user, + Postgres(_, user_info) => &user_info.user, Link(_) => "link", #[cfg(test)] Test(_) => "test", @@ -341,7 +352,6 @@ impl<'a> BackendType<'a, ClientCredentials> { pub async fn authenticate( self, ctx: &mut RequestMonitoring, - extra: &ConsoleReqExtra, client: &mut stream::PqStream>, allow_cleartext: bool, config: &'static AuthenticationConfig, @@ -349,43 +359,29 @@ impl<'a> BackendType<'a, ClientCredentials> { use BackendType::*; let res = match self { - Console(api, creds) => { + Console(api, user_info) => { info!( - user = &*creds.user, - project = creds.project(), + user = &*user_info.user, + project = user_info.project(), "performing authentication using the console" ); - let (cache_info, user_info) = auth_and_wake_compute( - ctx, - &*api, - extra, - creds, - client, - allow_cleartext, - config, - ) - .await?; + let (cache_info, user_info) = + auth_and_wake_compute(ctx, &*api, user_info, client, allow_cleartext, config) + .await?; (cache_info, BackendType::Console(api, user_info)) } #[cfg(feature = "testing")] - Postgres(api, creds) => { + Postgres(api, user_info) => { info!( - user = &*creds.user, - project = creds.project(), + user = &*user_info.user, + project = user_info.project(), "performing authentication using a local postgres instance" ); - let (cache_info, user_info) = auth_and_wake_compute( - ctx, - &*api, - extra, - creds, - client, - allow_cleartext, - config, - ) - .await?; + let (cache_info, user_info) = + auth_and_wake_compute(ctx, &*api, user_info, client, allow_cleartext, config) + .await?; (cache_info, BackendType::Postgres(api, user_info)) } // NOTE: this auth backend doesn't use client credentials. @@ -417,9 +413,9 @@ impl BackendType<'_, ComputeUserInfo> { ) -> Result { use BackendType::*; match self { - Console(api, creds) => api.get_allowed_ips(ctx, creds).await, + Console(api, user_info) => api.get_allowed_ips(ctx, user_info).await, #[cfg(feature = "testing")] - Postgres(api, creds) => api.get_allowed_ips(ctx, creds).await, + Postgres(api, user_info) => api.get_allowed_ips(ctx, user_info).await, Link(_) => Ok(Cached::new_uncached(Arc::new(vec![]))), #[cfg(test)] Test(x) => Ok(Cached::new_uncached(Arc::new(x.get_allowed_ips()?))), @@ -431,14 +427,13 @@ impl BackendType<'_, ComputeUserInfo> { pub async fn wake_compute( &self, ctx: &mut RequestMonitoring, - extra: &ConsoleReqExtra, ) -> Result, console::errors::WakeComputeError> { use BackendType::*; match self { - Console(api, creds) => api.wake_compute(ctx, extra, creds).map_ok(Some).await, + Console(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await, #[cfg(feature = "testing")] - Postgres(api, creds) => api.wake_compute(ctx, extra, creds).map_ok(Some).await, + Postgres(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await, Link(_) => Ok(None), #[cfg(test)] Test(x) => x.wake_compute().map(Some), diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index 5c394ec649..358b335b88 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -54,7 +54,7 @@ pub(super) async fn authenticate( sasl::Outcome::Success(key) => key, sasl::Outcome::Failure(reason) => { info!("auth backend failed with an error: {reason}"); - return Err(auth::AuthError::auth_failed(&*creds.inner.user)); + return Err(auth::AuthError::auth_failed(&*creds.user)); } }; diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index 5dde514bca..b6c1a92d3c 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -36,7 +36,7 @@ pub async fn authenticate_cleartext( sasl::Outcome::Success(key) => key, sasl::Outcome::Failure(reason) => { info!("auth backend failed with an error: {reason}"); - return Err(auth::AuthError::auth_failed(&*info.inner.user)); + return Err(auth::AuthError::auth_failed(&*info.user)); } }; @@ -67,7 +67,8 @@ pub async fn password_hack_no_authentication( // Report tentative success; compute node will check the password anyway. Ok(ComputeCredentials { info: ComputeUserInfo { - inner: info, + user: info.user, + options: info.options, endpoint: payload.endpoint, }, keys: payload.password, diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index d282e894c8..ada7f3614c 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -2,7 +2,7 @@ use crate::{ auth::password_hack::parse_endpoint_param, context::RequestMonitoring, error::UserFacingError, - metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::neon_options_str, + metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions, }; use itertools::Itertools; use pq_proto::StartupMessageParams; @@ -12,7 +12,7 @@ use thiserror::Error; use tracing::{info, warn}; #[derive(Debug, Error, PartialEq, Eq, Clone)] -pub enum ClientCredsParseError { +pub enum ComputeUserInfoParseError { #[error("Parameter '{0}' is missing in startup packet.")] MissingKey(&'static str), @@ -33,34 +33,49 @@ pub enum ClientCredsParseError { MalformedProjectName(SmolStr), } -impl UserFacingError for ClientCredsParseError {} +impl UserFacingError for ComputeUserInfoParseError {} /// Various client credentials which we use for authentication. /// Note that we don't store any kind of client key or password here. #[derive(Debug, Clone, PartialEq, Eq)] -pub struct ClientCredentials { +pub struct ComputeUserInfoMaybeEndpoint { pub user: SmolStr, // TODO: this is a severe misnomer! We should think of a new name ASAP. pub project: Option, - pub cache_key: SmolStr, + pub options: NeonOptions, } -impl ClientCredentials { +impl ComputeUserInfoMaybeEndpoint { #[inline] pub fn project(&self) -> Option<&str> { self.project.as_deref() } } -impl ClientCredentials { +pub fn endpoint_sni<'a>( + sni: &'a str, + common_names: &HashSet, +) -> Result<&'a str, ComputeUserInfoParseError> { + let Some((subdomain, common_name)) = sni.split_once('.') else { + return Err(ComputeUserInfoParseError::UnknownCommonName { cn: sni.into() }); + }; + if !common_names.contains(common_name) { + return Err(ComputeUserInfoParseError::UnknownCommonName { + cn: common_name.into(), + }); + } + Ok(subdomain) +} + +impl ComputeUserInfoMaybeEndpoint { pub fn parse( ctx: &mut RequestMonitoring, params: &StartupMessageParams, sni: Option<&str>, - common_names: Option>, - ) -> Result { - use ClientCredsParseError::*; + common_names: Option<&HashSet>, + ) -> Result { + use ComputeUserInfoParseError::*; // Some parameters are stored in the startup message. let get_param = |key| params.get(key).ok_or(MissingKey(key)); @@ -87,21 +102,7 @@ impl ClientCredentials { let project_from_domain = if let Some(sni_str) = sni { if let Some(cn) = common_names { - let common_name_from_sni = sni_str.split_once('.').map(|(_, domain)| domain); - - let project = common_name_from_sni - .and_then(|domain| { - if cn.contains(domain) { - subdomain_from_sni(sni_str, domain) - } else { - None - } - }) - .ok_or_else(|| UnknownCommonName { - cn: common_name_from_sni.unwrap_or("").into(), - })?; - - Some(project) + Some(SmolStr::from(endpoint_sni(sni_str, cn)?)) } else { None } @@ -140,17 +141,12 @@ impl ClientCredentials { info!("Connection with password hack"); } - let cache_key = format!( - "{}{}", - project.as_deref().unwrap_or(""), - neon_options_str(params) - ) - .into(); + let options = NeonOptions::parse_params(params); Ok(Self { user, project, - cache_key, + options, }) } } @@ -207,25 +203,19 @@ fn project_name_valid(name: &str) -> bool { name.chars().all(|c| c.is_alphanumeric() || c == '-') } -fn subdomain_from_sni(sni: &str, common_name: &str) -> Option { - sni.strip_suffix(common_name)? - .strip_suffix('.') - .map(SmolStr::from) -} - #[cfg(test)] mod tests { use super::*; - use ClientCredsParseError::*; + use ComputeUserInfoParseError::*; #[test] fn parse_bare_minimum() -> anyhow::Result<()> { // According to postgresql, only `user` should be required. let options = StartupMessageParams::new([("user", "john_doe")]); let mut ctx = RequestMonitoring::test(); - let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?; - assert_eq!(creds.user, "john_doe"); - assert_eq!(creds.project, None); + let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?; + assert_eq!(user_info.user, "john_doe"); + assert_eq!(user_info.project, None); Ok(()) } @@ -238,9 +228,9 @@ mod tests { ("foo", "bar"), // should be ignored ]); let mut ctx = RequestMonitoring::test(); - let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?; - assert_eq!(creds.user, "john_doe"); - assert_eq!(creds.project, None); + let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?; + assert_eq!(user_info.user, "john_doe"); + assert_eq!(user_info.project, None); Ok(()) } @@ -253,10 +243,11 @@ mod tests { let common_names = Some(["localhost".into()].into()); let mut ctx = RequestMonitoring::test(); - let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?; - assert_eq!(creds.user, "john_doe"); - assert_eq!(creds.project.as_deref(), Some("foo")); - assert_eq!(creds.cache_key, "foo"); + let user_info = + ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?; + assert_eq!(user_info.user, "john_doe"); + assert_eq!(user_info.project.as_deref(), Some("foo")); + assert_eq!(user_info.options.get_cache_key("foo"), "foo"); Ok(()) } @@ -269,9 +260,9 @@ mod tests { ]); let mut ctx = RequestMonitoring::test(); - let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?; - assert_eq!(creds.user, "john_doe"); - assert_eq!(creds.project.as_deref(), Some("bar")); + let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?; + assert_eq!(user_info.user, "john_doe"); + assert_eq!(user_info.project.as_deref(), Some("bar")); Ok(()) } @@ -284,9 +275,9 @@ mod tests { ]); let mut ctx = RequestMonitoring::test(); - let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?; - assert_eq!(creds.user, "john_doe"); - assert_eq!(creds.project.as_deref(), Some("bar")); + let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?; + assert_eq!(user_info.user, "john_doe"); + assert_eq!(user_info.project.as_deref(), Some("bar")); Ok(()) } @@ -302,9 +293,9 @@ mod tests { ]); let mut ctx = RequestMonitoring::test(); - let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?; - assert_eq!(creds.user, "john_doe"); - assert!(creds.project.is_none()); + let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?; + assert_eq!(user_info.user, "john_doe"); + assert!(user_info.project.is_none()); Ok(()) } @@ -317,9 +308,9 @@ mod tests { ]); let mut ctx = RequestMonitoring::test(); - let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?; - assert_eq!(creds.user, "john_doe"); - assert!(creds.project.is_none()); + let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?; + assert_eq!(user_info.user, "john_doe"); + assert!(user_info.project.is_none()); Ok(()) } @@ -332,9 +323,10 @@ mod tests { let common_names = Some(["localhost".into()].into()); let mut ctx = RequestMonitoring::test(); - let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?; - assert_eq!(creds.user, "john_doe"); - assert_eq!(creds.project.as_deref(), Some("baz")); + let user_info = + ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?; + assert_eq!(user_info.user, "john_doe"); + assert_eq!(user_info.project.as_deref(), Some("baz")); Ok(()) } @@ -346,14 +338,16 @@ mod tests { let common_names = Some(["a.com".into(), "b.com".into()].into()); let sni = Some("p1.a.com"); let mut ctx = RequestMonitoring::test(); - let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?; - assert_eq!(creds.project.as_deref(), Some("p1")); + let user_info = + ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?; + assert_eq!(user_info.project.as_deref(), Some("p1")); let common_names = Some(["a.com".into(), "b.com".into()].into()); let sni = Some("p1.b.com"); let mut ctx = RequestMonitoring::test(); - let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?; - assert_eq!(creds.project.as_deref(), Some("p1")); + let user_info = + ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?; + assert_eq!(user_info.project.as_deref(), Some("p1")); Ok(()) } @@ -367,8 +361,9 @@ mod tests { let common_names = Some(["localhost".into()].into()); let mut ctx = RequestMonitoring::test(); - let err = ClientCredentials::parse(&mut ctx, &options, sni, common_names) - .expect_err("should fail"); + let err = + ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref()) + .expect_err("should fail"); match err { InconsistentProjectNames { domain, option } => { assert_eq!(option, "first"); @@ -386,8 +381,9 @@ mod tests { let common_names = Some(["example.com".into()].into()); let mut ctx = RequestMonitoring::test(); - let err = ClientCredentials::parse(&mut ctx, &options, sni, common_names) - .expect_err("should fail"); + let err = + ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref()) + .expect_err("should fail"); match err { UnknownCommonName { cn } => { assert_eq!(cn, "localhost"); @@ -406,9 +402,13 @@ mod tests { let sni = Some("project.localhost"); let common_names = Some(["localhost".into()].into()); let mut ctx = RequestMonitoring::test(); - let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?; - assert_eq!(creds.project.as_deref(), Some("project")); - assert_eq!(creds.cache_key, "projectendpoint_type:read_write lsn:0/2"); + let user_info = + ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?; + assert_eq!(user_info.project.as_deref(), Some("project")); + assert_eq!( + user_info.options.get_cache_key("project"), + "project endpoint_type:read_write lsn:0/2" + ); Ok(()) } diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 043d8d0791..2c46458a49 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -32,7 +32,7 @@ pub struct MetricCollectionConfig { pub struct TlsConfig { pub config: Arc, - pub common_names: Option>, + pub common_names: HashSet, pub cert_resolver: Arc, } @@ -97,7 +97,7 @@ pub fn configure_tls( Ok(TlsConfig { config, - common_names: Some(common_names), + common_names, cert_resolver, }) } diff --git a/proxy/src/console.rs b/proxy/src/console.rs index 07bc807950..fd3c46b946 100644 --- a/proxy/src/console.rs +++ b/proxy/src/console.rs @@ -6,7 +6,7 @@ pub mod messages; /// Wrappers for console APIs and their mocks. pub mod provider; -pub use provider::{errors, Api, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo}; +pub use provider::{errors, Api, AuthSecret, CachedNodeInfo, NodeInfo}; /// Various cache-related types. pub mod caches { diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index 9497d36bc7..84c43183cc 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -197,22 +197,6 @@ pub mod errors { } } -/// Extra query params we'd like to pass to the console. -pub struct ConsoleReqExtra { - pub options: Vec<(String, String)>, -} - -impl ConsoleReqExtra { - // https://swagger.io/docs/specification/serialization/ DeepObject format - // paramName[prop1]=value1¶mName[prop2]=value2&.... - pub fn options_as_deep_object(&self) -> Vec<(String, String)> { - self.options - .iter() - .map(|(k, v)| (format!("options[{}]", k), v.to_string())) - .collect() - } -} - /// Auth secret which is managed by the cloud. #[derive(Clone, Eq, PartialEq, Debug)] pub enum AuthSecret { @@ -249,7 +233,7 @@ pub struct NodeInfo { pub allow_self_signed_compute: bool, } -pub type NodeInfoCache = TimedLru, NodeInfo>; +pub type NodeInfoCache = TimedLru; pub type CachedNodeInfo = Cached<&'static NodeInfoCache>; pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, AuthSecret>; pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc>>; @@ -277,7 +261,6 @@ pub trait Api { async fn wake_compute( &self, ctx: &mut RequestMonitoring, - extra: &ConsoleReqExtra, creds: &ComputeUserInfo, ) -> Result; } @@ -310,7 +293,7 @@ impl ApiCaches { /// Various caches for [`console`](super). pub struct ApiLocks { name: &'static str, - node_locks: DashMap, Arc>, + node_locks: DashMap>, permits: usize, timeout: Duration, registered: prometheus::IntCounter, @@ -378,7 +361,7 @@ impl ApiLocks { pub async fn get_wake_compute_permit( &self, - key: &Arc, + key: &SmolStr, ) -> Result { if self.permits == 0 { return Ok(WakeComputePermit { permit: None }); diff --git a/proxy/src/console/provider/mock.rs b/proxy/src/console/provider/mock.rs index 8f50865288..cc35a06708 100644 --- a/proxy/src/console/provider/mock.rs +++ b/proxy/src/console/provider/mock.rs @@ -2,7 +2,7 @@ use super::{ errors::{ApiError, GetAuthInfoError, WakeComputeError}, - AuthInfo, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo, + AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo, }; use crate::cache::Cached; use crate::console::provider::{CachedAllowedIps, CachedRoleSecret}; @@ -50,7 +50,7 @@ impl Api { async fn do_get_auth_info( &self, - creds: &ComputeUserInfo, + user_info: &ComputeUserInfo, ) -> Result { let (secret, allowed_ips) = async { // Perhaps we could persist this connection, but then we'd have to @@ -63,7 +63,7 @@ impl Api { let secret = match get_execute_postgres_query( &client, "select rolpassword from pg_catalog.pg_authid where rolname = $1", - &[&&*creds.inner.user], + &[&&*user_info.user], "rolpassword", ) .await? @@ -74,14 +74,14 @@ impl Api { secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5)) } None => { - warn!("user '{}' does not exist", creds.inner.user); + warn!("user '{}' does not exist", user_info.user); None } }; let allowed_ips = match get_execute_postgres_query( &client, "select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1", - &[&creds.endpoint.as_str()], + &[&user_info.endpoint.as_str()], "allowed_ips", ) .await? @@ -149,10 +149,10 @@ impl super::Api for Api { async fn get_role_secret( &self, _ctx: &mut RequestMonitoring, - creds: &ComputeUserInfo, + user_info: &ComputeUserInfo, ) -> Result, GetAuthInfoError> { Ok(self - .do_get_auth_info(creds) + .do_get_auth_info(user_info) .await? .secret .map(CachedRoleSecret::new_uncached)) @@ -161,10 +161,10 @@ impl super::Api for Api { async fn get_allowed_ips( &self, _ctx: &mut RequestMonitoring, - creds: &ComputeUserInfo, + user_info: &ComputeUserInfo, ) -> Result { Ok(Cached::new_uncached(Arc::new( - self.do_get_auth_info(creds).await?.allowed_ips, + self.do_get_auth_info(user_info).await?.allowed_ips, ))) } @@ -172,8 +172,7 @@ impl super::Api for Api { async fn wake_compute( &self, _ctx: &mut RequestMonitoring, - _extra: &ConsoleReqExtra, - _creds: &ComputeUserInfo, + _user_info: &ComputeUserInfo, ) -> Result { self.do_wake_compute() .map_ok(CachedNodeInfo::new_uncached) diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index e0bb7952b5..b61e7d2301 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -4,7 +4,7 @@ use super::{ super::messages::{ConsoleError, GetRoleSecret, WakeCompute}, errors::{ApiError, GetAuthInfoError, WakeComputeError}, ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, - ConsoleReqExtra, NodeInfo, + NodeInfo, }; use crate::{auth::backend::ComputeUserInfo, compute, http, scram}; use crate::{ @@ -55,7 +55,7 @@ impl Api { async fn do_get_auth_info( &self, ctx: &mut RequestMonitoring, - creds: &ComputeUserInfo, + user_info: &ComputeUserInfo, ) -> Result { let request_id = uuid::Uuid::new_v4().to_string(); let application_name = ctx.console_application_name(); @@ -68,8 +68,8 @@ impl Api { .query(&[("session_id", ctx.session_id)]) .query(&[ ("application_name", application_name.as_str()), - ("project", creds.endpoint.as_str()), - ("role", creds.inner.user.as_str()), + ("project", user_info.endpoint.as_str()), + ("role", user_info.user.as_str()), ]) .build()?; @@ -110,8 +110,7 @@ impl Api { async fn do_wake_compute( &self, ctx: &mut RequestMonitoring, - extra: &ConsoleReqExtra, - creds: &ComputeUserInfo, + user_info: &ComputeUserInfo, ) -> Result { let request_id = uuid::Uuid::new_v4().to_string(); let application_name = ctx.console_application_name(); @@ -124,14 +123,14 @@ impl Api { .query(&[("session_id", ctx.session_id)]) .query(&[ ("application_name", application_name.as_str()), - ("project", creds.endpoint.as_str()), + ("project", user_info.endpoint.as_str()), ]); - request_builder = if extra.options.is_empty() { - request_builder - } else { - request_builder.query(&extra.options_as_deep_object()) - }; + let options = user_info.options.to_deep_object(); + if !options.is_empty() { + request_builder = request_builder.query(&options); + } + let request = request_builder.build()?; info!(url = request.url().as_str(), "sending http request"); @@ -172,14 +171,14 @@ impl super::Api for Api { async fn get_role_secret( &self, ctx: &mut RequestMonitoring, - creds: &ComputeUserInfo, + user_info: &ComputeUserInfo, ) -> Result, GetAuthInfoError> { - let ep = &creds.endpoint; - let user = &creds.inner.user; + let ep = &user_info.endpoint; + let user = &user_info.user; if let Some(role_secret) = self.caches.project_info.get_role_secret(ep, user) { return Ok(Some(role_secret)); } - let auth_info = self.do_get_auth_info(ctx, creds).await?; + let auth_info = self.do_get_auth_info(ctx, user_info).await?; let project_id = auth_info.project_id.unwrap_or(ep.clone()); if let Some(secret) = &auth_info.secret { self.caches @@ -198,9 +197,9 @@ impl super::Api for Api { async fn get_allowed_ips( &self, ctx: &mut RequestMonitoring, - creds: &ComputeUserInfo, + user_info: &ComputeUserInfo, ) -> Result { - let ep = &creds.endpoint; + let ep = &user_info.endpoint; if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(ep) { ALLOWED_IPS_BY_CACHE_OUTCOME .with_label_values(&["hit"]) @@ -210,9 +209,9 @@ impl super::Api for Api { ALLOWED_IPS_BY_CACHE_OUTCOME .with_label_values(&["miss"]) .inc(); - let auth_info = self.do_get_auth_info(ctx, creds).await?; + let auth_info = self.do_get_auth_info(ctx, user_info).await?; let allowed_ips = Arc::new(auth_info.allowed_ips); - let user = &creds.inner.user; + let user = &user_info.user; let project_id = auth_info.project_id.unwrap_or(ep.clone()); if let Some(secret) = &auth_info.secret { self.caches @@ -229,22 +228,19 @@ impl super::Api for Api { async fn wake_compute( &self, ctx: &mut RequestMonitoring, - extra: &ConsoleReqExtra, - creds: &ComputeUserInfo, + user_info: &ComputeUserInfo, ) -> Result { - let key: &str = &creds.inner.cache_key; + let key = user_info.endpoint_cache_key(); // Every time we do a wakeup http request, the compute node will stay up // for some time (highly depends on the console's scale-to-zero policy); // The connection info remains the same during that period of time, // which means that we might cache it to reduce the load and latency. - if let Some(cached) = self.caches.node_info.get(key) { - info!(key = key, "found cached compute node info"); + if let Some(cached) = self.caches.node_info.get(&*key) { + info!(key = &*key, "found cached compute node info"); return Ok(cached); } - let key: Arc = key.into(); - let permit = self.locks.get_wake_compute_permit(&key).await?; // after getting back a permit - it's possible the cache was filled @@ -256,7 +252,7 @@ impl super::Api for Api { } } - let node = self.do_wake_compute(ctx, extra, creds).await?; + let node = self.do_wake_compute(ctx, user_info).await?; let (_, cached) = self.caches.node_info.insert(key.clone(), node); info!(key = &*key, "created a cache entry for compute node info"); diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 4aba222082..84b4c266e6 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -9,7 +9,7 @@ use crate::{ cancellation::{self, CancelMap}, compute, config::{AuthenticationConfig, ProxyConfig, TlsConfig}, - console::{self, messages::MetricsAuxInfo}, + console::messages::MetricsAuxInfo, context::RequestMonitoring, metrics::{ NUM_BYTES_PROXIED_COUNTER, NUM_BYTES_PROXIED_PER_CLIENT_COUNTER, @@ -26,6 +26,7 @@ use itertools::Itertools; use once_cell::sync::OnceCell; use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams}; use regex::Regex; +use smol_str::SmolStr; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio_util::sync::CancellationToken; @@ -198,27 +199,29 @@ pub async fn handle_client( drop(pause); // Extract credentials which we're going to use for auth. - let creds = { + let user_info = { let hostname = mode.hostname(stream.get_ref()); - let common_names = tls.and_then(|tls| tls.common_names.clone()); + let common_names = tls.map(|tls| &tls.common_names); let result = config .auth_backend .as_ref() - .map(|_| auth::ClientCredentials::parse(ctx, ¶ms, hostname, common_names)) + .map(|_| { + auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names) + }) .transpose(); match result { - Ok(creds) => creds, + Ok(user_info) => user_info, Err(e) => stream.throw_error(e).await?, } }; - ctx.set_endpoint_id(creds.get_endpoint()); + ctx.set_endpoint_id(user_info.get_endpoint()); let client = Client::new( stream, - creds, + user_info, ¶ms, mode.allow_self_signed_compute(config), endpoint_rate_limiter, @@ -397,7 +400,7 @@ struct Client<'a, S> { /// The underlying libpq protocol stream. stream: PqStream>, /// Client credentials that we care about. - creds: auth::BackendType<'a, auth::ClientCredentials>, + user_info: auth::BackendType<'a, auth::ComputeUserInfoMaybeEndpoint>, /// KV-dictionary with PostgreSQL connection params. params: &'a StartupMessageParams, /// Allow self-signed certificates (for testing). @@ -410,14 +413,14 @@ impl<'a, S> Client<'a, S> { /// Construct a new connection context. fn new( stream: PqStream>, - creds: auth::BackendType<'a, auth::ClientCredentials>, + user_info: auth::BackendType<'a, auth::ComputeUserInfoMaybeEndpoint>, params: &'a StartupMessageParams, allow_self_signed_compute: bool, endpoint_rate_limiter: Arc, ) -> Self { Self { stream, - creds, + user_info, params, allow_self_signed_compute, endpoint_rate_limiter, @@ -429,7 +432,7 @@ impl Client<'_, S> { /// Let the client authenticate and connect to the designated compute node. // Instrumentation logs endpoint name everywhere. Doesn't work for link // auth; strictly speaking we don't know endpoint name in its case. - #[tracing::instrument(name = "", fields(ep = %self.creds.get_endpoint().unwrap_or_default()), skip_all)] + #[tracing::instrument(name = "", fields(ep = %self.user_info.get_endpoint().unwrap_or_default()), skip_all)] async fn connect_to_db( self, ctx: &mut RequestMonitoring, @@ -439,14 +442,14 @@ impl Client<'_, S> { ) -> anyhow::Result<()> { let Self { mut stream, - creds, + user_info, params, allow_self_signed_compute, endpoint_rate_limiter, } = self; // check rate limit - if let Some(ep) = creds.get_endpoint() { + if let Some(ep) = user_info.get_endpoint() { if !endpoint_rate_limiter.check(ep) { return stream .throw_error(auth::AuthError::too_many_connections()) @@ -454,13 +457,9 @@ impl Client<'_, S> { } } - let extra = console::ConsoleReqExtra { - options: neon_options(params), - }; - - let user = creds.get_user().to_owned(); - let auth_result = match creds - .authenticate(ctx, &extra, &mut stream, mode.allow_cleartext(), config) + let user = user_info.get_user().to_owned(); + let auth_result = match user_info + .authenticate(ctx, &mut stream, mode.allow_cleartext(), config) .await { Ok(auth_result) => auth_result, @@ -473,12 +472,12 @@ impl Client<'_, S> { } }; - let (mut node_info, creds) = auth_result; + let (mut node_info, user_info) = auth_result; node_info.allow_self_signed_compute = allow_self_signed_compute; let aux = node_info.aux.clone(); - let mut node = connect_to_compute(ctx, &TcpMechanism { params }, node_info, &extra, &creds) + let mut node = connect_to_compute(ctx, &TcpMechanism { params }, node_info, &user_info) .or_else(|e| stream.throw_error(e)) .await?; @@ -493,29 +492,52 @@ impl Client<'_, S> { } } -pub fn neon_options(params: &StartupMessageParams) -> Vec<(String, String)> { - #[allow(unstable_name_collisions)] - match params.options_raw() { - Some(options) => options.filter_map(neon_option).collect(), - None => vec![], +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct NeonOptions(Vec<(SmolStr, SmolStr)>); + +impl NeonOptions { + pub fn parse_params(params: &StartupMessageParams) -> Self { + params + .options_raw() + .map(Self::parse_from_iter) + .unwrap_or_default() + } + pub fn parse_options_raw(options: &str) -> Self { + Self::parse_from_iter(StartupMessageParams::parse_options_raw(options)) + } + + fn parse_from_iter<'a>(options: impl Iterator) -> Self { + let mut options = options + .filter_map(neon_option) + .map(|(k, v)| (k.into(), v.into())) + .collect_vec(); + options.sort(); + Self(options) + } + + pub fn get_cache_key(&self, prefix: &str) -> SmolStr { + // prefix + format!(" {k}:{v}") + // kinda jank because SmolStr is immutable + std::iter::once(prefix) + .chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v])) + .collect() + } + + /// DeepObject format + /// `paramName[prop1]=value1¶mName[prop2]=value2&...` + pub fn to_deep_object(&self) -> Vec<(String, SmolStr)> { + self.0 + .iter() + .map(|(k, v)| (format!("options[{}]", k), v.clone())) + .collect() } } -pub fn neon_options_str(params: &StartupMessageParams) -> String { - #[allow(unstable_name_collisions)] - neon_options(params) - .iter() - .map(|(k, v)| format!("{}:{}", k, v)) - .sorted() // we sort it to use as cache key - .intersperse(" ".to_owned()) - .collect() -} - -pub fn neon_option(bytes: &str) -> Option<(String, String)> { +pub fn neon_option(bytes: &str) -> Option<(&str, &str)> { static RE: OnceCell = OnceCell::new(); let re = RE.get_or_init(|| Regex::new(r"^neon_(\w+):(.+)").unwrap()); let cap = re.captures(bytes)?; let (_, [k, v]) = cap.extract(); - Some((k.to_owned(), v.to_owned())) + Some((k, v)) } diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index 1da2dee10b..72cab1fe5d 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -128,8 +128,7 @@ pub async fn connect_to_compute( ctx: &mut RequestMonitoring, mechanism: &M, mut node_info: console::CachedNodeInfo, - extra: &console::ConsoleReqExtra, - creds: &auth::BackendType<'_, auth::backend::ComputeUserInfo>, + user_info: &auth::BackendType<'_, auth::backend::ComputeUserInfo>, ) -> Result where M::ConnectError: ShouldRetry + std::fmt::Debug, @@ -159,10 +158,10 @@ where // if we failed to connect, it's likely that the compute node was suspended, wake a new compute node info!("compute node's state has likely changed; requesting a wake-up"); let node_info = loop { - let wake_res = match creds { - auth::BackendType::Console(api, creds) => api.wake_compute(ctx, extra, creds).await, + let wake_res = match user_info { + auth::BackendType::Console(api, user_info) => api.wake_compute(ctx, user_info).await, #[cfg(feature = "testing")] - auth::BackendType::Postgres(api, creds) => api.wake_compute(ctx, extra, creds).await, + auth::BackendType::Postgres(api, user_info) => api.wake_compute(ctx, user_info).await, // nothing to do? auth::BackendType::Link(_) => return Err(err.into()), // test backend diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 1d7c9bac54..73fde2d7d0 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -7,7 +7,7 @@ use super::retry::ShouldRetry; use super::*; use crate::auth::backend::{ComputeUserInfo, TestBackend}; use crate::config::CertResolver; -use crate::console::{CachedNodeInfo, NodeInfo}; +use crate::console::{self, CachedNodeInfo, NodeInfo}; use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT}; use crate::{auth, http, sasl, scram}; use async_trait::async_trait; @@ -83,7 +83,7 @@ fn generate_tls_config<'a>( let mut cert_resolver = CertResolver::new(); cert_resolver.add_cert(key, vec![cert], true)?; - let common_names = Some(cert_resolver.get_common_names()); + let common_names = cert_resolver.get_common_names(); TlsConfig { config, @@ -487,15 +487,10 @@ fn helper_create_cached_node_info() -> CachedNodeInfo { fn helper_create_connect_info( mechanism: &TestConnectMechanism, -) -> ( - CachedNodeInfo, - console::ConsoleReqExtra, - auth::BackendType<'_, ComputeUserInfo>, -) { +) -> (CachedNodeInfo, auth::BackendType<'_, ComputeUserInfo>) { let cache = helper_create_cached_node_info(); - let extra = console::ConsoleReqExtra { options: vec![] }; - let creds = auth::BackendType::Test(mechanism); - (cache, extra, creds) + let user_info = auth::BackendType::Test(mechanism); + (cache, user_info) } #[tokio::test] @@ -503,8 +498,8 @@ async fn connect_to_compute_success() { use ConnectAction::*; let mut ctx = RequestMonitoring::test(); let mechanism = TestConnectMechanism::new(vec![Connect]); - let (cache, extra, creds) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &extra, &creds) + let (cache, user_info) = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, cache, &user_info) .await .unwrap(); mechanism.verify(); @@ -515,8 +510,8 @@ async fn connect_to_compute_retry() { use ConnectAction::*; let mut ctx = RequestMonitoring::test(); let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Connect]); - let (cache, extra, creds) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &extra, &creds) + let (cache, user_info) = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, cache, &user_info) .await .unwrap(); mechanism.verify(); @@ -528,8 +523,8 @@ async fn connect_to_compute_non_retry_1() { use ConnectAction::*; let mut ctx = RequestMonitoring::test(); let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Fail]); - let (cache, extra, creds) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &extra, &creds) + let (cache, user_info) = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, cache, &user_info) .await .unwrap_err(); mechanism.verify(); @@ -541,8 +536,8 @@ async fn connect_to_compute_non_retry_2() { use ConnectAction::*; let mut ctx = RequestMonitoring::test(); let mechanism = TestConnectMechanism::new(vec![Fail, Wake, Retry, Connect]); - let (cache, extra, creds) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &extra, &creds) + let (cache, user_info) = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, cache, &user_info) .await .unwrap(); mechanism.verify(); @@ -558,8 +553,8 @@ async fn connect_to_compute_non_retry_3() { Retry, Wake, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, /* the 17th time */ Retry, ]); - let (cache, extra, creds) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &extra, &creds) + let (cache, user_info) = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, cache, &user_info) .await .unwrap_err(); mechanism.verify(); @@ -571,8 +566,8 @@ async fn wake_retry() { use ConnectAction::*; let mut ctx = RequestMonitoring::test(); let mechanism = TestConnectMechanism::new(vec![Retry, WakeRetry, Wake, Connect]); - let (cache, extra, creds) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &extra, &creds) + let (cache, user_info) = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, cache, &user_info) .await .unwrap(); mechanism.verify(); @@ -584,8 +579,8 @@ async fn wake_non_retry() { use ConnectAction::*; let mut ctx = RequestMonitoring::test(); let mechanism = TestConnectMechanism::new(vec![Retry, WakeFail]); - let (cache, extra, creds) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &extra, &creds) + let (cache, user_info) = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, cache, &user_info) .await .unwrap_err(); mechanism.verify(); diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index 83a9773052..8af008394a 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -17,6 +17,7 @@ pub use reqwest_middleware::{ClientWithMiddleware, Error}; pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use tokio_util::task::TaskTracker; +use crate::config::TlsConfig; use crate::context::RequestMonitoring; use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE; use crate::protocol2::{ProxyProtocolAccept, WithClientIp}; @@ -69,14 +70,14 @@ pub async fn task_main( } }); - let tls_config = config.tls_config.as_ref().map(|cfg| cfg.to_server_config()); - let tls_acceptor: tokio_rustls::TlsAcceptor = match tls_config { - Some(config) => config.into(), + let tls_config = match config.tls_config.as_ref() { + Some(config) => config, None => { warn!("TLS config is missing, WebSocket Secure server will not be started"); return Ok(()); } }; + let tls_acceptor: tokio_rustls::TlsAcceptor = tls_config.to_server_config().into(); let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?; let _ = addr_incoming.set_nodelay(true); @@ -126,6 +127,7 @@ pub async fn task_main( request_handler( req, config, + tls_config, conn_pool, ws_connections, cancel_map, @@ -195,6 +197,7 @@ where async fn request_handler( mut request: Request, config: &'static ProxyConfig, + tls: &'static TlsConfig, conn_pool: Arc, ws_connections: TaskTracker, cancel_map: Arc, @@ -243,6 +246,7 @@ async fn request_handler( let mut ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region); sql_over_http::handle( + tls, &config.http_config, &mut ctx, request, diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index c9f3fd6a38..787b8bb28e 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -1,4 +1,4 @@ -use anyhow::{anyhow, Context}; +use anyhow::Context; use async_trait::async_trait; use dashmap::DashMap; use futures::{future::poll_fn, Future}; @@ -9,7 +9,6 @@ use pbkdf2::{ password_hash::{PasswordHashString, PasswordHasher, PasswordVerifier, SaltString}, Params, Pbkdf2, }; -use pq_proto::StartupMessageParams; use prometheus::{exponential_buckets, register_histogram, Histogram}; use rand::Rng; use smol_str::SmolStr; @@ -30,7 +29,7 @@ use crate::{ console, context::RequestMonitoring, metrics::NUM_DB_CONNECTIONS_GAUGE, - proxy::{connect_compute::ConnectMechanism, neon_options}, + proxy::connect_compute::ConnectMechanism, usage_metrics::{Ids, MetricCounter, USAGE_METRICS}, }; use crate::{compute, config}; @@ -38,28 +37,37 @@ use crate::{compute, config}; use tracing::{debug, error, warn, Span}; use tracing::{info, info_span, Instrument}; -pub const APP_NAME: &str = "/sql_over_http"; +pub const APP_NAME: SmolStr = SmolStr::new_inline("/sql_over_http"); #[derive(Debug, Clone)] pub struct ConnInfo { - pub username: SmolStr, + pub user_info: ComputeUserInfo, pub dbname: SmolStr, - pub hostname: SmolStr, pub password: SmolStr, - pub options: Option, } impl ConnInfo { // hm, change to hasher to avoid cloning? pub fn db_and_user(&self) -> (SmolStr, SmolStr) { - (self.dbname.clone(), self.username.clone()) + (self.dbname.clone(), self.user_info.user.clone()) + } + + pub fn endpoint_cache_key(&self) -> SmolStr { + self.user_info.endpoint_cache_key() } } impl fmt::Display for ConnInfo { // use custom display to avoid logging password fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}@{}/{}", self.username, self.hostname, self.dbname) + write!( + f, + "{}@{}/{}?{}", + self.user_info.user, + self.user_info.endpoint, + self.dbname, + self.user_info.options.get_cache_key("") + ) } } @@ -319,7 +327,7 @@ impl GlobalConnPool { let mut hash_valid = false; let mut endpoint_pool = Weak::new(); if !force_new { - let pool = self.get_or_create_endpoint_pool(&conn_info.hostname); + let pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key()); endpoint_pool = Arc::downgrade(&pool); let mut hash = None; @@ -401,7 +409,7 @@ impl GlobalConnPool { Err(err) if hash_valid && err.to_string().contains("password authentication failed") => { - let pool = self.get_or_create_endpoint_pool(&conn_info.hostname); + let pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key()); let mut pool = pool.write(); if let Some(entry) = pool.pools.get_mut(&conn_info.db_and_user()) { entry.password_hash = None; @@ -418,7 +426,7 @@ impl GlobalConnPool { }) .await??; - let pool = self.get_or_create_endpoint_pool(&conn_info.hostname); + let pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key()); let mut pool = pool.write(); pool.pools .entry(conn_info.db_and_user()) @@ -520,23 +528,11 @@ async fn connect_to_compute( conn_id: uuid::Uuid, pool: Weak>, ) -> anyhow::Result { - let tls = config.tls_config.as_ref(); - let common_names = tls.and_then(|tls| tls.common_names.clone()); - - let params = StartupMessageParams::new([ - ("user", &conn_info.username), - ("database", &conn_info.dbname), - ("application_name", APP_NAME), - ("options", conn_info.options.as_deref().unwrap_or("")), - ]); - let creds = - auth::ClientCredentials::parse(ctx, ¶ms, Some(&conn_info.hostname), common_names)?; - - let creds = - ComputeUserInfo::try_from(creds).map_err(|_| anyhow!("missing endpoint identifier"))?; - let backend = config.auth_backend.as_ref().map(|_| creds); - - let console_options = neon_options(¶ms); + ctx.set_application(Some(APP_NAME)); + let backend = config + .auth_backend + .as_ref() + .map(|_| conn_info.user_info.clone()); if !config.disable_ip_check_for_http { let allowed_ips = backend.get_allowed_ips(ctx).await?; @@ -544,11 +540,8 @@ async fn connect_to_compute( return Err(auth::AuthError::ip_address_not_allowed().into()); } } - let extra = console::ConsoleReqExtra { - options: console_options, - }; let node_info = backend - .wake_compute(ctx, &extra) + .wake_compute(ctx) .await? .context("missing cache entry from wake_compute")?; @@ -563,7 +556,6 @@ async fn connect_to_compute( idle: config.http_config.pool_options.idle_timeout, }, node_info, - &extra, &backend, ) .await @@ -582,7 +574,7 @@ async fn connect_to_compute_once( let mut session = ctx.session_id; let (client, mut connection) = config - .user(&conn_info.username) + .user(&conn_info.user_info.user) .password(&*conn_info.password) .dbname(&conn_info.dbname) .connect_timeout(timeout) diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 70c0343fa3..719559ed48 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -28,9 +28,13 @@ use url::Url; use utils::http::error::ApiError; use utils::http::json::json_response; +use crate::auth::backend::ComputeUserInfo; +use crate::auth::endpoint_sni; use crate::config::HttpConfig; +use crate::config::TlsConfig; use crate::context::RequestMonitoring; use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE; +use crate::proxy::NeonOptions; use super::conn_pool::ConnInfo; use super::conn_pool::GlobalConnPool; @@ -125,6 +129,7 @@ fn get_conn_info( ctx: &mut RequestMonitoring, headers: &HeaderMap, sni_hostname: Option, + tls: &TlsConfig, ) -> Result { let connection_string = headers .get("Neon-Connection-String") @@ -179,8 +184,10 @@ fn get_conn_info( } } - let hostname: SmolStr = hostname.into(); - ctx.set_endpoint_id(Some(hostname.clone())); + let endpoint = endpoint_sni(hostname, &tls.common_names)?; + + let endpoint: SmolStr = endpoint.into(); + ctx.set_endpoint_id(Some(endpoint.clone())); let pairs = connection_url.query_pairs(); @@ -188,22 +195,27 @@ fn get_conn_info( for (key, value) in pairs { if key == "options" { - options = Some(value.into()); + options = Some(NeonOptions::parse_options_raw(&value)); break; } } + let user_info = ComputeUserInfo { + endpoint, + user: username, + options: options.unwrap_or_default(), + }; + Ok(ConnInfo { - username, + user_info, dbname: dbname.into(), - hostname, password: password.into(), - options, }) } // TODO: return different http error codes pub async fn handle( + tls: &'static TlsConfig, config: &'static HttpConfig, ctx: &mut RequestMonitoring, request: Request, @@ -212,7 +224,7 @@ pub async fn handle( ) -> Result, ApiError> { let result = tokio::time::timeout( config.request_timeout, - handle_inner(config, ctx, request, sni_hostname, conn_pool), + handle_inner(tls, config, ctx, request, sni_hostname, conn_pool), ) .await; let mut response = match result { @@ -294,6 +306,7 @@ pub async fn handle( #[instrument(name = "sql-over-http", fields(pid = tracing::field::Empty), skip_all)] async fn handle_inner( + tls: &'static TlsConfig, config: &'static HttpConfig, ctx: &mut RequestMonitoring, request: Request, @@ -308,7 +321,7 @@ async fn handle_inner( // Determine the destination and connection params // let headers = request.headers(); - let conn_info = get_conn_info(ctx, headers, sni_hostname)?; + let conn_info = get_conn_info(ctx, headers, sni_hostname, tls)?; // Determine the output options. Default behaviour is 'false'. Anything that is not // strictly 'true' assumed to be false.