proxy: random changes (#8602)

## Problem

1. Hard to correlate startup parameters with the endpoint that provided
them.
2. Some configurations are not needed in the `ProxyConfig` struct.

## Summary of changes

Because of some borrow checker fun, I needed to switch to an
interior-mutability implementation of our `RequestMonitoring` context
system. Using https://docs.rs/try-lock/latest/try_lock/ as a cheap lock
for such a use-case (needed to be thread safe).

Removed the lock of each startup message, instead just logging only the
startup params in a successful handshake.

Also removed from values from `ProxyConfig` and kept as arguments.
(needed for local-proxy config)
This commit is contained in:
Conrad Ludgate
2024-08-07 14:37:03 +01:00
committed by John Spray
parent 38f184bc91
commit 1f73dfb842
31 changed files with 386 additions and 276 deletions

View File

@@ -218,7 +218,7 @@ impl RateBucketInfo {
impl AuthenticationConfig {
pub fn check_rate_limit(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
config: &AuthenticationConfig,
secret: AuthSecret,
endpoint: &EndpointId,
@@ -243,7 +243,7 @@ impl AuthenticationConfig {
let limit_not_exceeded = self.rate_limiter.check(
(
endpoint_int,
MaskedIp::new(ctx.peer_addr, config.rate_limit_ip_subnet),
MaskedIp::new(ctx.peer_addr(), config.rate_limit_ip_subnet),
),
password_weight,
);
@@ -274,7 +274,7 @@ impl AuthenticationConfig {
///
/// All authentication flows will emit an AuthenticationOk message if successful.
async fn auth_quirks(
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
api: &impl console::Api,
user_info: ComputeUserInfoMaybeEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
@@ -303,8 +303,8 @@ async fn auth_quirks(
let (allowed_ips, maybe_secret) = api.get_allowed_ips_and_secret(ctx, &info).await?;
// check allowed list
if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr));
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) {
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
}
if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) {
@@ -356,7 +356,7 @@ async fn auth_quirks(
}
async fn authenticate_with_secret(
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
secret: AuthSecret,
info: ComputeUserInfo,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
@@ -421,7 +421,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
#[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
pub async fn authenticate(
self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
@@ -467,7 +467,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
impl BackendType<'_, ComputeUserInfo, &()> {
pub async fn get_role_secret(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
) -> Result<CachedRoleSecret, GetAuthInfoError> {
use BackendType::*;
match self {
@@ -478,7 +478,7 @@ impl BackendType<'_, ComputeUserInfo, &()> {
pub async fn get_allowed_ips_and_secret(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
use BackendType::*;
match self {
@@ -492,7 +492,7 @@ impl BackendType<'_, ComputeUserInfo, &()> {
impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> {
async fn wake_compute(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
use BackendType::*;
@@ -514,7 +514,7 @@ impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> {
impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> {
async fn wake_compute(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
use BackendType::*;
@@ -571,7 +571,7 @@ mod tests {
impl console::Api for Auth {
async fn get_role_secret(
&self,
_ctx: &mut RequestMonitoring,
_ctx: &RequestMonitoring,
_user_info: &super::ComputeUserInfo,
) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError> {
Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone())))
@@ -579,7 +579,7 @@ mod tests {
async fn get_allowed_ips_and_secret(
&self,
_ctx: &mut RequestMonitoring,
_ctx: &RequestMonitoring,
_user_info: &super::ComputeUserInfo,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>
{
@@ -591,7 +591,7 @@ mod tests {
async fn wake_compute(
&self,
_ctx: &mut RequestMonitoring,
_ctx: &RequestMonitoring,
_user_info: &super::ComputeUserInfo,
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
unimplemented!()
@@ -665,7 +665,7 @@ mod tests {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut ctx = RequestMonitoring::test();
let ctx = RequestMonitoring::test();
let api = Auth {
ips: vec![],
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
@@ -723,7 +723,7 @@ mod tests {
));
let _creds = auth_quirks(
&mut ctx,
&ctx,
&api,
user_info,
&mut stream,
@@ -742,7 +742,7 @@ mod tests {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut ctx = RequestMonitoring::test();
let ctx = RequestMonitoring::test();
let api = Auth {
ips: vec![],
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
@@ -775,7 +775,7 @@ mod tests {
));
let _creds = auth_quirks(
&mut ctx,
&ctx,
&api,
user_info,
&mut stream,
@@ -794,7 +794,7 @@ mod tests {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut ctx = RequestMonitoring::test();
let ctx = RequestMonitoring::test();
let api = Auth {
ips: vec![],
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
@@ -828,7 +828,7 @@ mod tests {
));
let creds = auth_quirks(
&mut ctx,
&ctx,
&api,
user_info,
&mut stream,

View File

@@ -12,7 +12,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};
pub(super) async fn authenticate(
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
creds: ComputeUserInfo,
client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
config: &'static AuthenticationConfig,
@@ -27,7 +27,7 @@ pub(super) async fn authenticate(
}
AuthSecret::Scram(secret) => {
info!("auth endpoint chooses SCRAM");
let scram = auth::Scram(&secret, &mut *ctx);
let scram = auth::Scram(&secret, ctx);
let auth_outcome = tokio::time::timeout(
config.scram_protocol_timeout,

View File

@@ -18,7 +18,7 @@ use tracing::{info, warn};
/// These properties are benefical for serverless JS workers, so we
/// use this mechanism for websocket connections.
pub async fn authenticate_cleartext(
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
info: ComputeUserInfo,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
secret: AuthSecret,
@@ -28,7 +28,7 @@ pub async fn authenticate_cleartext(
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
// pause the timer while we communicate with the client
let paused = ctx.latency_timer.pause(crate::metrics::Waiting::Client);
let paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let ep = EndpointIdInt::from(&info.endpoint);
@@ -60,7 +60,7 @@ pub async fn authenticate_cleartext(
/// Similar to [`authenticate_cleartext`], but there's a specific password format,
/// and passwords are not yet validated (we don't know how to validate them!)
pub async fn password_hack_no_authentication(
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
info: ComputeUserInfoNoEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
) -> auth::Result<ComputeCredentials> {
@@ -68,7 +68,7 @@ pub async fn password_hack_no_authentication(
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer.pause(crate::metrics::Waiting::Client);
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let payload = AuthFlow::new(client)
.begin(auth::PasswordHack)

View File

@@ -57,7 +57,7 @@ pub fn new_psql_session_id() -> String {
}
pub(super) async fn authenticate(
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
link_uri: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<NodeInfo> {

View File

@@ -84,7 +84,7 @@ pub fn endpoint_sni(
impl ComputeUserInfoMaybeEndpoint {
pub fn parse(
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
params: &StartupMessageParams,
sni: Option<&str>,
common_names: Option<&HashSet<String>>,
@@ -249,8 +249,8 @@ mod tests {
fn parse_bare_minimum() -> anyhow::Result<()> {
// According to postgresql, only `user` should be required.
let options = StartupMessageParams::new([("user", "john_doe")]);
let mut ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id, None);
@@ -264,8 +264,8 @@ mod tests {
("database", "world"), // should be ignored
("foo", "bar"), // should be ignored
]);
let mut ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id, None);
@@ -279,9 +279,9 @@ mod tests {
let sni = Some("foo.localhost");
let common_names = Some(["localhost".into()].into());
let mut ctx = RequestMonitoring::test();
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id.as_deref(), Some("foo"));
assert_eq!(user_info.options.get_cache_key("foo"), "foo");
@@ -296,8 +296,8 @@ mod tests {
("options", "-ckey=1 project=bar -c geqo=off"),
]);
let mut ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
@@ -311,8 +311,8 @@ mod tests {
("options", "-ckey=1 endpoint=bar -c geqo=off"),
]);
let mut ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
@@ -329,8 +329,8 @@ mod tests {
),
]);
let mut ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
assert_eq!(user_info.user, "john_doe");
assert!(user_info.endpoint_id.is_none());
@@ -344,8 +344,8 @@ mod tests {
("options", "-ckey=1 endpoint=bar project=foo -c geqo=off"),
]);
let mut ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
assert_eq!(user_info.user, "john_doe");
assert!(user_info.endpoint_id.is_none());
@@ -359,9 +359,9 @@ mod tests {
let sni = Some("baz.localhost");
let common_names = Some(["localhost".into()].into());
let mut ctx = RequestMonitoring::test();
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id.as_deref(), Some("baz"));
@@ -374,16 +374,16 @@ mod tests {
let common_names = Some(["a.com".into(), "b.com".into()].into());
let sni = Some("p1.a.com");
let mut ctx = RequestMonitoring::test();
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
let common_names = Some(["a.com".into(), "b.com".into()].into());
let sni = Some("p1.b.com");
let mut ctx = RequestMonitoring::test();
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
Ok(())
@@ -397,10 +397,9 @@ mod tests {
let sni = Some("second.localhost");
let common_names = Some(["localhost".into()].into());
let mut ctx = RequestMonitoring::test();
let err =
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())
.expect_err("should fail");
let ctx = RequestMonitoring::test();
let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
.expect_err("should fail");
match err {
InconsistentProjectNames { domain, option } => {
assert_eq!(option, "first");
@@ -417,10 +416,9 @@ mod tests {
let sni = Some("project.localhost");
let common_names = Some(["example.com".into()].into());
let mut ctx = RequestMonitoring::test();
let err =
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())
.expect_err("should fail");
let ctx = RequestMonitoring::test();
let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
.expect_err("should fail");
match err {
UnknownCommonName { cn } => {
assert_eq!(cn, "localhost");
@@ -438,9 +436,9 @@ mod tests {
let sni = Some("project.localhost");
let common_names = Some(["localhost".into()].into());
let mut ctx = RequestMonitoring::test();
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.endpoint_id.as_deref(), Some("project"));
assert_eq!(
user_info.options.get_cache_key("project"),

View File

@@ -27,7 +27,7 @@ pub trait AuthMethod {
pub struct Begin;
/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
pub struct Scram<'a>(pub &'a scram::ServerSecret, pub &'a mut RequestMonitoring);
pub struct Scram<'a>(pub &'a scram::ServerSecret, pub &'a RequestMonitoring);
impl AuthMethod for Scram<'_> {
#[inline(always)]
@@ -155,7 +155,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
let Scram(secret, ctx) = self.state;
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer.pause(crate::metrics::Waiting::Client);
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
// Initial client message contains the chosen auth method's name.
let msg = self.stream.read_password_message().await?;
@@ -168,10 +168,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
}
match sasl.method {
SCRAM_SHA_256 => ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256),
SCRAM_SHA_256_PLUS => {
ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256Plus)
}
SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
SCRAM_SHA_256_PLUS => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus),
_ => {}
}
info!("client chooses {}", sasl.method);