diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index fe6886e31e..b012714c62 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -507,6 +507,25 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials> { } } +pub struct ControlPlaneWakeCompute<'a> { + pub cplane: &'a ControlPlaneClient, + pub creds: ComputeCredentials, +} + +#[async_trait::async_trait] +impl ComputeConnectBackend for ControlPlaneWakeCompute<'_> { + async fn wake_compute( + &self, + ctx: &RequestContext, + ) -> Result { + self.cplane.wake_compute(ctx, &self.creds.info).await + } + + fn get_keys(&self) -> &ComputeCredentialKeys { + &self.creds.keys + } +} + #[cfg(test)] mod tests { #![allow(clippy::unimplemented, clippy::unwrap_used)] diff --git a/proxy/src/auth/mod.rs b/proxy/src/auth/mod.rs index 5670f8e43d..c9df8f5fd9 100644 --- a/proxy/src/auth/mod.rs +++ b/proxy/src/auth/mod.rs @@ -1,7 +1,7 @@ //! Client authentication mechanisms. pub mod backend; -pub use backend::Backend; +pub use backend::{Backend, ControlPlaneWakeCompute}; mod credentials; pub(crate) use credentials::{ diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 1b68f328f9..971ca273db 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -229,13 +229,13 @@ pub(crate) async fn handle_client( let mut node = connect_to_compute( ctx, - &TcpMechanism { + TcpMechanism { user_info, params_compat: true, params: ¶ms, locks: &config.connect_compute_locks, }, - &node_info, + node_info, config.wake_compute_retry_config, &config.connect_to_compute, ) diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index e013fbbe2e..ada6bd07dd 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -53,6 +53,25 @@ pub(crate) trait ConnectMechanism { fn update_connect_config(&self, conf: &mut compute::ConnCfg); } +#[async_trait] +impl ConnectMechanism for &T { + type Connection = T::Connection; + type ConnectError = T::ConnectError; + type Error = T::Error; + async fn connect_once( + &self, + ctx: &RequestContext, + node_info: &control_plane::CachedNodeInfo, + config: &ComputeConfig, + ) -> Result { + T::connect_once(self, ctx, node_info, config).await + } + + fn update_connect_config(&self, conf: &mut compute::ConnCfg) { + T::update_connect_config(self, conf); + } +} + #[async_trait] pub(crate) trait ComputeConnectBackend { async fn wake_compute( @@ -105,8 +124,8 @@ impl ConnectMechanism for TcpMechanism<'_> { #[tracing::instrument(skip_all)] pub(crate) async fn connect_to_compute( ctx: &RequestContext, - mechanism: &M, - user_info: &B, + mechanism: M, + backend: B, wake_compute_retry_config: RetryConfig, compute: &ComputeConfig, ) -> Result @@ -116,9 +135,9 @@ where { let mut num_retries = 0; let mut node_info = - wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?; + wake_compute(&mut num_retries, ctx, &backend, wake_compute_retry_config).await?; - node_info.set_keys(user_info.get_keys()); + node_info.set_keys(backend.get_keys()); mechanism.update_connect_config(&mut node_info.config); // try once @@ -159,7 +178,7 @@ where let old_node_info = invalidate_cache(node_info); // TODO: increment num_retries? let mut node_info = - wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?; + wake_compute(&mut num_retries, ctx, &backend, wake_compute_retry_config).await?; node_info.reuse_settings(old_node_info); mechanism.update_connect_config(&mut node_info.config); diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index dafa7de167..9bb3f2f305 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -394,16 +394,16 @@ pub(crate) async fn handle_client( let mut node = connect_to_compute( ctx, - &TcpMechanism { + TcpMechanism { user_info: compute_user_info, params_compat, params: ¶ms, locks: &config.connect_compute_locks, }, - &auth::Backend::ControlPlane( - auth::backend::MaybeOwned::Borrowed(cplane), - compute_creds, - ), + auth::ControlPlaneWakeCompute { + cplane, + creds: compute_creds, + }, config.wake_compute_retry_config, &config.connect_to_compute, ) diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index f28982df60..0b98ac60ee 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -625,7 +625,7 @@ async fn connect_to_compute_success() { let mechanism = TestConnectMechanism::new(vec![Wake, Connect]); let user_info = helper_create_connect_info(&mechanism); let config = config(); - connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) + connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config) .await .unwrap(); mechanism.verify(); @@ -639,7 +639,7 @@ async fn connect_to_compute_retry() { let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]); let user_info = helper_create_connect_info(&mechanism); let config = config(); - connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) + connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config) .await .unwrap(); mechanism.verify(); @@ -654,7 +654,7 @@ async fn connect_to_compute_non_retry_1() { let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]); let user_info = helper_create_connect_info(&mechanism); let config = config(); - connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) + connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config) .await .unwrap_err(); mechanism.verify(); @@ -669,7 +669,7 @@ async fn connect_to_compute_non_retry_2() { let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]); let user_info = helper_create_connect_info(&mechanism); let config = config(); - connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) + connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config) .await .unwrap(); mechanism.verify(); @@ -694,7 +694,7 @@ async fn connect_to_compute_non_retry_3() { connect_to_compute( &ctx, &mechanism, - &user_info, + user_info, wake_compute_retry_config, &config, ) @@ -712,7 +712,7 @@ async fn wake_retry() { let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]); let user_info = helper_create_connect_info(&mechanism); let config = config(); - connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) + connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config) .await .unwrap(); mechanism.verify(); @@ -727,7 +727,7 @@ async fn wake_non_retry() { let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]); let user_info = helper_create_connect_info(&mechanism); let config = config(); - connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) + connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config) .await .unwrap_err(); mechanism.verify(); @@ -746,7 +746,7 @@ async fn fail_but_wake_invalidates_cache() { let user = helper_create_connect_info(&mech); let cfg = config(); - connect_to_compute(&ctx, &mech, &user, cfg.retry, &cfg) + connect_to_compute(&ctx, &mech, user, cfg.retry, &cfg) .await .unwrap(); @@ -767,7 +767,7 @@ async fn fail_no_wake_skips_cache_invalidation() { let user = helper_create_connect_info(&mech); let cfg = config(); - connect_to_compute(&ctx, &mech, &user, cfg.retry, &cfg) + connect_to_compute(&ctx, &mech, user, cfg.retry, &cfg) .await .unwrap(); @@ -788,7 +788,7 @@ async fn retry_but_wake_invalidates_cache() { let user_info = helper_create_connect_info(&mechanism); let cfg = config(); - connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg) + connect_to_compute(&ctx, &mechanism, user_info, cfg.retry, &cfg) .await .unwrap(); mechanism.verify(); @@ -811,7 +811,7 @@ async fn retry_no_wake_skips_invalidation() { let user_info = helper_create_connect_info(&mechanism); let cfg = config(); - connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg) + connect_to_compute(&ctx, &mechanism, user_info, cfg.retry, &cfg) .await .unwrap_err(); mechanism.verify(); diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 13058f08f1..4d55788883 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -224,13 +224,13 @@ impl PoolingBackend { let backend = self.auth_backend.as_ref().map(|()| keys); crate::proxy::connect_compute::connect_to_compute( ctx, - &TokioMechanism { + TokioMechanism { conn_id, conn_info, pool: self.pool.clone(), locks: &self.config.connect_compute_locks, }, - &backend, + backend, self.config.wake_compute_retry_config, &self.config.connect_to_compute, ) @@ -268,13 +268,13 @@ impl PoolingBackend { }); crate::proxy::connect_compute::connect_to_compute( ctx, - &HyperMechanism { + HyperMechanism { conn_id, conn_info, pool: self.http_conn_pool.clone(), locks: &self.config.connect_compute_locks, }, - &backend, + backend, self.config.wake_compute_retry_config, &self.config.connect_to_compute, )