Merge branch 'main' into devin/1745492468-add-dev-flag-pr11517

This commit is contained in:
John Spray
2025-06-02 11:43:31 +01:00
committed by GitHub
53 changed files with 1708 additions and 1995 deletions

View File

@@ -310,13 +310,13 @@ RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux
. "$HOME/.cargo/env" && \
cargo --version && rustup --version && \
rustup component add llvm-tools rustfmt clippy && \
cargo install rustfilt --version ${RUSTFILT_VERSION} && \
cargo install cargo-hakari --version ${CARGO_HAKARI_VERSION} && \
cargo install cargo-deny --locked --version ${CARGO_DENY_VERSION} && \
cargo install cargo-hack --version ${CARGO_HACK_VERSION} && \
cargo install cargo-nextest --version ${CARGO_NEXTEST_VERSION} && \
cargo install cargo-chef --locked --version ${CARGO_CHEF_VERSION} && \
cargo install diesel_cli --version ${CARGO_DIESEL_CLI_VERSION} \
cargo install rustfilt --version ${RUSTFILT_VERSION} --locked && \
cargo install cargo-hakari --version ${CARGO_HAKARI_VERSION} --locked && \
cargo install cargo-deny --version ${CARGO_DENY_VERSION} --locked && \
cargo install cargo-hack --version ${CARGO_HACK_VERSION} --locked && \
cargo install cargo-nextest --version ${CARGO_NEXTEST_VERSION} --locked && \
cargo install cargo-chef --version ${CARGO_CHEF_VERSION} --locked && \
cargo install diesel_cli --version ${CARGO_DIESEL_CLI_VERSION} --locked \
--features postgres-bundled --no-default-features && \
rm -rf /home/nonroot/.cargo/registry && \
rm -rf /home/nonroot/.cargo/git

View File

