diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 18c82fe379..bb0cfb609a 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -16,8 +16,10 @@ use rustls::{client::danger::ServerCertVerifier, pki_types::InvalidDnsNameError} use std::{io, net::SocketAddr, sync::Arc, time::Duration}; use thiserror::Error; use tokio::net::TcpStream; -use tokio_postgres::tls::MakeTlsConnect; -use tokio_postgres_rustls::MakeRustlsConnect; +use tokio_postgres::{ + tls::{MakeTlsConnect, NoTlsError}, + Client, Connection, +}; use tracing::{error, info, warn}; const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node"; @@ -42,6 +44,12 @@ pub enum ConnectionError { TooManyConnectionAttempts(#[from] ApiLockError), } +impl From for ConnectionError { + fn from(value: NoTlsError) -> Self { + Self::CouldNotConnect(io::Error::new(io::ErrorKind::Other, value.to_string())) + } +} + impl UserFacingError for ConnectionError { fn to_string_client(&self) -> String { use ConnectionError::*; @@ -273,6 +281,30 @@ pub struct PostgresConnection { } impl ConnCfg { + /// Connect to a corresponding compute node. + pub async fn connect2>( + &self, + ctx: &RequestMonitoring, + timeout: Duration, + mktls: &mut M, + ) -> Result<(SocketAddr, Client, Connection), ConnectionError> + where + ConnectionError: From, + { + let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); + let (socket_addr, stream, host) = self.connect_raw(timeout).await?; + drop(pause); + + let tls = mktls.make_tls_connect(host)?; + + // connect_raw() will not use TLS if sslmode is "disable" + let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); + let (client, connection) = self.0.connect_raw(stream, tls).await?; + drop(pause); + + Ok((socket_addr, client, connection)) + } + /// Connect to a corresponding compute node. pub async fn connect( &self, @@ -281,10 +313,6 @@ impl ConnCfg { aux: MetricsAuxInfo, timeout: Duration, ) -> Result { - let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let (socket_addr, stream, host) = self.connect_raw(timeout).await?; - drop(pause); - let client_config = if allow_self_signed_compute { // Allow all certificates for creating the connection let verifier = Arc::new(AcceptEverythingVerifier) as Arc; @@ -298,21 +326,14 @@ impl ConnCfg { let client_config = client_config.with_no_client_auth(); let mut mk_tls = tokio_postgres_rustls::MakeRustlsConnect::new(client_config); - let tls = >::make_tls_connect( - &mut mk_tls, - host, - )?; - // connect_raw() will not use TLS if sslmode is "disable" - let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let (client, connection) = self.0.connect_raw(stream, tls).await?; - drop(pause); + let (socket_addr, client, connection) = self.connect2(ctx, timeout, &mut mk_tls).await?; tracing::Span::current().record("pid", tracing::field::display(client.get_process_id())); let stream = connection.stream.into_inner(); info!( cold_start_info = ctx.cold_start_info().as_str(), - "connected to compute node at {host} ({socket_addr}) sslmode={:?}", + "connected to compute node ({socket_addr}) sslmode={:?}", self.0.get_ssl_mode() ); diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 295ea1a1c7..977d7bda82 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -5,7 +5,7 @@ use tracing::{field::display, info}; use crate::{ auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError}, - compute, + compute::{self, ConnectionError}, config::{AuthenticationConfig, ProxyConfig}, console::{ errors::{GetAuthInfoError, WakeComputeError}, @@ -142,7 +142,7 @@ pub enum HttpConnError { #[error("pooled connection closed at inconsistent state")] ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError), #[error("could not connection to compute")] - ConnectionError(#[from] tokio_postgres::Error), + ConnectionError(#[from] ConnectionError), #[error("could not get auth info")] GetAuthInfo(#[from] GetAuthInfoError), @@ -229,17 +229,13 @@ impl ConnectMechanism for TokioMechanism { let host = node_info.config.get_host()?; let permit = self.locks.get_permit(&host).await?; - let mut config = (*node_info.config).clone(); - let config = config - .user(&self.conn_info.user_info.user) - .password(&*self.conn_info.password) - .dbname(&self.conn_info.dbname) - .connect_timeout(timeout); - let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let res = config.connect(tokio_postgres::NoTls).await; + let res = node_info + .config + .connect2(ctx, timeout, &mut tokio_postgres::NoTls) + .await; drop(pause); - let (client, connection) = permit.release_result(res)?; + let (_, client, connection) = permit.release_result(res)?; tracing::Span::current().record("pid", tracing::field::display(client.get_process_id())); Ok(poll_client( @@ -253,5 +249,10 @@ impl ConnectMechanism for TokioMechanism { )) } - fn update_connect_config(&self, _config: &mut compute::ConnCfg) {} + fn update_connect_config(&self, config: &mut compute::ConnCfg) { + config + .user(&self.conn_info.user_info.user) + .dbname(&self.conn_info.dbname) + .password(&self.conn_info.password); + } } diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index e1dc44dc1c..5bb136072d 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -12,9 +12,10 @@ use std::{ ops::Deref, sync::atomic::{self, AtomicUsize}, }; +use tokio::net::TcpStream; use tokio::time::Instant; use tokio_postgres::tls::NoTlsStream; -use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket}; +use tokio_postgres::{AsyncMessage, ReadyForQueryStatus}; use tokio_util::sync::CancellationToken; use crate::console::messages::{ColdStartInfo, MetricsAuxInfo}; @@ -468,7 +469,7 @@ pub fn poll_client( ctx: &RequestMonitoring, conn_info: ConnInfo, client: C, - mut connection: tokio_postgres::Connection, + mut connection: tokio_postgres::Connection, conn_id: uuid::Uuid, aux: MetricsAuxInfo, ) -> Client { diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index c41df07a4d..5d8f0bd6c4 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -40,6 +40,7 @@ use utils::http::error::ApiError; use crate::auth::backend::ComputeUserInfo; use crate::auth::endpoint_sni; use crate::auth::ComputeUserInfoParseError; +use crate::compute::ConnectionError; use crate::config::ProxyConfig; use crate::config::TlsConfig; use crate::context::RequestMonitoring; @@ -261,7 +262,9 @@ pub async fn handle( let mut message = e.to_string_client(); let db_error = match &e { - SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e)) + SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError( + ConnectionError::Postgres(e), + )) | SqlOverHttpError::Postgres(e) => e.as_db_error(), _ => None, }; @@ -663,7 +666,9 @@ impl QueryData { // query failed or was cancelled. Ok(Err(error)) => { let db_error = match &error { - SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e)) + SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError( + ConnectionError::Postgres(e), + )) | SqlOverHttpError::Postgres(e) => e.as_db_error(), _ => None, };