Compare commits

...

1 Commits

Author SHA1 Message Date
Ivan Efremov
0acc612e3e impr(proxy): introduce proxy_id for cancel key 2025-01-20 18:43:16 +02:00
9 changed files with 71 additions and 8 deletions

View File

@@ -14,7 +14,7 @@ use proxy::auth::backend::local::{LocalBackend, JWKS_ROLE_MAP};
use proxy::auth::{self}; use proxy::auth::{self};
use proxy::cancellation::CancellationHandlerMain; use proxy::cancellation::CancellationHandlerMain;
use proxy::config::{ use proxy::config::{
self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig, self, obfuscated_proxy_id, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig
}; };
use proxy::control_plane::locks::ApiLocks; use proxy::control_plane::locks::ApiLocks;
use proxy::control_plane::messages::{EndpointJwksResponse, JwksSettings}; use proxy::control_plane::messages::{EndpointJwksResponse, JwksSettings};
@@ -218,6 +218,7 @@ async fn main() -> anyhow::Result<()> {
proxy::metrics::CancellationSource::Local, proxy::metrics::CancellationSource::Local,
)), )),
endpoint_rate_limiter, endpoint_rate_limiter,
obfuscated_proxy_id(std::process::id(), "local"),
); );
match futures::future::select(pin!(maintenance_tasks.join_next()), pin!(task)).await { match futures::future::select(pin!(maintenance_tasks.join_next()), pin!(task)).await {

View File

@@ -10,6 +10,7 @@ use clap::Arg;
use futures::future::Either; use futures::future::Either;
use futures::TryFutureExt; use futures::TryFutureExt;
use itertools::Itertools; use itertools::Itertools;
use proxy::config::obfuscated_proxy_id;
use proxy::context::RequestContext; use proxy::context::RequestContext;
use proxy::metrics::{Metrics, ThreadPoolMetrics}; use proxy::metrics::{Metrics, ThreadPoolMetrics};
use proxy::protocol2::ConnectionInfo; use proxy::protocol2::ConnectionInfo;
@@ -185,6 +186,7 @@ async fn task_main(
}, },
proxy::metrics::Protocol::SniRouter, proxy::metrics::Protocol::SniRouter,
"sni", "sni",
obfuscated_proxy_id(std::process::id(), "sni-router"), // just a shim for context
); );
handle_client(ctx, dest_suffix, tls_config, tls_server_end_point, socket).await handle_client(ctx, dest_suffix, tls_config, tls_server_end_point, socket).await
} }

View File

@@ -10,7 +10,7 @@ use proxy::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned};
use proxy::cancellation::{CancelMap, CancellationHandler}; use proxy::cancellation::{CancelMap, CancellationHandler};
use proxy::config::{ use proxy::config::{
self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig, self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig,
ProjectInfoCacheOptions, ProxyConfig, ProxyProtocolV2, ProjectInfoCacheOptions, ProxyConfig, ProxyProtocolV2, obfuscated_proxy_id,
}; };
use proxy::context::parquet::ParquetUploadArgs; use proxy::context::parquet::ParquetUploadArgs;
use proxy::http::health_server::AppMetrics; use proxy::http::health_server::AppMetrics;
@@ -396,6 +396,8 @@ async fn main() -> anyhow::Result<()> {
None => None, None => None,
}; };
let proxy_id: u16 = obfuscated_proxy_id(std::process::id(), &args.region);
let cancellation_handler = Arc::new(CancellationHandler::< let cancellation_handler = Arc::new(CancellationHandler::<
Option<Arc<Mutex<RedisPublisherClient>>>, Option<Arc<Mutex<RedisPublisherClient>>>,
>::new( >::new(
@@ -437,6 +439,7 @@ async fn main() -> anyhow::Result<()> {
cancellation_token.clone(), cancellation_token.clone(),
cancellation_handler.clone(), cancellation_handler.clone(),
endpoint_rate_limiter.clone(), endpoint_rate_limiter.clone(),
proxy_id,
)); ));
} }
@@ -448,6 +451,7 @@ async fn main() -> anyhow::Result<()> {
cancellation_token.clone(), cancellation_token.clone(),
cancellation_handler.clone(), cancellation_handler.clone(),
endpoint_rate_limiter.clone(), endpoint_rate_limiter.clone(),
proxy_id,
)); ));
} }
} }
@@ -459,6 +463,7 @@ async fn main() -> anyhow::Result<()> {
proxy_listener, proxy_listener,
cancellation_token.clone(), cancellation_token.clone(),
cancellation_handler.clone(), cancellation_handler.clone(),
proxy_id,
)); ));
} }
} }

View File

