diff --git a/libs/proxy/postgres-protocol2/src/authentication/sasl.rs b/libs/proxy/postgres-protocol2/src/authentication/sasl.rs index b8304f9d8d..274c81c500 100644 --- a/libs/proxy/postgres-protocol2/src/authentication/sasl.rs +++ b/libs/proxy/postgres-protocol2/src/authentication/sasl.rs @@ -52,7 +52,7 @@ pub(crate) async fn hi(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] { } // yield every ~250us // hopefully reduces tail latencies - if i % 1024 == 0 { + if i.is_multiple_of(1024) { yield_now().await } } diff --git a/libs/proxy/tokio-postgres2/src/client.rs b/libs/proxy/tokio-postgres2/src/client.rs index 41b22e35b6..828884ffd8 100644 --- a/libs/proxy/tokio-postgres2/src/client.rs +++ b/libs/proxy/tokio-postgres2/src/client.rs @@ -90,7 +90,7 @@ pub struct InnerClient { } impl InnerClient { - pub fn start(&mut self) -> Result { + pub fn start(&mut self) -> Result, Error> { self.responses.waiting += 1; Ok(PartialQuery(Some(self))) } @@ -227,7 +227,7 @@ impl Client { &mut self, statement: &str, params: I, - ) -> Result + ) -> Result, Error> where S: AsRef, I: IntoIterator>, @@ -262,7 +262,7 @@ impl Client { pub(crate) async fn simple_query_raw( &mut self, query: &str, - ) -> Result { + ) -> Result, Error> { simple_query::simple_query(self.inner_mut(), query).await } diff --git a/libs/proxy/tokio-postgres2/src/generic_client.rs b/libs/proxy/tokio-postgres2/src/generic_client.rs index eeefb45d26..4c5fc623c5 100644 --- a/libs/proxy/tokio-postgres2/src/generic_client.rs +++ b/libs/proxy/tokio-postgres2/src/generic_client.rs @@ -12,7 +12,11 @@ mod private { /// This trait is "sealed", and cannot be implemented outside of this crate. pub trait GenericClient: private::Sealed { /// Like `Client::query_raw_txt`. - async fn query_raw_txt(&mut self, statement: &str, params: I) -> Result + async fn query_raw_txt( + &mut self, + statement: &str, + params: I, + ) -> Result, Error> where S: AsRef + Sync + Send, I: IntoIterator> + Sync + Send, @@ -22,7 +26,11 @@ pub trait GenericClient: private::Sealed { impl private::Sealed for Client {} impl GenericClient for Client { - async fn query_raw_txt(&mut self, statement: &str, params: I) -> Result + async fn query_raw_txt( + &mut self, + statement: &str, + params: I, + ) -> Result, Error> where S: AsRef + Sync + Send, I: IntoIterator> + Sync + Send, @@ -35,7 +43,11 @@ impl GenericClient for Client { impl private::Sealed for Transaction<'_> {} impl GenericClient for Transaction<'_> { - async fn query_raw_txt(&mut self, statement: &str, params: I) -> Result + async fn query_raw_txt( + &mut self, + statement: &str, + params: I, + ) -> Result, Error> where S: AsRef + Sync + Send, I: IntoIterator> + Sync + Send, diff --git a/libs/proxy/tokio-postgres2/src/transaction.rs b/libs/proxy/tokio-postgres2/src/transaction.rs index 12fe0737d4..0e37d2aad7 100644 --- a/libs/proxy/tokio-postgres2/src/transaction.rs +++ b/libs/proxy/tokio-postgres2/src/transaction.rs @@ -47,7 +47,7 @@ impl<'a> Transaction<'a> { &mut self, statement: &str, params: I, - ) -> Result + ) -> Result, Error> where S: AsRef, I: IntoIterator>, diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index 8440d198df..f561df9202 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -164,21 +164,20 @@ async fn authenticate( })? .map_err(ConsoleRedirectError::from)?; - if auth_config.ip_allowlist_check_enabled { - if let Some(allowed_ips) = &db_info.allowed_ips { - if !auth::check_peer_addr_is_in_list(&ctx.peer_addr(), allowed_ips) { - return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr())); - } - } + if auth_config.ip_allowlist_check_enabled + && let Some(allowed_ips) = &db_info.allowed_ips + && !auth::check_peer_addr_is_in_list(&ctx.peer_addr(), allowed_ips) + { + return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr())); } // Check if the access over the public internet is allowed, otherwise block. Note that // the console redirect is not behind the VPC service endpoint, so we don't need to check // the VPC endpoint ID. - if let Some(public_access_allowed) = db_info.public_access_allowed { - if !public_access_allowed { - return Err(auth::AuthError::NetworkNotAllowed); - } + if let Some(public_access_allowed) = db_info.public_access_allowed + && !public_access_allowed + { + return Err(auth::AuthError::NetworkNotAllowed); } client.write_message(BeMessage::NoticeResponse("Connecting to database.")); diff --git a/proxy/src/auth/backend/jwt.rs b/proxy/src/auth/backend/jwt.rs index 5edc878243..a716890a00 100644 --- a/proxy/src/auth/backend/jwt.rs +++ b/proxy/src/auth/backend/jwt.rs @@ -399,36 +399,36 @@ impl JwkCacheEntryLock { tracing::debug!(?payload, "JWT signature valid with claims"); - if let Some(aud) = expected_audience { - if payload.audience.0.iter().all(|s| s != aud) { - return Err(JwtError::InvalidClaims( - JwtClaimsError::InvalidJwtTokenAudience, - )); - } + if let Some(aud) = expected_audience + && payload.audience.0.iter().all(|s| s != aud) + { + return Err(JwtError::InvalidClaims( + JwtClaimsError::InvalidJwtTokenAudience, + )); } let now = SystemTime::now(); - if let Some(exp) = payload.expiration { - if now >= exp + CLOCK_SKEW_LEEWAY { - return Err(JwtError::InvalidClaims(JwtClaimsError::JwtTokenHasExpired( - exp.duration_since(SystemTime::UNIX_EPOCH) - .unwrap_or_default() - .as_secs(), - ))); - } + if let Some(exp) = payload.expiration + && now >= exp + CLOCK_SKEW_LEEWAY + { + return Err(JwtError::InvalidClaims(JwtClaimsError::JwtTokenHasExpired( + exp.duration_since(SystemTime::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(), + ))); } - if let Some(nbf) = payload.not_before { - if nbf >= now + CLOCK_SKEW_LEEWAY { - return Err(JwtError::InvalidClaims( - JwtClaimsError::JwtTokenNotYetReadyToUse( - nbf.duration_since(SystemTime::UNIX_EPOCH) - .unwrap_or_default() - .as_secs(), - ), - )); - } + if let Some(nbf) = payload.not_before + && nbf >= now + CLOCK_SKEW_LEEWAY + { + return Err(JwtError::InvalidClaims( + JwtClaimsError::JwtTokenNotYetReadyToUse( + nbf.duration_since(SystemTime::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(), + ), + )); } Ok(ComputeCredentialKeys::JwtPayload(payloadb)) diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index 8fc3ea1978..e7805d8bfe 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -345,15 +345,13 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> { Err(e) => { // The password could have been changed, so we invalidate the cache. // We should only invalidate the cache if the TTL might have expired. - if e.is_password_failed() { - #[allow(irrefutable_let_patterns)] - if let ControlPlaneClient::ProxyV1(api) = &*api { - if let Some(ep) = &user_info.endpoint_id { - api.caches - .project_info - .maybe_invalidate_role_secret(ep, &user_info.user); - } - } + if e.is_password_failed() + && let ControlPlaneClient::ProxyV1(api) = &*api + && let Some(ep) = &user_info.endpoint_id + { + api.caches + .project_info + .maybe_invalidate_role_secret(ep, &user_info.user); } Err(e) diff --git a/proxy/src/binary/local_proxy.rs b/proxy/src/binary/local_proxy.rs index 04cc7b3907..401203d48c 100644 --- a/proxy/src/binary/local_proxy.rs +++ b/proxy/src/binary/local_proxy.rs @@ -7,9 +7,7 @@ use anyhow::bail; use arc_swap::ArcSwapOption; use camino::Utf8PathBuf; use clap::Parser; - use futures::future::Either; - use tokio::net::TcpListener; use tokio::sync::Notify; use tokio::task::JoinSet; @@ -22,9 +20,9 @@ use crate::auth::backend::jwt::JwkCache; use crate::auth::backend::local::LocalBackend; use crate::auth::{self}; use crate::cancellation::CancellationHandler; -use crate::config::refresh_config_loop; use crate::config::{ self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig, + refresh_config_loop, }; use crate::control_plane::locks::ApiLocks; use crate::http::health_server::AppMetrics; diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 7522dd5162..c10678dc68 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -10,11 +10,15 @@ use std::time::Duration; use anyhow::Context; use anyhow::{bail, ensure}; use arc_swap::ArcSwapOption; +#[cfg(any(test, feature = "testing"))] +use camino::Utf8PathBuf; use futures::future::Either; use itertools::{Itertools, Position}; use rand::{Rng, thread_rng}; use remote_storage::RemoteStorageConfig; use tokio::net::TcpListener; +#[cfg(any(test, feature = "testing"))] +use tokio::sync::Notify; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; use tracing::{Instrument, error, info, warn}; @@ -47,10 +51,6 @@ use crate::tls::client_config::compute_client_config_with_root_certs; #[cfg(any(test, feature = "testing"))] use crate::url::ApiUrl; use crate::{auth, control_plane, http, serverless, usage_metrics}; -#[cfg(any(test, feature = "testing"))] -use camino::Utf8PathBuf; -#[cfg(any(test, feature = "testing"))] -use tokio::sync::Notify; project_git_version!(GIT_VERSION); project_build_tag!(BUILD_TAG); @@ -520,54 +520,51 @@ pub async fn run() -> anyhow::Result<()> { maintenance_tasks.spawn(usage_metrics::task_main(metrics_config)); } - #[cfg_attr(not(any(test, feature = "testing")), expect(irrefutable_let_patterns))] - if let Either::Left(auth::Backend::ControlPlane(api, ())) = &auth_backend { - if let crate::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api { - if let Some(client) = redis_client { - // project info cache and invalidation of that cache. - let cache = api.caches.project_info.clone(); - maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone())); - maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); + if let Either::Left(auth::Backend::ControlPlane(api, ())) = &auth_backend + && let crate::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api + && let Some(client) = redis_client + { + // project info cache and invalidation of that cache. + let cache = api.caches.project_info.clone(); + maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone())); + maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); - // Try to connect to Redis 3 times with 1 + (0..0.1) second interval. - // This prevents immediate exit and pod restart, - // which can cause hammering of the redis in case of connection issues. - // cancellation key management - let mut redis_kv_client = RedisKVClient::new(client.clone()); - for attempt in (0..3).with_position() { - match redis_kv_client.try_connect().await { - Ok(()) => { - info!("Connected to Redis KV client"); - cancellation_handler.init_tx(BatchQueue::new(CancellationProcessor { - client: redis_kv_client, - batch_size: args.cancellation_batch_size, - })); + // Try to connect to Redis 3 times with 1 + (0..0.1) second interval. + // This prevents immediate exit and pod restart, + // which can cause hammering of the redis in case of connection issues. + // cancellation key management + let mut redis_kv_client = RedisKVClient::new(client.clone()); + for attempt in (0..3).with_position() { + match redis_kv_client.try_connect().await { + Ok(()) => { + info!("Connected to Redis KV client"); + cancellation_handler.init_tx(BatchQueue::new(CancellationProcessor { + client: redis_kv_client, + batch_size: args.cancellation_batch_size, + })); - break; - } - Err(e) => { - error!("Failed to connect to Redis KV client: {e}"); - if matches!(attempt, Position::Last(_)) { - bail!( - "Failed to connect to Redis KV client after {} attempts", - attempt.into_inner() - ); - } - let jitter = thread_rng().gen_range(0..100); - tokio::time::sleep(Duration::from_millis(1000 + jitter)).await; - } - } + break; + } + Err(e) => { + error!("Failed to connect to Redis KV client: {e}"); + if matches!(attempt, Position::Last(_)) { + bail!( + "Failed to connect to Redis KV client after {} attempts", + attempt.into_inner() + ); + } + let jitter = thread_rng().gen_range(0..100); + tokio::time::sleep(Duration::from_millis(1000 + jitter)).await; } - - // listen for notifications of new projects/endpoints/branches - let cache = api.caches.endpoints_cache.clone(); - let span = tracing::info_span!("endpoints_cache"); - maintenance_tasks.spawn( - async move { cache.do_read(client, cancellation_token.clone()).await } - .instrument(span), - ); } } + + // listen for notifications of new projects/endpoints/branches + let cache = api.caches.endpoints_cache.clone(); + let span = tracing::info_span!("endpoints_cache"); + maintenance_tasks.spawn( + async move { cache.do_read(client, cancellation_token.clone()).await }.instrument(span), + ); } let maintenance = loop { diff --git a/proxy/src/config.rs b/proxy/src/config.rs index d5e6e1e4cb..f97006e206 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -4,28 +4,26 @@ use std::time::Duration; use anyhow::{Context, Ok, bail, ensure}; use arc_swap::ArcSwapOption; +use camino::{Utf8Path, Utf8PathBuf}; use clap::ValueEnum; +use compute_api::spec::LocalProxySpec; use remote_storage::RemoteStorageConfig; +use thiserror::Error; +use tokio::sync::Notify; +use tracing::{debug, error, info, warn}; use crate::auth::backend::jwt::JwkCache; +use crate::auth::backend::local::JWKS_ROLE_MAP; use crate::control_plane::locks::ApiLocks; +use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings}; +use crate::ext::TaskExt; +use crate::intern::RoleNameInt; use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig}; use crate::scram::threadpool::ThreadPool; use crate::serverless::GlobalConnPoolOptions; use crate::serverless::cancel_set::CancelSet; pub use crate::tls::server_config::{TlsConfig, configure_tls}; -use crate::types::Host; - -use crate::auth::backend::local::JWKS_ROLE_MAP; -use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings}; -use crate::ext::TaskExt; -use crate::intern::RoleNameInt; -use crate::types::RoleName; -use camino::{Utf8Path, Utf8PathBuf}; -use compute_api::spec::LocalProxySpec; -use thiserror::Error; -use tokio::sync::Notify; -use tracing::{debug, error, info, warn}; +use crate::types::{Host, RoleName}; pub struct ProxyConfig { pub tls_config: ArcSwapOption, diff --git a/proxy/src/context/mod.rs b/proxy/src/context/mod.rs index df1c4e194a..7b0549e76f 100644 --- a/proxy/src/context/mod.rs +++ b/proxy/src/context/mod.rs @@ -209,11 +209,9 @@ impl RequestContext { if let Some(options_str) = options.get("options") { // If not found directly, try to extract it from the options string for option in options_str.split_whitespace() { - if option.starts_with("neon_query_id:") { - if let Some(value) = option.strip_prefix("neon_query_id:") { - this.set_testodrome_id(value.into()); - break; - } + if let Some(value) = option.strip_prefix("neon_query_id:") { + this.set_testodrome_id(value.into()); + break; } } } diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index 8c76d034f7..fbacc97661 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -250,10 +250,8 @@ impl NeonControlPlaneClient { info!(duration = ?start.elapsed(), "received http response"); let body = parse_body::(response.status(), response.bytes().await?)?; - // Unfortunately, ownership won't let us use `Option::ok_or` here. - let (host, port) = match parse_host_port(&body.address) { - None => return Err(WakeComputeError::BadComputeAddress(body.address)), - Some(x) => x, + let Some((host, port)) = parse_host_port(&body.address) else { + return Err(WakeComputeError::BadComputeAddress(body.address)); }; let host_addr = IpAddr::from_str(host).ok(); diff --git a/proxy/src/logging.rs b/proxy/src/logging.rs index 2e444164df..e608300bd2 100644 --- a/proxy/src/logging.rs +++ b/proxy/src/logging.rs @@ -271,18 +271,18 @@ where }); // In case logging fails we generate a simpler JSON object. - if let Err(err) = res { - if let Ok(mut line) = serde_json::to_vec(&serde_json::json!( { + if let Err(err) = res + && let Ok(mut line) = serde_json::to_vec(&serde_json::json!( { "timestamp": now.to_rfc3339_opts(chrono::SecondsFormat::Micros, true), "level": "ERROR", "message": format_args!("cannot log event: {err:?}"), "fields": { "event": format_args!("{event:?}"), }, - })) { - line.push(b'\n'); - self.writer.make_writer().write_all(&line).ok(); - } + })) + { + line.push(b'\n'); + self.writer.make_writer().write_all(&line).ok(); } } @@ -583,10 +583,11 @@ impl EventFormatter { THREAD_ID.with(|tid| serializer.serialize_entry("thread_id", tid))?; // TODO: tls cache? name could change - if let Some(thread_name) = std::thread::current().name() { - if !thread_name.is_empty() && thread_name != "tokio-runtime-worker" { - serializer.serialize_entry("thread_name", thread_name)?; - } + if let Some(thread_name) = std::thread::current().name() + && !thread_name.is_empty() + && thread_name != "tokio-runtime-worker" + { + serializer.serialize_entry("thread_name", thread_name)?; } if let Some(task_id) = tokio::task::try_id() { @@ -596,10 +597,10 @@ impl EventFormatter { serializer.serialize_entry("target", meta.target())?; // Skip adding module if it's the same as target. - if let Some(module) = meta.module_path() { - if module != meta.target() { - serializer.serialize_entry("module", module)?; - } + if let Some(module) = meta.module_path() + && module != meta.target() + { + serializer.serialize_entry("module", module)?; } if let Some(file) = meta.file() { diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index 4c340edfd5..7a21e4ecee 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -236,13 +236,6 @@ pub enum Bool { False, } -#[derive(FixedCardinalityLabel, Copy, Clone)] -#[label(singleton = "outcome")] -pub enum Outcome { - Success, - Failed, -} - #[derive(FixedCardinalityLabel, Copy, Clone)] #[label(singleton = "outcome")] pub enum CacheOutcome { diff --git a/proxy/src/pglb/copy_bidirectional.rs b/proxy/src/pglb/copy_bidirectional.rs index 97f8d7c6af..5e4262a323 100644 --- a/proxy/src/pglb/copy_bidirectional.rs +++ b/proxy/src/pglb/copy_bidirectional.rs @@ -90,27 +90,27 @@ where // TODO: 1 info log, with a enum label for close direction. // Early termination checks from compute to client. - if let TransferState::Done(_) = compute_to_client { - if let TransferState::Running(buf) = &client_to_compute { - info!("Compute is done, terminate client"); - // Initiate shutdown - client_to_compute = TransferState::ShuttingDown(buf.amt); - client_to_compute_result = - transfer_one_direction(cx, &mut client_to_compute, client, compute) - .map_err(ErrorSource::from_client)?; - } + if let TransferState::Done(_) = compute_to_client + && let TransferState::Running(buf) = &client_to_compute + { + info!("Compute is done, terminate client"); + // Initiate shutdown + client_to_compute = TransferState::ShuttingDown(buf.amt); + client_to_compute_result = + transfer_one_direction(cx, &mut client_to_compute, client, compute) + .map_err(ErrorSource::from_client)?; } // Early termination checks from client to compute. - if let TransferState::Done(_) = client_to_compute { - if let TransferState::Running(buf) = &compute_to_client { - info!("Client is done, terminate compute"); - // Initiate shutdown - compute_to_client = TransferState::ShuttingDown(buf.amt); - compute_to_client_result = - transfer_one_direction(cx, &mut compute_to_client, compute, client) - .map_err(ErrorSource::from_compute)?; - } + if let TransferState::Done(_) = client_to_compute + && let TransferState::Running(buf) = &compute_to_client + { + info!("Client is done, terminate compute"); + // Initiate shutdown + compute_to_client = TransferState::ShuttingDown(buf.amt); + compute_to_client_result = + transfer_one_direction(cx, &mut compute_to_client, compute, client) + .map_err(ErrorSource::from_compute)?; } // It is not a problem if ready! returns early ... (comment remains the same) diff --git a/proxy/src/rate_limiter/leaky_bucket.rs b/proxy/src/rate_limiter/leaky_bucket.rs index f7e54ebfe7..12b4bda0c0 100644 --- a/proxy/src/rate_limiter/leaky_bucket.rs +++ b/proxy/src/rate_limiter/leaky_bucket.rs @@ -39,7 +39,11 @@ impl LeakyBucketRateLimiter { let config = config.map_or(self.default_config, Into::into); - if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 { + if self + .access_count + .fetch_add(1, Ordering::AcqRel) + .is_multiple_of(2048) + { self.do_gc(now); } diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index 2e40f5bf60..61d4636c2b 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -211,7 +211,11 @@ impl BucketRateLimiter { // worst case memory usage is about: // = 2 * 2048 * 64 * (48B + 72B) // = 30MB - if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 { + if self + .access_count + .fetch_add(1, Ordering::AcqRel) + .is_multiple_of(2048) + { self.do_gc(); } diff --git a/proxy/src/redis/cancellation_publisher.rs b/proxy/src/redis/cancellation_publisher.rs deleted file mode 100644 index 6f56aeea06..0000000000 --- a/proxy/src/redis/cancellation_publisher.rs +++ /dev/null @@ -1,79 +0,0 @@ -use core::net::IpAddr; -use std::sync::Arc; - -use tokio::sync::Mutex; -use uuid::Uuid; - -use crate::pqproto::CancelKeyData; - -pub trait CancellationPublisherMut: Send + Sync + 'static { - #[allow(async_fn_in_trait)] - async fn try_publish( - &mut self, - cancel_key_data: CancelKeyData, - session_id: Uuid, - peer_addr: IpAddr, - ) -> anyhow::Result<()>; -} - -pub trait CancellationPublisher: Send + Sync + 'static { - #[allow(async_fn_in_trait)] - async fn try_publish( - &self, - cancel_key_data: CancelKeyData, - session_id: Uuid, - peer_addr: IpAddr, - ) -> anyhow::Result<()>; -} - -impl CancellationPublisher for () { - async fn try_publish( - &self, - _cancel_key_data: CancelKeyData, - _session_id: Uuid, - _peer_addr: IpAddr, - ) -> anyhow::Result<()> { - Ok(()) - } -} - -impl CancellationPublisherMut for P { - async fn try_publish( - &mut self, - cancel_key_data: CancelKeyData, - session_id: Uuid, - peer_addr: IpAddr, - ) -> anyhow::Result<()> { -

