From 4d99b6ff4d1e5ab87f198421bae8bab3948c6b66 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Fri, 6 Jun 2025 11:29:55 +0100 Subject: [PATCH] [proxy] separate compute connect from compute authentication (#12145) ## Problem PGLB/Neonkeeper needs to separate the concerns of connecting to compute, and authenticating to compute. Additionally, the code within `connect_to_compute` is rather messy, spending effort on recovering the authentication info after wake_compute. ## Summary of changes Split `ConnCfg` into `ConnectInfo` and `AuthInfo`. `wake_compute` only returns `ConnectInfo` and `AuthInfo` is determined separately from the `handshake`/`authenticate` process. Additionally, `ConnectInfo::connect_raw` is in-charge or establishing the TLS connection, and the `postgres_client::Config::connect_raw` is configured to use `NoTls` which will force it to skip the TLS negotiation. This should just work. --- .../proxy/tokio-postgres2/src/cancel_query.rs | 2 +- libs/proxy/tokio-postgres2/src/config.rs | 3 +- libs/proxy/tokio-postgres2/src/connect.rs | 2 +- libs/proxy/tokio-postgres2/src/tls.rs | 4 +- proxy/src/auth/backend/classic.rs | 5 - proxy/src/auth/backend/console_redirect.rs | 59 +++-- proxy/src/auth/backend/local.rs | 10 +- proxy/src/auth/backend/mod.rs | 9 - proxy/src/auth/flow.rs | 7 - proxy/src/cancellation.rs | 7 +- proxy/src/{compute.rs => compute/mod.rs} | 202 ++++++++++-------- proxy/src/compute/tls.rs | 63 ++++++ proxy/src/console_redirect_proxy.rs | 6 +- .../control_plane/client/cplane_proxy_v1.rs | 24 +-- proxy/src/control_plane/client/mock.rs | 46 ++-- proxy/src/control_plane/mod.rs | 30 +-- proxy/src/pglb/connect_compute.rs | 47 ++-- proxy/src/pqproto.rs | 77 ++++--- proxy/src/proxy/mod.rs | 16 +- proxy/src/proxy/retry.rs | 4 +- proxy/src/proxy/tests/mod.rs | 26 ++- proxy/src/serverless/backend.rs | 39 ++-- proxy/src/serverless/conn_pool.rs | 4 +- proxy/src/tls/postgres_rustls.rs | 46 ++-- 24 files changed, 382 insertions(+), 356 deletions(-) rename proxy/src/{compute.rs => compute/mod.rs} (68%) create mode 100644 proxy/src/compute/tls.rs diff --git a/libs/proxy/tokio-postgres2/src/cancel_query.rs b/libs/proxy/tokio-postgres2/src/cancel_query.rs index 0bdad0b554..4c2a5ef50f 100644 --- a/libs/proxy/tokio-postgres2/src/cancel_query.rs +++ b/libs/proxy/tokio-postgres2/src/cancel_query.rs @@ -10,7 +10,7 @@ use crate::{Error, cancel_query_raw, connect_socket}; pub(crate) async fn cancel_query( config: Option, ssl_mode: SslMode, - mut tls: T, + tls: T, process_id: i32, secret_key: i32, ) -> Result<(), Error> diff --git a/libs/proxy/tokio-postgres2/src/config.rs b/libs/proxy/tokio-postgres2/src/config.rs index 978d348741..243a5bc725 100644 --- a/libs/proxy/tokio-postgres2/src/config.rs +++ b/libs/proxy/tokio-postgres2/src/config.rs @@ -17,7 +17,6 @@ use crate::{Client, Connection, Error}; /// TLS configuration. #[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] -#[non_exhaustive] pub enum SslMode { /// Do not use TLS. Disable, @@ -231,7 +230,7 @@ impl Config { /// Requires the `runtime` Cargo feature (enabled by default). pub async fn connect( &self, - tls: T, + tls: &T, ) -> Result<(Client, Connection), Error> where T: MakeTlsConnect, diff --git a/libs/proxy/tokio-postgres2/src/connect.rs b/libs/proxy/tokio-postgres2/src/connect.rs index 39a0a87c74..f7bc863337 100644 --- a/libs/proxy/tokio-postgres2/src/connect.rs +++ b/libs/proxy/tokio-postgres2/src/connect.rs @@ -13,7 +13,7 @@ use crate::tls::{MakeTlsConnect, TlsConnect}; use crate::{Client, Config, Connection, Error, RawConnection}; pub async fn connect( - mut tls: T, + tls: &T, config: &Config, ) -> Result<(Client, Connection), Error> where diff --git a/libs/proxy/tokio-postgres2/src/tls.rs b/libs/proxy/tokio-postgres2/src/tls.rs index 41b51368ff..f9cbcf4991 100644 --- a/libs/proxy/tokio-postgres2/src/tls.rs +++ b/libs/proxy/tokio-postgres2/src/tls.rs @@ -47,7 +47,7 @@ pub trait MakeTlsConnect { /// Creates a new `TlsConnect`or. /// /// The domain name is provided for certificate verification and SNI. - fn make_tls_connect(&mut self, domain: &str) -> Result; + fn make_tls_connect(&self, domain: &str) -> Result; } /// An asynchronous function wrapping a stream in a TLS session. @@ -85,7 +85,7 @@ impl MakeTlsConnect for NoTls { type TlsConnect = NoTls; type Error = NoTlsError; - fn make_tls_connect(&mut self, _: &str) -> Result { + fn make_tls_connect(&self, _: &str) -> Result { Ok(NoTls) } } diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index 8445368740..f35b3ecc05 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -18,11 +18,6 @@ pub(super) async fn authenticate( secret: AuthSecret, ) -> auth::Result { let scram_keys = match secret { - #[cfg(any(test, feature = "testing"))] - AuthSecret::Md5(_) => { - debug!("auth endpoint chooses MD5"); - return Err(auth::AuthError::MalformedPassword("MD5 not supported")); - } AuthSecret::Scram(secret) => { debug!("auth endpoint chooses SCRAM"); diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index c388848926..455d96c90a 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -6,10 +6,9 @@ use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{info, info_span}; -use super::ComputeCredentialKeys; -use crate::auth::IpPattern; use crate::auth::backend::ComputeUserInfo; use crate::cache::Cached; +use crate::compute::AuthInfo; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::client::cplane_proxy_v1; @@ -98,15 +97,11 @@ impl ConsoleRedirectBackend { ctx: &RequestContext, auth_config: &'static AuthenticationConfig, client: &mut PqStream, - ) -> auth::Result<( - ConsoleRedirectNodeInfo, - ComputeUserInfo, - Option>, - )> { + ) -> auth::Result<(ConsoleRedirectNodeInfo, AuthInfo, ComputeUserInfo)> { authenticate(ctx, auth_config, &self.console_uri, client) .await - .map(|(node_info, user_info, ip_allowlist)| { - (ConsoleRedirectNodeInfo(node_info), user_info, ip_allowlist) + .map(|(node_info, auth_info, user_info)| { + (ConsoleRedirectNodeInfo(node_info), auth_info, user_info) }) } } @@ -121,10 +116,6 @@ impl ComputeConnectBackend for ConsoleRedirectNodeInfo { ) -> Result { Ok(Cached::new_uncached(self.0.clone())) } - - fn get_keys(&self) -> &ComputeCredentialKeys { - &ComputeCredentialKeys::None - } } async fn authenticate( @@ -132,7 +123,7 @@ async fn authenticate( auth_config: &'static AuthenticationConfig, link_uri: &reqwest::Url, client: &mut PqStream, -) -> auth::Result<(NodeInfo, ComputeUserInfo, Option>)> { +) -> auth::Result<(NodeInfo, AuthInfo, ComputeUserInfo)> { ctx.set_auth_method(crate::context::AuthMethod::ConsoleRedirect); // registering waiter can fail if we get unlucky with rng. @@ -192,10 +183,24 @@ async fn authenticate( client.write_message(BeMessage::NoticeResponse("Connecting to database.")); - // This config should be self-contained, because we won't - // take username or dbname from client's startup message. - let mut config = compute::ConnCfg::new(db_info.host.to_string(), db_info.port); - config.dbname(&db_info.dbname).user(&db_info.user); + // Backwards compatibility. pg_sni_proxy uses "--" in domain names + // while direct connections do not. Once we migrate to pg_sni_proxy + // everywhere, we can remove this. + let ssl_mode = if db_info.host.contains("--") { + // we need TLS connection with SNI info to properly route it + SslMode::Require + } else { + SslMode::Disable + }; + + let conn_info = compute::ConnectInfo { + host: db_info.host.into(), + port: db_info.port, + ssl_mode, + host_addr: None, + }; + let auth_info = + AuthInfo::for_console_redirect(&db_info.dbname, &db_info.user, db_info.password.as_deref()); let user: RoleName = db_info.user.into(); let user_info = ComputeUserInfo { @@ -209,26 +214,12 @@ async fn authenticate( ctx.set_project(db_info.aux.clone()); info!("woken up a compute node"); - // Backwards compatibility. pg_sni_proxy uses "--" in domain names - // while direct connections do not. Once we migrate to pg_sni_proxy - // everywhere, we can remove this. - if db_info.host.contains("--") { - // we need TLS connection with SNI info to properly route it - config.ssl_mode(SslMode::Require); - } else { - config.ssl_mode(SslMode::Disable); - } - - if let Some(password) = db_info.password { - config.password(password.as_ref()); - } - Ok(( NodeInfo { - config, + conn_info, aux: db_info.aux, }, + auth_info, user_info, - db_info.allowed_ips, )) } diff --git a/proxy/src/auth/backend/local.rs b/proxy/src/auth/backend/local.rs index 7a6dceb194..2224f492b8 100644 --- a/proxy/src/auth/backend/local.rs +++ b/proxy/src/auth/backend/local.rs @@ -1,11 +1,12 @@ use std::net::SocketAddr; use arc_swap::ArcSwapOption; +use postgres_client::config::SslMode; use tokio::sync::Semaphore; use super::jwt::{AuthRule, FetchAuthRules}; use crate::auth::backend::jwt::FetchAuthRulesError; -use crate::compute::ConnCfg; +use crate::compute::ConnectInfo; use crate::compute_ctl::ComputeCtlApi; use crate::context::RequestContext; use crate::control_plane::NodeInfo; @@ -29,7 +30,12 @@ impl LocalBackend { api: http::Endpoint::new(compute_ctl, http::new_client()), }, node_info: NodeInfo { - config: ConnCfg::new(postgres_addr.ip().to_string(), postgres_addr.port()), + conn_info: ConnectInfo { + host_addr: Some(postgres_addr.ip()), + host: postgres_addr.ip().to_string().into(), + port: postgres_addr.port(), + ssl_mode: SslMode::Disable, + }, // TODO(conrad): make this better reflect compute info rather than endpoint info. aux: MetricsAuxInfo { endpoint_id: EndpointIdTag::get_interner().get_or_intern("local"), diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index f978f655c4..edc1ae06d9 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -168,8 +168,6 @@ impl ComputeUserInfo { #[cfg_attr(test, derive(Debug))] pub(crate) enum ComputeCredentialKeys { - #[cfg(any(test, feature = "testing"))] - Password(Vec), AuthKeys(AuthKeys), JwtPayload(Vec), None, @@ -419,13 +417,6 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials> { Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())), } } - - fn get_keys(&self) -> &ComputeCredentialKeys { - match self { - Self::ControlPlane(_, creds) => &creds.keys, - Self::Local(_) => &ComputeCredentialKeys::None, - } - } } #[cfg(test)] diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index 8fbc4577e9..c825d5bf4b 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -169,13 +169,6 @@ pub(crate) async fn validate_password_and_exchange( secret: AuthSecret, ) -> super::Result> { match secret { - #[cfg(any(test, feature = "testing"))] - AuthSecret::Md5(_) => { - // test only - Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password( - password.to_owned(), - ))) - } // perform scram authentication as both client and server to validate the keys AuthSecret::Scram(scram_secret) => { let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?; diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index d26641db46..cce4c1d3a0 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -24,7 +24,6 @@ use crate::pqproto::CancelKeyData; use crate::rate_limiter::LeakyBucketRateLimiter; use crate::redis::keys::KeyPrefix; use crate::redis::kv_ops::RedisKVClient; -use crate::tls::postgres_rustls::MakeRustlsConnect; type IpSubnetKey = IpNet; @@ -497,10 +496,8 @@ impl CancelClosure { ) -> Result<(), CancelError> { let socket = TcpStream::connect(self.socket_addr).await?; - let mut mk_tls = - crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone()); - let tls = >::make_tls_connect( - &mut mk_tls, + let tls = <_ as MakeTlsConnect>::make_tls_connect( + compute_config, &self.hostname, ) .map_err(|e| CancelError::IO(std::io::Error::other(e.to_string())))?; diff --git a/proxy/src/compute.rs b/proxy/src/compute/mod.rs similarity index 68% rename from proxy/src/compute.rs rename to proxy/src/compute/mod.rs index 2899f25129..0dacd15547 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute/mod.rs @@ -1,21 +1,24 @@ +mod tls; + use std::fmt::Debug; use std::io; -use std::net::SocketAddr; -use std::time::Duration; +use std::net::{IpAddr, SocketAddr}; use futures::{FutureExt, TryFutureExt}; use itertools::Itertools; +use postgres_client::config::{AuthKeys, SslMode}; +use postgres_client::maybe_tls_stream::MaybeTlsStream; use postgres_client::tls::MakeTlsConnect; -use postgres_client::{CancelToken, RawConnection}; +use postgres_client::{CancelToken, NoTls, RawConnection}; use postgres_protocol::message::backend::NoticeResponseBody; -use rustls::pki_types::InvalidDnsNameError; use thiserror::Error; use tokio::net::{TcpStream, lookup_host}; use tracing::{debug, error, info, warn}; -use crate::auth::backend::ComputeUserInfo; +use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo}; use crate::auth::parse_endpoint_param; use crate::cancellation::CancelClosure; +use crate::compute::tls::TlsError; use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::client::ApiLockError; @@ -25,7 +28,6 @@ use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumDbConnectionsGuard}; use crate::pqproto::StartupMessageParams; use crate::proxy::neon_option; -use crate::tls::postgres_rustls::MakeRustlsConnect; use crate::types::Host; pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node"; @@ -38,10 +40,7 @@ pub(crate) enum ConnectionError { Postgres(#[from] postgres_client::Error), #[error("{COULD_NOT_CONNECT}: {0}")] - CouldNotConnect(#[from] io::Error), - - #[error("{COULD_NOT_CONNECT}: {0}")] - TlsError(#[from] InvalidDnsNameError), + TlsError(#[from] TlsError), #[error("{COULD_NOT_CONNECT}: {0}")] WakeComputeError(#[from] WakeComputeError), @@ -73,7 +72,7 @@ impl UserFacingError for ConnectionError { ConnectionError::TooManyConnectionAttempts(_) => { "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned() } - _ => COULD_NOT_CONNECT.to_owned(), + ConnectionError::TlsError(_) => COULD_NOT_CONNECT.to_owned(), } } } @@ -85,7 +84,6 @@ impl ReportableError for ConnectionError { crate::error::ErrorKind::Postgres } ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute, - ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute, ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute, ConnectionError::WakeComputeError(e) => e.get_error_kind(), ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(), @@ -96,34 +94,85 @@ impl ReportableError for ConnectionError { /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`. pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>; -/// A config for establishing a connection to compute node. -/// Eventually, `postgres_client` will be replaced with something better. -/// Newtype allows us to implement methods on top of it. #[derive(Clone)] -pub(crate) struct ConnCfg(Box); +pub enum Auth { + /// Only used during console-redirect. + Password(Vec), + /// Used by sql-over-http, ws, tcp. + Scram(Box), +} + +/// A config for authenticating to the compute node. +pub(crate) struct AuthInfo { + /// None for local-proxy, as we use trust-based localhost auth. + /// Some for sql-over-http, ws, tcp, and in most cases for console-redirect. + /// Might be None for console-redirect, but that's only a consequence of testing environments ATM. + auth: Option, + server_params: StartupMessageParams, + + /// Console redirect sets user and database, we shouldn't re-use those from the params. + skip_db_user: bool, +} + +/// Contains only the data needed to establish a secure connection to compute. +#[derive(Clone)] +pub struct ConnectInfo { + pub host_addr: Option, + pub host: Host, + pub port: u16, + pub ssl_mode: SslMode, +} /// Creation and initialization routines. -impl ConnCfg { - pub(crate) fn new(host: String, port: u16) -> Self { - Self(Box::new(postgres_client::Config::new(host, port))) - } - - /// Reuse password or auth keys from the other config. - pub(crate) fn reuse_password(&mut self, other: Self) { - if let Some(password) = other.get_password() { - self.password(password); - } - - if let Some(keys) = other.get_auth_keys() { - self.auth_keys(keys); +impl AuthInfo { + pub(crate) fn for_console_redirect(db: &str, user: &str, pw: Option<&str>) -> Self { + let mut server_params = StartupMessageParams::default(); + server_params.insert("database", db); + server_params.insert("user", user); + Self { + auth: pw.map(|pw| Auth::Password(pw.as_bytes().to_owned())), + server_params, + skip_db_user: true, } } - pub(crate) fn get_host(&self) -> Host { - match self.0.get_host() { - postgres_client::config::Host::Tcp(s) => s.into(), + pub(crate) fn with_auth_keys(keys: &ComputeCredentialKeys) -> Self { + Self { + auth: match keys { + ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => { + Some(Auth::Scram(Box::new(*auth_keys))) + } + ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => None, + }, + server_params: StartupMessageParams::default(), + skip_db_user: false, } } +} + +impl ConnectInfo { + pub fn to_postgres_client_config(&self) -> postgres_client::Config { + let mut config = postgres_client::Config::new(self.host.to_string(), self.port); + config.ssl_mode(self.ssl_mode); + if let Some(host_addr) = self.host_addr { + config.set_host_addr(host_addr); + } + config + } +} + +impl AuthInfo { + fn enrich(&self, mut config: postgres_client::Config) -> postgres_client::Config { + match &self.auth { + Some(Auth::Scram(keys)) => config.auth_keys(AuthKeys::ScramSha256(**keys)), + Some(Auth::Password(pw)) => config.password(pw), + None => &mut config, + }; + for (k, v) in self.server_params.iter() { + config.set_param(k, v); + } + config + } /// Apply startup message params to the connection config. pub(crate) fn set_startup_params( @@ -132,27 +181,26 @@ impl ConnCfg { arbitrary_params: bool, ) { if !arbitrary_params { - self.set_param("client_encoding", "UTF8"); + self.server_params.insert("client_encoding", "UTF8"); } for (k, v) in params.iter() { match k { // Only set `user` if it's not present in the config. // Console redirect auth flow takes username from the console's response. - "user" if self.user_is_set() => {} - "database" if self.db_is_set() => {} + "user" | "database" if self.skip_db_user => {} "options" => { if let Some(options) = filtered_options(v) { - self.set_param(k, &options); + self.server_params.insert(k, &options); } } "user" | "database" | "application_name" | "replication" => { - self.set_param(k, v); + self.server_params.insert(k, v); } // if we allow arbitrary params, then we forward them through. // this is a flag for a period of backwards compatibility k if arbitrary_params => { - self.set_param(k, v); + self.server_params.insert(k, v); } _ => {} } @@ -160,25 +208,13 @@ impl ConnCfg { } } -impl std::ops::Deref for ConnCfg { - type Target = postgres_client::Config; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -/// For now, let's make it easier to setup the config. -impl std::ops::DerefMut for ConnCfg { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -impl ConnCfg { - /// Establish a raw TCP connection to the compute node. - async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> { - use postgres_client::config::Host; +impl ConnectInfo { + /// Establish a raw TCP+TLS connection to the compute node. + async fn connect_raw( + &self, + config: &ComputeConfig, + ) -> Result<(SocketAddr, MaybeTlsStream), TlsError> { + let timeout = config.timeout; // wrap TcpStream::connect with timeout let connect_with_timeout = |addrs| { @@ -208,34 +244,32 @@ impl ConnCfg { // We can't reuse connection establishing logic from `postgres_client` here, // because it has no means for extracting the underlying socket which we // require for our business. - let port = self.0.get_port(); - let host = self.0.get_host(); + let port = self.port; + let host = &*self.host; - let host = match host { - Host::Tcp(host) => host.as_str(), - }; - - let addrs = match self.0.get_host_addr() { + let addrs = match self.host_addr { Some(addr) => vec![SocketAddr::new(addr, port)], None => lookup_host((host, port)).await?.collect(), }; match connect_once(&*addrs).await { - Ok((sockaddr, stream)) => Ok((sockaddr, stream, host)), + Ok((sockaddr, stream)) => Ok(( + sockaddr, + tls::connect_tls(stream, self.ssl_mode, config, host).await?, + )), Err(err) => { warn!("couldn't connect to compute node at {host}:{port}: {err}"); - Err(err) + Err(TlsError::Connection(err)) } } } } -type RustlsStream = >::Stream; +type RustlsStream = >::Stream; pub(crate) struct PostgresConnection { /// Socket connected to a compute node. - pub(crate) stream: - postgres_client::maybe_tls_stream::MaybeTlsStream, + pub(crate) stream: MaybeTlsStream, /// PostgreSQL connection parameters. pub(crate) params: std::collections::HashMap, /// Query cancellation token. @@ -248,28 +282,23 @@ pub(crate) struct PostgresConnection { _guage: NumDbConnectionsGuard<'static>, } -impl ConnCfg { +impl ConnectInfo { /// Connect to a corresponding compute node. pub(crate) async fn connect( &self, ctx: &RequestContext, aux: MetricsAuxInfo, + auth: &AuthInfo, config: &ComputeConfig, user_info: ComputeUserInfo, ) -> Result { - let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let (socket_addr, stream, host) = self.connect_raw(config.timeout).await?; - drop(pause); + let mut tmp_config = auth.enrich(self.to_postgres_client_config()); + // we setup SSL early in `ConnectInfo::connect_raw`. + tmp_config.ssl_mode(SslMode::Disable); - let mut mk_tls = crate::tls::postgres_rustls::MakeRustlsConnect::new(config.tls.clone()); - let tls = >::make_tls_connect( - &mut mk_tls, - host, - )?; - - // connect_raw() will not use TLS if sslmode is "disable" let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let connection = self.0.connect_raw(stream, tls).await?; + let (socket_addr, stream) = self.connect_raw(config).await?; + let connection = tmp_config.connect_raw(stream, NoTls).await?; drop(pause); let RawConnection { @@ -282,13 +311,14 @@ impl ConnCfg { tracing::Span::current().record("pid", tracing::field::display(process_id)); tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id)); - let stream = stream.into_inner(); + let MaybeTlsStream::Raw(stream) = stream.into_inner(); // TODO: lots of useful info but maybe we can move it elsewhere (eg traces?) info!( cold_start_info = ctx.cold_start_info().as_str(), - "connected to compute node at {host} ({socket_addr}) sslmode={:?}, latency={}, query_id={}", - self.0.get_ssl_mode(), + "connected to compute node at {} ({socket_addr}) sslmode={:?}, latency={}, query_id={}", + self.host, + self.ssl_mode, ctx.get_proxy_latency(), ctx.get_testodrome_id().unwrap_or_default(), ); @@ -299,11 +329,11 @@ impl ConnCfg { socket_addr, CancelToken { socket_config: None, - ssl_mode: self.0.get_ssl_mode(), + ssl_mode: self.ssl_mode, process_id, secret_key, }, - host.to_string(), + self.host.to_string(), user_info, ); diff --git a/proxy/src/compute/tls.rs b/proxy/src/compute/tls.rs new file mode 100644 index 0000000000..000d75fca5 --- /dev/null +++ b/proxy/src/compute/tls.rs @@ -0,0 +1,63 @@ +use futures::FutureExt; +use postgres_client::config::SslMode; +use postgres_client::maybe_tls_stream::MaybeTlsStream; +use postgres_client::tls::{MakeTlsConnect, TlsConnect}; +use rustls::pki_types::InvalidDnsNameError; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncWrite}; + +use crate::pqproto::request_tls; +use crate::proxy::retry::CouldRetry; + +#[derive(Debug, Error)] +pub enum TlsError { + #[error(transparent)] + Dns(#[from] InvalidDnsNameError), + #[error(transparent)] + Connection(#[from] std::io::Error), + #[error("TLS required but not provided")] + Required, +} + +impl CouldRetry for TlsError { + fn could_retry(&self) -> bool { + match self { + TlsError::Dns(_) => false, + TlsError::Connection(err) => err.could_retry(), + // perhaps compute didn't realise it supports TLS? + TlsError::Required => true, + } + } +} + +pub async fn connect_tls( + mut stream: S, + mode: SslMode, + tls: &T, + host: &str, +) -> Result, TlsError> +where + S: AsyncRead + AsyncWrite + Unpin + Send, + T: MakeTlsConnect< + S, + Error = InvalidDnsNameError, + TlsConnect: TlsConnect, + >, +{ + match mode { + SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)), + SslMode::Prefer | SslMode::Require => {} + } + + if !request_tls(&mut stream).await? { + if SslMode::Require == mode { + return Err(TlsError::Required); + } + + return Ok(MaybeTlsStream::Raw(stream)); + } + + Ok(MaybeTlsStream::Tls( + tls.make_tls_connect(host)?.connect(stream).boxed().await?, + )) +} diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index f2484b54b8..324dcf5824 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -210,20 +210,20 @@ pub(crate) async fn handle_client( ctx.set_db_options(params.clone()); - let (node_info, user_info, _ip_allowlist) = match backend + let (node_info, mut auth_info, user_info) = match backend .authenticate(ctx, &config.authentication_config, &mut stream) .await { Ok(auth_result) => auth_result, Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, }; + auth_info.set_startup_params(¶ms, true); let node = connect_to_compute( ctx, &TcpMechanism { user_info, - params_compat: true, - params: ¶ms, + auth: auth_info, locks: &config.connect_compute_locks, }, &node_info, diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index da548d6b2c..cf2d9fba14 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -261,24 +261,18 @@ impl NeonControlPlaneClient { Some(_) => SslMode::Require, None => SslMode::Disable, }; - let host_name = match body.server_name { - Some(host) => host, - None => host.to_owned(), + let host = match body.server_name { + Some(host) => host.into(), + None => host.into(), }; - // Don't set anything but host and port! This config will be cached. - // We'll set username and such later using the startup message. - // TODO: add more type safety (in progress). - let mut config = compute::ConnCfg::new(host_name, port); - - if let Some(addr) = host_addr { - config.set_host_addr(addr); - } - - config.ssl_mode(ssl_mode); - let node = NodeInfo { - config, + conn_info: compute::ConnectInfo { + host_addr, + host, + port, + ssl_mode, + }, aux: body.aux, }; diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index ece7153fce..aeea57f2fc 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -6,6 +6,7 @@ use std::str::FromStr; use std::sync::Arc; use futures::TryFutureExt; +use postgres_client::config::SslMode; use thiserror::Error; use tokio_postgres::Client; use tracing::{Instrument, error, info, info_span, warn}; @@ -14,6 +15,7 @@ use crate::auth::IpPattern; use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; use crate::cache::Cached; +use crate::compute::ConnectInfo; use crate::context::RequestContext; use crate::control_plane::errors::{ ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError, @@ -24,9 +26,9 @@ use crate::control_plane::{ RoleAccessControl, }; use crate::intern::RoleNameInt; +use crate::scram; use crate::types::{BranchId, EndpointId, ProjectId, RoleName}; use crate::url::ApiUrl; -use crate::{compute, scram}; #[derive(Debug, Error)] enum MockApiError { @@ -87,8 +89,7 @@ impl MockControlPlane { .await? { info!("got a secret: {entry}"); // safe since it's not a prod scenario - let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram); - secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5)) + scram::ServerSecret::parse(&entry).map(AuthSecret::Scram) } else { warn!("user '{role}' does not exist"); None @@ -170,25 +171,23 @@ impl MockControlPlane { async fn do_wake_compute(&self) -> Result { let port = self.endpoint.port().unwrap_or(5432); - let mut config = match self.endpoint.host_str() { - None => { - let mut config = compute::ConnCfg::new("localhost".to_string(), port); - config.set_host_addr(IpAddr::V4(Ipv4Addr::LOCALHOST)); - config - } - Some(host) => { - let mut config = compute::ConnCfg::new(host.to_string(), port); - if let Ok(addr) = IpAddr::from_str(host) { - config.set_host_addr(addr); - } - config - } + let conn_info = match self.endpoint.host_str() { + None => ConnectInfo { + host_addr: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)), + host: "localhost".into(), + port, + ssl_mode: SslMode::Disable, + }, + Some(host) => ConnectInfo { + host_addr: IpAddr::from_str(host).ok(), + host: host.into(), + port, + ssl_mode: SslMode::Disable, + }, }; - config.ssl_mode(postgres_client::config::SslMode::Disable); - let node = NodeInfo { - config, + conn_info, aux: MetricsAuxInfo { endpoint_id: (&EndpointId::from("endpoint")).into(), project_id: (&ProjectId::from("project")).into(), @@ -266,12 +265,3 @@ impl super::ControlPlaneApi for MockControlPlane { self.do_wake_compute().map_ok(Cached::new_uncached).await } } - -fn parse_md5(input: &str) -> Option<[u8; 16]> { - let text = input.strip_prefix("md5")?; - - let mut bytes = [0u8; 16]; - hex::decode_to_slice(text, &mut bytes).ok()?; - - Some(bytes) -} diff --git a/proxy/src/control_plane/mod.rs b/proxy/src/control_plane/mod.rs index 7ff093d9dc..ad10cf4257 100644 --- a/proxy/src/control_plane/mod.rs +++ b/proxy/src/control_plane/mod.rs @@ -11,8 +11,8 @@ pub(crate) mod errors; use std::sync::Arc; +use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; -use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo}; use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list}; use crate::cache::{Cached, TimedLru}; use crate::config::ComputeConfig; @@ -39,10 +39,6 @@ pub mod mgmt; /// Auth secret which is managed by the cloud. #[derive(Clone, Eq, PartialEq, Debug)] pub(crate) enum AuthSecret { - #[cfg(any(test, feature = "testing"))] - /// Md5 hash of user's password. - Md5([u8; 16]), - /// [SCRAM](crate::scram) authentication info. Scram(scram::ServerSecret), } @@ -63,13 +59,9 @@ pub(crate) struct AuthInfo { } /// Info for establishing a connection to a compute node. -/// This is what we get after auth succeeded, but not before! #[derive(Clone)] pub(crate) struct NodeInfo { - /// Compute node connection params. - /// It's sad that we have to clone this, but this will improve - /// once we migrate to a bespoke connection logic. - pub(crate) config: compute::ConnCfg, + pub(crate) conn_info: compute::ConnectInfo, /// Labels for proxy's metrics. pub(crate) aux: MetricsAuxInfo, @@ -79,26 +71,14 @@ impl NodeInfo { pub(crate) async fn connect( &self, ctx: &RequestContext, + auth: &compute::AuthInfo, config: &ComputeConfig, user_info: ComputeUserInfo, ) -> Result { - self.config - .connect(ctx, self.aux.clone(), config, user_info) + self.conn_info + .connect(ctx, self.aux.clone(), auth, config, user_info) .await } - - pub(crate) fn reuse_settings(&mut self, other: Self) { - self.config.reuse_password(other.config); - } - - pub(crate) fn set_keys(&mut self, keys: &ComputeCredentialKeys) { - match keys { - #[cfg(any(test, feature = "testing"))] - ComputeCredentialKeys::Password(password) => self.config.password(password), - ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys), - ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => &mut self.config, - }; - } } #[derive(Copy, Clone, Default)] diff --git a/proxy/src/pglb/connect_compute.rs b/proxy/src/pglb/connect_compute.rs index 1d6ca5fbb3..1807cdff0e 100644 --- a/proxy/src/pglb/connect_compute.rs +++ b/proxy/src/pglb/connect_compute.rs @@ -2,8 +2,8 @@ use async_trait::async_trait; use tokio::time; use tracing::{debug, info, warn}; -use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo}; -use crate::compute::{self, COULD_NOT_CONNECT, PostgresConnection}; +use crate::auth::backend::ComputeUserInfo; +use crate::compute::{self, AuthInfo, COULD_NOT_CONNECT, PostgresConnection}; use crate::config::{ComputeConfig, RetryConfig}; use crate::context::RequestContext; use crate::control_plane::errors::WakeComputeError; @@ -13,7 +13,6 @@ use crate::error::ReportableError; use crate::metrics::{ ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType, }; -use crate::pqproto::StartupMessageParams; use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry}; use crate::proxy::wake_compute::wake_compute; use crate::types::Host; @@ -48,8 +47,6 @@ pub(crate) trait ConnectMechanism { node_info: &control_plane::CachedNodeInfo, config: &ComputeConfig, ) -> Result; - - fn update_connect_config(&self, conf: &mut compute::ConnCfg); } #[async_trait] @@ -58,24 +55,17 @@ pub(crate) trait ComputeConnectBackend { &self, ctx: &RequestContext, ) -> Result; - - fn get_keys(&self) -> &ComputeCredentialKeys; } -pub(crate) struct TcpMechanism<'a> { - pub(crate) params_compat: bool, - - /// KV-dictionary with PostgreSQL connection params. - pub(crate) params: &'a StartupMessageParams, - +pub(crate) struct TcpMechanism { + pub(crate) auth: AuthInfo, /// connect_to_compute concurrency lock pub(crate) locks: &'static ApiLocks, - pub(crate) user_info: ComputeUserInfo, } #[async_trait] -impl ConnectMechanism for TcpMechanism<'_> { +impl ConnectMechanism for TcpMechanism { type Connection = PostgresConnection; type ConnectError = compute::ConnectionError; type Error = compute::ConnectionError; @@ -90,13 +80,12 @@ impl ConnectMechanism for TcpMechanism<'_> { node_info: &control_plane::CachedNodeInfo, config: &ComputeConfig, ) -> Result { - let host = node_info.config.get_host(); - let permit = self.locks.get_permit(&host).await?; - permit.release_result(node_info.connect(ctx, config, self.user_info.clone()).await) - } - - fn update_connect_config(&self, config: &mut compute::ConnCfg) { - config.set_startup_params(self.params, self.params_compat); + let permit = self.locks.get_permit(&node_info.conn_info.host).await?; + permit.release_result( + node_info + .connect(ctx, &self.auth, config, self.user_info.clone()) + .await, + ) } } @@ -114,12 +103,9 @@ where M::Error: From, { let mut num_retries = 0; - let mut node_info = + let node_info = wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?; - node_info.set_keys(user_info.get_keys()); - mechanism.update_connect_config(&mut node_info.config); - // try once let err = match mechanism.connect_once(ctx, &node_info, compute).await { Ok(res) => { @@ -155,14 +141,9 @@ where } else { // if we failed to connect, it's likely that the compute node was suspended, wake a new compute node debug!("compute node's state has likely changed; requesting a wake-up"); - let old_node_info = invalidate_cache(node_info); + invalidate_cache(node_info); // TODO: increment num_retries? - let mut node_info = - wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?; - node_info.reuse_settings(old_node_info); - - mechanism.update_connect_config(&mut node_info.config); - node_info + wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await? }; // now that we have a new node, try connect to it repeatedly. diff --git a/proxy/src/pqproto.rs b/proxy/src/pqproto.rs index 43074bf208..ad99eecda5 100644 --- a/proxy/src/pqproto.rs +++ b/proxy/src/pqproto.rs @@ -8,7 +8,7 @@ use std::io::{self, Cursor}; use bytes::{Buf, BufMut}; use itertools::Itertools; use rand::distributions::{Distribution, Standard}; -use tokio::io::{AsyncRead, AsyncReadExt}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian}; pub type ErrorCode = [u8; 5]; @@ -53,6 +53,28 @@ impl fmt::Debug for ProtocolVersion { } } +/// +const MAX_STARTUP_PACKET_LENGTH: usize = 10000; +const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234; +/// +const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678); +/// +const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679); +/// +const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680); + +/// This first reads the startup message header, is 8 bytes. +/// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number. +/// +/// The length value is inclusive of the header. For example, +/// an empty message will always have length 8. +#[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)] +#[repr(C)] +struct StartupHeader { + len: big_endian::U32, + version: ProtocolVersion, +} + /// read the type from the stream using zerocopy. /// /// not cancel safe. @@ -66,32 +88,38 @@ macro_rules! read { }}; } +/// Returns true if TLS is supported. +/// +/// This is not cancel safe. +pub async fn request_tls(stream: &mut S) -> io::Result +where + S: AsyncRead + AsyncWrite + Unpin, +{ + let payload = StartupHeader { + len: 8.into(), + version: NEGOTIATE_SSL_CODE, + }; + stream.write_all(payload.as_bytes()).await?; + stream.flush().await?; + + // we expect back either `S` or `N` as a single byte. + let mut res = *b"0"; + stream.read_exact(&mut res).await?; + + debug_assert!( + res == *b"S" || res == *b"N", + "unexpected SSL negotiation response: {}", + char::from(res[0]), + ); + + // S for SSL. + Ok(res == *b"S") +} + pub async fn read_startup(stream: &mut S) -> io::Result where S: AsyncRead + Unpin, { - /// - const MAX_STARTUP_PACKET_LENGTH: usize = 10000; - const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234; - /// - const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678); - /// - const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679); - /// - const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680); - - /// This first reads the startup message header, is 8 bytes. - /// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number. - /// - /// The length value is inclusive of the header. For example, - /// an empty message will always have length 8. - #[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)] - #[repr(C)] - struct StartupHeader { - len: big_endian::U32, - version: ProtocolVersion, - } - let header = read!(stream => StartupHeader); // @@ -564,9 +592,8 @@ mod tests { use tokio::io::{AsyncWriteExt, duplex}; use zerocopy::IntoBytes; - use crate::pqproto::{FeStartupPacket, read_message, read_startup}; - use super::ProtocolVersion; + use crate::pqproto::{FeStartupPacket, read_message, read_startup}; #[tokio::test] async fn reject_large_startup() { diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 0e138cc0c7..0e00c4f97e 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -358,21 +358,19 @@ pub(crate) async fn handle_client( } }; - let compute_user_info = match &user_info { - auth::Backend::ControlPlane(_, info) => &info.info, + let creds = match &user_info { + auth::Backend::ControlPlane(_, creds) => creds, auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"), }; - let params_compat = compute_user_info - .options - .get(NeonOptions::PARAMS_COMPAT) - .is_some(); + let params_compat = creds.info.options.get(NeonOptions::PARAMS_COMPAT).is_some(); + let mut auth_info = compute::AuthInfo::with_auth_keys(&creds.keys); + auth_info.set_startup_params(¶ms, params_compat); let res = connect_to_compute( ctx, &TcpMechanism { - user_info: compute_user_info.clone(), - params_compat, - params: ¶ms, + user_info: creds.info.clone(), + auth: auth_info, locks: &config.connect_compute_locks, }, &user_info, diff --git a/proxy/src/proxy/retry.rs b/proxy/src/proxy/retry.rs index 01e603ec14..0f19944afa 100644 --- a/proxy/src/proxy/retry.rs +++ b/proxy/src/proxy/retry.rs @@ -100,9 +100,9 @@ impl CouldRetry for compute::ConnectionError { fn could_retry(&self) -> bool { match self { compute::ConnectionError::Postgres(err) => err.could_retry(), - compute::ConnectionError::CouldNotConnect(err) => err.could_retry(), + compute::ConnectionError::TlsError(err) => err.could_retry(), compute::ConnectionError::WakeComputeError(err) => err.could_retry(), - _ => false, + compute::ConnectionError::TooManyConnectionAttempts(_) => false, } } } diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index e5db0013a7..028247a97d 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -8,7 +8,7 @@ use std::time::Duration; use anyhow::{Context, bail}; use async_trait::async_trait; use http::StatusCode; -use postgres_client::config::SslMode; +use postgres_client::config::{AuthKeys, ScramKeys, SslMode}; use postgres_client::tls::{MakeTlsConnect, NoTls}; use retry::{ShouldRetryWakeCompute, retry_after}; use rstest::rstest; @@ -29,7 +29,6 @@ use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache}; use crate::error::ErrorKind; use crate::pglb::connect_compute::ConnectMechanism; use crate::tls::client_config::compute_client_config_with_certs; -use crate::tls::postgres_rustls::MakeRustlsConnect; use crate::tls::server_config::CertResolver; use crate::types::{BranchId, EndpointId, ProjectId}; use crate::{sasl, scram}; @@ -72,13 +71,14 @@ struct ClientConfig<'a> { hostname: &'a str, } -type TlsConnect = >::TlsConnect; +type TlsConnect = >::TlsConnect; impl ClientConfig<'_> { fn make_tls_connect(self) -> anyhow::Result> { - let mut mk = MakeRustlsConnect::new(self.config); - let tls = MakeTlsConnect::::make_tls_connect(&mut mk, self.hostname)?; - Ok(tls) + Ok(crate::tls::postgres_rustls::make_tls_connect( + &self.config, + self.hostname, + )?) } } @@ -497,8 +497,6 @@ impl ConnectMechanism for TestConnectMechanism { x => panic!("expecting action {x:?}, connect is called instead"), } } - - fn update_connect_config(&self, _conf: &mut compute::ConnCfg) {} } impl TestControlPlaneClient for TestConnectMechanism { @@ -557,7 +555,12 @@ impl TestControlPlaneClient for TestConnectMechanism { fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo { let node = NodeInfo { - config: compute::ConnCfg::new("test".to_owned(), 5432), + conn_info: compute::ConnectInfo { + host: "test".into(), + port: 5432, + ssl_mode: SslMode::Disable, + host_addr: None, + }, aux: MetricsAuxInfo { endpoint_id: (&EndpointId::from("endpoint")).into(), project_id: (&ProjectId::from("project")).into(), @@ -581,7 +584,10 @@ fn helper_create_connect_info( user: "user".into(), options: NeonOptions::parse_options_raw(""), }, - keys: ComputeCredentialKeys::Password("password".into()), + keys: ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(ScramKeys { + client_key: [0; 32], + server_key: [0; 32], + })), }, ) } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 748e0ce6f2..a0e782dab0 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -23,7 +23,6 @@ use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnP use crate::auth::backend::local::StaticAuthRules; use crate::auth::backend::{ComputeCredentials, ComputeUserInfo}; use crate::auth::{self, AuthError}; -use crate::compute; use crate::compute_ctl::{ ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest, }; @@ -305,12 +304,13 @@ impl PoolingBackend { tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "local_pool: opening a new connection '{conn_info}'"); - let mut node_info = local_backend.node_info.clone(); - let (key, jwk) = create_random_jwk(); - let config = node_info - .config + let mut config = local_backend + .node_info + .conn_info + .to_postgres_client_config(); + config .user(&conn_info.user_info.user) .dbname(&conn_info.dbname) .set_param( @@ -322,7 +322,7 @@ impl PoolingBackend { ); let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let (client, connection) = config.connect(postgres_client::NoTls).await?; + let (client, connection) = config.connect(&postgres_client::NoTls).await?; drop(pause); let pid = client.get_process_id(); @@ -336,7 +336,7 @@ impl PoolingBackend { connection, key, conn_id, - node_info.aux.clone(), + local_backend.node_info.aux.clone(), ); { @@ -512,19 +512,16 @@ impl ConnectMechanism for TokioMechanism { node_info: &CachedNodeInfo, compute_config: &ComputeConfig, ) -> Result { - let host = node_info.config.get_host(); - let permit = self.locks.get_permit(&host).await?; + let permit = self.locks.get_permit(&node_info.conn_info.host).await?; - let mut config = (*node_info.config).clone(); + let mut config = node_info.conn_info.to_postgres_client_config(); let config = config .user(&self.conn_info.user_info.user) .dbname(&self.conn_info.dbname) .connect_timeout(compute_config.timeout); - let mk_tls = - crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone()); let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let res = config.connect(mk_tls).await; + let res = config.connect(compute_config).await; drop(pause); let (client, connection) = permit.release_result(res)?; @@ -548,8 +545,6 @@ impl ConnectMechanism for TokioMechanism { node_info.aux.clone(), )) } - - fn update_connect_config(&self, _config: &mut compute::ConnCfg) {} } struct HyperMechanism { @@ -573,20 +568,20 @@ impl ConnectMechanism for HyperMechanism { node_info: &CachedNodeInfo, config: &ComputeConfig, ) -> Result { - let host_addr = node_info.config.get_host_addr(); - let host = node_info.config.get_host(); - let permit = self.locks.get_permit(&host).await?; + let host_addr = node_info.conn_info.host_addr; + let host = &node_info.conn_info.host; + let permit = self.locks.get_permit(host).await?; let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let tls = if node_info.config.get_ssl_mode() == SslMode::Disable { + let tls = if node_info.conn_info.ssl_mode == SslMode::Disable { None } else { Some(&config.tls) }; - let port = node_info.config.get_port(); - let res = connect_http2(host_addr, &host, port, config.timeout, tls).await; + let port = node_info.conn_info.port; + let res = connect_http2(host_addr, host, port, config.timeout, tls).await; drop(pause); let (client, connection) = permit.release_result(res)?; @@ -609,8 +604,6 @@ impl ConnectMechanism for HyperMechanism { node_info.aux.clone(), )) } - - fn update_connect_config(&self, _config: &mut compute::ConnCfg) {} } async fn connect_http2( diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 87176ff7d6..dd8cf052c5 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -23,12 +23,12 @@ use super::conn_pool_lib::{ Client, ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, EndpointConnPool, GlobalConnPool, }; +use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::messages::MetricsAuxInfo; use crate::metrics::Metrics; -use crate::tls::postgres_rustls::MakeRustlsConnect; -type TlsStream = >::Stream; +type TlsStream = >::Stream; #[derive(Debug, Clone)] pub(crate) struct ConnInfoWithAuth { diff --git a/proxy/src/tls/postgres_rustls.rs b/proxy/src/tls/postgres_rustls.rs index 013b307f0b..9269ad8a06 100644 --- a/proxy/src/tls/postgres_rustls.rs +++ b/proxy/src/tls/postgres_rustls.rs @@ -2,10 +2,11 @@ use std::convert::TryFrom; use std::sync::Arc; use postgres_client::tls::MakeTlsConnect; -use rustls::ClientConfig; -use rustls::pki_types::ServerName; +use rustls::pki_types::{InvalidDnsNameError, ServerName}; use tokio::io::{AsyncRead, AsyncWrite}; +use crate::config::ComputeConfig; + mod private { use std::future::Future; use std::io; @@ -123,36 +124,27 @@ mod private { } } -/// A `MakeTlsConnect` implementation using `rustls`. -/// -/// That way you can connect to PostgreSQL using `rustls` as the TLS stack. -#[derive(Clone)] -pub struct MakeRustlsConnect { - pub config: Arc, -} - -impl MakeRustlsConnect { - /// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`. - #[must_use] - pub fn new(config: Arc) -> Self { - Self { config } - } -} - -impl MakeTlsConnect for MakeRustlsConnect +impl MakeTlsConnect for ComputeConfig where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { type Stream = private::RustlsStream; type TlsConnect = private::RustlsConnect; - type Error = rustls::pki_types::InvalidDnsNameError; + type Error = InvalidDnsNameError; - fn make_tls_connect(&mut self, hostname: &str) -> Result { - ServerName::try_from(hostname).map(|dns_name| { - private::RustlsConnect(private::RustlsConnectData { - hostname: dns_name.to_owned(), - connector: Arc::clone(&self.config).into(), - }) - }) + fn make_tls_connect(&self, hostname: &str) -> Result { + make_tls_connect(&self.tls, hostname) } } + +pub fn make_tls_connect( + tls: &Arc, + hostname: &str, +) -> Result { + ServerName::try_from(hostname).map(|dns_name| { + private::RustlsConnect(private::RustlsConnectData { + hostname: dns_name.to_owned(), + connector: tls.clone().into(), + }) + }) +}