move h2::handshake outside of hypermechanism

This commit is contained in:
Conrad Ludgate
2025-06-24 16:45:06 +01:00
parent 16d9889a51
commit d6a5085664
8 changed files with 56 additions and 145 deletions

View File

@@ -288,6 +288,7 @@ impl ConnectInfo {
async fn connect_raw(
&self,
config: &ComputeConfig,
direct: bool,
) -> Result<(SocketAddr, MaybeTlsStream<TcpStream, RustlsStream>), TlsError> {
let timeout = config.timeout;
@@ -330,7 +331,7 @@ impl ConnectInfo {
match connect_once(&*addrs).await {
Ok((sockaddr, stream)) => Ok((
sockaddr,
tls::connect_tls(stream, self.ssl_mode, config, host).await?,
tls::connect_tls(stream, self.ssl_mode, config, host, direct).await?,
)),
Err(err) => {
warn!("couldn't connect to compute node at {host}:{port}: {err}");
@@ -373,9 +374,10 @@ impl ConnectInfo {
ctx: &RequestContext,
aux: &MetricsAuxInfo,
config: &ComputeConfig,
direct: bool,
) -> Result<ComputeConnection, ConnectionError> {
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (socket_addr, stream) = self.connect_raw(config).await?;
let (socket_addr, stream) = self.connect_raw(config, direct).await?;
drop(pause);
tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id));

View File

@@ -32,6 +32,7 @@ pub async fn connect_tls<S, T>(
mode: SslMode,
tls: &T,
host: &str,
direct: bool,
) -> Result<MaybeTlsStream<S, T::Stream>, TlsError>
where
S: AsyncRead + AsyncWrite + Unpin + Send,
@@ -46,7 +47,7 @@ where
SslMode::Prefer | SslMode::Require => {}
}
if !request_tls(&mut stream).await? {
if !direct && !request_tls(&mut stream).await? {
if SslMode::Require == mode {
return Err(TlsError::Required);
}

View File

@@ -222,6 +222,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
ctx,
&TcpMechanism {
locks: &config.connect_compute_locks,
direct: false,
},
&node_info,
config.wake_compute_retry_config,

View File

@@ -77,8 +77,9 @@ impl NodeInfo {
&self,
ctx: &RequestContext,
config: &ComputeConfig,
direct: bool,
) -> Result<compute::ComputeConnection, compute::ConnectionError> {
self.conn_info.connect(ctx, &self.aux, config).await
self.conn_info.connect(ctx, &self.aux, config, direct).await
}
}

View File

@@ -51,6 +51,8 @@ pub(crate) trait ConnectMechanism {
pub(crate) struct TcpMechanism {
/// connect_to_compute concurrency lock
pub(crate) locks: &'static ApiLocks<Host>,
// whether to negotiate TLS for postgres protocol.
pub(crate) direct: bool,
}
#[async_trait]
@@ -70,7 +72,7 @@ impl ConnectMechanism for TcpMechanism {
config: &ComputeConfig,
) -> Result<ComputeConnection, Self::Error> {
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;
permit.release_result(node_info.connect(ctx, config).await)
permit.release_result(node_info.connect(ctx, config, self.direct).await)
}
}

View File

@@ -358,6 +358,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
ctx,
&TcpMechanism {
locks: &config.connect_compute_locks,
direct: false,
},
&auth::Backend::ControlPlane(cplane, creds.info.clone()),
config.wake_compute_retry_config,

View File

@@ -1,18 +1,12 @@
use std::io;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use ed25519_dalek::SigningKey;
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
use jose_jwk::jose_b64;
use postgres_client::SocketConfig;
use postgres_client::config::SslMode;
use postgres_client::maybe_tls_stream::MaybeTlsStream;
use rand::rngs::OsRng;
use rustls::pki_types::{DnsName, ServerName};
use tokio::net::{TcpStream, lookup_host};
use tokio_rustls::TlsConnector;
use tracing::field::display;
use tracing::{debug, info};
@@ -28,18 +22,16 @@ use crate::compute::{self, ComputeConnection};
use crate::compute_ctl::{
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
};
use crate::config::{ComputeConfig, ProxyConfig};
use crate::config::ProxyConfig;
use crate::context::RequestContext;
use crate::control_plane::CachedNodeInfo;
use crate::control_plane::client::ApiLockError;
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
use crate::control_plane::locks::ApiLocks;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::intern::EndpointIdInt;
use crate::proxy::connect_compute::{ConnectMechanism, TcpMechanism};
use crate::proxy::connect_compute::TcpMechanism;
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute};
use crate::rate_limiter::EndpointRateLimiter;
use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX};
use crate::types::{EndpointId, LOCAL_PROXY_SUFFIX};
pub(crate) struct PoolingBackend {
pub(crate) http_conn_pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
@@ -190,6 +182,7 @@ impl PoolingBackend {
ctx,
&TcpMechanism {
locks: &self.config.connect_compute_locks,
direct: false,
},
&backend,
self.config.wake_compute_retry_config,
@@ -226,19 +219,19 @@ impl PoolingBackend {
)),
options: conn_info.user_info.options.clone(),
});
crate::proxy::connect_compute::connect_to_compute(
let connection = crate::proxy::connect_compute::connect_to_compute(
ctx,
&HyperMechanism {
conn_id,
conn_info,
pool: self.http_conn_pool.clone(),
&TcpMechanism {
locks: &self.config.connect_compute_locks,
direct: true,
},
&backend,
self.config.wake_compute_retry_config,
&self.config.connect_to_compute,
)
.await
.await?;
h2handshake(ctx, &self.http_conn_pool, &conn_info, connection, conn_id).await
}
/// Connect to postgres over localhost.
@@ -396,8 +389,6 @@ pub(crate) enum HttpConnError {
#[derive(Debug, thiserror::Error)]
pub(crate) enum LocalProxyConnError {
#[error("error with connection to local-proxy")]
Io(#[source] std::io::Error),
#[error("could not establish h2 connection")]
H2(#[from] hyper::Error),
}
@@ -468,7 +459,6 @@ impl ShouldRetryWakeCompute for HttpConnError {
impl ReportableError for LocalProxyConnError {
fn get_error_kind(&self) -> ErrorKind {
match self {
LocalProxyConnError::Io(_) => ErrorKind::Compute,
LocalProxyConnError::H2(_) => ErrorKind::Compute,
}
}
@@ -483,7 +473,6 @@ impl UserFacingError for LocalProxyConnError {
impl CouldRetry for LocalProxyConnError {
fn could_retry(&self) -> bool {
match self {
LocalProxyConnError::Io(_) => false,
LocalProxyConnError::H2(_) => false,
}
}
@@ -491,7 +480,6 @@ impl CouldRetry for LocalProxyConnError {
impl ShouldRetryWakeCompute for LocalProxyConnError {
fn should_retry_wake_compute(&self) -> bool {
match self {
LocalProxyConnError::Io(_) => false,
LocalProxyConnError::H2(_) => false,
}
}
@@ -542,130 +530,45 @@ async fn authenticate(
))
}
struct HyperMechanism {
pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
conn_info: ConnInfo,
async fn h2handshake(
ctx: &RequestContext,
pool: &Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
conn_info: &ConnInfo,
compute: ComputeConnection,
conn_id: uuid::Uuid,
/// connect_to_compute concurrency lock
locks: &'static ApiLocks<Host>,
}
#[async_trait]
impl ConnectMechanism for HyperMechanism {
type Connection = http_conn_pool::Client<Send>;
type ConnectError = HttpConnError;
type Error = HttpConnError;
async fn connect_once(
&self,
ctx: &RequestContext,
node_info: &CachedNodeInfo,
config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
let host_addr = node_info.conn_info.host_addr;
let host = &node_info.conn_info.host;
let permit = self.locks.get_permit(host).await?;
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let tls = if node_info.conn_info.ssl_mode == SslMode::Disable {
None
} else {
Some(&config.tls)
};
let port = node_info.conn_info.port;
let res = connect_http2(host_addr, host, port, config.timeout, tls).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;
tracing::Span::current().record(
"compute_id",
tracing::field::display(&node_info.aux.compute_id),
);
if let Some(query_id) = ctx.get_testodrome_id() {
info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id);
}
Ok(poll_http2_client(
self.pool.clone(),
ctx,
&self.conn_info,
client,
connection,
self.conn_id,
node_info.aux.clone(),
))
}
}
async fn connect_http2(
host_addr: Option<IpAddr>,
host: &str,
port: u16,
timeout: Duration,
tls: Option<&Arc<rustls::ClientConfig>>,
) -> Result<(http_conn_pool::Send, http_conn_pool::Connect), LocalProxyConnError> {
let addrs = match host_addr {
Some(addr) => vec![SocketAddr::new(addr, port)],
None => lookup_host((host, port))
.await
.map_err(LocalProxyConnError::Io)?
.collect(),
};
let mut last_err = None;
let mut addrs = addrs.into_iter();
let stream = loop {
let Some(addr) = addrs.next() else {
return Err(last_err.unwrap_or_else(|| {
LocalProxyConnError::Io(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve any addresses",
))
}));
};
match tokio::time::timeout(timeout, TcpStream::connect(addr)).await {
Ok(Ok(stream)) => {
stream.set_nodelay(true).map_err(LocalProxyConnError::Io)?;
break stream;
}
Ok(Err(e)) => {
last_err = Some(LocalProxyConnError::Io(e));
}
Err(e) => {
last_err = Some(LocalProxyConnError::Io(io::Error::new(
io::ErrorKind::TimedOut,
e,
)));
}
}
};
let stream = if let Some(tls) = tls {
let host = DnsName::try_from(host)
.map_err(io::Error::other)
.map_err(LocalProxyConnError::Io)?
.to_owned();
let stream = TlsConnector::from(tls.clone())
.connect(ServerName::DnsName(host), stream)
.await
.map_err(LocalProxyConnError::Io)?;
Box::pin(stream) as AsyncRW
} else {
Box::pin(stream) as AsyncRW
) -> Result<http_conn_pool::Client<Send>, HttpConnError> {
let stream = match compute.stream {
MaybeTlsStream::Raw(tcp) => Box::pin(tcp) as AsyncRW,
MaybeTlsStream::Tls(tls) => Box::into_pin(tls.0) as AsyncRW,
};
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (client, connection) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
.timer(TokioTimer::new())
.keep_alive_interval(Duration::from_secs(20))
.keep_alive_while_idle(true)
.keep_alive_timeout(Duration::from_secs(5))
.handshake(TokioIo::new(stream))
.await?;
.await
.map_err(LocalProxyConnError::H2)?;
drop(pause);
Ok((client, connection))
tracing::Span::current().record(
"compute_id",
tracing::field::display(&compute.aux.compute_id),
);
if let Some(query_id) = ctx.get_testodrome_id() {
info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id);
}
Ok(poll_http2_client(
pool.clone(),
ctx,
conn_info,
client,
connection,
conn_id,
compute.aux,
))
}

View File

@@ -60,7 +60,7 @@ mod private {
}
}
pub struct RustlsStream<S>(Box<TlsStream<S>>);
pub struct RustlsStream<S>(pub Box<TlsStream<S>>);
impl<S> postgres_client::tls::TlsStream for RustlsStream<S>
where