Merge branch 'main' into ruslan/subzero-integration

This commit is contained in:
Ruslan Talpa
2025-07-01 13:46:57 +03:00
105 changed files with 3256 additions and 1410 deletions

View File

@@ -288,7 +288,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
},
proxy_protocol_v2: config::ProxyProtocolV2::Rejected,
handshake_timeout: Duration::from_secs(10),
region: "local".into(),
wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
connect_compute_locks,
connect_to_compute: compute_config,

View File

@@ -26,9 +26,10 @@ use utils::sentry_init::init_sentry;
use crate::context::RequestContext;
use crate::metrics::{Metrics, ThreadPoolMetrics};
use crate::pglb::TlsRequired;
use crate::pqproto::FeStartupPacket;
use crate::protocol2::ConnectionInfo;
use crate::proxy::{ErrorSource, TlsRequired, copy_bidirectional_client_compute};
use crate::proxy::{ErrorSource, copy_bidirectional_client_compute};
use crate::stream::{PqStream, Stream};
use crate::util::run_until_cancelled;
@@ -236,7 +237,6 @@ pub(super) async fn task_main(
extra: None,
},
crate::metrics::Protocol::SniRouter,
"sni",
);
handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await
}

View File

@@ -154,12 +154,6 @@ struct ProxyCliArgs {
/// timeout for the TLS handshake
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
handshake_timeout: tokio::time::Duration,
/// http endpoint to receive periodic metric updates
#[clap(long)]
metric_collection_endpoint: Option<String>,
/// how often metrics should be sent to a collection endpoint
#[clap(long)]
metric_collection_interval: Option<String>,
/// cache for `wake_compute` api method (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
wake_compute_cache: String,
@@ -186,40 +180,31 @@ struct ProxyCliArgs {
/// Wake compute rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
wake_compute_limit: Vec<RateBucketInfo>,
/// Redis rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_REDIS_SET)]
redis_rps_limit: Vec<RateBucketInfo>,
/// Cancellation channel size (max queue size for redis kv client)
#[clap(long, default_value_t = 1024)]
cancellation_ch_size: usize,
/// Cancellation ops batch size for redis
#[clap(long, default_value_t = 8)]
cancellation_batch_size: usize,
/// cache for `allowed_ips` (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
allowed_ips_cache: String,
/// cache for `role_secret` (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
role_secret_cache: String,
/// redis url for notifications (if empty, redis_host:port will be used for both notifications and streaming connections)
#[clap(long)]
redis_notifications: Option<String>,
/// what from the available authentications type to use for the regional redis we have. Supported are "irsa" and "plain".
/// redis url for plain authentication
#[clap(long, alias("redis-notifications"))]
redis_plain: Option<String>,
/// what from the available authentications type to use for redis. Supported are "irsa" and "plain".
#[clap(long, default_value = "irsa")]
redis_auth_type: String,
/// redis host for streaming connections (might be different from the notifications host)
/// redis host for irsa authentication
#[clap(long)]
redis_host: Option<String>,
/// redis port for streaming connections (might be different from the notifications host)
/// redis port for irsa authentication
#[clap(long)]
redis_port: Option<u16>,
/// redis cluster name, used in aws elasticache
/// redis cluster name for irsa authentication
#[clap(long)]
redis_cluster_name: Option<String>,
/// redis user_id, used in aws elasticache
/// redis user_id for irsa authentication
#[clap(long)]
redis_user_id: Option<String>,
/// aws region to retrieve credentials
/// aws region for irsa authentication
#[clap(long, default_value_t = String::new())]
aws_region: String,
/// cache for `project_info` (use `size=0` to disable)
@@ -231,6 +216,12 @@ struct ProxyCliArgs {
#[clap(flatten)]
parquet_upload: ParquetUploadArgs,
/// http endpoint to receive periodic metric updates
#[clap(long)]
metric_collection_endpoint: Option<String>,
/// how often metrics should be sent to a collection endpoint
#[clap(long)]
metric_collection_interval: Option<String>,
/// interval for backup metric collection
#[clap(long, default_value = "10m", value_parser = humantime::parse_duration)]
metric_backup_collection_interval: std::time::Duration,
@@ -243,6 +234,7 @@ struct ProxyCliArgs {
/// Size of each event is no more than 400 bytes, so 2**22 is about 200MB before the compression.
#[clap(long, default_value = "4194304")]
metric_backup_collection_chunk_size: usize,
/// Whether to retry the connection to the compute node
#[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)]
connect_to_compute_retry: String,
@@ -370,7 +362,7 @@ pub async fn run() -> anyhow::Result<()> {
Either::Right(auth_backend) => info!("Authentication backend: {auth_backend:?}"),
}
info!("Using region: {}", args.aws_region);
let (regional_redis_client, redis_notifications_client) = configure_redis(&args).await?;
let redis_client = configure_redis(&args).await?;
// Check that we can bind to address before further initialization
info!("Starting http on {}", args.http);
@@ -425,13 +417,6 @@ pub async fn run() -> anyhow::Result<()> {
let cancellation_token = CancellationToken::new();
let redis_rps_limit = Vec::leak(args.redis_rps_limit.clone());
RateBucketInfo::validate(redis_rps_limit)?;
let redis_kv_client = regional_redis_client
.as_ref()
.map(|redis_publisher| RedisKVClient::new(redis_publisher.clone(), redis_rps_limit));
let cancellation_handler = Arc::new(CancellationHandler::new(&config.connect_to_compute));
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
@@ -446,7 +431,7 @@ pub async fn run() -> anyhow::Result<()> {
match auth_backend {
Either::Left(auth_backend) => {
if let Some(proxy_listener) = proxy_listener {
client_tasks.spawn(crate::proxy::task_main(
client_tasks.spawn(crate::pglb::task_main(
config,
auth_backend,
proxy_listener,
@@ -528,6 +513,7 @@ pub async fn run() -> anyhow::Result<()> {
client_tasks.spawn(crate::context::parquet::worker(
cancellation_token.clone(),
args.parquet_upload,
args.region,
));
// maintenance tasks. these never return unless there's an error
@@ -561,32 +547,17 @@ pub async fn run() -> anyhow::Result<()> {
#[cfg_attr(not(any(test, feature = "testing")), expect(irrefutable_let_patterns))]
if let Either::Left(auth::Backend::ControlPlane(api, ())) = &auth_backend {
if let crate::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api {
match (redis_notifications_client, regional_redis_client.clone()) {
(None, None) => {}
(client1, client2) => {
let cache = api.caches.project_info.clone();
if let Some(client) = client1 {
maintenance_tasks.spawn(notifications::task_main(
client,
cache.clone(),
args.region.clone(),
));
}
if let Some(client) = client2 {
maintenance_tasks.spawn(notifications::task_main(
client,
cache.clone(),
args.region.clone(),
));
}
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
}
}
if let Some(client) = redis_client {
// project info cache and invalidation of that cache.
let cache = api.caches.project_info.clone();
maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone()));
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
// Try to connect to Redis 3 times with 1 + (0..0.1) second interval.
// This prevents immediate exit and pod restart,
// which can cause hammering of the redis in case of connection issues.
if let Some(mut redis_kv_client) = redis_kv_client {
// Try to connect to Redis 3 times with 1 + (0..0.1) second interval.
// This prevents immediate exit and pod restart,
// which can cause hammering of the redis in case of connection issues.
// cancellation key management
let mut redis_kv_client = RedisKVClient::new(client.clone());
for attempt in (0..3).with_position() {
match redis_kv_client.try_connect().await {
Ok(()) => {
@@ -611,14 +582,12 @@ pub async fn run() -> anyhow::Result<()> {
}
}
}
}
if let Some(regional_redis_client) = regional_redis_client {
// listen for notifications of new projects/endpoints/branches
let cache = api.caches.endpoints_cache.clone();
let con = regional_redis_client;
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(
async move { cache.do_read(con, cancellation_token.clone()).await }
async move { cache.do_read(client, cancellation_token.clone()).await }
.instrument(span),
);
}
@@ -765,7 +734,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
authentication_config,
proxy_protocol_v2: args.proxy_protocol_v2,
handshake_timeout: args.handshake_timeout,
region: args.region.clone(),
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
connect_compute_locks,
connect_to_compute: compute_config,
@@ -942,21 +910,18 @@ fn build_auth_backend(
async fn configure_redis(
args: &ProxyCliArgs,
) -> anyhow::Result<(
Option<ConnectionWithCredentialsProvider>,
Option<ConnectionWithCredentialsProvider>,
)> {
) -> anyhow::Result<Option<ConnectionWithCredentialsProvider>> {
// TODO: untangle the config args
let regional_redis_client = match (args.redis_auth_type.as_str(), &args.redis_notifications) {
("plain", redis_url) => match redis_url {
let redis_client = match &*args.redis_auth_type {
"plain" => match &args.redis_plain {
None => {
bail!("plain auth requires redis_notifications to be set");
bail!("plain auth requires redis_plain to be set");
}
Some(url) => {
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url.clone()))
}
},
("irsa", _) => match (&args.redis_host, args.redis_port) {
"irsa" => match (&args.redis_host, args.redis_port) {
(Some(host), Some(port)) => Some(
ConnectionWithCredentialsProvider::new_with_credentials_provider(
host.clone(),
@@ -980,18 +945,12 @@ async fn configure_redis(
bail!("redis-host and redis-port must be specified together");
}
},
_ => {
bail!("unknown auth type given");
auth_type => {
bail!("unknown auth type {auth_type:?} given")
}
};
let redis_notifications_client = if let Some(url) = &args.redis_notifications {
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(&**url))
} else {
regional_redis_client.clone()
};
Ok((regional_redis_client, redis_notifications_client))
Ok(redis_client)
}

View File

@@ -30,7 +30,7 @@ use super::{Cache, timed_lru};
///
/// * There's an API for immediate invalidation (removal) of a cache entry;
/// It's useful in case we know for sure that the entry is no longer correct.
/// See [`timed_lru::LookupInfo`] & [`timed_lru::Cached`] for more information.
/// See [`timed_lru::Cached`] for more information.
///
/// * Expired entries are kept in the cache, until they are evicted by the LRU policy,
/// or by a successful lookup (i.e. the entry hasn't expired yet).
@@ -54,7 +54,7 @@ pub(crate) struct TimedLru<K, V> {
impl<K: Hash + Eq, V> Cache for TimedLru<K, V> {
type Key = K;
type Value = V;
type LookupInfo<Key> = LookupInfo<Key>;
type LookupInfo<Key> = Key;
fn invalidate(&self, info: &Self::LookupInfo<K>) {
self.invalidate_raw(info);
@@ -87,30 +87,24 @@ impl<K: Hash + Eq, V> TimedLru<K, V> {
/// Drop an entry from the cache if it's outdated.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn invalidate_raw(&self, info: &LookupInfo<K>) {
let now = Instant::now();
fn invalidate_raw(&self, key: &K) {
// Do costly things before taking the lock.
let mut cache = self.cache.lock();
let raw_entry = match cache.raw_entry_mut().from_key(&info.key) {
let entry = match cache.raw_entry_mut().from_key(key) {
RawEntryMut::Vacant(_) => return,
RawEntryMut::Occupied(x) => x,
RawEntryMut::Occupied(x) => x.remove(),
};
// Remove the entry if it was created prior to lookup timestamp.
let entry = raw_entry.get();
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
let should_remove = created_at <= info.created_at || expires_at <= now;
if should_remove {
raw_entry.remove();
}
drop(cache); // drop lock before logging
let Entry {
created_at,
expires_at,
..
} = entry;
debug!(
created_at = format_args!("{created_at:?}"),
expires_at = format_args!("{expires_at:?}"),
entry_removed = should_remove,
?created_at,
?expires_at,
"processed a cache entry invalidation event"
);
}
@@ -215,10 +209,10 @@ impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
}
pub(crate) fn insert_unit(&self, key: K, value: V) -> (Option<V>, Cached<&Self, ()>) {
let (created_at, old) = self.insert_raw(key.clone(), value);
let (_, old) = self.insert_raw(key.clone(), value);
let cached = Cached {
token: Some((self, LookupInfo { created_at, key })),
token: Some((self, key)),
value: (),
};
@@ -255,28 +249,9 @@ impl<K: Hash + Eq, V: Clone> TimedLru<K, V> {
K: Borrow<Q> + Clone,
Q: Hash + Eq + ?Sized,
{
self.get_raw(key, |key, entry| {
let info = LookupInfo {
created_at: entry.created_at,
key: key.clone(),
};
Cached {
token: Some((self, info)),
value: entry.value.clone(),
}
self.get_raw(key, |key, entry| Cached {
token: Some((self, key.clone())),
value: entry.value.clone(),
})
}
}
/// Lookup information for key invalidation.
pub(crate) struct LookupInfo<K> {
/// Time of creation of a cache [`Entry`].
/// We use this during invalidation lookups to prevent eviction of a newer
/// entry sharing the same key (it might've been inserted by a different
/// task after we got the entry we're trying to invalidate now).
created_at: Instant,
/// Search by this key.
key: K,
}

View File

@@ -236,7 +236,7 @@ impl AuthInfo {
&self,
ctx: &RequestContext,
compute: &mut ComputeConnection,
user_info: ComputeUserInfo,
user_info: &ComputeUserInfo,
) -> Result<PostgresSettings, PostgresError> {
// client config with stubbed connect info.
// TODO(conrad): should we rewrite this to bypass tokio-postgres2 entirely,
@@ -272,7 +272,7 @@ impl AuthInfo {
secret_key,
},
compute.hostname.to_string(),
user_info,
user_info.clone(),
);
Ok(PostgresSettings {

View File

@@ -24,7 +24,6 @@ pub struct ProxyConfig {
pub authentication_config: AuthenticationConfig,
pub rest_config: RestConfig,
pub proxy_protocol_v2: ProxyProtocolV2,
pub region: String,
pub handshake_timeout: Duration,
pub wake_compute_retry_config: RetryConfig,
pub connect_compute_locks: ApiLocks<Host>,

View File

@@ -11,11 +11,12 @@ use crate::config::{ProxyConfig, ProxyProtocolV2};
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::pglb::ClientRequestError;
use crate::pglb::handshake::{HandshakeData, handshake};
use crate::pglb::passthrough::ProxyPassthrough;
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
use crate::proxy::{ClientRequestError, ErrorSource, prepare_client_connection};
use crate::proxy::{ErrorSource, finish_client_init};
use crate::util::run_until_cancelled;
pub async fn task_main(
@@ -89,12 +90,7 @@ pub async fn task_main(
}
}
let ctx = RequestContext::new(
session_id,
conn_info,
crate::metrics::Protocol::Tcp,
&config.region,
);
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp);
let res = handle_client(
config,
@@ -231,13 +227,13 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
.await?;
let pg_settings = auth_info
.authenticate(ctx, &mut node, user_info)
.authenticate(ctx, &mut node, &user_info)
.or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) })
.await?;
let session = cancellation_handler.get_key();
prepare_client_connection(&pg_settings, *session.key(), &mut stream);
finish_client_init(&pg_settings, *session.key(), &mut stream);
let stream = stream.flush_and_into_inner().await?;
let session_id = ctx.session_id();

View File

@@ -46,7 +46,6 @@ struct RequestContextInner {
pub(crate) session_id: Uuid,
pub(crate) protocol: Protocol,
first_packet: chrono::DateTime<Utc>,
region: &'static str,
pub(crate) span: Span,
// filled in as they are discovered
@@ -94,7 +93,6 @@ impl Clone for RequestContext {
session_id: inner.session_id,
protocol: inner.protocol,
first_packet: inner.first_packet,
region: inner.region,
span: info_span!("background_task"),
project: inner.project,
@@ -124,12 +122,7 @@ impl Clone for RequestContext {
}
impl RequestContext {
pub fn new(
session_id: Uuid,
conn_info: ConnectionInfo,
protocol: Protocol,
region: &'static str,
) -> Self {
pub fn new(session_id: Uuid, conn_info: ConnectionInfo, protocol: Protocol) -> Self {
// TODO: be careful with long lived spans
let span = info_span!(
"connect_request",
@@ -145,7 +138,6 @@ impl RequestContext {
session_id,
protocol,
first_packet: Utc::now(),
region,
span,
project: None,
@@ -179,7 +171,7 @@ impl RequestContext {
let ip = IpAddr::from([127, 0, 0, 1]);
let addr = SocketAddr::new(ip, 5432);
let conn_info = ConnectionInfo { addr, extra: None };
RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp, "test")
RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp)
}
pub(crate) fn console_application_name(&self) -> String {

View File

@@ -74,7 +74,7 @@ pub(crate) const FAILED_UPLOAD_MAX_RETRIES: u32 = 10;
#[derive(parquet_derive::ParquetRecordWriter)]
pub(crate) struct RequestData {
region: &'static str,
region: String,
protocol: &'static str,
/// Must be UTC. The derive macro doesn't like the timezones
timestamp: chrono::NaiveDateTime,
@@ -147,7 +147,7 @@ impl From<&RequestContextInner> for RequestData {
}),
jwt_issuer: value.jwt_issuer.clone(),
protocol: value.protocol.as_str(),
region: value.region,
region: String::new(),
error: value.error_kind.as_ref().map(|e| e.to_metric_label()),
success: value.success,
cold_start_info: value.cold_start_info.as_str(),
@@ -167,6 +167,7 @@ impl From<&RequestContextInner> for RequestData {
pub async fn worker(
cancellation_token: CancellationToken,
config: ParquetUploadArgs,
region: String,
) -> anyhow::Result<()> {
let Some(remote_storage_config) = config.parquet_upload_remote_storage else {
tracing::warn!("parquet request upload: no s3 bucket configured");
@@ -232,12 +233,17 @@ pub async fn worker(
.context("remote storage for disconnect events init")?;
let parquet_config_disconnect = parquet_config.clone();
tokio::try_join!(
worker_inner(storage, rx, parquet_config),
worker_inner(storage_disconnect, rx_disconnect, parquet_config_disconnect)
worker_inner(storage, rx, parquet_config, &region),
worker_inner(
storage_disconnect,
rx_disconnect,
parquet_config_disconnect,
&region
)
)
.map(|_| ())
} else {
worker_inner(storage, rx, parquet_config).await
worker_inner(storage, rx, parquet_config, &region).await
}
}
@@ -257,6 +263,7 @@ async fn worker_inner(
storage: GenericRemoteStorage,
rx: impl Stream<Item = RequestData>,
config: ParquetConfig,
region: &str,
) -> anyhow::Result<()> {
#[cfg(any(test, feature = "testing"))]
let storage = if config.test_remote_failures > 0 {
@@ -277,7 +284,8 @@ async fn worker_inner(
let mut last_upload = time::Instant::now();
let mut len = 0;
while let Some(row) = rx.next().await {
while let Some(mut row) = rx.next().await {
region.clone_into(&mut row.region);
rows.push(row);
let force = last_upload.elapsed() > config.max_duration;
if rows.len() == config.rows_per_group || force {
@@ -533,7 +541,7 @@ mod tests {
auth_method: None,
jwt_issuer: None,
protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)],
region: "us-east-1",
region: String::new(),
error: None,
success: rng.r#gen(),
cold_start_info: "no",
@@ -565,7 +573,9 @@ mod tests {
.await
.unwrap();
worker_inner(storage, rx, config).await.unwrap();
worker_inner(storage, rx, config, "us-east-1")
.await
.unwrap();
let mut files = WalkDir::new(tmpdir.as_std_path())
.into_iter()

View File

@@ -8,10 +8,10 @@ use crate::config::TlsConfig;
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::Metrics;
use crate::pglb::TlsRequired;
use crate::pqproto::{
BeMessage, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams,
};
use crate::proxy::TlsRequired;
use crate::stream::{PqStream, Stream, StreamUpgradeError};
use crate::tls::PG_ALPN_PROTOCOL;

View File

@@ -2,3 +2,332 @@ pub mod copy_bidirectional;
pub mod handshake;
pub mod inprocess;
pub mod passthrough;
use std::sync::Arc;
use futures::FutureExt;
use smol_str::ToSmolStr;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info, warn};
use crate::auth;
use crate::cancellation::{self, CancellationHandler};
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::context::RequestContext;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumClientConnectionsGuard};
pub use crate::pglb::copy_bidirectional::ErrorSource;
use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake};
use crate::pglb::passthrough::ProxyPassthrough;
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
use crate::proxy::handle_client;
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::Stream;
use crate::util::run_until_cancelled;
pub const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
#[derive(Error, Debug)]
#[error("{ERR_INSECURE_CONNECTION}")]
pub struct TlsRequired;
impl ReportableError for TlsRequired {
fn get_error_kind(&self) -> crate::error::ErrorKind {
crate::error::ErrorKind::User
}
}
impl UserFacingError for TlsRequired {}
pub async fn task_main(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandler>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("proxy has shut down");
}
// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
{
let (socket, peer_addr) = accept_result?;
let conn_gauge = Metrics::get()
.proxy
.client_connections
.guard(crate::metrics::Protocol::Tcp);
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
let cancellations = cancellations.clone();
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
connections.spawn(async move {
let (socket, conn_info) = match config.proxy_protocol_v2 {
ProxyProtocolV2::Required => {
match read_proxy_protocol(socket).await {
Err(e) => {
warn!("per-client task finished with an error: {e:#}");
return;
}
// our load balancers will not send any more data. let's just exit immediately
Ok((_socket, ConnectHeader::Local)) => {
debug!("healthcheck received");
return;
}
Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
}
}
// ignore the header - it cannot be confused for a postgres or http connection so will
// error later.
ProxyProtocolV2::Rejected => (
socket,
ConnectionInfo {
addr: peer_addr,
extra: None,
},
),
};
match socket.set_nodelay(true) {
Ok(()) => {}
Err(e) => {
error!(
"per-client task finished with an error: failed to set socket option: {e:#}"
);
return;
}
}
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp);
let res = handle_connection(
config,
auth_backend,
&ctx,
cancellation_handler,
socket,
ClientMode::Tcp,
endpoint_rate_limiter2,
conn_gauge,
cancellations,
)
.instrument(ctx.span())
.boxed()
.await;
match res {
Err(e) => {
ctx.set_error_kind(e.get_error_kind());
warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
}
Ok(None) => {
ctx.set_success();
}
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass().await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
warn!(
?session_id,
"per-client task finished with an IO error from the client: {e:#}"
);
}
Err(ErrorSource::Compute(e)) => {
error!(
?session_id,
"per-client task finished with an IO error from the compute: {e:#}"
);
}
}
}
}
});
}
connections.close();
cancellations.close();
drop(listener);
// Drain connections
connections.wait().await;
cancellations.wait().await;
Ok(())
}
pub(crate) enum ClientMode {
Tcp,
Websockets { hostname: Option<String> },
}
/// Abstracts the logic of handling TCP vs WS clients
impl ClientMode {
pub fn allow_cleartext(&self) -> bool {
match self {
ClientMode::Tcp => false,
ClientMode::Websockets { .. } => true,
}
}
pub fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
match self {
ClientMode::Tcp => s.sni_hostname(),
ClientMode::Websockets { hostname } => hostname.as_deref(),
}
}
pub fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
match self {
ClientMode::Tcp => tls,
// TLS is None here if using websockets, because the connection is already encrypted.
ClientMode::Websockets { .. } => None,
}
}
}
#[derive(Debug, Error)]
// almost all errors should be reported to the user, but there's a few cases where we cannot
// 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
// we cannot be sure the client even understands our error message
// 3. PrepareClient: The client disconnected, so we can't tell them anyway...
pub(crate) enum ClientRequestError {
#[error("{0}")]
Cancellation(#[from] cancellation::CancelError),
#[error("{0}")]
Handshake(#[from] HandshakeError),
#[error("{0}")]
HandshakeTimeout(#[from] tokio::time::error::Elapsed),
#[error("{0}")]
PrepareClient(#[from] std::io::Error),
#[error("{0}")]
ReportedError(#[from] crate::stream::ReportedError),
}
impl ReportableError for ClientRequestError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
ClientRequestError::Cancellation(e) => e.get_error_kind(),
ClientRequestError::Handshake(e) => e.get_error_kind(),
ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
ClientRequestError::ReportedError(e) => e.get_error_kind(),
ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
}
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_connection<S: AsyncRead + AsyncWrite + Unpin + Send>(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
ctx: &RequestContext,
cancellation_handler: Arc<CancellationHandler>,
client: S,
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conn_gauge: NumClientConnectionsGuard<'static>,
cancellations: tokio_util::task::task_tracker::TaskTracker,
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
debug!(
protocol = %ctx.protocol(),
"handling interactive connection from client"
);
let metrics = &Metrics::get().proxy;
let proto = ctx.protocol();
let request_gauge = metrics.connection_requests.guard(proto);
let tls = config.tls_config.load();
let tls = tls.as_deref();
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, client, mode.handshake_tls(tls), record_handshake_error);
let (mut client, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
.await??
{
HandshakeData::Startup(client, params) => (client, params),
HandshakeData::Cancel(cancel_key_data) => {
// spawn a task to cancel the session, but don't wait for it
cancellations.spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let ctx = ctx.clone();
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
cancel_span.follows_from(tracing::Span::current());
async move {
cancellation_handler_clone
.cancel_session(
cancel_key_data,
ctx,
config.authentication_config.ip_allowlist_check_enabled,
config.authentication_config.is_vpc_acccess_proxy,
auth_backend.get_api(),
)
.await
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
}.instrument(cancel_span)
});
return Ok(None);
}
};
drop(pause);
ctx.set_db_options(params.clone());
let common_names = tls.map(|tls| &tls.common_names);
let (node, cancel_on_shutdown) = handle_client(
config,
auth_backend,
ctx,
cancellation_handler,
&mut client,
&mode,
endpoint_rate_limiter,
common_names,
&params,
)
.await?;
let client = client.flush_and_into_inner().await?;
let private_link_id = match ctx.extra() {
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
None => None,
};
Ok(Some(ProxyPassthrough {
client,
compute: node.stream,
aux: node.aux,
private_link_id,
_cancel_on_shutdown: cancel_on_shutdown,
_req: request_gauge,
_conn: conn_gauge,
_db_conn: node.guage,
}))
}

View File

@@ -112,7 +112,7 @@ where
let node_info = if !node_info.cached() || !err.should_retry_wake_compute() {
// If we just recieved this from cplane and didn't get it from cache, we shouldn't retry.
// Do not need to retrieve a new node_info, just return the old one.
if should_retry(&err, num_retries, compute.retry) {
if !should_retry(&err, num_retries, compute.retry) {
Metrics::get().proxy.retries_metric.observe(
RetriesMetricGroup {
outcome: ConnectOutcome::Failed,

View File

@@ -5,328 +5,64 @@ pub(crate) mod connect_compute;
pub(crate) mod retry;
pub(crate) mod wake_compute;
use std::collections::HashSet;
use std::convert::Infallible;
use std::sync::Arc;
use futures::FutureExt;
use itertools::Itertools;
use once_cell::sync::OnceCell;
use regex::Regex;
use serde::{Deserialize, Serialize};
use smol_str::{SmolStr, ToSmolStr, format_smolstr};
use thiserror::Error;
use smol_str::{SmolStr, format_smolstr};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info, warn};
use tokio::sync::oneshot;
use tracing::Instrument;
use crate::cancellation::{self, CancellationHandler};
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::cache::Cache;
use crate::cancellation::CancellationHandler;
use crate::compute::ComputeConnection;
use crate::config::ProxyConfig;
use crate::context::RequestContext;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::control_plane::client::ControlPlaneClient;
pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake};
use crate::pglb::passthrough::ProxyPassthrough;
use crate::pglb::{ClientMode, ClientRequestError};
use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams};
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
use crate::proxy::retry::ShouldRetryWakeCompute;
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::{PqStream, Stream};
use crate::types::EndpointCacheKey;
use crate::util::run_until_cancelled;
use crate::{auth, compute};
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
#[derive(Error, Debug)]
#[error("{ERR_INSECURE_CONNECTION}")]
pub struct TlsRequired;
impl ReportableError for TlsRequired {
fn get_error_kind(&self) -> crate::error::ErrorKind {
crate::error::ErrorKind::User
}
}
impl UserFacingError for TlsRequired {}
pub async fn task_main(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandler>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("proxy has shut down");
}
// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
{
let (socket, peer_addr) = accept_result?;
let conn_gauge = Metrics::get()
.proxy
.client_connections
.guard(crate::metrics::Protocol::Tcp);
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
let cancellations = cancellations.clone();
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
connections.spawn(async move {
let (socket, conn_info) = match config.proxy_protocol_v2 {
ProxyProtocolV2::Required => {
match read_proxy_protocol(socket).await {
Err(e) => {
warn!("per-client task finished with an error: {e:#}");
return;
}
// our load balancers will not send any more data. let's just exit immediately
Ok((_socket, ConnectHeader::Local)) => {
debug!("healthcheck received");
return;
}
Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
}
}
// ignore the header - it cannot be confused for a postgres or http connection so will
// error later.
ProxyProtocolV2::Rejected => (
socket,
ConnectionInfo {
addr: peer_addr,
extra: None,
},
),
};
match socket.set_nodelay(true) {
Ok(()) => {}
Err(e) => {
error!(
"per-client task finished with an error: failed to set socket option: {e:#}"
);
return;
}
}
let ctx = RequestContext::new(
session_id,
conn_info,
crate::metrics::Protocol::Tcp,
&config.region,
);
let res = handle_client(
config,
auth_backend,
&ctx,
cancellation_handler,
socket,
ClientMode::Tcp,
endpoint_rate_limiter2,
conn_gauge,
cancellations,
)
.instrument(ctx.span())
.boxed()
.await;
match res {
Err(e) => {
ctx.set_error_kind(e.get_error_kind());
warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
}
Ok(None) => {
ctx.set_success();
}
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass().await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
warn!(
?session_id,
"per-client task finished with an IO error from the client: {e:#}"
);
}
Err(ErrorSource::Compute(e)) => {
error!(
?session_id,
"per-client task finished with an IO error from the compute: {e:#}"
);
}
}
}
}
});
}
connections.close();
cancellations.close();
drop(listener);
// Drain connections
connections.wait().await;
cancellations.wait().await;
Ok(())
}
pub(crate) enum ClientMode {
Tcp,
Websockets { hostname: Option<String> },
}
/// Abstracts the logic of handling TCP vs WS clients
impl ClientMode {
pub(crate) fn allow_cleartext(&self) -> bool {
match self {
ClientMode::Tcp => false,
ClientMode::Websockets { .. } => true,
}
}
fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
match self {
ClientMode::Tcp => s.sni_hostname(),
ClientMode::Websockets { hostname } => hostname.as_deref(),
}
}
fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
match self {
ClientMode::Tcp => tls,
// TLS is None here if using websockets, because the connection is already encrypted.
ClientMode::Websockets { .. } => None,
}
}
}
#[derive(Debug, Error)]
// almost all errors should be reported to the user, but there's a few cases where we cannot
// 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
// we cannot be sure the client even understands our error message
// 3. PrepareClient: The client disconnected, so we can't tell them anyway...
pub(crate) enum ClientRequestError {
#[error("{0}")]
Cancellation(#[from] cancellation::CancelError),
#[error("{0}")]
Handshake(#[from] HandshakeError),
#[error("{0}")]
HandshakeTimeout(#[from] tokio::time::error::Elapsed),
#[error("{0}")]
PrepareClient(#[from] std::io::Error),
#[error("{0}")]
ReportedError(#[from] crate::stream::ReportedError),
}
impl ReportableError for ClientRequestError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
ClientRequestError::Cancellation(e) => e.get_error_kind(),
ClientRequestError::Handshake(e) => e.get_error_kind(),
ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
ClientRequestError::ReportedError(e) => e.get_error_kind(),
ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
}
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
ctx: &RequestContext,
cancellation_handler: Arc<CancellationHandler>,
stream: S,
mode: ClientMode,
client: &mut PqStream<Stream<S>>,
mode: &ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conn_gauge: NumClientConnectionsGuard<'static>,
cancellations: tokio_util::task::task_tracker::TaskTracker,
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
debug!(
protocol = %ctx.protocol(),
"handling interactive connection from client"
);
let metrics = &Metrics::get().proxy;
let proto = ctx.protocol();
let request_gauge = metrics.connection_requests.guard(proto);
let tls = config.tls_config.load();
let tls = tls.as_deref();
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
.await??
{
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
// spawn a task to cancel the session, but don't wait for it
cancellations.spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let ctx = ctx.clone();
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
cancel_span.follows_from(tracing::Span::current());
async move {
cancellation_handler_clone
.cancel_session(
cancel_key_data,
ctx,
config.authentication_config.ip_allowlist_check_enabled,
config.authentication_config.is_vpc_acccess_proxy,
auth_backend.get_api(),
)
.await
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
}.instrument(cancel_span)
});
return Ok(None);
}
};
drop(pause);
ctx.set_db_options(params.clone());
let hostname = mode.hostname(stream.get_ref());
let common_names = tls.map(|tls| &tls.common_names);
common_names: Option<&HashSet<String>>,
params: &StartupMessageParams,
) -> Result<(ComputeConnection, oneshot::Sender<Infallible>), ClientRequestError> {
let hostname = mode.hostname(client.get_ref());
// Extract credentials which we're going to use for auth.
let result = auth_backend
.as_ref()
.map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names))
.map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, params, hostname, common_names))
.transpose();
let user_info = match result {
Ok(user_info) => user_info,
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
};
let user = user_info.get_user().to_owned();
let user_info = match user_info
.authenticate(
ctx,
&mut stream,
client,
mode.allow_cleartext(),
&config.authentication_config,
endpoint_rate_limiter,
@@ -339,7 +75,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
let app = params.get("application_name");
let params_span = tracing::info_span!("", ?user, ?db, ?app);
return Err(stream
return Err(client
.throw_error(e, Some(ctx))
.instrument(params_span)
.await)?;
@@ -352,37 +88,67 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
};
let params_compat = creds.info.options.get(NeonOptions::PARAMS_COMPAT).is_some();
let mut auth_info = compute::AuthInfo::with_auth_keys(creds.keys);
auth_info.set_startup_params(&params, params_compat);
auth_info.set_startup_params(params, params_compat);
let res = connect_to_compute(
ctx,
&TcpMechanism {
locks: &config.connect_compute_locks,
},
&auth::Backend::ControlPlane(cplane, creds.info.clone()),
config.wake_compute_retry_config,
&config.connect_to_compute,
)
.await;
let mut node = match res {
Ok(node) => node,
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
let mut node;
let mut attempt = 0;
let connect = TcpMechanism {
locks: &config.connect_compute_locks,
};
let backend = auth::Backend::ControlPlane(cplane, creds.info);
let pg_settings = auth_info.authenticate(ctx, &mut node, creds.info).await;
let pg_settings = match pg_settings {
Ok(pg_settings) => pg_settings,
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
// NOTE: This is messy, but should hopefully be detangled with PGLB.
// We wanted to separate the concerns of **connect** to compute (a PGLB operation),
// from **authenticate** to compute (a NeonKeeper operation).
//
// This unfortunately removed retry handling for one error case where
// the compute was cached, and we connected, but the compute cache was actually stale
// and is associated with the wrong endpoint. We detect this when the **authentication** fails.
// As such, we retry once here if the `authenticate` function fails and the error is valid to retry.
let pg_settings = loop {
attempt += 1;
// TODO: callback to pglb
let res = connect_to_compute(
ctx,
&connect,
&backend,
config.wake_compute_retry_config,
&config.connect_to_compute,
)
.await;
match res {
Ok(n) => node = n,
Err(e) => return Err(client.throw_error(e, Some(ctx)).await)?,
}
let auth::Backend::ControlPlane(cplane, user_info) = &backend else {
unreachable!("ensured above");
};
let res = auth_info.authenticate(ctx, &mut node, user_info).await;
match res {
Ok(pg_settings) => break pg_settings,
Err(e) if attempt < 2 && e.should_retry_wake_compute() => {
tracing::warn!(error = ?e, "retrying wake compute");
#[allow(irrefutable_let_patterns)]
if let ControlPlaneClient::ProxyV1(cplane_proxy_v1) = &**cplane {
let key = user_info.endpoint_cache_key();
cplane_proxy_v1.caches.node_info.invalidate(&key);
}
}
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
}
};
let session = cancellation_handler.get_key();
prepare_client_connection(&pg_settings, *session.key(), &mut stream);
let stream = stream.flush_and_into_inner().await?;
finish_client_init(&pg_settings, *session.key(), client);
let session_id = ctx.session_id();
let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel();
let (cancel_on_shutdown, cancel) = oneshot::channel();
tokio::spawn(async move {
session
.maintain_cancel_key(
@@ -394,50 +160,32 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
.await;
});
let private_link_id = match ctx.extra() {
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
None => None,
};
Ok(Some(ProxyPassthrough {
client: stream,
compute: node.stream,
aux: node.aux,
private_link_id,
_cancel_on_shutdown: cancel_on_shutdown,
_req: request_gauge,
_conn: conn_gauge,
_db_conn: node.guage,
}))
Ok((node, cancel_on_shutdown))
}
/// Finish client connection initialization: confirm auth success, send params, etc.
pub(crate) fn prepare_client_connection(
pub(crate) fn finish_client_init(
settings: &compute::PostgresSettings,
cancel_key_data: CancelKeyData,
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) {
// Forward all deferred notices to the client.
for notice in &settings.delayed_notice {
stream.write_raw(notice.as_bytes().len(), b'N', |buf| {
client.write_raw(notice.as_bytes().len(), b'N', |buf| {
buf.extend_from_slice(notice.as_bytes());
});
}
// Forward all postgres connection params to the client.
for (name, value) in &settings.params {
stream.write_message(BeMessage::ParameterStatus {
client.write_message(BeMessage::ParameterStatus {
name: name.as_bytes(),
value: value.as_bytes(),
});
}
stream.write_message(BeMessage::BackendKeyData(cancel_key_data));
stream.write_message(BeMessage::ReadyForQuery);
client.write_message(BeMessage::BackendKeyData(cancel_key_data));
client.write_message(BeMessage::ReadyForQuery);
}
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
@@ -447,7 +195,7 @@ impl NeonOptions {
// proxy options:
/// `PARAMS_COMPAT` allows opting in to forwarding all startup parameters from client to compute.
const PARAMS_COMPAT: &str = "proxy_params_compat";
pub const PARAMS_COMPAT: &str = "proxy_params_compat";
// cplane options:

View File

@@ -3,7 +3,7 @@ use std::io;
use tokio::time;
use crate::compute;
use crate::compute::{self, PostgresError};
use crate::config::RetryConfig;
pub(crate) trait CouldRetry {
@@ -115,6 +115,14 @@ impl ShouldRetryWakeCompute for compute::ConnectionError {
}
}
impl ShouldRetryWakeCompute for PostgresError {
fn should_retry_wake_compute(&self) -> bool {
match self {
PostgresError::Postgres(error) => error.should_retry_wake_compute(),
}
}
}
pub(crate) fn retry_after(num_retries: u32, config: RetryConfig) -> time::Duration {
config
.base_delay

View File

@@ -14,6 +14,9 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream};
use tokio_util::codec::{Decoder, Encoder};
use super::*;
use crate::config::TlsConfig;
use crate::context::RequestContext;
use crate::pglb::handshake::{HandshakeData, handshake};
enum Intercept {
None,

View File

@@ -3,6 +3,7 @@
mod mitm;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, bail};
@@ -10,26 +11,31 @@ use async_trait::async_trait;
use http::StatusCode;
use postgres_client::config::SslMode;
use postgres_client::tls::{MakeTlsConnect, NoTls};
use retry::{ShouldRetryWakeCompute, retry_after};
use rstest::rstest;
use rustls::crypto::ring;
use rustls::pki_types;
use tokio::io::DuplexStream;
use tokio::io::{AsyncRead, AsyncWrite, DuplexStream};
use tracing_test::traced_test;
use super::retry::CouldRetry;
use super::*;
use crate::auth::backend::{ComputeUserInfo, MaybeOwned};
use crate::config::{ComputeConfig, RetryConfig};
use crate::config::{ComputeConfig, RetryConfig, TlsConfig};
use crate::context::RequestContext;
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache};
use crate::error::ErrorKind;
use crate::proxy::connect_compute::ConnectMechanism;
use crate::error::{ErrorKind, ReportableError};
use crate::pglb::ERR_INSECURE_CONNECTION;
use crate::pglb::handshake::{HandshakeData, handshake};
use crate::pqproto::BeMessage;
use crate::proxy::NeonOptions;
use crate::proxy::connect_compute::{ConnectMechanism, connect_to_compute};
use crate::proxy::retry::{ShouldRetryWakeCompute, retry_after};
use crate::stream::{PqStream, Stream};
use crate::tls::client_config::compute_client_config_with_certs;
use crate::tls::server_config::CertResolver;
use crate::types::{BranchId, EndpointId, ProjectId};
use crate::{sasl, scram};
use crate::{auth, compute, sasl, scram};
/// Generate a set of TLS certificates: CA + server.
fn generate_certs(
@@ -374,6 +380,7 @@ fn connect_compute_total_wait() {
#[derive(Clone, Copy, Debug)]
enum ConnectAction {
Wake,
WakeCold,
WakeFail,
WakeRetry,
Connect,
@@ -504,6 +511,9 @@ impl TestControlPlaneClient for TestConnectMechanism {
*counter += 1;
match action {
ConnectAction::Wake => Ok(helper_create_cached_node_info(self.cache)),
ConnectAction::WakeCold => Ok(CachedNodeInfo::new_uncached(
helper_create_uncached_node_info(),
)),
ConnectAction::WakeFail => {
let err = control_plane::errors::ControlPlaneError::Message(Box::new(
ControlPlaneErrorMessage {
@@ -551,8 +561,8 @@ impl TestControlPlaneClient for TestConnectMechanism {
}
}
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
let node = NodeInfo {
fn helper_create_uncached_node_info() -> NodeInfo {
NodeInfo {
conn_info: compute::ConnectInfo {
host: "test".into(),
port: 5432,
@@ -566,7 +576,11 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn
compute_id: "compute".into(),
cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
},
};
}
}
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
let node = helper_create_uncached_node_info();
let (_, node2) = cache.insert_unit("key".into(), Ok(node.clone()));
node2.map(|()| node)
}
@@ -742,7 +756,7 @@ async fn fail_no_wake_skips_cache_invalidation() {
let ctx = RequestContext::test();
let mech = TestConnectMechanism::new(vec![
ConnectAction::Wake,
ConnectAction::FailNoWake,
ConnectAction::RetryNoWake,
ConnectAction::Connect,
]);
let user = helper_create_connect_info(&mech);
@@ -788,7 +802,7 @@ async fn retry_no_wake_skips_invalidation() {
let ctx = RequestContext::test();
// Wake → RetryNoWake (retryable + NOT wakeable)
let mechanism = TestConnectMechanism::new(vec![Wake, RetryNoWake]);
let mechanism = TestConnectMechanism::new(vec![Wake, RetryNoWake, Fail]);
let user_info = helper_create_connect_info(&mechanism);
let cfg = config();
@@ -802,3 +816,44 @@ async fn retry_no_wake_skips_invalidation() {
"invalidating stalled compute node info cache entry"
));
}
#[tokio::test]
#[traced_test]
async fn retry_no_wake_error_fast() {
let _ = env_logger::try_init();
use ConnectAction::*;
let ctx = RequestContext::test();
// Wake → FailNoWake (not retryable + NOT wakeable)
let mechanism = TestConnectMechanism::new(vec![Wake, FailNoWake]);
let user_info = helper_create_connect_info(&mechanism);
let cfg = config();
connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
.await
.unwrap_err();
mechanism.verify();
// Because FailNoWake has wakeable=false, we must NOT see invalidate_cache
assert!(!logs_contain(
"invalidating stalled compute node info cache entry"
));
}
#[tokio::test]
#[traced_test]
async fn retry_cold_wake_skips_invalidation() {
let _ = env_logger::try_init();
use ConnectAction::*;
let ctx = RequestContext::test();
// WakeCold → FailNoWake (not retryable + NOT wakeable)
let mechanism = TestConnectMechanism::new(vec![WakeCold, Retry, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let cfg = config();
connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
.await
.unwrap();
mechanism.verify();
}

View File

@@ -139,12 +139,6 @@ impl RateBucketInfo {
Self::new(200, Duration::from_secs(600)),
];
// For all the sessions will be cancel key. So this limit is essentially global proxy limit.
pub const DEFAULT_REDIS_SET: [Self; 2] = [
Self::new(100_000, Duration::from_secs(1)),
Self::new(50_000, Duration::from_secs(10)),
];
pub fn rps(&self) -> f64 {
(self.max_rpi as f64) / self.interval.as_secs_f64()
}

View File

@@ -5,11 +5,9 @@ use redis::aio::ConnectionLike;
use redis::{Cmd, FromRedisValue, Pipeline, RedisResult};
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo};
pub struct RedisKVClient {
client: ConnectionWithCredentialsProvider,
limiter: GlobalRateLimiter,
}
#[allow(async_fn_in_trait)]
@@ -30,11 +28,8 @@ impl Queryable for Cmd {
}
impl RedisKVClient {
pub fn new(client: ConnectionWithCredentialsProvider, info: &'static [RateBucketInfo]) -> Self {
Self {
client,
limiter: GlobalRateLimiter::new(info.into()),
}
pub fn new(client: ConnectionWithCredentialsProvider) -> Self {
Self { client }
}
pub async fn try_connect(&mut self) -> anyhow::Result<()> {
@@ -49,11 +44,6 @@ impl RedisKVClient {
&mut self,
q: &impl Queryable,
) -> anyhow::Result<T> {
if !self.limiter.check() {
tracing::info!("Rate limit exceeded. Skipping query");
return Err(anyhow::anyhow!("Rate limit exceeded"));
}
let e = match q.query(&mut self.client).await {
Ok(t) => return Ok(t),
Err(e) => e,

View File

@@ -141,29 +141,19 @@ where
struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
cache: Arc<C>,
region_id: String,
}
impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
fn clone(&self) -> Self {
Self {
cache: self.cache.clone(),
region_id: self.region_id.clone(),
}
}
}
impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
pub(crate) fn new(cache: Arc<C>, region_id: String) -> Self {
Self { cache, region_id }
}
pub(crate) async fn increment_active_listeners(&self) {
self.cache.increment_active_listeners().await;
}
pub(crate) async fn decrement_active_listeners(&self) {
self.cache.decrement_active_listeners().await;
pub(crate) fn new(cache: Arc<C>) -> Self {
Self { cache }
}
#[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
@@ -276,7 +266,7 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
}
let mut conn = match try_connect(&redis).await {
Ok(conn) => {
handler.increment_active_listeners().await;
handler.cache.increment_active_listeners().await;
conn
}
Err(e) => {
@@ -297,11 +287,11 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
}
}
if cancellation_token.is_cancelled() {
handler.decrement_active_listeners().await;
handler.cache.decrement_active_listeners().await;
return Ok(());
}
}
handler.decrement_active_listeners().await;
handler.cache.decrement_active_listeners().await;
}
}
@@ -310,12 +300,11 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
pub async fn task_main<C>(
redis: ConnectionWithCredentialsProvider,
cache: Arc<C>,
region_id: String,
) -> anyhow::Result<Infallible>
where
C: ProjectInfoCache + Send + Sync + 'static,
{
let handler = MessageHandler::new(cache, region_id);
let handler = MessageHandler::new(cache);
// 6h - 1m.
// There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));

View File

@@ -418,12 +418,7 @@ async fn request_handler(
if config.http_config.accept_websockets
&& framed_websockets::upgrade::is_upgrade_request(&request)
{
let ctx = RequestContext::new(
session_id,
conn_info,
crate::metrics::Protocol::Ws,
&config.region,
);
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Ws);
ctx.set_user_agent(
request
@@ -463,12 +458,7 @@ async fn request_handler(
// Return the response so the spawned future can continue.
Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
let ctx = RequestContext::new(
session_id,
conn_info,
crate::metrics::Protocol::Http,
&config.region,
);
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Http);
let span = ctx.span();
let testodrome_id = request

View File

@@ -17,7 +17,8 @@ use crate::config::ProxyConfig;
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::Metrics;
use crate::proxy::{ClientMode, ErrorSource, handle_client};
use crate::pglb::{ClientMode, handle_connection};
use crate::proxy::ErrorSource;
use crate::rate_limiter::EndpointRateLimiter;
pin_project! {
@@ -142,7 +143,7 @@ pub(crate) async fn serve_websocket(
.client_connections
.guard(crate::metrics::Protocol::Ws);
let res = Box::pin(handle_client(
let res = Box::pin(handle_connection(
config,
auth_backend,
&ctx,