proxy: random changes (#8602)

## Problem

1. Hard to correlate startup parameters with the endpoint that provided
them.
2. Some configurations are not needed in the `ProxyConfig` struct.

## Summary of changes

Because of some borrow checker fun, I needed to switch to an
interior-mutability implementation of our `RequestMonitoring` context
system. Using https://docs.rs/try-lock/latest/try_lock/ as a cheap lock
for such a use-case (needed to be thread safe).

Removed the lock of each startup message, instead just logging only the
startup params in a successful handshake.

Also removed from values from `ProxyConfig` and kept as arguments.
(needed for local-proxy config)
This commit is contained in:
Conrad Ludgate
2024-08-07 14:37:03 +01:00
committed by GitHub
parent 4d7c0dac93
commit ad0988f278
31 changed files with 386 additions and 276 deletions

View File

@@ -218,7 +218,7 @@ impl RateBucketInfo {
impl AuthenticationConfig {
pub fn check_rate_limit(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
config: &AuthenticationConfig,
secret: AuthSecret,
endpoint: &EndpointId,
@@ -243,7 +243,7 @@ impl AuthenticationConfig {
let limit_not_exceeded = self.rate_limiter.check(
(
endpoint_int,
MaskedIp::new(ctx.peer_addr, config.rate_limit_ip_subnet),
MaskedIp::new(ctx.peer_addr(), config.rate_limit_ip_subnet),
),
password_weight,
);
@@ -274,7 +274,7 @@ impl AuthenticationConfig {
///
/// All authentication flows will emit an AuthenticationOk message if successful.
async fn auth_quirks(
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
api: &impl console::Api,
user_info: ComputeUserInfoMaybeEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
@@ -303,8 +303,8 @@ async fn auth_quirks(
let (allowed_ips, maybe_secret) = api.get_allowed_ips_and_secret(ctx, &info).await?;
// check allowed list
if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr));
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) {
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
}
if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) {
@@ -356,7 +356,7 @@ async fn auth_quirks(
}
async fn authenticate_with_secret(
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
secret: AuthSecret,
info: ComputeUserInfo,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
@@ -421,7 +421,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
#[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
pub async fn authenticate(
self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
@@ -467,7 +467,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
impl BackendType<'_, ComputeUserInfo, &()> {
pub async fn get_role_secret(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
) -> Result<CachedRoleSecret, GetAuthInfoError> {
use BackendType::*;
match self {
@@ -478,7 +478,7 @@ impl BackendType<'_, ComputeUserInfo, &()> {
pub async fn get_allowed_ips_and_secret(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
use BackendType::*;
match self {
@@ -492,7 +492,7 @@ impl BackendType<'_, ComputeUserInfo, &()> {
impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> {
async fn wake_compute(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
use BackendType::*;
@@ -514,7 +514,7 @@ impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> {
impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> {
async fn wake_compute(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
use BackendType::*;
@@ -571,7 +571,7 @@ mod tests {
impl console::Api for Auth {
async fn get_role_secret(
&self,
_ctx: &mut RequestMonitoring,
_ctx: &RequestMonitoring,
_user_info: &super::ComputeUserInfo,
) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError> {
Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone())))
@@ -579,7 +579,7 @@ mod tests {
async fn get_allowed_ips_and_secret(
&self,
_ctx: &mut RequestMonitoring,
_ctx: &RequestMonitoring,
_user_info: &super::ComputeUserInfo,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>
{
@@ -591,7 +591,7 @@ mod tests {
async fn wake_compute(
&self,
_ctx: &mut RequestMonitoring,
_ctx: &RequestMonitoring,
_user_info: &super::ComputeUserInfo,
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
unimplemented!()
@@ -665,7 +665,7 @@ mod tests {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut ctx = RequestMonitoring::test();
let ctx = RequestMonitoring::test();
let api = Auth {
ips: vec![],
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
@@ -723,7 +723,7 @@ mod tests {
));
let _creds = auth_quirks(
&mut ctx,
&ctx,
&api,
user_info,
&mut stream,
@@ -742,7 +742,7 @@ mod tests {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut ctx = RequestMonitoring::test();
let ctx = RequestMonitoring::test();
let api = Auth {
ips: vec![],
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
@@ -775,7 +775,7 @@ mod tests {
));
let _creds = auth_quirks(
&mut ctx,
&ctx,
&api,
user_info,
&mut stream,
@@ -794,7 +794,7 @@ mod tests {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut ctx = RequestMonitoring::test();
let ctx = RequestMonitoring::test();
let api = Auth {
ips: vec![],
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
@@ -828,7 +828,7 @@ mod tests {
));
let creds = auth_quirks(
&mut ctx,
&ctx,
&api,
user_info,
&mut stream,

View File

@@ -12,7 +12,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};
pub(super) async fn authenticate(
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
creds: ComputeUserInfo,
client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
config: &'static AuthenticationConfig,
@@ -27,7 +27,7 @@ pub(super) async fn authenticate(
}
AuthSecret::Scram(secret) => {
info!("auth endpoint chooses SCRAM");
let scram = auth::Scram(&secret, &mut *ctx);
let scram = auth::Scram(&secret, ctx);
let auth_outcome = tokio::time::timeout(
config.scram_protocol_timeout,

View File

@@ -18,7 +18,7 @@ use tracing::{info, warn};
/// These properties are benefical for serverless JS workers, so we
/// use this mechanism for websocket connections.
pub async fn authenticate_cleartext(
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
info: ComputeUserInfo,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
secret: AuthSecret,
@@ -28,7 +28,7 @@ pub async fn authenticate_cleartext(
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
// pause the timer while we communicate with the client
let paused = ctx.latency_timer.pause(crate::metrics::Waiting::Client);
let paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let ep = EndpointIdInt::from(&info.endpoint);
@@ -60,7 +60,7 @@ pub async fn authenticate_cleartext(
/// Similar to [`authenticate_cleartext`], but there's a specific password format,
/// and passwords are not yet validated (we don't know how to validate them!)
pub async fn password_hack_no_authentication(
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
info: ComputeUserInfoNoEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
) -> auth::Result<ComputeCredentials> {
@@ -68,7 +68,7 @@ pub async fn password_hack_no_authentication(
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer.pause(crate::metrics::Waiting::Client);
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let payload = AuthFlow::new(client)
.begin(auth::PasswordHack)

View File

@@ -57,7 +57,7 @@ pub fn new_psql_session_id() -> String {
}
pub(super) async fn authenticate(
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
link_uri: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<NodeInfo> {

View File

@@ -84,7 +84,7 @@ pub fn endpoint_sni(
impl ComputeUserInfoMaybeEndpoint {
pub fn parse(
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
params: &StartupMessageParams,
sni: Option<&str>,
common_names: Option<&HashSet<String>>,
@@ -249,8 +249,8 @@ mod tests {
fn parse_bare_minimum() -> anyhow::Result<()> {
// According to postgresql, only `user` should be required.
let options = StartupMessageParams::new([("user", "john_doe")]);
let mut ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id, None);
@@ -264,8 +264,8 @@ mod tests {
("database", "world"), // should be ignored
("foo", "bar"), // should be ignored
]);
let mut ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id, None);
@@ -279,9 +279,9 @@ mod tests {
let sni = Some("foo.localhost");
let common_names = Some(["localhost".into()].into());
let mut ctx = RequestMonitoring::test();
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id.as_deref(), Some("foo"));
assert_eq!(user_info.options.get_cache_key("foo"), "foo");
@@ -296,8 +296,8 @@ mod tests {
("options", "-ckey=1 project=bar -c geqo=off"),
]);
let mut ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
@@ -311,8 +311,8 @@ mod tests {
("options", "-ckey=1 endpoint=bar -c geqo=off"),
]);
let mut ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
@@ -329,8 +329,8 @@ mod tests {
),
]);
let mut ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
assert_eq!(user_info.user, "john_doe");
assert!(user_info.endpoint_id.is_none());
@@ -344,8 +344,8 @@ mod tests {
("options", "-ckey=1 endpoint=bar project=foo -c geqo=off"),
]);
let mut ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
assert_eq!(user_info.user, "john_doe");
assert!(user_info.endpoint_id.is_none());
@@ -359,9 +359,9 @@ mod tests {
let sni = Some("baz.localhost");
let common_names = Some(["localhost".into()].into());
let mut ctx = RequestMonitoring::test();
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id.as_deref(), Some("baz"));
@@ -374,16 +374,16 @@ mod tests {
let common_names = Some(["a.com".into(), "b.com".into()].into());
let sni = Some("p1.a.com");
let mut ctx = RequestMonitoring::test();
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
let common_names = Some(["a.com".into(), "b.com".into()].into());
let sni = Some("p1.b.com");
let mut ctx = RequestMonitoring::test();
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
Ok(())
@@ -397,10 +397,9 @@ mod tests {
let sni = Some("second.localhost");
let common_names = Some(["localhost".into()].into());
let mut ctx = RequestMonitoring::test();
let err =
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())
.expect_err("should fail");
let ctx = RequestMonitoring::test();
let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
.expect_err("should fail");
match err {
InconsistentProjectNames { domain, option } => {
assert_eq!(option, "first");
@@ -417,10 +416,9 @@ mod tests {
let sni = Some("project.localhost");
let common_names = Some(["example.com".into()].into());
let mut ctx = RequestMonitoring::test();
let err =
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())
.expect_err("should fail");
let ctx = RequestMonitoring::test();
let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
.expect_err("should fail");
match err {
UnknownCommonName { cn } => {
assert_eq!(cn, "localhost");
@@ -438,9 +436,9 @@ mod tests {
let sni = Some("project.localhost");
let common_names = Some(["localhost".into()].into());
let mut ctx = RequestMonitoring::test();
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.endpoint_id.as_deref(), Some("project"));
assert_eq!(
user_info.options.get_cache_key("project"),

View File

@@ -27,7 +27,7 @@ pub trait AuthMethod {
pub struct Begin;
/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
pub struct Scram<'a>(pub &'a scram::ServerSecret, pub &'a mut RequestMonitoring);
pub struct Scram<'a>(pub &'a scram::ServerSecret, pub &'a RequestMonitoring);
impl AuthMethod for Scram<'_> {
#[inline(always)]
@@ -155,7 +155,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
let Scram(secret, ctx) = self.state;
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer.pause(crate::metrics::Waiting::Client);
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
// Initial client message contains the chosen auth method's name.
let msg = self.stream.read_password_message().await?;
@@ -168,10 +168,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
}
match sasl.method {
SCRAM_SHA_256 => ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256),
SCRAM_SHA_256_PLUS => {
ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256Plus)
}
SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
SCRAM_SHA_256_PLUS => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus),
_ => {}
}
info!("client chooses {}", sasl.method);

View File

@@ -205,7 +205,7 @@ async fn task_main(
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
raw_stream: S,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
@@ -256,13 +256,13 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
}
async fn handle_client(
mut ctx: RequestMonitoring,
ctx: RequestMonitoring,
dest_suffix: Arc<String>,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
stream: impl AsyncRead + AsyncWrite + Unpin,
) -> anyhow::Result<()> {
let mut tls_stream = ssl_handshake(&mut ctx, stream, tls_config, tls_server_end_point).await?;
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?;
// Cut off first part of the SNI domain
// We receive required destination details in the format of

View File

@@ -5,6 +5,7 @@ use aws_config::meta::region::RegionProviderChain;
use aws_config::profile::ProfileFileCredentialsProvider;
use aws_config::provider_config::ProviderConfig;
use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider;
use aws_config::Region;
use futures::future::Either;
use proxy::auth;
use proxy::auth::backend::AuthRateLimiter;
@@ -290,9 +291,10 @@ async fn main() -> anyhow::Result<()> {
let config = build_config(&args)?;
info!("Authentication backend: {}", config.auth_backend);
info!("Using region: {}", config.aws_region);
info!("Using region: {}", args.aws_region);
let region_provider = RegionProviderChain::default_provider().or_else(&*config.aws_region); // Replace with your Redis region if needed
let region_provider =
RegionProviderChain::default_provider().or_else(Region::new(args.aws_region.clone()));
let provider_conf =
ProviderConfig::without_region().with_region(region_provider.region().await);
let aws_credentials_provider = {
@@ -318,7 +320,7 @@ async fn main() -> anyhow::Result<()> {
};
let elasticache_credentials_provider = Arc::new(elasticache::CredentialsProvider::new(
elasticache::AWSIRSAConfig::new(
config.aws_region.clone(),
args.aws_region.clone(),
args.redis_cluster_name,
args.redis_user_id,
),
@@ -376,11 +378,14 @@ async fn main() -> anyhow::Result<()> {
let cancel_map = CancelMap::default();
let redis_rps_limit = Vec::leak(args.redis_rps_limit.clone());
RateBucketInfo::validate(redis_rps_limit)?;
let redis_publisher = match &regional_redis_client {
Some(redis_publisher) => Some(Arc::new(Mutex::new(RedisPublisherClient::new(
redis_publisher.clone(),
args.region.clone(),
&config.redis_rps_limit,
redis_rps_limit,
)?))),
None => None,
};
@@ -656,7 +661,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
)?;
let http_config = HttpConfig {
request_timeout: args.sql_over_http.sql_over_http_timeout,
pool_options: GlobalConnPoolOptions {
max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint,
gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch,
@@ -676,9 +680,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
};
let mut redis_rps_limit = args.redis_rps_limit.clone();
RateBucketInfo::validate(&mut redis_rps_limit)?;
let config = Box::leak(Box::new(ProxyConfig {
tls_config,
auth_backend,
@@ -687,11 +688,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
http_config,
authentication_config,
require_client_ip: args.require_client_ip,
disable_ip_check_for_http: args.disable_ip_check_for_http,
redis_rps_limit,
handshake_timeout: args.handshake_timeout,
region: args.region.clone(),
aws_region: args.aws_region.clone(),
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
connect_compute_locks,
connect_to_compute_retry_config: config::RetryConfig::parse(

View File

@@ -68,7 +68,7 @@ impl EndpointsCache {
ready: AtomicBool::new(false),
}
}
pub async fn is_valid(&self, ctx: &mut RequestMonitoring, endpoint: &EndpointId) -> bool {
pub async fn is_valid(&self, ctx: &RequestMonitoring, endpoint: &EndpointId) -> bool {
if !self.ready.load(Ordering::Acquire) {
return true;
}

View File

@@ -288,12 +288,12 @@ impl ConnCfg {
/// Connect to a corresponding compute node.
pub async fn connect(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
allow_self_signed_compute: bool,
aux: MetricsAuxInfo,
timeout: Duration,
) -> Result<PostgresConnection, ConnectionError> {
let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Compute);
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
drop(pause);
@@ -316,14 +316,14 @@ impl ConnCfg {
)?;
// connect_raw() will not use TLS if sslmode is "disable"
let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Compute);
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (client, connection) = self.0.connect_raw(stream, tls).await?;
drop(pause);
tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
let stream = connection.stream.into_inner();
info!(
cold_start_info = ctx.cold_start_info.as_str(),
cold_start_info = ctx.cold_start_info().as_str(),
"connected to compute node at {host} ({socket_addr}) sslmode={:?}",
self.0.get_ssl_mode()
);
@@ -342,7 +342,7 @@ impl ConnCfg {
params,
cancel_closure,
aux,
_guage: Metrics::get().proxy.db_connections.guard(ctx.protocol),
_guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
};
Ok(connection)

View File

@@ -31,11 +31,8 @@ pub struct ProxyConfig {
pub http_config: HttpConfig,
pub authentication_config: AuthenticationConfig,
pub require_client_ip: bool,
pub disable_ip_check_for_http: bool,
pub redis_rps_limit: Vec<RateBucketInfo>,
pub region: String,
pub handshake_timeout: Duration,
pub aws_region: String,
pub wake_compute_retry_config: RetryConfig,
pub connect_compute_locks: ApiLocks<Host>,
pub connect_to_compute_retry_config: RetryConfig,
@@ -55,7 +52,6 @@ pub struct TlsConfig {
}
pub struct HttpConfig {
pub request_timeout: tokio::time::Duration,
pub pool_options: GlobalConnPoolOptions,
pub cancel_set: CancelSet,
pub client_conn_threshold: u64,

View File

@@ -292,7 +292,7 @@ pub struct NodeInfo {
impl NodeInfo {
pub async fn connect(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
timeout: Duration,
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
self.config
@@ -330,20 +330,20 @@ pub(crate) trait Api {
/// We still have to mock the scram to avoid leaking information that user doesn't exist.
async fn get_role_secret(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<CachedRoleSecret, errors::GetAuthInfoError>;
async fn get_allowed_ips_and_secret(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
/// Wake up the compute node and return the corresponding connection info.
async fn wake_compute(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<CachedNodeInfo, errors::WakeComputeError>;
}
@@ -363,7 +363,7 @@ pub enum ConsoleBackend {
impl Api for ConsoleBackend {
async fn get_role_secret(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
use ConsoleBackend::*;
@@ -378,7 +378,7 @@ impl Api for ConsoleBackend {
async fn get_allowed_ips_and_secret(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError> {
use ConsoleBackend::*;
@@ -393,7 +393,7 @@ impl Api for ConsoleBackend {
async fn wake_compute(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<CachedNodeInfo, errors::WakeComputeError> {
use ConsoleBackend::*;

View File

@@ -158,7 +158,7 @@ impl super::Api for Api {
#[tracing::instrument(skip_all)]
async fn get_role_secret(
&self,
_ctx: &mut RequestMonitoring,
_ctx: &RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<CachedRoleSecret, GetAuthInfoError> {
Ok(CachedRoleSecret::new_uncached(
@@ -168,7 +168,7 @@ impl super::Api for Api {
async fn get_allowed_ips_and_secret(
&self,
_ctx: &mut RequestMonitoring,
_ctx: &RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
Ok((
@@ -182,7 +182,7 @@ impl super::Api for Api {
#[tracing::instrument(skip_all)]
async fn wake_compute(
&self,
_ctx: &mut RequestMonitoring,
_ctx: &RequestMonitoring,
_user_info: &ComputeUserInfo,
) -> Result<CachedNodeInfo, WakeComputeError> {
self.do_wake_compute().map_ok(Cached::new_uncached).await

View File

@@ -57,7 +57,7 @@ impl Api {
async fn do_get_auth_info(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<AuthInfo, GetAuthInfoError> {
if !self
@@ -69,7 +69,7 @@ impl Api {
info!("endpoint is not valid, skipping the request");
return Ok(AuthInfo::default());
}
let request_id = ctx.session_id.to_string();
let request_id = ctx.session_id().to_string();
let application_name = ctx.console_application_name();
async {
let request = self
@@ -77,7 +77,7 @@ impl Api {
.get("proxy_get_role_secret")
.header("X-Request-ID", &request_id)
.header("Authorization", format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id)])
.query(&[("session_id", ctx.session_id())])
.query(&[
("application_name", application_name.as_str()),
("project", user_info.endpoint.as_str()),
@@ -87,7 +87,7 @@ impl Api {
info!(url = request.url().as_str(), "sending http request");
let start = Instant::now();
let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Cplane);
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
let response = self.endpoint.execute(request).await?;
drop(pause);
info!(duration = ?start.elapsed(), "received http response");
@@ -130,10 +130,10 @@ impl Api {
async fn do_wake_compute(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<NodeInfo, WakeComputeError> {
let request_id = ctx.session_id.to_string();
let request_id = ctx.session_id().to_string();
let application_name = ctx.console_application_name();
async {
let mut request_builder = self
@@ -141,7 +141,7 @@ impl Api {
.get("proxy_wake_compute")
.header("X-Request-ID", &request_id)
.header("Authorization", format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id)])
.query(&[("session_id", ctx.session_id())])
.query(&[
("application_name", application_name.as_str()),
("project", user_info.endpoint.as_str()),
@@ -156,7 +156,7 @@ impl Api {
info!(url = request.url().as_str(), "sending http request");
let start = Instant::now();
let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Cplane);
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
let response = self.endpoint.execute(request).await?;
drop(pause);
info!(duration = ?start.elapsed(), "received http response");
@@ -192,7 +192,7 @@ impl super::Api for Api {
#[tracing::instrument(skip_all)]
async fn get_role_secret(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<CachedRoleSecret, GetAuthInfoError> {
let normalized_ep = &user_info.endpoint.normalize();
@@ -226,7 +226,7 @@ impl super::Api for Api {
async fn get_allowed_ips_and_secret(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
let normalized_ep = &user_info.endpoint.normalize();
@@ -268,7 +268,7 @@ impl super::Api for Api {
#[tracing::instrument(skip_all)]
async fn wake_compute(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<CachedNodeInfo, WakeComputeError> {
let key = user_info.endpoint_cache_key();

View File

@@ -7,13 +7,14 @@ use smol_str::SmolStr;
use std::net::IpAddr;
use tokio::sync::mpsc;
use tracing::{field::display, info, info_span, Span};
use try_lock::TryLock;
use uuid::Uuid;
use crate::{
console::messages::{ColdStartInfo, MetricsAuxInfo},
error::ErrorKind,
intern::{BranchIdInt, ProjectIdInt},
metrics::{ConnectOutcome, InvalidEndpointsGroup, LatencyTimer, Metrics, Protocol},
metrics::{ConnectOutcome, InvalidEndpointsGroup, LatencyTimer, Metrics, Protocol, Waiting},
DbName, EndpointId, RoleName,
};
@@ -28,7 +29,15 @@ pub static LOG_CHAN_DISCONNECT: OnceCell<mpsc::WeakUnboundedSender<RequestData>>
///
/// This data should **not** be used for connection logic, only for observability and limiting purposes.
/// All connection logic should instead use strongly typed state machines, not a bunch of Options.
pub struct RequestMonitoring {
pub struct RequestMonitoring(
/// To allow easier use of the ctx object, we have interior mutability.
/// I would typically use a RefCell but that would break the `Send` requirements
/// so we need something with thread-safety. `TryLock` is a cheap alternative
/// that offers similar semantics to a `RefCell` but with synchronisation.
TryLock<RequestMonitoringInner>,
);
struct RequestMonitoringInner {
pub peer_addr: IpAddr,
pub session_id: Uuid,
pub protocol: Protocol,
@@ -85,7 +94,7 @@ impl RequestMonitoring {
role = tracing::field::Empty,
);
Self {
let inner = RequestMonitoringInner {
peer_addr,
session_id,
protocol,
@@ -110,7 +119,9 @@ impl RequestMonitoring {
disconnect_sender: LOG_CHAN_DISCONNECT.get().and_then(|tx| tx.upgrade()),
latency_timer: LatencyTimer::new(protocol),
disconnect_timestamp: None,
}
};
Self(TryLock::new(inner))
}
#[cfg(test)]
@@ -119,48 +130,177 @@ impl RequestMonitoring {
}
pub fn console_application_name(&self) -> String {
let this = self.0.try_lock().expect("should not deadlock");
format!(
"{}/{}",
self.application.as_deref().unwrap_or_default(),
self.protocol
this.application.as_deref().unwrap_or_default(),
this.protocol
)
}
pub fn set_rejected(&mut self, rejected: bool) {
self.rejected = Some(rejected);
pub fn set_rejected(&self, rejected: bool) {
let mut this = self.0.try_lock().expect("should not deadlock");
this.rejected = Some(rejected);
}
pub fn set_cold_start_info(&mut self, info: ColdStartInfo) {
pub fn set_cold_start_info(&self, info: ColdStartInfo) {
self.0
.try_lock()
.expect("should not deadlock")
.set_cold_start_info(info);
}
pub fn set_db_options(&self, options: StartupMessageParams) {
let mut this = self.0.try_lock().expect("should not deadlock");
this.set_application(options.get("application_name").map(SmolStr::from));
if let Some(user) = options.get("user") {
this.set_user(user.into());
}
if let Some(dbname) = options.get("database") {
this.set_dbname(dbname.into());
}
this.pg_options = Some(options);
}
pub fn set_project(&self, x: MetricsAuxInfo) {
let mut this = self.0.try_lock().expect("should not deadlock");
if this.endpoint_id.is_none() {
this.set_endpoint_id(x.endpoint_id.as_str().into())
}
this.branch = Some(x.branch_id);
this.project = Some(x.project_id);
this.set_cold_start_info(x.cold_start_info);
}
pub fn set_project_id(&self, project_id: ProjectIdInt) {
let mut this = self.0.try_lock().expect("should not deadlock");
this.project = Some(project_id);
}
pub fn set_endpoint_id(&self, endpoint_id: EndpointId) {
self.0
.try_lock()
.expect("should not deadlock")
.set_endpoint_id(endpoint_id);
}
pub fn set_dbname(&self, dbname: DbName) {
self.0
.try_lock()
.expect("should not deadlock")
.set_dbname(dbname);
}
pub fn set_user(&self, user: RoleName) {
self.0
.try_lock()
.expect("should not deadlock")
.set_user(user);
}
pub fn set_auth_method(&self, auth_method: AuthMethod) {
let mut this = self.0.try_lock().expect("should not deadlock");
this.auth_method = Some(auth_method);
}
pub fn has_private_peer_addr(&self) -> bool {
self.0
.try_lock()
.expect("should not deadlock")
.has_private_peer_addr()
}
pub fn set_error_kind(&self, kind: ErrorKind) {
let mut this = self.0.try_lock().expect("should not deadlock");
// Do not record errors from the private address to metrics.
if !this.has_private_peer_addr() {
Metrics::get().proxy.errors_total.inc(kind);
}
if let Some(ep) = &this.endpoint_id {
let metric = &Metrics::get().proxy.endpoints_affected_by_errors;
let label = metric.with_labels(kind);
metric.get_metric(label).measure(ep);
}
this.error_kind = Some(kind);
}
pub fn set_success(&self) {
let mut this = self.0.try_lock().expect("should not deadlock");
this.success = true;
}
pub fn log_connect(&self) {
self.0
.try_lock()
.expect("should not deadlock")
.log_connect();
}
pub fn protocol(&self) -> Protocol {
self.0.try_lock().expect("should not deadlock").protocol
}
pub fn span(&self) -> Span {
self.0.try_lock().expect("should not deadlock").span.clone()
}
pub fn session_id(&self) -> Uuid {
self.0.try_lock().expect("should not deadlock").session_id
}
pub fn peer_addr(&self) -> IpAddr {
self.0.try_lock().expect("should not deadlock").peer_addr
}
pub fn cold_start_info(&self) -> ColdStartInfo {
self.0
.try_lock()
.expect("should not deadlock")
.cold_start_info
}
pub fn latency_timer_pause(&self, waiting_for: Waiting) -> LatencyTimerPause {
LatencyTimerPause {
ctx: self,
start: tokio::time::Instant::now(),
waiting_for,
}
}
pub fn success(&self) {
self.0
.try_lock()
.expect("should not deadlock")
.latency_timer
.success()
}
}
pub struct LatencyTimerPause<'a> {
ctx: &'a RequestMonitoring,
start: tokio::time::Instant,
waiting_for: Waiting,
}
impl Drop for LatencyTimerPause<'_> {
fn drop(&mut self) {
self.ctx
.0
.try_lock()
.expect("should not deadlock")
.latency_timer
.unpause(self.start, self.waiting_for);
}
}
impl RequestMonitoringInner {
fn set_cold_start_info(&mut self, info: ColdStartInfo) {
self.cold_start_info = info;
self.latency_timer.cold_start_info(info);
}
pub fn set_db_options(&mut self, options: StartupMessageParams) {
self.set_application(options.get("application_name").map(SmolStr::from));
if let Some(user) = options.get("user") {
self.set_user(user.into());
}
if let Some(dbname) = options.get("database") {
self.set_dbname(dbname.into());
}
self.pg_options = Some(options);
}
pub fn set_project(&mut self, x: MetricsAuxInfo) {
if self.endpoint_id.is_none() {
self.set_endpoint_id(x.endpoint_id.as_str().into())
}
self.branch = Some(x.branch_id);
self.project = Some(x.project_id);
self.set_cold_start_info(x.cold_start_info);
}
pub fn set_project_id(&mut self, project_id: ProjectIdInt) {
self.project = Some(project_id);
}
pub fn set_endpoint_id(&mut self, endpoint_id: EndpointId) {
fn set_endpoint_id(&mut self, endpoint_id: EndpointId) {
if self.endpoint_id.is_none() {
self.span.record("ep", display(&endpoint_id));
let metric = &Metrics::get().proxy.connecting_endpoints;
@@ -176,44 +316,23 @@ impl RequestMonitoring {
}
}
pub fn set_dbname(&mut self, dbname: DbName) {
fn set_dbname(&mut self, dbname: DbName) {
self.dbname = Some(dbname);
}
pub fn set_user(&mut self, user: RoleName) {
fn set_user(&mut self, user: RoleName) {
self.span.record("role", display(&user));
self.user = Some(user);
}
pub fn set_auth_method(&mut self, auth_method: AuthMethod) {
self.auth_method = Some(auth_method);
}
pub fn has_private_peer_addr(&self) -> bool {
fn has_private_peer_addr(&self) -> bool {
match self.peer_addr {
IpAddr::V4(ip) => ip.is_private(),
_ => false,
}
}
pub fn set_error_kind(&mut self, kind: ErrorKind) {
// Do not record errors from the private address to metrics.
if !self.has_private_peer_addr() {
Metrics::get().proxy.errors_total.inc(kind);
}
if let Some(ep) = &self.endpoint_id {
let metric = &Metrics::get().proxy.endpoints_affected_by_errors;
let label = metric.with_labels(kind);
metric.get_metric(label).measure(ep);
}
self.error_kind = Some(kind);
}
pub fn set_success(&mut self) {
self.success = true;
}
pub fn log_connect(&mut self) {
fn log_connect(&mut self) {
let outcome = if self.success {
ConnectOutcome::Success
} else {
@@ -256,7 +375,7 @@ impl RequestMonitoring {
}
}
impl Drop for RequestMonitoring {
impl Drop for RequestMonitoringInner {
fn drop(&mut self) {
if self.sender.is_some() {
self.log_connect();

View File

@@ -23,7 +23,7 @@ use utils::backoff;
use crate::{config::remote_storage_from_toml, context::LOG_CHAN_DISCONNECT};
use super::{RequestMonitoring, LOG_CHAN};
use super::{RequestMonitoringInner, LOG_CHAN};
#[derive(clap::Args, Clone, Debug)]
pub struct ParquetUploadArgs {
@@ -118,8 +118,8 @@ impl<'a> serde::Serialize for Options<'a> {
}
}
impl From<&RequestMonitoring> for RequestData {
fn from(value: &RequestMonitoring) -> Self {
impl From<&RequestMonitoringInner> for RequestData {
fn from(value: &RequestMonitoringInner) -> Self {
Self {
session_id: value.session_id,
peer_addr: value.peer_addr.to_string(),

View File

@@ -370,6 +370,7 @@ pub struct CancellationRequest {
pub kind: CancellationOutcome,
}
#[derive(Clone, Copy)]
pub enum Waiting {
Cplane,
Client,
@@ -398,12 +399,6 @@ pub struct LatencyTimer {
outcome: ConnectOutcome,
}
pub struct LatencyTimerPause<'a> {
timer: &'a mut LatencyTimer,
start: time::Instant,
waiting_for: Waiting,
}
impl LatencyTimer {
pub fn new(protocol: Protocol) -> Self {
Self {
@@ -417,11 +412,13 @@ impl LatencyTimer {
}
}
pub fn pause(&mut self, waiting_for: Waiting) -> LatencyTimerPause<'_> {
LatencyTimerPause {
timer: self,
start: Instant::now(),
waiting_for,
pub fn unpause(&mut self, start: Instant, waiting_for: Waiting) {
let dur = start.elapsed();
match waiting_for {
Waiting::Cplane => self.accumulated.cplane += dur,
Waiting::Client => self.accumulated.client += dur,
Waiting::Compute => self.accumulated.compute += dur,
Waiting::RetryTimeout => self.accumulated.retry += dur,
}
}
@@ -438,18 +435,6 @@ impl LatencyTimer {
}
}
impl Drop for LatencyTimerPause<'_> {
fn drop(&mut self) {
let dur = self.start.elapsed();
match self.waiting_for {
Waiting::Cplane => self.timer.accumulated.cplane += dur,
Waiting::Client => self.timer.accumulated.client += dur,
Waiting::Compute => self.timer.accumulated.compute += dur,
Waiting::RetryTimeout => self.timer.accumulated.retry += dur,
}
}
}
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]
pub enum ConnectOutcome {
Success,

View File

@@ -113,18 +113,18 @@ pub async fn task_main(
}
};
let mut ctx = RequestMonitoring::new(
let ctx = RequestMonitoring::new(
session_id,
peer_addr,
crate::metrics::Protocol::Tcp,
&config.region,
);
let span = ctx.span.clone();
let span = ctx.span();
let startup = Box::pin(
handle_client(
config,
&mut ctx,
&ctx,
cancellation_handler,
socket,
ClientMode::Tcp,
@@ -240,7 +240,7 @@ impl ReportableError for ClientRequestError {
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
cancellation_handler: Arc<CancellationHandlerMain>,
stream: S,
mode: ClientMode,
@@ -248,25 +248,25 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
conn_gauge: NumClientConnectionsGuard<'static>,
) -> Result<Option<ProxyPassthrough<CancellationHandlerMainInternal, S>>, ClientRequestError> {
info!(
protocol = %ctx.protocol,
protocol = %ctx.protocol(),
"handling interactive connection from client"
);
let metrics = &Metrics::get().proxy;
let proto = ctx.protocol;
let proto = ctx.protocol();
let _request_gauge = metrics.connection_requests.guard(proto);
let tls = config.tls_config.as_ref();
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(stream, mode.handshake_tls(tls), record_handshake_error);
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
let (mut stream, params) =
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
return Ok(cancellation_handler
.cancel_session(cancel_key_data, ctx.session_id)
.cancel_session(cancel_key_data, ctx.session_id())
.await
.map(|()| None)?)
}

View File

@@ -46,7 +46,7 @@ pub trait ConnectMechanism {
type Error: From<Self::ConnectError>;
async fn connect_once(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
node_info: &console::CachedNodeInfo,
timeout: time::Duration,
) -> Result<Self::Connection, Self::ConnectError>;
@@ -58,7 +58,7 @@ pub trait ConnectMechanism {
pub trait ComputeConnectBackend {
async fn wake_compute(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
fn get_keys(&self) -> Option<&ComputeCredentialKeys>;
@@ -81,7 +81,7 @@ impl ConnectMechanism for TcpMechanism<'_> {
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
async fn connect_once(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
node_info: &console::CachedNodeInfo,
timeout: time::Duration,
) -> Result<PostgresConnection, Self::Error> {
@@ -98,7 +98,7 @@ impl ConnectMechanism for TcpMechanism<'_> {
/// Try to connect to the compute node, retrying if necessary.
#[tracing::instrument(skip_all)]
pub async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
mechanism: &M,
user_info: &B,
allow_self_signed_compute: bool,
@@ -126,7 +126,7 @@ where
.await
{
Ok(res) => {
ctx.latency_timer.success();
ctx.success();
Metrics::get().proxy.retries_metric.observe(
RetriesMetricGroup {
outcome: ConnectOutcome::Success,
@@ -178,7 +178,7 @@ where
.await
{
Ok(res) => {
ctx.latency_timer.success();
ctx.success();
Metrics::get().proxy.retries_metric.observe(
RetriesMetricGroup {
outcome: ConnectOutcome::Success,
@@ -209,9 +209,7 @@ where
let wait_duration = retry_after(num_retries, connect_to_compute_retry_config);
num_retries += 1;
let pause = ctx
.latency_timer
.pause(crate::metrics::Waiting::RetryTimeout);
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::RetryTimeout);
time::sleep(wait_duration).await;
drop(pause);
}

View File

@@ -10,6 +10,7 @@ use tracing::{info, warn};
use crate::{
auth::endpoint_sni,
config::{TlsConfig, PG_ALPN_PROTOCOL},
context::RequestMonitoring,
error::ReportableError,
metrics::Metrics,
proxy::ERR_INSECURE_CONNECTION,
@@ -67,6 +68,7 @@ pub enum HandshakeData<S> {
/// we also take an extra care of propagating only the select handshake errors to client.
#[tracing::instrument(skip_all)]
pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
ctx: &RequestMonitoring,
stream: S,
mut tls: Option<&TlsConfig>,
record_handshake_error: bool,
@@ -80,8 +82,6 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
let mut stream = PqStream::new(Stream::from_raw(stream));
loop {
let msg = stream.read_startup_packet().await?;
info!("received {msg:?}");
use FeStartupPacket::*;
match msg {
SslRequest { direct } => match stream.get_ref() {
@@ -145,16 +145,20 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
let conn_info = tls_stream.get_ref().1;
// try parse endpoint
let ep = conn_info
.server_name()
.and_then(|sni| endpoint_sni(sni, &tls.common_names).ok().flatten());
if let Some(ep) = ep {
ctx.set_endpoint_id(ep);
}
// check the ALPN, if exists, as required.
match conn_info.alpn_protocol() {
None | Some(PG_ALPN_PROTOCOL) => {}
Some(other) => {
// try parse ep for better error
let ep = conn_info.server_name().and_then(|sni| {
endpoint_sni(sni, &tls.common_names).ok().flatten()
});
let alpn = String::from_utf8_lossy(other);
warn!(?ep, %alpn, "unexpected ALPN");
warn!(%alpn, "unexpected ALPN");
return Err(HandshakeError::ProtocolViolation);
}
}
@@ -198,7 +202,12 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
.await?;
}
info!(?version, session_type = "normal", "successful handshake");
info!(
?version,
?params,
session_type = "normal",
"successful handshake"
);
break Ok(HandshakeData::Startup(stream, params));
}
// downgrade protocol version

View File

@@ -155,7 +155,7 @@ impl TestAuth for Scram {
stream: &mut PqStream<Stream<S>>,
) -> anyhow::Result<()> {
let outcome = auth::AuthFlow::new(stream)
.begin(auth::Scram(&self.0, &mut RequestMonitoring::test()))
.begin(auth::Scram(&self.0, &RequestMonitoring::test()))
.await?
.authenticate()
.await?;
@@ -175,10 +175,11 @@ async fn dummy_proxy(
auth: impl TestAuth + Send,
) -> anyhow::Result<()> {
let (client, _) = read_proxy_protocol(client).await?;
let mut stream = match handshake(client, tls.as_ref(), false).await? {
HandshakeData::Startup(stream, _) => stream,
HandshakeData::Cancel(_) => bail!("cancellation not supported"),
};
let mut stream =
match handshake(&RequestMonitoring::test(), client, tls.as_ref(), false).await? {
HandshakeData::Startup(stream, _) => stream,
HandshakeData::Cancel(_) => bail!("cancellation not supported"),
};
auth.authenticate(&mut stream).await?;
@@ -457,7 +458,7 @@ impl ConnectMechanism for TestConnectMechanism {
async fn connect_once(
&self,
_ctx: &mut RequestMonitoring,
_ctx: &RequestMonitoring,
_node_info: &console::CachedNodeInfo,
_timeout: std::time::Duration,
) -> Result<Self::Connection, Self::ConnectError> {
@@ -565,7 +566,7 @@ fn helper_create_connect_info(
async fn connect_to_compute_success() {
let _ = env_logger::try_init();
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let ctx = RequestMonitoring::test();
let mechanism = TestConnectMechanism::new(vec![Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = RetryConfig {
@@ -573,7 +574,7 @@ async fn connect_to_compute_success() {
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config)
connect_to_compute(&ctx, &mechanism, &user_info, false, config, config)
.await
.unwrap();
mechanism.verify();
@@ -583,7 +584,7 @@ async fn connect_to_compute_success() {
async fn connect_to_compute_retry() {
let _ = env_logger::try_init();
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let ctx = RequestMonitoring::test();
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = RetryConfig {
@@ -591,7 +592,7 @@ async fn connect_to_compute_retry() {
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config)
connect_to_compute(&ctx, &mechanism, &user_info, false, config, config)
.await
.unwrap();
mechanism.verify();
@@ -602,7 +603,7 @@ async fn connect_to_compute_retry() {
async fn connect_to_compute_non_retry_1() {
let _ = env_logger::try_init();
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let ctx = RequestMonitoring::test();
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]);
let user_info = helper_create_connect_info(&mechanism);
let config = RetryConfig {
@@ -610,7 +611,7 @@ async fn connect_to_compute_non_retry_1() {
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config)
connect_to_compute(&ctx, &mechanism, &user_info, false, config, config)
.await
.unwrap_err();
mechanism.verify();
@@ -621,7 +622,7 @@ async fn connect_to_compute_non_retry_1() {
async fn connect_to_compute_non_retry_2() {
let _ = env_logger::try_init();
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let ctx = RequestMonitoring::test();
let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = RetryConfig {
@@ -629,7 +630,7 @@ async fn connect_to_compute_non_retry_2() {
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config)
connect_to_compute(&ctx, &mechanism, &user_info, false, config, config)
.await
.unwrap();
mechanism.verify();
@@ -641,7 +642,7 @@ async fn connect_to_compute_non_retry_3() {
let _ = env_logger::try_init();
tokio::time::pause();
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let ctx = RequestMonitoring::test();
let mechanism =
TestConnectMechanism::new(vec![Wake, Retry, Wake, Retry, Retry, Retry, Retry, Retry]);
let user_info = helper_create_connect_info(&mechanism);
@@ -656,7 +657,7 @@ async fn connect_to_compute_non_retry_3() {
backoff_factor: 2.0,
};
connect_to_compute(
&mut ctx,
&ctx,
&mechanism,
&user_info,
false,
@@ -673,7 +674,7 @@ async fn connect_to_compute_non_retry_3() {
async fn wake_retry() {
let _ = env_logger::try_init();
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let ctx = RequestMonitoring::test();
let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = RetryConfig {
@@ -681,7 +682,7 @@ async fn wake_retry() {
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config)
connect_to_compute(&ctx, &mechanism, &user_info, false, config, config)
.await
.unwrap();
mechanism.verify();
@@ -692,7 +693,7 @@ async fn wake_retry() {
async fn wake_non_retry() {
let _ = env_logger::try_init();
use ConnectAction::*;
let mut ctx = RequestMonitoring::test();
let ctx = RequestMonitoring::test();
let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]);
let user_info = helper_create_connect_info(&mechanism);
let config = RetryConfig {
@@ -700,7 +701,7 @@ async fn wake_non_retry() {
max_retries: 5,
backoff_factor: 2.0,
};
connect_to_compute(&mut ctx, &mechanism, &user_info, false, config, config)
connect_to_compute(&ctx, &mechanism, &user_info, false, config, config)
.await
.unwrap_err();
mechanism.verify();

View File

@@ -34,9 +34,14 @@ async fn proxy_mitm(
tokio::spawn(async move {
// begin handshake with end_server
let end_server = connect_tls(server2, client_config2.make_tls_connect().unwrap()).await;
let (end_client, startup) = match handshake(client1, Some(&server_config1), false)
.await
.unwrap()
let (end_client, startup) = match handshake(
&RequestMonitoring::test(),
client1,
Some(&server_config1),
false,
)
.await
.unwrap()
{
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(_) => panic!("cancellation not supported"),

View File

@@ -14,7 +14,7 @@ use super::connect_compute::ComputeConnectBackend;
pub async fn wake_compute<B: ComputeConnectBackend>(
num_retries: &mut u32,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
api: &B,
config: RetryConfig,
) -> Result<CachedNodeInfo, WakeComputeError> {
@@ -52,9 +52,7 @@ pub async fn wake_compute<B: ComputeConnectBackend>(
let wait_duration = retry_after(*num_retries, config);
*num_retries += 1;
let pause = ctx
.latency_timer
.pause(crate::metrics::Waiting::RetryTimeout);
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::RetryTimeout);
tokio::time::sleep(wait_duration).await;
drop(pause);
}

View File

@@ -334,7 +334,7 @@ async fn request_handler(
&config.region,
);
let span = ctx.span.clone();
let span = ctx.span();
info!(parent: &span, "performing websocket upgrade");
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
@@ -367,7 +367,7 @@ async fn request_handler(
crate::metrics::Protocol::Http,
&config.region,
);
let span = ctx.span.clone();
let span = ctx.span();
sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)
.instrument(span)

View File

@@ -35,15 +35,15 @@ pub struct PoolingBackend {
impl PoolingBackend {
pub async fn authenticate(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
config: &AuthenticationConfig,
conn_info: &ConnInfo,
) -> Result<ComputeCredentials, AuthError> {
let user_info = conn_info.user_info.clone();
let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone());
let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
return Err(AuthError::ip_address_not_allowed(ctx.peer_addr));
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) {
return Err(AuthError::ip_address_not_allowed(ctx.peer_addr()));
}
if !self
.endpoint_rate_limiter
@@ -100,7 +100,7 @@ impl PoolingBackend {
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
pub async fn connect_to_compute(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
conn_info: ConnInfo,
keys: ComputeCredentials,
force_new: bool,
@@ -222,7 +222,7 @@ impl ConnectMechanism for TokioMechanism {
async fn connect_once(
&self,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
node_info: &CachedNodeInfo,
timeout: Duration,
) -> Result<Self::Connection, Self::ConnectError> {
@@ -240,7 +240,7 @@ impl ConnectMechanism for TokioMechanism {
.param("client_encoding", "UTF8")
.expect("client encoding UTF8 is always valid");
let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Compute);
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let res = config.connect(tokio_postgres::NoTls).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;

View File

@@ -377,7 +377,7 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
pub fn get(
self: &Arc<Self>,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
conn_info: &ConnInfo,
) -> Result<Option<Client<C>>, HttpConnError> {
let mut client: Option<ClientInner<C>> = None;
@@ -409,9 +409,9 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"pool: reusing connection '{conn_info}'"
);
client.session.send(ctx.session_id)?;
client.session.send(ctx.session_id())?;
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
ctx.latency_timer.success();
ctx.success();
return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
}
}
@@ -465,19 +465,19 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
pub fn poll_client<C: ClientInnerExt>(
global_pool: Arc<GlobalConnPool<C>>,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
conn_info: ConnInfo,
client: C,
mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
) -> Client<C> {
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol);
let mut session_id = ctx.session_id;
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
let mut session_id = ctx.session_id();
let (tx, mut rx) = tokio::sync::watch::channel(session_id);
let span = info_span!(parent: None, "connection", %conn_id);
let cold_start_info = ctx.cold_start_info;
let cold_start_info = ctx.cold_start_info();
span.in_scope(|| {
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
});
@@ -766,7 +766,6 @@ mod tests {
opt_in: false,
max_total_conns: 3,
},
request_timeout: Duration::from_secs(1),
cancel_set: CancelSet::new(0),
client_conn_threshold: u64::MAX,
}));

View File

@@ -144,7 +144,7 @@ impl UserFacingError for ConnInfoError {
}
fn get_conn_info(
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
headers: &HeaderMap,
tls: &TlsConfig,
) -> Result<ConnInfo, ConnInfoError> {
@@ -224,12 +224,12 @@ fn get_conn_info(
// TODO: return different http error codes
pub async fn handle(
config: &'static ProxyConfig,
mut ctx: RequestMonitoring,
ctx: RequestMonitoring,
request: Request<Incoming>,
backend: Arc<PoolingBackend>,
cancel: CancellationToken,
) -> Result<Response<Full<Bytes>>, ApiError> {
let result = handle_inner(cancel, config, &mut ctx, request, backend).await;
let result = handle_inner(cancel, config, &ctx, request, backend).await;
let mut response = match result {
Ok(r) => {
@@ -482,13 +482,16 @@ fn map_isolation_level_to_headers(level: IsolationLevel) -> Option<HeaderValue>
async fn handle_inner(
cancel: CancellationToken,
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
ctx: &RequestMonitoring,
request: Request<Incoming>,
backend: Arc<PoolingBackend>,
) -> Result<Response<Full<Bytes>>, SqlOverHttpError> {
let _requeset_gauge = Metrics::get().proxy.connection_requests.guard(ctx.protocol);
let _requeset_gauge = Metrics::get()
.proxy
.connection_requests
.guard(ctx.protocol());
info!(
protocol = %ctx.protocol,
protocol = %ctx.protocol(),
"handling interactive connection from client"
);
@@ -544,7 +547,7 @@ async fn handle_inner(
.await?;
// not strictly necessary to mark success here,
// but it's just insurance for if we forget it somewhere else
ctx.latency_timer.success();
ctx.success();
Ok::<_, HttpConnError>(client)
}
.map_err(SqlOverHttpError::from),

View File

@@ -129,7 +129,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
pub async fn serve_websocket(
config: &'static ProxyConfig,
mut ctx: RequestMonitoring,
ctx: RequestMonitoring,
websocket: OnUpgrade,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -145,7 +145,7 @@ pub async fn serve_websocket(
let res = Box::pin(handle_client(
config,
&mut ctx,
&ctx,
cancellation_handler,
WebSocketRw::new(websocket),
ClientMode::Websockets { hostname },