From ad0988f27856f8b80f86f808ad2dd4ec90aadac0 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Wed, 7 Aug 2024 14:37:03 +0100 Subject: [PATCH] proxy: random changes (#8602) ## Problem 1. Hard to correlate startup parameters with the endpoint that provided them. 2. Some configurations are not needed in the `ProxyConfig` struct. ## Summary of changes Because of some borrow checker fun, I needed to switch to an interior-mutability implementation of our `RequestMonitoring` context system. Using https://docs.rs/try-lock/latest/try_lock/ as a cheap lock for such a use-case (needed to be thread safe). Removed the lock of each startup message, instead just logging only the startup params in a successful handshake. Also removed from values from `ProxyConfig` and kept as arguments. (needed for local-proxy config) --- Cargo.lock | 5 +- Cargo.toml | 1 + proxy/Cargo.toml | 1 + proxy/src/auth/backend.rs | 40 ++--- proxy/src/auth/backend/classic.rs | 4 +- proxy/src/auth/backend/hacks.rs | 8 +- proxy/src/auth/backend/link.rs | 2 +- proxy/src/auth/credentials.rs | 60 ++++--- proxy/src/auth/flow.rs | 10 +- proxy/src/bin/pg_sni_router.rs | 6 +- proxy/src/bin/proxy.rs | 20 +-- proxy/src/cache/endpoints.rs | 2 +- proxy/src/compute.rs | 10 +- proxy/src/config.rs | 4 - proxy/src/console/provider.rs | 14 +- proxy/src/console/provider/mock.rs | 6 +- proxy/src/console/provider/neon.rs | 22 +-- proxy/src/context.rs | 241 +++++++++++++++++++------- proxy/src/context/parquet.rs | 6 +- proxy/src/metrics.rs | 31 +--- proxy/src/proxy.rs | 18 +- proxy/src/proxy/connect_compute.rs | 16 +- proxy/src/proxy/handshake.rs | 25 ++- proxy/src/proxy/tests.rs | 41 ++--- proxy/src/proxy/tests/mitm.rs | 11 +- proxy/src/proxy/wake_compute.rs | 6 +- proxy/src/serverless.rs | 4 +- proxy/src/serverless/backend.rs | 12 +- proxy/src/serverless/conn_pool.rs | 15 +- proxy/src/serverless/sql_over_http.rs | 17 +- proxy/src/serverless/websocket.rs | 4 +- 31 files changed, 386 insertions(+), 276 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 764c0fbd30..f565119dbd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4324,6 +4324,7 @@ dependencies = [ "tracing-opentelemetry", "tracing-subscriber", "tracing-utils", + "try-lock", "typed-json", "url", "urlencoding", @@ -6563,9 +6564,9 @@ dependencies = [ [[package]] name = "try-lock" -version = "0.2.4" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" [[package]] name = "tungstenite" diff --git a/Cargo.toml b/Cargo.toml index af1c1dfc82..963841e340 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -184,6 +184,7 @@ tracing = "0.1" tracing-error = "0.2.0" tracing-opentelemetry = "0.21.0" tracing-subscriber = { version = "0.3", default-features = false, features = ["smallvec", "fmt", "tracing-log", "std", "env-filter", "json"] } +try-lock = "0.2.5" twox-hash = { version = "1.6.3", default-features = false } typed-json = "0.1" url = "2.2" diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 2f18b5fbc6..b316c53034 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -92,6 +92,7 @@ tracing-opentelemetry.workspace = true tracing-subscriber.workspace = true tracing-utils.workspace = true tracing.workspace = true +try-lock.workspace = true typed-json.workspace = true url.workspace = true urlencoding.workspace = true diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index 67c4dd019e..90dea01bf3 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -218,7 +218,7 @@ impl RateBucketInfo { impl AuthenticationConfig { pub fn check_rate_limit( &self, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, config: &AuthenticationConfig, secret: AuthSecret, endpoint: &EndpointId, @@ -243,7 +243,7 @@ impl AuthenticationConfig { let limit_not_exceeded = self.rate_limiter.check( ( endpoint_int, - MaskedIp::new(ctx.peer_addr, config.rate_limit_ip_subnet), + MaskedIp::new(ctx.peer_addr(), config.rate_limit_ip_subnet), ), password_weight, ); @@ -274,7 +274,7 @@ impl AuthenticationConfig { /// /// All authentication flows will emit an AuthenticationOk message if successful. async fn auth_quirks( - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, api: &impl console::Api, user_info: ComputeUserInfoMaybeEndpoint, client: &mut stream::PqStream>, @@ -303,8 +303,8 @@ async fn auth_quirks( let (allowed_ips, maybe_secret) = api.get_allowed_ips_and_secret(ctx, &info).await?; // check allowed list - if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) { - return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr)); + if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) { + return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr())); } if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) { @@ -356,7 +356,7 @@ async fn auth_quirks( } async fn authenticate_with_secret( - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, secret: AuthSecret, info: ComputeUserInfo, client: &mut stream::PqStream>, @@ -421,7 +421,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> { #[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)] pub async fn authenticate( self, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, client: &mut stream::PqStream>, allow_cleartext: bool, config: &'static AuthenticationConfig, @@ -467,7 +467,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> { impl BackendType<'_, ComputeUserInfo, &()> { pub async fn get_role_secret( &self, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, ) -> Result { use BackendType::*; match self { @@ -478,7 +478,7 @@ impl BackendType<'_, ComputeUserInfo, &()> { pub async fn get_allowed_ips_and_secret( &self, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, ) -> Result<(CachedAllowedIps, Option), GetAuthInfoError> { use BackendType::*; match self { @@ -492,7 +492,7 @@ impl BackendType<'_, ComputeUserInfo, &()> { impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> { async fn wake_compute( &self, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, ) -> Result { use BackendType::*; @@ -514,7 +514,7 @@ impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> { impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> { async fn wake_compute( &self, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, ) -> Result { use BackendType::*; @@ -571,7 +571,7 @@ mod tests { impl console::Api for Auth { async fn get_role_secret( &self, - _ctx: &mut RequestMonitoring, + _ctx: &RequestMonitoring, _user_info: &super::ComputeUserInfo, ) -> Result { Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone()))) @@ -579,7 +579,7 @@ mod tests { async fn get_allowed_ips_and_secret( &self, - _ctx: &mut RequestMonitoring, + _ctx: &RequestMonitoring, _user_info: &super::ComputeUserInfo, ) -> Result<(CachedAllowedIps, Option), console::errors::GetAuthInfoError> { @@ -591,7 +591,7 @@ mod tests { async fn wake_compute( &self, - _ctx: &mut RequestMonitoring, + _ctx: &RequestMonitoring, _user_info: &super::ComputeUserInfo, ) -> Result { unimplemented!() @@ -665,7 +665,7 @@ mod tests { let (mut client, server) = tokio::io::duplex(1024); let mut stream = PqStream::new(Stream::from_raw(server)); - let mut ctx = RequestMonitoring::test(); + let ctx = RequestMonitoring::test(); let api = Auth { ips: vec![], secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()), @@ -723,7 +723,7 @@ mod tests { )); let _creds = auth_quirks( - &mut ctx, + &ctx, &api, user_info, &mut stream, @@ -742,7 +742,7 @@ mod tests { let (mut client, server) = tokio::io::duplex(1024); let mut stream = PqStream::new(Stream::from_raw(server)); - let mut ctx = RequestMonitoring::test(); + let ctx = RequestMonitoring::test(); let api = Auth { ips: vec![], secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()), @@ -775,7 +775,7 @@ mod tests { )); let _creds = auth_quirks( - &mut ctx, + &ctx, &api, user_info, &mut stream, @@ -794,7 +794,7 @@ mod tests { let (mut client, server) = tokio::io::duplex(1024); let mut stream = PqStream::new(Stream::from_raw(server)); - let mut ctx = RequestMonitoring::test(); + let ctx = RequestMonitoring::test(); let api = Auth { ips: vec![], secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()), @@ -828,7 +828,7 @@ mod tests { )); let creds = auth_quirks( - &mut ctx, + &ctx, &api, user_info, &mut stream, diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index b98fa63120..285fa29428 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -12,7 +12,7 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{info, warn}; pub(super) async fn authenticate( - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, creds: ComputeUserInfo, client: &mut PqStream>, config: &'static AuthenticationConfig, @@ -27,7 +27,7 @@ pub(super) async fn authenticate( } AuthSecret::Scram(secret) => { info!("auth endpoint chooses SCRAM"); - let scram = auth::Scram(&secret, &mut *ctx); + let scram = auth::Scram(&secret, ctx); let auth_outcome = tokio::time::timeout( config.scram_protocol_timeout, diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index 6b0f5e1726..56921dd949 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -18,7 +18,7 @@ use tracing::{info, warn}; /// These properties are benefical for serverless JS workers, so we /// use this mechanism for websocket connections. pub async fn authenticate_cleartext( - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, info: ComputeUserInfo, client: &mut stream::PqStream>, secret: AuthSecret, @@ -28,7 +28,7 @@ pub async fn authenticate_cleartext( ctx.set_auth_method(crate::context::AuthMethod::Cleartext); // pause the timer while we communicate with the client - let paused = ctx.latency_timer.pause(crate::metrics::Waiting::Client); + let paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); let ep = EndpointIdInt::from(&info.endpoint); @@ -60,7 +60,7 @@ pub async fn authenticate_cleartext( /// Similar to [`authenticate_cleartext`], but there's a specific password format, /// and passwords are not yet validated (we don't know how to validate them!) pub async fn password_hack_no_authentication( - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, info: ComputeUserInfoNoEndpoint, client: &mut stream::PqStream>, ) -> auth::Result { @@ -68,7 +68,7 @@ pub async fn password_hack_no_authentication( ctx.set_auth_method(crate::context::AuthMethod::Cleartext); // pause the timer while we communicate with the client - let _paused = ctx.latency_timer.pause(crate::metrics::Waiting::Client); + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); let payload = AuthFlow::new(client) .begin(auth::PasswordHack) diff --git a/proxy/src/auth/backend/link.rs b/proxy/src/auth/backend/link.rs index 5932e1337c..95f4614736 100644 --- a/proxy/src/auth/backend/link.rs +++ b/proxy/src/auth/backend/link.rs @@ -57,7 +57,7 @@ pub fn new_psql_session_id() -> String { } pub(super) async fn authenticate( - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, link_uri: &reqwest::Url, client: &mut PqStream, ) -> auth::Result { diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index d06f5614f1..8f4a392131 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -84,7 +84,7 @@ pub fn endpoint_sni( impl ComputeUserInfoMaybeEndpoint { pub fn parse( - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, params: &StartupMessageParams, sni: Option<&str>, common_names: Option<&HashSet>, @@ -249,8 +249,8 @@ mod tests { fn parse_bare_minimum() -> anyhow::Result<()> { // According to postgresql, only `user` should be required. let options = StartupMessageParams::new([("user", "john_doe")]); - let mut ctx = RequestMonitoring::test(); - let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?; + let ctx = RequestMonitoring::test(); + let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?; assert_eq!(user_info.user, "john_doe"); assert_eq!(user_info.endpoint_id, None); @@ -264,8 +264,8 @@ mod tests { ("database", "world"), // should be ignored ("foo", "bar"), // should be ignored ]); - let mut ctx = RequestMonitoring::test(); - let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?; + let ctx = RequestMonitoring::test(); + let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?; assert_eq!(user_info.user, "john_doe"); assert_eq!(user_info.endpoint_id, None); @@ -279,9 +279,9 @@ mod tests { let sni = Some("foo.localhost"); let common_names = Some(["localhost".into()].into()); - let mut ctx = RequestMonitoring::test(); + let ctx = RequestMonitoring::test(); let user_info = - ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?; + ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?; assert_eq!(user_info.user, "john_doe"); assert_eq!(user_info.endpoint_id.as_deref(), Some("foo")); assert_eq!(user_info.options.get_cache_key("foo"), "foo"); @@ -296,8 +296,8 @@ mod tests { ("options", "-ckey=1 project=bar -c geqo=off"), ]); - let mut ctx = RequestMonitoring::test(); - let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?; + let ctx = RequestMonitoring::test(); + let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?; assert_eq!(user_info.user, "john_doe"); assert_eq!(user_info.endpoint_id.as_deref(), Some("bar")); @@ -311,8 +311,8 @@ mod tests { ("options", "-ckey=1 endpoint=bar -c geqo=off"), ]); - let mut ctx = RequestMonitoring::test(); - let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?; + let ctx = RequestMonitoring::test(); + let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?; assert_eq!(user_info.user, "john_doe"); assert_eq!(user_info.endpoint_id.as_deref(), Some("bar")); @@ -329,8 +329,8 @@ mod tests { ), ]); - let mut ctx = RequestMonitoring::test(); - let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?; + let ctx = RequestMonitoring::test(); + let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?; assert_eq!(user_info.user, "john_doe"); assert!(user_info.endpoint_id.is_none()); @@ -344,8 +344,8 @@ mod tests { ("options", "-ckey=1 endpoint=bar project=foo -c geqo=off"), ]); - let mut ctx = RequestMonitoring::test(); - let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?; + let ctx = RequestMonitoring::test(); + let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?; assert_eq!(user_info.user, "john_doe"); assert!(user_info.endpoint_id.is_none()); @@ -359,9 +359,9 @@ mod tests { let sni = Some("baz.localhost"); let common_names = Some(["localhost".into()].into()); - let mut ctx = RequestMonitoring::test(); + let ctx = RequestMonitoring::test(); let user_info = - ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?; + ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?; assert_eq!(user_info.user, "john_doe"); assert_eq!(user_info.endpoint_id.as_deref(), Some("baz")); @@ -374,16 +374,16 @@ mod tests { let common_names = Some(["a.com".into(), "b.com".into()].into()); let sni = Some("p1.a.com"); - let mut ctx = RequestMonitoring::test(); + let ctx = RequestMonitoring::test(); let user_info = - ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?; + ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?; assert_eq!(user_info.endpoint_id.as_deref(), Some("p1")); let common_names = Some(["a.com".into(), "b.com".into()].into()); let sni = Some("p1.b.com"); - let mut ctx = RequestMonitoring::test(); + let ctx = RequestMonitoring::test(); let user_info = - ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?; + ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?; assert_eq!(user_info.endpoint_id.as_deref(), Some("p1")); Ok(()) @@ -397,10 +397,9 @@ mod tests { let sni = Some("second.localhost"); let common_names = Some(["localhost".into()].into()); - let mut ctx = RequestMonitoring::test(); - let err = - ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref()) - .expect_err("should fail"); + let ctx = RequestMonitoring::test(); + let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref()) + .expect_err("should fail"); match err { InconsistentProjectNames { domain, option } => { assert_eq!(option, "first"); @@ -417,10 +416,9 @@ mod tests { let sni = Some("project.localhost"); let common_names = Some(["example.com".into()].into()); - let mut ctx = RequestMonitoring::test(); - let err = - ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref()) - .expect_err("should fail"); + let ctx = RequestMonitoring::test(); + let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref()) + .expect_err("should fail"); match err { UnknownCommonName { cn } => { assert_eq!(cn, "localhost"); @@ -438,9 +436,9 @@ mod tests { let sni = Some("project.localhost"); let common_names = Some(["localhost".into()].into()); - let mut ctx = RequestMonitoring::test(); + let ctx = RequestMonitoring::test(); let user_info = - ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?; + ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?; assert_eq!(user_info.endpoint_id.as_deref(), Some("project")); assert_eq!( user_info.options.get_cache_key("project"), diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index 59d1ac17f4..acf7b4f6b6 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -27,7 +27,7 @@ pub trait AuthMethod { pub struct Begin; /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`]. -pub struct Scram<'a>(pub &'a scram::ServerSecret, pub &'a mut RequestMonitoring); +pub struct Scram<'a>(pub &'a scram::ServerSecret, pub &'a RequestMonitoring); impl AuthMethod for Scram<'_> { #[inline(always)] @@ -155,7 +155,7 @@ impl AuthFlow<'_, S, Scram<'_>> { let Scram(secret, ctx) = self.state; // pause the timer while we communicate with the client - let _paused = ctx.latency_timer.pause(crate::metrics::Waiting::Client); + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); // Initial client message contains the chosen auth method's name. let msg = self.stream.read_password_message().await?; @@ -168,10 +168,8 @@ impl AuthFlow<'_, S, Scram<'_>> { } match sasl.method { - SCRAM_SHA_256 => ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256), - SCRAM_SHA_256_PLUS => { - ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256Plus) - } + SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256), + SCRAM_SHA_256_PLUS => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus), _ => {} } info!("client chooses {}", sasl.method); diff --git a/proxy/src/bin/pg_sni_router.rs b/proxy/src/bin/pg_sni_router.rs index d7a3eb9a4d..1038fa5116 100644 --- a/proxy/src/bin/pg_sni_router.rs +++ b/proxy/src/bin/pg_sni_router.rs @@ -205,7 +205,7 @@ async fn task_main( const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; async fn ssl_handshake( - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, raw_stream: S, tls_config: Arc, tls_server_end_point: TlsServerEndPoint, @@ -256,13 +256,13 @@ async fn ssl_handshake( } async fn handle_client( - mut ctx: RequestMonitoring, + ctx: RequestMonitoring, dest_suffix: Arc, tls_config: Arc, tls_server_end_point: TlsServerEndPoint, stream: impl AsyncRead + AsyncWrite + Unpin, ) -> anyhow::Result<()> { - let mut tls_stream = ssl_handshake(&mut ctx, stream, tls_config, tls_server_end_point).await?; + let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?; // Cut off first part of the SNI domain // We receive required destination details in the format of diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index c1fd6dfd80..b44e0ddd2f 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -5,6 +5,7 @@ use aws_config::meta::region::RegionProviderChain; use aws_config::profile::ProfileFileCredentialsProvider; use aws_config::provider_config::ProviderConfig; use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider; +use aws_config::Region; use futures::future::Either; use proxy::auth; use proxy::auth::backend::AuthRateLimiter; @@ -290,9 +291,10 @@ async fn main() -> anyhow::Result<()> { let config = build_config(&args)?; info!("Authentication backend: {}", config.auth_backend); - info!("Using region: {}", config.aws_region); + info!("Using region: {}", args.aws_region); - let region_provider = RegionProviderChain::default_provider().or_else(&*config.aws_region); // Replace with your Redis region if needed + let region_provider = + RegionProviderChain::default_provider().or_else(Region::new(args.aws_region.clone())); let provider_conf = ProviderConfig::without_region().with_region(region_provider.region().await); let aws_credentials_provider = { @@ -318,7 +320,7 @@ async fn main() -> anyhow::Result<()> { }; let elasticache_credentials_provider = Arc::new(elasticache::CredentialsProvider::new( elasticache::AWSIRSAConfig::new( - config.aws_region.clone(), + args.aws_region.clone(), args.redis_cluster_name, args.redis_user_id, ), @@ -376,11 +378,14 @@ async fn main() -> anyhow::Result<()> { let cancel_map = CancelMap::default(); + let redis_rps_limit = Vec::leak(args.redis_rps_limit.clone()); + RateBucketInfo::validate(redis_rps_limit)?; + let redis_publisher = match ®ional_redis_client { Some(redis_publisher) => Some(Arc::new(Mutex::new(RedisPublisherClient::new( redis_publisher.clone(), args.region.clone(), - &config.redis_rps_limit, + redis_rps_limit, )?))), None => None, }; @@ -656,7 +661,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { )?; let http_config = HttpConfig { - request_timeout: args.sql_over_http.sql_over_http_timeout, pool_options: GlobalConnPoolOptions { max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint, gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch, @@ -676,9 +680,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet, }; - let mut redis_rps_limit = args.redis_rps_limit.clone(); - RateBucketInfo::validate(&mut redis_rps_limit)?; - let config = Box::leak(Box::new(ProxyConfig { tls_config, auth_backend, @@ -687,11 +688,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { http_config, authentication_config, require_client_ip: args.require_client_ip, - disable_ip_check_for_http: args.disable_ip_check_for_http, - redis_rps_limit, handshake_timeout: args.handshake_timeout, region: args.region.clone(), - aws_region: args.aws_region.clone(), wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?, connect_compute_locks, connect_to_compute_retry_config: config::RetryConfig::parse( diff --git a/proxy/src/cache/endpoints.rs b/proxy/src/cache/endpoints.rs index 4bc10a6020..8c851790c2 100644 --- a/proxy/src/cache/endpoints.rs +++ b/proxy/src/cache/endpoints.rs @@ -68,7 +68,7 @@ impl EndpointsCache { ready: AtomicBool::new(false), } } - pub async fn is_valid(&self, ctx: &mut RequestMonitoring, endpoint: &EndpointId) -> bool { + pub async fn is_valid(&self, ctx: &RequestMonitoring, endpoint: &EndpointId) -> bool { if !self.ready.load(Ordering::Acquire) { return true; } diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index f91693c704..21687160ea 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -288,12 +288,12 @@ impl ConnCfg { /// Connect to a corresponding compute node. pub async fn connect( &self, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, allow_self_signed_compute: bool, aux: MetricsAuxInfo, timeout: Duration, ) -> Result { - let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Compute); + let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); let (socket_addr, stream, host) = self.connect_raw(timeout).await?; drop(pause); @@ -316,14 +316,14 @@ impl ConnCfg { )?; // connect_raw() will not use TLS if sslmode is "disable" - let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Compute); + let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); let (client, connection) = self.0.connect_raw(stream, tls).await?; drop(pause); tracing::Span::current().record("pid", tracing::field::display(client.get_process_id())); let stream = connection.stream.into_inner(); info!( - cold_start_info = ctx.cold_start_info.as_str(), + cold_start_info = ctx.cold_start_info().as_str(), "connected to compute node at {host} ({socket_addr}) sslmode={:?}", self.0.get_ssl_mode() ); @@ -342,7 +342,7 @@ impl ConnCfg { params, cancel_closure, aux, - _guage: Metrics::get().proxy.db_connections.guard(ctx.protocol), + _guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()), }; Ok(connection) diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 6504919760..1412095505 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -31,11 +31,8 @@ pub struct ProxyConfig { pub http_config: HttpConfig, pub authentication_config: AuthenticationConfig, pub require_client_ip: bool, - pub disable_ip_check_for_http: bool, - pub redis_rps_limit: Vec, pub region: String, pub handshake_timeout: Duration, - pub aws_region: String, pub wake_compute_retry_config: RetryConfig, pub connect_compute_locks: ApiLocks, pub connect_to_compute_retry_config: RetryConfig, @@ -55,7 +52,6 @@ pub struct TlsConfig { } pub struct HttpConfig { - pub request_timeout: tokio::time::Duration, pub pool_options: GlobalConnPoolOptions, pub cancel_set: CancelSet, pub client_conn_threshold: u64, diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index 7a9637066f..15fc0134b3 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -292,7 +292,7 @@ pub struct NodeInfo { impl NodeInfo { pub async fn connect( &self, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, timeout: Duration, ) -> Result { self.config @@ -330,20 +330,20 @@ pub(crate) trait Api { /// We still have to mock the scram to avoid leaking information that user doesn't exist. async fn get_role_secret( &self, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, user_info: &ComputeUserInfo, ) -> Result; async fn get_allowed_ips_and_secret( &self, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, user_info: &ComputeUserInfo, ) -> Result<(CachedAllowedIps, Option), errors::GetAuthInfoError>; /// Wake up the compute node and return the corresponding connection info. async fn wake_compute( &self, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, user_info: &ComputeUserInfo, ) -> Result; } @@ -363,7 +363,7 @@ pub enum ConsoleBackend { impl Api for ConsoleBackend { async fn get_role_secret( &self, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, user_info: &ComputeUserInfo, ) -> Result { use ConsoleBackend::*; @@ -378,7 +378,7 @@ impl Api for ConsoleBackend { async fn get_allowed_ips_and_secret( &self, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, user_info: &ComputeUserInfo, ) -> Result<(CachedAllowedIps, Option), errors::GetAuthInfoError> { use ConsoleBackend::*; @@ -393,7 +393,7 @@ impl Api for ConsoleBackend { async fn wake_compute( &self, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, user_info: &ComputeUserInfo, ) -> Result { use ConsoleBackend::*; diff --git a/proxy/src/console/provider/mock.rs b/proxy/src/console/provider/mock.rs index cfe491f2aa..2093da7562 100644 --- a/proxy/src/console/provider/mock.rs +++ b/proxy/src/console/provider/mock.rs @@ -158,7 +158,7 @@ impl super::Api for Api { #[tracing::instrument(skip_all)] async fn get_role_secret( &self, - _ctx: &mut RequestMonitoring, + _ctx: &RequestMonitoring, user_info: &ComputeUserInfo, ) -> Result { Ok(CachedRoleSecret::new_uncached( @@ -168,7 +168,7 @@ impl super::Api for Api { async fn get_allowed_ips_and_secret( &self, - _ctx: &mut RequestMonitoring, + _ctx: &RequestMonitoring, user_info: &ComputeUserInfo, ) -> Result<(CachedAllowedIps, Option), GetAuthInfoError> { Ok(( @@ -182,7 +182,7 @@ impl super::Api for Api { #[tracing::instrument(skip_all)] async fn wake_compute( &self, - _ctx: &mut RequestMonitoring, + _ctx: &RequestMonitoring, _user_info: &ComputeUserInfo, ) -> Result { self.do_wake_compute().map_ok(Cached::new_uncached).await diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index 768cd2fdfa..7eda238b66 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -57,7 +57,7 @@ impl Api { async fn do_get_auth_info( &self, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, user_info: &ComputeUserInfo, ) -> Result { if !self @@ -69,7 +69,7 @@ impl Api { info!("endpoint is not valid, skipping the request"); return Ok(AuthInfo::default()); } - let request_id = ctx.session_id.to_string(); + let request_id = ctx.session_id().to_string(); let application_name = ctx.console_application_name(); async { let request = self @@ -77,7 +77,7 @@ impl Api { .get("proxy_get_role_secret") .header("X-Request-ID", &request_id) .header("Authorization", format!("Bearer {}", &self.jwt)) - .query(&[("session_id", ctx.session_id)]) + .query(&[("session_id", ctx.session_id())]) .query(&[ ("application_name", application_name.as_str()), ("project", user_info.endpoint.as_str()), @@ -87,7 +87,7 @@ impl Api { info!(url = request.url().as_str(), "sending http request"); let start = Instant::now(); - let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Cplane); + let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane); let response = self.endpoint.execute(request).await?; drop(pause); info!(duration = ?start.elapsed(), "received http response"); @@ -130,10 +130,10 @@ impl Api { async fn do_wake_compute( &self, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, user_info: &ComputeUserInfo, ) -> Result { - let request_id = ctx.session_id.to_string(); + let request_id = ctx.session_id().to_string(); let application_name = ctx.console_application_name(); async { let mut request_builder = self @@ -141,7 +141,7 @@ impl Api { .get("proxy_wake_compute") .header("X-Request-ID", &request_id) .header("Authorization", format!("Bearer {}", &self.jwt)) - .query(&[("session_id", ctx.session_id)]) + .query(&[("session_id", ctx.session_id())]) .query(&[ ("application_name", application_name.as_str()), ("project", user_info.endpoint.as_str()), @@ -156,7 +156,7 @@ impl Api { info!(url = request.url().as_str(), "sending http request"); let start = Instant::now(); - let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Cplane); + let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane); let response = self.endpoint.execute(request).await?; drop(pause); info!(duration = ?start.elapsed(), "received http response"); @@ -192,7 +192,7 @@ impl super::Api for Api { #[tracing::instrument(skip_all)] async fn get_role_secret( &self, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, user_info: &ComputeUserInfo, ) -> Result { let normalized_ep = &user_info.endpoint.normalize(); @@ -226,7 +226,7 @@ impl super::Api for Api { async fn get_allowed_ips_and_secret( &self, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, user_info: &ComputeUserInfo, ) -> Result<(CachedAllowedIps, Option), GetAuthInfoError> { let normalized_ep = &user_info.endpoint.normalize(); @@ -268,7 +268,7 @@ impl super::Api for Api { #[tracing::instrument(skip_all)] async fn wake_compute( &self, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, user_info: &ComputeUserInfo, ) -> Result { let key = user_info.endpoint_cache_key(); diff --git a/proxy/src/context.rs b/proxy/src/context.rs index ff79ba8275..e925f67233 100644 --- a/proxy/src/context.rs +++ b/proxy/src/context.rs @@ -7,13 +7,14 @@ use smol_str::SmolStr; use std::net::IpAddr; use tokio::sync::mpsc; use tracing::{field::display, info, info_span, Span}; +use try_lock::TryLock; use uuid::Uuid; use crate::{ console::messages::{ColdStartInfo, MetricsAuxInfo}, error::ErrorKind, intern::{BranchIdInt, ProjectIdInt}, - metrics::{ConnectOutcome, InvalidEndpointsGroup, LatencyTimer, Metrics, Protocol}, + metrics::{ConnectOutcome, InvalidEndpointsGroup, LatencyTimer, Metrics, Protocol, Waiting}, DbName, EndpointId, RoleName, }; @@ -28,7 +29,15 @@ pub static LOG_CHAN_DISCONNECT: OnceCell> /// /// This data should **not** be used for connection logic, only for observability and limiting purposes. /// All connection logic should instead use strongly typed state machines, not a bunch of Options. -pub struct RequestMonitoring { +pub struct RequestMonitoring( + /// To allow easier use of the ctx object, we have interior mutability. + /// I would typically use a RefCell but that would break the `Send` requirements + /// so we need something with thread-safety. `TryLock` is a cheap alternative + /// that offers similar semantics to a `RefCell` but with synchronisation. + TryLock, +); + +struct RequestMonitoringInner { pub peer_addr: IpAddr, pub session_id: Uuid, pub protocol: Protocol, @@ -85,7 +94,7 @@ impl RequestMonitoring { role = tracing::field::Empty, ); - Self { + let inner = RequestMonitoringInner { peer_addr, session_id, protocol, @@ -110,7 +119,9 @@ impl RequestMonitoring { disconnect_sender: LOG_CHAN_DISCONNECT.get().and_then(|tx| tx.upgrade()), latency_timer: LatencyTimer::new(protocol), disconnect_timestamp: None, - } + }; + + Self(TryLock::new(inner)) } #[cfg(test)] @@ -119,48 +130,177 @@ impl RequestMonitoring { } pub fn console_application_name(&self) -> String { + let this = self.0.try_lock().expect("should not deadlock"); format!( "{}/{}", - self.application.as_deref().unwrap_or_default(), - self.protocol + this.application.as_deref().unwrap_or_default(), + this.protocol ) } - pub fn set_rejected(&mut self, rejected: bool) { - self.rejected = Some(rejected); + pub fn set_rejected(&self, rejected: bool) { + let mut this = self.0.try_lock().expect("should not deadlock"); + this.rejected = Some(rejected); } - pub fn set_cold_start_info(&mut self, info: ColdStartInfo) { + pub fn set_cold_start_info(&self, info: ColdStartInfo) { + self.0 + .try_lock() + .expect("should not deadlock") + .set_cold_start_info(info); + } + + pub fn set_db_options(&self, options: StartupMessageParams) { + let mut this = self.0.try_lock().expect("should not deadlock"); + this.set_application(options.get("application_name").map(SmolStr::from)); + if let Some(user) = options.get("user") { + this.set_user(user.into()); + } + if let Some(dbname) = options.get("database") { + this.set_dbname(dbname.into()); + } + + this.pg_options = Some(options); + } + + pub fn set_project(&self, x: MetricsAuxInfo) { + let mut this = self.0.try_lock().expect("should not deadlock"); + if this.endpoint_id.is_none() { + this.set_endpoint_id(x.endpoint_id.as_str().into()) + } + this.branch = Some(x.branch_id); + this.project = Some(x.project_id); + this.set_cold_start_info(x.cold_start_info); + } + + pub fn set_project_id(&self, project_id: ProjectIdInt) { + let mut this = self.0.try_lock().expect("should not deadlock"); + this.project = Some(project_id); + } + + pub fn set_endpoint_id(&self, endpoint_id: EndpointId) { + self.0 + .try_lock() + .expect("should not deadlock") + .set_endpoint_id(endpoint_id); + } + + pub fn set_dbname(&self, dbname: DbName) { + self.0 + .try_lock() + .expect("should not deadlock") + .set_dbname(dbname); + } + + pub fn set_user(&self, user: RoleName) { + self.0 + .try_lock() + .expect("should not deadlock") + .set_user(user); + } + + pub fn set_auth_method(&self, auth_method: AuthMethod) { + let mut this = self.0.try_lock().expect("should not deadlock"); + this.auth_method = Some(auth_method); + } + + pub fn has_private_peer_addr(&self) -> bool { + self.0 + .try_lock() + .expect("should not deadlock") + .has_private_peer_addr() + } + + pub fn set_error_kind(&self, kind: ErrorKind) { + let mut this = self.0.try_lock().expect("should not deadlock"); + // Do not record errors from the private address to metrics. + if !this.has_private_peer_addr() { + Metrics::get().proxy.errors_total.inc(kind); + } + if let Some(ep) = &this.endpoint_id { + let metric = &Metrics::get().proxy.endpoints_affected_by_errors; + let label = metric.with_labels(kind); + metric.get_metric(label).measure(ep); + } + this.error_kind = Some(kind); + } + + pub fn set_success(&self) { + let mut this = self.0.try_lock().expect("should not deadlock"); + this.success = true; + } + + pub fn log_connect(&self) { + self.0 + .try_lock() + .expect("should not deadlock") + .log_connect(); + } + + pub fn protocol(&self) -> Protocol { + self.0.try_lock().expect("should not deadlock").protocol + } + + pub fn span(&self) -> Span { + self.0.try_lock().expect("should not deadlock").span.clone() + } + + pub fn session_id(&self) -> Uuid { + self.0.try_lock().expect("should not deadlock").session_id + } + + pub fn peer_addr(&self) -> IpAddr { + self.0.try_lock().expect("should not deadlock").peer_addr + } + + pub fn cold_start_info(&self) -> ColdStartInfo { + self.0 + .try_lock() + .expect("should not deadlock") + .cold_start_info + } + + pub fn latency_timer_pause(&self, waiting_for: Waiting) -> LatencyTimerPause { + LatencyTimerPause { + ctx: self, + start: tokio::time::Instant::now(), + waiting_for, + } + } + + pub fn success(&self) { + self.0 + .try_lock() + .expect("should not deadlock") + .latency_timer + .success() + } +} + +pub struct LatencyTimerPause<'a> { + ctx: &'a RequestMonitoring, + start: tokio::time::Instant, + waiting_for: Waiting, +} + +impl Drop for LatencyTimerPause<'_> { + fn drop(&mut self) { + self.ctx + .0 + .try_lock() + .expect("should not deadlock") + .latency_timer + .unpause(self.start, self.waiting_for); + } +} + +impl RequestMonitoringInner { + fn set_cold_start_info(&mut self, info: ColdStartInfo) { self.cold_start_info = info; self.latency_timer.cold_start_info(info); } - pub fn set_db_options(&mut self, options: StartupMessageParams) { - self.set_application(options.get("application_name").map(SmolStr::from)); - if let Some(user) = options.get("user") { - self.set_user(user.into()); - } - if let Some(dbname) = options.get("database") { - self.set_dbname(dbname.into()); - } - - self.pg_options = Some(options); - } - - pub fn set_project(&mut self, x: MetricsAuxInfo) { - if self.endpoint_id.is_none() { - self.set_endpoint_id(x.endpoint_id.as_str().into()) - } - self.branch = Some(x.branch_id); - self.project = Some(x.project_id); - self.set_cold_start_info(x.cold_start_info); - } - - pub fn set_project_id(&mut self, project_id: ProjectIdInt) { - self.project = Some(project_id); - } - - pub fn set_endpoint_id(&mut self, endpoint_id: EndpointId) { + fn set_endpoint_id(&mut self, endpoint_id: EndpointId) { if self.endpoint_id.is_none() { self.span.record("ep", display(&endpoint_id)); let metric = &Metrics::get().proxy.connecting_endpoints; @@ -176,44 +316,23 @@ impl RequestMonitoring { } } - pub fn set_dbname(&mut self, dbname: DbName) { + fn set_dbname(&mut self, dbname: DbName) { self.dbname = Some(dbname); } - pub fn set_user(&mut self, user: RoleName) { + fn set_user(&mut self, user: RoleName) { self.span.record("role", display(&user)); self.user = Some(user); } - pub fn set_auth_method(&mut self, auth_method: AuthMethod) { - self.auth_method = Some(auth_method); - } - - pub fn has_private_peer_addr(&self) -> bool { + fn has_private_peer_addr(&self) -> bool { match self.peer_addr { IpAddr::V4(ip) => ip.is_private(), _ => false, } } - pub fn set_error_kind(&mut self, kind: ErrorKind) { - // Do not record errors from the private address to metrics. - if !self.has_private_peer_addr() { - Metrics::get().proxy.errors_total.inc(kind); - } - if let Some(ep) = &self.endpoint_id { - let metric = &Metrics::get().proxy.endpoints_affected_by_errors; - let label = metric.with_labels(kind); - metric.get_metric(label).measure(ep); - } - self.error_kind = Some(kind); - } - - pub fn set_success(&mut self) { - self.success = true; - } - - pub fn log_connect(&mut self) { + fn log_connect(&mut self) { let outcome = if self.success { ConnectOutcome::Success } else { @@ -256,7 +375,7 @@ impl RequestMonitoring { } } -impl Drop for RequestMonitoring { +impl Drop for RequestMonitoringInner { fn drop(&mut self) { if self.sender.is_some() { self.log_connect(); diff --git a/proxy/src/context/parquet.rs b/proxy/src/context/parquet.rs index 543a458274..bb02a476fc 100644 --- a/proxy/src/context/parquet.rs +++ b/proxy/src/context/parquet.rs @@ -23,7 +23,7 @@ use utils::backoff; use crate::{config::remote_storage_from_toml, context::LOG_CHAN_DISCONNECT}; -use super::{RequestMonitoring, LOG_CHAN}; +use super::{RequestMonitoringInner, LOG_CHAN}; #[derive(clap::Args, Clone, Debug)] pub struct ParquetUploadArgs { @@ -118,8 +118,8 @@ impl<'a> serde::Serialize for Options<'a> { } } -impl From<&RequestMonitoring> for RequestData { - fn from(value: &RequestMonitoring) -> Self { +impl From<&RequestMonitoringInner> for RequestData { + fn from(value: &RequestMonitoringInner) -> Self { Self { session_id: value.session_id, peer_addr: value.peer_addr.to_string(), diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index db25ac0311..0167553e30 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -370,6 +370,7 @@ pub struct CancellationRequest { pub kind: CancellationOutcome, } +#[derive(Clone, Copy)] pub enum Waiting { Cplane, Client, @@ -398,12 +399,6 @@ pub struct LatencyTimer { outcome: ConnectOutcome, } -pub struct LatencyTimerPause<'a> { - timer: &'a mut LatencyTimer, - start: time::Instant, - waiting_for: Waiting, -} - impl LatencyTimer { pub fn new(protocol: Protocol) -> Self { Self { @@ -417,11 +412,13 @@ impl LatencyTimer { } } - pub fn pause(&mut self, waiting_for: Waiting) -> LatencyTimerPause<'_> { - LatencyTimerPause { - timer: self, - start: Instant::now(), - waiting_for, + pub fn unpause(&mut self, start: Instant, waiting_for: Waiting) { + let dur = start.elapsed(); + match waiting_for { + Waiting::Cplane => self.accumulated.cplane += dur, + Waiting::Client => self.accumulated.client += dur, + Waiting::Compute => self.accumulated.compute += dur, + Waiting::RetryTimeout => self.accumulated.retry += dur, } } @@ -438,18 +435,6 @@ impl LatencyTimer { } } -impl Drop for LatencyTimerPause<'_> { - fn drop(&mut self) { - let dur = self.start.elapsed(); - match self.waiting_for { - Waiting::Cplane => self.timer.accumulated.cplane += dur, - Waiting::Client => self.timer.accumulated.client += dur, - Waiting::Compute => self.timer.accumulated.compute += dur, - Waiting::RetryTimeout => self.timer.accumulated.retry += dur, - } - } -} - #[derive(FixedCardinalityLabel, Clone, Copy, Debug)] pub enum ConnectOutcome { Success, diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 3edefcf21a..2182f38fe7 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -113,18 +113,18 @@ pub async fn task_main( } }; - let mut ctx = RequestMonitoring::new( + let ctx = RequestMonitoring::new( session_id, peer_addr, crate::metrics::Protocol::Tcp, &config.region, ); - let span = ctx.span.clone(); + let span = ctx.span(); let startup = Box::pin( handle_client( config, - &mut ctx, + &ctx, cancellation_handler, socket, ClientMode::Tcp, @@ -240,7 +240,7 @@ impl ReportableError for ClientRequestError { pub async fn handle_client( config: &'static ProxyConfig, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, cancellation_handler: Arc, stream: S, mode: ClientMode, @@ -248,25 +248,25 @@ pub async fn handle_client( conn_gauge: NumClientConnectionsGuard<'static>, ) -> Result>, ClientRequestError> { info!( - protocol = %ctx.protocol, + protocol = %ctx.protocol(), "handling interactive connection from client" ); let metrics = &Metrics::get().proxy; - let proto = ctx.protocol; + let proto = ctx.protocol(); let _request_gauge = metrics.connection_requests.guard(proto); let tls = config.tls_config.as_ref(); let record_handshake_error = !ctx.has_private_peer_addr(); - let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Client); - let do_handshake = handshake(stream, mode.handshake_tls(tls), record_handshake_error); + let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error); let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? { HandshakeData::Startup(stream, params) => (stream, params), HandshakeData::Cancel(cancel_key_data) => { return Ok(cancellation_handler - .cancel_session(cancel_key_data, ctx.session_id) + .cancel_session(cancel_key_data, ctx.session_id()) .await .map(|()| None)?) } diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index 82180aaee3..f38e43ba5a 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -46,7 +46,7 @@ pub trait ConnectMechanism { type Error: From; async fn connect_once( &self, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, node_info: &console::CachedNodeInfo, timeout: time::Duration, ) -> Result; @@ -58,7 +58,7 @@ pub trait ConnectMechanism { pub trait ComputeConnectBackend { async fn wake_compute( &self, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, ) -> Result; fn get_keys(&self) -> Option<&ComputeCredentialKeys>; @@ -81,7 +81,7 @@ impl ConnectMechanism for TcpMechanism<'_> { #[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)] async fn connect_once( &self, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, node_info: &console::CachedNodeInfo, timeout: time::Duration, ) -> Result { @@ -98,7 +98,7 @@ impl ConnectMechanism for TcpMechanism<'_> { /// Try to connect to the compute node, retrying if necessary. #[tracing::instrument(skip_all)] pub async fn connect_to_compute( - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, mechanism: &M, user_info: &B, allow_self_signed_compute: bool, @@ -126,7 +126,7 @@ where .await { Ok(res) => { - ctx.latency_timer.success(); + ctx.success(); Metrics::get().proxy.retries_metric.observe( RetriesMetricGroup { outcome: ConnectOutcome::Success, @@ -178,7 +178,7 @@ where .await { Ok(res) => { - ctx.latency_timer.success(); + ctx.success(); Metrics::get().proxy.retries_metric.observe( RetriesMetricGroup { outcome: ConnectOutcome::Success, @@ -209,9 +209,7 @@ where let wait_duration = retry_after(num_retries, connect_to_compute_retry_config); num_retries += 1; - let pause = ctx - .latency_timer - .pause(crate::metrics::Waiting::RetryTimeout); + let pause = ctx.latency_timer_pause(crate::metrics::Waiting::RetryTimeout); time::sleep(wait_duration).await; drop(pause); } diff --git a/proxy/src/proxy/handshake.rs b/proxy/src/proxy/handshake.rs index d488aea927..c65a5558d9 100644 --- a/proxy/src/proxy/handshake.rs +++ b/proxy/src/proxy/handshake.rs @@ -10,6 +10,7 @@ use tracing::{info, warn}; use crate::{ auth::endpoint_sni, config::{TlsConfig, PG_ALPN_PROTOCOL}, + context::RequestMonitoring, error::ReportableError, metrics::Metrics, proxy::ERR_INSECURE_CONNECTION, @@ -67,6 +68,7 @@ pub enum HandshakeData { /// we also take an extra care of propagating only the select handshake errors to client. #[tracing::instrument(skip_all)] pub async fn handshake( + ctx: &RequestMonitoring, stream: S, mut tls: Option<&TlsConfig>, record_handshake_error: bool, @@ -80,8 +82,6 @@ pub async fn handshake( let mut stream = PqStream::new(Stream::from_raw(stream)); loop { let msg = stream.read_startup_packet().await?; - info!("received {msg:?}"); - use FeStartupPacket::*; match msg { SslRequest { direct } => match stream.get_ref() { @@ -145,16 +145,20 @@ pub async fn handshake( let conn_info = tls_stream.get_ref().1; + // try parse endpoint + let ep = conn_info + .server_name() + .and_then(|sni| endpoint_sni(sni, &tls.common_names).ok().flatten()); + if let Some(ep) = ep { + ctx.set_endpoint_id(ep); + } + // check the ALPN, if exists, as required. match conn_info.alpn_protocol() { None | Some(PG_ALPN_PROTOCOL) => {} Some(other) => { - // try parse ep for better error - let ep = conn_info.server_name().and_then(|sni| { - endpoint_sni(sni, &tls.common_names).ok().flatten() - }); let alpn = String::from_utf8_lossy(other); - warn!(?ep, %alpn, "unexpected ALPN"); + warn!(%alpn, "unexpected ALPN"); return Err(HandshakeError::ProtocolViolation); } } @@ -198,7 +202,12 @@ pub async fn handshake( .await?; } - info!(?version, session_type = "normal", "successful handshake"); + info!( + ?version, + ?params, + session_type = "normal", + "successful handshake" + ); break Ok(HandshakeData::Startup(stream, params)); } // downgrade protocol version diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 5186a9e1b0..d8308c4f2a 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -155,7 +155,7 @@ impl TestAuth for Scram { stream: &mut PqStream>, ) -> anyhow::Result<()> { let outcome = auth::AuthFlow::new(stream) - .begin(auth::Scram(&self.0, &mut RequestMonitoring::test())) + .begin(auth::Scram(&self.0, &RequestMonitoring::test())) .await? .authenticate() .await?; @@ -175,10 +175,11 @@ async fn dummy_proxy( auth: impl TestAuth + Send, ) -> anyhow::Result<()> { let (client, _) = read_proxy_protocol(client).await?; - let mut stream = match handshake(client, tls.as_ref(), false).await? { - HandshakeData::Startup(stream, _) => stream, - HandshakeData::Cancel(_) => bail!("cancellation not supported"), - }; + let mut stream = + match handshake(&RequestMonitoring::test(), client, tls.as_ref(), false).await? { + HandshakeData::Startup(stream, _) => stream, + HandshakeData::Cancel(_) => bail!("cancellation not supported"), + }; auth.authenticate(&mut stream).await?; @@ -457,7 +458,7 @@ impl ConnectMechanism for TestConnectMechanism { async fn connect_once( &self, - _ctx: &mut RequestMonitoring, + _ctx: &RequestMonitoring, _node_info: &console::CachedNodeInfo, _timeout: std::time::Duration, ) -> Result { @@ -565,7 +566,7 @@ fn helper_create_connect_info( async fn connect_to_compute_success() { let _ = env_logger::try_init(); use ConnectAction::*; - let mut ctx = RequestMonitoring::test(); + let ctx = RequestMonitoring::test(); let mechanism = TestConnectMechanism::new(vec![Wake, Connect]); let user_info = helper_create_connect_info(&mechanism); let config = RetryConfig { @@ -573,7 +574,7 @@ async fn connect_to_compute_success() { max_retries: 5, backoff_factor: 2.0, }; - connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config) + connect_to_compute(&ctx, &mechanism, &user_info, false, config, config) .await .unwrap(); mechanism.verify(); @@ -583,7 +584,7 @@ async fn connect_to_compute_success() { async fn connect_to_compute_retry() { let _ = env_logger::try_init(); use ConnectAction::*; - let mut ctx = RequestMonitoring::test(); + let ctx = RequestMonitoring::test(); let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]); let user_info = helper_create_connect_info(&mechanism); let config = RetryConfig { @@ -591,7 +592,7 @@ async fn connect_to_compute_retry() { max_retries: 5, backoff_factor: 2.0, }; - connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config) + connect_to_compute(&ctx, &mechanism, &user_info, false, config, config) .await .unwrap(); mechanism.verify(); @@ -602,7 +603,7 @@ async fn connect_to_compute_retry() { async fn connect_to_compute_non_retry_1() { let _ = env_logger::try_init(); use ConnectAction::*; - let mut ctx = RequestMonitoring::test(); + let ctx = RequestMonitoring::test(); let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]); let user_info = helper_create_connect_info(&mechanism); let config = RetryConfig { @@ -610,7 +611,7 @@ async fn connect_to_compute_non_retry_1() { max_retries: 5, backoff_factor: 2.0, }; - connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config) + connect_to_compute(&ctx, &mechanism, &user_info, false, config, config) .await .unwrap_err(); mechanism.verify(); @@ -621,7 +622,7 @@ async fn connect_to_compute_non_retry_1() { async fn connect_to_compute_non_retry_2() { let _ = env_logger::try_init(); use ConnectAction::*; - let mut ctx = RequestMonitoring::test(); + let ctx = RequestMonitoring::test(); let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]); let user_info = helper_create_connect_info(&mechanism); let config = RetryConfig { @@ -629,7 +630,7 @@ async fn connect_to_compute_non_retry_2() { max_retries: 5, backoff_factor: 2.0, }; - connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config) + connect_to_compute(&ctx, &mechanism, &user_info, false, config, config) .await .unwrap(); mechanism.verify(); @@ -641,7 +642,7 @@ async fn connect_to_compute_non_retry_3() { let _ = env_logger::try_init(); tokio::time::pause(); use ConnectAction::*; - let mut ctx = RequestMonitoring::test(); + let ctx = RequestMonitoring::test(); let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Retry, Retry, Retry, Retry, Retry]); let user_info = helper_create_connect_info(&mechanism); @@ -656,7 +657,7 @@ async fn connect_to_compute_non_retry_3() { backoff_factor: 2.0, }; connect_to_compute( - &mut ctx, + &ctx, &mechanism, &user_info, false, @@ -673,7 +674,7 @@ async fn connect_to_compute_non_retry_3() { async fn wake_retry() { let _ = env_logger::try_init(); use ConnectAction::*; - let mut ctx = RequestMonitoring::test(); + let ctx = RequestMonitoring::test(); let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]); let user_info = helper_create_connect_info(&mechanism); let config = RetryConfig { @@ -681,7 +682,7 @@ async fn wake_retry() { max_retries: 5, backoff_factor: 2.0, }; - connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config) + connect_to_compute(&ctx, &mechanism, &user_info, false, config, config) .await .unwrap(); mechanism.verify(); @@ -692,7 +693,7 @@ async fn wake_retry() { async fn wake_non_retry() { let _ = env_logger::try_init(); use ConnectAction::*; - let mut ctx = RequestMonitoring::test(); + let ctx = RequestMonitoring::test(); let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]); let user_info = helper_create_connect_info(&mechanism); let config = RetryConfig { @@ -700,7 +701,7 @@ async fn wake_non_retry() { max_retries: 5, backoff_factor: 2.0, }; - connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config) + connect_to_compute(&ctx, &mechanism, &user_info, false, config, config) .await .unwrap_err(); mechanism.verify(); diff --git a/proxy/src/proxy/tests/mitm.rs b/proxy/src/proxy/tests/mitm.rs index d96dd0947b..c8ec2b2db6 100644 --- a/proxy/src/proxy/tests/mitm.rs +++ b/proxy/src/proxy/tests/mitm.rs @@ -34,9 +34,14 @@ async fn proxy_mitm( tokio::spawn(async move { // begin handshake with end_server let end_server = connect_tls(server2, client_config2.make_tls_connect().unwrap()).await; - let (end_client, startup) = match handshake(client1, Some(&server_config1), false) - .await - .unwrap() + let (end_client, startup) = match handshake( + &RequestMonitoring::test(), + client1, + Some(&server_config1), + false, + ) + .await + .unwrap() { HandshakeData::Startup(stream, params) => (stream, params), HandshakeData::Cancel(_) => panic!("cancellation not supported"), diff --git a/proxy/src/proxy/wake_compute.rs b/proxy/src/proxy/wake_compute.rs index fef349aac0..5b06e8f054 100644 --- a/proxy/src/proxy/wake_compute.rs +++ b/proxy/src/proxy/wake_compute.rs @@ -14,7 +14,7 @@ use super::connect_compute::ComputeConnectBackend; pub async fn wake_compute( num_retries: &mut u32, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, api: &B, config: RetryConfig, ) -> Result { @@ -52,9 +52,7 @@ pub async fn wake_compute( let wait_duration = retry_after(*num_retries, config); *num_retries += 1; - let pause = ctx - .latency_timer - .pause(crate::metrics::Waiting::RetryTimeout); + let pause = ctx.latency_timer_pause(crate::metrics::Waiting::RetryTimeout); tokio::time::sleep(wait_duration).await; drop(pause); } diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index efa999ed7d..115bef7375 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -334,7 +334,7 @@ async fn request_handler( &config.region, ); - let span = ctx.span.clone(); + let span = ctx.span(); info!(parent: &span, "performing websocket upgrade"); let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request) @@ -367,7 +367,7 @@ async fn request_handler( crate::metrics::Protocol::Http, &config.region, ); - let span = ctx.span.clone(); + let span = ctx.span(); sql_over_http::handle(config, ctx, request, backend, http_cancellation_token) .instrument(span) diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 3b86c1838c..80d46c67eb 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -35,15 +35,15 @@ pub struct PoolingBackend { impl PoolingBackend { pub async fn authenticate( &self, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, config: &AuthenticationConfig, conn_info: &ConnInfo, ) -> Result { let user_info = conn_info.user_info.clone(); let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone()); let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?; - if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) { - return Err(AuthError::ip_address_not_allowed(ctx.peer_addr)); + if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) { + return Err(AuthError::ip_address_not_allowed(ctx.peer_addr())); } if !self .endpoint_rate_limiter @@ -100,7 +100,7 @@ impl PoolingBackend { #[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)] pub async fn connect_to_compute( &self, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, conn_info: ConnInfo, keys: ComputeCredentials, force_new: bool, @@ -222,7 +222,7 @@ impl ConnectMechanism for TokioMechanism { async fn connect_once( &self, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, node_info: &CachedNodeInfo, timeout: Duration, ) -> Result { @@ -240,7 +240,7 @@ impl ConnectMechanism for TokioMechanism { .param("client_encoding", "UTF8") .expect("client encoding UTF8 is always valid"); - let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Compute); + let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); let res = config.connect(tokio_postgres::NoTls).await; drop(pause); let (client, connection) = permit.release_result(res)?; diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index dbc58d48ec..e1dc44dc1c 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -377,7 +377,7 @@ impl GlobalConnPool { pub fn get( self: &Arc, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, conn_info: &ConnInfo, ) -> Result>, HttpConnError> { let mut client: Option> = None; @@ -409,9 +409,9 @@ impl GlobalConnPool { cold_start_info = ColdStartInfo::HttpPoolHit.as_str(), "pool: reusing connection '{conn_info}'" ); - client.session.send(ctx.session_id)?; + client.session.send(ctx.session_id())?; ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit); - ctx.latency_timer.success(); + ctx.success(); return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool))); } } @@ -465,19 +465,19 @@ impl GlobalConnPool { pub fn poll_client( global_pool: Arc>, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, conn_info: ConnInfo, client: C, mut connection: tokio_postgres::Connection, conn_id: uuid::Uuid, aux: MetricsAuxInfo, ) -> Client { - let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol); - let mut session_id = ctx.session_id; + let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol()); + let mut session_id = ctx.session_id(); let (tx, mut rx) = tokio::sync::watch::channel(session_id); let span = info_span!(parent: None, "connection", %conn_id); - let cold_start_info = ctx.cold_start_info; + let cold_start_info = ctx.cold_start_info(); span.in_scope(|| { info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection"); }); @@ -766,7 +766,6 @@ mod tests { opt_in: false, max_total_conns: 3, }, - request_timeout: Duration::from_secs(1), cancel_set: CancelSet::new(0), client_conn_threshold: u64::MAX, })); diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 6400e4ac7b..77ec6b1c73 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -144,7 +144,7 @@ impl UserFacingError for ConnInfoError { } fn get_conn_info( - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, headers: &HeaderMap, tls: &TlsConfig, ) -> Result { @@ -224,12 +224,12 @@ fn get_conn_info( // TODO: return different http error codes pub async fn handle( config: &'static ProxyConfig, - mut ctx: RequestMonitoring, + ctx: RequestMonitoring, request: Request, backend: Arc, cancel: CancellationToken, ) -> Result>, ApiError> { - let result = handle_inner(cancel, config, &mut ctx, request, backend).await; + let result = handle_inner(cancel, config, &ctx, request, backend).await; let mut response = match result { Ok(r) => { @@ -482,13 +482,16 @@ fn map_isolation_level_to_headers(level: IsolationLevel) -> Option async fn handle_inner( cancel: CancellationToken, config: &'static ProxyConfig, - ctx: &mut RequestMonitoring, + ctx: &RequestMonitoring, request: Request, backend: Arc, ) -> Result>, SqlOverHttpError> { - let _requeset_gauge = Metrics::get().proxy.connection_requests.guard(ctx.protocol); + let _requeset_gauge = Metrics::get() + .proxy + .connection_requests + .guard(ctx.protocol()); info!( - protocol = %ctx.protocol, + protocol = %ctx.protocol(), "handling interactive connection from client" ); @@ -544,7 +547,7 @@ async fn handle_inner( .await?; // not strictly necessary to mark success here, // but it's just insurance for if we forget it somewhere else - ctx.latency_timer.success(); + ctx.success(); Ok::<_, HttpConnError>(client) } .map_err(SqlOverHttpError::from), diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 0d5b88f07b..4fba4d141c 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -129,7 +129,7 @@ impl AsyncBufRead for WebSocketRw { pub async fn serve_websocket( config: &'static ProxyConfig, - mut ctx: RequestMonitoring, + ctx: RequestMonitoring, websocket: OnUpgrade, cancellation_handler: Arc, endpoint_rate_limiter: Arc, @@ -145,7 +145,7 @@ pub async fn serve_websocket( let res = Box::pin(handle_client( config, - &mut ctx, + &ctx, cancellation_handler, WebSocketRw::new(websocket), ClientMode::Websockets { hostname },