Compare commits

...

5 Commits

Author SHA1 Message Date
Conrad Ludgate
c0e1e1dd74 standardise logging 2025-06-25 15:54:26 +01:00
Conrad Ludgate
010cd34635 simplify error handling 2025-06-25 15:54:26 +01:00
Conrad Ludgate
11ffe4c86c remove unused error retries 2025-06-25 15:54:26 +01:00
Conrad Ludgate
d6a5085664 move h2::handshake outside of hypermechanism 2025-06-25 15:54:26 +01:00
Conrad Ludgate
16d9889a51 move authenticate outside of tokiomechanism 2025-06-25 15:54:26 +01:00
17 changed files with 246 additions and 523 deletions

View File

@@ -1,12 +1,13 @@
use std::net::IpAddr;
use postgres_protocol2::message::backend::Message;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use crate::client::SocketConfig;
use crate::codec::BackendMessage;
use crate::config::Host;
use crate::config::{Host, SslMode};
use crate::connect_raw::connect_raw;
use crate::connect_socket::connect_socket;
use crate::connect_tls::connect_tls;
@@ -46,13 +47,7 @@ where
{
let socket = connect_socket(host_addr, host, port, config.connect_timeout).await?;
let stream = connect_tls(socket, config.ssl_mode, tls).await?;
let RawConnection {
stream,
parameters,
delayed_notice,
process_id,
secret_key,
} = connect_raw(stream, config).await?;
let raw = connect_raw(stream, config).await?;
let socket_config = SocketConfig {
host_addr,
@@ -61,24 +56,46 @@ where
connect_timeout: config.connect_timeout,
};
let (client_tx, conn_rx) = mpsc::unbounded_channel();
let (conn_tx, client_rx) = mpsc::channel(4);
let client = Client::new(
client_tx,
client_rx,
socket_config,
config.ssl_mode,
process_id,
secret_key,
);
// delayed notices are always sent as "Async" messages.
let delayed = delayed_notice
.into_iter()
.map(|m| BackendMessage::Async(Message::NoticeResponse(m)))
.collect();
let connection = Connection::new(stream, delayed, parameters, conn_tx, conn_rx);
Ok((client, connection))
Ok(raw.into_managed_conn(socket_config, config.ssl_mode))
}
impl<S, T> RawConnection<S, T>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
pub fn into_managed_conn(
self,
socket_config: SocketConfig,
ssl_mode: SslMode,
) -> (Client, Connection<S, T>) {
let RawConnection {
stream,
parameters,
delayed_notice,
process_id,
secret_key,
} = self;
let (client_tx, conn_rx) = mpsc::unbounded_channel();
let (conn_tx, client_rx) = mpsc::channel(4);
let client = Client::new(
client_tx,
client_rx,
socket_config,
ssl_mode,
process_id,
secret_key,
);
// delayed notices are always sent as "Async" messages.
let delayed = delayed_notice
.into_iter()
.map(|m| BackendMessage::Async(Message::NoticeResponse(m)))
.collect();
let connection = Connection::new(stream, delayed, parameters, conn_tx, conn_rx);
(client, connection)
}
}

View File