@@ -80,15 +80,24 @@ impl ReportableError for CancelError {
impl<P: CancellationPublisher> CancellationHandler<P> { impl<P: CancellationPublisher> CancellationHandler<P> {
/// Run async action within an ephemeral session identified by [`CancelKeyData`]. /// Run async action within an ephemeral session identified by [`CancelKeyData`].
pub(crate) fn get_session(self: Arc<Self>) -> Session<P> { pub(crate) fn get_session(self: Arc<Self>, proxy_id: u16) -> Session<P> {
// we intentionally generate a random "backend pid" and "secret key" here. // we intentionally generate a random "backend pid" and "secret key" here.
// we use the corresponding u64 as an identifier for the // we use the corresponding u64 as an identifier for the
// actual endpoint+pid+secret for postgres/pgbouncer. // actual endpoint+pid+secret for postgres/pgbouncer.
// //
// if we forwarded the backend_pid from postgres to the client, there would be a lot // if we forwarded the backend_pid from postgres to the client, there would be a lot
// of overlap between our computes as most pids are small (~100). // of overlap between our computes as most pids are small (~100).
let key = loop { let key = loop {
let key = rand::random(); let key_rand: u64 = rand::random::<u64>() & 0x0000_ffff_ffff_ffff;
let backend_pid = ((proxy_id as u32) << 16) | ((key_rand >> 32) as u32) as u32;
let cancel_key = (key_rand as u32) as i32;
let key = CancelKeyData {
backend_pid: (backend_pid as i32),
cancel_key,
};
// Random key collisions are unlikely to happen here, but they're still possible, // Random key collisions are unlikely to happen here, but they're still possible,
// which is why we have to take care not to rewrite an existing key. // which is why we have to take care not to rewrite an existing key.
@@ -451,7 +460,7 @@ mod tests {
CancellationSource::FromRedis, CancellationSource::FromRedis,
)); ));
let session = cancellation_handler.clone().get_session(); let session = cancellation_handler.clone().get_session(123);
assert!(cancellation_handler.contains(&session)); assert!(cancellation_handler.contains(&session));
drop(session); drop(session);
// Check that the session has been dropped. // Check that the session has been dropped.

View File

@@ -16,6 +16,8 @@ use crate::serverless::GlobalConnPoolOptions;
pub use crate::tls::server_config::{configure_tls, TlsConfig}; pub use crate::tls::server_config::{configure_tls, TlsConfig};
use crate::types::Host; use crate::types::Host;
use sha2::{Digest, Sha256};
pub struct ProxyConfig { pub struct ProxyConfig {
pub tls_config: Option<TlsConfig>, pub tls_config: Option<TlsConfig>,
pub metric_collection: Option<MetricCollectionConfig>, pub metric_collection: Option<MetricCollectionConfig>,
@@ -416,6 +418,23 @@ impl FromStr for ConcurrencyLockOptions {
} }
} }
fn map_u32_to_u8(value: u32) -> u8 {
((value * 31 + 17) % 255) as u8
}
pub fn obfuscated_proxy_id(process_id: u32, region_id: &str) -> u16 {
let process_id = map_u32_to_u8(process_id);
let hash_region_id = Sha256::digest(region_id.as_bytes());
const BASE: u64 = 257;
let combined_region = hash_region_id.iter().enumerate().fold(0u64, |acc, (i, &byte)| {
(acc + (byte as u64 * BASE.pow(i as u32))) % 255
});
let combined_region = (combined_region % 255) as u8;
((combined_region as u16) * 257 + (process_id as u16)) % 65535
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@@ -511,4 +530,12 @@ mod tests {
Ok(()) Ok(())
} }
#[test]
fn test_proxy_id_obfuscation() {
let process_id = 123;
let region_id = "us-west-2";
let proxy_id = obfuscated_proxy_id(process_id, region_id);
assert_eq!(proxy_id, 0x1f7b);
}
} }

View File

@@ -25,6 +25,7 @@ pub async fn task_main(
listener: tokio::net::TcpListener, listener: tokio::net::TcpListener,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>, cancellation_handler: Arc<CancellationHandlerMain>,
proxy_id: u16,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
scopeguard::defer! { scopeguard::defer! {
info!("proxy has shut down"); info!("proxy has shut down");
@@ -89,6 +90,7 @@ pub async fn task_main(
peer_addr, peer_addr,
crate::metrics::Protocol::Tcp, crate::metrics::Protocol::Tcp,
&config.region, &config.region,
proxy_id,
); );
let res = handle_client( let res = handle_client(
@@ -222,7 +224,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
node.cancel_closure node.cancel_closure
.set_ip_allowlist(ip_allowlist.unwrap_or_default()); .set_ip_allowlist(ip_allowlist.unwrap_or_default());
let session = cancellation_handler.get_session(); let session = cancellation_handler.get_session(ctx.proxy_id());
prepare_client_connection(&node, &session, &mut stream).await?; prepare_client_connection(&node, &session, &mut stream).await?;
// Before proxy passing, forward to compute whatever data is left in the // Before proxy passing, forward to compute whatever data is left in the

View File

@@ -46,6 +46,7 @@ struct RequestContextInner {
pub(crate) protocol: Protocol, pub(crate) protocol: Protocol,
first_packet: chrono::DateTime<Utc>, first_packet: chrono::DateTime<Utc>,
region: &'static str, region: &'static str,
pub(crate) proxy_id: u16, // for generating cancel keys per region/process
pub(crate) span: Span, pub(crate) span: Span,
// filled in as they are discovered // filled in as they are discovered
@@ -92,6 +93,7 @@ impl Clone for RequestContext {
protocol: inner.protocol, protocol: inner.protocol,
first_packet: inner.first_packet, first_packet: inner.first_packet,
region: inner.region, region: inner.region,
proxy_id: inner.proxy_id,
span: info_span!("background_task"), span: info_span!("background_task"),
project: inner.project, project: inner.project,
@@ -124,6 +126,7 @@ impl RequestContext {
conn_info: ConnectionInfo, conn_info: ConnectionInfo,
protocol: Protocol, protocol: Protocol,
region: &'static str, region: &'static str,
proxy_id: u16,
) -> Self { ) -> Self {
// TODO: be careful with long lived spans // TODO: be careful with long lived spans
let span = info_span!( let span = info_span!(
@@ -141,6 +144,7 @@ impl RequestContext {
protocol, protocol,
first_packet: Utc::now(), first_packet: Utc::now(),
region, region,
proxy_id,
span, span,
project: None, project: None,
@@ -172,7 +176,7 @@ impl RequestContext {
let ip = IpAddr::from([127, 0, 0, 1]); let ip = IpAddr::from([127, 0, 0, 1]);
let addr = SocketAddr::new(ip, 5432); let addr = SocketAddr::new(ip, 5432);
let conn_info = ConnectionInfo { addr, extra: None }; let conn_info = ConnectionInfo { addr, extra: None };
RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp, "test") RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp, "test", 1)
} }
pub(crate) fn console_application_name(&self) -> String { pub(crate) fn console_application_name(&self) -> String {
@@ -334,6 +338,10 @@ impl RequestContext {
.latency_timer .latency_timer
.success(); .success();
} }
pub(crate) fn proxy_id(&self) -> u16 {
self.0.try_lock().expect("should not deadlock").proxy_id
}
} }
pub(crate) struct LatencyTimerPause<'a> { pub(crate) struct LatencyTimerPause<'a> {

View File

@@ -59,6 +59,7 @@ pub async fn task_main(
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>, cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>, endpoint_rate_limiter: Arc<EndpointRateLimiter>,
proxy_id: u16,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
scopeguard::defer! { scopeguard::defer! {
info!("proxy has shut down"); info!("proxy has shut down");
@@ -124,6 +125,7 @@ pub async fn task_main(
conn_info, conn_info,
crate::metrics::Protocol::Tcp, crate::metrics::Protocol::Tcp,
&config.region, &config.region,
proxy_id,
); );
let res = handle_client( let res = handle_client(
@@ -358,7 +360,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
node.cancel_closure node.cancel_closure
.set_ip_allowlist(ip_allowlist.unwrap_or_default()); .set_ip_allowlist(ip_allowlist.unwrap_or_default());
let session = cancellation_handler.get_session(); let session = cancellation_handler.get_session(ctx.proxy_id());
prepare_client_connection(&node, &session, &mut stream).await?; prepare_client_connection(&node, &session, &mut stream).await?;
// Before proxy passing, forward to compute whatever data is left in the // Before proxy passing, forward to compute whatever data is left in the

View File

@@ -63,6 +63,7 @@ pub async fn task_main(
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>, cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>, endpoint_rate_limiter: Arc<EndpointRateLimiter>,
proxy_id: u16,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
scopeguard::defer! { scopeguard::defer! {
info!("websocket server has shut down"); info!("websocket server has shut down");
@@ -198,6 +199,7 @@ pub async fn task_main(
conn, conn,
conn_info, conn_info,
session_id, session_id,
proxy_id,
)) ))
.await; .await;
} }
@@ -324,6 +326,7 @@ async fn connection_handler(
conn: AsyncRW, conn: AsyncRW,
conn_info: ConnectionInfo, conn_info: ConnectionInfo,
session_id: uuid::Uuid, session_id: uuid::Uuid,
proxy_id: u16,
) { ) {
let session_id = AtomicTake::new(session_id); let session_id = AtomicTake::new(session_id);
@@ -371,6 +374,7 @@ async fn connection_handler(
http_request_token, http_request_token,
endpoint_rate_limiter.clone(), endpoint_rate_limiter.clone(),
cancellations, cancellations,
proxy_id,
) )
.in_current_span() .in_current_span()
.map_ok_or_else(api_error_into_response, |r| r), .map_ok_or_else(api_error_into_response, |r| r),
@@ -419,6 +423,7 @@ async fn request_handler(
http_cancellation_token: CancellationToken, http_cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>, endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellations: TaskTracker, cancellations: TaskTracker,
proxy_id: u16,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> { ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
let host = request let host = request
.headers() .headers()
@@ -436,6 +441,7 @@ async fn request_handler(
conn_info, conn_info,
crate::metrics::Protocol::Ws, crate::metrics::Protocol::Ws,
&config.region, &config.region,
proxy_id,
); );
let span = ctx.span(); let span = ctx.span();
@@ -473,6 +479,7 @@ async fn request_handler(
conn_info, conn_info,
crate::metrics::Protocol::Http, crate::metrics::Protocol::Http,
&config.region, &config.region,
proxy_id,
); );
let span = ctx.span(); let span = ctx.span();