diff --git a/build-tools.Dockerfile b/build-tools.Dockerfile index 9d4c93e1cd..f97f04968e 100644 --- a/build-tools.Dockerfile +++ b/build-tools.Dockerfile @@ -310,13 +310,13 @@ RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux . "$HOME/.cargo/env" && \ cargo --version && rustup --version && \ rustup component add llvm-tools rustfmt clippy && \ - cargo install rustfilt --version ${RUSTFILT_VERSION} && \ - cargo install cargo-hakari --version ${CARGO_HAKARI_VERSION} && \ - cargo install cargo-deny --locked --version ${CARGO_DENY_VERSION} && \ - cargo install cargo-hack --version ${CARGO_HACK_VERSION} && \ - cargo install cargo-nextest --version ${CARGO_NEXTEST_VERSION} && \ - cargo install cargo-chef --locked --version ${CARGO_CHEF_VERSION} && \ - cargo install diesel_cli --version ${CARGO_DIESEL_CLI_VERSION} \ + cargo install rustfilt --version ${RUSTFILT_VERSION} --locked && \ + cargo install cargo-hakari --version ${CARGO_HAKARI_VERSION} --locked && \ + cargo install cargo-deny --version ${CARGO_DENY_VERSION} --locked && \ + cargo install cargo-hack --version ${CARGO_HACK_VERSION} --locked && \ + cargo install cargo-nextest --version ${CARGO_NEXTEST_VERSION} --locked && \ + cargo install cargo-chef --version ${CARGO_CHEF_VERSION} --locked && \ + cargo install diesel_cli --version ${CARGO_DIESEL_CLI_VERSION} --locked \ --features postgres-bundled --no-default-features && \ rm -rf /home/nonroot/.cargo/registry && \ rm -rf /home/nonroot/.cargo/git diff --git a/compute_tools/src/bin/compute_ctl.rs b/compute_tools/src/bin/compute_ctl.rs index 02339f752c..db6835da61 100644 --- a/compute_tools/src/bin/compute_ctl.rs +++ b/compute_tools/src/bin/compute_ctl.rs @@ -57,21 +57,6 @@ use tracing::{error, info}; use url::Url; use utils::failpoint_support; -// Compatibility hack: if the control plane specified any remote-ext-config -// use the default value for extension storage proxy gateway. -// Remove this once the control plane is updated to pass the gateway URL -fn parse_remote_ext_base_url(arg: &str) -> Result { - const FALLBACK_PG_EXT_GATEWAY_BASE_URL: &str = - "http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local"; - - Ok(if arg.starts_with("http") { - arg - } else { - FALLBACK_PG_EXT_GATEWAY_BASE_URL - } - .to_owned()) -} - #[derive(Parser)] #[command(rename_all = "kebab-case")] struct Cli { @@ -79,9 +64,8 @@ struct Cli { pub pgbin: String, /// The base URL for the remote extension storage proxy gateway. - /// Should be in the form of `http(s)://[:]`. - #[arg(short = 'r', long, value_parser = parse_remote_ext_base_url, alias = "remote-ext-config")] - pub remote_ext_base_url: Option, + #[arg(short = 'r', long)] + pub remote_ext_base_url: Option, /// The port to bind the external listening HTTP server to. Clients running /// outside the compute will talk to the compute through this port. Keep @@ -276,18 +260,4 @@ mod test { fn verify_cli() { Cli::command().debug_assert() } - - #[test] - fn parse_pg_ext_gateway_base_url() { - let arg = "http://pg-ext-s3-gateway2"; - let result = super::parse_remote_ext_base_url(arg).unwrap(); - assert_eq!(result, arg); - - let arg = "pg-ext-s3-gateway"; - let result = super::parse_remote_ext_base_url(arg).unwrap(); - assert_eq!( - result, - "http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local" - ); - } } diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index ff49c737f0..d678b7d670 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -31,6 +31,7 @@ use std::time::{Duration, Instant}; use std::{env, fs}; use tokio::spawn; use tracing::{Instrument, debug, error, info, instrument, warn}; +use url::Url; use utils::id::{TenantId, TimelineId}; use utils::lsn::Lsn; use utils::measured_stream::MeasuredReader; @@ -96,7 +97,7 @@ pub struct ComputeNodeParams { pub internal_http_port: u16, /// the address of extension storage proxy gateway - pub remote_ext_base_url: Option, + pub remote_ext_base_url: Option, /// Interval for installed extensions collection pub installed_extensions_collection_interval: u64, diff --git a/compute_tools/src/extension_server.rs b/compute_tools/src/extension_server.rs index 3439383699..1857afa08c 100644 --- a/compute_tools/src/extension_server.rs +++ b/compute_tools/src/extension_server.rs @@ -83,6 +83,7 @@ use reqwest::StatusCode; use tar::Archive; use tracing::info; use tracing::log::warn; +use url::Url; use zstd::stream::read::Decoder; use crate::metrics::{REMOTE_EXT_REQUESTS_TOTAL, UNKNOWN_HTTP_STATUS}; @@ -158,14 +159,14 @@ fn parse_pg_version(human_version: &str) -> PostgresMajorVersion { pub async fn download_extension( ext_name: &str, ext_path: &RemotePath, - remote_ext_base_url: &str, + remote_ext_base_url: &Url, pgbin: &str, ) -> Result { info!("Download extension {:?} from {:?}", ext_name, ext_path); // TODO add retry logic let download_buffer = - match download_extension_tar(remote_ext_base_url, &ext_path.to_string()).await { + match download_extension_tar(remote_ext_base_url.as_str(), &ext_path.to_string()).await { Ok(buffer) => buffer, Err(error_message) => { return Err(anyhow::anyhow!( diff --git a/libs/metrics/src/hll.rs b/libs/metrics/src/hll.rs index 93f6a2b7cc..1a7d7a7e44 100644 --- a/libs/metrics/src/hll.rs +++ b/libs/metrics/src/hll.rs @@ -107,7 +107,7 @@ impl MetricType for HyperLogLogState { } impl HyperLogLogState { - pub fn measure(&self, item: &impl Hash) { + pub fn measure(&self, item: &(impl Hash + ?Sized)) { // changing the hasher will break compatibility with previous measurements. self.record(BuildHasherDefault::::default().hash_one(item)); } diff --git a/libs/pageserver_api/src/config.rs b/libs/pageserver_api/src/config.rs index 012c020fb1..444983bd18 100644 --- a/libs/pageserver_api/src/config.rs +++ b/libs/pageserver_api/src/config.rs @@ -713,9 +713,9 @@ impl Default for ConfigToml { enable_tls_page_service_api: false, dev_mode: false, timeline_import_config: TimelineImportConfig { - import_job_concurrency: NonZeroUsize::new(128).unwrap(), - import_job_soft_size_limit: NonZeroUsize::new(1024 * 1024 * 1024).unwrap(), - import_job_checkpoint_threshold: NonZeroUsize::new(128).unwrap(), + import_job_concurrency: NonZeroUsize::new(32).unwrap(), + import_job_soft_size_limit: NonZeroUsize::new(256 * 1024 * 1024).unwrap(), + import_job_checkpoint_threshold: NonZeroUsize::new(32).unwrap(), }, basebackup_cache_config: None, posthog_config: None, diff --git a/libs/utils/src/leaky_bucket.rs b/libs/utils/src/leaky_bucket.rs index 2398f92766..17e96bd0a9 100644 --- a/libs/utils/src/leaky_bucket.rs +++ b/libs/utils/src/leaky_bucket.rs @@ -28,6 +28,7 @@ use std::time::Duration; use tokio::sync::Notify; use tokio::time::Instant; +#[derive(Clone, Copy)] pub struct LeakyBucketConfig { /// This is the "time cost" of a single request unit. /// Should loosely represent how long it takes to handle a request unit in active resource time. diff --git a/pageserver/src/tenant/timeline/import_pgdata.rs b/pageserver/src/tenant/timeline/import_pgdata.rs index f19a4b3e9c..3f760d858b 100644 --- a/pageserver/src/tenant/timeline/import_pgdata.rs +++ b/pageserver/src/tenant/timeline/import_pgdata.rs @@ -106,6 +106,8 @@ pub async fn doit( ); } + tracing::info!("Import plan executed. Flushing remote changes and notifying storcon"); + timeline .remote_client .schedule_index_upload_for_file_changes()?; diff --git a/pageserver/src/tenant/timeline/import_pgdata/flow.rs b/pageserver/src/tenant/timeline/import_pgdata/flow.rs index bf3c7eeda6..760e82dd57 100644 --- a/pageserver/src/tenant/timeline/import_pgdata/flow.rs +++ b/pageserver/src/tenant/timeline/import_pgdata/flow.rs @@ -130,7 +130,15 @@ async fn run_v1( pausable_failpoint!("import-timeline-pre-execute-pausable"); + let jobs_count = import_progress.as_ref().map(|p| p.jobs); let start_from_job_idx = import_progress.map(|progress| progress.completed); + + tracing::info!( + start_from_job_idx=?start_from_job_idx, + jobs=?jobs_count, + "Executing import plan" + ); + plan.execute(timeline, start_from_job_idx, plan_hash, &import_config, ctx) .await } @@ -484,6 +492,8 @@ impl Plan { anyhow::anyhow!("Shut down while putting timeline import status") })?; } + + tracing::info!(last_completed_job_idx, jobs=%jobs_in_plan, "Checkpointing import status"); }, Some(Err(_)) => { anyhow::bail!( @@ -760,7 +770,7 @@ impl ImportTask for ImportRelBlocksTask { layer_writer: &mut ImageLayerWriter, ctx: &RequestContext, ) -> anyhow::Result { - const MAX_BYTE_RANGE_SIZE: usize = 128 * 1024 * 1024; + const MAX_BYTE_RANGE_SIZE: usize = 4 * 1024 * 1024; debug!("Importing relation file"); diff --git a/pageserver/src/tenant/timeline/walreceiver.rs b/pageserver/src/tenant/timeline/walreceiver.rs index 0f73eb839b..633c94a010 100644 --- a/pageserver/src/tenant/timeline/walreceiver.rs +++ b/pageserver/src/tenant/timeline/walreceiver.rs @@ -113,7 +113,7 @@ impl WalReceiver { } connection_manager_state.shutdown().await; *loop_status.write().unwrap() = None; - debug!("task exits"); + info!("task exits"); } .instrument(info_span!(parent: None, "wal_connection_manager", tenant_id = %tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug(), timeline_id = %timeline_id)) }); diff --git a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs index 52259f205b..249849ac4b 100644 --- a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs +++ b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs @@ -297,6 +297,7 @@ pub(super) async fn handle_walreceiver_connection( let mut expected_wal_start = startpoint; while let Some(replication_message) = { select! { + biased; _ = cancellation.cancelled() => { debug!("walreceiver interrupted"); None diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index 5e494dfdd6..dcc500f2c8 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -17,35 +17,27 @@ pub(super) async fn authenticate( config: &'static AuthenticationConfig, secret: AuthSecret, ) -> auth::Result { - let flow = AuthFlow::new(client); let scram_keys = match secret { #[cfg(any(test, feature = "testing"))] AuthSecret::Md5(_) => { debug!("auth endpoint chooses MD5"); - return Err(auth::AuthError::bad_auth_method("MD5")); + return Err(auth::AuthError::MalformedPassword("MD5 not supported")); } AuthSecret::Scram(secret) => { debug!("auth endpoint chooses SCRAM"); let scram = auth::Scram(&secret, ctx); - let auth_outcome = tokio::time::timeout( - config.scram_protocol_timeout, - async { - - flow.begin(scram).await.map_err(|error| { - warn!(?error, "error sending scram acknowledgement"); - error - })?.authenticate().await.map_err(|error| { + let auth_outcome = tokio::time::timeout(config.scram_protocol_timeout, async { + AuthFlow::new(client, scram) + .authenticate() + .await + .inspect_err(|error| { warn!(?error, "error processing scram messages"); - error }) - } - ) + }) .await - .map_err(|e| { - warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs()); - auth::AuthError::user_timeout(e) - })??; + .inspect_err(|_| warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs())) + .map_err(auth::AuthError::user_timeout)??; let client_key = match auth_outcome { sasl::Outcome::Success(key) => key, diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index dd48384c03..a50c30257f 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -2,7 +2,6 @@ use std::fmt; use async_trait::async_trait; use postgres_client::config::SslMode; -use pq_proto::BeMessage as Be; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{info, info_span}; @@ -16,6 +15,7 @@ use crate::context::RequestContext; use crate::control_plane::client::cplane_proxy_v1; use crate::control_plane::{self, CachedNodeInfo, NodeInfo}; use crate::error::{ReportableError, UserFacingError}; +use crate::pqproto::BeMessage; use crate::proxy::NeonOptions; use crate::proxy::connect_compute::ComputeConnectBackend; use crate::stream::PqStream; @@ -154,11 +154,13 @@ async fn authenticate( // Give user a URL to spawn a new database. info!(parent: &span, "sending the auth URL to the user"); - client - .write_message_noflush(&Be::AuthenticationOk)? - .write_message_noflush(&Be::CLIENT_ENCODING)? - .write_message(&Be::NoticeResponse(&greeting)) - .await?; + client.write_message(BeMessage::AuthenticationOk); + client.write_message(BeMessage::ParameterStatus { + name: b"client_encoding", + value: b"UTF8", + }); + client.write_message(BeMessage::NoticeResponse(&greeting)); + client.flush().await?; // Wait for console response via control plane (see `mgmt`). info!(parent: &span, "waiting for console's reply..."); @@ -188,7 +190,7 @@ async fn authenticate( } } - client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?; + client.write_message(BeMessage::NoticeResponse("Connecting to database.")); // This config should be self-contained, because we won't // take username or dbname from client's startup message. diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index 3316543022..1e5c076fb9 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -24,23 +24,25 @@ pub(crate) async fn authenticate_cleartext( debug!("cleartext auth flow override is enabled, proceeding"); ctx.set_auth_method(crate::context::AuthMethod::Cleartext); - // pause the timer while we communicate with the client - let paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - let ep = EndpointIdInt::from(&info.endpoint); - let auth_flow = AuthFlow::new(client) - .begin(auth::CleartextPassword { + let auth_flow = AuthFlow::new( + client, + auth::CleartextPassword { secret, endpoint: ep, pool: config.thread_pool.clone(), - }) - .await?; - drop(paused); - // cleartext auth is only allowed to the ws/http protocol. - // If we're here, we already received the password in the first message. - // Scram protocol will be executed on the proxy side. - let auth_outcome = auth_flow.authenticate().await?; + }, + ); + let auth_outcome = { + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + + // cleartext auth is only allowed to the ws/http protocol. + // If we're here, we already received the password in the first message. + // Scram protocol will be executed on the proxy side. + auth_flow.authenticate().await? + }; let keys = match auth_outcome { sasl::Outcome::Success(key) => key, @@ -67,9 +69,7 @@ pub(crate) async fn password_hack_no_authentication( // pause the timer while we communicate with the client let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - let payload = AuthFlow::new(client) - .begin(auth::PasswordHack) - .await? + let payload = AuthFlow::new(client, auth::PasswordHack) .get_password() .await?; diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index 6e5c0a3954..735cb52f47 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -4,37 +4,31 @@ mod hacks; pub mod jwt; pub mod local; -use std::net::IpAddr; use std::sync::Arc; pub use console_redirect::ConsoleRedirectBackend; pub(crate) use console_redirect::ConsoleRedirectError; -use ipnet::{Ipv4Net, Ipv6Net}; use local::LocalBackend; use postgres_client::config::AuthKeys; use serde::{Deserialize, Serialize}; use tokio::io::{AsyncRead, AsyncWrite}; -use tracing::{debug, info, warn}; +use tracing::{debug, info}; -use crate::auth::credentials::check_peer_addr_is_in_list; -use crate::auth::{ - self, AuthError, ComputeUserInfoMaybeEndpoint, IpPattern, validate_password_and_exchange, -}; +use crate::auth::{self, AuthError, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange}; use crate::cache::Cached; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::client::ControlPlaneClient; use crate::control_plane::errors::GetAuthInfoError; use crate::control_plane::{ - self, AccessBlockerFlags, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps, - CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi, + self, AccessBlockerFlags, AuthSecret, CachedNodeInfo, ControlPlaneApi, EndpointAccessControl, + RoleAccessControl, }; use crate::intern::EndpointIdInt; -use crate::metrics::Metrics; -use crate::protocol2::ConnectionInfoExtra; +use crate::pqproto::BeMessage; use crate::proxy::NeonOptions; use crate::proxy::connect_compute::ComputeConnectBackend; -use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter}; +use crate::rate_limiter::EndpointRateLimiter; use crate::stream::Stream; use crate::types::{EndpointCacheKey, EndpointId, RoleName}; use crate::{scram, stream}; @@ -200,78 +194,6 @@ impl TryFrom for ComputeUserInfo { } } -#[derive(PartialEq, PartialOrd, Hash, Eq, Ord, Debug, Copy, Clone)] -pub struct MaskedIp(IpAddr); - -impl MaskedIp { - fn new(value: IpAddr, prefix: u8) -> Self { - match value { - IpAddr::V4(v4) => Self(IpAddr::V4( - Ipv4Net::new(v4, prefix).map_or(v4, |x| x.trunc().addr()), - )), - IpAddr::V6(v6) => Self(IpAddr::V6( - Ipv6Net::new(v6, prefix).map_or(v6, |x| x.trunc().addr()), - )), - } - } -} - -// This can't be just per IP because that would limit some PaaS that share IP addresses -pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, MaskedIp)>; - -impl AuthenticationConfig { - pub(crate) fn check_rate_limit( - &self, - ctx: &RequestContext, - secret: AuthSecret, - endpoint: &EndpointId, - is_cleartext: bool, - ) -> auth::Result { - // we have validated the endpoint exists, so let's intern it. - let endpoint_int = EndpointIdInt::from(endpoint.normalize()); - - // only count the full hash count if password hack or websocket flow. - // in other words, if proxy needs to run the hashing - let password_weight = if is_cleartext { - match &secret { - #[cfg(any(test, feature = "testing"))] - AuthSecret::Md5(_) => 1, - AuthSecret::Scram(s) => s.iterations + 1, - } - } else { - // validating scram takes just 1 hmac_sha_256 operation. - 1 - }; - - let limit_not_exceeded = self.rate_limiter.check( - ( - endpoint_int, - MaskedIp::new(ctx.peer_addr(), self.rate_limit_ip_subnet), - ), - password_weight, - ); - - if !limit_not_exceeded { - warn!( - enabled = self.rate_limiter_enabled, - "rate limiting authentication" - ); - Metrics::get().proxy.requests_auth_rate_limits_total.inc(); - Metrics::get() - .proxy - .endpoints_auth_rate_limits - .get_metric() - .measure(endpoint); - - if self.rate_limiter_enabled { - return Err(auth::AuthError::too_many_connections()); - } - } - - Ok(secret) - } -} - /// True to its name, this function encapsulates our current auth trade-offs. /// Here, we choose the appropriate auth flow based on circumstances. /// @@ -284,7 +206,7 @@ async fn auth_quirks( allow_cleartext: bool, config: &'static AuthenticationConfig, endpoint_rate_limiter: Arc, -) -> auth::Result<(ComputeCredentials, Option>)> { +) -> auth::Result { // If there's no project so far, that entails that client doesn't // support SNI or other means of passing the endpoint (project) name. // We now expect to see a very specific payload in the place of password. @@ -300,55 +222,27 @@ async fn auth_quirks( debug!("fetching authentication info and allowlists"); - // check allowed list - let allowed_ips = if config.ip_allowlist_check_enabled { - let allowed_ips = api.get_allowed_ips(ctx, &info).await?; - if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) { - return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr())); - } - allowed_ips - } else { - Cached::new_uncached(Arc::new(vec![])) - }; + let access_controls = api + .get_endpoint_access_control(ctx, &info.endpoint, &info.user) + .await?; - // check if a VPC endpoint ID is coming in and if yes, if it's allowed - let access_blocks = api.get_block_public_or_vpc_access(ctx, &info).await?; - if config.is_vpc_acccess_proxy { - if access_blocks.vpc_access_blocked { - return Err(AuthError::NetworkNotAllowed); - } + access_controls.check( + ctx, + config.ip_allowlist_check_enabled, + config.is_vpc_acccess_proxy, + )?; - let incoming_vpc_endpoint_id = match ctx.extra() { - None => return Err(AuthError::MissingEndpointName), - Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(), - Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), - }; - let allowed_vpc_endpoint_ids = api.get_allowed_vpc_endpoint_ids(ctx, &info).await?; - // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. - if !allowed_vpc_endpoint_ids.is_empty() - && !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id) - { - return Err(AuthError::vpc_endpoint_id_not_allowed( - incoming_vpc_endpoint_id, - )); - } - } else if access_blocks.public_access_blocked { - return Err(AuthError::NetworkNotAllowed); - } - - if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) { + let endpoint = EndpointIdInt::from(&info.endpoint); + let rate_limit_config = None; + if !endpoint_rate_limiter.check(endpoint, rate_limit_config, 1) { return Err(AuthError::too_many_connections()); } - let cached_secret = api.get_role_secret(ctx, &info).await?; - let (cached_entry, secret) = cached_secret.take_value(); + let role_access = api + .get_role_access_control(ctx, &info.endpoint, &info.user) + .await?; - let secret = if let Some(secret) = secret { - config.check_rate_limit( - ctx, - secret, - &info.endpoint, - unauthenticated_password.is_some() || allow_cleartext, - )? + let secret = if let Some(secret) = role_access.secret { + secret } else { // If we don't have an authentication secret, we mock one to // prevent malicious probing (possible due to missing protocol steps). @@ -368,14 +262,8 @@ async fn auth_quirks( ) .await { - Ok(keys) => Ok((keys, Some(allowed_ips.as_ref().clone()))), - Err(e) => { - if e.is_password_failed() { - // The password could have been changed, so we invalidate the cache. - cached_entry.invalidate(); - } - Err(e) - } + Ok(keys) => Ok(keys), + Err(e) => Err(e), } } @@ -402,7 +290,7 @@ async fn authenticate_with_secret( }; // we have authenticated the password - client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?; + client.write_message(BeMessage::AuthenticationOk); return Ok(ComputeCredentials { info, keys }); } @@ -438,7 +326,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> { allow_cleartext: bool, config: &'static AuthenticationConfig, endpoint_rate_limiter: Arc, - ) -> auth::Result<(Backend<'a, ComputeCredentials>, Option>)> { + ) -> auth::Result> { let res = match self { Self::ControlPlane(api, user_info) => { debug!( @@ -447,17 +335,35 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> { "performing authentication using the console" ); - let (credentials, ip_allowlist) = auth_quirks( + let auth_res = auth_quirks( ctx, &*api, - user_info, + user_info.clone(), client, allow_cleartext, config, endpoint_rate_limiter, ) - .await?; - Ok((Backend::ControlPlane(api, credentials), ip_allowlist)) + .await; + match auth_res { + Ok(credentials) => Ok(Backend::ControlPlane(api, credentials)), + Err(e) => { + // The password could have been changed, so we invalidate the cache. + // We should only invalidate the cache if the TTL might have expired. + if e.is_password_failed() { + #[allow(irrefutable_let_patterns)] + if let ControlPlaneClient::ProxyV1(api) = &*api { + if let Some(ep) = &user_info.endpoint_id { + api.caches + .project_info + .maybe_invalidate_role_secret(ep, &user_info.user); + } + } + } + + Err(e) + } + } } Self::Local(_) => { return Err(auth::AuthError::bad_auth_method("invalid for local proxy")); @@ -474,44 +380,30 @@ impl Backend<'_, ComputeUserInfo> { pub(crate) async fn get_role_secret( &self, ctx: &RequestContext, - ) -> Result { - match self { - Self::ControlPlane(api, user_info) => api.get_role_secret(ctx, user_info).await, - Self::Local(_) => Ok(Cached::new_uncached(None)), - } - } - - pub(crate) async fn get_allowed_ips( - &self, - ctx: &RequestContext, - ) -> Result { - match self { - Self::ControlPlane(api, user_info) => api.get_allowed_ips(ctx, user_info).await, - Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))), - } - } - - pub(crate) async fn get_allowed_vpc_endpoint_ids( - &self, - ctx: &RequestContext, - ) -> Result { + ) -> Result { match self { Self::ControlPlane(api, user_info) => { - api.get_allowed_vpc_endpoint_ids(ctx, user_info).await + api.get_role_access_control(ctx, &user_info.endpoint, &user_info.user) + .await } - Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))), + Self::Local(_) => Ok(RoleAccessControl { secret: None }), } } - pub(crate) async fn get_block_public_or_vpc_access( + pub(crate) async fn get_endpoint_access_control( &self, ctx: &RequestContext, - ) -> Result { + ) -> Result { match self { Self::ControlPlane(api, user_info) => { - api.get_block_public_or_vpc_access(ctx, user_info).await + api.get_endpoint_access_control(ctx, &user_info.endpoint, &user_info.user) + .await } - Self::Local(_) => Ok(Cached::new_uncached(AccessBlockerFlags::default())), + Self::Local(_) => Ok(EndpointAccessControl { + allowed_ips: Arc::new(vec![]), + allowed_vpce: Arc::new(vec![]), + flags: AccessBlockerFlags::default(), + }), } } } @@ -540,9 +432,7 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials> { mod tests { #![allow(clippy::unimplemented, clippy::unwrap_used)] - use std::net::IpAddr; use std::sync::Arc; - use std::time::Duration; use bytes::BytesMut; use control_plane::AuthSecret; @@ -553,18 +443,16 @@ mod tests { use postgres_protocol::message::frontend; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; + use super::auth_quirks; use super::jwt::JwkCache; - use super::{AuthRateLimiter, auth_quirks}; - use crate::auth::backend::MaskedIp; use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern}; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::{ - self, AccessBlockerFlags, CachedAccessBlockerFlags, CachedAllowedIps, - CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, + self, AccessBlockerFlags, CachedNodeInfo, EndpointAccessControl, RoleAccessControl, }; use crate::proxy::NeonOptions; - use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo}; + use crate::rate_limiter::EndpointRateLimiter; use crate::scram::ServerSecret; use crate::scram::threadpool::ThreadPool; use crate::stream::{PqStream, Stream}; @@ -577,46 +465,34 @@ mod tests { } impl control_plane::ControlPlaneApi for Auth { - async fn get_role_secret( + async fn get_role_access_control( &self, _ctx: &RequestContext, - _user_info: &super::ComputeUserInfo, - ) -> Result { - Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone()))) + _endpoint: &crate::types::EndpointId, + _role: &crate::types::RoleName, + ) -> Result { + Ok(RoleAccessControl { + secret: Some(self.secret.clone()), + }) } - async fn get_allowed_ips( + async fn get_endpoint_access_control( &self, _ctx: &RequestContext, - _user_info: &super::ComputeUserInfo, - ) -> Result { - Ok(CachedAllowedIps::new_uncached(Arc::new(self.ips.clone()))) - } - - async fn get_allowed_vpc_endpoint_ids( - &self, - _ctx: &RequestContext, - _user_info: &super::ComputeUserInfo, - ) -> Result { - Ok(CachedAllowedVpcEndpointIds::new_uncached(Arc::new( - self.vpc_endpoint_ids.clone(), - ))) - } - - async fn get_block_public_or_vpc_access( - &self, - _ctx: &RequestContext, - _user_info: &super::ComputeUserInfo, - ) -> Result { - Ok(CachedAccessBlockerFlags::new_uncached( - self.access_blocker_flags.clone(), - )) + _endpoint: &crate::types::EndpointId, + _role: &crate::types::RoleName, + ) -> Result { + Ok(EndpointAccessControl { + allowed_ips: Arc::new(self.ips.clone()), + allowed_vpce: Arc::new(self.vpc_endpoint_ids.clone()), + flags: self.access_blocker_flags, + }) } async fn get_endpoint_jwks( &self, _ctx: &RequestContext, - _endpoint: crate::types::EndpointId, + _endpoint: &crate::types::EndpointId, ) -> Result, control_plane::errors::GetEndpointJwksError> { unimplemented!() @@ -635,9 +511,6 @@ mod tests { jwks_cache: JwkCache::default(), thread_pool: ThreadPool::new(1), scram_protocol_timeout: std::time::Duration::from_secs(5), - rate_limiter_enabled: true, - rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET), - rate_limit_ip_subnet: 64, ip_allowlist_check_enabled: true, is_vpc_acccess_proxy: false, is_auth_broker: false, @@ -654,55 +527,10 @@ mod tests { } } - #[test] - fn masked_ip() { - let ip_a = IpAddr::V4([127, 0, 0, 1].into()); - let ip_b = IpAddr::V4([127, 0, 0, 2].into()); - let ip_c = IpAddr::V4([192, 168, 1, 101].into()); - let ip_d = IpAddr::V4([192, 168, 1, 102].into()); - let ip_e = IpAddr::V6("abcd:abcd:abcd:abcd:abcd:abcd:abcd:abcd".parse().unwrap()); - let ip_f = IpAddr::V6("abcd:abcd:abcd:abcd:1234:abcd:abcd:abcd".parse().unwrap()); - - assert_ne!(MaskedIp::new(ip_a, 64), MaskedIp::new(ip_b, 64)); - assert_ne!(MaskedIp::new(ip_a, 32), MaskedIp::new(ip_b, 32)); - assert_eq!(MaskedIp::new(ip_a, 30), MaskedIp::new(ip_b, 30)); - assert_eq!(MaskedIp::new(ip_c, 30), MaskedIp::new(ip_d, 30)); - - assert_ne!(MaskedIp::new(ip_e, 128), MaskedIp::new(ip_f, 128)); - assert_eq!(MaskedIp::new(ip_e, 64), MaskedIp::new(ip_f, 64)); - } - - #[test] - fn test_default_auth_rate_limit_set() { - // these values used to exceed u32::MAX - assert_eq!( - RateBucketInfo::DEFAULT_AUTH_SET, - [ - RateBucketInfo { - interval: Duration::from_secs(1), - max_rpi: 1000 * 4096, - }, - RateBucketInfo { - interval: Duration::from_secs(60), - max_rpi: 600 * 4096 * 60, - }, - RateBucketInfo { - interval: Duration::from_secs(600), - max_rpi: 300 * 4096 * 600, - } - ] - ); - - for x in RateBucketInfo::DEFAULT_AUTH_SET { - let y = x.to_string().parse().unwrap(); - assert_eq!(x, y); - } - } - #[tokio::test] async fn auth_quirks_scram() { let (mut client, server) = tokio::io::duplex(1024); - let mut stream = PqStream::new(Stream::from_raw(server)); + let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server)); let ctx = RequestContext::test(); let api = Auth { @@ -784,7 +612,7 @@ mod tests { #[tokio::test] async fn auth_quirks_cleartext() { let (mut client, server) = tokio::io::duplex(1024); - let mut stream = PqStream::new(Stream::from_raw(server)); + let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server)); let ctx = RequestContext::test(); let api = Auth { @@ -838,7 +666,7 @@ mod tests { #[tokio::test] async fn auth_quirks_password_hack() { let (mut client, server) = tokio::io::duplex(1024); - let mut stream = PqStream::new(Stream::from_raw(server)); + let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server)); let ctx = RequestContext::test(); let api = Auth { @@ -887,7 +715,7 @@ mod tests { .await .unwrap(); - assert_eq!(creds.0.info.endpoint, "my-endpoint"); + assert_eq!(creds.info.endpoint, "my-endpoint"); handle.await.unwrap(); } diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 526d0df7f2..b51da48862 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -5,7 +5,6 @@ use std::net::IpAddr; use std::str::FromStr; use itertools::Itertools; -use pq_proto::StartupMessageParams; use thiserror::Error; use tracing::{debug, warn}; @@ -13,6 +12,7 @@ use crate::auth::password_hack::parse_endpoint_param; use crate::context::RequestContext; use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, SniGroup, SniKind}; +use crate::pqproto::StartupMessageParams; use crate::proxy::NeonOptions; use crate::serverless::{AUTH_BROKER_SNI, SERVERLESS_DRIVER_SNI}; use crate::types::{EndpointId, RoleName}; diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index 0992c6d875..8fbc4577e9 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -1,10 +1,8 @@ //! Main authentication flow. -use std::io; use std::sync::Arc; use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS}; -use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be}; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::info; @@ -13,35 +11,26 @@ use super::{AuthError, PasswordHackPayload}; use crate::context::RequestContext; use crate::control_plane::AuthSecret; use crate::intern::EndpointIdInt; +use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage}; use crate::sasl; use crate::scram::threadpool::ThreadPool; use crate::scram::{self}; use crate::stream::{PqStream, Stream}; use crate::tls::TlsServerEndPoint; -/// Every authentication selector is supposed to implement this trait. -pub(crate) trait AuthMethod { - /// Any authentication selector should provide initial backend message - /// containing auth method name and parameters, e.g. md5 salt. - fn first_message(&self, channel_binding: bool) -> BeMessage<'_>; -} - -/// Initial state of [`AuthFlow`]. -pub(crate) struct Begin; - /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`]. pub(crate) struct Scram<'a>( pub(crate) &'a scram::ServerSecret, pub(crate) &'a RequestContext, ); -impl AuthMethod for Scram<'_> { +impl Scram<'_> { #[inline(always)] fn first_message(&self, channel_binding: bool) -> BeMessage<'_> { if channel_binding { - Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS)) + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS)) } else { - Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods( + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods( scram::METHODS_WITHOUT_PLUS, )) } @@ -52,13 +41,6 @@ impl AuthMethod for Scram<'_> { /// . pub(crate) struct PasswordHack; -impl AuthMethod for PasswordHack { - #[inline(always)] - fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> { - Be::AuthenticationCleartextPassword - } -} - /// Use clear-text password auth called `password` in docs /// pub(crate) struct CleartextPassword { @@ -67,53 +49,37 @@ pub(crate) struct CleartextPassword { pub(crate) secret: AuthSecret, } -impl AuthMethod for CleartextPassword { - #[inline(always)] - fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> { - Be::AuthenticationCleartextPassword - } -} - /// This wrapper for [`PqStream`] performs client authentication. #[must_use] pub(crate) struct AuthFlow<'a, S, State> { /// The underlying stream which implements libpq's protocol. stream: &'a mut PqStream>, - /// State might contain ancillary data (see [`Self::begin`]). + /// State might contain ancillary data. state: State, tls_server_end_point: TlsServerEndPoint, } /// Initial state of the stream wrapper. -impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> { +impl<'a, S: AsyncRead + AsyncWrite + Unpin, M> AuthFlow<'a, S, M> { /// Create a new wrapper for client authentication. - pub(crate) fn new(stream: &'a mut PqStream>) -> Self { + pub(crate) fn new(stream: &'a mut PqStream>, method: M) -> Self { let tls_server_end_point = stream.get_ref().tls_server_end_point(); Self { stream, - state: Begin, + state: method, tls_server_end_point, } } - - /// Move to the next step by sending auth method's name & params to client. - pub(crate) async fn begin(self, method: M) -> io::Result> { - self.stream - .write_message(&method.first_message(self.tls_server_end_point.supported())) - .await?; - - Ok(AuthFlow { - stream: self.stream, - state: method, - tls_server_end_point: self.tls_server_end_point, - }) - } } impl AuthFlow<'_, S, PasswordHack> { /// Perform user authentication. Raise an error in case authentication failed. pub(crate) async fn get_password(self) -> super::Result { + self.stream + .write_message(BeMessage::AuthenticationCleartextPassword); + self.stream.flush().await?; + let msg = self.stream.read_password_message().await?; let password = msg .strip_suffix(&[0]) @@ -133,6 +99,10 @@ impl AuthFlow<'_, S, PasswordHack> { impl AuthFlow<'_, S, CleartextPassword> { /// Perform user authentication. Raise an error in case authentication failed. pub(crate) async fn authenticate(self) -> super::Result> { + self.stream + .write_message(BeMessage::AuthenticationCleartextPassword); + self.stream.flush().await?; + let msg = self.stream.read_password_message().await?; let password = msg .strip_suffix(&[0]) @@ -147,7 +117,7 @@ impl AuthFlow<'_, S, CleartextPassword> { .await?; if let sasl::Outcome::Success(_) = &outcome { - self.stream.write_message_noflush(&Be::AuthenticationOk)?; + self.stream.write_message(BeMessage::AuthenticationOk); } Ok(outcome) @@ -159,42 +129,36 @@ impl AuthFlow<'_, S, Scram<'_>> { /// Perform user authentication. Raise an error in case authentication failed. pub(crate) async fn authenticate(self) -> super::Result> { let Scram(secret, ctx) = self.state; + let channel_binding = self.tls_server_end_point; - // pause the timer while we communicate with the client - let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + // send sasl message. + { + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - // Initial client message contains the chosen auth method's name. - let msg = self.stream.read_password_message().await?; - let sasl = sasl::FirstMessage::parse(&msg) - .ok_or(AuthError::MalformedPassword("bad sasl message"))?; - - // Currently, the only supported SASL method is SCRAM. - if !scram::METHODS.contains(&sasl.method) { - return Err(super::AuthError::bad_auth_method(sasl.method)); + let sasl = self.state.first_message(channel_binding.supported()); + self.stream.write_message(sasl); + self.stream.flush().await?; } - match sasl.method { - SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256), - SCRAM_SHA_256_PLUS => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus), - _ => {} - } + // complete sasl handshake. + sasl::authenticate(ctx, self.stream, |method| { + // Currently, the only supported SASL method is SCRAM. + match method { + SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256), + SCRAM_SHA_256_PLUS => { + ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus); + } + method => return Err(sasl::Error::BadAuthMethod(method.into())), + } - // TODO: make this a metric instead - info!("client chooses {}", sasl.method); + // TODO: make this a metric instead + info!("client chooses {}", method); - let outcome = sasl::SaslStream::new(self.stream, sasl.message) - .authenticate(scram::Exchange::new( - secret, - rand::random, - self.tls_server_end_point, - )) - .await?; - - if let sasl::Outcome::Success(_) = &outcome { - self.stream.write_message_noflush(&Be::AuthenticationOk)?; - } - - Ok(outcome) + Ok(scram::Exchange::new(secret, rand::random, channel_binding)) + }) + .await + .map_err(AuthError::Sasl) } } diff --git a/proxy/src/binary/local_proxy.rs b/proxy/src/binary/local_proxy.rs index a566383390..ba10fce7b4 100644 --- a/proxy/src/binary/local_proxy.rs +++ b/proxy/src/binary/local_proxy.rs @@ -32,9 +32,7 @@ use crate::ext::TaskExt; use crate::http::health_server::AppMetrics; use crate::intern::RoleNameInt; use crate::metrics::{Metrics, ThreadPoolMetrics}; -use crate::rate_limiter::{ - BucketRateLimiter, EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo, -}; +use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo}; use crate::scram::threadpool::ThreadPool; use crate::serverless::cancel_set::CancelSet; use crate::serverless::{self, GlobalConnPoolOptions}; @@ -69,15 +67,6 @@ struct LocalProxyCliArgs { /// Can be given multiple times for different bucket sizes. #[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)] user_rps_limit: Vec, - /// Whether the auth rate limiter actually takes effect (for testing) - #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] - auth_rate_limit_enabled: bool, - /// Authentication rate limiter max number of hashes per second. - #[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)] - auth_rate_limit: Vec, - /// The IP subnet to use when considering whether two IP addresses are considered the same. - #[clap(long, default_value_t = 64)] - auth_rate_limit_ip_subnet: u8, /// Whether to retry the connection to the compute node #[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)] connect_to_compute_retry: String, @@ -282,9 +271,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig jwks_cache: JwkCache::default(), thread_pool: ThreadPool::new(0), scram_protocol_timeout: Duration::from_secs(10), - rate_limiter_enabled: false, - rate_limiter: BucketRateLimiter::new(vec![]), - rate_limit_ip_subnet: 64, ip_allowlist_check_enabled: true, is_vpc_acccess_proxy: false, is_auth_broker: false, diff --git a/proxy/src/binary/pg_sni_router.rs b/proxy/src/binary/pg_sni_router.rs index 3e87538ae7..a4f517fead 100644 --- a/proxy/src/binary/pg_sni_router.rs +++ b/proxy/src/binary/pg_sni_router.rs @@ -4,8 +4,9 @@ //! This allows connecting to pods/services running in the same Kubernetes cluster from //! the outside. Similar to an ingress controller for HTTPS. +use std::net::SocketAddr; use std::path::Path; -use std::{net::SocketAddr, sync::Arc}; +use std::sync::Arc; use anyhow::{Context, anyhow, bail, ensure}; use clap::Arg; @@ -17,6 +18,7 @@ use rustls::pki_types::{DnsName, PrivateKeyDer}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpListener; use tokio_rustls::TlsConnector; +use tokio_rustls::server::TlsStream; use tokio_util::sync::CancellationToken; use tracing::{Instrument, error, info}; use utils::project_git_version; @@ -24,10 +26,12 @@ use utils::sentry_init::init_sentry; use crate::context::RequestContext; use crate::metrics::{Metrics, ThreadPoolMetrics}; +use crate::pqproto::FeStartupPacket; use crate::protocol2::ConnectionInfo; -use crate::proxy::{ErrorSource, copy_bidirectional_client_compute, run_until_cancelled}; +use crate::proxy::{ + ErrorSource, TlsRequired, copy_bidirectional_client_compute, run_until_cancelled, +}; use crate::stream::{PqStream, Stream}; -use crate::tls::TlsServerEndPoint; project_git_version!(GIT_VERSION); @@ -84,7 +88,7 @@ pub async fn run() -> anyhow::Result<()> { .parse()?; // Configure TLS - let (tls_config, tls_server_end_point): (Arc, TlsServerEndPoint) = match ( + let tls_config = match ( args.get_one::("tls-key"), args.get_one::("tls-cert"), ) { @@ -117,7 +121,6 @@ pub async fn run() -> anyhow::Result<()> { dest.clone(), tls_config.clone(), None, - tls_server_end_point, proxy_listener, cancellation_token.clone(), )) @@ -127,7 +130,6 @@ pub async fn run() -> anyhow::Result<()> { dest, tls_config, Some(compute_tls_config), - tls_server_end_point, proxy_listener_compute_tls, cancellation_token.clone(), )) @@ -154,7 +156,7 @@ pub async fn run() -> anyhow::Result<()> { pub(super) fn parse_tls( key_path: &Path, cert_path: &Path, -) -> anyhow::Result<(Arc, TlsServerEndPoint)> { +) -> anyhow::Result> { let key = { let key_bytes = std::fs::read(key_path).context("TLS key file")?; @@ -187,10 +189,6 @@ pub(super) fn parse_tls( })? }; - // needed for channel bindings - let first_cert = cert_chain.first().context("missing certificate")?; - let tls_server_end_point = TlsServerEndPoint::new(first_cert)?; - let tls_config = rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider())) .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12]) @@ -199,14 +197,13 @@ pub(super) fn parse_tls( .with_single_cert(cert_chain, key)? .into(); - Ok((tls_config, tls_server_end_point)) + Ok(tls_config) } pub(super) async fn task_main( dest_suffix: Arc, tls_config: Arc, compute_tls_config: Option>, - tls_server_end_point: TlsServerEndPoint, listener: tokio::net::TcpListener, cancellation_token: CancellationToken, ) -> anyhow::Result<()> { @@ -242,15 +239,7 @@ pub(super) async fn task_main( crate::metrics::Protocol::SniRouter, "sni", ); - handle_client( - ctx, - dest_suffix, - tls_config, - compute_tls_config, - tls_server_end_point, - socket, - ) - .await + handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await } .unwrap_or_else(|e| { // Acknowledge that the task has finished with an error. @@ -269,55 +258,26 @@ pub(super) async fn task_main( Ok(()) } -const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; - async fn ssl_handshake( ctx: &RequestContext, raw_stream: S, tls_config: Arc, - tls_server_end_point: TlsServerEndPoint, -) -> anyhow::Result> { - let mut stream = PqStream::new(Stream::from_raw(raw_stream)); - - let msg = stream.read_startup_packet().await?; - use pq_proto::FeStartupPacket::SslRequest; - +) -> anyhow::Result> { + let (mut stream, msg) = PqStream::parse_startup(Stream::from_raw(raw_stream)).await?; match msg { - SslRequest { direct: false } => { - stream - .write_message(&pq_proto::BeMessage::EncryptionResponse(true)) - .await?; + FeStartupPacket::SslRequest { direct: None } => { + let raw = stream.accept_tls().await?; - // Upgrade raw stream into a secure TLS-backed stream. - // NOTE: We've consumed `tls`; this fact will be used later. - - let (raw, read_buf) = stream.into_inner(); - // TODO: Normally, client doesn't send any data before - // server says TLS handshake is ok and read_buf is empty. - // However, you could imagine pipelining of postgres - // SSLRequest + TLS ClientHello in one hunk similar to - // pipelining in our node js driver. We should probably - // support that by chaining read_buf with the stream. - if !read_buf.is_empty() { - bail!("data is sent before server replied with EncryptionResponse"); - } - - Ok(Stream::Tls { - tls: Box::new( - raw.upgrade(tls_config, !ctx.has_private_peer_addr()) - .await?, - ), - tls_server_end_point, - }) + Ok(raw + .upgrade(tls_config, !ctx.has_private_peer_addr()) + .await?) } unexpected => { info!( ?unexpected, "unexpected startup packet, rejecting connection" ); - stream - .throw_error_str(ERR_INSECURE_CONNECTION, crate::error::ErrorKind::User, None) - .await? + Err(stream.throw_error(TlsRequired, None).await)? } } } @@ -327,15 +287,18 @@ async fn handle_client( dest_suffix: Arc, tls_config: Arc, compute_tls_config: Option>, - tls_server_end_point: TlsServerEndPoint, stream: impl AsyncRead + AsyncWrite + Unpin, ) -> anyhow::Result<()> { - let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?; + let mut tls_stream = ssl_handshake(&ctx, stream, tls_config).await?; // Cut off first part of the SNI domain // We receive required destination details in the format of // `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain` - let sni = tls_stream.sni_hostname().ok_or(anyhow!("SNI missing"))?; + let sni = tls_stream + .get_ref() + .1 + .server_name() + .ok_or(anyhow!("SNI missing"))?; let dest: Vec<&str> = sni .split_once('.') .context("invalid SNI")? diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 5f24940985..dcae263647 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -20,7 +20,7 @@ use utils::sentry_init::init_sentry; use utils::{project_build_tag, project_git_version}; use crate::auth::backend::jwt::JwkCache; -use crate::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned}; +use crate::auth::backend::{ConsoleRedirectBackend, MaybeOwned}; use crate::cancellation::{CancellationHandler, handle_cancel_messages}; use crate::config::{ self, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig, ProjectInfoCacheOptions, @@ -29,9 +29,7 @@ use crate::config::{ use crate::context::parquet::ParquetUploadArgs; use crate::http::health_server::AppMetrics; use crate::metrics::Metrics; -use crate::rate_limiter::{ - EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo, WakeComputeRateLimiter, -}; +use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::redis::kv_ops::RedisKVClient; use crate::redis::{elasticache, notifications}; @@ -154,15 +152,6 @@ struct ProxyCliArgs { /// Wake compute rate limiter max number of requests per second. #[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)] wake_compute_limit: Vec, - /// Whether the auth rate limiter actually takes effect (for testing) - #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] - auth_rate_limit_enabled: bool, - /// Authentication rate limiter max number of hashes per second. - #[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)] - auth_rate_limit: Vec, - /// The IP subnet to use when considering whether two IP addresses are considered the same. - #[clap(long, default_value_t = 64)] - auth_rate_limit_ip_subnet: u8, /// Redis rate limiter max number of requests per second. #[clap(long, default_values_t = RateBucketInfo::DEFAULT_REDIS_SET)] redis_rps_limit: Vec, @@ -410,22 +399,9 @@ pub async fn run() -> anyhow::Result<()> { Some(tx_cancel), )); - // bit of a hack - find the min rps and max rps supported and turn it into - // leaky bucket config instead - let max = args - .endpoint_rps_limit - .iter() - .map(|x| x.rps()) - .max_by(f64::total_cmp) - .unwrap_or(EndpointRateLimiter::DEFAULT.max); - let rps = args - .endpoint_rps_limit - .iter() - .map(|x| x.rps()) - .min_by(f64::total_cmp) - .unwrap_or(EndpointRateLimiter::DEFAULT.rps); let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards( - LeakyBucketConfig { rps, max }, + RateBucketInfo::to_leaky_bucket(&args.endpoint_rps_limit) + .unwrap_or(EndpointRateLimiter::DEFAULT), 64, )); @@ -476,8 +452,7 @@ pub async fn run() -> anyhow::Result<()> { let key_path = args.tls_key.expect("already asserted it is set"); let cert_path = args.tls_cert.expect("already asserted it is set"); - let (tls_config, tls_server_end_point) = - super::pg_sni_router::parse_tls(&key_path, &cert_path)?; + let tls_config = super::pg_sni_router::parse_tls(&key_path, &cert_path)?; let dest = Arc::new(dest); @@ -485,7 +460,6 @@ pub async fn run() -> anyhow::Result<()> { dest.clone(), tls_config.clone(), None, - tls_server_end_point, listen, cancellation_token.clone(), )); @@ -494,7 +468,6 @@ pub async fn run() -> anyhow::Result<()> { dest, tls_config, Some(config.connect_to_compute.tls.clone()), - tls_server_end_point, listen_tls, cancellation_token.clone(), )); @@ -681,9 +654,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { jwks_cache: JwkCache::default(), thread_pool, scram_protocol_timeout: args.scram_protocol_timeout, - rate_limiter_enabled: args.auth_rate_limit_enabled, - rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()), - rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet, ip_allowlist_check_enabled: !args.is_private_access_proxy, is_vpc_acccess_proxy: args.is_private_access_proxy, is_auth_broker: args.is_auth_broker, diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index 60678b034d..81c88e3ddd 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -1,30 +1,25 @@ -use std::collections::HashSet; +use std::collections::{HashMap, HashSet, hash_map}; use std::convert::Infallible; -use std::sync::Arc; use std::sync::atomic::AtomicU64; use std::time::Duration; use async_trait::async_trait; use clashmap::ClashMap; +use clashmap::mapref::one::Ref; use rand::{Rng, thread_rng}; -use smol_str::SmolStr; use tokio::sync::Mutex; use tokio::time::Instant; use tracing::{debug, info}; -use super::{Cache, Cached}; -use crate::auth::IpPattern; use crate::config::ProjectInfoCacheOptions; -use crate::control_plane::{AccessBlockerFlags, AuthSecret}; +use crate::control_plane::{EndpointAccessControl, RoleAccessControl}; use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt}; use crate::types::{EndpointId, RoleName}; #[async_trait] pub(crate) trait ProjectInfoCache { - fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt); - fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec); - fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt); - fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt); + fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt); + fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt); fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt); async fn decrement_active_listeners(&self); async fn increment_active_listeners(&self); @@ -42,6 +37,10 @@ impl Entry { value, } } + + pub(crate) fn get(&self, valid_since: Instant) -> Option<&T> { + (valid_since < self.created_at).then_some(&self.value) + } } impl From for Entry { @@ -50,101 +49,32 @@ impl From for Entry { } } -#[derive(Default)] struct EndpointInfo { - secret: std::collections::HashMap>>, - allowed_ips: Option>>>, - block_public_or_vpc_access: Option>, - allowed_vpc_endpoint_ids: Option>>>, + role_controls: HashMap>, + controls: Option>, } impl EndpointInfo { - fn check_ignore_cache(ignore_cache_since: Option, created_at: Instant) -> bool { - match ignore_cache_since { - None => false, - Some(t) => t < created_at, - } - } pub(crate) fn get_role_secret( &self, role_name: RoleNameInt, valid_since: Instant, - ignore_cache_since: Option, - ) -> Option<(Option, bool)> { - if let Some(secret) = self.secret.get(&role_name) { - if valid_since < secret.created_at { - return Some(( - secret.value.clone(), - Self::check_ignore_cache(ignore_cache_since, secret.created_at), - )); - } - } - None + ) -> Option { + let controls = self.role_controls.get(&role_name)?; + controls.get(valid_since).cloned() } - pub(crate) fn get_allowed_ips( - &self, - valid_since: Instant, - ignore_cache_since: Option, - ) -> Option<(Arc>, bool)> { - if let Some(allowed_ips) = &self.allowed_ips { - if valid_since < allowed_ips.created_at { - return Some(( - allowed_ips.value.clone(), - Self::check_ignore_cache(ignore_cache_since, allowed_ips.created_at), - )); - } - } - None - } - pub(crate) fn get_allowed_vpc_endpoint_ids( - &self, - valid_since: Instant, - ignore_cache_since: Option, - ) -> Option<(Arc>, bool)> { - if let Some(allowed_vpc_endpoint_ids) = &self.allowed_vpc_endpoint_ids { - if valid_since < allowed_vpc_endpoint_ids.created_at { - return Some(( - allowed_vpc_endpoint_ids.value.clone(), - Self::check_ignore_cache( - ignore_cache_since, - allowed_vpc_endpoint_ids.created_at, - ), - )); - } - } - None - } - pub(crate) fn get_block_public_or_vpc_access( - &self, - valid_since: Instant, - ignore_cache_since: Option, - ) -> Option<(AccessBlockerFlags, bool)> { - if let Some(block_public_or_vpc_access) = &self.block_public_or_vpc_access { - if valid_since < block_public_or_vpc_access.created_at { - return Some(( - block_public_or_vpc_access.value.clone(), - Self::check_ignore_cache( - ignore_cache_since, - block_public_or_vpc_access.created_at, - ), - )); - } - } - None + pub(crate) fn get_controls(&self, valid_since: Instant) -> Option { + let controls = self.controls.as_ref()?; + controls.get(valid_since).cloned() } - pub(crate) fn invalidate_allowed_ips(&mut self) { - self.allowed_ips = None; - } - pub(crate) fn invalidate_allowed_vpc_endpoint_ids(&mut self) { - self.allowed_vpc_endpoint_ids = None; - } - pub(crate) fn invalidate_block_public_or_vpc_access(&mut self) { - self.block_public_or_vpc_access = None; + pub(crate) fn invalidate_endpoint(&mut self) { + self.controls = None; } + pub(crate) fn invalidate_role_secret(&mut self, role_name: RoleNameInt) { - self.secret.remove(&role_name); + self.role_controls.remove(&role_name); } } @@ -170,34 +100,22 @@ pub struct ProjectInfoCacheImpl { #[async_trait] impl ProjectInfoCache for ProjectInfoCacheImpl { - fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec) { - info!( - "invalidating allowed vpc endpoint ids for projects `{}`", - project_ids - .iter() - .map(|id| id.to_string()) - .collect::>() - .join(", ") - ); - for project_id in project_ids { - let endpoints = self - .project2ep - .get(&project_id) - .map(|kv| kv.value().clone()) - .unwrap_or_default(); - for endpoint_id in endpoints { - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_allowed_vpc_endpoint_ids(); - } + fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) { + info!("invalidating endpoint access for project `{project_id}`"); + let endpoints = self + .project2ep + .get(&project_id) + .map(|kv| kv.value().clone()) + .unwrap_or_default(); + for endpoint_id in endpoints { + if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { + endpoint_info.invalidate_endpoint(); } } } - fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt) { - info!( - "invalidating allowed vpc endpoint ids for org `{}`", - account_id - ); + fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) { + info!("invalidating endpoint access for org `{account_id}`"); let endpoints = self .account2ep .get(&account_id) @@ -205,41 +123,11 @@ impl ProjectInfoCache for ProjectInfoCacheImpl { .unwrap_or_default(); for endpoint_id in endpoints { if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_allowed_vpc_endpoint_ids(); + endpoint_info.invalidate_endpoint(); } } } - fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt) { - info!( - "invalidating block public or vpc access for project `{}`", - project_id - ); - let endpoints = self - .project2ep - .get(&project_id) - .map(|kv| kv.value().clone()) - .unwrap_or_default(); - for endpoint_id in endpoints { - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_block_public_or_vpc_access(); - } - } - } - - fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt) { - info!("invalidating allowed ips for project `{}`", project_id); - let endpoints = self - .project2ep - .get(&project_id) - .map(|kv| kv.value().clone()) - .unwrap_or_default(); - for endpoint_id in endpoints { - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_allowed_ips(); - } - } - } fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt) { info!( "invalidating role secret for project_id `{}` and role_name `{}`", @@ -256,6 +144,7 @@ impl ProjectInfoCache for ProjectInfoCacheImpl { } } } + async fn decrement_active_listeners(&self) { let mut listeners_guard = self.active_listeners_lock.lock().await; if *listeners_guard == 0 { @@ -293,155 +182,71 @@ impl ProjectInfoCacheImpl { } } + fn get_endpoint_cache( + &self, + endpoint_id: &EndpointId, + ) -> Option> { + let endpoint_id = EndpointIdInt::get(endpoint_id)?; + self.cache.get(&endpoint_id) + } + pub(crate) fn get_role_secret( &self, endpoint_id: &EndpointId, role_name: &RoleName, - ) -> Option>> { - let endpoint_id = EndpointIdInt::get(endpoint_id)?; + ) -> Option { + let valid_since = self.get_cache_times(); let role_name = RoleNameInt::get(role_name)?; - let (valid_since, ignore_cache_since) = self.get_cache_times(); - let endpoint_info = self.cache.get(&endpoint_id)?; - let (value, ignore_cache) = - endpoint_info.get_role_secret(role_name, valid_since, ignore_cache_since)?; - if !ignore_cache { - let cached = Cached { - token: Some(( - self, - CachedLookupInfo::new_role_secret(endpoint_id, role_name), - )), - value, - }; - return Some(cached); - } - Some(Cached::new_uncached(value)) - } - pub(crate) fn get_allowed_ips( - &self, - endpoint_id: &EndpointId, - ) -> Option>>> { - let endpoint_id = EndpointIdInt::get(endpoint_id)?; - let (valid_since, ignore_cache_since) = self.get_cache_times(); - let endpoint_info = self.cache.get(&endpoint_id)?; - let value = endpoint_info.get_allowed_ips(valid_since, ignore_cache_since); - let (value, ignore_cache) = value?; - if !ignore_cache { - let cached = Cached { - token: Some((self, CachedLookupInfo::new_allowed_ips(endpoint_id))), - value, - }; - return Some(cached); - } - Some(Cached::new_uncached(value)) - } - pub(crate) fn get_allowed_vpc_endpoint_ids( - &self, - endpoint_id: &EndpointId, - ) -> Option>>> { - let endpoint_id = EndpointIdInt::get(endpoint_id)?; - let (valid_since, ignore_cache_since) = self.get_cache_times(); - let endpoint_info = self.cache.get(&endpoint_id)?; - let value = endpoint_info.get_allowed_vpc_endpoint_ids(valid_since, ignore_cache_since); - let (value, ignore_cache) = value?; - if !ignore_cache { - let cached = Cached { - token: Some(( - self, - CachedLookupInfo::new_allowed_vpc_endpoint_ids(endpoint_id), - )), - value, - }; - return Some(cached); - } - Some(Cached::new_uncached(value)) - } - pub(crate) fn get_block_public_or_vpc_access( - &self, - endpoint_id: &EndpointId, - ) -> Option> { - let endpoint_id = EndpointIdInt::get(endpoint_id)?; - let (valid_since, ignore_cache_since) = self.get_cache_times(); - let endpoint_info = self.cache.get(&endpoint_id)?; - let value = endpoint_info.get_block_public_or_vpc_access(valid_since, ignore_cache_since); - let (value, ignore_cache) = value?; - if !ignore_cache { - let cached = Cached { - token: Some(( - self, - CachedLookupInfo::new_block_public_or_vpc_access(endpoint_id), - )), - value, - }; - return Some(cached); - } - Some(Cached::new_uncached(value)) + let endpoint_info = self.get_endpoint_cache(endpoint_id)?; + endpoint_info.get_role_secret(role_name, valid_since) } - pub(crate) fn insert_role_secret( + pub(crate) fn get_endpoint_access( &self, - project_id: ProjectIdInt, - endpoint_id: EndpointIdInt, - role_name: RoleNameInt, - secret: Option, - ) { - if self.cache.len() >= self.config.size { - // If there are too many entries, wait until the next gc cycle. - return; - } - self.insert_project2endpoint(project_id, endpoint_id); - let mut entry = self.cache.entry(endpoint_id).or_default(); - if entry.secret.len() < self.config.max_roles { - entry.secret.insert(role_name, secret.into()); - } + endpoint_id: &EndpointId, + ) -> Option { + let valid_since = self.get_cache_times(); + let endpoint_info = self.get_endpoint_cache(endpoint_id)?; + endpoint_info.get_controls(valid_since) } - pub(crate) fn insert_allowed_ips( - &self, - project_id: ProjectIdInt, - endpoint_id: EndpointIdInt, - allowed_ips: Arc>, - ) { - if self.cache.len() >= self.config.size { - // If there are too many entries, wait until the next gc cycle. - return; - } - self.insert_project2endpoint(project_id, endpoint_id); - self.cache.entry(endpoint_id).or_default().allowed_ips = Some(allowed_ips.into()); - } - pub(crate) fn insert_allowed_vpc_endpoint_ids( + + pub(crate) fn insert_endpoint_access( &self, account_id: Option, project_id: ProjectIdInt, endpoint_id: EndpointIdInt, - allowed_vpc_endpoint_ids: Arc>, + role_name: RoleNameInt, + controls: EndpointAccessControl, + role_controls: RoleAccessControl, ) { - if self.cache.len() >= self.config.size { - // If there are too many entries, wait until the next gc cycle. - return; - } if let Some(account_id) = account_id { self.insert_account2endpoint(account_id, endpoint_id); } self.insert_project2endpoint(project_id, endpoint_id); - self.cache - .entry(endpoint_id) - .or_default() - .allowed_vpc_endpoint_ids = Some(allowed_vpc_endpoint_ids.into()); - } - pub(crate) fn insert_block_public_or_vpc_access( - &self, - project_id: ProjectIdInt, - endpoint_id: EndpointIdInt, - access_blockers: AccessBlockerFlags, - ) { + if self.cache.len() >= self.config.size { // If there are too many entries, wait until the next gc cycle. return; } - self.insert_project2endpoint(project_id, endpoint_id); - self.cache - .entry(endpoint_id) - .or_default() - .block_public_or_vpc_access = Some(access_blockers.into()); + + let controls = Entry::from(controls); + let role_controls = Entry::from(role_controls); + + match self.cache.entry(endpoint_id) { + clashmap::Entry::Vacant(e) => { + e.insert(EndpointInfo { + role_controls: HashMap::from_iter([(role_name, role_controls)]), + controls: Some(controls), + }); + } + clashmap::Entry::Occupied(mut e) => { + let ep = e.get_mut(); + ep.controls = Some(controls); + if ep.role_controls.len() < self.config.max_roles { + ep.role_controls.insert(role_name, role_controls); + } + } + } } fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) { @@ -452,6 +257,7 @@ impl ProjectInfoCacheImpl { .insert(project_id, HashSet::from([endpoint_id])); } } + fn insert_account2endpoint(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt) { if let Some(mut endpoints) = self.account2ep.get_mut(&account_id) { endpoints.insert(endpoint_id); @@ -460,21 +266,57 @@ impl ProjectInfoCacheImpl { .insert(account_id, HashSet::from([endpoint_id])); } } - fn get_cache_times(&self) -> (Instant, Option) { - let mut valid_since = Instant::now() - self.config.ttl; - // Only ignore cache if ttl is disabled. + + fn ignore_ttl_since(&self) -> Option { let ttl_disabled_since_us = self .ttl_disabled_since_us .load(std::sync::atomic::Ordering::Relaxed); - let ignore_cache_since = if ttl_disabled_since_us == u64::MAX { - None - } else { - let ignore_cache_since = self.start_time + Duration::from_micros(ttl_disabled_since_us); + + if ttl_disabled_since_us == u64::MAX { + return None; + } + + Some(self.start_time + Duration::from_micros(ttl_disabled_since_us)) + } + + fn get_cache_times(&self) -> Instant { + let mut valid_since = Instant::now() - self.config.ttl; + if let Some(ignore_ttl_since) = self.ignore_ttl_since() { // We are fine if entry is not older than ttl or was added before we are getting notifications. - valid_since = valid_since.min(ignore_cache_since); - Some(ignore_cache_since) + valid_since = valid_since.min(ignore_ttl_since); + } + valid_since + } + + pub fn maybe_invalidate_role_secret(&self, endpoint_id: &EndpointId, role_name: &RoleName) { + let Some(endpoint_id) = EndpointIdInt::get(endpoint_id) else { + return; }; - (valid_since, ignore_cache_since) + let Some(role_name) = RoleNameInt::get(role_name) else { + return; + }; + + let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) else { + return; + }; + + let entry = endpoint_info.role_controls.entry(role_name); + let hash_map::Entry::Occupied(role_controls) = entry else { + return; + }; + + let created_at = role_controls.get().created_at; + let expire = match self.ignore_ttl_since() { + // if ignoring TTL, we should still try and roll the password if it's old + // and we the client gave an incorrect password. There could be some lag on the redis channel. + Some(_) => created_at + self.config.ttl < Instant::now(), + // edge case: redis is down, let's be generous and invalidate the cache immediately. + None => true, + }; + + if expire { + role_controls.remove(); + } } pub async fn gc_worker(&self) -> anyhow::Result { @@ -509,84 +351,12 @@ impl ProjectInfoCacheImpl { } } -/// Lookup info for project info cache. -/// This is used to invalidate cache entries. -pub(crate) struct CachedLookupInfo { - /// Search by this key. - endpoint_id: EndpointIdInt, - lookup_type: LookupType, -} - -impl CachedLookupInfo { - pub(self) fn new_role_secret(endpoint_id: EndpointIdInt, role_name: RoleNameInt) -> Self { - Self { - endpoint_id, - lookup_type: LookupType::RoleSecret(role_name), - } - } - pub(self) fn new_allowed_ips(endpoint_id: EndpointIdInt) -> Self { - Self { - endpoint_id, - lookup_type: LookupType::AllowedIps, - } - } - pub(self) fn new_allowed_vpc_endpoint_ids(endpoint_id: EndpointIdInt) -> Self { - Self { - endpoint_id, - lookup_type: LookupType::AllowedVpcEndpointIds, - } - } - pub(self) fn new_block_public_or_vpc_access(endpoint_id: EndpointIdInt) -> Self { - Self { - endpoint_id, - lookup_type: LookupType::BlockPublicOrVpcAccess, - } - } -} - -enum LookupType { - RoleSecret(RoleNameInt), - AllowedIps, - AllowedVpcEndpointIds, - BlockPublicOrVpcAccess, -} - -impl Cache for ProjectInfoCacheImpl { - type Key = SmolStr; - // Value is not really used here, but we need to specify it. - type Value = SmolStr; - - type LookupInfo = CachedLookupInfo; - - fn invalidate(&self, key: &Self::LookupInfo) { - match &key.lookup_type { - LookupType::RoleSecret(role_name) => { - if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { - endpoint_info.invalidate_role_secret(*role_name); - } - } - LookupType::AllowedIps => { - if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { - endpoint_info.invalidate_allowed_ips(); - } - } - LookupType::AllowedVpcEndpointIds => { - if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { - endpoint_info.invalidate_allowed_vpc_endpoint_ids(); - } - } - LookupType::BlockPublicOrVpcAccess => { - if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { - endpoint_info.invalidate_block_public_or_vpc_access(); - } - } - } - } -} - #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; + use crate::control_plane::{AccessBlockerFlags, AuthSecret}; use crate::scram::ServerSecret; use crate::types::ProjectId; @@ -601,6 +371,8 @@ mod tests { }); let project_id: ProjectId = "project".into(); let endpoint_id: EndpointId = "endpoint".into(); + let account_id: Option = None; + let user1: RoleName = "user1".into(); let user2: RoleName = "user2".into(); let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32]))); @@ -609,183 +381,73 @@ mod tests { "127.0.0.1".parse().unwrap(), "127.0.0.2".parse().unwrap(), ]); - cache.insert_role_secret( + + cache.insert_endpoint_access( + account_id, (&project_id).into(), (&endpoint_id).into(), (&user1).into(), - secret1.clone(), + EndpointAccessControl { + allowed_ips: allowed_ips.clone(), + allowed_vpce: Arc::new(vec![]), + flags: AccessBlockerFlags::default(), + }, + RoleAccessControl { + secret: secret1.clone(), + }, ); - cache.insert_role_secret( + + cache.insert_endpoint_access( + account_id, (&project_id).into(), (&endpoint_id).into(), (&user2).into(), - secret2.clone(), - ); - cache.insert_allowed_ips( - (&project_id).into(), - (&endpoint_id).into(), - allowed_ips.clone(), + EndpointAccessControl { + allowed_ips: allowed_ips.clone(), + allowed_vpce: Arc::new(vec![]), + flags: AccessBlockerFlags::default(), + }, + RoleAccessControl { + secret: secret2.clone(), + }, ); let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); - assert!(cached.cached()); - assert_eq!(cached.value, secret1); + assert_eq!(cached.secret, secret1); + let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap(); - assert!(cached.cached()); - assert_eq!(cached.value, secret2); + assert_eq!(cached.secret, secret2); // Shouldn't add more than 2 roles. let user3: RoleName = "user3".into(); let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32]))); - cache.insert_role_secret( + + cache.insert_endpoint_access( + account_id, (&project_id).into(), (&endpoint_id).into(), (&user3).into(), - secret3.clone(), + EndpointAccessControl { + allowed_ips: allowed_ips.clone(), + allowed_vpce: Arc::new(vec![]), + flags: AccessBlockerFlags::default(), + }, + RoleAccessControl { + secret: secret3.clone(), + }, ); + assert!(cache.get_role_secret(&endpoint_id, &user3).is_none()); - let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); - assert!(cached.cached()); - assert_eq!(cached.value, allowed_ips); + let cached = cache.get_endpoint_access(&endpoint_id).unwrap(); + assert_eq!(cached.allowed_ips, allowed_ips); tokio::time::advance(Duration::from_secs(2)).await; let cached = cache.get_role_secret(&endpoint_id, &user1); assert!(cached.is_none()); let cached = cache.get_role_secret(&endpoint_id, &user2); assert!(cached.is_none()); - let cached = cache.get_allowed_ips(&endpoint_id); + let cached = cache.get_endpoint_access(&endpoint_id); assert!(cached.is_none()); } - - #[tokio::test] - async fn test_project_info_cache_invalidations() { - tokio::time::pause(); - let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions { - size: 2, - max_roles: 2, - ttl: Duration::from_secs(1), - gc_interval: Duration::from_secs(600), - })); - cache.clone().increment_active_listeners().await; - tokio::time::advance(Duration::from_secs(2)).await; - - let project_id: ProjectId = "project".into(); - let endpoint_id: EndpointId = "endpoint".into(); - let user1: RoleName = "user1".into(); - let user2: RoleName = "user2".into(); - let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32]))); - let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32]))); - let allowed_ips = Arc::new(vec![ - "127.0.0.1".parse().unwrap(), - "127.0.0.2".parse().unwrap(), - ]); - cache.insert_role_secret( - (&project_id).into(), - (&endpoint_id).into(), - (&user1).into(), - secret1.clone(), - ); - cache.insert_role_secret( - (&project_id).into(), - (&endpoint_id).into(), - (&user2).into(), - secret2.clone(), - ); - cache.insert_allowed_ips( - (&project_id).into(), - (&endpoint_id).into(), - allowed_ips.clone(), - ); - - tokio::time::advance(Duration::from_secs(2)).await; - // Nothing should be invalidated. - - let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); - // TTL is disabled, so it should be impossible to invalidate this value. - assert!(!cached.cached()); - assert_eq!(cached.value, secret1); - - cached.invalidate(); // Shouldn't do anything. - let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); - assert_eq!(cached.value, secret1); - - let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap(); - assert!(!cached.cached()); - assert_eq!(cached.value, secret2); - - // The only way to invalidate this value is to invalidate via the api. - cache.invalidate_role_secret_for_project((&project_id).into(), (&user2).into()); - assert!(cache.get_role_secret(&endpoint_id, &user2).is_none()); - - let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); - assert!(!cached.cached()); - assert_eq!(cached.value, allowed_ips); - } - - #[tokio::test] - async fn test_increment_active_listeners_invalidate_added_before() { - tokio::time::pause(); - let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions { - size: 2, - max_roles: 2, - ttl: Duration::from_secs(1), - gc_interval: Duration::from_secs(600), - })); - - let project_id: ProjectId = "project".into(); - let endpoint_id: EndpointId = "endpoint".into(); - let user1: RoleName = "user1".into(); - let user2: RoleName = "user2".into(); - let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32]))); - let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32]))); - let allowed_ips = Arc::new(vec![ - "127.0.0.1".parse().unwrap(), - "127.0.0.2".parse().unwrap(), - ]); - cache.insert_role_secret( - (&project_id).into(), - (&endpoint_id).into(), - (&user1).into(), - secret1.clone(), - ); - cache.clone().increment_active_listeners().await; - tokio::time::advance(Duration::from_millis(100)).await; - cache.insert_role_secret( - (&project_id).into(), - (&endpoint_id).into(), - (&user2).into(), - secret2.clone(), - ); - - // Added before ttl was disabled + ttl should be still cached. - let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); - assert!(cached.cached()); - let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap(); - assert!(cached.cached()); - - tokio::time::advance(Duration::from_secs(1)).await; - // Added before ttl was disabled + ttl should expire. - assert!(cache.get_role_secret(&endpoint_id, &user1).is_none()); - assert!(cache.get_role_secret(&endpoint_id, &user2).is_none()); - - // Added after ttl was disabled + ttl should not be cached. - cache.insert_allowed_ips( - (&project_id).into(), - (&endpoint_id).into(), - allowed_ips.clone(), - ); - let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); - assert!(!cached.cached()); - - tokio::time::advance(Duration::from_secs(1)).await; - // Added before ttl was disabled + ttl still should expire. - assert!(cache.get_role_secret(&endpoint_id, &user1).is_none()); - assert!(cache.get_role_secret(&endpoint_id, &user2).is_none()); - // Shouldn't be invalidated. - - let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); - assert!(!cached.cached()); - assert_eq!(cached.value, allowed_ips); - } } diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index a6e7bf85a0..d26641db46 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -5,7 +5,6 @@ use anyhow::{Context, anyhow}; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use postgres_client::CancelToken; use postgres_client::tls::MakeTlsConnect; -use pq_proto::CancelKeyData; use redis::{Cmd, FromRedisValue, Value}; use serde::{Deserialize, Serialize}; use thiserror::Error; @@ -13,15 +12,15 @@ use tokio::net::TcpStream; use tokio::sync::{mpsc, oneshot}; use tracing::{debug, error, info, warn}; +use crate::auth::AuthError; use crate::auth::backend::ComputeUserInfo; -use crate::auth::{AuthError, check_peer_addr_is_in_list}; use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::ControlPlaneApi; use crate::error::ReportableError; use crate::ext::LockExt; use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind}; -use crate::protocol2::ConnectionInfoExtra; +use crate::pqproto::CancelKeyData; use crate::rate_limiter::LeakyBucketRateLimiter; use crate::redis::keys::KeyPrefix; use crate::redis::kv_ops::RedisKVClient; @@ -272,13 +271,7 @@ pub(crate) enum CancelError { #[error("rate limit exceeded")] RateLimit, - #[error("IP is not allowed")] - IpNotAllowed, - - #[error("VPC endpoint id is not allowed to connect")] - VpcEndpointIdNotAllowed, - - #[error("Authentication backend error")] + #[error("Authentication error")] AuthError(#[from] AuthError), #[error("key not found")] @@ -297,10 +290,7 @@ impl ReportableError for CancelError { } CancelError::Postgres(_) => crate::error::ErrorKind::Compute, CancelError::RateLimit => crate::error::ErrorKind::RateLimit, - CancelError::IpNotAllowed - | CancelError::VpcEndpointIdNotAllowed - | CancelError::NotFound => crate::error::ErrorKind::User, - CancelError::AuthError(_) => crate::error::ErrorKind::ControlPlane, + CancelError::NotFound | CancelError::AuthError(_) => crate::error::ErrorKind::User, CancelError::InternalError => crate::error::ErrorKind::Service, } } @@ -422,7 +412,13 @@ impl CancellationHandler { IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()), }; - if !self.limiter.lock_propagate_poison().check(subnet_key, 1) { + + let allowed = { + let rate_limit_config = None; + let limiter = self.limiter.lock_propagate_poison(); + limiter.check(subnet_key, rate_limit_config, 1) + }; + if !allowed { // log only the subnet part of the IP address to know which subnet is rate limited tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}"); Metrics::get() @@ -450,52 +446,13 @@ impl CancellationHandler { return Err(CancelError::NotFound); }; - if check_ip_allowed { - let ip_allowlist = auth_backend - .get_allowed_ips(&ctx, &cancel_closure.user_info) - .await - .map_err(|e| CancelError::AuthError(e.into()))?; - - if !check_peer_addr_is_in_list(&ctx.peer_addr(), &ip_allowlist) { - // log it here since cancel_session could be spawned in a task - tracing::warn!( - "IP is not allowed to cancel the query: {key}, address: {}", - ctx.peer_addr() - ); - return Err(CancelError::IpNotAllowed); - } - } - - // check if a VPC endpoint ID is coming in and if yes, if it's allowed - let access_blocks = auth_backend - .get_block_public_or_vpc_access(&ctx, &cancel_closure.user_info) + let info = &cancel_closure.user_info; + let access_controls = auth_backend + .get_endpoint_access_control(&ctx, &info.endpoint, &info.user) .await .map_err(|e| CancelError::AuthError(e.into()))?; - if check_vpc_allowed { - if access_blocks.vpc_access_blocked { - return Err(CancelError::AuthError(AuthError::NetworkNotAllowed)); - } - - let incoming_vpc_endpoint_id = match ctx.extra() { - None => return Err(CancelError::AuthError(AuthError::MissingVPCEndpointId)), - Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(), - Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), - }; - - let allowed_vpc_endpoint_ids = auth_backend - .get_allowed_vpc_endpoint_ids(&ctx, &cancel_closure.user_info) - .await - .map_err(|e| CancelError::AuthError(e.into()))?; - // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. - if !allowed_vpc_endpoint_ids.is_empty() - && !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id) - { - return Err(CancelError::VpcEndpointIdNotAllowed); - } - } else if access_blocks.public_access_blocked { - return Err(CancelError::VpcEndpointIdNotAllowed); - } + access_controls.check(&ctx, check_ip_allowed, check_vpc_allowed)?; Metrics::get() .proxy diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 26254beecf..2899f25129 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -8,7 +8,6 @@ use itertools::Itertools; use postgres_client::tls::MakeTlsConnect; use postgres_client::{CancelToken, RawConnection}; use postgres_protocol::message::backend::NoticeResponseBody; -use pq_proto::StartupMessageParams; use rustls::pki_types::InvalidDnsNameError; use thiserror::Error; use tokio::net::{TcpStream, lookup_host}; @@ -24,6 +23,7 @@ use crate::control_plane::errors::WakeComputeError; use crate::control_plane::messages::MetricsAuxInfo; use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumDbConnectionsGuard}; +use crate::pqproto::StartupMessageParams; use crate::proxy::neon_option; use crate::tls::postgres_rustls::MakeRustlsConnect; use crate::types::Host; diff --git a/proxy/src/config.rs b/proxy/src/config.rs index ad398c122c..a97339df9a 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -7,7 +7,6 @@ use arc_swap::ArcSwapOption; use clap::ValueEnum; use remote_storage::RemoteStorageConfig; -use crate::auth::backend::AuthRateLimiter; use crate::auth::backend::jwt::JwkCache; use crate::control_plane::locks::ApiLocks; use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig}; @@ -65,9 +64,6 @@ pub struct HttpConfig { pub struct AuthenticationConfig { pub thread_pool: Arc, pub scram_protocol_timeout: tokio::time::Duration, - pub rate_limiter_enabled: bool, - pub rate_limiter: AuthRateLimiter, - pub rate_limit_ip_subnet: u8, pub ip_allowlist_check_enabled: bool, pub is_vpc_acccess_proxy: bool, pub jwks_cache: JwkCache, diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index e3184e20d1..9499aba61b 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use futures::{FutureExt, TryFutureExt}; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::sync::CancellationToken; use tracing::{Instrument, debug, error, info}; @@ -221,12 +221,10 @@ pub(crate) async fn handle_client( .await { Ok(auth_result) => auth_result, - Err(e) => { - return stream.throw_error(e, Some(ctx)).await?; - } + Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, }; - let mut node = connect_to_compute( + let node = connect_to_compute( ctx, &TcpMechanism { user_info, @@ -238,7 +236,7 @@ pub(crate) async fn handle_client( config.wake_compute_retry_config, &config.connect_to_compute, ) - .or_else(|e| stream.throw_error(e, Some(ctx))) + .or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) }) .await?; let cancellation_handler_clone = Arc::clone(&cancellation_handler); @@ -246,14 +244,8 @@ pub(crate) async fn handle_client( session.write_cancel_key(node.cancel_closure.clone())?; - prepare_client_connection(&node, *session.key(), &mut stream).await?; - - // Before proxy passing, forward to compute whatever data is left in the - // PqStream input buffer. Normally there is none, but our serverless npm - // driver in pipeline mode sends startup, password and first query - // immediately after opening the connection. - let (stream, read_buf) = stream.into_inner(); - node.stream.write_all(&read_buf).await?; + prepare_client_connection(&node, *session.key(), &mut stream); + let stream = stream.flush_and_into_inner().await?; Ok(Some(ProxyPassthrough { client: stream, diff --git a/proxy/src/context/mod.rs b/proxy/src/context/mod.rs index 79aaf22990..24268997ba 100644 --- a/proxy/src/context/mod.rs +++ b/proxy/src/context/mod.rs @@ -4,7 +4,6 @@ use std::net::IpAddr; use chrono::Utc; use once_cell::sync::OnceCell; -use pq_proto::StartupMessageParams; use smol_str::SmolStr; use tokio::sync::mpsc; use tracing::field::display; @@ -20,6 +19,7 @@ use crate::metrics::{ ConnectOutcome, InvalidEndpointsGroup, LatencyAccumulated, LatencyTimer, Metrics, Protocol, Waiting, }; +use crate::pqproto::StartupMessageParams; use crate::protocol2::{ConnectionInfo, ConnectionInfoExtra}; use crate::types::{DbName, EndpointId, RoleName}; @@ -370,6 +370,18 @@ impl RequestContext { } } + pub(crate) fn latency_timer_pause_at( + &self, + at: tokio::time::Instant, + waiting_for: Waiting, + ) -> LatencyTimerPause<'_> { + LatencyTimerPause { + ctx: self, + start: at, + waiting_for, + } + } + pub(crate) fn get_proxy_latency(&self) -> LatencyAccumulated { self.0 .try_lock() diff --git a/proxy/src/context/parquet.rs b/proxy/src/context/parquet.rs index f6250bcd17..c9d3905abd 100644 --- a/proxy/src/context/parquet.rs +++ b/proxy/src/context/parquet.rs @@ -11,7 +11,6 @@ use parquet::file::metadata::RowGroupMetaDataPtr; use parquet::file::properties::{DEFAULT_PAGE_SIZE, WriterProperties, WriterPropertiesPtr}; use parquet::file::writer::SerializedFileWriter; use parquet::record::RecordWriter; -use pq_proto::StartupMessageParams; use remote_storage::{GenericRemoteStorage, RemotePath, RemoteStorageConfig, TimeoutOrCancel}; use serde::ser::SerializeMap; use tokio::sync::mpsc; @@ -24,6 +23,7 @@ use super::{LOG_CHAN, RequestContextInner}; use crate::config::remote_storage_from_toml; use crate::context::LOG_CHAN_DISCONNECT; use crate::ext::TaskExt; +use crate::pqproto::StartupMessageParams; #[derive(clap::Args, Clone, Debug)] pub struct ParquetUploadArgs { diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index 2765aaa462..93f4ea6cf7 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -15,7 +15,6 @@ use tracing::{Instrument, debug, info, info_span, warn}; use super::super::messages::{ControlPlaneErrorMessage, GetEndpointAccessControl, WakeCompute}; use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; -use crate::cache::Cached; use crate::context::RequestContext; use crate::control_plane::caches::ApiCaches; use crate::control_plane::errors::{ @@ -24,12 +23,12 @@ use crate::control_plane::errors::{ use crate::control_plane::locks::ApiLocks; use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason}; use crate::control_plane::{ - AccessBlockerFlags, AuthInfo, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps, - CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, NodeInfo, + AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo, + RoleAccessControl, }; -use crate::metrics::{CacheOutcome, Metrics}; +use crate::metrics::Metrics; use crate::rate_limiter::WakeComputeRateLimiter; -use crate::types::{EndpointCacheKey, EndpointId}; +use crate::types::{EndpointCacheKey, EndpointId, RoleName}; use crate::{compute, http, scram}; pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id"); @@ -66,65 +65,34 @@ impl NeonControlPlaneClient { self.endpoint.url().as_str() } - async fn do_get_auth_info( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - if !self - .caches - .endpoints_cache - .is_valid(ctx, &user_info.endpoint.normalize()) - { - // TODO: refactor this because it's weird - // this is a failure to authenticate but we return Ok. - info!("endpoint is not valid, skipping the request"); - return Ok(AuthInfo::default()); - } - self.do_get_auth_req(user_info, &ctx.session_id(), Some(ctx)) - .await - } - async fn do_get_auth_req( &self, - user_info: &ComputeUserInfo, - session_id: &uuid::Uuid, - ctx: Option<&RequestContext>, + ctx: &RequestContext, + endpoint: &EndpointId, + role: &RoleName, ) -> Result { - let request_id: String = session_id.to_string(); - let application_name = if let Some(ctx) = ctx { - ctx.console_application_name() - } else { - "auth_cancellation".to_string() - }; - async { let request = self .endpoint .get_path("get_endpoint_access_control") - .header(X_REQUEST_ID, &request_id) + .header(X_REQUEST_ID, ctx.session_id().to_string()) .header(AUTHORIZATION, format!("Bearer {}", &self.jwt)) - .query(&[("session_id", session_id)]) + .query(&[("session_id", ctx.session_id())]) .query(&[ - ("application_name", application_name.as_str()), - ("endpointish", user_info.endpoint.as_str()), - ("role", user_info.user.as_str()), + ("application_name", ctx.console_application_name().as_str()), + ("endpointish", endpoint.as_str()), + ("role", role.as_str()), ]) .build()?; debug!(url = request.url().as_str(), "sending http request"); let start = Instant::now(); - let response = match ctx { - Some(ctx) => { - let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane); - let rsp = self.endpoint.execute(request).await; - drop(pause); - rsp? - } - None => self.endpoint.execute(request).await?, + let response = { + let _pause = ctx.latency_timer_pause_at(start, crate::metrics::Waiting::Cplane); + self.endpoint.execute(request).await? }; - info!(duration = ?start.elapsed(), "received http response"); + let body = match parse_body::(response).await { Ok(body) => body, // Error 404 is special: it's ok not to have a secret. @@ -180,7 +148,7 @@ impl NeonControlPlaneClient { async fn do_get_endpoint_jwks( &self, ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, GetEndpointJwksError> { if !self .caches @@ -313,225 +281,104 @@ impl NeonControlPlaneClient { impl super::ControlPlaneApi for NeonControlPlaneClient { #[tracing::instrument(skip_all)] - async fn get_role_secret( + async fn get_role_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - let normalized_ep = &user_info.endpoint.normalize(); - let user = &user_info.user; - if let Some(role_secret) = self + endpoint: &EndpointId, + role: &RoleName, + ) -> Result { + let normalized_ep = &endpoint.normalize(); + if let Some(secret) = self .caches .project_info - .get_role_secret(normalized_ep, user) + .get_role_secret(normalized_ep, role) { - return Ok(role_secret); + return Ok(secret); } - let auth_info = self.do_get_auth_info(ctx, user_info).await?; - let account_id = auth_info.account_id; + + if !self.caches.endpoints_cache.is_valid(ctx, normalized_ep) { + info!("endpoint is not valid, skipping the request"); + return Err(GetAuthInfoError::UnknownEndpoint); + } + + let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?; + + let control = EndpointAccessControl { + allowed_ips: Arc::new(auth_info.allowed_ips), + allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids), + flags: auth_info.access_blocker_flags, + }; + let role_control = RoleAccessControl { + secret: auth_info.secret, + }; + if let Some(project_id) = auth_info.project_id { let normalized_ep_int = normalized_ep.into(); - self.caches.project_info.insert_role_secret( + + self.caches.project_info.insert_endpoint_access( + auth_info.account_id, project_id, normalized_ep_int, - user.into(), - auth_info.secret.clone(), - ); - self.caches.project_info.insert_allowed_ips( - project_id, - normalized_ep_int, - Arc::new(auth_info.allowed_ips), - ); - self.caches.project_info.insert_allowed_vpc_endpoint_ids( - account_id, - project_id, - normalized_ep_int, - Arc::new(auth_info.allowed_vpc_endpoint_ids), - ); - self.caches.project_info.insert_block_public_or_vpc_access( - project_id, - normalized_ep_int, - auth_info.access_blocker_flags, + role.into(), + control, + role_control.clone(), ); ctx.set_project_id(project_id); } - // When we just got a secret, we don't need to invalidate it. - Ok(Cached::new_uncached(auth_info.secret)) + + Ok(role_control) } - async fn get_allowed_ips( + #[tracing::instrument(skip_all)] + async fn get_endpoint_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - let normalized_ep = &user_info.endpoint.normalize(); - if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) { - Metrics::get() - .proxy - .allowed_ips_cache_misses // TODO SR: Should we rename this variable to something like allowed_ip_cache_stats? - .inc(CacheOutcome::Hit); - return Ok(allowed_ips); + endpoint: &EndpointId, + role: &RoleName, + ) -> Result { + let normalized_ep = &endpoint.normalize(); + if let Some(control) = self.caches.project_info.get_endpoint_access(normalized_ep) { + return Ok(control); } - Metrics::get() - .proxy - .allowed_ips_cache_misses - .inc(CacheOutcome::Miss); - let auth_info = self.do_get_auth_info(ctx, user_info).await?; - let allowed_ips = Arc::new(auth_info.allowed_ips); - let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids); - let access_blocker_flags = auth_info.access_blocker_flags; - let user = &user_info.user; - let account_id = auth_info.account_id; + + if !self.caches.endpoints_cache.is_valid(ctx, normalized_ep) { + info!("endpoint is not valid, skipping the request"); + return Err(GetAuthInfoError::UnknownEndpoint); + } + + let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?; + + let control = EndpointAccessControl { + allowed_ips: Arc::new(auth_info.allowed_ips), + allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids), + flags: auth_info.access_blocker_flags, + }; + let role_control = RoleAccessControl { + secret: auth_info.secret, + }; + if let Some(project_id) = auth_info.project_id { let normalized_ep_int = normalized_ep.into(); - self.caches.project_info.insert_role_secret( + + self.caches.project_info.insert_endpoint_access( + auth_info.account_id, project_id, normalized_ep_int, - user.into(), - auth_info.secret.clone(), - ); - self.caches.project_info.insert_allowed_ips( - project_id, - normalized_ep_int, - allowed_ips.clone(), - ); - self.caches.project_info.insert_allowed_vpc_endpoint_ids( - account_id, - project_id, - normalized_ep_int, - allowed_vpc_endpoint_ids.clone(), - ); - self.caches.project_info.insert_block_public_or_vpc_access( - project_id, - normalized_ep_int, - access_blocker_flags, + role.into(), + control.clone(), + role_control, ); ctx.set_project_id(project_id); } - Ok(Cached::new_uncached(allowed_ips)) - } - async fn get_allowed_vpc_endpoint_ids( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - let normalized_ep = &user_info.endpoint.normalize(); - if let Some(allowed_vpc_endpoint_ids) = self - .caches - .project_info - .get_allowed_vpc_endpoint_ids(normalized_ep) - { - Metrics::get() - .proxy - .vpc_endpoint_id_cache_stats - .inc(CacheOutcome::Hit); - return Ok(allowed_vpc_endpoint_ids); - } - - Metrics::get() - .proxy - .vpc_endpoint_id_cache_stats - .inc(CacheOutcome::Miss); - - let auth_info = self.do_get_auth_info(ctx, user_info).await?; - let allowed_ips = Arc::new(auth_info.allowed_ips); - let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids); - let access_blocker_flags = auth_info.access_blocker_flags; - let user = &user_info.user; - let account_id = auth_info.account_id; - if let Some(project_id) = auth_info.project_id { - let normalized_ep_int = normalized_ep.into(); - self.caches.project_info.insert_role_secret( - project_id, - normalized_ep_int, - user.into(), - auth_info.secret.clone(), - ); - self.caches.project_info.insert_allowed_ips( - project_id, - normalized_ep_int, - allowed_ips.clone(), - ); - self.caches.project_info.insert_allowed_vpc_endpoint_ids( - account_id, - project_id, - normalized_ep_int, - allowed_vpc_endpoint_ids.clone(), - ); - self.caches.project_info.insert_block_public_or_vpc_access( - project_id, - normalized_ep_int, - access_blocker_flags, - ); - ctx.set_project_id(project_id); - } - Ok(Cached::new_uncached(allowed_vpc_endpoint_ids)) - } - - async fn get_block_public_or_vpc_access( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - let normalized_ep = &user_info.endpoint.normalize(); - if let Some(access_blocker_flags) = self - .caches - .project_info - .get_block_public_or_vpc_access(normalized_ep) - { - Metrics::get() - .proxy - .access_blocker_flags_cache_stats - .inc(CacheOutcome::Hit); - return Ok(access_blocker_flags); - } - - Metrics::get() - .proxy - .access_blocker_flags_cache_stats - .inc(CacheOutcome::Miss); - - let auth_info = self.do_get_auth_info(ctx, user_info).await?; - let allowed_ips = Arc::new(auth_info.allowed_ips); - let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids); - let access_blocker_flags = auth_info.access_blocker_flags; - let user = &user_info.user; - let account_id = auth_info.account_id; - if let Some(project_id) = auth_info.project_id { - let normalized_ep_int = normalized_ep.into(); - self.caches.project_info.insert_role_secret( - project_id, - normalized_ep_int, - user.into(), - auth_info.secret.clone(), - ); - self.caches.project_info.insert_allowed_ips( - project_id, - normalized_ep_int, - allowed_ips.clone(), - ); - self.caches.project_info.insert_allowed_vpc_endpoint_ids( - account_id, - project_id, - normalized_ep_int, - allowed_vpc_endpoint_ids.clone(), - ); - self.caches.project_info.insert_block_public_or_vpc_access( - project_id, - normalized_ep_int, - access_blocker_flags.clone(), - ); - ctx.set_project_id(project_id); - } - Ok(Cached::new_uncached(access_blocker_flags)) + Ok(control) } #[tracing::instrument(skip_all)] async fn get_endpoint_jwks( &self, ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, GetEndpointJwksError> { self.do_get_endpoint_jwks(ctx, endpoint).await } diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index d3ab4abd0b..ece7153fce 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -15,14 +15,14 @@ use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; use crate::cache::Cached; use crate::context::RequestContext; -use crate::control_plane::client::{ - CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedRoleSecret, -}; use crate::control_plane::errors::{ ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError, }; use crate::control_plane::messages::MetricsAuxInfo; -use crate::control_plane::{AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo}; +use crate::control_plane::{ + AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo, + RoleAccessControl, +}; use crate::intern::RoleNameInt; use crate::types::{BranchId, EndpointId, ProjectId, RoleName}; use crate::url::ApiUrl; @@ -66,7 +66,8 @@ impl MockControlPlane { async fn do_get_auth_info( &self, - user_info: &ComputeUserInfo, + endpoint: &EndpointId, + role: &RoleName, ) -> Result { let (secret, allowed_ips) = async { // Perhaps we could persist this connection, but then we'd have to @@ -80,7 +81,7 @@ impl MockControlPlane { let secret = if let Some(entry) = get_execute_postgres_query( &client, "select rolpassword from pg_catalog.pg_authid where rolname = $1", - &[&&*user_info.user], + &[&role.as_str()], "rolpassword", ) .await? @@ -89,7 +90,7 @@ impl MockControlPlane { let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram); secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5)) } else { - warn!("user '{}' does not exist", user_info.user); + warn!("user '{role}' does not exist"); None }; @@ -97,7 +98,7 @@ impl MockControlPlane { match get_execute_postgres_query( &client, "select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1", - &[&user_info.endpoint.as_str()], + &[&endpoint.as_str()], "allowed_ips", ) .await? @@ -133,7 +134,7 @@ impl MockControlPlane { async fn do_get_endpoint_jwks( &self, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, GetEndpointJwksError> { let (client, connection) = tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?; @@ -222,53 +223,36 @@ async fn get_execute_postgres_query( } impl super::ControlPlaneApi for MockControlPlane { - #[tracing::instrument(skip_all)] - async fn get_role_secret( + async fn get_endpoint_access_control( &self, _ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - Ok(CachedRoleSecret::new_uncached( - self.do_get_auth_info(user_info).await?.secret, - )) + endpoint: &EndpointId, + role: &RoleName, + ) -> Result { + let info = self.do_get_auth_info(endpoint, role).await?; + Ok(EndpointAccessControl { + allowed_ips: Arc::new(info.allowed_ips), + allowed_vpce: Arc::new(info.allowed_vpc_endpoint_ids), + flags: info.access_blocker_flags, + }) } - async fn get_allowed_ips( + async fn get_role_access_control( &self, _ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - Ok(Cached::new_uncached(Arc::new( - self.do_get_auth_info(user_info).await?.allowed_ips, - ))) - } - - async fn get_allowed_vpc_endpoint_ids( - &self, - _ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - Ok(Cached::new_uncached(Arc::new( - self.do_get_auth_info(user_info) - .await? - .allowed_vpc_endpoint_ids, - ))) - } - - async fn get_block_public_or_vpc_access( - &self, - _ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - Ok(Cached::new_uncached( - self.do_get_auth_info(user_info).await?.access_blocker_flags, - )) + endpoint: &EndpointId, + role: &RoleName, + ) -> Result { + let info = self.do_get_auth_info(endpoint, role).await?; + Ok(RoleAccessControl { + secret: info.secret, + }) } async fn get_endpoint_jwks( &self, _ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, GetEndpointJwksError> { self.do_get_endpoint_jwks(endpoint).await } diff --git a/proxy/src/control_plane/client/mod.rs b/proxy/src/control_plane/client/mod.rs index 746595de38..9b9d1e25ea 100644 --- a/proxy/src/control_plane/client/mod.rs +++ b/proxy/src/control_plane/client/mod.rs @@ -16,15 +16,14 @@ use crate::cache::endpoints::EndpointsCache; use crate::cache::project_info::ProjectInfoCacheImpl; use crate::config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions}; use crate::context::RequestContext; -use crate::control_plane::{ - CachedAccessBlockerFlags, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, - CachedRoleSecret, ControlPlaneApi, NodeInfoCache, errors, -}; +use crate::control_plane::{CachedNodeInfo, ControlPlaneApi, NodeInfoCache, errors}; use crate::error::ReportableError; use crate::metrics::ApiLockMetrics; use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token}; use crate::types::EndpointId; +use super::{EndpointAccessControl, RoleAccessControl}; + #[non_exhaustive] #[derive(Clone)] pub enum ControlPlaneClient { @@ -40,68 +39,42 @@ pub enum ControlPlaneClient { } impl ControlPlaneApi for ControlPlaneClient { - async fn get_role_secret( + async fn get_role_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { + endpoint: &EndpointId, + role: &crate::types::RoleName, + ) -> Result { match self { - Self::ProxyV1(api) => api.get_role_secret(ctx, user_info).await, + Self::ProxyV1(api) => api.get_role_access_control(ctx, endpoint, role).await, #[cfg(any(test, feature = "testing"))] - Self::PostgresMock(api) => api.get_role_secret(ctx, user_info).await, + Self::PostgresMock(api) => api.get_role_access_control(ctx, endpoint, role).await, #[cfg(test)] - Self::Test(_) => { + Self::Test(_api) => { unreachable!("this function should never be called in the test backend") } } } - async fn get_allowed_ips( + async fn get_endpoint_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { + endpoint: &EndpointId, + role: &crate::types::RoleName, + ) -> Result { match self { - Self::ProxyV1(api) => api.get_allowed_ips(ctx, user_info).await, + Self::ProxyV1(api) => api.get_endpoint_access_control(ctx, endpoint, role).await, #[cfg(any(test, feature = "testing"))] - Self::PostgresMock(api) => api.get_allowed_ips(ctx, user_info).await, + Self::PostgresMock(api) => api.get_endpoint_access_control(ctx, endpoint, role).await, #[cfg(test)] - Self::Test(api) => api.get_allowed_ips(), - } - } - - async fn get_allowed_vpc_endpoint_ids( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - match self { - Self::ProxyV1(api) => api.get_allowed_vpc_endpoint_ids(ctx, user_info).await, - #[cfg(any(test, feature = "testing"))] - Self::PostgresMock(api) => api.get_allowed_vpc_endpoint_ids(ctx, user_info).await, - #[cfg(test)] - Self::Test(api) => api.get_allowed_vpc_endpoint_ids(), - } - } - - async fn get_block_public_or_vpc_access( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - match self { - Self::ProxyV1(api) => api.get_block_public_or_vpc_access(ctx, user_info).await, - #[cfg(any(test, feature = "testing"))] - Self::PostgresMock(api) => api.get_block_public_or_vpc_access(ctx, user_info).await, - #[cfg(test)] - Self::Test(api) => api.get_block_public_or_vpc_access(), + Self::Test(api) => api.get_access_control(), } } async fn get_endpoint_jwks( &self, ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, errors::GetEndpointJwksError> { match self { Self::ProxyV1(api) => api.get_endpoint_jwks(ctx, endpoint).await, @@ -131,15 +104,7 @@ impl ControlPlaneApi for ControlPlaneClient { pub(crate) trait TestControlPlaneClient: Send + Sync + 'static { fn wake_compute(&self) -> Result; - fn get_allowed_ips(&self) -> Result; - - fn get_allowed_vpc_endpoint_ids( - &self, - ) -> Result; - - fn get_block_public_or_vpc_access( - &self, - ) -> Result; + fn get_access_control(&self) -> Result; fn dyn_clone(&self) -> Box; } @@ -309,7 +274,7 @@ impl FetchAuthRules for ControlPlaneClient { ctx: &RequestContext, endpoint: EndpointId, ) -> Result, FetchAuthRulesError> { - self.get_endpoint_jwks(ctx, endpoint) + self.get_endpoint_jwks(ctx, &endpoint) .await .map_err(FetchAuthRulesError::GetEndpointJwks) } diff --git a/proxy/src/control_plane/errors.rs b/proxy/src/control_plane/errors.rs index 850d061333..77312c89c5 100644 --- a/proxy/src/control_plane/errors.rs +++ b/proxy/src/control_plane/errors.rs @@ -99,6 +99,10 @@ pub(crate) enum GetAuthInfoError { #[error(transparent)] ApiError(ControlPlaneError), + + /// Proxy does not know about the endpoint in advanced + #[error("endpoint not found in endpoint cache")] + UnknownEndpoint, } // This allows more useful interactions than `#[from]`. @@ -115,6 +119,8 @@ impl UserFacingError for GetAuthInfoError { Self::BadSecret => REQUEST_FAILED.to_owned(), // However, API might return a meaningful error. Self::ApiError(e) => e.to_string_client(), + // pretend like control plane returned an error. + Self::UnknownEndpoint => REQUEST_FAILED.to_owned(), } } } @@ -124,6 +130,8 @@ impl ReportableError for GetAuthInfoError { match self { Self::BadSecret => crate::error::ErrorKind::ControlPlane, Self::ApiError(_) => crate::error::ErrorKind::ControlPlane, + // we only apply endpoint filtering if control plane is under high load. + Self::UnknownEndpoint => crate::error::ErrorKind::ServiceRateLimit, } } } diff --git a/proxy/src/control_plane/mod.rs b/proxy/src/control_plane/mod.rs index d592223be1..7ff093d9dc 100644 --- a/proxy/src/control_plane/mod.rs +++ b/proxy/src/control_plane/mod.rs @@ -11,16 +11,16 @@ pub(crate) mod errors; use std::sync::Arc; -use crate::auth::IpPattern; use crate::auth::backend::jwt::AuthRule; use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo}; -use crate::cache::project_info::ProjectInfoCacheImpl; +use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list}; use crate::cache::{Cached, TimedLru}; use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo}; use crate::intern::{AccountIdInt, ProjectIdInt}; -use crate::types::{EndpointCacheKey, EndpointId}; +use crate::protocol2::ConnectionInfoExtra; +use crate::types::{EndpointCacheKey, EndpointId, RoleName}; use crate::{compute, scram}; /// Various cache-related types. @@ -101,7 +101,7 @@ impl NodeInfo { } } -#[derive(Clone, Default, Eq, PartialEq, Debug)] +#[derive(Copy, Clone, Default)] pub(crate) struct AccessBlockerFlags { pub public_access_blocked: bool, pub vpc_access_blocked: bool, @@ -110,47 +110,78 @@ pub(crate) struct AccessBlockerFlags { pub(crate) type NodeInfoCache = TimedLru>>; pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>; -pub(crate) type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option>; -pub(crate) type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc>>; -pub(crate) type CachedAllowedVpcEndpointIds = - Cached<&'static ProjectInfoCacheImpl, Arc>>; -pub(crate) type CachedAccessBlockerFlags = - Cached<&'static ProjectInfoCacheImpl, AccessBlockerFlags>; + +#[derive(Clone)] +pub struct RoleAccessControl { + pub secret: Option, +} + +#[derive(Clone)] +pub struct EndpointAccessControl { + pub allowed_ips: Arc>, + pub allowed_vpce: Arc>, + pub flags: AccessBlockerFlags, +} + +impl EndpointAccessControl { + pub fn check( + &self, + ctx: &RequestContext, + check_ip_allowed: bool, + check_vpc_allowed: bool, + ) -> Result<(), AuthError> { + if check_ip_allowed && !check_peer_addr_is_in_list(&ctx.peer_addr(), &self.allowed_ips) { + return Err(AuthError::IpAddressNotAllowed(ctx.peer_addr())); + } + + // check if a VPC endpoint ID is coming in and if yes, if it's allowed + if check_vpc_allowed { + if self.flags.vpc_access_blocked { + return Err(AuthError::NetworkNotAllowed); + } + + let incoming_vpc_endpoint_id = match ctx.extra() { + None => return Err(AuthError::MissingVPCEndpointId), + Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(), + Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), + }; + + let vpce = &self.allowed_vpce; + // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. + if !vpce.is_empty() && !vpce.contains(&incoming_vpc_endpoint_id) { + return Err(AuthError::vpc_endpoint_id_not_allowed( + incoming_vpc_endpoint_id, + )); + } + } else if self.flags.public_access_blocked { + return Err(AuthError::NetworkNotAllowed); + } + + Ok(()) + } +} /// This will allocate per each call, but the http requests alone /// already require a few allocations, so it should be fine. pub(crate) trait ControlPlaneApi { - /// Get the client's auth secret for authentication. - /// Returns option because user not found situation is special. - /// We still have to mock the scram to avoid leaking information that user doesn't exist. - async fn get_role_secret( + async fn get_role_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result; + endpoint: &EndpointId, + role: &RoleName, + ) -> Result; - async fn get_allowed_ips( + async fn get_endpoint_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result; - - async fn get_allowed_vpc_endpoint_ids( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result; - - async fn get_block_public_or_vpc_access( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result; + endpoint: &EndpointId, + role: &RoleName, + ) -> Result; async fn get_endpoint_jwks( &self, ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, errors::GetEndpointJwksError>; /// Wake up the compute node and return the corresponding connection info. diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index d1f8430b8a..d65d056585 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -92,6 +92,7 @@ mod logging; mod metrics; mod parse; mod pglb; +mod pqproto; mod protocol2; mod proxy; mod rate_limiter; diff --git a/proxy/src/pqproto.rs b/proxy/src/pqproto.rs new file mode 100644 index 0000000000..d68d9f9474 --- /dev/null +++ b/proxy/src/pqproto.rs @@ -0,0 +1,693 @@ +//! Postgres protocol codec +//! +//! + +use std::fmt; +use std::io::{self, Cursor}; + +use bytes::{Buf, BufMut}; +use itertools::Itertools; +use rand::distributions::{Distribution, Standard}; +use tokio::io::{AsyncRead, AsyncReadExt}; +use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian}; + +pub type ErrorCode = [u8; 5]; + +pub const FE_PASSWORD_MESSAGE: u8 = b'p'; + +pub const SQLSTATE_INTERNAL_ERROR: [u8; 5] = *b"XX000"; + +/// The protocol version number. +/// +/// The most significant 16 bits are the major version number (3 for the protocol described here). +/// The least significant 16 bits are the minor version number (0 for the protocol described here). +/// +#[derive(Clone, Copy, PartialEq, PartialOrd, FromBytes, IntoBytes, Immutable)] +#[repr(C)] +pub struct ProtocolVersion { + major: big_endian::U16, + minor: big_endian::U16, +} + +impl ProtocolVersion { + pub const fn new(major: u16, minor: u16) -> Self { + Self { + major: big_endian::U16::new(major), + minor: big_endian::U16::new(minor), + } + } + pub const fn minor(self) -> u16 { + self.minor.get() + } + pub const fn major(self) -> u16 { + self.major.get() + } +} + +impl fmt::Debug for ProtocolVersion { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list() + .entry(&self.major()) + .entry(&self.minor()) + .finish() + } +} + +/// read the type from the stream using zerocopy. +/// +/// not cancel safe. +macro_rules! read { + ($s:expr => $t:ty) => {{ + // cannot be implemented as a function due to lack of const-generic-expr + let mut buf = [0; size_of::<$t>()]; + $s.read_exact(&mut buf).await?; + let res: $t = zerocopy::transmute!(buf); + res + }}; +} + +pub async fn read_startup(stream: &mut S) -> io::Result +where + S: AsyncRead + Unpin, +{ + /// + const MAX_STARTUP_PACKET_LENGTH: usize = 10000; + const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234; + /// + const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678); + /// + const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679); + /// + const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680); + + /// This first reads the startup message header, is 8 bytes. + /// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number. + /// + /// The length value is inclusive of the header. For example, + /// an empty message will always have length 8. + #[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)] + #[repr(C)] + struct StartupHeader { + len: big_endian::U32, + version: ProtocolVersion, + } + + let header = read!(stream => StartupHeader); + + // + // First byte indicates standard SSL handshake message + // (It can't be a Postgres startup length because in network byte order + // that would be a startup packet hundreds of megabytes long) + if header.as_bytes()[0] == 0x16 { + return Ok(FeStartupPacket::SslRequest { + // The bytes we read for the header are actually part of a TLS ClientHello. + // In theory, if the ClientHello was < 8 bytes we would fail with EOF before we get here. + // In practice though, I see no world where a ClientHello is less than 8 bytes + // since it includes ephemeral keys etc. + direct: Some(zerocopy::transmute!(header)), + }); + } + + let Some(len) = (header.len.get() as usize).checked_sub(8) else { + return Err(io::Error::other(format!( + "invalid startup message length {}, must be at least 8.", + header.len, + ))); + }; + + // TODO: add a histogram for startup packet lengths + if len > MAX_STARTUP_PACKET_LENGTH { + tracing::warn!("large startup message detected: {len} bytes"); + return Err(io::Error::other(format!( + "invalid startup message length {len}" + ))); + } + + match header.version { + // + CANCEL_REQUEST_CODE => { + if len != 8 { + return Err(io::Error::other( + "CancelRequest message is malformed, backend PID / secret key missing", + )); + } + + Ok(FeStartupPacket::CancelRequest( + read!(stream => CancelKeyData), + )) + } + // + NEGOTIATE_SSL_CODE => { + // Requested upgrade to SSL (aka TLS) + Ok(FeStartupPacket::SslRequest { direct: None }) + } + NEGOTIATE_GSS_CODE => { + // Requested upgrade to GSSAPI + Ok(FeStartupPacket::GssEncRequest) + } + version if version.major() == RESERVED_INVALID_MAJOR_VERSION => Err(io::Error::other( + format!("Unrecognized request code {version:?}"), + )), + // StartupMessage + version => { + // The protocol version number is followed by one or more pairs of parameter name and value strings. + // A zero byte is required as a terminator after the last name/value pair. + // Parameters can appear in any order. user is required, others are optional. + + let mut buf = vec![0; len]; + stream.read_exact(&mut buf).await?; + + if buf.pop() != Some(b'\0') { + return Err(io::Error::other( + "StartupMessage params: missing null terminator", + )); + } + + // TODO: Don't do this. + // There's no guarantee that these messages are utf8, + // but they usually happen to be simple ascii. + let params = String::from_utf8(buf) + .map_err(|_| io::Error::other("StartupMessage params: invalid utf-8"))?; + + Ok(FeStartupPacket::StartupMessage { + version, + params: StartupMessageParams { params }, + }) + } + } +} + +/// Read a raw postgres packet, which will respect the max length requested. +/// +/// This returns the message tag, as well as the message body. The message +/// body is written into `buf`, and it is otherwise completely overwritten. +/// +/// This is not cancel safe. +pub async fn read_message<'a, S>( + stream: &mut S, + buf: &'a mut Vec, + max: usize, +) -> io::Result<(u8, &'a mut [u8])> +where + S: AsyncRead + Unpin, +{ + /// This first reads the header, which for regular messages in the 3.0 protocol is 5 bytes. + /// The first byte is a message tag, and the next 4 bytes is a big-endian length. + /// + /// Awkwardly, the length value is inclusive of itself, but not of the tag. For example, + /// an empty message will always have length 4. + #[derive(Clone, Copy, FromBytes)] + #[repr(C)] + struct Header { + tag: u8, + len: big_endian::U32, + } + + let header = read!(stream => Header); + + // as described above, the length must be at least 4. + let Some(len) = (header.len.get() as usize).checked_sub(4) else { + return Err(io::Error::other(format!( + "invalid startup message length {}, must be at least 4.", + header.len, + ))); + }; + + // TODO: add a histogram for message lengths + + // check if the message exceeds our desired max. + if len > max { + tracing::warn!("large postgres message detected: {len} bytes"); + return Err(io::Error::other(format!("invalid message length {len}"))); + } + + // read in our entire message. + buf.resize(len, 0); + stream.read_exact(buf).await?; + + Ok((header.tag, buf)) +} + +pub struct WriteBuf(Cursor>); + +impl Buf for WriteBuf { + #[inline] + fn remaining(&self) -> usize { + self.0.remaining() + } + + #[inline] + fn chunk(&self) -> &[u8] { + self.0.chunk() + } + + #[inline] + fn advance(&mut self, cnt: usize) { + self.0.advance(cnt); + } +} + +impl WriteBuf { + pub const fn new() -> Self { + Self(Cursor::new(Vec::new())) + } + + /// Use a heuristic to determine if we should shrink the write buffer. + #[inline] + fn should_shrink(&self) -> bool { + let n = self.0.position() as usize; + let len = self.0.get_ref().len(); + + // the unused space at the front of our buffer is 2x the size of our filled portion. + n + n > len + } + + /// Shrink the write buffer so that subsequent writes have more spare capacity. + #[cold] + fn shrink(&mut self) { + let n = self.0.position() as usize; + let buf = self.0.get_mut(); + + // buf repr: + // [----unused------|-----filled-----|-----uninit-----] + // ^ n ^ buf.len() ^ buf.capacity() + let filled = n..buf.len(); + let filled_len = filled.len(); + buf.copy_within(filled, 0); + buf.truncate(filled_len); + self.0.set_position(0); + } + + /// clear the write buffer. + pub fn reset(&mut self) { + let buf = self.0.get_mut(); + buf.clear(); + self.0.set_position(0); + } + + /// Write a raw message to the internal buffer. + /// + /// The size_hint value is only a hint for reserving space. It's ok if it's incorrect, since + /// we calculate the length after the fact. + pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec)) { + if self.should_shrink() { + self.shrink(); + } + + let buf = self.0.get_mut(); + buf.reserve(5 + size_hint); + + buf.push(tag); + let start = buf.len(); + buf.extend_from_slice(&[0, 0, 0, 0]); + + f(buf); + + let end = buf.len(); + let len = (end - start) as u32; + buf[start..start + 4].copy_from_slice(&len.to_be_bytes()); + } + + /// Write an encryption response message. + pub fn encryption(&mut self, m: u8) { + self.0.get_mut().push(m); + } + + pub fn write_error(&mut self, msg: &str, error_code: ErrorCode) { + self.shrink(); + + // + // + // "SERROR\0CXXXXX\0M\0\0".len() == 17 + self.write_raw(17 + msg.len(), b'E', |buf| { + // Severity: ERROR + buf.put_slice(b"SERROR\0"); + + // Code: error_code + buf.put_u8(b'C'); + buf.put_slice(&error_code); + buf.put_u8(0); + + // Message: msg + buf.put_u8(b'M'); + buf.put_slice(msg.as_bytes()); + buf.put_u8(0); + + // End. + buf.put_u8(0); + }); + } +} + +#[derive(Debug)] +pub enum FeStartupPacket { + CancelRequest(CancelKeyData), + SslRequest { + direct: Option<[u8; 8]>, + }, + GssEncRequest, + StartupMessage { + version: ProtocolVersion, + params: StartupMessageParams, + }, +} + +#[derive(Debug, Clone, Default)] +pub struct StartupMessageParams { + pub params: String, +} + +impl StartupMessageParams { + /// Get parameter's value by its name. + pub fn get(&self, name: &str) -> Option<&str> { + self.iter().find_map(|(k, v)| (k == name).then_some(v)) + } + + /// Split command-line options according to PostgreSQL's logic, + /// taking into account all escape sequences but leaving them as-is. + /// [`None`] means that there's no `options` in [`Self`]. + pub fn options_raw(&self) -> Option> { + self.get("options").map(Self::parse_options_raw) + } + + /// Split command-line options according to PostgreSQL's logic, + /// taking into account all escape sequences but leaving them as-is. + pub fn parse_options_raw(input: &str) -> impl Iterator { + // See `postgres: pg_split_opts`. + let mut last_was_escape = false; + input + .split(move |c: char| { + // We split by non-escaped whitespace symbols. + let should_split = c.is_ascii_whitespace() && !last_was_escape; + last_was_escape = c == '\\' && !last_was_escape; + should_split + }) + .filter(|s| !s.is_empty()) + } + + /// Iterate through key-value pairs in an arbitrary order. + pub fn iter(&self) -> impl Iterator { + self.params.split_terminator('\0').tuples() + } + + // This function is mostly useful in tests. + #[cfg(test)] + pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self { + let mut b = Self { + params: String::new(), + }; + for (k, v) in pairs { + b.insert(k, v); + } + b + } + + /// Set parameter's value by its name. + /// name and value must not contain a \0 byte + pub fn insert(&mut self, name: &str, value: &str) { + self.params.reserve(name.len() + value.len() + 2); + self.params.push_str(name); + self.params.push('\0'); + self.params.push_str(value); + self.params.push('\0'); + } +} + +/// Cancel keys usually are represented as PID+SecretKey, but to proxy they're just +/// opaque bytes. +#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, FromBytes, IntoBytes, Immutable)] +pub struct CancelKeyData(pub big_endian::U64); + +pub fn id_to_cancel_key(id: u64) -> CancelKeyData { + CancelKeyData(big_endian::U64::new(id)) +} + +impl fmt::Display for CancelKeyData { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let id = self.0; + f.debug_tuple("CancelKeyData") + .field(&format_args!("{id:x}")) + .finish() + } +} +impl Distribution for Standard { + fn sample(&self, rng: &mut R) -> CancelKeyData { + id_to_cancel_key(rng.r#gen()) + } +} + +pub enum BeMessage<'a> { + AuthenticationOk, + AuthenticationSasl(BeAuthenticationSaslMessage<'a>), + AuthenticationCleartextPassword, + BackendKeyData(CancelKeyData), + ParameterStatus { + name: &'a [u8], + value: &'a [u8], + }, + ReadyForQuery, + NoticeResponse(&'a str), + NegotiateProtocolVersion { + version: ProtocolVersion, + options: &'a [&'a str], + }, +} + +#[derive(Debug)] +pub enum BeAuthenticationSaslMessage<'a> { + Methods(&'a [&'a str]), + Continue(&'a [u8]), + Final(&'a [u8]), +} + +impl BeMessage<'_> { + /// Write the message into an internal buffer + pub fn write_message(self, buf: &mut WriteBuf) { + match self { + // + BeMessage::AuthenticationOk => { + buf.write_raw(1, b'R', |buf| buf.put_i32(0)); + } + // + BeMessage::AuthenticationCleartextPassword => { + buf.write_raw(1, b'R', |buf| buf.put_i32(3)); + } + + // + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(methods)) => { + let len: usize = methods.iter().map(|m| m.len() + 1).sum(); + buf.write_raw(len + 2, b'R', |buf| { + buf.put_i32(10); // Specifies that SASL auth method is used. + for method in methods { + buf.put_slice(method.as_bytes()); + buf.put_u8(0); + } + buf.put_u8(0); // zero terminator for the list + }); + } + // + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Continue(extra)) => { + buf.write_raw(extra.len() + 1, b'R', |buf| { + buf.put_i32(11); // Continue SASL auth. + buf.put_slice(extra); + }); + } + // + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Final(extra)) => { + buf.write_raw(extra.len() + 1, b'R', |buf| { + buf.put_i32(12); // Send final SASL message. + buf.put_slice(extra); + }); + } + + // + BeMessage::BackendKeyData(key_data) => { + buf.write_raw(8, b'K', |buf| buf.put_slice(key_data.as_bytes())); + } + + // + // + BeMessage::NoticeResponse(msg) => { + // 'N' signalizes NoticeResponse messages + buf.write_raw(18 + msg.len(), b'N', |buf| { + // Severity: NOTICE + buf.put_slice(b"SNOTICE\0"); + + // Code: XX000 (ignored for notice, but still required) + buf.put_slice(b"CXX000\0"); + + // Message: msg + buf.put_u8(b'M'); + buf.put_slice(msg.as_bytes()); + buf.put_u8(0); + + // End notice. + buf.put_u8(0); + }); + } + + // + BeMessage::ParameterStatus { name, value } => { + buf.write_raw(name.len() + value.len() + 2, b'S', |buf| { + buf.put_slice(name.as_bytes()); + buf.put_u8(0); + buf.put_slice(value.as_bytes()); + buf.put_u8(0); + }); + } + + // + BeMessage::ReadyForQuery => { + buf.write_raw(1, b'Z', |buf| buf.put_u8(b'I')); + } + + // + BeMessage::NegotiateProtocolVersion { version, options } => { + let len: usize = options.iter().map(|o| o.len() + 1).sum(); + buf.write_raw(8 + len, b'v', |buf| { + buf.put_slice(version.as_bytes()); + buf.put_u32(options.len() as u32); + for option in options { + buf.put_slice(option.as_bytes()); + buf.put_u8(0); + } + }); + } + } + } +} + +#[cfg(test)] +mod tests { + use std::io::Cursor; + + use tokio::io::{AsyncWriteExt, duplex}; + use zerocopy::IntoBytes; + + use crate::pqproto::{FeStartupPacket, read_message, read_startup}; + + use super::ProtocolVersion; + + #[tokio::test] + async fn reject_large_startup() { + // we're going to define a v3.0 startup message with far too many parameters. + let mut payload = vec![]; + // 10001 + 8 bytes. + payload.extend_from_slice(&10009_u32.to_be_bytes()); + payload.extend_from_slice(ProtocolVersion::new(3, 0).as_bytes()); + payload.resize(10009, b'a'); + + let (mut server, mut client) = duplex(128); + #[rustfmt::skip] + let (server, client) = tokio::join!( + async move { read_startup(&mut server).await.unwrap_err() }, + async move { client.write_all(&payload).await.unwrap_err() }, + ); + + assert_eq!(server.to_string(), "invalid startup message length 10001"); + assert_eq!(client.to_string(), "broken pipe"); + } + + #[tokio::test] + async fn reject_large_password() { + // we're going to define a password message that is far too long. + let mut payload = vec![]; + payload.push(b'p'); + payload.extend_from_slice(&517_u32.to_be_bytes()); + payload.resize(518, b'a'); + + let (mut server, mut client) = duplex(128); + #[rustfmt::skip] + let (server, client) = tokio::join!( + async move { read_message(&mut server, &mut vec![], 512).await.unwrap_err() }, + async move { client.write_all(&payload).await.unwrap_err() }, + ); + + assert_eq!(server.to_string(), "invalid message length 513"); + assert_eq!(client.to_string(), "broken pipe"); + } + + #[tokio::test] + async fn read_startup_message() { + let mut payload = vec![]; + payload.extend_from_slice(&17_u32.to_be_bytes()); + payload.extend_from_slice(ProtocolVersion::new(3, 0).as_bytes()); + payload.extend_from_slice(b"abc\0def\0\0"); + + let startup = read_startup(&mut Cursor::new(&payload)).await.unwrap(); + let FeStartupPacket::StartupMessage { version, params } = startup else { + panic!("unexpected startup message: {startup:?}"); + }; + + assert_eq!(version.major(), 3); + assert_eq!(version.minor(), 0); + assert_eq!(params.params, "abc\0def\0"); + } + + #[tokio::test] + async fn read_ssl_message() { + let mut payload = vec![]; + payload.extend_from_slice(&8_u32.to_be_bytes()); + payload.extend_from_slice(ProtocolVersion::new(1234, 5679).as_bytes()); + + let startup = read_startup(&mut Cursor::new(&payload)).await.unwrap(); + let FeStartupPacket::SslRequest { direct: None } = startup else { + panic!("unexpected startup message: {startup:?}"); + }; + } + + #[tokio::test] + async fn read_tls_message() { + // sample client hello taken from + let client_hello = [ + 0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00, 0xf4, 0x03, 0x03, 0x00, 0x01, 0x02, + 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, + 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, + 0x1f, 0x20, 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb, + 0xec, 0xed, 0xee, 0xef, 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, + 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, 0x00, 0x08, 0x13, 0x02, 0x13, 0x03, 0x13, 0x01, + 0x00, 0xff, 0x01, 0x00, 0x00, 0xa3, 0x00, 0x00, 0x00, 0x18, 0x00, 0x16, 0x00, 0x00, + 0x13, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x75, 0x6c, 0x66, 0x68, 0x65, + 0x69, 0x6d, 0x2e, 0x6e, 0x65, 0x74, 0x00, 0x0b, 0x00, 0x04, 0x03, 0x00, 0x01, 0x02, + 0x00, 0x0a, 0x00, 0x16, 0x00, 0x14, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x1e, 0x00, 0x19, + 0x00, 0x18, 0x01, 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x23, + 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x1e, + 0x00, 0x1c, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x08, 0x07, 0x08, 0x08, 0x08, 0x09, + 0x08, 0x0a, 0x08, 0x0b, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06, 0x04, 0x01, 0x05, 0x01, + 0x06, 0x01, 0x00, 0x2b, 0x00, 0x03, 0x02, 0x03, 0x04, 0x00, 0x2d, 0x00, 0x02, 0x01, + 0x01, 0x00, 0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0x35, 0x80, 0x72, + 0xd6, 0x36, 0x58, 0x80, 0xd1, 0xae, 0xea, 0x32, 0x9a, 0xdf, 0x91, 0x21, 0x38, 0x38, + 0x51, 0xed, 0x21, 0xa2, 0x8e, 0x3b, 0x75, 0xe9, 0x65, 0xd0, 0xd2, 0xcd, 0x16, 0x62, + 0x54, + ]; + + let mut cursor = Cursor::new(&client_hello); + + let startup = read_startup(&mut cursor).await.unwrap(); + let FeStartupPacket::SslRequest { + direct: Some(prefix), + } = startup + else { + panic!("unexpected startup message: {startup:?}"); + }; + + // check that no data is lost. + assert_eq!(prefix, [0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00]); + assert_eq!(cursor.position(), 8); + } + + #[tokio::test] + async fn read_message_success() { + let query = b"Q\0\0\0\x0cSELECT 1Q\0\0\0\x0cSELECT 2"; + let mut cursor = Cursor::new(&query); + + let mut buf = vec![]; + let (tag, message) = read_message(&mut cursor, &mut buf, 100).await.unwrap(); + assert_eq!(tag, b'Q'); + assert_eq!(message, b"SELECT 1"); + + let (tag, message) = read_message(&mut cursor, &mut buf, 100).await.unwrap(); + assert_eq!(tag, b'Q'); + assert_eq!(message, b"SELECT 2"); + } +} diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index e013fbbe2e..57785c9ec5 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -1,5 +1,4 @@ use async_trait::async_trait; -use pq_proto::StartupMessageParams; use tokio::time; use tracing::{debug, info, warn}; @@ -15,6 +14,7 @@ use crate::error::ReportableError; use crate::metrics::{ ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType, }; +use crate::pqproto::StartupMessageParams; use crate::proxy::retry::{CouldRetry, retry_after, should_retry}; use crate::proxy::wake_compute::wake_compute; use crate::types::Host; diff --git a/proxy/src/proxy/handshake.rs b/proxy/src/proxy/handshake.rs index 54c02f2c15..13ee8c7dd2 100644 --- a/proxy/src/proxy/handshake.rs +++ b/proxy/src/proxy/handshake.rs @@ -1,8 +1,3 @@ -use bytes::Buf; -use pq_proto::framed::Framed; -use pq_proto::{ - BeMessage as Be, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams, -}; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{debug, info, warn}; @@ -12,7 +7,10 @@ use crate::config::TlsConfig; use crate::context::RequestContext; use crate::error::ReportableError; use crate::metrics::Metrics; -use crate::proxy::ERR_INSECURE_CONNECTION; +use crate::pqproto::{ + BeMessage, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams, +}; +use crate::proxy::TlsRequired; use crate::stream::{PqStream, Stream, StreamUpgradeError}; use crate::tls::PG_ALPN_PROTOCOL; @@ -71,33 +69,25 @@ pub(crate) async fn handshake( const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0); const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0); - let mut stream = PqStream::new(Stream::from_raw(stream)); + let (mut stream, mut msg) = PqStream::parse_startup(Stream::from_raw(stream)).await?; loop { - let msg = stream.read_startup_packet().await?; match msg { FeStartupPacket::SslRequest { direct } => match stream.get_ref() { Stream::Raw { .. } if !tried_ssl => { tried_ssl = true; - // We can't perform TLS handshake without a config - let have_tls = tls.is_some(); - if !direct { - stream - .write_message(&Be::EncryptionResponse(have_tls)) - .await?; - } else if !have_tls { - return Err(HandshakeError::ProtocolViolation); - } - if let Some(tls) = tls.take() { // Upgrade raw stream into a secure TLS-backed stream. // NOTE: We've consumed `tls`; this fact will be used later. - let Framed { - stream: raw, - read_buf, - write_buf, - } = stream.framed; + let mut read_buf; + let raw = if let Some(direct) = &direct { + read_buf = &direct[..]; + stream.accept_direct_tls() + } else { + read_buf = &[]; + stream.accept_tls().await? + }; let Stream::Raw { raw } = raw else { return Err(HandshakeError::StreamUpgradeError( @@ -105,12 +95,11 @@ pub(crate) async fn handshake( )); }; - let mut read_buf = read_buf.reader(); let mut res = Ok(()); let accept = tokio_rustls::TlsAcceptor::from(tls.pg_config.clone()) .accept_with(raw, |session| { // push the early data to the tls session - while !read_buf.get_ref().is_empty() { + while !read_buf.is_empty() { match session.read_tls(&mut read_buf) { Ok(_) => {} Err(e) => { @@ -123,7 +112,6 @@ pub(crate) async fn handshake( res?; - let read_buf = read_buf.into_inner(); if !read_buf.is_empty() { return Err(HandshakeError::EarlyData); } @@ -157,16 +145,17 @@ pub(crate) async fn handshake( let (_, tls_server_end_point) = tls.cert_resolver.resolve(conn_info.server_name()); - stream = PqStream { - framed: Framed { - stream: Stream::Tls { - tls: Box::new(tls_stream), - tls_server_end_point, - }, - read_buf, - write_buf, - }, + let tls = Stream::Tls { + tls: Box::new(tls_stream), + tls_server_end_point, }; + (stream, msg) = PqStream::parse_startup(tls).await?; + } else { + if direct.is_some() { + // client sent us a ClientHello already, we can't do anything with it. + return Err(HandshakeError::ProtocolViolation); + } + msg = stream.reject_encryption().await?; } } _ => return Err(HandshakeError::ProtocolViolation), @@ -176,7 +165,7 @@ pub(crate) async fn handshake( tried_gss = true; // Currently, we don't support GSSAPI - stream.write_message(&Be::EncryptionResponse(false)).await?; + msg = stream.reject_encryption().await?; } _ => return Err(HandshakeError::ProtocolViolation), }, @@ -186,13 +175,7 @@ pub(crate) async fn handshake( // Check that the config has been consumed during upgrade // OR we didn't provide it at all (for dev purposes). if tls.is_some() { - return stream - .throw_error_str( - ERR_INSECURE_CONNECTION, - crate::error::ErrorKind::User, - None, - ) - .await?; + Err(stream.throw_error(TlsRequired, None).await)?; } // This log highlights the start of the connection. @@ -214,20 +197,21 @@ pub(crate) async fn handshake( // no protocol extensions are supported. // let mut unsupported = vec![]; - for (k, _) in params.iter() { + let mut supported = StartupMessageParams::default(); + + for (k, v) in params.iter() { if k.starts_with("_pq_.") { unsupported.push(k); + } else { + supported.insert(k, v); } } - // TODO: remove unsupported options so we don't send them to compute. - - stream - .write_message(&Be::NegotiateProtocolVersion { - version: PG_PROTOCOL_LATEST, - options: &unsupported, - }) - .await?; + stream.write_message(BeMessage::NegotiateProtocolVersion { + version: PG_PROTOCOL_LATEST, + options: &unsupported, + }); + stream.flush().await?; info!( ?version, @@ -235,7 +219,7 @@ pub(crate) async fn handshake( session_type = "normal", "successful handshake; unsupported minor version requested" ); - break Ok(HandshakeData::Startup(stream, params)); + break Ok(HandshakeData::Startup(stream, supported)); } FeStartupPacket::StartupMessage { version, params } => { warn!( diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 0a86022e78..ac0aca1176 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -10,15 +10,14 @@ pub(crate) mod wake_compute; use std::sync::Arc; pub use copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute}; -use futures::{FutureExt, TryFutureExt}; +use futures::FutureExt; use itertools::Itertools; use once_cell::sync::OnceCell; -use pq_proto::{BeMessage as Be, CancelKeyData, StartupMessageParams}; use regex::Regex; use serde::{Deserialize, Serialize}; use smol_str::{SmolStr, ToSmolStr, format_smolstr}; use thiserror::Error; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::sync::CancellationToken; use tracing::{Instrument, debug, error, info, warn}; @@ -27,8 +26,9 @@ use self::passthrough::ProxyPassthrough; use crate::cancellation::{self, CancellationHandler}; use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig}; use crate::context::RequestContext; -use crate::error::ReportableError; +use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumClientConnectionsGuard}; +use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams}; use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol}; use crate::proxy::handshake::{HandshakeData, handshake}; use crate::rate_limiter::EndpointRateLimiter; @@ -38,6 +38,18 @@ use crate::{auth, compute}; const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; +#[derive(Error, Debug)] +#[error("{ERR_INSECURE_CONNECTION}")] +pub struct TlsRequired; + +impl ReportableError for TlsRequired { + fn get_error_kind(&self) -> crate::error::ErrorKind { + crate::error::ErrorKind::User + } +} + +impl UserFacingError for TlsRequired {} + pub async fn run_until_cancelled( f: F, cancellation_token: &CancellationToken, @@ -329,11 +341,11 @@ pub(crate) async fn handle_client( let user_info = match result { Ok(user_info) => user_info, - Err(e) => stream.throw_error(e, Some(ctx)).await?, + Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, }; let user = user_info.get_user().to_owned(); - let (user_info, _ip_allowlist) = match user_info + let user_info = match user_info .authenticate( ctx, &mut stream, @@ -349,10 +361,10 @@ pub(crate) async fn handle_client( let app = params.get("application_name"); let params_span = tracing::info_span!("", ?user, ?db, ?app); - return stream + return Err(stream .throw_error(e, Some(ctx)) .instrument(params_span) - .await?; + .await)?; } }; @@ -365,7 +377,7 @@ pub(crate) async fn handle_client( .get(NeonOptions::PARAMS_COMPAT) .is_some(); - let mut node = connect_to_compute( + let res = connect_to_compute( ctx, &TcpMechanism { user_info: compute_user_info.clone(), @@ -377,22 +389,19 @@ pub(crate) async fn handle_client( config.wake_compute_retry_config, &config.connect_to_compute, ) - .or_else(|e| stream.throw_error(e, Some(ctx))) - .await?; + .await; + + let node = match res { + Ok(node) => node, + Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, + }; let cancellation_handler_clone = Arc::clone(&cancellation_handler); let session = cancellation_handler_clone.get_key(); session.write_cancel_key(node.cancel_closure.clone())?; - - prepare_client_connection(&node, *session.key(), &mut stream).await?; - - // Before proxy passing, forward to compute whatever data is left in the - // PqStream input buffer. Normally there is none, but our serverless npm - // driver in pipeline mode sends startup, password and first query - // immediately after opening the connection. - let (stream, read_buf) = stream.into_inner(); - node.stream.write_all(&read_buf).await?; + prepare_client_connection(&node, *session.key(), &mut stream); + let stream = stream.flush_and_into_inner().await?; let private_link_id = match ctx.extra() { Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()), @@ -413,31 +422,28 @@ pub(crate) async fn handle_client( } /// Finish client connection initialization: confirm auth success, send params, etc. -#[tracing::instrument(skip_all)] -pub(crate) async fn prepare_client_connection( +pub(crate) fn prepare_client_connection( node: &compute::PostgresConnection, cancel_key_data: CancelKeyData, stream: &mut PqStream, -) -> Result<(), std::io::Error> { +) { // Forward all deferred notices to the client. for notice in &node.delayed_notice { - stream.write_message_noflush(&Be::Raw(b'N', notice.as_bytes()))?; + stream.write_raw(notice.as_bytes().len(), b'N', |buf| { + buf.extend_from_slice(notice.as_bytes()); + }); } // Forward all postgres connection params to the client. for (name, value) in &node.params { - stream.write_message_noflush(&Be::ParameterStatus { + stream.write_message(BeMessage::ParameterStatus { name: name.as_bytes(), value: value.as_bytes(), - })?; + }); } - stream - .write_message_noflush(&Be::BackendKeyData(cancel_key_data))? - .write_message(&Be::ReadyForQuery) - .await?; - - Ok(()) + stream.write_message(BeMessage::BackendKeyData(cancel_key_data)); + stream.write_message(BeMessage::ReadyForQuery); } #[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)] diff --git a/proxy/src/proxy/retry.rs b/proxy/src/proxy/retry.rs index 0879564ced..01e603ec14 100644 --- a/proxy/src/proxy/retry.rs +++ b/proxy/src/proxy/retry.rs @@ -125,9 +125,10 @@ pub(crate) fn retry_after(num_retries: u32, config: RetryConfig) -> time::Durati #[cfg(test)] mod tests { - use super::ShouldRetryWakeCompute; use postgres_client::error::{DbError, SqlState}; + use super::ShouldRetryWakeCompute; + #[test] fn should_retry_wake_compute_for_db_error() { // These SQLStates should NOT trigger a wake_compute retry. diff --git a/proxy/src/proxy/tests/mitm.rs b/proxy/src/proxy/tests/mitm.rs index 59c9ac27b8..c92ee49b8d 100644 --- a/proxy/src/proxy/tests/mitm.rs +++ b/proxy/src/proxy/tests/mitm.rs @@ -10,7 +10,7 @@ use bytes::{Bytes, BytesMut}; use futures::{SinkExt, StreamExt}; use postgres_client::tls::TlsConnect; use postgres_protocol::message::frontend; -use tokio::io::{AsyncReadExt, DuplexStream}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream}; use tokio_util::codec::{Decoder, Encoder}; use super::*; @@ -49,15 +49,14 @@ async fn proxy_mitm( }; let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame); - let (end_client, buf) = end_client.framed.into_inner(); - assert!(buf.is_empty()); + let end_client = end_client.flush_and_into_inner().await.unwrap(); let mut end_client = tokio_util::codec::Framed::new(end_client, PgFrame); // give the end_server the startup parameters let mut buf = BytesMut::new(); frontend::startup_message( &postgres_protocol::message::frontend::StartupMessageParams { - params: startup.params.into(), + params: startup.params.as_bytes().into(), }, &mut buf, ) diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index be6426a63c..61e8ee4a10 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -26,9 +26,7 @@ use crate::auth::backend::{ use crate::config::{ComputeConfig, RetryConfig}; use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient}; use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status}; -use crate::control_plane::{ - self, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, NodeInfo, NodeInfoCache, -}; +use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache}; use crate::error::ErrorKind; use crate::tls::client_config::compute_client_config_with_certs; use crate::tls::postgres_rustls::MakeRustlsConnect; @@ -128,7 +126,7 @@ trait TestAuth: Sized { self, stream: &mut PqStream>, ) -> anyhow::Result<()> { - stream.write_message_noflush(&Be::AuthenticationOk)?; + stream.write_message(BeMessage::AuthenticationOk); Ok(()) } } @@ -157,9 +155,7 @@ impl TestAuth for Scram { self, stream: &mut PqStream>, ) -> anyhow::Result<()> { - let outcome = auth::AuthFlow::new(stream) - .begin(auth::Scram(&self.0, &RequestContext::test())) - .await? + let outcome = auth::AuthFlow::new(stream, auth::Scram(&self.0, &RequestContext::test())) .authenticate() .await?; @@ -185,10 +181,12 @@ async fn dummy_proxy( auth.authenticate(&mut stream).await?; - stream - .write_message_noflush(&Be::CLIENT_ENCODING)? - .write_message(&Be::ReadyForQuery) - .await?; + stream.write_message(BeMessage::ParameterStatus { + name: b"client_encoding", + value: b"UTF8", + }); + stream.write_message(BeMessage::ReadyForQuery); + stream.flush().await?; Ok(()) } @@ -547,20 +545,9 @@ impl TestControlPlaneClient for TestConnectMechanism { } } - fn get_allowed_ips(&self) -> Result { - unimplemented!("not used in tests") - } - - fn get_allowed_vpc_endpoint_ids( + fn get_access_control( &self, - ) -> Result { - unimplemented!("not used in tests") - } - - fn get_block_public_or_vpc_access( - &self, - ) -> Result - { + ) -> Result { unimplemented!("not used in tests") } diff --git a/proxy/src/rate_limiter/leaky_bucket.rs b/proxy/src/rate_limiter/leaky_bucket.rs index 4f27c6faef..0c79b5e92f 100644 --- a/proxy/src/rate_limiter/leaky_bucket.rs +++ b/proxy/src/rate_limiter/leaky_bucket.rs @@ -15,7 +15,7 @@ pub type EndpointRateLimiter = LeakyBucketRateLimiter; pub struct LeakyBucketRateLimiter { map: ClashMap, - config: utils::leaky_bucket::LeakyBucketConfig, + default_config: utils::leaky_bucket::LeakyBucketConfig, access_count: AtomicUsize, } @@ -28,15 +28,17 @@ impl LeakyBucketRateLimiter { pub fn new_with_shards(config: LeakyBucketConfig, shards: usize) -> Self { Self { map: ClashMap::with_hasher_and_shard_amount(RandomState::new(), shards), - config: config.into(), + default_config: config.into(), access_count: AtomicUsize::new(0), } } /// Check that number of connections to the endpoint is below `max_rps` rps. - pub(crate) fn check(&self, key: K, n: u32) -> bool { + pub(crate) fn check(&self, key: K, config: Option, n: u32) -> bool { let now = Instant::now(); + let config = config.map_or(self.default_config, Into::into); + if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 { self.do_gc(now); } @@ -46,7 +48,7 @@ impl LeakyBucketRateLimiter { .entry(key) .or_insert_with(|| LeakyBucketState { empty_at: now }); - entry.add_tokens(&self.config, now, n as f64).is_ok() + entry.add_tokens(&config, now, n as f64).is_ok() } fn do_gc(&self, now: Instant) { diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index 21eaa6739b..9d700c1b52 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -15,6 +15,8 @@ use tracing::info; use crate::ext::LockExt; use crate::intern::EndpointIdInt; +use super::LeakyBucketConfig; + pub struct GlobalRateLimiter { data: Vec, info: Vec, @@ -144,19 +146,6 @@ impl RateBucketInfo { Self::new(50_000, Duration::from_secs(10)), ]; - /// All of these are per endpoint-maskedip pair. - /// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus). - /// - /// First bucket: 1000mcpus total per endpoint-ip pair - /// * 4096000 requests per second with 1 hash rounds. - /// * 1000 requests per second with 4096 hash rounds. - /// * 6.8 requests per second with 600000 hash rounds. - pub const DEFAULT_AUTH_SET: [Self; 3] = [ - Self::new(1000 * 4096, Duration::from_secs(1)), - Self::new(600 * 4096, Duration::from_secs(60)), - Self::new(300 * 4096, Duration::from_secs(600)), - ]; - pub fn rps(&self) -> f64 { (self.max_rpi as f64) / self.interval.as_secs_f64() } @@ -184,6 +173,21 @@ impl RateBucketInfo { max_rpi: ((max_rps as u64) * (interval.as_millis() as u64) / 1000) as u32, } } + + pub fn to_leaky_bucket(this: &[Self]) -> Option { + // bit of a hack - find the min rps and max rps supported and turn it into + // leaky bucket config instead + + let mut iter = this.iter().map(|info| info.rps()); + let first = iter.next()?; + + let (min, max) = (first, first); + let (min, max) = iter.fold((min, max), |(min, max), rps| { + (f64::min(min, rps), f64::max(max, rps)) + }); + + Some(LeakyBucketConfig { rps: min, max }) + } } impl BucketRateLimiter { diff --git a/proxy/src/rate_limiter/mod.rs b/proxy/src/rate_limiter/mod.rs index 5f90102da3..112b95873a 100644 --- a/proxy/src/rate_limiter/mod.rs +++ b/proxy/src/rate_limiter/mod.rs @@ -8,4 +8,4 @@ pub(crate) use limit_algorithm::aimd::Aimd; pub(crate) use limit_algorithm::{ DynamicLimiter, Outcome, RateLimitAlgorithm, RateLimiterConfig, Token, }; -pub use limiter::{BucketRateLimiter, GlobalRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; +pub use limiter::{GlobalRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; diff --git a/proxy/src/redis/cancellation_publisher.rs b/proxy/src/redis/cancellation_publisher.rs index 186fece4b2..6f56aeea06 100644 --- a/proxy/src/redis/cancellation_publisher.rs +++ b/proxy/src/redis/cancellation_publisher.rs @@ -1,10 +1,11 @@ use core::net::IpAddr; use std::sync::Arc; -use pq_proto::CancelKeyData; use tokio::sync::Mutex; use uuid::Uuid; +use crate::pqproto::CancelKeyData; + pub trait CancellationPublisherMut: Send + Sync + 'static { #[allow(async_fn_in_trait)] async fn try_publish( diff --git a/proxy/src/redis/keys.rs b/proxy/src/redis/keys.rs index 7527bca6d0..3113bad949 100644 --- a/proxy/src/redis/keys.rs +++ b/proxy/src/redis/keys.rs @@ -1,16 +1,15 @@ use std::io::ErrorKind; use anyhow::Ok; -use pq_proto::{CancelKeyData, id_to_cancel_key}; -use serde::{Deserialize, Serialize}; + +use crate::pqproto::{CancelKeyData, id_to_cancel_key}; pub mod keyspace { pub const CANCEL_PREFIX: &str = "cancel"; } -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +#[derive(Clone, Debug, Eq, PartialEq)] pub(crate) enum KeyPrefix { - #[serde(untagged)] Cancel(CancelKeyData), } @@ -18,9 +17,7 @@ impl KeyPrefix { pub(crate) fn build_redis_key(&self) -> String { match self { KeyPrefix::Cancel(key) => { - let hi = (key.backend_pid as u64) << 32; - let lo = (key.cancel_key as u64) & 0xffff_ffff; - let id = hi | lo; + let id = key.0.get(); let keyspace = keyspace::CANCEL_PREFIX; format!("{keyspace}:{id:x}") } @@ -63,10 +60,7 @@ mod tests { #[test] fn test_build_redis_key() { - let cancel_key: KeyPrefix = KeyPrefix::Cancel(CancelKeyData { - backend_pid: 12345, - cancel_key: 54321, - }); + let cancel_key: KeyPrefix = KeyPrefix::Cancel(id_to_cancel_key(12345 << 32 | 54321)); let redis_key = cancel_key.build_redis_key(); assert_eq!(redis_key, "cancel:30390000d431"); @@ -77,10 +71,7 @@ mod tests { let redis_key = "cancel:30390000d431"; let key: KeyPrefix = parse_redis_key(redis_key).expect("Failed to parse key"); - let ref_key = CancelKeyData { - backend_pid: 12345, - cancel_key: 54321, - }; + let ref_key = id_to_cancel_key(12345 << 32 | 54321); assert_eq!(key.as_str(), KeyPrefix::Cancel(ref_key).as_str()); let KeyPrefix::Cancel(cancel_key) = key; diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 5f9f2509e2..a9d6b40603 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -2,11 +2,9 @@ use std::convert::Infallible; use std::sync::Arc; use futures::StreamExt; -use pq_proto::CancelKeyData; use redis::aio::PubSub; use serde::{Deserialize, Serialize}; use tokio_util::sync::CancellationToken; -use uuid::Uuid; use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::cache::project_info::ProjectInfoCache; @@ -100,14 +98,6 @@ pub(crate) struct PasswordUpdate { role_name: RoleNameInt, } -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] -pub(crate) struct CancelSession { - pub(crate) region_id: Option, - pub(crate) cancel_key_data: CancelKeyData, - pub(crate) session_id: Uuid, - pub(crate) peer_addr: Option, -} - fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result where T: for<'de2> serde::Deserialize<'de2>, @@ -243,29 +233,30 @@ impl MessageHandler { fn invalidate_cache(cache: Arc, msg: Notification) { match msg { - Notification::AllowedIpsUpdate { allowed_ips_update } => { - cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id); + Notification::AllowedIpsUpdate { + allowed_ips_update: AllowedIpsUpdate { project_id }, } - Notification::BlockPublicOrVpcAccessUpdated { - block_public_or_vpc_access_updated, - } => cache.invalidate_block_public_or_vpc_access_for_project( - block_public_or_vpc_access_updated.project_id, - ), + | Notification::BlockPublicOrVpcAccessUpdated { + block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated { project_id }, + } => cache.invalidate_endpoint_access_for_project(project_id), Notification::AllowedVpcEndpointsUpdatedForOrg { - allowed_vpc_endpoints_updated_for_org, - } => cache.invalidate_allowed_vpc_endpoint_ids_for_org( - allowed_vpc_endpoints_updated_for_org.account_id, - ), + allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg { account_id }, + } => cache.invalidate_endpoint_access_for_org(account_id), Notification::AllowedVpcEndpointsUpdatedForProjects { - allowed_vpc_endpoints_updated_for_projects, - } => cache.invalidate_allowed_vpc_endpoint_ids_for_projects( - allowed_vpc_endpoints_updated_for_projects.project_ids, - ), - Notification::PasswordUpdate { password_update } => cache - .invalidate_role_secret_for_project( - password_update.project_id, - password_update.role_name, - ), + allowed_vpc_endpoints_updated_for_projects: + AllowedVpcEndpointsUpdatedForProjects { project_ids }, + } => { + for project in project_ids { + cache.invalidate_endpoint_access_for_project(project); + } + } + Notification::PasswordUpdate { + password_update: + PasswordUpdate { + project_id, + role_name, + }, + } => cache.invalidate_role_secret_for_project(project_id, role_name), Notification::UnknownTopic => unreachable!(), } } diff --git a/proxy/src/sasl/messages.rs b/proxy/src/sasl/messages.rs index 7f2f3a761c..8d26a3f453 100644 --- a/proxy/src/sasl/messages.rs +++ b/proxy/src/sasl/messages.rs @@ -1,7 +1,5 @@ //! Definitions for SASL messages. -use pq_proto::{BeAuthenticationSaslMessage, BeMessage}; - use crate::parse::split_cstr; /// SASL-specific payload of [`PasswordMessage`](pq_proto::FeMessage::PasswordMessage). @@ -30,26 +28,6 @@ impl<'a> FirstMessage<'a> { } } -/// A single SASL message. -/// This struct is deliberately decoupled from lower-level -/// [`BeAuthenticationSaslMessage`]. -#[derive(Debug)] -pub(super) enum ServerMessage { - /// We expect to see more steps. - Continue(T), - /// This is the final step. - Final(T), -} - -impl<'a> ServerMessage<&'a str> { - pub(super) fn to_reply(&self) -> BeMessage<'a> { - BeMessage::AuthenticationSasl(match self { - ServerMessage::Continue(s) => BeAuthenticationSaslMessage::Continue(s.as_bytes()), - ServerMessage::Final(s) => BeAuthenticationSaslMessage::Final(s.as_bytes()), - }) - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/proxy/src/sasl/mod.rs b/proxy/src/sasl/mod.rs index f0181b404f..007b62dfd2 100644 --- a/proxy/src/sasl/mod.rs +++ b/proxy/src/sasl/mod.rs @@ -14,7 +14,7 @@ use std::io; pub(crate) use channel_binding::ChannelBinding; pub(crate) use messages::FirstMessage; -pub(crate) use stream::{Outcome, SaslStream}; +pub(crate) use stream::{Outcome, authenticate}; use thiserror::Error; use crate::error::{ReportableError, UserFacingError}; @@ -22,6 +22,9 @@ use crate::error::{ReportableError, UserFacingError}; /// Fine-grained auth errors help in writing tests. #[derive(Error, Debug)] pub(crate) enum Error { + #[error("Unsupported authentication method: {0}")] + BadAuthMethod(Box), + #[error("Channel binding failed: {0}")] ChannelBindingFailed(&'static str), @@ -54,6 +57,7 @@ impl UserFacingError for Error { impl ReportableError for Error { fn get_error_kind(&self) -> crate::error::ErrorKind { match self { + Error::BadAuthMethod(_) => crate::error::ErrorKind::User, Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User, Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User, Error::BadClientMessage(_) => crate::error::ErrorKind::User, diff --git a/proxy/src/sasl/stream.rs b/proxy/src/sasl/stream.rs index 46e6a439e5..cb15132673 100644 --- a/proxy/src/sasl/stream.rs +++ b/proxy/src/sasl/stream.rs @@ -3,61 +3,12 @@ use std::io; use tokio::io::{AsyncRead, AsyncWrite}; -use tracing::info; -use super::Mechanism; -use super::messages::ServerMessage; +use super::{Mechanism, Step}; +use crate::context::RequestContext; +use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage}; use crate::stream::PqStream; -/// Abstracts away all peculiarities of the libpq's protocol. -pub(crate) struct SaslStream<'a, S> { - /// The underlying stream. - stream: &'a mut PqStream, - /// Current password message we received from client. - current: bytes::Bytes, - /// First SASL message produced by client. - first: Option<&'a str>, -} - -impl<'a, S> SaslStream<'a, S> { - pub(crate) fn new(stream: &'a mut PqStream, first: &'a str) -> Self { - Self { - stream, - current: bytes::Bytes::new(), - first: Some(first), - } - } -} - -impl SaslStream<'_, S> { - // Receive a new SASL message from the client. - async fn recv(&mut self) -> io::Result<&str> { - if let Some(first) = self.first.take() { - return Ok(first); - } - - self.current = self.stream.read_password_message().await?; - let s = std::str::from_utf8(&self.current) - .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?; - - Ok(s) - } -} - -impl SaslStream<'_, S> { - // Send a SASL message to the client. - async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> { - self.stream.write_message(&msg.to_reply()).await?; - Ok(()) - } - - // Queue a SASL message for the client. - fn send_noflush(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> { - self.stream.write_message_noflush(&msg.to_reply())?; - Ok(()) - } -} - /// SASL authentication outcome. /// It's much easier to match on those two variants /// than to peek into a noisy protocol error type. @@ -69,33 +20,62 @@ pub(crate) enum Outcome { Failure(&'static str), } -impl SaslStream<'_, S> { - /// Perform SASL message exchange according to the underlying algorithm - /// until user is either authenticated or denied access. - pub(crate) async fn authenticate( - mut self, - mut mechanism: M, - ) -> super::Result> { - loop { - let input = self.recv().await?; - let step = mechanism.exchange(input).map_err(|error| { - info!(?error, "error during SASL exchange"); - error - })?; +pub async fn authenticate( + ctx: &RequestContext, + stream: &mut PqStream, + mechanism: F, +) -> super::Result> +where + S: AsyncRead + AsyncWrite + Unpin, + F: FnOnce(&str) -> super::Result, + M: Mechanism, +{ + let sasl = { + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - use super::Step; - return Ok(match step { - Step::Continue(moved_mechanism, reply) => { - self.send(&ServerMessage::Continue(&reply)).await?; - mechanism = moved_mechanism; - continue; - } - Step::Success(result, reply) => { - self.send_noflush(&ServerMessage::Final(&reply))?; - Outcome::Success(result) - } - Step::Failure(reason) => Outcome::Failure(reason), - }); + // Initial client message contains the chosen auth method's name. + let msg = stream.read_password_message().await?; + super::FirstMessage::parse(msg).ok_or(super::Error::BadClientMessage("bad sasl message"))? + }; + + let mut mechanism = mechanism(sasl.method)?; + let mut input = sasl.message; + loop { + let step = mechanism + .exchange(input) + .inspect_err(|error| tracing::info!(?error, "error during SASL exchange"))?; + + match step { + Step::Continue(moved_mechanism, reply) => { + mechanism = moved_mechanism; + + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + + // write reply + let sasl_msg = BeAuthenticationSaslMessage::Continue(reply.as_bytes()); + stream.write_message(BeMessage::AuthenticationSasl(sasl_msg)); + + // get next input + stream.flush().await?; + let msg = stream.read_password_message().await?; + input = std::str::from_utf8(msg) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?; + } + Step::Success(result, reply) => { + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + + // write reply + let sasl_msg = BeAuthenticationSaslMessage::Final(reply.as_bytes()); + stream.write_message(BeMessage::AuthenticationSasl(sasl_msg)); + stream.write_message(BeMessage::AuthenticationOk); + // exit with success + break Ok(Outcome::Success(result)); + } + // exit with failure + Step::Failure(reason) => break Ok(Outcome::Failure(reason)), } } } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 13058f08f1..bf640c05e9 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -22,7 +22,7 @@ use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client}; use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool}; use crate::auth::backend::local::StaticAuthRules; use crate::auth::backend::{ComputeCredentials, ComputeUserInfo}; -use crate::auth::{self, AuthError, check_peer_addr_is_in_list}; +use crate::auth::{self, AuthError}; use crate::compute; use crate::compute_ctl::{ ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest, @@ -35,7 +35,6 @@ use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError}; use crate::control_plane::locks::ApiLocks; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::intern::EndpointIdInt; -use crate::protocol2::ConnectionInfoExtra; use crate::proxy::connect_compute::ConnectMechanism; use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute}; use crate::rate_limiter::EndpointRateLimiter; @@ -63,63 +62,24 @@ impl PoolingBackend { let user_info = user_info.clone(); let backend = self.auth_backend.as_ref().map(|()| user_info.clone()); - let allowed_ips = backend.get_allowed_ips(ctx).await?; + let access_control = backend.get_endpoint_access_control(ctx).await?; + access_control.check( + ctx, + self.config.authentication_config.ip_allowlist_check_enabled, + self.config.authentication_config.is_vpc_acccess_proxy, + )?; - if self.config.authentication_config.ip_allowlist_check_enabled - && !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) - { - return Err(AuthError::ip_address_not_allowed(ctx.peer_addr())); - } - - let access_blocker_flags = backend.get_block_public_or_vpc_access(ctx).await?; - if self.config.authentication_config.is_vpc_acccess_proxy { - if access_blocker_flags.vpc_access_blocked { - return Err(AuthError::NetworkNotAllowed); - } - - let extra = ctx.extra(); - let incoming_endpoint_id = match extra { - None => String::new(), - Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(), - Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), - }; - - if incoming_endpoint_id.is_empty() { - return Err(AuthError::MissingVPCEndpointId); - } - - let allowed_vpc_endpoint_ids = backend.get_allowed_vpc_endpoint_ids(ctx).await?; - // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. - if !allowed_vpc_endpoint_ids.is_empty() - && !allowed_vpc_endpoint_ids.contains(&incoming_endpoint_id) - { - return Err(AuthError::vpc_endpoint_id_not_allowed(incoming_endpoint_id)); - } - } else if access_blocker_flags.public_access_blocked { - return Err(AuthError::NetworkNotAllowed); - } - - if !self - .endpoint_rate_limiter - .check(user_info.endpoint.clone().into(), 1) - { + let ep = EndpointIdInt::from(&user_info.endpoint); + let rate_limit_config = None; + if !self.endpoint_rate_limiter.check(ep, rate_limit_config, 1) { return Err(AuthError::too_many_connections()); } - let cached_secret = backend.get_role_secret(ctx).await?; - let secret = match cached_secret.value.clone() { - Some(secret) => self.config.authentication_config.check_rate_limit( - ctx, - secret, - &user_info.endpoint, - true, - )?, - None => { - // If we don't have an authentication secret, for the http flow we can just return an error. - info!("authentication info not found"); - return Err(AuthError::password_failed(&*user_info.user)); - } + let role_access = backend.get_role_secret(ctx).await?; + let Some(secret) = role_access.secret else { + // If we don't have an authentication secret, for the http flow we can just return an error. + info!("authentication info not found"); + return Err(AuthError::password_failed(&*user_info.user)); }; - let ep = EndpointIdInt::from(&user_info.endpoint); let auth_outcome = crate::auth::validate_password_and_exchange( &self.config.authentication_config.thread_pool, ep, diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 1c5bb64480..eb80ac9ad0 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -17,7 +17,6 @@ use postgres_client::error::{DbError, ErrorPosition, SqlState}; use postgres_client::{ GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction, }; -use pq_proto::StartupMessageParamsBuilder; use serde::Serialize; use serde_json::Value; use serde_json::value::RawValue; @@ -41,6 +40,7 @@ use crate::context::RequestContext; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::http::{ReadBodyError, read_body_with_limit}; use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind}; +use crate::pqproto::StartupMessageParams; use crate::proxy::{NeonOptions, run_until_cancelled}; use crate::serverless::backend::HttpConnError; use crate::types::{DbName, RoleName}; @@ -219,7 +219,7 @@ fn get_conn_info( let mut options = Option::None; - let mut params = StartupMessageParamsBuilder::default(); + let mut params = StartupMessageParams::default(); params.insert("user", &username); params.insert("database", &dbname); for (key, value) in pairs { diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 360550b0ac..7126430a85 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -2,19 +2,17 @@ use std::pin::Pin; use std::sync::Arc; use std::{io, task}; -use bytes::BytesMut; -use pq_proto::framed::{ConnectionError, Framed}; -use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError}; use rustls::ServerConfig; -use serde::{Deserialize, Serialize}; use thiserror::Error; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; use tokio_rustls::server::TlsStream; -use tracing::debug; -use crate::control_plane::messages::ColdStartInfo; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::metrics::Metrics; +use crate::pqproto::{ + BeMessage, FE_PASSWORD_MESSAGE, FeStartupPacket, SQLSTATE_INTERNAL_ERROR, WriteBuf, + read_message, read_startup, +}; use crate::tls::TlsServerEndPoint; /// Stream wrapper which implements libpq's protocol. @@ -23,58 +21,77 @@ use crate::tls::TlsServerEndPoint; /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying /// to pass random malformed bytes through the connection). pub struct PqStream { - pub(crate) framed: Framed, + stream: S, + read: Vec, + write: WriteBuf, } impl PqStream { - /// Construct a new libpq protocol wrapper. - pub fn new(stream: S) -> Self { + pub fn get_ref(&self) -> &S { + &self.stream + } + + /// Construct a new libpq protocol wrapper over a stream without the first startup message. + #[cfg(test)] + pub fn new_skip_handshake(stream: S) -> Self { Self { - framed: Framed::new(stream), + stream, + read: Vec::new(), + write: WriteBuf::new(), } } - - /// Extract the underlying stream and read buffer. - pub fn into_inner(self) -> (S, BytesMut) { - self.framed.into_inner() - } - - /// Get a shared reference to the underlying stream. - pub(crate) fn get_ref(&self) -> &S { - self.framed.get_ref() - } } -fn err_connection() -> io::Error { - io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost") +impl PqStream { + /// Construct a new libpq protocol wrapper and read the first startup message. + /// + /// This is not cancel safe. + pub async fn parse_startup(mut stream: S) -> io::Result<(Self, FeStartupPacket)> { + let startup = read_startup(&mut stream).await?; + Ok(( + Self { + stream, + read: Vec::new(), + write: WriteBuf::new(), + }, + startup, + )) + } + + /// Tell the client that encryption is not supported. + /// + /// This is not cancel safe + pub async fn reject_encryption(&mut self) -> io::Result { + // N for No. + self.write.encryption(b'N'); + self.flush().await?; + read_startup(&mut self.stream).await + } } impl PqStream { - /// Receive [`FeStartupPacket`], which is a first packet sent by a client. - pub async fn read_startup_packet(&mut self) -> io::Result { - self.framed - .read_startup_message() - .await - .map_err(ConnectionError::into_io_error)? - .ok_or_else(err_connection) - } - - async fn read_message(&mut self) -> io::Result { - self.framed - .read_message() - .await - .map_err(ConnectionError::into_io_error)? - .ok_or_else(err_connection) - } - - pub(crate) async fn read_password_message(&mut self) -> io::Result { - match self.read_message().await? { - FeMessage::PasswordMessage(msg) => Ok(msg), - bad => Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("unexpected message type: {bad:?}"), - )), + /// Read a raw postgres packet, which will respect the max length requested. + /// This is not cancel safe. + async fn read_raw_expect(&mut self, tag: u8, max: usize) -> io::Result<&mut [u8]> { + let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?; + if actual_tag != tag { + return Err(io::Error::other(format!( + "incorrect message tag, expected {:?}, got {:?}", + tag as char, actual_tag as char, + ))); } + Ok(msg) + } + + /// Read a postgres password message, which will respect the max length requested. + /// This is not cancel safe. + pub async fn read_password_message(&mut self) -> io::Result<&mut [u8]> { + // passwords are usually pretty short + // and SASL SCRAM messages are no longer than 256 bytes in my testing + // (a few hashes and random bytes, encoded into base64). + const MAX_PASSWORD_LENGTH: usize = 512; + self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH) + .await } } @@ -84,6 +101,16 @@ pub struct ReportedError { error_kind: ErrorKind, } +impl ReportedError { + pub fn new(e: (impl UserFacingError + Into)) -> Self { + let error_kind = e.get_error_kind(); + Self { + source: e.into(), + error_kind, + } + } +} + impl std::fmt::Display for ReportedError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.source.fmt(f) @@ -102,109 +129,65 @@ impl ReportableError for ReportedError { } } -#[derive(Serialize, Deserialize, Debug)] -enum ErrorTag { - #[serde(rename = "proxy")] - Proxy, - #[serde(rename = "compute")] - Compute, - #[serde(rename = "client")] - Client, - #[serde(rename = "controlplane")] - ControlPlane, - #[serde(rename = "other")] - Other, -} - -impl From for ErrorTag { - fn from(error_kind: ErrorKind) -> Self { - match error_kind { - ErrorKind::User => Self::Client, - ErrorKind::ClientDisconnect => Self::Client, - ErrorKind::RateLimit => Self::Proxy, - ErrorKind::ServiceRateLimit => Self::Proxy, // considering rate limit as proxy error for SLI - ErrorKind::Quota => Self::Proxy, - ErrorKind::Service => Self::Proxy, - ErrorKind::ControlPlane => Self::ControlPlane, - ErrorKind::Postgres => Self::Other, - ErrorKind::Compute => Self::Compute, - } - } -} - -#[derive(Serialize, Deserialize, Debug)] -#[serde(rename_all = "snake_case")] -struct ProbeErrorData { - tag: ErrorTag, - msg: String, - cold_start_info: Option, -} - impl PqStream { - /// Write the message into an internal buffer, but don't flush the underlying stream. - pub(crate) fn write_message_noflush( - &mut self, - message: &BeMessage<'_>, - ) -> io::Result<&mut Self> { - self.framed - .write_message(message) - .map_err(ProtocolError::into_io_error)?; - Ok(self) + /// Tell the client that we are willing to accept SSL. + /// This is not cancel safe + pub async fn accept_tls(mut self) -> io::Result { + // S for SSL. + self.write.encryption(b'S'); + self.flush().await?; + Ok(self.stream) } - /// Write the message into an internal buffer and flush it. - pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> { - self.write_message_noflush(message)?; - self.flush().await?; - Ok(self) + /// Assert that we are using direct TLS. + pub fn accept_direct_tls(self) -> S { + self.stream + } + + /// Write a raw message to the internal buffer. + pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec)) { + self.write.write_raw(size_hint, tag, f); + } + + /// Write the message into an internal buffer + pub fn write_message(&mut self, message: BeMessage<'_>) { + message.write_message(&mut self.write); } /// Flush the output buffer into the underlying stream. - pub(crate) async fn flush(&mut self) -> io::Result<&mut Self> { - self.framed.flush().await?; - Ok(self) + /// + /// This is cancel safe. + pub async fn flush(&mut self) -> io::Result<()> { + self.stream.write_all_buf(&mut self.write).await?; + self.write.reset(); + + self.stream.flush().await?; + + Ok(()) } - /// Writes message with the given error kind to the stream. - /// Used only for probe queries - async fn write_format_message( - &mut self, - msg: &str, - error_kind: ErrorKind, - ctx: Option<&crate::context::RequestContext>, - ) -> String { - let formatted_msg = match ctx { - Some(ctx) if ctx.get_testodrome_id().is_some() => { - serde_json::to_string(&ProbeErrorData { - tag: ErrorTag::from(error_kind), - msg: msg.to_string(), - cold_start_info: Some(ctx.cold_start_info()), - }) - .unwrap_or_default() - } - _ => msg.to_string(), - }; - - // already error case, ignore client IO error - self.write_message(&BeMessage::ErrorResponse(&formatted_msg, None)) - .await - .inspect_err(|e| debug!("write_message failed: {e}")) - .ok(); - - formatted_msg + /// Flush the output buffer into the underlying stream. + /// + /// This is cancel safe. + pub async fn flush_and_into_inner(mut self) -> io::Result { + self.flush().await?; + Ok(self.stream) } - /// Write the error message using [`Self::write_format_message`], then re-throw it. - /// Allowing string literals is safe under the assumption they might not contain any runtime info. - /// This method exists due to `&str` not implementing `Into`. + /// Write the error message to the client, then re-throw it. + /// + /// Trait [`UserFacingError`] acts as an allowlist for error types. /// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind. - pub async fn throw_error_str( + pub(crate) async fn throw_error( &mut self, - msg: &'static str, - error_kind: ErrorKind, + error: E, ctx: Option<&crate::context::RequestContext>, - ) -> Result { - self.write_format_message(msg, error_kind, ctx).await; + ) -> ReportedError + where + E: UserFacingError + Into, + { + let error_kind = error.get_error_kind(); + let msg = error.to_string_client(); if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User { tracing::info!( @@ -214,39 +197,39 @@ impl PqStream { ); } - Err(ReportedError { - source: anyhow::anyhow!(msg), - error_kind, - }) - } - - /// Write the error message using [`Self::write_format_message`], then re-throw it. - /// Trait [`UserFacingError`] acts as an allowlist for error types. - /// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind. - pub(crate) async fn throw_error( - &mut self, - error: E, - ctx: Option<&crate::context::RequestContext>, - ) -> Result - where - E: UserFacingError + Into, - { - let error_kind = error.get_error_kind(); - let msg = error.to_string_client(); - self.write_format_message(&msg, error_kind, ctx).await; - if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User { - tracing::info!( - kind=error_kind.to_metric_label(), - error=%error, - msg, - "forwarding error to user", - ); + let probe_msg; + let mut msg = &*msg; + if let Some(ctx) = ctx { + if ctx.get_testodrome_id().is_some() { + let tag = match error_kind { + ErrorKind::User => "client", + ErrorKind::ClientDisconnect => "client", + ErrorKind::RateLimit => "proxy", + ErrorKind::ServiceRateLimit => "proxy", + ErrorKind::Quota => "proxy", + ErrorKind::Service => "proxy", + ErrorKind::ControlPlane => "controlplane", + ErrorKind::Postgres => "other", + ErrorKind::Compute => "compute", + }; + probe_msg = typed_json::json!({ + "tag": tag, + "msg": msg, + "cold_start_info": ctx.cold_start_info(), + }) + .to_string(); + msg = &probe_msg; + } } - Err(ReportedError { - source: anyhow::anyhow!(error), - error_kind, - }) + // TODO: either preserve the error code from postgres, or assign error codes to proxy errors. + self.write.write_error(msg, SQLSTATE_INTERNAL_ERROR); + + self.flush() + .await + .unwrap_or_else(|e| tracing::debug!("write_message failed: {e}")); + + ReportedError::new(error) } } diff --git a/test_runner/regress/test_pageserver_secondary.py b/test_runner/regress/test_pageserver_secondary.py index e5908de363..8d18311f3d 100644 --- a/test_runner/regress/test_pageserver_secondary.py +++ b/test_runner/regress/test_pageserver_secondary.py @@ -124,6 +124,9 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver, ".*downloading failed, possibly for shutdown", # {tenant_id=... timeline_id=...}:handle_pagerequests:handle_get_page_at_lsn_request{rel=1664/0/1260 blkno=0 req_lsn=0/149F0D8}: error reading relation or page version: Not found: will not become active. Current state: Stopping\n' ".*page_service.*will not become active.*", + # the following errors are possible when pageserver tries to ingest wal records despite being in unreadable state + ".*wal_connection_manager.*layer file download failed: No file found.*", + ".*wal_connection_manager.*could not ingest record.*", ] )