mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-08 05:52:55 +00:00
proxy: random changes (#8602)
## Problem 1. Hard to correlate startup parameters with the endpoint that provided them. 2. Some configurations are not needed in the `ProxyConfig` struct. ## Summary of changes Because of some borrow checker fun, I needed to switch to an interior-mutability implementation of our `RequestMonitoring` context system. Using https://docs.rs/try-lock/latest/try_lock/ as a cheap lock for such a use-case (needed to be thread safe). Removed the lock of each startup message, instead just logging only the startup params in a successful handshake. Also removed from values from `ProxyConfig` and kept as arguments. (needed for local-proxy config)
This commit is contained in:
@@ -218,7 +218,7 @@ impl RateBucketInfo {
|
||||
impl AuthenticationConfig {
|
||||
pub fn check_rate_limit(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
config: &AuthenticationConfig,
|
||||
secret: AuthSecret,
|
||||
endpoint: &EndpointId,
|
||||
@@ -243,7 +243,7 @@ impl AuthenticationConfig {
|
||||
let limit_not_exceeded = self.rate_limiter.check(
|
||||
(
|
||||
endpoint_int,
|
||||
MaskedIp::new(ctx.peer_addr, config.rate_limit_ip_subnet),
|
||||
MaskedIp::new(ctx.peer_addr(), config.rate_limit_ip_subnet),
|
||||
),
|
||||
password_weight,
|
||||
);
|
||||
@@ -274,7 +274,7 @@ impl AuthenticationConfig {
|
||||
///
|
||||
/// All authentication flows will emit an AuthenticationOk message if successful.
|
||||
async fn auth_quirks(
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
api: &impl console::Api,
|
||||
user_info: ComputeUserInfoMaybeEndpoint,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
@@ -303,8 +303,8 @@ async fn auth_quirks(
|
||||
let (allowed_ips, maybe_secret) = api.get_allowed_ips_and_secret(ctx, &info).await?;
|
||||
|
||||
// check allowed list
|
||||
if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
|
||||
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr));
|
||||
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) {
|
||||
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
|
||||
}
|
||||
|
||||
if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) {
|
||||
@@ -356,7 +356,7 @@ async fn auth_quirks(
|
||||
}
|
||||
|
||||
async fn authenticate_with_secret(
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
secret: AuthSecret,
|
||||
info: ComputeUserInfo,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
@@ -421,7 +421,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
|
||||
#[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
|
||||
pub async fn authenticate(
|
||||
self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
@@ -467,7 +467,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
|
||||
impl BackendType<'_, ComputeUserInfo, &()> {
|
||||
pub async fn get_role_secret(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
) -> Result<CachedRoleSecret, GetAuthInfoError> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
@@ -478,7 +478,7 @@ impl BackendType<'_, ComputeUserInfo, &()> {
|
||||
|
||||
pub async fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
@@ -492,7 +492,7 @@ impl BackendType<'_, ComputeUserInfo, &()> {
|
||||
impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
|
||||
use BackendType::*;
|
||||
|
||||
@@ -514,7 +514,7 @@ impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> {
|
||||
impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
|
||||
use BackendType::*;
|
||||
|
||||
@@ -571,7 +571,7 @@ mod tests {
|
||||
impl console::Api for Auth {
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
_ctx: &mut RequestMonitoring,
|
||||
_ctx: &RequestMonitoring,
|
||||
_user_info: &super::ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError> {
|
||||
Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone())))
|
||||
@@ -579,7 +579,7 @@ mod tests {
|
||||
|
||||
async fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
_ctx: &mut RequestMonitoring,
|
||||
_ctx: &RequestMonitoring,
|
||||
_user_info: &super::ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>
|
||||
{
|
||||
@@ -591,7 +591,7 @@ mod tests {
|
||||
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
_ctx: &mut RequestMonitoring,
|
||||
_ctx: &RequestMonitoring,
|
||||
_user_info: &super::ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
|
||||
unimplemented!()
|
||||
@@ -665,7 +665,7 @@ mod tests {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new(Stream::from_raw(server));
|
||||
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let ctx = RequestMonitoring::test();
|
||||
let api = Auth {
|
||||
ips: vec![],
|
||||
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
|
||||
@@ -723,7 +723,7 @@ mod tests {
|
||||
));
|
||||
|
||||
let _creds = auth_quirks(
|
||||
&mut ctx,
|
||||
&ctx,
|
||||
&api,
|
||||
user_info,
|
||||
&mut stream,
|
||||
@@ -742,7 +742,7 @@ mod tests {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new(Stream::from_raw(server));
|
||||
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let ctx = RequestMonitoring::test();
|
||||
let api = Auth {
|
||||
ips: vec![],
|
||||
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
|
||||
@@ -775,7 +775,7 @@ mod tests {
|
||||
));
|
||||
|
||||
let _creds = auth_quirks(
|
||||
&mut ctx,
|
||||
&ctx,
|
||||
&api,
|
||||
user_info,
|
||||
&mut stream,
|
||||
@@ -794,7 +794,7 @@ mod tests {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new(Stream::from_raw(server));
|
||||
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let ctx = RequestMonitoring::test();
|
||||
let api = Auth {
|
||||
ips: vec![],
|
||||
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
|
||||
@@ -828,7 +828,7 @@ mod tests {
|
||||
));
|
||||
|
||||
let creds = auth_quirks(
|
||||
&mut ctx,
|
||||
&ctx,
|
||||
&api,
|
||||
user_info,
|
||||
&mut stream,
|
||||
|
||||
@@ -12,7 +12,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{info, warn};
|
||||
|
||||
pub(super) async fn authenticate(
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
creds: ComputeUserInfo,
|
||||
client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
config: &'static AuthenticationConfig,
|
||||
@@ -27,7 +27,7 @@ pub(super) async fn authenticate(
|
||||
}
|
||||
AuthSecret::Scram(secret) => {
|
||||
info!("auth endpoint chooses SCRAM");
|
||||
let scram = auth::Scram(&secret, &mut *ctx);
|
||||
let scram = auth::Scram(&secret, ctx);
|
||||
|
||||
let auth_outcome = tokio::time::timeout(
|
||||
config.scram_protocol_timeout,
|
||||
|
||||
@@ -18,7 +18,7 @@ use tracing::{info, warn};
|
||||
/// These properties are benefical for serverless JS workers, so we
|
||||
/// use this mechanism for websocket connections.
|
||||
pub async fn authenticate_cleartext(
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
info: ComputeUserInfo,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
secret: AuthSecret,
|
||||
@@ -28,7 +28,7 @@ pub async fn authenticate_cleartext(
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||
|
||||
// pause the timer while we communicate with the client
|
||||
let paused = ctx.latency_timer.pause(crate::metrics::Waiting::Client);
|
||||
let paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
|
||||
let ep = EndpointIdInt::from(&info.endpoint);
|
||||
|
||||
@@ -60,7 +60,7 @@ pub async fn authenticate_cleartext(
|
||||
/// Similar to [`authenticate_cleartext`], but there's a specific password format,
|
||||
/// and passwords are not yet validated (we don't know how to validate them!)
|
||||
pub async fn password_hack_no_authentication(
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
info: ComputeUserInfoNoEndpoint,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
@@ -68,7 +68,7 @@ pub async fn password_hack_no_authentication(
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer.pause(crate::metrics::Waiting::Client);
|
||||
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
|
||||
let payload = AuthFlow::new(client)
|
||||
.begin(auth::PasswordHack)
|
||||
|
||||
@@ -57,7 +57,7 @@ pub fn new_psql_session_id() -> String {
|
||||
}
|
||||
|
||||
pub(super) async fn authenticate(
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
link_uri: &reqwest::Url,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<NodeInfo> {
|
||||
|
||||
@@ -84,7 +84,7 @@ pub fn endpoint_sni(
|
||||
|
||||
impl ComputeUserInfoMaybeEndpoint {
|
||||
pub fn parse(
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
params: &StartupMessageParams,
|
||||
sni: Option<&str>,
|
||||
common_names: Option<&HashSet<String>>,
|
||||
@@ -249,8 +249,8 @@ mod tests {
|
||||
fn parse_bare_minimum() -> anyhow::Result<()> {
|
||||
// According to postgresql, only `user` should be required.
|
||||
let options = StartupMessageParams::new([("user", "john_doe")]);
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
|
||||
let ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.endpoint_id, None);
|
||||
|
||||
@@ -264,8 +264,8 @@ mod tests {
|
||||
("database", "world"), // should be ignored
|
||||
("foo", "bar"), // should be ignored
|
||||
]);
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
|
||||
let ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.endpoint_id, None);
|
||||
|
||||
@@ -279,9 +279,9 @@ mod tests {
|
||||
let sni = Some("foo.localhost");
|
||||
let common_names = Some(["localhost".into()].into());
|
||||
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let ctx = RequestMonitoring::test();
|
||||
let user_info =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
|
||||
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.endpoint_id.as_deref(), Some("foo"));
|
||||
assert_eq!(user_info.options.get_cache_key("foo"), "foo");
|
||||
@@ -296,8 +296,8 @@ mod tests {
|
||||
("options", "-ckey=1 project=bar -c geqo=off"),
|
||||
]);
|
||||
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
|
||||
let ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
|
||||
|
||||
@@ -311,8 +311,8 @@ mod tests {
|
||||
("options", "-ckey=1 endpoint=bar -c geqo=off"),
|
||||
]);
|
||||
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
|
||||
let ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
|
||||
|
||||
@@ -329,8 +329,8 @@ mod tests {
|
||||
),
|
||||
]);
|
||||
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
|
||||
let ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert!(user_info.endpoint_id.is_none());
|
||||
|
||||
@@ -344,8 +344,8 @@ mod tests {
|
||||
("options", "-ckey=1 endpoint=bar project=foo -c geqo=off"),
|
||||
]);
|
||||
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
|
||||
let ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert!(user_info.endpoint_id.is_none());
|
||||
|
||||
@@ -359,9 +359,9 @@ mod tests {
|
||||
let sni = Some("baz.localhost");
|
||||
let common_names = Some(["localhost".into()].into());
|
||||
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let ctx = RequestMonitoring::test();
|
||||
let user_info =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
|
||||
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.endpoint_id.as_deref(), Some("baz"));
|
||||
|
||||
@@ -374,16 +374,16 @@ mod tests {
|
||||
|
||||
let common_names = Some(["a.com".into(), "b.com".into()].into());
|
||||
let sni = Some("p1.a.com");
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let ctx = RequestMonitoring::test();
|
||||
let user_info =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
|
||||
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
|
||||
assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
|
||||
|
||||
let common_names = Some(["a.com".into(), "b.com".into()].into());
|
||||
let sni = Some("p1.b.com");
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let ctx = RequestMonitoring::test();
|
||||
let user_info =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
|
||||
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
|
||||
assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
|
||||
|
||||
Ok(())
|
||||
@@ -397,10 +397,9 @@ mod tests {
|
||||
let sni = Some("second.localhost");
|
||||
let common_names = Some(["localhost".into()].into());
|
||||
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let err =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())
|
||||
.expect_err("should fail");
|
||||
let ctx = RequestMonitoring::test();
|
||||
let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
|
||||
.expect_err("should fail");
|
||||
match err {
|
||||
InconsistentProjectNames { domain, option } => {
|
||||
assert_eq!(option, "first");
|
||||
@@ -417,10 +416,9 @@ mod tests {
|
||||
let sni = Some("project.localhost");
|
||||
let common_names = Some(["example.com".into()].into());
|
||||
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let err =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())
|
||||
.expect_err("should fail");
|
||||
let ctx = RequestMonitoring::test();
|
||||
let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
|
||||
.expect_err("should fail");
|
||||
match err {
|
||||
UnknownCommonName { cn } => {
|
||||
assert_eq!(cn, "localhost");
|
||||
@@ -438,9 +436,9 @@ mod tests {
|
||||
|
||||
let sni = Some("project.localhost");
|
||||
let common_names = Some(["localhost".into()].into());
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let ctx = RequestMonitoring::test();
|
||||
let user_info =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
|
||||
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
|
||||
assert_eq!(user_info.endpoint_id.as_deref(), Some("project"));
|
||||
assert_eq!(
|
||||
user_info.options.get_cache_key("project"),
|
||||
|
||||
@@ -27,7 +27,7 @@ pub trait AuthMethod {
|
||||
pub struct Begin;
|
||||
|
||||
/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
|
||||
pub struct Scram<'a>(pub &'a scram::ServerSecret, pub &'a mut RequestMonitoring);
|
||||
pub struct Scram<'a>(pub &'a scram::ServerSecret, pub &'a RequestMonitoring);
|
||||
|
||||
impl AuthMethod for Scram<'_> {
|
||||
#[inline(always)]
|
||||
@@ -155,7 +155,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
let Scram(secret, ctx) = self.state;
|
||||
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer.pause(crate::metrics::Waiting::Client);
|
||||
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
|
||||
// Initial client message contains the chosen auth method's name.
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
@@ -168,10 +168,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
}
|
||||
|
||||
match sasl.method {
|
||||
SCRAM_SHA_256 => ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256),
|
||||
SCRAM_SHA_256_PLUS => {
|
||||
ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256Plus)
|
||||
}
|
||||
SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
|
||||
SCRAM_SHA_256_PLUS => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus),
|
||||
_ => {}
|
||||
}
|
||||
info!("client chooses {}", sasl.method);
|
||||
|
||||
@@ -205,7 +205,7 @@ async fn task_main(
|
||||
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
||||
|
||||
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
raw_stream: S,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
@@ -256,13 +256,13 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
|
||||
async fn handle_client(
|
||||
mut ctx: RequestMonitoring,
|
||||
ctx: RequestMonitoring,
|
||||
dest_suffix: Arc<String>,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut tls_stream = ssl_handshake(&mut ctx, stream, tls_config, tls_server_end_point).await?;
|
||||
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?;
|
||||
|
||||
// Cut off first part of the SNI domain
|
||||
// We receive required destination details in the format of
|
||||
|
||||
@@ -5,6 +5,7 @@ use aws_config::meta::region::RegionProviderChain;
|
||||
use aws_config::profile::ProfileFileCredentialsProvider;
|
||||
use aws_config::provider_config::ProviderConfig;
|
||||
use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider;
|
||||
use aws_config::Region;
|
||||
use futures::future::Either;
|
||||
use proxy::auth;
|
||||
use proxy::auth::backend::AuthRateLimiter;
|
||||
@@ -290,9 +291,10 @@ async fn main() -> anyhow::Result<()> {
|
||||
let config = build_config(&args)?;
|
||||
|
||||
info!("Authentication backend: {}", config.auth_backend);
|
||||
info!("Using region: {}", config.aws_region);
|
||||
info!("Using region: {}", args.aws_region);
|
||||
|
||||
let region_provider = RegionProviderChain::default_provider().or_else(&*config.aws_region); // Replace with your Redis region if needed
|
||||
let region_provider =
|
||||
RegionProviderChain::default_provider().or_else(Region::new(args.aws_region.clone()));
|
||||
let provider_conf =
|
||||
ProviderConfig::without_region().with_region(region_provider.region().await);
|
||||
let aws_credentials_provider = {
|
||||
@@ -318,7 +320,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
};
|
||||
let elasticache_credentials_provider = Arc::new(elasticache::CredentialsProvider::new(
|
||||
elasticache::AWSIRSAConfig::new(
|
||||
config.aws_region.clone(),
|
||||
args.aws_region.clone(),
|
||||
args.redis_cluster_name,
|
||||
args.redis_user_id,
|
||||
),
|
||||
@@ -376,11 +378,14 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
let cancel_map = CancelMap::default();
|
||||
|
||||
let redis_rps_limit = Vec::leak(args.redis_rps_limit.clone());
|
||||
RateBucketInfo::validate(redis_rps_limit)?;
|
||||
|
||||
let redis_publisher = match ®ional_redis_client {
|
||||
Some(redis_publisher) => Some(Arc::new(Mutex::new(RedisPublisherClient::new(
|
||||
redis_publisher.clone(),
|
||||
args.region.clone(),
|
||||
&config.redis_rps_limit,
|
||||
redis_rps_limit,
|
||||
)?))),
|
||||
None => None,
|
||||
};
|
||||
@@ -656,7 +661,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
)?;
|
||||
|
||||
let http_config = HttpConfig {
|
||||
request_timeout: args.sql_over_http.sql_over_http_timeout,
|
||||
pool_options: GlobalConnPoolOptions {
|
||||
max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint,
|
||||
gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch,
|
||||
@@ -676,9 +680,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
|
||||
};
|
||||
|
||||
let mut redis_rps_limit = args.redis_rps_limit.clone();
|
||||
RateBucketInfo::validate(&mut redis_rps_limit)?;
|
||||
|
||||
let config = Box::leak(Box::new(ProxyConfig {
|
||||
tls_config,
|
||||
auth_backend,
|
||||
@@ -687,11 +688,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
http_config,
|
||||
authentication_config,
|
||||
require_client_ip: args.require_client_ip,
|
||||
disable_ip_check_for_http: args.disable_ip_check_for_http,
|
||||
redis_rps_limit,
|
||||
handshake_timeout: args.handshake_timeout,
|
||||
region: args.region.clone(),
|
||||
aws_region: args.aws_region.clone(),
|
||||
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
|
||||
connect_compute_locks,
|
||||
connect_to_compute_retry_config: config::RetryConfig::parse(
|
||||
|
||||
2
proxy/src/cache/endpoints.rs
vendored
2
proxy/src/cache/endpoints.rs
vendored
@@ -68,7 +68,7 @@ impl EndpointsCache {
|
||||
ready: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
pub async fn is_valid(&self, ctx: &mut RequestMonitoring, endpoint: &EndpointId) -> bool {
|
||||
pub async fn is_valid(&self, ctx: &RequestMonitoring, endpoint: &EndpointId) -> bool {
|
||||
if !self.ready.load(Ordering::Acquire) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -288,12 +288,12 @@ impl ConnCfg {
|
||||
/// Connect to a corresponding compute node.
|
||||
pub async fn connect(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
allow_self_signed_compute: bool,
|
||||
aux: MetricsAuxInfo,
|
||||
timeout: Duration,
|
||||
) -> Result<PostgresConnection, ConnectionError> {
|
||||
let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Compute);
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
|
||||
drop(pause);
|
||||
|
||||
@@ -316,14 +316,14 @@ impl ConnCfg {
|
||||
)?;
|
||||
|
||||
// connect_raw() will not use TLS if sslmode is "disable"
|
||||
let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Compute);
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let (client, connection) = self.0.connect_raw(stream, tls).await?;
|
||||
drop(pause);
|
||||
tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
|
||||
let stream = connection.stream.into_inner();
|
||||
|
||||
info!(
|
||||
cold_start_info = ctx.cold_start_info.as_str(),
|
||||
cold_start_info = ctx.cold_start_info().as_str(),
|
||||
"connected to compute node at {host} ({socket_addr}) sslmode={:?}",
|
||||
self.0.get_ssl_mode()
|
||||
);
|
||||
@@ -342,7 +342,7 @@ impl ConnCfg {
|
||||
params,
|
||||
cancel_closure,
|
||||
aux,
|
||||
_guage: Metrics::get().proxy.db_connections.guard(ctx.protocol),
|
||||
_guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
|
||||
};
|
||||
|
||||
Ok(connection)
|
||||
|
||||
@@ -31,11 +31,8 @@ pub struct ProxyConfig {
|
||||
pub http_config: HttpConfig,
|
||||
pub authentication_config: AuthenticationConfig,
|
||||
pub require_client_ip: bool,
|
||||
pub disable_ip_check_for_http: bool,
|
||||
pub redis_rps_limit: Vec<RateBucketInfo>,
|
||||
pub region: String,
|
||||
pub handshake_timeout: Duration,
|
||||
pub aws_region: String,
|
||||
pub wake_compute_retry_config: RetryConfig,
|
||||
pub connect_compute_locks: ApiLocks<Host>,
|
||||
pub connect_to_compute_retry_config: RetryConfig,
|
||||
@@ -55,7 +52,6 @@ pub struct TlsConfig {
|
||||
}
|
||||
|
||||
pub struct HttpConfig {
|
||||
pub request_timeout: tokio::time::Duration,
|
||||
pub pool_options: GlobalConnPoolOptions,
|
||||
pub cancel_set: CancelSet,
|
||||
pub client_conn_threshold: u64,
|
||||
|
||||
@@ -292,7 +292,7 @@ pub struct NodeInfo {
|
||||
impl NodeInfo {
|
||||
pub async fn connect(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
timeout: Duration,
|
||||
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
|
||||
self.config
|
||||
@@ -330,20 +330,20 @@ pub(crate) trait Api {
|
||||
/// We still have to mock the scram to avoid leaking information that user doesn't exist.
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, errors::GetAuthInfoError>;
|
||||
|
||||
async fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
|
||||
|
||||
/// Wake up the compute node and return the corresponding connection info.
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, errors::WakeComputeError>;
|
||||
}
|
||||
@@ -363,7 +363,7 @@ pub enum ConsoleBackend {
|
||||
impl Api for ConsoleBackend {
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
|
||||
use ConsoleBackend::*;
|
||||
@@ -378,7 +378,7 @@ impl Api for ConsoleBackend {
|
||||
|
||||
async fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError> {
|
||||
use ConsoleBackend::*;
|
||||
@@ -393,7 +393,7 @@ impl Api for ConsoleBackend {
|
||||
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, errors::WakeComputeError> {
|
||||
use ConsoleBackend::*;
|
||||
|
||||
@@ -158,7 +158,7 @@ impl super::Api for Api {
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
_ctx: &mut RequestMonitoring,
|
||||
_ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, GetAuthInfoError> {
|
||||
Ok(CachedRoleSecret::new_uncached(
|
||||
@@ -168,7 +168,7 @@ impl super::Api for Api {
|
||||
|
||||
async fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
_ctx: &mut RequestMonitoring,
|
||||
_ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
|
||||
Ok((
|
||||
@@ -182,7 +182,7 @@ impl super::Api for Api {
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
_ctx: &mut RequestMonitoring,
|
||||
_ctx: &RequestMonitoring,
|
||||
_user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
self.do_wake_compute().map_ok(Cached::new_uncached).await
|
||||
|
||||
@@ -57,7 +57,7 @@ impl Api {
|
||||
|
||||
async fn do_get_auth_info(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<AuthInfo, GetAuthInfoError> {
|
||||
if !self
|
||||
@@ -69,7 +69,7 @@ impl Api {
|
||||
info!("endpoint is not valid, skipping the request");
|
||||
return Ok(AuthInfo::default());
|
||||
}
|
||||
let request_id = ctx.session_id.to_string();
|
||||
let request_id = ctx.session_id().to_string();
|
||||
let application_name = ctx.console_application_name();
|
||||
async {
|
||||
let request = self
|
||||
@@ -77,7 +77,7 @@ impl Api {
|
||||
.get("proxy_get_role_secret")
|
||||
.header("X-Request-ID", &request_id)
|
||||
.header("Authorization", format!("Bearer {}", &self.jwt))
|
||||
.query(&[("session_id", ctx.session_id)])
|
||||
.query(&[("session_id", ctx.session_id())])
|
||||
.query(&[
|
||||
("application_name", application_name.as_str()),
|
||||
("project", user_info.endpoint.as_str()),
|
||||
@@ -87,7 +87,7 @@ impl Api {
|
||||
|
||||
info!(url = request.url().as_str(), "sending http request");
|
||||
let start = Instant::now();
|
||||
let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Cplane);
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
|
||||
let response = self.endpoint.execute(request).await?;
|
||||
drop(pause);
|
||||
info!(duration = ?start.elapsed(), "received http response");
|
||||
@@ -130,10 +130,10 @@ impl Api {
|
||||
|
||||
async fn do_wake_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<NodeInfo, WakeComputeError> {
|
||||
let request_id = ctx.session_id.to_string();
|
||||
let request_id = ctx.session_id().to_string();
|
||||
let application_name = ctx.console_application_name();
|
||||
async {
|
||||
let mut request_builder = self
|
||||
@@ -141,7 +141,7 @@ impl Api {
|
||||
.get("proxy_wake_compute")
|
||||
.header("X-Request-ID", &request_id)
|
||||
.header("Authorization", format!("Bearer {}", &self.jwt))
|
||||
.query(&[("session_id", ctx.session_id)])
|
||||
.query(&[("session_id", ctx.session_id())])
|
||||
.query(&[
|
||||
("application_name", application_name.as_str()),
|
||||
("project", user_info.endpoint.as_str()),
|
||||
@@ -156,7 +156,7 @@ impl Api {
|
||||
|
||||
info!(url = request.url().as_str(), "sending http request");
|
||||
let start = Instant::now();
|
||||
let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Cplane);
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
|
||||
let response = self.endpoint.execute(request).await?;
|
||||
drop(pause);
|
||||
info!(duration = ?start.elapsed(), "received http response");
|
||||
@@ -192,7 +192,7 @@ impl super::Api for Api {
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, GetAuthInfoError> {
|
||||
let normalized_ep = &user_info.endpoint.normalize();
|
||||
@@ -226,7 +226,7 @@ impl super::Api for Api {
|
||||
|
||||
async fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
|
||||
let normalized_ep = &user_info.endpoint.normalize();
|
||||
@@ -268,7 +268,7 @@ impl super::Api for Api {
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
let key = user_info.endpoint_cache_key();
|
||||
|
||||
@@ -7,13 +7,14 @@ use smol_str::SmolStr;
|
||||
use std::net::IpAddr;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{field::display, info, info_span, Span};
|
||||
use try_lock::TryLock;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
console::messages::{ColdStartInfo, MetricsAuxInfo},
|
||||
error::ErrorKind,
|
||||
intern::{BranchIdInt, ProjectIdInt},
|
||||
metrics::{ConnectOutcome, InvalidEndpointsGroup, LatencyTimer, Metrics, Protocol},
|
||||
metrics::{ConnectOutcome, InvalidEndpointsGroup, LatencyTimer, Metrics, Protocol, Waiting},
|
||||
DbName, EndpointId, RoleName,
|
||||
};
|
||||
|
||||
@@ -28,7 +29,15 @@ pub static LOG_CHAN_DISCONNECT: OnceCell<mpsc::WeakUnboundedSender<RequestData>>
|
||||
///
|
||||
/// This data should **not** be used for connection logic, only for observability and limiting purposes.
|
||||
/// All connection logic should instead use strongly typed state machines, not a bunch of Options.
|
||||
pub struct RequestMonitoring {
|
||||
pub struct RequestMonitoring(
|
||||
/// To allow easier use of the ctx object, we have interior mutability.
|
||||
/// I would typically use a RefCell but that would break the `Send` requirements
|
||||
/// so we need something with thread-safety. `TryLock` is a cheap alternative
|
||||
/// that offers similar semantics to a `RefCell` but with synchronisation.
|
||||
TryLock<RequestMonitoringInner>,
|
||||
);
|
||||
|
||||
struct RequestMonitoringInner {
|
||||
pub peer_addr: IpAddr,
|
||||
pub session_id: Uuid,
|
||||
pub protocol: Protocol,
|
||||
@@ -85,7 +94,7 @@ impl RequestMonitoring {
|
||||
role = tracing::field::Empty,
|
||||
);
|
||||
|
||||
Self {
|
||||
let inner = RequestMonitoringInner {
|
||||
peer_addr,
|
||||
session_id,
|
||||
protocol,
|
||||
@@ -110,7 +119,9 @@ impl RequestMonitoring {
|
||||
disconnect_sender: LOG_CHAN_DISCONNECT.get().and_then(|tx| tx.upgrade()),
|
||||
latency_timer: LatencyTimer::new(protocol),
|
||||
disconnect_timestamp: None,
|
||||
}
|
||||
};
|
||||
|
||||
Self(TryLock::new(inner))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -119,48 +130,177 @@ impl RequestMonitoring {
|
||||
}
|
||||
|
||||
pub fn console_application_name(&self) -> String {
|
||||
let this = self.0.try_lock().expect("should not deadlock");
|
||||
format!(
|
||||
"{}/{}",
|
||||
self.application.as_deref().unwrap_or_default(),
|
||||
self.protocol
|
||||
this.application.as_deref().unwrap_or_default(),
|
||||
this.protocol
|
||||
)
|
||||
}
|
||||
|
||||
pub fn set_rejected(&mut self, rejected: bool) {
|
||||
self.rejected = Some(rejected);
|
||||
pub fn set_rejected(&self, rejected: bool) {
|
||||
let mut this = self.0.try_lock().expect("should not deadlock");
|
||||
this.rejected = Some(rejected);
|
||||
}
|
||||
|
||||
pub fn set_cold_start_info(&mut self, info: ColdStartInfo) {
|
||||
pub fn set_cold_start_info(&self, info: ColdStartInfo) {
|
||||
self.0
|
||||
.try_lock()
|
||||
.expect("should not deadlock")
|
||||
.set_cold_start_info(info);
|
||||
}
|
||||
|
||||
pub fn set_db_options(&self, options: StartupMessageParams) {
|
||||
let mut this = self.0.try_lock().expect("should not deadlock");
|
||||
this.set_application(options.get("application_name").map(SmolStr::from));
|
||||
if let Some(user) = options.get("user") {
|
||||
this.set_user(user.into());
|
||||
}
|
||||
if let Some(dbname) = options.get("database") {
|
||||
this.set_dbname(dbname.into());
|
||||
}
|
||||
|
||||
this.pg_options = Some(options);
|
||||
}
|
||||
|
||||
pub fn set_project(&self, x: MetricsAuxInfo) {
|
||||
let mut this = self.0.try_lock().expect("should not deadlock");
|
||||
if this.endpoint_id.is_none() {
|
||||
this.set_endpoint_id(x.endpoint_id.as_str().into())
|
||||
}
|
||||
this.branch = Some(x.branch_id);
|
||||
this.project = Some(x.project_id);
|
||||
this.set_cold_start_info(x.cold_start_info);
|
||||
}
|
||||
|
||||
pub fn set_project_id(&self, project_id: ProjectIdInt) {
|
||||
let mut this = self.0.try_lock().expect("should not deadlock");
|
||||
this.project = Some(project_id);
|
||||
}
|
||||
|
||||
pub fn set_endpoint_id(&self, endpoint_id: EndpointId) {
|
||||
self.0
|
||||
.try_lock()
|
||||
.expect("should not deadlock")
|
||||
.set_endpoint_id(endpoint_id);
|
||||
}
|
||||
|
||||
pub fn set_dbname(&self, dbname: DbName) {
|
||||
self.0
|
||||
.try_lock()
|
||||
.expect("should not deadlock")
|
||||
.set_dbname(dbname);
|
||||
}
|
||||
|
||||
pub fn set_user(&self, user: RoleName) {
|
||||
self.0
|
||||
.try_lock()
|
||||
.expect("should not deadlock")
|
||||
.set_user(user);
|
||||
}
|
||||
|
||||
pub fn set_auth_method(&self, auth_method: AuthMethod) {
|
||||
let mut this = self.0.try_lock().expect("should not deadlock");
|
||||
this.auth_method = Some(auth_method);
|
||||
}
|
||||
|
||||
pub fn has_private_peer_addr(&self) -> bool {
|
||||
self.0
|
||||
.try_lock()
|
||||
.expect("should not deadlock")
|
||||
.has_private_peer_addr()
|
||||
}
|
||||
|
||||
pub fn set_error_kind(&self, kind: ErrorKind) {
|
||||
let mut this = self.0.try_lock().expect("should not deadlock");
|
||||
// Do not record errors from the private address to metrics.
|
||||
if !this.has_private_peer_addr() {
|
||||
Metrics::get().proxy.errors_total.inc(kind);
|
||||
}
|
||||
if let Some(ep) = &this.endpoint_id {
|
||||
let metric = &Metrics::get().proxy.endpoints_affected_by_errors;
|
||||
let label = metric.with_labels(kind);
|
||||
metric.get_metric(label).measure(ep);
|
||||
}
|
||||
this.error_kind = Some(kind);
|
||||
}
|
||||
|
||||
pub fn set_success(&self) {
|
||||
let mut this = self.0.try_lock().expect("should not deadlock");
|
||||
this.success = true;
|
||||
}
|
||||
|
||||
pub fn log_connect(&self) {
|
||||
self.0
|
||||
.try_lock()
|
||||
.expect("should not deadlock")
|
||||
.log_connect();
|
||||
}
|
||||
|
||||
pub fn protocol(&self) -> Protocol {
|
||||
self.0.try_lock().expect("should not deadlock").protocol
|
||||
}
|
||||
|
||||
pub fn span(&self) -> Span {
|
||||
self.0.try_lock().expect("should not deadlock").span.clone()
|
||||
}
|
||||
|
||||
pub fn session_id(&self) -> Uuid {
|
||||
self.0.try_lock().expect("should not deadlock").session_id
|
||||
}
|
||||
|
||||
pub fn peer_addr(&self) -> IpAddr {
|
||||
self.0.try_lock().expect("should not deadlock").peer_addr
|
||||
}
|
||||
|
||||
pub fn cold_start_info(&self) -> ColdStartInfo {
|
||||
self.0
|
||||
.try_lock()
|
||||
.expect("should not deadlock")
|
||||
.cold_start_info
|
||||
}
|
||||
|
||||
pub fn latency_timer_pause(&self, waiting_for: Waiting) -> LatencyTimerPause {
|
||||
LatencyTimerPause {
|
||||
ctx: self,
|
||||
start: tokio::time::Instant::now(),
|
||||
waiting_for,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn success(&self) {
|
||||
self.0
|
||||
.try_lock()
|
||||
.expect("should not deadlock")
|
||||
.latency_timer
|
||||
.success()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LatencyTimerPause<'a> {
|
||||
ctx: &'a RequestMonitoring,
|
||||
start: tokio::time::Instant,
|
||||
waiting_for: Waiting,
|
||||
}
|
||||
|
||||
impl Drop for LatencyTimerPause<'_> {
|
||||
fn drop(&mut self) {
|
||||
self.ctx
|
||||
.0
|
||||
.try_lock()
|
||||
.expect("should not deadlock")
|
||||
.latency_timer
|
||||
.unpause(self.start, self.waiting_for);
|
||||
}
|
||||
}
|
||||
|
||||
impl RequestMonitoringInner {
|
||||
fn set_cold_start_info(&mut self, info: ColdStartInfo) {
|
||||
self.cold_start_info = info;
|
||||
self.latency_timer.cold_start_info(info);
|
||||
}
|
||||
|
||||
pub fn set_db_options(&mut self, options: StartupMessageParams) {
|
||||
self.set_application(options.get("application_name").map(SmolStr::from));
|
||||
if let Some(user) = options.get("user") {
|
||||
self.set_user(user.into());
|
||||
}
|
||||
if let Some(dbname) = options.get("database") {
|
||||
self.set_dbname(dbname.into());
|
||||
}
|
||||
|
||||
self.pg_options = Some(options);
|
||||
}
|
||||
|
||||
pub fn set_project(&mut self, x: MetricsAuxInfo) {
|
||||
if self.endpoint_id.is_none() {
|
||||
self.set_endpoint_id(x.endpoint_id.as_str().into())
|
||||
}
|
||||
self.branch = Some(x.branch_id);
|
||||
self.project = Some(x.project_id);
|
||||
self.set_cold_start_info(x.cold_start_info);
|
||||
}
|
||||
|
||||
pub fn set_project_id(&mut self, project_id: ProjectIdInt) {
|
||||
self.project = Some(project_id);
|
||||
}
|
||||
|
||||
pub fn set_endpoint_id(&mut self, endpoint_id: EndpointId) {
|
||||
fn set_endpoint_id(&mut self, endpoint_id: EndpointId) {
|
||||
if self.endpoint_id.is_none() {
|
||||
self.span.record("ep", display(&endpoint_id));
|
||||
let metric = &Metrics::get().proxy.connecting_endpoints;
|
||||
@@ -176,44 +316,23 @@ impl RequestMonitoring {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_dbname(&mut self, dbname: DbName) {
|
||||
fn set_dbname(&mut self, dbname: DbName) {
|
||||
self.dbname = Some(dbname);
|
||||
}
|
||||
|
||||
pub fn set_user(&mut self, user: RoleName) {
|
||||
fn set_user(&mut self, user: RoleName) {
|
||||
self.span.record("role", display(&user));
|
||||
self.user = Some(user);
|
||||
}
|
||||
|
||||
pub fn set_auth_method(&mut self, auth_method: AuthMethod) {
|
||||
self.auth_method = Some(auth_method);
|
||||
}
|
||||
|
||||
pub fn has_private_peer_addr(&self) -> bool {
|
||||
fn has_private_peer_addr(&self) -> bool {
|
||||
match self.peer_addr {
|
||||
IpAddr::V4(ip) => ip.is_private(),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_error_kind(&mut self, kind: ErrorKind) {
|
||||
// Do not record errors from the private address to metrics.
|
||||
if !self.has_private_peer_addr() {
|
||||
Metrics::get().proxy.errors_total.inc(kind);
|
||||
}
|
||||
if let Some(ep) = &self.endpoint_id {
|
||||
let metric = &Metrics::get().proxy.endpoints_affected_by_errors;
|
||||
let label = metric.with_labels(kind);
|
||||
metric.get_metric(label).measure(ep);
|
||||
}
|
||||
self.error_kind = Some(kind);
|
||||
}
|
||||
|
||||
pub fn set_success(&mut self) {
|
||||
self.success = true;
|
||||
}
|
||||
|
||||
pub fn log_connect(&mut self) {
|
||||
fn log_connect(&mut self) {
|
||||
let outcome = if self.success {
|
||||
ConnectOutcome::Success
|
||||
} else {
|
||||
@@ -256,7 +375,7 @@ impl RequestMonitoring {
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for RequestMonitoring {
|
||||
impl Drop for RequestMonitoringInner {
|
||||
fn drop(&mut self) {
|
||||
if self.sender.is_some() {
|
||||
self.log_connect();
|
||||
|
||||
@@ -23,7 +23,7 @@ use utils::backoff;
|
||||
|
||||
use crate::{config::remote_storage_from_toml, context::LOG_CHAN_DISCONNECT};
|
||||
|
||||
use super::{RequestMonitoring, LOG_CHAN};
|
||||
use super::{RequestMonitoringInner, LOG_CHAN};
|
||||
|
||||
#[derive(clap::Args, Clone, Debug)]
|
||||
pub struct ParquetUploadArgs {
|
||||
@@ -118,8 +118,8 @@ impl<'a> serde::Serialize for Options<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RequestMonitoring> for RequestData {
|
||||
fn from(value: &RequestMonitoring) -> Self {
|
||||
impl From<&RequestMonitoringInner> for RequestData {
|
||||
fn from(value: &RequestMonitoringInner) -> Self {
|
||||
Self {
|
||||
session_id: value.session_id,
|
||||
peer_addr: value.peer_addr.to_string(),
|
||||
|
||||
@@ -370,6 +370,7 @@ pub struct CancellationRequest {
|
||||
pub kind: CancellationOutcome,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum Waiting {
|
||||
Cplane,
|
||||
Client,
|
||||
@@ -398,12 +399,6 @@ pub struct LatencyTimer {
|
||||
outcome: ConnectOutcome,
|
||||
}
|
||||
|
||||
pub struct LatencyTimerPause<'a> {
|
||||
timer: &'a mut LatencyTimer,
|
||||
start: time::Instant,
|
||||
waiting_for: Waiting,
|
||||
}
|
||||
|
||||
impl LatencyTimer {
|
||||
pub fn new(protocol: Protocol) -> Self {
|
||||
Self {
|
||||
@@ -417,11 +412,13 @@ impl LatencyTimer {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn pause(&mut self, waiting_for: Waiting) -> LatencyTimerPause<'_> {
|
||||
LatencyTimerPause {
|
||||
timer: self,
|
||||
start: Instant::now(),
|
||||
waiting_for,
|
||||
pub fn unpause(&mut self, start: Instant, waiting_for: Waiting) {
|
||||
let dur = start.elapsed();
|
||||
match waiting_for {
|
||||
Waiting::Cplane => self.accumulated.cplane += dur,
|
||||
Waiting::Client => self.accumulated.client += dur,
|
||||
Waiting::Compute => self.accumulated.compute += dur,
|
||||
Waiting::RetryTimeout => self.accumulated.retry += dur,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -438,18 +435,6 @@ impl LatencyTimer {
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for LatencyTimerPause<'_> {
|
||||
fn drop(&mut self) {
|
||||
let dur = self.start.elapsed();
|
||||
match self.waiting_for {
|
||||
Waiting::Cplane => self.timer.accumulated.cplane += dur,
|
||||
Waiting::Client => self.timer.accumulated.client += dur,
|
||||
Waiting::Compute => self.timer.accumulated.compute += dur,
|
||||
Waiting::RetryTimeout => self.timer.accumulated.retry += dur,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]
|
||||
pub enum ConnectOutcome {
|
||||
Success,
|
||||
|
||||
@@ -113,18 +113,18 @@ pub async fn task_main(
|
||||
}
|
||||
};
|
||||
|
||||
let mut ctx = RequestMonitoring::new(
|
||||
let ctx = RequestMonitoring::new(
|
||||
session_id,
|
||||
peer_addr,
|
||||
crate::metrics::Protocol::Tcp,
|
||||
&config.region,
|
||||
);
|
||||
let span = ctx.span.clone();
|
||||
let span = ctx.span();
|
||||
|
||||
let startup = Box::pin(
|
||||
handle_client(
|
||||
config,
|
||||
&mut ctx,
|
||||
&ctx,
|
||||
cancellation_handler,
|
||||
socket,
|
||||
ClientMode::Tcp,
|
||||
@@ -240,7 +240,7 @@ impl ReportableError for ClientRequestError {
|
||||
|
||||
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
stream: S,
|
||||
mode: ClientMode,
|
||||
@@ -248,25 +248,25 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
) -> Result<Option<ProxyPassthrough<CancellationHandlerMainInternal, S>>, ClientRequestError> {
|
||||
info!(
|
||||
protocol = %ctx.protocol,
|
||||
protocol = %ctx.protocol(),
|
||||
"handling interactive connection from client"
|
||||
);
|
||||
|
||||
let metrics = &Metrics::get().proxy;
|
||||
let proto = ctx.protocol;
|
||||
let proto = ctx.protocol();
|
||||
let _request_gauge = metrics.connection_requests.guard(proto);
|
||||
|
||||
let tls = config.tls_config.as_ref();
|
||||
|
||||
let record_handshake_error = !ctx.has_private_peer_addr();
|
||||
let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Client);
|
||||
let do_handshake = handshake(stream, mode.handshake_tls(tls), record_handshake_error);
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
|
||||
let (mut stream, params) =
|
||||
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(cancel_key_data) => {
|
||||
return Ok(cancellation_handler
|
||||
.cancel_session(cancel_key_data, ctx.session_id)
|
||||
.cancel_session(cancel_key_data, ctx.session_id())
|
||||
.await
|
||||
.map(|()| None)?)
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ pub trait ConnectMechanism {
|
||||
type Error: From<Self::ConnectError>;
|
||||
async fn connect_once(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
node_info: &console::CachedNodeInfo,
|
||||
timeout: time::Duration,
|
||||
) -> Result<Self::Connection, Self::ConnectError>;
|
||||
@@ -58,7 +58,7 @@ pub trait ConnectMechanism {
|
||||
pub trait ComputeConnectBackend {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
|
||||
|
||||
fn get_keys(&self) -> Option<&ComputeCredentialKeys>;
|
||||
@@ -81,7 +81,7 @@ impl ConnectMechanism for TcpMechanism<'_> {
|
||||
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
|
||||
async fn connect_once(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
node_info: &console::CachedNodeInfo,
|
||||
timeout: time::Duration,
|
||||
) -> Result<PostgresConnection, Self::Error> {
|
||||
@@ -98,7 +98,7 @@ impl ConnectMechanism for TcpMechanism<'_> {
|
||||
/// Try to connect to the compute node, retrying if necessary.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
mechanism: &M,
|
||||
user_info: &B,
|
||||
allow_self_signed_compute: bool,
|
||||
@@ -126,7 +126,7 @@ where
|
||||
.await
|
||||
{
|
||||
Ok(res) => {
|
||||
ctx.latency_timer.success();
|
||||
ctx.success();
|
||||
Metrics::get().proxy.retries_metric.observe(
|
||||
RetriesMetricGroup {
|
||||
outcome: ConnectOutcome::Success,
|
||||
@@ -178,7 +178,7 @@ where
|
||||
.await
|
||||
{
|
||||
Ok(res) => {
|
||||
ctx.latency_timer.success();
|
||||
ctx.success();
|
||||
Metrics::get().proxy.retries_metric.observe(
|
||||
RetriesMetricGroup {
|
||||
outcome: ConnectOutcome::Success,
|
||||
@@ -209,9 +209,7 @@ where
|
||||
let wait_duration = retry_after(num_retries, connect_to_compute_retry_config);
|
||||
num_retries += 1;
|
||||
|
||||
let pause = ctx
|
||||
.latency_timer
|
||||
.pause(crate::metrics::Waiting::RetryTimeout);
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::RetryTimeout);
|
||||
time::sleep(wait_duration).await;
|
||||
drop(pause);
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ use tracing::{info, warn};
|
||||
use crate::{
|
||||
auth::endpoint_sni,
|
||||
config::{TlsConfig, PG_ALPN_PROTOCOL},
|
||||
context::RequestMonitoring,
|
||||
error::ReportableError,
|
||||
metrics::Metrics,
|
||||
proxy::ERR_INSECURE_CONNECTION,
|
||||
@@ -67,6 +68,7 @@ pub enum HandshakeData<S> {
|
||||
/// we also take an extra care of propagating only the select handshake errors to client.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
ctx: &RequestMonitoring,
|
||||
stream: S,
|
||||
mut tls: Option<&TlsConfig>,
|
||||
record_handshake_error: bool,
|
||||
@@ -80,8 +82,6 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
let mut stream = PqStream::new(Stream::from_raw(stream));
|
||||
loop {
|
||||
let msg = stream.read_startup_packet().await?;
|
||||
info!("received {msg:?}");
|
||||
|
||||
use FeStartupPacket::*;
|
||||
match msg {
|
||||
SslRequest { direct } => match stream.get_ref() {
|
||||
@@ -145,16 +145,20 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
|
||||
let conn_info = tls_stream.get_ref().1;
|
||||
|
||||
// try parse endpoint
|
||||
let ep = conn_info
|
||||
.server_name()
|
||||
.and_then(|sni| endpoint_sni(sni, &tls.common_names).ok().flatten());
|
||||
if let Some(ep) = ep {
|
||||
ctx.set_endpoint_id(ep);
|
||||
}
|
||||
|
||||
// check the ALPN, if exists, as required.
|
||||
match conn_info.alpn_protocol() {
|
||||
None | Some(PG_ALPN_PROTOCOL) => {}
|
||||
Some(other) => {
|
||||
// try parse ep for better error
|
||||
let ep = conn_info.server_name().and_then(|sni| {
|
||||
endpoint_sni(sni, &tls.common_names).ok().flatten()
|
||||
});
|
||||
let alpn = String::from_utf8_lossy(other);
|
||||
warn!(?ep, %alpn, "unexpected ALPN");
|
||||
warn!(%alpn, "unexpected ALPN");
|
||||
return Err(HandshakeError::ProtocolViolation);
|
||||
}
|
||||
}
|
||||
@@ -198,7 +202,12 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
.await?;
|
||||
}
|
||||
|
||||
info!(?version, session_type = "normal", "successful handshake");
|
||||
info!(
|
||||
?version,
|
||||
?params,
|
||||
session_type = "normal",
|
||||
"successful handshake"
|
||||
);
|
||||
break Ok(HandshakeData::Startup(stream, params));
|
||||
}
|
||||
// downgrade protocol version
|
||||
|
||||
@@ -155,7 +155,7 @@ impl TestAuth for Scram {
|
||||
stream: &mut PqStream<Stream<S>>,
|
||||
) -> anyhow::Result<()> {
|
||||
let outcome = auth::AuthFlow::new(stream)
|
||||
.begin(auth::Scram(&self.0, &mut RequestMonitoring::test()))
|
||||
.begin(auth::Scram(&self.0, &RequestMonitoring::test()))
|
||||
.await?
|
||||
.authenticate()
|
||||
.await?;
|
||||
@@ -175,10 +175,11 @@ async fn dummy_proxy(
|
||||
auth: impl TestAuth + Send,
|
||||
) -> anyhow::Result<()> {
|
||||
let (client, _) = read_proxy_protocol(client).await?;
|
||||
let mut stream = match handshake(client, tls.as_ref(), false).await? {
|
||||
HandshakeData::Startup(stream, _) => stream,
|
||||
HandshakeData::Cancel(_) => bail!("cancellation not supported"),
|
||||
};
|
||||
let mut stream =
|
||||
match handshake(&RequestMonitoring::test(), client, tls.as_ref(), false).await? {
|
||||
HandshakeData::Startup(stream, _) => stream,
|
||||
HandshakeData::Cancel(_) => bail!("cancellation not supported"),
|
||||
};
|
||||
|
||||
auth.authenticate(&mut stream).await?;
|
||||
|
||||
@@ -457,7 +458,7 @@ impl ConnectMechanism for TestConnectMechanism {
|
||||
|
||||
async fn connect_once(
|
||||
&self,
|
||||
_ctx: &mut RequestMonitoring,
|
||||
_ctx: &RequestMonitoring,
|
||||
_node_info: &console::CachedNodeInfo,
|
||||
_timeout: std::time::Duration,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
@@ -565,7 +566,7 @@ fn helper_create_connect_info(
|
||||
async fn connect_to_compute_success() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = RetryConfig {
|
||||
@@ -573,7 +574,7 @@ async fn connect_to_compute_success() {
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config)
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, false, config, config)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -583,7 +584,7 @@ async fn connect_to_compute_success() {
|
||||
async fn connect_to_compute_retry() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = RetryConfig {
|
||||
@@ -591,7 +592,7 @@ async fn connect_to_compute_retry() {
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config)
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, false, config, config)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -602,7 +603,7 @@ async fn connect_to_compute_retry() {
|
||||
async fn connect_to_compute_non_retry_1() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = RetryConfig {
|
||||
@@ -610,7 +611,7 @@ async fn connect_to_compute_non_retry_1() {
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config)
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, false, config, config)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
@@ -621,7 +622,7 @@ async fn connect_to_compute_non_retry_1() {
|
||||
async fn connect_to_compute_non_retry_2() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = RetryConfig {
|
||||
@@ -629,7 +630,7 @@ async fn connect_to_compute_non_retry_2() {
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config)
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, false, config, config)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -641,7 +642,7 @@ async fn connect_to_compute_non_retry_3() {
|
||||
let _ = env_logger::try_init();
|
||||
tokio::time::pause();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let ctx = RequestMonitoring::test();
|
||||
let mechanism =
|
||||
TestConnectMechanism::new(vec![Wake, Retry, Wake, Retry, Retry, Retry, Retry, Retry]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
@@ -656,7 +657,7 @@ async fn connect_to_compute_non_retry_3() {
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
connect_to_compute(
|
||||
&mut ctx,
|
||||
&ctx,
|
||||
&mechanism,
|
||||
&user_info,
|
||||
false,
|
||||
@@ -673,7 +674,7 @@ async fn connect_to_compute_non_retry_3() {
|
||||
async fn wake_retry() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = RetryConfig {
|
||||
@@ -681,7 +682,7 @@ async fn wake_retry() {
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config)
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, false, config, config)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -692,7 +693,7 @@ async fn wake_retry() {
|
||||
async fn wake_non_retry() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = RetryConfig {
|
||||
@@ -700,7 +701,7 @@ async fn wake_non_retry() {
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config)
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, false, config, config)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
|
||||
@@ -34,9 +34,14 @@ async fn proxy_mitm(
|
||||
tokio::spawn(async move {
|
||||
// begin handshake with end_server
|
||||
let end_server = connect_tls(server2, client_config2.make_tls_connect().unwrap()).await;
|
||||
let (end_client, startup) = match handshake(client1, Some(&server_config1), false)
|
||||
.await
|
||||
.unwrap()
|
||||
let (end_client, startup) = match handshake(
|
||||
&RequestMonitoring::test(),
|
||||
client1,
|
||||
Some(&server_config1),
|
||||
false,
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
{
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(_) => panic!("cancellation not supported"),
|
||||
|
||||
@@ -14,7 +14,7 @@ use super::connect_compute::ComputeConnectBackend;
|
||||
|
||||
pub async fn wake_compute<B: ComputeConnectBackend>(
|
||||
num_retries: &mut u32,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
api: &B,
|
||||
config: RetryConfig,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
@@ -52,9 +52,7 @@ pub async fn wake_compute<B: ComputeConnectBackend>(
|
||||
|
||||
let wait_duration = retry_after(*num_retries, config);
|
||||
*num_retries += 1;
|
||||
let pause = ctx
|
||||
.latency_timer
|
||||
.pause(crate::metrics::Waiting::RetryTimeout);
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::RetryTimeout);
|
||||
tokio::time::sleep(wait_duration).await;
|
||||
drop(pause);
|
||||
}
|
||||
|
||||
@@ -334,7 +334,7 @@ async fn request_handler(
|
||||
&config.region,
|
||||
);
|
||||
|
||||
let span = ctx.span.clone();
|
||||
let span = ctx.span();
|
||||
info!(parent: &span, "performing websocket upgrade");
|
||||
|
||||
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
|
||||
@@ -367,7 +367,7 @@ async fn request_handler(
|
||||
crate::metrics::Protocol::Http,
|
||||
&config.region,
|
||||
);
|
||||
let span = ctx.span.clone();
|
||||
let span = ctx.span();
|
||||
|
||||
sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)
|
||||
.instrument(span)
|
||||
|
||||
@@ -35,15 +35,15 @@ pub struct PoolingBackend {
|
||||
impl PoolingBackend {
|
||||
pub async fn authenticate(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
config: &AuthenticationConfig,
|
||||
conn_info: &ConnInfo,
|
||||
) -> Result<ComputeCredentials, AuthError> {
|
||||
let user_info = conn_info.user_info.clone();
|
||||
let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone());
|
||||
let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
|
||||
if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
|
||||
return Err(AuthError::ip_address_not_allowed(ctx.peer_addr));
|
||||
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) {
|
||||
return Err(AuthError::ip_address_not_allowed(ctx.peer_addr()));
|
||||
}
|
||||
if !self
|
||||
.endpoint_rate_limiter
|
||||
@@ -100,7 +100,7 @@ impl PoolingBackend {
|
||||
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
|
||||
pub async fn connect_to_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: ConnInfo,
|
||||
keys: ComputeCredentials,
|
||||
force_new: bool,
|
||||
@@ -222,7 +222,7 @@ impl ConnectMechanism for TokioMechanism {
|
||||
|
||||
async fn connect_once(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
node_info: &CachedNodeInfo,
|
||||
timeout: Duration,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
@@ -240,7 +240,7 @@ impl ConnectMechanism for TokioMechanism {
|
||||
.param("client_encoding", "UTF8")
|
||||
.expect("client encoding UTF8 is always valid");
|
||||
|
||||
let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Compute);
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let res = config.connect(tokio_postgres::NoTls).await;
|
||||
drop(pause);
|
||||
let (client, connection) = permit.release_result(res)?;
|
||||
|
||||
@@ -377,7 +377,7 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
|
||||
|
||||
pub fn get(
|
||||
self: &Arc<Self>,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: &ConnInfo,
|
||||
) -> Result<Option<Client<C>>, HttpConnError> {
|
||||
let mut client: Option<ClientInner<C>> = None;
|
||||
@@ -409,9 +409,9 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
|
||||
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
|
||||
"pool: reusing connection '{conn_info}'"
|
||||
);
|
||||
client.session.send(ctx.session_id)?;
|
||||
client.session.send(ctx.session_id())?;
|
||||
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
|
||||
ctx.latency_timer.success();
|
||||
ctx.success();
|
||||
return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
|
||||
}
|
||||
}
|
||||
@@ -465,19 +465,19 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
|
||||
|
||||
pub fn poll_client<C: ClientInnerExt>(
|
||||
global_pool: Arc<GlobalConnPool<C>>,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: ConnInfo,
|
||||
client: C,
|
||||
mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
|
||||
conn_id: uuid::Uuid,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> Client<C> {
|
||||
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol);
|
||||
let mut session_id = ctx.session_id;
|
||||
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
|
||||
let mut session_id = ctx.session_id();
|
||||
let (tx, mut rx) = tokio::sync::watch::channel(session_id);
|
||||
|
||||
let span = info_span!(parent: None, "connection", %conn_id);
|
||||
let cold_start_info = ctx.cold_start_info;
|
||||
let cold_start_info = ctx.cold_start_info();
|
||||
span.in_scope(|| {
|
||||
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
|
||||
});
|
||||
@@ -766,7 +766,6 @@ mod tests {
|
||||
opt_in: false,
|
||||
max_total_conns: 3,
|
||||
},
|
||||
request_timeout: Duration::from_secs(1),
|
||||
cancel_set: CancelSet::new(0),
|
||||
client_conn_threshold: u64::MAX,
|
||||
}));
|
||||
|
||||
@@ -144,7 +144,7 @@ impl UserFacingError for ConnInfoError {
|
||||
}
|
||||
|
||||
fn get_conn_info(
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
headers: &HeaderMap,
|
||||
tls: &TlsConfig,
|
||||
) -> Result<ConnInfo, ConnInfoError> {
|
||||
@@ -224,12 +224,12 @@ fn get_conn_info(
|
||||
// TODO: return different http error codes
|
||||
pub async fn handle(
|
||||
config: &'static ProxyConfig,
|
||||
mut ctx: RequestMonitoring,
|
||||
ctx: RequestMonitoring,
|
||||
request: Request<Incoming>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
cancel: CancellationToken,
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
let result = handle_inner(cancel, config, &mut ctx, request, backend).await;
|
||||
let result = handle_inner(cancel, config, &ctx, request, backend).await;
|
||||
|
||||
let mut response = match result {
|
||||
Ok(r) => {
|
||||
@@ -482,13 +482,16 @@ fn map_isolation_level_to_headers(level: IsolationLevel) -> Option<HeaderValue>
|
||||
async fn handle_inner(
|
||||
cancel: CancellationToken,
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
ctx: &RequestMonitoring,
|
||||
request: Request<Incoming>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<Full<Bytes>>, SqlOverHttpError> {
|
||||
let _requeset_gauge = Metrics::get().proxy.connection_requests.guard(ctx.protocol);
|
||||
let _requeset_gauge = Metrics::get()
|
||||
.proxy
|
||||
.connection_requests
|
||||
.guard(ctx.protocol());
|
||||
info!(
|
||||
protocol = %ctx.protocol,
|
||||
protocol = %ctx.protocol(),
|
||||
"handling interactive connection from client"
|
||||
);
|
||||
|
||||
@@ -544,7 +547,7 @@ async fn handle_inner(
|
||||
.await?;
|
||||
// not strictly necessary to mark success here,
|
||||
// but it's just insurance for if we forget it somewhere else
|
||||
ctx.latency_timer.success();
|
||||
ctx.success();
|
||||
Ok::<_, HttpConnError>(client)
|
||||
}
|
||||
.map_err(SqlOverHttpError::from),
|
||||
|
||||
@@ -129,7 +129,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
|
||||
|
||||
pub async fn serve_websocket(
|
||||
config: &'static ProxyConfig,
|
||||
mut ctx: RequestMonitoring,
|
||||
ctx: RequestMonitoring,
|
||||
websocket: OnUpgrade,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
@@ -145,7 +145,7 @@ pub async fn serve_websocket(
|
||||
|
||||
let res = Box::pin(handle_client(
|
||||
config,
|
||||
&mut ctx,
|
||||
&ctx,
|
||||
cancellation_handler,
|
||||
WebSocketRw::new(websocket),
|
||||
ClientMode::Websockets { hostname },
|
||||
|
||||
Reference in New Issue
Block a user