::try_publish(self, cancel_key_data, session_id, peer_addr) - .await - } -} - -impl CancellationPublisher for Option

{ - async fn try_publish( - &self, - cancel_key_data: CancelKeyData, - session_id: Uuid, - peer_addr: IpAddr, - ) -> anyhow::Result<()> { - if let Some(p) = self { - p.try_publish(cancel_key_data, session_id, peer_addr).await - } else { - Ok(()) - } - } -} - -impl CancellationPublisher for Arc> { - async fn try_publish( - &self, - cancel_key_data: CancelKeyData, - session_id: Uuid, - peer_addr: IpAddr, - ) -> anyhow::Result<()> { - self.lock() - .await - .try_publish(cancel_key_data, session_id, peer_addr) - .await - } -} diff --git a/proxy/src/redis/connection_with_credentials_provider.rs b/proxy/src/redis/connection_with_credentials_provider.rs index 0465493799..35a3fe4334 100644 --- a/proxy/src/redis/connection_with_credentials_provider.rs +++ b/proxy/src/redis/connection_with_credentials_provider.rs @@ -1,11 +1,11 @@ -use std::convert::Infallible; -use std::sync::{Arc, atomic::AtomicBool, atomic::Ordering}; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; use futures::FutureExt; use redis::aio::{ConnectionLike, MultiplexedConnection}; use redis::{ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisResult}; -use tokio::task::JoinHandle; +use tokio::task::AbortHandle; use tracing::{error, info, warn}; use super::elasticache::CredentialsProvider; @@ -32,7 +32,7 @@ pub struct ConnectionWithCredentialsProvider { credentials: Credentials, // TODO: with more load on the connection, we should consider using a connection pool con: Option, - refresh_token_task: Option>, + refresh_token_task: Option, mutex: tokio::sync::Mutex<()>, credentials_refreshed: Arc, } @@ -127,7 +127,7 @@ impl ConnectionWithCredentialsProvider { credentials_provider, credentials_refreshed, )); - self.refresh_token_task = Some(f); + self.refresh_token_task = Some(f.abort_handle()); } match Self::ping(&mut con).await { Ok(()) => { @@ -179,7 +179,7 @@ impl ConnectionWithCredentialsProvider { mut con: MultiplexedConnection, credentials_provider: Arc, credentials_refreshed: Arc, - ) -> Infallible { + ) -> ! { loop { // The connection lives for 12h, for the sanity check we refresh it every hour. tokio::time::sleep(Duration::from_secs(60 * 60)).await; @@ -244,7 +244,7 @@ impl ConnectionLike for ConnectionWithCredentialsProvider { &'a mut self, cmd: &'a redis::Cmd, ) -> redis::RedisFuture<'a, redis::Value> { - (async move { self.send_packed_command(cmd).await }).boxed() + self.send_packed_command(cmd).boxed() } fn req_packed_commands<'a>( @@ -253,10 +253,10 @@ impl ConnectionLike for ConnectionWithCredentialsProvider { offset: usize, count: usize, ) -> redis::RedisFuture<'a, Vec> { - (async move { self.send_packed_commands(cmd, offset, count).await }).boxed() + self.send_packed_commands(cmd, offset, count).boxed() } fn get_db(&self) -> i64 { - 0 + self.con.as_ref().map_or(0, |c| c.get_db()) } } diff --git a/proxy/src/redis/mod.rs b/proxy/src/redis/mod.rs index 8b46a8e6ca..4f5e24ab5f 100644 --- a/proxy/src/redis/mod.rs +++ b/proxy/src/redis/mod.rs @@ -1,4 +1,3 @@ -pub mod cancellation_publisher; pub mod connection_with_credentials_provider; pub mod elasticache; pub mod keys; diff --git a/proxy/src/sasl/channel_binding.rs b/proxy/src/sasl/channel_binding.rs index e548cf3a83..fcc262f415 100644 --- a/proxy/src/sasl/channel_binding.rs +++ b/proxy/src/sasl/channel_binding.rs @@ -54,9 +54,7 @@ impl ChannelBinding { "eSws".into() } Self::Required(mode) => { - use std::io::Write; - let mut cbind_input = vec![]; - write!(&mut cbind_input, "p={mode},,",).unwrap(); + let mut cbind_input = format!("p={mode},,",).into_bytes(); cbind_input.extend_from_slice(get_cbind_data(mode)?); BASE64_STANDARD.encode(&cbind_input).into() } diff --git a/proxy/src/scram/exchange.rs b/proxy/src/scram/exchange.rs index 3ba8a79368..a0918fca9f 100644 --- a/proxy/src/scram/exchange.rs +++ b/proxy/src/scram/exchange.rs @@ -107,7 +107,7 @@ pub(crate) async fn exchange( secret: &ServerSecret, password: &[u8], ) -> sasl::Result> { - let salt = BASE64_STANDARD.decode(&secret.salt_base64)?; + let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?; let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await; if secret.is_password_invalid(&client_key).into() { diff --git a/proxy/src/scram/messages.rs b/proxy/src/scram/messages.rs index 42039f099c..c0073917a1 100644 --- a/proxy/src/scram/messages.rs +++ b/proxy/src/scram/messages.rs @@ -87,13 +87,20 @@ impl<'a> ClientFirstMessage<'a> { salt_base64: &str, iterations: u32, ) -> OwnedServerFirstMessage { - use std::fmt::Write; + let mut message = String::with_capacity(128); + message.push_str("r="); - let mut message = String::new(); - write!(&mut message, "r={}", self.nonce).unwrap(); + // write combined nonce + let combined_nonce_start = message.len(); + message.push_str(self.nonce); BASE64_STANDARD.encode_string(nonce, &mut message); - let combined_nonce = 2..message.len(); - write!(&mut message, ",s={salt_base64},i={iterations}").unwrap(); + let combined_nonce = combined_nonce_start..message.len(); + + // write salt and iterations + message.push_str(",s="); + message.push_str(salt_base64); + message.push_str(",i="); + message.push_str(itoa::Buffer::new().format(iterations)); // This design guarantees that it's impossible to create a // server-first-message without receiving a client-first-message diff --git a/proxy/src/scram/secret.rs b/proxy/src/scram/secret.rs index f03617f34d..0e070c2f27 100644 --- a/proxy/src/scram/secret.rs +++ b/proxy/src/scram/secret.rs @@ -14,7 +14,7 @@ pub(crate) struct ServerSecret { /// Number of iterations for `PBKDF2` function. pub(crate) iterations: u32, /// Salt used to hash user's password. - pub(crate) salt_base64: String, + pub(crate) salt_base64: Box, /// Hashed `ClientKey`. pub(crate) stored_key: ScramKey, /// Used by client to verify server's signature. @@ -35,7 +35,7 @@ impl ServerSecret { let secret = ServerSecret { iterations: iterations.parse().ok()?, - salt_base64: salt.to_owned(), + salt_base64: salt.into(), stored_key: base64_decode_array(stored_key)?.into(), server_key: base64_decode_array(server_key)?.into(), doomed: false, @@ -58,7 +58,7 @@ impl ServerSecret { // iteration count 1 for our generated passwords going forward. // PG16 users can set iteration count=1 already today. iterations: 1, - salt_base64: BASE64_STANDARD.encode(nonce), + salt_base64: BASE64_STANDARD.encode(nonce).into_boxed_str(), stored_key: ScramKey::default(), server_key: ScramKey::default(), doomed: true, @@ -88,7 +88,7 @@ mod tests { let parsed = ServerSecret::parse(&secret).unwrap(); assert_eq!(parsed.iterations, iterations); - assert_eq!(parsed.salt_base64, salt); + assert_eq!(&*parsed.salt_base64, salt); assert_eq!(BASE64_STANDARD.encode(parsed.stored_key), stored_key); assert_eq!(BASE64_STANDARD.encode(parsed.server_key), server_key); diff --git a/proxy/src/scram/threadpool.rs b/proxy/src/scram/threadpool.rs index 8f1684c75b..1aa402227f 100644 --- a/proxy/src/scram/threadpool.rs +++ b/proxy/src/scram/threadpool.rs @@ -137,7 +137,7 @@ impl Future for JobSpec { let state = state.as_mut().expect("should be set on thread startup"); state.tick = state.tick.wrapping_add(1); - if state.tick % SKETCH_RESET_INTERVAL == 0 { + if state.tick.is_multiple_of(SKETCH_RESET_INTERVAL) { state.countmin.reset(); } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 4b3f379e76..daa6429039 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -349,11 +349,11 @@ impl PoolingBackend { debug!("setting up backend session state"); // initiates the auth session - if !disable_pg_session_jwt { - if let Err(e) = client.batch_execute("select auth.init();").await { - discard.discard(); - return Err(e.into()); - } + if !disable_pg_session_jwt + && let Err(e) = client.batch_execute("select auth.init();").await + { + discard.discard(); + return Err(e.into()); } info!("backend session state initialized"); diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index dd8cf052c5..672e59f81f 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -148,11 +148,10 @@ pub(crate) fn poll_client( } // remove from connection pool - if let Some(pool) = pool.clone().upgrade() { - if pool.write().remove_client(db_user.clone(), conn_id) { + if let Some(pool) = pool.clone().upgrade() + && pool.write().remove_client(db_user.clone(), conn_id) { info!("closed connection removed"); } - } Poll::Ready(()) }).await; diff --git a/proxy/src/serverless/http_conn_pool.rs b/proxy/src/serverless/http_conn_pool.rs index 18f7ecc0b1..7acd816026 100644 --- a/proxy/src/serverless/http_conn_pool.rs +++ b/proxy/src/serverless/http_conn_pool.rs @@ -2,6 +2,8 @@ use std::collections::VecDeque; use std::sync::atomic::{self, AtomicUsize}; use std::sync::{Arc, Weak}; +use bytes::Bytes; +use http_body_util::combinators::BoxBody; use hyper::client::conn::http2; use hyper_util::rt::{TokioExecutor, TokioIo}; use parking_lot::RwLock; @@ -20,8 +22,6 @@ use crate::metrics::{HttpEndpointPoolsGuard, Metrics}; use crate::protocol2::ConnectionInfoExtra; use crate::types::EndpointCacheKey; use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; -use bytes::Bytes; -use http_body_util::combinators::BoxBody; pub(crate) type Send = http2::SendRequest>; pub(crate) type Connect = @@ -240,10 +240,10 @@ pub(crate) fn poll_http2_client( } // remove from connection pool - if let Some(pool) = pool.clone().upgrade() { - if pool.write().remove_conn(conn_id) { - info!("closed connection removed"); - } + if let Some(pool) = pool.clone().upgrade() + && pool.write().remove_conn(conn_id) + { + info!("closed connection removed"); } } .instrument(span), diff --git a/proxy/src/serverless/http_util.rs b/proxy/src/serverless/http_util.rs index c876d8f096..0c91ac6835 100644 --- a/proxy/src/serverless/http_util.rs +++ b/proxy/src/serverless/http_util.rs @@ -12,8 +12,7 @@ use serde::Serialize; use url::Url; use uuid::Uuid; -use super::conn_pool::AuthData; -use super::conn_pool::ConnInfoWithAuth; +use super::conn_pool::{AuthData, ConnInfoWithAuth}; use super::conn_pool_lib::ConnInfo; use super::error::{ConnInfoError, Credentials}; use crate::auth::backend::ComputeUserInfo; diff --git a/proxy/src/serverless/local_conn_pool.rs b/proxy/src/serverless/local_conn_pool.rs index c367615fb8..e4cbd02bfe 100644 --- a/proxy/src/serverless/local_conn_pool.rs +++ b/proxy/src/serverless/local_conn_pool.rs @@ -249,11 +249,10 @@ pub(crate) fn poll_client( } // remove from connection pool - if let Some(pool) = pool.clone().upgrade() { - if pool.global_pool.write().remove_client(db_user.clone(), conn_id) { + if let Some(pool) = pool.clone().upgrade() + && pool.global_pool.write().remove_client(db_user.clone(), conn_id) { info!("closed connection removed"); } - } Poll::Ready(()) }).await; diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index a901a47746..7a718d0280 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -1,23 +1,25 @@ +use std::pin::pin; +use std::sync::Arc; + use bytes::Bytes; use futures::future::{Either, select, try_join}; use futures::{StreamExt, TryFutureExt}; -use http::{Method, header::AUTHORIZATION}; -use http_body_util::{BodyExt, Full, combinators::BoxBody}; +use http::Method; +use http::header::AUTHORIZATION; +use http_body_util::combinators::BoxBody; +use http_body_util::{BodyExt, Full}; use http_utils::error::ApiError; use hyper::body::Incoming; -use hyper::{ - Request, Response, StatusCode, header, - http::{HeaderName, HeaderValue}, -}; +use hyper::http::{HeaderName, HeaderValue}; +use hyper::{Request, Response, StatusCode, header}; use indexmap::IndexMap; use postgres_client::error::{DbError, ErrorPosition, SqlState}; use postgres_client::{ GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction, }; use serde::Serialize; -use serde_json::{Value, value::RawValue}; -use std::pin::pin; -use std::sync::Arc; +use serde_json::Value; +use serde_json::value::RawValue; use tokio::time::{self, Instant}; use tokio_util::sync::CancellationToken; use tracing::{Level, debug, error, info}; @@ -33,7 +35,6 @@ use super::http_util::{ }; use super::json::{JsonConversionError, json_to_pg_text, pg_text_row_to_json}; use crate::auth::backend::ComputeCredentialKeys; - use crate::config::{HttpConfig, ProxyConfig}; use crate::context::RequestContext; use crate::error::{ErrorKind, ReportableError, UserFacingError}; diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index c49a431c95..4e55654515 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -199,27 +199,27 @@ impl PqStream { let probe_msg; let mut msg = &*msg; - if let Some(ctx) = ctx { - if ctx.get_testodrome_id().is_some() { - let tag = match error_kind { - ErrorKind::User => "client", - ErrorKind::ClientDisconnect => "client", - ErrorKind::RateLimit => "proxy", - ErrorKind::ServiceRateLimit => "proxy", - ErrorKind::Quota => "proxy", - ErrorKind::Service => "proxy", - ErrorKind::ControlPlane => "controlplane", - ErrorKind::Postgres => "other", - ErrorKind::Compute => "compute", - }; - probe_msg = typed_json::json!({ - "tag": tag, - "msg": msg, - "cold_start_info": ctx.cold_start_info(), - }) - .to_string(); - msg = &probe_msg; - } + if let Some(ctx) = ctx + && ctx.get_testodrome_id().is_some() + { + let tag = match error_kind { + ErrorKind::User => "client", + ErrorKind::ClientDisconnect => "client", + ErrorKind::RateLimit => "proxy", + ErrorKind::ServiceRateLimit => "proxy", + ErrorKind::Quota => "proxy", + ErrorKind::Service => "proxy", + ErrorKind::ControlPlane => "controlplane", + ErrorKind::Postgres => "other", + ErrorKind::Compute => "compute", + }; + probe_msg = typed_json::json!({ + "tag": tag, + "msg": msg, + "cold_start_info": ctx.cold_start_info(), + }) + .to_string(); + msg = &probe_msg; } // TODO: either preserve the error code from postgres, or assign error codes to proxy errors.