proxy: async dns resolver

This commit is contained in:
Conrad Ludgate
2024-04-23 16:51:38 +01:00
parent 5dda371c2b
commit c003b43781
16 changed files with 386 additions and 80 deletions

135
Cargo.lock generated
View File

@@ -1780,6 +1780,18 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "enum-as-inner"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ffccbb6966c05b32ef8fbac435df276c4ae4d3dc55a8cd0eb9745e6c12f546a"
dependencies = [
"heck 0.4.1",
"proc-macro2",
"quote",
"syn 2.0.52",
]
[[package]]
name = "enum-map"
version = "2.5.0"
@@ -1971,9 +1983,9 @@ checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
[[package]]
name = "form_urlencoded"
version = "1.1.0"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9c384f161156f5260c24a097c56119f9be8c798586aecc13afbcbe7b7e26bf8"
checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456"
dependencies = [
"percent-encoding",
]
@@ -2332,6 +2344,51 @@ version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fe2267d4ed49bc07b63801559be28c718ea06c4738b7a03c94df7386d2cde46"
[[package]]
name = "hickory-proto"
version = "0.24.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07698b8420e2f0d6447a436ba999ec85d8fbf2a398bbd737b82cac4a2e96e512"
dependencies = [
"async-trait",
"cfg-if",
"data-encoding",
"enum-as-inner",
"futures-channel",
"futures-io",
"futures-util",
"idna 0.4.0",
"ipnet",
"once_cell",
"rand 0.8.5",
"thiserror",
"tinyvec",
"tokio",
"tracing",
"url",
]
[[package]]
name = "hickory-resolver"
version = "0.24.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28757f23aa75c98f254cf0405e6d8c25b831b32921b050a66692427679b1f243"
dependencies = [
"cfg-if",
"futures-util",
"hickory-proto",
"ipconfig",
"lru-cache",
"once_cell",
"parking_lot 0.12.1",
"rand 0.8.5",
"resolv-conf",
"smallvec",
"thiserror",
"tokio",
"tracing",
]
[[package]]
name = "histogram"
version = "0.7.4"
@@ -2612,9 +2669,19 @@ checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
[[package]]
name = "idna"
version = "0.3.0"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e14ddfc70884202db2244c223200c204c2bda1bc6e0998d11b5e024d657209e6"
checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c"
dependencies = [
"unicode-bidi",
"unicode-normalization",
]
[[package]]
name = "idna"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6"
dependencies = [
"unicode-bidi",
"unicode-normalization",
@@ -2719,6 +2786,18 @@ dependencies = [
"libc",
]
[[package]]
name = "ipconfig"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b58db92f96b720de98181bbbe63c831e87005ab460c1bf306eb2622b4707997f"
dependencies = [
"socket2 0.5.5",
"widestring",
"windows-sys 0.48.0",
"winreg",
]
[[package]]
name = "ipnet"
version = "2.9.0"
@@ -2860,6 +2939,12 @@ version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
[[package]]
name = "linked-hash-map"
version = "0.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f"
[[package]]
name = "linux-raw-sys"
version = "0.1.4"
@@ -2894,6 +2979,15 @@ version = "0.4.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f"
[[package]]
name = "lru-cache"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "31e24f1ad8321ca0e8a1e0ac13f23cb668e6f5466c2c57319f6a5cf1cc8e3b1c"
dependencies = [
"linked-hash-map",
]
[[package]]
name = "match_cfg"
version = "0.1.0"
@@ -3889,9 +3983,9 @@ dependencies = [
[[package]]
name = "percent-encoding"
version = "2.2.0"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e"
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
[[package]]
name = "petgraph"
@@ -4310,6 +4404,7 @@ dependencies = [
"hashbrown 0.13.2",
"hashlink",
"hex",
"hickory-resolver",
"hmac",
"hostname",
"http 1.1.0",
@@ -4385,6 +4480,12 @@ dependencies = [
"x509-parser",
]
[[package]]
name = "quick-error"
version = "1.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
[[package]]
name = "quick-xml"
version = "0.31.0"
@@ -4772,6 +4873,16 @@ dependencies = [
"tracing-opentelemetry",
]
[[package]]
name = "resolv-conf"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52e44394d2086d010551b14b53b1f24e31647570cd1deb0379e2c21b329aba00"
dependencies = [
"hostname",
"quick-error",
]
[[package]]
name = "retry-policies"
version = "0.1.2"
@@ -6695,12 +6806,12 @@ dependencies = [
[[package]]
name = "url"
version = "2.3.1"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d68c799ae75762b8c3fe375feb6600ef5602c883c5d21eb51c09f22b83c4643"
checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633"
dependencies = [
"form_urlencoded",
"idna",
"idna 0.5.0",
"percent-encoding",
"serde",
]
@@ -7032,6 +7143,12 @@ dependencies = [
"once_cell",
]
[[package]]
name = "widestring"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "653f141f39ec16bba3c5abe400a0c60da7468261cc2cbf36805022876bc721a8"
[[package]]
name = "winapi"
version = "0.3.9"

View File

@@ -31,6 +31,7 @@ git-version.workspace = true
hashbrown.workspace = true
hashlink.workspace = true
hex.workspace = true
hickory-resolver = "0.24.1"
hmac.workspace = true
hostname.workspace = true
http.workspace = true

View File

@@ -18,6 +18,7 @@ use proxy::config::HttpConfig;
use proxy::config::ProjectInfoCacheOptions;
use proxy::console;
use proxy::context::parquet::ParquetUploadArgs;
use proxy::dns::Dns;
use proxy::http;
use proxy::http::health_server::AppMetrics;
use proxy::metrics::Metrics;
@@ -400,7 +401,7 @@ async fn main() -> anyhow::Result<()> {
if let Some(metrics_config) = &config.metric_collection {
// TODO: Add gc regardles of the metric collection being enabled.
maintenance_tasks.spawn(usage_metrics::task_main(metrics_config));
maintenance_tasks.spawn(usage_metrics::task_main(config.dns.clone(), metrics_config));
client_tasks.spawn(usage_metrics::task_backup(
&metrics_config.backup_metric_collection_config,
cancellation_token.clone(),
@@ -497,6 +498,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
bail!("dynamic rate limiter should be disabled");
}
let dns = Dns::new();
let auth_backend = match &args.auth_backend {
AuthBackend::Console => {
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
@@ -537,7 +540,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
tokio::spawn(locks.garbage_collect_worker());
let url = args.auth_endpoint.parse()?;
let endpoint = http::Endpoint::new(url, http::new_client());
let endpoint = http::Endpoint::new(url, http::new_client(dns.clone()));
let mut endpoint_rps_limit = args.endpoint_rps_limit.clone();
RateBucketInfo::validate(&mut endpoint_rps_limit)?;
@@ -581,6 +584,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
RateBucketInfo::validate(&mut redis_rps_limit)?;
let config = Box::leak(Box::new(ProxyConfig {
dns,
tls_config,
auth_backend,
metric_collection,

View File

@@ -3,17 +3,21 @@ use crate::{
cancellation::CancelClosure,
console::{errors::WakeComputeError, messages::MetricsAuxInfo},
context::RequestMonitoring,
dns::Dns,
error::{ReportableError, UserFacingError},
metrics::{Metrics, NumDbConnectionsGuard},
proxy::neon_option,
};
use futures::{FutureExt, TryFutureExt};
use futures::TryFutureExt;
use itertools::Itertools;
use pq_proto::StartupMessageParams;
use std::{io, net::SocketAddr, time::Duration};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres::{
tls::{MakeTlsConnect, NoTlsError},
Connection,
};
use tracing::{error, info, warn};
const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
@@ -33,6 +37,9 @@ pub enum ConnectionError {
#[error("{COULD_NOT_CONNECT}: {0}")]
WakeComputeError(#[from] WakeComputeError),
#[error("{COULD_NOT_CONNECT}: {0}")]
TlsNotSupported(#[from] NoTlsError),
}
impl UserFacingError for ConnectionError {
@@ -70,6 +77,7 @@ impl ReportableError for ConnectionError {
ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
ConnectionError::TlsNotSupported(_) => crate::error::ErrorKind::Compute,
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
}
}
@@ -165,20 +173,42 @@ impl std::ops::DerefMut for ConnCfg {
impl ConnCfg {
/// Establish a raw TCP connection to the compute node.
async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
async fn connect_raw(
&self,
dns: &Dns,
timeout: Duration,
) -> io::Result<(SocketAddr, TcpStream, &str)> {
use tokio_postgres::config::Host;
// wrap TcpStream::connect with timeout
let connect_with_timeout = |host, port| {
tokio::time::timeout(timeout, TcpStream::connect((host, port))).map(
move |res| match res {
Ok(tcpstream_connect_res) => tcpstream_connect_res,
Err(_) => Err(io::Error::new(
io::ErrorKind::TimedOut,
format!("exceeded connection timeout {timeout:?}"),
)),
},
)
let connect_with_timeout = |host, port| async move {
let addrs = dns
.resolve(host)
.await
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
.collect::<Vec<_>>();
let timeout = timeout / addrs.len() as u32;
let mut last_err = None;
for addr in addrs {
match tokio::time::timeout(timeout, TcpStream::connect((addr, port))).await {
Ok(Ok(stream)) => return Ok(stream),
Ok(Err(e)) => last_err = Some(e),
Err(_) => {
last_err = Some(io::Error::new(
io::ErrorKind::TimedOut,
format!("exceeded connection timeout {timeout:?}"),
))
}
};
}
Err(last_err.unwrap_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve to any address",
)
}))
};
let connect_once = |host, port| {
@@ -235,12 +265,11 @@ impl ConnCfg {
}
}
type TlsStream = postgres_native_tls::TlsStream<TcpStream>;
pub struct PostgresConnection {
/// Socket connected to a compute node.
pub stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream<
tokio::net::TcpStream,
postgres_native_tls::TlsStream<tokio::net::TcpStream>,
>,
pub stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream<TcpStream, TlsStream>,
/// PostgreSQL connection parameters.
pub params: std::collections::HashMap<String, String>,
/// Query cancellation token.
@@ -253,26 +282,30 @@ pub struct PostgresConnection {
impl ConnCfg {
/// Connect to a corresponding compute node.
pub async fn connect(
pub async fn connect_managed<Tls: MakeTlsConnect<TcpStream>>(
&self,
ctx: &mut RequestMonitoring,
allow_self_signed_compute: bool,
aux: MetricsAuxInfo,
dns: &Dns,
timeout: Duration,
) -> Result<PostgresConnection, ConnectionError> {
let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
mut tls: Tls,
) -> Result<
(
SocketAddr,
tokio_postgres::Client,
Connection<TcpStream, Tls::Stream>,
),
ConnectionError,
>
where
ConnectionError: From<Tls::Error>,
{
let (socket_addr, stream, host) = self.connect_raw(dns, timeout).await?;
let tls_connector = native_tls::TlsConnector::builder()
.danger_accept_invalid_certs(allow_self_signed_compute)
.build()
.unwrap();
let mut mk_tls = postgres_native_tls::MakeTlsConnector::new(tls_connector);
let tls = MakeTlsConnect::<tokio::net::TcpStream>::make_tls_connect(&mut mk_tls, host)?;
let tls = MakeTlsConnect::<TcpStream>::make_tls_connect(&mut tls, host)?;
// connect_raw() will not use TLS if sslmode is "disable"
let (client, connection) = self.0.connect_raw(stream, tls).await?;
tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
let stream = connection.stream.into_inner();
info!(
cold_start_info = ctx.cold_start_info.as_str(),
@@ -280,6 +313,28 @@ impl ConnCfg {
self.0.get_ssl_mode()
);
Ok((socket_addr, client, connection))
}
/// Connect to a corresponding compute node.
pub async fn connect(
&self,
ctx: &mut RequestMonitoring,
dns: &Dns,
allow_self_signed_compute: bool,
aux: MetricsAuxInfo,
timeout: Duration,
) -> Result<PostgresConnection, ConnectionError> {
let tls_connector = native_tls::TlsConnector::builder()
.danger_accept_invalid_certs(allow_self_signed_compute)
.build()
.unwrap();
let mk_tls = postgres_native_tls::MakeTlsConnector::new(tls_connector);
let (socket_addr, client, connection) =
self.connect_managed(ctx, dns, timeout, mk_tls).await?;
let stream = connection.stream.into_inner();
// This is very ugly but as of now there's no better way to
// extract the connection parameters from tokio-postgres' connection.
// TODO: solve this problem in a more elegant manner (e.g. the new library).

View File

@@ -1,5 +1,6 @@
use crate::{
auth::{self, backend::AuthRateLimiter},
dns::Dns,
rate_limiter::RateBucketInfo,
serverless::GlobalConnPoolOptions,
};
@@ -21,6 +22,7 @@ use tracing::{error, info};
use x509_parser::oid_registry;
pub struct ProxyConfig {
pub dns: Dns,
pub tls_config: Option<TlsConfig>,
pub auth_backend: auth::BackendType<'static, (), ()>,
pub metric_collection: Option<MetricCollectionConfig>,

View File

@@ -12,6 +12,7 @@ use crate::{
compute,
config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions},
context::RequestMonitoring,
dns::Dns,
intern::ProjectIdInt,
metrics::ApiLockMetrics,
scram, EndpointCacheKey,
@@ -302,11 +303,13 @@ impl NodeInfo {
pub async fn connect(
&self,
ctx: &mut RequestMonitoring,
dns: &Dns,
timeout: Duration,
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
self.config
.connect(
ctx,
dns,
self.allow_self_signed_compute,
self.aux.clone(),
timeout,

55
proxy/src/dns.rs Normal file
View File

@@ -0,0 +1,55 @@
//! Async dns resolvers
use std::{
net::{IpAddr, SocketAddr},
sync::Arc,
};
use hickory_resolver::error::ResolveError;
use tokio::time::Instant;
use tracing::trace;
#[derive(Clone)]
pub struct Dns {
resolver: Arc<hickory_resolver::TokioAsyncResolver>,
}
impl Default for Dns {
fn default() -> Self {
Self::new()
}
}
impl Dns {
pub fn new() -> Self {
let (config, options) =
hickory_resolver::system_conf::read_system_conf().expect("could not read resolv.conf");
let resolver = Arc::new(hickory_resolver::TokioAsyncResolver::tokio(config, options));
Self { resolver }
}
pub async fn resolve(&self, name: &str) -> Result<impl Iterator<Item = IpAddr>, ResolveError> {
let start = Instant::now();
let res = self.resolver.lookup_ip(name).await;
let resolve_duration = start.elapsed();
trace!(duration = ?resolve_duration, addr = %name, "resolve host complete");
Ok(res?.into_iter())
}
}
impl reqwest::dns::Resolve for Dns {
fn resolve(&self, name: hyper::client::connect::dns::Name) -> reqwest::dns::Resolving {
let this = self.clone();
Box::pin(async move {
match this.resolve(name.as_str()).await {
Ok(iter) => Ok(Box::new(iter.map(|ip| SocketAddr::new(ip, 0))) as Box<_>),
Err(e) => Err(e.into()),
}
})
}
}

View File

@@ -14,6 +14,7 @@ use tokio::time::Instant;
use tracing::trace;
use crate::{
dns::Dns,
metrics::{ConsoleRequest, Metrics},
url::ApiUrl,
};
@@ -22,9 +23,9 @@ use reqwest_middleware::RequestBuilder;
/// This is the preferred way to create new http clients,
/// because it takes care of observability (OpenTelemetry).
/// We deliberately don't want to replace this with a public static.
pub fn new_client() -> ClientWithMiddleware {
pub fn new_client(dns: Dns) -> ClientWithMiddleware {
let client = reqwest::ClientBuilder::new()
.dns_resolver(Arc::new(GaiResolver::default()))
.dns_resolver(Arc::new(dns))
.connection_verbose(true)
.build()
.expect("Failed to create http client");
@@ -34,9 +35,9 @@ pub fn new_client() -> ClientWithMiddleware {
.build()
}
pub fn new_client_with_timeout(default_timout: Duration) -> ClientWithMiddleware {
pub fn new_client_with_timeout(dns: Dns, default_timout: Duration) -> ClientWithMiddleware {
let timeout_client = reqwest::ClientBuilder::new()
.dns_resolver(Arc::new(GaiResolver::default()))
.dns_resolver(Arc::new(dns))
.connection_verbose(true)
.timeout(default_timout)
.build()

View File

@@ -14,6 +14,7 @@ pub mod compute;
pub mod config;
pub mod console;
pub mod context;
pub mod dns;
pub mod error;
pub mod http;
pub mod intern;

View File

@@ -307,6 +307,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
ctx,
&TcpMechanism { params: &params },
&user_info,
&config.dns,
mode.allow_self_signed_compute(config),
config.wake_compute_retry_config,
config.connect_to_compute_retry_config,

View File

@@ -4,6 +4,7 @@ use crate::{
config::RetryConfig,
console::{self, errors::WakeComputeError, CachedNodeInfo, NodeInfo},
context::RequestMonitoring,
dns::Dns,
error::ReportableError,
metrics::{ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType},
proxy::{
@@ -44,6 +45,7 @@ pub trait ConnectMechanism {
async fn connect_once(
&self,
ctx: &mut RequestMonitoring,
dns: &Dns,
node_info: &console::CachedNodeInfo,
timeout: time::Duration,
) -> Result<Self::Connection, Self::ConnectError>;
@@ -76,10 +78,11 @@ impl ConnectMechanism for TcpMechanism<'_> {
async fn connect_once(
&self,
ctx: &mut RequestMonitoring,
dns: &Dns,
node_info: &console::CachedNodeInfo,
timeout: time::Duration,
) -> Result<PostgresConnection, Self::Error> {
node_info.connect(ctx, timeout).await
node_info.connect(ctx, dns, timeout).await
}
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
@@ -93,6 +96,7 @@ pub async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
ctx: &mut RequestMonitoring,
mechanism: &M,
user_info: &B,
dns: &Dns,
allow_self_signed_compute: bool,
wake_compute_retry_config: RetryConfig,
connect_to_compute_retry_config: RetryConfig,
@@ -114,7 +118,7 @@ where
// try once
let err = match mechanism
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)
.connect_once(ctx, dns, &node_info, CONNECT_TIMEOUT)
.await
{
Ok(res) => {
@@ -159,7 +163,7 @@ where
num_retries = 1;
loop {
match mechanism
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)
.connect_once(ctx, dns, &node_info, CONNECT_TIMEOUT)
.await
{
Ok(res) => {

View File

@@ -15,6 +15,7 @@ use crate::console::caches::NodeInfoCache;
use crate::console::messages::MetricsAuxInfo;
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend};
use crate::console::{self, CachedNodeInfo, NodeInfo};
use crate::dns::Dns;
use crate::error::ErrorKind;
use crate::proxy::retry::retry_after;
use crate::{http, sasl, scram, BranchId, EndpointId, ProjectId};
@@ -453,6 +454,7 @@ impl ConnectMechanism for TestConnectMechanism {
async fn connect_once(
&self,
_ctx: &mut RequestMonitoring,
_dns: &Dns,
_node_info: &console::CachedNodeInfo,
_timeout: std::time::Duration,
) -> Result<Self::Connection, Self::ConnectError> {
@@ -558,9 +560,17 @@ async fn connect_to_compute_success() {
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config)
.await
.unwrap();
connect_to_compute(
&mut ctx,
&mechanism,
&user_info,
&Dns::new(),
false,
config,
config,
)
.await
.unwrap();
mechanism.verify();
}
@@ -576,9 +586,17 @@ async fn connect_to_compute_retry() {
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config)
.await
.unwrap();
connect_to_compute(
&mut ctx,
&mechanism,
&user_info,
&Dns::new(),
false,
config,
config,
)
.await
.unwrap();
mechanism.verify();
}
@@ -595,9 +613,17 @@ async fn connect_to_compute_non_retry_1() {
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config)
.await
.unwrap_err();
connect_to_compute(
&mut ctx,
&mechanism,
&user_info,
&Dns::new(),
false,
config,
config,
)
.await
.unwrap_err();
mechanism.verify();
}
@@ -614,9 +640,17 @@ async fn connect_to_compute_non_retry_2() {
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config)
.await
.unwrap();
connect_to_compute(
&mut ctx,
&mechanism,
&user_info,
&Dns::new(),
false,
config,
config,
)
.await
.unwrap();
mechanism.verify();
}
@@ -644,6 +678,7 @@ async fn connect_to_compute_non_retry_3() {
&mut ctx,
&mechanism,
&user_info,
&Dns::new(),
false,
wake_compute_retry_config,
connect_to_compute_retry_config,
@@ -666,9 +701,17 @@ async fn wake_retry() {
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config)
.await
.unwrap();
connect_to_compute(
&mut ctx,
&mechanism,
&user_info,
&Dns::new(),
false,
config,
config,
)
.await
.unwrap();
mechanism.verify();
}
@@ -685,8 +728,16 @@ async fn wake_non_retry() {
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config)
.await
.unwrap_err();
connect_to_compute(
&mut ctx,
&mechanism,
&user_info,
&Dns::new(),
false,
config,
config,
)
.await
.unwrap_err();
mechanism.verify();
}

View File

@@ -1,17 +1,19 @@
use std::{sync::Arc, time::Duration};
use async_trait::async_trait;
use tokio_postgres::NoTls;
use tracing::{field::display, info};
use crate::{
auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError},
compute,
compute::{self, ConnectionError},
config::{AuthenticationConfig, ProxyConfig},
console::{
errors::{GetAuthInfoError, WakeComputeError},
CachedNodeInfo,
},
context::RequestMonitoring,
dns::Dns,
error::{ErrorKind, ReportableError, UserFacingError},
proxy::connect_compute::ConnectMechanism,
};
@@ -107,6 +109,7 @@ impl PoolingBackend {
pool: self.pool.clone(),
},
&backend,
&self.config.dns,
false, // do not allow self signed compute for http flow
self.config.wake_compute_retry_config,
self.config.connect_to_compute_retry_config,
@@ -120,7 +123,7 @@ pub enum HttpConnError {
#[error("pooled connection closed at inconsistent state")]
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
#[error("could not connection to compute")]
ConnectionError(#[from] tokio_postgres::Error),
ConnectionError(#[from] ConnectionError),
#[error("could not get auth info")]
GetAuthInfo(#[from] GetAuthInfoError),
@@ -163,23 +166,24 @@ struct TokioMechanism {
#[async_trait]
impl ConnectMechanism for TokioMechanism {
type Connection = Client<tokio_postgres::Client>;
type ConnectError = tokio_postgres::Error;
type ConnectError = ConnectionError;
type Error = HttpConnError;
async fn connect_once(
&self,
ctx: &mut RequestMonitoring,
dns: &Dns,
node_info: &CachedNodeInfo,
timeout: Duration,
) -> Result<Self::Connection, Self::ConnectError> {
let mut config = (*node_info.config).clone();
let config = config
) -> Result<Self::Connection, ConnectionError> {
let mut config = node_info.config.clone();
config
.user(&self.conn_info.user_info.user)
.password(&*self.conn_info.password)
.dbname(&self.conn_info.dbname)
.connect_timeout(timeout);
let (client, connection) = config.connect(tokio_postgres::NoTls).await?;
let (_, client, connection) = config.connect_managed(ctx, dns, timeout, NoTls).await?;
tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
Ok(poll_client(

View File

@@ -12,9 +12,10 @@ use std::{
ops::Deref,
sync::atomic::{self, AtomicUsize},
};
use tokio::net::TcpStream;
use tokio::time::Instant;
use tokio_postgres::tls::NoTlsStream;
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket};
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus};
use tokio_util::sync::CancellationToken;
use crate::console::messages::{ColdStartInfo, MetricsAuxInfo};
@@ -468,7 +469,7 @@ pub fn poll_client<C: ClientInnerExt>(
ctx: &mut RequestMonitoring,
conn_info: ConnInfo,
client: C,
mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
mut connection: tokio_postgres::Connection<TcpStream, NoTlsStream>,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
) -> Client<C> {

View File

@@ -37,6 +37,7 @@ use utils::http::error::ApiError;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::endpoint_sni;
use crate::auth::ComputeUserInfoParseError;
use crate::compute::ConnectionError;
use crate::config::ProxyConfig;
use crate::config::TlsConfig;
use crate::context::RequestMonitoring;
@@ -257,7 +258,9 @@ pub async fn handle(
let mut message = e.to_string_client();
let db_error = match &e {
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e))
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(
ConnectionError::Postgres(e),
))
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
_ => None,
};
@@ -661,7 +664,9 @@ impl QueryData {
// query failed or was cancelled.
Ok(Err(error)) => {
let db_error = match &error {
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e))
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(
ConnectionError::Postgres(e),
))
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
_ => None,
};

View File

@@ -3,6 +3,7 @@
use crate::{
config::{MetricBackupCollectionConfig, MetricCollectionConfig},
context::parquet::{FAILED_UPLOAD_MAX_RETRIES, FAILED_UPLOAD_WARN_THRESHOLD},
dns::Dns,
http,
intern::{BranchIdInt, EndpointIdInt},
};
@@ -217,13 +218,13 @@ impl Metrics {
pub static USAGE_METRICS: Lazy<Metrics> = Lazy::new(Metrics::default);
pub async fn task_main(config: &MetricCollectionConfig) -> anyhow::Result<Infallible> {
pub async fn task_main(dns: Dns, config: &MetricCollectionConfig) -> anyhow::Result<Infallible> {
info!("metrics collector config: {config:?}");
scopeguard::defer! {
info!("metrics collector has shut down");
}
let http_client = http::new_client_with_timeout(DEFAULT_HTTP_REPORTING_TIMEOUT);
let http_client = http::new_client_with_timeout(dns, DEFAULT_HTTP_REPORTING_TIMEOUT);
let hostname = hostname::get()?.as_os_str().to_string_lossy().into_owned();
let mut prev = Utc::now();
@@ -495,7 +496,7 @@ mod tests {
use url::Url;
use super::*;
use crate::{http, BranchId, EndpointId};
use crate::{dns::Dns, http, BranchId, EndpointId};
#[tokio::test]
async fn metrics() {
@@ -525,7 +526,7 @@ mod tests {
tokio::spawn(server);
let metrics = Metrics::default();
let client = http::new_client();
let client = http::new_client(Dns::new());
let endpoint = Url::parse(&format!("http://{addr}")).unwrap();
let now = Utc::now();