Compare commits

...

1 Commits

Author SHA1 Message Date
Ivan Efremov
0acc612e3e impr(proxy): introduce proxy_id for cancel key 2025-01-20 18:43:16 +02:00
9 changed files with 71 additions and 8 deletions

View File

@@ -14,7 +14,7 @@ use proxy::auth::backend::local::{LocalBackend, JWKS_ROLE_MAP};
use proxy::auth::{self};
use proxy::cancellation::CancellationHandlerMain;
use proxy::config::{
self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig,
self, obfuscated_proxy_id, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig
};
use proxy::control_plane::locks::ApiLocks;
use proxy::control_plane::messages::{EndpointJwksResponse, JwksSettings};
@@ -218,6 +218,7 @@ async fn main() -> anyhow::Result<()> {
proxy::metrics::CancellationSource::Local,
)),
endpoint_rate_limiter,
obfuscated_proxy_id(std::process::id(), "local"),
);
match futures::future::select(pin!(maintenance_tasks.join_next()), pin!(task)).await {

View File

@@ -10,6 +10,7 @@ use clap::Arg;
use futures::future::Either;
use futures::TryFutureExt;
use itertools::Itertools;
use proxy::config::obfuscated_proxy_id;
use proxy::context::RequestContext;
use proxy::metrics::{Metrics, ThreadPoolMetrics};
use proxy::protocol2::ConnectionInfo;
@@ -185,6 +186,7 @@ async fn task_main(
},
proxy::metrics::Protocol::SniRouter,
"sni",
obfuscated_proxy_id(std::process::id(), "sni-router"), // just a shim for context
);
handle_client(ctx, dest_suffix, tls_config, tls_server_end_point, socket).await
}

View File

@@ -10,7 +10,7 @@ use proxy::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned};
use proxy::cancellation::{CancelMap, CancellationHandler};
use proxy::config::{
self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig,
ProjectInfoCacheOptions, ProxyConfig, ProxyProtocolV2,
ProjectInfoCacheOptions, ProxyConfig, ProxyProtocolV2, obfuscated_proxy_id,
};
use proxy::context::parquet::ParquetUploadArgs;
use proxy::http::health_server::AppMetrics;
@@ -396,6 +396,8 @@ async fn main() -> anyhow::Result<()> {
None => None,
};
let proxy_id: u16 = obfuscated_proxy_id(std::process::id(), &args.region);
let cancellation_handler = Arc::new(CancellationHandler::<
Option<Arc<Mutex<RedisPublisherClient>>>,
>::new(
@@ -437,6 +439,7 @@ async fn main() -> anyhow::Result<()> {
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
proxy_id,
));
}
@@ -448,6 +451,7 @@ async fn main() -> anyhow::Result<()> {
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
proxy_id,
));
}
}
@@ -459,6 +463,7 @@ async fn main() -> anyhow::Result<()> {
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
proxy_id,
));
}
}

View File

