diff --git a/proxy/src/compute/mod.rs b/proxy/src/compute/mod.rs index 0ec53867be..4c03f422e1 100644 --- a/proxy/src/compute/mod.rs +++ b/proxy/src/compute/mod.rs @@ -288,6 +288,7 @@ impl ConnectInfo { async fn connect_raw( &self, config: &ComputeConfig, + direct: bool, ) -> Result<(SocketAddr, MaybeTlsStream), TlsError> { let timeout = config.timeout; @@ -330,7 +331,7 @@ impl ConnectInfo { match connect_once(&*addrs).await { Ok((sockaddr, stream)) => Ok(( sockaddr, - tls::connect_tls(stream, self.ssl_mode, config, host).await?, + tls::connect_tls(stream, self.ssl_mode, config, host, direct).await?, )), Err(err) => { warn!("couldn't connect to compute node at {host}:{port}: {err}"); @@ -373,9 +374,10 @@ impl ConnectInfo { ctx: &RequestContext, aux: &MetricsAuxInfo, config: &ComputeConfig, + direct: bool, ) -> Result { let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let (socket_addr, stream) = self.connect_raw(config).await?; + let (socket_addr, stream) = self.connect_raw(config, direct).await?; drop(pause); tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id)); diff --git a/proxy/src/compute/tls.rs b/proxy/src/compute/tls.rs index 780d211786..0c1cf449a0 100644 --- a/proxy/src/compute/tls.rs +++ b/proxy/src/compute/tls.rs @@ -32,6 +32,7 @@ pub async fn connect_tls( mode: SslMode, tls: &T, host: &str, + direct: bool, ) -> Result, TlsError> where S: AsyncRead + AsyncWrite + Unpin + Send, @@ -46,7 +47,7 @@ where SslMode::Prefer | SslMode::Require => {} } - if !request_tls(&mut stream).await? { + if !direct && !request_tls(&mut stream).await? { if SslMode::Require == mode { return Err(TlsError::Required); } diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 113a11beab..145b841770 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -222,6 +222,7 @@ pub(crate) async fn handle_client( ctx, &TcpMechanism { locks: &config.connect_compute_locks, + direct: false, }, &node_info, config.wake_compute_retry_config, diff --git a/proxy/src/control_plane/mod.rs b/proxy/src/control_plane/mod.rs index a8c59dad0c..9b9de1dbd7 100644 --- a/proxy/src/control_plane/mod.rs +++ b/proxy/src/control_plane/mod.rs @@ -77,8 +77,9 @@ impl NodeInfo { &self, ctx: &RequestContext, config: &ComputeConfig, + direct: bool, ) -> Result { - self.conn_info.connect(ctx, &self.aux, config).await + self.conn_info.connect(ctx, &self.aux, config, direct).await } } diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index aa675a439e..36c174030f 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -51,6 +51,8 @@ pub(crate) trait ConnectMechanism { pub(crate) struct TcpMechanism { /// connect_to_compute concurrency lock pub(crate) locks: &'static ApiLocks, + // whether to negotiate TLS for postgres protocol. + pub(crate) direct: bool, } #[async_trait] @@ -70,7 +72,7 @@ impl ConnectMechanism for TcpMechanism { config: &ComputeConfig, ) -> Result { let permit = self.locks.get_permit(&node_info.conn_info.host).await?; - permit.release_result(node_info.connect(ctx, config).await) + permit.release_result(node_info.connect(ctx, config, self.direct).await) } } diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 6947e07488..f7a28d770e 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -358,6 +358,7 @@ pub(crate) async fn handle_client( ctx, &TcpMechanism { locks: &config.connect_compute_locks, + direct: false, }, &auth::Backend::ControlPlane(cplane, creds.info.clone()), config.wake_compute_retry_config, diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index df5598ade8..57457c337e 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -1,18 +1,12 @@ -use std::io; -use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; use std::time::Duration; -use async_trait::async_trait; use ed25519_dalek::SigningKey; use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; use jose_jwk::jose_b64; use postgres_client::SocketConfig; -use postgres_client::config::SslMode; +use postgres_client::maybe_tls_stream::MaybeTlsStream; use rand::rngs::OsRng; -use rustls::pki_types::{DnsName, ServerName}; -use tokio::net::{TcpStream, lookup_host}; -use tokio_rustls::TlsConnector; use tracing::field::display; use tracing::{debug, info}; @@ -28,18 +22,16 @@ use crate::compute::{self, ComputeConnection}; use crate::compute_ctl::{ ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest, }; -use crate::config::{ComputeConfig, ProxyConfig}; +use crate::config::ProxyConfig; use crate::context::RequestContext; -use crate::control_plane::CachedNodeInfo; use crate::control_plane::client::ApiLockError; use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError}; -use crate::control_plane::locks::ApiLocks; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::intern::EndpointIdInt; -use crate::proxy::connect_compute::{ConnectMechanism, TcpMechanism}; +use crate::proxy::connect_compute::TcpMechanism; use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute}; use crate::rate_limiter::EndpointRateLimiter; -use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX}; +use crate::types::{EndpointId, LOCAL_PROXY_SUFFIX}; pub(crate) struct PoolingBackend { pub(crate) http_conn_pool: Arc>>, @@ -190,6 +182,7 @@ impl PoolingBackend { ctx, &TcpMechanism { locks: &self.config.connect_compute_locks, + direct: false, }, &backend, self.config.wake_compute_retry_config, @@ -226,19 +219,19 @@ impl PoolingBackend { )), options: conn_info.user_info.options.clone(), }); - crate::proxy::connect_compute::connect_to_compute( + let connection = crate::proxy::connect_compute::connect_to_compute( ctx, - &HyperMechanism { - conn_id, - conn_info, - pool: self.http_conn_pool.clone(), + &TcpMechanism { locks: &self.config.connect_compute_locks, + direct: true, }, &backend, self.config.wake_compute_retry_config, &self.config.connect_to_compute, ) - .await + .await?; + + h2handshake(ctx, &self.http_conn_pool, &conn_info, connection, conn_id).await } /// Connect to postgres over localhost. @@ -396,8 +389,6 @@ pub(crate) enum HttpConnError { #[derive(Debug, thiserror::Error)] pub(crate) enum LocalProxyConnError { - #[error("error with connection to local-proxy")] - Io(#[source] std::io::Error), #[error("could not establish h2 connection")] H2(#[from] hyper::Error), } @@ -468,7 +459,6 @@ impl ShouldRetryWakeCompute for HttpConnError { impl ReportableError for LocalProxyConnError { fn get_error_kind(&self) -> ErrorKind { match self { - LocalProxyConnError::Io(_) => ErrorKind::Compute, LocalProxyConnError::H2(_) => ErrorKind::Compute, } } @@ -483,7 +473,6 @@ impl UserFacingError for LocalProxyConnError { impl CouldRetry for LocalProxyConnError { fn could_retry(&self) -> bool { match self { - LocalProxyConnError::Io(_) => false, LocalProxyConnError::H2(_) => false, } } @@ -491,7 +480,6 @@ impl CouldRetry for LocalProxyConnError { impl ShouldRetryWakeCompute for LocalProxyConnError { fn should_retry_wake_compute(&self) -> bool { match self { - LocalProxyConnError::Io(_) => false, LocalProxyConnError::H2(_) => false, } } @@ -542,130 +530,45 @@ async fn authenticate( )) } -struct HyperMechanism { - pool: Arc>>, - conn_info: ConnInfo, +async fn h2handshake( + ctx: &RequestContext, + pool: &Arc>>, + conn_info: &ConnInfo, + compute: ComputeConnection, conn_id: uuid::Uuid, - - /// connect_to_compute concurrency lock - locks: &'static ApiLocks, -} - -#[async_trait] -impl ConnectMechanism for HyperMechanism { - type Connection = http_conn_pool::Client; - type ConnectError = HttpConnError; - type Error = HttpConnError; - - async fn connect_once( - &self, - ctx: &RequestContext, - node_info: &CachedNodeInfo, - config: &ComputeConfig, - ) -> Result { - let host_addr = node_info.conn_info.host_addr; - let host = &node_info.conn_info.host; - let permit = self.locks.get_permit(host).await?; - - let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - - let tls = if node_info.conn_info.ssl_mode == SslMode::Disable { - None - } else { - Some(&config.tls) - }; - - let port = node_info.conn_info.port; - let res = connect_http2(host_addr, host, port, config.timeout, tls).await; - drop(pause); - let (client, connection) = permit.release_result(res)?; - - tracing::Span::current().record( - "compute_id", - tracing::field::display(&node_info.aux.compute_id), - ); - - if let Some(query_id) = ctx.get_testodrome_id() { - info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id); - } - - Ok(poll_http2_client( - self.pool.clone(), - ctx, - &self.conn_info, - client, - connection, - self.conn_id, - node_info.aux.clone(), - )) - } -} - -async fn connect_http2( - host_addr: Option, - host: &str, - port: u16, - timeout: Duration, - tls: Option<&Arc>, -) -> Result<(http_conn_pool::Send, http_conn_pool::Connect), LocalProxyConnError> { - let addrs = match host_addr { - Some(addr) => vec![SocketAddr::new(addr, port)], - None => lookup_host((host, port)) - .await - .map_err(LocalProxyConnError::Io)? - .collect(), - }; - let mut last_err = None; - - let mut addrs = addrs.into_iter(); - let stream = loop { - let Some(addr) = addrs.next() else { - return Err(last_err.unwrap_or_else(|| { - LocalProxyConnError::Io(io::Error::new( - io::ErrorKind::InvalidInput, - "could not resolve any addresses", - )) - })); - }; - - match tokio::time::timeout(timeout, TcpStream::connect(addr)).await { - Ok(Ok(stream)) => { - stream.set_nodelay(true).map_err(LocalProxyConnError::Io)?; - break stream; - } - Ok(Err(e)) => { - last_err = Some(LocalProxyConnError::Io(e)); - } - Err(e) => { - last_err = Some(LocalProxyConnError::Io(io::Error::new( - io::ErrorKind::TimedOut, - e, - ))); - } - } - }; - - let stream = if let Some(tls) = tls { - let host = DnsName::try_from(host) - .map_err(io::Error::other) - .map_err(LocalProxyConnError::Io)? - .to_owned(); - let stream = TlsConnector::from(tls.clone()) - .connect(ServerName::DnsName(host), stream) - .await - .map_err(LocalProxyConnError::Io)?; - Box::pin(stream) as AsyncRW - } else { - Box::pin(stream) as AsyncRW +) -> Result, HttpConnError> { + let stream = match compute.stream { + MaybeTlsStream::Raw(tcp) => Box::pin(tcp) as AsyncRW, + MaybeTlsStream::Tls(tls) => Box::into_pin(tls.0) as AsyncRW, }; + let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); let (client, connection) = hyper::client::conn::http2::Builder::new(TokioExecutor::new()) .timer(TokioTimer::new()) .keep_alive_interval(Duration::from_secs(20)) .keep_alive_while_idle(true) .keep_alive_timeout(Duration::from_secs(5)) .handshake(TokioIo::new(stream)) - .await?; + .await + .map_err(LocalProxyConnError::H2)?; + drop(pause); - Ok((client, connection)) + tracing::Span::current().record( + "compute_id", + tracing::field::display(&compute.aux.compute_id), + ); + + if let Some(query_id) = ctx.get_testodrome_id() { + info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id); + } + + Ok(poll_http2_client( + pool.clone(), + ctx, + conn_info, + client, + connection, + conn_id, + compute.aux, + )) } diff --git a/proxy/src/tls/postgres_rustls.rs b/proxy/src/tls/postgres_rustls.rs index 9269ad8a06..def3afa6af 100644 --- a/proxy/src/tls/postgres_rustls.rs +++ b/proxy/src/tls/postgres_rustls.rs @@ -60,7 +60,7 @@ mod private { } } - pub struct RustlsStream(Box>); + pub struct RustlsStream(pub Box>); impl postgres_client::tls::TlsStream for RustlsStream where