@@ -57,21 +57,6 @@ use tracing::{error, info};
use url::Url;
use utils::failpoint_support;
// Compatibility hack: if the control plane specified any remote-ext-config
// use the default value for extension storage proxy gateway.
// Remove this once the control plane is updated to pass the gateway URL
fn parse_remote_ext_base_url(arg: &str) -> Result<String> {
const FALLBACK_PG_EXT_GATEWAY_BASE_URL: &str =
"http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local";
Ok(if arg.starts_with("http") {
arg
} else {
FALLBACK_PG_EXT_GATEWAY_BASE_URL
}
.to_owned())
}
#[derive(Parser)]
#[command(rename_all = "kebab-case")]
struct Cli {
@@ -79,9 +64,8 @@ struct Cli {
pub pgbin: String,
/// The base URL for the remote extension storage proxy gateway.
/// Should be in the form of `http(s)://<gateway-hostname>[:<port>]`.
#[arg(short = 'r', long, value_parser = parse_remote_ext_base_url, alias = "remote-ext-config")]
pub remote_ext_base_url: Option<String>,
#[arg(short = 'r', long)]
pub remote_ext_base_url: Option<Url>,
/// The port to bind the external listening HTTP server to. Clients running
/// outside the compute will talk to the compute through this port. Keep
@@ -276,18 +260,4 @@ mod test {
fn verify_cli() {
Cli::command().debug_assert()
}
#[test]
fn parse_pg_ext_gateway_base_url() {
let arg = "http://pg-ext-s3-gateway2";
let result = super::parse_remote_ext_base_url(arg).unwrap();
assert_eq!(result, arg);
let arg = "pg-ext-s3-gateway";
let result = super::parse_remote_ext_base_url(arg).unwrap();
assert_eq!(
result,
"http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local"
);
}
}

View File

@@ -31,6 +31,7 @@ use std::time::{Duration, Instant};
use std::{env, fs};
use tokio::spawn;
use tracing::{Instrument, debug, error, info, instrument, warn};
use url::Url;
use utils::id::{TenantId, TimelineId};
use utils::lsn::Lsn;
use utils::measured_stream::MeasuredReader;
@@ -96,7 +97,7 @@ pub struct ComputeNodeParams {
pub internal_http_port: u16,
/// the address of extension storage proxy gateway
pub remote_ext_base_url: Option<String>,
pub remote_ext_base_url: Option<Url>,
/// Interval for installed extensions collection
pub installed_extensions_collection_interval: u64,

View File

@@ -83,6 +83,7 @@ use reqwest::StatusCode;
use tar::Archive;
use tracing::info;
use tracing::log::warn;
use url::Url;
use zstd::stream::read::Decoder;
use crate::metrics::{REMOTE_EXT_REQUESTS_TOTAL, UNKNOWN_HTTP_STATUS};
@@ -158,14 +159,14 @@ fn parse_pg_version(human_version: &str) -> PostgresMajorVersion {
pub async fn download_extension(
ext_name: &str,
ext_path: &RemotePath,
remote_ext_base_url: &str,
remote_ext_base_url: &Url,
pgbin: &str,
) -> Result<u64> {
info!("Download extension {:?} from {:?}", ext_name, ext_path);
// TODO add retry logic
let download_buffer =
match download_extension_tar(remote_ext_base_url, &ext_path.to_string()).await {
match download_extension_tar(remote_ext_base_url.as_str(), &ext_path.to_string()).await {
Ok(buffer) => buffer,
Err(error_message) => {
return Err(anyhow::anyhow!(

View File

@@ -107,7 +107,7 @@ impl<const N: usize> MetricType for HyperLogLogState<N> {
}
impl<const N: usize> HyperLogLogState<N> {
pub fn measure(&self, item: &impl Hash) {
pub fn measure(&self, item: &(impl Hash + ?Sized)) {
// changing the hasher will break compatibility with previous measurements.
self.record(BuildHasherDefault::<xxh3::Hash64>::default().hash_one(item));
}

View File

@@ -713,9 +713,9 @@ impl Default for ConfigToml {
enable_tls_page_service_api: false,
dev_mode: false,
timeline_import_config: TimelineImportConfig {
import_job_concurrency: NonZeroUsize::new(128).unwrap(),
import_job_soft_size_limit: NonZeroUsize::new(1024 * 1024 * 1024).unwrap(),
import_job_checkpoint_threshold: NonZeroUsize::new(128).unwrap(),
import_job_concurrency: NonZeroUsize::new(32).unwrap(),
import_job_soft_size_limit: NonZeroUsize::new(256 * 1024 * 1024).unwrap(),
import_job_checkpoint_threshold: NonZeroUsize::new(32).unwrap(),
},
basebackup_cache_config: None,
posthog_config: None,

View File

@@ -28,6 +28,7 @@ use std::time::Duration;
use tokio::sync::Notify;
use tokio::time::Instant;
#[derive(Clone, Copy)]
pub struct LeakyBucketConfig {
/// This is the "time cost" of a single request unit.
/// Should loosely represent how long it takes to handle a request unit in active resource time.

View File

@@ -106,6 +106,8 @@ pub async fn doit(
);
}
tracing::info!("Import plan executed. Flushing remote changes and notifying storcon");
timeline
.remote_client
.schedule_index_upload_for_file_changes()?;

View File

@@ -130,7 +130,15 @@ async fn run_v1(
pausable_failpoint!("import-timeline-pre-execute-pausable");
let jobs_count = import_progress.as_ref().map(|p| p.jobs);
let start_from_job_idx = import_progress.map(|progress| progress.completed);
tracing::info!(
start_from_job_idx=?start_from_job_idx,
jobs=?jobs_count,
"Executing import plan"
);
plan.execute(timeline, start_from_job_idx, plan_hash, &import_config, ctx)
.await
}
@@ -484,6 +492,8 @@ impl Plan {
anyhow::anyhow!("Shut down while putting timeline import status")
})?;
}
tracing::info!(last_completed_job_idx, jobs=%jobs_in_plan, "Checkpointing import status");
},
Some(Err(_)) => {
anyhow::bail!(
@@ -760,7 +770,7 @@ impl ImportTask for ImportRelBlocksTask {
layer_writer: &mut ImageLayerWriter,
ctx: &RequestContext,
) -> anyhow::Result<usize> {
const MAX_BYTE_RANGE_SIZE: usize = 128 * 1024 * 1024;
const MAX_BYTE_RANGE_SIZE: usize = 4 * 1024 * 1024;
debug!("Importing relation file");

View File

@@ -113,7 +113,7 @@ impl WalReceiver {
}
connection_manager_state.shutdown().await;
*loop_status.write().unwrap() = None;
debug!("task exits");
info!("task exits");
}
.instrument(info_span!(parent: None, "wal_connection_manager", tenant_id = %tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug(), timeline_id = %timeline_id))
});

View File

@@ -297,6 +297,7 @@ pub(super) async fn handle_walreceiver_connection(
let mut expected_wal_start = startpoint;
while let Some(replication_message) = {
select! {
biased;
_ = cancellation.cancelled() => {
debug!("walreceiver interrupted");
None

View File

@@ -17,35 +17,27 @@ pub(super) async fn authenticate(
config: &'static AuthenticationConfig,
secret: AuthSecret,
) -> auth::Result<ComputeCredentials> {
let flow = AuthFlow::new(client);
let scram_keys = match secret {
#[cfg(any(test, feature = "testing"))]
AuthSecret::Md5(_) => {
debug!("auth endpoint chooses MD5");
return Err(auth::AuthError::bad_auth_method("MD5"));
return Err(auth::AuthError::MalformedPassword("MD5 not supported"));
}
AuthSecret::Scram(secret) => {
debug!("auth endpoint chooses SCRAM");
let scram = auth::Scram(&secret, ctx);
let auth_outcome = tokio::time::timeout(
config.scram_protocol_timeout,
async {
flow.begin(scram).await.map_err(|error| {
warn!(?error, "error sending scram acknowledgement");
error
})?.authenticate().await.map_err(|error| {
let auth_outcome = tokio::time::timeout(config.scram_protocol_timeout, async {
AuthFlow::new(client, scram)
.authenticate()
.await
.inspect_err(|error| {
warn!(?error, "error processing scram messages");
error
})
}
)
})
.await
.map_err(|e| {
warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs());
auth::AuthError::user_timeout(e)
})??;
.inspect_err(|_| warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs()))
.map_err(auth::AuthError::user_timeout)??;
let client_key = match auth_outcome {
sasl::Outcome::Success(key) => key,

View File

@@ -2,7 +2,6 @@ use std::fmt;
use async_trait::async_trait;
use postgres_client::config::SslMode;
use pq_proto::BeMessage as Be;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, info_span};
@@ -16,6 +15,7 @@ use crate::context::RequestContext;
use crate::control_plane::client::cplane_proxy_v1;
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
use crate::error::{ReportableError, UserFacingError};
use crate::pqproto::BeMessage;
use crate::proxy::NeonOptions;
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::stream::PqStream;
@@ -154,11 +154,13 @@ async fn authenticate(
// Give user a URL to spawn a new database.
info!(parent: &span, "sending the auth URL to the user");
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&Be::CLIENT_ENCODING)?
.write_message(&Be::NoticeResponse(&greeting))
.await?;
client.write_message(BeMessage::AuthenticationOk);
client.write_message(BeMessage::ParameterStatus {
name: b"client_encoding",
value: b"UTF8",
});
client.write_message(BeMessage::NoticeResponse(&greeting));
client.flush().await?;
// Wait for console response via control plane (see `mgmt`).
info!(parent: &span, "waiting for console's reply...");
@@ -188,7 +190,7 @@ async fn authenticate(
}
}
client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
client.write_message(BeMessage::NoticeResponse("Connecting to database."));
// This config should be self-contained, because we won't
// take username or dbname from client's startup message.

View File

@@ -24,23 +24,25 @@ pub(crate) async fn authenticate_cleartext(
debug!("cleartext auth flow override is enabled, proceeding");
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
// pause the timer while we communicate with the client
let paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let ep = EndpointIdInt::from(&info.endpoint);
let auth_flow = AuthFlow::new(client)
.begin(auth::CleartextPassword {
let auth_flow = AuthFlow::new(
client,
auth::CleartextPassword {
secret,
endpoint: ep,
pool: config.thread_pool.clone(),
})
.await?;
drop(paused);
// cleartext auth is only allowed to the ws/http protocol.
// If we're here, we already received the password in the first message.
// Scram protocol will be executed on the proxy side.
let auth_outcome = auth_flow.authenticate().await?;
},
);
let auth_outcome = {
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
// cleartext auth is only allowed to the ws/http protocol.
// If we're here, we already received the password in the first message.
// Scram protocol will be executed on the proxy side.
auth_flow.authenticate().await?
};
let keys = match auth_outcome {
sasl::Outcome::Success(key) => key,
@@ -67,9 +69,7 @@ pub(crate) async fn password_hack_no_authentication(
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let payload = AuthFlow::new(client)
.begin(auth::PasswordHack)
.await?
let payload = AuthFlow::new(client, auth::PasswordHack)
.get_password()
.await?;

View File

@@ -4,37 +4,31 @@ mod hacks;
pub mod jwt;
pub mod local;
use std::net::IpAddr;
use std::sync::Arc;
pub use console_redirect::ConsoleRedirectBackend;
pub(crate) use console_redirect::ConsoleRedirectError;
use ipnet::{Ipv4Net, Ipv6Net};
use local::LocalBackend;
use postgres_client::config::AuthKeys;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, info, warn};
use tracing::{debug, info};
use crate::auth::credentials::check_peer_addr_is_in_list;
use crate::auth::{
self, AuthError, ComputeUserInfoMaybeEndpoint, IpPattern, validate_password_and_exchange,
};
use crate::auth::{self, AuthError, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange};
use crate::cache::Cached;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::client::ControlPlaneClient;
use crate::control_plane::errors::GetAuthInfoError;
use crate::control_plane::{
self, AccessBlockerFlags, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps,
CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi,
self, AccessBlockerFlags, AuthSecret, CachedNodeInfo, ControlPlaneApi, EndpointAccessControl,
RoleAccessControl,
};
use crate::intern::EndpointIdInt;
use crate::metrics::Metrics;
use crate::protocol2::ConnectionInfoExtra;
use crate::pqproto::BeMessage;
use crate::proxy::NeonOptions;
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter};
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::Stream;
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
use crate::{scram, stream};
@@ -200,78 +194,6 @@ impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
}
}
#[derive(PartialEq, PartialOrd, Hash, Eq, Ord, Debug, Copy, Clone)]
pub struct MaskedIp(IpAddr);
impl MaskedIp {
fn new(value: IpAddr, prefix: u8) -> Self {
match value {
IpAddr::V4(v4) => Self(IpAddr::V4(
Ipv4Net::new(v4, prefix).map_or(v4, |x| x.trunc().addr()),
)),
IpAddr::V6(v6) => Self(IpAddr::V6(
Ipv6Net::new(v6, prefix).map_or(v6, |x| x.trunc().addr()),
)),
}
}
}
// This can't be just per IP because that would limit some PaaS that share IP addresses
pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, MaskedIp)>;
impl AuthenticationConfig {
pub(crate) fn check_rate_limit(
&self,
ctx: &RequestContext,
secret: AuthSecret,
endpoint: &EndpointId,
is_cleartext: bool,
) -> auth::Result<AuthSecret> {
// we have validated the endpoint exists, so let's intern it.
let endpoint_int = EndpointIdInt::from(endpoint.normalize());
// only count the full hash count if password hack or websocket flow.
// in other words, if proxy needs to run the hashing
let password_weight = if is_cleartext {
match &secret {
#[cfg(any(test, feature = "testing"))]
AuthSecret::Md5(_) => 1,
AuthSecret::Scram(s) => s.iterations + 1,
}
} else {
// validating scram takes just 1 hmac_sha_256 operation.
1
};
let limit_not_exceeded = self.rate_limiter.check(
(
endpoint_int,
MaskedIp::new(ctx.peer_addr(), self.rate_limit_ip_subnet),
),
password_weight,
);
if !limit_not_exceeded {
warn!(
enabled = self.rate_limiter_enabled,
"rate limiting authentication"
);
Metrics::get().proxy.requests_auth_rate_limits_total.inc();
Metrics::get()
.proxy
.endpoints_auth_rate_limits
.get_metric()
.measure(endpoint);
if self.rate_limiter_enabled {
return Err(auth::AuthError::too_many_connections());
}
}
Ok(secret)
}
}
/// True to its name, this function encapsulates our current auth trade-offs.
/// Here, we choose the appropriate auth flow based on circumstances.
///
@@ -284,7 +206,7 @@ async fn auth_quirks(
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> auth::Result<(ComputeCredentials, Option<Vec<IpPattern>>)> {
) -> auth::Result<ComputeCredentials> {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name.
// We now expect to see a very specific payload in the place of password.
@@ -300,55 +222,27 @@ async fn auth_quirks(
debug!("fetching authentication info and allowlists");
// check allowed list
let allowed_ips = if config.ip_allowlist_check_enabled {
let allowed_ips = api.get_allowed_ips(ctx, &info).await?;
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) {
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
}
allowed_ips
} else {
Cached::new_uncached(Arc::new(vec![]))
};
let access_controls = api
.get_endpoint_access_control(ctx, &info.endpoint, &info.user)
.await?;
// check if a VPC endpoint ID is coming in and if yes, if it's allowed
let access_blocks = api.get_block_public_or_vpc_access(ctx, &info).await?;
if config.is_vpc_acccess_proxy {
if access_blocks.vpc_access_blocked {
return Err(AuthError::NetworkNotAllowed);
}
access_controls.check(
ctx,
config.ip_allowlist_check_enabled,
config.is_vpc_acccess_proxy,
)?;
let incoming_vpc_endpoint_id = match ctx.extra() {
None => return Err(AuthError::MissingEndpointName),
Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(),
Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(),
};
let allowed_vpc_endpoint_ids = api.get_allowed_vpc_endpoint_ids(ctx, &info).await?;
// TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that.
if !allowed_vpc_endpoint_ids.is_empty()
&& !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id)
{
return Err(AuthError::vpc_endpoint_id_not_allowed(
incoming_vpc_endpoint_id,
));
}
} else if access_blocks.public_access_blocked {
return Err(AuthError::NetworkNotAllowed);
}
if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) {
let endpoint = EndpointIdInt::from(&info.endpoint);
let rate_limit_config = None;
if !endpoint_rate_limiter.check(endpoint, rate_limit_config, 1) {
return Err(AuthError::too_many_connections());
}
let cached_secret = api.get_role_secret(ctx, &info).await?;
let (cached_entry, secret) = cached_secret.take_value();
let role_access = api
.get_role_access_control(ctx, &info.endpoint, &info.user)
.await?;
let secret = if let Some(secret) = secret {
config.check_rate_limit(
ctx,
secret,
&info.endpoint,
unauthenticated_password.is_some() || allow_cleartext,
)?
let secret = if let Some(secret) = role_access.secret {
secret
} else {
// If we don't have an authentication secret, we mock one to
// prevent malicious probing (possible due to missing protocol steps).
@@ -368,14 +262,8 @@ async fn auth_quirks(
)
.await
{
Ok(keys) => Ok((keys, Some(allowed_ips.as_ref().clone()))),
Err(e) => {
if e.is_password_failed() {
// The password could have been changed, so we invalidate the cache.
cached_entry.invalidate();
}
Err(e)
}
Ok(keys) => Ok(keys),
Err(e) => Err(e),
}
}
@@ -402,7 +290,7 @@ async fn authenticate_with_secret(
};
// we have authenticated the password
client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?;
client.write_message(BeMessage::AuthenticationOk);
return Ok(ComputeCredentials { info, keys });
}
@@ -438,7 +326,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> auth::Result<(Backend<'a, ComputeCredentials>, Option<Vec<IpPattern>>)> {
) -> auth::Result<Backend<'a, ComputeCredentials>> {
let res = match self {
Self::ControlPlane(api, user_info) => {
debug!(
@@ -447,17 +335,35 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
"performing authentication using the console"
);
let (credentials, ip_allowlist) = auth_quirks(
let auth_res = auth_quirks(
ctx,
&*api,
user_info,
user_info.clone(),
client,
allow_cleartext,
config,
endpoint_rate_limiter,
)
.await?;
Ok((Backend::ControlPlane(api, credentials), ip_allowlist))
.await;
match auth_res {
Ok(credentials) => Ok(Backend::ControlPlane(api, credentials)),
Err(e) => {
// The password could have been changed, so we invalidate the cache.
// We should only invalidate the cache if the TTL might have expired.
if e.is_password_failed() {
#[allow(irrefutable_let_patterns)]
if let ControlPlaneClient::ProxyV1(api) = &*api {
if let Some(ep) = &user_info.endpoint_id {
api.caches
.project_info
.maybe_invalidate_role_secret(ep, &user_info.user);
}
}
}
Err(e)
}
}
}
Self::Local(_) => {
return Err(auth::AuthError::bad_auth_method("invalid for local proxy"));
@@ -474,44 +380,30 @@ impl Backend<'_, ComputeUserInfo> {
pub(crate) async fn get_role_secret(
&self,
ctx: &RequestContext,
) -> Result<CachedRoleSecret, GetAuthInfoError> {
match self {
Self::ControlPlane(api, user_info) => api.get_role_secret(ctx, user_info).await,
Self::Local(_) => Ok(Cached::new_uncached(None)),
}
}
pub(crate) async fn get_allowed_ips(
&self,
ctx: &RequestContext,
) -> Result<CachedAllowedIps, GetAuthInfoError> {
match self {
Self::ControlPlane(api, user_info) => api.get_allowed_ips(ctx, user_info).await,
Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))),
}
}
pub(crate) async fn get_allowed_vpc_endpoint_ids(
&self,
ctx: &RequestContext,
) -> Result<CachedAllowedVpcEndpointIds, GetAuthInfoError> {
) -> Result<RoleAccessControl, GetAuthInfoError> {
match self {
Self::ControlPlane(api, user_info) => {
api.get_allowed_vpc_endpoint_ids(ctx, user_info).await
api.get_role_access_control(ctx, &user_info.endpoint, &user_info.user)
.await
}
Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))),
Self::Local(_) => Ok(RoleAccessControl { secret: None }),
}
}
pub(crate) async fn get_block_public_or_vpc_access(
pub(crate) async fn get_endpoint_access_control(
&self,
ctx: &RequestContext,
) -> Result<CachedAccessBlockerFlags, GetAuthInfoError> {
) -> Result<EndpointAccessControl, GetAuthInfoError> {
match self {
Self::ControlPlane(api, user_info) => {
api.get_block_public_or_vpc_access(ctx, user_info).await
api.get_endpoint_access_control(ctx, &user_info.endpoint, &user_info.user)
.await
}
Self::Local(_) => Ok(Cached::new_uncached(AccessBlockerFlags::default())),
Self::Local(_) => Ok(EndpointAccessControl {
allowed_ips: Arc::new(vec![]),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
}),
}
}
}
@@ -540,9 +432,7 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
mod tests {
#![allow(clippy::unimplemented, clippy::unwrap_used)]
use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;
use bytes::BytesMut;
use control_plane::AuthSecret;
@@ -553,18 +443,16 @@ mod tests {
use postgres_protocol::message::frontend;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
use super::auth_quirks;
use super::jwt::JwkCache;
use super::{AuthRateLimiter, auth_quirks};
use crate::auth::backend::MaskedIp;
use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern};
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::{
self, AccessBlockerFlags, CachedAccessBlockerFlags, CachedAllowedIps,
CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret,
self, AccessBlockerFlags, CachedNodeInfo, EndpointAccessControl, RoleAccessControl,
};
use crate::proxy::NeonOptions;
use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo};
use crate::rate_limiter::EndpointRateLimiter;
use crate::scram::ServerSecret;
use crate::scram::threadpool::ThreadPool;
use crate::stream::{PqStream, Stream};
@@ -577,46 +465,34 @@ mod tests {
}
impl control_plane::ControlPlaneApi for Auth {
async fn get_role_secret(
async fn get_role_access_control(
&self,
_ctx: &RequestContext,
_user_info: &super::ComputeUserInfo,
) -> Result<CachedRoleSecret, control_plane::errors::GetAuthInfoError> {
Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone())))
_endpoint: &crate::types::EndpointId,
_role: &crate::types::RoleName,
) -> Result<RoleAccessControl, control_plane::errors::GetAuthInfoError> {
Ok(RoleAccessControl {
secret: Some(self.secret.clone()),
})
}
async fn get_allowed_ips(
async fn get_endpoint_access_control(
&self,
_ctx: &RequestContext,
_user_info: &super::ComputeUserInfo,
) -> Result<CachedAllowedIps, control_plane::errors::GetAuthInfoError> {
Ok(CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())))
}
async fn get_allowed_vpc_endpoint_ids(
&self,
_ctx: &RequestContext,
_user_info: &super::ComputeUserInfo,
) -> Result<CachedAllowedVpcEndpointIds, control_plane::errors::GetAuthInfoError> {
Ok(CachedAllowedVpcEndpointIds::new_uncached(Arc::new(
self.vpc_endpoint_ids.clone(),
)))
}
async fn get_block_public_or_vpc_access(
&self,
_ctx: &RequestContext,
_user_info: &super::ComputeUserInfo,
) -> Result<CachedAccessBlockerFlags, control_plane::errors::GetAuthInfoError> {
Ok(CachedAccessBlockerFlags::new_uncached(
self.access_blocker_flags.clone(),
))
_endpoint: &crate::types::EndpointId,
_role: &crate::types::RoleName,
) -> Result<EndpointAccessControl, control_plane::errors::GetAuthInfoError> {
Ok(EndpointAccessControl {
allowed_ips: Arc::new(self.ips.clone()),
allowed_vpce: Arc::new(self.vpc_endpoint_ids.clone()),
flags: self.access_blocker_flags,
})
}
async fn get_endpoint_jwks(
&self,
_ctx: &RequestContext,
_endpoint: crate::types::EndpointId,
_endpoint: &crate::types::EndpointId,
) -> Result<Vec<super::jwt::AuthRule>, control_plane::errors::GetEndpointJwksError>
{
unimplemented!()
@@ -635,9 +511,6 @@ mod tests {
jwks_cache: JwkCache::default(),
thread_pool: ThreadPool::new(1),
scram_protocol_timeout: std::time::Duration::from_secs(5),
rate_limiter_enabled: true,
rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
rate_limit_ip_subnet: 64,
ip_allowlist_check_enabled: true,
is_vpc_acccess_proxy: false,
is_auth_broker: false,
@@ -654,55 +527,10 @@ mod tests {
}
}
#[test]
fn masked_ip() {
let ip_a = IpAddr::V4([127, 0, 0, 1].into());
let ip_b = IpAddr::V4([127, 0, 0, 2].into());
let ip_c = IpAddr::V4([192, 168, 1, 101].into());
let ip_d = IpAddr::V4([192, 168, 1, 102].into());
let ip_e = IpAddr::V6("abcd:abcd:abcd:abcd:abcd:abcd:abcd:abcd".parse().unwrap());
let ip_f = IpAddr::V6("abcd:abcd:abcd:abcd:1234:abcd:abcd:abcd".parse().unwrap());
assert_ne!(MaskedIp::new(ip_a, 64), MaskedIp::new(ip_b, 64));
assert_ne!(MaskedIp::new(ip_a, 32), MaskedIp::new(ip_b, 32));
assert_eq!(MaskedIp::new(ip_a, 30), MaskedIp::new(ip_b, 30));
assert_eq!(MaskedIp::new(ip_c, 30), MaskedIp::new(ip_d, 30));
assert_ne!(MaskedIp::new(ip_e, 128), MaskedIp::new(ip_f, 128));
assert_eq!(MaskedIp::new(ip_e, 64), MaskedIp::new(ip_f, 64));
}
#[test]
fn test_default_auth_rate_limit_set() {
// these values used to exceed u32::MAX
assert_eq!(
RateBucketInfo::DEFAULT_AUTH_SET,
[
RateBucketInfo {
interval: Duration::from_secs(1),
max_rpi: 1000 * 4096,
},
RateBucketInfo {
interval: Duration::from_secs(60),
max_rpi: 600 * 4096 * 60,
},
RateBucketInfo {
interval: Duration::from_secs(600),
max_rpi: 300 * 4096 * 600,
}
]
);
for x in RateBucketInfo::DEFAULT_AUTH_SET {
let y = x.to_string().parse().unwrap();
assert_eq!(x, y);
}
}
#[tokio::test]
async fn auth_quirks_scram() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
let ctx = RequestContext::test();
let api = Auth {
@@ -784,7 +612,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_cleartext() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
let ctx = RequestContext::test();
let api = Auth {
@@ -838,7 +666,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_password_hack() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
let ctx = RequestContext::test();
let api = Auth {
@@ -887,7 +715,7 @@ mod tests {
.await
.unwrap();
assert_eq!(creds.0.info.endpoint, "my-endpoint");
assert_eq!(creds.info.endpoint, "my-endpoint");
handle.await.unwrap();
}

View File

@@ -5,7 +5,6 @@ use std::net::IpAddr;
use std::str::FromStr;
use itertools::Itertools;
use pq_proto::StartupMessageParams;
use thiserror::Error;
use tracing::{debug, warn};
@@ -13,6 +12,7 @@ use crate::auth::password_hack::parse_endpoint_param;
use crate::context::RequestContext;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, SniGroup, SniKind};
use crate::pqproto::StartupMessageParams;
use crate::proxy::NeonOptions;
use crate::serverless::{AUTH_BROKER_SNI, SERVERLESS_DRIVER_SNI};
use crate::types::{EndpointId, RoleName};

View File

@@ -1,10 +1,8 @@
//! Main authentication flow.
use std::io;
use std::sync::Arc;
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
@@ -13,35 +11,26 @@ use super::{AuthError, PasswordHackPayload};
use crate::context::RequestContext;
use crate::control_plane::AuthSecret;
use crate::intern::EndpointIdInt;
use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
use crate::sasl;
use crate::scram::threadpool::ThreadPool;
use crate::scram::{self};
use crate::stream::{PqStream, Stream};
use crate::tls::TlsServerEndPoint;
/// Every authentication selector is supposed to implement this trait.
pub(crate) trait AuthMethod {
/// Any authentication selector should provide initial backend message
/// containing auth method name and parameters, e.g. md5 salt.
fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
}
/// Initial state of [`AuthFlow`].
pub(crate) struct Begin;
/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
pub(crate) struct Scram<'a>(
pub(crate) &'a scram::ServerSecret,
pub(crate) &'a RequestContext,
);
impl AuthMethod for Scram<'_> {
impl Scram<'_> {
#[inline(always)]
fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
if channel_binding {
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
} else {
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
scram::METHODS_WITHOUT_PLUS,
))
}
@@ -52,13 +41,6 @@ impl AuthMethod for Scram<'_> {
/// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
pub(crate) struct PasswordHack;
impl AuthMethod for PasswordHack {
#[inline(always)]
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
Be::AuthenticationCleartextPassword
}
}
/// Use clear-text password auth called `password` in docs
/// <https://www.postgresql.org/docs/current/auth-password.html>
pub(crate) struct CleartextPassword {
@@ -67,53 +49,37 @@ pub(crate) struct CleartextPassword {
pub(crate) secret: AuthSecret,
}
impl AuthMethod for CleartextPassword {
#[inline(always)]
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
Be::AuthenticationCleartextPassword
}
}
/// This wrapper for [`PqStream`] performs client authentication.
#[must_use]
pub(crate) struct AuthFlow<'a, S, State> {
/// The underlying stream which implements libpq's protocol.
stream: &'a mut PqStream<Stream<S>>,
/// State might contain ancillary data (see [`Self::begin`]).
/// State might contain ancillary data.
state: State,
tls_server_end_point: TlsServerEndPoint,
}
/// Initial state of the stream wrapper.
impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
impl<'a, S: AsyncRead + AsyncWrite + Unpin, M> AuthFlow<'a, S, M> {
/// Create a new wrapper for client authentication.
pub(crate) fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
pub(crate) fn new(stream: &'a mut PqStream<Stream<S>>, method: M) -> Self {
let tls_server_end_point = stream.get_ref().tls_server_end_point();
Self {
stream,
state: Begin,
state: method,
tls_server_end_point,
}
}
/// Move to the next step by sending auth method's name & params to client.
pub(crate) async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
self.stream
.write_message(&method.first_message(self.tls_server_end_point.supported()))
.await?;
Ok(AuthFlow {
stream: self.stream,
state: method,
tls_server_end_point: self.tls_server_end_point,
})
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
/// Perform user authentication. Raise an error in case authentication failed.
pub(crate) async fn get_password(self) -> super::Result<PasswordHackPayload> {
self.stream
.write_message(BeMessage::AuthenticationCleartextPassword);
self.stream.flush().await?;
let msg = self.stream.read_password_message().await?;
let password = msg
.strip_suffix(&[0])
@@ -133,6 +99,10 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
/// Perform user authentication. Raise an error in case authentication failed.
pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
self.stream
.write_message(BeMessage::AuthenticationCleartextPassword);
self.stream.flush().await?;
let msg = self.stream.read_password_message().await?;
let password = msg
.strip_suffix(&[0])
@@ -147,7 +117,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
.await?;
if let sasl::Outcome::Success(_) = &outcome {
self.stream.write_message_noflush(&Be::AuthenticationOk)?;
self.stream.write_message(BeMessage::AuthenticationOk);
}
Ok(outcome)
@@ -159,42 +129,36 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
/// Perform user authentication. Raise an error in case authentication failed.
pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
let Scram(secret, ctx) = self.state;
let channel_binding = self.tls_server_end_point;
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
// send sasl message.
{
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
// Initial client message contains the chosen auth method's name.
let msg = self.stream.read_password_message().await?;
let sasl = sasl::FirstMessage::parse(&msg)
.ok_or(AuthError::MalformedPassword("bad sasl message"))?;
// Currently, the only supported SASL method is SCRAM.
if !scram::METHODS.contains(&sasl.method) {
return Err(super::AuthError::bad_auth_method(sasl.method));
let sasl = self.state.first_message(channel_binding.supported());
self.stream.write_message(sasl);
self.stream.flush().await?;
}
match sasl.method {
SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
SCRAM_SHA_256_PLUS => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus),
_ => {}
}
// complete sasl handshake.
sasl::authenticate(ctx, self.stream, |method| {
// Currently, the only supported SASL method is SCRAM.
match method {
SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
SCRAM_SHA_256_PLUS => {
ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus);
}
method => return Err(sasl::Error::BadAuthMethod(method.into())),
}
// TODO: make this a metric instead
info!("client chooses {}", sasl.method);
// TODO: make this a metric instead
info!("client chooses {}", method);
let outcome = sasl::SaslStream::new(self.stream, sasl.message)
.authenticate(scram::Exchange::new(
secret,
rand::random,
self.tls_server_end_point,
))
.await?;
if let sasl::Outcome::Success(_) = &outcome {
self.stream.write_message_noflush(&Be::AuthenticationOk)?;
}
Ok(outcome)
Ok(scram::Exchange::new(secret, rand::random, channel_binding))
})
.await
.map_err(AuthError::Sasl)
}
}

View File

@@ -32,9 +32,7 @@ use crate::ext::TaskExt;
use crate::http::health_server::AppMetrics;
use crate::intern::RoleNameInt;
use crate::metrics::{Metrics, ThreadPoolMetrics};
use crate::rate_limiter::{
BucketRateLimiter, EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo,
};
use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo};
use crate::scram::threadpool::ThreadPool;
use crate::serverless::cancel_set::CancelSet;
use crate::serverless::{self, GlobalConnPoolOptions};
@@ -69,15 +67,6 @@ struct LocalProxyCliArgs {
/// Can be given multiple times for different bucket sizes.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
user_rps_limit: Vec<RateBucketInfo>,
/// Whether the auth rate limiter actually takes effect (for testing)
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
auth_rate_limit_enabled: bool,
/// Authentication rate limiter max number of hashes per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)]
auth_rate_limit: Vec<RateBucketInfo>,
/// The IP subnet to use when considering whether two IP addresses are considered the same.
#[clap(long, default_value_t = 64)]
auth_rate_limit_ip_subnet: u8,
/// Whether to retry the connection to the compute node
#[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)]
connect_to_compute_retry: String,
@@ -282,9 +271,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
jwks_cache: JwkCache::default(),
thread_pool: ThreadPool::new(0),
scram_protocol_timeout: Duration::from_secs(10),
rate_limiter_enabled: false,
rate_limiter: BucketRateLimiter::new(vec![]),
rate_limit_ip_subnet: 64,
ip_allowlist_check_enabled: true,
is_vpc_acccess_proxy: false,
is_auth_broker: false,

View File

@@ -4,8 +4,9 @@
//! This allows connecting to pods/services running in the same Kubernetes cluster from
//! the outside. Similar to an ingress controller for HTTPS.
use std::net::SocketAddr;
use std::path::Path;
use std::{net::SocketAddr, sync::Arc};
use std::sync::Arc;
use anyhow::{Context, anyhow, bail, ensure};
use clap::Arg;
@@ -17,6 +18,7 @@ use rustls::pki_types::{DnsName, PrivateKeyDer};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio_rustls::TlsConnector;
use tokio_rustls::server::TlsStream;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, error, info};
use utils::project_git_version;
@@ -24,10 +26,12 @@ use utils::sentry_init::init_sentry;
use crate::context::RequestContext;
use crate::metrics::{Metrics, ThreadPoolMetrics};
use crate::pqproto::FeStartupPacket;
use crate::protocol2::ConnectionInfo;
use crate::proxy::{ErrorSource, copy_bidirectional_client_compute, run_until_cancelled};
use crate::proxy::{
ErrorSource, TlsRequired, copy_bidirectional_client_compute, run_until_cancelled,
};
use crate::stream::{PqStream, Stream};
use crate::tls::TlsServerEndPoint;
project_git_version!(GIT_VERSION);
@@ -84,7 +88,7 @@ pub async fn run() -> anyhow::Result<()> {
.parse()?;
// Configure TLS
let (tls_config, tls_server_end_point): (Arc<rustls::ServerConfig>, TlsServerEndPoint) = match (
let tls_config = match (
args.get_one::<String>("tls-key"),
args.get_one::<String>("tls-cert"),
) {
@@ -117,7 +121,6 @@ pub async fn run() -> anyhow::Result<()> {
dest.clone(),
tls_config.clone(),
None,
tls_server_end_point,
proxy_listener,
cancellation_token.clone(),
))
@@ -127,7 +130,6 @@ pub async fn run() -> anyhow::Result<()> {
dest,
tls_config,
Some(compute_tls_config),
tls_server_end_point,
proxy_listener_compute_tls,
cancellation_token.clone(),
))
@@ -154,7 +156,7 @@ pub async fn run() -> anyhow::Result<()> {
pub(super) fn parse_tls(
key_path: &Path,
cert_path: &Path,
) -> anyhow::Result<(Arc<rustls::ServerConfig>, TlsServerEndPoint)> {
) -> anyhow::Result<Arc<rustls::ServerConfig>> {
let key = {
let key_bytes = std::fs::read(key_path).context("TLS key file")?;
@@ -187,10 +189,6 @@ pub(super) fn parse_tls(
})?
};
// needed for channel bindings
let first_cert = cert_chain.first().context("missing certificate")?;
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let tls_config =
rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
@@ -199,14 +197,13 @@ pub(super) fn parse_tls(
.with_single_cert(cert_chain, key)?
.into();
Ok((tls_config, tls_server_end_point))
Ok(tls_config)
}
pub(super) async fn task_main(
dest_suffix: Arc<String>,
tls_config: Arc<rustls::ServerConfig>,
compute_tls_config: Option<Arc<rustls::ClientConfig>>,
tls_server_end_point: TlsServerEndPoint,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
) -> anyhow::Result<()> {
@@ -242,15 +239,7 @@ pub(super) async fn task_main(
crate::metrics::Protocol::SniRouter,
"sni",
);
handle_client(
ctx,
dest_suffix,
tls_config,
compute_tls_config,
tls_server_end_point,
socket,
)
.await
handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await
}
.unwrap_or_else(|e| {
// Acknowledge that the task has finished with an error.
@@ -269,55 +258,26 @@ pub(super) async fn task_main(
Ok(())
}
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
ctx: &RequestContext,
raw_stream: S,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
) -> anyhow::Result<Stream<S>> {
let mut stream = PqStream::new(Stream::from_raw(raw_stream));
let msg = stream.read_startup_packet().await?;
use pq_proto::FeStartupPacket::SslRequest;
) -> anyhow::Result<TlsStream<S>> {
let (mut stream, msg) = PqStream::parse_startup(Stream::from_raw(raw_stream)).await?;
match msg {
SslRequest { direct: false } => {
stream
.write_message(&pq_proto::BeMessage::EncryptionResponse(true))
.await?;
FeStartupPacket::SslRequest { direct: None } => {
let raw = stream.accept_tls().await?;
// Upgrade raw stream into a secure TLS-backed stream.
// NOTE: We've consumed `tls`; this fact will be used later.
let (raw, read_buf) = stream.into_inner();
// TODO: Normally, client doesn't send any data before
// server says TLS handshake is ok and read_buf is empty.
// However, you could imagine pipelining of postgres
// SSLRequest + TLS ClientHello in one hunk similar to
// pipelining in our node js driver. We should probably
// support that by chaining read_buf with the stream.
if !read_buf.is_empty() {
bail!("data is sent before server replied with EncryptionResponse");
}
Ok(Stream::Tls {
tls: Box::new(
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
.await?,
),
tls_server_end_point,
})
Ok(raw
.upgrade(tls_config, !ctx.has_private_peer_addr())
.await?)
}
unexpected => {
info!(
?unexpected,
"unexpected startup packet, rejecting connection"
);
stream
.throw_error_str(ERR_INSECURE_CONNECTION, crate::error::ErrorKind::User, None)
.await?
Err(stream.throw_error(TlsRequired, None).await)?
}
}
}
@@ -327,15 +287,18 @@ async fn handle_client(
dest_suffix: Arc<String>,
tls_config: Arc<rustls::ServerConfig>,
compute_tls_config: Option<Arc<rustls::ClientConfig>>,
tls_server_end_point: TlsServerEndPoint,
stream: impl AsyncRead + AsyncWrite + Unpin,
) -> anyhow::Result<()> {
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?;
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config).await?;
// Cut off first part of the SNI domain
// We receive required destination details in the format of
// `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain`
let sni = tls_stream.sni_hostname().ok_or(anyhow!("SNI missing"))?;
let sni = tls_stream
.get_ref()
.1
.server_name()
.ok_or(anyhow!("SNI missing"))?;
let dest: Vec<&str> = sni
.split_once('.')
.context("invalid SNI")?

View File

@@ -20,7 +20,7 @@ use utils::sentry_init::init_sentry;
use utils::{project_build_tag, project_git_version};
use crate::auth::backend::jwt::JwkCache;
use crate::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned};
use crate::auth::backend::{ConsoleRedirectBackend, MaybeOwned};
use crate::cancellation::{CancellationHandler, handle_cancel_messages};
use crate::config::{
self, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig, ProjectInfoCacheOptions,
@@ -29,9 +29,7 @@ use crate::config::{
use crate::context::parquet::ParquetUploadArgs;
use crate::http::health_server::AppMetrics;
use crate::metrics::Metrics;
use crate::rate_limiter::{
EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo, WakeComputeRateLimiter,
};
use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo, WakeComputeRateLimiter};
use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::redis::kv_ops::RedisKVClient;
use crate::redis::{elasticache, notifications};
@@ -154,15 +152,6 @@ struct ProxyCliArgs {
/// Wake compute rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
wake_compute_limit: Vec<RateBucketInfo>,
/// Whether the auth rate limiter actually takes effect (for testing)
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
auth_rate_limit_enabled: bool,
/// Authentication rate limiter max number of hashes per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)]
auth_rate_limit: Vec<RateBucketInfo>,
/// The IP subnet to use when considering whether two IP addresses are considered the same.
#[clap(long, default_value_t = 64)]
auth_rate_limit_ip_subnet: u8,
/// Redis rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_REDIS_SET)]
redis_rps_limit: Vec<RateBucketInfo>,
@@ -410,22 +399,9 @@ pub async fn run() -> anyhow::Result<()> {
Some(tx_cancel),
));
// bit of a hack - find the min rps and max rps supported and turn it into
// leaky bucket config instead
let max = args
.endpoint_rps_limit
.iter()
.map(|x| x.rps())
.max_by(f64::total_cmp)
.unwrap_or(EndpointRateLimiter::DEFAULT.max);
let rps = args
.endpoint_rps_limit
.iter()
.map(|x| x.rps())
.min_by(f64::total_cmp)
.unwrap_or(EndpointRateLimiter::DEFAULT.rps);
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
LeakyBucketConfig { rps, max },
RateBucketInfo::to_leaky_bucket(&args.endpoint_rps_limit)
.unwrap_or(EndpointRateLimiter::DEFAULT),
64,
));
@@ -476,8 +452,7 @@ pub async fn run() -> anyhow::Result<()> {
let key_path = args.tls_key.expect("already asserted it is set");
let cert_path = args.tls_cert.expect("already asserted it is set");
let (tls_config, tls_server_end_point) =
super::pg_sni_router::parse_tls(&key_path, &cert_path)?;
let tls_config = super::pg_sni_router::parse_tls(&key_path, &cert_path)?;
let dest = Arc::new(dest);
@@ -485,7 +460,6 @@ pub async fn run() -> anyhow::Result<()> {
dest.clone(),
tls_config.clone(),
None,
tls_server_end_point,
listen,
cancellation_token.clone(),
));
@@ -494,7 +468,6 @@ pub async fn run() -> anyhow::Result<()> {
dest,
tls_config,
Some(config.connect_to_compute.tls.clone()),
tls_server_end_point,
listen_tls,
cancellation_token.clone(),
));
@@ -681,9 +654,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
jwks_cache: JwkCache::default(),
thread_pool,
scram_protocol_timeout: args.scram_protocol_timeout,
rate_limiter_enabled: args.auth_rate_limit_enabled,
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
ip_allowlist_check_enabled: !args.is_private_access_proxy,
is_vpc_acccess_proxy: args.is_private_access_proxy,
is_auth_broker: args.is_auth_broker,

View File

@@ -1,30 +1,25 @@
use std::collections::HashSet;
use std::collections::{HashMap, HashSet, hash_map};
use std::convert::Infallible;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::time::Duration;
use async_trait::async_trait;
use clashmap::ClashMap;
use clashmap::mapref::one::Ref;
use rand::{Rng, thread_rng};
use smol_str::SmolStr;
use tokio::sync::Mutex;
use tokio::time::Instant;
use tracing::{debug, info};
use super::{Cache, Cached};
use crate::auth::IpPattern;
use crate::config::ProjectInfoCacheOptions;
use crate::control_plane::{AccessBlockerFlags, AuthSecret};
use crate::control_plane::{EndpointAccessControl, RoleAccessControl};
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
use crate::types::{EndpointId, RoleName};
#[async_trait]
pub(crate) trait ProjectInfoCache {
fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt);
fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec<ProjectIdInt>);
fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt);
fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt);
fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt);
fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt);
fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
async fn decrement_active_listeners(&self);
async fn increment_active_listeners(&self);
@@ -42,6 +37,10 @@ impl<T> Entry<T> {
value,
}
}
pub(crate) fn get(&self, valid_since: Instant) -> Option<&T> {
(valid_since < self.created_at).then_some(&self.value)
}
}
impl<T> From<T> for Entry<T> {
@@ -50,101 +49,32 @@ impl<T> From<T> for Entry<T> {
}
}
#[derive(Default)]
struct EndpointInfo {
secret: std::collections::HashMap<RoleNameInt, Entry<Option<AuthSecret>>>,
allowed_ips: Option<Entry<Arc<Vec<IpPattern>>>>,
block_public_or_vpc_access: Option<Entry<AccessBlockerFlags>>,
allowed_vpc_endpoint_ids: Option<Entry<Arc<Vec<String>>>>,
role_controls: HashMap<RoleNameInt, Entry<RoleAccessControl>>,
controls: Option<Entry<EndpointAccessControl>>,
}
impl EndpointInfo {
fn check_ignore_cache(ignore_cache_since: Option<Instant>, created_at: Instant) -> bool {
match ignore_cache_since {
None => false,
Some(t) => t < created_at,
}
}
pub(crate) fn get_role_secret(
&self,
role_name: RoleNameInt,
valid_since: Instant,
ignore_cache_since: Option<Instant>,
) -> Option<(Option<AuthSecret>, bool)> {
if let Some(secret) = self.secret.get(&role_name) {
if valid_since < secret.created_at {
return Some((
secret.value.clone(),
Self::check_ignore_cache(ignore_cache_since, secret.created_at),
));
}
}
None
) -> Option<RoleAccessControl> {
let controls = self.role_controls.get(&role_name)?;
controls.get(valid_since).cloned()
}
pub(crate) fn get_allowed_ips(
&self,
valid_since: Instant,
ignore_cache_since: Option<Instant>,
) -> Option<(Arc<Vec<IpPattern>>, bool)> {
if let Some(allowed_ips) = &self.allowed_ips {
if valid_since < allowed_ips.created_at {
return Some((
allowed_ips.value.clone(),
Self::check_ignore_cache(ignore_cache_since, allowed_ips.created_at),
));
}
}
None
}
pub(crate) fn get_allowed_vpc_endpoint_ids(
&self,
valid_since: Instant,
ignore_cache_since: Option<Instant>,
) -> Option<(Arc<Vec<String>>, bool)> {
if let Some(allowed_vpc_endpoint_ids) = &self.allowed_vpc_endpoint_ids {
if valid_since < allowed_vpc_endpoint_ids.created_at {
return Some((
allowed_vpc_endpoint_ids.value.clone(),
Self::check_ignore_cache(
ignore_cache_since,
allowed_vpc_endpoint_ids.created_at,
),
));
}
}
None
}
pub(crate) fn get_block_public_or_vpc_access(
&self,
valid_since: Instant,
ignore_cache_since: Option<Instant>,
) -> Option<(AccessBlockerFlags, bool)> {
if let Some(block_public_or_vpc_access) = &self.block_public_or_vpc_access {
if valid_since < block_public_or_vpc_access.created_at {
return Some((
block_public_or_vpc_access.value.clone(),
Self::check_ignore_cache(
ignore_cache_since,
block_public_or_vpc_access.created_at,
),
));
}
}
None
pub(crate) fn get_controls(&self, valid_since: Instant) -> Option<EndpointAccessControl> {
let controls = self.controls.as_ref()?;
controls.get(valid_since).cloned()
}
pub(crate) fn invalidate_allowed_ips(&mut self) {
self.allowed_ips = None;
}
pub(crate) fn invalidate_allowed_vpc_endpoint_ids(&mut self) {
self.allowed_vpc_endpoint_ids = None;
}
pub(crate) fn invalidate_block_public_or_vpc_access(&mut self) {
self.block_public_or_vpc_access = None;
pub(crate) fn invalidate_endpoint(&mut self) {
self.controls = None;
}
pub(crate) fn invalidate_role_secret(&mut self, role_name: RoleNameInt) {
self.secret.remove(&role_name);
self.role_controls.remove(&role_name);
}
}
@@ -170,34 +100,22 @@ pub struct ProjectInfoCacheImpl {
#[async_trait]
impl ProjectInfoCache for ProjectInfoCacheImpl {
fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec<ProjectIdInt>) {
info!(
"invalidating allowed vpc endpoint ids for projects `{}`",
project_ids
.iter()
.map(|id| id.to_string())
.collect::<Vec<_>>()
.join(", ")
);
for project_id in project_ids {
let endpoints = self
.project2ep
.get(&project_id)
.map(|kv| kv.value().clone())
.unwrap_or_default();
for endpoint_id in endpoints {
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_allowed_vpc_endpoint_ids();
}
fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) {
info!("invalidating endpoint access for project `{project_id}`");
let endpoints = self
.project2ep
.get(&project_id)
.map(|kv| kv.value().clone())
.unwrap_or_default();
for endpoint_id in endpoints {
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_endpoint();
}
}
}
fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt) {
info!(
"invalidating allowed vpc endpoint ids for org `{}`",
account_id
);
fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) {
info!("invalidating endpoint access for org `{account_id}`");
let endpoints = self
.account2ep
.get(&account_id)
@@ -205,41 +123,11 @@ impl ProjectInfoCache for ProjectInfoCacheImpl {
.unwrap_or_default();
for endpoint_id in endpoints {
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_allowed_vpc_endpoint_ids();
endpoint_info.invalidate_endpoint();
}
}
}
fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt) {
info!(
"invalidating block public or vpc access for project `{}`",
project_id
);
let endpoints = self
.project2ep
.get(&project_id)
.map(|kv| kv.value().clone())
.unwrap_or_default();
for endpoint_id in endpoints {
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_block_public_or_vpc_access();
}
}
}
fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt) {
info!("invalidating allowed ips for project `{}`", project_id);
let endpoints = self
.project2ep
.get(&project_id)
.map(|kv| kv.value().clone())
.unwrap_or_default();
for endpoint_id in endpoints {
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_allowed_ips();
}
}
}
fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt) {
info!(
"invalidating role secret for project_id `{}` and role_name `{}`",
@@ -256,6 +144,7 @@ impl ProjectInfoCache for ProjectInfoCacheImpl {
}
}
}
async fn decrement_active_listeners(&self) {
let mut listeners_guard = self.active_listeners_lock.lock().await;
if *listeners_guard == 0 {
@@ -293,155 +182,71 @@ impl ProjectInfoCacheImpl {
}
}
fn get_endpoint_cache(
&self,
endpoint_id: &EndpointId,
) -> Option<Ref<'_, EndpointIdInt, EndpointInfo>> {
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
self.cache.get(&endpoint_id)
}
pub(crate) fn get_role_secret(
&self,
endpoint_id: &EndpointId,
role_name: &RoleName,
) -> Option<Cached<&Self, Option<AuthSecret>>> {
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
) -> Option<RoleAccessControl> {
let valid_since = self.get_cache_times();
let role_name = RoleNameInt::get(role_name)?;
let (valid_since, ignore_cache_since) = self.get_cache_times();
let endpoint_info = self.cache.get(&endpoint_id)?;
let (value, ignore_cache) =
endpoint_info.get_role_secret(role_name, valid_since, ignore_cache_since)?;
if !ignore_cache {
let cached = Cached {
token: Some((
self,
CachedLookupInfo::new_role_secret(endpoint_id, role_name),
)),
value,
};
return Some(cached);
}
Some(Cached::new_uncached(value))
}
pub(crate) fn get_allowed_ips(
&self,
endpoint_id: &EndpointId,
) -> Option<Cached<&Self, Arc<Vec<IpPattern>>>> {
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
let (valid_since, ignore_cache_since) = self.get_cache_times();
let endpoint_info = self.cache.get(&endpoint_id)?;
let value = endpoint_info.get_allowed_ips(valid_since, ignore_cache_since);
let (value, ignore_cache) = value?;
if !ignore_cache {
let cached = Cached {
token: Some((self, CachedLookupInfo::new_allowed_ips(endpoint_id))),
value,
};
return Some(cached);
}
Some(Cached::new_uncached(value))
}
pub(crate) fn get_allowed_vpc_endpoint_ids(
&self,
endpoint_id: &EndpointId,
) -> Option<Cached<&Self, Arc<Vec<String>>>> {
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
let (valid_since, ignore_cache_since) = self.get_cache_times();
let endpoint_info = self.cache.get(&endpoint_id)?;
let value = endpoint_info.get_allowed_vpc_endpoint_ids(valid_since, ignore_cache_since);
let (value, ignore_cache) = value?;
if !ignore_cache {
let cached = Cached {
token: Some((
self,
CachedLookupInfo::new_allowed_vpc_endpoint_ids(endpoint_id),
)),
value,
};
return Some(cached);
}
Some(Cached::new_uncached(value))
}
pub(crate) fn get_block_public_or_vpc_access(
&self,
endpoint_id: &EndpointId,
) -> Option<Cached<&Self, AccessBlockerFlags>> {
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
let (valid_since, ignore_cache_since) = self.get_cache_times();
let endpoint_info = self.cache.get(&endpoint_id)?;
let value = endpoint_info.get_block_public_or_vpc_access(valid_since, ignore_cache_since);
let (value, ignore_cache) = value?;
if !ignore_cache {
let cached = Cached {
token: Some((
self,
CachedLookupInfo::new_block_public_or_vpc_access(endpoint_id),
)),
value,
};
return Some(cached);
}
Some(Cached::new_uncached(value))
let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
endpoint_info.get_role_secret(role_name, valid_since)
}
pub(crate) fn insert_role_secret(
pub(crate) fn get_endpoint_access(
&self,
project_id: ProjectIdInt,
endpoint_id: EndpointIdInt,
role_name: RoleNameInt,
secret: Option<AuthSecret>,
) {
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
return;
}
self.insert_project2endpoint(project_id, endpoint_id);
let mut entry = self.cache.entry(endpoint_id).or_default();
if entry.secret.len() < self.config.max_roles {
entry.secret.insert(role_name, secret.into());
}
endpoint_id: &EndpointId,
) -> Option<EndpointAccessControl> {
let valid_since = self.get_cache_times();
let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
endpoint_info.get_controls(valid_since)
}
pub(crate) fn insert_allowed_ips(
&self,
project_id: ProjectIdInt,
endpoint_id: EndpointIdInt,
allowed_ips: Arc<Vec<IpPattern>>,
) {
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
return;
}
self.insert_project2endpoint(project_id, endpoint_id);
self.cache.entry(endpoint_id).or_default().allowed_ips = Some(allowed_ips.into());
}
pub(crate) fn insert_allowed_vpc_endpoint_ids(
pub(crate) fn insert_endpoint_access(
&self,
account_id: Option<AccountIdInt>,
project_id: ProjectIdInt,
endpoint_id: EndpointIdInt,
allowed_vpc_endpoint_ids: Arc<Vec<String>>,
role_name: RoleNameInt,
controls: EndpointAccessControl,
role_controls: RoleAccessControl,
) {
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
return;
}
if let Some(account_id) = account_id {
self.insert_account2endpoint(account_id, endpoint_id);
}
self.insert_project2endpoint(project_id, endpoint_id);
self.cache
.entry(endpoint_id)
.or_default()
.allowed_vpc_endpoint_ids = Some(allowed_vpc_endpoint_ids.into());
}
pub(crate) fn insert_block_public_or_vpc_access(
&self,
project_id: ProjectIdInt,
endpoint_id: EndpointIdInt,
access_blockers: AccessBlockerFlags,
) {
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
return;
}
self.insert_project2endpoint(project_id, endpoint_id);
self.cache
.entry(endpoint_id)
.or_default()
.block_public_or_vpc_access = Some(access_blockers.into());
let controls = Entry::from(controls);
let role_controls = Entry::from(role_controls);
match self.cache.entry(endpoint_id) {
clashmap::Entry::Vacant(e) => {
e.insert(EndpointInfo {
role_controls: HashMap::from_iter([(role_name, role_controls)]),
controls: Some(controls),
});
}
clashmap::Entry::Occupied(mut e) => {
let ep = e.get_mut();
ep.controls = Some(controls);
if ep.role_controls.len() < self.config.max_roles {
ep.role_controls.insert(role_name, role_controls);
}
}
}
}
fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) {
@@ -452,6 +257,7 @@ impl ProjectInfoCacheImpl {
.insert(project_id, HashSet::from([endpoint_id]));
}
}
fn insert_account2endpoint(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt) {
if let Some(mut endpoints) = self.account2ep.get_mut(&account_id) {
endpoints.insert(endpoint_id);
@@ -460,21 +266,57 @@ impl ProjectInfoCacheImpl {
.insert(account_id, HashSet::from([endpoint_id]));
}
}
fn get_cache_times(&self) -> (Instant, Option<Instant>) {
let mut valid_since = Instant::now() - self.config.ttl;
// Only ignore cache if ttl is disabled.
fn ignore_ttl_since(&self) -> Option<Instant> {
let ttl_disabled_since_us = self
.ttl_disabled_since_us
.load(std::sync::atomic::Ordering::Relaxed);
let ignore_cache_since = if ttl_disabled_since_us == u64::MAX {
None
} else {
let ignore_cache_since = self.start_time + Duration::from_micros(ttl_disabled_since_us);
if ttl_disabled_since_us == u64::MAX {
return None;
}
Some(self.start_time + Duration::from_micros(ttl_disabled_since_us))
}
fn get_cache_times(&self) -> Instant {
let mut valid_since = Instant::now() - self.config.ttl;
if let Some(ignore_ttl_since) = self.ignore_ttl_since() {
// We are fine if entry is not older than ttl or was added before we are getting notifications.
valid_since = valid_since.min(ignore_cache_since);
Some(ignore_cache_since)
valid_since = valid_since.min(ignore_ttl_since);
}
valid_since
}
pub fn maybe_invalidate_role_secret(&self, endpoint_id: &EndpointId, role_name: &RoleName) {
let Some(endpoint_id) = EndpointIdInt::get(endpoint_id) else {
return;
};
(valid_since, ignore_cache_since)
let Some(role_name) = RoleNameInt::get(role_name) else {
return;
};
let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) else {
return;
};
let entry = endpoint_info.role_controls.entry(role_name);
let hash_map::Entry::Occupied(role_controls) = entry else {
return;
};
let created_at = role_controls.get().created_at;
let expire = match self.ignore_ttl_since() {
// if ignoring TTL, we should still try and roll the password if it's old
// and we the client gave an incorrect password. There could be some lag on the redis channel.
Some(_) => created_at + self.config.ttl < Instant::now(),
// edge case: redis is down, let's be generous and invalidate the cache immediately.
None => true,
};
if expire {
role_controls.remove();
}
}
pub async fn gc_worker(&self) -> anyhow::Result<Infallible> {
@@ -509,84 +351,12 @@ impl ProjectInfoCacheImpl {
}
}
/// Lookup info for project info cache.
/// This is used to invalidate cache entries.
pub(crate) struct CachedLookupInfo {
/// Search by this key.
endpoint_id: EndpointIdInt,
lookup_type: LookupType,
}
impl CachedLookupInfo {
pub(self) fn new_role_secret(endpoint_id: EndpointIdInt, role_name: RoleNameInt) -> Self {
Self {
endpoint_id,
lookup_type: LookupType::RoleSecret(role_name),
}
}
pub(self) fn new_allowed_ips(endpoint_id: EndpointIdInt) -> Self {
Self {
endpoint_id,
lookup_type: LookupType::AllowedIps,
}
}
pub(self) fn new_allowed_vpc_endpoint_ids(endpoint_id: EndpointIdInt) -> Self {
Self {
endpoint_id,
lookup_type: LookupType::AllowedVpcEndpointIds,
}
}
pub(self) fn new_block_public_or_vpc_access(endpoint_id: EndpointIdInt) -> Self {
Self {
endpoint_id,
lookup_type: LookupType::BlockPublicOrVpcAccess,
}
}
}
enum LookupType {
RoleSecret(RoleNameInt),
AllowedIps,
AllowedVpcEndpointIds,
BlockPublicOrVpcAccess,
}
impl Cache for ProjectInfoCacheImpl {
type Key = SmolStr;
// Value is not really used here, but we need to specify it.
type Value = SmolStr;
type LookupInfo<Key> = CachedLookupInfo;
fn invalidate(&self, key: &Self::LookupInfo<SmolStr>) {
match &key.lookup_type {
LookupType::RoleSecret(role_name) => {
if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
endpoint_info.invalidate_role_secret(*role_name);
}
}
LookupType::AllowedIps => {
if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
endpoint_info.invalidate_allowed_ips();
}
}
LookupType::AllowedVpcEndpointIds => {
if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
endpoint_info.invalidate_allowed_vpc_endpoint_ids();
}
}
LookupType::BlockPublicOrVpcAccess => {
if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
endpoint_info.invalidate_block_public_or_vpc_access();
}
}
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use crate::control_plane::{AccessBlockerFlags, AuthSecret};
use crate::scram::ServerSecret;
use crate::types::ProjectId;
@@ -601,6 +371,8 @@ mod tests {
});
let project_id: ProjectId = "project".into();
let endpoint_id: EndpointId = "endpoint".into();
let account_id: Option<AccountIdInt> = None;
let user1: RoleName = "user1".into();
let user2: RoleName = "user2".into();
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
@@ -609,183 +381,73 @@ mod tests {
"127.0.0.1".parse().unwrap(),
"127.0.0.2".parse().unwrap(),
]);
cache.insert_role_secret(
cache.insert_endpoint_access(
account_id,
(&project_id).into(),
(&endpoint_id).into(),
(&user1).into(),
secret1.clone(),
EndpointAccessControl {
allowed_ips: allowed_ips.clone(),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
},
RoleAccessControl {
secret: secret1.clone(),
},
);
cache.insert_role_secret(
cache.insert_endpoint_access(
account_id,
(&project_id).into(),
(&endpoint_id).into(),
(&user2).into(),
secret2.clone(),
);
cache.insert_allowed_ips(
(&project_id).into(),
(&endpoint_id).into(),
allowed_ips.clone(),
EndpointAccessControl {
allowed_ips: allowed_ips.clone(),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
},
RoleAccessControl {
secret: secret2.clone(),
},
);
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
assert!(cached.cached());
assert_eq!(cached.value, secret1);
assert_eq!(cached.secret, secret1);
let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
assert!(cached.cached());
assert_eq!(cached.value, secret2);
assert_eq!(cached.secret, secret2);
// Shouldn't add more than 2 roles.
let user3: RoleName = "user3".into();
let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32])));
cache.insert_role_secret(
cache.insert_endpoint_access(
account_id,
(&project_id).into(),
(&endpoint_id).into(),
(&user3).into(),
secret3.clone(),
EndpointAccessControl {
allowed_ips: allowed_ips.clone(),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
},
RoleAccessControl {
secret: secret3.clone(),
},
);
assert!(cache.get_role_secret(&endpoint_id, &user3).is_none());
let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
assert!(cached.cached());
assert_eq!(cached.value, allowed_ips);
let cached = cache.get_endpoint_access(&endpoint_id).unwrap();
assert_eq!(cached.allowed_ips, allowed_ips);
tokio::time::advance(Duration::from_secs(2)).await;
let cached = cache.get_role_secret(&endpoint_id, &user1);
assert!(cached.is_none());
let cached = cache.get_role_secret(&endpoint_id, &user2);
assert!(cached.is_none());
let cached = cache.get_allowed_ips(&endpoint_id);
let cached = cache.get_endpoint_access(&endpoint_id);
assert!(cached.is_none());
}
#[tokio::test]
async fn test_project_info_cache_invalidations() {
tokio::time::pause();
let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
size: 2,
max_roles: 2,
ttl: Duration::from_secs(1),
gc_interval: Duration::from_secs(600),
}));
cache.clone().increment_active_listeners().await;
tokio::time::advance(Duration::from_secs(2)).await;
let project_id: ProjectId = "project".into();
let endpoint_id: EndpointId = "endpoint".into();
let user1: RoleName = "user1".into();
let user2: RoleName = "user2".into();
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32])));
let allowed_ips = Arc::new(vec![
"127.0.0.1".parse().unwrap(),
"127.0.0.2".parse().unwrap(),
]);
cache.insert_role_secret(
(&project_id).into(),
(&endpoint_id).into(),
(&user1).into(),
secret1.clone(),
);
cache.insert_role_secret(
(&project_id).into(),
(&endpoint_id).into(),
(&user2).into(),
secret2.clone(),
);
cache.insert_allowed_ips(
(&project_id).into(),
(&endpoint_id).into(),
allowed_ips.clone(),
);
tokio::time::advance(Duration::from_secs(2)).await;
// Nothing should be invalidated.
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
// TTL is disabled, so it should be impossible to invalidate this value.
assert!(!cached.cached());
assert_eq!(cached.value, secret1);
cached.invalidate(); // Shouldn't do anything.
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
assert_eq!(cached.value, secret1);
let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
assert!(!cached.cached());
assert_eq!(cached.value, secret2);
// The only way to invalidate this value is to invalidate via the api.
cache.invalidate_role_secret_for_project((&project_id).into(), (&user2).into());
assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
assert!(!cached.cached());
assert_eq!(cached.value, allowed_ips);
}
#[tokio::test]
async fn test_increment_active_listeners_invalidate_added_before() {
tokio::time::pause();
let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
size: 2,
max_roles: 2,
ttl: Duration::from_secs(1),
gc_interval: Duration::from_secs(600),
}));
let project_id: ProjectId = "project".into();
let endpoint_id: EndpointId = "endpoint".into();
let user1: RoleName = "user1".into();
let user2: RoleName = "user2".into();
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32])));
let allowed_ips = Arc::new(vec![
"127.0.0.1".parse().unwrap(),
"127.0.0.2".parse().unwrap(),
]);
cache.insert_role_secret(
(&project_id).into(),
(&endpoint_id).into(),
(&user1).into(),
secret1.clone(),
);
cache.clone().increment_active_listeners().await;
tokio::time::advance(Duration::from_millis(100)).await;
cache.insert_role_secret(
(&project_id).into(),
(&endpoint_id).into(),
(&user2).into(),
secret2.clone(),
);
// Added before ttl was disabled + ttl should be still cached.
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
assert!(cached.cached());
let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
assert!(cached.cached());
tokio::time::advance(Duration::from_secs(1)).await;
// Added before ttl was disabled + ttl should expire.
assert!(cache.get_role_secret(&endpoint_id, &user1).is_none());
assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
// Added after ttl was disabled + ttl should not be cached.
cache.insert_allowed_ips(
(&project_id).into(),
(&endpoint_id).into(),
allowed_ips.clone(),
);
let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
assert!(!cached.cached());
tokio::time::advance(Duration::from_secs(1)).await;
// Added before ttl was disabled + ttl still should expire.
assert!(cache.get_role_secret(&endpoint_id, &user1).is_none());
assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
// Shouldn't be invalidated.
let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
assert!(!cached.cached());
assert_eq!(cached.value, allowed_ips);
}
}

