mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-31 12:00:42 +00:00
Merge branch 'main' into ruslan/subzero-integration
This commit is contained in:
@@ -288,7 +288,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
|
||||
},
|
||||
proxy_protocol_v2: config::ProxyProtocolV2::Rejected,
|
||||
handshake_timeout: Duration::from_secs(10),
|
||||
region: "local".into(),
|
||||
wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
|
||||
connect_compute_locks,
|
||||
connect_to_compute: compute_config,
|
||||
|
||||
@@ -26,9 +26,10 @@ use utils::sentry_init::init_sentry;
|
||||
|
||||
use crate::context::RequestContext;
|
||||
use crate::metrics::{Metrics, ThreadPoolMetrics};
|
||||
use crate::pglb::TlsRequired;
|
||||
use crate::pqproto::FeStartupPacket;
|
||||
use crate::protocol2::ConnectionInfo;
|
||||
use crate::proxy::{ErrorSource, TlsRequired, copy_bidirectional_client_compute};
|
||||
use crate::proxy::{ErrorSource, copy_bidirectional_client_compute};
|
||||
use crate::stream::{PqStream, Stream};
|
||||
use crate::util::run_until_cancelled;
|
||||
|
||||
@@ -236,7 +237,6 @@ pub(super) async fn task_main(
|
||||
extra: None,
|
||||
},
|
||||
crate::metrics::Protocol::SniRouter,
|
||||
"sni",
|
||||
);
|
||||
handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await
|
||||
}
|
||||
|
||||
@@ -154,12 +154,6 @@ struct ProxyCliArgs {
|
||||
/// timeout for the TLS handshake
|
||||
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
|
||||
handshake_timeout: tokio::time::Duration,
|
||||
/// http endpoint to receive periodic metric updates
|
||||
#[clap(long)]
|
||||
metric_collection_endpoint: Option<String>,
|
||||
/// how often metrics should be sent to a collection endpoint
|
||||
#[clap(long)]
|
||||
metric_collection_interval: Option<String>,
|
||||
/// cache for `wake_compute` api method (use `size=0` to disable)
|
||||
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
|
||||
wake_compute_cache: String,
|
||||
@@ -186,40 +180,31 @@ struct ProxyCliArgs {
|
||||
/// Wake compute rate limiter max number of requests per second.
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
|
||||
wake_compute_limit: Vec<RateBucketInfo>,
|
||||
/// Redis rate limiter max number of requests per second.
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_REDIS_SET)]
|
||||
redis_rps_limit: Vec<RateBucketInfo>,
|
||||
/// Cancellation channel size (max queue size for redis kv client)
|
||||
#[clap(long, default_value_t = 1024)]
|
||||
cancellation_ch_size: usize,
|
||||
/// Cancellation ops batch size for redis
|
||||
#[clap(long, default_value_t = 8)]
|
||||
cancellation_batch_size: usize,
|
||||
/// cache for `allowed_ips` (use `size=0` to disable)
|
||||
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
|
||||
allowed_ips_cache: String,
|
||||
/// cache for `role_secret` (use `size=0` to disable)
|
||||
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
|
||||
role_secret_cache: String,
|
||||
/// redis url for notifications (if empty, redis_host:port will be used for both notifications and streaming connections)
|
||||
#[clap(long)]
|
||||
redis_notifications: Option<String>,
|
||||
/// what from the available authentications type to use for the regional redis we have. Supported are "irsa" and "plain".
|
||||
/// redis url for plain authentication
|
||||
#[clap(long, alias("redis-notifications"))]
|
||||
redis_plain: Option<String>,
|
||||
/// what from the available authentications type to use for redis. Supported are "irsa" and "plain".
|
||||
#[clap(long, default_value = "irsa")]
|
||||
redis_auth_type: String,
|
||||
/// redis host for streaming connections (might be different from the notifications host)
|
||||
/// redis host for irsa authentication
|
||||
#[clap(long)]
|
||||
redis_host: Option<String>,
|
||||
/// redis port for streaming connections (might be different from the notifications host)
|
||||
/// redis port for irsa authentication
|
||||
#[clap(long)]
|
||||
redis_port: Option<u16>,
|
||||
/// redis cluster name, used in aws elasticache
|
||||
/// redis cluster name for irsa authentication
|
||||
#[clap(long)]
|
||||
redis_cluster_name: Option<String>,
|
||||
/// redis user_id, used in aws elasticache
|
||||
/// redis user_id for irsa authentication
|
||||
#[clap(long)]
|
||||
redis_user_id: Option<String>,
|
||||
/// aws region to retrieve credentials
|
||||
/// aws region for irsa authentication
|
||||
#[clap(long, default_value_t = String::new())]
|
||||
aws_region: String,
|
||||
/// cache for `project_info` (use `size=0` to disable)
|
||||
@@ -231,6 +216,12 @@ struct ProxyCliArgs {
|
||||
#[clap(flatten)]
|
||||
parquet_upload: ParquetUploadArgs,
|
||||
|
||||
/// http endpoint to receive periodic metric updates
|
||||
#[clap(long)]
|
||||
metric_collection_endpoint: Option<String>,
|
||||
/// how often metrics should be sent to a collection endpoint
|
||||
#[clap(long)]
|
||||
metric_collection_interval: Option<String>,
|
||||
/// interval for backup metric collection
|
||||
#[clap(long, default_value = "10m", value_parser = humantime::parse_duration)]
|
||||
metric_backup_collection_interval: std::time::Duration,
|
||||
@@ -243,6 +234,7 @@ struct ProxyCliArgs {
|
||||
/// Size of each event is no more than 400 bytes, so 2**22 is about 200MB before the compression.
|
||||
#[clap(long, default_value = "4194304")]
|
||||
metric_backup_collection_chunk_size: usize,
|
||||
|
||||
/// Whether to retry the connection to the compute node
|
||||
#[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)]
|
||||
connect_to_compute_retry: String,
|
||||
@@ -370,7 +362,7 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
Either::Right(auth_backend) => info!("Authentication backend: {auth_backend:?}"),
|
||||
}
|
||||
info!("Using region: {}", args.aws_region);
|
||||
let (regional_redis_client, redis_notifications_client) = configure_redis(&args).await?;
|
||||
let redis_client = configure_redis(&args).await?;
|
||||
|
||||
// Check that we can bind to address before further initialization
|
||||
info!("Starting http on {}", args.http);
|
||||
@@ -425,13 +417,6 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
|
||||
let cancellation_token = CancellationToken::new();
|
||||
|
||||
let redis_rps_limit = Vec::leak(args.redis_rps_limit.clone());
|
||||
RateBucketInfo::validate(redis_rps_limit)?;
|
||||
|
||||
let redis_kv_client = regional_redis_client
|
||||
.as_ref()
|
||||
.map(|redis_publisher| RedisKVClient::new(redis_publisher.clone(), redis_rps_limit));
|
||||
|
||||
let cancellation_handler = Arc::new(CancellationHandler::new(&config.connect_to_compute));
|
||||
|
||||
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
|
||||
@@ -446,7 +431,7 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
match auth_backend {
|
||||
Either::Left(auth_backend) => {
|
||||
if let Some(proxy_listener) = proxy_listener {
|
||||
client_tasks.spawn(crate::proxy::task_main(
|
||||
client_tasks.spawn(crate::pglb::task_main(
|
||||
config,
|
||||
auth_backend,
|
||||
proxy_listener,
|
||||
@@ -528,6 +513,7 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
client_tasks.spawn(crate::context::parquet::worker(
|
||||
cancellation_token.clone(),
|
||||
args.parquet_upload,
|
||||
args.region,
|
||||
));
|
||||
|
||||
// maintenance tasks. these never return unless there's an error
|
||||
@@ -561,32 +547,17 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
#[cfg_attr(not(any(test, feature = "testing")), expect(irrefutable_let_patterns))]
|
||||
if let Either::Left(auth::Backend::ControlPlane(api, ())) = &auth_backend {
|
||||
if let crate::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api {
|
||||
match (redis_notifications_client, regional_redis_client.clone()) {
|
||||
(None, None) => {}
|
||||
(client1, client2) => {
|
||||
let cache = api.caches.project_info.clone();
|
||||
if let Some(client) = client1 {
|
||||
maintenance_tasks.spawn(notifications::task_main(
|
||||
client,
|
||||
cache.clone(),
|
||||
args.region.clone(),
|
||||
));
|
||||
}
|
||||
if let Some(client) = client2 {
|
||||
maintenance_tasks.spawn(notifications::task_main(
|
||||
client,
|
||||
cache.clone(),
|
||||
args.region.clone(),
|
||||
));
|
||||
}
|
||||
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
|
||||
}
|
||||
}
|
||||
if let Some(client) = redis_client {
|
||||
// project info cache and invalidation of that cache.
|
||||
let cache = api.caches.project_info.clone();
|
||||
maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone()));
|
||||
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
|
||||
|
||||
// Try to connect to Redis 3 times with 1 + (0..0.1) second interval.
|
||||
// This prevents immediate exit and pod restart,
|
||||
// which can cause hammering of the redis in case of connection issues.
|
||||
if let Some(mut redis_kv_client) = redis_kv_client {
|
||||
// Try to connect to Redis 3 times with 1 + (0..0.1) second interval.
|
||||
// This prevents immediate exit and pod restart,
|
||||
// which can cause hammering of the redis in case of connection issues.
|
||||
// cancellation key management
|
||||
let mut redis_kv_client = RedisKVClient::new(client.clone());
|
||||
for attempt in (0..3).with_position() {
|
||||
match redis_kv_client.try_connect().await {
|
||||
Ok(()) => {
|
||||
@@ -611,14 +582,12 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(regional_redis_client) = regional_redis_client {
|
||||
// listen for notifications of new projects/endpoints/branches
|
||||
let cache = api.caches.endpoints_cache.clone();
|
||||
let con = regional_redis_client;
|
||||
let span = tracing::info_span!("endpoints_cache");
|
||||
maintenance_tasks.spawn(
|
||||
async move { cache.do_read(con, cancellation_token.clone()).await }
|
||||
async move { cache.do_read(client, cancellation_token.clone()).await }
|
||||
.instrument(span),
|
||||
);
|
||||
}
|
||||
@@ -765,7 +734,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
authentication_config,
|
||||
proxy_protocol_v2: args.proxy_protocol_v2,
|
||||
handshake_timeout: args.handshake_timeout,
|
||||
region: args.region.clone(),
|
||||
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
|
||||
connect_compute_locks,
|
||||
connect_to_compute: compute_config,
|
||||
@@ -942,21 +910,18 @@ fn build_auth_backend(
|
||||
|
||||
async fn configure_redis(
|
||||
args: &ProxyCliArgs,
|
||||
) -> anyhow::Result<(
|
||||
Option<ConnectionWithCredentialsProvider>,
|
||||
Option<ConnectionWithCredentialsProvider>,
|
||||
)> {
|
||||
) -> anyhow::Result<Option<ConnectionWithCredentialsProvider>> {
|
||||
// TODO: untangle the config args
|
||||
let regional_redis_client = match (args.redis_auth_type.as_str(), &args.redis_notifications) {
|
||||
("plain", redis_url) => match redis_url {
|
||||
let redis_client = match &*args.redis_auth_type {
|
||||
"plain" => match &args.redis_plain {
|
||||
None => {
|
||||
bail!("plain auth requires redis_notifications to be set");
|
||||
bail!("plain auth requires redis_plain to be set");
|
||||
}
|
||||
Some(url) => {
|
||||
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url.clone()))
|
||||
}
|
||||
},
|
||||
("irsa", _) => match (&args.redis_host, args.redis_port) {
|
||||
"irsa" => match (&args.redis_host, args.redis_port) {
|
||||
(Some(host), Some(port)) => Some(
|
||||
ConnectionWithCredentialsProvider::new_with_credentials_provider(
|
||||
host.clone(),
|
||||
@@ -980,18 +945,12 @@ async fn configure_redis(
|
||||
bail!("redis-host and redis-port must be specified together");
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
bail!("unknown auth type given");
|
||||
auth_type => {
|
||||
bail!("unknown auth type {auth_type:?} given")
|
||||
}
|
||||
};
|
||||
|
||||
let redis_notifications_client = if let Some(url) = &args.redis_notifications {
|
||||
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(&**url))
|
||||
} else {
|
||||
regional_redis_client.clone()
|
||||
};
|
||||
|
||||
Ok((regional_redis_client, redis_notifications_client))
|
||||
Ok(redis_client)
|
||||
}
|
||||
|
||||
|
||||
|
||||
63
proxy/src/cache/timed_lru.rs
vendored
63
proxy/src/cache/timed_lru.rs
vendored
@@ -30,7 +30,7 @@ use super::{Cache, timed_lru};
|
||||
///
|
||||
/// * There's an API for immediate invalidation (removal) of a cache entry;
|
||||
/// It's useful in case we know for sure that the entry is no longer correct.
|
||||
/// See [`timed_lru::LookupInfo`] & [`timed_lru::Cached`] for more information.
|
||||
/// See [`timed_lru::Cached`] for more information.
|
||||
///
|
||||
/// * Expired entries are kept in the cache, until they are evicted by the LRU policy,
|
||||
/// or by a successful lookup (i.e. the entry hasn't expired yet).
|
||||
@@ -54,7 +54,7 @@ pub(crate) struct TimedLru<K, V> {
|
||||
impl<K: Hash + Eq, V> Cache for TimedLru<K, V> {
|
||||
type Key = K;
|
||||
type Value = V;
|
||||
type LookupInfo<Key> = LookupInfo<Key>;
|
||||
type LookupInfo<Key> = Key;
|
||||
|
||||
fn invalidate(&self, info: &Self::LookupInfo<K>) {
|
||||
self.invalidate_raw(info);
|
||||
@@ -87,30 +87,24 @@ impl<K: Hash + Eq, V> TimedLru<K, V> {
|
||||
|
||||
/// Drop an entry from the cache if it's outdated.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn invalidate_raw(&self, info: &LookupInfo<K>) {
|
||||
let now = Instant::now();
|
||||
|
||||
fn invalidate_raw(&self, key: &K) {
|
||||
// Do costly things before taking the lock.
|
||||
let mut cache = self.cache.lock();
|
||||
let raw_entry = match cache.raw_entry_mut().from_key(&info.key) {
|
||||
let entry = match cache.raw_entry_mut().from_key(key) {
|
||||
RawEntryMut::Vacant(_) => return,
|
||||
RawEntryMut::Occupied(x) => x,
|
||||
RawEntryMut::Occupied(x) => x.remove(),
|
||||
};
|
||||
|
||||
// Remove the entry if it was created prior to lookup timestamp.
|
||||
let entry = raw_entry.get();
|
||||
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
|
||||
let should_remove = created_at <= info.created_at || expires_at <= now;
|
||||
|
||||
if should_remove {
|
||||
raw_entry.remove();
|
||||
}
|
||||
|
||||
drop(cache); // drop lock before logging
|
||||
|
||||
let Entry {
|
||||
created_at,
|
||||
expires_at,
|
||||
..
|
||||
} = entry;
|
||||
|
||||
debug!(
|
||||
created_at = format_args!("{created_at:?}"),
|
||||
expires_at = format_args!("{expires_at:?}"),
|
||||
entry_removed = should_remove,
|
||||
?created_at,
|
||||
?expires_at,
|
||||
"processed a cache entry invalidation event"
|
||||
);
|
||||
}
|
||||
@@ -215,10 +209,10 @@ impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
|
||||
}
|
||||
|
||||
pub(crate) fn insert_unit(&self, key: K, value: V) -> (Option<V>, Cached<&Self, ()>) {
|
||||
let (created_at, old) = self.insert_raw(key.clone(), value);
|
||||
let (_, old) = self.insert_raw(key.clone(), value);
|
||||
|
||||
let cached = Cached {
|
||||
token: Some((self, LookupInfo { created_at, key })),
|
||||
token: Some((self, key)),
|
||||
value: (),
|
||||
};
|
||||
|
||||
@@ -255,28 +249,9 @@ impl<K: Hash + Eq, V: Clone> TimedLru<K, V> {
|
||||
K: Borrow<Q> + Clone,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
self.get_raw(key, |key, entry| {
|
||||
let info = LookupInfo {
|
||||
created_at: entry.created_at,
|
||||
key: key.clone(),
|
||||
};
|
||||
|
||||
Cached {
|
||||
token: Some((self, info)),
|
||||
value: entry.value.clone(),
|
||||
}
|
||||
self.get_raw(key, |key, entry| Cached {
|
||||
token: Some((self, key.clone())),
|
||||
value: entry.value.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Lookup information for key invalidation.
|
||||
pub(crate) struct LookupInfo<K> {
|
||||
/// Time of creation of a cache [`Entry`].
|
||||
/// We use this during invalidation lookups to prevent eviction of a newer
|
||||
/// entry sharing the same key (it might've been inserted by a different
|
||||
/// task after we got the entry we're trying to invalidate now).
|
||||
created_at: Instant,
|
||||
|
||||
/// Search by this key.
|
||||
key: K,
|
||||
}
|
||||
|
||||
@@ -236,7 +236,7 @@ impl AuthInfo {
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
compute: &mut ComputeConnection,
|
||||
user_info: ComputeUserInfo,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<PostgresSettings, PostgresError> {
|
||||
// client config with stubbed connect info.
|
||||
// TODO(conrad): should we rewrite this to bypass tokio-postgres2 entirely,
|
||||
@@ -272,7 +272,7 @@ impl AuthInfo {
|
||||
secret_key,
|
||||
},
|
||||
compute.hostname.to_string(),
|
||||
user_info,
|
||||
user_info.clone(),
|
||||
);
|
||||
|
||||
Ok(PostgresSettings {
|
||||
|
||||
@@ -24,7 +24,6 @@ pub struct ProxyConfig {
|
||||
pub authentication_config: AuthenticationConfig,
|
||||
pub rest_config: RestConfig,
|
||||
pub proxy_protocol_v2: ProxyProtocolV2,
|
||||
pub region: String,
|
||||
pub handshake_timeout: Duration,
|
||||
pub wake_compute_retry_config: RetryConfig,
|
||||
pub connect_compute_locks: ApiLocks<Host>,
|
||||
|
||||
@@ -11,11 +11,12 @@ use crate::config::{ProxyConfig, ProxyProtocolV2};
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::pglb::ClientRequestError;
|
||||
use crate::pglb::handshake::{HandshakeData, handshake};
|
||||
use crate::pglb::passthrough::ProxyPassthrough;
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
|
||||
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
|
||||
use crate::proxy::{ClientRequestError, ErrorSource, prepare_client_connection};
|
||||
use crate::proxy::{ErrorSource, finish_client_init};
|
||||
use crate::util::run_until_cancelled;
|
||||
|
||||
pub async fn task_main(
|
||||
@@ -89,12 +90,7 @@ pub async fn task_main(
|
||||
}
|
||||
}
|
||||
|
||||
let ctx = RequestContext::new(
|
||||
session_id,
|
||||
conn_info,
|
||||
crate::metrics::Protocol::Tcp,
|
||||
&config.region,
|
||||
);
|
||||
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp);
|
||||
|
||||
let res = handle_client(
|
||||
config,
|
||||
@@ -231,13 +227,13 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
.await?;
|
||||
|
||||
let pg_settings = auth_info
|
||||
.authenticate(ctx, &mut node, user_info)
|
||||
.authenticate(ctx, &mut node, &user_info)
|
||||
.or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) })
|
||||
.await?;
|
||||
|
||||
let session = cancellation_handler.get_key();
|
||||
|
||||
prepare_client_connection(&pg_settings, *session.key(), &mut stream);
|
||||
finish_client_init(&pg_settings, *session.key(), &mut stream);
|
||||
let stream = stream.flush_and_into_inner().await?;
|
||||
|
||||
let session_id = ctx.session_id();
|
||||
|
||||
@@ -46,7 +46,6 @@ struct RequestContextInner {
|
||||
pub(crate) session_id: Uuid,
|
||||
pub(crate) protocol: Protocol,
|
||||
first_packet: chrono::DateTime<Utc>,
|
||||
region: &'static str,
|
||||
pub(crate) span: Span,
|
||||
|
||||
// filled in as they are discovered
|
||||
@@ -94,7 +93,6 @@ impl Clone for RequestContext {
|
||||
session_id: inner.session_id,
|
||||
protocol: inner.protocol,
|
||||
first_packet: inner.first_packet,
|
||||
region: inner.region,
|
||||
span: info_span!("background_task"),
|
||||
|
||||
project: inner.project,
|
||||
@@ -124,12 +122,7 @@ impl Clone for RequestContext {
|
||||
}
|
||||
|
||||
impl RequestContext {
|
||||
pub fn new(
|
||||
session_id: Uuid,
|
||||
conn_info: ConnectionInfo,
|
||||
protocol: Protocol,
|
||||
region: &'static str,
|
||||
) -> Self {
|
||||
pub fn new(session_id: Uuid, conn_info: ConnectionInfo, protocol: Protocol) -> Self {
|
||||
// TODO: be careful with long lived spans
|
||||
let span = info_span!(
|
||||
"connect_request",
|
||||
@@ -145,7 +138,6 @@ impl RequestContext {
|
||||
session_id,
|
||||
protocol,
|
||||
first_packet: Utc::now(),
|
||||
region,
|
||||
span,
|
||||
|
||||
project: None,
|
||||
@@ -179,7 +171,7 @@ impl RequestContext {
|
||||
let ip = IpAddr::from([127, 0, 0, 1]);
|
||||
let addr = SocketAddr::new(ip, 5432);
|
||||
let conn_info = ConnectionInfo { addr, extra: None };
|
||||
RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp, "test")
|
||||
RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp)
|
||||
}
|
||||
|
||||
pub(crate) fn console_application_name(&self) -> String {
|
||||
|
||||
@@ -74,7 +74,7 @@ pub(crate) const FAILED_UPLOAD_MAX_RETRIES: u32 = 10;
|
||||
|
||||
#[derive(parquet_derive::ParquetRecordWriter)]
|
||||
pub(crate) struct RequestData {
|
||||
region: &'static str,
|
||||
region: String,
|
||||
protocol: &'static str,
|
||||
/// Must be UTC. The derive macro doesn't like the timezones
|
||||
timestamp: chrono::NaiveDateTime,
|
||||
@@ -147,7 +147,7 @@ impl From<&RequestContextInner> for RequestData {
|
||||
}),
|
||||
jwt_issuer: value.jwt_issuer.clone(),
|
||||
protocol: value.protocol.as_str(),
|
||||
region: value.region,
|
||||
region: String::new(),
|
||||
error: value.error_kind.as_ref().map(|e| e.to_metric_label()),
|
||||
success: value.success,
|
||||
cold_start_info: value.cold_start_info.as_str(),
|
||||
@@ -167,6 +167,7 @@ impl From<&RequestContextInner> for RequestData {
|
||||
pub async fn worker(
|
||||
cancellation_token: CancellationToken,
|
||||
config: ParquetUploadArgs,
|
||||
region: String,
|
||||
) -> anyhow::Result<()> {
|
||||
let Some(remote_storage_config) = config.parquet_upload_remote_storage else {
|
||||
tracing::warn!("parquet request upload: no s3 bucket configured");
|
||||
@@ -232,12 +233,17 @@ pub async fn worker(
|
||||
.context("remote storage for disconnect events init")?;
|
||||
let parquet_config_disconnect = parquet_config.clone();
|
||||
tokio::try_join!(
|
||||
worker_inner(storage, rx, parquet_config),
|
||||
worker_inner(storage_disconnect, rx_disconnect, parquet_config_disconnect)
|
||||
worker_inner(storage, rx, parquet_config, ®ion),
|
||||
worker_inner(
|
||||
storage_disconnect,
|
||||
rx_disconnect,
|
||||
parquet_config_disconnect,
|
||||
®ion
|
||||
)
|
||||
)
|
||||
.map(|_| ())
|
||||
} else {
|
||||
worker_inner(storage, rx, parquet_config).await
|
||||
worker_inner(storage, rx, parquet_config, ®ion).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -257,6 +263,7 @@ async fn worker_inner(
|
||||
storage: GenericRemoteStorage,
|
||||
rx: impl Stream<Item = RequestData>,
|
||||
config: ParquetConfig,
|
||||
region: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
let storage = if config.test_remote_failures > 0 {
|
||||
@@ -277,7 +284,8 @@ async fn worker_inner(
|
||||
let mut last_upload = time::Instant::now();
|
||||
|
||||
let mut len = 0;
|
||||
while let Some(row) = rx.next().await {
|
||||
while let Some(mut row) = rx.next().await {
|
||||
region.clone_into(&mut row.region);
|
||||
rows.push(row);
|
||||
let force = last_upload.elapsed() > config.max_duration;
|
||||
if rows.len() == config.rows_per_group || force {
|
||||
@@ -533,7 +541,7 @@ mod tests {
|
||||
auth_method: None,
|
||||
jwt_issuer: None,
|
||||
protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)],
|
||||
region: "us-east-1",
|
||||
region: String::new(),
|
||||
error: None,
|
||||
success: rng.r#gen(),
|
||||
cold_start_info: "no",
|
||||
@@ -565,7 +573,9 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
worker_inner(storage, rx, config).await.unwrap();
|
||||
worker_inner(storage, rx, config, "us-east-1")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut files = WalkDir::new(tmpdir.as_std_path())
|
||||
.into_iter()
|
||||
|
||||
@@ -8,10 +8,10 @@ use crate::config::TlsConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::pglb::TlsRequired;
|
||||
use crate::pqproto::{
|
||||
BeMessage, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams,
|
||||
};
|
||||
use crate::proxy::TlsRequired;
|
||||
use crate::stream::{PqStream, Stream, StreamUpgradeError};
|
||||
use crate::tls::PG_ALPN_PROTOCOL;
|
||||
|
||||
|
||||
@@ -2,3 +2,332 @@ pub mod copy_bidirectional;
|
||||
pub mod handshake;
|
||||
pub mod inprocess;
|
||||
pub mod passthrough;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::FutureExt;
|
||||
use smol_str::ToSmolStr;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{Instrument, debug, error, info, warn};
|
||||
|
||||
use crate::auth;
|
||||
use crate::cancellation::{self, CancellationHandler};
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
pub use crate::pglb::copy_bidirectional::ErrorSource;
|
||||
use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake};
|
||||
use crate::pglb::passthrough::ProxyPassthrough;
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
|
||||
use crate::proxy::handle_client;
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::stream::Stream;
|
||||
use crate::util::run_until_cancelled;
|
||||
|
||||
pub const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
#[error("{ERR_INSECURE_CONNECTION}")]
|
||||
pub struct TlsRequired;
|
||||
|
||||
impl ReportableError for TlsRequired {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
crate::error::ErrorKind::User
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for TlsRequired {}
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static auth::Backend<'static, ()>,
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("proxy has shut down");
|
||||
}
|
||||
|
||||
// When set for the server socket, the keepalive setting
|
||||
// will be inherited by all accepted client sockets.
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
|
||||
while let Some(accept_result) =
|
||||
run_until_cancelled(listener.accept(), &cancellation_token).await
|
||||
{
|
||||
let (socket, peer_addr) = accept_result?;
|
||||
|
||||
let conn_gauge = Metrics::get()
|
||||
.proxy
|
||||
.client_connections
|
||||
.guard(crate::metrics::Protocol::Tcp);
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
let cancellations = cancellations.clone();
|
||||
|
||||
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
|
||||
|
||||
connections.spawn(async move {
|
||||
let (socket, conn_info) = match config.proxy_protocol_v2 {
|
||||
ProxyProtocolV2::Required => {
|
||||
match read_proxy_protocol(socket).await {
|
||||
Err(e) => {
|
||||
warn!("per-client task finished with an error: {e:#}");
|
||||
return;
|
||||
}
|
||||
// our load balancers will not send any more data. let's just exit immediately
|
||||
Ok((_socket, ConnectHeader::Local)) => {
|
||||
debug!("healthcheck received");
|
||||
return;
|
||||
}
|
||||
Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
|
||||
}
|
||||
}
|
||||
// ignore the header - it cannot be confused for a postgres or http connection so will
|
||||
// error later.
|
||||
ProxyProtocolV2::Rejected => (
|
||||
socket,
|
||||
ConnectionInfo {
|
||||
addr: peer_addr,
|
||||
extra: None,
|
||||
},
|
||||
),
|
||||
};
|
||||
|
||||
match socket.set_nodelay(true) {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"per-client task finished with an error: failed to set socket option: {e:#}"
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp);
|
||||
|
||||
let res = handle_connection(
|
||||
config,
|
||||
auth_backend,
|
||||
&ctx,
|
||||
cancellation_handler,
|
||||
socket,
|
||||
ClientMode::Tcp,
|
||||
endpoint_rate_limiter2,
|
||||
conn_gauge,
|
||||
cancellations,
|
||||
)
|
||||
.instrument(ctx.span())
|
||||
.boxed()
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Err(e) => {
|
||||
ctx.set_error_kind(e.get_error_kind());
|
||||
warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
|
||||
}
|
||||
Ok(None) => {
|
||||
ctx.set_success();
|
||||
}
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
let _disconnect = ctx.log_connect();
|
||||
match p.proxy_pass().await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
warn!(
|
||||
?session_id,
|
||||
"per-client task finished with an IO error from the client: {e:#}"
|
||||
);
|
||||
}
|
||||
Err(ErrorSource::Compute(e)) => {
|
||||
error!(
|
||||
?session_id,
|
||||
"per-client task finished with an IO error from the compute: {e:#}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
connections.close();
|
||||
cancellations.close();
|
||||
drop(listener);
|
||||
|
||||
// Drain connections
|
||||
connections.wait().await;
|
||||
cancellations.wait().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) enum ClientMode {
|
||||
Tcp,
|
||||
Websockets { hostname: Option<String> },
|
||||
}
|
||||
|
||||
/// Abstracts the logic of handling TCP vs WS clients
|
||||
impl ClientMode {
|
||||
pub fn allow_cleartext(&self) -> bool {
|
||||
match self {
|
||||
ClientMode::Tcp => false,
|
||||
ClientMode::Websockets { .. } => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
|
||||
match self {
|
||||
ClientMode::Tcp => s.sni_hostname(),
|
||||
ClientMode::Websockets { hostname } => hostname.as_deref(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
|
||||
match self {
|
||||
ClientMode::Tcp => tls,
|
||||
// TLS is None here if using websockets, because the connection is already encrypted.
|
||||
ClientMode::Websockets { .. } => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
// almost all errors should be reported to the user, but there's a few cases where we cannot
|
||||
// 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
|
||||
// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
|
||||
// we cannot be sure the client even understands our error message
|
||||
// 3. PrepareClient: The client disconnected, so we can't tell them anyway...
|
||||
pub(crate) enum ClientRequestError {
|
||||
#[error("{0}")]
|
||||
Cancellation(#[from] cancellation::CancelError),
|
||||
#[error("{0}")]
|
||||
Handshake(#[from] HandshakeError),
|
||||
#[error("{0}")]
|
||||
HandshakeTimeout(#[from] tokio::time::error::Elapsed),
|
||||
#[error("{0}")]
|
||||
PrepareClient(#[from] std::io::Error),
|
||||
#[error("{0}")]
|
||||
ReportedError(#[from] crate::stream::ReportedError),
|
||||
}
|
||||
|
||||
impl ReportableError for ClientRequestError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
ClientRequestError::Cancellation(e) => e.get_error_kind(),
|
||||
ClientRequestError::Handshake(e) => e.get_error_kind(),
|
||||
ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
|
||||
ClientRequestError::ReportedError(e) => e.get_error_kind(),
|
||||
ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn handle_connection<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static auth::Backend<'static, ()>,
|
||||
ctx: &RequestContext,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
client: S,
|
||||
mode: ClientMode,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
|
||||
debug!(
|
||||
protocol = %ctx.protocol(),
|
||||
"handling interactive connection from client"
|
||||
);
|
||||
|
||||
let metrics = &Metrics::get().proxy;
|
||||
let proto = ctx.protocol();
|
||||
let request_gauge = metrics.connection_requests.guard(proto);
|
||||
|
||||
let tls = config.tls_config.load();
|
||||
let tls = tls.as_deref();
|
||||
|
||||
let record_handshake_error = !ctx.has_private_peer_addr();
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
let do_handshake = handshake(ctx, client, mode.handshake_tls(tls), record_handshake_error);
|
||||
|
||||
let (mut client, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
|
||||
.await??
|
||||
{
|
||||
HandshakeData::Startup(client, params) => (client, params),
|
||||
HandshakeData::Cancel(cancel_key_data) => {
|
||||
// spawn a task to cancel the session, but don't wait for it
|
||||
cancellations.spawn({
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let ctx = ctx.clone();
|
||||
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
|
||||
cancel_span.follows_from(tracing::Span::current());
|
||||
async move {
|
||||
cancellation_handler_clone
|
||||
.cancel_session(
|
||||
cancel_key_data,
|
||||
ctx,
|
||||
config.authentication_config.ip_allowlist_check_enabled,
|
||||
config.authentication_config.is_vpc_acccess_proxy,
|
||||
auth_backend.get_api(),
|
||||
)
|
||||
.await
|
||||
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
|
||||
}.instrument(cancel_span)
|
||||
});
|
||||
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
drop(pause);
|
||||
|
||||
ctx.set_db_options(params.clone());
|
||||
|
||||
let common_names = tls.map(|tls| &tls.common_names);
|
||||
|
||||
let (node, cancel_on_shutdown) = handle_client(
|
||||
config,
|
||||
auth_backend,
|
||||
ctx,
|
||||
cancellation_handler,
|
||||
&mut client,
|
||||
&mode,
|
||||
endpoint_rate_limiter,
|
||||
common_names,
|
||||
¶ms,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let client = client.flush_and_into_inner().await?;
|
||||
|
||||
let private_link_id = match ctx.extra() {
|
||||
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
|
||||
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
|
||||
None => None,
|
||||
};
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client,
|
||||
compute: node.stream,
|
||||
|
||||
aux: node.aux,
|
||||
private_link_id,
|
||||
|
||||
_cancel_on_shutdown: cancel_on_shutdown,
|
||||
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
_db_conn: node.guage,
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -112,7 +112,7 @@ where
|
||||
let node_info = if !node_info.cached() || !err.should_retry_wake_compute() {
|
||||
// If we just recieved this from cplane and didn't get it from cache, we shouldn't retry.
|
||||
// Do not need to retrieve a new node_info, just return the old one.
|
||||
if should_retry(&err, num_retries, compute.retry) {
|
||||
if !should_retry(&err, num_retries, compute.retry) {
|
||||
Metrics::get().proxy.retries_metric.observe(
|
||||
RetriesMetricGroup {
|
||||
outcome: ConnectOutcome::Failed,
|
||||
|
||||
@@ -5,328 +5,64 @@ pub(crate) mod connect_compute;
|
||||
pub(crate) mod retry;
|
||||
pub(crate) mod wake_compute;
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::convert::Infallible;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::FutureExt;
|
||||
use itertools::Itertools;
|
||||
use once_cell::sync::OnceCell;
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol_str::{SmolStr, ToSmolStr, format_smolstr};
|
||||
use thiserror::Error;
|
||||
use smol_str::{SmolStr, format_smolstr};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{Instrument, debug, error, info, warn};
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::cancellation::{self, CancellationHandler};
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
|
||||
use crate::cache::Cache;
|
||||
use crate::cancellation::CancellationHandler;
|
||||
use crate::compute::ComputeConnection;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::control_plane::client::ControlPlaneClient;
|
||||
pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
|
||||
use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake};
|
||||
use crate::pglb::passthrough::ProxyPassthrough;
|
||||
use crate::pglb::{ClientMode, ClientRequestError};
|
||||
use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams};
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
|
||||
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
|
||||
use crate::proxy::retry::ShouldRetryWakeCompute;
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::stream::{PqStream, Stream};
|
||||
use crate::types::EndpointCacheKey;
|
||||
use crate::util::run_until_cancelled;
|
||||
use crate::{auth, compute};
|
||||
|
||||
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
#[error("{ERR_INSECURE_CONNECTION}")]
|
||||
pub struct TlsRequired;
|
||||
|
||||
impl ReportableError for TlsRequired {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
crate::error::ErrorKind::User
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for TlsRequired {}
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static auth::Backend<'static, ()>,
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("proxy has shut down");
|
||||
}
|
||||
|
||||
// When set for the server socket, the keepalive setting
|
||||
// will be inherited by all accepted client sockets.
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
|
||||
while let Some(accept_result) =
|
||||
run_until_cancelled(listener.accept(), &cancellation_token).await
|
||||
{
|
||||
let (socket, peer_addr) = accept_result?;
|
||||
|
||||
let conn_gauge = Metrics::get()
|
||||
.proxy
|
||||
.client_connections
|
||||
.guard(crate::metrics::Protocol::Tcp);
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
let cancellations = cancellations.clone();
|
||||
|
||||
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
|
||||
|
||||
connections.spawn(async move {
|
||||
let (socket, conn_info) = match config.proxy_protocol_v2 {
|
||||
ProxyProtocolV2::Required => {
|
||||
match read_proxy_protocol(socket).await {
|
||||
Err(e) => {
|
||||
warn!("per-client task finished with an error: {e:#}");
|
||||
return;
|
||||
}
|
||||
// our load balancers will not send any more data. let's just exit immediately
|
||||
Ok((_socket, ConnectHeader::Local)) => {
|
||||
debug!("healthcheck received");
|
||||
return;
|
||||
}
|
||||
Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
|
||||
}
|
||||
}
|
||||
// ignore the header - it cannot be confused for a postgres or http connection so will
|
||||
// error later.
|
||||
ProxyProtocolV2::Rejected => (
|
||||
socket,
|
||||
ConnectionInfo {
|
||||
addr: peer_addr,
|
||||
extra: None,
|
||||
},
|
||||
),
|
||||
};
|
||||
|
||||
match socket.set_nodelay(true) {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"per-client task finished with an error: failed to set socket option: {e:#}"
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let ctx = RequestContext::new(
|
||||
session_id,
|
||||
conn_info,
|
||||
crate::metrics::Protocol::Tcp,
|
||||
&config.region,
|
||||
);
|
||||
|
||||
let res = handle_client(
|
||||
config,
|
||||
auth_backend,
|
||||
&ctx,
|
||||
cancellation_handler,
|
||||
socket,
|
||||
ClientMode::Tcp,
|
||||
endpoint_rate_limiter2,
|
||||
conn_gauge,
|
||||
cancellations,
|
||||
)
|
||||
.instrument(ctx.span())
|
||||
.boxed()
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Err(e) => {
|
||||
ctx.set_error_kind(e.get_error_kind());
|
||||
warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
|
||||
}
|
||||
Ok(None) => {
|
||||
ctx.set_success();
|
||||
}
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
let _disconnect = ctx.log_connect();
|
||||
match p.proxy_pass().await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
warn!(
|
||||
?session_id,
|
||||
"per-client task finished with an IO error from the client: {e:#}"
|
||||
);
|
||||
}
|
||||
Err(ErrorSource::Compute(e)) => {
|
||||
error!(
|
||||
?session_id,
|
||||
"per-client task finished with an IO error from the compute: {e:#}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
connections.close();
|
||||
cancellations.close();
|
||||
drop(listener);
|
||||
|
||||
// Drain connections
|
||||
connections.wait().await;
|
||||
cancellations.wait().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) enum ClientMode {
|
||||
Tcp,
|
||||
Websockets { hostname: Option<String> },
|
||||
}
|
||||
|
||||
/// Abstracts the logic of handling TCP vs WS clients
|
||||
impl ClientMode {
|
||||
pub(crate) fn allow_cleartext(&self) -> bool {
|
||||
match self {
|
||||
ClientMode::Tcp => false,
|
||||
ClientMode::Websockets { .. } => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
|
||||
match self {
|
||||
ClientMode::Tcp => s.sni_hostname(),
|
||||
ClientMode::Websockets { hostname } => hostname.as_deref(),
|
||||
}
|
||||
}
|
||||
|
||||
fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
|
||||
match self {
|
||||
ClientMode::Tcp => tls,
|
||||
// TLS is None here if using websockets, because the connection is already encrypted.
|
||||
ClientMode::Websockets { .. } => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
// almost all errors should be reported to the user, but there's a few cases where we cannot
|
||||
// 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
|
||||
// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
|
||||
// we cannot be sure the client even understands our error message
|
||||
// 3. PrepareClient: The client disconnected, so we can't tell them anyway...
|
||||
pub(crate) enum ClientRequestError {
|
||||
#[error("{0}")]
|
||||
Cancellation(#[from] cancellation::CancelError),
|
||||
#[error("{0}")]
|
||||
Handshake(#[from] HandshakeError),
|
||||
#[error("{0}")]
|
||||
HandshakeTimeout(#[from] tokio::time::error::Elapsed),
|
||||
#[error("{0}")]
|
||||
PrepareClient(#[from] std::io::Error),
|
||||
#[error("{0}")]
|
||||
ReportedError(#[from] crate::stream::ReportedError),
|
||||
}
|
||||
|
||||
impl ReportableError for ClientRequestError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
ClientRequestError::Cancellation(e) => e.get_error_kind(),
|
||||
ClientRequestError::Handshake(e) => e.get_error_kind(),
|
||||
ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
|
||||
ClientRequestError::ReportedError(e) => e.get_error_kind(),
|
||||
ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static auth::Backend<'static, ()>,
|
||||
ctx: &RequestContext,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
stream: S,
|
||||
mode: ClientMode,
|
||||
client: &mut PqStream<Stream<S>>,
|
||||
mode: &ClientMode,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
|
||||
debug!(
|
||||
protocol = %ctx.protocol(),
|
||||
"handling interactive connection from client"
|
||||
);
|
||||
|
||||
let metrics = &Metrics::get().proxy;
|
||||
let proto = ctx.protocol();
|
||||
let request_gauge = metrics.connection_requests.guard(proto);
|
||||
|
||||
let tls = config.tls_config.load();
|
||||
let tls = tls.as_deref();
|
||||
|
||||
let record_handshake_error = !ctx.has_private_peer_addr();
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
|
||||
|
||||
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
|
||||
.await??
|
||||
{
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(cancel_key_data) => {
|
||||
// spawn a task to cancel the session, but don't wait for it
|
||||
cancellations.spawn({
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let ctx = ctx.clone();
|
||||
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
|
||||
cancel_span.follows_from(tracing::Span::current());
|
||||
async move {
|
||||
cancellation_handler_clone
|
||||
.cancel_session(
|
||||
cancel_key_data,
|
||||
ctx,
|
||||
config.authentication_config.ip_allowlist_check_enabled,
|
||||
config.authentication_config.is_vpc_acccess_proxy,
|
||||
auth_backend.get_api(),
|
||||
)
|
||||
.await
|
||||
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
|
||||
}.instrument(cancel_span)
|
||||
});
|
||||
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
drop(pause);
|
||||
|
||||
ctx.set_db_options(params.clone());
|
||||
|
||||
let hostname = mode.hostname(stream.get_ref());
|
||||
|
||||
let common_names = tls.map(|tls| &tls.common_names);
|
||||
|
||||
common_names: Option<&HashSet<String>>,
|
||||
params: &StartupMessageParams,
|
||||
) -> Result<(ComputeConnection, oneshot::Sender<Infallible>), ClientRequestError> {
|
||||
let hostname = mode.hostname(client.get_ref());
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let result = auth_backend
|
||||
.as_ref()
|
||||
.map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names))
|
||||
.map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, params, hostname, common_names))
|
||||
.transpose();
|
||||
|
||||
let user_info = match result {
|
||||
Ok(user_info) => user_info,
|
||||
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
|
||||
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
|
||||
};
|
||||
|
||||
let user = user_info.get_user().to_owned();
|
||||
let user_info = match user_info
|
||||
.authenticate(
|
||||
ctx,
|
||||
&mut stream,
|
||||
client,
|
||||
mode.allow_cleartext(),
|
||||
&config.authentication_config,
|
||||
endpoint_rate_limiter,
|
||||
@@ -339,7 +75,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
let app = params.get("application_name");
|
||||
let params_span = tracing::info_span!("", ?user, ?db, ?app);
|
||||
|
||||
return Err(stream
|
||||
return Err(client
|
||||
.throw_error(e, Some(ctx))
|
||||
.instrument(params_span)
|
||||
.await)?;
|
||||
@@ -352,37 +88,67 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
};
|
||||
let params_compat = creds.info.options.get(NeonOptions::PARAMS_COMPAT).is_some();
|
||||
let mut auth_info = compute::AuthInfo::with_auth_keys(creds.keys);
|
||||
auth_info.set_startup_params(¶ms, params_compat);
|
||||
auth_info.set_startup_params(params, params_compat);
|
||||
|
||||
let res = connect_to_compute(
|
||||
ctx,
|
||||
&TcpMechanism {
|
||||
locks: &config.connect_compute_locks,
|
||||
},
|
||||
&auth::Backend::ControlPlane(cplane, creds.info.clone()),
|
||||
config.wake_compute_retry_config,
|
||||
&config.connect_to_compute,
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut node = match res {
|
||||
Ok(node) => node,
|
||||
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
|
||||
let mut node;
|
||||
let mut attempt = 0;
|
||||
let connect = TcpMechanism {
|
||||
locks: &config.connect_compute_locks,
|
||||
};
|
||||
let backend = auth::Backend::ControlPlane(cplane, creds.info);
|
||||
|
||||
let pg_settings = auth_info.authenticate(ctx, &mut node, creds.info).await;
|
||||
let pg_settings = match pg_settings {
|
||||
Ok(pg_settings) => pg_settings,
|
||||
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
|
||||
// NOTE: This is messy, but should hopefully be detangled with PGLB.
|
||||
// We wanted to separate the concerns of **connect** to compute (a PGLB operation),
|
||||
// from **authenticate** to compute (a NeonKeeper operation).
|
||||
//
|
||||
// This unfortunately removed retry handling for one error case where
|
||||
// the compute was cached, and we connected, but the compute cache was actually stale
|
||||
// and is associated with the wrong endpoint. We detect this when the **authentication** fails.
|
||||
// As such, we retry once here if the `authenticate` function fails and the error is valid to retry.
|
||||
let pg_settings = loop {
|
||||
attempt += 1;
|
||||
|
||||
// TODO: callback to pglb
|
||||
let res = connect_to_compute(
|
||||
ctx,
|
||||
&connect,
|
||||
&backend,
|
||||
config.wake_compute_retry_config,
|
||||
&config.connect_to_compute,
|
||||
)
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Ok(n) => node = n,
|
||||
Err(e) => return Err(client.throw_error(e, Some(ctx)).await)?,
|
||||
}
|
||||
|
||||
let auth::Backend::ControlPlane(cplane, user_info) = &backend else {
|
||||
unreachable!("ensured above");
|
||||
};
|
||||
|
||||
let res = auth_info.authenticate(ctx, &mut node, user_info).await;
|
||||
match res {
|
||||
Ok(pg_settings) => break pg_settings,
|
||||
Err(e) if attempt < 2 && e.should_retry_wake_compute() => {
|
||||
tracing::warn!(error = ?e, "retrying wake compute");
|
||||
|
||||
#[allow(irrefutable_let_patterns)]
|
||||
if let ControlPlaneClient::ProxyV1(cplane_proxy_v1) = &**cplane {
|
||||
let key = user_info.endpoint_cache_key();
|
||||
cplane_proxy_v1.caches.node_info.invalidate(&key);
|
||||
}
|
||||
}
|
||||
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
|
||||
}
|
||||
};
|
||||
|
||||
let session = cancellation_handler.get_key();
|
||||
|
||||
prepare_client_connection(&pg_settings, *session.key(), &mut stream);
|
||||
let stream = stream.flush_and_into_inner().await?;
|
||||
finish_client_init(&pg_settings, *session.key(), client);
|
||||
|
||||
let session_id = ctx.session_id();
|
||||
let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel();
|
||||
let (cancel_on_shutdown, cancel) = oneshot::channel();
|
||||
tokio::spawn(async move {
|
||||
session
|
||||
.maintain_cancel_key(
|
||||
@@ -394,50 +160,32 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
.await;
|
||||
});
|
||||
|
||||
let private_link_id = match ctx.extra() {
|
||||
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
|
||||
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
|
||||
None => None,
|
||||
};
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
compute: node.stream,
|
||||
|
||||
aux: node.aux,
|
||||
private_link_id,
|
||||
|
||||
_cancel_on_shutdown: cancel_on_shutdown,
|
||||
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
_db_conn: node.guage,
|
||||
}))
|
||||
Ok((node, cancel_on_shutdown))
|
||||
}
|
||||
|
||||
/// Finish client connection initialization: confirm auth success, send params, etc.
|
||||
pub(crate) fn prepare_client_connection(
|
||||
pub(crate) fn finish_client_init(
|
||||
settings: &compute::PostgresSettings,
|
||||
cancel_key_data: CancelKeyData,
|
||||
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) {
|
||||
// Forward all deferred notices to the client.
|
||||
for notice in &settings.delayed_notice {
|
||||
stream.write_raw(notice.as_bytes().len(), b'N', |buf| {
|
||||
client.write_raw(notice.as_bytes().len(), b'N', |buf| {
|
||||
buf.extend_from_slice(notice.as_bytes());
|
||||
});
|
||||
}
|
||||
|
||||
// Forward all postgres connection params to the client.
|
||||
for (name, value) in &settings.params {
|
||||
stream.write_message(BeMessage::ParameterStatus {
|
||||
client.write_message(BeMessage::ParameterStatus {
|
||||
name: name.as_bytes(),
|
||||
value: value.as_bytes(),
|
||||
});
|
||||
}
|
||||
|
||||
stream.write_message(BeMessage::BackendKeyData(cancel_key_data));
|
||||
stream.write_message(BeMessage::ReadyForQuery);
|
||||
client.write_message(BeMessage::BackendKeyData(cancel_key_data));
|
||||
client.write_message(BeMessage::ReadyForQuery);
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
|
||||
@@ -447,7 +195,7 @@ impl NeonOptions {
|
||||
// proxy options:
|
||||
|
||||
/// `PARAMS_COMPAT` allows opting in to forwarding all startup parameters from client to compute.
|
||||
const PARAMS_COMPAT: &str = "proxy_params_compat";
|
||||
pub const PARAMS_COMPAT: &str = "proxy_params_compat";
|
||||
|
||||
// cplane options:
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::io;
|
||||
|
||||
use tokio::time;
|
||||
|
||||
use crate::compute;
|
||||
use crate::compute::{self, PostgresError};
|
||||
use crate::config::RetryConfig;
|
||||
|
||||
pub(crate) trait CouldRetry {
|
||||
@@ -115,6 +115,14 @@ impl ShouldRetryWakeCompute for compute::ConnectionError {
|
||||
}
|
||||
}
|
||||
|
||||
impl ShouldRetryWakeCompute for PostgresError {
|
||||
fn should_retry_wake_compute(&self) -> bool {
|
||||
match self {
|
||||
PostgresError::Postgres(error) => error.should_retry_wake_compute(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn retry_after(num_retries: u32, config: RetryConfig) -> time::Duration {
|
||||
config
|
||||
.base_delay
|
||||
|
||||
@@ -14,6 +14,9 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream};
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
use super::*;
|
||||
use crate::config::TlsConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::pglb::handshake::{HandshakeData, handshake};
|
||||
|
||||
enum Intercept {
|
||||
None,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
mod mitm;
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{Context, bail};
|
||||
@@ -10,26 +11,31 @@ use async_trait::async_trait;
|
||||
use http::StatusCode;
|
||||
use postgres_client::config::SslMode;
|
||||
use postgres_client::tls::{MakeTlsConnect, NoTls};
|
||||
use retry::{ShouldRetryWakeCompute, retry_after};
|
||||
use rstest::rstest;
|
||||
use rustls::crypto::ring;
|
||||
use rustls::pki_types;
|
||||
use tokio::io::DuplexStream;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, DuplexStream};
|
||||
use tracing_test::traced_test;
|
||||
|
||||
use super::retry::CouldRetry;
|
||||
use super::*;
|
||||
use crate::auth::backend::{ComputeUserInfo, MaybeOwned};
|
||||
use crate::config::{ComputeConfig, RetryConfig};
|
||||
use crate::config::{ComputeConfig, RetryConfig, TlsConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
|
||||
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
|
||||
use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache};
|
||||
use crate::error::ErrorKind;
|
||||
use crate::proxy::connect_compute::ConnectMechanism;
|
||||
use crate::error::{ErrorKind, ReportableError};
|
||||
use crate::pglb::ERR_INSECURE_CONNECTION;
|
||||
use crate::pglb::handshake::{HandshakeData, handshake};
|
||||
use crate::pqproto::BeMessage;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::proxy::connect_compute::{ConnectMechanism, connect_to_compute};
|
||||
use crate::proxy::retry::{ShouldRetryWakeCompute, retry_after};
|
||||
use crate::stream::{PqStream, Stream};
|
||||
use crate::tls::client_config::compute_client_config_with_certs;
|
||||
use crate::tls::server_config::CertResolver;
|
||||
use crate::types::{BranchId, EndpointId, ProjectId};
|
||||
use crate::{sasl, scram};
|
||||
use crate::{auth, compute, sasl, scram};
|
||||
|
||||
/// Generate a set of TLS certificates: CA + server.
|
||||
fn generate_certs(
|
||||
@@ -374,6 +380,7 @@ fn connect_compute_total_wait() {
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
enum ConnectAction {
|
||||
Wake,
|
||||
WakeCold,
|
||||
WakeFail,
|
||||
WakeRetry,
|
||||
Connect,
|
||||
@@ -504,6 +511,9 @@ impl TestControlPlaneClient for TestConnectMechanism {
|
||||
*counter += 1;
|
||||
match action {
|
||||
ConnectAction::Wake => Ok(helper_create_cached_node_info(self.cache)),
|
||||
ConnectAction::WakeCold => Ok(CachedNodeInfo::new_uncached(
|
||||
helper_create_uncached_node_info(),
|
||||
)),
|
||||
ConnectAction::WakeFail => {
|
||||
let err = control_plane::errors::ControlPlaneError::Message(Box::new(
|
||||
ControlPlaneErrorMessage {
|
||||
@@ -551,8 +561,8 @@ impl TestControlPlaneClient for TestConnectMechanism {
|
||||
}
|
||||
}
|
||||
|
||||
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
|
||||
let node = NodeInfo {
|
||||
fn helper_create_uncached_node_info() -> NodeInfo {
|
||||
NodeInfo {
|
||||
conn_info: compute::ConnectInfo {
|
||||
host: "test".into(),
|
||||
port: 5432,
|
||||
@@ -566,7 +576,11 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn
|
||||
compute_id: "compute".into(),
|
||||
cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
|
||||
let node = helper_create_uncached_node_info();
|
||||
let (_, node2) = cache.insert_unit("key".into(), Ok(node.clone()));
|
||||
node2.map(|()| node)
|
||||
}
|
||||
@@ -742,7 +756,7 @@ async fn fail_no_wake_skips_cache_invalidation() {
|
||||
let ctx = RequestContext::test();
|
||||
let mech = TestConnectMechanism::new(vec![
|
||||
ConnectAction::Wake,
|
||||
ConnectAction::FailNoWake,
|
||||
ConnectAction::RetryNoWake,
|
||||
ConnectAction::Connect,
|
||||
]);
|
||||
let user = helper_create_connect_info(&mech);
|
||||
@@ -788,7 +802,7 @@ async fn retry_no_wake_skips_invalidation() {
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
// Wake → RetryNoWake (retryable + NOT wakeable)
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, RetryNoWake]);
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, RetryNoWake, Fail]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let cfg = config();
|
||||
|
||||
@@ -802,3 +816,44 @@ async fn retry_no_wake_skips_invalidation() {
|
||||
"invalidating stalled compute node info cache entry"
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[traced_test]
|
||||
async fn retry_no_wake_error_fast() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
// Wake → FailNoWake (not retryable + NOT wakeable)
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, FailNoWake]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let cfg = config();
|
||||
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
|
||||
// Because FailNoWake has wakeable=false, we must NOT see invalidate_cache
|
||||
assert!(!logs_contain(
|
||||
"invalidating stalled compute node info cache entry"
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[traced_test]
|
||||
async fn retry_cold_wake_skips_invalidation() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
// WakeCold → FailNoWake (not retryable + NOT wakeable)
|
||||
let mechanism = TestConnectMechanism::new(vec![WakeCold, Retry, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let cfg = config();
|
||||
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
}
|
||||
|
||||
@@ -139,12 +139,6 @@ impl RateBucketInfo {
|
||||
Self::new(200, Duration::from_secs(600)),
|
||||
];
|
||||
|
||||
// For all the sessions will be cancel key. So this limit is essentially global proxy limit.
|
||||
pub const DEFAULT_REDIS_SET: [Self; 2] = [
|
||||
Self::new(100_000, Duration::from_secs(1)),
|
||||
Self::new(50_000, Duration::from_secs(10)),
|
||||
];
|
||||
|
||||
pub fn rps(&self) -> f64 {
|
||||
(self.max_rpi as f64) / self.interval.as_secs_f64()
|
||||
}
|
||||
|
||||
@@ -5,11 +5,9 @@ use redis::aio::ConnectionLike;
|
||||
use redis::{Cmd, FromRedisValue, Pipeline, RedisResult};
|
||||
|
||||
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
|
||||
use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo};
|
||||
|
||||
pub struct RedisKVClient {
|
||||
client: ConnectionWithCredentialsProvider,
|
||||
limiter: GlobalRateLimiter,
|
||||
}
|
||||
|
||||
#[allow(async_fn_in_trait)]
|
||||
@@ -30,11 +28,8 @@ impl Queryable for Cmd {
|
||||
}
|
||||
|
||||
impl RedisKVClient {
|
||||
pub fn new(client: ConnectionWithCredentialsProvider, info: &'static [RateBucketInfo]) -> Self {
|
||||
Self {
|
||||
client,
|
||||
limiter: GlobalRateLimiter::new(info.into()),
|
||||
}
|
||||
pub fn new(client: ConnectionWithCredentialsProvider) -> Self {
|
||||
Self { client }
|
||||
}
|
||||
|
||||
pub async fn try_connect(&mut self) -> anyhow::Result<()> {
|
||||
@@ -49,11 +44,6 @@ impl RedisKVClient {
|
||||
&mut self,
|
||||
q: &impl Queryable,
|
||||
) -> anyhow::Result<T> {
|
||||
if !self.limiter.check() {
|
||||
tracing::info!("Rate limit exceeded. Skipping query");
|
||||
return Err(anyhow::anyhow!("Rate limit exceeded"));
|
||||
}
|
||||
|
||||
let e = match q.query(&mut self.client).await {
|
||||
Ok(t) => return Ok(t),
|
||||
Err(e) => e,
|
||||
|
||||
@@ -141,29 +141,19 @@ where
|
||||
|
||||
struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
|
||||
cache: Arc<C>,
|
||||
region_id: String,
|
||||
}
|
||||
|
||||
impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
cache: self.cache.clone(),
|
||||
region_id: self.region_id.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
|
||||
pub(crate) fn new(cache: Arc<C>, region_id: String) -> Self {
|
||||
Self { cache, region_id }
|
||||
}
|
||||
|
||||
pub(crate) async fn increment_active_listeners(&self) {
|
||||
self.cache.increment_active_listeners().await;
|
||||
}
|
||||
|
||||
pub(crate) async fn decrement_active_listeners(&self) {
|
||||
self.cache.decrement_active_listeners().await;
|
||||
pub(crate) fn new(cache: Arc<C>) -> Self {
|
||||
Self { cache }
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
|
||||
@@ -276,7 +266,7 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
|
||||
}
|
||||
let mut conn = match try_connect(&redis).await {
|
||||
Ok(conn) => {
|
||||
handler.increment_active_listeners().await;
|
||||
handler.cache.increment_active_listeners().await;
|
||||
conn
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -297,11 +287,11 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
|
||||
}
|
||||
}
|
||||
if cancellation_token.is_cancelled() {
|
||||
handler.decrement_active_listeners().await;
|
||||
handler.cache.decrement_active_listeners().await;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
handler.decrement_active_listeners().await;
|
||||
handler.cache.decrement_active_listeners().await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -310,12 +300,11 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
|
||||
pub async fn task_main<C>(
|
||||
redis: ConnectionWithCredentialsProvider,
|
||||
cache: Arc<C>,
|
||||
region_id: String,
|
||||
) -> anyhow::Result<Infallible>
|
||||
where
|
||||
C: ProjectInfoCache + Send + Sync + 'static,
|
||||
{
|
||||
let handler = MessageHandler::new(cache, region_id);
|
||||
let handler = MessageHandler::new(cache);
|
||||
// 6h - 1m.
|
||||
// There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));
|
||||
|
||||
@@ -418,12 +418,7 @@ async fn request_handler(
|
||||
if config.http_config.accept_websockets
|
||||
&& framed_websockets::upgrade::is_upgrade_request(&request)
|
||||
{
|
||||
let ctx = RequestContext::new(
|
||||
session_id,
|
||||
conn_info,
|
||||
crate::metrics::Protocol::Ws,
|
||||
&config.region,
|
||||
);
|
||||
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Ws);
|
||||
|
||||
ctx.set_user_agent(
|
||||
request
|
||||
@@ -463,12 +458,7 @@ async fn request_handler(
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
|
||||
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
|
||||
let ctx = RequestContext::new(
|
||||
session_id,
|
||||
conn_info,
|
||||
crate::metrics::Protocol::Http,
|
||||
&config.region,
|
||||
);
|
||||
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Http);
|
||||
let span = ctx.span();
|
||||
|
||||
let testodrome_id = request
|
||||
|
||||
@@ -17,7 +17,8 @@ use crate::config::ProxyConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::proxy::{ClientMode, ErrorSource, handle_client};
|
||||
use crate::pglb::{ClientMode, handle_connection};
|
||||
use crate::proxy::ErrorSource;
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
|
||||
pin_project! {
|
||||
@@ -142,7 +143,7 @@ pub(crate) async fn serve_websocket(
|
||||
.client_connections
|
||||
.guard(crate::metrics::Protocol::Ws);
|
||||
|
||||
let res = Box::pin(handle_client(
|
||||
let res = Box::pin(handle_connection(
|
||||
config,
|
||||
auth_backend,
|
||||
&ctx,
|
||||
|
||||
Reference in New Issue
Block a user