diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index a7133b22e5..8440d198df 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -14,9 +14,9 @@ use crate::context::RequestContext; use crate::control_plane::client::cplane_proxy_v1; use crate::control_plane::{self, CachedNodeInfo, NodeInfo}; use crate::error::{ReportableError, UserFacingError}; -use crate::pglb::connect_compute::WakeComputeBackend; use crate::pqproto::BeMessage; use crate::proxy::NeonOptions; +use crate::proxy::wake_compute::WakeComputeBackend; use crate::stream::PqStream; use crate::types::RoleName; use crate::{auth, compute, waiters}; diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index a153ea4d42..7ceb1e6814 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -25,9 +25,9 @@ use crate::control_plane::{ RoleAccessControl, }; use crate::intern::EndpointIdInt; -use crate::pglb::connect_compute::WakeComputeBackend; use crate::pqproto::BeMessage; use crate::proxy::NeonOptions; +use crate::proxy::wake_compute::WakeComputeBackend; use crate::rate_limiter::EndpointRateLimiter; use crate::stream::Stream; use crate::types::{EndpointCacheKey, EndpointId, RoleName}; diff --git a/proxy/src/binary/pg_sni_router.rs b/proxy/src/binary/pg_sni_router.rs index a4f517fead..481bd8501c 100644 --- a/proxy/src/binary/pg_sni_router.rs +++ b/proxy/src/binary/pg_sni_router.rs @@ -28,10 +28,9 @@ use crate::context::RequestContext; use crate::metrics::{Metrics, ThreadPoolMetrics}; use crate::pqproto::FeStartupPacket; use crate::protocol2::ConnectionInfo; -use crate::proxy::{ - ErrorSource, TlsRequired, copy_bidirectional_client_compute, run_until_cancelled, -}; +use crate::proxy::{ErrorSource, TlsRequired, copy_bidirectional_client_compute}; use crate::stream::{PqStream, Stream}; +use crate::util::run_until_cancelled; project_git_version!(GIT_VERSION); diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 324dcf5824..5331ea41fd 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -11,13 +11,12 @@ use crate::config::{ProxyConfig, ProxyProtocolV2}; use crate::context::RequestContext; use crate::error::ReportableError; use crate::metrics::{Metrics, NumClientConnectionsGuard}; -use crate::pglb::connect_compute::{TcpMechanism, connect_to_compute}; use crate::pglb::handshake::{HandshakeData, handshake}; use crate::pglb::passthrough::ProxyPassthrough; use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol}; -use crate::proxy::{ - ClientRequestError, ErrorSource, prepare_client_connection, run_until_cancelled, -}; +use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute}; +use crate::proxy::{ClientRequestError, ErrorSource, prepare_client_connection}; +use crate::util::run_until_cancelled; pub async fn task_main( config: &'static ProxyConfig, diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index d65d056585..026c6aeba9 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -106,4 +106,5 @@ mod tls; mod types; mod url; mod usage_metrics; +mod util; mod waiters; diff --git a/proxy/src/pglb/mod.rs b/proxy/src/pglb/mod.rs index 4b107142a7..cb82524cf6 100644 --- a/proxy/src/pglb/mod.rs +++ b/proxy/src/pglb/mod.rs @@ -1,4 +1,3 @@ -pub mod connect_compute; pub mod copy_bidirectional; pub mod handshake; pub mod inprocess; diff --git a/proxy/src/pglb/connect_compute.rs b/proxy/src/proxy/connect_compute.rs similarity index 94% rename from proxy/src/pglb/connect_compute.rs rename to proxy/src/proxy/connect_compute.rs index 4984588f75..92ed84f50f 100644 --- a/proxy/src/pglb/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -8,19 +8,19 @@ use crate::config::{ComputeConfig, RetryConfig}; use crate::context::RequestContext; use crate::control_plane::errors::WakeComputeError; use crate::control_plane::locks::ApiLocks; -use crate::control_plane::{self, CachedNodeInfo, NodeInfo}; +use crate::control_plane::{self, NodeInfo}; use crate::error::ReportableError; use crate::metrics::{ ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType, }; use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry}; -use crate::proxy::wake_compute::wake_compute; +use crate::proxy::wake_compute::{WakeComputeBackend, wake_compute}; use crate::types::Host; /// If we couldn't connect, a cached connection info might be to blame /// (e.g. the compute node's address might've changed at the wrong time). /// Invalidate the cache entry (if any) to prevent subsequent errors. -#[tracing::instrument(name = "invalidate_cache", skip_all)] +#[tracing::instrument(skip_all)] pub(crate) fn invalidate_cache(node_info: control_plane::CachedNodeInfo) -> NodeInfo { let is_cached = node_info.cached(); if is_cached { @@ -49,14 +49,6 @@ pub(crate) trait ConnectMechanism { ) -> Result; } -#[async_trait] -pub(crate) trait WakeComputeBackend { - async fn wake_compute( - &self, - ctx: &RequestContext, - ) -> Result; -} - pub(crate) struct TcpMechanism { pub(crate) auth: AuthInfo, /// connect_to_compute concurrency lock diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 4a294c1e82..4211406f6c 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -1,8 +1,10 @@ #[cfg(test)] mod tests; +pub(crate) mod connect_compute; pub(crate) mod retry; pub(crate) mod wake_compute; + use std::sync::Arc; use futures::FutureExt; @@ -21,15 +23,16 @@ use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig}; use crate::context::RequestContext; use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumClientConnectionsGuard}; -use crate::pglb::connect_compute::{TcpMechanism, connect_to_compute}; pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute}; use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake}; use crate::pglb::passthrough::ProxyPassthrough; use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams}; use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol}; +use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute}; use crate::rate_limiter::EndpointRateLimiter; use crate::stream::{PqStream, Stream}; use crate::types::EndpointCacheKey; +use crate::util::run_until_cancelled; use crate::{auth, compute}; const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; @@ -46,21 +49,6 @@ impl ReportableError for TlsRequired { impl UserFacingError for TlsRequired {} -pub async fn run_until_cancelled( - f: F, - cancellation_token: &CancellationToken, -) -> Option { - match futures::future::select( - std::pin::pin!(f), - std::pin::pin!(cancellation_token.cancelled()), - ) - .await - { - futures::future::Either::Left((f, _)) => Some(f), - futures::future::Either::Right(((), _)) => None, - } -} - pub async fn task_main( config: &'static ProxyConfig, auth_backend: &'static auth::Backend<'static, ()>, diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index de85bf9df2..12de5cbc09 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -25,7 +25,7 @@ use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient}; use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status}; use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache}; use crate::error::ErrorKind; -use crate::pglb::connect_compute::ConnectMechanism; +use crate::proxy::connect_compute::ConnectMechanism; use crate::tls::client_config::compute_client_config_with_certs; use crate::tls::server_config::CertResolver; use crate::types::{BranchId, EndpointId, ProjectId}; diff --git a/proxy/src/proxy/wake_compute.rs b/proxy/src/proxy/wake_compute.rs index fd59800745..b8edf9fd5c 100644 --- a/proxy/src/proxy/wake_compute.rs +++ b/proxy/src/proxy/wake_compute.rs @@ -1,3 +1,4 @@ +use async_trait::async_trait; use tracing::{error, info}; use crate::config::RetryConfig; @@ -8,7 +9,6 @@ use crate::error::ReportableError; use crate::metrics::{ ConnectOutcome, ConnectionFailuresBreakdownGroup, Metrics, RetriesMetricGroup, RetryType, }; -use crate::pglb::connect_compute::WakeComputeBackend; use crate::proxy::retry::{retry_after, should_retry}; // Use macro to retain original callsite. @@ -23,6 +23,11 @@ macro_rules! log_wake_compute_error { }; } +#[async_trait] +pub(crate) trait WakeComputeBackend { + async fn wake_compute(&self, ctx: &RequestContext) -> Result; +} + pub(crate) async fn wake_compute( num_retries: &mut u32, ctx: &RequestContext, diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index d74e3cad3d..316e038344 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -34,7 +34,7 @@ use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError}; use crate::control_plane::locks::ApiLocks; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::intern::EndpointIdInt; -use crate::pglb::connect_compute::ConnectMechanism; +use crate::proxy::connect_compute::ConnectMechanism; use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute}; use crate::rate_limiter::EndpointRateLimiter; use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX}; @@ -181,7 +181,7 @@ impl PoolingBackend { tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "pool: opening a new connection '{conn_info}'"); let backend = self.auth_backend.as_ref().map(|()| keys.info); - crate::pglb::connect_compute::connect_to_compute( + crate::proxy::connect_compute::connect_to_compute( ctx, &TokioMechanism { conn_id, @@ -223,7 +223,7 @@ impl PoolingBackend { )), options: conn_info.user_info.options.clone(), }); - crate::pglb::connect_compute::connect_to_compute( + crate::proxy::connect_compute::connect_to_compute( ctx, &HyperMechanism { conn_id, diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index f6f681ac45..ed33bf1246 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -50,10 +50,10 @@ use crate::context::RequestContext; use crate::ext::TaskExt; use crate::metrics::Metrics; use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol}; -use crate::proxy::run_until_cancelled; use crate::rate_limiter::EndpointRateLimiter; use crate::serverless::backend::PoolingBackend; use crate::serverless::http_util::{api_error_into_response, json_response}; +use crate::util::run_until_cancelled; pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api"; pub(crate) const AUTH_BROKER_SNI: &str = "apiauth"; diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index eb80ac9ad0..b2eb801f5c 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -41,10 +41,11 @@ use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::http::{ReadBodyError, read_body_with_limit}; use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind}; use crate::pqproto::StartupMessageParams; -use crate::proxy::{NeonOptions, run_until_cancelled}; +use crate::proxy::NeonOptions; use crate::serverless::backend::HttpConnError; use crate::types::{DbName, RoleName}; use crate::usage_metrics::{MetricCounter, MetricCounterRecorder}; +use crate::util::run_until_cancelled; #[derive(serde::Deserialize)] #[serde(rename_all = "camelCase")] diff --git a/proxy/src/util.rs b/proxy/src/util.rs new file mode 100644 index 0000000000..7fc2d9fbdb --- /dev/null +++ b/proxy/src/util.rs @@ -0,0 +1,14 @@ +use std::pin::pin; + +use futures::future::{Either, select}; +use tokio_util::sync::CancellationToken; + +pub async fn run_until_cancelled( + f: F, + cancellation_token: &CancellationToken, +) -> Option { + match select(pin!(f), pin!(cancellation_token.cancelled())).await { + Either::Left((f, _)) => Some(f), + Either::Right(((), _)) => None, + } +}