@@ -350,7 +350,7 @@ impl CancellationHandler {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CancelClosure {
socket_addr: SocketAddr,
cancel_token: RawCancelToken,
pub cancel_token: RawCancelToken,
hostname: String, // for pg_sni router
user_info: ComputeUserInfo,
}

View File

@@ -86,6 +86,14 @@ pub(crate) enum ConnectionError {
#[error("error acquiring resource permit: {0}")]
TooManyConnectionAttempts(#[from] ApiLockError),
#[cfg(test)]
#[error("retryable: {retryable}, wakeable: {wakeable}, kind: {kind:?}")]
TestError {
retryable: bool,
wakeable: bool,
kind: crate::error::ErrorKind,
},
}
impl UserFacingError for ConnectionError {
@@ -96,6 +104,8 @@ impl UserFacingError for ConnectionError {
"Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
}
ConnectionError::TlsError(_) => COULD_NOT_CONNECT.to_owned(),
#[cfg(test)]
ConnectionError::TestError { .. } => self.to_string(),
}
}
}
@@ -106,6 +116,8 @@ impl ReportableError for ConnectionError {
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
#[cfg(test)]
ConnectionError::TestError { kind, .. } => *kind,
}
}
}
@@ -252,6 +264,19 @@ impl AuthInfo {
.await?;
drop(pause);
// TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
info!(
compute_id = %compute.aux.compute_id,
pid = connection.process_id,
cold_start_info = ctx.cold_start_info().as_str(),
query_id = ctx.get_testodrome_id().as_deref(),
sslmode = ?compute.ssl_mode,
"connected to compute node at {} ({}) latency={}",
compute.hostname,
compute.socket_addr,
ctx.get_proxy_latency(),
);
let RawConnection {
stream: _,
parameters,
@@ -260,8 +285,6 @@ impl AuthInfo {
secret_key,
} = connection;
tracing::Span::current().record("pid", tracing::field::display(process_id));
// NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
// Yet another reason to rework the connection establishing code.
let cancel_closure = CancelClosure::new(
@@ -288,6 +311,7 @@ impl ConnectInfo {
async fn connect_raw(
&self,
config: &ComputeConfig,
direct: bool,
) -> Result<(SocketAddr, MaybeTlsStream<TcpStream, RustlsStream>), TlsError> {
let timeout = config.timeout;
@@ -330,7 +354,7 @@ impl ConnectInfo {
match connect_once(&*addrs).await {
Ok((sockaddr, stream)) => Ok((
sockaddr,
tls::connect_tls(stream, self.ssl_mode, config, host).await?,
tls::connect_tls(stream, self.ssl_mode, config, host, direct).await?,
)),
Err(err) => {
warn!("couldn't connect to compute node at {host}:{port}: {err}");
@@ -357,7 +381,7 @@ pub struct PostgresSettings {
pub struct ComputeConnection {
/// Socket connected to a compute node.
pub stream: MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
pub stream: MaybeRustlsStream,
/// Labels for proxy's metrics.
pub aux: MetricsAuxInfo,
pub hostname: Host,
@@ -373,23 +397,12 @@ impl ConnectInfo {
ctx: &RequestContext,
aux: &MetricsAuxInfo,
config: &ComputeConfig,
direct: bool,
) -> Result<ComputeConnection, ConnectionError> {
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (socket_addr, stream) = self.connect_raw(config).await?;
let (socket_addr, stream) = self.connect_raw(config, direct).await?;
drop(pause);
tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id));
// TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
info!(
cold_start_info = ctx.cold_start_info().as_str(),
"connected to compute node at {} ({socket_addr}) sslmode={:?}, latency={}, query_id={}",
self.host,
self.ssl_mode,
ctx.get_proxy_latency(),
ctx.get_testodrome_id().unwrap_or_default(),
);
let connection = ComputeConnection {
stream,
socket_addr,

View File

@@ -11,8 +11,6 @@ use crate::proxy::retry::CouldRetry;
#[derive(Debug, Error)]
pub enum TlsError {
#[error(transparent)]
Dns(#[from] InvalidDnsNameError),
#[error(transparent)]
Connection(#[from] std::io::Error),
#[error("TLS required but not provided")]
@@ -22,7 +20,6 @@ pub enum TlsError {
impl CouldRetry for TlsError {
fn could_retry(&self) -> bool {
match self {
TlsError::Dns(_) => false,
TlsError::Connection(err) => err.could_retry(),
// perhaps compute didn't realise it supports TLS?
TlsError::Required => true,
@@ -35,6 +32,7 @@ pub async fn connect_tls<S, T>(
mode: SslMode,
tls: &T,
host: &str,
direct: bool,
) -> Result<MaybeTlsStream<S, T::Stream>, TlsError>
where
S: AsyncRead + AsyncWrite + Unpin + Send,
@@ -49,7 +47,7 @@ where
SslMode::Prefer | SslMode::Require => {}
}
if !request_tls(&mut stream).await? {
if !direct && !request_tls(&mut stream).await? {
if SslMode::Require == mode {
return Err(TlsError::Required);
}
@@ -57,7 +55,6 @@ where
return Ok(MaybeTlsStream::Raw(stream));
}
Ok(MaybeTlsStream::Tls(
tls.make_tls_connect(host)?.connect(stream).boxed().await?,
))
let c = tls.make_tls_connect(host).map_err(std::io::Error::other)?;
Ok(MaybeTlsStream::Tls(c.connect(stream).boxed().await?))
}

View File

@@ -222,6 +222,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
ctx,
&TcpMechanism {
locks: &config.connect_compute_locks,
direct: false,
},
&node_info,
config.wake_compute_retry_config,

View File

@@ -263,7 +263,12 @@ impl NeonControlPlaneClient {
None => SslMode::Disable,
};
let host = match body.server_name {
Some(host) => host.into(),
Some(host) => {
if rustls::pki_types::DnsName::try_from_str(&host).is_err() {
return Err(WakeComputeError::BadComputeAddress(host.into_boxed_str()));
}
host.into()
}
None => host.into(),
};

View File

@@ -77,8 +77,9 @@ impl NodeInfo {
&self,
ctx: &RequestContext,
config: &ComputeConfig,
direct: bool,
) -> Result<compute::ComputeConnection, compute::ConnectionError> {
self.conn_info.connect(ctx, &self.aux, config).await
self.conn_info.connect(ctx, &self.aux, config, direct).await
}
}

View File

@@ -1,18 +1,15 @@
use async_trait::async_trait;
use tokio::time;
use tracing::{debug, info, warn};
use crate::compute::{self, COULD_NOT_CONNECT, ComputeConnection};
use crate::config::{ComputeConfig, RetryConfig};
use crate::context::RequestContext;
use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::{self, NodeInfo};
use crate::error::ReportableError;
use crate::metrics::{
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
};
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry};
use crate::proxy::retry::{ShouldRetryWakeCompute, retry_after, should_retry};
use crate::proxy::wake_compute::{WakeComputeBackend, wake_compute};
use crate::types::Host;
@@ -35,42 +32,34 @@ pub(crate) fn invalidate_cache(node_info: control_plane::CachedNodeInfo) -> Node
node_info.invalidate()
}
#[async_trait]
pub(crate) trait ConnectMechanism {
type Connection;
type ConnectError: ReportableError;
type Error: From<Self::ConnectError>;
async fn connect_once(
&self,
ctx: &RequestContext,
node_info: &control_plane::CachedNodeInfo,
config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError>;
) -> Result<Self::Connection, compute::ConnectionError>;
}
pub(crate) struct TcpMechanism {
/// connect_to_compute concurrency lock
pub(crate) locks: &'static ApiLocks<Host>,
// whether to negotiate TLS for postgres protocol.
pub(crate) direct: bool,
}
#[async_trait]
impl ConnectMechanism for TcpMechanism {
type Connection = ComputeConnection;
type ConnectError = compute::ConnectionError;
type Error = compute::ConnectionError;
#[tracing::instrument(skip_all, fields(
pid = tracing::field::Empty,
compute_id = tracing::field::Empty
))]
async fn connect_once(
&self,
ctx: &RequestContext,
node_info: &control_plane::CachedNodeInfo,
config: &ComputeConfig,
) -> Result<ComputeConnection, Self::Error> {
) -> Result<ComputeConnection, compute::ConnectionError> {
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;
permit.release_result(node_info.connect(ctx, config).await)
permit.release_result(node_info.connect(ctx, config, self.direct).await)
}
}
@@ -82,11 +71,7 @@ pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: WakeComputeBacken
user_info: &B,
wake_compute_retry_config: RetryConfig,
compute: &ComputeConfig,
) -> Result<M::Connection, M::Error>
where
M::ConnectError: CouldRetry + ShouldRetryWakeCompute + std::fmt::Debug,
M::Error: From<WakeComputeError>,
{
) -> Result<M::Connection, compute::ConnectionError> {
let mut num_retries = 0;
let node_info =
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
@@ -120,7 +105,7 @@ where
},
num_retries.into(),
);
return Err(err.into());
return Err(err);
}
node_info
} else {
@@ -161,7 +146,7 @@ where
},
num_retries.into(),
);
return Err(e.into());
return Err(e);
}
warn!(error = ?e, num_retries, retriable = true, COULD_NOT_CONNECT);

View File

@@ -358,6 +358,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
ctx,
&TcpMechanism {
locks: &config.connect_compute_locks,
direct: false,
},
&auth::Backend::ControlPlane(cplane, creds.info.clone()),
config.wake_compute_retry_config,

View File

@@ -1,4 +1,3 @@
use std::error::Error;
use std::io;
use tokio::time;
@@ -31,85 +30,25 @@ impl CouldRetry for io::Error {
}
}
impl CouldRetry for postgres_client::error::DbError {
fn could_retry(&self) -> bool {
use postgres_client::error::SqlState;
matches!(
self.code(),
&SqlState::CONNECTION_FAILURE
| &SqlState::CONNECTION_EXCEPTION
| &SqlState::CONNECTION_DOES_NOT_EXIST
| &SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION,
)
}
}
impl ShouldRetryWakeCompute for postgres_client::error::DbError {
fn should_retry_wake_compute(&self) -> bool {
use postgres_client::error::SqlState;
// Here are errors that happens after the user successfully authenticated to the database.
// TODO: there are pgbouncer errors that should be retried, but they are not listed here.
let non_retriable_pg_errors = matches!(
self.code(),
&SqlState::TOO_MANY_CONNECTIONS
| &SqlState::OUT_OF_MEMORY
| &SqlState::SYNTAX_ERROR
| &SqlState::T_R_SERIALIZATION_FAILURE
| &SqlState::INVALID_CATALOG_NAME
| &SqlState::INVALID_SCHEMA_NAME
| &SqlState::INVALID_PARAMETER_VALUE,
);
if non_retriable_pg_errors {
return false;
}
// PGBouncer errors that should not trigger a wake_compute retry.
if self.code() == &SqlState::PROTOCOL_VIOLATION {
// Source for the error message:
// https://github.com/pgbouncer/pgbouncer/blob/f15997fe3effe3a94ba8bcc1ea562e6117d1a131/src/client.c#L1070
return !self
.message()
.contains("no more connections allowed (max_client_conn)");
}
true
}
}
impl CouldRetry for postgres_client::Error {
fn could_retry(&self) -> bool {
if let Some(io_err) = self.source().and_then(|x| x.downcast_ref()) {
io::Error::could_retry(io_err)
} else if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) {
postgres_client::error::DbError::could_retry(db_err)
} else {
false
}
}
}
impl ShouldRetryWakeCompute for postgres_client::Error {
fn should_retry_wake_compute(&self) -> bool {
if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) {
postgres_client::error::DbError::should_retry_wake_compute(db_err)
} else {
// likely an IO error. Possible the compute has shutdown and the
// cache is stale.
true
}
}
}
impl CouldRetry for compute::ConnectionError {
fn could_retry(&self) -> bool {
match self {
compute::ConnectionError::TlsError(err) => err.could_retry(),
compute::ConnectionError::WakeComputeError(err) => err.could_retry(),
compute::ConnectionError::TooManyConnectionAttempts(_) => false,
#[cfg(test)]
compute::ConnectionError::TestError { retryable, .. } => *retryable,
}
}
}
impl ShouldRetryWakeCompute for compute::ConnectionError {
fn should_retry_wake_compute(&self) -> bool {
match self {
// the cache entry was not checked for validity
compute::ConnectionError::TooManyConnectionAttempts(_) => false,
#[cfg(test)]
compute::ConnectionError::TestError { wakeable, .. } => *wakeable,
_ => true,
}
}
@@ -120,56 +59,3 @@ pub(crate) fn retry_after(num_retries: u32, config: RetryConfig) -> time::Durati
.base_delay
.mul_f64(config.backoff_factor.powi((num_retries as i32) - 1))
}
#[cfg(test)]
mod tests {
use postgres_client::error::{DbError, SqlState};
use super::ShouldRetryWakeCompute;
#[test]
fn should_retry_wake_compute_for_db_error() {
// These SQLStates should NOT trigger a wake_compute retry.
let non_retry_states = [
SqlState::TOO_MANY_CONNECTIONS,
SqlState::OUT_OF_MEMORY,
SqlState::SYNTAX_ERROR,
SqlState::T_R_SERIALIZATION_FAILURE,
SqlState::INVALID_CATALOG_NAME,
SqlState::INVALID_SCHEMA_NAME,
SqlState::INVALID_PARAMETER_VALUE,
];
for state in non_retry_states {
let err = DbError::new_test_error(state.clone(), "oops".to_string());
assert!(
!err.should_retry_wake_compute(),
"State {state:?} unexpectedly retried"
);
}
// Errors coming from pgbouncer should not trigger a wake_compute retry
let non_retry_pgbouncer_errors = ["no more connections allowed (max_client_conn)"];
for error in non_retry_pgbouncer_errors {
let err = DbError::new_test_error(SqlState::PROTOCOL_VIOLATION, error.to_string());
assert!(
!err.should_retry_wake_compute(),
"PGBouncer error {error:?} unexpectedly retried"
);
}
// These SQLStates should trigger a wake_compute retry.
let retry_states = [
SqlState::CONNECTION_FAILURE,
SqlState::CONNECTION_EXCEPTION,
SqlState::CONNECTION_DOES_NOT_EXIST,
SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION,
];
for state in retry_states {
let err = DbError::new_test_error(state.clone(), "oops".to_string());
assert!(
err.should_retry_wake_compute(),
"State {state:?} unexpectedly skipped retry"
);
}
}
}

View File

@@ -10,7 +10,7 @@ use async_trait::async_trait;
use http::StatusCode;
use postgres_client::config::SslMode;
use postgres_client::tls::{MakeTlsConnect, NoTls};
use retry::{ShouldRetryWakeCompute, retry_after};
use retry::retry_after;
use rstest::rstest;
use rustls::crypto::ring;
use rustls::pki_types;
@@ -20,6 +20,7 @@ use tracing_test::traced_test;
use super::retry::CouldRetry;
use super::*;
use crate::auth::backend::{ComputeUserInfo, MaybeOwned};
use crate::compute::ConnectionError;
use crate::config::{ComputeConfig, RetryConfig};
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
@@ -423,71 +424,36 @@ impl TestConnectMechanism {
#[derive(Debug)]
struct TestConnection;
#[derive(Debug)]
struct TestConnectError {
retryable: bool,
wakeable: bool,
kind: crate::error::ErrorKind,
}
impl ReportableError for TestConnectError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
self.kind
}
}
impl std::fmt::Display for TestConnectError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
impl std::error::Error for TestConnectError {}
impl CouldRetry for TestConnectError {
fn could_retry(&self) -> bool {
self.retryable
}
}
impl ShouldRetryWakeCompute for TestConnectError {
fn should_retry_wake_compute(&self) -> bool {
self.wakeable
}
}
#[async_trait]
impl ConnectMechanism for TestConnectMechanism {
type Connection = TestConnection;
type ConnectError = TestConnectError;
type Error = anyhow::Error;
async fn connect_once(
&self,
_ctx: &RequestContext,
_node_info: &control_plane::CachedNodeInfo,
_config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
) -> Result<Self::Connection, ConnectionError> {
let mut counter = self.counter.lock().unwrap();
let action = self.sequence[*counter];
*counter += 1;
match action {
ConnectAction::Connect => Ok(TestConnection),
ConnectAction::Retry => Err(TestConnectError {
ConnectAction::Retry => Err(ConnectionError::TestError {
retryable: true,
wakeable: true,
kind: ErrorKind::Compute,
}),
ConnectAction::RetryNoWake => Err(TestConnectError {
ConnectAction::RetryNoWake => Err(ConnectionError::TestError {
retryable: true,
wakeable: false,
kind: ErrorKind::Compute,
}),
ConnectAction::Fail => Err(TestConnectError {
ConnectAction::Fail => Err(ConnectionError::TestError {
retryable: false,
wakeable: true,
kind: ErrorKind::Compute,
}),
ConnectAction::FailNoWake => Err(TestConnectError {
ConnectAction::FailNoWake => Err(ConnectionError::TestError {
retryable: false,
wakeable: false,
kind: ErrorKind::Compute,

View File

@@ -1,17 +1,12 @@
use std::io;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use ed25519_dalek::SigningKey;
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
use jose_jwk::jose_b64;
use postgres_client::config::SslMode;
use postgres_client::SocketConfig;
use postgres_client::maybe_tls_stream::MaybeTlsStream;
use rand::rngs::OsRng;
use rustls::pki_types::{DnsName, ServerName};
use tokio::net::{TcpStream, lookup_host};
use tokio_rustls::TlsConnector;
use tracing::field::display;
use tracing::{debug, info};
@@ -23,21 +18,19 @@ use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnP
use crate::auth::backend::local::StaticAuthRules;
use crate::auth::backend::{ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo};
use crate::auth::{self, AuthError};
use crate::compute::{self, ComputeConnection};
use crate::compute_ctl::{
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
};
use crate::config::{ComputeConfig, ProxyConfig};
use crate::config::ProxyConfig;
use crate::context::RequestContext;
use crate::control_plane::CachedNodeInfo;
use crate::control_plane::client::ApiLockError;
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
use crate::control_plane::locks::ApiLocks;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::intern::EndpointIdInt;
use crate::proxy::connect_compute::ConnectMechanism;
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute};
use crate::proxy::connect_compute::TcpMechanism;
use crate::rate_limiter::EndpointRateLimiter;
use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX};
use crate::types::{EndpointId, LOCAL_PROXY_SUFFIX};
pub(crate) struct PoolingBackend {
pub(crate) http_conn_pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
@@ -157,11 +150,6 @@ impl PoolingBackend {
// Wake up the destination if needed. Code here is a bit involved because
// we reuse the code from the usual proxy and we need to prepare few structures
// that this code expects.
#[tracing::instrument(skip_all, fields(
pid = tracing::field::Empty,
compute_id = tracing::field::Empty,
conn_id = tracing::field::Empty,
))]
pub(crate) async fn connect_to_compute(
&self,
ctx: &RequestContext,
@@ -181,30 +169,24 @@ impl PoolingBackend {
return Ok(client);
}
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self.auth_backend.as_ref().map(|()| keys.info);
crate::proxy::connect_compute::connect_to_compute(
let connection = crate::proxy::connect_compute::connect_to_compute(
ctx,
&TokioMechanism {
conn_id,
conn_info,
pool: self.pool.clone(),
&TcpMechanism {
locks: &self.config.connect_compute_locks,
keys: keys.keys,
direct: false,
},
&backend,
self.config.wake_compute_retry_config,
&self.config.connect_to_compute,
)
.await
.await?;
authenticate(ctx, &self.pool, &conn_info, keys.keys, connection, conn_id).await
}
// Wake up the destination if needed
#[tracing::instrument(skip_all, fields(
compute_id = tracing::field::Empty,
conn_id = tracing::field::Empty,
))]
pub(crate) async fn connect_to_local_proxy(
&self,
ctx: &RequestContext,
@@ -216,7 +198,6 @@ impl PoolingBackend {
}
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
debug!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self.auth_backend.as_ref().map(|()| ComputeUserInfo {
user: conn_info.user_info.user.clone(),
@@ -226,19 +207,19 @@ impl PoolingBackend {
)),
options: conn_info.user_info.options.clone(),
});
crate::proxy::connect_compute::connect_to_compute(
let connection = crate::proxy::connect_compute::connect_to_compute(
ctx,
&HyperMechanism {
conn_id,
conn_info,
pool: self.http_conn_pool.clone(),
&TcpMechanism {
locks: &self.config.connect_compute_locks,
direct: true,
},
&backend,
self.config.wake_compute_retry_config,
&self.config.connect_to_compute,
)
.await
.await?;
h2handshake(ctx, &self.http_conn_pool, &conn_info, connection, conn_id).await
}
/// Connect to postgres over localhost.
@@ -248,10 +229,6 @@ impl PoolingBackend {
/// # Panics
///
/// Panics if called with a non-local_proxy backend.
#[tracing::instrument(skip_all, fields(
pid = tracing::field::Empty,
conn_id = tracing::field::Empty,
))]
pub(crate) async fn connect_to_local_postgres(
&self,
ctx: &RequestContext,
@@ -373,6 +350,8 @@ fn create_random_jwk() -> (SigningKey, jose_jwk::Key) {
pub(crate) enum HttpConnError {
#[error("pooled connection closed at inconsistent state")]
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
#[error("could not connect to compute")]
ConnectError(#[from] compute::ConnectionError),
#[error("could not connect to postgres in compute")]
PostgresConnectionError(#[from] postgres_client::Error),
#[error("could not connect to local-proxy in compute")]
@@ -394,8 +373,6 @@ pub(crate) enum HttpConnError {
#[derive(Debug, thiserror::Error)]
pub(crate) enum LocalProxyConnError {
#[error("error with connection to local-proxy")]
Io(#[source] std::io::Error),
#[error("could not establish h2 connection")]
H2(#[from] hyper::Error),
}
@@ -403,6 +380,7 @@ pub(crate) enum LocalProxyConnError {
impl ReportableError for HttpConnError {
fn get_error_kind(&self) -> ErrorKind {
match self {
HttpConnError::ConnectError(_) => ErrorKind::Compute,
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
HttpConnError::PostgresConnectionError(p) => p.get_error_kind(),
HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute,
@@ -419,6 +397,7 @@ impl ReportableError for HttpConnError {
impl UserFacingError for HttpConnError {
fn to_string_client(&self) -> String {
match self {
HttpConnError::ConnectError(p) => p.to_string_client(),
HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
HttpConnError::PostgresConnectionError(p) => p.to_string(),
HttpConnError::LocalProxyConnectionError(p) => p.to_string(),
@@ -434,36 +413,9 @@ impl UserFacingError for HttpConnError {
}
}
impl CouldRetry for HttpConnError {
fn could_retry(&self) -> bool {
match self {
HttpConnError::PostgresConnectionError(e) => e.could_retry(),
HttpConnError::LocalProxyConnectionError(e) => e.could_retry(),
HttpConnError::ComputeCtl(_) => false,
HttpConnError::ConnectionClosedAbruptly(_) => false,
HttpConnError::JwtPayloadError(_) => false,
HttpConnError::GetAuthInfo(_) => false,
HttpConnError::AuthError(_) => false,
HttpConnError::WakeCompute(_) => false,
HttpConnError::TooManyConnectionAttempts(_) => false,
}
}
}
impl ShouldRetryWakeCompute for HttpConnError {
fn should_retry_wake_compute(&self) -> bool {
match self {
HttpConnError::PostgresConnectionError(e) => e.should_retry_wake_compute(),
// we never checked cache validity
HttpConnError::TooManyConnectionAttempts(_) => false,
_ => true,
}
}
}
impl ReportableError for LocalProxyConnError {
fn get_error_kind(&self) -> ErrorKind {
match self {
LocalProxyConnError::Io(_) => ErrorKind::Compute,
LocalProxyConnError::H2(_) => ErrorKind::Compute,
}
}
@@ -475,208 +427,106 @@ impl UserFacingError for LocalProxyConnError {
}
}
impl CouldRetry for LocalProxyConnError {
fn could_retry(&self) -> bool {
match self {
LocalProxyConnError::Io(_) => false,
LocalProxyConnError::H2(_) => false,
}
}
}
impl ShouldRetryWakeCompute for LocalProxyConnError {
fn should_retry_wake_compute(&self) -> bool {
match self {
LocalProxyConnError::Io(_) => false,
LocalProxyConnError::H2(_) => false,
}
}
}
struct TokioMechanism {
pool: Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
async fn authenticate(
ctx: &RequestContext,
pool: &Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
conn_info: &ConnInfo,
keys: ComputeCredentialKeys,
/// connect_to_compute concurrency lock
locks: &'static ApiLocks<Host>,
}
#[async_trait]
impl ConnectMechanism for TokioMechanism {
type Connection = Client<postgres_client::Client>;
type ConnectError = HttpConnError;
type Error = HttpConnError;
async fn connect_once(
&self,
ctx: &RequestContext,
node_info: &CachedNodeInfo,
compute_config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;
let mut config = node_info.conn_info.to_postgres_client_config();
let config = config
.user(&self.conn_info.user_info.user)
.dbname(&self.conn_info.dbname)
.connect_timeout(compute_config.timeout);
if let ComputeCredentialKeys::AuthKeys(auth_keys) = self.keys {
config.auth_keys(auth_keys);
}
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let res = config.connect(compute_config).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;
tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
tracing::Span::current().record(
"compute_id",
tracing::field::display(&node_info.aux.compute_id),
);
if let Some(query_id) = ctx.get_testodrome_id() {
info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id);
}
Ok(poll_client(
self.pool.clone(),
ctx,
self.conn_info.clone(),
client,
connection,
self.conn_id,
node_info.aux.clone(),
))
}
}
struct HyperMechanism {
pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
conn_info: ConnInfo,
compute: ComputeConnection,
conn_id: uuid::Uuid,
) -> Result<Client<postgres_client::Client>, HttpConnError> {
// client config with stubbed connect info.
let mut config = postgres_client::Config::new(String::new(), 0);
config
.user(&conn_info.user_info.user)
.dbname(&conn_info.dbname);
/// connect_to_compute concurrency lock
locks: &'static ApiLocks<Host>,
}
#[async_trait]
impl ConnectMechanism for HyperMechanism {
type Connection = http_conn_pool::Client<Send>;
type ConnectError = HttpConnError;
type Error = HttpConnError;
async fn connect_once(
&self,
ctx: &RequestContext,
node_info: &CachedNodeInfo,
config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
let host_addr = node_info.conn_info.host_addr;
let host = &node_info.conn_info.host;
let permit = self.locks.get_permit(host).await?;
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let tls = if node_info.conn_info.ssl_mode == SslMode::Disable {
None
} else {
Some(&config.tls)
};
let port = node_info.conn_info.port;
let res = connect_http2(host_addr, host, port, config.timeout, tls).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;
tracing::Span::current().record(
"compute_id",
tracing::field::display(&node_info.aux.compute_id),
);
if let Some(query_id) = ctx.get_testodrome_id() {
info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id);
}
Ok(poll_http2_client(
self.pool.clone(),
ctx,
&self.conn_info,
client,
connection,
self.conn_id,
node_info.aux.clone(),
))
if let ComputeCredentialKeys::AuthKeys(auth_keys) = keys {
config.auth_keys(auth_keys);
}
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let connection = config.authenticate(compute.stream).await?;
drop(pause);
// TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
info!(
compute_id = %compute.aux.compute_id,
pid = connection.process_id,
cold_start_info = ctx.cold_start_info().as_str(),
query_id = ctx.get_testodrome_id().as_deref(),
sslmode = ?compute.ssl_mode,
%conn_id,
"connected to compute node at {} ({}) latency={}",
compute.hostname,
compute.socket_addr,
ctx.get_proxy_latency(),
);
let (client, connection) = connection.into_managed_conn(
SocketConfig {
host_addr: Some(compute.socket_addr.ip()),
host: postgres_client::config::Host::Tcp(compute.hostname.to_string()),
port: compute.socket_addr.port(),
connect_timeout: None,
},
compute.ssl_mode,
);
Ok(poll_client(
pool.clone(),
ctx,
conn_info.clone(),
client,
connection,
conn_id,
compute.aux,
))
}
async fn connect_http2(
host_addr: Option<IpAddr>,
host: &str,
port: u16,
timeout: Duration,
tls: Option<&Arc<rustls::ClientConfig>>,
) -> Result<(http_conn_pool::Send, http_conn_pool::Connect), LocalProxyConnError> {
let addrs = match host_addr {
Some(addr) => vec![SocketAddr::new(addr, port)],
None => lookup_host((host, port))
.await
.map_err(LocalProxyConnError::Io)?
.collect(),
};
let mut last_err = None;
let mut addrs = addrs.into_iter();
let stream = loop {
let Some(addr) = addrs.next() else {
return Err(last_err.unwrap_or_else(|| {
LocalProxyConnError::Io(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve any addresses",
))
}));
};
match tokio::time::timeout(timeout, TcpStream::connect(addr)).await {
Ok(Ok(stream)) => {
stream.set_nodelay(true).map_err(LocalProxyConnError::Io)?;
break stream;
}
Ok(Err(e)) => {
last_err = Some(LocalProxyConnError::Io(e));
}
Err(e) => {
last_err = Some(LocalProxyConnError::Io(io::Error::new(
io::ErrorKind::TimedOut,
e,
)));
}
}
};
let stream = if let Some(tls) = tls {
let host = DnsName::try_from(host)
.map_err(io::Error::other)
.map_err(LocalProxyConnError::Io)?
.to_owned();
let stream = TlsConnector::from(tls.clone())
.connect(ServerName::DnsName(host), stream)
.await
.map_err(LocalProxyConnError::Io)?;
Box::pin(stream) as AsyncRW
} else {
Box::pin(stream) as AsyncRW
async fn h2handshake(
ctx: &RequestContext,
pool: &Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
conn_info: &ConnInfo,
compute: ComputeConnection,
conn_id: uuid::Uuid,
) -> Result<http_conn_pool::Client<Send>, HttpConnError> {
let stream = match compute.stream {
MaybeTlsStream::Raw(tcp) => Box::pin(tcp) as AsyncRW,
MaybeTlsStream::Tls(tls) => Box::into_pin(tls.0) as AsyncRW,
};
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (client, connection) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
.timer(TokioTimer::new())
.keep_alive_interval(Duration::from_secs(20))
.keep_alive_while_idle(true)
.keep_alive_timeout(Duration::from_secs(5))
.handshake(TokioIo::new(stream))
.await?;
.await
.map_err(LocalProxyConnError::H2)?;
drop(pause);
Ok((client, connection))
// TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
info!(
compute_id = %compute.aux.compute_id,
cold_start_info = ctx.cold_start_info().as_str(),
query_id = ctx.get_testodrome_id().as_deref(),
sslmode = ?compute.ssl_mode,
%conn_id,
"connected to compute node at {} ({}) latency={}",
compute.hostname,
compute.socket_addr,
ctx.get_proxy_latency(),
);
Ok(poll_http2_client(
pool.clone(),
ctx,
conn_info,
client,
connection,
conn_id,
compute.aux,
))
}

View File

@@ -69,7 +69,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
let mut session_id = ctx.session_id();
let (tx, mut rx) = tokio::sync::watch::channel(session_id);
let span = info_span!(parent: None, "connection", %conn_id);
let span = info_span!(parent: None, "connection", %conn_id, pid=client.get_process_id(), compute_id=%aux.compute_id);
let cold_start_info = ctx.cold_start_info();
span.in_scope(|| {
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");

View File

@@ -518,15 +518,14 @@ impl<C: ClientInnerExt> GlobalConnPool<C, EndpointConnPool<C>> {
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
return Ok(None);
}
tracing::Span::current()
.record("conn_id", tracing::field::display(client.get_conn_id()));
tracing::Span::current().record(
"pid",
tracing::field::display(client.inner.get_process_id()),
);
debug!(
info!(
conn_id = %client.get_conn_id(),
pid = client.inner.get_process_id(),
compute_id = &*client.aux.compute_id,
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"pool: reusing connection '{conn_info}'"
query_id = ctx.get_testodrome_id().as_deref(),
"reusing connection: latency={}",
ctx.get_proxy_latency(),
);
match client.get_data() {

View File

@@ -6,7 +6,7 @@ use hyper::client::conn::http2;
use hyper_util::rt::{TokioExecutor, TokioIo};
use parking_lot::RwLock;
use smol_str::ToSmolStr;
use tracing::{Instrument, debug, error, info, info_span};
use tracing::{Instrument, error, info, info_span};
use super::AsyncRW;
use super::backend::HttpConnError;
@@ -115,7 +115,6 @@ impl<C: ClientInnerExt + Clone> Drop for HttpConnPool<C> {
}
impl<C: ClientInnerExt + Clone> GlobalConnPool<C, HttpConnPool<C>> {
#[expect(unused_results)]
pub(crate) fn get(
self: &Arc<Self>,
ctx: &RequestContext,
@@ -132,10 +131,13 @@ impl<C: ClientInnerExt + Clone> GlobalConnPool<C, HttpConnPool<C>> {
return result;
};
tracing::Span::current().record("conn_id", tracing::field::display(client.conn.conn_id));
debug!(
info!(
conn_id = %client.conn.conn_id,
compute_id = &*client.conn.aux.compute_id,
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"pool: reusing connection '{conn_info}'"
query_id = ctx.get_testodrome_id().as_deref(),
"reusing connection: latency={}",
ctx.get_proxy_latency(),
);
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
ctx.success();
@@ -197,7 +199,7 @@ pub(crate) fn poll_http2_client(
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
let session_id = ctx.session_id();
let span = info_span!(parent: None, "connection", %conn_id);
let span = info_span!(parent: None, "connection", %conn_id, compute_id=%aux.compute_id);
let cold_start_info = ctx.cold_start_info();
span.in_scope(|| {
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
@@ -229,6 +231,8 @@ pub(crate) fn poll_http2_client(
tokio::spawn(
async move {
info!("new local proxy connection");
let _conn_gauge = conn_gauge;
let res = connection.await;
match res {

View File

@@ -30,7 +30,7 @@ use serde_json::value::RawValue;
use tokio::net::TcpStream;
use tokio::time::Instant;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info, info_span, warn};
use tracing::{Instrument, error, info, info_span, warn};
use super::backend::HttpConnError;
use super::conn_pool_lib::{
@@ -107,15 +107,13 @@ impl<C: ClientInnerExt> LocalConnPool<C> {
return Ok(None);
}
tracing::Span::current()
.record("conn_id", tracing::field::display(client.get_conn_id()));
tracing::Span::current().record(
"pid",
tracing::field::display(client.inner.get_process_id()),
);
debug!(
info!(
pid = client.inner.get_process_id(),
conn_id = %client.get_conn_id(),
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"local_pool: reusing connection '{conn_info}'"
query_id = ctx.get_testodrome_id().as_deref(),
"reusing connection: latency={}",
ctx.get_proxy_latency(),
);
match client.get_data() {

View File

@@ -60,7 +60,7 @@ mod private {
}
}
pub struct RustlsStream<S>(Box<TlsStream<S>>);
pub struct RustlsStream<S>(pub Box<TlsStream<S>>);
impl<S> postgres_client::tls::TlsStream for RustlsStream<S>
where