mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-16 12:40:36 +00:00
Compare commits
6 Commits
release-pr
...
conrad/pro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a4a72c8075 | ||
|
|
90ce4f3002 | ||
|
|
b79a1dd337 | ||
|
|
bbc799ce77 | ||
|
|
cd0924c686 | ||
|
|
2548926ea6 |
@@ -9,7 +9,7 @@ use std::io;
|
||||
pub(crate) async fn cancel_query<T>(
|
||||
config: Option<SocketConfig>,
|
||||
ssl_mode: SslMode,
|
||||
mut tls: T,
|
||||
tls: T,
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
) -> Result<(), Error>
|
||||
|
||||
@@ -10,7 +10,7 @@ use tokio::net::TcpStream;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
pub async fn connect<T>(
|
||||
mut tls: T,
|
||||
tls: T,
|
||||
config: &Config,
|
||||
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
|
||||
where
|
||||
|
||||
@@ -46,7 +46,7 @@ pub trait MakeTlsConnect<S> {
|
||||
/// Creates a new `TlsConnect`or.
|
||||
///
|
||||
/// The domain name is provided for certificate verification and SNI.
|
||||
fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error>;
|
||||
fn make_tls_connect(self, domain: &str) -> Result<Self::TlsConnect, Self::Error>;
|
||||
}
|
||||
|
||||
/// An asynchronous function wrapping a stream in a TLS session.
|
||||
@@ -84,7 +84,7 @@ impl<S> MakeTlsConnect<S> for NoTls {
|
||||
type TlsConnect = NoTls;
|
||||
type Error = NoTlsError;
|
||||
|
||||
fn make_tls_connect(&mut self, _: &str) -> Result<NoTls, NoTlsError> {
|
||||
fn make_tls_connect(self, _: &str) -> Result<NoTls, NoTlsError> {
|
||||
Ok(NoTls)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,9 @@ use proxy::auth::backend::jwt::JwkCache;
|
||||
use proxy::auth::backend::local::{LocalBackend, JWKS_ROLE_MAP};
|
||||
use proxy::auth::{self};
|
||||
use proxy::cancellation::CancellationHandlerMain;
|
||||
use proxy::config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig};
|
||||
use proxy::config::{
|
||||
self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig,
|
||||
};
|
||||
use proxy::control_plane::locks::ApiLocks;
|
||||
use proxy::control_plane::messages::{EndpointJwksResponse, JwksSettings};
|
||||
use proxy::http::health_server::AppMetrics;
|
||||
@@ -32,6 +34,8 @@ project_git_version!(GIT_VERSION);
|
||||
project_build_tag!(BUILD_TAG);
|
||||
|
||||
use clap::Parser;
|
||||
use rustls::crypto::ring;
|
||||
use rustls::RootCertStore;
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::Notify;
|
||||
@@ -209,6 +213,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
http_listener,
|
||||
shutdown.clone(),
|
||||
Arc::new(CancellationHandlerMain::new(
|
||||
&config.connect_to_compute,
|
||||
Arc::new(DashMap::new()),
|
||||
None,
|
||||
proxy::metrics::CancellationSource::Local,
|
||||
@@ -268,10 +273,25 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
|
||||
max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
|
||||
};
|
||||
|
||||
// local_proxy won't use TLS to talk to postgres.
|
||||
let root_store = RootCertStore::empty();
|
||||
|
||||
let client_config =
|
||||
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
|
||||
.with_safe_default_protocol_versions()
|
||||
.expect("ring should support the default protocol versions")
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth();
|
||||
|
||||
let compute_config = ComputeConfig {
|
||||
retry: RetryConfig::parse(RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)?,
|
||||
tls: Arc::new(client_config).into(),
|
||||
timeout: Duration::from_secs(2),
|
||||
};
|
||||
|
||||
Ok(Box::leak(Box::new(ProxyConfig {
|
||||
tls_config: None,
|
||||
metric_collection: None,
|
||||
allow_self_signed_compute: false,
|
||||
http_config,
|
||||
authentication_config: AuthenticationConfig {
|
||||
jwks_cache: JwkCache::default(),
|
||||
@@ -290,9 +310,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
|
||||
region: "local".into(),
|
||||
wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
|
||||
connect_compute_locks,
|
||||
connect_to_compute_retry_config: RetryConfig::parse(
|
||||
RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES,
|
||||
)?,
|
||||
connect_to_compute: compute_config,
|
||||
})))
|
||||
}
|
||||
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::bail;
|
||||
use anyhow::{bail, Context};
|
||||
use futures::future::Either;
|
||||
use proxy::auth::backend::jwt::JwkCache;
|
||||
use proxy::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned};
|
||||
use proxy::cancellation::{CancelMap, CancellationHandler};
|
||||
use proxy::config::{
|
||||
self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, HttpConfig,
|
||||
self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig,
|
||||
ProjectInfoCacheOptions, ProxyConfig, ProxyProtocolV2,
|
||||
};
|
||||
use proxy::context::parquet::ParquetUploadArgs;
|
||||
@@ -25,6 +26,7 @@ use proxy::serverless::cancel_set::CancelSet;
|
||||
use proxy::serverless::GlobalConnPoolOptions;
|
||||
use proxy::{auth, control_plane, http, serverless, usage_metrics};
|
||||
use remote_storage::RemoteStorageConfig;
|
||||
use rustls::crypto::ring;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::JoinSet;
|
||||
@@ -129,9 +131,6 @@ struct ProxyCliArgs {
|
||||
/// lock for `connect_compute` api method. example: "shards=32,permits=4,epoch=10m,timeout=1s". (use `permits=0` to disable).
|
||||
#[clap(long, default_value = config::ConcurrencyLockOptions::DEFAULT_OPTIONS_CONNECT_COMPUTE_LOCK)]
|
||||
connect_compute_lock: String,
|
||||
/// Allow self-signed certificates for compute nodes (for testing)
|
||||
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
|
||||
allow_self_signed_compute: bool,
|
||||
#[clap(flatten)]
|
||||
sql_over_http: SqlOverHttpArgs,
|
||||
/// timeout for scram authentication protocol
|
||||
@@ -400,6 +399,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
let cancellation_handler = Arc::new(CancellationHandler::<
|
||||
Option<Arc<Mutex<RedisPublisherClient>>>,
|
||||
>::new(
|
||||
&config.connect_to_compute,
|
||||
cancel_map.clone(),
|
||||
redis_publisher,
|
||||
proxy::metrics::CancellationSource::FromClient,
|
||||
@@ -495,6 +495,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
let cache = api.caches.project_info.clone();
|
||||
if let Some(client) = client1 {
|
||||
maintenance_tasks.spawn(notifications::task_main(
|
||||
config,
|
||||
client,
|
||||
cache.clone(),
|
||||
cancel_map.clone(),
|
||||
@@ -503,6 +504,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
if let Some(client) = client2 {
|
||||
maintenance_tasks.spawn(notifications::task_main(
|
||||
config,
|
||||
client,
|
||||
cache.clone(),
|
||||
cancel_map.clone(),
|
||||
@@ -564,9 +566,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
_ => bail!("either both or neither tls-key and tls-cert must be specified"),
|
||||
};
|
||||
|
||||
if args.allow_self_signed_compute {
|
||||
warn!("allowing self-signed compute certificates");
|
||||
}
|
||||
let backup_metric_collection_config = config::MetricBackupCollectionConfig {
|
||||
interval: args.metric_backup_collection_interval,
|
||||
remote_storage_config: args.metric_backup_collection_remote_storage.clone(),
|
||||
@@ -638,10 +637,26 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
console_redirect_confirmation_timeout: args.webauth_confirmation_timeout,
|
||||
};
|
||||
|
||||
let root_store = load_certs()
|
||||
.context("loading native tls certificates")?
|
||||
.clone();
|
||||
|
||||
let client_config =
|
||||
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
|
||||
.with_safe_default_protocol_versions()
|
||||
.expect("ring should support the default protocol versions")
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth();
|
||||
|
||||
let compute_config = ComputeConfig {
|
||||
retry: config::RetryConfig::parse(&args.connect_to_compute_retry)?,
|
||||
tls: Arc::new(client_config).into(),
|
||||
timeout: Duration::from_secs(2),
|
||||
};
|
||||
|
||||
let config = ProxyConfig {
|
||||
tls_config,
|
||||
metric_collection,
|
||||
allow_self_signed_compute: args.allow_self_signed_compute,
|
||||
http_config,
|
||||
authentication_config,
|
||||
proxy_protocol_v2: args.proxy_protocol_v2,
|
||||
@@ -649,9 +664,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
region: args.region.clone(),
|
||||
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
|
||||
connect_compute_locks,
|
||||
connect_to_compute_retry_config: config::RetryConfig::parse(
|
||||
&args.connect_to_compute_retry,
|
||||
)?,
|
||||
connect_to_compute: compute_config,
|
||||
};
|
||||
|
||||
let config = Box::leak(Box::new(config));
|
||||
@@ -661,6 +674,18 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
pub(crate) fn load_certs() -> anyhow::Result<Arc<rustls::RootCertStore>> {
|
||||
let der_certs = rustls_native_certs::load_native_certs();
|
||||
|
||||
if !der_certs.errors.is_empty() {
|
||||
bail!("could not parse certificates: {:?}", der_certs.errors);
|
||||
}
|
||||
|
||||
let mut store = rustls::RootCertStore::empty();
|
||||
store.add_parsable_certificates(der_certs.certs);
|
||||
Ok(Arc::new(store))
|
||||
}
|
||||
|
||||
/// auth::Backend is created at proxy startup, and lives forever.
|
||||
fn build_auth_backend(
|
||||
args: &ProxyCliArgs,
|
||||
|
||||
@@ -3,10 +3,9 @@ use std::sync::Arc;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
|
||||
use once_cell::sync::OnceCell;
|
||||
use postgres_client::{tls::MakeTlsConnect, CancelToken};
|
||||
use postgres_client::tls::MakeTlsConnect;
|
||||
use postgres_client::CancelToken;
|
||||
use pq_proto::CancelKeyData;
|
||||
use rustls::crypto::ring;
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
@@ -14,17 +13,16 @@ use tracing::{debug, info};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::auth::{check_peer_addr_is_in_list, IpPattern};
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::error::ReportableError;
|
||||
use crate::ext::LockExt;
|
||||
use crate::metrics::{CancellationRequest, CancellationSource, Metrics};
|
||||
use crate::postgres_rustls::MakeRustlsConnect;
|
||||
use crate::rate_limiter::LeakyBucketRateLimiter;
|
||||
use crate::redis::cancellation_publisher::{
|
||||
CancellationPublisher, CancellationPublisherMut, RedisPublisherClient,
|
||||
};
|
||||
|
||||
use crate::compute::{load_certs, AcceptEverythingVerifier};
|
||||
use crate::postgres_rustls::MakeRustlsConnect;
|
||||
|
||||
pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
|
||||
pub type CancellationHandlerMain = CancellationHandler<Option<Arc<Mutex<RedisPublisherClient>>>>;
|
||||
pub(crate) type CancellationHandlerMainInternal = Option<Arc<Mutex<RedisPublisherClient>>>;
|
||||
@@ -35,6 +33,7 @@ type IpSubnetKey = IpNet;
|
||||
///
|
||||
/// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
|
||||
pub struct CancellationHandler<P> {
|
||||
compute_config: &'static ComputeConfig,
|
||||
map: CancelMap,
|
||||
client: P,
|
||||
/// This field used for the monitoring purposes.
|
||||
@@ -183,7 +182,7 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
|
||||
"cancelling query per user's request using key {key}, hostname {}, address: {}",
|
||||
cancel_closure.hostname, cancel_closure.socket_addr
|
||||
);
|
||||
cancel_closure.try_cancel_query().await
|
||||
cancel_closure.try_cancel_query(self.compute_config).await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -198,8 +197,13 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
|
||||
}
|
||||
|
||||
impl CancellationHandler<()> {
|
||||
pub fn new(map: CancelMap, from: CancellationSource) -> Self {
|
||||
pub fn new(
|
||||
compute_config: &'static ComputeConfig,
|
||||
map: CancelMap,
|
||||
from: CancellationSource,
|
||||
) -> Self {
|
||||
Self {
|
||||
compute_config,
|
||||
map,
|
||||
client: (),
|
||||
from,
|
||||
@@ -214,8 +218,14 @@ impl CancellationHandler<()> {
|
||||
}
|
||||
|
||||
impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
|
||||
pub fn new(map: CancelMap, client: Option<Arc<Mutex<P>>>, from: CancellationSource) -> Self {
|
||||
pub fn new(
|
||||
compute_config: &'static ComputeConfig,
|
||||
map: CancelMap,
|
||||
client: Option<Arc<Mutex<P>>>,
|
||||
from: CancellationSource,
|
||||
) -> Self {
|
||||
Self {
|
||||
compute_config,
|
||||
map,
|
||||
client,
|
||||
from,
|
||||
@@ -229,8 +239,6 @@ impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
|
||||
}
|
||||
}
|
||||
|
||||
static TLS_ROOTS: OnceCell<Arc<rustls::RootCertStore>> = OnceCell::new();
|
||||
|
||||
/// This should've been a [`std::future::Future`], but
|
||||
/// it's impossible to name a type of an unboxed future
|
||||
/// (we'd need something like `#![feature(type_alias_impl_trait)]`).
|
||||
@@ -240,7 +248,6 @@ pub struct CancelClosure {
|
||||
cancel_token: CancelToken,
|
||||
ip_allowlist: Vec<IpPattern>,
|
||||
hostname: String, // for pg_sni router
|
||||
allow_self_signed_compute: bool,
|
||||
}
|
||||
|
||||
impl CancelClosure {
|
||||
@@ -249,49 +256,23 @@ impl CancelClosure {
|
||||
cancel_token: CancelToken,
|
||||
ip_allowlist: Vec<IpPattern>,
|
||||
hostname: String,
|
||||
allow_self_signed_compute: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
socket_addr,
|
||||
cancel_token,
|
||||
ip_allowlist,
|
||||
hostname,
|
||||
allow_self_signed_compute,
|
||||
}
|
||||
}
|
||||
/// Cancels the query running on user's compute node.
|
||||
pub(crate) async fn try_cancel_query(self) -> Result<(), CancelError> {
|
||||
pub(crate) async fn try_cancel_query(
|
||||
self,
|
||||
compute_config: &ComputeConfig,
|
||||
) -> Result<(), CancelError> {
|
||||
let socket = TcpStream::connect(self.socket_addr).await?;
|
||||
|
||||
let client_config = if self.allow_self_signed_compute {
|
||||
// Allow all certificates for creating the connection. Used only for tests
|
||||
let verifier = Arc::new(AcceptEverythingVerifier);
|
||||
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
|
||||
.with_safe_default_protocol_versions()
|
||||
.expect("ring should support the default protocol versions")
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(verifier)
|
||||
} else {
|
||||
let root_store = TLS_ROOTS
|
||||
.get_or_try_init(load_certs)
|
||||
.map_err(|_e| {
|
||||
CancelError::IO(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
"TLS root store initialization failed".to_string(),
|
||||
))
|
||||
})?
|
||||
.clone();
|
||||
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
|
||||
.with_safe_default_protocol_versions()
|
||||
.expect("ring should support the default protocol versions")
|
||||
.with_root_certificates(root_store)
|
||||
};
|
||||
|
||||
let client_config = client_config.with_no_client_auth();
|
||||
|
||||
let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(client_config);
|
||||
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
|
||||
&mut mk_tls,
|
||||
crate::postgres_rustls::MakeRustlsConnect::new(&compute_config.tls),
|
||||
&self.hostname,
|
||||
)
|
||||
.map_err(|e| {
|
||||
@@ -341,11 +322,41 @@ impl<P> Drop for Session<P> {
|
||||
#[cfg(test)]
|
||||
#[expect(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use rustls::crypto::ring;
|
||||
use rustls::RootCertStore;
|
||||
|
||||
use super::*;
|
||||
use crate::config::RetryConfig;
|
||||
|
||||
fn config() -> ComputeConfig {
|
||||
let retry = RetryConfig {
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
|
||||
let root_store = RootCertStore::empty();
|
||||
|
||||
let client_config =
|
||||
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
|
||||
.with_safe_default_protocol_versions()
|
||||
.expect("ring should support the default protocol versions")
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth();
|
||||
|
||||
ComputeConfig {
|
||||
retry,
|
||||
tls: Arc::new(client_config).into(),
|
||||
timeout: Duration::from_secs(2),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn check_session_drop() -> anyhow::Result<()> {
|
||||
let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
|
||||
Box::leak(Box::new(config())),
|
||||
CancelMap::default(),
|
||||
CancellationSource::FromRedis,
|
||||
));
|
||||
@@ -361,8 +372,11 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn cancel_session_noop_regression() {
|
||||
let handler =
|
||||
CancellationHandler::<()>::new(CancelMap::default(), CancellationSource::Local);
|
||||
let handler = CancellationHandler::<()>::new(
|
||||
Box::leak(Box::new(config())),
|
||||
CancelMap::default(),
|
||||
CancellationSource::Local,
|
||||
);
|
||||
handler
|
||||
.cancel_session(
|
||||
CancelKeyData {
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use itertools::Itertools;
|
||||
use once_cell::sync::OnceCell;
|
||||
use postgres_client::tls::MakeTlsConnect;
|
||||
use postgres_client::{CancelToken, RawConnection};
|
||||
use postgres_protocol::message::backend::NoticeResponseBody;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use rustls::client::danger::ServerCertVerifier;
|
||||
use rustls::crypto::ring;
|
||||
use rustls::pki_types::InvalidDnsNameError;
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
@@ -19,6 +15,7 @@ use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::auth::parse_endpoint_param;
|
||||
use crate::cancellation::CancelClosure;
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::client::ApiLockError;
|
||||
use crate::control_plane::errors::WakeComputeError;
|
||||
@@ -41,9 +38,6 @@ pub(crate) enum ConnectionError {
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
CouldNotConnect(#[from] io::Error),
|
||||
|
||||
#[error("Couldn't load native TLS certificates: {0:?}")]
|
||||
TlsCertificateError(Vec<rustls_native_certs::Error>),
|
||||
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
TlsError(#[from] InvalidDnsNameError),
|
||||
|
||||
@@ -90,7 +84,6 @@ impl ReportableError for ConnectionError {
|
||||
}
|
||||
ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::TlsCertificateError(_) => crate::error::ErrorKind::Service,
|
||||
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
|
||||
ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
|
||||
@@ -228,7 +221,7 @@ impl ConnCfg {
|
||||
}
|
||||
}
|
||||
|
||||
type RustlsStream = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
|
||||
type RustlsStream = crate::postgres_rustls::RustlsStream<tokio::net::TcpStream>;
|
||||
|
||||
pub(crate) struct PostgresConnection {
|
||||
/// Socket connected to a compute node.
|
||||
@@ -251,37 +244,15 @@ impl ConnCfg {
|
||||
pub(crate) async fn connect(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
allow_self_signed_compute: bool,
|
||||
aux: MetricsAuxInfo,
|
||||
timeout: Duration,
|
||||
config: &ComputeConfig,
|
||||
) -> Result<PostgresConnection, ConnectionError> {
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
|
||||
let (socket_addr, stream, host) = self.connect_raw(config.timeout).await?;
|
||||
drop(pause);
|
||||
|
||||
let client_config = if allow_self_signed_compute {
|
||||
// Allow all certificates for creating the connection
|
||||
let verifier = Arc::new(AcceptEverythingVerifier);
|
||||
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
|
||||
.with_safe_default_protocol_versions()
|
||||
.expect("ring should support the default protocol versions")
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(verifier)
|
||||
} else {
|
||||
let root_store = TLS_ROOTS
|
||||
.get_or_try_init(load_certs)
|
||||
.map_err(ConnectionError::TlsCertificateError)?
|
||||
.clone();
|
||||
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
|
||||
.with_safe_default_protocol_versions()
|
||||
.expect("ring should support the default protocol versions")
|
||||
.with_root_certificates(root_store)
|
||||
};
|
||||
let client_config = client_config.with_no_client_auth();
|
||||
|
||||
let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(client_config);
|
||||
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
|
||||
&mut mk_tls,
|
||||
crate::postgres_rustls::MakeRustlsConnect::new(&config.tls),
|
||||
host,
|
||||
)?;
|
||||
|
||||
@@ -320,7 +291,6 @@ impl ConnCfg {
|
||||
},
|
||||
vec![],
|
||||
host.to_string(),
|
||||
allow_self_signed_compute,
|
||||
);
|
||||
|
||||
let connection = PostgresConnection {
|
||||
@@ -352,63 +322,6 @@ fn filtered_options(options: &str) -> Option<String> {
|
||||
Some(options)
|
||||
}
|
||||
|
||||
pub(crate) fn load_certs() -> Result<Arc<rustls::RootCertStore>, Vec<rustls_native_certs::Error>> {
|
||||
let der_certs = rustls_native_certs::load_native_certs();
|
||||
|
||||
if !der_certs.errors.is_empty() {
|
||||
return Err(der_certs.errors);
|
||||
}
|
||||
|
||||
let mut store = rustls::RootCertStore::empty();
|
||||
store.add_parsable_certificates(der_certs.certs);
|
||||
Ok(Arc::new(store))
|
||||
}
|
||||
static TLS_ROOTS: OnceCell<Arc<rustls::RootCertStore>> = OnceCell::new();
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct AcceptEverythingVerifier;
|
||||
impl ServerCertVerifier for AcceptEverythingVerifier {
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
use rustls::SignatureScheme;
|
||||
// The schemes for which `SignatureScheme::supported_in_tls13` returns true.
|
||||
vec![
|
||||
SignatureScheme::ECDSA_NISTP521_SHA512,
|
||||
SignatureScheme::ECDSA_NISTP384_SHA384,
|
||||
SignatureScheme::ECDSA_NISTP256_SHA256,
|
||||
SignatureScheme::RSA_PSS_SHA512,
|
||||
SignatureScheme::RSA_PSS_SHA384,
|
||||
SignatureScheme::RSA_PSS_SHA256,
|
||||
SignatureScheme::ED25519,
|
||||
]
|
||||
}
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &rustls::pki_types::CertificateDer<'_>,
|
||||
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
|
||||
_server_name: &rustls::pki_types::ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls::pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
}
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls::pki_types::CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls::pki_types::CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -10,6 +10,7 @@ use remote_storage::RemoteStorageConfig;
|
||||
use rustls::crypto::ring::{self, sign};
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
|
||||
use sha2::{Digest, Sha256};
|
||||
use tokio_rustls::TlsConnector;
|
||||
use tracing::{error, info};
|
||||
use x509_parser::oid_registry;
|
||||
|
||||
@@ -25,7 +26,6 @@ use crate::types::Host;
|
||||
pub struct ProxyConfig {
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
pub metric_collection: Option<MetricCollectionConfig>,
|
||||
pub allow_self_signed_compute: bool,
|
||||
pub http_config: HttpConfig,
|
||||
pub authentication_config: AuthenticationConfig,
|
||||
pub proxy_protocol_v2: ProxyProtocolV2,
|
||||
@@ -33,7 +33,13 @@ pub struct ProxyConfig {
|
||||
pub handshake_timeout: Duration,
|
||||
pub wake_compute_retry_config: RetryConfig,
|
||||
pub connect_compute_locks: ApiLocks<Host>,
|
||||
pub connect_to_compute_retry_config: RetryConfig,
|
||||
pub connect_to_compute: ComputeConfig,
|
||||
}
|
||||
|
||||
pub struct ComputeConfig {
|
||||
pub retry: RetryConfig,
|
||||
pub tls: TlsConnector,
|
||||
pub timeout: Duration,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, ValueEnum, PartialEq)]
|
||||
|
||||
@@ -115,7 +115,7 @@ pub async fn task_main(
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
let _disconnect = ctx.log_connect();
|
||||
match p.proxy_pass().await {
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
error!(?session_id, "per-client task finished with an IO error from the client: {e:#}");
|
||||
@@ -213,11 +213,10 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
params_compat: true,
|
||||
params: ¶ms,
|
||||
locks: &config.connect_compute_locks,
|
||||
allow_self_signed_compute: config.allow_self_signed_compute,
|
||||
},
|
||||
&user_info,
|
||||
config.wake_compute_retry_config,
|
||||
config.connect_to_compute_retry_config,
|
||||
&config.connect_to_compute,
|
||||
)
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
|
||||
@@ -10,13 +10,13 @@ pub mod client;
|
||||
pub(crate) mod errors;
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::auth::backend::jwt::AuthRule;
|
||||
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
|
||||
use crate::auth::IpPattern;
|
||||
use crate::cache::project_info::ProjectInfoCacheImpl;
|
||||
use crate::cache::{Cached, TimedLru};
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo};
|
||||
use crate::intern::ProjectIdInt;
|
||||
@@ -73,12 +73,9 @@ impl NodeInfo {
|
||||
pub(crate) async fn connect(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
allow_self_signed_compute: bool,
|
||||
timeout: Duration,
|
||||
config: &ComputeConfig,
|
||||
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
|
||||
self.config
|
||||
.connect(ctx, allow_self_signed_compute, self.aux.clone(), timeout)
|
||||
.await
|
||||
self.config.connect(ctx, self.aux.clone(), config).await
|
||||
}
|
||||
|
||||
pub(crate) fn reuse_settings(&mut self, other: Self) {
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
use std::convert::TryFrom;
|
||||
use std::sync::Arc;
|
||||
|
||||
use postgres_client::tls::MakeTlsConnect;
|
||||
pub use private::RustlsStream;
|
||||
use rustls::pki_types::ServerName;
|
||||
use rustls::ClientConfig;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_rustls::TlsConnector;
|
||||
|
||||
mod private {
|
||||
use std::future::Future;
|
||||
@@ -35,14 +35,12 @@ mod private {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RustlsConnect(pub RustlsConnectData);
|
||||
|
||||
pub struct RustlsConnectData {
|
||||
pub struct RustlsConnectData<'a> {
|
||||
pub hostname: ServerName<'static>,
|
||||
pub connector: TlsConnector,
|
||||
pub connector: &'a TlsConnector,
|
||||
}
|
||||
|
||||
impl<S> TlsConnect<S> for RustlsConnect
|
||||
impl<S> TlsConnect<S> for RustlsConnectData<'_>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
@@ -52,7 +50,7 @@ mod private {
|
||||
|
||||
fn connect(self, stream: S) -> Self::Future {
|
||||
TlsConnectFuture {
|
||||
inner: self.0.connector.connect(self.0.hostname, stream),
|
||||
inner: self.connector.connect(self.hostname, stream),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -124,35 +122,30 @@ mod private {
|
||||
/// A `MakeTlsConnect` implementation using `rustls`.
|
||||
///
|
||||
/// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
|
||||
#[derive(Clone)]
|
||||
pub struct MakeRustlsConnect {
|
||||
config: Arc<ClientConfig>,
|
||||
pub struct MakeRustlsConnect<'a> {
|
||||
pub connector: &'a TlsConnector,
|
||||
}
|
||||
|
||||
impl MakeRustlsConnect {
|
||||
/// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
|
||||
impl<'a> MakeRustlsConnect<'a> {
|
||||
/// Creates a new `MakeRustlsConnect` from the provided `TlsConnector`.
|
||||
#[must_use]
|
||||
pub fn new(config: ClientConfig) -> Self {
|
||||
Self {
|
||||
config: Arc::new(config),
|
||||
}
|
||||
pub fn new(connector: &'a TlsConnector) -> Self {
|
||||
Self { connector }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> MakeTlsConnect<S> for MakeRustlsConnect
|
||||
impl<'a, S> MakeTlsConnect<S> for MakeRustlsConnect<'a>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
type Stream = private::RustlsStream<S>;
|
||||
type TlsConnect = private::RustlsConnect;
|
||||
type TlsConnect = private::RustlsConnectData<'a>;
|
||||
type Error = rustls::pki_types::InvalidDnsNameError;
|
||||
|
||||
fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
|
||||
ServerName::try_from(hostname).map(|dns_name| {
|
||||
private::RustlsConnect(private::RustlsConnectData {
|
||||
hostname: dns_name.to_owned(),
|
||||
connector: Arc::clone(&self.config).into(),
|
||||
})
|
||||
fn make_tls_connect(self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
|
||||
ServerName::try_from(hostname).map(|dns_name| private::RustlsConnectData {
|
||||
hostname: dns_name.to_owned(),
|
||||
connector: self.connector,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ use tracing::{debug, info, warn};
|
||||
use super::retry::ShouldRetryWakeCompute;
|
||||
use crate::auth::backend::ComputeCredentialKeys;
|
||||
use crate::compute::{self, PostgresConnection, COULD_NOT_CONNECT};
|
||||
use crate::config::RetryConfig;
|
||||
use crate::config::{ComputeConfig, RetryConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::errors::WakeComputeError;
|
||||
use crate::control_plane::locks::ApiLocks;
|
||||
@@ -19,8 +19,6 @@ use crate::proxy::retry::{retry_after, should_retry, CouldRetry};
|
||||
use crate::proxy::wake_compute::wake_compute;
|
||||
use crate::types::Host;
|
||||
|
||||
const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);
|
||||
|
||||
/// If we couldn't connect, a cached connection info might be to blame
|
||||
/// (e.g. the compute node's address might've changed at the wrong time).
|
||||
/// Invalidate the cache entry (if any) to prevent subsequent errors.
|
||||
@@ -49,7 +47,7 @@ pub(crate) trait ConnectMechanism {
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
node_info: &control_plane::CachedNodeInfo,
|
||||
timeout: time::Duration,
|
||||
config: &ComputeConfig,
|
||||
) -> Result<Self::Connection, Self::ConnectError>;
|
||||
|
||||
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
|
||||
@@ -73,9 +71,6 @@ pub(crate) struct TcpMechanism<'a> {
|
||||
|
||||
/// connect_to_compute concurrency lock
|
||||
pub(crate) locks: &'static ApiLocks<Host>,
|
||||
|
||||
/// Whether we should accept self-signed certificates (for testing)
|
||||
pub(crate) allow_self_signed_compute: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -89,15 +84,11 @@ impl ConnectMechanism for TcpMechanism<'_> {
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
node_info: &control_plane::CachedNodeInfo,
|
||||
timeout: time::Duration,
|
||||
config: &ComputeConfig,
|
||||
) -> Result<PostgresConnection, Self::Error> {
|
||||
let host = node_info.config.get_host();
|
||||
let permit = self.locks.get_permit(&host).await?;
|
||||
permit.release_result(
|
||||
node_info
|
||||
.connect(ctx, self.allow_self_signed_compute, timeout)
|
||||
.await,
|
||||
)
|
||||
permit.release_result(node_info.connect(ctx, config).await)
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
|
||||
@@ -112,7 +103,7 @@ pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBac
|
||||
mechanism: &M,
|
||||
user_info: &B,
|
||||
wake_compute_retry_config: RetryConfig,
|
||||
connect_to_compute_retry_config: RetryConfig,
|
||||
compute: &ComputeConfig,
|
||||
) -> Result<M::Connection, M::Error>
|
||||
where
|
||||
M::ConnectError: CouldRetry + ShouldRetryWakeCompute + std::fmt::Debug,
|
||||
@@ -126,10 +117,7 @@ where
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
|
||||
// try once
|
||||
let err = match mechanism
|
||||
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)
|
||||
.await
|
||||
{
|
||||
let err = match mechanism.connect_once(ctx, &node_info, compute).await {
|
||||
Ok(res) => {
|
||||
ctx.success();
|
||||
Metrics::get().proxy.retries_metric.observe(
|
||||
@@ -149,7 +137,7 @@ where
|
||||
let node_info = if !node_info.cached() || !err.should_retry_wake_compute() {
|
||||
// If we just recieved this from cplane and didn't get it from cache, we shouldn't retry.
|
||||
// Do not need to retrieve a new node_info, just return the old one.
|
||||
if should_retry(&err, num_retries, connect_to_compute_retry_config) {
|
||||
if should_retry(&err, num_retries, compute.retry) {
|
||||
Metrics::get().proxy.retries_metric.observe(
|
||||
RetriesMetricGroup {
|
||||
outcome: ConnectOutcome::Failed,
|
||||
@@ -179,10 +167,7 @@ where
|
||||
debug!("wake_compute success. attempting to connect");
|
||||
num_retries = 1;
|
||||
loop {
|
||||
match mechanism
|
||||
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)
|
||||
.await
|
||||
{
|
||||
match mechanism.connect_once(ctx, &node_info, compute).await {
|
||||
Ok(res) => {
|
||||
ctx.success();
|
||||
Metrics::get().proxy.retries_metric.observe(
|
||||
@@ -197,7 +182,7 @@ where
|
||||
return Ok(res);
|
||||
}
|
||||
Err(e) => {
|
||||
if !should_retry(&e, num_retries, connect_to_compute_retry_config) {
|
||||
if !should_retry(&e, num_retries, compute.retry) {
|
||||
// Don't log an error here, caller will print the error
|
||||
Metrics::get().proxy.retries_metric.observe(
|
||||
RetriesMetricGroup {
|
||||
@@ -213,7 +198,7 @@ where
|
||||
}
|
||||
};
|
||||
|
||||
let wait_duration = retry_after(num_retries, connect_to_compute_retry_config);
|
||||
let wait_duration = retry_after(num_retries, compute.retry);
|
||||
num_retries += 1;
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::RetryTimeout);
|
||||
|
||||
@@ -152,7 +152,7 @@ pub async fn task_main(
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
let _disconnect = ctx.log_connect();
|
||||
match p.proxy_pass().await {
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
warn!(?session_id, "per-client task finished with an IO error from the client: {e:#}");
|
||||
@@ -348,12 +348,10 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
params_compat,
|
||||
params: ¶ms,
|
||||
locks: &config.connect_compute_locks,
|
||||
// only used for console redirect testing.
|
||||
allow_self_signed_compute: false,
|
||||
},
|
||||
&user_info,
|
||||
config.wake_compute_retry_config,
|
||||
config.connect_to_compute_retry_config,
|
||||
&config.connect_to_compute,
|
||||
)
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
|
||||
@@ -5,6 +5,7 @@ use utils::measured_stream::MeasuredStream;
|
||||
use super::copy_bidirectional::ErrorSource;
|
||||
use crate::cancellation;
|
||||
use crate::compute::PostgresConnection;
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
|
||||
use crate::stream::Stream;
|
||||
@@ -67,9 +68,17 @@ pub(crate) struct ProxyPassthrough<P, S> {
|
||||
}
|
||||
|
||||
impl<P, S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<P, S> {
|
||||
pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> {
|
||||
pub(crate) async fn proxy_pass(
|
||||
self,
|
||||
compute_config: &ComputeConfig,
|
||||
) -> Result<(), ErrorSource> {
|
||||
let res = proxy_pass(self.client, self.compute.stream, self.aux).await;
|
||||
if let Err(err) = self.compute.cancel_closure.try_cancel_query().await {
|
||||
if let Err(err) = self
|
||||
.compute
|
||||
.cancel_closure
|
||||
.try_cancel_query(compute_config)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");
|
||||
}
|
||||
res
|
||||
|
||||
@@ -13,8 +13,9 @@ use postgres_client::tls::{MakeTlsConnect, NoTls};
|
||||
use retry::{retry_after, ShouldRetryWakeCompute};
|
||||
use rstest::rstest;
|
||||
use rustls::crypto::ring;
|
||||
use rustls::pki_types;
|
||||
use rustls::{pki_types, RootCertStore};
|
||||
use tokio::io::DuplexStream;
|
||||
use tokio_rustls::TlsConnector;
|
||||
|
||||
use super::connect_compute::ConnectMechanism;
|
||||
use super::retry::CouldRetry;
|
||||
@@ -22,7 +23,7 @@ use super::*;
|
||||
use crate::auth::backend::{
|
||||
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned,
|
||||
};
|
||||
use crate::config::{CertResolver, RetryConfig};
|
||||
use crate::config::{CertResolver, ComputeConfig, RetryConfig};
|
||||
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
|
||||
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
|
||||
use crate::control_plane::{
|
||||
@@ -67,16 +68,16 @@ fn generate_certs(
|
||||
}
|
||||
|
||||
struct ClientConfig<'a> {
|
||||
config: rustls::ClientConfig,
|
||||
config: TlsConnector,
|
||||
hostname: &'a str,
|
||||
}
|
||||
|
||||
type TlsConnect<S> = <MakeRustlsConnect as MakeTlsConnect<S>>::TlsConnect;
|
||||
type TlsConnect<'a, S> = <MakeRustlsConnect<'a> as MakeTlsConnect<S>>::TlsConnect;
|
||||
|
||||
impl ClientConfig<'_> {
|
||||
fn make_tls_connect(self) -> anyhow::Result<TlsConnect<DuplexStream>> {
|
||||
let mut mk = MakeRustlsConnect::new(self.config);
|
||||
let tls = MakeTlsConnect::<DuplexStream>::make_tls_connect(&mut mk, self.hostname)?;
|
||||
fn make_tls_connect(&self) -> anyhow::Result<TlsConnect<DuplexStream>> {
|
||||
let mk = MakeRustlsConnect::new(&self.config);
|
||||
let tls = MakeTlsConnect::<DuplexStream>::make_tls_connect(mk, self.hostname)?;
|
||||
Ok(tls)
|
||||
}
|
||||
}
|
||||
@@ -120,8 +121,12 @@ fn generate_tls_config<'a>(
|
||||
store
|
||||
})
|
||||
.with_no_client_auth();
|
||||
let config = Arc::new(config);
|
||||
|
||||
ClientConfig { config, hostname }
|
||||
ClientConfig {
|
||||
config: TlsConnector::from(config),
|
||||
hostname,
|
||||
}
|
||||
};
|
||||
|
||||
Ok((client_config, tls_config))
|
||||
@@ -468,7 +473,7 @@ impl ConnectMechanism for TestConnectMechanism {
|
||||
&self,
|
||||
_ctx: &RequestContext,
|
||||
_node_info: &control_plane::CachedNodeInfo,
|
||||
_timeout: std::time::Duration,
|
||||
_config: &ComputeConfig,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
let mut counter = self.counter.lock().unwrap();
|
||||
let action = self.sequence[*counter];
|
||||
@@ -576,6 +581,29 @@ fn helper_create_connect_info(
|
||||
user_info
|
||||
}
|
||||
|
||||
fn config() -> ComputeConfig {
|
||||
let retry = RetryConfig {
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
|
||||
let root_store = RootCertStore::empty();
|
||||
|
||||
let client_config =
|
||||
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
|
||||
.with_safe_default_protocol_versions()
|
||||
.expect("ring should support the default protocol versions")
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth();
|
||||
|
||||
ComputeConfig {
|
||||
retry,
|
||||
tls: Arc::new(client_config).into(),
|
||||
timeout: Duration::from_secs(2),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_success() {
|
||||
let _ = env_logger::try_init();
|
||||
@@ -583,12 +611,8 @@ async fn connect_to_compute_success() {
|
||||
let ctx = RequestContext::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = RetryConfig {
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -601,12 +625,8 @@ async fn connect_to_compute_retry() {
|
||||
let ctx = RequestContext::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = RetryConfig {
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -620,12 +640,8 @@ async fn connect_to_compute_non_retry_1() {
|
||||
let ctx = RequestContext::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = RetryConfig {
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
@@ -639,12 +655,8 @@ async fn connect_to_compute_non_retry_2() {
|
||||
let ctx = RequestContext::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = RetryConfig {
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -665,17 +677,13 @@ async fn connect_to_compute_non_retry_3() {
|
||||
max_retries: 1,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
let connect_to_compute_retry_config = RetryConfig {
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
let config = config();
|
||||
connect_to_compute(
|
||||
&ctx,
|
||||
&mechanism,
|
||||
&user_info,
|
||||
wake_compute_retry_config,
|
||||
connect_to_compute_retry_config,
|
||||
&config,
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
@@ -690,12 +698,8 @@ async fn wake_retry() {
|
||||
let ctx = RequestContext::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = RetryConfig {
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -709,12 +713,8 @@ async fn wake_non_retry() {
|
||||
let ctx = RequestContext::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = RetryConfig {
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_retries: 5,
|
||||
backoff_factor: 2.0,
|
||||
};
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config, config)
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
|
||||
@@ -12,6 +12,7 @@ use uuid::Uuid;
|
||||
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
|
||||
use crate::cache::project_info::ProjectInfoCache;
|
||||
use crate::cancellation::{CancelMap, CancellationHandler};
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::intern::{ProjectIdInt, RoleNameInt};
|
||||
use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
|
||||
|
||||
@@ -249,6 +250,7 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
|
||||
/// Handle console's invalidation messages.
|
||||
#[tracing::instrument(name = "redis_notifications", skip_all)]
|
||||
pub async fn task_main<C>(
|
||||
config: &'static ProxyConfig,
|
||||
redis: ConnectionWithCredentialsProvider,
|
||||
cache: Arc<C>,
|
||||
cancel_map: CancelMap,
|
||||
@@ -258,6 +260,7 @@ where
|
||||
C: ProjectInfoCache + Send + Sync + 'static,
|
||||
{
|
||||
let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
|
||||
&config.connect_to_compute,
|
||||
cancel_map,
|
||||
crate::metrics::CancellationSource::FromRedis,
|
||||
));
|
||||
|
||||
@@ -22,7 +22,7 @@ use crate::compute;
|
||||
use crate::compute_ctl::{
|
||||
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
|
||||
};
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::config::{ComputeConfig, ProxyConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::client::ApiLockError;
|
||||
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
|
||||
@@ -30,6 +30,7 @@ use crate::control_plane::locks::ApiLocks;
|
||||
use crate::control_plane::CachedNodeInfo;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::postgres_rustls::MakeRustlsConnect;
|
||||
use crate::proxy::connect_compute::ConnectMechanism;
|
||||
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
@@ -196,7 +197,7 @@ impl PoolingBackend {
|
||||
},
|
||||
&backend,
|
||||
self.config.wake_compute_retry_config,
|
||||
self.config.connect_to_compute_retry_config,
|
||||
&self.config.connect_to_compute,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -237,7 +238,7 @@ impl PoolingBackend {
|
||||
},
|
||||
&backend,
|
||||
self.config.wake_compute_retry_config,
|
||||
self.config.connect_to_compute_retry_config,
|
||||
&self.config.connect_to_compute,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -502,7 +503,7 @@ impl ConnectMechanism for TokioMechanism {
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
node_info: &CachedNodeInfo,
|
||||
timeout: Duration,
|
||||
compute_config: &ComputeConfig,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
let host = node_info.config.get_host();
|
||||
let permit = self.locks.get_permit(&host).await?;
|
||||
@@ -511,10 +512,12 @@ impl ConnectMechanism for TokioMechanism {
|
||||
let config = config
|
||||
.user(&self.conn_info.user_info.user)
|
||||
.dbname(&self.conn_info.dbname)
|
||||
.connect_timeout(timeout);
|
||||
.connect_timeout(compute_config.timeout);
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let res = config.connect(postgres_client::NoTls).await;
|
||||
let res = config
|
||||
.connect(MakeRustlsConnect::new(&compute_config.tls))
|
||||
.await;
|
||||
drop(pause);
|
||||
let (client, connection) = permit.release_result(res)?;
|
||||
|
||||
@@ -552,7 +555,7 @@ impl ConnectMechanism for HyperMechanism {
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
node_info: &CachedNodeInfo,
|
||||
timeout: Duration,
|
||||
config: &ComputeConfig,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
let host = node_info.config.get_host();
|
||||
let permit = self.locks.get_permit(&host).await?;
|
||||
@@ -560,7 +563,11 @@ impl ConnectMechanism for HyperMechanism {
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
|
||||
let port = node_info.config.get_port();
|
||||
let res = connect_http2(&host, port, timeout).await;
|
||||
|
||||
// TODO(conrad): how would we roll-out TLS for these connections?
|
||||
// Postgres has negotiation, but no such thing for HTTP.
|
||||
// Assume https, fall back to http (on the same port)?
|
||||
let res = connect_http2(&host, port, config.timeout).await;
|
||||
drop(pause);
|
||||
let (client, connection) = permit.release_result(res)?;
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ use std::task::{ready, Poll};
|
||||
|
||||
use futures::future::poll_fn;
|
||||
use futures::Future;
|
||||
use postgres_client::tls::NoTlsStream;
|
||||
use postgres_client::AsyncMessage;
|
||||
use smallvec::SmallVec;
|
||||
use tokio::net::TcpStream;
|
||||
@@ -26,6 +25,7 @@ use super::conn_pool_lib::{
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::postgres_rustls::RustlsStream;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct ConnInfoWithAuth {
|
||||
@@ -58,7 +58,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
|
||||
ctx: &RequestContext,
|
||||
conn_info: ConnInfo,
|
||||
client: C,
|
||||
mut connection: postgres_client::Connection<TcpStream, NoTlsStream>,
|
||||
mut connection: postgres_client::Connection<TcpStream, RustlsStream<tokio::net::TcpStream>>,
|
||||
conn_id: uuid::Uuid,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> Client<C> {
|
||||
|
||||
@@ -168,7 +168,7 @@ pub(crate) async fn serve_websocket(
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
ctx.log_connect();
|
||||
match p.proxy_pass().await {
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
Ok(()) => Ok(()),
|
||||
Err(ErrorSource::Client(err)) => Err(err).context("client"),
|
||||
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
|
||||
|
||||
@@ -3219,7 +3219,6 @@ class NeonProxy(PgProtocol):
|
||||
# Link auth backend params
|
||||
*["--auth-backend", "link"],
|
||||
*["--uri", NeonProxy.link_auth_uri],
|
||||
*["--allow-self-signed-compute", "true"],
|
||||
]
|
||||
|
||||
class ProxyV1(AuthBackend):
|
||||
|
||||
Reference in New Issue
Block a user