mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-05 20:42:54 +00:00
proxy: make auth more type safe (#5689)
## Problema5292f7e67/proxy/src/auth/backend.rs (L146-L148)a5292f7e67/proxy/src/console/provider/neon.rs (L90)a5292f7e67/proxy/src/console/provider/neon.rs (L154)## Summary of changes 1. Test backend is only enabled on `cfg(test)`. 2. Postgres mock backend + MD5 auth keys are only enabled on `cfg(feature = testing)` 3. Password hack and cleartext flow will have their passwords validated before proceeding. 4. Distinguish between ClientCredentials with endpoint and without, removing many panics in the process
This commit is contained in:
@@ -4,6 +4,10 @@ version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[features]
|
||||
default = []
|
||||
testing = []
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
async-trait.workspace = true
|
||||
@@ -69,6 +73,7 @@ webpki-roots.workspace = true
|
||||
x509-parser.workspace = true
|
||||
native-tls.workspace = true
|
||||
postgres-native-tls.workspace = true
|
||||
postgres-protocol.workspace = true
|
||||
smol_str.workspace = true
|
||||
|
||||
workspace_hack.workspace = true
|
||||
@@ -78,4 +83,3 @@ tokio-util.workspace = true
|
||||
rcgen.workspace = true
|
||||
rstest.workspace = true
|
||||
tokio-postgres-rustls.workspace = true
|
||||
postgres-protocol.workspace = true
|
||||
|
||||
@@ -3,9 +3,11 @@ mod hacks;
|
||||
mod link;
|
||||
|
||||
pub use link::LinkAuthError;
|
||||
use smol_str::SmolStr;
|
||||
use tokio_postgres::config::AuthKeys;
|
||||
|
||||
use crate::auth::credentials::check_peer_addr_is_in_list;
|
||||
use crate::auth::validate_password_and_exchange;
|
||||
use crate::console::errors::GetAuthInfoError;
|
||||
use crate::console::provider::AuthInfo;
|
||||
use crate::console::AuthSecret;
|
||||
@@ -24,31 +26,12 @@ use crate::{
|
||||
};
|
||||
use futures::TryFutureExt;
|
||||
use std::borrow::Cow;
|
||||
use std::net::IpAddr;
|
||||
use std::ops::ControlFlow;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
/// A product of successful authentication.
|
||||
pub struct AuthSuccess<T> {
|
||||
/// Did we send [`pq_proto::BeMessage::AuthenticationOk`] to client?
|
||||
pub reported_auth_ok: bool,
|
||||
/// Something to be considered a positive result.
|
||||
pub value: T,
|
||||
}
|
||||
|
||||
impl<T> AuthSuccess<T> {
|
||||
/// Very similar to [`std::option::Option::map`].
|
||||
/// Maps [`AuthSuccess<T>`] to [`AuthSuccess<R>`] by applying
|
||||
/// a function to a contained value.
|
||||
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> AuthSuccess<R> {
|
||||
AuthSuccess {
|
||||
reported_auth_ok: self.reported_auth_ok,
|
||||
value: f(self.value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// This type serves two purposes:
|
||||
///
|
||||
/// * When `T` is `()`, it's just a regular auth backend selector
|
||||
@@ -61,9 +44,11 @@ pub enum BackendType<'a, T> {
|
||||
/// Current Cloud API (V2).
|
||||
Console(Cow<'a, console::provider::neon::Api>, T),
|
||||
/// Local mock of Cloud API (V2).
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(Cow<'a, console::provider::mock::Api>, T),
|
||||
/// Authentication via a web browser.
|
||||
Link(Cow<'a, url::ApiUrl>),
|
||||
#[cfg(test)]
|
||||
/// Test backend.
|
||||
Test(&'a dyn TestBackend),
|
||||
}
|
||||
@@ -78,8 +63,10 @@ impl std::fmt::Display for BackendType<'_, ()> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(endpoint, _) => fmt.debug_tuple("Console").field(&endpoint.url()).finish(),
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(endpoint, _) => fmt.debug_tuple("Postgres").field(&endpoint.url()).finish(),
|
||||
Link(url) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
|
||||
#[cfg(test)]
|
||||
Test(_) => fmt.debug_tuple("Test").finish(),
|
||||
}
|
||||
}
|
||||
@@ -92,8 +79,10 @@ impl<T> BackendType<'_, T> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(c, x) => Console(Cow::Borrowed(c), x),
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(c, x) => Postgres(Cow::Borrowed(c), x),
|
||||
Link(c) => Link(Cow::Borrowed(c)),
|
||||
#[cfg(test)]
|
||||
Test(x) => Test(*x),
|
||||
}
|
||||
}
|
||||
@@ -107,8 +96,10 @@ impl<'a, T> BackendType<'a, T> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(c, x) => Console(c, f(x)),
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(c, x) => Postgres(c, f(x)),
|
||||
Link(c) => Link(c),
|
||||
#[cfg(test)]
|
||||
Test(x) => Test(x),
|
||||
}
|
||||
}
|
||||
@@ -121,51 +112,87 @@ impl<'a, T, E> BackendType<'a, Result<T, E>> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(c, x) => x.map(|x| Console(c, x)),
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(c, x) => x.map(|x| Postgres(c, x)),
|
||||
Link(c) => Ok(Link(c)),
|
||||
#[cfg(test)]
|
||||
Test(x) => Ok(Test(x)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ComputeCredentials {
|
||||
pub struct ComputeCredentials<T> {
|
||||
pub info: ComputeUserInfo,
|
||||
pub keys: T,
|
||||
}
|
||||
|
||||
pub struct ComputeUserInfoNoEndpoint {
|
||||
pub user: SmolStr,
|
||||
pub peer_addr: IpAddr,
|
||||
pub cache_key: SmolStr,
|
||||
}
|
||||
|
||||
pub struct ComputeUserInfo {
|
||||
pub endpoint: SmolStr,
|
||||
pub inner: ComputeUserInfoNoEndpoint,
|
||||
}
|
||||
|
||||
pub enum ComputeCredentialKeys {
|
||||
#[cfg(feature = "testing")]
|
||||
Password(Vec<u8>),
|
||||
AuthKeys(AuthKeys),
|
||||
}
|
||||
|
||||
impl TryFrom<ClientCredentials> for ComputeUserInfo {
|
||||
// user name
|
||||
type Error = ComputeUserInfoNoEndpoint;
|
||||
|
||||
fn try_from(creds: ClientCredentials) -> Result<Self, Self::Error> {
|
||||
let inner = ComputeUserInfoNoEndpoint {
|
||||
user: creds.user,
|
||||
peer_addr: creds.peer_addr,
|
||||
cache_key: creds.cache_key,
|
||||
};
|
||||
match creds.project {
|
||||
None => Err(inner),
|
||||
Some(endpoint) => Ok(ComputeUserInfo { endpoint, inner }),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// True to its name, this function encapsulates our current auth trade-offs.
|
||||
/// Here, we choose the appropriate auth flow based on circumstances.
|
||||
async fn auth_quirks_creds(
|
||||
///
|
||||
/// All authentication flows will emit an AuthenticationOk message if successful.
|
||||
async fn auth_quirks(
|
||||
api: &impl console::Api,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &mut ClientCredentials<'_>,
|
||||
creds: ClientCredentials,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
) -> auth::Result<AuthSuccess<ComputeCredentials>> {
|
||||
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
|
||||
// If there's no project so far, that entails that client doesn't
|
||||
// support SNI or other means of passing the endpoint (project) name.
|
||||
// We now expect to see a very specific payload in the place of password.
|
||||
let maybe_success = if creds.project.is_none() {
|
||||
// Password will be checked by the compute node later.
|
||||
Some(hacks::password_hack(creds, client, latency_timer).await?)
|
||||
} else {
|
||||
None
|
||||
let (info, unauthenticated_password) = match creds.try_into() {
|
||||
Err(info) => {
|
||||
let res = hacks::password_hack_no_authentication(info, client, latency_timer).await?;
|
||||
(res.info, Some(res.keys))
|
||||
}
|
||||
Ok(info) => (info, None),
|
||||
};
|
||||
|
||||
// Password hack should set the project name.
|
||||
// TODO: make `creds.project` more type-safe.
|
||||
assert!(creds.project.is_some());
|
||||
info!("fetching user's authentication info");
|
||||
// TODO(anna): this will slow down both "hacks" below; we probably need a cache.
|
||||
let AuthInfo {
|
||||
secret,
|
||||
allowed_ips,
|
||||
} = api.get_auth_info(extra, creds).await?;
|
||||
} = api.get_auth_info(extra, &info).await?;
|
||||
|
||||
// check allowed list
|
||||
if !check_peer_addr_is_in_list(&creds.peer_addr.ip(), &allowed_ips) {
|
||||
if !check_peer_addr_is_in_list(&info.inner.peer_addr, &allowed_ips) {
|
||||
return Err(auth::AuthError::ip_address_not_allowed());
|
||||
}
|
||||
let secret = secret.unwrap_or_else(|| {
|
||||
@@ -173,36 +200,49 @@ async fn auth_quirks_creds(
|
||||
// prevent malicious probing (possible due to missing protocol steps).
|
||||
// This mocked secret will never lead to successful authentication.
|
||||
info!("authentication info not found, mocking it");
|
||||
AuthSecret::Scram(scram::ServerSecret::mock(creds.user, rand::random()))
|
||||
AuthSecret::Scram(scram::ServerSecret::mock(&info.inner.user, rand::random()))
|
||||
});
|
||||
|
||||
if let Some(success) = maybe_success {
|
||||
return Ok(success);
|
||||
if let Some(password) = unauthenticated_password {
|
||||
let auth_outcome = validate_password_and_exchange(&password, secret)?;
|
||||
let keys = match auth_outcome {
|
||||
crate::sasl::Outcome::Success(key) => key,
|
||||
crate::sasl::Outcome::Failure(reason) => {
|
||||
info!("auth backend failed with an error: {reason}");
|
||||
return Err(auth::AuthError::auth_failed(&*info.inner.user));
|
||||
}
|
||||
};
|
||||
|
||||
// we have authenticated the password
|
||||
client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?;
|
||||
|
||||
return Ok(ComputeCredentials { info, keys });
|
||||
}
|
||||
|
||||
// -- the remaining flows are self-authenticating --
|
||||
|
||||
// Perform cleartext auth if we're allowed to do that.
|
||||
// Currently, we use it for websocket connections (latency).
|
||||
if allow_cleartext {
|
||||
// Password will be checked by the compute node later.
|
||||
return hacks::cleartext_hack(client, latency_timer).await;
|
||||
return hacks::authenticate_cleartext(info, client, latency_timer, secret).await;
|
||||
}
|
||||
|
||||
// Finally, proceed with the main auth flow (SCRAM-based).
|
||||
classic::authenticate(creds, client, config, latency_timer, secret).await
|
||||
classic::authenticate(info, client, config, latency_timer, secret).await
|
||||
}
|
||||
|
||||
/// True to its name, this function encapsulates our current auth trade-offs.
|
||||
/// Here, we choose the appropriate auth flow based on circumstances.
|
||||
async fn auth_quirks(
|
||||
/// Authenticate the user and then wake a compute (or retrieve an existing compute session from cache)
|
||||
/// only if authentication was successfuly.
|
||||
async fn auth_and_wake_compute(
|
||||
api: &impl console::Api,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &mut ClientCredentials<'_>,
|
||||
creds: ClientCredentials,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
|
||||
let auth_stuff = auth_quirks_creds(
|
||||
) -> auth::Result<(CachedNodeInfo, ComputeUserInfo)> {
|
||||
let compute_credentials = auth_quirks(
|
||||
api,
|
||||
extra,
|
||||
creds,
|
||||
@@ -215,7 +255,7 @@ async fn auth_quirks(
|
||||
|
||||
let mut num_retries = 0;
|
||||
let mut node = loop {
|
||||
let wake_res = api.wake_compute(extra, creds).await;
|
||||
let wake_res = api.wake_compute(extra, &compute_credentials.info).await;
|
||||
match handle_try_wake(wake_res, num_retries) {
|
||||
Err(e) => {
|
||||
error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
|
||||
@@ -232,27 +272,27 @@ async fn auth_quirks(
|
||||
tokio::time::sleep(wait_duration).await;
|
||||
};
|
||||
|
||||
match auth_stuff.value {
|
||||
ComputeCredentials::Password(password) => node.config.password(password),
|
||||
ComputeCredentials::AuthKeys(auth_keys) => node.config.auth_keys(auth_keys),
|
||||
match compute_credentials.keys {
|
||||
#[cfg(feature = "testing")]
|
||||
ComputeCredentialKeys::Password(password) => node.config.password(password),
|
||||
ComputeCredentialKeys::AuthKeys(auth_keys) => node.config.auth_keys(auth_keys),
|
||||
};
|
||||
|
||||
Ok(AuthSuccess {
|
||||
reported_auth_ok: auth_stuff.reported_auth_ok,
|
||||
value: node,
|
||||
})
|
||||
Ok((node, compute_credentials.info))
|
||||
}
|
||||
|
||||
impl BackendType<'_, ClientCredentials<'_>> {
|
||||
impl<'a> BackendType<'a, ClientCredentials> {
|
||||
/// Get compute endpoint name from the credentials.
|
||||
pub fn get_endpoint(&self) -> Option<String> {
|
||||
pub fn get_endpoint(&self) -> Option<SmolStr> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(_, creds) => creds.project.clone(),
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(_, creds) => creds.project.clone(),
|
||||
Link(_) => Some("link".to_owned()),
|
||||
Test(_) => Some("test".to_owned()),
|
||||
Link(_) => Some("link".into()),
|
||||
#[cfg(test)]
|
||||
Test(_) => Some("test".into()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -261,9 +301,11 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(_, creds) => creds.user,
|
||||
Postgres(_, creds) => creds.user,
|
||||
Console(_, creds) => &creds.user,
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(_, creds) => &creds.user,
|
||||
Link(_) => "link",
|
||||
#[cfg(test)]
|
||||
Test(_) => "test",
|
||||
}
|
||||
}
|
||||
@@ -271,26 +313,25 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
/// Authenticate the client via the requested backend, possibly using credentials.
|
||||
#[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
|
||||
pub async fn authenticate(
|
||||
&mut self,
|
||||
self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
|
||||
) -> auth::Result<(CachedNodeInfo, BackendType<'a, ComputeUserInfo>)> {
|
||||
use BackendType::*;
|
||||
|
||||
let res = match self {
|
||||
Console(api, creds) => {
|
||||
info!(
|
||||
user = creds.user,
|
||||
user = &*creds.user,
|
||||
project = creds.project(),
|
||||
"performing authentication using the console"
|
||||
);
|
||||
|
||||
let api = api.as_ref();
|
||||
auth_quirks(
|
||||
api,
|
||||
let (cache_info, user_info) = auth_and_wake_compute(
|
||||
&*api,
|
||||
extra,
|
||||
creds,
|
||||
client,
|
||||
@@ -298,18 +339,19 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
config,
|
||||
latency_timer,
|
||||
)
|
||||
.await?
|
||||
.await?;
|
||||
(cache_info, BackendType::Console(api, user_info))
|
||||
}
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(api, creds) => {
|
||||
info!(
|
||||
user = creds.user,
|
||||
user = &*creds.user,
|
||||
project = creds.project(),
|
||||
"performing authentication using a local postgres instance"
|
||||
);
|
||||
|
||||
let api = api.as_ref();
|
||||
auth_quirks(
|
||||
api,
|
||||
let (cache_info, user_info) = auth_and_wake_compute(
|
||||
&*api,
|
||||
extra,
|
||||
creds,
|
||||
client,
|
||||
@@ -317,16 +359,21 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
config,
|
||||
latency_timer,
|
||||
)
|
||||
.await?
|
||||
.await?;
|
||||
(cache_info, BackendType::Postgres(api, user_info))
|
||||
}
|
||||
// NOTE: this auth backend doesn't use client credentials.
|
||||
Link(url) => {
|
||||
info!("performing link authentication");
|
||||
|
||||
link::authenticate(url, client)
|
||||
.await?
|
||||
.map(CachedNodeInfo::new_uncached)
|
||||
let node_info = link::authenticate(&url, client).await?;
|
||||
|
||||
(
|
||||
CachedNodeInfo::new_uncached(node_info),
|
||||
BackendType::Link(url),
|
||||
)
|
||||
}
|
||||
#[cfg(test)]
|
||||
Test(_) => {
|
||||
unreachable!("this function should never be called in the test backend")
|
||||
}
|
||||
@@ -335,7 +382,9 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
info!("user successfully authenticated");
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendType<'_, ComputeUserInfo> {
|
||||
pub async fn get_allowed_ips(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
@@ -343,8 +392,10 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(api, creds) => api.get_allowed_ips(extra, creds).await,
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(api, creds) => api.get_allowed_ips(extra, creds).await,
|
||||
Link(_) => Ok(Arc::new(vec![])),
|
||||
#[cfg(test)]
|
||||
Test(x) => x.get_allowed_ips(),
|
||||
}
|
||||
}
|
||||
@@ -359,8 +410,10 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
|
||||
match self {
|
||||
Console(api, creds) => api.wake_compute(extra, creds).map_ok(Some).await,
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(api, creds) => api.wake_compute(extra, creds).map_ok(Some).await,
|
||||
Link(_) => Ok(None),
|
||||
#[cfg(test)]
|
||||
Test(x) => x.wake_compute().map(Some),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use super::{AuthSuccess, ComputeCredentials};
|
||||
use super::{ComputeCredentials, ComputeUserInfo};
|
||||
use crate::{
|
||||
auth::{self, AuthFlow, ClientCredentials},
|
||||
auth::{self, backend::ComputeCredentialKeys, AuthFlow},
|
||||
compute,
|
||||
config::AuthenticationConfig,
|
||||
console::AuthSecret,
|
||||
@@ -12,14 +12,15 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{info, warn};
|
||||
|
||||
pub(super) async fn authenticate(
|
||||
creds: &ClientCredentials<'_>,
|
||||
creds: ComputeUserInfo,
|
||||
client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
config: &'static AuthenticationConfig,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
secret: AuthSecret,
|
||||
) -> auth::Result<AuthSuccess<ComputeCredentials>> {
|
||||
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
|
||||
let flow = AuthFlow::new(client);
|
||||
let scram_keys = match secret {
|
||||
#[cfg(feature = "testing")]
|
||||
AuthSecret::Md5(_) => {
|
||||
info!("auth endpoint chooses MD5");
|
||||
return Err(auth::AuthError::bad_auth_method("MD5"));
|
||||
@@ -53,7 +54,7 @@ pub(super) async fn authenticate(
|
||||
sasl::Outcome::Success(key) => key,
|
||||
sasl::Outcome::Failure(reason) => {
|
||||
info!("auth backend failed with an error: {reason}");
|
||||
return Err(auth::AuthError::auth_failed(creds.user));
|
||||
return Err(auth::AuthError::auth_failed(&*creds.inner.user));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -64,9 +65,9 @@ pub(super) async fn authenticate(
|
||||
}
|
||||
};
|
||||
|
||||
Ok(AuthSuccess {
|
||||
reported_auth_ok: false,
|
||||
value: ComputeCredentials::AuthKeys(tokio_postgres::config::AuthKeys::ScramSha256(
|
||||
Ok(ComputeCredentials {
|
||||
info: creds,
|
||||
keys: ComputeCredentialKeys::AuthKeys(tokio_postgres::config::AuthKeys::ScramSha256(
|
||||
scram_keys,
|
||||
)),
|
||||
})
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
use super::{AuthSuccess, ComputeCredentials};
|
||||
use super::{
|
||||
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint,
|
||||
};
|
||||
use crate::{
|
||||
auth::{self, AuthFlow, ClientCredentials},
|
||||
auth::{self, AuthFlow},
|
||||
console::AuthSecret,
|
||||
proxy::LatencyTimer,
|
||||
sasl,
|
||||
stream::{self, Stream},
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
@@ -11,35 +15,42 @@ use tracing::{info, warn};
|
||||
/// one round trip and *expensive* computations (>= 4096 HMAC iterations).
|
||||
/// These properties are benefical for serverless JS workers, so we
|
||||
/// use this mechanism for websocket connections.
|
||||
pub async fn cleartext_hack(
|
||||
pub async fn authenticate_cleartext(
|
||||
info: ComputeUserInfo,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
) -> auth::Result<AuthSuccess<ComputeCredentials>> {
|
||||
secret: AuthSecret,
|
||||
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
|
||||
warn!("cleartext auth flow override is enabled, proceeding");
|
||||
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = latency_timer.pause();
|
||||
|
||||
let password = AuthFlow::new(client)
|
||||
.begin(auth::CleartextPassword)
|
||||
let auth_outcome = AuthFlow::new(client)
|
||||
.begin(auth::CleartextPassword(secret))
|
||||
.await?
|
||||
.authenticate()
|
||||
.await?;
|
||||
|
||||
// Report tentative success; compute node will check the password anyway.
|
||||
Ok(AuthSuccess {
|
||||
reported_auth_ok: false,
|
||||
value: ComputeCredentials::Password(password),
|
||||
})
|
||||
let keys = match auth_outcome {
|
||||
sasl::Outcome::Success(key) => key,
|
||||
sasl::Outcome::Failure(reason) => {
|
||||
info!("auth backend failed with an error: {reason}");
|
||||
return Err(auth::AuthError::auth_failed(&*info.inner.user));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(ComputeCredentials { info, keys })
|
||||
}
|
||||
|
||||
/// Workaround for clients which don't provide an endpoint (project) name.
|
||||
/// Very similar to [`cleartext_hack`], but there's a specific password format.
|
||||
pub async fn password_hack(
|
||||
creds: &mut ClientCredentials<'_>,
|
||||
/// Similar to [`authenticate_cleartext`], but there's a specific password format,
|
||||
/// and passwords are not yet validated (we don't know how to validate them!)
|
||||
pub async fn password_hack_no_authentication(
|
||||
info: ComputeUserInfoNoEndpoint,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
) -> auth::Result<AuthSuccess<ComputeCredentials>> {
|
||||
) -> auth::Result<ComputeCredentials<Vec<u8>>> {
|
||||
warn!("project not specified, resorting to the password hack auth flow");
|
||||
|
||||
// pause the timer while we communicate with the client
|
||||
@@ -48,15 +59,17 @@ pub async fn password_hack(
|
||||
let payload = AuthFlow::new(client)
|
||||
.begin(auth::PasswordHack)
|
||||
.await?
|
||||
.authenticate()
|
||||
.get_password()
|
||||
.await?;
|
||||
|
||||
info!(project = &payload.endpoint, "received missing parameter");
|
||||
creds.project = Some(payload.endpoint);
|
||||
info!(project = &*payload.endpoint, "received missing parameter");
|
||||
|
||||
// Report tentative success; compute node will check the password anyway.
|
||||
Ok(AuthSuccess {
|
||||
reported_auth_ok: false,
|
||||
value: ComputeCredentials::Password(payload.password),
|
||||
Ok(ComputeCredentials {
|
||||
info: ComputeUserInfo {
|
||||
inner: info,
|
||||
endpoint: payload.endpoint,
|
||||
},
|
||||
keys: payload.password,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use super::AuthSuccess;
|
||||
use crate::{
|
||||
auth, compute,
|
||||
console::{self, provider::NodeInfo},
|
||||
@@ -57,7 +56,7 @@ pub fn new_psql_session_id() -> String {
|
||||
pub(super) async fn authenticate(
|
||||
link_uri: &reqwest::Url,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<AuthSuccess<NodeInfo>> {
|
||||
) -> auth::Result<NodeInfo> {
|
||||
let psql_session_id = new_psql_session_id();
|
||||
let span = info_span!("link", psql_session_id = &psql_session_id);
|
||||
let greeting = hello_message(link_uri, &psql_session_id);
|
||||
@@ -102,12 +101,9 @@ pub(super) async fn authenticate(
|
||||
config.password(password.as_ref());
|
||||
}
|
||||
|
||||
Ok(AuthSuccess {
|
||||
reported_auth_ok: true,
|
||||
value: NodeInfo {
|
||||
config,
|
||||
aux: db_info.aux,
|
||||
allow_self_signed_compute: false, // caller may override
|
||||
},
|
||||
Ok(NodeInfo {
|
||||
config,
|
||||
aux: db_info.aux,
|
||||
allow_self_signed_compute: false, // caller may override
|
||||
})
|
||||
}
|
||||
|
||||
@@ -7,10 +7,8 @@ use crate::{
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use std::{
|
||||
collections::HashSet,
|
||||
net::{IpAddr, SocketAddr},
|
||||
};
|
||||
use smol_str::SmolStr;
|
||||
use std::{collections::HashSet, net::IpAddr};
|
||||
use thiserror::Error;
|
||||
use tracing::{info, warn};
|
||||
|
||||
@@ -24,7 +22,7 @@ pub enum ClientCredsParseError {
|
||||
SNI ('{}') and project option ('{}').",
|
||||
.domain, .option,
|
||||
)]
|
||||
InconsistentProjectNames { domain: String, option: String },
|
||||
InconsistentProjectNames { domain: SmolStr, option: SmolStr },
|
||||
|
||||
#[error(
|
||||
"Common name inferred from SNI ('{}') is not known",
|
||||
@@ -33,7 +31,7 @@ pub enum ClientCredsParseError {
|
||||
UnknownCommonName { cn: String },
|
||||
|
||||
#[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")]
|
||||
MalformedProjectName(String),
|
||||
MalformedProjectName(SmolStr),
|
||||
}
|
||||
|
||||
impl UserFacingError for ClientCredsParseError {}
|
||||
@@ -41,34 +39,34 @@ impl UserFacingError for ClientCredsParseError {}
|
||||
/// Various client credentials which we use for authentication.
|
||||
/// Note that we don't store any kind of client key or password here.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ClientCredentials<'a> {
|
||||
pub user: &'a str,
|
||||
pub struct ClientCredentials {
|
||||
pub user: SmolStr,
|
||||
// TODO: this is a severe misnomer! We should think of a new name ASAP.
|
||||
pub project: Option<String>,
|
||||
pub project: Option<SmolStr>,
|
||||
|
||||
pub cache_key: String,
|
||||
pub peer_addr: SocketAddr,
|
||||
pub cache_key: SmolStr,
|
||||
pub peer_addr: IpAddr,
|
||||
}
|
||||
|
||||
impl ClientCredentials<'_> {
|
||||
impl ClientCredentials {
|
||||
#[inline]
|
||||
pub fn project(&self) -> Option<&str> {
|
||||
self.project.as_deref()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> ClientCredentials<'a> {
|
||||
impl ClientCredentials {
|
||||
pub fn parse(
|
||||
params: &'a StartupMessageParams,
|
||||
params: &StartupMessageParams,
|
||||
sni: Option<&str>,
|
||||
common_names: Option<HashSet<String>>,
|
||||
peer_addr: SocketAddr,
|
||||
peer_addr: IpAddr,
|
||||
) -> Result<Self, ClientCredsParseError> {
|
||||
use ClientCredsParseError::*;
|
||||
|
||||
// Some parameters are stored in the startup message.
|
||||
let get_param = |key| params.get(key).ok_or(MissingKey(key));
|
||||
let user = get_param("user")?;
|
||||
let user = get_param("user")?.into();
|
||||
|
||||
// Project name might be passed via PG's command-line options.
|
||||
let project_option = params
|
||||
@@ -82,7 +80,7 @@ impl<'a> ClientCredentials<'a> {
|
||||
.at_most_one()
|
||||
.ok()?
|
||||
})
|
||||
.map(|name| name.to_string());
|
||||
.map(|name| name.into());
|
||||
|
||||
let project_from_domain = if let Some(sni_str) = sni {
|
||||
if let Some(cn) = common_names {
|
||||
@@ -121,7 +119,7 @@ impl<'a> ClientCredentials<'a> {
|
||||
}
|
||||
.transpose()?;
|
||||
|
||||
info!(user, project = project.as_deref(), "credentials");
|
||||
info!(%user, project = project.as_deref(), "credentials");
|
||||
if sni.is_some() {
|
||||
info!("Connection with sni");
|
||||
NUM_CONNECTION_ACCEPTED_BY_SNI
|
||||
@@ -143,7 +141,8 @@ impl<'a> ClientCredentials<'a> {
|
||||
"{}{}",
|
||||
project.as_deref().unwrap_or(""),
|
||||
neon_options(params).unwrap_or("".to_string())
|
||||
);
|
||||
)
|
||||
.into();
|
||||
|
||||
Ok(Self {
|
||||
user,
|
||||
@@ -206,10 +205,10 @@ fn project_name_valid(name: &str) -> bool {
|
||||
name.chars().all(|c| c.is_alphanumeric() || c == '-')
|
||||
}
|
||||
|
||||
fn subdomain_from_sni(sni: &str, common_name: &str) -> Option<String> {
|
||||
fn subdomain_from_sni(sni: &str, common_name: &str) -> Option<SmolStr> {
|
||||
sni.strip_suffix(common_name)?
|
||||
.strip_suffix('.')
|
||||
.map(str::to_owned)
|
||||
.map(SmolStr::from)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -221,7 +220,7 @@ mod tests {
|
||||
fn parse_bare_minimum() -> anyhow::Result<()> {
|
||||
// According to postgresql, only `user` should be required.
|
||||
let options = StartupMessageParams::new([("user", "john_doe")]);
|
||||
let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234));
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, None, None, peer_addr)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.project, None);
|
||||
@@ -236,7 +235,7 @@ mod tests {
|
||||
("database", "world"), // should be ignored
|
||||
("foo", "bar"), // should be ignored
|
||||
]);
|
||||
let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234));
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, None, None, peer_addr)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.project, None);
|
||||
@@ -251,7 +250,7 @@ mod tests {
|
||||
let sni = Some("foo.localhost");
|
||||
let common_names = Some(["localhost".into()].into());
|
||||
|
||||
let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234));
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.project.as_deref(), Some("foo"));
|
||||
@@ -267,7 +266,7 @@ mod tests {
|
||||
("options", "-ckey=1 project=bar -c geqo=off"),
|
||||
]);
|
||||
|
||||
let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234));
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, None, None, peer_addr)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.project.as_deref(), Some("bar"));
|
||||
@@ -282,7 +281,7 @@ mod tests {
|
||||
("options", "-ckey=1 endpoint=bar -c geqo=off"),
|
||||
]);
|
||||
|
||||
let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234));
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, None, None, peer_addr)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.project.as_deref(), Some("bar"));
|
||||
@@ -300,7 +299,7 @@ mod tests {
|
||||
),
|
||||
]);
|
||||
|
||||
let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234));
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, None, None, peer_addr)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert!(creds.project.is_none());
|
||||
@@ -315,7 +314,7 @@ mod tests {
|
||||
("options", "-ckey=1 endpoint=bar project=foo -c geqo=off"),
|
||||
]);
|
||||
|
||||
let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234));
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, None, None, peer_addr)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert!(creds.project.is_none());
|
||||
@@ -330,7 +329,7 @@ mod tests {
|
||||
let sni = Some("baz.localhost");
|
||||
let common_names = Some(["localhost".into()].into());
|
||||
|
||||
let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234));
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.project.as_deref(), Some("baz"));
|
||||
@@ -344,13 +343,13 @@ mod tests {
|
||||
|
||||
let common_names = Some(["a.com".into(), "b.com".into()].into());
|
||||
let sni = Some("p1.a.com");
|
||||
let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234));
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?;
|
||||
assert_eq!(creds.project.as_deref(), Some("p1"));
|
||||
|
||||
let common_names = Some(["a.com".into(), "b.com".into()].into());
|
||||
let sni = Some("p1.b.com");
|
||||
let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234));
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?;
|
||||
assert_eq!(creds.project.as_deref(), Some("p1"));
|
||||
|
||||
@@ -365,7 +364,7 @@ mod tests {
|
||||
let sni = Some("second.localhost");
|
||||
let common_names = Some(["localhost".into()].into());
|
||||
|
||||
let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234));
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let err = ClientCredentials::parse(&options, sni, common_names, peer_addr)
|
||||
.expect_err("should fail");
|
||||
match err {
|
||||
@@ -384,7 +383,7 @@ mod tests {
|
||||
let sni = Some("project.localhost");
|
||||
let common_names = Some(["example.com".into()].into());
|
||||
|
||||
let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234));
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let err = ClientCredentials::parse(&options, sni, common_names, peer_addr)
|
||||
.expect_err("should fail");
|
||||
match err {
|
||||
@@ -404,7 +403,7 @@ mod tests {
|
||||
|
||||
let sni = Some("project.localhost");
|
||||
let common_names = Some(["localhost".into()].into());
|
||||
let peer_addr = SocketAddr::from(([127, 0, 0, 1], 1234));
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?;
|
||||
assert_eq!(creds.project.as_deref(), Some("project"));
|
||||
assert_eq!(
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
//! Main authentication flow.
|
||||
|
||||
use super::{AuthErrorImpl, PasswordHackPayload};
|
||||
use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload};
|
||||
use crate::{
|
||||
config::TlsServerEndPoint,
|
||||
console::AuthSecret,
|
||||
sasl, scram,
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
@@ -50,7 +51,7 @@ impl AuthMethod for PasswordHack {
|
||||
|
||||
/// Use clear-text password auth called `password` in docs
|
||||
/// <https://www.postgresql.org/docs/current/auth-password.html>
|
||||
pub struct CleartextPassword;
|
||||
pub struct CleartextPassword(pub AuthSecret);
|
||||
|
||||
impl AuthMethod for CleartextPassword {
|
||||
#[inline(always)]
|
||||
@@ -98,7 +99,7 @@ impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
pub async fn authenticate(self) -> super::Result<PasswordHackPayload> {
|
||||
pub async fn get_password(self) -> super::Result<PasswordHackPayload> {
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
let password = msg
|
||||
.strip_suffix(&[0])
|
||||
@@ -117,13 +118,19 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
pub async fn authenticate(self) -> super::Result<Vec<u8>> {
|
||||
pub async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
let password = msg
|
||||
.strip_suffix(&[0])
|
||||
.ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
|
||||
|
||||
Ok(password.to_vec())
|
||||
let outcome = validate_password_and_exchange(password, self.state.0)?;
|
||||
|
||||
if let sasl::Outcome::Success(_) = &outcome {
|
||||
self.stream.write_message_noflush(&Be::AuthenticationOk)?;
|
||||
}
|
||||
|
||||
Ok(outcome)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,6 +159,49 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
))
|
||||
.await?;
|
||||
|
||||
if let sasl::Outcome::Success(_) = &outcome {
|
||||
self.stream.write_message_noflush(&Be::AuthenticationOk)?;
|
||||
}
|
||||
|
||||
Ok(outcome)
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn validate_password_and_exchange(
|
||||
password: &[u8],
|
||||
secret: AuthSecret,
|
||||
) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
|
||||
match secret {
|
||||
#[cfg(feature = "testing")]
|
||||
AuthSecret::Md5(_) => {
|
||||
// test only
|
||||
Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password(
|
||||
password.to_owned(),
|
||||
)))
|
||||
}
|
||||
// perform scram authentication as both client and server to validate the keys
|
||||
AuthSecret::Scram(scram_secret) => {
|
||||
use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256};
|
||||
let sasl_client = ScramSha256::new(password, ChannelBinding::unsupported());
|
||||
let outcome = crate::scram::exchange(
|
||||
&scram_secret,
|
||||
sasl_client,
|
||||
crate::config::TlsServerEndPoint::Undefined,
|
||||
)?;
|
||||
|
||||
let client_key = match outcome {
|
||||
sasl::Outcome::Success(client_key) => client_key,
|
||||
sasl::Outcome::Failure(reason) => return Ok(sasl::Outcome::Failure(reason)),
|
||||
};
|
||||
|
||||
let keys = crate::compute::ScramKeys {
|
||||
client_key: client_key.as_bytes(),
|
||||
server_key: scram_secret.server_key.as_bytes(),
|
||||
};
|
||||
|
||||
Ok(sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys(
|
||||
tokio_postgres::config::AuthKeys::ScramSha256(keys),
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,9 +4,10 @@
|
||||
//! UPDATE (Mon Aug 8 13:20:34 UTC 2022): the payload format has been simplified.
|
||||
|
||||
use bstr::ByteSlice;
|
||||
use smol_str::SmolStr;
|
||||
|
||||
pub struct PasswordHackPayload {
|
||||
pub endpoint: String,
|
||||
pub endpoint: SmolStr,
|
||||
pub password: Vec<u8>,
|
||||
}
|
||||
|
||||
@@ -18,7 +19,7 @@ impl PasswordHackPayload {
|
||||
if let Some((endpoint, password)) = bytes.split_once_str(sep) {
|
||||
let endpoint = endpoint.to_str().ok()?;
|
||||
return Some(Self {
|
||||
endpoint: parse_endpoint_param(endpoint)?.to_owned(),
|
||||
endpoint: parse_endpoint_param(endpoint)?.into(),
|
||||
password: password.to_owned(),
|
||||
});
|
||||
}
|
||||
|
||||
@@ -30,6 +30,7 @@ use clap::{Parser, ValueEnum};
|
||||
#[derive(Clone, Debug, ValueEnum)]
|
||||
enum AuthBackend {
|
||||
Console,
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres,
|
||||
Link,
|
||||
}
|
||||
@@ -289,6 +290,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
let api = console::provider::neon::Api::new(endpoint, caches, locks);
|
||||
auth::BackendType::Console(Cow::Owned(api), ())
|
||||
}
|
||||
#[cfg(feature = "testing")]
|
||||
AuthBackend::Postgres => {
|
||||
let url = args.auth_endpoint.parse()?;
|
||||
let api = console::provider::mock::Api::new(url);
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
#[cfg(feature = "testing")]
|
||||
pub mod mock;
|
||||
pub mod neon;
|
||||
|
||||
use super::messages::MetricsAuxInfo;
|
||||
use crate::{
|
||||
auth::ClientCredentials,
|
||||
auth::backend::ComputeUserInfo,
|
||||
cache::{timed_lru, TimedLru},
|
||||
compute, scram,
|
||||
};
|
||||
@@ -205,6 +206,7 @@ pub struct ConsoleReqExtra<'a> {
|
||||
|
||||
/// Auth secret which is managed by the cloud.
|
||||
pub enum AuthSecret {
|
||||
#[cfg(feature = "testing")]
|
||||
/// Md5 hash of user's password.
|
||||
Md5([u8; 16]),
|
||||
|
||||
@@ -247,20 +249,20 @@ pub trait Api {
|
||||
async fn get_auth_info(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<AuthInfo, errors::GetAuthInfoError>;
|
||||
|
||||
async fn get_allowed_ips(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<Arc<Vec<String>>, errors::GetAuthInfoError>;
|
||||
|
||||
/// Wake up the compute node and return the corresponding connection info.
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, errors::WakeComputeError>;
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ use super::{
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
AuthInfo, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
|
||||
};
|
||||
use crate::{auth::ClientCredentials, compute, error::io_error, scram, url::ApiUrl};
|
||||
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
|
||||
use async_trait::async_trait;
|
||||
use futures::TryFutureExt;
|
||||
use thiserror::Error;
|
||||
@@ -47,7 +47,7 @@ impl Api {
|
||||
|
||||
async fn do_get_auth_info(
|
||||
&self,
|
||||
creds: &ClientCredentials<'_>,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<AuthInfo, GetAuthInfoError> {
|
||||
let (secret, allowed_ips) = async {
|
||||
// Perhaps we could persist this connection, but then we'd have to
|
||||
@@ -60,7 +60,7 @@ impl Api {
|
||||
let secret = match get_execute_postgres_query(
|
||||
&client,
|
||||
"select rolpassword from pg_catalog.pg_authid where rolname = $1",
|
||||
&[&creds.user],
|
||||
&[&&*creds.inner.user],
|
||||
"rolpassword",
|
||||
)
|
||||
.await?
|
||||
@@ -71,14 +71,14 @@ impl Api {
|
||||
secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
|
||||
}
|
||||
None => {
|
||||
warn!("user '{}' does not exist", creds.user);
|
||||
warn!("user '{}' does not exist", creds.inner.user);
|
||||
None
|
||||
}
|
||||
};
|
||||
let allowed_ips = match get_execute_postgres_query(
|
||||
&client,
|
||||
"select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1",
|
||||
&[&creds.project.clone().unwrap_or_default().as_str()],
|
||||
&[&creds.endpoint.as_str()],
|
||||
"allowed_ips",
|
||||
)
|
||||
.await?
|
||||
@@ -145,7 +145,7 @@ impl super::Api for Api {
|
||||
async fn get_auth_info(
|
||||
&self,
|
||||
_extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<AuthInfo, GetAuthInfoError> {
|
||||
self.do_get_auth_info(creds).await
|
||||
}
|
||||
@@ -153,7 +153,7 @@ impl super::Api for Api {
|
||||
async fn get_allowed_ips(
|
||||
&self,
|
||||
_extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
|
||||
Ok(Arc::new(self.do_get_auth_info(creds).await?.allowed_ips))
|
||||
}
|
||||
@@ -162,7 +162,7 @@ impl super::Api for Api {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
_extra: &ConsoleReqExtra<'_>,
|
||||
_creds: &ClientCredentials,
|
||||
_creds: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
self.do_wake_compute()
|
||||
.map_ok(CachedNodeInfo::new_uncached)
|
||||
|
||||
@@ -5,12 +5,8 @@ use super::{
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
|
||||
};
|
||||
use crate::{
|
||||
auth::ClientCredentials,
|
||||
compute, http,
|
||||
proxy::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER},
|
||||
scram,
|
||||
};
|
||||
use crate::proxy::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER};
|
||||
use crate::{auth::backend::ComputeUserInfo, compute, http, scram};
|
||||
use async_trait::async_trait;
|
||||
use futures::TryFutureExt;
|
||||
use itertools::Itertools;
|
||||
@@ -53,7 +49,7 @@ impl Api {
|
||||
async fn do_get_auth_info(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<AuthInfo, GetAuthInfoError> {
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
async {
|
||||
@@ -65,8 +61,8 @@ impl Api {
|
||||
.query(&[("session_id", extra.session_id)])
|
||||
.query(&[
|
||||
("application_name", extra.application_name),
|
||||
("project", Some(creds.project().expect("impossible"))),
|
||||
("role", Some(creds.user)),
|
||||
("project", Some(&creds.endpoint)),
|
||||
("role", Some(&creds.inner.user)),
|
||||
])
|
||||
.build()?;
|
||||
|
||||
@@ -106,9 +102,8 @@ impl Api {
|
||||
async fn do_wake_compute(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials<'_>,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<NodeInfo, WakeComputeError> {
|
||||
let project = creds.project().expect("impossible");
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
async {
|
||||
let request = self
|
||||
@@ -119,7 +114,7 @@ impl Api {
|
||||
.query(&[("session_id", extra.session_id)])
|
||||
.query(&[
|
||||
("application_name", extra.application_name),
|
||||
("project", Some(project)),
|
||||
("project", Some(&creds.endpoint)),
|
||||
("options", extra.options),
|
||||
])
|
||||
.build()?;
|
||||
@@ -162,7 +157,7 @@ impl super::Api for Api {
|
||||
async fn get_auth_info(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<AuthInfo, GetAuthInfoError> {
|
||||
self.do_get_auth_info(extra, creds).await
|
||||
}
|
||||
@@ -170,9 +165,9 @@ impl super::Api for Api {
|
||||
async fn get_allowed_ips(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
|
||||
let key: &str = creds.project().expect("impossible");
|
||||
let key: &str = &creds.endpoint;
|
||||
if let Some(allowed_ips) = self.caches.allowed_ips.get(key) {
|
||||
ALLOWED_IPS_BY_CACHE_OUTCOME
|
||||
.with_label_values(&["hit"])
|
||||
@@ -193,9 +188,9 @@ impl super::Api for Api {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra<'_>,
|
||||
creds: &ClientCredentials,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
let key: &str = &creds.cache_key;
|
||||
let key: &str = &creds.inner.cache_key;
|
||||
|
||||
// Every time we do a wakeup http request, the compute node will stay up
|
||||
// for some time (highly depends on the console's scale-to-zero policy);
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
mod tests;
|
||||
|
||||
use crate::{
|
||||
auth::{self, backend::AuthSuccess},
|
||||
auth,
|
||||
cancellation::{self, CancelMap},
|
||||
compute::{self, PostgresConnection},
|
||||
config::{AuthenticationConfig, ProxyConfig, TlsConfig},
|
||||
@@ -24,7 +24,7 @@ use prometheus::{
|
||||
IntGaugeVec,
|
||||
};
|
||||
use regex::Regex;
|
||||
use std::{error::Error, io, net::SocketAddr, ops::ControlFlow, sync::Arc, time::Instant};
|
||||
use std::{error::Error, io, net::IpAddr, ops::ControlFlow, sync::Arc, time::Instant};
|
||||
use tokio::{
|
||||
io::{AsyncRead, AsyncWrite, AsyncWriteExt},
|
||||
time,
|
||||
@@ -318,7 +318,7 @@ pub async fn task_main(
|
||||
.set_nodelay(true)
|
||||
.context("failed to set socket option")?;
|
||||
|
||||
handle_client(config, &cancel_map, session_id, socket, ClientMode::Tcp, peer_addr).await
|
||||
handle_client(config, &cancel_map, session_id, socket, ClientMode::Tcp, peer_addr.ip()).await
|
||||
}
|
||||
.instrument(info_span!("handle_client", ?session_id, peer_addr = tracing::field::Empty))
|
||||
.unwrap_or_else(move |e| {
|
||||
@@ -408,7 +408,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
session_id: uuid::Uuid,
|
||||
stream: S,
|
||||
mode: ClientMode,
|
||||
peer_addr: SocketAddr,
|
||||
peer_addr: IpAddr,
|
||||
) -> anyhow::Result<()> {
|
||||
info!(
|
||||
protocol = mode.protocol_label(),
|
||||
@@ -666,7 +666,7 @@ pub async fn connect_to_compute<M: ConnectMechanism>(
|
||||
mechanism: &M,
|
||||
mut node_info: console::CachedNodeInfo,
|
||||
extra: &console::ConsoleReqExtra<'_>,
|
||||
creds: &auth::BackendType<'_, auth::ClientCredentials<'_>>,
|
||||
creds: &auth::BackendType<'_, auth::backend::ComputeUserInfo>,
|
||||
mut latency_timer: LatencyTimer,
|
||||
) -> Result<M::Connection, M::Error>
|
||||
where
|
||||
@@ -696,10 +696,12 @@ where
|
||||
let node_info = loop {
|
||||
let wake_res = match creds {
|
||||
auth::BackendType::Console(api, creds) => api.wake_compute(extra, creds).await,
|
||||
#[cfg(feature = "testing")]
|
||||
auth::BackendType::Postgres(api, creds) => api.wake_compute(extra, creds).await,
|
||||
// nothing to do?
|
||||
auth::BackendType::Link(_) => return Err(err.into()),
|
||||
// test backend
|
||||
#[cfg(test)]
|
||||
auth::BackendType::Test(x) => x.wake_compute(),
|
||||
};
|
||||
|
||||
@@ -838,7 +840,6 @@ pub fn retry_after(num_retries: u32) -> time::Duration {
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn prepare_client_connection(
|
||||
node: &compute::PostgresConnection,
|
||||
reported_auth_ok: bool,
|
||||
session: cancellation::Session<'_>,
|
||||
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -846,13 +847,6 @@ async fn prepare_client_connection(
|
||||
// The new token (cancel_key_data) will be sent to the client.
|
||||
let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
|
||||
|
||||
// Report authentication success if we haven't done this already.
|
||||
// Note that we do this only (for the most part) after we've connected
|
||||
// to a compute (see above) which performs its own authentication.
|
||||
if !reported_auth_ok {
|
||||
stream.write_message_noflush(&Be::AuthenticationOk)?;
|
||||
}
|
||||
|
||||
// Forward all postgres connection params to the client.
|
||||
// Right now the implementation is very hacky and inefficent (ideally,
|
||||
// we don't need an intermediate hashmap), but at least it should be correct.
|
||||
@@ -921,7 +915,7 @@ struct Client<'a, S> {
|
||||
/// The underlying libpq protocol stream.
|
||||
stream: PqStream<Stream<S>>,
|
||||
/// Client credentials that we care about.
|
||||
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
|
||||
creds: auth::BackendType<'a, auth::ClientCredentials>,
|
||||
/// KV-dictionary with PostgreSQL connection params.
|
||||
params: &'a StartupMessageParams,
|
||||
/// Unique connection ID.
|
||||
@@ -934,7 +928,7 @@ impl<'a, S> Client<'a, S> {
|
||||
/// Construct a new connection context.
|
||||
fn new(
|
||||
stream: PqStream<Stream<S>>,
|
||||
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
|
||||
creds: auth::BackendType<'a, auth::ClientCredentials>,
|
||||
params: &'a StartupMessageParams,
|
||||
session_id: uuid::Uuid,
|
||||
allow_self_signed_compute: bool,
|
||||
@@ -953,7 +947,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
/// Let the client authenticate and connect to the designated compute node.
|
||||
// Instrumentation logs endpoint name everywhere. Doesn't work for link
|
||||
// auth; strictly speaking we don't know endpoint name in its case.
|
||||
#[tracing::instrument(name = "", fields(ep = self.creds.get_endpoint().unwrap_or("".to_owned())), skip_all)]
|
||||
#[tracing::instrument(name = "", fields(ep = %self.creds.get_endpoint().unwrap_or_default()), skip_all)]
|
||||
async fn connect_to_db(
|
||||
self,
|
||||
session: cancellation::Session<'_>,
|
||||
@@ -962,7 +956,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
) -> anyhow::Result<()> {
|
||||
let Self {
|
||||
mut stream,
|
||||
mut creds,
|
||||
creds,
|
||||
params,
|
||||
session_id,
|
||||
allow_self_signed_compute,
|
||||
@@ -978,6 +972,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
|
||||
let mut latency_timer = LatencyTimer::new(mode.protocol_label());
|
||||
|
||||
let user = creds.get_user().to_owned();
|
||||
let auth_result = match creds
|
||||
.authenticate(
|
||||
&extra,
|
||||
@@ -990,7 +985,6 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
{
|
||||
Ok(auth_result) => auth_result,
|
||||
Err(e) => {
|
||||
let user = creds.get_user();
|
||||
let db = params.get("database");
|
||||
let app = params.get("application_name");
|
||||
let params_span = tracing::info_span!("", ?user, ?db, ?app);
|
||||
@@ -999,10 +993,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
}
|
||||
};
|
||||
|
||||
let AuthSuccess {
|
||||
reported_auth_ok,
|
||||
value: mut node_info,
|
||||
} = auth_result;
|
||||
let (mut node_info, creds) = auth_result;
|
||||
|
||||
node_info.allow_self_signed_compute = allow_self_signed_compute;
|
||||
|
||||
@@ -1025,7 +1016,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
NUM_DB_CONNECTIONS_CLOSED_COUNTER.with_label_values(&[proto]).inc();
|
||||
}
|
||||
|
||||
prepare_client_connection(&node, reported_auth_ok, session, &mut stream).await?;
|
||||
prepare_client_connection(&node, session, &mut stream).await?;
|
||||
// Before proxy passing, forward to compute whatever data is left in the
|
||||
// PqStream input buffer. Normally there is none, but our serverless npm
|
||||
// driver in pipeline mode sends startup, password and first query
|
||||
|
||||
@@ -3,8 +3,7 @@
|
||||
mod mitm;
|
||||
|
||||
use super::*;
|
||||
use crate::auth::backend::TestBackend;
|
||||
use crate::auth::ClientCredentials;
|
||||
use crate::auth::backend::{ComputeUserInfo, TestBackend};
|
||||
use crate::config::CertResolver;
|
||||
use crate::console::{CachedNodeInfo, NodeInfo};
|
||||
use crate::{auth, http, sasl, scram};
|
||||
@@ -109,8 +108,9 @@ fn generate_tls_config<'a>(
|
||||
trait TestAuth: Sized {
|
||||
async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
self,
|
||||
_stream: &mut PqStream<Stream<S>>,
|
||||
stream: &mut PqStream<Stream<S>>,
|
||||
) -> anyhow::Result<()> {
|
||||
stream.write_message_noflush(&Be::AuthenticationOk)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -168,7 +168,6 @@ async fn dummy_proxy(
|
||||
auth.authenticate(&mut stream).await?;
|
||||
|
||||
stream
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&Be::CLIENT_ENCODING)?
|
||||
.write_message(&Be::ReadyForQuery)
|
||||
.await?;
|
||||
@@ -486,7 +485,7 @@ fn helper_create_connect_info(
|
||||
) -> (
|
||||
CachedNodeInfo,
|
||||
console::ConsoleReqExtra<'static>,
|
||||
auth::BackendType<'_, ClientCredentials<'static>>,
|
||||
auth::BackendType<'_, ComputeUserInfo>,
|
||||
) {
|
||||
let cache = helper_create_cached_node_info();
|
||||
let extra = console::ConsoleReqExtra {
|
||||
|
||||
@@ -15,7 +15,7 @@ mod signature;
|
||||
#[cfg(any(test, doc))]
|
||||
mod password;
|
||||
|
||||
pub use exchange::Exchange;
|
||||
pub use exchange::{exchange, Exchange};
|
||||
pub use key::ScramKey;
|
||||
pub use secret::ServerSecret;
|
||||
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
//! Implementation of the SCRAM authentication algorithm.
|
||||
|
||||
use std::convert::Infallible;
|
||||
|
||||
use postgres_protocol::authentication::sasl::ScramSha256;
|
||||
|
||||
use super::messages::{
|
||||
ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
|
||||
};
|
||||
@@ -29,22 +33,27 @@ impl std::str::FromStr for TlsServerEndPoint {
|
||||
}
|
||||
}
|
||||
|
||||
struct SaslSentInner {
|
||||
cbind_flag: ChannelBinding<TlsServerEndPoint>,
|
||||
client_first_message_bare: String,
|
||||
server_first_message: OwnedServerFirstMessage,
|
||||
}
|
||||
|
||||
struct SaslInitial {
|
||||
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
|
||||
}
|
||||
|
||||
enum ExchangeState {
|
||||
/// Waiting for [`ClientFirstMessage`].
|
||||
Initial,
|
||||
Initial(SaslInitial),
|
||||
/// Waiting for [`ClientFinalMessage`].
|
||||
SaltSent {
|
||||
cbind_flag: ChannelBinding<TlsServerEndPoint>,
|
||||
client_first_message_bare: String,
|
||||
server_first_message: OwnedServerFirstMessage,
|
||||
},
|
||||
SaltSent(SaslSentInner),
|
||||
}
|
||||
|
||||
/// Server's side of SCRAM auth algorithm.
|
||||
pub struct Exchange<'a> {
|
||||
state: ExchangeState,
|
||||
secret: &'a ServerSecret,
|
||||
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
|
||||
tls_server_end_point: config::TlsServerEndPoint,
|
||||
}
|
||||
|
||||
@@ -55,90 +64,160 @@ impl<'a> Exchange<'a> {
|
||||
tls_server_end_point: config::TlsServerEndPoint,
|
||||
) -> Self {
|
||||
Self {
|
||||
state: ExchangeState::Initial,
|
||||
state: ExchangeState::Initial(SaslInitial { nonce }),
|
||||
secret,
|
||||
nonce,
|
||||
tls_server_end_point,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn exchange(
|
||||
secret: &ServerSecret,
|
||||
mut client: ScramSha256,
|
||||
tls_server_end_point: config::TlsServerEndPoint,
|
||||
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
|
||||
use sasl::Step::*;
|
||||
|
||||
let init = SaslInitial {
|
||||
nonce: rand::random,
|
||||
};
|
||||
|
||||
let client_first = std::str::from_utf8(client.message())
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
|
||||
let sent = match init.transition(secret, &tls_server_end_point, client_first)? {
|
||||
Continue(sent, server_first) => {
|
||||
client.update(server_first.as_bytes())?;
|
||||
sent
|
||||
}
|
||||
Success(x, _) => match x {},
|
||||
Failure(msg) => return Ok(sasl::Outcome::Failure(msg)),
|
||||
};
|
||||
|
||||
let client_final = std::str::from_utf8(client.message())
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
|
||||
let keys = match sent.transition(secret, &tls_server_end_point, client_final)? {
|
||||
Success(keys, server_final) => {
|
||||
client.finish(server_final.as_bytes())?;
|
||||
keys
|
||||
}
|
||||
Continue(x, _) => match x {},
|
||||
Failure(msg) => return Ok(sasl::Outcome::Failure(msg)),
|
||||
};
|
||||
|
||||
Ok(sasl::Outcome::Success(keys))
|
||||
}
|
||||
|
||||
impl SaslInitial {
|
||||
fn transition(
|
||||
&self,
|
||||
secret: &ServerSecret,
|
||||
tls_server_end_point: &config::TlsServerEndPoint,
|
||||
input: &str,
|
||||
) -> sasl::Result<sasl::Step<SaslSentInner, Infallible>> {
|
||||
let client_first_message = ClientFirstMessage::parse(input)
|
||||
.ok_or(SaslError::BadClientMessage("invalid client-first-message"))?;
|
||||
|
||||
// If the flag is set to "y" and the server supports channel
|
||||
// binding, the server MUST fail authentication
|
||||
if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer
|
||||
&& tls_server_end_point.supported()
|
||||
{
|
||||
return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used"));
|
||||
}
|
||||
|
||||
let server_first_message = client_first_message.build_server_first_message(
|
||||
&(self.nonce)(),
|
||||
&secret.salt_base64,
|
||||
secret.iterations,
|
||||
);
|
||||
let msg = server_first_message.as_str().to_owned();
|
||||
|
||||
let next = SaslSentInner {
|
||||
cbind_flag: client_first_message.cbind_flag.and_then(str::parse)?,
|
||||
client_first_message_bare: client_first_message.bare.to_owned(),
|
||||
server_first_message,
|
||||
};
|
||||
|
||||
Ok(sasl::Step::Continue(next, msg))
|
||||
}
|
||||
}
|
||||
|
||||
impl SaslSentInner {
|
||||
fn transition(
|
||||
&self,
|
||||
secret: &ServerSecret,
|
||||
tls_server_end_point: &config::TlsServerEndPoint,
|
||||
input: &str,
|
||||
) -> sasl::Result<sasl::Step<Infallible, super::ScramKey>> {
|
||||
let Self {
|
||||
cbind_flag,
|
||||
client_first_message_bare,
|
||||
server_first_message,
|
||||
} = self;
|
||||
|
||||
let client_final_message = ClientFinalMessage::parse(input)
|
||||
.ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
|
||||
|
||||
let channel_binding = cbind_flag.encode(|_| match tls_server_end_point {
|
||||
config::TlsServerEndPoint::Sha256(x) => Ok(x),
|
||||
config::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
|
||||
})?;
|
||||
|
||||
// This might've been caused by a MITM attack
|
||||
if client_final_message.channel_binding != channel_binding {
|
||||
return Err(SaslError::ChannelBindingFailed(
|
||||
"insecure connection: secure channel data mismatch",
|
||||
));
|
||||
}
|
||||
|
||||
if client_final_message.nonce != server_first_message.nonce() {
|
||||
return Err(SaslError::BadClientMessage("combined nonce doesn't match"));
|
||||
}
|
||||
|
||||
let signature_builder = SignatureBuilder {
|
||||
client_first_message_bare,
|
||||
server_first_message: server_first_message.as_str(),
|
||||
client_final_message_without_proof: client_final_message.without_proof,
|
||||
};
|
||||
|
||||
let client_key = signature_builder
|
||||
.build(&secret.stored_key)
|
||||
.derive_client_key(&client_final_message.proof);
|
||||
|
||||
// Auth fails either if keys don't match or it's pre-determined to fail.
|
||||
if client_key.sha256() != secret.stored_key || secret.doomed {
|
||||
return Ok(sasl::Step::Failure("password doesn't match"));
|
||||
}
|
||||
|
||||
let msg =
|
||||
client_final_message.build_server_final_message(signature_builder, &secret.server_key);
|
||||
|
||||
Ok(sasl::Step::Success(client_key, msg))
|
||||
}
|
||||
}
|
||||
|
||||
impl sasl::Mechanism for Exchange<'_> {
|
||||
type Output = super::ScramKey;
|
||||
|
||||
fn exchange(mut self, input: &str) -> sasl::Result<sasl::Step<Self, Self::Output>> {
|
||||
use {sasl::Step::*, ExchangeState::*};
|
||||
match &self.state {
|
||||
Initial => {
|
||||
let client_first_message = ClientFirstMessage::parse(input)
|
||||
.ok_or(SaslError::BadClientMessage("invalid client-first-message"))?;
|
||||
|
||||
// If the flag is set to "y" and the server supports channel
|
||||
// binding, the server MUST fail authentication
|
||||
if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer
|
||||
&& self.tls_server_end_point.supported()
|
||||
{
|
||||
return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used"));
|
||||
Initial(init) => {
|
||||
match init.transition(self.secret, &self.tls_server_end_point, input)? {
|
||||
Continue(sent, msg) => {
|
||||
self.state = SaltSent(sent);
|
||||
Ok(Continue(self, msg))
|
||||
}
|
||||
Success(x, _) => match x {},
|
||||
Failure(msg) => Ok(Failure(msg)),
|
||||
}
|
||||
|
||||
let server_first_message = client_first_message.build_server_first_message(
|
||||
&(self.nonce)(),
|
||||
&self.secret.salt_base64,
|
||||
self.secret.iterations,
|
||||
);
|
||||
let msg = server_first_message.as_str().to_owned();
|
||||
|
||||
self.state = SaltSent {
|
||||
cbind_flag: client_first_message.cbind_flag.and_then(str::parse)?,
|
||||
client_first_message_bare: client_first_message.bare.to_owned(),
|
||||
server_first_message,
|
||||
};
|
||||
|
||||
Ok(Continue(self, msg))
|
||||
}
|
||||
SaltSent {
|
||||
cbind_flag,
|
||||
client_first_message_bare,
|
||||
server_first_message,
|
||||
} => {
|
||||
let client_final_message = ClientFinalMessage::parse(input)
|
||||
.ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
|
||||
|
||||
let channel_binding = cbind_flag.encode(|_| match &self.tls_server_end_point {
|
||||
config::TlsServerEndPoint::Sha256(x) => Ok(x),
|
||||
config::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
|
||||
})?;
|
||||
|
||||
// This might've been caused by a MITM attack
|
||||
if client_final_message.channel_binding != channel_binding {
|
||||
return Err(SaslError::ChannelBindingFailed(
|
||||
"insecure connection: secure channel data mismatch",
|
||||
));
|
||||
SaltSent(sent) => {
|
||||
match sent.transition(self.secret, &self.tls_server_end_point, input)? {
|
||||
Success(keys, msg) => Ok(Success(keys, msg)),
|
||||
Continue(x, _) => match x {},
|
||||
Failure(msg) => Ok(Failure(msg)),
|
||||
}
|
||||
|
||||
if client_final_message.nonce != server_first_message.nonce() {
|
||||
return Err(SaslError::BadClientMessage("combined nonce doesn't match"));
|
||||
}
|
||||
|
||||
let signature_builder = SignatureBuilder {
|
||||
client_first_message_bare,
|
||||
server_first_message: server_first_message.as_str(),
|
||||
client_final_message_without_proof: client_final_message.without_proof,
|
||||
};
|
||||
|
||||
let client_key = signature_builder
|
||||
.build(&self.secret.stored_key)
|
||||
.derive_client_key(&client_final_message.proof);
|
||||
|
||||
// Auth fails either if keys don't match or it's pre-determined to fail.
|
||||
if client_key.sha256() != self.secret.stored_key || self.secret.doomed {
|
||||
return Ok(Failure("password doesn't match"));
|
||||
}
|
||||
|
||||
let msg = client_final_message
|
||||
.build_server_final_message(signature_builder, &self.secret.server_key);
|
||||
|
||||
Ok(Success(client_key, msg))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ use hyper::{
|
||||
Body, Method, Request, Response,
|
||||
};
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::net::IpAddr;
|
||||
use std::task::Poll;
|
||||
use std::{future::ready, sync::Arc};
|
||||
use tls_listener::TlsListener;
|
||||
@@ -103,7 +103,13 @@ pub async fn task_main(
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
|
||||
request_handler(
|
||||
req, config, conn_pool, cancel_map, session_id, sni_name, peer_addr,
|
||||
req,
|
||||
config,
|
||||
conn_pool,
|
||||
cancel_map,
|
||||
session_id,
|
||||
sni_name,
|
||||
peer_addr.ip(),
|
||||
)
|
||||
.instrument(info_span!(
|
||||
"serverless",
|
||||
@@ -171,7 +177,7 @@ async fn request_handler(
|
||||
cancel_map: Arc<CancelMap>,
|
||||
session_id: uuid::Uuid,
|
||||
sni_hostname: Option<String>,
|
||||
peer_addr: SocketAddr,
|
||||
peer_addr: IpAddr,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
let host = request
|
||||
.headers()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use anyhow::Context;
|
||||
use anyhow::{anyhow, Context};
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
use futures::future::poll_fn;
|
||||
@@ -9,7 +9,7 @@ use pbkdf2::{
|
||||
};
|
||||
use pq_proto::StartupMessageParams;
|
||||
use smol_str::SmolStr;
|
||||
use std::{collections::HashMap, net::SocketAddr, sync::Arc};
|
||||
use std::{collections::HashMap, net::IpAddr, sync::Arc};
|
||||
use std::{
|
||||
fmt,
|
||||
task::{ready, Poll},
|
||||
@@ -22,7 +22,7 @@ use tokio::time;
|
||||
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus};
|
||||
|
||||
use crate::{
|
||||
auth::{self, check_peer_addr_is_in_list},
|
||||
auth::{self, backend::ComputeUserInfo, check_peer_addr_is_in_list},
|
||||
console,
|
||||
proxy::{
|
||||
neon_options, LatencyTimer, NUM_DB_CONNECTIONS_CLOSED_COUNTER,
|
||||
@@ -146,7 +146,7 @@ impl GlobalConnPool {
|
||||
conn_info: &ConnInfo,
|
||||
force_new: bool,
|
||||
session_id: uuid::Uuid,
|
||||
peer_addr: SocketAddr,
|
||||
peer_addr: IpAddr,
|
||||
) -> anyhow::Result<Client> {
|
||||
let mut client: Option<ClientInner> = None;
|
||||
let mut latency_timer = LatencyTimer::new("http");
|
||||
@@ -406,7 +406,7 @@ async fn connect_to_compute(
|
||||
conn_id: uuid::Uuid,
|
||||
session_id: uuid::Uuid,
|
||||
latency_timer: LatencyTimer,
|
||||
peer_addr: SocketAddr,
|
||||
peer_addr: IpAddr,
|
||||
) -> anyhow::Result<ClientInner> {
|
||||
let tls = config.tls_config.as_ref();
|
||||
let common_names = tls.and_then(|tls| tls.common_names.clone());
|
||||
@@ -423,6 +423,9 @@ async fn connect_to_compute(
|
||||
common_names,
|
||||
peer_addr,
|
||||
)?;
|
||||
|
||||
let creds =
|
||||
ComputeUserInfo::try_from(creds).map_err(|_| anyhow!("missing endpoint identifier"))?;
|
||||
let backend = config.auth_backend.as_ref().map(|_| creds);
|
||||
|
||||
let console_options = neon_options(¶ms);
|
||||
@@ -435,7 +438,7 @@ async fn connect_to_compute(
|
||||
// TODO(anna): this is a bit hacky way, consider using console notification listener.
|
||||
if !config.disable_ip_check_for_http {
|
||||
let allowed_ips = backend.get_allowed_ips(&extra).await?;
|
||||
if !check_peer_addr_is_in_list(&peer_addr.ip(), &allowed_ips) {
|
||||
if !check_peer_addr_is_in_list(&peer_addr, &allowed_ips) {
|
||||
return Err(auth::AuthError::ip_address_not_allowed().into());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::bail;
|
||||
@@ -202,7 +202,7 @@ pub async fn handle(
|
||||
sni_hostname: Option<String>,
|
||||
conn_pool: Arc<GlobalConnPool>,
|
||||
session_id: uuid::Uuid,
|
||||
peer_addr: SocketAddr,
|
||||
peer_addr: IpAddr,
|
||||
config: &'static HttpConfig,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
let result = tokio::time::timeout(
|
||||
@@ -301,7 +301,7 @@ async fn handle_inner(
|
||||
sni_hostname: Option<String>,
|
||||
conn_pool: Arc<GlobalConnPool>,
|
||||
session_id: uuid::Uuid,
|
||||
peer_addr: SocketAddr,
|
||||
peer_addr: IpAddr,
|
||||
) -> anyhow::Result<Response<Body>> {
|
||||
NUM_CONNECTIONS_ACCEPTED_COUNTER
|
||||
.with_label_values(&["http"])
|
||||
|
||||
@@ -11,7 +11,7 @@ use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
|
||||
use pin_project_lite::pin_project;
|
||||
|
||||
use std::{
|
||||
net::SocketAddr,
|
||||
net::IpAddr,
|
||||
pin::Pin,
|
||||
task::{ready, Context, Poll},
|
||||
};
|
||||
@@ -133,7 +133,7 @@ pub async fn serve_websocket(
|
||||
cancel_map: &CancelMap,
|
||||
session_id: uuid::Uuid,
|
||||
hostname: Option<String>,
|
||||
peer_addr: SocketAddr,
|
||||
peer_addr: IpAddr,
|
||||
) -> anyhow::Result<()> {
|
||||
let websocket = websocket.await?;
|
||||
handle_client(
|
||||
|
||||
Reference in New Issue
Block a user