mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-10 15:02:56 +00:00
move h2::handshake outside of hypermechanism
This commit is contained in:
@@ -288,6 +288,7 @@ impl ConnectInfo {
|
||||
async fn connect_raw(
|
||||
&self,
|
||||
config: &ComputeConfig,
|
||||
direct: bool,
|
||||
) -> Result<(SocketAddr, MaybeTlsStream<TcpStream, RustlsStream>), TlsError> {
|
||||
let timeout = config.timeout;
|
||||
|
||||
@@ -330,7 +331,7 @@ impl ConnectInfo {
|
||||
match connect_once(&*addrs).await {
|
||||
Ok((sockaddr, stream)) => Ok((
|
||||
sockaddr,
|
||||
tls::connect_tls(stream, self.ssl_mode, config, host).await?,
|
||||
tls::connect_tls(stream, self.ssl_mode, config, host, direct).await?,
|
||||
)),
|
||||
Err(err) => {
|
||||
warn!("couldn't connect to compute node at {host}:{port}: {err}");
|
||||
@@ -373,9 +374,10 @@ impl ConnectInfo {
|
||||
ctx: &RequestContext,
|
||||
aux: &MetricsAuxInfo,
|
||||
config: &ComputeConfig,
|
||||
direct: bool,
|
||||
) -> Result<ComputeConnection, ConnectionError> {
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let (socket_addr, stream) = self.connect_raw(config).await?;
|
||||
let (socket_addr, stream) = self.connect_raw(config, direct).await?;
|
||||
drop(pause);
|
||||
|
||||
tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id));
|
||||
|
||||
@@ -32,6 +32,7 @@ pub async fn connect_tls<S, T>(
|
||||
mode: SslMode,
|
||||
tls: &T,
|
||||
host: &str,
|
||||
direct: bool,
|
||||
) -> Result<MaybeTlsStream<S, T::Stream>, TlsError>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
@@ -46,7 +47,7 @@ where
|
||||
SslMode::Prefer | SslMode::Require => {}
|
||||
}
|
||||
|
||||
if !request_tls(&mut stream).await? {
|
||||
if !direct && !request_tls(&mut stream).await? {
|
||||
if SslMode::Require == mode {
|
||||
return Err(TlsError::Required);
|
||||
}
|
||||
|
||||
@@ -222,6 +222,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
ctx,
|
||||
&TcpMechanism {
|
||||
locks: &config.connect_compute_locks,
|
||||
direct: false,
|
||||
},
|
||||
&node_info,
|
||||
config.wake_compute_retry_config,
|
||||
|
||||
@@ -77,8 +77,9 @@ impl NodeInfo {
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
config: &ComputeConfig,
|
||||
direct: bool,
|
||||
) -> Result<compute::ComputeConnection, compute::ConnectionError> {
|
||||
self.conn_info.connect(ctx, &self.aux, config).await
|
||||
self.conn_info.connect(ctx, &self.aux, config, direct).await
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -51,6 +51,8 @@ pub(crate) trait ConnectMechanism {
|
||||
pub(crate) struct TcpMechanism {
|
||||
/// connect_to_compute concurrency lock
|
||||
pub(crate) locks: &'static ApiLocks<Host>,
|
||||
// whether to negotiate TLS for postgres protocol.
|
||||
pub(crate) direct: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -70,7 +72,7 @@ impl ConnectMechanism for TcpMechanism {
|
||||
config: &ComputeConfig,
|
||||
) -> Result<ComputeConnection, Self::Error> {
|
||||
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;
|
||||
permit.release_result(node_info.connect(ctx, config).await)
|
||||
permit.release_result(node_info.connect(ctx, config, self.direct).await)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -358,6 +358,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
ctx,
|
||||
&TcpMechanism {
|
||||
locks: &config.connect_compute_locks,
|
||||
direct: false,
|
||||
},
|
||||
&auth::Backend::ControlPlane(cplane, creds.info.clone()),
|
||||
config.wake_compute_retry_config,
|
||||
|
||||
@@ -1,18 +1,12 @@
|
||||
use std::io;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use ed25519_dalek::SigningKey;
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
|
||||
use jose_jwk::jose_b64;
|
||||
use postgres_client::SocketConfig;
|
||||
use postgres_client::config::SslMode;
|
||||
use postgres_client::maybe_tls_stream::MaybeTlsStream;
|
||||
use rand::rngs::OsRng;
|
||||
use rustls::pki_types::{DnsName, ServerName};
|
||||
use tokio::net::{TcpStream, lookup_host};
|
||||
use tokio_rustls::TlsConnector;
|
||||
use tracing::field::display;
|
||||
use tracing::{debug, info};
|
||||
|
||||
@@ -28,18 +22,16 @@ use crate::compute::{self, ComputeConnection};
|
||||
use crate::compute_ctl::{
|
||||
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
|
||||
};
|
||||
use crate::config::{ComputeConfig, ProxyConfig};
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::CachedNodeInfo;
|
||||
use crate::control_plane::client::ApiLockError;
|
||||
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
|
||||
use crate::control_plane::locks::ApiLocks;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::proxy::connect_compute::{ConnectMechanism, TcpMechanism};
|
||||
use crate::proxy::connect_compute::TcpMechanism;
|
||||
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX};
|
||||
use crate::types::{EndpointId, LOCAL_PROXY_SUFFIX};
|
||||
|
||||
pub(crate) struct PoolingBackend {
|
||||
pub(crate) http_conn_pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
|
||||
@@ -190,6 +182,7 @@ impl PoolingBackend {
|
||||
ctx,
|
||||
&TcpMechanism {
|
||||
locks: &self.config.connect_compute_locks,
|
||||
direct: false,
|
||||
},
|
||||
&backend,
|
||||
self.config.wake_compute_retry_config,
|
||||
@@ -226,19 +219,19 @@ impl PoolingBackend {
|
||||
)),
|
||||
options: conn_info.user_info.options.clone(),
|
||||
});
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
let connection = crate::proxy::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
&HyperMechanism {
|
||||
conn_id,
|
||||
conn_info,
|
||||
pool: self.http_conn_pool.clone(),
|
||||
&TcpMechanism {
|
||||
locks: &self.config.connect_compute_locks,
|
||||
direct: true,
|
||||
},
|
||||
&backend,
|
||||
self.config.wake_compute_retry_config,
|
||||
&self.config.connect_to_compute,
|
||||
)
|
||||
.await
|
||||
.await?;
|
||||
|
||||
h2handshake(ctx, &self.http_conn_pool, &conn_info, connection, conn_id).await
|
||||
}
|
||||
|
||||
/// Connect to postgres over localhost.
|
||||
@@ -396,8 +389,6 @@ pub(crate) enum HttpConnError {
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum LocalProxyConnError {
|
||||
#[error("error with connection to local-proxy")]
|
||||
Io(#[source] std::io::Error),
|
||||
#[error("could not establish h2 connection")]
|
||||
H2(#[from] hyper::Error),
|
||||
}
|
||||
@@ -468,7 +459,6 @@ impl ShouldRetryWakeCompute for HttpConnError {
|
||||
impl ReportableError for LocalProxyConnError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
match self {
|
||||
LocalProxyConnError::Io(_) => ErrorKind::Compute,
|
||||
LocalProxyConnError::H2(_) => ErrorKind::Compute,
|
||||
}
|
||||
}
|
||||
@@ -483,7 +473,6 @@ impl UserFacingError for LocalProxyConnError {
|
||||
impl CouldRetry for LocalProxyConnError {
|
||||
fn could_retry(&self) -> bool {
|
||||
match self {
|
||||
LocalProxyConnError::Io(_) => false,
|
||||
LocalProxyConnError::H2(_) => false,
|
||||
}
|
||||
}
|
||||
@@ -491,7 +480,6 @@ impl CouldRetry for LocalProxyConnError {
|
||||
impl ShouldRetryWakeCompute for LocalProxyConnError {
|
||||
fn should_retry_wake_compute(&self) -> bool {
|
||||
match self {
|
||||
LocalProxyConnError::Io(_) => false,
|
||||
LocalProxyConnError::H2(_) => false,
|
||||
}
|
||||
}
|
||||
@@ -542,130 +530,45 @@ async fn authenticate(
|
||||
))
|
||||
}
|
||||
|
||||
struct HyperMechanism {
|
||||
pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
|
||||
conn_info: ConnInfo,
|
||||
async fn h2handshake(
|
||||
ctx: &RequestContext,
|
||||
pool: &Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
|
||||
conn_info: &ConnInfo,
|
||||
compute: ComputeConnection,
|
||||
conn_id: uuid::Uuid,
|
||||
|
||||
/// connect_to_compute concurrency lock
|
||||
locks: &'static ApiLocks<Host>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConnectMechanism for HyperMechanism {
|
||||
type Connection = http_conn_pool::Client<Send>;
|
||||
type ConnectError = HttpConnError;
|
||||
type Error = HttpConnError;
|
||||
|
||||
async fn connect_once(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
node_info: &CachedNodeInfo,
|
||||
config: &ComputeConfig,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
let host_addr = node_info.conn_info.host_addr;
|
||||
let host = &node_info.conn_info.host;
|
||||
let permit = self.locks.get_permit(host).await?;
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
|
||||
let tls = if node_info.conn_info.ssl_mode == SslMode::Disable {
|
||||
None
|
||||
} else {
|
||||
Some(&config.tls)
|
||||
};
|
||||
|
||||
let port = node_info.conn_info.port;
|
||||
let res = connect_http2(host_addr, host, port, config.timeout, tls).await;
|
||||
drop(pause);
|
||||
let (client, connection) = permit.release_result(res)?;
|
||||
|
||||
tracing::Span::current().record(
|
||||
"compute_id",
|
||||
tracing::field::display(&node_info.aux.compute_id),
|
||||
);
|
||||
|
||||
if let Some(query_id) = ctx.get_testodrome_id() {
|
||||
info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id);
|
||||
}
|
||||
|
||||
Ok(poll_http2_client(
|
||||
self.pool.clone(),
|
||||
ctx,
|
||||
&self.conn_info,
|
||||
client,
|
||||
connection,
|
||||
self.conn_id,
|
||||
node_info.aux.clone(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect_http2(
|
||||
host_addr: Option<IpAddr>,
|
||||
host: &str,
|
||||
port: u16,
|
||||
timeout: Duration,
|
||||
tls: Option<&Arc<rustls::ClientConfig>>,
|
||||
) -> Result<(http_conn_pool::Send, http_conn_pool::Connect), LocalProxyConnError> {
|
||||
let addrs = match host_addr {
|
||||
Some(addr) => vec![SocketAddr::new(addr, port)],
|
||||
None => lookup_host((host, port))
|
||||
.await
|
||||
.map_err(LocalProxyConnError::Io)?
|
||||
.collect(),
|
||||
};
|
||||
let mut last_err = None;
|
||||
|
||||
let mut addrs = addrs.into_iter();
|
||||
let stream = loop {
|
||||
let Some(addr) = addrs.next() else {
|
||||
return Err(last_err.unwrap_or_else(|| {
|
||||
LocalProxyConnError::Io(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"could not resolve any addresses",
|
||||
))
|
||||
}));
|
||||
};
|
||||
|
||||
match tokio::time::timeout(timeout, TcpStream::connect(addr)).await {
|
||||
Ok(Ok(stream)) => {
|
||||
stream.set_nodelay(true).map_err(LocalProxyConnError::Io)?;
|
||||
break stream;
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
last_err = Some(LocalProxyConnError::Io(e));
|
||||
}
|
||||
Err(e) => {
|
||||
last_err = Some(LocalProxyConnError::Io(io::Error::new(
|
||||
io::ErrorKind::TimedOut,
|
||||
e,
|
||||
)));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let stream = if let Some(tls) = tls {
|
||||
let host = DnsName::try_from(host)
|
||||
.map_err(io::Error::other)
|
||||
.map_err(LocalProxyConnError::Io)?
|
||||
.to_owned();
|
||||
let stream = TlsConnector::from(tls.clone())
|
||||
.connect(ServerName::DnsName(host), stream)
|
||||
.await
|
||||
.map_err(LocalProxyConnError::Io)?;
|
||||
Box::pin(stream) as AsyncRW
|
||||
} else {
|
||||
Box::pin(stream) as AsyncRW
|
||||
) -> Result<http_conn_pool::Client<Send>, HttpConnError> {
|
||||
let stream = match compute.stream {
|
||||
MaybeTlsStream::Raw(tcp) => Box::pin(tcp) as AsyncRW,
|
||||
MaybeTlsStream::Tls(tls) => Box::into_pin(tls.0) as AsyncRW,
|
||||
};
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let (client, connection) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
|
||||
.timer(TokioTimer::new())
|
||||
.keep_alive_interval(Duration::from_secs(20))
|
||||
.keep_alive_while_idle(true)
|
||||
.keep_alive_timeout(Duration::from_secs(5))
|
||||
.handshake(TokioIo::new(stream))
|
||||
.await?;
|
||||
.await
|
||||
.map_err(LocalProxyConnError::H2)?;
|
||||
drop(pause);
|
||||
|
||||
Ok((client, connection))
|
||||
tracing::Span::current().record(
|
||||
"compute_id",
|
||||
tracing::field::display(&compute.aux.compute_id),
|
||||
);
|
||||
|
||||
if let Some(query_id) = ctx.get_testodrome_id() {
|
||||
info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id);
|
||||
}
|
||||
|
||||
Ok(poll_http2_client(
|
||||
pool.clone(),
|
||||
ctx,
|
||||
conn_info,
|
||||
client,
|
||||
connection,
|
||||
conn_id,
|
||||
compute.aux,
|
||||
))
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ mod private {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RustlsStream<S>(Box<TlsStream<S>>);
|
||||
pub struct RustlsStream<S>(pub Box<TlsStream<S>>);
|
||||
|
||||
impl<S> postgres_client::tls::TlsStream for RustlsStream<S>
|
||||
where
|
||||
|
||||
Reference in New Issue
Block a user