View File

@@ -5,7 +5,6 @@ use anyhow::{Context, anyhow};
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use postgres_client::CancelToken;
use postgres_client::tls::MakeTlsConnect;
use pq_proto::CancelKeyData;
use redis::{Cmd, FromRedisValue, Value};
use serde::{Deserialize, Serialize};
use thiserror::Error;
@@ -13,15 +12,15 @@ use tokio::net::TcpStream;
use tokio::sync::{mpsc, oneshot};
use tracing::{debug, error, info, warn};
use crate::auth::AuthError;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::{AuthError, check_peer_addr_is_in_list};
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::ControlPlaneApi;
use crate::error::ReportableError;
use crate::ext::LockExt;
use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind};
use crate::protocol2::ConnectionInfoExtra;
use crate::pqproto::CancelKeyData;
use crate::rate_limiter::LeakyBucketRateLimiter;
use crate::redis::keys::KeyPrefix;
use crate::redis::kv_ops::RedisKVClient;
@@ -272,13 +271,7 @@ pub(crate) enum CancelError {
#[error("rate limit exceeded")]
RateLimit,
#[error("IP is not allowed")]
IpNotAllowed,
#[error("VPC endpoint id is not allowed to connect")]
VpcEndpointIdNotAllowed,
#[error("Authentication backend error")]
#[error("Authentication error")]
AuthError(#[from] AuthError),
#[error("key not found")]
@@ -297,10 +290,7 @@ impl ReportableError for CancelError {
}
CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
CancelError::RateLimit => crate::error::ErrorKind::RateLimit,
CancelError::IpNotAllowed
| CancelError::VpcEndpointIdNotAllowed
| CancelError::NotFound => crate::error::ErrorKind::User,
CancelError::AuthError(_) => crate::error::ErrorKind::ControlPlane,
CancelError::NotFound | CancelError::AuthError(_) => crate::error::ErrorKind::User,
CancelError::InternalError => crate::error::ErrorKind::Service,
}
}
@@ -422,7 +412,13 @@ impl CancellationHandler {
IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
};
if !self.limiter.lock_propagate_poison().check(subnet_key, 1) {
let allowed = {
let rate_limit_config = None;
let limiter = self.limiter.lock_propagate_poison();
limiter.check(subnet_key, rate_limit_config, 1)
};
if !allowed {
// log only the subnet part of the IP address to know which subnet is rate limited
tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}");
Metrics::get()
@@ -450,52 +446,13 @@ impl CancellationHandler {
return Err(CancelError::NotFound);
};
if check_ip_allowed {
let ip_allowlist = auth_backend
.get_allowed_ips(&ctx, &cancel_closure.user_info)
.await
.map_err(|e| CancelError::AuthError(e.into()))?;
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &ip_allowlist) {
// log it here since cancel_session could be spawned in a task
tracing::warn!(
"IP is not allowed to cancel the query: {key}, address: {}",
ctx.peer_addr()
);
return Err(CancelError::IpNotAllowed);
}
}
// check if a VPC endpoint ID is coming in and if yes, if it's allowed
let access_blocks = auth_backend
.get_block_public_or_vpc_access(&ctx, &cancel_closure.user_info)
let info = &cancel_closure.user_info;
let access_controls = auth_backend
.get_endpoint_access_control(&ctx, &info.endpoint, &info.user)
.await
.map_err(|e| CancelError::AuthError(e.into()))?;
if check_vpc_allowed {
if access_blocks.vpc_access_blocked {
return Err(CancelError::AuthError(AuthError::NetworkNotAllowed));
}
let incoming_vpc_endpoint_id = match ctx.extra() {
None => return Err(CancelError::AuthError(AuthError::MissingVPCEndpointId)),
Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(),
Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(),
};
let allowed_vpc_endpoint_ids = auth_backend
.get_allowed_vpc_endpoint_ids(&ctx, &cancel_closure.user_info)
.await
.map_err(|e| CancelError::AuthError(e.into()))?;
// TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that.
if !allowed_vpc_endpoint_ids.is_empty()
&& !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id)
{
return Err(CancelError::VpcEndpointIdNotAllowed);
}
} else if access_blocks.public_access_blocked {
return Err(CancelError::VpcEndpointIdNotAllowed);
}
access_controls.check(&ctx, check_ip_allowed, check_vpc_allowed)?;
Metrics::get()
.proxy

