diff --git a/Cargo.lock b/Cargo.lock index 85a59ec0ed..2b4637b299 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1780,6 +1780,18 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "enum-as-inner" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ffccbb6966c05b32ef8fbac435df276c4ae4d3dc55a8cd0eb9745e6c12f546a" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "quote", + "syn 2.0.52", +] + [[package]] name = "enum-map" version = "2.5.0" @@ -1971,9 +1983,9 @@ checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" [[package]] name = "form_urlencoded" -version = "1.1.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9c384f161156f5260c24a097c56119f9be8c798586aecc13afbcbe7b7e26bf8" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" dependencies = [ "percent-encoding", ] @@ -2332,6 +2344,51 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6fe2267d4ed49bc07b63801559be28c718ea06c4738b7a03c94df7386d2cde46" +[[package]] +name = "hickory-proto" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07698b8420e2f0d6447a436ba999ec85d8fbf2a398bbd737b82cac4a2e96e512" +dependencies = [ + "async-trait", + "cfg-if", + "data-encoding", + "enum-as-inner", + "futures-channel", + "futures-io", + "futures-util", + "idna 0.4.0", + "ipnet", + "once_cell", + "rand 0.8.5", + "thiserror", + "tinyvec", + "tokio", + "tracing", + "url", +] + +[[package]] +name = "hickory-resolver" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28757f23aa75c98f254cf0405e6d8c25b831b32921b050a66692427679b1f243" +dependencies = [ + "cfg-if", + "futures-util", + "hickory-proto", + "ipconfig", + "lru-cache", + "once_cell", + "parking_lot 0.12.1", + "rand 0.8.5", + "resolv-conf", + "smallvec", + "thiserror", + "tokio", + "tracing", +] + [[package]] name = "histogram" version = "0.7.4" @@ -2612,9 +2669,19 @@ checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" [[package]] name = "idna" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e14ddfc70884202db2244c223200c204c2bda1bc6e0998d11b5e024d657209e6" +checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + +[[package]] +name = "idna" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" dependencies = [ "unicode-bidi", "unicode-normalization", @@ -2719,6 +2786,18 @@ dependencies = [ "libc", ] +[[package]] +name = "ipconfig" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b58db92f96b720de98181bbbe63c831e87005ab460c1bf306eb2622b4707997f" +dependencies = [ + "socket2 0.5.5", + "widestring", + "windows-sys 0.48.0", + "winreg", +] + [[package]] name = "ipnet" version = "2.9.0" @@ -2860,6 +2939,12 @@ version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" +[[package]] +name = "linked-hash-map" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" + [[package]] name = "linux-raw-sys" version = "0.1.4" @@ -2894,6 +2979,15 @@ version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" +[[package]] +name = "lru-cache" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31e24f1ad8321ca0e8a1e0ac13f23cb668e6f5466c2c57319f6a5cf1cc8e3b1c" +dependencies = [ + "linked-hash-map", +] + [[package]] name = "match_cfg" version = "0.1.0" @@ -3889,9 +3983,9 @@ dependencies = [ [[package]] name = "percent-encoding" -version = "2.2.0" +version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "petgraph" @@ -4310,6 +4404,7 @@ dependencies = [ "hashbrown 0.13.2", "hashlink", "hex", + "hickory-resolver", "hmac", "hostname", "http 1.1.0", @@ -4385,6 +4480,12 @@ dependencies = [ "x509-parser", ] +[[package]] +name = "quick-error" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" + [[package]] name = "quick-xml" version = "0.31.0" @@ -4772,6 +4873,16 @@ dependencies = [ "tracing-opentelemetry", ] +[[package]] +name = "resolv-conf" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52e44394d2086d010551b14b53b1f24e31647570cd1deb0379e2c21b329aba00" +dependencies = [ + "hostname", + "quick-error", +] + [[package]] name = "retry-policies" version = "0.1.2" @@ -6695,12 +6806,12 @@ dependencies = [ [[package]] name = "url" -version = "2.3.1" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d68c799ae75762b8c3fe375feb6600ef5602c883c5d21eb51c09f22b83c4643" +checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" dependencies = [ "form_urlencoded", - "idna", + "idna 0.5.0", "percent-encoding", "serde", ] @@ -7032,6 +7143,12 @@ dependencies = [ "once_cell", ] +[[package]] +name = "widestring" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "653f141f39ec16bba3c5abe400a0c60da7468261cc2cbf36805022876bc721a8" + [[package]] name = "winapi" version = "0.3.9" diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 6b8f2ecbf4..f1070895fc 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -31,6 +31,7 @@ git-version.workspace = true hashbrown.workspace = true hashlink.workspace = true hex.workspace = true +hickory-resolver = "0.24.1" hmac.workspace = true hostname.workspace = true http.workspace = true diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 760ccf40d4..f75408e07d 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -18,6 +18,7 @@ use proxy::config::HttpConfig; use proxy::config::ProjectInfoCacheOptions; use proxy::console; use proxy::context::parquet::ParquetUploadArgs; +use proxy::dns::Dns; use proxy::http; use proxy::http::health_server::AppMetrics; use proxy::metrics::Metrics; @@ -400,7 +401,7 @@ async fn main() -> anyhow::Result<()> { if let Some(metrics_config) = &config.metric_collection { // TODO: Add gc regardles of the metric collection being enabled. - maintenance_tasks.spawn(usage_metrics::task_main(metrics_config)); + maintenance_tasks.spawn(usage_metrics::task_main(config.dns.clone(), metrics_config)); client_tasks.spawn(usage_metrics::task_backup( &metrics_config.backup_metric_collection_config, cancellation_token.clone(), @@ -497,6 +498,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { bail!("dynamic rate limiter should be disabled"); } + let dns = Dns::new(); + let auth_backend = match &args.auth_backend { AuthBackend::Console => { let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?; @@ -537,7 +540,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { tokio::spawn(locks.garbage_collect_worker()); let url = args.auth_endpoint.parse()?; - let endpoint = http::Endpoint::new(url, http::new_client()); + let endpoint = http::Endpoint::new(url, http::new_client(dns.clone())); let mut endpoint_rps_limit = args.endpoint_rps_limit.clone(); RateBucketInfo::validate(&mut endpoint_rps_limit)?; @@ -581,6 +584,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { RateBucketInfo::validate(&mut redis_rps_limit)?; let config = Box::leak(Box::new(ProxyConfig { + dns, tls_config, auth_backend, metric_collection, diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 149a619316..47e7dd68a8 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -3,17 +3,21 @@ use crate::{ cancellation::CancelClosure, console::{errors::WakeComputeError, messages::MetricsAuxInfo}, context::RequestMonitoring, + dns::Dns, error::{ReportableError, UserFacingError}, metrics::{Metrics, NumDbConnectionsGuard}, proxy::neon_option, }; -use futures::{FutureExt, TryFutureExt}; +use futures::TryFutureExt; use itertools::Itertools; use pq_proto::StartupMessageParams; use std::{io, net::SocketAddr, time::Duration}; use thiserror::Error; use tokio::net::TcpStream; -use tokio_postgres::tls::MakeTlsConnect; +use tokio_postgres::{ + tls::{MakeTlsConnect, NoTlsError}, + Connection, +}; use tracing::{error, info, warn}; const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node"; @@ -33,6 +37,9 @@ pub enum ConnectionError { #[error("{COULD_NOT_CONNECT}: {0}")] WakeComputeError(#[from] WakeComputeError), + + #[error("{COULD_NOT_CONNECT}: {0}")] + TlsNotSupported(#[from] NoTlsError), } impl UserFacingError for ConnectionError { @@ -70,6 +77,7 @@ impl ReportableError for ConnectionError { ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute, ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute, ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute, + ConnectionError::TlsNotSupported(_) => crate::error::ErrorKind::Compute, ConnectionError::WakeComputeError(e) => e.get_error_kind(), } } @@ -165,20 +173,42 @@ impl std::ops::DerefMut for ConnCfg { impl ConnCfg { /// Establish a raw TCP connection to the compute node. - async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> { + async fn connect_raw( + &self, + dns: &Dns, + timeout: Duration, + ) -> io::Result<(SocketAddr, TcpStream, &str)> { use tokio_postgres::config::Host; // wrap TcpStream::connect with timeout - let connect_with_timeout = |host, port| { - tokio::time::timeout(timeout, TcpStream::connect((host, port))).map( - move |res| match res { - Ok(tcpstream_connect_res) => tcpstream_connect_res, - Err(_) => Err(io::Error::new( - io::ErrorKind::TimedOut, - format!("exceeded connection timeout {timeout:?}"), - )), - }, - ) + let connect_with_timeout = |host, port| async move { + let addrs = dns + .resolve(host) + .await + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))? + .collect::>(); + let timeout = timeout / addrs.len() as u32; + + let mut last_err = None; + for addr in addrs { + match tokio::time::timeout(timeout, TcpStream::connect((addr, port))).await { + Ok(Ok(stream)) => return Ok(stream), + Ok(Err(e)) => last_err = Some(e), + Err(_) => { + last_err = Some(io::Error::new( + io::ErrorKind::TimedOut, + format!("exceeded connection timeout {timeout:?}"), + )) + } + }; + } + + Err(last_err.unwrap_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "could not resolve to any address", + ) + })) }; let connect_once = |host, port| { @@ -235,12 +265,11 @@ impl ConnCfg { } } +type TlsStream = postgres_native_tls::TlsStream; + pub struct PostgresConnection { /// Socket connected to a compute node. - pub stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream< - tokio::net::TcpStream, - postgres_native_tls::TlsStream, - >, + pub stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream, /// PostgreSQL connection parameters. pub params: std::collections::HashMap, /// Query cancellation token. @@ -253,26 +282,30 @@ pub struct PostgresConnection { impl ConnCfg { /// Connect to a corresponding compute node. - pub async fn connect( + pub async fn connect_managed>( &self, ctx: &mut RequestMonitoring, - allow_self_signed_compute: bool, - aux: MetricsAuxInfo, + dns: &Dns, timeout: Duration, - ) -> Result { - let (socket_addr, stream, host) = self.connect_raw(timeout).await?; + mut tls: Tls, + ) -> Result< + ( + SocketAddr, + tokio_postgres::Client, + Connection, + ), + ConnectionError, + > + where + ConnectionError: From, + { + let (socket_addr, stream, host) = self.connect_raw(dns, timeout).await?; - let tls_connector = native_tls::TlsConnector::builder() - .danger_accept_invalid_certs(allow_self_signed_compute) - .build() - .unwrap(); - let mut mk_tls = postgres_native_tls::MakeTlsConnector::new(tls_connector); - let tls = MakeTlsConnect::::make_tls_connect(&mut mk_tls, host)?; + let tls = MakeTlsConnect::::make_tls_connect(&mut tls, host)?; // connect_raw() will not use TLS if sslmode is "disable" let (client, connection) = self.0.connect_raw(stream, tls).await?; tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id())); - let stream = connection.stream.into_inner(); info!( cold_start_info = ctx.cold_start_info.as_str(), @@ -280,6 +313,28 @@ impl ConnCfg { self.0.get_ssl_mode() ); + Ok((socket_addr, client, connection)) + } + + /// Connect to a corresponding compute node. + pub async fn connect( + &self, + ctx: &mut RequestMonitoring, + dns: &Dns, + allow_self_signed_compute: bool, + aux: MetricsAuxInfo, + timeout: Duration, + ) -> Result { + let tls_connector = native_tls::TlsConnector::builder() + .danger_accept_invalid_certs(allow_self_signed_compute) + .build() + .unwrap(); + let mk_tls = postgres_native_tls::MakeTlsConnector::new(tls_connector); + + let (socket_addr, client, connection) = + self.connect_managed(ctx, dns, timeout, mk_tls).await?; + let stream = connection.stream.into_inner(); + // This is very ugly but as of now there's no better way to // extract the connection parameters from tokio-postgres' connection. // TODO: solve this problem in a more elegant manner (e.g. the new library). diff --git a/proxy/src/config.rs b/proxy/src/config.rs index ae7606e5d4..2a346cf165 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -1,5 +1,6 @@ use crate::{ auth::{self, backend::AuthRateLimiter}, + dns::Dns, rate_limiter::RateBucketInfo, serverless::GlobalConnPoolOptions, }; @@ -21,6 +22,7 @@ use tracing::{error, info}; use x509_parser::oid_registry; pub struct ProxyConfig { + pub dns: Dns, pub tls_config: Option, pub auth_backend: auth::BackendType<'static, (), ()>, pub metric_collection: Option, diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index aa1800a9da..105dc94cf3 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -12,6 +12,7 @@ use crate::{ compute, config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions}, context::RequestMonitoring, + dns::Dns, intern::ProjectIdInt, metrics::ApiLockMetrics, scram, EndpointCacheKey, @@ -302,11 +303,13 @@ impl NodeInfo { pub async fn connect( &self, ctx: &mut RequestMonitoring, + dns: &Dns, timeout: Duration, ) -> Result { self.config .connect( ctx, + dns, self.allow_self_signed_compute, self.aux.clone(), timeout, diff --git a/proxy/src/dns.rs b/proxy/src/dns.rs new file mode 100644 index 0000000000..2cf4027a61 --- /dev/null +++ b/proxy/src/dns.rs @@ -0,0 +1,55 @@ +//! Async dns resolvers + +use std::{ + net::{IpAddr, SocketAddr}, + sync::Arc, +}; + +use hickory_resolver::error::ResolveError; +use tokio::time::Instant; +use tracing::trace; + +#[derive(Clone)] +pub struct Dns { + resolver: Arc, +} + +impl Default for Dns { + fn default() -> Self { + Self::new() + } +} + +impl Dns { + pub fn new() -> Self { + let (config, options) = + hickory_resolver::system_conf::read_system_conf().expect("could not read resolv.conf"); + + let resolver = Arc::new(hickory_resolver::TokioAsyncResolver::tokio(config, options)); + + Self { resolver } + } + + pub async fn resolve(&self, name: &str) -> Result, ResolveError> { + let start = Instant::now(); + + let res = self.resolver.lookup_ip(name).await; + + let resolve_duration = start.elapsed(); + trace!(duration = ?resolve_duration, addr = %name, "resolve host complete"); + + Ok(res?.into_iter()) + } +} + +impl reqwest::dns::Resolve for Dns { + fn resolve(&self, name: hyper::client::connect::dns::Name) -> reqwest::dns::Resolving { + let this = self.clone(); + Box::pin(async move { + match this.resolve(name.as_str()).await { + Ok(iter) => Ok(Box::new(iter.map(|ip| SocketAddr::new(ip, 0))) as Box<_>), + Err(e) => Err(e.into()), + } + }) + } +} diff --git a/proxy/src/http.rs b/proxy/src/http.rs index e20488e23c..c6c017573f 100644 --- a/proxy/src/http.rs +++ b/proxy/src/http.rs @@ -14,6 +14,7 @@ use tokio::time::Instant; use tracing::trace; use crate::{ + dns::Dns, metrics::{ConsoleRequest, Metrics}, url::ApiUrl, }; @@ -22,9 +23,9 @@ use reqwest_middleware::RequestBuilder; /// This is the preferred way to create new http clients, /// because it takes care of observability (OpenTelemetry). /// We deliberately don't want to replace this with a public static. -pub fn new_client() -> ClientWithMiddleware { +pub fn new_client(dns: Dns) -> ClientWithMiddleware { let client = reqwest::ClientBuilder::new() - .dns_resolver(Arc::new(GaiResolver::default())) + .dns_resolver(Arc::new(dns)) .connection_verbose(true) .build() .expect("Failed to create http client"); @@ -34,9 +35,9 @@ pub fn new_client() -> ClientWithMiddleware { .build() } -pub fn new_client_with_timeout(default_timout: Duration) -> ClientWithMiddleware { +pub fn new_client_with_timeout(dns: Dns, default_timout: Duration) -> ClientWithMiddleware { let timeout_client = reqwest::ClientBuilder::new() - .dns_resolver(Arc::new(GaiResolver::default())) + .dns_resolver(Arc::new(dns)) .connection_verbose(true) .timeout(default_timout) .build() diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index 3f6d985fe8..608d191d27 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -14,6 +14,7 @@ pub mod compute; pub mod config; pub mod console; pub mod context; +pub mod dns; pub mod error; pub mod http; pub mod intern; diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index a4554eef38..c8f51ab5c0 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -307,6 +307,7 @@ pub async fn handle_client( ctx, &TcpMechanism { params: ¶ms }, &user_info, + &config.dns, mode.allow_self_signed_compute(config), config.wake_compute_retry_config, config.connect_to_compute_retry_config, diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index 8a220aaa0c..bf54b3336a 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -4,6 +4,7 @@ use crate::{ config::RetryConfig, console::{self, errors::WakeComputeError, CachedNodeInfo, NodeInfo}, context::RequestMonitoring, + dns::Dns, error::ReportableError, metrics::{ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType}, proxy::{ @@ -44,6 +45,7 @@ pub trait ConnectMechanism { async fn connect_once( &self, ctx: &mut RequestMonitoring, + dns: &Dns, node_info: &console::CachedNodeInfo, timeout: time::Duration, ) -> Result; @@ -76,10 +78,11 @@ impl ConnectMechanism for TcpMechanism<'_> { async fn connect_once( &self, ctx: &mut RequestMonitoring, + dns: &Dns, node_info: &console::CachedNodeInfo, timeout: time::Duration, ) -> Result { - node_info.connect(ctx, timeout).await + node_info.connect(ctx, dns, timeout).await } fn update_connect_config(&self, config: &mut compute::ConnCfg) { @@ -93,6 +96,7 @@ pub async fn connect_to_compute( ctx: &mut RequestMonitoring, mechanism: &M, user_info: &B, + dns: &Dns, allow_self_signed_compute: bool, wake_compute_retry_config: RetryConfig, connect_to_compute_retry_config: RetryConfig, @@ -114,7 +118,7 @@ where // try once let err = match mechanism - .connect_once(ctx, &node_info, CONNECT_TIMEOUT) + .connect_once(ctx, dns, &node_info, CONNECT_TIMEOUT) .await { Ok(res) => { @@ -159,7 +163,7 @@ where num_retries = 1; loop { match mechanism - .connect_once(ctx, &node_info, CONNECT_TIMEOUT) + .connect_once(ctx, dns, &node_info, CONNECT_TIMEOUT) .await { Ok(res) => { diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index e0ec90cb44..ed19bc2fa4 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -15,6 +15,7 @@ use crate::console::caches::NodeInfoCache; use crate::console::messages::MetricsAuxInfo; use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend}; use crate::console::{self, CachedNodeInfo, NodeInfo}; +use crate::dns::Dns; use crate::error::ErrorKind; use crate::proxy::retry::retry_after; use crate::{http, sasl, scram, BranchId, EndpointId, ProjectId}; @@ -453,6 +454,7 @@ impl ConnectMechanism for TestConnectMechanism { async fn connect_once( &self, _ctx: &mut RequestMonitoring, + _dns: &Dns, _node_info: &console::CachedNodeInfo, _timeout: std::time::Duration, ) -> Result { @@ -558,9 +560,17 @@ async fn connect_to_compute_success() { max_retries: 5, backoff_factor: 2.0, }; - connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config) - .await - .unwrap(); + connect_to_compute( + &mut ctx, + &mechanism, + &user_info, + &Dns::new(), + false, + config, + config, + ) + .await + .unwrap(); mechanism.verify(); } @@ -576,9 +586,17 @@ async fn connect_to_compute_retry() { max_retries: 5, backoff_factor: 2.0, }; - connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config) - .await - .unwrap(); + connect_to_compute( + &mut ctx, + &mechanism, + &user_info, + &Dns::new(), + false, + config, + config, + ) + .await + .unwrap(); mechanism.verify(); } @@ -595,9 +613,17 @@ async fn connect_to_compute_non_retry_1() { max_retries: 5, backoff_factor: 2.0, }; - connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config) - .await - .unwrap_err(); + connect_to_compute( + &mut ctx, + &mechanism, + &user_info, + &Dns::new(), + false, + config, + config, + ) + .await + .unwrap_err(); mechanism.verify(); } @@ -614,9 +640,17 @@ async fn connect_to_compute_non_retry_2() { max_retries: 5, backoff_factor: 2.0, }; - connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config) - .await - .unwrap(); + connect_to_compute( + &mut ctx, + &mechanism, + &user_info, + &Dns::new(), + false, + config, + config, + ) + .await + .unwrap(); mechanism.verify(); } @@ -644,6 +678,7 @@ async fn connect_to_compute_non_retry_3() { &mut ctx, &mechanism, &user_info, + &Dns::new(), false, wake_compute_retry_config, connect_to_compute_retry_config, @@ -666,9 +701,17 @@ async fn wake_retry() { max_retries: 5, backoff_factor: 2.0, }; - connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config) - .await - .unwrap(); + connect_to_compute( + &mut ctx, + &mechanism, + &user_info, + &Dns::new(), + false, + config, + config, + ) + .await + .unwrap(); mechanism.verify(); } @@ -685,8 +728,16 @@ async fn wake_non_retry() { max_retries: 5, backoff_factor: 2.0, }; - connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config) - .await - .unwrap_err(); + connect_to_compute( + &mut ctx, + &mechanism, + &user_info, + &Dns::new(), + false, + config, + config, + ) + .await + .unwrap_err(); mechanism.verify(); } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index b91c0e62ed..dbe9df3543 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -1,17 +1,19 @@ use std::{sync::Arc, time::Duration}; use async_trait::async_trait; +use tokio_postgres::NoTls; use tracing::{field::display, info}; use crate::{ auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError}, - compute, + compute::{self, ConnectionError}, config::{AuthenticationConfig, ProxyConfig}, console::{ errors::{GetAuthInfoError, WakeComputeError}, CachedNodeInfo, }, context::RequestMonitoring, + dns::Dns, error::{ErrorKind, ReportableError, UserFacingError}, proxy::connect_compute::ConnectMechanism, }; @@ -107,6 +109,7 @@ impl PoolingBackend { pool: self.pool.clone(), }, &backend, + &self.config.dns, false, // do not allow self signed compute for http flow self.config.wake_compute_retry_config, self.config.connect_to_compute_retry_config, @@ -120,7 +123,7 @@ pub enum HttpConnError { #[error("pooled connection closed at inconsistent state")] ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError), #[error("could not connection to compute")] - ConnectionError(#[from] tokio_postgres::Error), + ConnectionError(#[from] ConnectionError), #[error("could not get auth info")] GetAuthInfo(#[from] GetAuthInfoError), @@ -163,23 +166,24 @@ struct TokioMechanism { #[async_trait] impl ConnectMechanism for TokioMechanism { type Connection = Client; - type ConnectError = tokio_postgres::Error; + type ConnectError = ConnectionError; type Error = HttpConnError; async fn connect_once( &self, ctx: &mut RequestMonitoring, + dns: &Dns, node_info: &CachedNodeInfo, timeout: Duration, - ) -> Result { - let mut config = (*node_info.config).clone(); - let config = config + ) -> Result { + let mut config = node_info.config.clone(); + config .user(&self.conn_info.user_info.user) .password(&*self.conn_info.password) .dbname(&self.conn_info.dbname) .connect_timeout(timeout); - let (client, connection) = config.connect(tokio_postgres::NoTls).await?; + let (_, client, connection) = config.connect_managed(ctx, dns, timeout, NoTls).await?; tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id())); Ok(poll_client( diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 798e488509..dc522a7944 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -12,9 +12,10 @@ use std::{ ops::Deref, sync::atomic::{self, AtomicUsize}, }; +use tokio::net::TcpStream; use tokio::time::Instant; use tokio_postgres::tls::NoTlsStream; -use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket}; +use tokio_postgres::{AsyncMessage, ReadyForQueryStatus}; use tokio_util::sync::CancellationToken; use crate::console::messages::{ColdStartInfo, MetricsAuxInfo}; @@ -468,7 +469,7 @@ pub fn poll_client( ctx: &mut RequestMonitoring, conn_info: ConnInfo, client: C, - mut connection: tokio_postgres::Connection, + mut connection: tokio_postgres::Connection, conn_id: uuid::Uuid, aux: MetricsAuxInfo, ) -> Client { diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index e856053a7e..0651c3ac40 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -37,6 +37,7 @@ use utils::http::error::ApiError; use crate::auth::backend::ComputeUserInfo; use crate::auth::endpoint_sni; use crate::auth::ComputeUserInfoParseError; +use crate::compute::ConnectionError; use crate::config::ProxyConfig; use crate::config::TlsConfig; use crate::context::RequestMonitoring; @@ -257,7 +258,9 @@ pub async fn handle( let mut message = e.to_string_client(); let db_error = match &e { - SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e)) + SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError( + ConnectionError::Postgres(e), + )) | SqlOverHttpError::Postgres(e) => e.as_db_error(), _ => None, }; @@ -661,7 +664,9 @@ impl QueryData { // query failed or was cancelled. Ok(Err(error)) => { let db_error = match &error { - SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e)) + SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError( + ConnectionError::Postgres(e), + )) | SqlOverHttpError::Postgres(e) => e.as_db_error(), _ => None, }; diff --git a/proxy/src/usage_metrics.rs b/proxy/src/usage_metrics.rs index 56ed2145dc..11c870b736 100644 --- a/proxy/src/usage_metrics.rs +++ b/proxy/src/usage_metrics.rs @@ -3,6 +3,7 @@ use crate::{ config::{MetricBackupCollectionConfig, MetricCollectionConfig}, context::parquet::{FAILED_UPLOAD_MAX_RETRIES, FAILED_UPLOAD_WARN_THRESHOLD}, + dns::Dns, http, intern::{BranchIdInt, EndpointIdInt}, }; @@ -217,13 +218,13 @@ impl Metrics { pub static USAGE_METRICS: Lazy = Lazy::new(Metrics::default); -pub async fn task_main(config: &MetricCollectionConfig) -> anyhow::Result { +pub async fn task_main(dns: Dns, config: &MetricCollectionConfig) -> anyhow::Result { info!("metrics collector config: {config:?}"); scopeguard::defer! { info!("metrics collector has shut down"); } - let http_client = http::new_client_with_timeout(DEFAULT_HTTP_REPORTING_TIMEOUT); + let http_client = http::new_client_with_timeout(dns, DEFAULT_HTTP_REPORTING_TIMEOUT); let hostname = hostname::get()?.as_os_str().to_string_lossy().into_owned(); let mut prev = Utc::now(); @@ -495,7 +496,7 @@ mod tests { use url::Url; use super::*; - use crate::{http, BranchId, EndpointId}; + use crate::{dns::Dns, http, BranchId, EndpointId}; #[tokio::test] async fn metrics() { @@ -525,7 +526,7 @@ mod tests { tokio::spawn(server); let metrics = Metrics::default(); - let client = http::new_client(); + let client = http::new_client(Dns::new()); let endpoint = Url::parse(&format!("http://{addr}")).unwrap(); let now = Utc::now();