@@ -80,15 +80,24 @@ impl ReportableError for CancelError {
impl<P: CancellationPublisher> CancellationHandler<P> {
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
pub(crate) fn get_session(self: Arc<Self>) -> Session<P> {
pub(crate) fn get_session(self: Arc<Self>, proxy_id: u16) -> Session<P> {
// we intentionally generate a random "backend pid" and "secret key" here.
// we use the corresponding u64 as an identifier for the
// actual endpoint+pid+secret for postgres/pgbouncer.
//
// if we forwarded the backend_pid from postgres to the client, there would be a lot
// of overlap between our computes as most pids are small (~100).
let key = loop {
let key = rand::random();
let key_rand: u64 = rand::random::<u64>() & 0x0000_ffff_ffff_ffff;
let backend_pid = ((proxy_id as u32) << 16) | ((key_rand >> 32) as u32) as u32;
let cancel_key = (key_rand as u32) as i32;
let key = CancelKeyData {
backend_pid: (backend_pid as i32),
cancel_key,
};
// Random key collisions are unlikely to happen here, but they're still possible,
// which is why we have to take care not to rewrite an existing key.
@@ -451,7 +460,7 @@ mod tests {
CancellationSource::FromRedis,
));
let session = cancellation_handler.clone().get_session();
let session = cancellation_handler.clone().get_session(123);
assert!(cancellation_handler.contains(&session));
drop(session);
// Check that the session has been dropped.

View File

@@ -16,6 +16,8 @@ use crate::serverless::GlobalConnPoolOptions;
pub use crate::tls::server_config::{configure_tls, TlsConfig};
use crate::types::Host;
use sha2::{Digest, Sha256};
pub struct ProxyConfig {
pub tls_config: Option<TlsConfig>,
pub metric_collection: Option<MetricCollectionConfig>,
@@ -416,6 +418,23 @@ impl FromStr for ConcurrencyLockOptions {
}
}
fn map_u32_to_u8(value: u32) -> u8 {
((value * 31 + 17) % 255) as u8
}
pub fn obfuscated_proxy_id(process_id: u32, region_id: &str) -> u16 {
let process_id = map_u32_to_u8(process_id);
let hash_region_id = Sha256::digest(region_id.as_bytes());
const BASE: u64 = 257;
let combined_region = hash_region_id.iter().enumerate().fold(0u64, |acc, (i, &byte)| {
(acc + (byte as u64 * BASE.pow(i as u32))) % 255
});
let combined_region = (combined_region % 255) as u8;
((combined_region as u16) * 257 + (process_id as u16)) % 65535
}
#[cfg(test)]
mod tests {
use super::*;
@@ -511,4 +530,12 @@ mod tests {
Ok(())
}
#[test]
fn test_proxy_id_obfuscation() {
let process_id = 123;
let region_id = "us-west-2";
let proxy_id = obfuscated_proxy_id(process_id, region_id);
assert_eq!(proxy_id, 0x1f7b);
}
}

View File

@@ -25,6 +25,7 @@ pub async fn task_main(
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
proxy_id: u16,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("proxy has shut down");
@@ -89,6 +90,7 @@ pub async fn task_main(
peer_addr,
crate::metrics::Protocol::Tcp,
&config.region,
proxy_id,
);
let res = handle_client(
@@ -222,7 +224,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
node.cancel_closure
.set_ip_allowlist(ip_allowlist.unwrap_or_default());
let session = cancellation_handler.get_session();
let session = cancellation_handler.get_session(ctx.proxy_id());
prepare_client_connection(&node, &session, &mut stream).await?;
// Before proxy passing, forward to compute whatever data is left in the

View File

@@ -46,6 +46,7 @@ struct RequestContextInner {
pub(crate) protocol: Protocol,
first_packet: chrono::DateTime<Utc>,
region: &'static str,
pub(crate) proxy_id: u16, // for generating cancel keys per region/process
pub(crate) span: Span,
// filled in as they are discovered
@@ -92,6 +93,7 @@ impl Clone for RequestContext {
protocol: inner.protocol,
first_packet: inner.first_packet,
region: inner.region,
proxy_id: inner.proxy_id,
span: info_span!("background_task"),
project: inner.project,
@@ -124,6 +126,7 @@ impl RequestContext {
conn_info: ConnectionInfo,
protocol: Protocol,
region: &'static str,
proxy_id: u16,
) -> Self {
// TODO: be careful with long lived spans
let span = info_span!(
@@ -141,6 +144,7 @@ impl RequestContext {
protocol,
first_packet: Utc::now(),
region,
proxy_id,
span,
project: None,
@@ -172,7 +176,7 @@ impl RequestContext {
let ip = IpAddr::from([127, 0, 0, 1]);
let addr = SocketAddr::new(ip, 5432);
let conn_info = ConnectionInfo { addr, extra: None };
RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp, "test")
RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp, "test", 1)
}
pub(crate) fn console_application_name(&self) -> String {
@@ -334,6 +338,10 @@ impl RequestContext {
.latency_timer
.success();
}
pub(crate) fn proxy_id(&self) -> u16 {
self.0.try_lock().expect("should not deadlock").proxy_id
}
}
pub(crate) struct LatencyTimerPause<'a> {

View File

@@ -59,6 +59,7 @@ pub async fn task_main(
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
proxy_id: u16,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("proxy has shut down");
@@ -124,6 +125,7 @@ pub async fn task_main(
conn_info,
crate::metrics::Protocol::Tcp,
&config.region,
proxy_id,
);
let res = handle_client(
@@ -358,7 +360,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
node.cancel_closure
.set_ip_allowlist(ip_allowlist.unwrap_or_default());
let session = cancellation_handler.get_session();
let session = cancellation_handler.get_session(ctx.proxy_id());
prepare_client_connection(&node, &session, &mut stream).await?;
// Before proxy passing, forward to compute whatever data is left in the

View File

@@ -63,6 +63,7 @@ pub async fn task_main(
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
proxy_id: u16,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("websocket server has shut down");
@@ -198,6 +199,7 @@ pub async fn task_main(
conn,
conn_info,
session_id,
proxy_id,
))
.await;
}
@@ -324,6 +326,7 @@ async fn connection_handler(
conn: AsyncRW,
conn_info: ConnectionInfo,
session_id: uuid::Uuid,
proxy_id: u16,
) {
let session_id = AtomicTake::new(session_id);
@@ -371,6 +374,7 @@ async fn connection_handler(
http_request_token,
endpoint_rate_limiter.clone(),
cancellations,
proxy_id,
)
.in_current_span()
.map_ok_or_else(api_error_into_response, |r| r),
@@ -419,6 +423,7 @@ async fn request_handler(
http_cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellations: TaskTracker,
proxy_id: u16,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
let host = request
.headers()
@@ -436,6 +441,7 @@ async fn request_handler(
conn_info,
crate::metrics::Protocol::Ws,
&config.region,
proxy_id,
);
let span = ctx.span();
@@ -473,6 +479,7 @@ async fn request_handler(
conn_info,
crate::metrics::Protocol::Http,
&config.region,
proxy_id,
);
let span = ctx.span();