View File

@@ -8,7 +8,6 @@ use itertools::Itertools;
use postgres_client::tls::MakeTlsConnect;
use postgres_client::{CancelToken, RawConnection};
use postgres_protocol::message::backend::NoticeResponseBody;
use pq_proto::StartupMessageParams;
use rustls::pki_types::InvalidDnsNameError;
use thiserror::Error;
use tokio::net::{TcpStream, lookup_host};
@@ -24,6 +23,7 @@ use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumDbConnectionsGuard};
use crate::pqproto::StartupMessageParams;
use crate::proxy::neon_option;
use crate::tls::postgres_rustls::MakeRustlsConnect;
use crate::types::Host;

View File

@@ -7,7 +7,6 @@ use arc_swap::ArcSwapOption;
use clap::ValueEnum;
use remote_storage::RemoteStorageConfig;
use crate::auth::backend::AuthRateLimiter;
use crate::auth::backend::jwt::JwkCache;
use crate::control_plane::locks::ApiLocks;
use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig};
@@ -65,9 +64,6 @@ pub struct HttpConfig {
pub struct AuthenticationConfig {
pub thread_pool: Arc<ThreadPool>,
pub scram_protocol_timeout: tokio::time::Duration,
pub rate_limiter_enabled: bool,
pub rate_limiter: AuthRateLimiter,
pub rate_limit_ip_subnet: u8,
pub ip_allowlist_check_enabled: bool,
pub is_vpc_acccess_proxy: bool,
pub jwks_cache: JwkCache,

View File

@@ -1,7 +1,7 @@
use std::sync::Arc;
use futures::{FutureExt, TryFutureExt};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info};
@@ -221,12 +221,10 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
.await
{
Ok(auth_result) => auth_result,
Err(e) => {
return stream.throw_error(e, Some(ctx)).await?;
}
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
};
let mut node = connect_to_compute(
let node = connect_to_compute(
ctx,
&TcpMechanism {
user_info,
@@ -238,7 +236,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config.wake_compute_retry_config,
&config.connect_to_compute,
)
.or_else(|e| stream.throw_error(e, Some(ctx)))
.or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) })
.await?;
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
@@ -246,14 +244,8 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
session.write_cancel_key(node.cancel_closure.clone())?;
prepare_client_connection(&node, *session.key(), &mut stream).await?;
// Before proxy passing, forward to compute whatever data is left in the
// PqStream input buffer. Normally there is none, but our serverless npm
// driver in pipeline mode sends startup, password and first query
// immediately after opening the connection.
let (stream, read_buf) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
prepare_client_connection(&node, *session.key(), &mut stream);
let stream = stream.flush_and_into_inner().await?;
Ok(Some(ProxyPassthrough {
client: stream,

View File

@@ -4,7 +4,6 @@ use std::net::IpAddr;
use chrono::Utc;
use once_cell::sync::OnceCell;
use pq_proto::StartupMessageParams;
use smol_str::SmolStr;
use tokio::sync::mpsc;
use tracing::field::display;
@@ -20,6 +19,7 @@ use crate::metrics::{
ConnectOutcome, InvalidEndpointsGroup, LatencyAccumulated, LatencyTimer, Metrics, Protocol,
Waiting,
};
use crate::pqproto::StartupMessageParams;
use crate::protocol2::{ConnectionInfo, ConnectionInfoExtra};
use crate::types::{DbName, EndpointId, RoleName};
@@ -370,6 +370,18 @@ impl RequestContext {
}
}
pub(crate) fn latency_timer_pause_at(
&self,
at: tokio::time::Instant,
waiting_for: Waiting,
) -> LatencyTimerPause<'_> {
LatencyTimerPause {
ctx: self,
start: at,
waiting_for,
}
}
pub(crate) fn get_proxy_latency(&self) -> LatencyAccumulated {
self.0
.try_lock()

View File

@@ -11,7 +11,6 @@ use parquet::file::metadata::RowGroupMetaDataPtr;
use parquet::file::properties::{DEFAULT_PAGE_SIZE, WriterProperties, WriterPropertiesPtr};
use parquet::file::writer::SerializedFileWriter;
use parquet::record::RecordWriter;
use pq_proto::StartupMessageParams;
use remote_storage::{GenericRemoteStorage, RemotePath, RemoteStorageConfig, TimeoutOrCancel};
use serde::ser::SerializeMap;
use tokio::sync::mpsc;
@@ -24,6 +23,7 @@ use super::{LOG_CHAN, RequestContextInner};
use crate::config::remote_storage_from_toml;
use crate::context::LOG_CHAN_DISCONNECT;
use crate::ext::TaskExt;
use crate::pqproto::StartupMessageParams;
#[derive(clap::Args, Clone, Debug)]
pub struct ParquetUploadArgs {

View File

@@ -15,7 +15,6 @@ use tracing::{Instrument, debug, info, info_span, warn};
use super::super::messages::{ControlPlaneErrorMessage, GetEndpointAccessControl, WakeCompute};
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::AuthRule;
use crate::cache::Cached;
use crate::context::RequestContext;
use crate::control_plane::caches::ApiCaches;
use crate::control_plane::errors::{
@@ -24,12 +23,12 @@ use crate::control_plane::errors::{
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason};
use crate::control_plane::{
AccessBlockerFlags, AuthInfo, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps,
CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, NodeInfo,
AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo,
RoleAccessControl,
};
use crate::metrics::{CacheOutcome, Metrics};
use crate::metrics::Metrics;
use crate::rate_limiter::WakeComputeRateLimiter;
use crate::types::{EndpointCacheKey, EndpointId};
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
use crate::{compute, http, scram};
pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
@@ -66,65 +65,34 @@ impl NeonControlPlaneClient {
self.endpoint.url().as_str()
}
async fn do_get_auth_info(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<AuthInfo, GetAuthInfoError> {
if !self
.caches
.endpoints_cache
.is_valid(ctx, &user_info.endpoint.normalize())
{
// TODO: refactor this because it's weird
// this is a failure to authenticate but we return Ok.
info!("endpoint is not valid, skipping the request");
return Ok(AuthInfo::default());
}
self.do_get_auth_req(user_info, &ctx.session_id(), Some(ctx))
.await
}
async fn do_get_auth_req(
&self,
user_info: &ComputeUserInfo,
session_id: &uuid::Uuid,
ctx: Option<&RequestContext>,
ctx: &RequestContext,
endpoint: &EndpointId,
role: &RoleName,
) -> Result<AuthInfo, GetAuthInfoError> {
let request_id: String = session_id.to_string();
let application_name = if let Some(ctx) = ctx {
ctx.console_application_name()
} else {
"auth_cancellation".to_string()
};
async {
let request = self
.endpoint
.get_path("get_endpoint_access_control")
.header(X_REQUEST_ID, &request_id)
.header(X_REQUEST_ID, ctx.session_id().to_string())
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
.query(&[("session_id", session_id)])
.query(&[("session_id", ctx.session_id())])
.query(&[
("application_name", application_name.as_str()),
("endpointish", user_info.endpoint.as_str()),
("role", user_info.user.as_str()),
("application_name", ctx.console_application_name().as_str()),
("endpointish", endpoint.as_str()),
("role", role.as_str()),
])
.build()?;
debug!(url = request.url().as_str(), "sending http request");
let start = Instant::now();
let response = match ctx {
Some(ctx) => {
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
let rsp = self.endpoint.execute(request).await;
drop(pause);
rsp?
}
None => self.endpoint.execute(request).await?,
let response = {
let _pause = ctx.latency_timer_pause_at(start, crate::metrics::Waiting::Cplane);
self.endpoint.execute(request).await?
};
info!(duration = ?start.elapsed(), "received http response");
let body = match parse_body::<GetEndpointAccessControl>(response).await {
Ok(body) => body,
// Error 404 is special: it's ok not to have a secret.
@@ -180,7 +148,7 @@ impl NeonControlPlaneClient {
async fn do_get_endpoint_jwks(
&self,
ctx: &RequestContext,
endpoint: EndpointId,
endpoint: &EndpointId,
) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
if !self
.caches
@@ -313,225 +281,104 @@ impl NeonControlPlaneClient {
impl super::ControlPlaneApi for NeonControlPlaneClient {
#[tracing::instrument(skip_all)]
async fn get_role_secret(
async fn get_role_access_control(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedRoleSecret, GetAuthInfoError> {
let normalized_ep = &user_info.endpoint.normalize();
let user = &user_info.user;
if let Some(role_secret) = self
endpoint: &EndpointId,
role: &RoleName,
) -> Result<RoleAccessControl, crate::control_plane::errors::GetAuthInfoError> {
let normalized_ep = &endpoint.normalize();
if let Some(secret) = self
.caches
.project_info
.get_role_secret(normalized_ep, user)
.get_role_secret(normalized_ep, role)
{
return Ok(role_secret);
return Ok(secret);
}
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
let account_id = auth_info.account_id;
if !self.caches.endpoints_cache.is_valid(ctx, normalized_ep) {
info!("endpoint is not valid, skipping the request");
return Err(GetAuthInfoError::UnknownEndpoint);
}
let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?;
let control = EndpointAccessControl {
allowed_ips: Arc::new(auth_info.allowed_ips),
allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
flags: auth_info.access_blocker_flags,
};
let role_control = RoleAccessControl {
secret: auth_info.secret,
};
if let Some(project_id) = auth_info.project_id {
let normalized_ep_int = normalized_ep.into();
self.caches.project_info.insert_role_secret(
self.caches.project_info.insert_endpoint_access(
auth_info.account_id,
project_id,
normalized_ep_int,
user.into(),
auth_info.secret.clone(),
);
self.caches.project_info.insert_allowed_ips(
project_id,
normalized_ep_int,
Arc::new(auth_info.allowed_ips),
);
self.caches.project_info.insert_allowed_vpc_endpoint_ids(
account_id,
project_id,
normalized_ep_int,
Arc::new(auth_info.allowed_vpc_endpoint_ids),
);
self.caches.project_info.insert_block_public_or_vpc_access(
project_id,
normalized_ep_int,
auth_info.access_blocker_flags,
role.into(),
control,
role_control.clone(),
);
ctx.set_project_id(project_id);
}
// When we just got a secret, we don't need to invalidate it.
Ok(Cached::new_uncached(auth_info.secret))
Ok(role_control)
}
async fn get_allowed_ips(
#[tracing::instrument(skip_all)]
async fn get_endpoint_access_control(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedAllowedIps, GetAuthInfoError> {
let normalized_ep = &user_info.endpoint.normalize();
if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) {
Metrics::get()
.proxy
.allowed_ips_cache_misses // TODO SR: Should we rename this variable to something like allowed_ip_cache_stats?
.inc(CacheOutcome::Hit);
return Ok(allowed_ips);
endpoint: &EndpointId,
role: &RoleName,
) -> Result<EndpointAccessControl, GetAuthInfoError> {
let normalized_ep = &endpoint.normalize();
if let Some(control) = self.caches.project_info.get_endpoint_access(normalized_ep) {
return Ok(control);
}
Metrics::get()
.proxy
.allowed_ips_cache_misses
.inc(CacheOutcome::Miss);
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
let allowed_ips = Arc::new(auth_info.allowed_ips);
let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids);
let access_blocker_flags = auth_info.access_blocker_flags;
let user = &user_info.user;
let account_id = auth_info.account_id;
if !self.caches.endpoints_cache.is_valid(ctx, normalized_ep) {
info!("endpoint is not valid, skipping the request");
return Err(GetAuthInfoError::UnknownEndpoint);
}
let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?;
let control = EndpointAccessControl {
allowed_ips: Arc::new(auth_info.allowed_ips),
allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
flags: auth_info.access_blocker_flags,
};
let role_control = RoleAccessControl {
secret: auth_info.secret,
};
if let Some(project_id) = auth_info.project_id {
let normalized_ep_int = normalized_ep.into();
self.caches.project_info.insert_role_secret(
self.caches.project_info.insert_endpoint_access(
auth_info.account_id,
project_id,
normalized_ep_int,
user.into(),
auth_info.secret.clone(),
);
self.caches.project_info.insert_allowed_ips(
project_id,
normalized_ep_int,
allowed_ips.clone(),
);
self.caches.project_info.insert_allowed_vpc_endpoint_ids(
account_id,
project_id,
normalized_ep_int,
allowed_vpc_endpoint_ids.clone(),
);
self.caches.project_info.insert_block_public_or_vpc_access(
project_id,
normalized_ep_int,
access_blocker_flags,
role.into(),
control.clone(),
role_control,
);
ctx.set_project_id(project_id);
}
Ok(Cached::new_uncached(allowed_ips))
}
async fn get_allowed_vpc_endpoint_ids(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedAllowedVpcEndpointIds, GetAuthInfoError> {
let normalized_ep = &user_info.endpoint.normalize();
if let Some(allowed_vpc_endpoint_ids) = self
.caches
.project_info
.get_allowed_vpc_endpoint_ids(normalized_ep)
{
Metrics::get()
.proxy
.vpc_endpoint_id_cache_stats
.inc(CacheOutcome::Hit);
return Ok(allowed_vpc_endpoint_ids);
}
Metrics::get()
.proxy
.vpc_endpoint_id_cache_stats
.inc(CacheOutcome::Miss);
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
let allowed_ips = Arc::new(auth_info.allowed_ips);
let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids);
let access_blocker_flags = auth_info.access_blocker_flags;
let user = &user_info.user;
let account_id = auth_info.account_id;
if let Some(project_id) = auth_info.project_id {
let normalized_ep_int = normalized_ep.into();
self.caches.project_info.insert_role_secret(
project_id,
normalized_ep_int,
user.into(),
auth_info.secret.clone(),
);
self.caches.project_info.insert_allowed_ips(
project_id,
normalized_ep_int,
allowed_ips.clone(),
);
self.caches.project_info.insert_allowed_vpc_endpoint_ids(
account_id,
project_id,
normalized_ep_int,
allowed_vpc_endpoint_ids.clone(),
);
self.caches.project_info.insert_block_public_or_vpc_access(
project_id,
normalized_ep_int,
access_blocker_flags,
);
ctx.set_project_id(project_id);
}
Ok(Cached::new_uncached(allowed_vpc_endpoint_ids))
}
async fn get_block_public_or_vpc_access(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedAccessBlockerFlags, GetAuthInfoError> {
let normalized_ep = &user_info.endpoint.normalize();
if let Some(access_blocker_flags) = self
.caches
.project_info
.get_block_public_or_vpc_access(normalized_ep)
{
Metrics::get()
.proxy
.access_blocker_flags_cache_stats
.inc(CacheOutcome::Hit);
return Ok(access_blocker_flags);
}
Metrics::get()
.proxy
.access_blocker_flags_cache_stats
.inc(CacheOutcome::Miss);
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
let allowed_ips = Arc::new(auth_info.allowed_ips);
let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids);
let access_blocker_flags = auth_info.access_blocker_flags;
let user = &user_info.user;
let account_id = auth_info.account_id;
if let Some(project_id) = auth_info.project_id {
let normalized_ep_int = normalized_ep.into();
self.caches.project_info.insert_role_secret(
project_id,
normalized_ep_int,
user.into(),
auth_info.secret.clone(),
);
self.caches.project_info.insert_allowed_ips(
project_id,
normalized_ep_int,
allowed_ips.clone(),
);
self.caches.project_info.insert_allowed_vpc_endpoint_ids(
account_id,
project_id,
normalized_ep_int,
allowed_vpc_endpoint_ids.clone(),
);
self.caches.project_info.insert_block_public_or_vpc_access(
project_id,
normalized_ep_int,
access_blocker_flags.clone(),
);
ctx.set_project_id(project_id);
}
Ok(Cached::new_uncached(access_blocker_flags))
Ok(control)
}
#[tracing::instrument(skip_all)]
async fn get_endpoint_jwks(
&self,
ctx: &RequestContext,
endpoint: EndpointId,
endpoint: &EndpointId,
) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
self.do_get_endpoint_jwks(ctx, endpoint).await
}

View File

@@ -15,14 +15,14 @@ use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::AuthRule;
use crate::cache::Cached;
use crate::context::RequestContext;
use crate::control_plane::client::{
CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedRoleSecret,
};
use crate::control_plane::errors::{
ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
};
use crate::control_plane::messages::MetricsAuxInfo;
use crate::control_plane::{AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo};
use crate::control_plane::{
AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo,
RoleAccessControl,
};
use crate::intern::RoleNameInt;
use crate::types::{BranchId, EndpointId, ProjectId, RoleName};
use crate::url::ApiUrl;
@@ -66,7 +66,8 @@ impl MockControlPlane {
async fn do_get_auth_info(
&self,
user_info: &ComputeUserInfo,
endpoint: &EndpointId,
role: &RoleName,
) -> Result<AuthInfo, GetAuthInfoError> {
let (secret, allowed_ips) = async {
// Perhaps we could persist this connection, but then we'd have to
@@ -80,7 +81,7 @@ impl MockControlPlane {
let secret = if let Some(entry) = get_execute_postgres_query(
&client,
"select rolpassword from pg_catalog.pg_authid where rolname = $1",
&[&&*user_info.user],
&[&role.as_str()],
"rolpassword",
)
.await?
@@ -89,7 +90,7 @@ impl MockControlPlane {
let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
} else {
warn!("user '{}' does not exist", user_info.user);
warn!("user '{role}' does not exist");
None
};
@@ -97,7 +98,7 @@ impl MockControlPlane {
match get_execute_postgres_query(
&client,
"select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1",
&[&user_info.endpoint.as_str()],
&[&endpoint.as_str()],
"allowed_ips",
)
.await?
@@ -133,7 +134,7 @@ impl MockControlPlane {
async fn do_get_endpoint_jwks(
&self,
endpoint: EndpointId,
endpoint: &EndpointId,
) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
let (client, connection) =
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
@@ -222,53 +223,36 @@ async fn get_execute_postgres_query(
}
impl super::ControlPlaneApi for MockControlPlane {
#[tracing::instrument(skip_all)]
async fn get_role_secret(
async fn get_endpoint_access_control(
&self,
_ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedRoleSecret, GetAuthInfoError> {
Ok(CachedRoleSecret::new_uncached(
self.do_get_auth_info(user_info).await?.secret,
))
endpoint: &EndpointId,
role: &RoleName,
) -> Result<EndpointAccessControl, GetAuthInfoError> {
let info = self.do_get_auth_info(endpoint, role).await?;
Ok(EndpointAccessControl {
allowed_ips: Arc::new(info.allowed_ips),
allowed_vpce: Arc::new(info.allowed_vpc_endpoint_ids),
flags: info.access_blocker_flags,
})
}
async fn get_allowed_ips(
async fn get_role_access_control(
&self,
_ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedAllowedIps, GetAuthInfoError> {
Ok(Cached::new_uncached(Arc::new(
self.do_get_auth_info(user_info).await?.allowed_ips,
)))
}
async fn get_allowed_vpc_endpoint_ids(
&self,
_ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedAllowedVpcEndpointIds, super::errors::GetAuthInfoError> {
Ok(Cached::new_uncached(Arc::new(
self.do_get_auth_info(user_info)
.await?
.allowed_vpc_endpoint_ids,
)))
}
async fn get_block_public_or_vpc_access(
&self,
_ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<super::CachedAccessBlockerFlags, super::errors::GetAuthInfoError> {
Ok(Cached::new_uncached(
self.do_get_auth_info(user_info).await?.access_blocker_flags,
))
endpoint: &EndpointId,
role: &RoleName,
) -> Result<RoleAccessControl, GetAuthInfoError> {
let info = self.do_get_auth_info(endpoint, role).await?;
Ok(RoleAccessControl {
secret: info.secret,
})
}
async fn get_endpoint_jwks(
&self,
_ctx: &RequestContext,
endpoint: EndpointId,
endpoint: &EndpointId,
) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
self.do_get_endpoint_jwks(endpoint).await
}

View File

@@ -16,15 +16,14 @@ use crate::cache::endpoints::EndpointsCache;
use crate::cache::project_info::ProjectInfoCacheImpl;
use crate::config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions};
use crate::context::RequestContext;
use crate::control_plane::{
CachedAccessBlockerFlags, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo,
CachedRoleSecret, ControlPlaneApi, NodeInfoCache, errors,
};
use crate::control_plane::{CachedNodeInfo, ControlPlaneApi, NodeInfoCache, errors};
use crate::error::ReportableError;
use crate::metrics::ApiLockMetrics;
use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token};
use crate::types::EndpointId;
use super::{EndpointAccessControl, RoleAccessControl};
#[non_exhaustive]
#[derive(Clone)]
pub enum ControlPlaneClient {
@@ -40,68 +39,42 @@ pub enum ControlPlaneClient {
}
impl ControlPlaneApi for ControlPlaneClient {
async fn get_role_secret(
async fn get_role_access_control(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
endpoint: &EndpointId,
role: &crate::types::RoleName,
) -> Result<RoleAccessControl, errors::GetAuthInfoError> {
match self {
Self::ProxyV1(api) => api.get_role_secret(ctx, user_info).await,
Self::ProxyV1(api) => api.get_role_access_control(ctx, endpoint, role).await,
#[cfg(any(test, feature = "testing"))]
Self::PostgresMock(api) => api.get_role_secret(ctx, user_info).await,
Self::PostgresMock(api) => api.get_role_access_control(ctx, endpoint, role).await,
#[cfg(test)]
Self::Test(_) => {
Self::Test(_api) => {
unreachable!("this function should never be called in the test backend")
}
}
}
async fn get_allowed_ips(
async fn get_endpoint_access_control(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedAllowedIps, errors::GetAuthInfoError> {
endpoint: &EndpointId,
role: &crate::types::RoleName,
) -> Result<EndpointAccessControl, errors::GetAuthInfoError> {
match self {
Self::ProxyV1(api) => api.get_allowed_ips(ctx, user_info).await,
Self::ProxyV1(api) => api.get_endpoint_access_control(ctx, endpoint, role).await,
#[cfg(any(test, feature = "testing"))]
Self::PostgresMock(api) => api.get_allowed_ips(ctx, user_info).await,
Self::PostgresMock(api) => api.get_endpoint_access_control(ctx, endpoint, role).await,
#[cfg(test)]
Self::Test(api) => api.get_allowed_ips(),
}
}
async fn get_allowed_vpc_endpoint_ids(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedAllowedVpcEndpointIds, errors::GetAuthInfoError> {
match self {
Self::ProxyV1(api) => api.get_allowed_vpc_endpoint_ids(ctx, user_info).await,
#[cfg(any(test, feature = "testing"))]
Self::PostgresMock(api) => api.get_allowed_vpc_endpoint_ids(ctx, user_info).await,
#[cfg(test)]
Self::Test(api) => api.get_allowed_vpc_endpoint_ids(),
}
}
async fn get_block_public_or_vpc_access(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedAccessBlockerFlags, errors::GetAuthInfoError> {
match self {
Self::ProxyV1(api) => api.get_block_public_or_vpc_access(ctx, user_info).await,
#[cfg(any(test, feature = "testing"))]
Self::PostgresMock(api) => api.get_block_public_or_vpc_access(ctx, user_info).await,
#[cfg(test)]
Self::Test(api) => api.get_block_public_or_vpc_access(),
Self::Test(api) => api.get_access_control(),
}
}
async fn get_endpoint_jwks(
&self,
ctx: &RequestContext,
endpoint: EndpointId,
endpoint: &EndpointId,
) -> Result<Vec<AuthRule>, errors::GetEndpointJwksError> {
match self {
Self::ProxyV1(api) => api.get_endpoint_jwks(ctx, endpoint).await,
@@ -131,15 +104,7 @@ impl ControlPlaneApi for ControlPlaneClient {
pub(crate) trait TestControlPlaneClient: Send + Sync + 'static {
fn wake_compute(&self) -> Result<CachedNodeInfo, errors::WakeComputeError>;
fn get_allowed_ips(&self) -> Result<CachedAllowedIps, errors::GetAuthInfoError>;
fn get_allowed_vpc_endpoint_ids(
&self,
) -> Result<CachedAllowedVpcEndpointIds, errors::GetAuthInfoError>;
fn get_block_public_or_vpc_access(
&self,
) -> Result<CachedAccessBlockerFlags, errors::GetAuthInfoError>;
fn get_access_control(&self) -> Result<EndpointAccessControl, errors::GetAuthInfoError>;
fn dyn_clone(&self) -> Box<dyn TestControlPlaneClient>;
}
@@ -309,7 +274,7 @@ impl FetchAuthRules for ControlPlaneClient {
ctx: &RequestContext,
endpoint: EndpointId,
) -> Result<Vec<AuthRule>, FetchAuthRulesError> {
self.get_endpoint_jwks(ctx, endpoint)
self.get_endpoint_jwks(ctx, &endpoint)
.await
.map_err(FetchAuthRulesError::GetEndpointJwks)
}

View File

@@ -99,6 +99,10 @@ pub(crate) enum GetAuthInfoError {
#[error(transparent)]
ApiError(ControlPlaneError),
/// Proxy does not know about the endpoint in advanced
#[error("endpoint not found in endpoint cache")]
UnknownEndpoint,
}
// This allows more useful interactions than `#[from]`.
@@ -115,6 +119,8 @@ impl UserFacingError for GetAuthInfoError {
Self::BadSecret => REQUEST_FAILED.to_owned(),
// However, API might return a meaningful error.
Self::ApiError(e) => e.to_string_client(),
// pretend like control plane returned an error.
Self::UnknownEndpoint => REQUEST_FAILED.to_owned(),
}
}
}
@@ -124,6 +130,8 @@ impl ReportableError for GetAuthInfoError {
match self {
Self::BadSecret => crate::error::ErrorKind::ControlPlane,
Self::ApiError(_) => crate::error::ErrorKind::ControlPlane,
// we only apply endpoint filtering if control plane is under high load.
Self::UnknownEndpoint => crate::error::ErrorKind::ServiceRateLimit,
}
}
}

View File

@@ -11,16 +11,16 @@ pub(crate) mod errors;
use std::sync::Arc;
use crate::auth::IpPattern;
use crate::auth::backend::jwt::AuthRule;
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::cache::project_info::ProjectInfoCacheImpl;
use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list};
use crate::cache::{Cached, TimedLru};
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo};
use crate::intern::{AccountIdInt, ProjectIdInt};
use crate::types::{EndpointCacheKey, EndpointId};
use crate::protocol2::ConnectionInfoExtra;
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
use crate::{compute, scram};
/// Various cache-related types.
@@ -101,7 +101,7 @@ impl NodeInfo {
}
}
#[derive(Clone, Default, Eq, PartialEq, Debug)]
#[derive(Copy, Clone, Default)]
pub(crate) struct AccessBlockerFlags {
pub public_access_blocked: bool,
pub vpc_access_blocked: bool,
@@ -110,47 +110,78 @@ pub(crate) struct AccessBlockerFlags {
pub(crate) type NodeInfoCache =
TimedLru<EndpointCacheKey, Result<NodeInfo, Box<ControlPlaneErrorMessage>>>;
pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>;
pub(crate) type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
pub(crate) type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
pub(crate) type CachedAllowedVpcEndpointIds =
Cached<&'static ProjectInfoCacheImpl, Arc<Vec<String>>>;
pub(crate) type CachedAccessBlockerFlags =
Cached<&'static ProjectInfoCacheImpl, AccessBlockerFlags>;
#[derive(Clone)]
pub struct RoleAccessControl {
pub secret: Option<AuthSecret>,
}
#[derive(Clone)]
pub struct EndpointAccessControl {
pub allowed_ips: Arc<Vec<IpPattern>>,
pub allowed_vpce: Arc<Vec<String>>,
pub flags: AccessBlockerFlags,
}
impl EndpointAccessControl {
pub fn check(
&self,
ctx: &RequestContext,
check_ip_allowed: bool,
check_vpc_allowed: bool,
) -> Result<(), AuthError> {
if check_ip_allowed && !check_peer_addr_is_in_list(&ctx.peer_addr(), &self.allowed_ips) {
return Err(AuthError::IpAddressNotAllowed(ctx.peer_addr()));
}
// check if a VPC endpoint ID is coming in and if yes, if it's allowed
if check_vpc_allowed {
if self.flags.vpc_access_blocked {
return Err(AuthError::NetworkNotAllowed);
}
let incoming_vpc_endpoint_id = match ctx.extra() {
None => return Err(AuthError::MissingVPCEndpointId),
Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(),
Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(),
};
let vpce = &self.allowed_vpce;
// TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that.
if !vpce.is_empty() && !vpce.contains(&incoming_vpc_endpoint_id) {
return Err(AuthError::vpc_endpoint_id_not_allowed(
incoming_vpc_endpoint_id,
));
}
} else if self.flags.public_access_blocked {
return Err(AuthError::NetworkNotAllowed);
}
Ok(())
}
}
/// This will allocate per each call, but the http requests alone
/// already require a few allocations, so it should be fine.
pub(crate) trait ControlPlaneApi {
/// Get the client's auth secret for authentication.
/// Returns option because user not found situation is special.
/// We still have to mock the scram to avoid leaking information that user doesn't exist.
async fn get_role_secret(
async fn get_role_access_control(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedRoleSecret, errors::GetAuthInfoError>;
endpoint: &EndpointId,
role: &RoleName,
) -> Result<RoleAccessControl, errors::GetAuthInfoError>;
async fn get_allowed_ips(
async fn get_endpoint_access_control(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedAllowedIps, errors::GetAuthInfoError>;
async fn get_allowed_vpc_endpoint_ids(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedAllowedVpcEndpointIds, errors::GetAuthInfoError>;
async fn get_block_public_or_vpc_access(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedAccessBlockerFlags, errors::GetAuthInfoError>;
endpoint: &EndpointId,
role: &RoleName,
) -> Result<EndpointAccessControl, errors::GetAuthInfoError>;
async fn get_endpoint_jwks(
&self,
ctx: &RequestContext,
endpoint: EndpointId,
endpoint: &EndpointId,
) -> Result<Vec<AuthRule>, errors::GetEndpointJwksError>;
/// Wake up the compute node and return the corresponding connection info.

View File

@@ -92,6 +92,7 @@ mod logging;
mod metrics;
mod parse;
mod pglb;
mod pqproto;
mod protocol2;
mod proxy;
mod rate_limiter;

693
proxy/src/pqproto.rs Normal file
View File

@@ -0,0 +1,693 @@
//! Postgres protocol codec
//!
//! <https://www.postgresql.org/docs/current/protocol-message-formats.html>
use std::fmt;
use std::io::{self, Cursor};
use bytes::{Buf, BufMut};
use itertools::Itertools;
use rand::distributions::{Distribution, Standard};
use tokio::io::{AsyncRead, AsyncReadExt};
use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian};
pub type ErrorCode = [u8; 5];
pub const FE_PASSWORD_MESSAGE: u8 = b'p';
pub const SQLSTATE_INTERNAL_ERROR: [u8; 5] = *b"XX000";
/// The protocol version number.
///
/// The most significant 16 bits are the major version number (3 for the protocol described here).
/// The least significant 16 bits are the minor version number (0 for the protocol described here).
/// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-STARTUPMESSAGE>
#[derive(Clone, Copy, PartialEq, PartialOrd, FromBytes, IntoBytes, Immutable)]
#[repr(C)]
pub struct ProtocolVersion {
major: big_endian::U16,
minor: big_endian::U16,
}
impl ProtocolVersion {
pub const fn new(major: u16, minor: u16) -> Self {
Self {
major: big_endian::U16::new(major),
minor: big_endian::U16::new(minor),
}
}
pub const fn minor(self) -> u16 {
self.minor.get()
}
pub const fn major(self) -> u16 {
self.major.get()
}
}
impl fmt::Debug for ProtocolVersion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list()
.entry(&self.major())
.entry(&self.minor())
.finish()
}
}
/// read the type from the stream using zerocopy.
///
/// not cancel safe.
macro_rules! read {
($s:expr => $t:ty) => {{
// cannot be implemented as a function due to lack of const-generic-expr
let mut buf = [0; size_of::<$t>()];
$s.read_exact(&mut buf).await?;
let res: $t = zerocopy::transmute!(buf);
res
}};
}
pub async fn read_startup<S>(stream: &mut S) -> io::Result<FeStartupPacket>
where
S: AsyncRead + Unpin,
{
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L118>
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234;
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L132>
const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678);
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L166>
const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679);
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L167>
const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680);
/// This first reads the startup message header, is 8 bytes.
/// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number.
///
/// The length value is inclusive of the header. For example,
/// an empty message will always have length 8.
#[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)]
#[repr(C)]
struct StartupHeader {
len: big_endian::U32,
version: ProtocolVersion,
}
let header = read!(stream => StartupHeader);
// <https://github.com/postgres/postgres/blob/04bcf9e19a4261fe9c7df37c777592c2e10c32a7/src/backend/tcop/backend_startup.c#L378-L382>
// First byte indicates standard SSL handshake message
// (It can't be a Postgres startup length because in network byte order
// that would be a startup packet hundreds of megabytes long)
if header.as_bytes()[0] == 0x16 {
return Ok(FeStartupPacket::SslRequest {
// The bytes we read for the header are actually part of a TLS ClientHello.
// In theory, if the ClientHello was < 8 bytes we would fail with EOF before we get here.
// In practice though, I see no world where a ClientHello is less than 8 bytes
// since it includes ephemeral keys etc.
direct: Some(zerocopy::transmute!(header)),
});
}
let Some(len) = (header.len.get() as usize).checked_sub(8) else {
return Err(io::Error::other(format!(
"invalid startup message length {}, must be at least 8.",
header.len,
)));
};
// TODO: add a histogram for startup packet lengths
if len > MAX_STARTUP_PACKET_LENGTH {
tracing::warn!("large startup message detected: {len} bytes");
return Err(io::Error::other(format!(
"invalid startup message length {len}"
)));
}
match header.version {
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-CANCELREQUEST>
CANCEL_REQUEST_CODE => {
if len != 8 {
return Err(io::Error::other(
"CancelRequest message is malformed, backend PID / secret key missing",
));
}
Ok(FeStartupPacket::CancelRequest(
read!(stream => CancelKeyData),
))
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-SSLREQUEST>
NEGOTIATE_SSL_CODE => {
// Requested upgrade to SSL (aka TLS)
Ok(FeStartupPacket::SslRequest { direct: None })
}
NEGOTIATE_GSS_CODE => {
// Requested upgrade to GSSAPI
Ok(FeStartupPacket::GssEncRequest)
}
version if version.major() == RESERVED_INVALID_MAJOR_VERSION => Err(io::Error::other(
format!("Unrecognized request code {version:?}"),
)),
// StartupMessage
version => {
// The protocol version number is followed by one or more pairs of parameter name and value strings.
// A zero byte is required as a terminator after the last name/value pair.
// Parameters can appear in any order. user is required, others are optional.
let mut buf = vec![0; len];
stream.read_exact(&mut buf).await?;
if buf.pop() != Some(b'\0') {
return Err(io::Error::other(
"StartupMessage params: missing null terminator",
));
}
// TODO: Don't do this.
// There's no guarantee that these messages are utf8,
// but they usually happen to be simple ascii.
let params = String::from_utf8(buf)
.map_err(|_| io::Error::other("StartupMessage params: invalid utf-8"))?;
Ok(FeStartupPacket::StartupMessage {
version,
params: StartupMessageParams { params },
})
}
}
}
/// Read a raw postgres packet, which will respect the max length requested.
///
/// This returns the message tag, as well as the message body. The message
/// body is written into `buf`, and it is otherwise completely overwritten.
///
/// This is not cancel safe.
pub async fn read_message<'a, S>(
stream: &mut S,
buf: &'a mut Vec<u8>,
max: usize,
) -> io::Result<(u8, &'a mut [u8])>
where
S: AsyncRead + Unpin,
{
/// This first reads the header, which for regular messages in the 3.0 protocol is 5 bytes.
/// The first byte is a message tag, and the next 4 bytes is a big-endian length.
///
/// Awkwardly, the length value is inclusive of itself, but not of the tag. For example,
/// an empty message will always have length 4.
#[derive(Clone, Copy, FromBytes)]
#[repr(C)]
struct Header {
tag: u8,
len: big_endian::U32,
}
let header = read!(stream => Header);
// as described above, the length must be at least 4.
let Some(len) = (header.len.get() as usize).checked_sub(4) else {
return Err(io::Error::other(format!(
"invalid startup message length {}, must be at least 4.",
header.len,
)));
};
// TODO: add a histogram for message lengths
// check if the message exceeds our desired max.
if len > max {
tracing::warn!("large postgres message detected: {len} bytes");
return Err(io::Error::other(format!("invalid message length {len}")));
}
// read in our entire message.
buf.resize(len, 0);
stream.read_exact(buf).await?;
Ok((header.tag, buf))
}
pub struct WriteBuf(Cursor<Vec<u8>>);
impl Buf for WriteBuf {
#[inline]
fn remaining(&self) -> usize {
self.0.remaining()
}
#[inline]
fn chunk(&self) -> &[u8] {
self.0.chunk()
}
#[inline]
fn advance(&mut self, cnt: usize) {
self.0.advance(cnt);
}
}
impl WriteBuf {
pub const fn new() -> Self {
Self(Cursor::new(Vec::new()))
}
/// Use a heuristic to determine if we should shrink the write buffer.
#[inline]
fn should_shrink(&self) -> bool {
let n = self.0.position() as usize;
let len = self.0.get_ref().len();
// the unused space at the front of our buffer is 2x the size of our filled portion.
n + n > len
}
/// Shrink the write buffer so that subsequent writes have more spare capacity.
#[cold]
fn shrink(&mut self) {
let n = self.0.position() as usize;
let buf = self.0.get_mut();
// buf repr:
// [----unused------|-----filled-----|-----uninit-----]
// ^ n ^ buf.len() ^ buf.capacity()
let filled = n..buf.len();
let filled_len = filled.len();
buf.copy_within(filled, 0);
buf.truncate(filled_len);
self.0.set_position(0);
}
/// clear the write buffer.
pub fn reset(&mut self) {
let buf = self.0.get_mut();
buf.clear();
self.0.set_position(0);
}
/// Write a raw message to the internal buffer.
///
/// The size_hint value is only a hint for reserving space. It's ok if it's incorrect, since
/// we calculate the length after the fact.
pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec<u8>)) {
if self.should_shrink() {
self.shrink();
}
let buf = self.0.get_mut();
buf.reserve(5 + size_hint);
buf.push(tag);
let start = buf.len();
buf.extend_from_slice(&[0, 0, 0, 0]);
f(buf);
let end = buf.len();
let len = (end - start) as u32;
buf[start..start + 4].copy_from_slice(&len.to_be_bytes());
}
/// Write an encryption response message.
pub fn encryption(&mut self, m: u8) {
self.0.get_mut().push(m);
}
pub fn write_error(&mut self, msg: &str, error_code: ErrorCode) {
self.shrink();
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-ERRORRESPONSE>
// <https://www.postgresql.org/docs/current/protocol-error-fields.html>
// "SERROR\0CXXXXX\0M\0\0".len() == 17
self.write_raw(17 + msg.len(), b'E', |buf| {
// Severity: ERROR
buf.put_slice(b"SERROR\0");
// Code: error_code
buf.put_u8(b'C');
buf.put_slice(&error_code);
buf.put_u8(0);
// Message: msg
buf.put_u8(b'M');
buf.put_slice(msg.as_bytes());
buf.put_u8(0);
// End.
buf.put_u8(0);
});
}
}
#[derive(Debug)]
pub enum FeStartupPacket {
CancelRequest(CancelKeyData),
SslRequest {
direct: Option<[u8; 8]>,
},
GssEncRequest,
StartupMessage {
version: ProtocolVersion,
params: StartupMessageParams,
},
}
#[derive(Debug, Clone, Default)]
pub struct StartupMessageParams {
pub params: String,
}
impl StartupMessageParams {
/// Get parameter's value by its name.
pub fn get(&self, name: &str) -> Option<&str> {
self.iter().find_map(|(k, v)| (k == name).then_some(v))
}
/// Split command-line options according to PostgreSQL's logic,
/// taking into account all escape sequences but leaving them as-is.
/// [`None`] means that there's no `options` in [`Self`].
pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> {
self.get("options").map(Self::parse_options_raw)
}
/// Split command-line options according to PostgreSQL's logic,
/// taking into account all escape sequences but leaving them as-is.
pub fn parse_options_raw(input: &str) -> impl Iterator<Item = &str> {
// See `postgres: pg_split_opts`.
let mut last_was_escape = false;
input
.split(move |c: char| {
// We split by non-escaped whitespace symbols.
let should_split = c.is_ascii_whitespace() && !last_was_escape;
last_was_escape = c == '\\' && !last_was_escape;
should_split
})
.filter(|s| !s.is_empty())
}
/// Iterate through key-value pairs in an arbitrary order.
pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
self.params.split_terminator('\0').tuples()
}
// This function is mostly useful in tests.
#[cfg(test)]
pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self {
let mut b = Self {
params: String::new(),
};
for (k, v) in pairs {
b.insert(k, v);
}
b
}
/// Set parameter's value by its name.
/// name and value must not contain a \0 byte
pub fn insert(&mut self, name: &str, value: &str) {
self.params.reserve(name.len() + value.len() + 2);
self.params.push_str(name);
self.params.push('\0');
self.params.push_str(value);
self.params.push('\0');
}
}
/// Cancel keys usually are represented as PID+SecretKey, but to proxy they're just
/// opaque bytes.
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, FromBytes, IntoBytes, Immutable)]
pub struct CancelKeyData(pub big_endian::U64);
pub fn id_to_cancel_key(id: u64) -> CancelKeyData {
CancelKeyData(big_endian::U64::new(id))
}
impl fmt::Display for CancelKeyData {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let id = self.0;
f.debug_tuple("CancelKeyData")
.field(&format_args!("{id:x}"))
.finish()
}
}
impl Distribution<CancelKeyData> for Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
id_to_cancel_key(rng.r#gen())
}
}
pub enum BeMessage<'a> {
AuthenticationOk,
AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
AuthenticationCleartextPassword,
BackendKeyData(CancelKeyData),
ParameterStatus {
name: &'a [u8],
value: &'a [u8],
},
ReadyForQuery,
NoticeResponse(&'a str),
NegotiateProtocolVersion {
version: ProtocolVersion,
options: &'a [&'a str],
},
}
#[derive(Debug)]
pub enum BeAuthenticationSaslMessage<'a> {
Methods(&'a [&'a str]),
Continue(&'a [u8]),
Final(&'a [u8]),
}
impl BeMessage<'_> {
/// Write the message into an internal buffer
pub fn write_message(self, buf: &mut WriteBuf) {
match self {
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONCLEARTEXTPASSWORD>
BeMessage::AuthenticationOk => {
buf.write_raw(1, b'R', |buf| buf.put_i32(0));
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONCLEARTEXTPASSWORD>
BeMessage::AuthenticationCleartextPassword => {
buf.write_raw(1, b'R', |buf| buf.put_i32(3));
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(methods)) => {
let len: usize = methods.iter().map(|m| m.len() + 1).sum();
buf.write_raw(len + 2, b'R', |buf| {
buf.put_i32(10); // Specifies that SASL auth method is used.
for method in methods {
buf.put_slice(method.as_bytes());
buf.put_u8(0);
}
buf.put_u8(0); // zero terminator for the list
});
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Continue(extra)) => {
buf.write_raw(extra.len() + 1, b'R', |buf| {
buf.put_i32(11); // Continue SASL auth.
buf.put_slice(extra);
});
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Final(extra)) => {
buf.write_raw(extra.len() + 1, b'R', |buf| {
buf.put_i32(12); // Send final SASL message.
buf.put_slice(extra);
});
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-BACKENDKEYDATA>
BeMessage::BackendKeyData(key_data) => {
buf.write_raw(8, b'K', |buf| buf.put_slice(key_data.as_bytes()));
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NOTICERESPONSE>
// <https://www.postgresql.org/docs/current/protocol-error-fields.html>
BeMessage::NoticeResponse(msg) => {
// 'N' signalizes NoticeResponse messages
buf.write_raw(18 + msg.len(), b'N', |buf| {
// Severity: NOTICE
buf.put_slice(b"SNOTICE\0");
// Code: XX000 (ignored for notice, but still required)
buf.put_slice(b"CXX000\0");
// Message: msg
buf.put_u8(b'M');
buf.put_slice(msg.as_bytes());
buf.put_u8(0);
// End notice.
buf.put_u8(0);
});
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-PARAMETERSTATUS>
BeMessage::ParameterStatus { name, value } => {
buf.write_raw(name.len() + value.len() + 2, b'S', |buf| {
buf.put_slice(name.as_bytes());
buf.put_u8(0);
buf.put_slice(value.as_bytes());
buf.put_u8(0);
});
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NEGOTIATEPROTOCOLVERSION>
BeMessage::ReadyForQuery => {
buf.write_raw(1, b'Z', |buf| buf.put_u8(b'I'));
}
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NEGOTIATEPROTOCOLVERSION>
BeMessage::NegotiateProtocolVersion { version, options } => {
let len: usize = options.iter().map(|o| o.len() + 1).sum();
buf.write_raw(8 + len, b'v', |buf| {
buf.put_slice(version.as_bytes());
buf.put_u32(options.len() as u32);
for option in options {
buf.put_slice(option.as_bytes());
buf.put_u8(0);
}
});
}
}
}
}
#[cfg(test)]
mod tests {
use std::io::Cursor;
use tokio::io::{AsyncWriteExt, duplex};
use zerocopy::IntoBytes;
use crate::pqproto::{FeStartupPacket, read_message, read_startup};
use super::ProtocolVersion;
#[tokio::test]
async fn reject_large_startup() {
// we're going to define a v3.0 startup message with far too many parameters.
let mut payload = vec![];
// 10001 + 8 bytes.
payload.extend_from_slice(&10009_u32.to_be_bytes());
payload.extend_from_slice(ProtocolVersion::new(3, 0).as_bytes());
payload.resize(10009, b'a');
let (mut server, mut client) = duplex(128);
#[rustfmt::skip]
let (server, client) = tokio::join!(
async move { read_startup(&mut server).await.unwrap_err() },
async move { client.write_all(&payload).await.unwrap_err() },
);
assert_eq!(server.to_string(), "invalid startup message length 10001");
assert_eq!(client.to_string(), "broken pipe");
}
#[tokio::test]
async fn reject_large_password() {
// we're going to define a password message that is far too long.
let mut payload = vec![];
payload.push(b'p');
payload.extend_from_slice(&517_u32.to_be_bytes());
payload.resize(518, b'a');
let (mut server, mut client) = duplex(128);
#[rustfmt::skip]
let (server, client) = tokio::join!(
async move { read_message(&mut server, &mut vec![], 512).await.unwrap_err() },
async move { client.write_all(&payload).await.unwrap_err() },
);
assert_eq!(server.to_string(), "invalid message length 513");
assert_eq!(client.to_string(), "broken pipe");
}
#[tokio::test]
async fn read_startup_message() {
let mut payload = vec![];
payload.extend_from_slice(&17_u32.to_be_bytes());
payload.extend_from_slice(ProtocolVersion::new(3, 0).as_bytes());
payload.extend_from_slice(b"abc\0def\0\0");
let startup = read_startup(&mut Cursor::new(&payload)).await.unwrap();
let FeStartupPacket::StartupMessage { version, params } = startup else {
panic!("unexpected startup message: {startup:?}");
};
assert_eq!(version.major(), 3);
assert_eq!(version.minor(), 0);
assert_eq!(params.params, "abc\0def\0");
}
#[tokio::test]
async fn read_ssl_message() {
let mut payload = vec![];
payload.extend_from_slice(&8_u32.to_be_bytes());
payload.extend_from_slice(ProtocolVersion::new(1234, 5679).as_bytes());
let startup = read_startup(&mut Cursor::new(&payload)).await.unwrap();
let FeStartupPacket::SslRequest { direct: None } = startup else {
panic!("unexpected startup message: {startup:?}");
};
}
#[tokio::test]
async fn read_tls_message() {
// sample client hello taken from <https://tls13.xargs.org/#client-hello>
let client_hello = [
0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00, 0xf4, 0x03, 0x03, 0x00, 0x01, 0x02,
0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10,
0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e,
0x1f, 0x20, 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb,
0xec, 0xed, 0xee, 0xef, 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9,
0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, 0x00, 0x08, 0x13, 0x02, 0x13, 0x03, 0x13, 0x01,
0x00, 0xff, 0x01, 0x00, 0x00, 0xa3, 0x00, 0x00, 0x00, 0x18, 0x00, 0x16, 0x00, 0x00,
0x13, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x75, 0x6c, 0x66, 0x68, 0x65,
0x69, 0x6d, 0x2e, 0x6e, 0x65, 0x74, 0x00, 0x0b, 0x00, 0x04, 0x03, 0x00, 0x01, 0x02,
0x00, 0x0a, 0x00, 0x16, 0x00, 0x14, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x1e, 0x00, 0x19,
0x00, 0x18, 0x01, 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x23,
0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x1e,
0x00, 0x1c, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x08, 0x07, 0x08, 0x08, 0x08, 0x09,
0x08, 0x0a, 0x08, 0x0b, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06, 0x04, 0x01, 0x05, 0x01,
0x06, 0x01, 0x00, 0x2b, 0x00, 0x03, 0x02, 0x03, 0x04, 0x00, 0x2d, 0x00, 0x02, 0x01,
0x01, 0x00, 0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0x35, 0x80, 0x72,
0xd6, 0x36, 0x58, 0x80, 0xd1, 0xae, 0xea, 0x32, 0x9a, 0xdf, 0x91, 0x21, 0x38, 0x38,
0x51, 0xed, 0x21, 0xa2, 0x8e, 0x3b, 0x75, 0xe9, 0x65, 0xd0, 0xd2, 0xcd, 0x16, 0x62,
0x54,
];
let mut cursor = Cursor::new(&client_hello);
let startup = read_startup(&mut cursor).await.unwrap();
let FeStartupPacket::SslRequest {
direct: Some(prefix),
} = startup
else {
panic!("unexpected startup message: {startup:?}");
};
// check that no data is lost.
assert_eq!(prefix, [0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00]);
assert_eq!(cursor.position(), 8);
}
#[tokio::test]
async fn read_message_success() {
let query = b"Q\0\0\0\x0cSELECT 1Q\0\0\0\x0cSELECT 2";
let mut cursor = Cursor::new(&query);
let mut buf = vec![];
let (tag, message) = read_message(&mut cursor, &mut buf, 100).await.unwrap();
assert_eq!(tag, b'Q');
assert_eq!(message, b"SELECT 1");
let (tag, message) = read_message(&mut cursor, &mut buf, 100).await.unwrap();
assert_eq!(tag, b'Q');
assert_eq!(message, b"SELECT 2");
}
}

View File

@@ -1,5 +1,4 @@
use async_trait::async_trait;
use pq_proto::StartupMessageParams;
use tokio::time;
use tracing::{debug, info, warn};
@@ -15,6 +14,7 @@ use crate::error::ReportableError;
use crate::metrics::{
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
};
use crate::pqproto::StartupMessageParams;
use crate::proxy::retry::{CouldRetry, retry_after, should_retry};
use crate::proxy::wake_compute::wake_compute;
use crate::types::Host;

View File

@@ -1,8 +1,3 @@
use bytes::Buf;
use pq_proto::framed::Framed;
use pq_proto::{
BeMessage as Be, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams,
};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, info, warn};
@@ -12,7 +7,10 @@ use crate::config::TlsConfig;
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::Metrics;
use crate::proxy::ERR_INSECURE_CONNECTION;
use crate::pqproto::{
BeMessage, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams,
};
use crate::proxy::TlsRequired;
use crate::stream::{PqStream, Stream, StreamUpgradeError};
use crate::tls::PG_ALPN_PROTOCOL;
@@ -71,33 +69,25 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0);
const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0);
let mut stream = PqStream::new(Stream::from_raw(stream));
let (mut stream, mut msg) = PqStream::parse_startup(Stream::from_raw(stream)).await?;
loop {
let msg = stream.read_startup_packet().await?;
match msg {
FeStartupPacket::SslRequest { direct } => match stream.get_ref() {
Stream::Raw { .. } if !tried_ssl => {
tried_ssl = true;
// We can't perform TLS handshake without a config
let have_tls = tls.is_some();
if !direct {
stream
.write_message(&Be::EncryptionResponse(have_tls))
.await?;
} else if !have_tls {
return Err(HandshakeError::ProtocolViolation);
}
if let Some(tls) = tls.take() {
// Upgrade raw stream into a secure TLS-backed stream.
// NOTE: We've consumed `tls`; this fact will be used later.
let Framed {
stream: raw,
read_buf,
write_buf,
} = stream.framed;
let mut read_buf;
let raw = if let Some(direct) = &direct {
read_buf = &direct[..];
stream.accept_direct_tls()
} else {
read_buf = &[];
stream.accept_tls().await?
};
let Stream::Raw { raw } = raw else {
return Err(HandshakeError::StreamUpgradeError(
@@ -105,12 +95,11 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
));
};
let mut read_buf = read_buf.reader();
let mut res = Ok(());
let accept = tokio_rustls::TlsAcceptor::from(tls.pg_config.clone())
.accept_with(raw, |session| {
// push the early data to the tls session
while !read_buf.get_ref().is_empty() {
while !read_buf.is_empty() {
match session.read_tls(&mut read_buf) {
Ok(_) => {}
Err(e) => {
@@ -123,7 +112,6 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
res?;
let read_buf = read_buf.into_inner();
if !read_buf.is_empty() {
return Err(HandshakeError::EarlyData);
}
@@ -157,16 +145,17 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
let (_, tls_server_end_point) =
tls.cert_resolver.resolve(conn_info.server_name());
stream = PqStream {
framed: Framed {
stream: Stream::Tls {
tls: Box::new(tls_stream),
tls_server_end_point,
},
read_buf,
write_buf,
},
let tls = Stream::Tls {
tls: Box::new(tls_stream),
tls_server_end_point,
};
(stream, msg) = PqStream::parse_startup(tls).await?;
} else {
if direct.is_some() {
// client sent us a ClientHello already, we can't do anything with it.
return Err(HandshakeError::ProtocolViolation);
}
msg = stream.reject_encryption().await?;
}
}
_ => return Err(HandshakeError::ProtocolViolation),
@@ -176,7 +165,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
tried_gss = true;
// Currently, we don't support GSSAPI
stream.write_message(&Be::EncryptionResponse(false)).await?;
msg = stream.reject_encryption().await?;
}
_ => return Err(HandshakeError::ProtocolViolation),
},
@@ -186,13 +175,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
// Check that the config has been consumed during upgrade
// OR we didn't provide it at all (for dev purposes).
if tls.is_some() {
return stream
.throw_error_str(
ERR_INSECURE_CONNECTION,
crate::error::ErrorKind::User,
None,
)
.await?;
Err(stream.throw_error(TlsRequired, None).await)?;
}
// This log highlights the start of the connection.
@@ -214,20 +197,21 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
// no protocol extensions are supported.
// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/backend/tcop/backend_startup.c#L744-L753>
let mut unsupported = vec![];
for (k, _) in params.iter() {
let mut supported = StartupMessageParams::default();
for (k, v) in params.iter() {
if k.starts_with("_pq_.") {
unsupported.push(k);
} else {
supported.insert(k, v);
}
}
// TODO: remove unsupported options so we don't send them to compute.
stream
.write_message(&Be::NegotiateProtocolVersion {
version: PG_PROTOCOL_LATEST,
options: &unsupported,
})
.await?;
stream.write_message(BeMessage::NegotiateProtocolVersion {
version: PG_PROTOCOL_LATEST,
options: &unsupported,
});
stream.flush().await?;
info!(
?version,
@@ -235,7 +219,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
session_type = "normal",
"successful handshake; unsupported minor version requested"
);
break Ok(HandshakeData::Startup(stream, params));
break Ok(HandshakeData::Startup(stream, supported));
}
FeStartupPacket::StartupMessage { version, params } => {
warn!(

View File

@@ -10,15 +10,14 @@ pub(crate) mod wake_compute;
use std::sync::Arc;
pub use copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
use futures::{FutureExt, TryFutureExt};
use futures::FutureExt;
use itertools::Itertools;
use once_cell::sync::OnceCell;
use pq_proto::{BeMessage as Be, CancelKeyData, StartupMessageParams};
use regex::Regex;
use serde::{Deserialize, Serialize};
use smol_str::{SmolStr, ToSmolStr, format_smolstr};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info, warn};
@@ -27,8 +26,9 @@ use self::passthrough::ProxyPassthrough;
use crate::cancellation::{self, CancellationHandler};
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams};
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
use crate::proxy::handshake::{HandshakeData, handshake};
use crate::rate_limiter::EndpointRateLimiter;
@@ -38,6 +38,18 @@ use crate::{auth, compute};
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
#[derive(Error, Debug)]
#[error("{ERR_INSECURE_CONNECTION}")]
pub struct TlsRequired;
impl ReportableError for TlsRequired {
fn get_error_kind(&self) -> crate::error::ErrorKind {
crate::error::ErrorKind::User
}
}
impl UserFacingError for TlsRequired {}
pub async fn run_until_cancelled<F: std::future::Future>(
f: F,
cancellation_token: &CancellationToken,
@@ -329,11 +341,11 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let user_info = match result {
Ok(user_info) => user_info,
Err(e) => stream.throw_error(e, Some(ctx)).await?,
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
};
let user = user_info.get_user().to_owned();
let (user_info, _ip_allowlist) = match user_info
let user_info = match user_info
.authenticate(
ctx,
&mut stream,
@@ -349,10 +361,10 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let app = params.get("application_name");
let params_span = tracing::info_span!("", ?user, ?db, ?app);
return stream
return Err(stream
.throw_error(e, Some(ctx))
.instrument(params_span)
.await?;
.await)?;
}
};
@@ -365,7 +377,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
.get(NeonOptions::PARAMS_COMPAT)
.is_some();
let mut node = connect_to_compute(
let res = connect_to_compute(
ctx,
&TcpMechanism {
user_info: compute_user_info.clone(),
@@ -377,22 +389,19 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config.wake_compute_retry_config,
&config.connect_to_compute,
)
.or_else(|e| stream.throw_error(e, Some(ctx)))
.await?;
.await;
let node = match res {
Ok(node) => node,
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
};
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let session = cancellation_handler_clone.get_key();
session.write_cancel_key(node.cancel_closure.clone())?;
prepare_client_connection(&node, *session.key(), &mut stream).await?;
// Before proxy passing, forward to compute whatever data is left in the
// PqStream input buffer. Normally there is none, but our serverless npm
// driver in pipeline mode sends startup, password and first query
// immediately after opening the connection.
let (stream, read_buf) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
prepare_client_connection(&node, *session.key(), &mut stream);
let stream = stream.flush_and_into_inner().await?;
let private_link_id = match ctx.extra() {
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
@@ -413,31 +422,28 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
}
/// Finish client connection initialization: confirm auth success, send params, etc.
#[tracing::instrument(skip_all)]
pub(crate) async fn prepare_client_connection(
pub(crate) fn prepare_client_connection(
node: &compute::PostgresConnection,
cancel_key_data: CancelKeyData,
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> Result<(), std::io::Error> {
) {
// Forward all deferred notices to the client.
for notice in &node.delayed_notice {
stream.write_message_noflush(&Be::Raw(b'N', notice.as_bytes()))?;
stream.write_raw(notice.as_bytes().len(), b'N', |buf| {
buf.extend_from_slice(notice.as_bytes());
});
}
// Forward all postgres connection params to the client.
for (name, value) in &node.params {
stream.write_message_noflush(&Be::ParameterStatus {
stream.write_message(BeMessage::ParameterStatus {
name: name.as_bytes(),
value: value.as_bytes(),
})?;
});
}
stream
.write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
.write_message(&Be::ReadyForQuery)
.await?;
Ok(())
stream.write_message(BeMessage::BackendKeyData(cancel_key_data));
stream.write_message(BeMessage::ReadyForQuery);
}
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]

View File

@@ -125,9 +125,10 @@ pub(crate) fn retry_after(num_retries: u32, config: RetryConfig) -> time::Durati
#[cfg(test)]
mod tests {
use super::ShouldRetryWakeCompute;
use postgres_client::error::{DbError, SqlState};
use super::ShouldRetryWakeCompute;
#[test]
fn should_retry_wake_compute_for_db_error() {
// These SQLStates should NOT trigger a wake_compute retry.

View File

@@ -10,7 +10,7 @@ use bytes::{Bytes, BytesMut};
use futures::{SinkExt, StreamExt};
use postgres_client::tls::TlsConnect;
use postgres_protocol::message::frontend;
use tokio::io::{AsyncReadExt, DuplexStream};
use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream};
use tokio_util::codec::{Decoder, Encoder};
use super::*;
@@ -49,15 +49,14 @@ async fn proxy_mitm(
};
let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame);
let (end_client, buf) = end_client.framed.into_inner();
assert!(buf.is_empty());
let end_client = end_client.flush_and_into_inner().await.unwrap();
let mut end_client = tokio_util::codec::Framed::new(end_client, PgFrame);
// give the end_server the startup parameters
let mut buf = BytesMut::new();
frontend::startup_message(
&postgres_protocol::message::frontend::StartupMessageParams {
params: startup.params.into(),
params: startup.params.as_bytes().into(),
},
&mut buf,
)

View File

@@ -26,9 +26,7 @@ use crate::auth::backend::{
use crate::config::{ComputeConfig, RetryConfig};
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
use crate::control_plane::{
self, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, NodeInfo, NodeInfoCache,
};
use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache};
use crate::error::ErrorKind;
use crate::tls::client_config::compute_client_config_with_certs;
use crate::tls::postgres_rustls::MakeRustlsConnect;
@@ -128,7 +126,7 @@ trait TestAuth: Sized {
self,
stream: &mut PqStream<Stream<S>>,
) -> anyhow::Result<()> {
stream.write_message_noflush(&Be::AuthenticationOk)?;
stream.write_message(BeMessage::AuthenticationOk);
Ok(())
}
}
@@ -157,9 +155,7 @@ impl TestAuth for Scram {
self,
stream: &mut PqStream<Stream<S>>,
) -> anyhow::Result<()> {
let outcome = auth::AuthFlow::new(stream)
.begin(auth::Scram(&self.0, &RequestContext::test()))
.await?
let outcome = auth::AuthFlow::new(stream, auth::Scram(&self.0, &RequestContext::test()))
.authenticate()
.await?;
@@ -185,10 +181,12 @@ async fn dummy_proxy(
auth.authenticate(&mut stream).await?;
stream
.write_message_noflush(&Be::CLIENT_ENCODING)?
.write_message(&Be::ReadyForQuery)
.await?;
stream.write_message(BeMessage::ParameterStatus {
name: b"client_encoding",
value: b"UTF8",
});
stream.write_message(BeMessage::ReadyForQuery);
stream.flush().await?;
Ok(())
}
@@ -547,20 +545,9 @@ impl TestControlPlaneClient for TestConnectMechanism {
}
}
fn get_allowed_ips(&self) -> Result<CachedAllowedIps, control_plane::errors::GetAuthInfoError> {
unimplemented!("not used in tests")
}
fn get_allowed_vpc_endpoint_ids(
fn get_access_control(
&self,
) -> Result<CachedAllowedVpcEndpointIds, control_plane::errors::GetAuthInfoError> {
unimplemented!("not used in tests")
}
fn get_block_public_or_vpc_access(
&self,
) -> Result<control_plane::CachedAccessBlockerFlags, control_plane::errors::GetAuthInfoError>
{
) -> Result<control_plane::EndpointAccessControl, control_plane::errors::GetAuthInfoError> {
unimplemented!("not used in tests")
}

View File

@@ -15,7 +15,7 @@ pub type EndpointRateLimiter = LeakyBucketRateLimiter<EndpointIdInt>;
pub struct LeakyBucketRateLimiter<Key> {
map: ClashMap<Key, LeakyBucketState, RandomState>,
config: utils::leaky_bucket::LeakyBucketConfig,
default_config: utils::leaky_bucket::LeakyBucketConfig,
access_count: AtomicUsize,
}
@@ -28,15 +28,17 @@ impl<K: Hash + Eq> LeakyBucketRateLimiter<K> {
pub fn new_with_shards(config: LeakyBucketConfig, shards: usize) -> Self {
Self {
map: ClashMap::with_hasher_and_shard_amount(RandomState::new(), shards),
config: config.into(),
default_config: config.into(),
access_count: AtomicUsize::new(0),
}
}
/// Check that number of connections to the endpoint is below `max_rps` rps.
pub(crate) fn check(&self, key: K, n: u32) -> bool {
pub(crate) fn check(&self, key: K, config: Option<LeakyBucketConfig>, n: u32) -> bool {
let now = Instant::now();
let config = config.map_or(self.default_config, Into::into);
if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
self.do_gc(now);
}
@@ -46,7 +48,7 @@ impl<K: Hash + Eq> LeakyBucketRateLimiter<K> {
.entry(key)
.or_insert_with(|| LeakyBucketState { empty_at: now });
entry.add_tokens(&self.config, now, n as f64).is_ok()
entry.add_tokens(&config, now, n as f64).is_ok()
}
fn do_gc(&self, now: Instant) {

View File

@@ -15,6 +15,8 @@ use tracing::info;
use crate::ext::LockExt;
use crate::intern::EndpointIdInt;
use super::LeakyBucketConfig;
pub struct GlobalRateLimiter {
data: Vec<RateBucket>,
info: Vec<RateBucketInfo>,
@@ -144,19 +146,6 @@ impl RateBucketInfo {
Self::new(50_000, Duration::from_secs(10)),
];
/// All of these are per endpoint-maskedip pair.
/// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus).
///
/// First bucket: 1000mcpus total per endpoint-ip pair
/// * 4096000 requests per second with 1 hash rounds.
/// * 1000 requests per second with 4096 hash rounds.
/// * 6.8 requests per second with 600000 hash rounds.
pub const DEFAULT_AUTH_SET: [Self; 3] = [
Self::new(1000 * 4096, Duration::from_secs(1)),
Self::new(600 * 4096, Duration::from_secs(60)),
Self::new(300 * 4096, Duration::from_secs(600)),
];
pub fn rps(&self) -> f64 {
(self.max_rpi as f64) / self.interval.as_secs_f64()
}
@@ -184,6 +173,21 @@ impl RateBucketInfo {
max_rpi: ((max_rps as u64) * (interval.as_millis() as u64) / 1000) as u32,
}
}
pub fn to_leaky_bucket(this: &[Self]) -> Option<LeakyBucketConfig> {
// bit of a hack - find the min rps and max rps supported and turn it into
// leaky bucket config instead
let mut iter = this.iter().map(|info| info.rps());
let first = iter.next()?;
let (min, max) = (first, first);
let (min, max) = iter.fold((min, max), |(min, max), rps| {
(f64::min(min, rps), f64::max(max, rps))
});
Some(LeakyBucketConfig { rps: min, max })
}
}
impl<K: Hash + Eq> BucketRateLimiter<K> {

View File

@@ -8,4 +8,4 @@ pub(crate) use limit_algorithm::aimd::Aimd;
pub(crate) use limit_algorithm::{
DynamicLimiter, Outcome, RateLimitAlgorithm, RateLimiterConfig, Token,
};
pub use limiter::{BucketRateLimiter, GlobalRateLimiter, RateBucketInfo, WakeComputeRateLimiter};
pub use limiter::{GlobalRateLimiter, RateBucketInfo, WakeComputeRateLimiter};

View File

@@ -1,10 +1,11 @@
use core::net::IpAddr;
use std::sync::Arc;
use pq_proto::CancelKeyData;
use tokio::sync::Mutex;
use uuid::Uuid;
use crate::pqproto::CancelKeyData;
pub trait CancellationPublisherMut: Send + Sync + 'static {
#[allow(async_fn_in_trait)]
async fn try_publish(

View File

@@ -1,16 +1,15 @@
use std::io::ErrorKind;
use anyhow::Ok;
use pq_proto::{CancelKeyData, id_to_cancel_key};
use serde::{Deserialize, Serialize};
use crate::pqproto::{CancelKeyData, id_to_cancel_key};
pub mod keyspace {
pub const CANCEL_PREFIX: &str = "cancel";
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
#[derive(Clone, Debug, Eq, PartialEq)]
pub(crate) enum KeyPrefix {
#[serde(untagged)]
Cancel(CancelKeyData),
}
@@ -18,9 +17,7 @@ impl KeyPrefix {
pub(crate) fn build_redis_key(&self) -> String {
match self {
KeyPrefix::Cancel(key) => {
let hi = (key.backend_pid as u64) << 32;
let lo = (key.cancel_key as u64) & 0xffff_ffff;
let id = hi | lo;
let id = key.0.get();
let keyspace = keyspace::CANCEL_PREFIX;
format!("{keyspace}:{id:x}")
}
@@ -63,10 +60,7 @@ mod tests {
#[test]
fn test_build_redis_key() {
let cancel_key: KeyPrefix = KeyPrefix::Cancel(CancelKeyData {
backend_pid: 12345,
cancel_key: 54321,
});
let cancel_key: KeyPrefix = KeyPrefix::Cancel(id_to_cancel_key(12345 << 32 | 54321));
let redis_key = cancel_key.build_redis_key();
assert_eq!(redis_key, "cancel:30390000d431");
@@ -77,10 +71,7 @@ mod tests {
let redis_key = "cancel:30390000d431";
let key: KeyPrefix = parse_redis_key(redis_key).expect("Failed to parse key");
let ref_key = CancelKeyData {
backend_pid: 12345,
cancel_key: 54321,
};
let ref_key = id_to_cancel_key(12345 << 32 | 54321);
assert_eq!(key.as_str(), KeyPrefix::Cancel(ref_key).as_str());
let KeyPrefix::Cancel(cancel_key) = key;

View File

@@ -2,11 +2,9 @@ use std::convert::Infallible;
use std::sync::Arc;
use futures::StreamExt;
use pq_proto::CancelKeyData;
use redis::aio::PubSub;
use serde::{Deserialize, Serialize};
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::cache::project_info::ProjectInfoCache;
@@ -100,14 +98,6 @@ pub(crate) struct PasswordUpdate {
role_name: RoleNameInt,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct CancelSession {
pub(crate) region_id: Option<String>,
pub(crate) cancel_key_data: CancelKeyData,
pub(crate) session_id: Uuid,
pub(crate) peer_addr: Option<std::net::IpAddr>,
}
fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
T: for<'de2> serde::Deserialize<'de2>,
@@ -243,29 +233,30 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
match msg {
Notification::AllowedIpsUpdate { allowed_ips_update } => {
cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id);
Notification::AllowedIpsUpdate {
allowed_ips_update: AllowedIpsUpdate { project_id },
}
Notification::BlockPublicOrVpcAccessUpdated {
block_public_or_vpc_access_updated,
} => cache.invalidate_block_public_or_vpc_access_for_project(
block_public_or_vpc_access_updated.project_id,
),
| Notification::BlockPublicOrVpcAccessUpdated {
block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated { project_id },
} => cache.invalidate_endpoint_access_for_project(project_id),
Notification::AllowedVpcEndpointsUpdatedForOrg {
allowed_vpc_endpoints_updated_for_org,
} => cache.invalidate_allowed_vpc_endpoint_ids_for_org(
allowed_vpc_endpoints_updated_for_org.account_id,
),
allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg { account_id },
} => cache.invalidate_endpoint_access_for_org(account_id),
Notification::AllowedVpcEndpointsUpdatedForProjects {
allowed_vpc_endpoints_updated_for_projects,
} => cache.invalidate_allowed_vpc_endpoint_ids_for_projects(
allowed_vpc_endpoints_updated_for_projects.project_ids,
),
Notification::PasswordUpdate { password_update } => cache
.invalidate_role_secret_for_project(
password_update.project_id,
password_update.role_name,
),
allowed_vpc_endpoints_updated_for_projects:
AllowedVpcEndpointsUpdatedForProjects { project_ids },
} => {
for project in project_ids {
cache.invalidate_endpoint_access_for_project(project);
}
}
Notification::PasswordUpdate {
password_update:
PasswordUpdate {
project_id,
role_name,
},
} => cache.invalidate_role_secret_for_project(project_id, role_name),
Notification::UnknownTopic => unreachable!(),
}
}

View File

@@ -1,7 +1,5 @@
//! Definitions for SASL messages.
use pq_proto::{BeAuthenticationSaslMessage, BeMessage};
use crate::parse::split_cstr;
/// SASL-specific payload of [`PasswordMessage`](pq_proto::FeMessage::PasswordMessage).
@@ -30,26 +28,6 @@ impl<'a> FirstMessage<'a> {
}
}
/// A single SASL message.
/// This struct is deliberately decoupled from lower-level
/// [`BeAuthenticationSaslMessage`].
#[derive(Debug)]
pub(super) enum ServerMessage<T> {
/// We expect to see more steps.
Continue(T),
/// This is the final step.
Final(T),
}
impl<'a> ServerMessage<&'a str> {
pub(super) fn to_reply(&self) -> BeMessage<'a> {
BeMessage::AuthenticationSasl(match self {
ServerMessage::Continue(s) => BeAuthenticationSaslMessage::Continue(s.as_bytes()),
ServerMessage::Final(s) => BeAuthenticationSaslMessage::Final(s.as_bytes()),
})
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -14,7 +14,7 @@ use std::io;
pub(crate) use channel_binding::ChannelBinding;
pub(crate) use messages::FirstMessage;
pub(crate) use stream::{Outcome, SaslStream};
pub(crate) use stream::{Outcome, authenticate};
use thiserror::Error;
use crate::error::{ReportableError, UserFacingError};
@@ -22,6 +22,9 @@ use crate::error::{ReportableError, UserFacingError};
/// Fine-grained auth errors help in writing tests.
#[derive(Error, Debug)]
pub(crate) enum Error {
#[error("Unsupported authentication method: {0}")]
BadAuthMethod(Box<str>),
#[error("Channel binding failed: {0}")]
ChannelBindingFailed(&'static str),
@@ -54,6 +57,7 @@ impl UserFacingError for Error {
impl ReportableError for Error {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
Error::BadAuthMethod(_) => crate::error::ErrorKind::User,
Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User,
Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User,
Error::BadClientMessage(_) => crate::error::ErrorKind::User,

View File

@@ -3,61 +3,12 @@
use std::io;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
use super::Mechanism;
use super::messages::ServerMessage;
use super::{Mechanism, Step};
use crate::context::RequestContext;
use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
use crate::stream::PqStream;
/// Abstracts away all peculiarities of the libpq's protocol.
pub(crate) struct SaslStream<'a, S> {
/// The underlying stream.
stream: &'a mut PqStream<S>,
/// Current password message we received from client.
current: bytes::Bytes,
/// First SASL message produced by client.
first: Option<&'a str>,
}
impl<'a, S> SaslStream<'a, S> {
pub(crate) fn new(stream: &'a mut PqStream<S>, first: &'a str) -> Self {
Self {
stream,
current: bytes::Bytes::new(),
first: Some(first),
}
}
}
impl<S: AsyncRead + Unpin> SaslStream<'_, S> {
// Receive a new SASL message from the client.
async fn recv(&mut self) -> io::Result<&str> {
if let Some(first) = self.first.take() {
return Ok(first);
}
self.current = self.stream.read_password_message().await?;
let s = std::str::from_utf8(&self.current)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
Ok(s)
}
}
impl<S: AsyncWrite + Unpin> SaslStream<'_, S> {
// Send a SASL message to the client.
async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
self.stream.write_message(&msg.to_reply()).await?;
Ok(())
}
// Queue a SASL message for the client.
fn send_noflush(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
self.stream.write_message_noflush(&msg.to_reply())?;
Ok(())
}
}
/// SASL authentication outcome.
/// It's much easier to match on those two variants
/// than to peek into a noisy protocol error type.
@@ -69,33 +20,62 @@ pub(crate) enum Outcome<R> {
Failure(&'static str),
}
impl<S: AsyncRead + AsyncWrite + Unpin> SaslStream<'_, S> {
/// Perform SASL message exchange according to the underlying algorithm
/// until user is either authenticated or denied access.
pub(crate) async fn authenticate<M: Mechanism>(
mut self,
mut mechanism: M,
) -> super::Result<Outcome<M::Output>> {
loop {
let input = self.recv().await?;
let step = mechanism.exchange(input).map_err(|error| {
info!(?error, "error during SASL exchange");
error
})?;
pub async fn authenticate<S, F, M>(
ctx: &RequestContext,
stream: &mut PqStream<S>,
mechanism: F,
) -> super::Result<Outcome<M::Output>>
where
S: AsyncRead + AsyncWrite + Unpin,
F: FnOnce(&str) -> super::Result<M>,
M: Mechanism,
{
let sasl = {
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
use super::Step;
return Ok(match step {
Step::Continue(moved_mechanism, reply) => {
self.send(&ServerMessage::Continue(&reply)).await?;
mechanism = moved_mechanism;
continue;
}
Step::Success(result, reply) => {
self.send_noflush(&ServerMessage::Final(&reply))?;
Outcome::Success(result)
}
Step::Failure(reason) => Outcome::Failure(reason),
});
// Initial client message contains the chosen auth method's name.
let msg = stream.read_password_message().await?;
super::FirstMessage::parse(msg).ok_or(super::Error::BadClientMessage("bad sasl message"))?
};
let mut mechanism = mechanism(sasl.method)?;
let mut input = sasl.message;
loop {
let step = mechanism
.exchange(input)
.inspect_err(|error| tracing::info!(?error, "error during SASL exchange"))?;
match step {
Step::Continue(moved_mechanism, reply) => {
mechanism = moved_mechanism;
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
// write reply
let sasl_msg = BeAuthenticationSaslMessage::Continue(reply.as_bytes());
stream.write_message(BeMessage::AuthenticationSasl(sasl_msg));
// get next input
stream.flush().await?;
let msg = stream.read_password_message().await?;
input = std::str::from_utf8(msg)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
}
Step::Success(result, reply) => {
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
// write reply
let sasl_msg = BeAuthenticationSaslMessage::Final(reply.as_bytes());
stream.write_message(BeMessage::AuthenticationSasl(sasl_msg));
stream.write_message(BeMessage::AuthenticationOk);
// exit with success
break Ok(Outcome::Success(result));
}
// exit with failure
Step::Failure(reason) => break Ok(Outcome::Failure(reason)),
}
}
}

View File

@@ -22,7 +22,7 @@ use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client};
use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool};
use crate::auth::backend::local::StaticAuthRules;
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
use crate::auth::{self, AuthError, check_peer_addr_is_in_list};
use crate::auth::{self, AuthError};
use crate::compute;
use crate::compute_ctl::{
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
@@ -35,7 +35,6 @@ use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
use crate::control_plane::locks::ApiLocks;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::intern::EndpointIdInt;
use crate::protocol2::ConnectionInfoExtra;
use crate::proxy::connect_compute::ConnectMechanism;
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute};
use crate::rate_limiter::EndpointRateLimiter;
@@ -63,63 +62,24 @@ impl PoolingBackend {
let user_info = user_info.clone();
let backend = self.auth_backend.as_ref().map(|()| user_info.clone());
let allowed_ips = backend.get_allowed_ips(ctx).await?;
let access_control = backend.get_endpoint_access_control(ctx).await?;
access_control.check(
ctx,
self.config.authentication_config.ip_allowlist_check_enabled,
self.config.authentication_config.is_vpc_acccess_proxy,
)?;
if self.config.authentication_config.ip_allowlist_check_enabled
&& !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
{
return Err(AuthError::ip_address_not_allowed(ctx.peer_addr()));
}
let access_blocker_flags = backend.get_block_public_or_vpc_access(ctx).await?;
if self.config.authentication_config.is_vpc_acccess_proxy {
if access_blocker_flags.vpc_access_blocked {
return Err(AuthError::NetworkNotAllowed);
}
let extra = ctx.extra();
let incoming_endpoint_id = match extra {
None => String::new(),
Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(),
Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(),
};
if incoming_endpoint_id.is_empty() {
return Err(AuthError::MissingVPCEndpointId);
}
let allowed_vpc_endpoint_ids = backend.get_allowed_vpc_endpoint_ids(ctx).await?;
// TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that.
if !allowed_vpc_endpoint_ids.is_empty()
&& !allowed_vpc_endpoint_ids.contains(&incoming_endpoint_id)
{
return Err(AuthError::vpc_endpoint_id_not_allowed(incoming_endpoint_id));
}
} else if access_blocker_flags.public_access_blocked {
return Err(AuthError::NetworkNotAllowed);
}
if !self
.endpoint_rate_limiter
.check(user_info.endpoint.clone().into(), 1)
{
let ep = EndpointIdInt::from(&user_info.endpoint);
let rate_limit_config = None;
if !self.endpoint_rate_limiter.check(ep, rate_limit_config, 1) {
return Err(AuthError::too_many_connections());
}
let cached_secret = backend.get_role_secret(ctx).await?;
let secret = match cached_secret.value.clone() {
Some(secret) => self.config.authentication_config.check_rate_limit(
ctx,
secret,
&user_info.endpoint,
true,
)?,
None => {
// If we don't have an authentication secret, for the http flow we can just return an error.
info!("authentication info not found");
return Err(AuthError::password_failed(&*user_info.user));
}
let role_access = backend.get_role_secret(ctx).await?;
let Some(secret) = role_access.secret else {
// If we don't have an authentication secret, for the http flow we can just return an error.
info!("authentication info not found");
return Err(AuthError::password_failed(&*user_info.user));
};
let ep = EndpointIdInt::from(&user_info.endpoint);
let auth_outcome = crate::auth::validate_password_and_exchange(
&self.config.authentication_config.thread_pool,
ep,

View File

@@ -17,7 +17,6 @@ use postgres_client::error::{DbError, ErrorPosition, SqlState};
use postgres_client::{
GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction,
};
use pq_proto::StartupMessageParamsBuilder;
use serde::Serialize;
use serde_json::Value;
use serde_json::value::RawValue;
@@ -41,6 +40,7 @@ use crate::context::RequestContext;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::http::{ReadBodyError, read_body_with_limit};
use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind};
use crate::pqproto::StartupMessageParams;
use crate::proxy::{NeonOptions, run_until_cancelled};
use crate::serverless::backend::HttpConnError;
use crate::types::{DbName, RoleName};
@@ -219,7 +219,7 @@ fn get_conn_info(
let mut options = Option::None;
let mut params = StartupMessageParamsBuilder::default();
let mut params = StartupMessageParams::default();
params.insert("user", &username);
params.insert("database", &dbname);
for (key, value) in pairs {

View File

@@ -2,19 +2,17 @@ use std::pin::Pin;
use std::sync::Arc;
use std::{io, task};
use bytes::BytesMut;
use pq_proto::framed::{ConnectionError, Framed};
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
use rustls::ServerConfig;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio_rustls::server::TlsStream;
use tracing::debug;
use crate::control_plane::messages::ColdStartInfo;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::Metrics;
use crate::pqproto::{
BeMessage, FE_PASSWORD_MESSAGE, FeStartupPacket, SQLSTATE_INTERNAL_ERROR, WriteBuf,
read_message, read_startup,
};
use crate::tls::TlsServerEndPoint;
/// Stream wrapper which implements libpq's protocol.
@@ -23,58 +21,77 @@ use crate::tls::TlsServerEndPoint;
/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
/// to pass random malformed bytes through the connection).
pub struct PqStream<S> {
pub(crate) framed: Framed<S>,
stream: S,
read: Vec<u8>,
write: WriteBuf,
}
impl<S> PqStream<S> {
/// Construct a new libpq protocol wrapper.
pub fn new(stream: S) -> Self {
pub fn get_ref(&self) -> &S {
&self.stream
}
/// Construct a new libpq protocol wrapper over a stream without the first startup message.
#[cfg(test)]
pub fn new_skip_handshake(stream: S) -> Self {
Self {
framed: Framed::new(stream),
stream,
read: Vec::new(),
write: WriteBuf::new(),
}
}
/// Extract the underlying stream and read buffer.
pub fn into_inner(self) -> (S, BytesMut) {
self.framed.into_inner()
}
/// Get a shared reference to the underlying stream.
pub(crate) fn get_ref(&self) -> &S {
self.framed.get_ref()
}
}
fn err_connection() -> io::Error {
io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost")
impl<S: AsyncRead + AsyncWrite + Unpin> PqStream<S> {
/// Construct a new libpq protocol wrapper and read the first startup message.
///
/// This is not cancel safe.
pub async fn parse_startup(mut stream: S) -> io::Result<(Self, FeStartupPacket)> {
let startup = read_startup(&mut stream).await?;
Ok((
Self {
stream,
read: Vec::new(),
write: WriteBuf::new(),
},
startup,
))
}
/// Tell the client that encryption is not supported.
///
/// This is not cancel safe
pub async fn reject_encryption(&mut self) -> io::Result<FeStartupPacket> {
// N for No.
self.write.encryption(b'N');
self.flush().await?;
read_startup(&mut self.stream).await
}
}
impl<S: AsyncRead + Unpin> PqStream<S> {
/// Receive [`FeStartupPacket`], which is a first packet sent by a client.
pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
self.framed
.read_startup_message()
.await
.map_err(ConnectionError::into_io_error)?
.ok_or_else(err_connection)
}
async fn read_message(&mut self) -> io::Result<FeMessage> {
self.framed
.read_message()
.await
.map_err(ConnectionError::into_io_error)?
.ok_or_else(err_connection)
}
pub(crate) async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
match self.read_message().await? {
FeMessage::PasswordMessage(msg) => Ok(msg),
bad => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("unexpected message type: {bad:?}"),
)),
/// Read a raw postgres packet, which will respect the max length requested.
/// This is not cancel safe.
async fn read_raw_expect(&mut self, tag: u8, max: usize) -> io::Result<&mut [u8]> {
let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?;
if actual_tag != tag {
return Err(io::Error::other(format!(
"incorrect message tag, expected {:?}, got {:?}",
tag as char, actual_tag as char,
)));
}
Ok(msg)
}
/// Read a postgres password message, which will respect the max length requested.
/// This is not cancel safe.
pub async fn read_password_message(&mut self) -> io::Result<&mut [u8]> {
// passwords are usually pretty short
// and SASL SCRAM messages are no longer than 256 bytes in my testing
// (a few hashes and random bytes, encoded into base64).
const MAX_PASSWORD_LENGTH: usize = 512;
self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH)
.await
}
}
@@ -84,6 +101,16 @@ pub struct ReportedError {
error_kind: ErrorKind,
}
impl ReportedError {
pub fn new(e: (impl UserFacingError + Into<anyhow::Error>)) -> Self {
let error_kind = e.get_error_kind();
Self {
source: e.into(),
error_kind,
}
}
}
impl std::fmt::Display for ReportedError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.source.fmt(f)
@@ -102,109 +129,65 @@ impl ReportableError for ReportedError {
}
}
#[derive(Serialize, Deserialize, Debug)]
enum ErrorTag {
#[serde(rename = "proxy")]
Proxy,
#[serde(rename = "compute")]
Compute,
#[serde(rename = "client")]
Client,
#[serde(rename = "controlplane")]
ControlPlane,
#[serde(rename = "other")]
Other,
}
impl From<ErrorKind> for ErrorTag {
fn from(error_kind: ErrorKind) -> Self {
match error_kind {
ErrorKind::User => Self::Client,
ErrorKind::ClientDisconnect => Self::Client,
ErrorKind::RateLimit => Self::Proxy,
ErrorKind::ServiceRateLimit => Self::Proxy, // considering rate limit as proxy error for SLI
ErrorKind::Quota => Self::Proxy,
ErrorKind::Service => Self::Proxy,
ErrorKind::ControlPlane => Self::ControlPlane,
ErrorKind::Postgres => Self::Other,
ErrorKind::Compute => Self::Compute,
}
}
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(rename_all = "snake_case")]
struct ProbeErrorData {
tag: ErrorTag,
msg: String,
cold_start_info: Option<ColdStartInfo>,
}
impl<S: AsyncWrite + Unpin> PqStream<S> {
/// Write the message into an internal buffer, but don't flush the underlying stream.
pub(crate) fn write_message_noflush(
&mut self,
message: &BeMessage<'_>,
) -> io::Result<&mut Self> {
self.framed
.write_message(message)
.map_err(ProtocolError::into_io_error)?;
Ok(self)
/// Tell the client that we are willing to accept SSL.
/// This is not cancel safe
pub async fn accept_tls(mut self) -> io::Result<S> {
// S for SSL.
self.write.encryption(b'S');
self.flush().await?;
Ok(self.stream)
}
/// Write the message into an internal buffer and flush it.
pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
self.write_message_noflush(message)?;
self.flush().await?;
Ok(self)
/// Assert that we are using direct TLS.
pub fn accept_direct_tls(self) -> S {
self.stream
}
/// Write a raw message to the internal buffer.
pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec<u8>)) {
self.write.write_raw(size_hint, tag, f);
}
/// Write the message into an internal buffer
pub fn write_message(&mut self, message: BeMessage<'_>) {
message.write_message(&mut self.write);
}
/// Flush the output buffer into the underlying stream.
pub(crate) async fn flush(&mut self) -> io::Result<&mut Self> {
self.framed.flush().await?;
Ok(self)
///
/// This is cancel safe.
pub async fn flush(&mut self) -> io::Result<()> {
self.stream.write_all_buf(&mut self.write).await?;
self.write.reset();
self.stream.flush().await?;
Ok(())
}
/// Writes message with the given error kind to the stream.
/// Used only for probe queries
async fn write_format_message(
&mut self,
msg: &str,
error_kind: ErrorKind,
ctx: Option<&crate::context::RequestContext>,
) -> String {
let formatted_msg = match ctx {
Some(ctx) if ctx.get_testodrome_id().is_some() => {
serde_json::to_string(&ProbeErrorData {
tag: ErrorTag::from(error_kind),
msg: msg.to_string(),
cold_start_info: Some(ctx.cold_start_info()),
})
.unwrap_or_default()
}
_ => msg.to_string(),
};
// already error case, ignore client IO error
self.write_message(&BeMessage::ErrorResponse(&formatted_msg, None))
.await
.inspect_err(|e| debug!("write_message failed: {e}"))
.ok();
formatted_msg
/// Flush the output buffer into the underlying stream.
///
/// This is cancel safe.
pub async fn flush_and_into_inner(mut self) -> io::Result<S> {
self.flush().await?;
Ok(self.stream)
}
/// Write the error message using [`Self::write_format_message`], then re-throw it.
/// Allowing string literals is safe under the assumption they might not contain any runtime info.
/// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
/// Write the error message to the client, then re-throw it.
///
/// Trait [`UserFacingError`] acts as an allowlist for error types.
/// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind.
pub async fn throw_error_str<T>(
pub(crate) async fn throw_error<E>(
&mut self,
msg: &'static str,
error_kind: ErrorKind,
error: E,
ctx: Option<&crate::context::RequestContext>,
) -> Result<T, ReportedError> {
self.write_format_message(msg, error_kind, ctx).await;
) -> ReportedError
where
E: UserFacingError + Into<anyhow::Error>,
{
let error_kind = error.get_error_kind();
let msg = error.to_string_client();
if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User {
tracing::info!(
@@ -214,39 +197,39 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
);
}
Err(ReportedError {
source: anyhow::anyhow!(msg),
error_kind,
})
}
/// Write the error message using [`Self::write_format_message`], then re-throw it.
/// Trait [`UserFacingError`] acts as an allowlist for error types.
/// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind.
pub(crate) async fn throw_error<T, E>(
&mut self,
error: E,
ctx: Option<&crate::context::RequestContext>,
) -> Result<T, ReportedError>
where
E: UserFacingError + Into<anyhow::Error>,
{
let error_kind = error.get_error_kind();
let msg = error.to_string_client();
self.write_format_message(&msg, error_kind, ctx).await;
if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User {
tracing::info!(
kind=error_kind.to_metric_label(),
error=%error,
msg,
"forwarding error to user",
);
let probe_msg;
let mut msg = &*msg;
if let Some(ctx) = ctx {
if ctx.get_testodrome_id().is_some() {
let tag = match error_kind {
ErrorKind::User => "client",
ErrorKind::ClientDisconnect => "client",
ErrorKind::RateLimit => "proxy",
ErrorKind::ServiceRateLimit => "proxy",
ErrorKind::Quota => "proxy",
ErrorKind::Service => "proxy",
ErrorKind::ControlPlane => "controlplane",
ErrorKind::Postgres => "other",
ErrorKind::Compute => "compute",
};
probe_msg = typed_json::json!({
"tag": tag,
"msg": msg,
"cold_start_info": ctx.cold_start_info(),
})
.to_string();
msg = &probe_msg;
}
}
Err(ReportedError {
source: anyhow::anyhow!(error),
error_kind,
})
// TODO: either preserve the error code from postgres, or assign error codes to proxy errors.
self.write.write_error(msg, SQLSTATE_INTERNAL_ERROR);
self.flush()
.await
.unwrap_or_else(|e| tracing::debug!("write_message failed: {e}"));
ReportedError::new(error)
}
}

View File

@@ -124,6 +124,9 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver,
".*downloading failed, possibly for shutdown",
# {tenant_id=... timeline_id=...}:handle_pagerequests:handle_get_page_at_lsn_request{rel=1664/0/1260 blkno=0 req_lsn=0/149F0D8}: error reading relation or page version: Not found: will not become active. Current state: Stopping\n'
".*page_service.*will not become active.*",
# the following errors are possible when pageserver tries to ingest wal records despite being in unreadable state
".*wal_connection_manager.*layer file download failed: No file found.*",
".*wal_connection_manager.*could not ingest record.*",
]
)