proxy: refactor how neon-options are handled (#6306)

## Problem

HTTP connection pool was not respecting the PitR options.

## Summary of changes

1. refactor neon_options a bit to allow easier access to cache_key
2. make HTTP not go through `StartupMessageParams`
3. expose SNI processing to replace what was removed in step 2.
This commit is contained in:
Conrad Ludgate
2024-01-11 14:58:31 +00:00
committed by GitHub
parent a84935d266
commit 551f0cc097
16 changed files with 315 additions and 316 deletions

View File

@@ -4,7 +4,7 @@ pub mod backend;
pub use backend::BackendType;
mod credentials;
pub use credentials::{check_peer_addr_is_in_list, ClientCredentials};
pub use credentials::{check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint};
mod password_hack;
pub use password_hack::parse_endpoint_param;

View File

@@ -14,14 +14,15 @@ use crate::console::AuthSecret;
use crate::context::RequestMonitoring;
use crate::proxy::connect_compute::handle_try_wake;
use crate::proxy::retry::retry_after;
use crate::proxy::NeonOptions;
use crate::scram;
use crate::stream::Stream;
use crate::{
auth::{self, ClientCredentials},
auth::{self, ComputeUserInfoMaybeEndpoint},
config::AuthenticationConfig,
console::{
self,
provider::{CachedAllowedIps, CachedNodeInfo, ConsoleReqExtra},
provider::{CachedAllowedIps, CachedNodeInfo},
Api,
},
stream, url,
@@ -38,7 +39,7 @@ use tracing::{error, info, warn};
/// * When `T` is `()`, it's just a regular auth backend selector
/// which we use in [`crate::config::ProxyConfig`].
///
/// * However, when we substitute `T` with [`ClientCredentials`],
/// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
/// this helps us provide the credentials only to those auth
/// backends which require them for the authentication process.
pub enum BackendType<'a, T> {
@@ -127,14 +128,23 @@ pub struct ComputeCredentials<T> {
pub keys: T,
}
#[derive(Debug, Clone)]
pub struct ComputeUserInfoNoEndpoint {
pub user: SmolStr,
pub cache_key: SmolStr,
pub options: NeonOptions,
}
#[derive(Debug, Clone)]
pub struct ComputeUserInfo {
pub endpoint: SmolStr,
pub inner: ComputeUserInfoNoEndpoint,
pub user: SmolStr,
pub options: NeonOptions,
}
impl ComputeUserInfo {
pub fn endpoint_cache_key(&self) -> SmolStr {
self.options.get_cache_key(&self.endpoint)
}
}
pub enum ComputeCredentialKeys {
@@ -143,18 +153,21 @@ pub enum ComputeCredentialKeys {
AuthKeys(AuthKeys),
}
impl TryFrom<ClientCredentials> for ComputeUserInfo {
impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
// user name
type Error = ComputeUserInfoNoEndpoint;
fn try_from(creds: ClientCredentials) -> Result<Self, Self::Error> {
let inner = ComputeUserInfoNoEndpoint {
user: creds.user,
cache_key: creds.cache_key,
};
match creds.project {
None => Err(inner),
Some(endpoint) => Ok(ComputeUserInfo { endpoint, inner }),
fn try_from(user_info: ComputeUserInfoMaybeEndpoint) -> Result<Self, Self::Error> {
match user_info.project {
None => Err(ComputeUserInfoNoEndpoint {
user: user_info.user,
options: user_info.options,
}),
Some(endpoint) => Ok(ComputeUserInfo {
endpoint,
user: user_info.user,
options: user_info.options,
}),
}
}
}
@@ -166,7 +179,7 @@ impl TryFrom<ClientCredentials> for ComputeUserInfo {
async fn auth_quirks(
ctx: &mut RequestMonitoring,
api: &impl console::Api,
creds: ClientCredentials,
user_info: ComputeUserInfoMaybeEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
@@ -174,7 +187,7 @@ async fn auth_quirks(
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name.
// We now expect to see a very specific payload in the place of password.
let (info, unauthenticated_password) = match creds.try_into() {
let (info, unauthenticated_password) = match user_info.try_into() {
Err(info) => {
let res = hacks::password_hack_no_authentication(info, client, &mut ctx.latency_timer)
.await?;
@@ -199,7 +212,7 @@ async fn auth_quirks(
// This mocked secret will never lead to successful authentication.
info!("authentication info not found, mocking it");
Cached::new_uncached(AuthSecret::Scram(scram::ServerSecret::mock(
&info.inner.user,
&info.user,
rand::random(),
)))
});
@@ -240,7 +253,7 @@ async fn authenticate_with_secret(
crate::sasl::Outcome::Success(key) => key,
crate::sasl::Outcome::Failure(reason) => {
info!("auth backend failed with an error: {reason}");
return Err(auth::AuthError::auth_failed(&*info.inner.user));
return Err(auth::AuthError::auth_failed(&*info.user));
}
};
@@ -267,19 +280,17 @@ async fn authenticate_with_secret(
async fn auth_and_wake_compute(
ctx: &mut RequestMonitoring,
api: &impl console::Api,
extra: &ConsoleReqExtra,
creds: ClientCredentials,
user_info: ComputeUserInfoMaybeEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
) -> auth::Result<(CachedNodeInfo, ComputeUserInfo)> {
let compute_credentials = auth_quirks(ctx, api, creds, client, allow_cleartext, config).await?;
let compute_credentials =
auth_quirks(ctx, api, user_info, client, allow_cleartext, config).await?;
let mut num_retries = 0;
let mut node = loop {
let wake_res = api
.wake_compute(ctx, extra, &compute_credentials.info)
.await;
let wake_res = api.wake_compute(ctx, &compute_credentials.info).await;
match handle_try_wake(wake_res, num_retries) {
Err(e) => {
error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
@@ -307,15 +318,15 @@ async fn auth_and_wake_compute(
Ok((node, compute_credentials.info))
}
impl<'a> BackendType<'a, ClientCredentials> {
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
/// Get compute endpoint name from the credentials.
pub fn get_endpoint(&self) -> Option<SmolStr> {
use BackendType::*;
match self {
Console(_, creds) => creds.project.clone(),
Console(_, user_info) => user_info.project.clone(),
#[cfg(feature = "testing")]
Postgres(_, creds) => creds.project.clone(),
Postgres(_, user_info) => user_info.project.clone(),
Link(_) => Some("link".into()),
#[cfg(test)]
Test(_) => Some("test".into()),
@@ -327,9 +338,9 @@ impl<'a> BackendType<'a, ClientCredentials> {
use BackendType::*;
match self {
Console(_, creds) => &creds.user,
Console(_, user_info) => &user_info.user,
#[cfg(feature = "testing")]
Postgres(_, creds) => &creds.user,
Postgres(_, user_info) => &user_info.user,
Link(_) => "link",
#[cfg(test)]
Test(_) => "test",
@@ -341,7 +352,6 @@ impl<'a> BackendType<'a, ClientCredentials> {
pub async fn authenticate(
self,
ctx: &mut RequestMonitoring,
extra: &ConsoleReqExtra,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
@@ -349,43 +359,29 @@ impl<'a> BackendType<'a, ClientCredentials> {
use BackendType::*;
let res = match self {
Console(api, creds) => {
Console(api, user_info) => {
info!(
user = &*creds.user,
project = creds.project(),
user = &*user_info.user,
project = user_info.project(),
"performing authentication using the console"
);
let (cache_info, user_info) = auth_and_wake_compute(
ctx,
&*api,
extra,
creds,
client,
allow_cleartext,
config,
)
.await?;
let (cache_info, user_info) =
auth_and_wake_compute(ctx, &*api, user_info, client, allow_cleartext, config)
.await?;
(cache_info, BackendType::Console(api, user_info))
}
#[cfg(feature = "testing")]
Postgres(api, creds) => {
Postgres(api, user_info) => {
info!(
user = &*creds.user,
project = creds.project(),
user = &*user_info.user,
project = user_info.project(),
"performing authentication using a local postgres instance"
);
let (cache_info, user_info) = auth_and_wake_compute(
ctx,
&*api,
extra,
creds,
client,
allow_cleartext,
config,
)
.await?;
let (cache_info, user_info) =
auth_and_wake_compute(ctx, &*api, user_info, client, allow_cleartext, config)
.await?;
(cache_info, BackendType::Postgres(api, user_info))
}
// NOTE: this auth backend doesn't use client credentials.
@@ -417,9 +413,9 @@ impl BackendType<'_, ComputeUserInfo> {
) -> Result<CachedAllowedIps, GetAuthInfoError> {
use BackendType::*;
match self {
Console(api, creds) => api.get_allowed_ips(ctx, creds).await,
Console(api, user_info) => api.get_allowed_ips(ctx, user_info).await,
#[cfg(feature = "testing")]
Postgres(api, creds) => api.get_allowed_ips(ctx, creds).await,
Postgres(api, user_info) => api.get_allowed_ips(ctx, user_info).await,
Link(_) => Ok(Cached::new_uncached(Arc::new(vec![]))),
#[cfg(test)]
Test(x) => Ok(Cached::new_uncached(Arc::new(x.get_allowed_ips()?))),
@@ -431,14 +427,13 @@ impl BackendType<'_, ComputeUserInfo> {
pub async fn wake_compute(
&self,
ctx: &mut RequestMonitoring,
extra: &ConsoleReqExtra,
) -> Result<Option<CachedNodeInfo>, console::errors::WakeComputeError> {
use BackendType::*;
match self {
Console(api, creds) => api.wake_compute(ctx, extra, creds).map_ok(Some).await,
Console(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await,
#[cfg(feature = "testing")]
Postgres(api, creds) => api.wake_compute(ctx, extra, creds).map_ok(Some).await,
Postgres(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await,
Link(_) => Ok(None),
#[cfg(test)]
Test(x) => x.wake_compute().map(Some),

View File

@@ -54,7 +54,7 @@ pub(super) async fn authenticate(
sasl::Outcome::Success(key) => key,
sasl::Outcome::Failure(reason) => {
info!("auth backend failed with an error: {reason}");
return Err(auth::AuthError::auth_failed(&*creds.inner.user));
return Err(auth::AuthError::auth_failed(&*creds.user));
}
};

View File

@@ -36,7 +36,7 @@ pub async fn authenticate_cleartext(
sasl::Outcome::Success(key) => key,
sasl::Outcome::Failure(reason) => {
info!("auth backend failed with an error: {reason}");
return Err(auth::AuthError::auth_failed(&*info.inner.user));
return Err(auth::AuthError::auth_failed(&*info.user));
}
};
@@ -67,7 +67,8 @@ pub async fn password_hack_no_authentication(
// Report tentative success; compute node will check the password anyway.
Ok(ComputeCredentials {
info: ComputeUserInfo {
inner: info,
user: info.user,
options: info.options,
endpoint: payload.endpoint,
},
keys: payload.password,

View File

@@ -2,7 +2,7 @@
use crate::{
auth::password_hack::parse_endpoint_param, context::RequestMonitoring, error::UserFacingError,
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::neon_options_str,
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions,
};
use itertools::Itertools;
use pq_proto::StartupMessageParams;
@@ -12,7 +12,7 @@ use thiserror::Error;
use tracing::{info, warn};
#[derive(Debug, Error, PartialEq, Eq, Clone)]
pub enum ClientCredsParseError {
pub enum ComputeUserInfoParseError {
#[error("Parameter '{0}' is missing in startup packet.")]
MissingKey(&'static str),
@@ -33,34 +33,49 @@ pub enum ClientCredsParseError {
MalformedProjectName(SmolStr),
}
impl UserFacingError for ClientCredsParseError {}
impl UserFacingError for ComputeUserInfoParseError {}
/// Various client credentials which we use for authentication.
/// Note that we don't store any kind of client key or password here.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ClientCredentials {
pub struct ComputeUserInfoMaybeEndpoint {
pub user: SmolStr,
// TODO: this is a severe misnomer! We should think of a new name ASAP.
pub project: Option<SmolStr>,
pub cache_key: SmolStr,
pub options: NeonOptions,
}
impl ClientCredentials {
impl ComputeUserInfoMaybeEndpoint {
#[inline]
pub fn project(&self) -> Option<&str> {
self.project.as_deref()
}
}
impl ClientCredentials {
pub fn endpoint_sni<'a>(
sni: &'a str,
common_names: &HashSet<String>,
) -> Result<&'a str, ComputeUserInfoParseError> {
let Some((subdomain, common_name)) = sni.split_once('.') else {
return Err(ComputeUserInfoParseError::UnknownCommonName { cn: sni.into() });
};
if !common_names.contains(common_name) {
return Err(ComputeUserInfoParseError::UnknownCommonName {
cn: common_name.into(),
});
}
Ok(subdomain)
}
impl ComputeUserInfoMaybeEndpoint {
pub fn parse(
ctx: &mut RequestMonitoring,
params: &StartupMessageParams,
sni: Option<&str>,
common_names: Option<HashSet<String>>,
) -> Result<Self, ClientCredsParseError> {
use ClientCredsParseError::*;
common_names: Option<&HashSet<String>>,
) -> Result<Self, ComputeUserInfoParseError> {
use ComputeUserInfoParseError::*;
// Some parameters are stored in the startup message.
let get_param = |key| params.get(key).ok_or(MissingKey(key));
@@ -87,21 +102,7 @@ impl ClientCredentials {
let project_from_domain = if let Some(sni_str) = sni {
if let Some(cn) = common_names {
let common_name_from_sni = sni_str.split_once('.').map(|(_, domain)| domain);
let project = common_name_from_sni
.and_then(|domain| {
if cn.contains(domain) {
subdomain_from_sni(sni_str, domain)
} else {
None
}
})
.ok_or_else(|| UnknownCommonName {
cn: common_name_from_sni.unwrap_or("").into(),
})?;
Some(project)
Some(SmolStr::from(endpoint_sni(sni_str, cn)?))
} else {
None
}
@@ -140,17 +141,12 @@ impl ClientCredentials {
info!("Connection with password hack");
}
let cache_key = format!(
"{}{}",
project.as_deref().unwrap_or(""),
neon_options_str(params)
)
.into();
let options = NeonOptions::parse_params(params);
Ok(Self {
user,
project,
cache_key,
options,
})
}
}
@@ -207,25 +203,19 @@ fn project_name_valid(name: &str) -> bool {
name.chars().all(|c| c.is_alphanumeric() || c == '-')
}
fn subdomain_from_sni(sni: &str, common_name: &str) -> Option<SmolStr> {
sni.strip_suffix(common_name)?
.strip_suffix('.')
.map(SmolStr::from)
}
#[cfg(test)]
mod tests {
use super::*;
use ClientCredsParseError::*;
use ComputeUserInfoParseError::*;
#[test]
fn parse_bare_minimum() -> anyhow::Result<()> {
// According to postgresql, only `user` should be required.
let options = StartupMessageParams::new([("user", "john_doe")]);
let mut ctx = RequestMonitoring::test();
let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project, None);
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.project, None);
Ok(())
}
@@ -238,9 +228,9 @@ mod tests {
("foo", "bar"), // should be ignored
]);
let mut ctx = RequestMonitoring::test();
let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project, None);
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.project, None);
Ok(())
}
@@ -253,10 +243,11 @@ mod tests {
let common_names = Some(["localhost".into()].into());
let mut ctx = RequestMonitoring::test();
let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project.as_deref(), Some("foo"));
assert_eq!(creds.cache_key, "foo");
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.project.as_deref(), Some("foo"));
assert_eq!(user_info.options.get_cache_key("foo"), "foo");
Ok(())
}
@@ -269,9 +260,9 @@ mod tests {
]);
let mut ctx = RequestMonitoring::test();
let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project.as_deref(), Some("bar"));
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.project.as_deref(), Some("bar"));
Ok(())
}
@@ -284,9 +275,9 @@ mod tests {
]);
let mut ctx = RequestMonitoring::test();
let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project.as_deref(), Some("bar"));
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.project.as_deref(), Some("bar"));
Ok(())
}
@@ -302,9 +293,9 @@ mod tests {
]);
let mut ctx = RequestMonitoring::test();
let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
assert_eq!(creds.user, "john_doe");
assert!(creds.project.is_none());
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
assert_eq!(user_info.user, "john_doe");
assert!(user_info.project.is_none());
Ok(())
}
@@ -317,9 +308,9 @@ mod tests {
]);
let mut ctx = RequestMonitoring::test();
let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
assert_eq!(creds.user, "john_doe");
assert!(creds.project.is_none());
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
assert_eq!(user_info.user, "john_doe");
assert!(user_info.project.is_none());
Ok(())
}
@@ -332,9 +323,10 @@ mod tests {
let common_names = Some(["localhost".into()].into());
let mut ctx = RequestMonitoring::test();
let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project.as_deref(), Some("baz"));
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.project.as_deref(), Some("baz"));
Ok(())
}
@@ -346,14 +338,16 @@ mod tests {
let common_names = Some(["a.com".into(), "b.com".into()].into());
let sni = Some("p1.a.com");
let mut ctx = RequestMonitoring::test();
let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?;
assert_eq!(creds.project.as_deref(), Some("p1"));
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.project.as_deref(), Some("p1"));
let common_names = Some(["a.com".into(), "b.com".into()].into());
let sni = Some("p1.b.com");
let mut ctx = RequestMonitoring::test();
let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?;
assert_eq!(creds.project.as_deref(), Some("p1"));
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.project.as_deref(), Some("p1"));
Ok(())
}
@@ -367,8 +361,9 @@ mod tests {
let common_names = Some(["localhost".into()].into());
let mut ctx = RequestMonitoring::test();
let err = ClientCredentials::parse(&mut ctx, &options, sni, common_names)
.expect_err("should fail");
let err =
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())
.expect_err("should fail");
match err {
InconsistentProjectNames { domain, option } => {
assert_eq!(option, "first");
@@ -386,8 +381,9 @@ mod tests {
let common_names = Some(["example.com".into()].into());
let mut ctx = RequestMonitoring::test();
let err = ClientCredentials::parse(&mut ctx, &options, sni, common_names)
.expect_err("should fail");
let err =
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())
.expect_err("should fail");
match err {
UnknownCommonName { cn } => {
assert_eq!(cn, "localhost");
@@ -406,9 +402,13 @@ mod tests {
let sni = Some("project.localhost");
let common_names = Some(["localhost".into()].into());
let mut ctx = RequestMonitoring::test();
let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?;
assert_eq!(creds.project.as_deref(), Some("project"));
assert_eq!(creds.cache_key, "projectendpoint_type:read_write lsn:0/2");
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.project.as_deref(), Some("project"));
assert_eq!(
user_info.options.get_cache_key("project"),
"project endpoint_type:read_write lsn:0/2"
);
Ok(())
}

View File

@@ -32,7 +32,7 @@ pub struct MetricCollectionConfig {
pub struct TlsConfig {
pub config: Arc<rustls::ServerConfig>,
pub common_names: Option<HashSet<String>>,
pub common_names: HashSet<String>,
pub cert_resolver: Arc<CertResolver>,
}
@@ -97,7 +97,7 @@ pub fn configure_tls(
Ok(TlsConfig {
config,
common_names: Some(common_names),
common_names,
cert_resolver,
})
}

View File

@@ -6,7 +6,7 @@ pub mod messages;
/// Wrappers for console APIs and their mocks.
pub mod provider;
pub use provider::{errors, Api, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo};
pub use provider::{errors, Api, AuthSecret, CachedNodeInfo, NodeInfo};
/// Various cache-related types.
pub mod caches {

View File

@@ -197,22 +197,6 @@ pub mod errors {
}
}
/// Extra query params we'd like to pass to the console.
pub struct ConsoleReqExtra {
pub options: Vec<(String, String)>,
}
impl ConsoleReqExtra {
// https://swagger.io/docs/specification/serialization/ DeepObject format
// paramName[prop1]=value1&paramName[prop2]=value2&....
pub fn options_as_deep_object(&self) -> Vec<(String, String)> {
self.options
.iter()
.map(|(k, v)| (format!("options[{}]", k), v.to_string()))
.collect()
}
}
/// Auth secret which is managed by the cloud.
#[derive(Clone, Eq, PartialEq, Debug)]
pub enum AuthSecret {
@@ -249,7 +233,7 @@ pub struct NodeInfo {
pub allow_self_signed_compute: bool,
}
pub type NodeInfoCache = TimedLru<Arc<str>, NodeInfo>;
pub type NodeInfoCache = TimedLru<SmolStr, NodeInfo>;
pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, AuthSecret>;
pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<SmolStr>>>;
@@ -277,7 +261,6 @@ pub trait Api {
async fn wake_compute(
&self,
ctx: &mut RequestMonitoring,
extra: &ConsoleReqExtra,
creds: &ComputeUserInfo,
) -> Result<CachedNodeInfo, errors::WakeComputeError>;
}
@@ -310,7 +293,7 @@ impl ApiCaches {
/// Various caches for [`console`](super).
pub struct ApiLocks {
name: &'static str,
node_locks: DashMap<Arc<str>, Arc<Semaphore>>,
node_locks: DashMap<SmolStr, Arc<Semaphore>>,
permits: usize,
timeout: Duration,
registered: prometheus::IntCounter,
@@ -378,7 +361,7 @@ impl ApiLocks {
pub async fn get_wake_compute_permit(
&self,
key: &Arc<str>,
key: &SmolStr,
) -> Result<WakeComputePermit, errors::WakeComputeError> {
if self.permits == 0 {
return Ok(WakeComputePermit { permit: None });

View File

@@ -2,7 +2,7 @@
use super::{
errors::{ApiError, GetAuthInfoError, WakeComputeError},
AuthInfo, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo,
};
use crate::cache::Cached;
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret};
@@ -50,7 +50,7 @@ impl Api {
async fn do_get_auth_info(
&self,
creds: &ComputeUserInfo,
user_info: &ComputeUserInfo,
) -> Result<AuthInfo, GetAuthInfoError> {
let (secret, allowed_ips) = async {
// Perhaps we could persist this connection, but then we'd have to
@@ -63,7 +63,7 @@ impl Api {
let secret = match get_execute_postgres_query(
&client,
"select rolpassword from pg_catalog.pg_authid where rolname = $1",
&[&&*creds.inner.user],
&[&&*user_info.user],
"rolpassword",
)
.await?
@@ -74,14 +74,14 @@ impl Api {
secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
}
None => {
warn!("user '{}' does not exist", creds.inner.user);
warn!("user '{}' does not exist", user_info.user);
None
}
};
let allowed_ips = match get_execute_postgres_query(
&client,
"select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1",
&[&creds.endpoint.as_str()],
&[&user_info.endpoint.as_str()],
"allowed_ips",
)
.await?
@@ -149,10 +149,10 @@ impl super::Api for Api {
async fn get_role_secret(
&self,
_ctx: &mut RequestMonitoring,
creds: &ComputeUserInfo,
user_info: &ComputeUserInfo,
) -> Result<Option<CachedRoleSecret>, GetAuthInfoError> {
Ok(self
.do_get_auth_info(creds)
.do_get_auth_info(user_info)
.await?
.secret
.map(CachedRoleSecret::new_uncached))
@@ -161,10 +161,10 @@ impl super::Api for Api {
async fn get_allowed_ips(
&self,
_ctx: &mut RequestMonitoring,
creds: &ComputeUserInfo,
user_info: &ComputeUserInfo,
) -> Result<CachedAllowedIps, GetAuthInfoError> {
Ok(Cached::new_uncached(Arc::new(
self.do_get_auth_info(creds).await?.allowed_ips,
self.do_get_auth_info(user_info).await?.allowed_ips,
)))
}
@@ -172,8 +172,7 @@ impl super::Api for Api {
async fn wake_compute(
&self,
_ctx: &mut RequestMonitoring,
_extra: &ConsoleReqExtra,
_creds: &ComputeUserInfo,
_user_info: &ComputeUserInfo,
) -> Result<CachedNodeInfo, WakeComputeError> {
self.do_wake_compute()
.map_ok(CachedNodeInfo::new_uncached)

View File

@@ -4,7 +4,7 @@ use super::{
super::messages::{ConsoleError, GetRoleSecret, WakeCompute},
errors::{ApiError, GetAuthInfoError, WakeComputeError},
ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret,
ConsoleReqExtra, NodeInfo,
NodeInfo,
};
use crate::{auth::backend::ComputeUserInfo, compute, http, scram};
use crate::{
@@ -55,7 +55,7 @@ impl Api {
async fn do_get_auth_info(
&self,
ctx: &mut RequestMonitoring,
creds: &ComputeUserInfo,
user_info: &ComputeUserInfo,
) -> Result<AuthInfo, GetAuthInfoError> {
let request_id = uuid::Uuid::new_v4().to_string();
let application_name = ctx.console_application_name();
@@ -68,8 +68,8 @@ impl Api {
.query(&[("session_id", ctx.session_id)])
.query(&[
("application_name", application_name.as_str()),
("project", creds.endpoint.as_str()),
("role", creds.inner.user.as_str()),
("project", user_info.endpoint.as_str()),
("role", user_info.user.as_str()),
])
.build()?;
@@ -110,8 +110,7 @@ impl Api {
async fn do_wake_compute(
&self,
ctx: &mut RequestMonitoring,
extra: &ConsoleReqExtra,
creds: &ComputeUserInfo,
user_info: &ComputeUserInfo,
) -> Result<NodeInfo, WakeComputeError> {
let request_id = uuid::Uuid::new_v4().to_string();
let application_name = ctx.console_application_name();
@@ -124,14 +123,14 @@ impl Api {
.query(&[("session_id", ctx.session_id)])
.query(&[
("application_name", application_name.as_str()),
("project", creds.endpoint.as_str()),
("project", user_info.endpoint.as_str()),
]);
request_builder = if extra.options.is_empty() {
request_builder
} else {
request_builder.query(&extra.options_as_deep_object())
};
let options = user_info.options.to_deep_object();
if !options.is_empty() {
request_builder = request_builder.query(&options);
}
let request = request_builder.build()?;
info!(url = request.url().as_str(), "sending http request");
@@ -172,14 +171,14 @@ impl super::Api for Api {
async fn get_role_secret(
&self,
ctx: &mut RequestMonitoring,
creds: &ComputeUserInfo,
user_info: &ComputeUserInfo,
) -> Result<Option<CachedRoleSecret>, GetAuthInfoError> {
let ep = &creds.endpoint;
let user = &creds.inner.user;
let ep = &user_info.endpoint;
let user = &user_info.user;
if let Some(role_secret) = self.caches.project_info.get_role_secret(ep, user) {
return Ok(Some(role_secret));
}
let auth_info = self.do_get_auth_info(ctx, creds).await?;
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
let project_id = auth_info.project_id.unwrap_or(ep.clone());
if let Some(secret) = &auth_info.secret {
self.caches
@@ -198,9 +197,9 @@ impl super::Api for Api {
async fn get_allowed_ips(
&self,
ctx: &mut RequestMonitoring,
creds: &ComputeUserInfo,
user_info: &ComputeUserInfo,
) -> Result<CachedAllowedIps, GetAuthInfoError> {
let ep = &creds.endpoint;
let ep = &user_info.endpoint;
if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(ep) {
ALLOWED_IPS_BY_CACHE_OUTCOME
.with_label_values(&["hit"])
@@ -210,9 +209,9 @@ impl super::Api for Api {
ALLOWED_IPS_BY_CACHE_OUTCOME
.with_label_values(&["miss"])
.inc();
let auth_info = self.do_get_auth_info(ctx, creds).await?;
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
let allowed_ips = Arc::new(auth_info.allowed_ips);
let user = &creds.inner.user;
let user = &user_info.user;
let project_id = auth_info.project_id.unwrap_or(ep.clone());
if let Some(secret) = &auth_info.secret {
self.caches
@@ -229,22 +228,19 @@ impl super::Api for Api {
async fn wake_compute(
&self,
ctx: &mut RequestMonitoring,
extra: &ConsoleReqExtra,
creds: &ComputeUserInfo,
user_info: &ComputeUserInfo,
) -> Result<CachedNodeInfo, WakeComputeError> {
let key: &str = &creds.inner.cache_key;
let key = user_info.endpoint_cache_key();
// Every time we do a wakeup http request, the compute node will stay up
// for some time (highly depends on the console's scale-to-zero policy);
// The connection info remains the same during that period of time,
// which means that we might cache it to reduce the load and latency.
if let Some(cached) = self.caches.node_info.get(key) {
info!(key = key, "found cached compute node info");
if let Some(cached) = self.caches.node_info.get(&*key) {
info!(key = &*key, "found cached compute node info");
return Ok(cached);
}
let key: Arc<str> = key.into();
let permit = self.locks.get_wake_compute_permit(&key).await?;
// after getting back a permit - it's possible the cache was filled
@@ -256,7 +252,7 @@ impl super::Api for Api {
}
}
let node = self.do_wake_compute(ctx, extra, creds).await?;
let node = self.do_wake_compute(ctx, user_info).await?;
let (_, cached) = self.caches.node_info.insert(key.clone(), node);
info!(key = &*key, "created a cache entry for compute node info");

View File

@@ -9,7 +9,7 @@ use crate::{
cancellation::{self, CancelMap},
compute,
config::{AuthenticationConfig, ProxyConfig, TlsConfig},
console::{self, messages::MetricsAuxInfo},
console::messages::MetricsAuxInfo,
context::RequestMonitoring,
metrics::{
NUM_BYTES_PROXIED_COUNTER, NUM_BYTES_PROXIED_PER_CLIENT_COUNTER,
@@ -26,6 +26,7 @@ use itertools::Itertools;
use once_cell::sync::OnceCell;
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
use regex::Regex;
use smol_str::SmolStr;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
@@ -198,27 +199,29 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
drop(pause);
// Extract credentials which we're going to use for auth.
let creds = {
let user_info = {
let hostname = mode.hostname(stream.get_ref());
let common_names = tls.and_then(|tls| tls.common_names.clone());
let common_names = tls.map(|tls| &tls.common_names);
let result = config
.auth_backend
.as_ref()
.map(|_| auth::ClientCredentials::parse(ctx, &params, hostname, common_names))
.map(|_| {
auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names)
})
.transpose();
match result {
Ok(creds) => creds,
Ok(user_info) => user_info,
Err(e) => stream.throw_error(e).await?,
}
};
ctx.set_endpoint_id(creds.get_endpoint());
ctx.set_endpoint_id(user_info.get_endpoint());
let client = Client::new(
stream,
creds,
user_info,
&params,
mode.allow_self_signed_compute(config),
endpoint_rate_limiter,
@@ -397,7 +400,7 @@ struct Client<'a, S> {
/// The underlying libpq protocol stream.
stream: PqStream<Stream<S>>,
/// Client credentials that we care about.
creds: auth::BackendType<'a, auth::ClientCredentials>,
user_info: auth::BackendType<'a, auth::ComputeUserInfoMaybeEndpoint>,
/// KV-dictionary with PostgreSQL connection params.
params: &'a StartupMessageParams,
/// Allow self-signed certificates (for testing).
@@ -410,14 +413,14 @@ impl<'a, S> Client<'a, S> {
/// Construct a new connection context.
fn new(
stream: PqStream<Stream<S>>,
creds: auth::BackendType<'a, auth::ClientCredentials>,
user_info: auth::BackendType<'a, auth::ComputeUserInfoMaybeEndpoint>,
params: &'a StartupMessageParams,
allow_self_signed_compute: bool,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> Self {
Self {
stream,
creds,
user_info,
params,
allow_self_signed_compute,
endpoint_rate_limiter,
@@ -429,7 +432,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
/// Let the client authenticate and connect to the designated compute node.
// Instrumentation logs endpoint name everywhere. Doesn't work for link
// auth; strictly speaking we don't know endpoint name in its case.
#[tracing::instrument(name = "", fields(ep = %self.creds.get_endpoint().unwrap_or_default()), skip_all)]
#[tracing::instrument(name = "", fields(ep = %self.user_info.get_endpoint().unwrap_or_default()), skip_all)]
async fn connect_to_db(
self,
ctx: &mut RequestMonitoring,
@@ -439,14 +442,14 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
) -> anyhow::Result<()> {
let Self {
mut stream,
creds,
user_info,
params,
allow_self_signed_compute,
endpoint_rate_limiter,
} = self;
// check rate limit
if let Some(ep) = creds.get_endpoint() {
if let Some(ep) = user_info.get_endpoint() {
if !endpoint_rate_limiter.check(ep) {
return stream
.throw_error(auth::AuthError::too_many_connections())
@@ -454,13 +457,9 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
}
}
let extra = console::ConsoleReqExtra {
options: neon_options(params),
};
let user = creds.get_user().to_owned();
let auth_result = match creds
.authenticate(ctx, &extra, &mut stream, mode.allow_cleartext(), config)
let user = user_info.get_user().to_owned();
let auth_result = match user_info
.authenticate(ctx, &mut stream, mode.allow_cleartext(), config)
.await
{
Ok(auth_result) => auth_result,
@@ -473,12 +472,12 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
}
};
let (mut node_info, creds) = auth_result;
let (mut node_info, user_info) = auth_result;
node_info.allow_self_signed_compute = allow_self_signed_compute;
let aux = node_info.aux.clone();
let mut node = connect_to_compute(ctx, &TcpMechanism { params }, node_info, &extra, &creds)
let mut node = connect_to_compute(ctx, &TcpMechanism { params }, node_info, &user_info)
.or_else(|e| stream.throw_error(e))
.await?;
@@ -493,29 +492,52 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
}
}
pub fn neon_options(params: &StartupMessageParams) -> Vec<(String, String)> {
#[allow(unstable_name_collisions)]
match params.options_raw() {
Some(options) => options.filter_map(neon_option).collect(),
None => vec![],
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct NeonOptions(Vec<(SmolStr, SmolStr)>);
impl NeonOptions {
pub fn parse_params(params: &StartupMessageParams) -> Self {
params
.options_raw()
.map(Self::parse_from_iter)
.unwrap_or_default()
}
pub fn parse_options_raw(options: &str) -> Self {
Self::parse_from_iter(StartupMessageParams::parse_options_raw(options))
}
fn parse_from_iter<'a>(options: impl Iterator<Item = &'a str>) -> Self {
let mut options = options
.filter_map(neon_option)
.map(|(k, v)| (k.into(), v.into()))
.collect_vec();
options.sort();
Self(options)
}
pub fn get_cache_key(&self, prefix: &str) -> SmolStr {
// prefix + format!(" {k}:{v}")
// kinda jank because SmolStr is immutable
std::iter::once(prefix)
.chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v]))
.collect()
}
/// <https://swagger.io/docs/specification/serialization/> DeepObject format
/// `paramName[prop1]=value1&paramName[prop2]=value2&...`
pub fn to_deep_object(&self) -> Vec<(String, SmolStr)> {
self.0
.iter()
.map(|(k, v)| (format!("options[{}]", k), v.clone()))
.collect()
}
}
pub fn neon_options_str(params: &StartupMessageParams) -> String {
#[allow(unstable_name_collisions)]
neon_options(params)
.iter()
.map(|(k, v)| format!("{}:{}", k, v))
.sorted() // we sort it to use as cache key
.intersperse(" ".to_owned())
.collect()
}
pub fn neon_option(bytes: &str) -> Option<(String, String)> {
pub fn neon_option(bytes: &str) -> Option<(&str, &str)> {
static RE: OnceCell<Regex> = OnceCell::new();
let re = RE.get_or_init(|| Regex::new(r"^neon_(\w+):(.+)").unwrap());
let cap = re.captures(bytes)?;
let (_, [k, v]) = cap.extract();
Some((k.to_owned(), v.to_owned()))
Some((k, v))
}

View File

@@ -128,8 +128,7 @@ pub async fn connect_to_compute<M: ConnectMechanism>(
ctx: &mut RequestMonitoring,
mechanism: &M,
mut node_info: console::CachedNodeInfo,
extra: &console::ConsoleReqExtra,
creds: &auth::BackendType<'_, auth::backend::ComputeUserInfo>,
user_info: &auth::BackendType<'_, auth::backend::ComputeUserInfo>,
) -> Result<M::Connection, M::Error>
where
M::ConnectError: ShouldRetry + std::fmt::Debug,
@@ -159,10 +158,10 @@ where
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
info!("compute node's state has likely changed; requesting a wake-up");
let node_info = loop {
let wake_res = match creds {
auth::BackendType::Console(api, creds) => api.wake_compute(ctx, extra, creds).await,
let wake_res = match user_info {
auth::BackendType::Console(api, user_info) => api.wake_compute(ctx, user_info).await,
#[cfg(feature = "testing")]
auth::BackendType::Postgres(api, creds) => api.wake_compute(ctx, extra, creds).await,
auth::BackendType::Postgres(api, user_info) => api.wake_compute(ctx, user_info).await,
// nothing to do?
auth::BackendType::Link(_) => return Err(err.into()),
// test backend

View File

@@ -7,7 +7,7 @@ use super::retry::ShouldRetry;
use super::*;
use crate::auth::backend::{ComputeUserInfo, TestBackend};
use crate::config::CertResolver;
use crate::console::{CachedNodeInfo, NodeInfo};
use crate::console::{self, CachedNodeInfo, NodeInfo};
use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT};
use crate::{auth, http, sasl, scram};
use async_trait::async_trait;
@@ -83,7 +83,7 @@ fn generate_tls_config<'a>(
let mut cert_resolver = CertResolver::new();
cert_resolver.add_cert(key, vec![cert], true)?;
let common_names = Some(cert_resolver.get_common_names());
let common_names = cert_resolver.get_common_names();
TlsConfig {
config,
@@ -487,15 +487,10 @@ fn helper_create_cached_node_info() -> CachedNodeInfo {
fn helper_create_connect_info(
mechanism: &TestConnectMechanism,
) -> (
CachedNodeInfo,
console::ConsoleReqExtra,
auth::BackendType<'_, ComputeUserInfo>,
) {
) -> (CachedNodeInfo, auth::BackendType<'_, ComputeUserInfo>) {
let cache = helper_create_cached_node_info();
let extra = console::ConsoleReqExtra { options: vec![] };
let creds = auth::BackendType::Test(mechanism);
(cache, extra, creds)
let user_info = auth::BackendType::Test(mechanism);
(cache, user_info)
}
#[tokio::test]
@@ -503,8 +498,8 @@ async fn connect_to_compute_success() {
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let mechanism = TestConnectMechanism::new(vec![Connect]);
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, cache, &extra, &creds)
let (cache, user_info) = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
.await
.unwrap();
mechanism.verify();
@@ -515,8 +510,8 @@ async fn connect_to_compute_retry() {
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Connect]);
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, cache, &extra, &creds)
let (cache, user_info) = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
.await
.unwrap();
mechanism.verify();
@@ -528,8 +523,8 @@ async fn connect_to_compute_non_retry_1() {
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Fail]);
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, cache, &extra, &creds)
let (cache, user_info) = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
.await
.unwrap_err();
mechanism.verify();
@@ -541,8 +536,8 @@ async fn connect_to_compute_non_retry_2() {
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let mechanism = TestConnectMechanism::new(vec![Fail, Wake, Retry, Connect]);
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, cache, &extra, &creds)
let (cache, user_info) = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
.await
.unwrap();
mechanism.verify();
@@ -558,8 +553,8 @@ async fn connect_to_compute_non_retry_3() {
Retry, Wake, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry,
Retry, Retry, Retry, Retry, /* the 17th time */ Retry,
]);
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, cache, &extra, &creds)
let (cache, user_info) = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
.await
.unwrap_err();
mechanism.verify();
@@ -571,8 +566,8 @@ async fn wake_retry() {
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let mechanism = TestConnectMechanism::new(vec![Retry, WakeRetry, Wake, Connect]);
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, cache, &extra, &creds)
let (cache, user_info) = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
.await
.unwrap();
mechanism.verify();
@@ -584,8 +579,8 @@ async fn wake_non_retry() {
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let mechanism = TestConnectMechanism::new(vec![Retry, WakeFail]);
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, cache, &extra, &creds)
let (cache, user_info) = helper_create_connect_info(&mechanism);
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
.await
.unwrap_err();
mechanism.verify();

View File

@@ -17,6 +17,7 @@ pub use reqwest_middleware::{ClientWithMiddleware, Error};
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use tokio_util::task::TaskTracker;
use crate::config::TlsConfig;
use crate::context::RequestMonitoring;
use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE;
use crate::protocol2::{ProxyProtocolAccept, WithClientIp};
@@ -69,14 +70,14 @@ pub async fn task_main(
}
});
let tls_config = config.tls_config.as_ref().map(|cfg| cfg.to_server_config());
let tls_acceptor: tokio_rustls::TlsAcceptor = match tls_config {
Some(config) => config.into(),
let tls_config = match config.tls_config.as_ref() {
Some(config) => config,
None => {
warn!("TLS config is missing, WebSocket Secure server will not be started");
return Ok(());
}
};
let tls_acceptor: tokio_rustls::TlsAcceptor = tls_config.to_server_config().into();
let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
let _ = addr_incoming.set_nodelay(true);
@@ -126,6 +127,7 @@ pub async fn task_main(
request_handler(
req,
config,
tls_config,
conn_pool,
ws_connections,
cancel_map,
@@ -195,6 +197,7 @@ where
async fn request_handler(
mut request: Request<Body>,
config: &'static ProxyConfig,
tls: &'static TlsConfig,
conn_pool: Arc<conn_pool::GlobalConnPool>,
ws_connections: TaskTracker,
cancel_map: Arc<CancelMap>,
@@ -243,6 +246,7 @@ async fn request_handler(
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
sql_over_http::handle(
tls,
&config.http_config,
&mut ctx,
request,

View File

@@ -1,4 +1,4 @@
use anyhow::{anyhow, Context};
use anyhow::Context;
use async_trait::async_trait;
use dashmap::DashMap;
use futures::{future::poll_fn, Future};
@@ -9,7 +9,6 @@ use pbkdf2::{
password_hash::{PasswordHashString, PasswordHasher, PasswordVerifier, SaltString},
Params, Pbkdf2,
};
use pq_proto::StartupMessageParams;
use prometheus::{exponential_buckets, register_histogram, Histogram};
use rand::Rng;
use smol_str::SmolStr;
@@ -30,7 +29,7 @@ use crate::{
console,
context::RequestMonitoring,
metrics::NUM_DB_CONNECTIONS_GAUGE,
proxy::{connect_compute::ConnectMechanism, neon_options},
proxy::connect_compute::ConnectMechanism,
usage_metrics::{Ids, MetricCounter, USAGE_METRICS},
};
use crate::{compute, config};
@@ -38,28 +37,37 @@ use crate::{compute, config};
use tracing::{debug, error, warn, Span};
use tracing::{info, info_span, Instrument};
pub const APP_NAME: &str = "/sql_over_http";
pub const APP_NAME: SmolStr = SmolStr::new_inline("/sql_over_http");
#[derive(Debug, Clone)]
pub struct ConnInfo {
pub username: SmolStr,
pub user_info: ComputeUserInfo,
pub dbname: SmolStr,
pub hostname: SmolStr,
pub password: SmolStr,
pub options: Option<SmolStr>,
}
impl ConnInfo {
// hm, change to hasher to avoid cloning?
pub fn db_and_user(&self) -> (SmolStr, SmolStr) {
(self.dbname.clone(), self.username.clone())
(self.dbname.clone(), self.user_info.user.clone())
}
pub fn endpoint_cache_key(&self) -> SmolStr {
self.user_info.endpoint_cache_key()
}
}
impl fmt::Display for ConnInfo {
// use custom display to avoid logging password
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}@{}/{}", self.username, self.hostname, self.dbname)
write!(
f,
"{}@{}/{}?{}",
self.user_info.user,
self.user_info.endpoint,
self.dbname,
self.user_info.options.get_cache_key("")
)
}
}
@@ -319,7 +327,7 @@ impl GlobalConnPool {
let mut hash_valid = false;
let mut endpoint_pool = Weak::new();
if !force_new {
let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
let pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key());
endpoint_pool = Arc::downgrade(&pool);
let mut hash = None;
@@ -401,7 +409,7 @@ impl GlobalConnPool {
Err(err)
if hash_valid && err.to_string().contains("password authentication failed") =>
{
let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
let pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key());
let mut pool = pool.write();
if let Some(entry) = pool.pools.get_mut(&conn_info.db_and_user()) {
entry.password_hash = None;
@@ -418,7 +426,7 @@ impl GlobalConnPool {
})
.await??;
let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
let pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key());
let mut pool = pool.write();
pool.pools
.entry(conn_info.db_and_user())
@@ -520,23 +528,11 @@ async fn connect_to_compute(
conn_id: uuid::Uuid,
pool: Weak<RwLock<EndpointConnPool>>,
) -> anyhow::Result<ClientInner> {
let tls = config.tls_config.as_ref();
let common_names = tls.and_then(|tls| tls.common_names.clone());
let params = StartupMessageParams::new([
("user", &conn_info.username),
("database", &conn_info.dbname),
("application_name", APP_NAME),
("options", conn_info.options.as_deref().unwrap_or("")),
]);
let creds =
auth::ClientCredentials::parse(ctx, &params, Some(&conn_info.hostname), common_names)?;
let creds =
ComputeUserInfo::try_from(creds).map_err(|_| anyhow!("missing endpoint identifier"))?;
let backend = config.auth_backend.as_ref().map(|_| creds);
let console_options = neon_options(&params);
ctx.set_application(Some(APP_NAME));
let backend = config
.auth_backend
.as_ref()
.map(|_| conn_info.user_info.clone());
if !config.disable_ip_check_for_http {
let allowed_ips = backend.get_allowed_ips(ctx).await?;
@@ -544,11 +540,8 @@ async fn connect_to_compute(
return Err(auth::AuthError::ip_address_not_allowed().into());
}
}
let extra = console::ConsoleReqExtra {
options: console_options,
};
let node_info = backend
.wake_compute(ctx, &extra)
.wake_compute(ctx)
.await?
.context("missing cache entry from wake_compute")?;
@@ -563,7 +556,6 @@ async fn connect_to_compute(
idle: config.http_config.pool_options.idle_timeout,
},
node_info,
&extra,
&backend,
)
.await
@@ -582,7 +574,7 @@ async fn connect_to_compute_once(
let mut session = ctx.session_id;
let (client, mut connection) = config
.user(&conn_info.username)
.user(&conn_info.user_info.user)
.password(&*conn_info.password)
.dbname(&conn_info.dbname)
.connect_timeout(timeout)

View File

@@ -28,9 +28,13 @@ use url::Url;
use utils::http::error::ApiError;
use utils::http::json::json_response;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::endpoint_sni;
use crate::config::HttpConfig;
use crate::config::TlsConfig;
use crate::context::RequestMonitoring;
use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
use crate::proxy::NeonOptions;
use super::conn_pool::ConnInfo;
use super::conn_pool::GlobalConnPool;
@@ -125,6 +129,7 @@ fn get_conn_info(
ctx: &mut RequestMonitoring,
headers: &HeaderMap,
sni_hostname: Option<String>,
tls: &TlsConfig,
) -> Result<ConnInfo, anyhow::Error> {
let connection_string = headers
.get("Neon-Connection-String")
@@ -179,8 +184,10 @@ fn get_conn_info(
}
}
let hostname: SmolStr = hostname.into();
ctx.set_endpoint_id(Some(hostname.clone()));
let endpoint = endpoint_sni(hostname, &tls.common_names)?;
let endpoint: SmolStr = endpoint.into();
ctx.set_endpoint_id(Some(endpoint.clone()));
let pairs = connection_url.query_pairs();
@@ -188,22 +195,27 @@ fn get_conn_info(
for (key, value) in pairs {
if key == "options" {
options = Some(value.into());
options = Some(NeonOptions::parse_options_raw(&value));
break;
}
}
let user_info = ComputeUserInfo {
endpoint,
user: username,
options: options.unwrap_or_default(),
};
Ok(ConnInfo {
username,
user_info,
dbname: dbname.into(),
hostname,
password: password.into(),
options,
})
}
// TODO: return different http error codes
pub async fn handle(
tls: &'static TlsConfig,
config: &'static HttpConfig,
ctx: &mut RequestMonitoring,
request: Request<Body>,
@@ -212,7 +224,7 @@ pub async fn handle(
) -> Result<Response<Body>, ApiError> {
let result = tokio::time::timeout(
config.request_timeout,
handle_inner(config, ctx, request, sni_hostname, conn_pool),
handle_inner(tls, config, ctx, request, sni_hostname, conn_pool),
)
.await;
let mut response = match result {
@@ -294,6 +306,7 @@ pub async fn handle(
#[instrument(name = "sql-over-http", fields(pid = tracing::field::Empty), skip_all)]
async fn handle_inner(
tls: &'static TlsConfig,
config: &'static HttpConfig,
ctx: &mut RequestMonitoring,
request: Request<Body>,
@@ -308,7 +321,7 @@ async fn handle_inner(
// Determine the destination and connection params
//
let headers = request.headers();
let conn_info = get_conn_info(ctx, headers, sni_hostname)?;
let conn_info = get_conn_info(ctx, headers, sni_hostname, tls)?;
// Determine the output options. Default behaviour is 'false'. Anything that is not
// strictly 'true' assumed to be false.