diff --git a/libs/proxy/tokio-postgres2/src/connect.rs b/libs/proxy/tokio-postgres2/src/connect.rs index 4a07eccf9a..5724405054 100644 --- a/libs/proxy/tokio-postgres2/src/connect.rs +++ b/libs/proxy/tokio-postgres2/src/connect.rs @@ -1,12 +1,13 @@ use std::net::IpAddr; use postgres_protocol2::message::backend::Message; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; use tokio::sync::mpsc; use crate::client::SocketConfig; use crate::codec::BackendMessage; -use crate::config::Host; +use crate::config::{Host, SslMode}; use crate::connect_raw::connect_raw; use crate::connect_socket::connect_socket; use crate::connect_tls::connect_tls; @@ -46,13 +47,7 @@ where { let socket = connect_socket(host_addr, host, port, config.connect_timeout).await?; let stream = connect_tls(socket, config.ssl_mode, tls).await?; - let RawConnection { - stream, - parameters, - delayed_notice, - process_id, - secret_key, - } = connect_raw(stream, config).await?; + let raw = connect_raw(stream, config).await?; let socket_config = SocketConfig { host_addr, @@ -61,24 +56,46 @@ where connect_timeout: config.connect_timeout, }; - let (client_tx, conn_rx) = mpsc::unbounded_channel(); - let (conn_tx, client_rx) = mpsc::channel(4); - let client = Client::new( - client_tx, - client_rx, - socket_config, - config.ssl_mode, - process_id, - secret_key, - ); - - // delayed notices are always sent as "Async" messages. - let delayed = delayed_notice - .into_iter() - .map(|m| BackendMessage::Async(Message::NoticeResponse(m))) - .collect(); - - let connection = Connection::new(stream, delayed, parameters, conn_tx, conn_rx); - - Ok((client, connection)) + Ok(raw.into_managed_conn(socket_config, config.ssl_mode)) +} + +impl RawConnection +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + pub fn into_managed_conn( + self, + socket_config: SocketConfig, + ssl_mode: SslMode, + ) -> (Client, Connection) { + let RawConnection { + stream, + parameters, + delayed_notice, + process_id, + secret_key, + } = self; + + let (client_tx, conn_rx) = mpsc::unbounded_channel(); + let (conn_tx, client_rx) = mpsc::channel(4); + let client = Client::new( + client_tx, + client_rx, + socket_config, + ssl_mode, + process_id, + secret_key, + ); + + // delayed notices are always sent as "Async" messages. + let delayed = delayed_notice + .into_iter() + .map(|m| BackendMessage::Async(Message::NoticeResponse(m))) + .collect(); + + let connection = Connection::new(stream, delayed, parameters, conn_tx, conn_rx); + + (client, connection) + } } diff --git a/proxy/src/compute/mod.rs b/proxy/src/compute/mod.rs index 7fb88e6a45..0ec53867be 100644 --- a/proxy/src/compute/mod.rs +++ b/proxy/src/compute/mod.rs @@ -357,7 +357,7 @@ pub struct PostgresSettings { pub struct ComputeConnection { /// Socket connected to a compute node. - pub stream: MaybeTlsStream, + pub stream: MaybeRustlsStream, /// Labels for proxy's metrics. pub aux: MetricsAuxInfo, pub hostname: Host, diff --git a/proxy/src/compute/tls.rs b/proxy/src/compute/tls.rs index 000d75fca5..780d211786 100644 --- a/proxy/src/compute/tls.rs +++ b/proxy/src/compute/tls.rs @@ -11,8 +11,6 @@ use crate::proxy::retry::CouldRetry; #[derive(Debug, Error)] pub enum TlsError { - #[error(transparent)] - Dns(#[from] InvalidDnsNameError), #[error(transparent)] Connection(#[from] std::io::Error), #[error("TLS required but not provided")] @@ -22,7 +20,6 @@ pub enum TlsError { impl CouldRetry for TlsError { fn could_retry(&self) -> bool { match self { - TlsError::Dns(_) => false, TlsError::Connection(err) => err.could_retry(), // perhaps compute didn't realise it supports TLS? TlsError::Required => true, @@ -57,7 +54,6 @@ where return Ok(MaybeTlsStream::Raw(stream)); } - Ok(MaybeTlsStream::Tls( - tls.make_tls_connect(host)?.connect(stream).boxed().await?, - )) + let c = tls.make_tls_connect(host).map_err(std::io::Error::other)?; + Ok(MaybeTlsStream::Tls(c.connect(stream).boxed().await?)) } diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index 8c76d034f7..cf2afaccca 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -263,7 +263,12 @@ impl NeonControlPlaneClient { None => SslMode::Disable, }; let host = match body.server_name { - Some(host) => host.into(), + Some(host) => { + if rustls::pki_types::DnsName::try_from_str(&host).is_err() { + return Err(WakeComputeError::BadComputeAddress(host.into_boxed_str())); + } + host.into() + } None => host.into(), }; diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 26269d0a6e..df5598ade8 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -7,6 +7,7 @@ use async_trait::async_trait; use ed25519_dalek::SigningKey; use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; use jose_jwk::jose_b64; +use postgres_client::SocketConfig; use postgres_client::config::SslMode; use rand::rngs::OsRng; use rustls::pki_types::{DnsName, ServerName}; @@ -23,6 +24,7 @@ use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnP use crate::auth::backend::local::StaticAuthRules; use crate::auth::backend::{ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo}; use crate::auth::{self, AuthError}; +use crate::compute::{self, ComputeConnection}; use crate::compute_ctl::{ ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest, }; @@ -34,7 +36,7 @@ use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError}; use crate::control_plane::locks::ApiLocks; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::intern::EndpointIdInt; -use crate::proxy::connect_compute::ConnectMechanism; +use crate::proxy::connect_compute::{ConnectMechanism, TcpMechanism}; use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute}; use crate::rate_limiter::EndpointRateLimiter; use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX}; @@ -184,20 +186,18 @@ impl PoolingBackend { tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "pool: opening a new connection '{conn_info}'"); let backend = self.auth_backend.as_ref().map(|()| keys.info); - crate::proxy::connect_compute::connect_to_compute( + let connection = crate::proxy::connect_compute::connect_to_compute( ctx, - &TokioMechanism { - conn_id, - conn_info, - pool: self.pool.clone(), + &TcpMechanism { locks: &self.config.connect_compute_locks, - keys: keys.keys, }, &backend, self.config.wake_compute_retry_config, &self.config.connect_to_compute, ) - .await + .await?; + + authenticate(ctx, &self.pool, &conn_info, keys.keys, connection, conn_id).await } // Wake up the destination if needed @@ -373,6 +373,8 @@ fn create_random_jwk() -> (SigningKey, jose_jwk::Key) { pub(crate) enum HttpConnError { #[error("pooled connection closed at inconsistent state")] ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError), + #[error("could not connect to compute")] + ConnectError(#[from] compute::ConnectionError), #[error("could not connect to postgres in compute")] PostgresConnectionError(#[from] postgres_client::Error), #[error("could not connect to local-proxy in compute")] @@ -403,6 +405,7 @@ pub(crate) enum LocalProxyConnError { impl ReportableError for HttpConnError { fn get_error_kind(&self) -> ErrorKind { match self { + HttpConnError::ConnectError(_) => ErrorKind::Compute, HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute, HttpConnError::PostgresConnectionError(p) => p.get_error_kind(), HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute, @@ -419,6 +422,7 @@ impl ReportableError for HttpConnError { impl UserFacingError for HttpConnError { fn to_string_client(&self) -> String { match self { + HttpConnError::ConnectError(p) => p.to_string_client(), HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(), HttpConnError::PostgresConnectionError(p) => p.to_string(), HttpConnError::LocalProxyConnectionError(p) => p.to_string(), @@ -437,6 +441,7 @@ impl UserFacingError for HttpConnError { impl CouldRetry for HttpConnError { fn could_retry(&self) -> bool { match self { + HttpConnError::ConnectError(e) => e.could_retry(), HttpConnError::PostgresConnectionError(e) => e.could_retry(), HttpConnError::LocalProxyConnectionError(e) => e.could_retry(), HttpConnError::ComputeCtl(_) => false, @@ -492,65 +497,49 @@ impl ShouldRetryWakeCompute for LocalProxyConnError { } } -struct TokioMechanism { - pool: Arc>>, - conn_info: ConnInfo, - conn_id: uuid::Uuid, +async fn authenticate( + ctx: &RequestContext, + pool: &Arc>>, + conn_info: &ConnInfo, keys: ComputeCredentialKeys, + compute: ComputeConnection, + conn_id: uuid::Uuid, +) -> Result, HttpConnError> { + // client config with stubbed connect info. + let mut config = postgres_client::Config::new(String::new(), 0); + config + .user(&conn_info.user_info.user) + .dbname(&conn_info.dbname); - /// connect_to_compute concurrency lock - locks: &'static ApiLocks, -} - -#[async_trait] -impl ConnectMechanism for TokioMechanism { - type Connection = Client; - type ConnectError = HttpConnError; - type Error = HttpConnError; - - async fn connect_once( - &self, - ctx: &RequestContext, - node_info: &CachedNodeInfo, - compute_config: &ComputeConfig, - ) -> Result { - let permit = self.locks.get_permit(&node_info.conn_info.host).await?; - - let mut config = node_info.conn_info.to_postgres_client_config(); - let config = config - .user(&self.conn_info.user_info.user) - .dbname(&self.conn_info.dbname) - .connect_timeout(compute_config.timeout); - - if let ComputeCredentialKeys::AuthKeys(auth_keys) = self.keys { - config.auth_keys(auth_keys); - } - - let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let res = config.connect(compute_config).await; - drop(pause); - let (client, connection) = permit.release_result(res)?; - - tracing::Span::current().record("pid", tracing::field::display(client.get_process_id())); - tracing::Span::current().record( - "compute_id", - tracing::field::display(&node_info.aux.compute_id), - ); - - if let Some(query_id) = ctx.get_testodrome_id() { - info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id); - } - - Ok(poll_client( - self.pool.clone(), - ctx, - self.conn_info.clone(), - client, - connection, - self.conn_id, - node_info.aux.clone(), - )) + if let ComputeCredentialKeys::AuthKeys(auth_keys) = keys { + config.auth_keys(auth_keys); } + + let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); + let connection = config.authenticate(compute.stream).await?; + drop(pause); + + let (client, connection) = connection.into_managed_conn( + SocketConfig { + host_addr: Some(compute.socket_addr.ip()), + host: postgres_client::config::Host::Tcp(compute.hostname.to_string()), + port: compute.socket_addr.port(), + connect_timeout: None, + }, + compute.ssl_mode, + ); + + tracing::Span::current().record("pid", tracing::field::display(client.get_process_id())); + + Ok(poll_client( + pool.clone(), + ctx, + conn_info.clone(), + client, + connection, + conn_id, + compute.aux, + )) } struct HyperMechanism {