diff --git a/proxy/src/cache.rs b/proxy/src/cache.rs index 4e16cc39ec..a9d6793bbd 100644 --- a/proxy/src/cache.rs +++ b/proxy/src/cache.rs @@ -262,24 +262,21 @@ pub mod timed_lru { token: Option<(C, C::LookupInfo)>, /// The value itself. - pub value: C::Value, + value: C::Value, } impl Cached { /// Place any entry into this wrapper; invalidation will be a no-op. - /// Unfortunately, rust doesn't let us implement [`From`] or [`Into`]. - pub fn new_uncached(value: impl Into) -> Self { - Self { - token: None, - value: value.into(), - } + pub fn new_uncached(value: C::Value) -> Self { + Self { token: None, value } } /// Drop this entry from a cache if it's still there. - pub fn invalidate(&self) { + pub fn invalidate(self) -> C::Value { if let Some((cache, info)) = &self.token { cache.invalidate(info); } + self.value } /// Tell if this entry is actually cached. diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index ccf100397b..b1cf2a8559 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -1,4 +1,9 @@ -use crate::{auth::parse_endpoint_param, cancellation::CancelClosure, error::UserFacingError}; +use crate::{ + auth::parse_endpoint_param, + cancellation::CancelClosure, + console::errors::WakeComputeError, + error::{io_error, UserFacingError}, +}; use futures::{FutureExt, TryFutureExt}; use itertools::Itertools; use pq_proto::StartupMessageParams; @@ -24,6 +29,12 @@ pub enum ConnectionError { TlsError(#[from] native_tls::Error), } +impl From for ConnectionError { + fn from(value: WakeComputeError) -> Self { + io_error(value).into() + } +} + impl UserFacingError for ConnectionError { fn to_string_client(&self) -> String { use ConnectionError::*; diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index 77b4330e44..3eaed1b82b 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -186,14 +186,14 @@ pub trait Api { async fn get_auth_info( &self, extra: &ConsoleReqExtra<'_>, - creds: &ClientCredentials<'_>, + creds: &ClientCredentials, ) -> Result, errors::GetAuthInfoError>; /// Wake up the compute node and return the corresponding connection info. async fn wake_compute( &self, extra: &ConsoleReqExtra<'_>, - creds: &ClientCredentials<'_>, + creds: &ClientCredentials, ) -> Result; } diff --git a/proxy/src/console/provider/mock.rs b/proxy/src/console/provider/mock.rs index 3b42c73a34..282567269d 100644 --- a/proxy/src/console/provider/mock.rs +++ b/proxy/src/console/provider/mock.rs @@ -106,7 +106,7 @@ impl super::Api for Api { async fn get_auth_info( &self, _extra: &ConsoleReqExtra<'_>, - creds: &ClientCredentials<'_>, + creds: &ClientCredentials, ) -> Result, GetAuthInfoError> { self.do_get_auth_info(creds).await } @@ -115,7 +115,7 @@ impl super::Api for Api { async fn wake_compute( &self, _extra: &ConsoleReqExtra<'_>, - _creds: &ClientCredentials<'_>, + _creds: &ClientCredentials, ) -> Result { self.do_wake_compute() .map_ok(CachedNodeInfo::new_uncached) diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index a8e855b2c8..22e766b5f1 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -123,7 +123,7 @@ impl super::Api for Api { async fn get_auth_info( &self, extra: &ConsoleReqExtra<'_>, - creds: &ClientCredentials<'_>, + creds: &ClientCredentials, ) -> Result, GetAuthInfoError> { self.do_get_auth_info(extra, creds).await } @@ -132,7 +132,7 @@ impl super::Api for Api { async fn wake_compute( &self, extra: &ConsoleReqExtra<'_>, - creds: &ClientCredentials<'_>, + creds: &ClientCredentials, ) -> Result { let key = creds.project().expect("impossible"); diff --git a/proxy/src/http/conn_pool.rs b/proxy/src/http/conn_pool.rs index 49c94a830f..703632a511 100644 --- a/proxy/src/http/conn_pool.rs +++ b/proxy/src/http/conn_pool.rs @@ -1,19 +1,17 @@ +use anyhow::Context; +use async_trait::async_trait; use parking_lot::Mutex; use pq_proto::StartupMessageParams; use std::fmt; -use std::ops::ControlFlow; use std::{collections::HashMap, sync::Arc}; use tokio::time; -use crate::config; use crate::{auth, console}; +use crate::{compute, config}; use super::sql_over_http::MAX_RESPONSE_SIZE; -use crate::proxy::{ - can_retry_tokio_postgres_error, invalidate_cache, retry_after, try_wake, - NUM_RETRIES_WAKE_COMPUTE, -}; +use crate::proxy::ConnectMechanism; use tracing::error; use tracing::info; @@ -187,6 +185,27 @@ impl GlobalConnPool { } } +struct TokioMechanism<'a> { + conn_info: &'a ConnInfo, +} + +#[async_trait] +impl ConnectMechanism for TokioMechanism<'_> { + type Connection = tokio_postgres::Client; + type ConnectError = tokio_postgres::Error; + type Error = anyhow::Error; + + async fn connect_once( + &self, + node_info: &console::CachedNodeInfo, + timeout: time::Duration, + ) -> Result { + connect_to_compute_once(node_info, self.conn_info, timeout).await + } + + fn update_connect_config(&self, _config: &mut compute::ConnCfg) {} +} + // Wake up the destination if needed. Code here is a bit involved because // we reuse the code from the usual proxy and we need to prepare few structures // that this code expects. @@ -220,72 +239,18 @@ async fn connect_to_compute( application_name: Some(APP_NAME), }; - let node_info = &mut creds.wake_compute(&extra).await?.expect("msg"); + let node_info = creds + .wake_compute(&extra) + .await? + .context("missing cache entry from wake_compute")?; - let mut num_retries = 0; - let mut wait_duration = time::Duration::ZERO; - let mut should_wake_with_error = None; - loop { - if !wait_duration.is_zero() { - time::sleep(wait_duration).await; - } - - // try wake the compute node if we have determined it's sensible to do so - if let Some(err) = should_wake_with_error.take() { - match try_wake(node_info, &extra, &creds).await { - // we can't wake up the compute node - Ok(None) => return Err(err), - // there was an error communicating with the control plane - Err(e) => return Err(e.into()), - // failed to wake up but we can continue to retry - Ok(Some(ControlFlow::Continue(()))) => { - wait_duration = retry_after(num_retries); - should_wake_with_error = Some(err); - - num_retries += 1; - info!(num_retries, "retrying wake compute"); - continue; - } - // successfully woke up a compute node and can break the wakeup loop - Ok(Some(ControlFlow::Break(()))) => {} - } - } - - match connect_to_compute_once(node_info, conn_info).await { - Ok(res) => return Ok(res), - Err(e) => { - error!(error = ?e, "could not connect to compute node"); - if !can_retry_error(&e, num_retries) { - return Err(e.into()); - } - wait_duration = retry_after(num_retries); - - // after the first connect failure, - // we should invalidate the cache and wake up a new compute node - if num_retries == 0 { - invalidate_cache(node_info); - should_wake_with_error = Some(e.into()); - } - } - } - - num_retries += 1; - info!(num_retries, "retrying connect"); - } -} - -fn can_retry_error(err: &tokio_postgres::Error, num_retries: u32) -> bool { - match err { - // retry all errors at least once - _ if num_retries == 0 => true, - _ if num_retries >= NUM_RETRIES_WAKE_COMPUTE => false, - err => can_retry_tokio_postgres_error(err), - } + crate::proxy::connect_to_compute(&TokioMechanism { conn_info }, node_info, &extra, &creds).await } async fn connect_to_compute_once( node_info: &console::CachedNodeInfo, conn_info: &ConnInfo, + timeout: time::Duration, ) -> Result { let mut config = (*node_info.config).clone(); @@ -294,6 +259,7 @@ async fn connect_to_compute_once( .password(&conn_info.password) .dbname(&conn_info.dbname) .max_backend_message_size(MAX_RESPONSE_SIZE) + .connect_timeout(timeout) .connect(tokio_postgres::NoTls) .await?; diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 8722109b80..d4a3f2641e 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -11,16 +11,16 @@ use crate::{ errors::{ApiError, WakeComputeError}, messages::MetricsAuxInfo, }, - error::io_error, stream::{PqStream, Stream}, }; use anyhow::{bail, Context}; +use async_trait::async_trait; use futures::TryFutureExt; use hyper::StatusCode; use metrics::{register_int_counter, register_int_counter_vec, IntCounter, IntCounterVec}; use once_cell::sync::Lazy; use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams}; -use std::{error::Error, ops::ControlFlow, sync::Arc}; +use std::{error::Error, io, ops::ControlFlow, sync::Arc}; use tokio::{ io::{AsyncRead, AsyncWrite, AsyncWriteExt}, time, @@ -31,7 +31,7 @@ use utils::measured_stream::MeasuredStream; /// Number of times we should retry the `/proxy_wake_compute` http request. /// Retry duration is BASE_RETRY_WAIT_DURATION * 1.5^n -pub const NUM_RETRIES_WAKE_COMPUTE: u32 = 10; +const NUM_RETRIES_WAKE_COMPUTE: u32 = 10; const BASE_RETRY_WAIT_DURATION: time::Duration = time::Duration::from_millis(100); const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; @@ -303,18 +303,18 @@ async fn handshake( /// (e.g. the compute node's address might've changed at the wrong time). /// Invalidate the cache entry (if any) to prevent subsequent errors. #[tracing::instrument(name = "invalidate_cache", skip_all)] -pub fn invalidate_cache(node_info: &console::CachedNodeInfo) { +pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg { let is_cached = node_info.cached(); if is_cached { warn!("invalidating stalled compute node info cache entry"); - node_info.invalidate(); } - let label = match is_cached { true => "compute_cached", false => "compute_uncached", }; NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc(); + + node_info.invalidate().config } /// Try to connect to the compute node once. @@ -331,47 +331,68 @@ async fn connect_to_compute_once( .await } +enum ConnectionState { + Cached(console::CachedNodeInfo), + Invalid(compute::ConnCfg, E), +} + +#[async_trait] +pub trait ConnectMechanism { + type Connection; + type ConnectError; + type Error: From; + async fn connect_once( + &self, + node_info: &console::CachedNodeInfo, + timeout: time::Duration, + ) -> Result; + + fn update_connect_config(&self, conf: &mut compute::ConnCfg); +} + +pub struct TcpMechanism<'a> { + /// KV-dictionary with PostgreSQL connection params. + pub params: &'a StartupMessageParams, +} + +#[async_trait] +impl ConnectMechanism for TcpMechanism<'_> { + type Connection = PostgresConnection; + type ConnectError = compute::ConnectionError; + type Error = compute::ConnectionError; + + async fn connect_once( + &self, + node_info: &console::CachedNodeInfo, + timeout: time::Duration, + ) -> Result { + connect_to_compute_once(node_info, timeout).await + } + + fn update_connect_config(&self, config: &mut compute::ConnCfg) { + config.set_startup_params(self.params); + } +} + /// Try to connect to the compute node, retrying if necessary. /// This function might update `node_info`, so we take it by `&mut`. #[tracing::instrument(skip_all)] -async fn connect_to_compute( - node_info: &mut console::CachedNodeInfo, - params: &StartupMessageParams, +pub async fn connect_to_compute( + mechanism: &M, + mut node_info: console::CachedNodeInfo, extra: &console::ConsoleReqExtra<'_>, creds: &auth::BackendType<'_, auth::ClientCredentials<'_>>, -) -> Result { +) -> Result +where + M::ConnectError: ShouldRetry + std::fmt::Debug, + M::Error: From, +{ + mechanism.update_connect_config(&mut node_info.config); + let mut num_retries = 0; - let mut wait_duration = time::Duration::ZERO; - let mut should_wake_with_error = None; + let mut state = ConnectionState::::Cached(node_info); + loop { - // Apply startup params to the (possibly, cached) compute node info. - node_info.config.set_startup_params(params); - - if !wait_duration.is_zero() { - time::sleep(wait_duration).await; - } - - // try wake the compute node if we have determined it's sensible to do so - if let Some(err) = should_wake_with_error.take() { - match try_wake(node_info, extra, creds).await { - // we can't wake up the compute node - Ok(None) => return Err(err), - // there was an error communicating with the control plane - Err(e) => return Err(io_error(e).into()), - // failed to wake up but we can continue to retry - Ok(Some(ControlFlow::Continue(()))) => { - wait_duration = retry_after(num_retries); - should_wake_with_error = Some(err); - - num_retries += 1; - info!(num_retries, "retrying wake compute"); - continue; - } - // successfully woke up a compute node and can break the wakeup loop - Ok(Some(ControlFlow::Break(()))) => {} - } - } - // Set a shorter timeout for the initial connection attempt. // // In case we try to connect to an outdated address that is no longer valid, the @@ -391,29 +412,56 @@ async fn connect_to_compute( time::Duration::from_secs(10) }; - // do this again to ensure we have username? - node_info.config.set_startup_params(params); + match state { + ConnectionState::Invalid(config, err) => { + match try_wake(&config, extra, creds).await { + // we can't wake up the compute node + Ok(None) => return Err(err.into()), + // there was an error communicating with the control plane + Err(e) => return Err(e.into()), + // failed to wake up but we can continue to retry + Ok(Some(ControlFlow::Continue(()))) => { + state = ConnectionState::Invalid(config, err); + let wait_duration = retry_after(num_retries); + num_retries += 1; - match connect_to_compute_once(node_info, timeout).await { - Ok(res) => return Ok(res), - Err(e) => { - error!(error = ?e, "could not connect to compute node"); - if !can_retry_error(&e, num_retries) { - return Err(e); + info!(num_retries, "retrying wake compute"); + time::sleep(wait_duration).await; + continue; + } + // successfully woke up a compute node and can break the wakeup loop + Ok(Some(ControlFlow::Break(mut node_info))) => { + mechanism.update_connect_config(&mut node_info.config); + state = ConnectionState::Cached(node_info) + } } - wait_duration = retry_after(num_retries); + } + ConnectionState::Cached(node_info) => { + match mechanism.connect_once(&node_info, timeout).await { + Ok(res) => return Ok(res), + Err(e) => { + error!(error = ?e, "could not connect to compute node"); + if !e.should_retry(num_retries) { + return Err(e.into()); + } - // after the first connect failure, - // we should invalidate the cache and wake up a new compute node - if num_retries == 0 { - invalidate_cache(node_info); - should_wake_with_error = Some(e); + // after the first connect failure, + // we should invalidate the cache and wake up a new compute node + if num_retries == 0 { + state = ConnectionState::Invalid(invalidate_cache(node_info), e); + } else { + state = ConnectionState::Cached(node_info); + } + + let wait_duration = retry_after(num_retries); + num_retries += 1; + + info!(num_retries, "retrying wake compute"); + time::sleep(wait_duration).await; + } } } } - - num_retries += 1; - info!(num_retries, "retrying connect"); } } @@ -421,11 +469,11 @@ async fn connect_to_compute( /// * Returns Ok(Some(true)) if there was an error waking but retries are acceptable /// * Returns Ok(Some(false)) if the wakeup succeeded /// * Returns Ok(None) or Err(e) if there was an error -pub async fn try_wake( - node_info: &mut console::CachedNodeInfo, +async fn try_wake( + config: &compute::ConnCfg, extra: &console::ConsoleReqExtra<'_>, creds: &auth::BackendType<'_, auth::ClientCredentials<'_>>, -) -> Result>, WakeComputeError> { +) -> Result>, WakeComputeError> { info!("compute node's state has likely changed; requesting a wake-up"); match creds.wake_compute(extra).await { // retry wake if the compute was in an invalid state @@ -435,53 +483,69 @@ pub async fn try_wake( })) => Ok(Some(ControlFlow::Continue(()))), // Update `node_info` and try again. Ok(Some(mut new)) => { - new.config.reuse_password(&node_info.config); - *node_info = new; - Ok(Some(ControlFlow::Break(()))) + new.config.reuse_password(config); + Ok(Some(ControlFlow::Break(new))) } Err(e) => Err(e), Ok(None) => Ok(None), } } -fn can_retry_error(err: &compute::ConnectionError, num_retries: u32) -> bool { - match err { - // retry all errors at least once - _ if num_retries == 0 => true, - _ if num_retries >= NUM_RETRIES_WAKE_COMPUTE => false, - compute::ConnectionError::Postgres(err) => can_retry_tokio_postgres_error(err), - compute::ConnectionError::CouldNotConnect(err) => is_io_connection_err(err), - _ => false, +pub trait ShouldRetry { + fn could_retry(&self) -> bool; + fn should_retry(&self, num_retries: u32) -> bool { + match self { + // retry all errors at least once + _ if num_retries == 0 => true, + _ if num_retries >= NUM_RETRIES_WAKE_COMPUTE => false, + err => err.could_retry(), + } } } -pub fn can_retry_tokio_postgres_error(err: &tokio_postgres::Error) -> bool { - if let Some(io_err) = err.source().and_then(|x| x.downcast_ref()) { - is_io_connection_err(io_err) - } else if let Some(db_err) = err.source().and_then(|x| x.downcast_ref()) { - is_sql_connection_err(db_err) - } else { - false +impl ShouldRetry for io::Error { + fn could_retry(&self) -> bool { + use std::io::ErrorKind; + matches!( + self.kind(), + ErrorKind::ConnectionRefused | ErrorKind::AddrNotAvailable | ErrorKind::TimedOut + ) } } -fn is_sql_connection_err(err: &tokio_postgres::error::DbError) -> bool { - use tokio_postgres::error::SqlState; - matches!( - err.code(), - &SqlState::CONNECTION_FAILURE - | &SqlState::CONNECTION_EXCEPTION - | &SqlState::CONNECTION_DOES_NOT_EXIST - | &SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION, - ) +impl ShouldRetry for tokio_postgres::error::DbError { + fn could_retry(&self) -> bool { + use tokio_postgres::error::SqlState; + matches!( + self.code(), + &SqlState::CONNECTION_FAILURE + | &SqlState::CONNECTION_EXCEPTION + | &SqlState::CONNECTION_DOES_NOT_EXIST + | &SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION, + ) + } } -fn is_io_connection_err(err: &std::io::Error) -> bool { - use std::io::ErrorKind; - matches!( - err.kind(), - ErrorKind::ConnectionRefused | ErrorKind::AddrNotAvailable | ErrorKind::TimedOut - ) +impl ShouldRetry for tokio_postgres::Error { + fn could_retry(&self) -> bool { + if let Some(io_err) = self.source().and_then(|x| x.downcast_ref()) { + io::Error::could_retry(io_err) + } else if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) { + tokio_postgres::error::DbError::could_retry(db_err) + } else { + false + } + } +} + +impl ShouldRetry for compute::ConnectionError { + fn could_retry(&self) -> bool { + match self { + compute::ConnectionError::Postgres(err) => err.could_retry(), + compute::ConnectionError::CouldNotConnect(err) => err.could_retry(), + _ => false, + } + } } pub fn retry_after(num_retries: u32) -> time::Duration { @@ -637,7 +701,8 @@ impl Client<'_, S> { node_info.allow_self_signed_compute = allow_self_signed_compute; - let mut node = connect_to_compute(&mut node_info, params, &extra, &creds) + let aux = node_info.aux.clone(); + let mut node = connect_to_compute(&TcpMechanism { params }, node_info, &extra, &creds) .or_else(|e| stream.throw_error(e)) .await?; @@ -648,6 +713,6 @@ impl Client<'_, S> { // immediately after opening the connection. let (stream, read_buf) = stream.into_inner(); node.stream.write_all(&read_buf).await?; - proxy_pass(stream, node.stream, &node_info.aux).await + proxy_pass(stream, node.stream, &aux).await } }