proxy: merge connect compute (#4713)

## Problem

Half of #4699.

TCP/WS have one implementation of `connect_to_compute`, HTTP has another
implementation of `connect_to_compute`.

Having both is annoying to deal with.

## Summary of changes

Creates a set of traits `ConnectMechanism` and `ShouldError` that allows
the `connect_to_compute` to be generic over raw TCP stream or
tokio_postgres based connections.

I'm not super happy with this. I think it would be nice to
remove tokio_postgres entirely but that will need a lot more thought to
be put into it.

I have also slightly refactored the caching to use fewer references.
Instead using ownership to ensure the state of retrying is encoded in
the type system.
This commit is contained in:
Conrad Ludgate
2023-07-17 15:53:01 +01:00
committed by GitHub
parent 1066bca5e3
commit 7c85c7ea91
7 changed files with 215 additions and 176 deletions

View File

@@ -262,24 +262,21 @@ pub mod timed_lru {
token: Option<(C, C::LookupInfo<C::Key>)>,
/// The value itself.
pub value: C::Value,
value: C::Value,
}
impl<C: Cache> Cached<C> {
/// Place any entry into this wrapper; invalidation will be a no-op.
/// Unfortunately, rust doesn't let us implement [`From`] or [`Into`].
pub fn new_uncached(value: impl Into<C::Value>) -> Self {
Self {
token: None,
value: value.into(),
}
pub fn new_uncached(value: C::Value) -> Self {
Self { token: None, value }
}
/// Drop this entry from a cache if it's still there.
pub fn invalidate(&self) {
pub fn invalidate(self) -> C::Value {
if let Some((cache, info)) = &self.token {
cache.invalidate(info);
}
self.value
}
/// Tell if this entry is actually cached.

View File

@@ -1,4 +1,9 @@
use crate::{auth::parse_endpoint_param, cancellation::CancelClosure, error::UserFacingError};
use crate::{
auth::parse_endpoint_param,
cancellation::CancelClosure,
console::errors::WakeComputeError,
error::{io_error, UserFacingError},
};
use futures::{FutureExt, TryFutureExt};
use itertools::Itertools;
use pq_proto::StartupMessageParams;
@@ -24,6 +29,12 @@ pub enum ConnectionError {
TlsError(#[from] native_tls::Error),
}
impl From<WakeComputeError> for ConnectionError {
fn from(value: WakeComputeError) -> Self {
io_error(value).into()
}
}
impl UserFacingError for ConnectionError {
fn to_string_client(&self) -> String {
use ConnectionError::*;

View File

@@ -186,14 +186,14 @@ pub trait Api {
async fn get_auth_info(
&self,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
creds: &ClientCredentials,
) -> Result<Option<AuthInfo>, errors::GetAuthInfoError>;
/// Wake up the compute node and return the corresponding connection info.
async fn wake_compute(
&self,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
creds: &ClientCredentials,
) -> Result<CachedNodeInfo, errors::WakeComputeError>;
}

View File

@@ -106,7 +106,7 @@ impl super::Api for Api {
async fn get_auth_info(
&self,
_extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
creds: &ClientCredentials,
) -> Result<Option<AuthInfo>, GetAuthInfoError> {
self.do_get_auth_info(creds).await
}
@@ -115,7 +115,7 @@ impl super::Api for Api {
async fn wake_compute(
&self,
_extra: &ConsoleReqExtra<'_>,
_creds: &ClientCredentials<'_>,
_creds: &ClientCredentials,
) -> Result<CachedNodeInfo, WakeComputeError> {
self.do_wake_compute()
.map_ok(CachedNodeInfo::new_uncached)

View File

@@ -123,7 +123,7 @@ impl super::Api for Api {
async fn get_auth_info(
&self,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
creds: &ClientCredentials,
) -> Result<Option<AuthInfo>, GetAuthInfoError> {
self.do_get_auth_info(extra, creds).await
}
@@ -132,7 +132,7 @@ impl super::Api for Api {
async fn wake_compute(
&self,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
creds: &ClientCredentials,
) -> Result<CachedNodeInfo, WakeComputeError> {
let key = creds.project().expect("impossible");

View File

@@ -1,19 +1,17 @@
use anyhow::Context;
use async_trait::async_trait;
use parking_lot::Mutex;
use pq_proto::StartupMessageParams;
use std::fmt;
use std::ops::ControlFlow;
use std::{collections::HashMap, sync::Arc};
use tokio::time;
use crate::config;
use crate::{auth, console};
use crate::{compute, config};
use super::sql_over_http::MAX_RESPONSE_SIZE;
use crate::proxy::{
can_retry_tokio_postgres_error, invalidate_cache, retry_after, try_wake,
NUM_RETRIES_WAKE_COMPUTE,
};
use crate::proxy::ConnectMechanism;
use tracing::error;
use tracing::info;
@@ -187,6 +185,27 @@ impl GlobalConnPool {
}
}
struct TokioMechanism<'a> {
conn_info: &'a ConnInfo,
}
#[async_trait]
impl ConnectMechanism for TokioMechanism<'_> {
type Connection = tokio_postgres::Client;
type ConnectError = tokio_postgres::Error;
type Error = anyhow::Error;
async fn connect_once(
&self,
node_info: &console::CachedNodeInfo,
timeout: time::Duration,
) -> Result<Self::Connection, Self::ConnectError> {
connect_to_compute_once(node_info, self.conn_info, timeout).await
}
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
}
// Wake up the destination if needed. Code here is a bit involved because
// we reuse the code from the usual proxy and we need to prepare few structures
// that this code expects.
@@ -220,72 +239,18 @@ async fn connect_to_compute(
application_name: Some(APP_NAME),
};
let node_info = &mut creds.wake_compute(&extra).await?.expect("msg");
let node_info = creds
.wake_compute(&extra)
.await?
.context("missing cache entry from wake_compute")?;
let mut num_retries = 0;
let mut wait_duration = time::Duration::ZERO;
let mut should_wake_with_error = None;
loop {
if !wait_duration.is_zero() {
time::sleep(wait_duration).await;
}
// try wake the compute node if we have determined it's sensible to do so
if let Some(err) = should_wake_with_error.take() {
match try_wake(node_info, &extra, &creds).await {
// we can't wake up the compute node
Ok(None) => return Err(err),
// there was an error communicating with the control plane
Err(e) => return Err(e.into()),
// failed to wake up but we can continue to retry
Ok(Some(ControlFlow::Continue(()))) => {
wait_duration = retry_after(num_retries);
should_wake_with_error = Some(err);
num_retries += 1;
info!(num_retries, "retrying wake compute");
continue;
}
// successfully woke up a compute node and can break the wakeup loop
Ok(Some(ControlFlow::Break(()))) => {}
}
}
match connect_to_compute_once(node_info, conn_info).await {
Ok(res) => return Ok(res),
Err(e) => {
error!(error = ?e, "could not connect to compute node");
if !can_retry_error(&e, num_retries) {
return Err(e.into());
}
wait_duration = retry_after(num_retries);
// after the first connect failure,
// we should invalidate the cache and wake up a new compute node
if num_retries == 0 {
invalidate_cache(node_info);
should_wake_with_error = Some(e.into());
}
}
}
num_retries += 1;
info!(num_retries, "retrying connect");
}
}
fn can_retry_error(err: &tokio_postgres::Error, num_retries: u32) -> bool {
match err {
// retry all errors at least once
_ if num_retries == 0 => true,
_ if num_retries >= NUM_RETRIES_WAKE_COMPUTE => false,
err => can_retry_tokio_postgres_error(err),
}
crate::proxy::connect_to_compute(&TokioMechanism { conn_info }, node_info, &extra, &creds).await
}
async fn connect_to_compute_once(
node_info: &console::CachedNodeInfo,
conn_info: &ConnInfo,
timeout: time::Duration,
) -> Result<tokio_postgres::Client, tokio_postgres::Error> {
let mut config = (*node_info.config).clone();
@@ -294,6 +259,7 @@ async fn connect_to_compute_once(
.password(&conn_info.password)
.dbname(&conn_info.dbname)
.max_backend_message_size(MAX_RESPONSE_SIZE)
.connect_timeout(timeout)
.connect(tokio_postgres::NoTls)
.await?;

View File

@@ -11,16 +11,16 @@ use crate::{
errors::{ApiError, WakeComputeError},
messages::MetricsAuxInfo,
},
error::io_error,
stream::{PqStream, Stream},
};
use anyhow::{bail, Context};
use async_trait::async_trait;
use futures::TryFutureExt;
use hyper::StatusCode;
use metrics::{register_int_counter, register_int_counter_vec, IntCounter, IntCounterVec};
use once_cell::sync::Lazy;
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
use std::{error::Error, ops::ControlFlow, sync::Arc};
use std::{error::Error, io, ops::ControlFlow, sync::Arc};
use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt},
time,
@@ -31,7 +31,7 @@ use utils::measured_stream::MeasuredStream;
/// Number of times we should retry the `/proxy_wake_compute` http request.
/// Retry duration is BASE_RETRY_WAIT_DURATION * 1.5^n
pub const NUM_RETRIES_WAKE_COMPUTE: u32 = 10;
const NUM_RETRIES_WAKE_COMPUTE: u32 = 10;
const BASE_RETRY_WAIT_DURATION: time::Duration = time::Duration::from_millis(100);
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
@@ -303,18 +303,18 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
/// (e.g. the compute node's address might've changed at the wrong time).
/// Invalidate the cache entry (if any) to prevent subsequent errors.
#[tracing::instrument(name = "invalidate_cache", skip_all)]
pub fn invalidate_cache(node_info: &console::CachedNodeInfo) {
pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg {
let is_cached = node_info.cached();
if is_cached {
warn!("invalidating stalled compute node info cache entry");
node_info.invalidate();
}
let label = match is_cached {
true => "compute_cached",
false => "compute_uncached",
};
NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc();
node_info.invalidate().config
}
/// Try to connect to the compute node once.
@@ -331,47 +331,68 @@ async fn connect_to_compute_once(
.await
}
enum ConnectionState<E> {
Cached(console::CachedNodeInfo),
Invalid(compute::ConnCfg, E),
}
#[async_trait]
pub trait ConnectMechanism {
type Connection;
type ConnectError;
type Error: From<Self::ConnectError>;
async fn connect_once(
&self,
node_info: &console::CachedNodeInfo,
timeout: time::Duration,
) -> Result<Self::Connection, Self::ConnectError>;
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
}
pub struct TcpMechanism<'a> {
/// KV-dictionary with PostgreSQL connection params.
pub params: &'a StartupMessageParams,
}
#[async_trait]
impl ConnectMechanism for TcpMechanism<'_> {
type Connection = PostgresConnection;
type ConnectError = compute::ConnectionError;
type Error = compute::ConnectionError;
async fn connect_once(
&self,
node_info: &console::CachedNodeInfo,
timeout: time::Duration,
) -> Result<PostgresConnection, Self::Error> {
connect_to_compute_once(node_info, timeout).await
}
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
config.set_startup_params(self.params);
}
}
/// Try to connect to the compute node, retrying if necessary.
/// This function might update `node_info`, so we take it by `&mut`.
#[tracing::instrument(skip_all)]
async fn connect_to_compute(
node_info: &mut console::CachedNodeInfo,
params: &StartupMessageParams,
pub async fn connect_to_compute<M: ConnectMechanism>(
mechanism: &M,
mut node_info: console::CachedNodeInfo,
extra: &console::ConsoleReqExtra<'_>,
creds: &auth::BackendType<'_, auth::ClientCredentials<'_>>,
) -> Result<PostgresConnection, compute::ConnectionError> {
) -> Result<M::Connection, M::Error>
where
M::ConnectError: ShouldRetry + std::fmt::Debug,
M::Error: From<WakeComputeError>,
{
mechanism.update_connect_config(&mut node_info.config);
let mut num_retries = 0;
let mut wait_duration = time::Duration::ZERO;
let mut should_wake_with_error = None;
let mut state = ConnectionState::<M::ConnectError>::Cached(node_info);
loop {
// Apply startup params to the (possibly, cached) compute node info.
node_info.config.set_startup_params(params);
if !wait_duration.is_zero() {
time::sleep(wait_duration).await;
}
// try wake the compute node if we have determined it's sensible to do so
if let Some(err) = should_wake_with_error.take() {
match try_wake(node_info, extra, creds).await {
// we can't wake up the compute node
Ok(None) => return Err(err),
// there was an error communicating with the control plane
Err(e) => return Err(io_error(e).into()),
// failed to wake up but we can continue to retry
Ok(Some(ControlFlow::Continue(()))) => {
wait_duration = retry_after(num_retries);
should_wake_with_error = Some(err);
num_retries += 1;
info!(num_retries, "retrying wake compute");
continue;
}
// successfully woke up a compute node and can break the wakeup loop
Ok(Some(ControlFlow::Break(()))) => {}
}
}
// Set a shorter timeout for the initial connection attempt.
//
// In case we try to connect to an outdated address that is no longer valid, the
@@ -391,29 +412,56 @@ async fn connect_to_compute(
time::Duration::from_secs(10)
};
// do this again to ensure we have username?
node_info.config.set_startup_params(params);
match state {
ConnectionState::Invalid(config, err) => {
match try_wake(&config, extra, creds).await {
// we can't wake up the compute node
Ok(None) => return Err(err.into()),
// there was an error communicating with the control plane
Err(e) => return Err(e.into()),
// failed to wake up but we can continue to retry
Ok(Some(ControlFlow::Continue(()))) => {
state = ConnectionState::Invalid(config, err);
let wait_duration = retry_after(num_retries);
num_retries += 1;
match connect_to_compute_once(node_info, timeout).await {
Ok(res) => return Ok(res),
Err(e) => {
error!(error = ?e, "could not connect to compute node");
if !can_retry_error(&e, num_retries) {
return Err(e);
info!(num_retries, "retrying wake compute");
time::sleep(wait_duration).await;
continue;
}
// successfully woke up a compute node and can break the wakeup loop
Ok(Some(ControlFlow::Break(mut node_info))) => {
mechanism.update_connect_config(&mut node_info.config);
state = ConnectionState::Cached(node_info)
}
}
wait_duration = retry_after(num_retries);
}
ConnectionState::Cached(node_info) => {
match mechanism.connect_once(&node_info, timeout).await {
Ok(res) => return Ok(res),
Err(e) => {
error!(error = ?e, "could not connect to compute node");
if !e.should_retry(num_retries) {
return Err(e.into());
}
// after the first connect failure,
// we should invalidate the cache and wake up a new compute node
if num_retries == 0 {
invalidate_cache(node_info);
should_wake_with_error = Some(e);
// after the first connect failure,
// we should invalidate the cache and wake up a new compute node
if num_retries == 0 {
state = ConnectionState::Invalid(invalidate_cache(node_info), e);
} else {
state = ConnectionState::Cached(node_info);
}
let wait_duration = retry_after(num_retries);
num_retries += 1;
info!(num_retries, "retrying wake compute");
time::sleep(wait_duration).await;
}
}
}
}
num_retries += 1;
info!(num_retries, "retrying connect");
}
}
@@ -421,11 +469,11 @@ async fn connect_to_compute(
/// * Returns Ok(Some(true)) if there was an error waking but retries are acceptable
/// * Returns Ok(Some(false)) if the wakeup succeeded
/// * Returns Ok(None) or Err(e) if there was an error
pub async fn try_wake(
node_info: &mut console::CachedNodeInfo,
async fn try_wake(
config: &compute::ConnCfg,
extra: &console::ConsoleReqExtra<'_>,
creds: &auth::BackendType<'_, auth::ClientCredentials<'_>>,
) -> Result<Option<ControlFlow<()>>, WakeComputeError> {
) -> Result<Option<ControlFlow<console::CachedNodeInfo>>, WakeComputeError> {
info!("compute node's state has likely changed; requesting a wake-up");
match creds.wake_compute(extra).await {
// retry wake if the compute was in an invalid state
@@ -435,53 +483,69 @@ pub async fn try_wake(
})) => Ok(Some(ControlFlow::Continue(()))),
// Update `node_info` and try again.
Ok(Some(mut new)) => {
new.config.reuse_password(&node_info.config);
*node_info = new;
Ok(Some(ControlFlow::Break(())))
new.config.reuse_password(config);
Ok(Some(ControlFlow::Break(new)))
}
Err(e) => Err(e),
Ok(None) => Ok(None),
}
}
fn can_retry_error(err: &compute::ConnectionError, num_retries: u32) -> bool {
match err {
// retry all errors at least once
_ if num_retries == 0 => true,
_ if num_retries >= NUM_RETRIES_WAKE_COMPUTE => false,
compute::ConnectionError::Postgres(err) => can_retry_tokio_postgres_error(err),
compute::ConnectionError::CouldNotConnect(err) => is_io_connection_err(err),
_ => false,
pub trait ShouldRetry {
fn could_retry(&self) -> bool;
fn should_retry(&self, num_retries: u32) -> bool {
match self {
// retry all errors at least once
_ if num_retries == 0 => true,
_ if num_retries >= NUM_RETRIES_WAKE_COMPUTE => false,
err => err.could_retry(),
}
}
}
pub fn can_retry_tokio_postgres_error(err: &tokio_postgres::Error) -> bool {
if let Some(io_err) = err.source().and_then(|x| x.downcast_ref()) {
is_io_connection_err(io_err)
} else if let Some(db_err) = err.source().and_then(|x| x.downcast_ref()) {
is_sql_connection_err(db_err)
} else {
false
impl ShouldRetry for io::Error {
fn could_retry(&self) -> bool {
use std::io::ErrorKind;
matches!(
self.kind(),
ErrorKind::ConnectionRefused | ErrorKind::AddrNotAvailable | ErrorKind::TimedOut
)
}
}
fn is_sql_connection_err(err: &tokio_postgres::error::DbError) -> bool {
use tokio_postgres::error::SqlState;
matches!(
err.code(),
&SqlState::CONNECTION_FAILURE
| &SqlState::CONNECTION_EXCEPTION
| &SqlState::CONNECTION_DOES_NOT_EXIST
| &SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION,
)
impl ShouldRetry for tokio_postgres::error::DbError {
fn could_retry(&self) -> bool {
use tokio_postgres::error::SqlState;
matches!(
self.code(),
&SqlState::CONNECTION_FAILURE
| &SqlState::CONNECTION_EXCEPTION
| &SqlState::CONNECTION_DOES_NOT_EXIST
| &SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION,
)
}
}
fn is_io_connection_err(err: &std::io::Error) -> bool {
use std::io::ErrorKind;
matches!(
err.kind(),
ErrorKind::ConnectionRefused | ErrorKind::AddrNotAvailable | ErrorKind::TimedOut
)
impl ShouldRetry for tokio_postgres::Error {
fn could_retry(&self) -> bool {
if let Some(io_err) = self.source().and_then(|x| x.downcast_ref()) {
io::Error::could_retry(io_err)
} else if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) {
tokio_postgres::error::DbError::could_retry(db_err)
} else {
false
}
}
}
impl ShouldRetry for compute::ConnectionError {
fn could_retry(&self) -> bool {
match self {
compute::ConnectionError::Postgres(err) => err.could_retry(),
compute::ConnectionError::CouldNotConnect(err) => err.could_retry(),
_ => false,
}
}
}
pub fn retry_after(num_retries: u32) -> time::Duration {
@@ -637,7 +701,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
node_info.allow_self_signed_compute = allow_self_signed_compute;
let mut node = connect_to_compute(&mut node_info, params, &extra, &creds)
let aux = node_info.aux.clone();
let mut node = connect_to_compute(&TcpMechanism { params }, node_info, &extra, &creds)
.or_else(|e| stream.throw_error(e))
.await?;
@@ -648,6 +713,6 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
// immediately after opening the connection.
let (stream, read_buf) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
proxy_pass(stream, node.stream, &node_info.aux).await
proxy_pass(stream, node.stream, &aux).await
}
}