chore(proxy): pre-load native tls certificates and propagate compute client config

This commit is contained in:
Conrad Ludgate
2024-12-18 09:29:42 +00:00
parent cd0924c686
commit bbc799ce77
15 changed files with 211 additions and 160 deletions

View File

@@ -13,7 +13,9 @@ use proxy::auth::backend::jwt::JwkCache;
use proxy::auth::backend::local::{LocalBackend, JWKS_ROLE_MAP};
use proxy::auth::{self};
use proxy::cancellation::CancellationHandlerMain;
use proxy::config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig};
use proxy::config::{
self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig,
};
use proxy::control_plane::locks::ApiLocks;
use proxy::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use proxy::http::health_server::AppMetrics;
@@ -32,6 +34,8 @@ project_git_version!(GIT_VERSION);
project_build_tag!(BUILD_TAG);
use clap::Parser;
use rustls::crypto::ring;
use rustls::RootCertStore;
use thiserror::Error;
use tokio::net::TcpListener;
use tokio::sync::Notify;
@@ -209,6 +213,7 @@ async fn main() -> anyhow::Result<()> {
http_listener,
shutdown.clone(),
Arc::new(CancellationHandlerMain::new(
&config.connect_to_compute,
Arc::new(DashMap::new()),
None,
proxy::metrics::CancellationSource::Local,
@@ -268,6 +273,22 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
};
// local_proxy won't use TLS to talk to postgres.
let root_store = RootCertStore::empty();
let client_config =
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("ring should support the default protocol versions")
.with_root_certificates(root_store)
.with_no_client_auth();
let compute_config = ComputeConfig {
retry: RetryConfig::parse(RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)?,
tls: Arc::new(client_config),
timeout: Duration::from_secs(2),
};
Ok(Box::leak(Box::new(ProxyConfig {
tls_config: None,
metric_collection: None,
@@ -289,9 +310,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
region: "local".into(),
wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
connect_compute_locks,
connect_to_compute_retry_config: RetryConfig::parse(
RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES,
)?,
connect_to_compute: compute_config,
})))
}

View File

@@ -1,14 +1,15 @@
use std::net::SocketAddr;
use std::pin::pin;
use std::sync::Arc;
use std::time::Duration;
use anyhow::bail;
use anyhow::{bail, Context};
use futures::future::Either;
use proxy::auth::backend::jwt::JwkCache;
use proxy::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned};
use proxy::cancellation::{CancelMap, CancellationHandler};
use proxy::config::{
self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, HttpConfig,
self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig,
ProjectInfoCacheOptions, ProxyConfig, ProxyProtocolV2,
};
use proxy::context::parquet::ParquetUploadArgs;
@@ -25,6 +26,7 @@ use proxy::serverless::cancel_set::CancelSet;
use proxy::serverless::GlobalConnPoolOptions;
use proxy::{auth, control_plane, http, serverless, usage_metrics};
use remote_storage::RemoteStorageConfig;
use rustls::crypto::ring;
use tokio::net::TcpListener;
use tokio::sync::Mutex;
use tokio::task::JoinSet;
@@ -397,6 +399,7 @@ async fn main() -> anyhow::Result<()> {
let cancellation_handler = Arc::new(CancellationHandler::<
Option<Arc<Mutex<RedisPublisherClient>>>,
>::new(
&config.connect_to_compute,
cancel_map.clone(),
redis_publisher,
proxy::metrics::CancellationSource::FromClient,
@@ -492,6 +495,7 @@ async fn main() -> anyhow::Result<()> {
let cache = api.caches.project_info.clone();
if let Some(client) = client1 {
maintenance_tasks.spawn(notifications::task_main(
config,
client,
cache.clone(),
cancel_map.clone(),
@@ -500,6 +504,7 @@ async fn main() -> anyhow::Result<()> {
}
if let Some(client) = client2 {
maintenance_tasks.spawn(notifications::task_main(
config,
client,
cache.clone(),
cancel_map.clone(),
@@ -632,6 +637,23 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
console_redirect_confirmation_timeout: args.webauth_confirmation_timeout,
};
let root_store = load_certs()
.context("loading native tls certificates")?
.clone();
let client_config =
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("ring should support the default protocol versions")
.with_root_certificates(root_store)
.with_no_client_auth();
let compute_config = ComputeConfig {
retry: config::RetryConfig::parse(&args.connect_to_compute_retry)?,
tls: Arc::new(client_config),
timeout: Duration::from_secs(2),
};
let config = ProxyConfig {
tls_config,
metric_collection,
@@ -642,9 +664,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
region: args.region.clone(),
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
connect_compute_locks,
connect_to_compute_retry_config: config::RetryConfig::parse(
&args.connect_to_compute_retry,
)?,
connect_to_compute: compute_config,
};
let config = Box::leak(Box::new(config));
@@ -654,6 +674,18 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
Ok(config)
}
pub(crate) fn load_certs() -> anyhow::Result<Arc<rustls::RootCertStore>> {
let der_certs = rustls_native_certs::load_native_certs();
if !der_certs.errors.is_empty() {
bail!("could not parse certificates: {:?}", der_certs.errors);
}
let mut store = rustls::RootCertStore::empty();
store.add_parsable_certificates(der_certs.certs);
Ok(Arc::new(store))
}
/// auth::Backend is created at proxy startup, and lives forever.
fn build_auth_backend(
args: &ProxyCliArgs,

View File

@@ -3,11 +3,9 @@ use std::sync::Arc;
use dashmap::DashMap;
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use once_cell::sync::OnceCell;
use postgres_client::tls::MakeTlsConnect;
use postgres_client::CancelToken;
use pq_proto::CancelKeyData;
use rustls::crypto::ring;
use thiserror::Error;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
@@ -15,7 +13,7 @@ use tracing::{debug, info};
use uuid::Uuid;
use crate::auth::{check_peer_addr_is_in_list, IpPattern};
use crate::compute::load_certs;
use crate::config::ComputeConfig;
use crate::error::ReportableError;
use crate::ext::LockExt;
use crate::metrics::{CancellationRequest, CancellationSource, Metrics};
@@ -35,6 +33,7 @@ type IpSubnetKey = IpNet;
///
/// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
pub struct CancellationHandler<P> {
compute_config: &'static ComputeConfig,
map: CancelMap,
client: P,
/// This field used for the monitoring purposes.
@@ -183,7 +182,7 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
"cancelling query per user's request using key {key}, hostname {}, address: {}",
cancel_closure.hostname, cancel_closure.socket_addr
);
cancel_closure.try_cancel_query().await
cancel_closure.try_cancel_query(self.compute_config).await
}
#[cfg(test)]
@@ -198,8 +197,13 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
}
impl CancellationHandler<()> {
pub fn new(map: CancelMap, from: CancellationSource) -> Self {
pub fn new(
compute_config: &'static ComputeConfig,
map: CancelMap,
from: CancellationSource,
) -> Self {
Self {
compute_config,
map,
client: (),
from,
@@ -214,8 +218,14 @@ impl CancellationHandler<()> {
}
impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
pub fn new(map: CancelMap, client: Option<Arc<Mutex<P>>>, from: CancellationSource) -> Self {
pub fn new(
compute_config: &'static ComputeConfig,
map: CancelMap,
client: Option<Arc<Mutex<P>>>,
from: CancellationSource,
) -> Self {
Self {
compute_config,
map,
client,
from,
@@ -229,8 +239,6 @@ impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
}
}
static TLS_ROOTS: OnceCell<Arc<rustls::RootCertStore>> = OnceCell::new();
/// This should've been a [`std::future::Future`], but
/// it's impossible to name a type of an unboxed future
/// (we'd need something like `#![feature(type_alias_impl_trait)]`).
@@ -257,27 +265,13 @@ impl CancelClosure {
}
}
/// Cancels the query running on user's compute node.
pub(crate) async fn try_cancel_query(self) -> Result<(), CancelError> {
pub(crate) async fn try_cancel_query(
self,
compute_config: &ComputeConfig,
) -> Result<(), CancelError> {
let socket = TcpStream::connect(self.socket_addr).await?;
let root_store = TLS_ROOTS
.get_or_try_init(load_certs)
.map_err(|_e| {
CancelError::IO(std::io::Error::new(
std::io::ErrorKind::Other,
"TLS root store initialization failed".to_string(),
))
})?
.clone();
let client_config =
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("ring should support the default protocol versions")
.with_root_certificates(root_store)
.with_no_client_auth();
let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(client_config);
let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
&mut mk_tls,
&self.hostname,
@@ -329,11 +323,41 @@ impl<P> Drop for Session<P> {
#[cfg(test)]
#[expect(clippy::unwrap_used)]
mod tests {
use std::time::Duration;
use rustls::crypto::ring;
use rustls::RootCertStore;
use super::*;
use crate::config::RetryConfig;
fn config() -> ComputeConfig {
let retry = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
let root_store = RootCertStore::empty();
let client_config =
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("ring should support the default protocol versions")
.with_root_certificates(root_store)
.with_no_client_auth();
ComputeConfig {
retry,
tls: Arc::new(client_config),
timeout: Duration::from_secs(2),
}
}
#[tokio::test]
async fn check_session_drop() -> anyhow::Result<()> {
let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
Box::leak(Box::new(config())),
CancelMap::default(),
CancellationSource::FromRedis,
));
@@ -349,8 +373,11 @@ mod tests {
#[tokio::test]
async fn cancel_session_noop_regression() {
let handler =
CancellationHandler::<()>::new(CancelMap::default(), CancellationSource::Local);
let handler = CancellationHandler::<()>::new(
Box::leak(Box::new(config())),
CancelMap::default(),
CancellationSource::Local,
);
handler
.cancel_session(
CancelKeyData {

View File

@@ -1,16 +1,13 @@
use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use futures::{FutureExt, TryFutureExt};
use itertools::Itertools;
use once_cell::sync::OnceCell;
use postgres_client::tls::MakeTlsConnect;
use postgres_client::{CancelToken, RawConnection};
use postgres_protocol::message::backend::NoticeResponseBody;
use pq_proto::StartupMessageParams;
use rustls::crypto::ring;
use rustls::pki_types::InvalidDnsNameError;
use thiserror::Error;
use tokio::net::TcpStream;
@@ -18,6 +15,7 @@ use tracing::{debug, error, info, warn};
use crate::auth::parse_endpoint_param;
use crate::cancellation::CancelClosure;
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::client::ApiLockError;
use crate::control_plane::errors::WakeComputeError;
@@ -40,9 +38,6 @@ pub(crate) enum ConnectionError {
#[error("{COULD_NOT_CONNECT}: {0}")]
CouldNotConnect(#[from] io::Error),
#[error("Couldn't load native TLS certificates: {0:?}")]
TlsCertificateError(Vec<rustls_native_certs::Error>),
#[error("{COULD_NOT_CONNECT}: {0}")]
TlsError(#[from] InvalidDnsNameError),
@@ -89,7 +84,6 @@ impl ReportableError for ConnectionError {
}
ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
ConnectionError::TlsCertificateError(_) => crate::error::ErrorKind::Service,
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
@@ -251,25 +245,13 @@ impl ConnCfg {
&self,
ctx: &RequestContext,
aux: MetricsAuxInfo,
timeout: Duration,
config: &ComputeConfig,
) -> Result<PostgresConnection, ConnectionError> {
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
let (socket_addr, stream, host) = self.connect_raw(config.timeout).await?;
drop(pause);
let root_store = TLS_ROOTS
.get_or_try_init(load_certs)
.map_err(ConnectionError::TlsCertificateError)?
.clone();
let client_config =
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("ring should support the default protocol versions")
.with_root_certificates(root_store)
.with_no_client_auth();
let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(client_config);
let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(config.tls.clone());
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
&mut mk_tls,
host,
@@ -341,19 +323,6 @@ fn filtered_options(options: &str) -> Option<String> {
Some(options)
}
pub(crate) fn load_certs() -> Result<Arc<rustls::RootCertStore>, Vec<rustls_native_certs::Error>> {
let der_certs = rustls_native_certs::load_native_certs();
if !der_certs.errors.is_empty() {
return Err(der_certs.errors);
}
let mut store = rustls::RootCertStore::empty();
store.add_parsable_certificates(der_certs.certs);
Ok(Arc::new(store))
}
static TLS_ROOTS: OnceCell<Arc<rustls::RootCertStore>> = OnceCell::new();
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -32,7 +32,13 @@ pub struct ProxyConfig {
pub handshake_timeout: Duration,
pub wake_compute_retry_config: RetryConfig,
pub connect_compute_locks: ApiLocks<Host>,
pub connect_to_compute_retry_config: RetryConfig,
pub connect_to_compute: ComputeConfig,
}
pub struct ComputeConfig {
pub retry: RetryConfig,
pub tls: Arc<rustls::ClientConfig>,
pub timeout: Duration,
}
#[derive(Copy, Clone, Debug, ValueEnum, PartialEq)]

View File

@@ -115,7 +115,7 @@ pub async fn task_main(
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass().await {
match p.proxy_pass(&config.connect_to_compute).await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
error!(?session_id, "per-client task finished with an IO error from the client: {e:#}");
@@ -216,7 +216,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
},
&user_info,
config.wake_compute_retry_config,
config.connect_to_compute_retry_config,
&config.connect_to_compute,
)
.or_else(|e| stream.throw_error(e))
.await?;

View File

@@ -10,13 +10,13 @@ pub mod client;
pub(crate) mod errors;
use std::sync::Arc;
use std::time::Duration;
use crate::auth::backend::jwt::AuthRule;
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::auth::IpPattern;
use crate::cache::project_info::ProjectInfoCacheImpl;
use crate::cache::{Cached, TimedLru};
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo};
use crate::intern::ProjectIdInt;
@@ -73,9 +73,9 @@ impl NodeInfo {
pub(crate) async fn connect(
&self,
ctx: &RequestContext,
timeout: Duration,
config: &ComputeConfig,
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
self.config.connect(ctx, self.aux.clone(), timeout).await
self.config.connect(ctx, self.aux.clone(), config).await
}
pub(crate) fn reuse_settings(&mut self, other: Self) {

View File

@@ -126,16 +126,14 @@ mod private {
/// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
#[derive(Clone)]
pub struct MakeRustlsConnect {
config: Arc<ClientConfig>,
pub config: Arc<ClientConfig>,
}
impl MakeRustlsConnect {
/// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
#[must_use]
pub fn new(config: ClientConfig) -> Self {
Self {
config: Arc::new(config),
}
pub fn new(config: Arc<ClientConfig>) -> Self {
Self { config }
}
}

View File

@@ -6,7 +6,7 @@ use tracing::{debug, info, warn};
use super::retry::ShouldRetryWakeCompute;
use crate::auth::backend::ComputeCredentialKeys;
use crate::compute::{self, PostgresConnection, COULD_NOT_CONNECT};
use crate::config::RetryConfig;
use crate::config::{ComputeConfig, RetryConfig};
use crate::context::RequestContext;
use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::locks::ApiLocks;
@@ -19,8 +19,6 @@ use crate::proxy::retry::{retry_after, should_retry, CouldRetry};
use crate::proxy::wake_compute::wake_compute;
use crate::types::Host;
const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);
/// If we couldn't connect, a cached connection info might be to blame
/// (e.g. the compute node's address might've changed at the wrong time).
/// Invalidate the cache entry (if any) to prevent subsequent errors.
@@ -49,7 +47,7 @@ pub(crate) trait ConnectMechanism {
&self,
ctx: &RequestContext,
node_info: &control_plane::CachedNodeInfo,
timeout: time::Duration,
config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError>;
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
@@ -86,11 +84,11 @@ impl ConnectMechanism for TcpMechanism<'_> {
&self,
ctx: &RequestContext,
node_info: &control_plane::CachedNodeInfo,
timeout: time::Duration,
config: &ComputeConfig,
) -> Result<PostgresConnection, Self::Error> {
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
permit.release_result(node_info.connect(ctx, timeout).await)
permit.release_result(node_info.connect(ctx, config).await)
}
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
@@ -105,7 +103,7 @@ pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBac
mechanism: &M,
user_info: &B,
wake_compute_retry_config: RetryConfig,
connect_to_compute_retry_config: RetryConfig,
compute: &ComputeConfig,
) -> Result<M::Connection, M::Error>
where
M::ConnectError: CouldRetry + ShouldRetryWakeCompute + std::fmt::Debug,
@@ -119,10 +117,7 @@ where
mechanism.update_connect_config(&mut node_info.config);
// try once
let err = match mechanism
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)
.await
{
let err = match mechanism.connect_once(ctx, &node_info, compute).await {
Ok(res) => {
ctx.success();
Metrics::get().proxy.retries_metric.observe(
@@ -142,7 +137,7 @@ where
let node_info = if !node_info.cached() || !err.should_retry_wake_compute() {
// If we just recieved this from cplane and didn't get it from cache, we shouldn't retry.
// Do not need to retrieve a new node_info, just return the old one.
if should_retry(&err, num_retries, connect_to_compute_retry_config) {
if should_retry(&err, num_retries, compute.retry) {
Metrics::get().proxy.retries_metric.observe(
RetriesMetricGroup {
outcome: ConnectOutcome::Failed,
@@ -172,10 +167,7 @@ where
debug!("wake_compute success. attempting to connect");
num_retries = 1;
loop {
match mechanism
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)
.await
{
match mechanism.connect_once(ctx, &node_info, compute).await {
Ok(res) => {
ctx.success();
Metrics::get().proxy.retries_metric.observe(
@@ -190,7 +182,7 @@ where
return Ok(res);
}
Err(e) => {
if !should_retry(&e, num_retries, connect_to_compute_retry_config) {
if !should_retry(&e, num_retries, compute.retry) {
// Don't log an error here, caller will print the error
Metrics::get().proxy.retries_metric.observe(
RetriesMetricGroup {
@@ -206,7 +198,7 @@ where
}
};
let wait_duration = retry_after(num_retries, connect_to_compute_retry_config);
let wait_duration = retry_after(num_retries, compute.retry);
num_retries += 1;
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::RetryTimeout);

View File

@@ -152,7 +152,7 @@ pub async fn task_main(
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass().await {
match p.proxy_pass(&config.connect_to_compute).await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
warn!(?session_id, "per-client task finished with an IO error from the client: {e:#}");
@@ -351,7 +351,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
},
&user_info,
config.wake_compute_retry_config,
config.connect_to_compute_retry_config,
&config.connect_to_compute,
)
.or_else(|e| stream.throw_error(e))
.await?;

View File

@@ -5,6 +5,7 @@ use utils::measured_stream::MeasuredStream;
use super::copy_bidirectional::ErrorSource;
use crate::cancellation;
use crate::compute::PostgresConnection;
use crate::config::ComputeConfig;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
use crate::stream::Stream;
@@ -67,9 +68,17 @@ pub(crate) struct ProxyPassthrough<P, S> {
}
impl<P, S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<P, S> {
pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> {
pub(crate) async fn proxy_pass(
self,
compute_config: &ComputeConfig,
) -> Result<(), ErrorSource> {
let res = proxy_pass(self.client, self.compute.stream, self.aux).await;
if let Err(err) = self.compute.cancel_closure.try_cancel_query().await {
if let Err(err) = self
.compute
.cancel_closure
.try_cancel_query(compute_config)
.await
{
tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");
}
res

View File

@@ -13,7 +13,7 @@ use postgres_client::tls::{MakeTlsConnect, NoTls};
use retry::{retry_after, ShouldRetryWakeCompute};
use rstest::rstest;
use rustls::crypto::ring;
use rustls::pki_types;
use rustls::{pki_types, RootCertStore};
use tokio::io::DuplexStream;
use super::connect_compute::ConnectMechanism;
@@ -22,7 +22,7 @@ use super::*;
use crate::auth::backend::{
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned,
};
use crate::config::{CertResolver, RetryConfig};
use crate::config::{CertResolver, ComputeConfig, RetryConfig};
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
use crate::control_plane::{
@@ -67,7 +67,7 @@ fn generate_certs(
}
struct ClientConfig<'a> {
config: rustls::ClientConfig,
config: Arc<rustls::ClientConfig>,
hostname: &'a str,
}
@@ -120,6 +120,7 @@ fn generate_tls_config<'a>(
store
})
.with_no_client_auth();
let config = Arc::new(config);
ClientConfig { config, hostname }
};
@@ -468,7 +469,7 @@ impl ConnectMechanism for TestConnectMechanism {
&self,
_ctx: &RequestContext,
_node_info: &control_plane::CachedNodeInfo,
_timeout: std::time::Duration,
_config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
let mut counter = self.counter.lock().unwrap();
let action = self.sequence[*counter];
@@ -576,6 +577,29 @@ fn helper_create_connect_info(
user_info
}
fn config() -> ComputeConfig {
let retry = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
let root_store = RootCertStore::empty();
let client_config =
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("ring should support the default protocol versions")
.with_root_certificates(root_store)
.with_no_client_auth();
ComputeConfig {
retry,
tls: Arc::new(client_config),
timeout: Duration::from_secs(2),
}
}
#[tokio::test]
async fn connect_to_compute_success() {
let _ = env_logger::try_init();
@@ -583,12 +607,8 @@ async fn connect_to_compute_success() {
let ctx = RequestContext::test();
let mechanism = TestConnectMechanism::new(vec![Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap();
mechanism.verify();
@@ -601,12 +621,8 @@ async fn connect_to_compute_retry() {
let ctx = RequestContext::test();
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap();
mechanism.verify();
@@ -620,12 +636,8 @@ async fn connect_to_compute_non_retry_1() {
let ctx = RequestContext::test();
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]);
let user_info = helper_create_connect_info(&mechanism);
let config = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap_err();
mechanism.verify();
@@ -639,12 +651,8 @@ async fn connect_to_compute_non_retry_2() {
let ctx = RequestContext::test();
let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap();
mechanism.verify();
@@ -665,17 +673,13 @@ async fn connect_to_compute_non_retry_3() {
max_retries: 1,
backoff_factor: 2.0,
};
let connect_to_compute_retry_config = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
let config = config();
connect_to_compute(
&ctx,
&mechanism,
&user_info,
wake_compute_retry_config,
connect_to_compute_retry_config,
&config,
)
.await
.unwrap_err();
@@ -690,12 +694,8 @@ async fn wake_retry() {
let ctx = RequestContext::test();
let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap();
mechanism.verify();
@@ -709,12 +709,8 @@ async fn wake_non_retry() {
let ctx = RequestContext::test();
let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]);
let user_info = helper_create_connect_info(&mechanism);
let config = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap_err();
mechanism.verify();

View File

@@ -12,6 +12,7 @@ use uuid::Uuid;
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::cache::project_info::ProjectInfoCache;
use crate::cancellation::{CancelMap, CancellationHandler};
use crate::config::ProxyConfig;
use crate::intern::{ProjectIdInt, RoleNameInt};
use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
@@ -249,6 +250,7 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
/// Handle console's invalidation messages.
#[tracing::instrument(name = "redis_notifications", skip_all)]
pub async fn task_main<C>(
config: &'static ProxyConfig,
redis: ConnectionWithCredentialsProvider,
cache: Arc<C>,
cancel_map: CancelMap,
@@ -258,6 +260,7 @@ where
C: ProjectInfoCache + Send + Sync + 'static,
{
let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
&config.connect_to_compute,
cancel_map,
crate::metrics::CancellationSource::FromRedis,
));

View File

@@ -22,7 +22,7 @@ use crate::compute;
use crate::compute_ctl::{
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
};
use crate::config::ProxyConfig;
use crate::config::{ComputeConfig, ProxyConfig};
use crate::context::RequestContext;
use crate::control_plane::client::ApiLockError;
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
@@ -196,7 +196,7 @@ impl PoolingBackend {
},
&backend,
self.config.wake_compute_retry_config,
self.config.connect_to_compute_retry_config,
&self.config.connect_to_compute,
)
.await
}
@@ -237,7 +237,7 @@ impl PoolingBackend {
},
&backend,
self.config.wake_compute_retry_config,
self.config.connect_to_compute_retry_config,
&self.config.connect_to_compute,
)
.await
}
@@ -502,7 +502,7 @@ impl ConnectMechanism for TokioMechanism {
&self,
ctx: &RequestContext,
node_info: &CachedNodeInfo,
timeout: Duration,
compute_config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
@@ -511,7 +511,7 @@ impl ConnectMechanism for TokioMechanism {
let config = config
.user(&self.conn_info.user_info.user)
.dbname(&self.conn_info.dbname)
.connect_timeout(timeout);
.connect_timeout(compute_config.timeout);
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let res = config.connect(postgres_client::NoTls).await;
@@ -552,7 +552,7 @@ impl ConnectMechanism for HyperMechanism {
&self,
ctx: &RequestContext,
node_info: &CachedNodeInfo,
timeout: Duration,
config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
@@ -560,7 +560,7 @@ impl ConnectMechanism for HyperMechanism {
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let port = node_info.config.get_port();
let res = connect_http2(&host, port, timeout).await;
let res = connect_http2(&host, port, config.timeout).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;

View File

@@ -168,7 +168,7 @@ pub(crate) async fn serve_websocket(
Ok(Some(p)) => {
ctx.set_success();
ctx.log_connect();
match p.proxy_pass().await {
match p.proxy_pass(&config.connect_to_compute).await {
Ok(()) => Ok(()),
Err(ErrorSource::Client(err)) => Err(err).context("client"),
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),