mirror of
https://github.com/neondatabase/neon.git
synced 2026-06-04 22:10:39 +00:00
proxy: refactor how neon-options are handled (#6306)
## Problem HTTP connection pool was not respecting the PitR options. ## Summary of changes 1. refactor neon_options a bit to allow easier access to cache_key 2. make HTTP not go through `StartupMessageParams` 3. expose SNI processing to replace what was removed in step 2.
This commit is contained in:
@@ -4,7 +4,7 @@ pub mod backend;
|
||||
pub use backend::BackendType;
|
||||
|
||||
mod credentials;
|
||||
pub use credentials::{check_peer_addr_is_in_list, ClientCredentials};
|
||||
pub use credentials::{check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint};
|
||||
|
||||
mod password_hack;
|
||||
pub use password_hack::parse_endpoint_param;
|
||||
|
||||
@@ -14,14 +14,15 @@ use crate::console::AuthSecret;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::proxy::connect_compute::handle_try_wake;
|
||||
use crate::proxy::retry::retry_after;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::scram;
|
||||
use crate::stream::Stream;
|
||||
use crate::{
|
||||
auth::{self, ClientCredentials},
|
||||
auth::{self, ComputeUserInfoMaybeEndpoint},
|
||||
config::AuthenticationConfig,
|
||||
console::{
|
||||
self,
|
||||
provider::{CachedAllowedIps, CachedNodeInfo, ConsoleReqExtra},
|
||||
provider::{CachedAllowedIps, CachedNodeInfo},
|
||||
Api,
|
||||
},
|
||||
stream, url,
|
||||
@@ -38,7 +39,7 @@ use tracing::{error, info, warn};
|
||||
/// * When `T` is `()`, it's just a regular auth backend selector
|
||||
/// which we use in [`crate::config::ProxyConfig`].
|
||||
///
|
||||
/// * However, when we substitute `T` with [`ClientCredentials`],
|
||||
/// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
|
||||
/// this helps us provide the credentials only to those auth
|
||||
/// backends which require them for the authentication process.
|
||||
pub enum BackendType<'a, T> {
|
||||
@@ -127,14 +128,23 @@ pub struct ComputeCredentials<T> {
|
||||
pub keys: T,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ComputeUserInfoNoEndpoint {
|
||||
pub user: SmolStr,
|
||||
pub cache_key: SmolStr,
|
||||
pub options: NeonOptions,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ComputeUserInfo {
|
||||
pub endpoint: SmolStr,
|
||||
pub inner: ComputeUserInfoNoEndpoint,
|
||||
pub user: SmolStr,
|
||||
pub options: NeonOptions,
|
||||
}
|
||||
|
||||
impl ComputeUserInfo {
|
||||
pub fn endpoint_cache_key(&self) -> SmolStr {
|
||||
self.options.get_cache_key(&self.endpoint)
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ComputeCredentialKeys {
|
||||
@@ -143,18 +153,21 @@ pub enum ComputeCredentialKeys {
|
||||
AuthKeys(AuthKeys),
|
||||
}
|
||||
|
||||
impl TryFrom<ClientCredentials> for ComputeUserInfo {
|
||||
impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
|
||||
// user name
|
||||
type Error = ComputeUserInfoNoEndpoint;
|
||||
|
||||
fn try_from(creds: ClientCredentials) -> Result<Self, Self::Error> {
|
||||
let inner = ComputeUserInfoNoEndpoint {
|
||||
user: creds.user,
|
||||
cache_key: creds.cache_key,
|
||||
};
|
||||
match creds.project {
|
||||
None => Err(inner),
|
||||
Some(endpoint) => Ok(ComputeUserInfo { endpoint, inner }),
|
||||
fn try_from(user_info: ComputeUserInfoMaybeEndpoint) -> Result<Self, Self::Error> {
|
||||
match user_info.project {
|
||||
None => Err(ComputeUserInfoNoEndpoint {
|
||||
user: user_info.user,
|
||||
options: user_info.options,
|
||||
}),
|
||||
Some(endpoint) => Ok(ComputeUserInfo {
|
||||
endpoint,
|
||||
user: user_info.user,
|
||||
options: user_info.options,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -166,7 +179,7 @@ impl TryFrom<ClientCredentials> for ComputeUserInfo {
|
||||
async fn auth_quirks(
|
||||
ctx: &mut RequestMonitoring,
|
||||
api: &impl console::Api,
|
||||
creds: ClientCredentials,
|
||||
user_info: ComputeUserInfoMaybeEndpoint,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
@@ -174,7 +187,7 @@ async fn auth_quirks(
|
||||
// If there's no project so far, that entails that client doesn't
|
||||
// support SNI or other means of passing the endpoint (project) name.
|
||||
// We now expect to see a very specific payload in the place of password.
|
||||
let (info, unauthenticated_password) = match creds.try_into() {
|
||||
let (info, unauthenticated_password) = match user_info.try_into() {
|
||||
Err(info) => {
|
||||
let res = hacks::password_hack_no_authentication(info, client, &mut ctx.latency_timer)
|
||||
.await?;
|
||||
@@ -199,7 +212,7 @@ async fn auth_quirks(
|
||||
// This mocked secret will never lead to successful authentication.
|
||||
info!("authentication info not found, mocking it");
|
||||
Cached::new_uncached(AuthSecret::Scram(scram::ServerSecret::mock(
|
||||
&info.inner.user,
|
||||
&info.user,
|
||||
rand::random(),
|
||||
)))
|
||||
});
|
||||
@@ -240,7 +253,7 @@ async fn authenticate_with_secret(
|
||||
crate::sasl::Outcome::Success(key) => key,
|
||||
crate::sasl::Outcome::Failure(reason) => {
|
||||
info!("auth backend failed with an error: {reason}");
|
||||
return Err(auth::AuthError::auth_failed(&*info.inner.user));
|
||||
return Err(auth::AuthError::auth_failed(&*info.user));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -267,19 +280,17 @@ async fn authenticate_with_secret(
|
||||
async fn auth_and_wake_compute(
|
||||
ctx: &mut RequestMonitoring,
|
||||
api: &impl console::Api,
|
||||
extra: &ConsoleReqExtra,
|
||||
creds: ClientCredentials,
|
||||
user_info: ComputeUserInfoMaybeEndpoint,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
) -> auth::Result<(CachedNodeInfo, ComputeUserInfo)> {
|
||||
let compute_credentials = auth_quirks(ctx, api, creds, client, allow_cleartext, config).await?;
|
||||
let compute_credentials =
|
||||
auth_quirks(ctx, api, user_info, client, allow_cleartext, config).await?;
|
||||
|
||||
let mut num_retries = 0;
|
||||
let mut node = loop {
|
||||
let wake_res = api
|
||||
.wake_compute(ctx, extra, &compute_credentials.info)
|
||||
.await;
|
||||
let wake_res = api.wake_compute(ctx, &compute_credentials.info).await;
|
||||
match handle_try_wake(wake_res, num_retries) {
|
||||
Err(e) => {
|
||||
error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
|
||||
@@ -307,15 +318,15 @@ async fn auth_and_wake_compute(
|
||||
Ok((node, compute_credentials.info))
|
||||
}
|
||||
|
||||
impl<'a> BackendType<'a, ClientCredentials> {
|
||||
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
/// Get compute endpoint name from the credentials.
|
||||
pub fn get_endpoint(&self) -> Option<SmolStr> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(_, creds) => creds.project.clone(),
|
||||
Console(_, user_info) => user_info.project.clone(),
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(_, creds) => creds.project.clone(),
|
||||
Postgres(_, user_info) => user_info.project.clone(),
|
||||
Link(_) => Some("link".into()),
|
||||
#[cfg(test)]
|
||||
Test(_) => Some("test".into()),
|
||||
@@ -327,9 +338,9 @@ impl<'a> BackendType<'a, ClientCredentials> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(_, creds) => &creds.user,
|
||||
Console(_, user_info) => &user_info.user,
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(_, creds) => &creds.user,
|
||||
Postgres(_, user_info) => &user_info.user,
|
||||
Link(_) => "link",
|
||||
#[cfg(test)]
|
||||
Test(_) => "test",
|
||||
@@ -341,7 +352,6 @@ impl<'a> BackendType<'a, ClientCredentials> {
|
||||
pub async fn authenticate(
|
||||
self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
extra: &ConsoleReqExtra,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
@@ -349,43 +359,29 @@ impl<'a> BackendType<'a, ClientCredentials> {
|
||||
use BackendType::*;
|
||||
|
||||
let res = match self {
|
||||
Console(api, creds) => {
|
||||
Console(api, user_info) => {
|
||||
info!(
|
||||
user = &*creds.user,
|
||||
project = creds.project(),
|
||||
user = &*user_info.user,
|
||||
project = user_info.project(),
|
||||
"performing authentication using the console"
|
||||
);
|
||||
|
||||
let (cache_info, user_info) = auth_and_wake_compute(
|
||||
ctx,
|
||||
&*api,
|
||||
extra,
|
||||
creds,
|
||||
client,
|
||||
allow_cleartext,
|
||||
config,
|
||||
)
|
||||
.await?;
|
||||
let (cache_info, user_info) =
|
||||
auth_and_wake_compute(ctx, &*api, user_info, client, allow_cleartext, config)
|
||||
.await?;
|
||||
(cache_info, BackendType::Console(api, user_info))
|
||||
}
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(api, creds) => {
|
||||
Postgres(api, user_info) => {
|
||||
info!(
|
||||
user = &*creds.user,
|
||||
project = creds.project(),
|
||||
user = &*user_info.user,
|
||||
project = user_info.project(),
|
||||
"performing authentication using a local postgres instance"
|
||||
);
|
||||
|
||||
let (cache_info, user_info) = auth_and_wake_compute(
|
||||
ctx,
|
||||
&*api,
|
||||
extra,
|
||||
creds,
|
||||
client,
|
||||
allow_cleartext,
|
||||
config,
|
||||
)
|
||||
.await?;
|
||||
let (cache_info, user_info) =
|
||||
auth_and_wake_compute(ctx, &*api, user_info, client, allow_cleartext, config)
|
||||
.await?;
|
||||
(cache_info, BackendType::Postgres(api, user_info))
|
||||
}
|
||||
// NOTE: this auth backend doesn't use client credentials.
|
||||
@@ -417,9 +413,9 @@ impl BackendType<'_, ComputeUserInfo> {
|
||||
) -> Result<CachedAllowedIps, GetAuthInfoError> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(api, creds) => api.get_allowed_ips(ctx, creds).await,
|
||||
Console(api, user_info) => api.get_allowed_ips(ctx, user_info).await,
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(api, creds) => api.get_allowed_ips(ctx, creds).await,
|
||||
Postgres(api, user_info) => api.get_allowed_ips(ctx, user_info).await,
|
||||
Link(_) => Ok(Cached::new_uncached(Arc::new(vec![]))),
|
||||
#[cfg(test)]
|
||||
Test(x) => Ok(Cached::new_uncached(Arc::new(x.get_allowed_ips()?))),
|
||||
@@ -431,14 +427,13 @@ impl BackendType<'_, ComputeUserInfo> {
|
||||
pub async fn wake_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
extra: &ConsoleReqExtra,
|
||||
) -> Result<Option<CachedNodeInfo>, console::errors::WakeComputeError> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(api, creds) => api.wake_compute(ctx, extra, creds).map_ok(Some).await,
|
||||
Console(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await,
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(api, creds) => api.wake_compute(ctx, extra, creds).map_ok(Some).await,
|
||||
Postgres(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await,
|
||||
Link(_) => Ok(None),
|
||||
#[cfg(test)]
|
||||
Test(x) => x.wake_compute().map(Some),
|
||||
|
||||
@@ -54,7 +54,7 @@ pub(super) async fn authenticate(
|
||||
sasl::Outcome::Success(key) => key,
|
||||
sasl::Outcome::Failure(reason) => {
|
||||
info!("auth backend failed with an error: {reason}");
|
||||
return Err(auth::AuthError::auth_failed(&*creds.inner.user));
|
||||
return Err(auth::AuthError::auth_failed(&*creds.user));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ pub async fn authenticate_cleartext(
|
||||
sasl::Outcome::Success(key) => key,
|
||||
sasl::Outcome::Failure(reason) => {
|
||||
info!("auth backend failed with an error: {reason}");
|
||||
return Err(auth::AuthError::auth_failed(&*info.inner.user));
|
||||
return Err(auth::AuthError::auth_failed(&*info.user));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -67,7 +67,8 @@ pub async fn password_hack_no_authentication(
|
||||
// Report tentative success; compute node will check the password anyway.
|
||||
Ok(ComputeCredentials {
|
||||
info: ComputeUserInfo {
|
||||
inner: info,
|
||||
user: info.user,
|
||||
options: info.options,
|
||||
endpoint: payload.endpoint,
|
||||
},
|
||||
keys: payload.password,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
use crate::{
|
||||
auth::password_hack::parse_endpoint_param, context::RequestMonitoring, error::UserFacingError,
|
||||
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::neon_options_str,
|
||||
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use pq_proto::StartupMessageParams;
|
||||
@@ -12,7 +12,7 @@ use thiserror::Error;
|
||||
use tracing::{info, warn};
|
||||
|
||||
#[derive(Debug, Error, PartialEq, Eq, Clone)]
|
||||
pub enum ClientCredsParseError {
|
||||
pub enum ComputeUserInfoParseError {
|
||||
#[error("Parameter '{0}' is missing in startup packet.")]
|
||||
MissingKey(&'static str),
|
||||
|
||||
@@ -33,34 +33,49 @@ pub enum ClientCredsParseError {
|
||||
MalformedProjectName(SmolStr),
|
||||
}
|
||||
|
||||
impl UserFacingError for ClientCredsParseError {}
|
||||
impl UserFacingError for ComputeUserInfoParseError {}
|
||||
|
||||
/// Various client credentials which we use for authentication.
|
||||
/// Note that we don't store any kind of client key or password here.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ClientCredentials {
|
||||
pub struct ComputeUserInfoMaybeEndpoint {
|
||||
pub user: SmolStr,
|
||||
// TODO: this is a severe misnomer! We should think of a new name ASAP.
|
||||
pub project: Option<SmolStr>,
|
||||
|
||||
pub cache_key: SmolStr,
|
||||
pub options: NeonOptions,
|
||||
}
|
||||
|
||||
impl ClientCredentials {
|
||||
impl ComputeUserInfoMaybeEndpoint {
|
||||
#[inline]
|
||||
pub fn project(&self) -> Option<&str> {
|
||||
self.project.as_deref()
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientCredentials {
|
||||
pub fn endpoint_sni<'a>(
|
||||
sni: &'a str,
|
||||
common_names: &HashSet<String>,
|
||||
) -> Result<&'a str, ComputeUserInfoParseError> {
|
||||
let Some((subdomain, common_name)) = sni.split_once('.') else {
|
||||
return Err(ComputeUserInfoParseError::UnknownCommonName { cn: sni.into() });
|
||||
};
|
||||
if !common_names.contains(common_name) {
|
||||
return Err(ComputeUserInfoParseError::UnknownCommonName {
|
||||
cn: common_name.into(),
|
||||
});
|
||||
}
|
||||
Ok(subdomain)
|
||||
}
|
||||
|
||||
impl ComputeUserInfoMaybeEndpoint {
|
||||
pub fn parse(
|
||||
ctx: &mut RequestMonitoring,
|
||||
params: &StartupMessageParams,
|
||||
sni: Option<&str>,
|
||||
common_names: Option<HashSet<String>>,
|
||||
) -> Result<Self, ClientCredsParseError> {
|
||||
use ClientCredsParseError::*;
|
||||
common_names: Option<&HashSet<String>>,
|
||||
) -> Result<Self, ComputeUserInfoParseError> {
|
||||
use ComputeUserInfoParseError::*;
|
||||
|
||||
// Some parameters are stored in the startup message.
|
||||
let get_param = |key| params.get(key).ok_or(MissingKey(key));
|
||||
@@ -87,21 +102,7 @@ impl ClientCredentials {
|
||||
|
||||
let project_from_domain = if let Some(sni_str) = sni {
|
||||
if let Some(cn) = common_names {
|
||||
let common_name_from_sni = sni_str.split_once('.').map(|(_, domain)| domain);
|
||||
|
||||
let project = common_name_from_sni
|
||||
.and_then(|domain| {
|
||||
if cn.contains(domain) {
|
||||
subdomain_from_sni(sni_str, domain)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.ok_or_else(|| UnknownCommonName {
|
||||
cn: common_name_from_sni.unwrap_or("").into(),
|
||||
})?;
|
||||
|
||||
Some(project)
|
||||
Some(SmolStr::from(endpoint_sni(sni_str, cn)?))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
@@ -140,17 +141,12 @@ impl ClientCredentials {
|
||||
info!("Connection with password hack");
|
||||
}
|
||||
|
||||
let cache_key = format!(
|
||||
"{}{}",
|
||||
project.as_deref().unwrap_or(""),
|
||||
neon_options_str(params)
|
||||
)
|
||||
.into();
|
||||
let options = NeonOptions::parse_params(params);
|
||||
|
||||
Ok(Self {
|
||||
user,
|
||||
project,
|
||||
cache_key,
|
||||
options,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -207,25 +203,19 @@ fn project_name_valid(name: &str) -> bool {
|
||||
name.chars().all(|c| c.is_alphanumeric() || c == '-')
|
||||
}
|
||||
|
||||
fn subdomain_from_sni(sni: &str, common_name: &str) -> Option<SmolStr> {
|
||||
sni.strip_suffix(common_name)?
|
||||
.strip_suffix('.')
|
||||
.map(SmolStr::from)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use ClientCredsParseError::*;
|
||||
use ComputeUserInfoParseError::*;
|
||||
|
||||
#[test]
|
||||
fn parse_bare_minimum() -> anyhow::Result<()> {
|
||||
// According to postgresql, only `user` should be required.
|
||||
let options = StartupMessageParams::new([("user", "john_doe")]);
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.project, None);
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.project, None);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -238,9 +228,9 @@ mod tests {
|
||||
("foo", "bar"), // should be ignored
|
||||
]);
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.project, None);
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.project, None);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -253,10 +243,11 @@ mod tests {
|
||||
let common_names = Some(["localhost".into()].into());
|
||||
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.project.as_deref(), Some("foo"));
|
||||
assert_eq!(creds.cache_key, "foo");
|
||||
let user_info =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.project.as_deref(), Some("foo"));
|
||||
assert_eq!(user_info.options.get_cache_key("foo"), "foo");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -269,9 +260,9 @@ mod tests {
|
||||
]);
|
||||
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.project.as_deref(), Some("bar"));
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.project.as_deref(), Some("bar"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -284,9 +275,9 @@ mod tests {
|
||||
]);
|
||||
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.project.as_deref(), Some("bar"));
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.project.as_deref(), Some("bar"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -302,9 +293,9 @@ mod tests {
|
||||
]);
|
||||
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert!(creds.project.is_none());
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert!(user_info.project.is_none());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -317,9 +308,9 @@ mod tests {
|
||||
]);
|
||||
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let creds = ClientCredentials::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert!(creds.project.is_none());
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert!(user_info.project.is_none());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -332,9 +323,10 @@ mod tests {
|
||||
let common_names = Some(["localhost".into()].into());
|
||||
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.project.as_deref(), Some("baz"));
|
||||
let user_info =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.project.as_deref(), Some("baz"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -346,14 +338,16 @@ mod tests {
|
||||
let common_names = Some(["a.com".into(), "b.com".into()].into());
|
||||
let sni = Some("p1.a.com");
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?;
|
||||
assert_eq!(creds.project.as_deref(), Some("p1"));
|
||||
let user_info =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
|
||||
assert_eq!(user_info.project.as_deref(), Some("p1"));
|
||||
|
||||
let common_names = Some(["a.com".into(), "b.com".into()].into());
|
||||
let sni = Some("p1.b.com");
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?;
|
||||
assert_eq!(creds.project.as_deref(), Some("p1"));
|
||||
let user_info =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
|
||||
assert_eq!(user_info.project.as_deref(), Some("p1"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -367,8 +361,9 @@ mod tests {
|
||||
let common_names = Some(["localhost".into()].into());
|
||||
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let err = ClientCredentials::parse(&mut ctx, &options, sni, common_names)
|
||||
.expect_err("should fail");
|
||||
let err =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())
|
||||
.expect_err("should fail");
|
||||
match err {
|
||||
InconsistentProjectNames { domain, option } => {
|
||||
assert_eq!(option, "first");
|
||||
@@ -386,8 +381,9 @@ mod tests {
|
||||
let common_names = Some(["example.com".into()].into());
|
||||
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let err = ClientCredentials::parse(&mut ctx, &options, sni, common_names)
|
||||
.expect_err("should fail");
|
||||
let err =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())
|
||||
.expect_err("should fail");
|
||||
match err {
|
||||
UnknownCommonName { cn } => {
|
||||
assert_eq!(cn, "localhost");
|
||||
@@ -406,9 +402,13 @@ mod tests {
|
||||
let sni = Some("project.localhost");
|
||||
let common_names = Some(["localhost".into()].into());
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let creds = ClientCredentials::parse(&mut ctx, &options, sni, common_names)?;
|
||||
assert_eq!(creds.project.as_deref(), Some("project"));
|
||||
assert_eq!(creds.cache_key, "projectendpoint_type:read_write lsn:0/2");
|
||||
let user_info =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
|
||||
assert_eq!(user_info.project.as_deref(), Some("project"));
|
||||
assert_eq!(
|
||||
user_info.options.get_cache_key("project"),
|
||||
"project endpoint_type:read_write lsn:0/2"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ pub struct MetricCollectionConfig {
|
||||
|
||||
pub struct TlsConfig {
|
||||
pub config: Arc<rustls::ServerConfig>,
|
||||
pub common_names: Option<HashSet<String>>,
|
||||
pub common_names: HashSet<String>,
|
||||
pub cert_resolver: Arc<CertResolver>,
|
||||
}
|
||||
|
||||
@@ -97,7 +97,7 @@ pub fn configure_tls(
|
||||
|
||||
Ok(TlsConfig {
|
||||
config,
|
||||
common_names: Some(common_names),
|
||||
common_names,
|
||||
cert_resolver,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ pub mod messages;
|
||||
|
||||
/// Wrappers for console APIs and their mocks.
|
||||
pub mod provider;
|
||||
pub use provider::{errors, Api, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo};
|
||||
pub use provider::{errors, Api, AuthSecret, CachedNodeInfo, NodeInfo};
|
||||
|
||||
/// Various cache-related types.
|
||||
pub mod caches {
|
||||
|
||||
@@ -197,22 +197,6 @@ pub mod errors {
|
||||
}
|
||||
}
|
||||
|
||||
/// Extra query params we'd like to pass to the console.
|
||||
pub struct ConsoleReqExtra {
|
||||
pub options: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
impl ConsoleReqExtra {
|
||||
// https://swagger.io/docs/specification/serialization/ DeepObject format
|
||||
// paramName[prop1]=value1¶mName[prop2]=value2&....
|
||||
pub fn options_as_deep_object(&self) -> Vec<(String, String)> {
|
||||
self.options
|
||||
.iter()
|
||||
.map(|(k, v)| (format!("options[{}]", k), v.to_string()))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Auth secret which is managed by the cloud.
|
||||
#[derive(Clone, Eq, PartialEq, Debug)]
|
||||
pub enum AuthSecret {
|
||||
@@ -249,7 +233,7 @@ pub struct NodeInfo {
|
||||
pub allow_self_signed_compute: bool,
|
||||
}
|
||||
|
||||
pub type NodeInfoCache = TimedLru<Arc<str>, NodeInfo>;
|
||||
pub type NodeInfoCache = TimedLru<SmolStr, NodeInfo>;
|
||||
pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
|
||||
pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, AuthSecret>;
|
||||
pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<SmolStr>>>;
|
||||
@@ -277,7 +261,6 @@ pub trait Api {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
extra: &ConsoleReqExtra,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, errors::WakeComputeError>;
|
||||
}
|
||||
@@ -310,7 +293,7 @@ impl ApiCaches {
|
||||
/// Various caches for [`console`](super).
|
||||
pub struct ApiLocks {
|
||||
name: &'static str,
|
||||
node_locks: DashMap<Arc<str>, Arc<Semaphore>>,
|
||||
node_locks: DashMap<SmolStr, Arc<Semaphore>>,
|
||||
permits: usize,
|
||||
timeout: Duration,
|
||||
registered: prometheus::IntCounter,
|
||||
@@ -378,7 +361,7 @@ impl ApiLocks {
|
||||
|
||||
pub async fn get_wake_compute_permit(
|
||||
&self,
|
||||
key: &Arc<str>,
|
||||
key: &SmolStr,
|
||||
) -> Result<WakeComputePermit, errors::WakeComputeError> {
|
||||
if self.permits == 0 {
|
||||
return Ok(WakeComputePermit { permit: None });
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
use super::{
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
AuthInfo, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
|
||||
AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo,
|
||||
};
|
||||
use crate::cache::Cached;
|
||||
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret};
|
||||
@@ -50,7 +50,7 @@ impl Api {
|
||||
|
||||
async fn do_get_auth_info(
|
||||
&self,
|
||||
creds: &ComputeUserInfo,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<AuthInfo, GetAuthInfoError> {
|
||||
let (secret, allowed_ips) = async {
|
||||
// Perhaps we could persist this connection, but then we'd have to
|
||||
@@ -63,7 +63,7 @@ impl Api {
|
||||
let secret = match get_execute_postgres_query(
|
||||
&client,
|
||||
"select rolpassword from pg_catalog.pg_authid where rolname = $1",
|
||||
&[&&*creds.inner.user],
|
||||
&[&&*user_info.user],
|
||||
"rolpassword",
|
||||
)
|
||||
.await?
|
||||
@@ -74,14 +74,14 @@ impl Api {
|
||||
secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
|
||||
}
|
||||
None => {
|
||||
warn!("user '{}' does not exist", creds.inner.user);
|
||||
warn!("user '{}' does not exist", user_info.user);
|
||||
None
|
||||
}
|
||||
};
|
||||
let allowed_ips = match get_execute_postgres_query(
|
||||
&client,
|
||||
"select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1",
|
||||
&[&creds.endpoint.as_str()],
|
||||
&[&user_info.endpoint.as_str()],
|
||||
"allowed_ips",
|
||||
)
|
||||
.await?
|
||||
@@ -149,10 +149,10 @@ impl super::Api for Api {
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
_ctx: &mut RequestMonitoring,
|
||||
creds: &ComputeUserInfo,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<Option<CachedRoleSecret>, GetAuthInfoError> {
|
||||
Ok(self
|
||||
.do_get_auth_info(creds)
|
||||
.do_get_auth_info(user_info)
|
||||
.await?
|
||||
.secret
|
||||
.map(CachedRoleSecret::new_uncached))
|
||||
@@ -161,10 +161,10 @@ impl super::Api for Api {
|
||||
async fn get_allowed_ips(
|
||||
&self,
|
||||
_ctx: &mut RequestMonitoring,
|
||||
creds: &ComputeUserInfo,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedAllowedIps, GetAuthInfoError> {
|
||||
Ok(Cached::new_uncached(Arc::new(
|
||||
self.do_get_auth_info(creds).await?.allowed_ips,
|
||||
self.do_get_auth_info(user_info).await?.allowed_ips,
|
||||
)))
|
||||
}
|
||||
|
||||
@@ -172,8 +172,7 @@ impl super::Api for Api {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
_ctx: &mut RequestMonitoring,
|
||||
_extra: &ConsoleReqExtra,
|
||||
_creds: &ComputeUserInfo,
|
||||
_user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
self.do_wake_compute()
|
||||
.map_ok(CachedNodeInfo::new_uncached)
|
||||
|
||||
@@ -4,7 +4,7 @@ use super::{
|
||||
super::messages::{ConsoleError, GetRoleSecret, WakeCompute},
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret,
|
||||
ConsoleReqExtra, NodeInfo,
|
||||
NodeInfo,
|
||||
};
|
||||
use crate::{auth::backend::ComputeUserInfo, compute, http, scram};
|
||||
use crate::{
|
||||
@@ -55,7 +55,7 @@ impl Api {
|
||||
async fn do_get_auth_info(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
creds: &ComputeUserInfo,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<AuthInfo, GetAuthInfoError> {
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
let application_name = ctx.console_application_name();
|
||||
@@ -68,8 +68,8 @@ impl Api {
|
||||
.query(&[("session_id", ctx.session_id)])
|
||||
.query(&[
|
||||
("application_name", application_name.as_str()),
|
||||
("project", creds.endpoint.as_str()),
|
||||
("role", creds.inner.user.as_str()),
|
||||
("project", user_info.endpoint.as_str()),
|
||||
("role", user_info.user.as_str()),
|
||||
])
|
||||
.build()?;
|
||||
|
||||
@@ -110,8 +110,7 @@ impl Api {
|
||||
async fn do_wake_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
extra: &ConsoleReqExtra,
|
||||
creds: &ComputeUserInfo,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<NodeInfo, WakeComputeError> {
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
let application_name = ctx.console_application_name();
|
||||
@@ -124,14 +123,14 @@ impl Api {
|
||||
.query(&[("session_id", ctx.session_id)])
|
||||
.query(&[
|
||||
("application_name", application_name.as_str()),
|
||||
("project", creds.endpoint.as_str()),
|
||||
("project", user_info.endpoint.as_str()),
|
||||
]);
|
||||
|
||||
request_builder = if extra.options.is_empty() {
|
||||
request_builder
|
||||
} else {
|
||||
request_builder.query(&extra.options_as_deep_object())
|
||||
};
|
||||
let options = user_info.options.to_deep_object();
|
||||
if !options.is_empty() {
|
||||
request_builder = request_builder.query(&options);
|
||||
}
|
||||
|
||||
let request = request_builder.build()?;
|
||||
|
||||
info!(url = request.url().as_str(), "sending http request");
|
||||
@@ -172,14 +171,14 @@ impl super::Api for Api {
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
creds: &ComputeUserInfo,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<Option<CachedRoleSecret>, GetAuthInfoError> {
|
||||
let ep = &creds.endpoint;
|
||||
let user = &creds.inner.user;
|
||||
let ep = &user_info.endpoint;
|
||||
let user = &user_info.user;
|
||||
if let Some(role_secret) = self.caches.project_info.get_role_secret(ep, user) {
|
||||
return Ok(Some(role_secret));
|
||||
}
|
||||
let auth_info = self.do_get_auth_info(ctx, creds).await?;
|
||||
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
|
||||
let project_id = auth_info.project_id.unwrap_or(ep.clone());
|
||||
if let Some(secret) = &auth_info.secret {
|
||||
self.caches
|
||||
@@ -198,9 +197,9 @@ impl super::Api for Api {
|
||||
async fn get_allowed_ips(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
creds: &ComputeUserInfo,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedAllowedIps, GetAuthInfoError> {
|
||||
let ep = &creds.endpoint;
|
||||
let ep = &user_info.endpoint;
|
||||
if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(ep) {
|
||||
ALLOWED_IPS_BY_CACHE_OUTCOME
|
||||
.with_label_values(&["hit"])
|
||||
@@ -210,9 +209,9 @@ impl super::Api for Api {
|
||||
ALLOWED_IPS_BY_CACHE_OUTCOME
|
||||
.with_label_values(&["miss"])
|
||||
.inc();
|
||||
let auth_info = self.do_get_auth_info(ctx, creds).await?;
|
||||
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
|
||||
let allowed_ips = Arc::new(auth_info.allowed_ips);
|
||||
let user = &creds.inner.user;
|
||||
let user = &user_info.user;
|
||||
let project_id = auth_info.project_id.unwrap_or(ep.clone());
|
||||
if let Some(secret) = &auth_info.secret {
|
||||
self.caches
|
||||
@@ -229,22 +228,19 @@ impl super::Api for Api {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
extra: &ConsoleReqExtra,
|
||||
creds: &ComputeUserInfo,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
let key: &str = &creds.inner.cache_key;
|
||||
let key = user_info.endpoint_cache_key();
|
||||
|
||||
// Every time we do a wakeup http request, the compute node will stay up
|
||||
// for some time (highly depends on the console's scale-to-zero policy);
|
||||
// The connection info remains the same during that period of time,
|
||||
// which means that we might cache it to reduce the load and latency.
|
||||
if let Some(cached) = self.caches.node_info.get(key) {
|
||||
info!(key = key, "found cached compute node info");
|
||||
if let Some(cached) = self.caches.node_info.get(&*key) {
|
||||
info!(key = &*key, "found cached compute node info");
|
||||
return Ok(cached);
|
||||
}
|
||||
|
||||
let key: Arc<str> = key.into();
|
||||
|
||||
let permit = self.locks.get_wake_compute_permit(&key).await?;
|
||||
|
||||
// after getting back a permit - it's possible the cache was filled
|
||||
@@ -256,7 +252,7 @@ impl super::Api for Api {
|
||||
}
|
||||
}
|
||||
|
||||
let node = self.do_wake_compute(ctx, extra, creds).await?;
|
||||
let node = self.do_wake_compute(ctx, user_info).await?;
|
||||
let (_, cached) = self.caches.node_info.insert(key.clone(), node);
|
||||
info!(key = &*key, "created a cache entry for compute node info");
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ use crate::{
|
||||
cancellation::{self, CancelMap},
|
||||
compute,
|
||||
config::{AuthenticationConfig, ProxyConfig, TlsConfig},
|
||||
console::{self, messages::MetricsAuxInfo},
|
||||
console::messages::MetricsAuxInfo,
|
||||
context::RequestMonitoring,
|
||||
metrics::{
|
||||
NUM_BYTES_PROXIED_COUNTER, NUM_BYTES_PROXIED_PER_CLIENT_COUNTER,
|
||||
@@ -26,6 +26,7 @@ use itertools::Itertools;
|
||||
use once_cell::sync::OnceCell;
|
||||
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
|
||||
use regex::Regex;
|
||||
use smol_str::SmolStr;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
@@ -198,27 +199,29 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
drop(pause);
|
||||
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let creds = {
|
||||
let user_info = {
|
||||
let hostname = mode.hostname(stream.get_ref());
|
||||
|
||||
let common_names = tls.and_then(|tls| tls.common_names.clone());
|
||||
let common_names = tls.map(|tls| &tls.common_names);
|
||||
let result = config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|_| auth::ClientCredentials::parse(ctx, ¶ms, hostname, common_names))
|
||||
.map(|_| {
|
||||
auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names)
|
||||
})
|
||||
.transpose();
|
||||
|
||||
match result {
|
||||
Ok(creds) => creds,
|
||||
Ok(user_info) => user_info,
|
||||
Err(e) => stream.throw_error(e).await?,
|
||||
}
|
||||
};
|
||||
|
||||
ctx.set_endpoint_id(creds.get_endpoint());
|
||||
ctx.set_endpoint_id(user_info.get_endpoint());
|
||||
|
||||
let client = Client::new(
|
||||
stream,
|
||||
creds,
|
||||
user_info,
|
||||
¶ms,
|
||||
mode.allow_self_signed_compute(config),
|
||||
endpoint_rate_limiter,
|
||||
@@ -397,7 +400,7 @@ struct Client<'a, S> {
|
||||
/// The underlying libpq protocol stream.
|
||||
stream: PqStream<Stream<S>>,
|
||||
/// Client credentials that we care about.
|
||||
creds: auth::BackendType<'a, auth::ClientCredentials>,
|
||||
user_info: auth::BackendType<'a, auth::ComputeUserInfoMaybeEndpoint>,
|
||||
/// KV-dictionary with PostgreSQL connection params.
|
||||
params: &'a StartupMessageParams,
|
||||
/// Allow self-signed certificates (for testing).
|
||||
@@ -410,14 +413,14 @@ impl<'a, S> Client<'a, S> {
|
||||
/// Construct a new connection context.
|
||||
fn new(
|
||||
stream: PqStream<Stream<S>>,
|
||||
creds: auth::BackendType<'a, auth::ClientCredentials>,
|
||||
user_info: auth::BackendType<'a, auth::ComputeUserInfoMaybeEndpoint>,
|
||||
params: &'a StartupMessageParams,
|
||||
allow_self_signed_compute: bool,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
creds,
|
||||
user_info,
|
||||
params,
|
||||
allow_self_signed_compute,
|
||||
endpoint_rate_limiter,
|
||||
@@ -429,7 +432,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
/// Let the client authenticate and connect to the designated compute node.
|
||||
// Instrumentation logs endpoint name everywhere. Doesn't work for link
|
||||
// auth; strictly speaking we don't know endpoint name in its case.
|
||||
#[tracing::instrument(name = "", fields(ep = %self.creds.get_endpoint().unwrap_or_default()), skip_all)]
|
||||
#[tracing::instrument(name = "", fields(ep = %self.user_info.get_endpoint().unwrap_or_default()), skip_all)]
|
||||
async fn connect_to_db(
|
||||
self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
@@ -439,14 +442,14 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
) -> anyhow::Result<()> {
|
||||
let Self {
|
||||
mut stream,
|
||||
creds,
|
||||
user_info,
|
||||
params,
|
||||
allow_self_signed_compute,
|
||||
endpoint_rate_limiter,
|
||||
} = self;
|
||||
|
||||
// check rate limit
|
||||
if let Some(ep) = creds.get_endpoint() {
|
||||
if let Some(ep) = user_info.get_endpoint() {
|
||||
if !endpoint_rate_limiter.check(ep) {
|
||||
return stream
|
||||
.throw_error(auth::AuthError::too_many_connections())
|
||||
@@ -454,13 +457,9 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
}
|
||||
}
|
||||
|
||||
let extra = console::ConsoleReqExtra {
|
||||
options: neon_options(params),
|
||||
};
|
||||
|
||||
let user = creds.get_user().to_owned();
|
||||
let auth_result = match creds
|
||||
.authenticate(ctx, &extra, &mut stream, mode.allow_cleartext(), config)
|
||||
let user = user_info.get_user().to_owned();
|
||||
let auth_result = match user_info
|
||||
.authenticate(ctx, &mut stream, mode.allow_cleartext(), config)
|
||||
.await
|
||||
{
|
||||
Ok(auth_result) => auth_result,
|
||||
@@ -473,12 +472,12 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
}
|
||||
};
|
||||
|
||||
let (mut node_info, creds) = auth_result;
|
||||
let (mut node_info, user_info) = auth_result;
|
||||
|
||||
node_info.allow_self_signed_compute = allow_self_signed_compute;
|
||||
|
||||
let aux = node_info.aux.clone();
|
||||
let mut node = connect_to_compute(ctx, &TcpMechanism { params }, node_info, &extra, &creds)
|
||||
let mut node = connect_to_compute(ctx, &TcpMechanism { params }, node_info, &user_info)
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
|
||||
@@ -493,29 +492,52 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn neon_options(params: &StartupMessageParams) -> Vec<(String, String)> {
|
||||
#[allow(unstable_name_collisions)]
|
||||
match params.options_raw() {
|
||||
Some(options) => options.filter_map(neon_option).collect(),
|
||||
None => vec![],
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||
pub struct NeonOptions(Vec<(SmolStr, SmolStr)>);
|
||||
|
||||
impl NeonOptions {
|
||||
pub fn parse_params(params: &StartupMessageParams) -> Self {
|
||||
params
|
||||
.options_raw()
|
||||
.map(Self::parse_from_iter)
|
||||
.unwrap_or_default()
|
||||
}
|
||||
pub fn parse_options_raw(options: &str) -> Self {
|
||||
Self::parse_from_iter(StartupMessageParams::parse_options_raw(options))
|
||||
}
|
||||
|
||||
fn parse_from_iter<'a>(options: impl Iterator<Item = &'a str>) -> Self {
|
||||
let mut options = options
|
||||
.filter_map(neon_option)
|
||||
.map(|(k, v)| (k.into(), v.into()))
|
||||
.collect_vec();
|
||||
options.sort();
|
||||
Self(options)
|
||||
}
|
||||
|
||||
pub fn get_cache_key(&self, prefix: &str) -> SmolStr {
|
||||
// prefix + format!(" {k}:{v}")
|
||||
// kinda jank because SmolStr is immutable
|
||||
std::iter::once(prefix)
|
||||
.chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v]))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// <https://swagger.io/docs/specification/serialization/> DeepObject format
|
||||
/// `paramName[prop1]=value1¶mName[prop2]=value2&...`
|
||||
pub fn to_deep_object(&self) -> Vec<(String, SmolStr)> {
|
||||
self.0
|
||||
.iter()
|
||||
.map(|(k, v)| (format!("options[{}]", k), v.clone()))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn neon_options_str(params: &StartupMessageParams) -> String {
|
||||
#[allow(unstable_name_collisions)]
|
||||
neon_options(params)
|
||||
.iter()
|
||||
.map(|(k, v)| format!("{}:{}", k, v))
|
||||
.sorted() // we sort it to use as cache key
|
||||
.intersperse(" ".to_owned())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn neon_option(bytes: &str) -> Option<(String, String)> {
|
||||
pub fn neon_option(bytes: &str) -> Option<(&str, &str)> {
|
||||
static RE: OnceCell<Regex> = OnceCell::new();
|
||||
let re = RE.get_or_init(|| Regex::new(r"^neon_(\w+):(.+)").unwrap());
|
||||
|
||||
let cap = re.captures(bytes)?;
|
||||
let (_, [k, v]) = cap.extract();
|
||||
Some((k.to_owned(), v.to_owned()))
|
||||
Some((k, v))
|
||||
}
|
||||
|
||||
@@ -128,8 +128,7 @@ pub async fn connect_to_compute<M: ConnectMechanism>(
|
||||
ctx: &mut RequestMonitoring,
|
||||
mechanism: &M,
|
||||
mut node_info: console::CachedNodeInfo,
|
||||
extra: &console::ConsoleReqExtra,
|
||||
creds: &auth::BackendType<'_, auth::backend::ComputeUserInfo>,
|
||||
user_info: &auth::BackendType<'_, auth::backend::ComputeUserInfo>,
|
||||
) -> Result<M::Connection, M::Error>
|
||||
where
|
||||
M::ConnectError: ShouldRetry + std::fmt::Debug,
|
||||
@@ -159,10 +158,10 @@ where
|
||||
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
|
||||
info!("compute node's state has likely changed; requesting a wake-up");
|
||||
let node_info = loop {
|
||||
let wake_res = match creds {
|
||||
auth::BackendType::Console(api, creds) => api.wake_compute(ctx, extra, creds).await,
|
||||
let wake_res = match user_info {
|
||||
auth::BackendType::Console(api, user_info) => api.wake_compute(ctx, user_info).await,
|
||||
#[cfg(feature = "testing")]
|
||||
auth::BackendType::Postgres(api, creds) => api.wake_compute(ctx, extra, creds).await,
|
||||
auth::BackendType::Postgres(api, user_info) => api.wake_compute(ctx, user_info).await,
|
||||
// nothing to do?
|
||||
auth::BackendType::Link(_) => return Err(err.into()),
|
||||
// test backend
|
||||
|
||||
@@ -7,7 +7,7 @@ use super::retry::ShouldRetry;
|
||||
use super::*;
|
||||
use crate::auth::backend::{ComputeUserInfo, TestBackend};
|
||||
use crate::config::CertResolver;
|
||||
use crate::console::{CachedNodeInfo, NodeInfo};
|
||||
use crate::console::{self, CachedNodeInfo, NodeInfo};
|
||||
use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT};
|
||||
use crate::{auth, http, sasl, scram};
|
||||
use async_trait::async_trait;
|
||||
@@ -83,7 +83,7 @@ fn generate_tls_config<'a>(
|
||||
let mut cert_resolver = CertResolver::new();
|
||||
cert_resolver.add_cert(key, vec![cert], true)?;
|
||||
|
||||
let common_names = Some(cert_resolver.get_common_names());
|
||||
let common_names = cert_resolver.get_common_names();
|
||||
|
||||
TlsConfig {
|
||||
config,
|
||||
@@ -487,15 +487,10 @@ fn helper_create_cached_node_info() -> CachedNodeInfo {
|
||||
|
||||
fn helper_create_connect_info(
|
||||
mechanism: &TestConnectMechanism,
|
||||
) -> (
|
||||
CachedNodeInfo,
|
||||
console::ConsoleReqExtra,
|
||||
auth::BackendType<'_, ComputeUserInfo>,
|
||||
) {
|
||||
) -> (CachedNodeInfo, auth::BackendType<'_, ComputeUserInfo>) {
|
||||
let cache = helper_create_cached_node_info();
|
||||
let extra = console::ConsoleReqExtra { options: vec![] };
|
||||
let creds = auth::BackendType::Test(mechanism);
|
||||
(cache, extra, creds)
|
||||
let user_info = auth::BackendType::Test(mechanism);
|
||||
(cache, user_info)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -503,8 +498,8 @@ async fn connect_to_compute_success() {
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Connect]);
|
||||
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &extra, &creds)
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -515,8 +510,8 @@ async fn connect_to_compute_retry() {
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Connect]);
|
||||
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &extra, &creds)
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -528,8 +523,8 @@ async fn connect_to_compute_non_retry_1() {
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Fail]);
|
||||
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &extra, &creds)
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
@@ -541,8 +536,8 @@ async fn connect_to_compute_non_retry_2() {
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Fail, Wake, Retry, Connect]);
|
||||
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &extra, &creds)
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -558,8 +553,8 @@ async fn connect_to_compute_non_retry_3() {
|
||||
Retry, Wake, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry,
|
||||
Retry, Retry, Retry, Retry, /* the 17th time */ Retry,
|
||||
]);
|
||||
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &extra, &creds)
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
@@ -571,8 +566,8 @@ async fn wake_retry() {
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Retry, WakeRetry, Wake, Connect]);
|
||||
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &extra, &creds)
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -584,8 +579,8 @@ async fn wake_non_retry() {
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Retry, WakeFail]);
|
||||
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &extra, &creds)
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
|
||||
@@ -17,6 +17,7 @@ pub use reqwest_middleware::{ClientWithMiddleware, Error};
|
||||
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
use tokio_util::task::TaskTracker;
|
||||
|
||||
use crate::config::TlsConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE;
|
||||
use crate::protocol2::{ProxyProtocolAccept, WithClientIp};
|
||||
@@ -69,14 +70,14 @@ pub async fn task_main(
|
||||
}
|
||||
});
|
||||
|
||||
let tls_config = config.tls_config.as_ref().map(|cfg| cfg.to_server_config());
|
||||
let tls_acceptor: tokio_rustls::TlsAcceptor = match tls_config {
|
||||
Some(config) => config.into(),
|
||||
let tls_config = match config.tls_config.as_ref() {
|
||||
Some(config) => config,
|
||||
None => {
|
||||
warn!("TLS config is missing, WebSocket Secure server will not be started");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
let tls_acceptor: tokio_rustls::TlsAcceptor = tls_config.to_server_config().into();
|
||||
|
||||
let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
|
||||
let _ = addr_incoming.set_nodelay(true);
|
||||
@@ -126,6 +127,7 @@ pub async fn task_main(
|
||||
request_handler(
|
||||
req,
|
||||
config,
|
||||
tls_config,
|
||||
conn_pool,
|
||||
ws_connections,
|
||||
cancel_map,
|
||||
@@ -195,6 +197,7 @@ where
|
||||
async fn request_handler(
|
||||
mut request: Request<Body>,
|
||||
config: &'static ProxyConfig,
|
||||
tls: &'static TlsConfig,
|
||||
conn_pool: Arc<conn_pool::GlobalConnPool>,
|
||||
ws_connections: TaskTracker,
|
||||
cancel_map: Arc<CancelMap>,
|
||||
@@ -243,6 +246,7 @@ async fn request_handler(
|
||||
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
|
||||
|
||||
sql_over_http::handle(
|
||||
tls,
|
||||
&config.http_config,
|
||||
&mut ctx,
|
||||
request,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use anyhow::{anyhow, Context};
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
use futures::{future::poll_fn, Future};
|
||||
@@ -9,7 +9,6 @@ use pbkdf2::{
|
||||
password_hash::{PasswordHashString, PasswordHasher, PasswordVerifier, SaltString},
|
||||
Params, Pbkdf2,
|
||||
};
|
||||
use pq_proto::StartupMessageParams;
|
||||
use prometheus::{exponential_buckets, register_histogram, Histogram};
|
||||
use rand::Rng;
|
||||
use smol_str::SmolStr;
|
||||
@@ -30,7 +29,7 @@ use crate::{
|
||||
console,
|
||||
context::RequestMonitoring,
|
||||
metrics::NUM_DB_CONNECTIONS_GAUGE,
|
||||
proxy::{connect_compute::ConnectMechanism, neon_options},
|
||||
proxy::connect_compute::ConnectMechanism,
|
||||
usage_metrics::{Ids, MetricCounter, USAGE_METRICS},
|
||||
};
|
||||
use crate::{compute, config};
|
||||
@@ -38,28 +37,37 @@ use crate::{compute, config};
|
||||
use tracing::{debug, error, warn, Span};
|
||||
use tracing::{info, info_span, Instrument};
|
||||
|
||||
pub const APP_NAME: &str = "/sql_over_http";
|
||||
pub const APP_NAME: SmolStr = SmolStr::new_inline("/sql_over_http");
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConnInfo {
|
||||
pub username: SmolStr,
|
||||
pub user_info: ComputeUserInfo,
|
||||
pub dbname: SmolStr,
|
||||
pub hostname: SmolStr,
|
||||
pub password: SmolStr,
|
||||
pub options: Option<SmolStr>,
|
||||
}
|
||||
|
||||
impl ConnInfo {
|
||||
// hm, change to hasher to avoid cloning?
|
||||
pub fn db_and_user(&self) -> (SmolStr, SmolStr) {
|
||||
(self.dbname.clone(), self.username.clone())
|
||||
(self.dbname.clone(), self.user_info.user.clone())
|
||||
}
|
||||
|
||||
pub fn endpoint_cache_key(&self) -> SmolStr {
|
||||
self.user_info.endpoint_cache_key()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ConnInfo {
|
||||
// use custom display to avoid logging password
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}@{}/{}", self.username, self.hostname, self.dbname)
|
||||
write!(
|
||||
f,
|
||||
"{}@{}/{}?{}",
|
||||
self.user_info.user,
|
||||
self.user_info.endpoint,
|
||||
self.dbname,
|
||||
self.user_info.options.get_cache_key("")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -319,7 +327,7 @@ impl GlobalConnPool {
|
||||
let mut hash_valid = false;
|
||||
let mut endpoint_pool = Weak::new();
|
||||
if !force_new {
|
||||
let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
|
||||
let pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key());
|
||||
endpoint_pool = Arc::downgrade(&pool);
|
||||
let mut hash = None;
|
||||
|
||||
@@ -401,7 +409,7 @@ impl GlobalConnPool {
|
||||
Err(err)
|
||||
if hash_valid && err.to_string().contains("password authentication failed") =>
|
||||
{
|
||||
let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
|
||||
let pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key());
|
||||
let mut pool = pool.write();
|
||||
if let Some(entry) = pool.pools.get_mut(&conn_info.db_and_user()) {
|
||||
entry.password_hash = None;
|
||||
@@ -418,7 +426,7 @@ impl GlobalConnPool {
|
||||
})
|
||||
.await??;
|
||||
|
||||
let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
|
||||
let pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key());
|
||||
let mut pool = pool.write();
|
||||
pool.pools
|
||||
.entry(conn_info.db_and_user())
|
||||
@@ -520,23 +528,11 @@ async fn connect_to_compute(
|
||||
conn_id: uuid::Uuid,
|
||||
pool: Weak<RwLock<EndpointConnPool>>,
|
||||
) -> anyhow::Result<ClientInner> {
|
||||
let tls = config.tls_config.as_ref();
|
||||
let common_names = tls.and_then(|tls| tls.common_names.clone());
|
||||
|
||||
let params = StartupMessageParams::new([
|
||||
("user", &conn_info.username),
|
||||
("database", &conn_info.dbname),
|
||||
("application_name", APP_NAME),
|
||||
("options", conn_info.options.as_deref().unwrap_or("")),
|
||||
]);
|
||||
let creds =
|
||||
auth::ClientCredentials::parse(ctx, ¶ms, Some(&conn_info.hostname), common_names)?;
|
||||
|
||||
let creds =
|
||||
ComputeUserInfo::try_from(creds).map_err(|_| anyhow!("missing endpoint identifier"))?;
|
||||
let backend = config.auth_backend.as_ref().map(|_| creds);
|
||||
|
||||
let console_options = neon_options(¶ms);
|
||||
ctx.set_application(Some(APP_NAME));
|
||||
let backend = config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|_| conn_info.user_info.clone());
|
||||
|
||||
if !config.disable_ip_check_for_http {
|
||||
let allowed_ips = backend.get_allowed_ips(ctx).await?;
|
||||
@@ -544,11 +540,8 @@ async fn connect_to_compute(
|
||||
return Err(auth::AuthError::ip_address_not_allowed().into());
|
||||
}
|
||||
}
|
||||
let extra = console::ConsoleReqExtra {
|
||||
options: console_options,
|
||||
};
|
||||
let node_info = backend
|
||||
.wake_compute(ctx, &extra)
|
||||
.wake_compute(ctx)
|
||||
.await?
|
||||
.context("missing cache entry from wake_compute")?;
|
||||
|
||||
@@ -563,7 +556,6 @@ async fn connect_to_compute(
|
||||
idle: config.http_config.pool_options.idle_timeout,
|
||||
},
|
||||
node_info,
|
||||
&extra,
|
||||
&backend,
|
||||
)
|
||||
.await
|
||||
@@ -582,7 +574,7 @@ async fn connect_to_compute_once(
|
||||
let mut session = ctx.session_id;
|
||||
|
||||
let (client, mut connection) = config
|
||||
.user(&conn_info.username)
|
||||
.user(&conn_info.user_info.user)
|
||||
.password(&*conn_info.password)
|
||||
.dbname(&conn_info.dbname)
|
||||
.connect_timeout(timeout)
|
||||
|
||||
@@ -28,9 +28,13 @@ use url::Url;
|
||||
use utils::http::error::ApiError;
|
||||
use utils::http::json::json_response;
|
||||
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::endpoint_sni;
|
||||
use crate::config::HttpConfig;
|
||||
use crate::config::TlsConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
|
||||
use crate::proxy::NeonOptions;
|
||||
|
||||
use super::conn_pool::ConnInfo;
|
||||
use super::conn_pool::GlobalConnPool;
|
||||
@@ -125,6 +129,7 @@ fn get_conn_info(
|
||||
ctx: &mut RequestMonitoring,
|
||||
headers: &HeaderMap,
|
||||
sni_hostname: Option<String>,
|
||||
tls: &TlsConfig,
|
||||
) -> Result<ConnInfo, anyhow::Error> {
|
||||
let connection_string = headers
|
||||
.get("Neon-Connection-String")
|
||||
@@ -179,8 +184,10 @@ fn get_conn_info(
|
||||
}
|
||||
}
|
||||
|
||||
let hostname: SmolStr = hostname.into();
|
||||
ctx.set_endpoint_id(Some(hostname.clone()));
|
||||
let endpoint = endpoint_sni(hostname, &tls.common_names)?;
|
||||
|
||||
let endpoint: SmolStr = endpoint.into();
|
||||
ctx.set_endpoint_id(Some(endpoint.clone()));
|
||||
|
||||
let pairs = connection_url.query_pairs();
|
||||
|
||||
@@ -188,22 +195,27 @@ fn get_conn_info(
|
||||
|
||||
for (key, value) in pairs {
|
||||
if key == "options" {
|
||||
options = Some(value.into());
|
||||
options = Some(NeonOptions::parse_options_raw(&value));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let user_info = ComputeUserInfo {
|
||||
endpoint,
|
||||
user: username,
|
||||
options: options.unwrap_or_default(),
|
||||
};
|
||||
|
||||
Ok(ConnInfo {
|
||||
username,
|
||||
user_info,
|
||||
dbname: dbname.into(),
|
||||
hostname,
|
||||
password: password.into(),
|
||||
options,
|
||||
})
|
||||
}
|
||||
|
||||
// TODO: return different http error codes
|
||||
pub async fn handle(
|
||||
tls: &'static TlsConfig,
|
||||
config: &'static HttpConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
request: Request<Body>,
|
||||
@@ -212,7 +224,7 @@ pub async fn handle(
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
let result = tokio::time::timeout(
|
||||
config.request_timeout,
|
||||
handle_inner(config, ctx, request, sni_hostname, conn_pool),
|
||||
handle_inner(tls, config, ctx, request, sni_hostname, conn_pool),
|
||||
)
|
||||
.await;
|
||||
let mut response = match result {
|
||||
@@ -294,6 +306,7 @@ pub async fn handle(
|
||||
|
||||
#[instrument(name = "sql-over-http", fields(pid = tracing::field::Empty), skip_all)]
|
||||
async fn handle_inner(
|
||||
tls: &'static TlsConfig,
|
||||
config: &'static HttpConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
request: Request<Body>,
|
||||
@@ -308,7 +321,7 @@ async fn handle_inner(
|
||||
// Determine the destination and connection params
|
||||
//
|
||||
let headers = request.headers();
|
||||
let conn_info = get_conn_info(ctx, headers, sni_hostname)?;
|
||||
let conn_info = get_conn_info(ctx, headers, sni_hostname, tls)?;
|
||||
|
||||
// Determine the output options. Default behaviour is 'false'. Anything that is not
|
||||
// strictly 'true' assumed to be false.
|
||||
|
||||
Reference in New Issue
Block a user