From ec991877f451893d81db5856f18ae65070baa211 Mon Sep 17 00:00:00 2001 From: Erik Grinaker Date: Mon, 26 May 2025 10:27:48 +0200 Subject: [PATCH 01/48] pageserver: add gRPC server (#11972) ## Problem We want to expose the page service over gRPC, for use with the communicator. Requires #11995. Touches #11728. ## Summary of changes This patch wires up a gRPC server in the Pageserver, using Tonic. It does not yet implement the actual page service. * Adds `listen_grpc_addr` and `grpc_auth_type` config options (disabled by default). * Enables gRPC by default with `neon_local`. * Stub implementation of `page_api.PageService`, returning unimplemented errors. * gRPC reflection service for use with e.g. `grpcurl`. Subsequent PRs will implement the actual page service, including authentication and observability. Notably, TLS support is not yet implemented. Certificate reloading requires us to reimplement the entire Tonic gRPC server. --- Cargo.lock | 21 ++++ Cargo.toml | 3 +- control_plane/safekeepers.conf | 2 + control_plane/simple.conf | 2 + control_plane/src/bin/neon_local.rs | 4 + control_plane/src/local_env.rs | 16 +++ control_plane/src/pageserver.rs | 4 +- libs/pageserver_api/src/config.rs | 6 + pageserver/Cargo.toml | 3 + pageserver/src/bin/pageserver.rs | 69 ++++++++--- pageserver/src/config.rs | 19 ++- pageserver/src/lib.rs | 11 ++ pageserver/src/page_service.rs | 167 +++++++++++++++++++++++++- pageserver/src/task_mgr.rs | 7 +- test_runner/fixtures/neon_fixtures.py | 2 + workspace_hack/Cargo.toml | 6 +- 16 files changed, 312 insertions(+), 30 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 422af2c97e..ddca5bbd3f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4321,6 +4321,7 @@ dependencies = [ "pageserver_api", "pageserver_client", "pageserver_compaction", + "pageserver_page_api", "pem", "pin-project-lite", "postgres-protocol", @@ -4363,6 +4364,8 @@ dependencies = [ "tokio-tar", "tokio-util", "toml_edit", + "tonic 0.13.1", + "tonic-reflection", "tracing", "tracing-utils", "twox-hash", @@ -7520,8 +7523,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7e581ba15a835f4d9ea06c55ab1bd4dce26fc53752c69a04aac00703bfb49ba9" dependencies = [ "async-trait", + "axum", "base64 0.22.1", "bytes", + "h2 0.4.4", "http 1.1.0", "http-body 1.0.0", "http-body-util", @@ -7532,6 +7537,7 @@ dependencies = [ "pin-project", "prost 0.13.5", "rustls-native-certs 0.8.0", + "socket2", "tokio", "tokio-rustls 0.26.2", "tokio-stream", @@ -7555,6 +7561,19 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "tonic-reflection" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9687bd5bfeafebdded2356950f278bba8226f0b32109537c4253406e09aafe1" +dependencies = [ + "prost 0.13.5", + "prost-types 0.13.3", + "tokio", + "tokio-stream", + "tonic 0.13.1", +] + [[package]] name = "tower" version = "0.4.13" @@ -8526,6 +8545,8 @@ dependencies = [ "ahash", "anstream", "anyhow", + "axum", + "axum-core", "base64 0.13.1", "base64 0.21.7", "base64ct", diff --git a/Cargo.toml b/Cargo.toml index c8e2c38c85..d2c8e86bd4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -199,7 +199,8 @@ tokio-tar = "0.3" tokio-util = { version = "0.7.10", features = ["io", "rt"] } toml = "0.8" toml_edit = "0.22" -tonic = { version = "0.13.1", default-features = false, features = ["channel", "codegen", "prost", "tls-ring", "tls-native-roots"] } +tonic = { version = "0.13.1", default-features = false, features = ["channel", "codegen", "prost", "router", "server", "tls-ring", "tls-native-roots"] } +tonic-reflection = { version = "0.13.1", features = ["server"] } tower = { version = "0.5.2", default-features = false } tower-http = { version = "0.6.2", features = ["auth", "request-id", "trace"] } diff --git a/control_plane/safekeepers.conf b/control_plane/safekeepers.conf index 576cc4a3a9..a73e274dfa 100644 --- a/control_plane/safekeepers.conf +++ b/control_plane/safekeepers.conf @@ -2,8 +2,10 @@ [pageserver] listen_pg_addr = '127.0.0.1:64000' listen_http_addr = '127.0.0.1:9898' +listen_grpc_addr = '127.0.0.1:51051' pg_auth_type = 'Trust' http_auth_type = 'Trust' +grpc_auth_type = 'Trust' [[safekeepers]] id = 1 diff --git a/control_plane/simple.conf b/control_plane/simple.conf index 0ad90a4618..1eb21f846e 100644 --- a/control_plane/simple.conf +++ b/control_plane/simple.conf @@ -4,8 +4,10 @@ id=1 listen_pg_addr = '127.0.0.1:64000' listen_http_addr = '127.0.0.1:9898' +listen_grpc_addr = '127.0.0.1:51051' pg_auth_type = 'Trust' http_auth_type = 'Trust' +grpc_auth_type = 'Trust' [[safekeepers]] id = 1 diff --git a/control_plane/src/bin/neon_local.rs b/control_plane/src/bin/neon_local.rs index 98ab6e5657..3bceef8fa7 100644 --- a/control_plane/src/bin/neon_local.rs +++ b/control_plane/src/bin/neon_local.rs @@ -32,6 +32,7 @@ use control_plane::storage_controller::{ }; use nix::fcntl::{Flock, FlockArg}; use pageserver_api::config::{ + DEFAULT_GRPC_LISTEN_PORT as DEFAULT_PAGESERVER_GRPC_PORT, DEFAULT_HTTP_LISTEN_PORT as DEFAULT_PAGESERVER_HTTP_PORT, DEFAULT_PG_LISTEN_PORT as DEFAULT_PAGESERVER_PG_PORT, }; @@ -1007,13 +1008,16 @@ fn handle_init(args: &InitCmdArgs) -> anyhow::Result { let pageserver_id = NodeId(DEFAULT_PAGESERVER_ID.0 + i as u64); let pg_port = DEFAULT_PAGESERVER_PG_PORT + i; let http_port = DEFAULT_PAGESERVER_HTTP_PORT + i; + let grpc_port = DEFAULT_PAGESERVER_GRPC_PORT + i; NeonLocalInitPageserverConf { id: pageserver_id, listen_pg_addr: format!("127.0.0.1:{pg_port}"), listen_http_addr: format!("127.0.0.1:{http_port}"), listen_https_addr: None, + listen_grpc_addr: Some(format!("127.0.0.1:{grpc_port}")), pg_auth_type: AuthType::Trust, http_auth_type: AuthType::Trust, + grpc_auth_type: AuthType::Trust, other: Default::default(), // Typical developer machines use disks with slow fsync, and we don't care // about data integrity: disable disk syncs. diff --git a/control_plane/src/local_env.rs b/control_plane/src/local_env.rs index 4a8892c6de..47b77f0720 100644 --- a/control_plane/src/local_env.rs +++ b/control_plane/src/local_env.rs @@ -278,8 +278,10 @@ pub struct PageServerConf { pub listen_pg_addr: String, pub listen_http_addr: String, pub listen_https_addr: Option, + pub listen_grpc_addr: Option, pub pg_auth_type: AuthType, pub http_auth_type: AuthType, + pub grpc_auth_type: AuthType, pub no_sync: bool, } @@ -290,8 +292,10 @@ impl Default for PageServerConf { listen_pg_addr: String::new(), listen_http_addr: String::new(), listen_https_addr: None, + listen_grpc_addr: None, pg_auth_type: AuthType::Trust, http_auth_type: AuthType::Trust, + grpc_auth_type: AuthType::Trust, no_sync: false, } } @@ -306,8 +310,10 @@ pub struct NeonLocalInitPageserverConf { pub listen_pg_addr: String, pub listen_http_addr: String, pub listen_https_addr: Option, + pub listen_grpc_addr: Option, pub pg_auth_type: AuthType, pub http_auth_type: AuthType, + pub grpc_auth_type: AuthType, #[serde(default, skip_serializing_if = "std::ops::Not::not")] pub no_sync: bool, #[serde(flatten)] @@ -321,8 +327,10 @@ impl From<&NeonLocalInitPageserverConf> for PageServerConf { listen_pg_addr, listen_http_addr, listen_https_addr, + listen_grpc_addr, pg_auth_type, http_auth_type, + grpc_auth_type, no_sync, other: _, } = conf; @@ -331,7 +339,9 @@ impl From<&NeonLocalInitPageserverConf> for PageServerConf { listen_pg_addr: listen_pg_addr.clone(), listen_http_addr: listen_http_addr.clone(), listen_https_addr: listen_https_addr.clone(), + listen_grpc_addr: listen_grpc_addr.clone(), pg_auth_type: *pg_auth_type, + grpc_auth_type: *grpc_auth_type, http_auth_type: *http_auth_type, no_sync: *no_sync, } @@ -707,8 +717,10 @@ impl LocalEnv { listen_pg_addr: String, listen_http_addr: String, listen_https_addr: Option, + listen_grpc_addr: Option, pg_auth_type: AuthType, http_auth_type: AuthType, + grpc_auth_type: AuthType, #[serde(default)] no_sync: bool, } @@ -732,8 +744,10 @@ impl LocalEnv { listen_pg_addr, listen_http_addr, listen_https_addr, + listen_grpc_addr, pg_auth_type, http_auth_type, + grpc_auth_type, no_sync, } = config_toml; let IdentityTomlSubset { @@ -750,8 +764,10 @@ impl LocalEnv { listen_pg_addr, listen_http_addr, listen_https_addr, + listen_grpc_addr, pg_auth_type, http_auth_type, + grpc_auth_type, no_sync, }; pageservers.push(conf); diff --git a/control_plane/src/pageserver.rs b/control_plane/src/pageserver.rs index 756f2b02db..29314dab9e 100644 --- a/control_plane/src/pageserver.rs +++ b/control_plane/src/pageserver.rs @@ -129,7 +129,9 @@ impl PageServerNode { )); } - if conf.http_auth_type != AuthType::Trust || conf.pg_auth_type != AuthType::Trust { + if [conf.http_auth_type, conf.pg_auth_type, conf.grpc_auth_type] + .contains(&AuthType::NeonJWT) + { // Keys are generated in the toplevel repo dir, pageservers' workdirs // are one level below that, so refer to keys with ../ overrides.push("auth_validation_public_key_path='../auth_public_key.pem'".to_owned()); diff --git a/libs/pageserver_api/src/config.rs b/libs/pageserver_api/src/config.rs index 0fb2ff38ff..daec65ce2d 100644 --- a/libs/pageserver_api/src/config.rs +++ b/libs/pageserver_api/src/config.rs @@ -8,6 +8,8 @@ pub const DEFAULT_PG_LISTEN_PORT: u16 = 64000; pub const DEFAULT_PG_LISTEN_ADDR: &str = formatcp!("127.0.0.1:{DEFAULT_PG_LISTEN_PORT}"); pub const DEFAULT_HTTP_LISTEN_PORT: u16 = 9898; pub const DEFAULT_HTTP_LISTEN_ADDR: &str = formatcp!("127.0.0.1:{DEFAULT_HTTP_LISTEN_PORT}"); +// TODO: gRPC is disabled by default for now, but the port is used in neon_local. +pub const DEFAULT_GRPC_LISTEN_PORT: u16 = 51051; // storage-broker already uses 50051 use std::collections::HashMap; use std::num::{NonZeroU64, NonZeroUsize}; @@ -104,6 +106,7 @@ pub struct ConfigToml { pub listen_pg_addr: String, pub listen_http_addr: String, pub listen_https_addr: Option, + pub listen_grpc_addr: Option, pub ssl_key_file: Utf8PathBuf, pub ssl_cert_file: Utf8PathBuf, #[serde(with = "humantime_serde")] @@ -123,6 +126,7 @@ pub struct ConfigToml { pub http_auth_type: AuthType, #[serde_as(as = "serde_with::DisplayFromStr")] pub pg_auth_type: AuthType, + pub grpc_auth_type: AuthType, pub auth_validation_public_key_path: Option, pub remote_storage: Option, pub tenant_config: TenantConfigToml, @@ -588,6 +592,7 @@ impl Default for ConfigToml { listen_pg_addr: (DEFAULT_PG_LISTEN_ADDR.to_string()), listen_http_addr: (DEFAULT_HTTP_LISTEN_ADDR.to_string()), listen_https_addr: (None), + listen_grpc_addr: None, // TODO: default to 127.0.0.1:51051 ssl_key_file: Utf8PathBuf::from(DEFAULT_SSL_KEY_FILE), ssl_cert_file: Utf8PathBuf::from(DEFAULT_SSL_CERT_FILE), ssl_cert_reload_period: Duration::from_secs(60), @@ -604,6 +609,7 @@ impl Default for ConfigToml { pg_distrib_dir: None, // Utf8PathBuf::from("./pg_install"), // TODO: formely, this was std::env::current_dir() http_auth_type: (AuthType::Trust), pg_auth_type: (AuthType::Trust), + grpc_auth_type: (AuthType::Trust), auth_validation_public_key_path: (None), remote_storage: None, broker_endpoint: (storage_broker::DEFAULT_ENDPOINT diff --git a/pageserver/Cargo.toml b/pageserver/Cargo.toml index 6a9a5a292a..1f5cc89b33 100644 --- a/pageserver/Cargo.toml +++ b/pageserver/Cargo.toml @@ -43,6 +43,7 @@ nix.workspace = true num_cpus.workspace = true num-traits.workspace = true once_cell.workspace = true +pageserver_page_api.workspace = true pin-project-lite.workspace = true postgres_backend.workspace = true postgres-protocol.workspace = true @@ -71,6 +72,8 @@ tokio-rustls.workspace = true tokio-stream.workspace = true tokio-util.workspace = true toml_edit = { workspace = true, features = [ "serde" ] } +tonic.workspace = true +tonic-reflection.workspace = true tracing.workspace = true tracing-utils.workspace = true url.workspace = true diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index 6001ea0345..8d76d0d678 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -388,23 +388,30 @@ fn start_pageserver( // We need to release the lock file only when the process exits. std::mem::forget(lock_file); - // Bind the HTTP and libpq ports early, so that if they are in use by some other - // process, we error out early. - let http_addr = &conf.listen_http_addr; - info!("Starting pageserver http handler on {http_addr}"); - let http_listener = tcp_listener::bind(http_addr)?; + // Bind the HTTP, libpq, and gRPC ports early, to error out if they are + // already in use. + info!( + "Starting pageserver http handler on {} with auth {:#?}", + conf.listen_http_addr, conf.http_auth_type + ); + let http_listener = tcp_listener::bind(&conf.listen_http_addr)?; let https_listener = match conf.listen_https_addr.as_ref() { Some(https_addr) => { - info!("Starting pageserver https handler on {https_addr}"); + info!( + "Starting pageserver https handler on {https_addr} with auth {:#?}", + conf.http_auth_type + ); Some(tcp_listener::bind(https_addr)?) } None => None, }; - let pg_addr = &conf.listen_pg_addr; - info!("Starting pageserver pg protocol handler on {pg_addr}"); - let pageserver_listener = tcp_listener::bind(pg_addr)?; + info!( + "Starting pageserver pg protocol handler on {} with auth {:#?}", + conf.listen_pg_addr, conf.pg_auth_type, + ); + let pageserver_listener = tcp_listener::bind(&conf.listen_pg_addr)?; // Enable SO_KEEPALIVE on the socket, to detect dead connections faster. // These are configured via net.ipv4.tcp_keepalive_* sysctls. @@ -413,6 +420,15 @@ fn start_pageserver( // support enabling keepalives while using the default OS sysctls. setsockopt(&pageserver_listener, sockopt::KeepAlive, &true)?; + let mut grpc_listener = None; + if let Some(grpc_addr) = &conf.listen_grpc_addr { + info!( + "Starting pageserver gRPC handler on {grpc_addr} with auth {:#?}", + conf.grpc_auth_type + ); + grpc_listener = Some(tcp_listener::bind(grpc_addr).map_err(|e| anyhow!("{e}"))?); + } + // Launch broker client // The storage_broker::connect call needs to happen inside a tokio runtime thread. let broker_client = WALRECEIVER_RUNTIME @@ -440,7 +456,8 @@ fn start_pageserver( // Initialize authentication for incoming connections let http_auth; let pg_auth; - if conf.http_auth_type == AuthType::NeonJWT || conf.pg_auth_type == AuthType::NeonJWT { + let grpc_auth; + if [conf.http_auth_type, conf.pg_auth_type, conf.grpc_auth_type].contains(&AuthType::NeonJWT) { // unwrap is ok because check is performed when creating config, so path is set and exists let key_path = conf.auth_validation_public_key_path.as_ref().unwrap(); info!("Loading public key(s) for verifying JWT tokens from {key_path:?}"); @@ -448,20 +465,23 @@ fn start_pageserver( let jwt_auth = JwtAuth::from_key_path(key_path)?; let auth: Arc = Arc::new(SwappableJwtAuth::new(jwt_auth)); - http_auth = match &conf.http_auth_type { + http_auth = match conf.http_auth_type { AuthType::Trust => None, AuthType::NeonJWT => Some(auth.clone()), }; - pg_auth = match &conf.pg_auth_type { + pg_auth = match conf.pg_auth_type { + AuthType::Trust => None, + AuthType::NeonJWT => Some(auth.clone()), + }; + grpc_auth = match conf.grpc_auth_type { AuthType::Trust => None, AuthType::NeonJWT => Some(auth), }; } else { http_auth = None; pg_auth = None; + grpc_auth = None; } - info!("Using auth for http API: {:#?}", conf.http_auth_type); - info!("Using auth for pg connections: {:#?}", conf.pg_auth_type); let tls_server_config = if conf.listen_https_addr.is_some() || conf.enable_tls_page_service_api { @@ -776,9 +796,27 @@ fn start_pageserver( } else { None }, - basebackup_cache, + basebackup_cache.clone(), ); + // Spawn a Pageserver gRPC server task. It will spawn separate tasks for + // each stream/request. + // + // TODO: this uses a separate Tokio runtime for the page service. If we want + // other gRPC services, they will need their own port and runtime. Is this + // necessary? + let mut page_service_grpc = None; + if let Some(grpc_listener) = grpc_listener { + page_service_grpc = Some(page_service::spawn_grpc( + conf, + tenant_manager.clone(), + grpc_auth, + otel_guard.as_ref().map(|g| g.dispatch.clone()), + grpc_listener, + basebackup_cache, + )?); + } + // All started up! Now just sit and wait for shutdown signal. BACKGROUND_RUNTIME.block_on(async move { let signal_token = CancellationToken::new(); @@ -797,6 +835,7 @@ fn start_pageserver( http_endpoint_listener, https_endpoint_listener, page_service, + page_service_grpc, consumption_metrics_tasks, disk_usage_eviction_task, &tenant_manager, diff --git a/pageserver/src/config.rs b/pageserver/src/config.rs index e8b3b7b3ab..e8af548ec4 100644 --- a/pageserver/src/config.rs +++ b/pageserver/src/config.rs @@ -58,11 +58,16 @@ pub struct PageServerConf { pub listen_http_addr: String, /// Example: 127.0.0.1:9899 pub listen_https_addr: Option, + /// If set, expose a gRPC API on this address. + /// Example: 127.0.0.1:51051 + /// + /// EXPERIMENTAL: this protocol is unstable and under active development. + pub listen_grpc_addr: Option, - /// Path to a file with certificate's private key for https API. + /// Path to a file with certificate's private key for https and gRPC API. /// Default: server.key pub ssl_key_file: Utf8PathBuf, - /// Path to a file with a X509 certificate for https API. + /// Path to a file with a X509 certificate for https and gRPC API. /// Default: server.crt pub ssl_cert_file: Utf8PathBuf, /// Period to reload certificate and private key from files. @@ -100,6 +105,8 @@ pub struct PageServerConf { pub http_auth_type: AuthType, /// authentication method for libpq connections from compute pub pg_auth_type: AuthType, + /// authentication method for gRPC connections from compute + pub grpc_auth_type: AuthType, /// Path to a file or directory containing public key(s) for verifying JWT tokens. /// Used for both mgmt and compute auth, if enabled. pub auth_validation_public_key_path: Option, @@ -355,6 +362,7 @@ impl PageServerConf { listen_pg_addr, listen_http_addr, listen_https_addr, + listen_grpc_addr, ssl_key_file, ssl_cert_file, ssl_cert_reload_period, @@ -369,6 +377,7 @@ impl PageServerConf { pg_distrib_dir, http_auth_type, pg_auth_type, + grpc_auth_type, auth_validation_public_key_path, remote_storage, broker_endpoint, @@ -423,6 +432,7 @@ impl PageServerConf { listen_pg_addr, listen_http_addr, listen_https_addr, + listen_grpc_addr, ssl_key_file, ssl_cert_file, ssl_cert_reload_period, @@ -435,6 +445,7 @@ impl PageServerConf { max_file_descriptors, http_auth_type, pg_auth_type, + grpc_auth_type, auth_validation_public_key_path, remote_storage_config: remote_storage, broker_endpoint, @@ -531,7 +542,9 @@ impl PageServerConf { // custom validation code that covers more than one field in isolation // ------------------------------------------------------------ - if conf.http_auth_type == AuthType::NeonJWT || conf.pg_auth_type == AuthType::NeonJWT { + if [conf.http_auth_type, conf.pg_auth_type, conf.grpc_auth_type] + .contains(&AuthType::NeonJWT) + { let auth_validation_public_key_path = conf .auth_validation_public_key_path .get_or_insert_with(|| workdir.join("auth_public_key.pem")); diff --git a/pageserver/src/lib.rs b/pageserver/src/lib.rs index 71d9c6603f..25461c23ab 100644 --- a/pageserver/src/lib.rs +++ b/pageserver/src/lib.rs @@ -84,6 +84,7 @@ pub async fn shutdown_pageserver( http_listener: HttpEndpointListener, https_listener: Option, page_service: page_service::Listener, + grpc_task: Option, consumption_metrics_worker: ConsumptionMetricsTasks, disk_usage_eviction_task: Option, tenant_manager: &TenantManager, @@ -177,6 +178,16 @@ pub async fn shutdown_pageserver( ) .await; + // Shut down the gRPC server task, including request handlers. + if let Some(grpc_task) = grpc_task { + timed( + grpc_task.shutdown(), + "shutdown gRPC PageRequestHandler", + Duration::from_secs(3), + ) + .await; + } + // Shut down all the tenants. This flushes everything to disk and kills // the checkpoint and GC tasks. timed( diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index 69519dfa87..34dc158694 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -4,6 +4,7 @@ use std::borrow::Cow; use std::num::NonZeroUsize; use std::os::fd::AsRawFd; +use std::pin::Pin; use std::str::FromStr; use std::sync::Arc; use std::time::{Duration, Instant, SystemTime}; @@ -12,7 +13,7 @@ use std::{io, str}; use anyhow::{Context, bail}; use async_compression::tokio::write::GzipEncoder; use bytes::Buf; -use futures::FutureExt; +use futures::{FutureExt, Stream}; use itertools::Itertools; use jsonwebtoken::TokenData; use once_cell::sync::OnceCell; @@ -30,6 +31,7 @@ use pageserver_api::models::{ }; use pageserver_api::reltag::SlruKind; use pageserver_api::shard::TenantShardId; +use pageserver_page_api::proto; use postgres_backend::{ AuthType, PostgresBackend, PostgresBackendReader, QueryError, is_expected_io_error, }; @@ -51,9 +53,8 @@ use utils::simple_rcu::RcuReadGuard; use utils::sync::gate::{Gate, GateGuard}; use utils::sync::spsc_fold; -use crate::PERF_TRACE_TARGET; use crate::auth::check_permission; -use crate::basebackup::BasebackupError; +use crate::basebackup::{self, BasebackupError}; use crate::basebackup_cache::BasebackupCache; use crate::config::PageServerConf; use crate::context::{ @@ -75,7 +76,7 @@ use crate::tenant::mgr::{ use crate::tenant::storage_layer::IoConcurrency; use crate::tenant::timeline::{self, WaitLsnError}; use crate::tenant::{GetTimelineError, PageReconstructError, Timeline}; -use crate::{basebackup, timed_after_cancellation}; +use crate::{CancellableTask, PERF_TRACE_TARGET, timed_after_cancellation}; /// How long we may wait for a [`crate::tenant::mgr::TenantSlot::InProgress`]` and/or a [`crate::tenant::TenantShard`] which /// is not yet in state [`TenantState::Active`]. @@ -86,6 +87,26 @@ const ACTIVE_TENANT_TIMEOUT: Duration = Duration::from_millis(30000); /// Threshold at which to log slow GetPage requests. const LOG_SLOW_GETPAGE_THRESHOLD: Duration = Duration::from_secs(30); +/// The idle time before sending TCP keepalive probes for gRPC connections. The +/// interval and timeout between each probe is configured via sysctl. This +/// allows detecting dead connections sooner. +const GRPC_TCP_KEEPALIVE_TIME: Duration = Duration::from_secs(60); + +/// Whether to enable TCP nodelay for gRPC connections. This disables Nagle's +/// algorithm, which can cause latency spikes for small messages. +const GRPC_TCP_NODELAY: bool = true; + +/// The interval between HTTP2 keepalive pings. This allows shutting down server +/// tasks when clients are unresponsive. +const GRPC_HTTP2_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(30); + +/// The timeout for HTTP2 keepalive pings. Should be <= GRPC_KEEPALIVE_INTERVAL. +const GRPC_HTTP2_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(20); + +/// Number of concurrent gRPC streams per TCP connection. We expect something +/// like 8 GetPage streams per connections, plus any unary requests. +const GRPC_MAX_CONCURRENT_STREAMS: u32 = 256; + /////////////////////////////////////////////////////////////////////////////// pub struct Listener { @@ -140,6 +161,83 @@ pub fn spawn( Listener { cancel, task } } +/// Spawns a gRPC server for the page service. +/// +/// TODO: this doesn't support TLS. We need TLS reloading via ReloadingCertificateResolver, so we +/// need to reimplement the TCP+TLS accept loop ourselves. +pub fn spawn_grpc( + conf: &'static PageServerConf, + tenant_manager: Arc, + auth: Option>, + perf_trace_dispatch: Option, + listener: std::net::TcpListener, + basebackup_cache: Arc, +) -> anyhow::Result { + let cancel = CancellationToken::new(); + let ctx = RequestContextBuilder::new(TaskKind::PageRequestHandler) + .download_behavior(DownloadBehavior::Download) + .perf_span_dispatch(perf_trace_dispatch) + .detached_child(); + let gate = Gate::default(); + + // Set up the TCP socket. We take a preconfigured TcpListener to bind the + // port early during startup. + let incoming = { + let _runtime = COMPUTE_REQUEST_RUNTIME.enter(); // required by TcpListener::from_std + listener.set_nonblocking(true)?; + tonic::transport::server::TcpIncoming::from(tokio::net::TcpListener::from_std(listener)?) + .with_nodelay(Some(GRPC_TCP_NODELAY)) + .with_keepalive(Some(GRPC_TCP_KEEPALIVE_TIME)) + }; + + // Set up the gRPC server. + // + // TODO: consider tuning window sizes. + // TODO: wire up tracing. + let mut server = tonic::transport::Server::builder() + .http2_keepalive_interval(Some(GRPC_HTTP2_KEEPALIVE_INTERVAL)) + .http2_keepalive_timeout(Some(GRPC_HTTP2_KEEPALIVE_TIMEOUT)) + .max_concurrent_streams(Some(GRPC_MAX_CONCURRENT_STREAMS)); + + // Main page service. + let page_service = proto::PageServiceServer::new(PageServerHandler::new( + tenant_manager, + auth, + PageServicePipeliningConfig::Serial, // TODO: unused with gRPC + conf.get_vectored_concurrent_io, + ConnectionPerfSpanFields::default(), + basebackup_cache, + ctx, + cancel.clone(), + gate.enter().expect("just created"), + )); + let server = server.add_service(page_service); + + // Reflection service for use with e.g. grpcurl. + let reflection_service = tonic_reflection::server::Builder::configure() + .register_encoded_file_descriptor_set(proto::FILE_DESCRIPTOR_SET) + .build_v1()?; + let server = server.add_service(reflection_service); + + // Spawn server task. + let task_cancel = cancel.clone(); + let task = COMPUTE_REQUEST_RUNTIME.spawn(task_mgr::exit_on_panic_or_error( + "grpc listener", + async move { + let result = server + .serve_with_incoming_shutdown(incoming, task_cancel.cancelled()) + .await; + if result.is_ok() { + // TODO: revisit shutdown logic once page service is implemented. + gate.close().await; + } + result + }, + )); + + Ok(CancellableTask { task, cancel }) +} + impl Listener { pub async fn stop_accepting(self) -> Connections { self.cancel.cancel(); @@ -259,7 +357,7 @@ type ConnectionHandlerResult = anyhow::Result<()>; /// Perf root spans start at the per-request level, after shard routing. /// This struct carries connection-level information to the root perf span definition. -#[derive(Clone)] +#[derive(Clone, Default)] struct ConnectionPerfSpanFields { peer_addr: String, application_name: Option, @@ -377,6 +475,11 @@ async fn page_service_conn_main( } } +/// Page service connection handler. +/// +/// TODO: for gRPC, this will be shared by all requests from all connections. +/// Decompose it into global state and per-connection/request state, and make +/// libpq-specific options (e.g. pipelining) separate. struct PageServerHandler { auth: Option>, claims: Option, @@ -3117,6 +3220,60 @@ where } } +/// Implements the page service over gRPC. +/// +/// TODO: not yet implemented, all methods return unimplemented. +#[tonic::async_trait] +impl proto::PageService for PageServerHandler { + type GetBaseBackupStream = Pin< + Box> + Send>, + >; + type GetPagesStream = + Pin> + Send>>; + + async fn check_rel_exists( + &self, + _: tonic::Request, + ) -> Result, tonic::Status> { + Err(tonic::Status::unimplemented("not implemented")) + } + + async fn get_base_backup( + &self, + _: tonic::Request, + ) -> Result, tonic::Status> { + Err(tonic::Status::unimplemented("not implemented")) + } + + async fn get_db_size( + &self, + _: tonic::Request, + ) -> Result, tonic::Status> { + Err(tonic::Status::unimplemented("not implemented")) + } + + async fn get_pages( + &self, + _: tonic::Request>, + ) -> Result, tonic::Status> { + Err(tonic::Status::unimplemented("not implemented")) + } + + async fn get_rel_size( + &self, + _: tonic::Request, + ) -> Result, tonic::Status> { + Err(tonic::Status::unimplemented("not implemented")) + } + + async fn get_slru_segment( + &self, + _: tonic::Request, + ) -> Result, tonic::Status> { + Err(tonic::Status::unimplemented("not implemented")) + } +} + impl From for QueryError { fn from(e: GetActiveTenantError) -> Self { match e { diff --git a/pageserver/src/task_mgr.rs b/pageserver/src/task_mgr.rs index 55272b2125..29897af642 100644 --- a/pageserver/src/task_mgr.rs +++ b/pageserver/src/task_mgr.rs @@ -276,9 +276,10 @@ pub enum TaskKind { // HTTP endpoint listener. HttpEndpointListener, - // Task that handles a single connection. A PageRequestHandler task - // starts detached from any particular tenant or timeline, but it can be - // associated with one later, after receiving a command from the client. + /// Task that handles a single page service connection. A PageRequestHandler + /// task starts detached from any particular tenant or timeline, but it can + /// be associated with one later, after receiving a command from the client. + /// Also used for the gRPC page service API, including the main server task. PageRequestHandler, /// Manages the WAL receiver connection for one timeline. diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 5c92f2e2d0..dda4d40a11 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -1224,6 +1224,7 @@ class NeonEnv: # Create config for pageserver http_auth_type = "NeonJWT" if config.auth_enabled else "Trust" pg_auth_type = "NeonJWT" if config.auth_enabled else "Trust" + grpc_auth_type = "NeonJWT" if config.auth_enabled else "Trust" for ps_id in range( self.BASE_PAGESERVER_ID, self.BASE_PAGESERVER_ID + config.num_pageservers ): @@ -1250,6 +1251,7 @@ class NeonEnv: else None, "pg_auth_type": pg_auth_type, "http_auth_type": http_auth_type, + "grpc_auth_type": grpc_auth_type, "availability_zone": availability_zone, # Disable pageserver disk syncs in tests: when running tests concurrently, this avoids # the pageserver taking a long time to start up due to syncfs flushing other tests' data diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 9e1123ac0e..726d7c20c9 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -18,6 +18,8 @@ license.workspace = true ahash = { version = "0.8" } anstream = { version = "0.6" } anyhow = { version = "1", features = ["backtrace"] } +axum = { version = "0.8", features = ["ws"] } +axum-core = { version = "0.5", default-features = false, features = ["tracing"] } base64-594e8ee84c453af0 = { package = "base64", version = "0.13", features = ["alloc"] } base64-647d43efb71741da = { package = "base64", version = "0.21" } base64ct = { version = "1", default-features = false, features = ["std"] } @@ -52,7 +54,7 @@ hex = { version = "0.4", features = ["serde"] } hmac = { version = "0.12", default-features = false, features = ["reset"] } hyper-582f2526e08bb6a0 = { package = "hyper", version = "0.14", features = ["client", "http1", "http2", "runtime", "server", "stream"] } hyper-dff4ba8e3ae991db = { package = "hyper", version = "1", features = ["full"] } -hyper-util = { version = "0.1", features = ["client-legacy", "http1", "http2", "server", "service"] } +hyper-util = { version = "0.1", features = ["client-legacy", "server-auto", "service"] } indexmap = { version = "2", features = ["serde"] } itertools = { version = "0.12" } lazy_static = { version = "1", default-features = false, features = ["spin_no_std"] } @@ -98,7 +100,7 @@ tikv-jemalloc-sys = { version = "0.6", features = ["profiling", "stats", "unpref time = { version = "0.3", features = ["macros", "serde-well-known"] } tokio = { version = "1", features = ["full", "test-util"] } tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring", "tls12"] } -tokio-stream = { version = "0.1" } +tokio-stream = { version = "0.1", features = ["net"] } tokio-util = { version = "0.7", features = ["codec", "compat", "io", "rt"] } toml_edit = { version = "0.22", features = ["serde"] } tower = { version = "0.5", default-features = false, features = ["balance", "buffer", "limit", "log"] } From a082f9814ad85248c8fcd152d34b20fc0fa1855a Mon Sep 17 00:00:00 2001 From: Erik Grinaker Date: Mon, 26 May 2025 12:24:45 +0200 Subject: [PATCH 02/48] pageserver: add gRPC authentication (#12010) ## Problem We need authentication for the gRPC server. Requires #11972. Touches #11728. ## Summary of changes Add two request interceptors that decode the tenant/timeline/shard metadata and authenticate the JWT token against them. --- pageserver/src/page_service.rs | 119 +++++++++++++++++++++++++++++++-- 1 file changed, 115 insertions(+), 4 deletions(-) diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index 34dc158694..06aa207f82 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -43,12 +43,14 @@ use strum_macros::IntoStaticStr; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufWriter}; use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; +use tonic::service::Interceptor as _; use tracing::*; use utils::auth::{Claims, Scope, SwappableJwtAuth}; use utils::failpoint_support; -use utils::id::{TenantId, TimelineId}; +use utils::id::{TenantId, TenantTimelineId, TimelineId}; use utils::logging::log_slow; use utils::lsn::Lsn; +use utils::shard::ShardIndex; use utils::simple_rcu::RcuReadGuard; use utils::sync::gate::{Gate, GateGuard}; use utils::sync::spsc_fold; @@ -200,9 +202,9 @@ pub fn spawn_grpc( .max_concurrent_streams(Some(GRPC_MAX_CONCURRENT_STREAMS)); // Main page service. - let page_service = proto::PageServiceServer::new(PageServerHandler::new( + let page_service_handler = PageServerHandler::new( tenant_manager, - auth, + auth.clone(), PageServicePipeliningConfig::Serial, // TODO: unused with gRPC conf.get_vectored_concurrent_io, ConnectionPerfSpanFields::default(), @@ -210,7 +212,18 @@ pub fn spawn_grpc( ctx, cancel.clone(), gate.enter().expect("just created"), - )); + ); + + let mut tenant_interceptor = TenantMetadataInterceptor; + let mut auth_interceptor = TenantAuthInterceptor::new(auth); + let interceptors = move |mut req: tonic::Request<()>| { + req = tenant_interceptor.call(req)?; + req = auth_interceptor.call(req)?; + Ok(req) + }; + + let page_service = + proto::PageServiceServer::with_interceptor(page_service_handler, interceptors); let server = server.add_service(page_service); // Reflection service for use with e.g. grpcurl. @@ -3290,6 +3303,104 @@ impl From for QueryError { } } +/// gRPC interceptor that decodes tenant metadata and stores it as request extensions of type +/// TenantTimelineId and ShardIndex. +/// +/// TODO: consider looking up the timeline handle here and storing it. +#[derive(Clone)] +struct TenantMetadataInterceptor; + +impl tonic::service::Interceptor for TenantMetadataInterceptor { + fn call(&mut self, mut req: tonic::Request<()>) -> Result, tonic::Status> { + // Decode the tenant ID. + let tenant_id = req + .metadata() + .get("neon-tenant-id") + .ok_or_else(|| tonic::Status::invalid_argument("missing neon-tenant-id"))? + .to_str() + .map_err(|_| tonic::Status::invalid_argument("invalid neon-tenant-id"))?; + let tenant_id = TenantId::from_str(tenant_id) + .map_err(|_| tonic::Status::invalid_argument("invalid neon-tenant-id"))?; + + // Decode the timeline ID. + let timeline_id = req + .metadata() + .get("neon-timeline-id") + .ok_or_else(|| tonic::Status::invalid_argument("missing neon-timeline-id"))? + .to_str() + .map_err(|_| tonic::Status::invalid_argument("invalid neon-timeline-id"))?; + let timeline_id = TimelineId::from_str(timeline_id) + .map_err(|_| tonic::Status::invalid_argument("invalid neon-timeline-id"))?; + + // Decode the shard ID. + let shard_index = req + .metadata() + .get("neon-shard-id") + .ok_or_else(|| tonic::Status::invalid_argument("missing neon-shard-id"))? + .to_str() + .map_err(|_| tonic::Status::invalid_argument("invalid neon-shard-id"))?; + let shard_index = ShardIndex::from_str(shard_index) + .map_err(|_| tonic::Status::invalid_argument("invalid neon-shard-id"))?; + + // Stash them in the request. + let extensions = req.extensions_mut(); + extensions.insert(TenantTimelineId::new(tenant_id, timeline_id)); + extensions.insert(shard_index); + + Ok(req) + } +} + +/// Authenticates gRPC page service requests. Must run after TenantMetadataInterceptor. +#[derive(Clone)] +struct TenantAuthInterceptor { + auth: Option>, +} + +impl TenantAuthInterceptor { + fn new(auth: Option>) -> Self { + Self { auth } + } +} + +impl tonic::service::Interceptor for TenantAuthInterceptor { + fn call(&mut self, req: tonic::Request<()>) -> Result, tonic::Status> { + // Do nothing if auth is disabled. + let Some(auth) = self.auth.as_ref() else { + return Ok(req); + }; + + // Fetch the tenant ID that's been set by TenantMetadataInterceptor. + let ttid = req + .extensions() + .get::() + .expect("TenantMetadataInterceptor must run before TenantAuthInterceptor"); + + // Fetch and decode the JWT token. + let jwt = req + .metadata() + .get("authorization") + .ok_or_else(|| tonic::Status::unauthenticated("no authorization header"))? + .to_str() + .map_err(|_| tonic::Status::invalid_argument("invalid authorization header"))? + .strip_prefix("Bearer ") + .ok_or_else(|| tonic::Status::invalid_argument("invalid authorization header"))? + .trim(); + let jwtdata: TokenData = auth + .decode(jwt) + .map_err(|err| tonic::Status::invalid_argument(format!("invalid JWT token: {err}")))?; + let claims = jwtdata.claims; + + // Check if the token is valid for this tenant. + check_permission(&claims, Some(ttid.tenant_id)) + .map_err(|err| tonic::Status::permission_denied(err.to_string()))?; + + // TODO: consider stashing the claims in the request extensions, if needed. + + Ok(req) + } +} + #[derive(Debug, thiserror::Error)] pub(crate) enum GetActiveTimelineError { #[error(transparent)] From 7cd0defaf0b71d72b8954915317dbad65f730143 Mon Sep 17 00:00:00 2001 From: Erik Grinaker Date: Mon, 26 May 2025 13:01:36 +0200 Subject: [PATCH 03/48] page_api: add Rust domain types (#11999) ## Problem For the gRPC Pageserver API, we should convert the Protobuf types to stricter, canonical Rust types. Touches https://github.com/neondatabase/neon/issues/11728. ## Summary of changes Adds Rust domain types that mirror the Protobuf types, with conversion and validation. --- Cargo.lock | 6 + pageserver/page_api/Cargo.toml | 6 + pageserver/page_api/src/lib.rs | 4 + pageserver/page_api/src/model.rs | 581 +++++++++++++++++++++++++++++++ 4 files changed, 597 insertions(+) create mode 100644 pageserver/page_api/src/model.rs diff --git a/Cargo.lock b/Cargo.lock index ddca5bbd3f..21c863ff95 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4458,9 +4458,15 @@ dependencies = [ name = "pageserver_page_api" version = "0.1.0" dependencies = [ + "bytes", + "pageserver_api", + "postgres_ffi", "prost 0.13.5", + "smallvec", + "thiserror 1.0.69", "tonic 0.13.1", "tonic-build", + "utils", "workspace_hack", ] diff --git a/pageserver/page_api/Cargo.toml b/pageserver/page_api/Cargo.toml index c237949226..4f62c77eb2 100644 --- a/pageserver/page_api/Cargo.toml +++ b/pageserver/page_api/Cargo.toml @@ -5,8 +5,14 @@ edition.workspace = true license.workspace = true [dependencies] +bytes.workspace = true +pageserver_api.workspace = true +postgres_ffi.workspace = true prost.workspace = true +smallvec.workspace = true +thiserror.workspace = true tonic.workspace = true +utils.workspace = true workspace_hack.workspace = true [build-dependencies] diff --git a/pageserver/page_api/src/lib.rs b/pageserver/page_api/src/lib.rs index 0b68d03aaa..f515f27f3e 100644 --- a/pageserver/page_api/src/lib.rs +++ b/pageserver/page_api/src/lib.rs @@ -17,3 +17,7 @@ pub mod proto { pub use page_service_client::PageServiceClient; pub use page_service_server::{PageService, PageServiceServer}; } + +mod model; + +pub use model::*; diff --git a/pageserver/page_api/src/model.rs b/pageserver/page_api/src/model.rs new file mode 100644 index 0000000000..a83d0a5947 --- /dev/null +++ b/pageserver/page_api/src/model.rs @@ -0,0 +1,581 @@ +//! Structs representing the canonical page service API. +//! +//! These mirror the autogenerated Protobuf types. The differences are: +//! +//! - Types that are in fact required by the API are not Options. The protobuf "required" +//! attribute is deprecated and 'prost' marks a lot of members as optional because of that. +//! (See for a gripe on this) +//! +//! - Use more precise datatypes, e.g. Lsn and uints shorter than 32 bits. +//! +//! - Validate protocol invariants, via try_from() and try_into(). + +use bytes::Bytes; +use postgres_ffi::Oid; +use smallvec::SmallVec; +// TODO: split out Lsn, RelTag, SlruKind, Oid and other basic types to a separate crate, to avoid +// pulling in all of their other crate dependencies when building the client. +use utils::lsn::Lsn; + +use crate::proto; + +/// A protocol error. Typically returned via try_from() or try_into(). +#[derive(thiserror::Error, Debug)] +pub enum ProtocolError { + #[error("field '{0}' has invalid value '{1}'")] + Invalid(&'static str, String), + #[error("required field '{0}' is missing")] + Missing(&'static str), +} + +impl ProtocolError { + /// Helper to generate a new ProtocolError::Invalid for the given field and value. + pub fn invalid(field: &'static str, value: impl std::fmt::Debug) -> Self { + Self::Invalid(field, format!("{value:?}")) + } +} + +/// The LSN a request should read at. +#[derive(Clone, Copy, Debug)] +pub struct ReadLsn { + /// The request's read LSN. + pub request_lsn: Lsn, + /// If given, the caller guarantees that the page has not been modified since this LSN. Must be + /// smaller than or equal to request_lsn. This allows the Pageserver to serve an old page + /// without waiting for the request LSN to arrive. Valid for all request types. + /// + /// It is undefined behaviour to make a request such that the page was, in fact, modified + /// between request_lsn and not_modified_since_lsn. The Pageserver might detect it and return an + /// error, or it might return the old page version or the new page version. Setting + /// not_modified_since_lsn equal to request_lsn is always safe, but can lead to unnecessary + /// waiting. + pub not_modified_since_lsn: Option, +} + +impl ReadLsn { + /// Validates the ReadLsn. + pub fn validate(&self) -> Result<(), ProtocolError> { + if self.request_lsn == Lsn::INVALID { + return Err(ProtocolError::invalid("request_lsn", self.request_lsn)); + } + if self.not_modified_since_lsn > Some(self.request_lsn) { + return Err(ProtocolError::invalid( + "not_modified_since_lsn", + self.not_modified_since_lsn, + )); + } + Ok(()) + } +} + +impl TryFrom for ReadLsn { + type Error = ProtocolError; + + fn try_from(pb: proto::ReadLsn) -> Result { + let read_lsn = Self { + request_lsn: Lsn(pb.request_lsn), + not_modified_since_lsn: match pb.not_modified_since_lsn { + 0 => None, + lsn => Some(Lsn(lsn)), + }, + }; + read_lsn.validate()?; + Ok(read_lsn) + } +} + +impl TryFrom for proto::ReadLsn { + type Error = ProtocolError; + + fn try_from(read_lsn: ReadLsn) -> Result { + read_lsn.validate()?; + Ok(Self { + request_lsn: read_lsn.request_lsn.0, + not_modified_since_lsn: read_lsn.not_modified_since_lsn.unwrap_or_default().0, + }) + } +} + +// RelTag is defined in pageserver_api::reltag. +pub type RelTag = pageserver_api::reltag::RelTag; + +impl TryFrom for RelTag { + type Error = ProtocolError; + + fn try_from(pb: proto::RelTag) -> Result { + Ok(Self { + spcnode: pb.spc_oid, + dbnode: pb.db_oid, + relnode: pb.rel_number, + forknum: pb + .fork_number + .try_into() + .map_err(|_| ProtocolError::invalid("fork_number", pb.fork_number))?, + }) + } +} + +impl From for proto::RelTag { + fn from(rel_tag: RelTag) -> Self { + Self { + spc_oid: rel_tag.spcnode, + db_oid: rel_tag.dbnode, + rel_number: rel_tag.relnode, + fork_number: rel_tag.forknum as u32, + } + } +} + +/// Checks whether a relation exists, at the given LSN. Only valid on shard 0, other shards error. +#[derive(Clone, Copy, Debug)] +pub struct CheckRelExistsRequest { + pub read_lsn: ReadLsn, + pub rel: RelTag, +} + +impl TryFrom for CheckRelExistsRequest { + type Error = ProtocolError; + + fn try_from(pb: proto::CheckRelExistsRequest) -> Result { + Ok(Self { + read_lsn: pb + .read_lsn + .ok_or(ProtocolError::Missing("read_lsn"))? + .try_into()?, + rel: pb.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?, + }) + } +} + +pub type CheckRelExistsResponse = bool; + +impl From for CheckRelExistsResponse { + fn from(pb: proto::CheckRelExistsResponse) -> Self { + pb.exists + } +} + +impl From for proto::CheckRelExistsResponse { + fn from(exists: CheckRelExistsResponse) -> Self { + Self { exists } + } +} + +/// Requests a base backup at a given LSN. +#[derive(Clone, Copy, Debug)] +pub struct GetBaseBackupRequest { + /// The LSN to fetch a base backup at. + pub read_lsn: ReadLsn, + /// If true, logical replication slots will not be created. + pub replica: bool, +} + +impl TryFrom for GetBaseBackupRequest { + type Error = ProtocolError; + + fn try_from(pb: proto::GetBaseBackupRequest) -> Result { + Ok(Self { + read_lsn: pb + .read_lsn + .ok_or(ProtocolError::Missing("read_lsn"))? + .try_into()?, + replica: pb.replica, + }) + } +} + +impl TryFrom for proto::GetBaseBackupRequest { + type Error = ProtocolError; + + fn try_from(request: GetBaseBackupRequest) -> Result { + Ok(Self { + read_lsn: Some(request.read_lsn.try_into()?), + replica: request.replica, + }) + } +} + +pub type GetBaseBackupResponseChunk = Bytes; + +impl TryFrom for GetBaseBackupResponseChunk { + type Error = ProtocolError; + + fn try_from(pb: proto::GetBaseBackupResponseChunk) -> Result { + if pb.chunk.is_empty() { + return Err(ProtocolError::Missing("chunk")); + } + Ok(pb.chunk) + } +} + +impl TryFrom for proto::GetBaseBackupResponseChunk { + type Error = ProtocolError; + + fn try_from(chunk: GetBaseBackupResponseChunk) -> Result { + if chunk.is_empty() { + return Err(ProtocolError::Missing("chunk")); + } + Ok(Self { chunk }) + } +} + +/// Requests the size of a database, as # of bytes. Only valid on shard 0, other shards will error. +#[derive(Clone, Copy, Debug)] +pub struct GetDbSizeRequest { + pub read_lsn: ReadLsn, + pub db_oid: Oid, +} + +impl TryFrom for GetDbSizeRequest { + type Error = ProtocolError; + + fn try_from(pb: proto::GetDbSizeRequest) -> Result { + Ok(Self { + read_lsn: pb + .read_lsn + .ok_or(ProtocolError::Missing("read_lsn"))? + .try_into()?, + db_oid: pb.db_oid, + }) + } +} + +impl TryFrom for proto::GetDbSizeRequest { + type Error = ProtocolError; + + fn try_from(request: GetDbSizeRequest) -> Result { + Ok(Self { + read_lsn: Some(request.read_lsn.try_into()?), + db_oid: request.db_oid, + }) + } +} + +pub type GetDbSizeResponse = u64; + +impl From for GetDbSizeResponse { + fn from(pb: proto::GetDbSizeResponse) -> Self { + pb.num_bytes + } +} + +impl From for proto::GetDbSizeResponse { + fn from(num_bytes: GetDbSizeResponse) -> Self { + Self { num_bytes } + } +} + +/// Requests one or more pages. +#[derive(Clone, Debug)] +pub struct GetPageRequest { + /// A request ID. Will be included in the response. Should be unique for in-flight requests on + /// the stream. + pub request_id: RequestID, + /// The request class. + pub request_class: GetPageClass, + /// The LSN to read at. + pub read_lsn: ReadLsn, + /// The relation to read from. + pub rel: RelTag, + /// Page numbers to read. Must belong to the remote shard. + /// + /// Multiple pages will be executed as a single batch by the Pageserver, amortizing layer access + /// costs and parallelizing them. This may increase the latency of any individual request, but + /// improves the overall latency and throughput of the batch as a whole. + pub block_numbers: SmallVec<[u32; 1]>, +} + +impl TryFrom for GetPageRequest { + type Error = ProtocolError; + + fn try_from(pb: proto::GetPageRequest) -> Result { + if pb.block_number.is_empty() { + return Err(ProtocolError::Missing("block_number")); + } + Ok(Self { + request_id: pb.request_id, + request_class: pb.request_class.into(), + read_lsn: pb + .read_lsn + .ok_or(ProtocolError::Missing("read_lsn"))? + .try_into()?, + rel: pb.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?, + block_numbers: pb.block_number.into(), + }) + } +} + +impl TryFrom for proto::GetPageRequest { + type Error = ProtocolError; + + fn try_from(request: GetPageRequest) -> Result { + if request.block_numbers.is_empty() { + return Err(ProtocolError::Missing("block_number")); + } + Ok(Self { + request_id: request.request_id, + request_class: request.request_class.into(), + read_lsn: Some(request.read_lsn.try_into()?), + rel: Some(request.rel.into()), + block_number: request.block_numbers.into_vec(), + }) + } +} + +/// A GetPage request ID. +pub type RequestID = u64; + +/// A GetPage request class. +#[derive(Clone, Copy, Debug)] +pub enum GetPageClass { + /// Unknown status. For backwards compatibility: used when an older client version sends a class + /// that a newer server version has removed. + Unknown, + /// A normal request. This is the default. + Normal, + /// A prefetch request. NB: can only be classified on pg < 18. + Prefetch, + /// A background request (e.g. vacuum). + Background, +} + +impl From for GetPageClass { + fn from(pb: proto::GetPageClass) -> Self { + match pb { + proto::GetPageClass::Unknown => Self::Unknown, + proto::GetPageClass::Normal => Self::Normal, + proto::GetPageClass::Prefetch => Self::Prefetch, + proto::GetPageClass::Background => Self::Background, + } + } +} + +impl From for GetPageClass { + fn from(class: i32) -> Self { + proto::GetPageClass::try_from(class) + .unwrap_or(proto::GetPageClass::Unknown) + .into() + } +} + +impl From for proto::GetPageClass { + fn from(class: GetPageClass) -> Self { + match class { + GetPageClass::Unknown => Self::Unknown, + GetPageClass::Normal => Self::Normal, + GetPageClass::Prefetch => Self::Prefetch, + GetPageClass::Background => Self::Background, + } + } +} + +impl From for i32 { + fn from(class: GetPageClass) -> Self { + proto::GetPageClass::from(class).into() + } +} + +/// A GetPage response. +/// +/// A batch response will contain all of the requested pages. We could eagerly emit individual pages +/// as soon as they are ready, but on a readv() Postgres holds buffer pool locks on all pages in the +/// batch and we'll only return once the entire batch is ready, so no one can make use of the +/// individual pages. +#[derive(Clone, Debug)] +pub struct GetPageResponse { + /// The original request's ID. + pub request_id: RequestID, + /// The response status code. + pub status: GetPageStatus, + /// A string describing the status, if any. + pub reason: Option, + /// The 8KB page images, in the same order as the request. Empty if status != OK. + pub page_images: SmallVec<[Bytes; 1]>, +} + +impl From for GetPageResponse { + fn from(pb: proto::GetPageResponse) -> Self { + Self { + request_id: pb.request_id, + status: pb.status.into(), + reason: Some(pb.reason).filter(|r| !r.is_empty()), + page_images: pb.page_image.into(), + } + } +} + +impl From for proto::GetPageResponse { + fn from(response: GetPageResponse) -> Self { + Self { + request_id: response.request_id, + status: response.status.into(), + reason: response.reason.unwrap_or_default(), + page_image: response.page_images.into_vec(), + } + } +} + +/// A GetPage response status. +#[derive(Clone, Copy, Debug)] +pub enum GetPageStatus { + /// Unknown status. For forwards compatibility: used when an older client version receives a new + /// status code from a newer server version. + Unknown, + /// The request was successful. + Ok, + /// The page did not exist. The tenant/timeline/shard has already been validated during stream + /// setup. + NotFound, + /// The request was invalid. + Invalid, + /// The tenant is rate limited. Slow down and retry later. + SlowDown, +} + +impl From for GetPageStatus { + fn from(pb: proto::GetPageStatus) -> Self { + match pb { + proto::GetPageStatus::Unknown => Self::Unknown, + proto::GetPageStatus::Ok => Self::Ok, + proto::GetPageStatus::NotFound => Self::NotFound, + proto::GetPageStatus::Invalid => Self::Invalid, + proto::GetPageStatus::SlowDown => Self::SlowDown, + } + } +} + +impl From for GetPageStatus { + fn from(status: i32) -> Self { + proto::GetPageStatus::try_from(status) + .unwrap_or(proto::GetPageStatus::Unknown) + .into() + } +} + +impl From for proto::GetPageStatus { + fn from(status: GetPageStatus) -> Self { + match status { + GetPageStatus::Unknown => Self::Unknown, + GetPageStatus::Ok => Self::Ok, + GetPageStatus::NotFound => Self::NotFound, + GetPageStatus::Invalid => Self::Invalid, + GetPageStatus::SlowDown => Self::SlowDown, + } + } +} + +impl From for i32 { + fn from(status: GetPageStatus) -> Self { + proto::GetPageStatus::from(status).into() + } +} + +// Fetches the size of a relation at a given LSN, as # of blocks. Only valid on shard 0, other +// shards will error. +pub struct GetRelSizeRequest { + pub read_lsn: ReadLsn, + pub rel: RelTag, +} + +impl TryFrom for GetRelSizeRequest { + type Error = ProtocolError; + + fn try_from(proto: proto::GetRelSizeRequest) -> Result { + Ok(Self { + read_lsn: proto + .read_lsn + .ok_or(ProtocolError::Missing("read_lsn"))? + .try_into()?, + rel: proto.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?, + }) + } +} + +impl TryFrom for proto::GetRelSizeRequest { + type Error = ProtocolError; + + fn try_from(request: GetRelSizeRequest) -> Result { + Ok(Self { + read_lsn: Some(request.read_lsn.try_into()?), + rel: Some(request.rel.into()), + }) + } +} + +pub type GetRelSizeResponse = u32; + +impl From for GetRelSizeResponse { + fn from(proto: proto::GetRelSizeResponse) -> Self { + proto.num_blocks + } +} + +impl From for proto::GetRelSizeResponse { + fn from(num_blocks: GetRelSizeResponse) -> Self { + Self { num_blocks } + } +} + +/// Requests an SLRU segment. Only valid on shard 0, other shards will error. +pub struct GetSlruSegmentRequest { + pub read_lsn: ReadLsn, + pub kind: SlruKind, + pub segno: u32, +} + +impl TryFrom for GetSlruSegmentRequest { + type Error = ProtocolError; + + fn try_from(pb: proto::GetSlruSegmentRequest) -> Result { + Ok(Self { + read_lsn: pb + .read_lsn + .ok_or(ProtocolError::Missing("read_lsn"))? + .try_into()?, + kind: u8::try_from(pb.kind) + .ok() + .and_then(SlruKind::from_repr) + .ok_or_else(|| ProtocolError::invalid("slru_kind", pb.kind))?, + segno: pb.segno, + }) + } +} + +impl TryFrom for proto::GetSlruSegmentRequest { + type Error = ProtocolError; + + fn try_from(request: GetSlruSegmentRequest) -> Result { + Ok(Self { + read_lsn: Some(request.read_lsn.try_into()?), + kind: request.kind as u32, + segno: request.segno, + }) + } +} + +pub type GetSlruSegmentResponse = Bytes; + +impl TryFrom for GetSlruSegmentResponse { + type Error = ProtocolError; + + fn try_from(pb: proto::GetSlruSegmentResponse) -> Result { + if pb.segment.is_empty() { + return Err(ProtocolError::Missing("segment")); + } + Ok(pb.segment) + } +} + +impl TryFrom for proto::GetSlruSegmentResponse { + type Error = ProtocolError; + + fn try_from(segment: GetSlruSegmentResponse) -> Result { + if segment.is_empty() { + return Err(ProtocolError::Missing("segment")); + } + Ok(Self { segment }) + } +} + +// SlruKind is defined in pageserver_api::reltag. +pub type SlruKind = pageserver_api::reltag::SlruKind; From 1369d73dcd52dc88416c9087914f2a7ac8c39876 Mon Sep 17 00:00:00 2001 From: a-masterov <72613290+a-masterov@users.noreply.github.com> Date: Mon, 26 May 2025 13:29:39 +0200 Subject: [PATCH 04/48] Add h3 to neon-extensions-test (#11946) ## Problem We didn't test the h3 extension in our test suite. ## Summary of changes Added tests for h3 and h3-postgis extensions Includes upgrade test for h3 --------- Co-authored-by: Tristan Partin --- compute/compute-node.Dockerfile | 2 +- docker-compose/ext-src/h3-pg-src/neon-test.sh | 16 ++++++++++++++++ docker-compose/ext-src/h3-pg-src/test-upgrade.sh | 7 +++++++ docker-compose/test_extensions_upgrade.sh | 3 ++- 4 files changed, 26 insertions(+), 2 deletions(-) create mode 100755 docker-compose/ext-src/h3-pg-src/neon-test.sh create mode 100755 docker-compose/ext-src/h3-pg-src/test-upgrade.sh diff --git a/compute/compute-node.Dockerfile b/compute/compute-node.Dockerfile index f4a5593b71..3459983a34 100644 --- a/compute/compute-node.Dockerfile +++ b/compute/compute-node.Dockerfile @@ -1847,7 +1847,7 @@ COPY docker-compose/ext-src/ /ext-src/ COPY --from=pg-build /postgres /postgres #COPY --from=postgis-src /ext-src/ /ext-src/ COPY --from=plv8-src /ext-src/ /ext-src/ -#COPY --from=h3-pg-src /ext-src/ /ext-src/ +COPY --from=h3-pg-src /ext-src/h3-pg-src /ext-src/h3-pg-src COPY --from=postgresql-unit-src /ext-src/ /ext-src/ COPY --from=pgvector-src /ext-src/ /ext-src/ COPY --from=pgjwt-src /ext-src/ /ext-src/ diff --git a/docker-compose/ext-src/h3-pg-src/neon-test.sh b/docker-compose/ext-src/h3-pg-src/neon-test.sh new file mode 100755 index 0000000000..e2ab22f03e --- /dev/null +++ b/docker-compose/ext-src/h3-pg-src/neon-test.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash +set -ex +cd "$(dirname "${0}")" +PG_REGRESS=$(dirname "$(pg_config --pgxs)")/../test/regress/pg_regress +dropdb --if-exists contrib_regression +createdb contrib_regression +cd h3_postgis/test +psql -d contrib_regression -c "CREATE EXTENSION postgis" -c "CREATE EXTENSION postgis_raster" -c "CREATE EXTENSION h3" -c "CREATE EXTENSION h3_postgis" +TESTS=$(echo sql/* | sed 's|sql/||g; s|\.sql||g') +${PG_REGRESS} --use-existing --dbname contrib_regression ${TESTS} +cd ../../h3/test +TESTS=$(echo sql/* | sed 's|sql/||g; s|\.sql||g') +dropdb --if-exists contrib_regression +createdb contrib_regression +psql -d contrib_regression -c "CREATE EXTENSION h3" +${PG_REGRESS} --use-existing --dbname contrib_regression ${TESTS} diff --git a/docker-compose/ext-src/h3-pg-src/test-upgrade.sh b/docker-compose/ext-src/h3-pg-src/test-upgrade.sh new file mode 100755 index 0000000000..72d7040966 --- /dev/null +++ b/docker-compose/ext-src/h3-pg-src/test-upgrade.sh @@ -0,0 +1,7 @@ +#!/bin/sh +set -ex +cd "$(dirname ${0})" +PG_REGRESS=$(dirname "$(pg_config --pgxs)")/../test/regress/pg_regress +cd h3/test +TESTS=$(echo sql/* | sed 's|sql/||g; s|\.sql||g') +${PG_REGRESS} --use-existing --inputdir=./ --bindir='/usr/local/pgsql/bin' --dbname=contrib_regression ${TESTS} \ No newline at end of file diff --git a/docker-compose/test_extensions_upgrade.sh b/docker-compose/test_extensions_upgrade.sh index 51d1e40802..f1cf17f531 100755 --- a/docker-compose/test_extensions_upgrade.sh +++ b/docker-compose/test_extensions_upgrade.sh @@ -82,7 +82,8 @@ EXTENSIONS='[ {"extname": "pg_ivm", "extdir": "pg_ivm-src"}, {"extname": "pgjwt", "extdir": "pgjwt-src"}, {"extname": "pgtap", "extdir": "pgtap-src"}, -{"extname": "pg_repack", "extdir": "pg_repack-src"} +{"extname": "pg_repack", "extdir": "pg_repack-src"}, +{"extname": "h3", "extdir": "h3-pg-src"} ]' EXTNAMES=$(echo ${EXTENSIONS} | jq -r '.[].extname' | paste -sd ' ' -) COMPUTE_TAG=${NEW_COMPUTE_TAG} docker compose --profile test-extensions up --quiet-pull --build -d From 841517ee3760902e204b090f971513b92371abca Mon Sep 17 00:00:00 2001 From: "Alex Chi Z." <4198311+skyzh@users.noreply.github.com> Date: Mon, 26 May 2025 19:31:27 +0800 Subject: [PATCH 05/48] fix(pageserver): do not increase basebackup err counter when reconnect (#12016) ## Problem We see unexpected basebackup error alerts in the alert channel. https://github.com/neondatabase/neon/pull/11778 only fixed the alerts for shutdown errors. However, another path is that tenant shutting down while waiting LSN -> WaitLsnError::BadState -> QueryError::Reconnect. Therefore, the reconnect error should also be discarded from the ok/error counter. ## Summary of changes Do not increase ok/err counter for reconnect errors. --------- Signed-off-by: Alex Chi Z --- pageserver/src/metrics.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pageserver/src/metrics.rs b/pageserver/src/metrics.rs index 3076c7f1d6..0ff31dcb8a 100644 --- a/pageserver/src/metrics.rs +++ b/pageserver/src/metrics.rs @@ -2234,8 +2234,10 @@ impl BasebackupQueryTimeOngoingRecording<'_> { // If you want to change categorize of a specific error, also change it in `log_query_error`. let metric = match res { Ok(_) => &self.parent.ok, - Err(QueryError::Shutdown) => { - // Do not observe ok/err for shutdown + Err(QueryError::Shutdown) | Err(QueryError::Reconnect) => { + // Do not observe ok/err for shutdown/reconnect. + // Reconnect error might be raised when the operation is waiting for LSN and the tenant shutdown interrupts + // the operation. A reconnect error will be issued and the client will retry. return; } Err(QueryError::Disconnected(ConnectionError::Io(io_error))) From dc953de85d6ebb5bec3a4213f623f26eaf85fa25 Mon Sep 17 00:00:00 2001 From: "Alex Chi Z." <4198311+skyzh@users.noreply.github.com> Date: Mon, 26 May 2025 21:09:37 +0800 Subject: [PATCH 06/48] feat(pageserver): integrate PostHog with gc-compaction rollout (#11917) ## Problem part of https://github.com/neondatabase/neon/issues/11813 ## Summary of changes * Integrate feature store with tenant structure. * gc-compaction picks up the current strategy from the feature store. * We only log them for now for testing purpose. They will not be used until we have more patches to support different strategies defined in PostHog. * We don't support property-based evaulation for now; it will be implemented later. * Evaluating result of the feature flag is not cached -- it's not efficient and cannot be used on hot path right now. * We don't report the evaluation result back to PostHog right now. I plan to enable it in staging once we get the patch merged. --------- Signed-off-by: Alex Chi Z --- Cargo.lock | 9 ++- Cargo.toml | 5 +- libs/pageserver_api/src/config.rs | 18 +++++ libs/posthog_client_lite/Cargo.toml | 9 ++- .../src/background_loop.rs | 59 +++++++++++++++ libs/posthog_client_lite/src/lib.rs | 74 ++++++++++--------- pageserver/Cargo.toml | 58 +++++++-------- pageserver/src/bin/pageserver.rs | 16 ++++ pageserver/src/config.rs | 7 +- pageserver/src/feature_resolver.rs | 65 ++++++++++++++++ pageserver/src/lib.rs | 1 + pageserver/src/tenant.rs | 22 +++++- workspace_hack/Cargo.toml | 3 - 13 files changed, 267 insertions(+), 79 deletions(-) create mode 100644 libs/posthog_client_lite/src/background_loop.rs create mode 100644 pageserver/src/feature_resolver.rs diff --git a/Cargo.lock b/Cargo.lock index 21c863ff95..89351432c1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4330,6 +4330,7 @@ dependencies = [ "postgres_connection", "postgres_ffi", "postgres_initdb", + "posthog_client_lite", "pprof", "pq_proto", "procfs", @@ -4907,11 +4908,16 @@ name = "posthog_client_lite" version = "0.1.0" dependencies = [ "anyhow", + "arc-swap", "reqwest", "serde", "serde_json", "sha2", "thiserror 1.0.69", + "tokio", + "tokio-util", + "tracing", + "tracing-utils", "workspace_hack", ] @@ -8575,10 +8581,8 @@ dependencies = [ "fail", "form_urlencoded", "futures-channel", - "futures-core", "futures-executor", "futures-io", - "futures-task", "futures-util", "generic-array", "getrandom 0.2.11", @@ -8608,7 +8612,6 @@ dependencies = [ "once_cell", "p256 0.13.2", "parquet", - "percent-encoding", "prettyplease", "proc-macro2", "prost 0.13.5", diff --git a/Cargo.toml b/Cargo.toml index d2c8e86bd4..a040010fb7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -247,6 +247,7 @@ azure_storage_blobs = { git = "https://github.com/neondatabase/azure-sdk-for-rus ## Local libraries compute_api = { version = "0.1", path = "./libs/compute_api/" } consumption_metrics = { version = "0.1", path = "./libs/consumption_metrics/" } +desim = { version = "0.1", path = "./libs/desim" } endpoint_storage = { version = "0.0.1", path = "./endpoint_storage/" } http-utils = { version = "0.1", path = "./libs/http-utils/" } metrics = { version = "0.1", path = "./libs/metrics/" } @@ -259,19 +260,19 @@ postgres_backend = { version = "0.1", path = "./libs/postgres_backend/" } postgres_connection = { version = "0.1", path = "./libs/postgres_connection/" } postgres_ffi = { version = "0.1", path = "./libs/postgres_ffi/" } postgres_initdb = { path = "./libs/postgres_initdb" } +posthog_client_lite = { version = "0.1", path = "./libs/posthog_client_lite" } pq_proto = { version = "0.1", path = "./libs/pq_proto/" } remote_storage = { version = "0.1", path = "./libs/remote_storage/" } safekeeper_api = { version = "0.1", path = "./libs/safekeeper_api" } safekeeper_client = { path = "./safekeeper/client" } -desim = { version = "0.1", path = "./libs/desim" } storage_broker = { version = "0.1", path = "./storage_broker/" } # Note: main broker code is inside the binary crate, so linking with the library shouldn't be heavy. storage_controller_client = { path = "./storage_controller/client" } tenant_size_model = { version = "0.1", path = "./libs/tenant_size_model/" } tracing-utils = { version = "0.1", path = "./libs/tracing-utils/" } utils = { version = "0.1", path = "./libs/utils/" } vm_monitor = { version = "0.1", path = "./libs/vm_monitor/" } -walproposer = { version = "0.1", path = "./libs/walproposer/" } wal_decoder = { version = "0.1", path = "./libs/wal_decoder" } +walproposer = { version = "0.1", path = "./libs/walproposer/" } ## Common library dependency workspace_hack = { version = "0.1", path = "./workspace_hack/" } diff --git a/libs/pageserver_api/src/config.rs b/libs/pageserver_api/src/config.rs index daec65ce2d..012c020fb1 100644 --- a/libs/pageserver_api/src/config.rs +++ b/libs/pageserver_api/src/config.rs @@ -45,6 +45,21 @@ pub struct NodeMetadata { pub other: HashMap, } +/// PostHog integration config. +#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub struct PostHogConfig { + /// PostHog project ID + pub project_id: String, + /// Server-side (private) API key + pub server_api_key: String, + /// Client-side (public) API key + pub client_api_key: String, + /// Private API URL + pub private_api_url: String, + /// Public API URL + pub public_api_url: String, +} + /// `pageserver.toml` /// /// We use serde derive with `#[serde(default)]` to generate a deserializer @@ -186,6 +201,8 @@ pub struct ConfigToml { pub tracing: Option, pub enable_tls_page_service_api: bool, pub dev_mode: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub posthog_config: Option, pub timeline_import_config: TimelineImportConfig, #[serde(skip_serializing_if = "Option::is_none")] pub basebackup_cache_config: Option, @@ -701,6 +718,7 @@ impl Default for ConfigToml { import_job_checkpoint_threshold: NonZeroUsize::new(128).unwrap(), }, basebackup_cache_config: None, + posthog_config: None, } } } diff --git a/libs/posthog_client_lite/Cargo.toml b/libs/posthog_client_lite/Cargo.toml index 7c19bf2ccb..05a3a9774e 100644 --- a/libs/posthog_client_lite/Cargo.toml +++ b/libs/posthog_client_lite/Cargo.toml @@ -6,9 +6,14 @@ license.workspace = true [dependencies] anyhow.workspace = true +arc-swap.workspace = true reqwest.workspace = true -serde.workspace = true serde_json.workspace = true +serde.workspace = true sha2.workspace = true -workspace_hack.workspace = true thiserror.workspace = true +tokio = { workspace = true, features = ["process", "sync", "fs", "rt", "io-util", "time"] } +tokio-util.workspace = true +tracing-utils.workspace = true +tracing.workspace = true +workspace_hack.workspace = true diff --git a/libs/posthog_client_lite/src/background_loop.rs b/libs/posthog_client_lite/src/background_loop.rs new file mode 100644 index 0000000000..9ffcda3728 --- /dev/null +++ b/libs/posthog_client_lite/src/background_loop.rs @@ -0,0 +1,59 @@ +//! A background loop that fetches feature flags from PostHog and updates the feature store. + +use std::{sync::Arc, time::Duration}; + +use arc_swap::ArcSwap; +use tokio_util::sync::CancellationToken; + +use crate::{FeatureStore, PostHogClient, PostHogClientConfig}; + +/// A background loop that fetches feature flags from PostHog and updates the feature store. +pub struct FeatureResolverBackgroundLoop { + posthog_client: PostHogClient, + feature_store: ArcSwap, + cancel: CancellationToken, +} + +impl FeatureResolverBackgroundLoop { + pub fn new(config: PostHogClientConfig, shutdown_pageserver: CancellationToken) -> Self { + Self { + posthog_client: PostHogClient::new(config), + feature_store: ArcSwap::new(Arc::new(FeatureStore::new())), + cancel: shutdown_pageserver, + } + } + + pub fn spawn(self: Arc, handle: &tokio::runtime::Handle, refresh_period: Duration) { + let this = self.clone(); + let cancel = self.cancel.clone(); + handle.spawn(async move { + tracing::info!("Starting PostHog feature resolver"); + let mut ticker = tokio::time::interval(refresh_period); + ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + loop { + tokio::select! { + _ = ticker.tick() => {} + _ = cancel.cancelled() => break + } + let resp = match this + .posthog_client + .get_feature_flags_local_evaluation() + .await + { + Ok(resp) => resp, + Err(e) => { + tracing::warn!("Cannot get feature flags: {}", e); + continue; + } + }; + let feature_store = FeatureStore::new_with_flags(resp.flags); + this.feature_store.store(Arc::new(feature_store)); + } + tracing::info!("PostHog feature resolver stopped"); + }); + } + + pub fn feature_store(&self) -> Arc { + self.feature_store.load_full() + } +} diff --git a/libs/posthog_client_lite/src/lib.rs b/libs/posthog_client_lite/src/lib.rs index 53deb26ab7..21e978df3e 100644 --- a/libs/posthog_client_lite/src/lib.rs +++ b/libs/posthog_client_lite/src/lib.rs @@ -1,5 +1,9 @@ //! A lite version of the PostHog client that only supports local evaluation of feature flags. +mod background_loop; + +pub use background_loop::FeatureResolverBackgroundLoop; + use std::collections::HashMap; use serde::{Deserialize, Serialize}; @@ -20,8 +24,7 @@ pub enum PostHogEvaluationError { #[derive(Deserialize)] pub struct LocalEvaluationResponse { - #[allow(dead_code)] - flags: Vec, + pub flags: Vec, } #[derive(Deserialize)] @@ -94,6 +97,12 @@ impl FeatureStore { } } + pub fn new_with_flags(flags: Vec) -> Self { + let mut store = Self::new(); + store.set_flags(flags); + store + } + pub fn set_flags(&mut self, flags: Vec) { self.flags.clear(); for flag in flags { @@ -267,6 +276,7 @@ impl FeatureStore { &self, flag_key: &str, user_id: &str, + properties: &HashMap, ) -> Result { let hash_on_global_rollout_percentage = Self::consistent_hash(user_id, flag_key, "multivariate"); @@ -276,7 +286,7 @@ impl FeatureStore { flag_key, hash_on_global_rollout_percentage, hash_on_group_rollout_percentage, - &HashMap::new(), + properties, ) } @@ -344,6 +354,19 @@ impl FeatureStore { } } +pub struct PostHogClientConfig { + /// The server API key. + pub server_api_key: String, + /// The client API key. + pub client_api_key: String, + /// The project ID. + pub project_id: String, + /// The private API URL. + pub private_api_url: String, + /// The public API URL. + pub public_api_url: String, +} + /// A lite PostHog client. /// /// At the point of writing this code, PostHog does not have a functional Rust client with feature flag support. @@ -360,37 +383,16 @@ impl FeatureStore { /// want to report the feature flag usage back to PostHog. The current plan is to use PostHog only as an UI to /// configure feature flags so it is very likely that the client API will not be used. pub struct PostHogClient { - /// The server API key. - server_api_key: String, - /// The client API key. - client_api_key: String, - /// The project ID. - project_id: String, - /// The private API URL. - private_api_url: String, - /// The public API URL. - public_api_url: String, + /// The config. + config: PostHogClientConfig, /// The HTTP client. client: reqwest::Client, } impl PostHogClient { - pub fn new( - server_api_key: String, - client_api_key: String, - project_id: String, - private_api_url: String, - public_api_url: String, - ) -> Self { + pub fn new(config: PostHogClientConfig) -> Self { let client = reqwest::Client::new(); - Self { - server_api_key, - client_api_key, - project_id, - private_api_url, - public_api_url, - client, - } + Self { config, client } } pub fn new_with_us_region( @@ -398,13 +400,13 @@ impl PostHogClient { client_api_key: String, project_id: String, ) -> Self { - Self::new( + Self::new(PostHogClientConfig { server_api_key, client_api_key, project_id, - "https://us.posthog.com".to_string(), - "https://us.i.posthog.com".to_string(), - ) + private_api_url: "https://us.posthog.com".to_string(), + public_api_url: "https://us.i.posthog.com".to_string(), + }) } /// Fetch the feature flag specs from the server. @@ -422,12 +424,12 @@ impl PostHogClient { // with bearer token of self.server_api_key let url = format!( "{}/api/projects/{}/feature_flags/local_evaluation", - self.private_api_url, self.project_id + self.config.private_api_url, self.config.project_id ); let response = self .client .get(url) - .bearer_auth(&self.server_api_key) + .bearer_auth(&self.config.server_api_key) .send() .await?; let body = response.text().await?; @@ -446,11 +448,11 @@ impl PostHogClient { ) -> anyhow::Result<()> { // PUBLIC_URL/capture/ // with bearer token of self.client_api_key - let url = format!("{}/capture/", self.public_api_url); + let url = format!("{}/capture/", self.config.public_api_url); self.client .post(url) .body(serde_json::to_string(&json!({ - "api_key": self.client_api_key, + "api_key": self.config.client_api_key, "distinct_id": distinct_id, "event": event, "properties": properties, diff --git a/pageserver/Cargo.toml b/pageserver/Cargo.toml index 1f5cc89b33..c4d6d58945 100644 --- a/pageserver/Cargo.toml +++ b/pageserver/Cargo.toml @@ -17,51 +17,69 @@ anyhow.workspace = true arc-swap.workspace = true async-compression.workspace = true async-stream.workspace = true -bit_field.workspace = true bincode.workspace = true +bit_field.workspace = true byteorder.workspace = true bytes.workspace = true -camino.workspace = true camino-tempfile.workspace = true +camino.workspace = true chrono = { workspace = true, features = ["serde"] } clap = { workspace = true, features = ["string"] } consumption_metrics.workspace = true crc32c.workspace = true either.workspace = true +enum-map.workspace = true +enumset = { workspace = true, features = ["serde"]} fail.workspace = true futures.workspace = true hashlink.workspace = true hex.workspace = true -humantime.workspace = true +http-utils.workspace = true humantime-serde.workspace = true +humantime.workspace = true hyper0.workspace = true itertools.workspace = true jsonwebtoken.workspace = true md5.workspace = true +metrics.workspace = true nix.workspace = true -# hack to get the number of worker threads tokio uses -num_cpus.workspace = true +num_cpus.workspace = true # hack to get the number of worker threads tokio uses num-traits.workspace = true once_cell.workspace = true +pageserver_api.workspace = true +pageserver_client.workspace = true # for ResponseErrorMessageExt TOOD refactor that +pageserver_compaction.workspace = true pageserver_page_api.workspace = true +pem.workspace = true pin-project-lite.workspace = true postgres_backend.workspace = true +postgres_connection.workspace = true +postgres_ffi.workspace = true +postgres_initdb.workspace = true postgres-protocol.workspace = true postgres-types.workspace = true -postgres_initdb.workspace = true +posthog_client_lite.workspace = true pprof.workspace = true +pq_proto.workspace = true rand.workspace = true range-set-blaze = { version = "0.1.16", features = ["alloc"] } regex.workspace = true +remote_storage.workspace = true +reqwest.workspace = true +rpds.workspace = true rustls.workspace = true scopeguard.workspace = true send-future.workspace = true -serde.workspace = true serde_json = { workspace = true, features = ["raw_value"] } serde_path_to_error.workspace = true serde_with.workspace = true +serde.workspace = true +smallvec.workspace = true +storage_broker.workspace = true +strum_macros.workspace = true +strum.workspace = true sysinfo.workspace = true -tokio-tar.workspace = true +tenant_size_model.workspace = true thiserror.workspace = true tikv-jemallocator.workspace = true tokio = { workspace = true, features = ["process", "sync", "fs", "rt", "io-util", "time"] } @@ -70,6 +88,7 @@ tokio-io-timeout.workspace = true tokio-postgres.workspace = true tokio-rustls.workspace = true tokio-stream.workspace = true +tokio-tar.workspace = true tokio-util.workspace = true toml_edit = { workspace = true, features = [ "serde" ] } tonic.workspace = true @@ -77,29 +96,10 @@ tonic-reflection.workspace = true tracing.workspace = true tracing-utils.workspace = true url.workspace = true -walkdir.workspace = true -metrics.workspace = true -pageserver_api.workspace = true -pageserver_client.workspace = true # for ResponseErrorMessageExt TOOD refactor that -pageserver_compaction.workspace = true -pem.workspace = true -postgres_connection.workspace = true -postgres_ffi.workspace = true -pq_proto.workspace = true -remote_storage.workspace = true -storage_broker.workspace = true -tenant_size_model.workspace = true -http-utils.workspace = true utils.workspace = true -workspace_hack.workspace = true -reqwest.workspace = true -rpds.workspace = true -enum-map.workspace = true -enumset = { workspace = true, features = ["serde"]} -strum.workspace = true -strum_macros.workspace = true wal_decoder.workspace = true -smallvec.workspace = true +walkdir.workspace = true +workspace_hack.workspace = true twox-hash.workspace = true [target.'cfg(target_os = "linux")'.dependencies] diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index 8d76d0d678..df3c045145 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -21,6 +21,7 @@ use pageserver::config::{PageServerConf, PageserverIdentity, ignored_fields}; use pageserver::controller_upcall_client::StorageControllerUpcallClient; use pageserver::deletion_queue::DeletionQueue; use pageserver::disk_usage_eviction_task::{self, launch_disk_usage_global_eviction_task}; +use pageserver::feature_resolver::FeatureResolver; use pageserver::metrics::{STARTUP_DURATION, STARTUP_IS_LOADING}; use pageserver::task_mgr::{ BACKGROUND_RUNTIME, COMPUTE_REQUEST_RUNTIME, MGMT_REQUEST_RUNTIME, WALRECEIVER_RUNTIME, @@ -522,6 +523,12 @@ fn start_pageserver( // Set up remote storage client let remote_storage = BACKGROUND_RUNTIME.block_on(create_remote_storage_client(conf))?; + let feature_resolver = create_feature_resolver( + conf, + shutdown_pageserver.clone(), + BACKGROUND_RUNTIME.handle(), + )?; + // Set up deletion queue let (deletion_queue, deletion_workers) = DeletionQueue::new( remote_storage.clone(), @@ -575,6 +582,7 @@ fn start_pageserver( deletion_queue_client, l0_flush_global_state, basebackup_prepare_sender, + feature_resolver, }, order, shutdown_pageserver.clone(), @@ -849,6 +857,14 @@ fn start_pageserver( }) } +fn create_feature_resolver( + conf: &'static PageServerConf, + shutdown_pageserver: CancellationToken, + handle: &tokio::runtime::Handle, +) -> anyhow::Result { + FeatureResolver::spawn(conf, shutdown_pageserver, handle) +} + async fn create_remote_storage_client( conf: &'static PageServerConf, ) -> anyhow::Result { diff --git a/pageserver/src/config.rs b/pageserver/src/config.rs index e8af548ec4..89f7539722 100644 --- a/pageserver/src/config.rs +++ b/pageserver/src/config.rs @@ -14,7 +14,7 @@ use std::time::Duration; use anyhow::{Context, bail, ensure}; use camino::{Utf8Path, Utf8PathBuf}; use once_cell::sync::OnceCell; -use pageserver_api::config::{DiskUsageEvictionTaskConfig, MaxVectoredReadBytes}; +use pageserver_api::config::{DiskUsageEvictionTaskConfig, MaxVectoredReadBytes, PostHogConfig}; use pageserver_api::models::ImageCompressionAlgorithm; use pageserver_api::shard::TenantShardId; use pem::Pem; @@ -238,6 +238,9 @@ pub struct PageServerConf { /// This is insecure and should only be used in development environments. pub dev_mode: bool, + /// PostHog integration config. + pub posthog_config: Option, + pub timeline_import_config: pageserver_api::config::TimelineImportConfig, pub basebackup_cache_config: Option, @@ -421,6 +424,7 @@ impl PageServerConf { tracing, enable_tls_page_service_api, dev_mode, + posthog_config, timeline_import_config, basebackup_cache_config, } = config_toml; @@ -536,6 +540,7 @@ impl PageServerConf { } None => Vec::new(), }, + posthog_config, }; // ------------------------------------------------------------ diff --git a/pageserver/src/feature_resolver.rs b/pageserver/src/feature_resolver.rs new file mode 100644 index 0000000000..193fb10abc --- /dev/null +++ b/pageserver/src/feature_resolver.rs @@ -0,0 +1,65 @@ +use std::{collections::HashMap, sync::Arc, time::Duration}; + +use posthog_client_lite::{ + FeatureResolverBackgroundLoop, PostHogClientConfig, PostHogEvaluationError, +}; +use tokio_util::sync::CancellationToken; +use utils::id::TenantId; + +use crate::config::PageServerConf; + +#[derive(Clone)] +pub struct FeatureResolver { + inner: Option>, +} + +impl FeatureResolver { + pub fn new_disabled() -> Self { + Self { inner: None } + } + + pub fn spawn( + conf: &PageServerConf, + shutdown_pageserver: CancellationToken, + handle: &tokio::runtime::Handle, + ) -> anyhow::Result { + // DO NOT block in this function: make it return as fast as possible to avoid startup delays. + if let Some(posthog_config) = &conf.posthog_config { + let inner = FeatureResolverBackgroundLoop::new( + PostHogClientConfig { + server_api_key: posthog_config.server_api_key.clone(), + client_api_key: posthog_config.client_api_key.clone(), + project_id: posthog_config.project_id.clone(), + private_api_url: posthog_config.private_api_url.clone(), + public_api_url: posthog_config.public_api_url.clone(), + }, + shutdown_pageserver, + ); + let inner = Arc::new(inner); + // TODO: make this configurable + inner.clone().spawn(handle, Duration::from_secs(60)); + Ok(FeatureResolver { inner: Some(inner) }) + } else { + Ok(FeatureResolver { inner: None }) + } + } + + /// Evaluate a multivariate feature flag. Currently, we do not support any properties. + pub fn evaluate_multivariate( + &self, + flag_key: &str, + tenant_id: TenantId, + ) -> Result { + if let Some(inner) = &self.inner { + inner.feature_store().evaluate_multivariate( + flag_key, + &tenant_id.to_string(), + &HashMap::new(), + ) + } else { + Err(PostHogEvaluationError::NotAvailable( + "PostHog integration is not enabled".to_string(), + )) + } + } +} diff --git a/pageserver/src/lib.rs b/pageserver/src/lib.rs index 25461c23ab..ae7cbf1d6b 100644 --- a/pageserver/src/lib.rs +++ b/pageserver/src/lib.rs @@ -10,6 +10,7 @@ pub mod context; pub mod controller_upcall_client; pub mod deletion_queue; pub mod disk_usage_eviction_task; +pub mod feature_resolver; pub mod http; pub mod import_datadir; pub mod l0_flush; diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index bf3f71e35a..7e006ef9e6 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -84,6 +84,7 @@ use crate::context; use crate::context::RequestContextBuilder; use crate::context::{DownloadBehavior, RequestContext}; use crate::deletion_queue::{DeletionQueueClient, DeletionQueueError}; +use crate::feature_resolver::FeatureResolver; use crate::l0_flush::L0FlushGlobalState; use crate::metrics::{ BROKEN_TENANTS_SET, CIRCUIT_BREAKERS_BROKEN, CIRCUIT_BREAKERS_UNBROKEN, CONCURRENT_INITDBS, @@ -159,6 +160,7 @@ pub struct TenantSharedResources { pub deletion_queue_client: DeletionQueueClient, pub l0_flush_global_state: L0FlushGlobalState, pub basebackup_prepare_sender: BasebackupPrepareSender, + pub feature_resolver: FeatureResolver, } /// A [`TenantShard`] is really an _attached_ tenant. The configuration @@ -380,6 +382,8 @@ pub struct TenantShard { pub(crate) gc_block: gc_block::GcBlock, l0_flush_global_state: L0FlushGlobalState, + + feature_resolver: FeatureResolver, } impl std::fmt::Debug for TenantShard { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1292,6 +1296,7 @@ impl TenantShard { deletion_queue_client, l0_flush_global_state, basebackup_prepare_sender, + feature_resolver, } = resources; let attach_mode = attached_conf.location.attach_mode; @@ -1308,6 +1313,7 @@ impl TenantShard { deletion_queue_client, l0_flush_global_state, basebackup_prepare_sender, + feature_resolver, )); // The attach task will carry a GateGuard, so that shutdown() reliably waits for it to drop out if @@ -3135,11 +3141,18 @@ impl TenantShard { .or_insert_with(|| Arc::new(GcCompactionQueue::new())) .clone() }; + let gc_compaction_strategy = self + .feature_resolver + .evaluate_multivariate("gc-comapction-strategy", self.tenant_shard_id.tenant_id) + .ok(); + let span = if let Some(gc_compaction_strategy) = gc_compaction_strategy { + info_span!("gc_compact_timeline", timeline_id = %timeline.timeline_id, strategy = %gc_compaction_strategy) + } else { + info_span!("gc_compact_timeline", timeline_id = %timeline.timeline_id) + }; outcome = queue .iteration(cancel, ctx, &self.gc_block, &timeline) - .instrument( - info_span!("gc_compact_timeline", timeline_id = %timeline.timeline_id), - ) + .instrument(span) .await?; } @@ -4247,6 +4260,7 @@ impl TenantShard { deletion_queue_client: DeletionQueueClient, l0_flush_global_state: L0FlushGlobalState, basebackup_prepare_sender: BasebackupPrepareSender, + feature_resolver: FeatureResolver, ) -> TenantShard { assert!(!attached_conf.location.generation.is_none()); @@ -4351,6 +4365,7 @@ impl TenantShard { gc_block: Default::default(), l0_flush_global_state, basebackup_prepare_sender, + feature_resolver, } } @@ -5873,6 +5888,7 @@ pub(crate) mod harness { // TODO: ideally we should run all unit tests with both configs L0FlushGlobalState::new(L0FlushConfig::default()), basebackup_requst_sender, + FeatureResolver::new_disabled(), )); let preload = tenant diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 726d7c20c9..2b07889871 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -41,10 +41,8 @@ env_logger = { version = "0.11" } fail = { version = "0.5", default-features = false, features = ["failpoints"] } form_urlencoded = { version = "1" } futures-channel = { version = "0.3", features = ["sink"] } -futures-core = { version = "0.3" } futures-executor = { version = "0.3" } futures-io = { version = "0.3" } -futures-task = { version = "0.3", default-features = false, features = ["std"] } futures-util = { version = "0.3", features = ["channel", "io", "sink"] } generic-array = { version = "0.14", default-features = false, features = ["more_lengths", "zeroize"] } getrandom = { version = "0.2", default-features = false, features = ["std"] } @@ -74,7 +72,6 @@ num-traits = { version = "0.2", features = ["i128", "libm"] } once_cell = { version = "1" } p256 = { version = "0.13", features = ["jwk"] } parquet = { version = "53", default-features = false, features = ["zstd"] } -percent-encoding = { version = "2" } prost = { version = "0.13", features = ["no-recursion-limit", "prost-derive"] } rand = { version = "0.8", features = ["small_rng"] } regex = { version = "1" } From 23fc611461fb5f9026bd4bb3a4e3a0efcb8630b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lassi=20P=C3=B6l=C3=B6nen?= Date: Mon, 26 May 2025 17:57:09 +0300 Subject: [PATCH 07/48] Add metadata to pgaudit log logline (#11933) Previously we were using project-id/endpoint-id as SYSLOGTAG, which has a limit of 32 characters, so the endpoint-id got truncated. The output is now in RFC5424 format, where the message is json encoded with additional metadata `endpoint_id` and `project_id` Also as pgaudit logs multiline messages, we now detect this by parsing the timestamp in the specific format, and consider non-matching lines to belong in the previous log message. Using syslog structured-data would be an alternative, but leaning towards json due to being somewhat more generic. --- compute_tools/src/compute.rs | 25 +++++++------------ .../compute_audit_rsyslog_template.conf | 22 +++++++++++++--- compute_tools/src/rsyslog.rs | 6 +++-- 3 files changed, 31 insertions(+), 22 deletions(-) diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index f494e2444a..cb857e0a3e 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -695,25 +695,18 @@ impl ComputeNode { let log_directory_path = Path::new(&self.params.pgdata).join("log"); let log_directory_path = log_directory_path.to_string_lossy().to_string(); - // Add project_id,endpoint_id tag to identify the logs. + // Add project_id,endpoint_id to identify the logs. // // These ids are passed from cplane, - // for backwards compatibility (old computes that don't have them), - // we set them to None. - // TODO: Clean up this code when all computes have them. - let tag: Option = match ( - pspec.spec.project_id.as_deref(), - pspec.spec.endpoint_id.as_deref(), - ) { - (Some(project_id), Some(endpoint_id)) => { - Some(format!("{project_id}/{endpoint_id}")) - } - (Some(project_id), None) => Some(format!("{project_id}/None")), - (None, Some(endpoint_id)) => Some(format!("None,{endpoint_id}")), - (None, None) => None, - }; + let endpoint_id = pspec.spec.endpoint_id.as_deref().unwrap_or(""); + let project_id = pspec.spec.project_id.as_deref().unwrap_or(""); - configure_audit_rsyslog(log_directory_path.clone(), tag, &remote_endpoint)?; + configure_audit_rsyslog( + log_directory_path.clone(), + endpoint_id, + project_id, + &remote_endpoint, + )?; // Launch a background task to clean up the audit logs launch_pgaudit_gc(log_directory_path); diff --git a/compute_tools/src/config_template/compute_audit_rsyslog_template.conf b/compute_tools/src/config_template/compute_audit_rsyslog_template.conf index 9ca7e36738..48b1a6f5c3 100644 --- a/compute_tools/src/config_template/compute_audit_rsyslog_template.conf +++ b/compute_tools/src/config_template/compute_audit_rsyslog_template.conf @@ -2,10 +2,24 @@ module(load="imfile") # Input configuration for log files in the specified directory -# Replace {log_directory} with the directory containing the log files -input(type="imfile" File="{log_directory}/*.log" Tag="{tag}" Severity="info" Facility="local0") +# The messages can be multiline. The start of the message is a timestamp +# in "%Y-%m-%d %H:%M:%S.%3N GMT" (so timezone hardcoded). +# Replace log_directory with the directory containing the log files +input(type="imfile" File="{log_directory}/*.log" + Tag="pgaudit_log" Severity="info" Facility="local5" + startmsg.regex="^[[:digit:]]{{4}}-[[:digit:]]{{2}}-[[:digit:]]{{2}} [[:digit:]]{{2}}:[[:digit:]]{{2}}:[[:digit:]]{{2}}.[[:digit:]]{{3}} GMT,") + # the directory to store rsyslog state files global(workDirectory="/var/log/rsyslog") -# Forward logs to remote syslog server -*.* @@{remote_endpoint} +# Construct json, endpoint_id and project_id as additional metadata +set $.json_log!endpoint_id = "{endpoint_id}"; +set $.json_log!project_id = "{project_id}"; +set $.json_log!msg = $msg; + +# Template suitable for rfc5424 syslog format +template(name="PgAuditLog" type="string" + string="<%PRI%>1 %TIMESTAMP:::date-rfc3339% %HOSTNAME% - - - - %$.json_log%") + +# Forward to remote syslog receiver (@@:;format +local5.info @@{remote_endpoint};PgAuditLog diff --git a/compute_tools/src/rsyslog.rs b/compute_tools/src/rsyslog.rs index c873697623..3bc2e72b19 100644 --- a/compute_tools/src/rsyslog.rs +++ b/compute_tools/src/rsyslog.rs @@ -84,13 +84,15 @@ fn restart_rsyslog() -> Result<()> { pub fn configure_audit_rsyslog( log_directory: String, - tag: Option, + endpoint_id: &str, + project_id: &str, remote_endpoint: &str, ) -> Result<()> { let config_content: String = format!( include_str!("config_template/compute_audit_rsyslog_template.conf"), log_directory = log_directory, - tag = tag.unwrap_or("".to_string()), + endpoint_id = endpoint_id, + project_id = project_id, remote_endpoint = remote_endpoint ); From 3e86008e66e76839b31c112dea52671604df8f23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arpad=20M=C3=BCller?= Date: Tue, 27 May 2025 01:23:58 +0200 Subject: [PATCH 08/48] read-only timelines (#12015) Support timeline creations on the storage controller to opt out from their creation on the safekeepers, introducing the read-only timelines concept. Read only timelines: * will never receive WAL of their own, so it's fine to not create them on the safekeepers * the property is non-transitive. children of read-only timelines aren't neccessarily read-only themselves. This feature can be used for snapshots, to prevent the safekeepers from being overloaded by empty timelines that won't ever get written to. In the current world, this is not a problem, because timelines are created implicitly by the compute connecting to a safekeeper that doesn't have the timeline yet. In the future however, where the storage controller creates timelines eagerly, we should watch out for that. We represent read-only timelines in the storage controller database so that we ensure that they never touch the safekeepers at all. Especially we don't want them to cause a mess during the importing process of the timelines from the cplane to the storcon database. In a hypothetical future where we have a feature to detach timelines from safekeepers, we'll either need to find a way to distinguish the two, or if not, asking safekeepers to list the (empty) timeline prefix and delete everything from it isn't a big issue either. This patch will unconditionally hit the new safekeeper timeline creation path for read-only timelines, without them needing the `--timelines-onto-safekeepers` flag enabled. This is done because it's lower risk (no safekeepers or computes involved at all) and gives us some initial way to verify at least some parts of that code in prod. https://github.com/neondatabase/cloud/issues/29435 https://github.com/neondatabase/neon/issues/11670 --- control_plane/src/bin/neon_local.rs | 1 + libs/pageserver_api/src/models.rs | 2 + pageserver/src/http/openapi_spec.yml | 2 + pageserver/src/http/routes.rs | 1 + storage_controller/src/service.rs | 11 ++++- .../src/service/safekeeper_service.rs | 40 +++++++++++++++---- .../regress/test_timeline_detach_ancestor.py | 22 +++++++++- 7 files changed, 67 insertions(+), 12 deletions(-) diff --git a/control_plane/src/bin/neon_local.rs b/control_plane/src/bin/neon_local.rs index 3bceef8fa7..ef6985d697 100644 --- a/control_plane/src/bin/neon_local.rs +++ b/control_plane/src/bin/neon_local.rs @@ -1279,6 +1279,7 @@ async fn handle_timeline(cmd: &TimelineCmd, env: &mut local_env::LocalEnv) -> Re mode: pageserver_api::models::TimelineCreateRequestMode::Branch { ancestor_timeline_id, ancestor_start_lsn: start_lsn, + read_only: false, pg_version: None, }, }; diff --git a/libs/pageserver_api/src/models.rs b/libs/pageserver_api/src/models.rs index 383939a13f..9f3736d57a 100644 --- a/libs/pageserver_api/src/models.rs +++ b/libs/pageserver_api/src/models.rs @@ -402,6 +402,8 @@ pub enum TimelineCreateRequestMode { // using a flattened enum, so, it was an accepted field, and // we continue to accept it by having it here. pg_version: Option, + #[serde(default, skip_serializing_if = "std::ops::Not::not")] + read_only: bool, }, ImportPgdata { import_pgdata: TimelineCreateRequestModeImportPgdata, diff --git a/pageserver/src/http/openapi_spec.yml b/pageserver/src/http/openapi_spec.yml index 7ea148971f..cf99cb110c 100644 --- a/pageserver/src/http/openapi_spec.yml +++ b/pageserver/src/http/openapi_spec.yml @@ -626,6 +626,8 @@ paths: format: hex pg_version: type: integer + read_only: + type: boolean existing_initdb_timeline_id: type: string format: hex diff --git a/pageserver/src/http/routes.rs b/pageserver/src/http/routes.rs index 0d6791cddd..65e24ff3e9 100644 --- a/pageserver/src/http/routes.rs +++ b/pageserver/src/http/routes.rs @@ -572,6 +572,7 @@ async fn timeline_create_handler( TimelineCreateRequestMode::Branch { ancestor_timeline_id, ancestor_start_lsn, + read_only: _, pg_version: _, } => tenant::CreateTimelineParams::Branch(tenant::CreateTimelineParamsBranch { new_timeline_id, diff --git a/storage_controller/src/service.rs b/storage_controller/src/service.rs index 7e4bb627af..d8167e9d94 100644 --- a/storage_controller/src/service.rs +++ b/storage_controller/src/service.rs @@ -3823,6 +3823,13 @@ impl Service { .await; failpoint_support::sleep_millis_async!("tenant-create-timeline-shared-lock"); let is_import = create_req.is_import(); + let read_only = matches!( + create_req.mode, + models::TimelineCreateRequestMode::Branch { + read_only: true, + .. + } + ); if is_import { // Ensure that there is no split on-going. @@ -3895,13 +3902,13 @@ impl Service { } None - } else if safekeepers { + } else if safekeepers || read_only { // Note that for imported timelines, we do not create the timeline on the safekeepers // straight away. Instead, we do it once the import finalized such that we know what // start LSN to provide for the safekeepers. This is done in // [`Self::finalize_timeline_import`]. let res = self - .tenant_timeline_create_safekeepers(tenant_id, &timeline_info) + .tenant_timeline_create_safekeepers(tenant_id, &timeline_info, read_only) .instrument(tracing::info_span!("timeline_create_safekeepers", %tenant_id, timeline_id=%timeline_info.timeline_id)) .await?; Some(res) diff --git a/storage_controller/src/service/safekeeper_service.rs b/storage_controller/src/service/safekeeper_service.rs index cd5ace449d..1f673fe445 100644 --- a/storage_controller/src/service/safekeeper_service.rs +++ b/storage_controller/src/service/safekeeper_service.rs @@ -208,6 +208,7 @@ impl Service { self: &Arc, tenant_id: TenantId, timeline_info: &TimelineInfo, + read_only: bool, ) -> Result { let timeline_id = timeline_info.timeline_id; let pg_version = timeline_info.pg_version * 10000; @@ -220,7 +221,11 @@ impl Service { let start_lsn = timeline_info.last_record_lsn; // Choose initial set of safekeepers respecting affinity - let sks = self.safekeepers_for_new_timeline().await?; + let sks = if !read_only { + self.safekeepers_for_new_timeline().await? + } else { + Vec::new() + }; let sks_persistence = sks.iter().map(|sk| sk.id.0 as i64).collect::>(); // Add timeline to db let mut timeline_persist = TimelinePersistence { @@ -253,6 +258,16 @@ impl Service { ))); } } + let ret = SafekeepersInfo { + generation: timeline_persist.generation as u32, + safekeepers: sks.clone(), + tenant_id, + timeline_id, + }; + if read_only { + return Ok(ret); + } + // Create the timeline on a quorum of safekeepers let remaining = self .tenant_timeline_create_safekeepers_quorum( @@ -316,12 +331,7 @@ impl Service { } } - Ok(SafekeepersInfo { - generation: timeline_persist.generation as u32, - safekeepers: sks, - tenant_id, - timeline_id, - }) + Ok(ret) } pub(crate) async fn tenant_timeline_create_safekeepers_until_success( @@ -336,8 +346,10 @@ impl Service { return Err(TimelineImportFinalizeError::ShuttingDown); } + // This function is only used in non-read-only scenarios + let read_only = false; let res = self - .tenant_timeline_create_safekeepers(tenant_id, &timeline_info) + .tenant_timeline_create_safekeepers(tenant_id, &timeline_info, read_only) .await; match res { @@ -410,6 +422,18 @@ impl Service { .chain(tl.sk_set.iter()) .collect::>(); + // The timeline has no safekeepers: we need to delete it from the db manually, + // as no safekeeper reconciler will get to it + if all_sks.is_empty() { + if let Err(err) = self + .persistence + .delete_timeline(tenant_id, timeline_id) + .await + { + tracing::warn!(%tenant_id, %timeline_id, "couldn't delete timeline from db: {err}"); + } + } + // Schedule reconciliations for &sk_id in all_sks.iter() { let pending_op = TimelinePendingOpPersistence { diff --git a/test_runner/regress/test_timeline_detach_ancestor.py b/test_runner/regress/test_timeline_detach_ancestor.py index d42c5d403e..f0810270b1 100644 --- a/test_runner/regress/test_timeline_detach_ancestor.py +++ b/test_runner/regress/test_timeline_detach_ancestor.py @@ -10,6 +10,7 @@ from queue import Empty, Queue from threading import Barrier import pytest +import requests from fixtures.common_types import Lsn, TimelineArchivalState, TimelineId from fixtures.log_helper import log from fixtures.neon_fixtures import ( @@ -401,8 +402,25 @@ def test_ancestor_detach_behavior_v2(neon_env_builder: NeonEnvBuilder, snapshots "earlier", ancestor_branch_name="main", ancestor_start_lsn=branchpoint_pipe ) - snapshot_branchpoint_old = env.create_branch( - "snapshot_branchpoint_old", ancestor_branch_name="main", ancestor_start_lsn=branchpoint_y + snapshot_branchpoint_old = TimelineId.generate() + + env.storage_controller.timeline_create( + env.initial_tenant, + { + "new_timeline_id": str(snapshot_branchpoint_old), + "ancestor_start_lsn": str(branchpoint_y), + "ancestor_timeline_id": str(env.initial_timeline), + "read_only": True, + }, + ) + sk = env.safekeepers[0] + assert sk + with pytest.raises(requests.exceptions.HTTPError, match="Not Found"): + sk.http_client().timeline_status( + tenant_id=env.initial_tenant, timeline_id=snapshot_branchpoint_old + ) + env.neon_cli.mappings_map_branch( + "snapshot_branchpoint_old", env.initial_tenant, snapshot_branchpoint_old ) snapshot_branchpoint = env.create_branch( From fe1513ca5774473214cadf7f82b36ed23d79c734 Mon Sep 17 00:00:00 2001 From: Tristan Partin Date: Mon, 26 May 2025 21:21:24 -0500 Subject: [PATCH 09/48] Add neon.safekeeper_conninfo_options GUC (#11901) In order to enable TLS connections between computes and safekeepers, we need to provide the control plane with a way to configure the various libpq keyword parameters, sslmode and sslrootcert. neon.safekeepers is a comma separated list of safekeepers formatted as host:port, so isn't available for extension in the same way that neon.pageserver_connstring is. This could be remedied in a future PR. Part-of: https://github.com/neondatabase/cloud/issues/25823 Link: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS Signed-off-by: Tristan Partin --- libs/walproposer/src/walproposer.rs | 8 ++++++++ pgxn/neon/walproposer.c | 5 +++-- pgxn/neon/walproposer.h | 3 +++ pgxn/neon/walproposer_pg.c | 12 ++++++++++++ safekeeper/tests/walproposer_sim/simulation.rs | 1 + 5 files changed, 27 insertions(+), 2 deletions(-) diff --git a/libs/walproposer/src/walproposer.rs b/libs/walproposer/src/walproposer.rs index 4e50c21fca..e95494297c 100644 --- a/libs/walproposer/src/walproposer.rs +++ b/libs/walproposer/src/walproposer.rs @@ -1,6 +1,7 @@ #![allow(clippy::todo)] use std::ffi::CString; +use std::str::FromStr; use postgres_ffi::WAL_SEGMENT_SIZE; use utils::id::TenantTimelineId; @@ -173,6 +174,8 @@ pub struct Config { pub ttid: TenantTimelineId, /// List of safekeepers in format `host:port` pub safekeepers_list: Vec, + /// libpq connection info options + pub safekeeper_conninfo_options: String, /// Safekeeper reconnect timeout in milliseconds pub safekeeper_reconnect_timeout: i32, /// Safekeeper connection timeout in milliseconds @@ -202,6 +205,9 @@ impl Wrapper { .into_bytes_with_nul(); assert!(safekeepers_list_vec.len() == safekeepers_list_vec.capacity()); let safekeepers_list = safekeepers_list_vec.as_mut_ptr() as *mut std::ffi::c_char; + let safekeeper_conninfo_options = CString::from_str(&config.safekeeper_conninfo_options) + .unwrap() + .into_raw(); let callback_data = Box::into_raw(Box::new(api)) as *mut ::std::os::raw::c_void; @@ -209,6 +215,7 @@ impl Wrapper { neon_tenant, neon_timeline, safekeepers_list, + safekeeper_conninfo_options, safekeeper_reconnect_timeout: config.safekeeper_reconnect_timeout, safekeeper_connection_timeout: config.safekeeper_connection_timeout, wal_segment_size: WAL_SEGMENT_SIZE as i32, // default 16MB @@ -576,6 +583,7 @@ mod tests { let config = crate::walproposer::Config { ttid, safekeepers_list: vec!["localhost:5000".to_string()], + safekeeper_conninfo_options: String::new(), safekeeper_reconnect_timeout: 1000, safekeeper_connection_timeout: 10000, sync_safekeepers: true, diff --git a/pgxn/neon/walproposer.c b/pgxn/neon/walproposer.c index 3befb42030..f42103c7cd 100644 --- a/pgxn/neon/walproposer.c +++ b/pgxn/neon/walproposer.c @@ -155,8 +155,9 @@ WalProposerCreate(WalProposerConfig *config, walproposer_api api) int written = 0; written = snprintf((char *) &sk->conninfo, MAXCONNINFO, - "host=%s port=%s dbname=replication options='-c timeline_id=%s tenant_id=%s'", - sk->host, sk->port, wp->config->neon_timeline, wp->config->neon_tenant); + "%s host=%s port=%s dbname=replication options='-c timeline_id=%s tenant_id=%s'", + wp->config->safekeeper_conninfo_options, sk->host, sk->port, + wp->config->neon_timeline, wp->config->neon_tenant); if (written > MAXCONNINFO || written < 0) wp_log(FATAL, "could not create connection string for safekeeper %s:%s", sk->host, sk->port); } diff --git a/pgxn/neon/walproposer.h b/pgxn/neon/walproposer.h index 83ef72d3d7..cca20e746b 100644 --- a/pgxn/neon/walproposer.h +++ b/pgxn/neon/walproposer.h @@ -714,6 +714,9 @@ typedef struct WalProposerConfig */ char *safekeepers_list; + /* libpq connection info options. */ + char *safekeeper_conninfo_options; + /* * WalProposer reconnects to offline safekeepers once in this interval. * Time is in milliseconds. diff --git a/pgxn/neon/walproposer_pg.c b/pgxn/neon/walproposer_pg.c index 17582405db..d15bf91d24 100644 --- a/pgxn/neon/walproposer_pg.c +++ b/pgxn/neon/walproposer_pg.c @@ -64,6 +64,7 @@ char *wal_acceptors_list = ""; int wal_acceptor_reconnect_timeout = 1000; int wal_acceptor_connection_timeout = 10000; int safekeeper_proto_version = 3; +char *safekeeper_conninfo_options = ""; /* Set to true in the walproposer bgw. */ static bool am_walproposer; @@ -119,6 +120,7 @@ init_walprop_config(bool syncSafekeepers) walprop_config.neon_timeline = neon_timeline; /* WalProposerCreate scribbles directly on it, so pstrdup */ walprop_config.safekeepers_list = pstrdup(wal_acceptors_list); + walprop_config.safekeeper_conninfo_options = pstrdup(safekeeper_conninfo_options); walprop_config.safekeeper_reconnect_timeout = wal_acceptor_reconnect_timeout; walprop_config.safekeeper_connection_timeout = wal_acceptor_connection_timeout; walprop_config.wal_segment_size = wal_segment_size; @@ -203,6 +205,16 @@ nwp_register_gucs(void) * GUC_LIST_QUOTE */ NULL, assign_neon_safekeepers, NULL); + DefineCustomStringVariable( + "neon.safekeeper_conninfo_options", + "libpq keyword parameters and values to apply to safekeeper connections", + NULL, + &safekeeper_conninfo_options, + "", + PGC_POSTMASTER, + 0, + NULL, NULL, NULL); + DefineCustomIntVariable( "neon.safekeeper_reconnect_timeout", "Walproposer reconnects to offline safekeepers once in this interval.", diff --git a/safekeeper/tests/walproposer_sim/simulation.rs b/safekeeper/tests/walproposer_sim/simulation.rs index f314143952..70fecfbe22 100644 --- a/safekeeper/tests/walproposer_sim/simulation.rs +++ b/safekeeper/tests/walproposer_sim/simulation.rs @@ -87,6 +87,7 @@ impl WalProposer { let config = Config { ttid, safekeepers_list: addrs, + safekeeper_conninfo_options: String::new(), safekeeper_reconnect_timeout: 1000, safekeeper_connection_timeout: 5000, sync_safekeepers, From dd501554c9af60083ff93339b06cc4ff2022fce0 Mon Sep 17 00:00:00 2001 From: a-masterov <72613290+a-masterov@users.noreply.github.com> Date: Tue, 27 May 2025 10:54:59 +0200 Subject: [PATCH 10/48] add a script to run the test for online-advisor as a regular user. (#12017) ## Problem The regression test for the extension online_advisor fails on the staging instance due to a lack of permission to alter the database. ## Summary of changes A script was added to work around this problem. --------- Co-authored-by: Alexander Lakhin --- docker-compose/ext-src/online_advisor-src/neon-test.sh | 6 ++++++ .../ext-src/online_advisor-src/regular-test.sh | 9 +++++++++ 2 files changed, 15 insertions(+) create mode 100755 docker-compose/ext-src/online_advisor-src/neon-test.sh create mode 100755 docker-compose/ext-src/online_advisor-src/regular-test.sh diff --git a/docker-compose/ext-src/online_advisor-src/neon-test.sh b/docker-compose/ext-src/online_advisor-src/neon-test.sh new file mode 100755 index 0000000000..db5c2821fa --- /dev/null +++ b/docker-compose/ext-src/online_advisor-src/neon-test.sh @@ -0,0 +1,6 @@ +#!/bin/sh +set -ex +cd "$(dirname "${0}")" +if [ -f Makefile ]; then + make installcheck +fi diff --git a/docker-compose/ext-src/online_advisor-src/regular-test.sh b/docker-compose/ext-src/online_advisor-src/regular-test.sh new file mode 100755 index 0000000000..e94f03aa70 --- /dev/null +++ b/docker-compose/ext-src/online_advisor-src/regular-test.sh @@ -0,0 +1,9 @@ +#!/bin/sh +set -ex +cd "$(dirname ${0})" +[ -f Makefile ] || exit 0 +dropdb --if-exist contrib_regression +createdb contrib_regression +PG_REGRESS=$(dirname "$(pg_config --pgxs)")/../test/regress/pg_regress +TESTS=$(echo sql/* | sed 's|sql/||g; s|\.sql||g') +${PG_REGRESS} --use-existing --inputdir=./ --bindir='/usr/local/pgsql/bin' --dbname=contrib_regression ${TESTS} From 9657fbc1941de6c5396b1d59ff12811c6d38c00c Mon Sep 17 00:00:00 2001 From: Vlad Lazar Date: Tue, 27 May 2025 10:52:59 +0100 Subject: [PATCH 11/48] pageserver: add and stabilize import chaos test (#11982) ## Problem Test coverage of timeline imports is lacking. ## Summary of changes This PR adds a chaos import test. It runs an import while injecting various chaos events in the environment. All the commits that follow the test fix various issues that were surfaced by it. Closes https://github.com/neondatabase/neon/issues/10191 --- pageserver/src/http/routes.rs | 17 +- pageserver/src/tenant.rs | 77 +++-- .../src/tenant/timeline/import_pgdata.rs | 10 +- .../src/tenant/timeline/import_pgdata/flow.rs | 34 ++- storage_controller/src/service.rs | 10 + test_runner/fixtures/neon_fixtures.py | 41 ++- test_runner/regress/test_import_pgdata.py | 267 +++++++++++++++++- 7 files changed, 413 insertions(+), 43 deletions(-) diff --git a/pageserver/src/http/routes.rs b/pageserver/src/http/routes.rs index 65e24ff3e9..c449e3373f 100644 --- a/pageserver/src/http/routes.rs +++ b/pageserver/src/http/routes.rs @@ -370,6 +370,18 @@ impl From for ApiError { } } +impl From for ApiError { + fn from(err: crate::tenant::FinalizeTimelineImportError) -> ApiError { + use crate::tenant::FinalizeTimelineImportError::*; + match err { + ImportTaskStillRunning => { + ApiError::ResourceUnavailable("Import task still running".into()) + } + ShuttingDown => ApiError::ShuttingDown, + } + } +} + // Helper function to construct a TimelineInfo struct for a timeline async fn build_timeline_info( timeline: &Arc, @@ -3533,10 +3545,7 @@ async fn activate_post_import_handler( tenant.wait_to_become_active(ACTIVE_TENANT_TIMEOUT).await?; - tenant - .finalize_importing_timeline(timeline_id) - .await - .map_err(ApiError::InternalServerError)?; + tenant.finalize_importing_timeline(timeline_id).await?; match tenant.get_timeline(timeline_id, false) { Ok(_timeline) => { diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index 7e006ef9e6..86731fb666 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -864,6 +864,14 @@ impl Debug for SetStoppingError { } } +#[derive(thiserror::Error, Debug)] +pub(crate) enum FinalizeTimelineImportError { + #[error("Import task not done yet")] + ImportTaskStillRunning, + #[error("Shutting down")] + ShuttingDown, +} + /// Arguments to [`TenantShard::create_timeline`]. /// /// Not usable as an idempotency key for timeline creation because if [`CreateTimelineParamsBranch::ancestor_start_lsn`] @@ -1150,10 +1158,20 @@ impl TenantShard { ctx, )?; let disk_consistent_lsn = timeline.get_disk_consistent_lsn(); - anyhow::ensure!( - disk_consistent_lsn.is_valid(), - "Timeline {tenant_id}/{timeline_id} has invalid disk_consistent_lsn" - ); + + if !disk_consistent_lsn.is_valid() { + // As opposed to normal timelines which get initialised with a disk consitent LSN + // via initdb, imported timelines start from 0. If the import task stops before + // it advances disk consitent LSN, allow it to resume. + let in_progress_import = import_pgdata + .as_ref() + .map(|import| !import.is_done()) + .unwrap_or(false); + if !in_progress_import { + anyhow::bail!("Timeline {tenant_id}/{timeline_id} has invalid disk_consistent_lsn"); + } + } + assert_eq!( disk_consistent_lsn, metadata.disk_consistent_lsn(), @@ -1247,20 +1265,25 @@ impl TenantShard { } } - // Sanity check: a timeline should have some content. - anyhow::ensure!( - ancestor.is_some() - || timeline - .layers - .read() - .await - .layer_map() - .expect("currently loading, layer manager cannot be shutdown already") - .iter_historic_layers() - .next() - .is_some(), - "Timeline has no ancestor and no layer files" - ); + if disk_consistent_lsn.is_valid() { + // Sanity check: a timeline should have some content. + // Exception: importing timelines might not yet have any + anyhow::ensure!( + ancestor.is_some() + || timeline + .layers + .read() + .await + .layer_map() + .expect( + "currently loading, layer manager cannot be shutdown already" + ) + .iter_historic_layers() + .next() + .is_some(), + "Timeline has no ancestor and no layer files" + ); + } Ok(TimelineInitAndSyncResult::ReadyToActivate) } @@ -2860,13 +2883,13 @@ impl TenantShard { pub(crate) async fn finalize_importing_timeline( &self, timeline_id: TimelineId, - ) -> anyhow::Result<()> { + ) -> Result<(), FinalizeTimelineImportError> { let timeline = { let locked = self.timelines_importing.lock().unwrap(); match locked.get(&timeline_id) { Some(importing_timeline) => { if !importing_timeline.import_task_handle.is_finished() { - return Err(anyhow::anyhow!("Import task not done yet")); + return Err(FinalizeTimelineImportError::ImportTaskStillRunning); } importing_timeline.timeline.clone() @@ -2879,8 +2902,13 @@ impl TenantShard { timeline .remote_client - .schedule_index_upload_for_import_pgdata_finalize()?; - timeline.remote_client.wait_completion().await?; + .schedule_index_upload_for_import_pgdata_finalize() + .map_err(|_err| FinalizeTimelineImportError::ShuttingDown)?; + timeline + .remote_client + .wait_completion() + .await + .map_err(|_err| FinalizeTimelineImportError::ShuttingDown)?; self.timelines_importing .lock() @@ -3484,8 +3512,9 @@ impl TenantShard { let mut timelines_importing = self.timelines_importing.lock().unwrap(); timelines_importing .drain() - .for_each(|(_timeline_id, importing_timeline)| { - importing_timeline.shutdown(); + .for_each(|(timeline_id, importing_timeline)| { + let span = tracing::info_span!("importing_timeline_shutdown", %timeline_id); + js.spawn(async move { importing_timeline.shutdown().instrument(span).await }); }); } // test_long_timeline_create_then_tenant_delete is leaning on this message diff --git a/pageserver/src/tenant/timeline/import_pgdata.rs b/pageserver/src/tenant/timeline/import_pgdata.rs index 658d867c18..db62e9000c 100644 --- a/pageserver/src/tenant/timeline/import_pgdata.rs +++ b/pageserver/src/tenant/timeline/import_pgdata.rs @@ -25,8 +25,11 @@ pub(crate) struct ImportingTimeline { } impl ImportingTimeline { - pub(crate) fn shutdown(self) { + pub(crate) async fn shutdown(self) { self.import_task_handle.abort(); + let _ = self.import_task_handle.await; + + self.timeline.remote_client.shutdown().await; } } @@ -93,6 +96,11 @@ pub async fn doit( ); } + timeline + .remote_client + .schedule_index_upload_for_file_changes()?; + timeline.remote_client.wait_completion().await?; + // Communicate that shard is done. // Ensure at-least-once delivery of the upcall to storage controller // before we mark the task as done and never come here again. diff --git a/pageserver/src/tenant/timeline/import_pgdata/flow.rs b/pageserver/src/tenant/timeline/import_pgdata/flow.rs index 3e10a4e6d6..2ba4ca69ac 100644 --- a/pageserver/src/tenant/timeline/import_pgdata/flow.rs +++ b/pageserver/src/tenant/timeline/import_pgdata/flow.rs @@ -113,14 +113,14 @@ async fn run_v1( let plan_hash = hasher.finish(); if let Some(progress) = &import_progress { - if plan_hash != progress.import_plan_hash { - anyhow::bail!("Import plan does not match storcon metadata"); - } - // Handle collisions on jobs of unequal length if progress.jobs != plan.jobs.len() { anyhow::bail!("Import plan job length does not match storcon metadata") } + + if plan_hash != progress.import_plan_hash { + anyhow::bail!("Import plan does not match storcon metadata"); + } } pausable_failpoint!("import-timeline-pre-execute-pausable"); @@ -218,6 +218,19 @@ impl Planner { checkpoint_buf, ))); + // Sort the tasks by the key ranges they handle. + // The plan being generated here needs to be stable across invocations + // of this method. + self.tasks.sort_by_key(|task| match task { + AnyImportTask::SingleKey(key) => (key.key, key.key.next()), + AnyImportTask::RelBlocks(rel_blocks) => { + (rel_blocks.key_range.start, rel_blocks.key_range.end) + } + AnyImportTask::SlruBlocks(slru_blocks) => { + (slru_blocks.key_range.start, slru_blocks.key_range.end) + } + }); + // Assigns parts of key space to later parallel jobs let mut last_end_key = Key::MIN; let mut current_chunk = Vec::new(); @@ -426,6 +439,8 @@ impl Plan { })); }, maybe_complete_job_idx = work.next() => { + pausable_failpoint!("import-task-complete-pausable"); + match maybe_complete_job_idx { Some(Ok((job_idx, res))) => { assert!(last_completed_job_idx.checked_add(1).unwrap() == job_idx); @@ -440,6 +455,9 @@ impl Plan { import_plan_hash, }; + timeline.remote_client.schedule_index_upload_for_file_changes()?; + timeline.remote_client.wait_completion().await?; + storcon_client.put_timeline_import_status( timeline.tenant_shard_id, timeline.timeline_id, @@ -640,7 +658,11 @@ impl Hash for ImportSingleKeyTask { let ImportSingleKeyTask { key, buf } = self; key.hash(state); - buf.hash(state); + // The key value might not have a stable binary representation. + // For instance, the db directory uses an unstable hash-map. + // To work around this we are a bit lax here and only hash the + // size of the buffer which must be consistent. + buf.len().hash(state); } } @@ -915,7 +937,7 @@ impl ChunkProcessingJob { let guard = timeline.layers.read().await; let existing_layer = guard.try_get_from_key(&desc.key()); if let Some(layer) = existing_layer { - if layer.metadata().generation != timeline.generation { + if layer.metadata().generation == timeline.generation { return Err(anyhow::anyhow!( "Import attempted to rewrite layer file in the same generation: {}", layer.local_path() diff --git a/storage_controller/src/service.rs b/storage_controller/src/service.rs index d8167e9d94..d284747f73 100644 --- a/storage_controller/src/service.rs +++ b/storage_controller/src/service.rs @@ -3922,6 +3922,11 @@ impl Service { }) } + #[instrument(skip_all, fields( + tenant_id=%req.tenant_shard_id.tenant_id, + shard_id=%req.tenant_shard_id.shard_slug(), + timeline_id=%req.timeline_id, + ))] pub(crate) async fn handle_timeline_shard_import_progress( self: &Arc, req: TimelineImportStatusRequest, @@ -3971,6 +3976,11 @@ impl Service { }) } + #[instrument(skip_all, fields( + tenant_id=%req.tenant_shard_id.tenant_id, + shard_id=%req.tenant_shard_id.shard_slug(), + timeline_id=%req.timeline_id, + ))] pub(crate) async fn handle_timeline_shard_import_progress_upcall( self: &Arc, req: PutTimelineImportStatusRequest, diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index dda4d40a11..7f4150b580 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -404,6 +404,29 @@ class PageserverTracingConfig: return ("tracing", value) +@dataclass +class PageserverImportConfig: + import_job_concurrency: int + import_job_soft_size_limit: int + import_job_checkpoint_threshold: int + + @staticmethod + def default() -> PageserverImportConfig: + return PageserverImportConfig( + import_job_concurrency=4, + import_job_soft_size_limit=512 * 1024, + import_job_checkpoint_threshold=4, + ) + + def to_config_key_value(self) -> tuple[str, dict[str, Any]]: + value = { + "import_job_concurrency": self.import_job_concurrency, + "import_job_soft_size_limit": self.import_job_soft_size_limit, + "import_job_checkpoint_threshold": self.import_job_checkpoint_threshold, + } + return ("timeline_import_config", value) + + class NeonEnvBuilder: """ Builder object to create a Neon runtime environment @@ -454,6 +477,7 @@ class NeonEnvBuilder: pageserver_wal_receiver_protocol: PageserverWalReceiverProtocol | None = None, pageserver_get_vectored_concurrent_io: str | None = None, pageserver_tracing_config: PageserverTracingConfig | None = None, + pageserver_import_config: PageserverImportConfig | None = None, ): self.repo_dir = repo_dir self.rust_log_override = rust_log_override @@ -511,6 +535,7 @@ class NeonEnvBuilder: ) self.pageserver_tracing_config = pageserver_tracing_config + self.pageserver_import_config = pageserver_import_config self.pageserver_default_tenant_config_compaction_algorithm: dict[str, Any] | None = ( pageserver_default_tenant_config_compaction_algorithm @@ -1179,6 +1204,10 @@ class NeonEnv: self.pageserver_wal_receiver_protocol = config.pageserver_wal_receiver_protocol self.pageserver_get_vectored_concurrent_io = config.pageserver_get_vectored_concurrent_io self.pageserver_tracing_config = config.pageserver_tracing_config + if config.pageserver_import_config is None: + self.pageserver_import_config = PageserverImportConfig.default() + else: + self.pageserver_import_config = config.pageserver_import_config # Create the neon_local's `NeonLocalInitConf` cfg: dict[str, Any] = { @@ -1258,12 +1287,6 @@ class NeonEnv: "no_sync": True, # Look for gaps in WAL received from safekeepeers "validate_wal_contiguity": True, - # TODO(vlad): make these configurable through the builder - "timeline_import_config": { - "import_job_concurrency": 4, - "import_job_soft_size_limit": 512 * 1024, - "import_job_checkpoint_threshold": 4, - }, } # Batching (https://github.com/neondatabase/neon/issues/9377): @@ -1325,6 +1348,12 @@ class NeonEnv: ps_cfg[key] = value + if self.pageserver_import_config is not None: + key, value = self.pageserver_import_config.to_config_key_value() + + if key not in ps_cfg: + ps_cfg[key] = value + # Create a corresponding NeonPageserver object ps = NeonPageserver( self, ps_id, port=pageserver_port, az_id=ps_cfg["availability_zone"] diff --git a/test_runner/regress/test_import_pgdata.py b/test_runner/regress/test_import_pgdata.py index 0472b92145..69cbdec5b0 100644 --- a/test_runner/regress/test_import_pgdata.py +++ b/test_runner/regress/test_import_pgdata.py @@ -1,7 +1,10 @@ import base64 +import concurrent.futures import json +import random +import threading import time -from enum import Enum +from enum import Enum, StrEnum from pathlib import Path from threading import Event @@ -11,7 +14,14 @@ import pytest from fixtures.common_types import Lsn, TenantId, TenantShardId, TimelineId from fixtures.fast_import import FastImport from fixtures.log_helper import log -from fixtures.neon_fixtures import NeonEnvBuilder, PgBin, PgProtocol, VanillaPostgres +from fixtures.neon_fixtures import ( + NeonEnvBuilder, + PageserverImportConfig, + PgBin, + PgProtocol, + StorageControllerMigrationConfig, + VanillaPostgres, +) from fixtures.pageserver.http import ( ImportPgdataIdemptencyKey, ) @@ -494,6 +504,259 @@ def test_import_respects_tenant_shutdown( wait_until(cplane_notified) +@skip_in_debug_build("Validation query takes too long in debug builds") +def test_import_chaos( + neon_env_builder: NeonEnvBuilder, vanilla_pg: VanillaPostgres, make_httpserver: HTTPServer +): + """ + Perform a timeline import while injecting chaos in the environment. + We expect that the import completes eventually, that it passes validation and + the resulting timeline can be written to. + """ + TARGET_RELBOCK_SIZE = 512 * 1024 * 1024 # 512 MiB + ALLOWED_IMPORT_RUNTIME = 90 # seconds + SHARD_COUNT = 4 + + neon_env_builder.num_pageservers = SHARD_COUNT + neon_env_builder.pageserver_import_config = PageserverImportConfig( + import_job_concurrency=1, + import_job_soft_size_limit=64 * 1024, + import_job_checkpoint_threshold=4, + ) + + # Set up mock control plane HTTP server to listen for import completions + import_completion_signaled = Event() + # There's some Python magic at play here. A list can be updated from the + # handler thread, but an optional cannot. Hence, use a list with one element. + import_error = [] + + def handler(request: Request) -> Response: + assert request.json is not None + + body = request.json + if "error" in body: + if body["error"]: + import_error.append(body["error"]) + + log.info(f"control plane /import_complete request: {request.json}") + import_completion_signaled.set() + return Response(json.dumps({}), status=200) + + cplane_mgmt_api_server = make_httpserver + cplane_mgmt_api_server.expect_request( + "/storage/api/v1/import_complete", method="PUT" + ).respond_with_handler(handler) + + # Plug the cplane mock in + neon_env_builder.control_plane_hooks_api = ( + f"http://{cplane_mgmt_api_server.host}:{cplane_mgmt_api_server.port}/storage/api/v1/" + ) + + # The import will specifiy a local filesystem path mocking remote storage + neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS) + + vanilla_pg.start() + vanilla_pg.safe_psql("create user cloud_admin with password 'postgres' superuser") + vanilla_pg.safe_psql("""create table t (data char(900)) with (fillfactor = 10)""") + + nrows = 0 + while True: + relblock_size = vanilla_pg.safe_psql_scalar("select pg_relation_size('t')") + log.info( + f"relblock size: {relblock_size / 8192} pages (target: {TARGET_RELBOCK_SIZE // 8192}) pages" + ) + if relblock_size >= TARGET_RELBOCK_SIZE: + break + addrows = int((TARGET_RELBOCK_SIZE - relblock_size) // 8192) + assert addrows >= 1, "forward progress" + vanilla_pg.safe_psql( + f"insert into t select generate_series({nrows + 1}, {nrows + addrows})" + ) + nrows += addrows + + vanilla_pg.stop() + + env = neon_env_builder.init_configs() + env.start() + + # Pause after every import task to extend the test runtime and allow + # for more chaos injection. + for ps in env.pageservers: + ps.add_persistent_failpoint("import-task-complete-pausable", "sleep(5)") + + env.storage_controller.allowed_errors.extend( + [ + # The shard might have moved or the pageserver hosting the shard restarted + ".*Call to node.*management API.*failed.*", + # Migrations have their time outs set to 0 + ".*Timed out after.*downloading layers.*", + ".*Failed to prepare by downloading layers.*", + # The test may kill the storage controller or pageservers + ".*request was dropped before completing.*", + ] + ) + for ps in env.pageservers: + ps.allowed_errors.extend( + [ + # We might re-write a layer in a different generation if the import + # needs to redo some of the progress since not each job is checkpointed. + ".*was unlinked but was not dangling.*", + # The test may kill the storage controller or pageservers + ".*request was dropped before completing.*", + # Test can SIGTERM pageserver while it is downloading + ".*removing local file.*temp_download.*", + ".*Failed to flush heatmap.*", + # Test can SIGTERM the storage controller while pageserver + # is attempting to upcall. + ".*storage controller upcall failed.*timeline_import_status.*", + # TODO(vlad): TenantManager::reset_tenant returns a blanked anyhow error. + # It should return ResourceUnavailable or something that doesn't error log. + ".*activate_post_import.*InternalServerError.*tenant map is shutting down.*", + # TODO(vlad): How can this happen? + ".*Failed to download a remote file: deserialize index part file.*", + ".*Cancelled request finished with an error.*", + ] + ) + + importbucket_path = neon_env_builder.repo_dir / "test_import_chaos_bucket" + mock_import_bucket(vanilla_pg, importbucket_path) + + tenant_id = TenantId.generate() + timeline_id = TimelineId.generate() + idempotency = ImportPgdataIdemptencyKey.random() + + env.storage_controller.tenant_create( + tenant_id, shard_count=SHARD_COUNT, placement_policy={"Attached": 1} + ) + env.storage_controller.reconcile_until_idle() + + env.storage_controller.timeline_create( + tenant_id, + { + "new_timeline_id": str(timeline_id), + "import_pgdata": { + "idempotency_key": str(idempotency), + "location": {"LocalFs": {"path": str(importbucket_path.absolute())}}, + }, + }, + ) + + def chaos(stop_chaos: threading.Event): + class ChaosType(StrEnum): + MIGRATE_SHARD = "migrate_shard" + RESTART_IMMEDIATE = "restart_immediate" + RESTART = "restart" + STORCON_RESTART_IMMEDIATE = "storcon_restart_immediate" + + while not stop_chaos.is_set(): + chaos_type = random.choices( + population=[ + ChaosType.MIGRATE_SHARD, + ChaosType.RESTART, + ChaosType.RESTART_IMMEDIATE, + ChaosType.STORCON_RESTART_IMMEDIATE, + ], + weights=[0.25, 0.25, 0.25, 0.25], + k=1, + )[0] + + try: + if chaos_type == ChaosType.MIGRATE_SHARD: + target_shard_number = random.randint(0, SHARD_COUNT - 1) + target_shard = TenantShardId(tenant_id, target_shard_number, SHARD_COUNT) + + placements = env.storage_controller.get_tenants_placement() + log.info(f"{placements=}") + target_ps = placements[str(target_shard)]["intent"]["attached"] + if len(placements[str(target_shard)]["intent"]["secondary"]) == 0: + dest_ps = None + else: + dest_ps = placements[str(target_shard)]["intent"]["secondary"][0] + + if target_ps is None or dest_ps is None: + continue + + config = StorageControllerMigrationConfig( + secondary_warmup_timeout="0s", + secondary_download_request_timeout="0s", + prewarm=False, + ) + env.storage_controller.tenant_shard_migrate(target_shard, dest_ps, config) + + log.info( + f"CHAOS: Migrating shard {target_shard} from pageserver {target_ps} to {dest_ps}" + ) + elif chaos_type == ChaosType.RESTART_IMMEDIATE: + target_ps = random.choice(env.pageservers) + log.info(f"CHAOS: Immediate restart of pageserver {target_ps.id}") + target_ps.stop(immediate=True) + target_ps.start() + elif chaos_type == ChaosType.RESTART: + target_ps = random.choice(env.pageservers) + log.info(f"CHAOS: Normal restart of pageserver {target_ps.id}") + target_ps.stop(immediate=False) + target_ps.start() + elif chaos_type == ChaosType.STORCON_RESTART_IMMEDIATE: + log.info("CHAOS: Immediate restart of storage controller") + env.storage_controller.stop(immediate=True) + env.storage_controller.start() + except Exception as e: + log.warning(f"CHAOS: Error during chaos operation {chaos_type}: {e}") + + # Sleep before next chaos event + time.sleep(1) + + log.info("Chaos injector stopped") + + def wait_for_import_completion(): + start = time.time() + done = import_completion_signaled.wait(ALLOWED_IMPORT_RUNTIME) + if not done: + raise TimeoutError(f"Import did not signal completion within {ALLOWED_IMPORT_RUNTIME}") + + end = time.time() + + log.info(f"Import completion signalled after {end - start}s {import_error=}") + + if import_error: + raise RuntimeError(f"Import error: {import_error}") + + with concurrent.futures.ThreadPoolExecutor() as executor: + stop_chaos = threading.Event() + + wait_for_import_completion_fut = executor.submit(wait_for_import_completion) + chaos_fut = executor.submit(chaos, stop_chaos) + + try: + wait_for_import_completion_fut.result() + except Exception as e: + raise e + finally: + stop_chaos.set() + chaos_fut.result() + + import_branch_name = "imported" + env.neon_cli.mappings_map_branch(import_branch_name, tenant_id, timeline_id) + endpoint = env.endpoints.create_start(branch_name=import_branch_name, tenant_id=tenant_id) + + # Validate the imported data is legit + assert endpoint.safe_psql_many( + [ + "set effective_io_concurrency=32;", + "SET statement_timeout='300s';", + "select count(*), sum(data::bigint)::bigint from t", + ] + ) == [[], [], [(nrows, nrows * (nrows + 1) // 2)]] + + endpoint.stop() + + # Validate writes + workload = Workload(env, tenant_id, timeline_id, branch_name=import_branch_name) + workload.init() + workload.write_rows(64) + workload.validate() + + def test_fast_import_with_pageserver_ingest( test_output_dir, vanilla_pg: VanillaPostgres, From f3976e5c6017752ac9d9502fa1f3c6ee4017b280 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arpad=20M=C3=BCller?= Date: Tue, 27 May 2025 13:32:15 +0200 Subject: [PATCH 12/48] remove safekeeper_proto_version = 3 from tests (#12020) Some tests still explicitly specify version 3 of the safekeeper walproposer protocol. Remove the explicit opt in from the tests as v3 is the default now since #11518. We don't touch the places where a test exercises both v2 and v3. Those we leave for #12021. Part of https://github.com/neondatabase/neon/issues/10326 --- test_runner/regress/test_storage_controller.py | 18 ++++-------------- test_runner/regress/test_wal_acceptor.py | 10 ++-------- test_runner/regress/test_wal_acceptor_async.py | 5 +---- 3 files changed, 7 insertions(+), 26 deletions(-) diff --git a/test_runner/regress/test_storage_controller.py b/test_runner/regress/test_storage_controller.py index af018f7b5d..d07fb38c5a 100644 --- a/test_runner/regress/test_storage_controller.py +++ b/test_runner/regress/test_storage_controller.py @@ -4158,17 +4158,12 @@ def test_storcon_create_delete_sk_down( env.storage_controller.stop() env.storage_controller.start() - config_lines = [ - "neon.safekeeper_proto_version = 3", - ] - with env.endpoints.create("main", tenant_id=tenant_id, config_lines=config_lines) as ep: + with env.endpoints.create("main", tenant_id=tenant_id) as ep: # endpoint should start. ep.start(safekeeper_generation=1, safekeepers=[1, 2, 3]) ep.safe_psql("CREATE TABLE IF NOT EXISTS t(key int, value text)") - with env.endpoints.create( - "child_of_main", tenant_id=tenant_id, config_lines=config_lines - ) as ep: + with env.endpoints.create("child_of_main", tenant_id=tenant_id) as ep: # endpoint should start. ep.start(safekeeper_generation=1, safekeepers=[1, 2, 3]) ep.safe_psql("CREATE TABLE IF NOT EXISTS t(key int, value text)") @@ -4249,17 +4244,12 @@ def test_storcon_few_sk( env.safekeepers[0].assert_log_contains(f"creating new timeline {tenant_id}/{timeline_id}") - config_lines = [ - "neon.safekeeper_proto_version = 3", - ] - with env.endpoints.create("main", tenant_id=tenant_id, config_lines=config_lines) as ep: + with env.endpoints.create("main", tenant_id=tenant_id) as ep: # endpoint should start. ep.start(safekeeper_generation=1, safekeepers=safekeeper_list) ep.safe_psql("CREATE TABLE IF NOT EXISTS t(key int, value text)") - with env.endpoints.create( - "child_of_main", tenant_id=tenant_id, config_lines=config_lines - ) as ep: + with env.endpoints.create("child_of_main", tenant_id=tenant_id) as ep: # endpoint should start. ep.start(safekeeper_generation=1, safekeepers=safekeeper_list) ep.safe_psql("CREATE TABLE IF NOT EXISTS t(key int, value text)") diff --git a/test_runner/regress/test_wal_acceptor.py b/test_runner/regress/test_wal_acceptor.py index a9a6699e5c..6a7c7a8bef 100644 --- a/test_runner/regress/test_wal_acceptor.py +++ b/test_runner/regress/test_wal_acceptor.py @@ -2012,10 +2012,7 @@ def test_explicit_timeline_creation(neon_env_builder: NeonEnvBuilder): tenant_id = env.initial_tenant timeline_id = env.initial_timeline - config_lines = [ - "neon.safekeeper_proto_version = 3", - ] - ep = env.endpoints.create("main", config_lines=config_lines) + ep = env.endpoints.create("main") # expected to fail because timeline is not created on safekeepers with pytest.raises(Exception, match=r".*timed out.*"): @@ -2043,10 +2040,7 @@ def test_explicit_timeline_creation_storcon(neon_env_builder: NeonEnvBuilder): } env = neon_env_builder.init_start() - config_lines = [ - "neon.safekeeper_proto_version = 3", - ] - ep = env.endpoints.create("main", config_lines=config_lines) + ep = env.endpoints.create("main") # endpoint should start. ep.start(safekeeper_generation=1, safekeepers=[1, 2, 3]) diff --git a/test_runner/regress/test_wal_acceptor_async.py b/test_runner/regress/test_wal_acceptor_async.py index c5dd34f64f..4070f99568 100644 --- a/test_runner/regress/test_wal_acceptor_async.py +++ b/test_runner/regress/test_wal_acceptor_async.py @@ -637,10 +637,7 @@ async def quorum_sanity_single( # create timeline on `members_sks` Safekeeper.create_timeline(tenant_id, timeline_id, env.pageservers[0], mconf, members_sks) - config_lines = [ - "neon.safekeeper_proto_version = 3", - ] - ep = env.endpoints.create(branch_name, config_lines=config_lines) + ep = env.endpoints.create(branch_name) ep.start(safekeeper_generation=1, safekeepers=compute_sks_ids) ep.safe_psql("create table t(key int, value text)") From 5d538a950334e49545a1b0eff9fed032c380f2e6 Mon Sep 17 00:00:00 2001 From: Erik Grinaker Date: Tue, 27 May 2025 14:06:51 +0200 Subject: [PATCH 13/48] page_api: tweak errors (#12019) ## Problem The page API gRPC errors need a few tweaks to integrate better with the GetPage machinery. Touches https://github.com/neondatabase/neon/issues/11728. ## Summary of changes * Add `GetPageStatus::InternalError` for internal server errors. * Rename `GetPageStatus::Invalid` to `InvalidRequest` for clarity. * Rename `status` and `GetPageStatus` to `status_code` and `GetPageStatusCode`. * Add an `Into` implementation for `ProtocolError`. --- pageserver/page_api/proto/page_service.proto | 46 +++++++------ pageserver/page_api/src/model.rs | 72 ++++++++++++-------- 2 files changed, 69 insertions(+), 49 deletions(-) diff --git a/pageserver/page_api/proto/page_service.proto b/pageserver/page_api/proto/page_service.proto index f6acb3eeeb..44976084bf 100644 --- a/pageserver/page_api/proto/page_service.proto +++ b/pageserver/page_api/proto/page_service.proto @@ -54,9 +54,9 @@ service PageService { // RPCs use regular unary requests, since they are not as frequent and // performance-critical, and this simplifies implementation. // - // NB: a status response (e.g. errors) will terminate the stream. The stream - // may be shared by e.g. multiple Postgres backends, so we should avoid this. - // Most errors are therefore sent as GetPageResponse.status instead. + // NB: a gRPC status response (e.g. errors) will terminate the stream. The + // stream may be shared by multiple Postgres backends, so we avoid this by + // sending them as GetPageResponse.status_code instead. rpc GetPages (stream GetPageRequest) returns (stream GetPageResponse); // Returns the size of a relation, as # of blocks. @@ -159,8 +159,8 @@ message GetPageRequest { // A GetPageRequest class. Primarily intended for observability, but may also be // used for prioritization in the future. enum GetPageClass { - // Unknown class. For forwards compatibility: used when the client sends a - // class that the server doesn't know about. + // Unknown class. For backwards compatibility: used when an older client version sends a class + // that a newer server version has removed. GET_PAGE_CLASS_UNKNOWN = 0; // A normal request. This is the default. GET_PAGE_CLASS_NORMAL = 1; @@ -180,31 +180,37 @@ message GetPageResponse { // The original request's ID. uint64 request_id = 1; // The response status code. - GetPageStatus status = 2; + GetPageStatusCode status_code = 2; // A string describing the status, if any. string reason = 3; - // The 8KB page images, in the same order as the request. Empty if status != OK. + // The 8KB page images, in the same order as the request. Empty if status_code != OK. repeated bytes page_image = 4; } -// A GetPageResponse status code. Since we use a bidirectional stream, we don't -// want to send errors as gRPC statuses, since this would terminate the stream. -enum GetPageStatus { - // Unknown status. For forwards compatibility: used when the server sends a - // status code that the client doesn't know about. - GET_PAGE_STATUS_UNKNOWN = 0; +// A GetPageResponse status code. +// +// These are effectively equivalent to gRPC statuses. However, we use a bidirectional stream +// (potentially shared by many backends), and a gRPC status response would terminate the stream so +// we send GetPageResponse messages with these codes instead. +enum GetPageStatusCode { + // Unknown status. For forwards compatibility: used when an older client version receives a new + // status code from a newer server version. + GET_PAGE_STATUS_CODE_UNKNOWN = 0; // The request was successful. - GET_PAGE_STATUS_OK = 1; + GET_PAGE_STATUS_CODE_OK = 1; // The page did not exist. The tenant/timeline/shard has already been // validated during stream setup. - GET_PAGE_STATUS_NOT_FOUND = 2; + GET_PAGE_STATUS_CODE_NOT_FOUND = 2; // The request was invalid. - GET_PAGE_STATUS_INVALID = 3; + GET_PAGE_STATUS_CODE_INVALID_REQUEST = 3; + // The request failed due to an internal server error. + GET_PAGE_STATUS_CODE_INTERNAL_ERROR = 4; // The tenant is rate limited. Slow down and retry later. - GET_PAGE_STATUS_SLOW_DOWN = 4; - // TODO: consider adding a GET_PAGE_STATUS_LAYER_DOWNLOAD in the case of a - // layer download. This could free up the server task to process other - // requests while the layer download is in progress. + GET_PAGE_STATUS_CODE_SLOW_DOWN = 5; + // NB: shutdown errors are emitted as a gRPC Unavailable status. + // + // TODO: consider adding a GET_PAGE_STATUS_CODE_LAYER_DOWNLOAD in the case of a layer download. + // This could free up the server task to process other requests while the download is in progress. } // Fetches the size of a relation at a given LSN, as # of blocks. Only valid on diff --git a/pageserver/page_api/src/model.rs b/pageserver/page_api/src/model.rs index a83d0a5947..7ab97a994e 100644 --- a/pageserver/page_api/src/model.rs +++ b/pageserver/page_api/src/model.rs @@ -35,6 +35,12 @@ impl ProtocolError { } } +impl From for tonic::Status { + fn from(err: ProtocolError) -> Self { + tonic::Status::invalid_argument(format!("{err}")) + } +} + /// The LSN a request should read at. #[derive(Clone, Copy, Debug)] pub struct ReadLsn { @@ -328,7 +334,7 @@ pub type RequestID = u64; /// A GetPage request class. #[derive(Clone, Copy, Debug)] pub enum GetPageClass { - /// Unknown status. For backwards compatibility: used when an older client version sends a class + /// Unknown class. For backwards compatibility: used when an older client version sends a class /// that a newer server version has removed. Unknown, /// A normal request. This is the default. @@ -386,7 +392,7 @@ pub struct GetPageResponse { /// The original request's ID. pub request_id: RequestID, /// The response status code. - pub status: GetPageStatus, + pub status_code: GetPageStatusCode, /// A string describing the status, if any. pub reason: Option, /// The 8KB page images, in the same order as the request. Empty if status != OK. @@ -397,7 +403,7 @@ impl From for GetPageResponse { fn from(pb: proto::GetPageResponse) -> Self { Self { request_id: pb.request_id, - status: pb.status.into(), + status_code: pb.status_code.into(), reason: Some(pb.reason).filter(|r| !r.is_empty()), page_images: pb.page_image.into(), } @@ -408,16 +414,20 @@ impl From for proto::GetPageResponse { fn from(response: GetPageResponse) -> Self { Self { request_id: response.request_id, - status: response.status.into(), + status_code: response.status_code.into(), reason: response.reason.unwrap_or_default(), page_image: response.page_images.into_vec(), } } } -/// A GetPage response status. +/// A GetPage response status code. +/// +/// These are effectively equivalent to gRPC statuses. However, we use a bidirectional stream +/// (potentially shared by many backends), and a gRPC status response would terminate the stream so +/// we send GetPageResponse messages with these codes instead. #[derive(Clone, Copy, Debug)] -pub enum GetPageStatus { +pub enum GetPageStatusCode { /// Unknown status. For forwards compatibility: used when an older client version receives a new /// status code from a newer server version. Unknown, @@ -427,46 +437,50 @@ pub enum GetPageStatus { /// setup. NotFound, /// The request was invalid. - Invalid, + InvalidRequest, + /// The request failed due to an internal server error. + InternalError, /// The tenant is rate limited. Slow down and retry later. SlowDown, } -impl From for GetPageStatus { - fn from(pb: proto::GetPageStatus) -> Self { +impl From for GetPageStatusCode { + fn from(pb: proto::GetPageStatusCode) -> Self { match pb { - proto::GetPageStatus::Unknown => Self::Unknown, - proto::GetPageStatus::Ok => Self::Ok, - proto::GetPageStatus::NotFound => Self::NotFound, - proto::GetPageStatus::Invalid => Self::Invalid, - proto::GetPageStatus::SlowDown => Self::SlowDown, + proto::GetPageStatusCode::Unknown => Self::Unknown, + proto::GetPageStatusCode::Ok => Self::Ok, + proto::GetPageStatusCode::NotFound => Self::NotFound, + proto::GetPageStatusCode::InvalidRequest => Self::InvalidRequest, + proto::GetPageStatusCode::InternalError => Self::InternalError, + proto::GetPageStatusCode::SlowDown => Self::SlowDown, } } } -impl From for GetPageStatus { - fn from(status: i32) -> Self { - proto::GetPageStatus::try_from(status) - .unwrap_or(proto::GetPageStatus::Unknown) +impl From for GetPageStatusCode { + fn from(status_code: i32) -> Self { + proto::GetPageStatusCode::try_from(status_code) + .unwrap_or(proto::GetPageStatusCode::Unknown) .into() } } -impl From for proto::GetPageStatus { - fn from(status: GetPageStatus) -> Self { - match status { - GetPageStatus::Unknown => Self::Unknown, - GetPageStatus::Ok => Self::Ok, - GetPageStatus::NotFound => Self::NotFound, - GetPageStatus::Invalid => Self::Invalid, - GetPageStatus::SlowDown => Self::SlowDown, +impl From for proto::GetPageStatusCode { + fn from(status_code: GetPageStatusCode) -> Self { + match status_code { + GetPageStatusCode::Unknown => Self::Unknown, + GetPageStatusCode::Ok => Self::Ok, + GetPageStatusCode::NotFound => Self::NotFound, + GetPageStatusCode::InvalidRequest => Self::InvalidRequest, + GetPageStatusCode::InternalError => Self::InternalError, + GetPageStatusCode::SlowDown => Self::SlowDown, } } } -impl From for i32 { - fn from(status: GetPageStatus) -> Self { - proto::GetPageStatus::from(status).into() +impl From for i32 { + fn from(status_code: GetPageStatusCode) -> Self { + proto::GetPageStatusCode::from(status_code).into() } } From 30adf8e2bd8ccabd90764158effb9145590e1fd6 Mon Sep 17 00:00:00 2001 From: Vlad Lazar Date: Tue, 27 May 2025 14:57:53 +0100 Subject: [PATCH 14/48] pageserver: add tracing spans for time spent in batch and flushing (#12012) ## Problem We have some gaps in our traces. This indicates missing spans. ## Summary of changes This PR adds two new spans: * WAIT_EXECUTOR: time a batched request spends in the batch waiting to be picked up * FLUSH_RESPONSE: time a get page request spends flushing the response to the compute ![image](https://github.com/user-attachments/assets/41b3ddb8-438d-4375-9da3-da341fc0916a) --- pageserver/src/page_service.rs | 97 ++++++++++++++++++++++++++++------ 1 file changed, 82 insertions(+), 15 deletions(-) diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index 06aa207f82..e96787e027 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -769,6 +769,9 @@ struct BatchedGetPageRequest { timer: SmgrOpTimer, lsn_range: LsnRange, ctx: RequestContext, + // If the request is perf enabled, this contains a context + // with a perf span tracking the time spent waiting for the executor. + batch_wait_ctx: Option, } #[cfg(feature = "testing")] @@ -781,6 +784,7 @@ struct BatchedTestRequest { /// so that we don't keep the [`Timeline::gate`] open while the batch /// is being built up inside the [`spsc_fold`] (pagestream pipelining). #[derive(IntoStaticStr)] +#[allow(clippy::large_enum_variant)] enum BatchedFeMessage { Exists { span: Span, @@ -1298,6 +1302,22 @@ impl PageServerHandler { } }; + let batch_wait_ctx = if ctx.has_perf_span() { + Some( + RequestContextBuilder::from(&ctx) + .perf_span(|crnt_perf_span| { + info_span!( + target: PERF_TRACE_TARGET, + parent: crnt_perf_span, + "WAIT_EXECUTOR", + ) + }) + .attached_child(), + ) + } else { + None + }; + BatchedFeMessage::GetPage { span, shard: shard.downgrade(), @@ -1309,6 +1329,7 @@ impl PageServerHandler { request_lsn: req.hdr.request_lsn }, ctx, + batch_wait_ctx, }], // The executor grabs the batch when it becomes idle. // Hence, [`GetPageBatchBreakReason::ExecutorSteal`] is the @@ -1464,7 +1485,7 @@ impl PageServerHandler { let mut flush_timers = Vec::with_capacity(handler_results.len()); for handler_result in &mut handler_results { let flush_timer = match handler_result { - Ok((_, timer)) => Some( + Ok((_response, timer, _ctx)) => Some( timer .observe_execution_end(flushing_start_time) .expect("we are the first caller"), @@ -1484,7 +1505,7 @@ impl PageServerHandler { // Some handler errors cause exit from pagestream protocol. // Other handler errors are sent back as an error message and we stay in pagestream protocol. for (handler_result, flushing_timer) in handler_results.into_iter().zip(flush_timers) { - let response_msg = match handler_result { + let (response_msg, ctx) = match handler_result { Err(e) => match &e.err { PageStreamError::Shutdown => { // If we fail to fulfil a request during shutdown, which may be _because_ of @@ -1509,15 +1530,30 @@ impl PageServerHandler { error!("error reading relation or page version: {full:#}") }); - PagestreamBeMessage::Error(PagestreamErrorResponse { - req: e.req, - message: e.err.to_string(), - }) + ( + PagestreamBeMessage::Error(PagestreamErrorResponse { + req: e.req, + message: e.err.to_string(), + }), + None, + ) } }, - Ok((response_msg, _op_timer_already_observed)) => response_msg, + Ok((response_msg, _op_timer_already_observed, ctx)) => (response_msg, Some(ctx)), }; + let ctx = ctx.map(|req_ctx| { + RequestContextBuilder::from(&req_ctx) + .perf_span(|crnt_perf_span| { + info_span!( + target: PERF_TRACE_TARGET, + parent: crnt_perf_span, + "FLUSH_RESPONSE", + ) + }) + .attached_child() + }); + // // marshal & transmit response message // @@ -1540,6 +1576,17 @@ impl PageServerHandler { )), None => futures::future::Either::Right(flush_fut), }; + + let flush_fut = if let Some(req_ctx) = ctx.as_ref() { + futures::future::Either::Left( + flush_fut.maybe_perf_instrument(req_ctx, |current_perf_span| { + current_perf_span.clone() + }), + ) + } else { + futures::future::Either::Right(flush_fut) + }; + // do it while respecting cancellation let _: () = async move { tokio::select! { @@ -1569,7 +1616,7 @@ impl PageServerHandler { ctx: &RequestContext, ) -> Result< ( - Vec>, + Vec>, Span, ), QueryError, @@ -1596,7 +1643,7 @@ impl PageServerHandler { self.handle_get_rel_exists_request(&shard, &req, &ctx) .instrument(span.clone()) .await - .map(|msg| (msg, timer)) + .map(|msg| (msg, timer, ctx)) .map_err(|err| BatchedPageStreamError { err, req: req.hdr }), ], span, @@ -1615,7 +1662,7 @@ impl PageServerHandler { self.handle_get_nblocks_request(&shard, &req, &ctx) .instrument(span.clone()) .await - .map(|msg| (msg, timer)) + .map(|msg| (msg, timer, ctx)) .map_err(|err| BatchedPageStreamError { err, req: req.hdr }), ], span, @@ -1662,7 +1709,7 @@ impl PageServerHandler { self.handle_db_size_request(&shard, &req, &ctx) .instrument(span.clone()) .await - .map(|msg| (msg, timer)) + .map(|msg| (msg, timer, ctx)) .map_err(|err| BatchedPageStreamError { err, req: req.hdr }), ], span, @@ -1681,7 +1728,7 @@ impl PageServerHandler { self.handle_get_slru_segment_request(&shard, &req, &ctx) .instrument(span.clone()) .await - .map(|msg| (msg, timer)) + .map(|msg| (msg, timer, ctx)) .map_err(|err| BatchedPageStreamError { err, req: req.hdr }), ], span, @@ -2033,12 +2080,25 @@ impl PageServerHandler { return Ok(()); } }; - let batch = match batch { + let mut batch = match batch { Ok(batch) => batch, Err(e) => { return Err(e); } }; + + if let BatchedFeMessage::GetPage { + pages, + span: _, + shard: _, + batch_break_reason: _, + } = &mut batch + { + for req in pages { + req.batch_wait_ctx.take(); + } + } + self.pagestream_handle_batched_message( pgb_writer, batch, @@ -2351,7 +2411,8 @@ impl PageServerHandler { io_concurrency: IoConcurrency, batch_break_reason: GetPageBatchBreakReason, ctx: &RequestContext, - ) -> Vec> { + ) -> Vec> + { debug_assert_current_span_has_tenant_and_timeline_id(); timeline @@ -2458,6 +2519,7 @@ impl PageServerHandler { page, }), req.timer, + req.ctx, ) }) .map_err(|e| BatchedPageStreamError { @@ -2502,7 +2564,8 @@ impl PageServerHandler { timeline: &Timeline, requests: Vec, _ctx: &RequestContext, - ) -> Vec> { + ) -> Vec> + { // real requests would do something with the timeline let mut results = Vec::with_capacity(requests.len()); for _req in requests.iter() { @@ -2529,6 +2592,10 @@ impl PageServerHandler { req: req.req.clone(), }), req.timer, + RequestContext::new( + TaskKind::PageRequestHandler, + DownloadBehavior::Warn, + ), ) }) .map_err(|e| BatchedPageStreamError { From f0bb93a9c9322d78deb21c6b4ee2dff59eda26b2 Mon Sep 17 00:00:00 2001 From: "Alex Chi Z." <4198311+skyzh@users.noreply.github.com> Date: Tue, 27 May 2025 22:29:15 +0800 Subject: [PATCH 15/48] feat(pageserver): support evaluate boolean flags (#12024) ## Problem Part of https://github.com/neondatabase/neon/issues/11813 ## Summary of changes * Support evaluate boolean flags. * Add docs on how to handle errors. * Add test cases based on real PostHog config. Signed-off-by: Alex Chi Z --- libs/posthog_client_lite/src/lib.rs | 466 ++++++++++++++++++++++------ pageserver/src/feature_resolver.rs | 29 ++ 2 files changed, 404 insertions(+), 91 deletions(-) diff --git a/libs/posthog_client_lite/src/lib.rs b/libs/posthog_client_lite/src/lib.rs index 21e978df3e..8aa8da2898 100644 --- a/libs/posthog_client_lite/src/lib.rs +++ b/libs/posthog_client_lite/src/lib.rs @@ -37,7 +37,7 @@ pub struct LocalEvaluationFlag { #[derive(Deserialize)] pub struct LocalEvaluationFlagFilters { groups: Vec, - multivariate: LocalEvaluationFlagMultivariate, + multivariate: Option, } #[derive(Deserialize)] @@ -254,7 +254,7 @@ impl FeatureStore { } } - /// Evaluate a multivariate feature flag. Returns `None` if the flag is not available or if there are errors + /// Evaluate a multivariate feature flag. Returns an error if the flag is not available or if there are errors /// during the evaluation. /// /// The parsing logic is as follows: @@ -272,6 +272,10 @@ impl FeatureStore { /// Example: we have a multivariate flag with 3 groups of the configured global rollout percentage: A (10%), B (20%), C (70%). /// There is a single group with a condition that has a rollout percentage of 10% and it does not have a variant override. /// Then, we will have 1% of the users evaluated to A, 2% to B, and 7% to C. + /// + /// Error handling: the caller should inspect the error and decide the behavior when a feature flag + /// cannot be evaluated (i.e., default to false if it cannot be resolved). The error should *not* be + /// propagated beyond where the feature flag gets resolved. pub fn evaluate_multivariate( &self, flag_key: &str, @@ -290,6 +294,35 @@ impl FeatureStore { ) } + /// Evaluate a boolean feature flag. Returns an error if the flag is not available or if there are errors + /// during the evaluation. + /// + /// The parsing logic is as follows: + /// + /// * Generate a consistent hash for the tenant-feature. + /// * Match each filter group. + /// - If a group is matched, it will first determine whether the user is in the range of the rollout + /// percentage. + /// - If the hash falls within the group's rollout percentage, return true. + /// * Otherwise, continue with the next group until all groups are evaluated and no group is within the + /// rollout percentage. + /// * If there are no matching groups, return an error. + /// + /// Returns `Ok(())` if the feature flag evaluates to true. In the future, it will return a payload. + /// + /// Error handling: the caller should inspect the error and decide the behavior when a feature flag + /// cannot be evaluated (i.e., default to false if it cannot be resolved). The error should *not* be + /// propagated beyond where the feature flag gets resolved. + pub fn evaluate_boolean( + &self, + flag_key: &str, + user_id: &str, + properties: &HashMap, + ) -> Result<(), PostHogEvaluationError> { + let hash_on_global_rollout_percentage = Self::consistent_hash(user_id, flag_key, "boolean"); + self.evaluate_boolean_inner(flag_key, hash_on_global_rollout_percentage, properties) + } + /// Evaluate a multivariate feature flag. Note that we directly take the mapped user ID /// (a consistent hash ranging from 0 to 1) so that it is easier to use it in the tests /// and avoid duplicate computations. @@ -316,6 +349,11 @@ impl FeatureStore { flag_key ))); } + let Some(ref multivariate) = flag_config.filters.multivariate else { + return Err(PostHogEvaluationError::Internal(format!( + "No multivariate available, should use evaluate_boolean?: {flag_key}" + ))); + }; // TODO: sort the groups so that variant overrides always get evaluated first and it follows the PostHog // Python SDK behavior; for now we do not configure conditions without variant overrides in Neon so it // does not matter. @@ -324,7 +362,7 @@ impl FeatureStore { GroupEvaluationResult::MatchedAndOverride(variant) => return Ok(variant), GroupEvaluationResult::MatchedAndEvaluate => { let mut percentage = 0; - for variant in &flag_config.filters.multivariate.variants { + for variant in &multivariate.variants { percentage += variant.rollout_percentage; if self .evaluate_percentage(hash_on_global_rollout_percentage, percentage) @@ -352,6 +390,64 @@ impl FeatureStore { ))) } } + + /// Evaluate a multivariate feature flag. Note that we directly take the mapped user ID + /// (a consistent hash ranging from 0 to 1) so that it is easier to use it in the tests + /// and avoid duplicate computations. + /// + /// Use a different consistent hash for evaluating the group rollout percentage. + /// The behavior: if the condition is set to rolling out to 10% of the users, and + /// we set the variant A to 20% in the global config, then 2% of the total users will + /// be evaluated to variant A. + /// + /// Note that the hash to determine group rollout percentage is shared across all groups. So if we have two + /// exactly-the-same conditions with 10% and 20% rollout percentage respectively, a total of 20% of the users + /// will be evaluated (versus 30% if group evaluation is done independently). + pub(crate) fn evaluate_boolean_inner( + &self, + flag_key: &str, + hash_on_global_rollout_percentage: f64, + properties: &HashMap, + ) -> Result<(), PostHogEvaluationError> { + if let Some(flag_config) = self.flags.get(flag_key) { + if !flag_config.active { + return Err(PostHogEvaluationError::NotAvailable(format!( + "The feature flag is not active: {}", + flag_key + ))); + } + if flag_config.filters.multivariate.is_some() { + return Err(PostHogEvaluationError::Internal(format!( + "This looks like a multivariate flag, should use evaluate_multivariate?: {flag_key}" + ))); + }; + // TODO: sort the groups so that variant overrides always get evaluated first and it follows the PostHog + // Python SDK behavior; for now we do not configure conditions without variant overrides in Neon so it + // does not matter. + for group in &flag_config.filters.groups { + match self.evaluate_group(group, hash_on_global_rollout_percentage, properties)? { + GroupEvaluationResult::MatchedAndOverride(_) => { + return Err(PostHogEvaluationError::Internal(format!( + "Boolean flag cannot have overrides: {}", + flag_key + ))); + } + GroupEvaluationResult::MatchedAndEvaluate => { + return Ok(()); + } + GroupEvaluationResult::Unmatched => continue, + } + } + // If no group is matched, the feature is not available, and up to the caller to decide what to do. + Err(PostHogEvaluationError::NoConditionGroupMatched) + } else { + // The feature flag is not available yet + Err(PostHogEvaluationError::NotAvailable(format!( + "Not found in the local evaluation spec: {}", + flag_key + ))) + } + } } pub struct PostHogClientConfig { @@ -469,95 +565,162 @@ mod tests { fn data() -> &'static str { r#"{ - "flags": [ - { - "id": 132794, - "team_id": 152860, - "name": "", - "key": "gc-compaction", - "filters": { - "groups": [ - { - "variant": "enabled-stage-2", - "properties": [ - { - "key": "plan_type", - "type": "person", - "value": [ - "free" - ], - "operator": "exact" - }, - { - "key": "pageserver_remote_size", - "type": "person", - "value": "10000000", - "operator": "lt" - } - ], - "rollout_percentage": 50 - }, - { - "properties": [ - { - "key": "plan_type", - "type": "person", - "value": [ - "free" - ], - "operator": "exact" - }, - { - "key": "pageserver_remote_size", - "type": "person", - "value": "10000000", - "operator": "lt" - } - ], - "rollout_percentage": 80 - } - ], - "payloads": {}, - "multivariate": { - "variants": [ - { - "key": "disabled", - "name": "", - "rollout_percentage": 90 - }, - { - "key": "enabled-stage-1", - "name": "", - "rollout_percentage": 10 - }, - { - "key": "enabled-stage-2", - "name": "", - "rollout_percentage": 0 - }, - { - "key": "enabled-stage-3", - "name": "", - "rollout_percentage": 0 - }, - { - "key": "enabled", - "name": "", - "rollout_percentage": 0 - } - ] - } - }, - "deleted": false, - "active": true, - "ensure_experience_continuity": false, - "has_encrypted_payloads": false, - "version": 6 - } + "flags": [ + { + "id": 141807, + "team_id": 152860, + "name": "", + "key": "image-compaction-boundary", + "filters": { + "groups": [ + { + "variant": null, + "properties": [ + { + "key": "plan_type", + "type": "person", + "value": [ + "free" + ], + "operator": "exact" + } ], - "group_type_mapping": {}, - "cohorts": {} - }"# + "rollout_percentage": 40 + }, + { + "variant": null, + "properties": [], + "rollout_percentage": 10 + } + ], + "payloads": {}, + "multivariate": null + }, + "deleted": false, + "active": true, + "ensure_experience_continuity": false, + "has_encrypted_payloads": false, + "version": 1 + }, + { + "id": 135586, + "team_id": 152860, + "name": "", + "key": "boolean-flag", + "filters": { + "groups": [ + { + "variant": null, + "properties": [ + { + "key": "plan_type", + "type": "person", + "value": [ + "free" + ], + "operator": "exact" + } + ], + "rollout_percentage": 47 + } + ], + "payloads": {}, + "multivariate": null + }, + "deleted": false, + "active": true, + "ensure_experience_continuity": false, + "has_encrypted_payloads": false, + "version": 1 + }, + { + "id": 132794, + "team_id": 152860, + "name": "", + "key": "gc-compaction", + "filters": { + "groups": [ + { + "variant": "enabled-stage-2", + "properties": [ + { + "key": "plan_type", + "type": "person", + "value": [ + "free" + ], + "operator": "exact" + }, + { + "key": "pageserver_remote_size", + "type": "person", + "value": "10000000", + "operator": "lt" + } + ], + "rollout_percentage": 50 + }, + { + "properties": [ + { + "key": "plan_type", + "type": "person", + "value": [ + "free" + ], + "operator": "exact" + }, + { + "key": "pageserver_remote_size", + "type": "person", + "value": "10000000", + "operator": "lt" + } + ], + "rollout_percentage": 80 + } + ], + "payloads": {}, + "multivariate": { + "variants": [ + { + "key": "disabled", + "name": "", + "rollout_percentage": 90 + }, + { + "key": "enabled-stage-1", + "name": "", + "rollout_percentage": 10 + }, + { + "key": "enabled-stage-2", + "name": "", + "rollout_percentage": 0 + }, + { + "key": "enabled-stage-3", + "name": "", + "rollout_percentage": 0 + }, + { + "key": "enabled", + "name": "", + "rollout_percentage": 0 + } + ] + } + }, + "deleted": false, + "active": true, + "ensure_experience_continuity": false, + "has_encrypted_payloads": false, + "version": 7 + } + ], + "group_type_mapping": {}, + "cohorts": {} +}"# } #[test] @@ -633,4 +796,125 @@ mod tests { Err(PostHogEvaluationError::NoConditionGroupMatched) ),); } + + #[test] + fn evaluate_boolean_1() { + // The `boolean-flag` feature flag only has one group that matches on the free user. + + let mut store = FeatureStore::new(); + let response: LocalEvaluationResponse = serde_json::from_str(data()).unwrap(); + store.set_flags(response.flags); + + // This lacks the required properties and cannot be evaluated. + let variant = store.evaluate_boolean_inner("boolean-flag", 1.00, &HashMap::new()); + assert!(matches!( + variant, + Err(PostHogEvaluationError::NotAvailable(_)) + ),); + + let properties_unmatched = HashMap::from([ + ( + "plan_type".to_string(), + PostHogFlagFilterPropertyValue::String("paid".to_string()), + ), + ( + "pageserver_remote_size".to_string(), + PostHogFlagFilterPropertyValue::Number(1000.0), + ), + ]); + + // This does not match any group so there will be an error. + let variant = store.evaluate_boolean_inner("boolean-flag", 1.00, &properties_unmatched); + assert!(matches!( + variant, + Err(PostHogEvaluationError::NoConditionGroupMatched) + ),); + + let properties = HashMap::from([ + ( + "plan_type".to_string(), + PostHogFlagFilterPropertyValue::String("free".to_string()), + ), + ( + "pageserver_remote_size".to_string(), + PostHogFlagFilterPropertyValue::Number(1000.0), + ), + ]); + + // It matches the first group as 0.10 <= 0.50 and the properties are matched. Then it gets evaluated to the variant override. + let variant = store.evaluate_boolean_inner("boolean-flag", 0.10, &properties); + assert!(variant.is_ok()); + + // It matches the group conditions but not the group rollout percentage. + let variant = store.evaluate_boolean_inner("boolean-flag", 1.00, &properties); + assert!(matches!( + variant, + Err(PostHogEvaluationError::NoConditionGroupMatched) + ),); + } + + #[test] + fn evaluate_boolean_2() { + // The `image-compaction-boundary` feature flag has one group that matches on the free user and a group that matches on all users. + + let mut store = FeatureStore::new(); + let response: LocalEvaluationResponse = serde_json::from_str(data()).unwrap(); + store.set_flags(response.flags); + + // This lacks the required properties and cannot be evaluated. + let variant = + store.evaluate_boolean_inner("image-compaction-boundary", 1.00, &HashMap::new()); + assert!(matches!( + variant, + Err(PostHogEvaluationError::NotAvailable(_)) + ),); + + let properties_unmatched = HashMap::from([ + ( + "plan_type".to_string(), + PostHogFlagFilterPropertyValue::String("paid".to_string()), + ), + ( + "pageserver_remote_size".to_string(), + PostHogFlagFilterPropertyValue::Number(1000.0), + ), + ]); + + // This does not match the filtered group but the all user group. + let variant = + store.evaluate_boolean_inner("image-compaction-boundary", 1.00, &properties_unmatched); + assert!(matches!( + variant, + Err(PostHogEvaluationError::NoConditionGroupMatched) + ),); + let variant = + store.evaluate_boolean_inner("image-compaction-boundary", 0.05, &properties_unmatched); + assert!(variant.is_ok()); + + let properties = HashMap::from([ + ( + "plan_type".to_string(), + PostHogFlagFilterPropertyValue::String("free".to_string()), + ), + ( + "pageserver_remote_size".to_string(), + PostHogFlagFilterPropertyValue::Number(1000.0), + ), + ]); + + // It matches the first group as 0.30 <= 0.40 and the properties are matched. Then it gets evaluated to the variant override. + let variant = store.evaluate_boolean_inner("image-compaction-boundary", 0.30, &properties); + assert!(variant.is_ok()); + + // It matches the group conditions but not the group rollout percentage. + let variant = store.evaluate_boolean_inner("image-compaction-boundary", 1.00, &properties); + assert!(matches!( + variant, + Err(PostHogEvaluationError::NoConditionGroupMatched) + ),); + + // It matches the second "all" group conditions. + let variant = store.evaluate_boolean_inner("image-compaction-boundary", 0.09, &properties); + assert!(variant.is_ok()); + } } diff --git a/pageserver/src/feature_resolver.rs b/pageserver/src/feature_resolver.rs index 193fb10abc..2b0f368079 100644 --- a/pageserver/src/feature_resolver.rs +++ b/pageserver/src/feature_resolver.rs @@ -45,6 +45,10 @@ impl FeatureResolver { } /// Evaluate a multivariate feature flag. Currently, we do not support any properties. + /// + /// Error handling: the caller should inspect the error and decide the behavior when a feature flag + /// cannot be evaluated (i.e., default to false if it cannot be resolved). The error should *not* be + /// propagated beyond where the feature flag gets resolved. pub fn evaluate_multivariate( &self, flag_key: &str, @@ -62,4 +66,29 @@ impl FeatureResolver { )) } } + + /// Evaluate a boolean feature flag. Currently, we do not support any properties. + /// + /// Returns `Ok(())` if the flag is evaluated to true, otherwise returns an error. + /// + /// Error handling: the caller should inspect the error and decide the behavior when a feature flag + /// cannot be evaluated (i.e., default to false if it cannot be resolved). The error should *not* be + /// propagated beyond where the feature flag gets resolved. + pub fn evaluate_boolean( + &self, + flag_key: &str, + tenant_id: TenantId, + ) -> Result<(), PostHogEvaluationError> { + if let Some(inner) = &self.inner { + inner.feature_store().evaluate_boolean( + flag_key, + &tenant_id.to_string(), + &HashMap::new(), + ) + } else { + Err(PostHogEvaluationError::NotAvailable( + "PostHog integration is not enabled".to_string(), + )) + } + } } From cdfa06caad553234594ff99e703f0c1bd1a4dae6 Mon Sep 17 00:00:00 2001 From: Tristan Partin Date: Tue, 27 May 2025 12:33:16 -0500 Subject: [PATCH 16/48] Remove test-images compatibility hack for confirming library load paths (#11927) This hack was needed for compatiblity tests, but after the compute release is no longer needed. Signed-off-by: Tristan Partin --- docker-compose/compute_wrapper/shell/compute.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker-compose/compute_wrapper/shell/compute.sh b/docker-compose/compute_wrapper/shell/compute.sh index 20a1ffb7a0..ab8d74d355 100755 --- a/docker-compose/compute_wrapper/shell/compute.sh +++ b/docker-compose/compute_wrapper/shell/compute.sh @@ -20,7 +20,7 @@ first_path="$(ldconfig --verbose 2>/dev/null \ | grep --invert-match ^$'\t' \ | cut --delimiter=: --fields=1 \ | head --lines=1)" -test "$first_path" == '/usr/local/lib' || true # Remove the || true in a follow-up PR. Needed for backwards compat. +test "$first_path" == '/usr/local/lib' echo "Waiting pageserver become ready." while ! nc -z pageserver 6400; do From e77961c1c6f5b242a47f923299df2237d1ee1649 Mon Sep 17 00:00:00 2001 From: Suhas Thalanki <54014218+thesuhas@users.noreply.github.com> Date: Tue, 27 May 2025 15:40:51 -0400 Subject: [PATCH 17/48] background worker that collects installed extensions (#11939) ## Problem Currently, we collect metrics of what extensions are installed on computes at start up time. We do not have a mechanism that does this at runtime. ## Summary of changes Added a background thread that queries all DBs at regular intervals and collects a list of installed extensions. --- compute_tools/src/bin/compute_ctl.rs | 5 +++ compute_tools/src/compute.rs | 53 ++++++++++++++++++++++------ 2 files changed, 47 insertions(+), 11 deletions(-) diff --git a/compute_tools/src/bin/compute_ctl.rs b/compute_tools/src/bin/compute_ctl.rs index 20b5e567a8..02339f752c 100644 --- a/compute_tools/src/bin/compute_ctl.rs +++ b/compute_tools/src/bin/compute_ctl.rs @@ -136,6 +136,10 @@ struct Cli { requires = "compute-id" )] pub control_plane_uri: Option, + + /// Interval in seconds for collecting installed extensions statistics + #[arg(long, default_value = "3600")] + pub installed_extensions_collection_interval: u64, } fn main() -> Result<()> { @@ -179,6 +183,7 @@ fn main() -> Result<()> { cgroup: cli.cgroup, #[cfg(target_os = "linux")] vm_monitor_addr: cli.vm_monitor_addr, + installed_extensions_collection_interval: cli.installed_extensions_collection_interval, }, config, )?; diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index cb857e0a3e..ff49c737f0 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -97,6 +97,9 @@ pub struct ComputeNodeParams { /// the address of extension storage proxy gateway pub remote_ext_base_url: Option, + + /// Interval for installed extensions collection + pub installed_extensions_collection_interval: u64, } /// Compute node info shared across several `compute_ctl` threads. @@ -742,17 +745,7 @@ impl ComputeNode { let conf = self.get_tokio_conn_conf(None); tokio::task::spawn(async { - let res = get_installed_extensions(conf).await; - match res { - Ok(extensions) => { - info!( - "[NEON_EXT_STAT] {}", - serde_json::to_string(&extensions) - .expect("failed to serialize extensions list") - ); - } - Err(err) => error!("could not get installed extensions: {err:?}"), - } + let _ = installed_extensions(conf).await; }); } @@ -782,6 +775,9 @@ impl ComputeNode { // Log metrics so that we can search for slow operations in logs info!(?metrics, postmaster_pid = %postmaster_pid, "compute start finished"); + // Spawn the extension stats background task + self.spawn_extension_stats_task(); + if pspec.spec.prewarm_lfc_on_startup { self.prewarm_lfc(); } @@ -2192,6 +2188,41 @@ LIMIT 100", info!("Pageserver config changed"); } } + + pub fn spawn_extension_stats_task(&self) { + let conf = self.tokio_conn_conf.clone(); + let installed_extensions_collection_interval = + self.params.installed_extensions_collection_interval; + tokio::spawn(async move { + // An initial sleep is added to ensure that two collections don't happen at the same time. + // The first collection happens during compute startup. + tokio::time::sleep(tokio::time::Duration::from_secs( + installed_extensions_collection_interval, + )) + .await; + let mut interval = tokio::time::interval(tokio::time::Duration::from_secs( + installed_extensions_collection_interval, + )); + loop { + interval.tick().await; + let _ = installed_extensions(conf.clone()).await; + } + }); + } +} + +pub async fn installed_extensions(conf: tokio_postgres::Config) -> Result<()> { + let res = get_installed_extensions(conf).await; + match res { + Ok(extensions) => { + info!( + "[NEON_EXT_STAT] {}", + serde_json::to_string(&extensions).expect("failed to serialize extensions list") + ); + } + Err(err) => error!("could not get installed extensions: {err:?}"), + } + Ok(()) } pub fn forward_termination_signal() { From 541fcd8d2fb5091ee4e8103cfd4cb19ea1ee39fd Mon Sep 17 00:00:00 2001 From: Nikita Kalyanov <44959448+nikitakalyanov@users.noreply.github.com> Date: Wed, 28 May 2025 06:39:59 +0300 Subject: [PATCH 18/48] chore: expose new mark_invisible API in openAPI spec for use in cplane (#12032) ## Problem There is a new API that I plan to use. We generate client from the spec so it should be in the spec ## Summary of changes Document the existing API in openAPI format --- pageserver/src/http/openapi_spec.yml | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/pageserver/src/http/openapi_spec.yml b/pageserver/src/http/openapi_spec.yml index cf99cb110c..e8d1367d6c 100644 --- a/pageserver/src/http/openapi_spec.yml +++ b/pageserver/src/http/openapi_spec.yml @@ -353,6 +353,33 @@ paths: "200": description: OK + /v1/tenant/{tenant_shard_id}/timeline/{timeline_id}/mark_invisible: + parameters: + - name: tenant_shard_id + in: path + required: true + schema: + type: string + - name: timeline_id + in: path + required: true + schema: + type: string + format: hex + put: + requestBody: + content: + application/json: + schema: + type: object + properties: + is_visible: + type: boolean + default: false + responses: + "200": + description: OK + /v1/tenant/{tenant_shard_id}/location_config: parameters: - name: tenant_shard_id From 67ddf1de28e5d79157cd096d04dce0010d8df2cd Mon Sep 17 00:00:00 2001 From: "Alex Chi Z." <4198311+skyzh@users.noreply.github.com> Date: Wed, 28 May 2025 15:00:52 +0800 Subject: [PATCH 19/48] feat(pageserver): create image layers at L0-L1 boundary (#12023) ## Problem Previous attempt https://github.com/neondatabase/neon/pull/10548 caused some issues in staging and we reverted it. This is a re-attempt to address https://github.com/neondatabase/neon/issues/11063. Currently we create image layers at latest record LSN. We would create "future image layers" (i.e., image layers with LSN larger than disk consistent LSN) that need special handling at startup. We also waste a lot of read operations to reconstruct from L0 layers while we could have compacted all of the L0 layers and operate on a flat level of historic layers. ## Summary of changes * Run repartition at L0-L1 boundary. * Roll out with feature flags. * Piggyback a change that downgrades "image layer creating below gc_cutoff" to debug level. --------- Signed-off-by: Alex Chi Z --- pageserver/src/tenant.rs | 25 +++++-- pageserver/src/tenant/timeline.rs | 10 ++- pageserver/src/tenant/timeline/compaction.rs | 67 ++++++++++++++++--- .../regress/test_layers_from_future.py | 3 + 4 files changed, 90 insertions(+), 15 deletions(-) diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index 86731fb666..58b766933d 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -5315,6 +5315,7 @@ impl TenantShard { l0_compaction_trigger: self.l0_compaction_trigger.clone(), l0_flush_global_state: self.l0_flush_global_state.clone(), basebackup_prepare_sender: self.basebackup_prepare_sender.clone(), + feature_resolver: self.feature_resolver.clone(), } } @@ -8359,10 +8360,24 @@ mod tests { } tline.freeze_and_flush().await?; + // Force layers to L1 + tline + .compact( + &cancel, + { + let mut flags = EnumSet::new(); + flags.insert(CompactFlags::ForceL0Compaction); + flags + }, + &ctx, + ) + .await?; if iter % 5 == 0 { + let scan_lsn = Lsn(lsn.0 + 1); + info!("scanning at {}", scan_lsn); let (_, before_delta_file_accessed) = - scan_with_statistics(&tline, &keyspace, lsn, &ctx, io_concurrency.clone()) + scan_with_statistics(&tline, &keyspace, scan_lsn, &ctx, io_concurrency.clone()) .await?; tline .compact( @@ -8371,13 +8386,14 @@ mod tests { let mut flags = EnumSet::new(); flags.insert(CompactFlags::ForceImageLayerCreation); flags.insert(CompactFlags::ForceRepartition); + flags.insert(CompactFlags::ForceL0Compaction); flags }, &ctx, ) .await?; let (_, after_delta_file_accessed) = - scan_with_statistics(&tline, &keyspace, lsn, &ctx, io_concurrency.clone()) + scan_with_statistics(&tline, &keyspace, scan_lsn, &ctx, io_concurrency.clone()) .await?; assert!( after_delta_file_accessed < before_delta_file_accessed, @@ -8818,6 +8834,8 @@ mod tests { let cancel = CancellationToken::new(); + // Image layer creation happens on the disk_consistent_lsn so we need to force set it now. + tline.force_set_disk_consistent_lsn(Lsn(0x40)); tline .compact( &cancel, @@ -8831,8 +8849,7 @@ mod tests { ) .await .unwrap(); - - // Image layers are created at last_record_lsn + // Image layers are created at repartition LSN let images = tline .inspect_image_layers(Lsn(0x40), &ctx, io_concurrency.clone()) .await diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index 54dc3b2d0b..71765b9197 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -103,6 +103,7 @@ use crate::context::{ DownloadBehavior, PerfInstrumentFutureExt, RequestContext, RequestContextBuilder, }; use crate::disk_usage_eviction_task::{DiskUsageEvictionInfo, EvictionCandidate, finite_f32}; +use crate::feature_resolver::FeatureResolver; use crate::keyspace::{KeyPartitioning, KeySpace}; use crate::l0_flush::{self, L0FlushGlobalState}; use crate::metrics::{ @@ -198,6 +199,7 @@ pub struct TimelineResources { pub l0_compaction_trigger: Arc, pub l0_flush_global_state: l0_flush::L0FlushGlobalState, pub basebackup_prepare_sender: BasebackupPrepareSender, + pub feature_resolver: FeatureResolver, } pub struct Timeline { @@ -444,6 +446,8 @@ pub struct Timeline { /// A channel to send async requests to prepare a basebackup for the basebackup cache. basebackup_prepare_sender: BasebackupPrepareSender, + + feature_resolver: FeatureResolver, } pub(crate) enum PreviousHeatmap { @@ -3072,6 +3076,8 @@ impl Timeline { wait_lsn_log_slow: tokio::sync::Semaphore::new(1), basebackup_prepare_sender: resources.basebackup_prepare_sender, + + feature_resolver: resources.feature_resolver, }; result.repartition_threshold = @@ -4906,6 +4912,7 @@ impl Timeline { LastImageLayerCreationStatus::Initial, false, // don't yield for L0, we're flushing L0 ) + .instrument(info_span!("create_image_layers", mode = %ImageLayerCreationMode::Initial, partition_mode = "initial", lsn = %self.initdb_lsn)) .await?; debug_assert!( matches!(is_complete, LastImageLayerCreationStatus::Complete), @@ -5462,7 +5469,8 @@ impl Timeline { /// Returns the image layers generated and an enum indicating whether the process is fully completed. /// true = we have generate all image layers, false = we preempt the process for L0 compaction. - #[tracing::instrument(skip_all, fields(%lsn, %mode))] + /// + /// `partition_mode` is only for logging purpose and is not used anywhere in this function. async fn create_image_layers( self: &Arc, partitioning: &KeyPartitioning, diff --git a/pageserver/src/tenant/timeline/compaction.rs b/pageserver/src/tenant/timeline/compaction.rs index 0e4b14c3e4..143c2e0865 100644 --- a/pageserver/src/tenant/timeline/compaction.rs +++ b/pageserver/src/tenant/timeline/compaction.rs @@ -1278,11 +1278,55 @@ impl Timeline { } let gc_cutoff = *self.applied_gc_cutoff_lsn.read(); + let l0_l1_boundary_lsn = { + // We do the repartition on the L0-L1 boundary. All data below the boundary + // are compacted by L0 with low read amplification, thus making the `repartition` + // function run fast. + let guard = self.layers.read().await; + guard + .all_persistent_layers() + .iter() + .map(|x| { + // Use the end LSN of delta layers OR the start LSN of image layers. + if x.is_delta { + x.lsn_range.end + } else { + x.lsn_range.start + } + }) + .max() + }; + + let (partition_mode, partition_lsn) = if cfg!(test) + || cfg!(feature = "testing") + || self + .feature_resolver + .evaluate_boolean("image-compaction-boundary", self.tenant_shard_id.tenant_id) + .is_ok() + { + let last_repartition_lsn = self.partitioning.read().1; + let lsn = match l0_l1_boundary_lsn { + Some(boundary) => gc_cutoff + .max(boundary) + .max(last_repartition_lsn) + .max(self.initdb_lsn) + .max(self.ancestor_lsn), + None => self.get_last_record_lsn(), + }; + if lsn <= self.initdb_lsn || lsn <= self.ancestor_lsn { + // Do not attempt to create image layers below the initdb or ancestor LSN -- no data below it + ("l0_l1_boundary", self.get_last_record_lsn()) + } else { + ("l0_l1_boundary", lsn) + } + } else { + ("latest_record", self.get_last_record_lsn()) + }; // 2. Repartition and create image layers if necessary match self .repartition( - self.get_last_record_lsn(), + partition_lsn, self.get_compaction_target_size(), options.flags, ctx, @@ -1301,18 +1345,19 @@ impl Timeline { .extend(sparse_partitioning.into_dense().parts); // 3. Create new image layers for partitions that have been modified "enough". + let mode = if options + .flags + .contains(CompactFlags::ForceImageLayerCreation) + { + ImageLayerCreationMode::Force + } else { + ImageLayerCreationMode::Try + }; let (image_layers, outcome) = self .create_image_layers( &partitioning, lsn, - if options - .flags - .contains(CompactFlags::ForceImageLayerCreation) - { - ImageLayerCreationMode::Force - } else { - ImageLayerCreationMode::Try - }, + mode, &image_ctx, self.last_image_layer_creation_status .load() @@ -1320,6 +1365,7 @@ impl Timeline { .clone(), options.flags.contains(CompactFlags::YieldForL0), ) + .instrument(info_span!("create_image_layers", mode = %mode, partition_mode = %partition_mode, lsn = %lsn)) .await .inspect_err(|err| { if let CreateImageLayersError::GetVectoredError( @@ -1344,7 +1390,8 @@ impl Timeline { } Ok(_) => { - info!("skipping repartitioning due to image compaction LSN being below GC cutoff"); + // This happens very frequently so we don't want to log it. + debug!("skipping repartitioning due to image compaction LSN being below GC cutoff"); } // Suppress errors when cancelled. diff --git a/test_runner/regress/test_layers_from_future.py b/test_runner/regress/test_layers_from_future.py index b4eba2779d..f3fcdb0d14 100644 --- a/test_runner/regress/test_layers_from_future.py +++ b/test_runner/regress/test_layers_from_future.py @@ -20,6 +20,9 @@ from fixtures.remote_storage import LocalFsStorage, RemoteStorageKind from fixtures.utils import query_scalar, wait_until +@pytest.mark.skip( + reason="We won't create future layers any more after https://github.com/neondatabase/neon/pull/10548" +) @pytest.mark.parametrize( "attach_mode", ["default_generation", "same_generation"], From eadabeddb892b330a9aa65034adb05141ec64035 Mon Sep 17 00:00:00 2001 From: Vlad Lazar Date: Wed, 28 May 2025 16:19:41 +0100 Subject: [PATCH 20/48] pageserver: use the same job size throughout the import lifetime (#12026) ## Problem Import planning takes a job size limit as its input. Previously, the job size came from a pageserver config field. This field may change while imports are in progress. If this happens, plans will no longer be identical and the import would fail permanently. ## Summary of Changes Bake the job size into the import progress reported to the storage controller. For new imports, use the value from the pagesever config, and, for existing imports, use the value present in the shard progress. This value is identical for all shards, but we want it to be versioned since future versions of the planner might split the jobs up differently. Hence, it ends up in `ShardImportProgress`. Closes https://github.com/neondatabase/neon/issues/11983 --- libs/pageserver_api/src/models.rs | 3 +++ .../src/tenant/timeline/import_pgdata/flow.rs | 24 ++++++++++++++++--- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/libs/pageserver_api/src/models.rs b/libs/pageserver_api/src/models.rs index 9f3736d57a..e7d612bb7a 100644 --- a/libs/pageserver_api/src/models.rs +++ b/libs/pageserver_api/src/models.rs @@ -354,6 +354,9 @@ pub struct ShardImportProgressV1 { pub completed: usize, /// Hash of the plan pub import_plan_hash: u64, + /// Soft limit for the job size + /// This needs to remain constant throughout the import + pub job_soft_size_limit: usize, } impl ShardImportStatus { diff --git a/pageserver/src/tenant/timeline/import_pgdata/flow.rs b/pageserver/src/tenant/timeline/import_pgdata/flow.rs index 2ba4ca69ac..0d87a2f135 100644 --- a/pageserver/src/tenant/timeline/import_pgdata/flow.rs +++ b/pageserver/src/tenant/timeline/import_pgdata/flow.rs @@ -30,6 +30,7 @@ use std::collections::HashSet; use std::hash::{Hash, Hasher}; +use std::num::NonZeroUsize; use std::ops::Range; use std::sync::Arc; @@ -100,8 +101,24 @@ async fn run_v1( tasks: Vec::default(), }; - let import_config = &timeline.conf.timeline_import_config; - let plan = planner.plan(import_config).await?; + // Use the job size limit encoded in the progress if we are resuming an import. + // This ensures that imports have stable plans even if the pageserver config changes. + let import_config = { + match &import_progress { + Some(progress) => { + let base = &timeline.conf.timeline_import_config; + TimelineImportConfig { + import_job_soft_size_limit: NonZeroUsize::new(progress.job_soft_size_limit) + .unwrap(), + import_job_concurrency: base.import_job_concurrency, + import_job_checkpoint_threshold: base.import_job_checkpoint_threshold, + } + } + None => timeline.conf.timeline_import_config.clone(), + } + }; + + let plan = planner.plan(&import_config).await?; // Hash the plan and compare with the hash of the plan we got back from the storage controller. // If the two match, it means that the planning stage had the same output. @@ -126,7 +143,7 @@ async fn run_v1( pausable_failpoint!("import-timeline-pre-execute-pausable"); let start_from_job_idx = import_progress.map(|progress| progress.completed); - plan.execute(timeline, start_from_job_idx, plan_hash, import_config, ctx) + plan.execute(timeline, start_from_job_idx, plan_hash, &import_config, ctx) .await } @@ -453,6 +470,7 @@ impl Plan { jobs: jobs_in_plan, completed: last_completed_job_idx, import_plan_hash, + job_soft_size_limit: import_config.import_job_soft_size_limit.into(), }; timeline.remote_client.schedule_index_upload_for_file_changes()?; From 831f2a4ba70e11a0b340fd8f371836eeacf1c40e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arpad=20M=C3=BCller?= Date: Wed, 28 May 2025 20:20:38 +0200 Subject: [PATCH 21/48] Fix flakiness of test_storcon_create_delete_sk_down (#12040) The `test_storcon_create_delete_sk_down` test is still flaky. This test addresses two possible causes for flakiness. both causes are related to deletion racing with `pull_timeline` which hasn't finished yet. * the first cause is timeline deletion racing with `pull_timeline`: * the first deletion attempt doesn't contain the line because the timeline doesn't exist yet * the subsequent deletion attempts don't contain it either, only a note that the timeline is already deleted. * so this patch adds the note that the timeline is already deleted to the regex * the second cause is about tenant deletion racing with `pull_timeline`: * there were no tenant specific tombstones so if a tenant was deleted, we only added tombstones for the specific timelines being deleted, not for the tenant itself. * This patch changes this, so we now have tenant specific tombstones as well as timeline specific ones, and creation of a timeline checks both. * we also don't see any retries of the tenant deletion in the logs. once it's done it's done. so extend the regex to contain the tenant deletion message as well. One could wonder why the regex and why not using the API to check whether the timeline is just "gone". The issue with the API is that it doesn't allow one to distinguish between "deleted" and "has never existed", and latter case might race with `pull_timeline`. I.e. the second case flakiness helped in the discovery of a real bug (no tenant tombstones), so the more precise check was helpful. Before, I could easily reproduce 2-9 occurences of flakiness when running the test with an additional `range(128)` parameter (i.e. 218 times 4 times). With this patch, I ran it three times, not a single failure. Fixes #11838 --- safekeeper/src/timelines_global_map.rs | 32 +++++++++++++++++-- .../regress/test_storage_controller.py | 6 ++-- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/safekeeper/src/timelines_global_map.rs b/safekeeper/src/timelines_global_map.rs index af33bcbd20..e3f7d88f7c 100644 --- a/safekeeper/src/timelines_global_map.rs +++ b/safekeeper/src/timelines_global_map.rs @@ -44,6 +44,7 @@ struct GlobalTimelinesState { // on-demand timeline creation from recreating deleted timelines. This is only soft-enforced, as // this map is dropped on restart. tombstones: HashMap, + tenant_tombstones: HashMap, conf: Arc, broker_active_set: Arc, @@ -81,10 +82,25 @@ impl GlobalTimelinesState { } } + fn has_tombstone(&self, ttid: &TenantTimelineId) -> bool { + self.tombstones.contains_key(ttid) || self.tenant_tombstones.contains_key(&ttid.tenant_id) + } + + /// Removes all blocking tombstones for the given timeline ID. + /// Returns `true` if there have been actual changes. + fn remove_tombstone(&mut self, ttid: &TenantTimelineId) -> bool { + self.tombstones.remove(ttid).is_some() + || self.tenant_tombstones.remove(&ttid.tenant_id).is_some() + } + fn delete(&mut self, ttid: TenantTimelineId) { self.timelines.remove(&ttid); self.tombstones.insert(ttid, Instant::now()); } + + fn add_tenant_tombstone(&mut self, tenant_id: TenantId) { + self.tenant_tombstones.insert(tenant_id, Instant::now()); + } } /// A struct used to manage access to the global timelines map. @@ -99,6 +115,7 @@ impl GlobalTimelines { state: Mutex::new(GlobalTimelinesState { timelines: HashMap::new(), tombstones: HashMap::new(), + tenant_tombstones: HashMap::new(), conf, broker_active_set: Arc::new(TimelinesSet::default()), global_rate_limiter: RateLimiter::new(1, 1), @@ -245,7 +262,7 @@ impl GlobalTimelines { return Ok(timeline); } - if state.tombstones.contains_key(&ttid) { + if state.has_tombstone(&ttid) { anyhow::bail!("Timeline {ttid} is deleted, refusing to recreate"); } @@ -295,13 +312,14 @@ impl GlobalTimelines { _ => {} } if check_tombstone { - if state.tombstones.contains_key(&ttid) { + if state.has_tombstone(&ttid) { anyhow::bail!("timeline {ttid} is deleted, refusing to recreate"); } } else { // We may be have been asked to load a timeline that was previously deleted (e.g. from `pull_timeline.rs`). We trust // that the human doing this manual intervention knows what they are doing, and remove its tombstone. - if state.tombstones.remove(&ttid).is_some() { + // It's also possible that we enter this when the tenant has been deleted, even if the timeline itself has never existed. + if state.remove_tombstone(&ttid) { warn!("un-deleted timeline {ttid}"); } } @@ -482,6 +500,7 @@ impl GlobalTimelines { let tli_res = { let state = self.state.lock().unwrap(); + // Do NOT check tenant tombstones here: those were set earlier if state.tombstones.contains_key(ttid) { // Presence of a tombstone guarantees that a previous deletion has completed and there is no work to do. info!("Timeline {ttid} was already deleted"); @@ -557,6 +576,10 @@ impl GlobalTimelines { action: DeleteOrExclude, ) -> Result> { info!("deleting all timelines for tenant {}", tenant_id); + + // Adding a tombstone before getting the timelines to prevent new timeline additions + self.state.lock().unwrap().add_tenant_tombstone(*tenant_id); + let to_delete = self.get_all_for_tenant(*tenant_id); let mut err = None; @@ -600,6 +623,9 @@ impl GlobalTimelines { state .tombstones .retain(|_, v| now.duration_since(*v) < *tombstone_ttl); + state + .tenant_tombstones + .retain(|_, v| now.duration_since(*v) < *tombstone_ttl); } } diff --git a/test_runner/regress/test_storage_controller.py b/test_runner/regress/test_storage_controller.py index d07fb38c5a..346ef0951d 100644 --- a/test_runner/regress/test_storage_controller.py +++ b/test_runner/regress/test_storage_controller.py @@ -4192,10 +4192,10 @@ def test_storcon_create_delete_sk_down( # ensure the safekeeper deleted the timeline def timeline_deleted_on_active_sks(): env.safekeepers[0].assert_log_contains( - f"deleting timeline {tenant_id}/{child_timeline_id} from disk" + f"((deleting timeline|Timeline) {tenant_id}/{child_timeline_id} (from disk|was already deleted)|DELETE.*tenant/{tenant_id} .*status: 200 OK)" ) env.safekeepers[2].assert_log_contains( - f"deleting timeline {tenant_id}/{child_timeline_id} from disk" + f"((deleting timeline|Timeline) {tenant_id}/{child_timeline_id} (from disk|was already deleted)|DELETE.*tenant/{tenant_id} .*status: 200 OK)" ) wait_until(timeline_deleted_on_active_sks) @@ -4210,7 +4210,7 @@ def test_storcon_create_delete_sk_down( # ensure that there is log msgs for the third safekeeper too def timeline_deleted_on_sk(): env.safekeepers[1].assert_log_contains( - f"deleting timeline {tenant_id}/{child_timeline_id} from disk" + f"((deleting timeline|Timeline) {tenant_id}/{child_timeline_id} (from disk|was already deleted)|DELETE.*tenant/{tenant_id} .*status: 200 OK)" ) wait_until(timeline_deleted_on_sk) From 9e4cf52949621cb3d7c51f03e029869ace80dff2 Mon Sep 17 00:00:00 2001 From: "Alex Chi Z." <4198311+skyzh@users.noreply.github.com> Date: Thu, 29 May 2025 17:32:19 +0800 Subject: [PATCH 22/48] pageserver: reduce concurrency for gc-compaction (#12054) ## Problem Temporarily reduce the concurrency of gc-compaction to 1 job at a time. We are going to roll out in the largest AWS region next week. Having one job running at a time makes it easier to identify what tenant causes problem if it's not running well and pause gc-compaction for that specific tenant. (We can make this configurable via pageserver config in the future!) ## Summary of changes Reduce `CONCURRENT_GC_COMPACTION_TASKS` from 2 to 1. Signed-off-by: Alex Chi Z --- pageserver/src/tenant/timeline/compaction.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pageserver/src/tenant/timeline/compaction.rs b/pageserver/src/tenant/timeline/compaction.rs index 143c2e0865..72ca0f9cc1 100644 --- a/pageserver/src/tenant/timeline/compaction.rs +++ b/pageserver/src/tenant/timeline/compaction.rs @@ -206,8 +206,8 @@ pub struct GcCompactionQueue { } static CONCURRENT_GC_COMPACTION_TASKS: Lazy> = Lazy::new(|| { - // Only allow two timelines on one pageserver to run gc compaction at a time. - Arc::new(Semaphore::new(2)) + // Only allow one timeline on one pageserver to run gc compaction at a time. + Arc::new(Semaphore::new(1)) }); impl GcCompactionQueue { From 529d661532939a01ec74e594cac9ada54ebb2586 Mon Sep 17 00:00:00 2001 From: "devin-ai-integration[bot]" <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Thu, 29 May 2025 12:07:09 +0100 Subject: [PATCH 23/48] storcon: skip offline nodes in get_top_tenant_shards (#12057) ## Summary The optimiser background loop could get delayed a lot by waiting for timeouts trying to talk to offline nodes. Fixes: #12056 ## Solution - Skip offline nodes in `get_top_tenant_shards` Link to Devin run: https://app.devin.ai/sessions/065afd6756734d33bbd4d012428c4b6e Requested by: John Spray (john@neon.tech) Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: John Spray --- storage_controller/src/service.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/storage_controller/src/service.rs b/storage_controller/src/service.rs index d284747f73..823f4dadfa 100644 --- a/storage_controller/src/service.rs +++ b/storage_controller/src/service.rs @@ -8538,8 +8538,9 @@ impl Service { Some(ShardCount(new_shard_count)) } - /// Fetches the top tenant shards from every node, in descending order of - /// max logical size. Any node errors will be logged and ignored. + /// Fetches the top tenant shards from every available node, in descending order of + /// max logical size. Offline nodes are skipped, and any errors from available nodes + /// will be logged and ignored. async fn get_top_tenant_shards( &self, request: &TopTenantShardsRequest, @@ -8550,6 +8551,7 @@ impl Service { .unwrap() .nodes .values() + .filter(|node| node.is_available()) .cloned() .collect_vec(); From 51639cd6afc12ddb14a475c1b4be68e996a1389b Mon Sep 17 00:00:00 2001 From: Vlad Lazar Date: Thu, 29 May 2025 12:13:52 +0100 Subject: [PATCH 24/48] pageserver: allow for deletion of importing timelines (#12033) ## Problem Importing timelines can't currently be deleted. This is problematic because: 1. Cplane cannot delete failed imports and we leave the timeline behind. 2. The flow does not support user driven cancellation of the import ## Summary of changes On the pageserver: I've taken the path of least resistance, extended `TimelineOrOffloaded` with a new variant and added handling in the right places. I'm open to thoughts here, but I think it turned out better than I was envisioning. On the storage controller: Again, fairly simple business: when a DELETE timeline request is received, we remove the import from the DB and stop any finalization tasks/futures. In order to stop finalizations, we track them in-memory. For each finalizing import, we associate a gate and a cancellation token. Note that we delete the entry from the database before cancelling any finalizations. This is such that a concurrent request can't progress the import into finalize state and race with the deletion. This concern about deleting an import with on-going finalization is theoretical in the near future. We are only going to delete importing timelines after the storage controller reports the failure to cplane. Alas, the design works for user driven cancellation too. Closes https://github.com/neondatabase/neon/issues/11897 --- pageserver/src/tenant.rs | 50 +++++++++- pageserver/src/tenant/timeline/delete.rs | 40 ++++++-- .../src/tenant/timeline/import_pgdata.rs | 17 +++- storage_controller/src/http.rs | 4 + storage_controller/src/service.rs | 93 ++++++++++++++++++- storage_controller/src/timeline_import.rs | 8 ++ test_runner/fixtures/neon_fixtures.py | 16 ++++ test_runner/fixtures/pageserver/http.py | 4 +- test_runner/regress/test_import_pgdata.py | 42 +++++++-- 9 files changed, 245 insertions(+), 29 deletions(-) diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index 58b766933d..d85d970583 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -300,7 +300,7 @@ pub struct TenantShard { /// as in progress. /// * Imported timelines are removed when the storage controller calls the post timeline /// import activation endpoint. - timelines_importing: std::sync::Mutex>, + timelines_importing: std::sync::Mutex>>, /// The last tenant manifest known to be in remote storage. None if the manifest has not yet /// been either downloaded or uploaded. Always Some after tenant attach. @@ -672,6 +672,7 @@ pub enum MaybeOffloaded { pub enum TimelineOrOffloaded { Timeline(Arc), Offloaded(Arc), + Importing(Arc), } impl TimelineOrOffloaded { @@ -683,6 +684,9 @@ impl TimelineOrOffloaded { TimelineOrOffloaded::Offloaded(offloaded) => { TimelineOrOffloadedArcRef::Offloaded(offloaded) } + TimelineOrOffloaded::Importing(importing) => { + TimelineOrOffloadedArcRef::Importing(importing) + } } } pub fn tenant_shard_id(&self) -> TenantShardId { @@ -695,12 +699,16 @@ impl TimelineOrOffloaded { match self { TimelineOrOffloaded::Timeline(timeline) => &timeline.delete_progress, TimelineOrOffloaded::Offloaded(offloaded) => &offloaded.delete_progress, + TimelineOrOffloaded::Importing(importing) => &importing.delete_progress, } } fn maybe_remote_client(&self) -> Option> { match self { TimelineOrOffloaded::Timeline(timeline) => Some(timeline.remote_client.clone()), TimelineOrOffloaded::Offloaded(_offloaded) => None, + TimelineOrOffloaded::Importing(importing) => { + Some(importing.timeline.remote_client.clone()) + } } } } @@ -708,6 +716,7 @@ impl TimelineOrOffloaded { pub enum TimelineOrOffloadedArcRef<'a> { Timeline(&'a Arc), Offloaded(&'a Arc), + Importing(&'a Arc), } impl TimelineOrOffloadedArcRef<'_> { @@ -715,12 +724,14 @@ impl TimelineOrOffloadedArcRef<'_> { match self { TimelineOrOffloadedArcRef::Timeline(timeline) => timeline.tenant_shard_id, TimelineOrOffloadedArcRef::Offloaded(offloaded) => offloaded.tenant_shard_id, + TimelineOrOffloadedArcRef::Importing(importing) => importing.timeline.tenant_shard_id, } } pub fn timeline_id(&self) -> TimelineId { match self { TimelineOrOffloadedArcRef::Timeline(timeline) => timeline.timeline_id, TimelineOrOffloadedArcRef::Offloaded(offloaded) => offloaded.timeline_id, + TimelineOrOffloadedArcRef::Importing(importing) => importing.timeline.timeline_id, } } } @@ -737,6 +748,12 @@ impl<'a> From<&'a Arc> for TimelineOrOffloadedArcRef<'a> { } } +impl<'a> From<&'a Arc> for TimelineOrOffloadedArcRef<'a> { + fn from(timeline: &'a Arc) -> Self { + Self::Importing(timeline) + } +} + #[derive(Debug, thiserror::Error, PartialEq, Eq)] pub enum GetTimelineError { #[error("Timeline is shutting down")] @@ -1789,20 +1806,25 @@ impl TenantShard { }, ) => { let timeline_id = timeline.timeline_id; + let import_task_gate = Gate::default(); + let import_task_guard = import_task_gate.enter().unwrap(); let import_task_handle = tokio::task::spawn(self.clone().create_timeline_import_pgdata_task( timeline.clone(), import_pgdata, guard, + import_task_guard, ctx.detached_child(TaskKind::ImportPgdata, DownloadBehavior::Warn), )); let prev = self.timelines_importing.lock().unwrap().insert( timeline_id, - ImportingTimeline { + Arc::new(ImportingTimeline { timeline: timeline.clone(), import_task_handle, - }, + import_task_gate, + delete_progress: TimelineDeleteProgress::default(), + }), ); assert!(prev.is_none()); @@ -2853,19 +2875,25 @@ impl TenantShard { let (timeline, timeline_create_guard) = uninit_timeline.finish_creation_myself(); + let import_task_gate = Gate::default(); + let import_task_guard = import_task_gate.enter().unwrap(); + let import_task_handle = tokio::spawn(self.clone().create_timeline_import_pgdata_task( timeline.clone(), index_part, timeline_create_guard, + import_task_guard, timeline_ctx.detached_child(TaskKind::ImportPgdata, DownloadBehavior::Warn), )); let prev = self.timelines_importing.lock().unwrap().insert( timeline.timeline_id, - ImportingTimeline { + Arc::new(ImportingTimeline { timeline: timeline.clone(), import_task_handle, - }, + import_task_gate, + delete_progress: TimelineDeleteProgress::default(), + }), ); // Idempotency is enforced higher up the stack @@ -2924,6 +2952,7 @@ impl TenantShard { timeline: Arc, index_part: import_pgdata::index_part_format::Root, timeline_create_guard: TimelineCreateGuard, + _import_task_guard: GateGuard, ctx: RequestContext, ) { debug_assert_current_span_has_tenant_and_timeline_id(); @@ -3835,6 +3864,9 @@ impl TenantShard { .build_timeline_client(offloaded.timeline_id, self.remote_storage.clone()); Arc::new(remote_client) } + TimelineOrOffloadedArcRef::Importing(_) => { + unreachable!("Importing timelines are not included in the iterator") + } }; // Shut down the timeline's remote client: this means that the indices we write @@ -5044,6 +5076,14 @@ impl TenantShard { info!("timeline already exists but is offloaded"); Err(CreateTimelineError::Conflict) } + Err(TimelineExclusionError::AlreadyExists { + existing: TimelineOrOffloaded::Importing(_existing), + .. + }) => { + // If there's a timeline already importing, then we would hit + // the [`TimelineExclusionError::AlreadyCreating`] branch above. + unreachable!("Importing timelines hold the creation guard") + } Err(TimelineExclusionError::AlreadyExists { existing: TimelineOrOffloaded::Timeline(existing), arg, diff --git a/pageserver/src/tenant/timeline/delete.rs b/pageserver/src/tenant/timeline/delete.rs index 1d4dd05e34..51bdd59f4f 100644 --- a/pageserver/src/tenant/timeline/delete.rs +++ b/pageserver/src/tenant/timeline/delete.rs @@ -121,6 +121,7 @@ async fn remove_maybe_offloaded_timeline_from_tenant( // This observes the locking order between timelines and timelines_offloaded let mut timelines = tenant.timelines.lock().unwrap(); let mut timelines_offloaded = tenant.timelines_offloaded.lock().unwrap(); + let mut timelines_importing = tenant.timelines_importing.lock().unwrap(); let offloaded_children_exist = timelines_offloaded .iter() .any(|(_, entry)| entry.ancestor_timeline_id == Some(timeline.timeline_id())); @@ -150,8 +151,12 @@ async fn remove_maybe_offloaded_timeline_from_tenant( .expect("timeline that we were deleting was concurrently removed from 'timelines_offloaded' map"); offloaded_timeline.delete_from_ancestor_with_timelines(&timelines); } + TimelineOrOffloaded::Importing(importing) => { + timelines_importing.remove(&importing.timeline.timeline_id); + } } + drop(timelines_importing); drop(timelines_offloaded); drop(timelines); @@ -203,8 +208,17 @@ impl DeleteTimelineFlow { guard.mark_in_progress()?; // Now that the Timeline is in Stopping state, request all the related tasks to shut down. - if let TimelineOrOffloaded::Timeline(timeline) = &timeline { - timeline.shutdown(super::ShutdownMode::Hard).await; + // TODO(vlad): shut down imported timeline here + match &timeline { + TimelineOrOffloaded::Timeline(timeline) => { + timeline.shutdown(super::ShutdownMode::Hard).await; + } + TimelineOrOffloaded::Importing(importing) => { + importing.shutdown().await; + } + TimelineOrOffloaded::Offloaded(_offloaded) => { + // Nothing to shut down in this case + } } tenant.gc_block.before_delete(&timeline.timeline_id()); @@ -389,10 +403,18 @@ impl DeleteTimelineFlow { Err(anyhow::anyhow!("failpoint: timeline-delete-before-rm"))? }); - // Offloaded timelines have no local state - // TODO: once we persist offloaded information, delete the timeline from there, too - if let TimelineOrOffloaded::Timeline(timeline) = timeline { - delete_local_timeline_directory(conf, tenant.tenant_shard_id, timeline).await; + match timeline { + TimelineOrOffloaded::Timeline(timeline) => { + delete_local_timeline_directory(conf, tenant.tenant_shard_id, timeline).await; + } + TimelineOrOffloaded::Importing(importing) => { + delete_local_timeline_directory(conf, tenant.tenant_shard_id, &importing.timeline) + .await; + } + TimelineOrOffloaded::Offloaded(_offloaded) => { + // Offloaded timelines have no local state + // TODO: once we persist offloaded information, delete the timeline from there, too + } } fail::fail_point!("timeline-delete-after-rm", |_| { @@ -451,12 +473,16 @@ pub(super) fn make_timeline_delete_guard( // For more context see this discussion: `https://github.com/neondatabase/neon/pull/4552#discussion_r1253437346` let timelines = tenant.timelines.lock().unwrap(); let timelines_offloaded = tenant.timelines_offloaded.lock().unwrap(); + let timelines_importing = tenant.timelines_importing.lock().unwrap(); let timeline = match timelines.get(&timeline_id) { Some(t) => TimelineOrOffloaded::Timeline(Arc::clone(t)), None => match timelines_offloaded.get(&timeline_id) { Some(t) => TimelineOrOffloaded::Offloaded(Arc::clone(t)), - None => return Err(DeleteTimelineError::NotFound), + None => match timelines_importing.get(&timeline_id) { + Some(t) => TimelineOrOffloaded::Importing(Arc::clone(t)), + None => return Err(DeleteTimelineError::NotFound), + }, }, }; diff --git a/pageserver/src/tenant/timeline/import_pgdata.rs b/pageserver/src/tenant/timeline/import_pgdata.rs index db62e9000c..bdb34ec3a3 100644 --- a/pageserver/src/tenant/timeline/import_pgdata.rs +++ b/pageserver/src/tenant/timeline/import_pgdata.rs @@ -8,8 +8,9 @@ use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use tracing::info; use utils::lsn::Lsn; +use utils::sync::gate::Gate; -use super::Timeline; +use super::{Timeline, TimelineDeleteProgress}; use crate::context::RequestContext; use crate::controller_upcall_client::{StorageControllerUpcallApi, StorageControllerUpcallClient}; use crate::tenant::metadata::TimelineMetadata; @@ -19,15 +20,23 @@ mod importbucket_client; mod importbucket_format; pub(crate) mod index_part_format; -pub(crate) struct ImportingTimeline { +pub struct ImportingTimeline { pub import_task_handle: JoinHandle<()>, + pub import_task_gate: Gate, pub timeline: Arc, + pub delete_progress: TimelineDeleteProgress, +} + +impl std::fmt::Debug for ImportingTimeline { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "ImportingTimeline<{}>", self.timeline.timeline_id) + } } impl ImportingTimeline { - pub(crate) async fn shutdown(self) { + pub async fn shutdown(&self) { self.import_task_handle.abort(); - let _ = self.import_task_handle.await; + self.import_task_gate.close().await; self.timeline.remote_client.shutdown().await; } diff --git a/storage_controller/src/http.rs b/storage_controller/src/http.rs index 02c02c0e7f..2b1c0db12f 100644 --- a/storage_controller/src/http.rs +++ b/storage_controller/src/http.rs @@ -482,6 +482,10 @@ async fn handle_tenant_timeline_delete( ForwardOutcome::NotForwarded(_req) => {} }; + service + .maybe_delete_timeline_import(tenant_id, timeline_id) + .await?; + // For timeline deletions, which both implement an "initially return 202, then 404 once // we're done" semantic, we wrap with a retry loop to expose a simpler API upstream. async fn deletion_wrapper(service: Arc, f: F) -> Result, ApiError> diff --git a/storage_controller/src/service.rs b/storage_controller/src/service.rs index 823f4dadfa..790797bae2 100644 --- a/storage_controller/src/service.rs +++ b/storage_controller/src/service.rs @@ -99,8 +99,8 @@ use crate::tenant_shard::{ ScheduleOptimization, ScheduleOptimizationAction, TenantShard, }; use crate::timeline_import::{ - ImportResult, ShardImportStatuses, TimelineImport, TimelineImportFinalizeError, - TimelineImportState, UpcallClient, + FinalizingImport, ImportResult, ShardImportStatuses, TimelineImport, + TimelineImportFinalizeError, TimelineImportState, UpcallClient, }; const WAITER_FILL_DRAIN_POLL_TIMEOUT: Duration = Duration::from_millis(500); @@ -232,6 +232,9 @@ struct ServiceState { /// Queue of tenants who are waiting for concurrency limits to permit them to reconcile delayed_reconcile_rx: tokio::sync::mpsc::Receiver, + + /// Tracks ongoing timeline import finalization tasks + imports_finalizing: BTreeMap<(TenantId, TimelineId), FinalizingImport>, } /// Transform an error from a pageserver into an error to return to callers of a storage @@ -308,6 +311,7 @@ impl ServiceState { scheduler, ongoing_operation: None, delayed_reconcile_rx, + imports_finalizing: Default::default(), } } @@ -4097,13 +4101,58 @@ impl Service { /// /// If this method gets pre-empted by shut down, it will be called again at start-up (on-going /// imports are stored in the database). + /// + /// # Cancel-Safety + /// Not cancel safe. + /// If the caller stops polling, the import will not be removed from + /// [`ServiceState::imports_finalizing`]. #[instrument(skip_all, fields( tenant_id=%import.tenant_id, timeline_id=%import.timeline_id, ))] + async fn finalize_timeline_import( self: &Arc, import: TimelineImport, + ) -> Result<(), TimelineImportFinalizeError> { + let tenant_timeline = (import.tenant_id, import.timeline_id); + + let (_finalize_import_guard, cancel) = { + let mut locked = self.inner.write().unwrap(); + let gate = Gate::default(); + let cancel = CancellationToken::default(); + + let guard = gate.enter().unwrap(); + + locked.imports_finalizing.insert( + tenant_timeline, + FinalizingImport { + gate, + cancel: cancel.clone(), + }, + ); + + (guard, cancel) + }; + + let res = tokio::select! { + res = self.finalize_timeline_import_impl(import) => { + res + }, + _ = cancel.cancelled() => { + Err(TimelineImportFinalizeError::Cancelled) + } + }; + + let mut locked = self.inner.write().unwrap(); + locked.imports_finalizing.remove(&tenant_timeline); + + res + } + + async fn finalize_timeline_import_impl( + self: &Arc, + import: TimelineImport, ) -> Result<(), TimelineImportFinalizeError> { tracing::info!("Finalizing timeline import"); @@ -4303,6 +4352,46 @@ impl Service { .await; } + /// Delete a timeline import if it exists + /// + /// Firstly, delete the entry from the database. Any updates + /// from pageservers after the update will fail with a 404, so the + /// import cannot progress into finalizing state if it's not there already. + /// Secondly, cancel the finalization if one is in progress. + pub(crate) async fn maybe_delete_timeline_import( + self: &Arc, + tenant_id: TenantId, + timeline_id: TimelineId, + ) -> Result<(), DatabaseError> { + let tenant_has_ongoing_import = { + let locked = self.inner.read().unwrap(); + locked + .tenants + .range(TenantShardId::tenant_range(tenant_id)) + .any(|(_tid, shard)| shard.importing == TimelineImportState::Importing) + }; + + if !tenant_has_ongoing_import { + return Ok(()); + } + + self.persistence + .delete_timeline_import(tenant_id, timeline_id) + .await?; + + let maybe_finalizing = { + let mut locked = self.inner.write().unwrap(); + locked.imports_finalizing.remove(&(tenant_id, timeline_id)) + }; + + if let Some(finalizing) = maybe_finalizing { + finalizing.cancel.cancel(); + finalizing.gate.close().await; + } + + Ok(()) + } + pub(crate) async fn tenant_timeline_archival_config( &self, tenant_id: TenantId, diff --git a/storage_controller/src/timeline_import.rs b/storage_controller/src/timeline_import.rs index 909e8e2899..eb50819d02 100644 --- a/storage_controller/src/timeline_import.rs +++ b/storage_controller/src/timeline_import.rs @@ -7,6 +7,7 @@ use serde::{Deserialize, Serialize}; use pageserver_api::models::{ShardImportProgress, ShardImportStatus}; use tokio_util::sync::CancellationToken; +use utils::sync::gate::Gate; use utils::{ id::{TenantId, TimelineId}, shard::ShardIndex, @@ -55,6 +56,8 @@ pub(crate) enum TimelineImportUpdateFollowUp { pub(crate) enum TimelineImportFinalizeError { #[error("Shut down interrupted import finalize")] ShuttingDown, + #[error("Import finalization was cancelled")] + Cancelled, #[error("Mismatched shard detected during import finalize: {0}")] MismatchedShards(ShardIndex), } @@ -164,6 +167,11 @@ impl TimelineImport { } } +pub(crate) struct FinalizingImport { + pub(crate) gate: Gate, + pub(crate) cancel: CancellationToken, +} + pub(crate) type ImportResult = Result<(), String>; pub(crate) struct UpcallClient { diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 7f4150b580..eedeb4f696 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -2337,6 +2337,22 @@ class NeonStorageController(MetricsGetter, LogUtils): headers=self.headers(TokenScope.ADMIN), ) + def import_status( + self, tenant_shard_id: TenantShardId, timeline_id: TimelineId, generation: int + ): + payload = { + "tenant_shard_id": str(tenant_shard_id), + "timeline_id": str(timeline_id), + "generation": generation, + } + + self.request( + "GET", + f"{self.api}/upcall/v1/timeline_import_status", + headers=self.headers(TokenScope.GENERATIONS_API), + json=payload, + ) + def reconcile_all(self): r = self.request( "POST", diff --git a/test_runner/fixtures/pageserver/http.py b/test_runner/fixtures/pageserver/http.py index c2d176bf5a..c29192c25c 100644 --- a/test_runner/fixtures/pageserver/http.py +++ b/test_runner/fixtures/pageserver/http.py @@ -675,7 +675,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): def timeline_delete( self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, **kwargs - ): + ) -> int: """ Note that deletion is not instant, it is scheduled and performed mostly in the background. So if you need to wait for it to complete use `timeline_delete_wait_completed`. @@ -688,6 +688,8 @@ class PageserverHttpClient(requests.Session, MetricsGetter): res_json = res.json() assert res_json is None + return res.status_code + def timeline_gc( self, tenant_id: TenantId | TenantShardId, diff --git a/test_runner/regress/test_import_pgdata.py b/test_runner/regress/test_import_pgdata.py index 69cbdec5b0..262ec9b06c 100644 --- a/test_runner/regress/test_import_pgdata.py +++ b/test_runner/regress/test_import_pgdata.py @@ -19,6 +19,7 @@ from fixtures.neon_fixtures import ( PageserverImportConfig, PgBin, PgProtocol, + StorageControllerApiException, StorageControllerMigrationConfig, VanillaPostgres, ) @@ -423,8 +424,12 @@ def test_import_completion_on_restart( @run_only_on_default_postgres(reason="PG version is irrelevant here") -def test_import_respects_tenant_shutdown( - neon_env_builder: NeonEnvBuilder, vanilla_pg: VanillaPostgres, make_httpserver: HTTPServer +@pytest.mark.parametrize("action", ["restart", "delete"]) +def test_import_respects_timeline_lifecycle( + neon_env_builder: NeonEnvBuilder, + vanilla_pg: VanillaPostgres, + make_httpserver: HTTPServer, + action: str, ): """ Validate that importing timelines respect the usual timeline life cycle: @@ -492,16 +497,33 @@ def test_import_respects_tenant_shutdown( wait_until(hit_failpoint) assert not import_completion_signaled.is_set() - # Restart the pageserver while an import job is in progress. - # This clears the failpoint and we expect that the import starts up afresh - # after the restart and eventually completes. - env.pageserver.stop() - env.pageserver.start() + if action == "restart": + # Restart the pageserver while an import job is in progress. + # This clears the failpoint and we expect that the import starts up afresh + # after the restart and eventually completes. + env.pageserver.stop() + env.pageserver.start() - def cplane_notified(): - assert import_completion_signaled.is_set() + def cplane_notified(): + assert import_completion_signaled.is_set() - wait_until(cplane_notified) + wait_until(cplane_notified) + elif action == "delete": + status = env.storage_controller.pageserver_api().timeline_delete(tenant_id, timeline_id) + assert status == 200 + + timeline_path = env.pageserver.timeline_dir(tenant_id, timeline_id) + assert not timeline_path.exists(), "Timeline dir exists after deletion" + + shard_zero = TenantShardId(tenant_id, 0, 0) + location = env.storage_controller.inspect(shard_zero) + assert location is not None + generation = location[0] + + with pytest.raises(StorageControllerApiException, match="not found"): + env.storage_controller.import_status(shard_zero, timeline_id, generation) + else: + raise RuntimeError(f"{action} param not recognized") @skip_in_debug_build("Validation query takes too long in debug builds") From 8a6fc6fd8c46a5cdd00bcc4999c4fb2d22cfe968 Mon Sep 17 00:00:00 2001 From: Vlad Lazar Date: Thu, 29 May 2025 14:01:10 +0100 Subject: [PATCH 25/48] pageserver: hook importing timelines up into disk usage eviction (#12038) ## Problem Disk usage eviction isn't sensitive to layers of imported timelines. ## Summary of changes Hook importing timelines up into eviction and add a test for it. I don't think we need any special eviction logic for this. These layers will all be visible and their access time will be their creation time. Hence, we'll remove covered layers first and get to the imported layers if there's still disk pressure. --- pageserver/src/disk_usage_eviction_task.rs | 25 ++- pageserver/src/tenant.rs | 11 ++ .../src/tenant/remote_timeline_client.rs | 15 ++ .../src/tenant/timeline/import_pgdata.rs | 3 + .../src/tenant/timeline/import_pgdata/flow.rs | 9 ++ test_runner/fixtures/fast_import.py | 56 +++++++ test_runner/fixtures/neon_fixtures.py | 5 + .../regress/test_disk_usage_eviction.py | 142 +++++++++++++++++- test_runner/regress/test_import_pgdata.py | 103 ++----------- 9 files changed, 273 insertions(+), 96 deletions(-) diff --git a/pageserver/src/disk_usage_eviction_task.rs b/pageserver/src/disk_usage_eviction_task.rs index 13252037e5..f13b3709f5 100644 --- a/pageserver/src/disk_usage_eviction_task.rs +++ b/pageserver/src/disk_usage_eviction_task.rs @@ -837,7 +837,30 @@ async fn collect_eviction_candidates( continue; } let info = tl.get_local_layers_for_disk_usage_eviction().await; - debug!(tenant_id=%tl.tenant_shard_id.tenant_id, shard_id=%tl.tenant_shard_id.shard_slug(), timeline_id=%tl.timeline_id, "timeline resident layers count: {}", info.resident_layers.len()); + debug!( + tenant_id=%tl.tenant_shard_id.tenant_id, + shard_id=%tl.tenant_shard_id.shard_slug(), + timeline_id=%tl.timeline_id, + "timeline resident layers count: {}", info.resident_layers.len() + ); + + tenant_candidates.extend(info.resident_layers.into_iter()); + max_layer_size = max_layer_size.max(info.max_layer_size.unwrap_or(0)); + + if cancel.is_cancelled() { + return Ok(EvictionCandidates::Cancelled); + } + } + + // Also consider layers of timelines being imported for eviction + for tl in tenant.list_importing_timelines() { + let info = tl.timeline.get_local_layers_for_disk_usage_eviction().await; + debug!( + tenant_id=%tl.timeline.tenant_shard_id.tenant_id, + shard_id=%tl.timeline.tenant_shard_id.shard_slug(), + timeline_id=%tl.timeline.timeline_id, + "timeline resident layers count: {}", info.resident_layers.len() + ); tenant_candidates.extend(info.resident_layers.into_iter()); max_layer_size = max_layer_size.max(info.max_layer_size.unwrap_or(0)); diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index d85d970583..451d266bc0 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -2442,6 +2442,17 @@ impl TenantShard { .collect() } + /// Lists timelines the tenant contains. + /// It's up to callers to omit certain timelines that are not considered ready for use. + pub fn list_importing_timelines(&self) -> Vec> { + self.timelines_importing + .lock() + .unwrap() + .values() + .map(Arc::clone) + .collect() + } + /// Lists timelines the tenant manages, including offloaded ones. /// /// It's up to callers to omit certain timelines that are not considered ready for use. diff --git a/pageserver/src/tenant/remote_timeline_client.rs b/pageserver/src/tenant/remote_timeline_client.rs index 21d68495f7..fd65000379 100644 --- a/pageserver/src/tenant/remote_timeline_client.rs +++ b/pageserver/src/tenant/remote_timeline_client.rs @@ -1348,6 +1348,21 @@ impl RemoteTimelineClient { Ok(()) } + pub(crate) fn schedule_unlinking_of_layers_from_index_part( + self: &Arc, + names: I, + ) -> Result<(), NotInitialized> + where + I: IntoIterator, + { + let mut guard = self.upload_queue.lock().unwrap(); + let upload_queue = guard.initialized_mut()?; + + self.schedule_unlinking_of_layers_from_index_part0(upload_queue, names); + + Ok(()) + } + /// Update the remote index file, removing the to-be-deleted files from the index, /// allowing scheduling of actual deletions later. fn schedule_unlinking_of_layers_from_index_part0( diff --git a/pageserver/src/tenant/timeline/import_pgdata.rs b/pageserver/src/tenant/timeline/import_pgdata.rs index bdb34ec3a3..f19a4b3e9c 100644 --- a/pageserver/src/tenant/timeline/import_pgdata.rs +++ b/pageserver/src/tenant/timeline/import_pgdata.rs @@ -8,6 +8,7 @@ use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use tracing::info; use utils::lsn::Lsn; +use utils::pausable_failpoint; use utils::sync::gate::Gate; use super::{Timeline, TimelineDeleteProgress}; @@ -110,6 +111,8 @@ pub async fn doit( .schedule_index_upload_for_file_changes()?; timeline.remote_client.wait_completion().await?; + pausable_failpoint!("import-timeline-pre-success-notify-pausable"); + // Communicate that shard is done. // Ensure at-least-once delivery of the upcall to storage controller // before we mark the task as done and never come here again. diff --git a/pageserver/src/tenant/timeline/import_pgdata/flow.rs b/pageserver/src/tenant/timeline/import_pgdata/flow.rs index 0d87a2f135..9743aa3f26 100644 --- a/pageserver/src/tenant/timeline/import_pgdata/flow.rs +++ b/pageserver/src/tenant/timeline/import_pgdata/flow.rs @@ -982,6 +982,15 @@ impl ChunkProcessingJob { .cloned(); match existing_layer { Some(existing) => { + // Unlink the remote layer from the index without scheduling its deletion. + // When `existing_layer` drops [`LayerInner::drop`] will schedule its deletion from + // remote storage, but that assumes that the layer was unlinked from the index first. + timeline + .remote_client + .schedule_unlinking_of_layers_from_index_part(std::iter::once( + existing.layer_desc().layer_name(), + ))?; + guard.open_mut()?.rewrite_layers( &[(existing.clone(), resident_layer.clone())], &[], diff --git a/test_runner/fixtures/fast_import.py b/test_runner/fixtures/fast_import.py index f9e5f9c1db..bd6dc2583b 100644 --- a/test_runner/fixtures/fast_import.py +++ b/test_runner/fixtures/fast_import.py @@ -1,3 +1,4 @@ +import json import os import shutil import subprocess @@ -11,6 +12,7 @@ from _pytest.config import Config from fixtures.log_helper import log from fixtures.neon_cli import AbstractNeonCli +from fixtures.neon_fixtures import Endpoint, VanillaPostgres from fixtures.pg_version import PgVersion from fixtures.remote_storage import MockS3Server @@ -161,3 +163,57 @@ def fast_import( f.write(fi.cmd.stderr) log.info("Written logs to %s", test_output_dir) + + +def mock_import_bucket(vanilla_pg: VanillaPostgres, path: Path): + """ + Mock the import S3 bucket into a local directory for a provided vanilla PG instance. + """ + assert not vanilla_pg.is_running() + + path.mkdir() + # what cplane writes before scheduling fast_import + specpath = path / "spec.json" + specpath.write_text(json.dumps({"branch_id": "somebranch", "project_id": "someproject"})) + # what fast_import writes + vanilla_pg.pgdatadir.rename(path / "pgdata") + statusdir = path / "status" + statusdir.mkdir() + (statusdir / "pgdata").write_text(json.dumps({"done": True})) + (statusdir / "fast_import").write_text(json.dumps({"command": "pgdata", "done": True})) + + +def populate_vanilla_pg(vanilla_pg: VanillaPostgres, target_relblock_size: int) -> int: + assert vanilla_pg.is_running() + + vanilla_pg.safe_psql("create user cloud_admin with password 'postgres' superuser") + # fillfactor so we don't need to produce that much data + # 900 byte per row is > 10% => 1 row per page + vanilla_pg.safe_psql("""create table t (data char(900)) with (fillfactor = 10)""") + + nrows = 0 + while True: + relblock_size = vanilla_pg.safe_psql_scalar("select pg_relation_size('t')") + log.info( + f"relblock size: {relblock_size / 8192} pages (target: {target_relblock_size // 8192}) pages" + ) + if relblock_size >= target_relblock_size: + break + addrows = int((target_relblock_size - relblock_size) // 8192) + assert addrows >= 1, "forward progress" + vanilla_pg.safe_psql( + f"insert into t select generate_series({nrows + 1}, {nrows + addrows})" + ) + nrows += addrows + + return nrows + + +def validate_import_from_vanilla_pg(endpoint: Endpoint, nrows: int): + assert endpoint.safe_psql_many( + [ + "set effective_io_concurrency=32;", + "SET statement_timeout='300s';", + "select count(*), sum(data::bigint)::bigint from t", + ] + ) == [[], [], [(nrows, nrows * (nrows + 1) // 2)]] diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index eedeb4f696..ab4885ce6b 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -2829,6 +2829,11 @@ class NeonPageserver(PgProtocol, LogUtils): if self.running: self.http_client().configure_failpoints([(name, action)]) + def clear_persistent_failpoint(self, name: str): + del self._persistent_failpoints[name] + if self.running: + self.http_client().configure_failpoints([(name, "off")]) + def timeline_dir( self, tenant_shard_id: TenantId | TenantShardId, diff --git a/test_runner/regress/test_disk_usage_eviction.py b/test_runner/regress/test_disk_usage_eviction.py index b29610e021..1420dc59a1 100644 --- a/test_runner/regress/test_disk_usage_eviction.py +++ b/test_runner/regress/test_disk_usage_eviction.py @@ -1,31 +1,41 @@ from __future__ import annotations import enum +import json import time from collections import Counter from dataclasses import dataclass from enum import StrEnum +from threading import Event from typing import TYPE_CHECKING import pytest from fixtures.common_types import Lsn, TenantId, TimelineId +from fixtures.fast_import import mock_import_bucket, populate_vanilla_pg from fixtures.log_helper import log from fixtures.neon_fixtures import ( NeonEnv, NeonEnvBuilder, NeonPageserver, PgBin, + VanillaPostgres, wait_for_last_flush_lsn, ) +from fixtures.pageserver.http import ( + ImportPgdataIdemptencyKey, +) from fixtures.pageserver.utils import wait_for_upload_queue_empty from fixtures.remote_storage import RemoteStorageKind -from fixtures.utils import human_bytes, wait_until +from fixtures.utils import human_bytes, run_only_on_default_postgres, wait_until +from werkzeug.wrappers.response import Response if TYPE_CHECKING: from collections.abc import Iterable from typing import Any from fixtures.pageserver.http import PageserverHttpClient + from pytest_httpserver import HTTPServer + from werkzeug.wrappers.request import Request GLOBAL_LRU_LOG_LINE = "tenant_min_resident_size-respecting LRU would not relieve pressure, evicting more following global LRU policy" @@ -164,6 +174,7 @@ class EvictionEnv: min_avail_bytes, mock_behavior, eviction_order: EvictionOrder, + wait_logical_size: bool = True, ): """ Starts pageserver up with mocked statvfs setup. The startup is @@ -201,11 +212,12 @@ class EvictionEnv: pageserver.start() # we now do initial logical size calculation on startup, which on debug builds can fight with disk usage based eviction - for tenant_id, timeline_id in self.timelines: - tenant_ps = self.neon_env.get_tenant_pageserver(tenant_id) - # Pageserver may be none if we are currently not attached anywhere, e.g. during secondary eviction test - if tenant_ps is not None: - tenant_ps.http_client().timeline_wait_logical_size(tenant_id, timeline_id) + if wait_logical_size: + for tenant_id, timeline_id in self.timelines: + tenant_ps = self.neon_env.get_tenant_pageserver(tenant_id) + # Pageserver may be none if we are currently not attached anywhere, e.g. during secondary eviction test + if tenant_ps is not None: + tenant_ps.http_client().timeline_wait_logical_size(tenant_id, timeline_id) def statvfs_called(): pageserver.assert_log_contains(".*running mocked statvfs.*") @@ -882,3 +894,121 @@ def test_secondary_mode_eviction(eviction_env_ha: EvictionEnv): assert total_size - post_eviction_total_size >= evict_bytes, ( "we requested at least evict_bytes worth of free space" ) + + +@run_only_on_default_postgres(reason="PG version is irrelevant here") +def test_import_timeline_disk_pressure_eviction( + neon_env_builder: NeonEnvBuilder, + vanilla_pg: VanillaPostgres, + make_httpserver: HTTPServer, + pg_bin: PgBin, +): + """ + TODO + """ + # Set up mock control plane HTTP server to listen for import completions + import_completion_signaled = Event() + + def handler(request: Request) -> Response: + log.info(f"control plane /import_complete request: {request.json}") + import_completion_signaled.set() + return Response(json.dumps({}), status=200) + + cplane_mgmt_api_server = make_httpserver + cplane_mgmt_api_server.expect_request( + "/storage/api/v1/import_complete", method="PUT" + ).respond_with_handler(handler) + + # Plug the cplane mock in + neon_env_builder.control_plane_hooks_api = ( + f"http://{cplane_mgmt_api_server.host}:{cplane_mgmt_api_server.port}/storage/api/v1/" + ) + + # The import will specifiy a local filesystem path mocking remote storage + neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS) + + vanilla_pg.start() + target_relblock_size = 1024 * 1024 * 128 + populate_vanilla_pg(vanilla_pg, target_relblock_size) + vanilla_pg.stop() + + env = neon_env_builder.init_configs() + env.start() + + importbucket_path = neon_env_builder.repo_dir / "test_import_completion_bucket" + mock_import_bucket(vanilla_pg, importbucket_path) + + tenant_id = TenantId.generate() + timeline_id = TimelineId.generate() + idempotency = ImportPgdataIdemptencyKey.random() + + eviction_env = EvictionEnv( + timelines=[(tenant_id, timeline_id)], + neon_env=env, + pageserver_http=env.pageserver.http_client(), + layer_size=5 * 1024 * 1024, # Doesn't apply here + pg_bin=pg_bin, # Not used here + pgbench_init_lsns={}, # Not used here + ) + + # Pause before delivering the final notification to storcon. + # This keeps the import in progress. + failpoint_name = "import-timeline-pre-success-notify-pausable" + env.pageserver.add_persistent_failpoint(failpoint_name, "pause") + + env.storage_controller.tenant_create(tenant_id) + env.storage_controller.timeline_create( + tenant_id, + { + "new_timeline_id": str(timeline_id), + "import_pgdata": { + "idempotency_key": str(idempotency), + "location": {"LocalFs": {"path": str(importbucket_path.absolute())}}, + }, + }, + ) + + def hit_failpoint(): + log.info("Checking log for pattern...") + try: + assert env.pageserver.log_contains(f".*at failpoint {failpoint_name}.*") + except Exception: + log.exception("Failed to find pattern in log") + raise + + wait_until(hit_failpoint) + assert not import_completion_signaled.is_set() + + env.pageserver.stop() + + total_size, _, _ = eviction_env.timelines_du(env.pageserver) + blocksize = 512 + total_blocks = (total_size + (blocksize - 1)) // blocksize + + eviction_env.pageserver_start_with_disk_usage_eviction( + env.pageserver, + period="1s", + max_usage_pct=33, + min_avail_bytes=0, + mock_behavior={ + "type": "Success", + "blocksize": blocksize, + "total_blocks": total_blocks, + # Only count layer files towards used bytes in the mock_statvfs. + # This avoids accounting for metadata files & tenant conf in the tests. + "name_filter": ".*__.*", + }, + eviction_order=EvictionOrder.RELATIVE_ORDER_SPARE, + wait_logical_size=False, + ) + + wait_until(lambda: env.pageserver.assert_log_contains(".*disk usage pressure relieved")) + + env.pageserver.clear_persistent_failpoint(failpoint_name) + + def cplane_notified(): + assert import_completion_signaled.is_set() + + wait_until(cplane_notified) + + env.pageserver.allowed_errors.append(r".* running disk usage based eviction due to pressure.*") diff --git a/test_runner/regress/test_import_pgdata.py b/test_runner/regress/test_import_pgdata.py index 262ec9b06c..ba60c3caa6 100644 --- a/test_runner/regress/test_import_pgdata.py +++ b/test_runner/regress/test_import_pgdata.py @@ -12,7 +12,12 @@ import psycopg2 import psycopg2.errors import pytest from fixtures.common_types import Lsn, TenantId, TenantShardId, TimelineId -from fixtures.fast_import import FastImport +from fixtures.fast_import import ( + FastImport, + mock_import_bucket, + populate_vanilla_pg, + validate_import_from_vanilla_pg, +) from fixtures.log_helper import log from fixtures.neon_fixtures import ( NeonEnvBuilder, @@ -60,24 +65,6 @@ smoke_params = [ ] -def mock_import_bucket(vanilla_pg: VanillaPostgres, path: Path): - """ - Mock the import S3 bucket into a local directory for a provided vanilla PG instance. - """ - assert not vanilla_pg.is_running() - - path.mkdir() - # what cplane writes before scheduling fast_import - specpath = path / "spec.json" - specpath.write_text(json.dumps({"branch_id": "somebranch", "project_id": "someproject"})) - # what fast_import writes - vanilla_pg.pgdatadir.rename(path / "pgdata") - statusdir = path / "status" - statusdir.mkdir() - (statusdir / "pgdata").write_text(json.dumps({"done": True})) - (statusdir / "fast_import").write_text(json.dumps({"command": "pgdata", "done": True})) - - @skip_in_debug_build("MULTIPLE_RELATION_SEGMENTS has non trivial amount of data") @pytest.mark.parametrize("shard_count,stripe_size,rel_block_size", smoke_params) def test_pgdata_import_smoke( @@ -132,10 +119,6 @@ def test_pgdata_import_smoke( # Put data in vanilla pg # - vanilla_pg.start() - vanilla_pg.safe_psql("create user cloud_admin with password 'postgres' superuser") - - log.info("create relblock data") if rel_block_size == RelBlockSize.ONE_STRIPE_SIZE: target_relblock_size = stripe_size * 8192 elif rel_block_size == RelBlockSize.TWO_STRPES_PER_SHARD: @@ -146,45 +129,8 @@ def test_pgdata_import_smoke( else: raise ValueError - # fillfactor so we don't need to produce that much data - # 900 byte per row is > 10% => 1 row per page - vanilla_pg.safe_psql("""create table t (data char(900)) with (fillfactor = 10)""") - - nrows = 0 - while True: - relblock_size = vanilla_pg.safe_psql_scalar("select pg_relation_size('t')") - log.info( - f"relblock size: {relblock_size / 8192} pages (target: {target_relblock_size // 8192}) pages" - ) - if relblock_size >= target_relblock_size: - break - addrows = int((target_relblock_size - relblock_size) // 8192) - assert addrows >= 1, "forward progress" - vanilla_pg.safe_psql( - f"insert into t select generate_series({nrows + 1}, {nrows + addrows})" - ) - nrows += addrows - expect_nrows = nrows - expect_sum = ( - (nrows) * (nrows + 1) // 2 - ) # https://stackoverflow.com/questions/43901484/sum-of-the-integers-from-1-to-n - - def validate_vanilla_equivalence(ep): - # TODO: would be nicer to just compare pgdump - - # Enable IO concurrency for batching on large sequential scan, to avoid making - # this test unnecessarily onerous on CPU. Especially on debug mode, it's still - # pretty onerous though, so increase statement_timeout to avoid timeouts. - assert ep.safe_psql_many( - [ - "set effective_io_concurrency=32;", - "SET statement_timeout='300s';", - "select count(*), sum(data::bigint)::bigint from t", - ] - ) == [[], [], [(expect_nrows, expect_sum)]] - - validate_vanilla_equivalence(vanilla_pg) - + vanilla_pg.start() + rows_inserted = populate_vanilla_pg(vanilla_pg, target_relblock_size) vanilla_pg.stop() # @@ -275,14 +221,14 @@ def test_pgdata_import_smoke( config_lines=ep_config, ) - validate_vanilla_equivalence(ro_endpoint) + validate_import_from_vanilla_pg(ro_endpoint, rows_inserted) # ensure the import survives restarts ro_endpoint.stop() env.pageserver.stop(immediate=True) env.pageserver.start() ro_endpoint.start() - validate_vanilla_equivalence(ro_endpoint) + validate_import_from_vanilla_pg(ro_endpoint, rows_inserted) # # validate the layer files in each shard only have the shard-specific data @@ -322,7 +268,7 @@ def test_pgdata_import_smoke( child_workload = workload.branch(timeline_id=child_timeline_id, branch_name="br-tip") child_workload.validate() - validate_vanilla_equivalence(child_workload.endpoint()) + validate_import_from_vanilla_pg(child_workload.endpoint(), rows_inserted) # ... at the initdb lsn _ = env.create_branch( @@ -337,7 +283,7 @@ def test_pgdata_import_smoke( tenant_id=tenant_id, config_lines=ep_config, ) - validate_vanilla_equivalence(br_initdb_endpoint) + validate_import_from_vanilla_pg(br_initdb_endpoint, rows_inserted) with pytest.raises(psycopg2.errors.UndefinedTable): br_initdb_endpoint.safe_psql(f"select * from {workload.table}") @@ -578,23 +524,8 @@ def test_import_chaos( neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS) vanilla_pg.start() - vanilla_pg.safe_psql("create user cloud_admin with password 'postgres' superuser") - vanilla_pg.safe_psql("""create table t (data char(900)) with (fillfactor = 10)""") - nrows = 0 - while True: - relblock_size = vanilla_pg.safe_psql_scalar("select pg_relation_size('t')") - log.info( - f"relblock size: {relblock_size / 8192} pages (target: {TARGET_RELBOCK_SIZE // 8192}) pages" - ) - if relblock_size >= TARGET_RELBOCK_SIZE: - break - addrows = int((TARGET_RELBOCK_SIZE - relblock_size) // 8192) - assert addrows >= 1, "forward progress" - vanilla_pg.safe_psql( - f"insert into t select generate_series({nrows + 1}, {nrows + addrows})" - ) - nrows += addrows + inserted_rows = populate_vanilla_pg(vanilla_pg, TARGET_RELBOCK_SIZE) vanilla_pg.stop() @@ -762,13 +693,7 @@ def test_import_chaos( endpoint = env.endpoints.create_start(branch_name=import_branch_name, tenant_id=tenant_id) # Validate the imported data is legit - assert endpoint.safe_psql_many( - [ - "set effective_io_concurrency=32;", - "SET statement_timeout='300s';", - "select count(*), sum(data::bigint)::bigint from t", - ] - ) == [[], [], [(nrows, nrows * (nrows + 1) // 2)]] + validate_import_from_vanilla_pg(endpoint, inserted_rows) endpoint.stop() From f060537a310bf2fa4a00a905de826f95c170320b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arpad=20M=C3=BCller?= Date: Thu, 29 May 2025 16:07:33 +0200 Subject: [PATCH 26/48] Add safekeeper reconciler metrics (#12062) Adds two metrics to the storcon that are related to the safekeeper reconciler: * `storage_controller_safkeeper_reconciles_queued` to indicate currrent queue depth * `storage_controller_safkeeper_reconciles_complete` to indicate the number of complete reconciles Both metrics operate on a per-safekeeper basis (as reconcilers run on a per-safekeeper basis too). These metrics mirror the `storage_controller_pending_reconciles` and `storage_controller_reconcile_complete` metrics, although those are not scoped on a per-pageserver basis but are global for the entire storage controller. Part of #11670 --- storage_controller/src/metrics.rs | 19 ++++++++ .../src/service/safekeeper_reconciler.rs | 45 ++++++++++++++++++- 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/storage_controller/src/metrics.rs b/storage_controller/src/metrics.rs index 5ce2fb65e4..ccdbcad139 100644 --- a/storage_controller/src/metrics.rs +++ b/storage_controller/src/metrics.rs @@ -139,6 +139,14 @@ pub(crate) struct StorageControllerMetricGroup { /// HTTP request status counters for handled requests pub(crate) storage_controller_reconcile_long_running: measured::CounterVec, + + /// Indicator of safekeeper reconciler queue depth, broken down by safekeeper, excluding ongoing reconciles. + pub(crate) storage_controller_safkeeper_reconciles_queued: + measured::GaugeVec, + + /// Indicator of completed safekeeper reconciles, broken down by safekeeper. + pub(crate) storage_controller_safkeeper_reconciles_complete: + measured::CounterVec, } impl StorageControllerMetrics { @@ -257,6 +265,17 @@ pub(crate) enum Method { Other, } +#[derive(measured::LabelGroup, Clone)] +#[label(set = SafekeeperReconcilerLabelGroupSet)] +pub(crate) struct SafekeeperReconcilerLabelGroup<'a> { + #[label(dynamic_with = lasso::ThreadedRodeo, default)] + pub(crate) sk_az: &'a str, + #[label(dynamic_with = lasso::ThreadedRodeo, default)] + pub(crate) sk_node_id: &'a str, + #[label(dynamic_with = lasso::ThreadedRodeo, default)] + pub(crate) sk_hostname: &'a str, +} + impl From for Method { fn from(value: hyper::Method) -> Self { if value == hyper::Method::GET { diff --git a/storage_controller/src/service/safekeeper_reconciler.rs b/storage_controller/src/service/safekeeper_reconciler.rs index f756d98c64..fbf0b5c4e3 100644 --- a/storage_controller/src/service/safekeeper_reconciler.rs +++ b/storage_controller/src/service/safekeeper_reconciler.rs @@ -20,7 +20,9 @@ use utils::{ }; use crate::{ - persistence::SafekeeperTimelineOpKind, safekeeper::Safekeeper, + metrics::{METRICS_REGISTRY, SafekeeperReconcilerLabelGroup}, + persistence::SafekeeperTimelineOpKind, + safekeeper::Safekeeper, safekeeper_client::SafekeeperClient, }; @@ -218,7 +220,26 @@ impl ReconcilerHandle { fn schedule_reconcile(&self, req: ScheduleRequest) { let (cancel, token_id) = self.new_token_slot(req.tenant_id, req.timeline_id); let hostname = req.safekeeper.skp.host.clone(); + let sk_az = req.safekeeper.skp.availability_zone_id.clone(); + let sk_node_id = req.safekeeper.get_id().to_string(); + + // We don't have direct access to the queue depth here, so increase it blindly by 1. + // We know that putting into the queue increases the queue depth. The receiver will + // update with the correct value once it processes the next item. To avoid races where we + // reduce before we increase, leaving the gauge with a 1 value for a long time, we + // increase it before putting into the queue. + let queued_gauge = &METRICS_REGISTRY + .metrics_group + .storage_controller_safkeeper_reconciles_queued; + let label_group = SafekeeperReconcilerLabelGroup { + sk_az: &sk_az, + sk_node_id: &sk_node_id, + sk_hostname: &hostname, + }; + queued_gauge.inc(label_group.clone()); + if let Err(err) = self.tx.send((req, cancel, token_id)) { + queued_gauge.set(label_group, 0); tracing::info!("scheduling request onto {hostname} returned error: {err}"); } } @@ -283,6 +304,18 @@ impl SafekeeperReconciler { continue; } + let queued_gauge = &METRICS_REGISTRY + .metrics_group + .storage_controller_safkeeper_reconciles_queued; + queued_gauge.set( + SafekeeperReconcilerLabelGroup { + sk_az: &req.safekeeper.skp.availability_zone_id, + sk_node_id: &req.safekeeper.get_id().to_string(), + sk_hostname: &req.safekeeper.skp.host, + }, + self.rx.len() as i64, + ); + tokio::task::spawn(async move { let kind = req.kind; let tenant_id = req.tenant_id; @@ -511,6 +544,16 @@ impl SafekeeperReconcilerInner { req.generation, ) .await; + + let complete_counter = &METRICS_REGISTRY + .metrics_group + .storage_controller_safkeeper_reconciles_complete; + complete_counter.inc(SafekeeperReconcilerLabelGroup { + sk_az: &req.safekeeper.skp.availability_zone_id, + sk_node_id: &req.safekeeper.get_id().to_string(), + sk_hostname: &req.safekeeper.skp.host, + }); + if let Err(err) = res { tracing::info!( "couldn't remove reconciliation request onto {} from persistence: {err:?}", From 3b4d4eb53502fc16f8bef65be814b31783106f23 Mon Sep 17 00:00:00 2001 From: Gleb Novikov Date: Thu, 29 May 2025 19:25:42 +0100 Subject: [PATCH 27/48] fast_import.rs: log number of jobs for pg_dump/pg_restore (#12068) ## Problem I have a hypothesis that import might be using lower number of jobs than max for the VM, where the job is running. This change will help finding this out from logs ## Summary of changes Added logging of number of jobs, which is passed into both `pg_dump` and `pg_restore` --- compute_tools/src/bin/fast_import.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/compute_tools/src/bin/fast_import.rs b/compute_tools/src/bin/fast_import.rs index 78acd78585..e65c210b23 100644 --- a/compute_tools/src/bin/fast_import.rs +++ b/compute_tools/src/bin/fast_import.rs @@ -339,6 +339,8 @@ async fn run_dump_restore( destination_connstring: String, ) -> Result<(), anyhow::Error> { let dumpdir = workdir.join("dumpdir"); + let num_jobs = num_cpus::get().to_string(); + info!("using {num_jobs} jobs for dump/restore"); let common_args = [ // schema mapping (prob suffices to specify them on one side) @@ -354,7 +356,7 @@ async fn run_dump_restore( "directory".to_string(), // concurrency "--jobs".to_string(), - num_cpus::get().to_string(), + num_jobs, // progress updates "--verbose".to_string(), ]; From af429b4a62e54911dc58c2579e583a01db5f706c Mon Sep 17 00:00:00 2001 From: "Alex Chi Z." <4198311+skyzh@users.noreply.github.com> Date: Fri, 30 May 2025 16:02:25 +0800 Subject: [PATCH 28/48] feat(pageserver): observability for feature flags (#12034) ## Problem Part of #11813. This pull request adds misc observability improvements for the functionality. ## Summary of changes * Info span for the PostHog feature background loop. * New evaluate feature flag API. * Put the request error into the error message. * Log when feature flag gets updated. --------- Signed-off-by: Alex Chi Z --- .../src/background_loop.rs | 51 ++++++++++--------- libs/posthog_client_lite/src/lib.rs | 20 ++++++++ pageserver/src/feature_resolver.rs | 10 ++++ pageserver/src/http/routes.rs | 43 ++++++++++++++++ pageserver/src/tenant.rs | 2 +- 5 files changed, 102 insertions(+), 24 deletions(-) diff --git a/libs/posthog_client_lite/src/background_loop.rs b/libs/posthog_client_lite/src/background_loop.rs index 9ffcda3728..a05f6096b1 100644 --- a/libs/posthog_client_lite/src/background_loop.rs +++ b/libs/posthog_client_lite/src/background_loop.rs @@ -4,6 +4,7 @@ use std::{sync::Arc, time::Duration}; use arc_swap::ArcSwap; use tokio_util::sync::CancellationToken; +use tracing::{Instrument, info_span}; use crate::{FeatureStore, PostHogClient, PostHogClientConfig}; @@ -26,31 +27,35 @@ impl FeatureResolverBackgroundLoop { pub fn spawn(self: Arc, handle: &tokio::runtime::Handle, refresh_period: Duration) { let this = self.clone(); let cancel = self.cancel.clone(); - handle.spawn(async move { - tracing::info!("Starting PostHog feature resolver"); - let mut ticker = tokio::time::interval(refresh_period); - ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - loop { - tokio::select! { - _ = ticker.tick() => {} - _ = cancel.cancelled() => break - } - let resp = match this - .posthog_client - .get_feature_flags_local_evaluation() - .await - { - Ok(resp) => resp, - Err(e) => { - tracing::warn!("Cannot get feature flags: {}", e); - continue; + handle.spawn( + async move { + tracing::info!("Starting PostHog feature resolver"); + let mut ticker = tokio::time::interval(refresh_period); + ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + loop { + tokio::select! { + _ = ticker.tick() => {} + _ = cancel.cancelled() => break } - }; - let feature_store = FeatureStore::new_with_flags(resp.flags); - this.feature_store.store(Arc::new(feature_store)); + let resp = match this + .posthog_client + .get_feature_flags_local_evaluation() + .await + { + Ok(resp) => resp, + Err(e) => { + tracing::warn!("Cannot get feature flags: {}", e); + continue; + } + }; + let feature_store = FeatureStore::new_with_flags(resp.flags); + this.feature_store.store(Arc::new(feature_store)); + tracing::info!("Feature flag updated"); + } + tracing::info!("PostHog feature resolver stopped"); } - tracing::info!("PostHog feature resolver stopped"); - }); + .instrument(info_span!("posthog_feature_resolver")), + ); } pub fn feature_store(&self) -> Arc { diff --git a/libs/posthog_client_lite/src/lib.rs b/libs/posthog_client_lite/src/lib.rs index 8aa8da2898..ff12051196 100644 --- a/libs/posthog_client_lite/src/lib.rs +++ b/libs/posthog_client_lite/src/lib.rs @@ -448,6 +448,18 @@ impl FeatureStore { ))) } } + + /// Infer whether a feature flag is a boolean flag by checking if it has a multivariate filter. + pub fn is_feature_flag_boolean(&self, flag_key: &str) -> Result { + if let Some(flag_config) = self.flags.get(flag_key) { + Ok(flag_config.filters.multivariate.is_none()) + } else { + Err(PostHogEvaluationError::NotAvailable(format!( + "Not found in the local evaluation spec: {}", + flag_key + ))) + } + } } pub struct PostHogClientConfig { @@ -528,7 +540,15 @@ impl PostHogClient { .bearer_auth(&self.config.server_api_key) .send() .await?; + let status = response.status(); let body = response.text().await?; + if !status.is_success() { + return Err(anyhow::anyhow!( + "Failed to get feature flags: {}, {}", + status, + body + )); + } Ok(serde_json::from_str(&body)?) } diff --git a/pageserver/src/feature_resolver.rs b/pageserver/src/feature_resolver.rs index 2b0f368079..7e31b930d0 100644 --- a/pageserver/src/feature_resolver.rs +++ b/pageserver/src/feature_resolver.rs @@ -91,4 +91,14 @@ impl FeatureResolver { )) } } + + pub fn is_feature_flag_boolean(&self, flag_key: &str) -> Result { + if let Some(inner) = &self.inner { + inner.feature_store().is_feature_flag_boolean(flag_key) + } else { + Err(PostHogEvaluationError::NotAvailable( + "PostHog integration is not enabled".to_string(), + )) + } + } } diff --git a/pageserver/src/http/routes.rs b/pageserver/src/http/routes.rs index c449e3373f..1effa10404 100644 --- a/pageserver/src/http/routes.rs +++ b/pageserver/src/http/routes.rs @@ -3663,6 +3663,46 @@ async fn read_tar_eof(mut reader: (impl tokio::io::AsyncRead + Unpin)) -> anyhow Ok(()) } +async fn tenant_evaluate_feature_flag( + request: Request, + _cancel: CancellationToken, +) -> Result, ApiError> { + let tenant_shard_id: TenantShardId = parse_request_param(&request, "tenant_shard_id")?; + check_permission(&request, Some(tenant_shard_id.tenant_id))?; + + let flag: String = must_parse_query_param(&request, "flag")?; + let as_type: String = must_parse_query_param(&request, "as")?; + + let state = get_state(&request); + + async { + let tenant = state + .tenant_manager + .get_attached_tenant_shard(tenant_shard_id)?; + if as_type == "boolean" { + let result = tenant.feature_resolver.evaluate_boolean(&flag, tenant_shard_id.tenant_id); + let result = result.map(|_| true).map_err(|e| e.to_string()); + json_response(StatusCode::OK, result) + } else if as_type == "multivariate" { + let result = tenant.feature_resolver.evaluate_multivariate(&flag, tenant_shard_id.tenant_id).map_err(|e| e.to_string()); + json_response(StatusCode::OK, result) + } else { + // Auto infer the type of the feature flag. + let is_boolean = tenant.feature_resolver.is_feature_flag_boolean(&flag).map_err(|e| ApiError::InternalServerError(anyhow::anyhow!("{e}")))?; + if is_boolean { + let result = tenant.feature_resolver.evaluate_boolean(&flag, tenant_shard_id.tenant_id); + let result = result.map(|_| true).map_err(|e| e.to_string()); + json_response(StatusCode::OK, result) + } else { + let result = tenant.feature_resolver.evaluate_multivariate(&flag, tenant_shard_id.tenant_id).map_err(|e| e.to_string()); + json_response(StatusCode::OK, result) + } + } + } + .instrument(info_span!("tenant_evaluate_feature_flag", tenant_id = %tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug())) + .await +} + /// Common functionality of all the HTTP API handlers. /// /// - Adds a tracing span to each request (by `request_span`) @@ -4039,5 +4079,8 @@ pub fn make_router( "/v1/tenant/:tenant_shard_id/timeline/:timeline_id/activate_post_import", |r| api_handler(r, activate_post_import_handler), ) + .get("/v1/tenant/:tenant_shard_id/feature_flag", |r| { + api_handler(r, tenant_evaluate_feature_flag) + }) .any(handler_404)) } diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index 451d266bc0..3a054aff83 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -383,7 +383,7 @@ pub struct TenantShard { l0_flush_global_state: L0FlushGlobalState, - feature_resolver: FeatureResolver, + pub(crate) feature_resolver: FeatureResolver, } impl std::fmt::Debug for TenantShard { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { From e78d1e2ec6c398ed46ef51b60a03d5d498240848 Mon Sep 17 00:00:00 2001 From: John Spray Date: Fri, 30 May 2025 12:18:01 +0100 Subject: [PATCH 29/48] tests: tighten readability rules in test_location_conf_churn (#12059) ## Problem Checking the most recent state of pageservers was insufficient to evaluate whether another pageserver may read in a particular generation, since the latest state might mask some earlier AttachedSingle state. Related: https://github.com/neondatabase/neon/issues/11348 ## Summary of changes - Maintain a history of all attachments - Write out explicit rules for when a pageserver may read --- .../regress/test_pageserver_secondary.py | 73 +++++++++++++++---- 1 file changed, 59 insertions(+), 14 deletions(-) diff --git a/test_runner/regress/test_pageserver_secondary.py b/test_runner/regress/test_pageserver_secondary.py index f2523ec9b5..e5908de363 100644 --- a/test_runner/regress/test_pageserver_secondary.py +++ b/test_runner/regress/test_pageserver_secondary.py @@ -156,6 +156,45 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver, env.pageservers[2].id: ("Detached", None), } + # Track all the attached locations with mode and generation + history: list[tuple[int, str, int | None]] = [] + + def may_read(pageserver: NeonPageserver, mode: str, generation: int | None) -> bool: + # Rules for when a pageserver may read: + # - our generation is higher than any previous + # - our generation is equal to previous, but no other pageserver + # in that generation has been AttachedSingle (i.e. allowed to compact/GC) + # - our generation is equal to previous, and the previous holder of this + # generation was the same node as we're attaching now. + # + # If these conditions are not met, then a read _might_ work, but the pageserver might + # also hit errors trying to download layers. + highest_historic_generation = max([i[2] for i in history if i[2] is not None], default=None) + + if generation is None: + # We're not in an attached state, we may not read + return False + elif highest_historic_generation is not None and generation < highest_historic_generation: + # We are in an outdated generation, we may not read + return False + elif highest_historic_generation is not None and generation == highest_historic_generation: + # We are re-using a generation: if any pageserver other than this one + # has held AttachedSingle mode, this node may not read (because some other + # node may be doing GC/compaction). + if any( + i[1] == "AttachedSingle" + and i[2] == highest_historic_generation + and i[0] != pageserver.id + for i in history + ): + log.info( + f"Skipping read on {pageserver.id} because other pageserver has been in AttachedSingle mode in generation {highest_historic_generation}" + ) + return False + + # Fall through: we have passed conditions for readability + return True + latest_attached = env.pageservers[0].id for _i in range(0, 64): @@ -199,9 +238,10 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver, assert len(tenants) == 1 assert tenants[0]["generation"] == new_generation - log.info("Entering postgres...") - workload.churn_rows(rng.randint(128, 256), pageserver.id) - workload.validate(pageserver.id) + if may_read(pageserver, last_state_ps[0], last_state_ps[1]): + log.info("Entering postgres...") + workload.churn_rows(rng.randint(128, 256), pageserver.id) + workload.validate(pageserver.id) elif last_state_ps[0].startswith("Attached"): # The `storage_controller` will only re-attach on startup when a pageserver was the # holder of the latest generation: otherwise the pageserver will revert to detached @@ -241,18 +281,16 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver, location_conf["generation"] = generation pageserver.tenant_location_configure(tenant_id, location_conf) + last_state[pageserver.id] = (mode, generation) - # It's only valid to connect to the last generation. Newer generations may yank layer - # files used in older generations. - last_generation = max( - [s[1] for s in last_state.values() if s[1] is not None], default=None - ) + may_read_this_generation = may_read(pageserver, mode, generation) + history.append((pageserver.id, mode, generation)) - if mode.startswith("Attached") and generation == last_generation: - # This is a basic test: we are validating that he endpoint works properly _between_ - # configuration changes. A stronger test would be to validate that clients see - # no errors while we are making the changes. + # This is a basic test: we are validating that he endpoint works properly _between_ + # configuration changes. A stronger test would be to validate that clients see + # no errors while we are making the changes. + if may_read_this_generation: workload.churn_rows( rng.randint(128, 256), pageserver.id, upload=mode != "AttachedStale" ) @@ -265,9 +303,16 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver, assert gc_summary["remote_storage_errors"] == 0 assert gc_summary["indices_deleted"] > 0 - # Attach all pageservers + # Attach all pageservers, in a higher generation than any previous. We will use the same + # gen for all, and AttachedMulti mode so that they do not interfere with one another. + generation = env.storage_controller.attach_hook_issue(tenant_id, env.pageservers[0].id) for ps in env.pageservers: - location_conf = {"mode": "AttachedMulti", "secondary_conf": None, "tenant_conf": {}} + location_conf = { + "mode": "AttachedMulti", + "secondary_conf": None, + "tenant_conf": {}, + "generation": generation, + } ps.tenant_location_configure(tenant_id, location_conf) # Confirm that all are readable From 4a4a457312c6f39ee4dec137b599e92cbb7647ab Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Fri, 30 May 2025 13:22:37 +0200 Subject: [PATCH 30/48] fix(pageserver): frozen->L0 flush failure causes data loss (#12043) This patch is a fixup for - https://github.com/neondatabase/neon/pull/6788 Background ---------- That PR 6788 added artificial advancement of `disk_consistent_lsn` and `remote_consistent_lsn` for shards that weren't written to while other shards _were_ written to. See the PR description for more context. At the time of that PR, Pageservers shards were doing WAL filtering. Nowadays, the WAL filtering happens in Safekeepers. Shards learn about the WAL gaps via `InterpretedWalRecords::next_record_lsn`. The Bug ------- That artificial advancement code also runs if the flush failed. So, we advance the disk_consistent_lsn / remote_consistent_lsn, without having the corresponding L0 to the `index_part.json`. The frozen layer remains in the layer map until detach, so we continue to serve data correctly. We're not advancing flush loop variable `flushed_to_lsn` either, so, subsequent flush requests will retry the flush and repair the situation if they succeed. But if there aren't any successful retries, eventually the tenant will be detached and when it is attached somewhere else, the `index_part.json` and therefore layer map... 1. ... does not contain the frozen layer that failed to flush and 2. ... won't re-ingest that WAL either because walreceiver starts up with the advanced disk_consistent_lsn/remote_consistent_lsn. The result is that the read path will have a gap in the reconstruct data for the keys whose modifications were lost, resulting in a) either walredo failure b) or an incorrect page@lsn image if walredo doesn't error. The Fix ------- The fix is to only do the artificial advancement if `result.is_ok()`. Misc ---- As an aside, I took some time to re-review the flush loop and its callers. I found one more bug related to error handling that I filed here: - https://github.com/neondatabase/neon/issues/12025 ## Problem ## Summary of changes --- pageserver/src/tenant.rs | 77 ++++++++++++++++++++++++++++++- pageserver/src/tenant/timeline.rs | 9 +++- 2 files changed, 83 insertions(+), 3 deletions(-) diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index 3a054aff83..308ada3fa1 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -5832,6 +5832,7 @@ pub(crate) mod harness { pub conf: &'static PageServerConf, pub tenant_conf: pageserver_api::models::TenantConfig, pub tenant_shard_id: TenantShardId, + pub shard_identity: ShardIdentity, pub generation: Generation, pub shard: ShardIndex, pub remote_storage: GenericRemoteStorage, @@ -5899,6 +5900,7 @@ pub(crate) mod harness { conf, tenant_conf, tenant_shard_id, + shard_identity, generation, shard, remote_storage, @@ -5960,8 +5962,7 @@ pub(crate) mod harness { &ShardParameters::default(), )) .unwrap(), - // This is a legacy/test code path: sharding isn't supported here. - ShardIdentity::unsharded(), + self.shard_identity, Some(walredo_mgr), self.tenant_shard_id, self.remote_storage.clone(), @@ -6083,6 +6084,7 @@ mod tests { use timeline::compaction::{KeyHistoryRetention, KeyLogAtLsn}; use timeline::{CompactOptions, DeltaLayerTestDesc, VersionedKeySpaceQuery}; use utils::id::TenantId; + use utils::shard::{ShardCount, ShardNumber}; use super::*; use crate::DEFAULT_PG_VERSION; @@ -9418,6 +9420,77 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_failed_flush_should_not_update_disk_consistent_lsn() -> anyhow::Result<()> { + // + // Setup + // + let harness = TenantHarness::create_custom( + "test_failed_flush_should_not_upload_disk_consistent_lsn", + pageserver_api::models::TenantConfig::default(), + TenantId::generate(), + ShardIdentity::new(ShardNumber(0), ShardCount(4), ShardStripeSize(128)).unwrap(), + Generation::new(1), + ) + .await?; + let (tenant, ctx) = harness.load().await; + + let timeline = tenant + .create_test_timeline(TIMELINE_ID, Lsn(0x10), DEFAULT_PG_VERSION, &ctx) + .await?; + assert_eq!(timeline.get_shard_identity().count, ShardCount(4)); + let mut writer = timeline.writer().await; + writer + .put( + *TEST_KEY, + Lsn(0x20), + &Value::Image(test_img("foo at 0x20")), + &ctx, + ) + .await?; + writer.finish_write(Lsn(0x20)); + drop(writer); + timeline.freeze_and_flush().await.unwrap(); + + timeline.remote_client.wait_completion().await.unwrap(); + let disk_consistent_lsn = timeline.get_disk_consistent_lsn(); + let remote_consistent_lsn = timeline.get_remote_consistent_lsn_projected(); + assert_eq!(Some(disk_consistent_lsn), remote_consistent_lsn); + + // + // Test + // + + let mut writer = timeline.writer().await; + writer + .put( + *TEST_KEY, + Lsn(0x30), + &Value::Image(test_img("foo at 0x30")), + &ctx, + ) + .await?; + writer.finish_write(Lsn(0x30)); + drop(writer); + + fail::cfg( + "flush-layer-before-update-remote-consistent-lsn", + "return()", + ) + .unwrap(); + + let flush_res = timeline.freeze_and_flush().await; + // if flush failed, the disk/remote consistent LSN should not be updated + assert!(flush_res.is_err()); + assert_eq!(disk_consistent_lsn, timeline.get_disk_consistent_lsn()); + assert_eq!( + remote_consistent_lsn, + timeline.get_remote_consistent_lsn_projected() + ); + + Ok(()) + } + #[cfg(feature = "testing")] #[tokio::test] async fn test_simple_bottom_most_compaction_deltas_1() -> anyhow::Result<()> { diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index 71765b9197..23c40a7629 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -4767,7 +4767,10 @@ impl Timeline { || !flushed_to_lsn.is_valid() ); - if flushed_to_lsn < frozen_to_lsn && self.shard_identity.count.count() > 1 { + if flushed_to_lsn < frozen_to_lsn + && self.shard_identity.count.count() > 1 + && result.is_ok() + { // If our layer flushes didn't carry disk_consistent_lsn up to the `to_lsn` advertised // to us via layer_flush_start_rx, then advance it here. // @@ -4946,6 +4949,10 @@ impl Timeline { return Err(FlushLayerError::Cancelled); } + fail_point!("flush-layer-before-update-remote-consistent-lsn", |_| { + Err(FlushLayerError::Other(anyhow!("failpoint").into())) + }); + let disk_consistent_lsn = Lsn(lsn_range.end.0 - 1); // The new on-disk layers are now in the layer map. We can remove the From 99726495c79941efdf5ccc695ccb6a6ad046ac7e Mon Sep 17 00:00:00 2001 From: Vlad Lazar Date: Fri, 30 May 2025 13:14:36 +0100 Subject: [PATCH 31/48] test: allow list overly eager storcon finalization (#12055) ## Problem I noticed a small percentage of flakes on some import tests. They were all instances of the storage controller being too eager on the finalization. As a refresher: the pageserver notifies the storage controller that it's done from the import task and the storage controller has to call back into it in order to finalize the import. The pageserver checks that the import task is done before serving that request. Hence, we can get this race. In practice, this has no impact since the storage controller will simply retry. ## Summary of changes Allow list such cases --- test_runner/regress/test_import_pgdata.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/test_runner/regress/test_import_pgdata.py b/test_runner/regress/test_import_pgdata.py index ba60c3caa6..8d4f908cc0 100644 --- a/test_runner/regress/test_import_pgdata.py +++ b/test_runner/regress/test_import_pgdata.py @@ -287,6 +287,17 @@ def test_pgdata_import_smoke( with pytest.raises(psycopg2.errors.UndefinedTable): br_initdb_endpoint.safe_psql(f"select * from {workload.table}") + # The storage controller might be overly eager and attempt to finalize + # the import before the task got a chance to exit. + env.storage_controller.allowed_errors.extend( + [ + ".*Call to node.*management API.*failed.*Import task still running.*", + ] + ) + + for ps in env.pageservers: + ps.allowed_errors.extend([".*Error processing HTTP request.*Import task not done yet.*"]) + @run_only_on_default_postgres(reason="PG version is irrelevant here") def test_import_completion_on_restart( @@ -471,6 +482,17 @@ def test_import_respects_timeline_lifecycle( else: raise RuntimeError(f"{action} param not recognized") + # The storage controller might be overly eager and attempt to finalize + # the import before the task got a chance to exit. + env.storage_controller.allowed_errors.extend( + [ + ".*Call to node.*management API.*failed.*Import task still running.*", + ] + ) + + for ps in env.pageservers: + ps.allowed_errors.extend([".*Error processing HTTP request.*Import task not done yet.*"]) + @skip_in_debug_build("Validation query takes too long in debug builds") def test_import_chaos( From 6d95a3fe2dcef3c8f83f5b473fee6cd3f70dd209 Mon Sep 17 00:00:00 2001 From: Vlad Lazar Date: Fri, 30 May 2025 13:30:11 +0100 Subject: [PATCH 32/48] pageserver: various import flow fixups (#12047) ## Problem There's a bunch of TODOs in the import code. ## Summary of changes 1. Bound max import byte range to 128MiB. This might still be too high, given the default job concurrency, but it needs to be balanced with going back and forth to S3. 2. Prevent unsigned overflow when determining key range splits for concurrent jobs 3. Use sharded ranges to estimate task size when splitting jobs 4. Bubble up errors that we might hit due to invalid data in the bucket back to the storage controller. 5. Tweak the import bucket S3 client configuration. --- .../src/tenant/timeline/import_pgdata/flow.rs | 68 +++++++++++-------- .../import_pgdata/importbucket_client.rs | 19 +++--- 2 files changed, 50 insertions(+), 37 deletions(-) diff --git a/pageserver/src/tenant/timeline/import_pgdata/flow.rs b/pageserver/src/tenant/timeline/import_pgdata/flow.rs index 9743aa3f26..bf3c7eeda6 100644 --- a/pageserver/src/tenant/timeline/import_pgdata/flow.rs +++ b/pageserver/src/tenant/timeline/import_pgdata/flow.rs @@ -11,19 +11,7 @@ //! - => S3 as the source for the PGDATA instead of local filesystem //! //! TODOs before productionization: -//! - ChunkProcessingJob size / ImportJob::total_size does not account for sharding. -//! => produced image layers likely too small. //! - ChunkProcessingJob should cut up an ImportJob to hit exactly target image layer size. -//! - asserts / unwraps need to be replaced with errors -//! - don't trust remote objects will be small (=prevent OOMs in those cases) -//! - limit all in-memory buffers in size, or download to disk and read from there -//! - limit task concurrency -//! - generally play nice with other tenants in the system -//! - importbucket is different bucket than main pageserver storage, so, should be fine wrt S3 rate limits -//! - but concerns like network bandwidth, local disk write bandwidth, local disk capacity, etc -//! - integrate with layer eviction system -//! - audit for Tenant::cancel nor Timeline::cancel responsivity -//! - audit for Tenant/Timeline gate holding (we spawn tokio tasks during this flow!) //! //! An incomplete set of TODOs from the Hackathon: //! - version-specific CheckPointData (=> pgv abstraction, already exists for regular walingest) @@ -44,7 +32,7 @@ use pageserver_api::key::{ rel_dir_to_key, rel_size_to_key, relmap_file_key, slru_block_to_key, slru_dir_to_key, slru_segment_size_to_key, }; -use pageserver_api::keyspace::{contiguous_range_len, is_contiguous_range, singleton_range}; +use pageserver_api::keyspace::{ShardedRange, singleton_range}; use pageserver_api::models::{ShardImportProgress, ShardImportProgressV1, ShardImportStatus}; use pageserver_api::reltag::{RelTag, SlruKind}; use pageserver_api::shard::ShardIdentity; @@ -167,6 +155,7 @@ impl Planner { /// This function is and must remain pure: given the same input, it will generate the same import plan. async fn plan(mut self, import_config: &TimelineImportConfig) -> anyhow::Result { let pgdata_lsn = Lsn(self.control_file.control_file_data().checkPoint).align(); + anyhow::ensure!(pgdata_lsn.is_valid()); let datadir = PgDataDir::new(&self.storage).await?; @@ -249,14 +238,22 @@ impl Planner { }); // Assigns parts of key space to later parallel jobs + // Note: The image layers produced here may have gaps, meaning, + // there is not an image for each key in the layer's key range. + // The read path stops traversal at the first image layer, regardless + // of whether a base image has been found for a key or not. + // (Concept of sparse image layers doesn't exist.) + // This behavior is exactly right for the base image layers we're producing here. + // But, since no other place in the code currently produces image layers with gaps, + // it seems noteworthy. let mut last_end_key = Key::MIN; let mut current_chunk = Vec::new(); let mut current_chunk_size: usize = 0; let mut jobs = Vec::new(); for task in std::mem::take(&mut self.tasks).into_iter() { - if current_chunk_size + task.total_size() - > import_config.import_job_soft_size_limit.into() - { + let task_size = task.total_size(&self.shard); + let projected_chunk_size = current_chunk_size.saturating_add(task_size); + if projected_chunk_size > import_config.import_job_soft_size_limit.into() { let key_range = last_end_key..task.key_range().start; jobs.push(ChunkProcessingJob::new( key_range.clone(), @@ -266,7 +263,7 @@ impl Planner { last_end_key = key_range.end; current_chunk_size = 0; } - current_chunk_size += task.total_size(); + current_chunk_size = current_chunk_size.saturating_add(task_size); current_chunk.push(task); } jobs.push(ChunkProcessingJob::new( @@ -604,18 +601,18 @@ impl PgDataDirDb { }; let path = datadir_path.join(rel_tag.to_segfile_name(segno)); - assert!(filesize % BLCKSZ as usize == 0); // TODO: this should result in an error + anyhow::ensure!(filesize % BLCKSZ as usize == 0); let nblocks = filesize / BLCKSZ as usize; - PgDataDirDbFile { + Ok(PgDataDirDbFile { path, filesize, rel_tag, segno, nblocks: Some(nblocks), // first non-cummulative sizes - } + }) }) - .collect(); + .collect::>()?; // Set cummulative sizes. Do all of that math here, so that later we could easier // parallelize over segments and know with which segments we need to write relsize @@ -650,12 +647,22 @@ impl PgDataDirDb { trait ImportTask { fn key_range(&self) -> Range; - fn total_size(&self) -> usize { - // TODO: revisit this - if is_contiguous_range(&self.key_range()) { - contiguous_range_len(&self.key_range()) as usize * 8192 + fn total_size(&self, shard_identity: &ShardIdentity) -> usize { + let range = ShardedRange::new(self.key_range(), shard_identity); + let page_count = range.page_count(); + if page_count == u32::MAX { + tracing::warn!( + "Import task has non contiguous key range: {}..{}", + self.key_range().start, + self.key_range().end + ); + + // Tasks should operate on contiguous ranges. It is unexpected for + // ranges to violate this assumption. Calling code handles this by mapping + // any task on a non contiguous range to its own image layer. + usize::MAX } else { - u32::MAX as usize + page_count as usize * 8192 } } @@ -753,6 +760,8 @@ impl ImportTask for ImportRelBlocksTask { layer_writer: &mut ImageLayerWriter, ctx: &RequestContext, ) -> anyhow::Result { + const MAX_BYTE_RANGE_SIZE: usize = 128 * 1024 * 1024; + debug!("Importing relation file"); let (rel_tag, start_blk) = self.key_range.start.to_rel_block()?; @@ -777,7 +786,7 @@ impl ImportTask for ImportRelBlocksTask { assert_eq!(key.len(), 1); assert!(!acc.is_empty()); assert!(acc_end > acc_start); - if acc_end == start /* TODO additional max range check here, to limit memory consumption per task to X */ { + if acc_end == start && end - acc_start <= MAX_BYTE_RANGE_SIZE { acc.push(key.pop().unwrap()); Ok((acc, acc_start, end)) } else { @@ -792,8 +801,8 @@ impl ImportTask for ImportRelBlocksTask { .get_range(&self.path, range_start.into_u64(), range_end.into_u64()) .await?; let mut buf = Bytes::from(range_buf); - // TODO: batched writes for key in keys { + // The writer buffers writes internally let image = buf.split_to(8192); layer_writer.put_image(key, image, ctx).await?; nimages += 1; @@ -846,6 +855,9 @@ impl ImportTask for ImportSlruBlocksTask { debug!("Importing SLRU segment file {}", self.path); let buf = self.storage.get(&self.path).await?; + // TODO(vlad): Does timestamp to LSN work for imported timelines? + // Probably not since we don't append the `xact_time` to it as in + // [`WalIngest::ingest_xact_record`]. let (kind, segno, start_blk) = self.key_range.start.to_slru_block()?; let (_kind, _segno, end_blk) = self.key_range.end.to_slru_block()?; let mut blknum = start_blk; diff --git a/pageserver/src/tenant/timeline/import_pgdata/importbucket_client.rs b/pageserver/src/tenant/timeline/import_pgdata/importbucket_client.rs index 34313748b7..bf2d9875c1 100644 --- a/pageserver/src/tenant/timeline/import_pgdata/importbucket_client.rs +++ b/pageserver/src/tenant/timeline/import_pgdata/importbucket_client.rs @@ -6,7 +6,7 @@ use bytes::Bytes; use postgres_ffi::ControlFileData; use remote_storage::{ Download, DownloadError, DownloadKind, DownloadOpts, GenericRemoteStorage, Listing, - ListingObject, RemotePath, + ListingObject, RemotePath, RemoteStorageConfig, }; use serde::de::DeserializeOwned; use tokio_util::sync::CancellationToken; @@ -22,11 +22,9 @@ pub async fn new( location: &index_part_format::Location, cancel: CancellationToken, ) -> Result { - // FIXME: we probably want some timeout, and we might be able to assume the max file - // size on S3 is 1GiB (postgres segment size). But the problem is that the individual - // downloaders don't know enough about concurrent downloads to make a guess on the - // expected bandwidth and resulting best timeout. - let timeout = std::time::Duration::from_secs(24 * 60 * 60); + // Downloads should be reasonably sized. We do ranged reads for relblock raw data + // and full reads for SLRU segments which are bounded by Postgres. + let timeout = RemoteStorageConfig::DEFAULT_TIMEOUT; let location_storage = match location { #[cfg(feature = "testing")] index_part_format::Location::LocalFs { path } => { @@ -50,9 +48,12 @@ pub async fn new( .import_pgdata_aws_endpoint_url .clone() .map(|url| url.to_string()), // by specifying None here, remote_storage/aws-sdk-rust will infer from env - concurrency_limit: 100.try_into().unwrap(), // TODO: think about this - max_keys_per_list_response: Some(1000), // TODO: think about this - upload_storage_class: None, // irrelevant + // This matches the default import job concurrency. This is managed + // separately from the usual S3 client, but the concern here is bandwidth + // usage. + concurrency_limit: 128.try_into().unwrap(), + max_keys_per_list_response: Some(1000), + upload_storage_class: None, // irrelevant }, timeout, ) From 35372a8f12ae143da475cc0b5de1529d5c05804e Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Fri, 30 May 2025 15:22:53 +0200 Subject: [PATCH 33/48] adjust VirtualFile operation latency histogram buckets (#12075) The expected operating range for the production NVMe drives is in the range of 50 to 250us. The bucket boundaries before this PR were not well suited to reason about the utilization / queuing / latency variability of those devices. # Performance There was some concern about perf impact of having so many buckets, considering the impl does a linear search on each observe(). I added a benchmark and measured on relevant machines. In any way, the PR is 40 buckets, so, won't make a meaningful difference on production machines (im4gn.2xlarge), going from 30ns -> 35ns. --- libs/metrics/src/lib.rs | 1 + pageserver/benches/bench_metrics.rs | 72 ++++++++++++++++++++++++++++- pageserver/src/metrics.rs | 43 +++++++++++++++-- 3 files changed, 110 insertions(+), 6 deletions(-) diff --git a/libs/metrics/src/lib.rs b/libs/metrics/src/lib.rs index 4df8d7bc51..5d028ee041 100644 --- a/libs/metrics/src/lib.rs +++ b/libs/metrics/src/lib.rs @@ -27,6 +27,7 @@ pub use prometheus::{ pub mod launch_timestamp; mod wrappers; +pub use prometheus; pub use wrappers::{CountedReader, CountedWriter}; mod hll; pub use hll::{HyperLogLog, HyperLogLogState, HyperLogLogVec}; diff --git a/pageserver/benches/bench_metrics.rs b/pageserver/benches/bench_metrics.rs index 38025124e1..e0428f6372 100644 --- a/pageserver/benches/bench_metrics.rs +++ b/pageserver/benches/bench_metrics.rs @@ -264,10 +264,56 @@ mod propagation_of_cached_label_value { } } +criterion_group!(histograms, histograms::bench_bucket_scalability); +mod histograms { + use std::time::Instant; + + use criterion::{BenchmarkId, Criterion}; + use metrics::core::Collector; + + pub fn bench_bucket_scalability(c: &mut Criterion) { + let mut g = c.benchmark_group("bucket_scalability"); + + for n in [1, 4, 8, 16, 32, 64, 128, 256] { + g.bench_with_input(BenchmarkId::new("nbuckets", n), &n, |b, n| { + b.iter_custom(|iters| { + let buckets: Vec = (0..*n).map(|i| i as f64 * 100.0).collect(); + let histo = metrics::Histogram::with_opts( + metrics::prometheus::HistogramOpts::new("name", "help") + .buckets(buckets.clone()), + ) + .unwrap(); + let start = Instant::now(); + for i in 0..usize::try_from(iters).unwrap() { + histo.observe(buckets[i % buckets.len()]); + } + let elapsed = start.elapsed(); + // self-test + let mfs = histo.collect(); + assert_eq!(mfs.len(), 1); + let metrics = mfs[0].get_metric(); + assert_eq!(metrics.len(), 1); + let histo = metrics[0].get_histogram(); + let buckets = histo.get_bucket(); + assert!( + buckets + .iter() + .enumerate() + .all(|(i, b)| b.get_cumulative_count() + >= i as u64 * (iters / buckets.len() as u64)) + ); + elapsed + }) + }); + } + } +} + criterion_main!( label_values, single_metric_multicore_scalability, - propagation_of_cached_label_value + propagation_of_cached_label_value, + histograms, ); /* @@ -290,6 +336,14 @@ propagation_of_cached_label_value__naive/nthreads/8 time: [211.50 ns 214.44 ns propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/1 time: [14.135 ns 14.147 ns 14.160 ns] propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/4 time: [14.243 ns 14.255 ns 14.268 ns] propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/8 time: [14.470 ns 14.682 ns 14.895 ns] +bucket_scalability/nbuckets/1 time: [30.352 ns 30.353 ns 30.354 ns] +bucket_scalability/nbuckets/4 time: [30.464 ns 30.465 ns 30.467 ns] +bucket_scalability/nbuckets/8 time: [30.569 ns 30.575 ns 30.584 ns] +bucket_scalability/nbuckets/16 time: [30.961 ns 30.965 ns 30.969 ns] +bucket_scalability/nbuckets/32 time: [35.691 ns 35.707 ns 35.722 ns] +bucket_scalability/nbuckets/64 time: [47.829 ns 47.898 ns 47.974 ns] +bucket_scalability/nbuckets/128 time: [73.479 ns 73.512 ns 73.545 ns] +bucket_scalability/nbuckets/256 time: [127.92 ns 127.94 ns 127.96 ns] Results on an i3en.3xlarge instance @@ -344,6 +398,14 @@ propagation_of_cached_label_value__naive/nthreads/8 time: [434.87 ns 456.4 propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/1 time: [3.3767 ns 3.3974 ns 3.4220 ns] propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/4 time: [3.6105 ns 4.2355 ns 5.1463 ns] propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/8 time: [4.0889 ns 4.9714 ns 6.0779 ns] +bucket_scalability/nbuckets/1 time: [4.8455 ns 4.8542 ns 4.8646 ns] +bucket_scalability/nbuckets/4 time: [4.5663 ns 4.5722 ns 4.5787 ns] +bucket_scalability/nbuckets/8 time: [4.5531 ns 4.5670 ns 4.5842 ns] +bucket_scalability/nbuckets/16 time: [4.6392 ns 4.6524 ns 4.6685 ns] +bucket_scalability/nbuckets/32 time: [6.0302 ns 6.0439 ns 6.0589 ns] +bucket_scalability/nbuckets/64 time: [10.608 ns 10.644 ns 10.691 ns] +bucket_scalability/nbuckets/128 time: [22.178 ns 22.316 ns 22.483 ns] +bucket_scalability/nbuckets/256 time: [42.190 ns 42.328 ns 42.492 ns] Results on a Hetzner AX102 AMD Ryzen 9 7950X3D 16-Core Processor @@ -362,5 +424,13 @@ propagation_of_cached_label_value__naive/nthreads/8 time: [164.24 ns 170.1 propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/1 time: [2.2915 ns 2.2960 ns 2.3012 ns] propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/4 time: [2.5726 ns 2.6158 ns 2.6624 ns] propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/8 time: [2.7068 ns 2.8243 ns 2.9824 ns] +bucket_scalability/nbuckets/1 time: [6.3998 ns 6.4288 ns 6.4684 ns] +bucket_scalability/nbuckets/4 time: [6.3603 ns 6.3620 ns 6.3637 ns] +bucket_scalability/nbuckets/8 time: [6.1646 ns 6.1654 ns 6.1667 ns] +bucket_scalability/nbuckets/16 time: [6.1341 ns 6.1391 ns 6.1454 ns] +bucket_scalability/nbuckets/32 time: [8.2206 ns 8.2254 ns 8.2301 ns] +bucket_scalability/nbuckets/64 time: [13.988 ns 13.994 ns 14.000 ns] +bucket_scalability/nbuckets/128 time: [28.180 ns 28.216 ns 28.251 ns] +bucket_scalability/nbuckets/256 time: [54.914 ns 54.931 ns 54.951 ns] */ diff --git a/pageserver/src/metrics.rs b/pageserver/src/metrics.rs index 0ff31dcb8a..a9b2f1b7e0 100644 --- a/pageserver/src/metrics.rs +++ b/pageserver/src/metrics.rs @@ -1312,11 +1312,44 @@ impl EvictionsWithLowResidenceDuration { // // Roughly logarithmic scale. const STORAGE_IO_TIME_BUCKETS: &[f64] = &[ - 0.000030, // 30 usec - 0.001000, // 1000 usec - 0.030, // 30 ms - 1.000, // 1000 ms - 30.000, // 30000 ms + 0.00005, // 50us + 0.00006, // 60us + 0.00007, // 70us + 0.00008, // 80us + 0.00009, // 90us + 0.0001, // 100us + 0.000110, // 110us + 0.000120, // 120us + 0.000130, // 130us + 0.000140, // 140us + 0.000150, // 150us + 0.000160, // 160us + 0.000170, // 170us + 0.000180, // 180us + 0.000190, // 190us + 0.000200, // 200us + 0.000210, // 210us + 0.000220, // 220us + 0.000230, // 230us + 0.000240, // 240us + 0.000250, // 250us + 0.000300, // 300us + 0.000350, // 350us + 0.000400, // 400us + 0.000450, // 450us + 0.000500, // 500us + 0.000600, // 600us + 0.000700, // 700us + 0.000800, // 800us + 0.000900, // 900us + 0.001000, // 1ms + 0.002000, // 2ms + 0.003000, // 3ms + 0.004000, // 4ms + 0.005000, // 5ms + 0.01000, // 10ms + 0.02000, // 20ms + 0.05000, // 50ms ]; /// VirtualFile fs operation variants. From 8d26978ed9ede0273ece491fca21104c0ba63835 Mon Sep 17 00:00:00 2001 From: Alexander Lakhin Date: Fri, 30 May 2025 18:20:46 +0300 Subject: [PATCH 34/48] Allow known pageserver errors in test_location_conf_churn (#12082) ## Problem While a pageserver in the unreadable state could not be accessed by postgres thanks to https://github.com/neondatabase/neon/pull/12059, it may still receive WAL records and bump into the "layer file download failed: No file found" error when trying to ingest them. Closes: https://github.com/neondatabase/neon/issues/11348 ## Summary of changes Allow errors from wal_connection_manager, which are considered expected. See https://github.com/neondatabase/neon/issues/11348. --- test_runner/regress/test_pageserver_secondary.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test_runner/regress/test_pageserver_secondary.py b/test_runner/regress/test_pageserver_secondary.py index e5908de363..8d18311f3d 100644 --- a/test_runner/regress/test_pageserver_secondary.py +++ b/test_runner/regress/test_pageserver_secondary.py @@ -124,6 +124,9 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver, ".*downloading failed, possibly for shutdown", # {tenant_id=... timeline_id=...}:handle_pagerequests:handle_get_page_at_lsn_request{rel=1664/0/1260 blkno=0 req_lsn=0/149F0D8}: error reading relation or page version: Not found: will not become active. Current state: Stopping\n' ".*page_service.*will not become active.*", + # the following errors are possible when pageserver tries to ingest wal records despite being in unreadable state + ".*wal_connection_manager.*layer file download failed: No file found.*", + ".*wal_connection_manager.*could not ingest record.*", ] ) From 62cd3b8d3d60da5c72224952da20e2ad3d494a11 Mon Sep 17 00:00:00 2001 From: Shockingly Good Date: Fri, 30 May 2025 17:26:22 +0200 Subject: [PATCH 35/48] fix(compute) Remove the hardcoded default value for PGXN HTTP URL. (#12030) Removes the hardcoded value for the Postgres Extensions HTTP gateway URL as it is always provided by the calling code. --- compute_tools/src/bin/compute_ctl.rs | 31 +--------------------------- 1 file changed, 1 insertion(+), 30 deletions(-) diff --git a/compute_tools/src/bin/compute_ctl.rs b/compute_tools/src/bin/compute_ctl.rs index 02339f752c..f9d9c03422 100644 --- a/compute_tools/src/bin/compute_ctl.rs +++ b/compute_tools/src/bin/compute_ctl.rs @@ -57,21 +57,6 @@ use tracing::{error, info}; use url::Url; use utils::failpoint_support; -// Compatibility hack: if the control plane specified any remote-ext-config -// use the default value for extension storage proxy gateway. -// Remove this once the control plane is updated to pass the gateway URL -fn parse_remote_ext_base_url(arg: &str) -> Result { - const FALLBACK_PG_EXT_GATEWAY_BASE_URL: &str = - "http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local"; - - Ok(if arg.starts_with("http") { - arg - } else { - FALLBACK_PG_EXT_GATEWAY_BASE_URL - } - .to_owned()) -} - #[derive(Parser)] #[command(rename_all = "kebab-case")] struct Cli { @@ -80,7 +65,7 @@ struct Cli { /// The base URL for the remote extension storage proxy gateway. /// Should be in the form of `http(s)://[:]`. - #[arg(short = 'r', long, value_parser = parse_remote_ext_base_url, alias = "remote-ext-config")] + #[arg(short = 'r', long, alias = "remote-ext-config")] pub remote_ext_base_url: Option, /// The port to bind the external listening HTTP server to. Clients running @@ -276,18 +261,4 @@ mod test { fn verify_cli() { Cli::command().debug_assert() } - - #[test] - fn parse_pg_ext_gateway_base_url() { - let arg = "http://pg-ext-s3-gateway2"; - let result = super::parse_remote_ext_base_url(arg).unwrap(); - assert_eq!(result, arg); - - let arg = "pg-ext-s3-gateway"; - let result = super::parse_remote_ext_base_url(arg).unwrap(); - assert_eq!( - result, - "http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local" - ); - } } From f6c0f6c4ecbc2bc5beba7d57d8b44ed71d8300c7 Mon Sep 17 00:00:00 2001 From: "Alex Chi Z." <4198311+skyzh@users.noreply.github.com> Date: Sat, 31 May 2025 01:00:41 +0800 Subject: [PATCH 36/48] fix(ci): install build tools with --locked (#12083) ## Problem Release pipeline failing due to some tools cannot be installed. ## Summary of changes Install with `--locked`. Signed-off-by: Alex Chi Z --- build-tools.Dockerfile | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/build-tools.Dockerfile b/build-tools.Dockerfile index 9d4c93e1cd..f97f04968e 100644 --- a/build-tools.Dockerfile +++ b/build-tools.Dockerfile @@ -310,13 +310,13 @@ RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux . "$HOME/.cargo/env" && \ cargo --version && rustup --version && \ rustup component add llvm-tools rustfmt clippy && \ - cargo install rustfilt --version ${RUSTFILT_VERSION} && \ - cargo install cargo-hakari --version ${CARGO_HAKARI_VERSION} && \ - cargo install cargo-deny --locked --version ${CARGO_DENY_VERSION} && \ - cargo install cargo-hack --version ${CARGO_HACK_VERSION} && \ - cargo install cargo-nextest --version ${CARGO_NEXTEST_VERSION} && \ - cargo install cargo-chef --locked --version ${CARGO_CHEF_VERSION} && \ - cargo install diesel_cli --version ${CARGO_DIESEL_CLI_VERSION} \ + cargo install rustfilt --version ${RUSTFILT_VERSION} --locked && \ + cargo install cargo-hakari --version ${CARGO_HAKARI_VERSION} --locked && \ + cargo install cargo-deny --version ${CARGO_DENY_VERSION} --locked && \ + cargo install cargo-hack --version ${CARGO_HACK_VERSION} --locked && \ + cargo install cargo-nextest --version ${CARGO_NEXTEST_VERSION} --locked && \ + cargo install cargo-chef --version ${CARGO_CHEF_VERSION} --locked && \ + cargo install diesel_cli --version ${CARGO_DIESEL_CLI_VERSION} --locked \ --features postgres-bundled --no-default-features && \ rm -rf /home/nonroot/.cargo/registry && \ rm -rf /home/nonroot/.cargo/git From f05df409bd63424f0c5fb0efc14547fcad854c8b Mon Sep 17 00:00:00 2001 From: Shockingly Good Date: Fri, 30 May 2025 19:45:24 +0200 Subject: [PATCH 37/48] impr(compute): Remove the deprecated CLI arg alias for remote-ext-config. (#12087) Also moves it from `String` to `Url`. --- compute_tools/src/bin/compute_ctl.rs | 5 ++--- compute_tools/src/compute.rs | 3 ++- compute_tools/src/extension_server.rs | 5 +++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/compute_tools/src/bin/compute_ctl.rs b/compute_tools/src/bin/compute_ctl.rs index f9d9c03422..db6835da61 100644 --- a/compute_tools/src/bin/compute_ctl.rs +++ b/compute_tools/src/bin/compute_ctl.rs @@ -64,9 +64,8 @@ struct Cli { pub pgbin: String, /// The base URL for the remote extension storage proxy gateway. - /// Should be in the form of `http(s)://[:]`. - #[arg(short = 'r', long, alias = "remote-ext-config")] - pub remote_ext_base_url: Option, + #[arg(short = 'r', long)] + pub remote_ext_base_url: Option, /// The port to bind the external listening HTTP server to. Clients running /// outside the compute will talk to the compute through this port. Keep diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index ff49c737f0..d678b7d670 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -31,6 +31,7 @@ use std::time::{Duration, Instant}; use std::{env, fs}; use tokio::spawn; use tracing::{Instrument, debug, error, info, instrument, warn}; +use url::Url; use utils::id::{TenantId, TimelineId}; use utils::lsn::Lsn; use utils::measured_stream::MeasuredReader; @@ -96,7 +97,7 @@ pub struct ComputeNodeParams { pub internal_http_port: u16, /// the address of extension storage proxy gateway - pub remote_ext_base_url: Option, + pub remote_ext_base_url: Option, /// Interval for installed extensions collection pub installed_extensions_collection_interval: u64, diff --git a/compute_tools/src/extension_server.rs b/compute_tools/src/extension_server.rs index 3439383699..1857afa08c 100644 --- a/compute_tools/src/extension_server.rs +++ b/compute_tools/src/extension_server.rs @@ -83,6 +83,7 @@ use reqwest::StatusCode; use tar::Archive; use tracing::info; use tracing::log::warn; +use url::Url; use zstd::stream::read::Decoder; use crate::metrics::{REMOTE_EXT_REQUESTS_TOTAL, UNKNOWN_HTTP_STATUS}; @@ -158,14 +159,14 @@ fn parse_pg_version(human_version: &str) -> PostgresMajorVersion { pub async fn download_extension( ext_name: &str, ext_path: &RemotePath, - remote_ext_base_url: &str, + remote_ext_base_url: &Url, pgbin: &str, ) -> Result { info!("Download extension {:?} from {:?}", ext_name, ext_path); // TODO add retry logic let download_buffer = - match download_extension_tar(remote_ext_base_url, &ext_path.to_string()).await { + match download_extension_tar(remote_ext_base_url.as_str(), &ext_path.to_string()).await { Ok(buffer) => buffer, Err(error_message) => { return Err(anyhow::anyhow!( From 87179e26b3c18d9cd09b9eecbfed1db742b391ab Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Sun, 1 Jun 2025 19:41:45 +0100 Subject: [PATCH 38/48] completely rewrite pq_proto (#12085) libs/pqproto is designed for safekeeper/pageserver with maximum throughput. proxy only needs it for handshakes/authentication where throughput is not a concern but memory efficiency is. For this reason, we switch to using read_exact and only allocating as much memory as we need to. All reads return a `&'a [u8]` instead of a `Bytes` because accidental sharing of bytes can cause fragmentation. Returning the reference enforces all callers only hold onto the bytes they absolutely need. For example, before this change, `pqproto` was allocating 8KiB for the initial read `BytesMut`, and proxy was holding the `Bytes` in the `StartupMessageParams` for the entire connection through to passthrough. --- proxy/src/auth/backend/classic.rs | 26 +- proxy/src/auth/backend/console_redirect.rs | 16 +- proxy/src/auth/backend/hacks.rs | 30 +- proxy/src/auth/backend/mod.rs | 9 +- proxy/src/auth/credentials.rs | 2 +- proxy/src/auth/flow.rs | 118 ++-- proxy/src/binary/pg_sni_router.rs | 87 +-- proxy/src/binary/proxy.rs | 5 +- proxy/src/cancellation.rs | 2 +- proxy/src/compute.rs | 2 +- proxy/src/console_redirect_proxy.rs | 20 +- proxy/src/context/mod.rs | 2 +- proxy/src/context/parquet.rs | 2 +- proxy/src/lib.rs | 1 + proxy/src/pqproto.rs | 693 +++++++++++++++++++++ proxy/src/proxy/connect_compute.rs | 2 +- proxy/src/proxy/handshake.rs | 90 ++- proxy/src/proxy/mod.rs | 68 +- proxy/src/proxy/retry.rs | 3 +- proxy/src/proxy/tests/mitm.rs | 7 +- proxy/src/proxy/tests/mod.rs | 16 +- proxy/src/redis/cancellation_publisher.rs | 3 +- proxy/src/redis/keys.rs | 21 +- proxy/src/redis/notifications.rs | 10 - proxy/src/sasl/messages.rs | 22 - proxy/src/sasl/mod.rs | 6 +- proxy/src/sasl/stream.rs | 136 ++-- proxy/src/serverless/sql_over_http.rs | 4 +- proxy/src/stream.rs | 319 +++++----- 29 files changed, 1122 insertions(+), 600 deletions(-) create mode 100644 proxy/src/pqproto.rs diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index 5e494dfdd6..dcc500f2c8 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -17,35 +17,27 @@ pub(super) async fn authenticate( config: &'static AuthenticationConfig, secret: AuthSecret, ) -> auth::Result { - let flow = AuthFlow::new(client); let scram_keys = match secret { #[cfg(any(test, feature = "testing"))] AuthSecret::Md5(_) => { debug!("auth endpoint chooses MD5"); - return Err(auth::AuthError::bad_auth_method("MD5")); + return Err(auth::AuthError::MalformedPassword("MD5 not supported")); } AuthSecret::Scram(secret) => { debug!("auth endpoint chooses SCRAM"); let scram = auth::Scram(&secret, ctx); - let auth_outcome = tokio::time::timeout( - config.scram_protocol_timeout, - async { - - flow.begin(scram).await.map_err(|error| { - warn!(?error, "error sending scram acknowledgement"); - error - })?.authenticate().await.map_err(|error| { + let auth_outcome = tokio::time::timeout(config.scram_protocol_timeout, async { + AuthFlow::new(client, scram) + .authenticate() + .await + .inspect_err(|error| { warn!(?error, "error processing scram messages"); - error }) - } - ) + }) .await - .map_err(|e| { - warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs()); - auth::AuthError::user_timeout(e) - })??; + .inspect_err(|_| warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs())) + .map_err(auth::AuthError::user_timeout)??; let client_key = match auth_outcome { sasl::Outcome::Success(key) => key, diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index dd48384c03..a50c30257f 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -2,7 +2,6 @@ use std::fmt; use async_trait::async_trait; use postgres_client::config::SslMode; -use pq_proto::BeMessage as Be; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{info, info_span}; @@ -16,6 +15,7 @@ use crate::context::RequestContext; use crate::control_plane::client::cplane_proxy_v1; use crate::control_plane::{self, CachedNodeInfo, NodeInfo}; use crate::error::{ReportableError, UserFacingError}; +use crate::pqproto::BeMessage; use crate::proxy::NeonOptions; use crate::proxy::connect_compute::ComputeConnectBackend; use crate::stream::PqStream; @@ -154,11 +154,13 @@ async fn authenticate( // Give user a URL to spawn a new database. info!(parent: &span, "sending the auth URL to the user"); - client - .write_message_noflush(&Be::AuthenticationOk)? - .write_message_noflush(&Be::CLIENT_ENCODING)? - .write_message(&Be::NoticeResponse(&greeting)) - .await?; + client.write_message(BeMessage::AuthenticationOk); + client.write_message(BeMessage::ParameterStatus { + name: b"client_encoding", + value: b"UTF8", + }); + client.write_message(BeMessage::NoticeResponse(&greeting)); + client.flush().await?; // Wait for console response via control plane (see `mgmt`). info!(parent: &span, "waiting for console's reply..."); @@ -188,7 +190,7 @@ async fn authenticate( } } - client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?; + client.write_message(BeMessage::NoticeResponse("Connecting to database.")); // This config should be self-contained, because we won't // take username or dbname from client's startup message. diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index 3316543022..1e5c076fb9 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -24,23 +24,25 @@ pub(crate) async fn authenticate_cleartext( debug!("cleartext auth flow override is enabled, proceeding"); ctx.set_auth_method(crate::context::AuthMethod::Cleartext); - // pause the timer while we communicate with the client - let paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - let ep = EndpointIdInt::from(&info.endpoint); - let auth_flow = AuthFlow::new(client) - .begin(auth::CleartextPassword { + let auth_flow = AuthFlow::new( + client, + auth::CleartextPassword { secret, endpoint: ep, pool: config.thread_pool.clone(), - }) - .await?; - drop(paused); - // cleartext auth is only allowed to the ws/http protocol. - // If we're here, we already received the password in the first message. - // Scram protocol will be executed on the proxy side. - let auth_outcome = auth_flow.authenticate().await?; + }, + ); + let auth_outcome = { + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + + // cleartext auth is only allowed to the ws/http protocol. + // If we're here, we already received the password in the first message. + // Scram protocol will be executed on the proxy side. + auth_flow.authenticate().await? + }; let keys = match auth_outcome { sasl::Outcome::Success(key) => key, @@ -67,9 +69,7 @@ pub(crate) async fn password_hack_no_authentication( // pause the timer while we communicate with the client let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - let payload = AuthFlow::new(client) - .begin(auth::PasswordHack) - .await? + let payload = AuthFlow::new(client, auth::PasswordHack) .get_password() .await?; diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index 6e5c0a3954..8c892d90a0 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -31,6 +31,7 @@ use crate::control_plane::{ }; use crate::intern::EndpointIdInt; use crate::metrics::Metrics; +use crate::pqproto::BeMessage; use crate::protocol2::ConnectionInfoExtra; use crate::proxy::NeonOptions; use crate::proxy::connect_compute::ComputeConnectBackend; @@ -402,7 +403,7 @@ async fn authenticate_with_secret( }; // we have authenticated the password - client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?; + client.write_message(BeMessage::AuthenticationOk); return Ok(ComputeCredentials { info, keys }); } @@ -702,7 +703,7 @@ mod tests { #[tokio::test] async fn auth_quirks_scram() { let (mut client, server) = tokio::io::duplex(1024); - let mut stream = PqStream::new(Stream::from_raw(server)); + let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server)); let ctx = RequestContext::test(); let api = Auth { @@ -784,7 +785,7 @@ mod tests { #[tokio::test] async fn auth_quirks_cleartext() { let (mut client, server) = tokio::io::duplex(1024); - let mut stream = PqStream::new(Stream::from_raw(server)); + let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server)); let ctx = RequestContext::test(); let api = Auth { @@ -838,7 +839,7 @@ mod tests { #[tokio::test] async fn auth_quirks_password_hack() { let (mut client, server) = tokio::io::duplex(1024); - let mut stream = PqStream::new(Stream::from_raw(server)); + let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server)); let ctx = RequestContext::test(); let api = Auth { diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 526d0df7f2..b51da48862 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -5,7 +5,6 @@ use std::net::IpAddr; use std::str::FromStr; use itertools::Itertools; -use pq_proto::StartupMessageParams; use thiserror::Error; use tracing::{debug, warn}; @@ -13,6 +12,7 @@ use crate::auth::password_hack::parse_endpoint_param; use crate::context::RequestContext; use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, SniGroup, SniKind}; +use crate::pqproto::StartupMessageParams; use crate::proxy::NeonOptions; use crate::serverless::{AUTH_BROKER_SNI, SERVERLESS_DRIVER_SNI}; use crate::types::{EndpointId, RoleName}; diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index 0992c6d875..8fbc4577e9 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -1,10 +1,8 @@ //! Main authentication flow. -use std::io; use std::sync::Arc; use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS}; -use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be}; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::info; @@ -13,35 +11,26 @@ use super::{AuthError, PasswordHackPayload}; use crate::context::RequestContext; use crate::control_plane::AuthSecret; use crate::intern::EndpointIdInt; +use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage}; use crate::sasl; use crate::scram::threadpool::ThreadPool; use crate::scram::{self}; use crate::stream::{PqStream, Stream}; use crate::tls::TlsServerEndPoint; -/// Every authentication selector is supposed to implement this trait. -pub(crate) trait AuthMethod { - /// Any authentication selector should provide initial backend message - /// containing auth method name and parameters, e.g. md5 salt. - fn first_message(&self, channel_binding: bool) -> BeMessage<'_>; -} - -/// Initial state of [`AuthFlow`]. -pub(crate) struct Begin; - /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`]. pub(crate) struct Scram<'a>( pub(crate) &'a scram::ServerSecret, pub(crate) &'a RequestContext, ); -impl AuthMethod for Scram<'_> { +impl Scram<'_> { #[inline(always)] fn first_message(&self, channel_binding: bool) -> BeMessage<'_> { if channel_binding { - Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS)) + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS)) } else { - Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods( + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods( scram::METHODS_WITHOUT_PLUS, )) } @@ -52,13 +41,6 @@ impl AuthMethod for Scram<'_> { /// . pub(crate) struct PasswordHack; -impl AuthMethod for PasswordHack { - #[inline(always)] - fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> { - Be::AuthenticationCleartextPassword - } -} - /// Use clear-text password auth called `password` in docs /// pub(crate) struct CleartextPassword { @@ -67,53 +49,37 @@ pub(crate) struct CleartextPassword { pub(crate) secret: AuthSecret, } -impl AuthMethod for CleartextPassword { - #[inline(always)] - fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> { - Be::AuthenticationCleartextPassword - } -} - /// This wrapper for [`PqStream`] performs client authentication. #[must_use] pub(crate) struct AuthFlow<'a, S, State> { /// The underlying stream which implements libpq's protocol. stream: &'a mut PqStream>, - /// State might contain ancillary data (see [`Self::begin`]). + /// State might contain ancillary data. state: State, tls_server_end_point: TlsServerEndPoint, } /// Initial state of the stream wrapper. -impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> { +impl<'a, S: AsyncRead + AsyncWrite + Unpin, M> AuthFlow<'a, S, M> { /// Create a new wrapper for client authentication. - pub(crate) fn new(stream: &'a mut PqStream>) -> Self { + pub(crate) fn new(stream: &'a mut PqStream>, method: M) -> Self { let tls_server_end_point = stream.get_ref().tls_server_end_point(); Self { stream, - state: Begin, + state: method, tls_server_end_point, } } - - /// Move to the next step by sending auth method's name & params to client. - pub(crate) async fn begin(self, method: M) -> io::Result> { - self.stream - .write_message(&method.first_message(self.tls_server_end_point.supported())) - .await?; - - Ok(AuthFlow { - stream: self.stream, - state: method, - tls_server_end_point: self.tls_server_end_point, - }) - } } impl AuthFlow<'_, S, PasswordHack> { /// Perform user authentication. Raise an error in case authentication failed. pub(crate) async fn get_password(self) -> super::Result { + self.stream + .write_message(BeMessage::AuthenticationCleartextPassword); + self.stream.flush().await?; + let msg = self.stream.read_password_message().await?; let password = msg .strip_suffix(&[0]) @@ -133,6 +99,10 @@ impl AuthFlow<'_, S, PasswordHack> { impl AuthFlow<'_, S, CleartextPassword> { /// Perform user authentication. Raise an error in case authentication failed. pub(crate) async fn authenticate(self) -> super::Result> { + self.stream + .write_message(BeMessage::AuthenticationCleartextPassword); + self.stream.flush().await?; + let msg = self.stream.read_password_message().await?; let password = msg .strip_suffix(&[0]) @@ -147,7 +117,7 @@ impl AuthFlow<'_, S, CleartextPassword> { .await?; if let sasl::Outcome::Success(_) = &outcome { - self.stream.write_message_noflush(&Be::AuthenticationOk)?; + self.stream.write_message(BeMessage::AuthenticationOk); } Ok(outcome) @@ -159,42 +129,36 @@ impl AuthFlow<'_, S, Scram<'_>> { /// Perform user authentication. Raise an error in case authentication failed. pub(crate) async fn authenticate(self) -> super::Result> { let Scram(secret, ctx) = self.state; + let channel_binding = self.tls_server_end_point; - // pause the timer while we communicate with the client - let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + // send sasl message. + { + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - // Initial client message contains the chosen auth method's name. - let msg = self.stream.read_password_message().await?; - let sasl = sasl::FirstMessage::parse(&msg) - .ok_or(AuthError::MalformedPassword("bad sasl message"))?; - - // Currently, the only supported SASL method is SCRAM. - if !scram::METHODS.contains(&sasl.method) { - return Err(super::AuthError::bad_auth_method(sasl.method)); + let sasl = self.state.first_message(channel_binding.supported()); + self.stream.write_message(sasl); + self.stream.flush().await?; } - match sasl.method { - SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256), - SCRAM_SHA_256_PLUS => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus), - _ => {} - } + // complete sasl handshake. + sasl::authenticate(ctx, self.stream, |method| { + // Currently, the only supported SASL method is SCRAM. + match method { + SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256), + SCRAM_SHA_256_PLUS => { + ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus); + } + method => return Err(sasl::Error::BadAuthMethod(method.into())), + } - // TODO: make this a metric instead - info!("client chooses {}", sasl.method); + // TODO: make this a metric instead + info!("client chooses {}", method); - let outcome = sasl::SaslStream::new(self.stream, sasl.message) - .authenticate(scram::Exchange::new( - secret, - rand::random, - self.tls_server_end_point, - )) - .await?; - - if let sasl::Outcome::Success(_) = &outcome { - self.stream.write_message_noflush(&Be::AuthenticationOk)?; - } - - Ok(outcome) + Ok(scram::Exchange::new(secret, rand::random, channel_binding)) + }) + .await + .map_err(AuthError::Sasl) } } diff --git a/proxy/src/binary/pg_sni_router.rs b/proxy/src/binary/pg_sni_router.rs index 3e87538ae7..a4f517fead 100644 --- a/proxy/src/binary/pg_sni_router.rs +++ b/proxy/src/binary/pg_sni_router.rs @@ -4,8 +4,9 @@ //! This allows connecting to pods/services running in the same Kubernetes cluster from //! the outside. Similar to an ingress controller for HTTPS. +use std::net::SocketAddr; use std::path::Path; -use std::{net::SocketAddr, sync::Arc}; +use std::sync::Arc; use anyhow::{Context, anyhow, bail, ensure}; use clap::Arg; @@ -17,6 +18,7 @@ use rustls::pki_types::{DnsName, PrivateKeyDer}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpListener; use tokio_rustls::TlsConnector; +use tokio_rustls::server::TlsStream; use tokio_util::sync::CancellationToken; use tracing::{Instrument, error, info}; use utils::project_git_version; @@ -24,10 +26,12 @@ use utils::sentry_init::init_sentry; use crate::context::RequestContext; use crate::metrics::{Metrics, ThreadPoolMetrics}; +use crate::pqproto::FeStartupPacket; use crate::protocol2::ConnectionInfo; -use crate::proxy::{ErrorSource, copy_bidirectional_client_compute, run_until_cancelled}; +use crate::proxy::{ + ErrorSource, TlsRequired, copy_bidirectional_client_compute, run_until_cancelled, +}; use crate::stream::{PqStream, Stream}; -use crate::tls::TlsServerEndPoint; project_git_version!(GIT_VERSION); @@ -84,7 +88,7 @@ pub async fn run() -> anyhow::Result<()> { .parse()?; // Configure TLS - let (tls_config, tls_server_end_point): (Arc, TlsServerEndPoint) = match ( + let tls_config = match ( args.get_one::("tls-key"), args.get_one::("tls-cert"), ) { @@ -117,7 +121,6 @@ pub async fn run() -> anyhow::Result<()> { dest.clone(), tls_config.clone(), None, - tls_server_end_point, proxy_listener, cancellation_token.clone(), )) @@ -127,7 +130,6 @@ pub async fn run() -> anyhow::Result<()> { dest, tls_config, Some(compute_tls_config), - tls_server_end_point, proxy_listener_compute_tls, cancellation_token.clone(), )) @@ -154,7 +156,7 @@ pub async fn run() -> anyhow::Result<()> { pub(super) fn parse_tls( key_path: &Path, cert_path: &Path, -) -> anyhow::Result<(Arc, TlsServerEndPoint)> { +) -> anyhow::Result> { let key = { let key_bytes = std::fs::read(key_path).context("TLS key file")?; @@ -187,10 +189,6 @@ pub(super) fn parse_tls( })? }; - // needed for channel bindings - let first_cert = cert_chain.first().context("missing certificate")?; - let tls_server_end_point = TlsServerEndPoint::new(first_cert)?; - let tls_config = rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider())) .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12]) @@ -199,14 +197,13 @@ pub(super) fn parse_tls( .with_single_cert(cert_chain, key)? .into(); - Ok((tls_config, tls_server_end_point)) + Ok(tls_config) } pub(super) async fn task_main( dest_suffix: Arc, tls_config: Arc, compute_tls_config: Option>, - tls_server_end_point: TlsServerEndPoint, listener: tokio::net::TcpListener, cancellation_token: CancellationToken, ) -> anyhow::Result<()> { @@ -242,15 +239,7 @@ pub(super) async fn task_main( crate::metrics::Protocol::SniRouter, "sni", ); - handle_client( - ctx, - dest_suffix, - tls_config, - compute_tls_config, - tls_server_end_point, - socket, - ) - .await + handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await } .unwrap_or_else(|e| { // Acknowledge that the task has finished with an error. @@ -269,55 +258,26 @@ pub(super) async fn task_main( Ok(()) } -const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; - async fn ssl_handshake( ctx: &RequestContext, raw_stream: S, tls_config: Arc, - tls_server_end_point: TlsServerEndPoint, -) -> anyhow::Result> { - let mut stream = PqStream::new(Stream::from_raw(raw_stream)); - - let msg = stream.read_startup_packet().await?; - use pq_proto::FeStartupPacket::SslRequest; - +) -> anyhow::Result> { + let (mut stream, msg) = PqStream::parse_startup(Stream::from_raw(raw_stream)).await?; match msg { - SslRequest { direct: false } => { - stream - .write_message(&pq_proto::BeMessage::EncryptionResponse(true)) - .await?; + FeStartupPacket::SslRequest { direct: None } => { + let raw = stream.accept_tls().await?; - // Upgrade raw stream into a secure TLS-backed stream. - // NOTE: We've consumed `tls`; this fact will be used later. - - let (raw, read_buf) = stream.into_inner(); - // TODO: Normally, client doesn't send any data before - // server says TLS handshake is ok and read_buf is empty. - // However, you could imagine pipelining of postgres - // SSLRequest + TLS ClientHello in one hunk similar to - // pipelining in our node js driver. We should probably - // support that by chaining read_buf with the stream. - if !read_buf.is_empty() { - bail!("data is sent before server replied with EncryptionResponse"); - } - - Ok(Stream::Tls { - tls: Box::new( - raw.upgrade(tls_config, !ctx.has_private_peer_addr()) - .await?, - ), - tls_server_end_point, - }) + Ok(raw + .upgrade(tls_config, !ctx.has_private_peer_addr()) + .await?) } unexpected => { info!( ?unexpected, "unexpected startup packet, rejecting connection" ); - stream - .throw_error_str(ERR_INSECURE_CONNECTION, crate::error::ErrorKind::User, None) - .await? + Err(stream.throw_error(TlsRequired, None).await)? } } } @@ -327,15 +287,18 @@ async fn handle_client( dest_suffix: Arc, tls_config: Arc, compute_tls_config: Option>, - tls_server_end_point: TlsServerEndPoint, stream: impl AsyncRead + AsyncWrite + Unpin, ) -> anyhow::Result<()> { - let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?; + let mut tls_stream = ssl_handshake(&ctx, stream, tls_config).await?; // Cut off first part of the SNI domain // We receive required destination details in the format of // `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain` - let sni = tls_stream.sni_hostname().ok_or(anyhow!("SNI missing"))?; + let sni = tls_stream + .get_ref() + .1 + .server_name() + .ok_or(anyhow!("SNI missing"))?; let dest: Vec<&str> = sni .split_once('.') .context("invalid SNI")? diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 5f24940985..9a3903ba9a 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -476,8 +476,7 @@ pub async fn run() -> anyhow::Result<()> { let key_path = args.tls_key.expect("already asserted it is set"); let cert_path = args.tls_cert.expect("already asserted it is set"); - let (tls_config, tls_server_end_point) = - super::pg_sni_router::parse_tls(&key_path, &cert_path)?; + let tls_config = super::pg_sni_router::parse_tls(&key_path, &cert_path)?; let dest = Arc::new(dest); @@ -485,7 +484,6 @@ pub async fn run() -> anyhow::Result<()> { dest.clone(), tls_config.clone(), None, - tls_server_end_point, listen, cancellation_token.clone(), )); @@ -494,7 +492,6 @@ pub async fn run() -> anyhow::Result<()> { dest, tls_config, Some(config.connect_to_compute.tls.clone()), - tls_server_end_point, listen_tls, cancellation_token.clone(), )); diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index a6e7bf85a0..0bff901376 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -5,7 +5,6 @@ use anyhow::{Context, anyhow}; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use postgres_client::CancelToken; use postgres_client::tls::MakeTlsConnect; -use pq_proto::CancelKeyData; use redis::{Cmd, FromRedisValue, Value}; use serde::{Deserialize, Serialize}; use thiserror::Error; @@ -21,6 +20,7 @@ use crate::control_plane::ControlPlaneApi; use crate::error::ReportableError; use crate::ext::LockExt; use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind}; +use crate::pqproto::CancelKeyData; use crate::protocol2::ConnectionInfoExtra; use crate::rate_limiter::LeakyBucketRateLimiter; use crate::redis::keys::KeyPrefix; diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 26254beecf..2899f25129 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -8,7 +8,6 @@ use itertools::Itertools; use postgres_client::tls::MakeTlsConnect; use postgres_client::{CancelToken, RawConnection}; use postgres_protocol::message::backend::NoticeResponseBody; -use pq_proto::StartupMessageParams; use rustls::pki_types::InvalidDnsNameError; use thiserror::Error; use tokio::net::{TcpStream, lookup_host}; @@ -24,6 +23,7 @@ use crate::control_plane::errors::WakeComputeError; use crate::control_plane::messages::MetricsAuxInfo; use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumDbConnectionsGuard}; +use crate::pqproto::StartupMessageParams; use crate::proxy::neon_option; use crate::tls::postgres_rustls::MakeRustlsConnect; use crate::types::Host; diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index e3184e20d1..9499aba61b 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use futures::{FutureExt, TryFutureExt}; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::sync::CancellationToken; use tracing::{Instrument, debug, error, info}; @@ -221,12 +221,10 @@ pub(crate) async fn handle_client( .await { Ok(auth_result) => auth_result, - Err(e) => { - return stream.throw_error(e, Some(ctx)).await?; - } + Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, }; - let mut node = connect_to_compute( + let node = connect_to_compute( ctx, &TcpMechanism { user_info, @@ -238,7 +236,7 @@ pub(crate) async fn handle_client( config.wake_compute_retry_config, &config.connect_to_compute, ) - .or_else(|e| stream.throw_error(e, Some(ctx))) + .or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) }) .await?; let cancellation_handler_clone = Arc::clone(&cancellation_handler); @@ -246,14 +244,8 @@ pub(crate) async fn handle_client( session.write_cancel_key(node.cancel_closure.clone())?; - prepare_client_connection(&node, *session.key(), &mut stream).await?; - - // Before proxy passing, forward to compute whatever data is left in the - // PqStream input buffer. Normally there is none, but our serverless npm - // driver in pipeline mode sends startup, password and first query - // immediately after opening the connection. - let (stream, read_buf) = stream.into_inner(); - node.stream.write_all(&read_buf).await?; + prepare_client_connection(&node, *session.key(), &mut stream); + let stream = stream.flush_and_into_inner().await?; Ok(Some(ProxyPassthrough { client: stream, diff --git a/proxy/src/context/mod.rs b/proxy/src/context/mod.rs index 79aaf22990..de4600951e 100644 --- a/proxy/src/context/mod.rs +++ b/proxy/src/context/mod.rs @@ -4,7 +4,6 @@ use std::net::IpAddr; use chrono::Utc; use once_cell::sync::OnceCell; -use pq_proto::StartupMessageParams; use smol_str::SmolStr; use tokio::sync::mpsc; use tracing::field::display; @@ -20,6 +19,7 @@ use crate::metrics::{ ConnectOutcome, InvalidEndpointsGroup, LatencyAccumulated, LatencyTimer, Metrics, Protocol, Waiting, }; +use crate::pqproto::StartupMessageParams; use crate::protocol2::{ConnectionInfo, ConnectionInfoExtra}; use crate::types::{DbName, EndpointId, RoleName}; diff --git a/proxy/src/context/parquet.rs b/proxy/src/context/parquet.rs index f6250bcd17..c9d3905abd 100644 --- a/proxy/src/context/parquet.rs +++ b/proxy/src/context/parquet.rs @@ -11,7 +11,6 @@ use parquet::file::metadata::RowGroupMetaDataPtr; use parquet::file::properties::{DEFAULT_PAGE_SIZE, WriterProperties, WriterPropertiesPtr}; use parquet::file::writer::SerializedFileWriter; use parquet::record::RecordWriter; -use pq_proto::StartupMessageParams; use remote_storage::{GenericRemoteStorage, RemotePath, RemoteStorageConfig, TimeoutOrCancel}; use serde::ser::SerializeMap; use tokio::sync::mpsc; @@ -24,6 +23,7 @@ use super::{LOG_CHAN, RequestContextInner}; use crate::config::remote_storage_from_toml; use crate::context::LOG_CHAN_DISCONNECT; use crate::ext::TaskExt; +use crate::pqproto::StartupMessageParams; #[derive(clap::Args, Clone, Debug)] pub struct ParquetUploadArgs { diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index d1f8430b8a..d65d056585 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -92,6 +92,7 @@ mod logging; mod metrics; mod parse; mod pglb; +mod pqproto; mod protocol2; mod proxy; mod rate_limiter; diff --git a/proxy/src/pqproto.rs b/proxy/src/pqproto.rs new file mode 100644 index 0000000000..d68d9f9474 --- /dev/null +++ b/proxy/src/pqproto.rs @@ -0,0 +1,693 @@ +//! Postgres protocol codec +//! +//! + +use std::fmt; +use std::io::{self, Cursor}; + +use bytes::{Buf, BufMut}; +use itertools::Itertools; +use rand::distributions::{Distribution, Standard}; +use tokio::io::{AsyncRead, AsyncReadExt}; +use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian}; + +pub type ErrorCode = [u8; 5]; + +pub const FE_PASSWORD_MESSAGE: u8 = b'p'; + +pub const SQLSTATE_INTERNAL_ERROR: [u8; 5] = *b"XX000"; + +/// The protocol version number. +/// +/// The most significant 16 bits are the major version number (3 for the protocol described here). +/// The least significant 16 bits are the minor version number (0 for the protocol described here). +/// +#[derive(Clone, Copy, PartialEq, PartialOrd, FromBytes, IntoBytes, Immutable)] +#[repr(C)] +pub struct ProtocolVersion { + major: big_endian::U16, + minor: big_endian::U16, +} + +impl ProtocolVersion { + pub const fn new(major: u16, minor: u16) -> Self { + Self { + major: big_endian::U16::new(major), + minor: big_endian::U16::new(minor), + } + } + pub const fn minor(self) -> u16 { + self.minor.get() + } + pub const fn major(self) -> u16 { + self.major.get() + } +} + +impl fmt::Debug for ProtocolVersion { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list() + .entry(&self.major()) + .entry(&self.minor()) + .finish() + } +} + +/// read the type from the stream using zerocopy. +/// +/// not cancel safe. +macro_rules! read { + ($s:expr => $t:ty) => {{ + // cannot be implemented as a function due to lack of const-generic-expr + let mut buf = [0; size_of::<$t>()]; + $s.read_exact(&mut buf).await?; + let res: $t = zerocopy::transmute!(buf); + res + }}; +} + +pub async fn read_startup(stream: &mut S) -> io::Result +where + S: AsyncRead + Unpin, +{ + /// + const MAX_STARTUP_PACKET_LENGTH: usize = 10000; + const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234; + /// + const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678); + /// + const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679); + /// + const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680); + + /// This first reads the startup message header, is 8 bytes. + /// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number. + /// + /// The length value is inclusive of the header. For example, + /// an empty message will always have length 8. + #[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)] + #[repr(C)] + struct StartupHeader { + len: big_endian::U32, + version: ProtocolVersion, + } + + let header = read!(stream => StartupHeader); + + // + // First byte indicates standard SSL handshake message + // (It can't be a Postgres startup length because in network byte order + // that would be a startup packet hundreds of megabytes long) + if header.as_bytes()[0] == 0x16 { + return Ok(FeStartupPacket::SslRequest { + // The bytes we read for the header are actually part of a TLS ClientHello. + // In theory, if the ClientHello was < 8 bytes we would fail with EOF before we get here. + // In practice though, I see no world where a ClientHello is less than 8 bytes + // since it includes ephemeral keys etc. + direct: Some(zerocopy::transmute!(header)), + }); + } + + let Some(len) = (header.len.get() as usize).checked_sub(8) else { + return Err(io::Error::other(format!( + "invalid startup message length {}, must be at least 8.", + header.len, + ))); + }; + + // TODO: add a histogram for startup packet lengths + if len > MAX_STARTUP_PACKET_LENGTH { + tracing::warn!("large startup message detected: {len} bytes"); + return Err(io::Error::other(format!( + "invalid startup message length {len}" + ))); + } + + match header.version { + // + CANCEL_REQUEST_CODE => { + if len != 8 { + return Err(io::Error::other( + "CancelRequest message is malformed, backend PID / secret key missing", + )); + } + + Ok(FeStartupPacket::CancelRequest( + read!(stream => CancelKeyData), + )) + } + // + NEGOTIATE_SSL_CODE => { + // Requested upgrade to SSL (aka TLS) + Ok(FeStartupPacket::SslRequest { direct: None }) + } + NEGOTIATE_GSS_CODE => { + // Requested upgrade to GSSAPI + Ok(FeStartupPacket::GssEncRequest) + } + version if version.major() == RESERVED_INVALID_MAJOR_VERSION => Err(io::Error::other( + format!("Unrecognized request code {version:?}"), + )), + // StartupMessage + version => { + // The protocol version number is followed by one or more pairs of parameter name and value strings. + // A zero byte is required as a terminator after the last name/value pair. + // Parameters can appear in any order. user is required, others are optional. + + let mut buf = vec![0; len]; + stream.read_exact(&mut buf).await?; + + if buf.pop() != Some(b'\0') { + return Err(io::Error::other( + "StartupMessage params: missing null terminator", + )); + } + + // TODO: Don't do this. + // There's no guarantee that these messages are utf8, + // but they usually happen to be simple ascii. + let params = String::from_utf8(buf) + .map_err(|_| io::Error::other("StartupMessage params: invalid utf-8"))?; + + Ok(FeStartupPacket::StartupMessage { + version, + params: StartupMessageParams { params }, + }) + } + } +} + +/// Read a raw postgres packet, which will respect the max length requested. +/// +/// This returns the message tag, as well as the message body. The message +/// body is written into `buf`, and it is otherwise completely overwritten. +/// +/// This is not cancel safe. +pub async fn read_message<'a, S>( + stream: &mut S, + buf: &'a mut Vec, + max: usize, +) -> io::Result<(u8, &'a mut [u8])> +where + S: AsyncRead + Unpin, +{ + /// This first reads the header, which for regular messages in the 3.0 protocol is 5 bytes. + /// The first byte is a message tag, and the next 4 bytes is a big-endian length. + /// + /// Awkwardly, the length value is inclusive of itself, but not of the tag. For example, + /// an empty message will always have length 4. + #[derive(Clone, Copy, FromBytes)] + #[repr(C)] + struct Header { + tag: u8, + len: big_endian::U32, + } + + let header = read!(stream => Header); + + // as described above, the length must be at least 4. + let Some(len) = (header.len.get() as usize).checked_sub(4) else { + return Err(io::Error::other(format!( + "invalid startup message length {}, must be at least 4.", + header.len, + ))); + }; + + // TODO: add a histogram for message lengths + + // check if the message exceeds our desired max. + if len > max { + tracing::warn!("large postgres message detected: {len} bytes"); + return Err(io::Error::other(format!("invalid message length {len}"))); + } + + // read in our entire message. + buf.resize(len, 0); + stream.read_exact(buf).await?; + + Ok((header.tag, buf)) +} + +pub struct WriteBuf(Cursor>); + +impl Buf for WriteBuf { + #[inline] + fn remaining(&self) -> usize { + self.0.remaining() + } + + #[inline] + fn chunk(&self) -> &[u8] { + self.0.chunk() + } + + #[inline] + fn advance(&mut self, cnt: usize) { + self.0.advance(cnt); + } +} + +impl WriteBuf { + pub const fn new() -> Self { + Self(Cursor::new(Vec::new())) + } + + /// Use a heuristic to determine if we should shrink the write buffer. + #[inline] + fn should_shrink(&self) -> bool { + let n = self.0.position() as usize; + let len = self.0.get_ref().len(); + + // the unused space at the front of our buffer is 2x the size of our filled portion. + n + n > len + } + + /// Shrink the write buffer so that subsequent writes have more spare capacity. + #[cold] + fn shrink(&mut self) { + let n = self.0.position() as usize; + let buf = self.0.get_mut(); + + // buf repr: + // [----unused------|-----filled-----|-----uninit-----] + // ^ n ^ buf.len() ^ buf.capacity() + let filled = n..buf.len(); + let filled_len = filled.len(); + buf.copy_within(filled, 0); + buf.truncate(filled_len); + self.0.set_position(0); + } + + /// clear the write buffer. + pub fn reset(&mut self) { + let buf = self.0.get_mut(); + buf.clear(); + self.0.set_position(0); + } + + /// Write a raw message to the internal buffer. + /// + /// The size_hint value is only a hint for reserving space. It's ok if it's incorrect, since + /// we calculate the length after the fact. + pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec)) { + if self.should_shrink() { + self.shrink(); + } + + let buf = self.0.get_mut(); + buf.reserve(5 + size_hint); + + buf.push(tag); + let start = buf.len(); + buf.extend_from_slice(&[0, 0, 0, 0]); + + f(buf); + + let end = buf.len(); + let len = (end - start) as u32; + buf[start..start + 4].copy_from_slice(&len.to_be_bytes()); + } + + /// Write an encryption response message. + pub fn encryption(&mut self, m: u8) { + self.0.get_mut().push(m); + } + + pub fn write_error(&mut self, msg: &str, error_code: ErrorCode) { + self.shrink(); + + // + // + // "SERROR\0CXXXXX\0M\0\0".len() == 17 + self.write_raw(17 + msg.len(), b'E', |buf| { + // Severity: ERROR + buf.put_slice(b"SERROR\0"); + + // Code: error_code + buf.put_u8(b'C'); + buf.put_slice(&error_code); + buf.put_u8(0); + + // Message: msg + buf.put_u8(b'M'); + buf.put_slice(msg.as_bytes()); + buf.put_u8(0); + + // End. + buf.put_u8(0); + }); + } +} + +#[derive(Debug)] +pub enum FeStartupPacket { + CancelRequest(CancelKeyData), + SslRequest { + direct: Option<[u8; 8]>, + }, + GssEncRequest, + StartupMessage { + version: ProtocolVersion, + params: StartupMessageParams, + }, +} + +#[derive(Debug, Clone, Default)] +pub struct StartupMessageParams { + pub params: String, +} + +impl StartupMessageParams { + /// Get parameter's value by its name. + pub fn get(&self, name: &str) -> Option<&str> { + self.iter().find_map(|(k, v)| (k == name).then_some(v)) + } + + /// Split command-line options according to PostgreSQL's logic, + /// taking into account all escape sequences but leaving them as-is. + /// [`None`] means that there's no `options` in [`Self`]. + pub fn options_raw(&self) -> Option> { + self.get("options").map(Self::parse_options_raw) + } + + /// Split command-line options according to PostgreSQL's logic, + /// taking into account all escape sequences but leaving them as-is. + pub fn parse_options_raw(input: &str) -> impl Iterator { + // See `postgres: pg_split_opts`. + let mut last_was_escape = false; + input + .split(move |c: char| { + // We split by non-escaped whitespace symbols. + let should_split = c.is_ascii_whitespace() && !last_was_escape; + last_was_escape = c == '\\' && !last_was_escape; + should_split + }) + .filter(|s| !s.is_empty()) + } + + /// Iterate through key-value pairs in an arbitrary order. + pub fn iter(&self) -> impl Iterator { + self.params.split_terminator('\0').tuples() + } + + // This function is mostly useful in tests. + #[cfg(test)] + pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self { + let mut b = Self { + params: String::new(), + }; + for (k, v) in pairs { + b.insert(k, v); + } + b + } + + /// Set parameter's value by its name. + /// name and value must not contain a \0 byte + pub fn insert(&mut self, name: &str, value: &str) { + self.params.reserve(name.len() + value.len() + 2); + self.params.push_str(name); + self.params.push('\0'); + self.params.push_str(value); + self.params.push('\0'); + } +} + +/// Cancel keys usually are represented as PID+SecretKey, but to proxy they're just +/// opaque bytes. +#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, FromBytes, IntoBytes, Immutable)] +pub struct CancelKeyData(pub big_endian::U64); + +pub fn id_to_cancel_key(id: u64) -> CancelKeyData { + CancelKeyData(big_endian::U64::new(id)) +} + +impl fmt::Display for CancelKeyData { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let id = self.0; + f.debug_tuple("CancelKeyData") + .field(&format_args!("{id:x}")) + .finish() + } +} +impl Distribution for Standard { + fn sample(&self, rng: &mut R) -> CancelKeyData { + id_to_cancel_key(rng.r#gen()) + } +} + +pub enum BeMessage<'a> { + AuthenticationOk, + AuthenticationSasl(BeAuthenticationSaslMessage<'a>), + AuthenticationCleartextPassword, + BackendKeyData(CancelKeyData), + ParameterStatus { + name: &'a [u8], + value: &'a [u8], + }, + ReadyForQuery, + NoticeResponse(&'a str), + NegotiateProtocolVersion { + version: ProtocolVersion, + options: &'a [&'a str], + }, +} + +#[derive(Debug)] +pub enum BeAuthenticationSaslMessage<'a> { + Methods(&'a [&'a str]), + Continue(&'a [u8]), + Final(&'a [u8]), +} + +impl BeMessage<'_> { + /// Write the message into an internal buffer + pub fn write_message(self, buf: &mut WriteBuf) { + match self { + // + BeMessage::AuthenticationOk => { + buf.write_raw(1, b'R', |buf| buf.put_i32(0)); + } + // + BeMessage::AuthenticationCleartextPassword => { + buf.write_raw(1, b'R', |buf| buf.put_i32(3)); + } + + // + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(methods)) => { + let len: usize = methods.iter().map(|m| m.len() + 1).sum(); + buf.write_raw(len + 2, b'R', |buf| { + buf.put_i32(10); // Specifies that SASL auth method is used. + for method in methods { + buf.put_slice(method.as_bytes()); + buf.put_u8(0); + } + buf.put_u8(0); // zero terminator for the list + }); + } + // + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Continue(extra)) => { + buf.write_raw(extra.len() + 1, b'R', |buf| { + buf.put_i32(11); // Continue SASL auth. + buf.put_slice(extra); + }); + } + // + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Final(extra)) => { + buf.write_raw(extra.len() + 1, b'R', |buf| { + buf.put_i32(12); // Send final SASL message. + buf.put_slice(extra); + }); + } + + // + BeMessage::BackendKeyData(key_data) => { + buf.write_raw(8, b'K', |buf| buf.put_slice(key_data.as_bytes())); + } + + // + // + BeMessage::NoticeResponse(msg) => { + // 'N' signalizes NoticeResponse messages + buf.write_raw(18 + msg.len(), b'N', |buf| { + // Severity: NOTICE + buf.put_slice(b"SNOTICE\0"); + + // Code: XX000 (ignored for notice, but still required) + buf.put_slice(b"CXX000\0"); + + // Message: msg + buf.put_u8(b'M'); + buf.put_slice(msg.as_bytes()); + buf.put_u8(0); + + // End notice. + buf.put_u8(0); + }); + } + + // + BeMessage::ParameterStatus { name, value } => { + buf.write_raw(name.len() + value.len() + 2, b'S', |buf| { + buf.put_slice(name.as_bytes()); + buf.put_u8(0); + buf.put_slice(value.as_bytes()); + buf.put_u8(0); + }); + } + + // + BeMessage::ReadyForQuery => { + buf.write_raw(1, b'Z', |buf| buf.put_u8(b'I')); + } + + // + BeMessage::NegotiateProtocolVersion { version, options } => { + let len: usize = options.iter().map(|o| o.len() + 1).sum(); + buf.write_raw(8 + len, b'v', |buf| { + buf.put_slice(version.as_bytes()); + buf.put_u32(options.len() as u32); + for option in options { + buf.put_slice(option.as_bytes()); + buf.put_u8(0); + } + }); + } + } + } +} + +#[cfg(test)] +mod tests { + use std::io::Cursor; + + use tokio::io::{AsyncWriteExt, duplex}; + use zerocopy::IntoBytes; + + use crate::pqproto::{FeStartupPacket, read_message, read_startup}; + + use super::ProtocolVersion; + + #[tokio::test] + async fn reject_large_startup() { + // we're going to define a v3.0 startup message with far too many parameters. + let mut payload = vec![]; + // 10001 + 8 bytes. + payload.extend_from_slice(&10009_u32.to_be_bytes()); + payload.extend_from_slice(ProtocolVersion::new(3, 0).as_bytes()); + payload.resize(10009, b'a'); + + let (mut server, mut client) = duplex(128); + #[rustfmt::skip] + let (server, client) = tokio::join!( + async move { read_startup(&mut server).await.unwrap_err() }, + async move { client.write_all(&payload).await.unwrap_err() }, + ); + + assert_eq!(server.to_string(), "invalid startup message length 10001"); + assert_eq!(client.to_string(), "broken pipe"); + } + + #[tokio::test] + async fn reject_large_password() { + // we're going to define a password message that is far too long. + let mut payload = vec![]; + payload.push(b'p'); + payload.extend_from_slice(&517_u32.to_be_bytes()); + payload.resize(518, b'a'); + + let (mut server, mut client) = duplex(128); + #[rustfmt::skip] + let (server, client) = tokio::join!( + async move { read_message(&mut server, &mut vec![], 512).await.unwrap_err() }, + async move { client.write_all(&payload).await.unwrap_err() }, + ); + + assert_eq!(server.to_string(), "invalid message length 513"); + assert_eq!(client.to_string(), "broken pipe"); + } + + #[tokio::test] + async fn read_startup_message() { + let mut payload = vec![]; + payload.extend_from_slice(&17_u32.to_be_bytes()); + payload.extend_from_slice(ProtocolVersion::new(3, 0).as_bytes()); + payload.extend_from_slice(b"abc\0def\0\0"); + + let startup = read_startup(&mut Cursor::new(&payload)).await.unwrap(); + let FeStartupPacket::StartupMessage { version, params } = startup else { + panic!("unexpected startup message: {startup:?}"); + }; + + assert_eq!(version.major(), 3); + assert_eq!(version.minor(), 0); + assert_eq!(params.params, "abc\0def\0"); + } + + #[tokio::test] + async fn read_ssl_message() { + let mut payload = vec![]; + payload.extend_from_slice(&8_u32.to_be_bytes()); + payload.extend_from_slice(ProtocolVersion::new(1234, 5679).as_bytes()); + + let startup = read_startup(&mut Cursor::new(&payload)).await.unwrap(); + let FeStartupPacket::SslRequest { direct: None } = startup else { + panic!("unexpected startup message: {startup:?}"); + }; + } + + #[tokio::test] + async fn read_tls_message() { + // sample client hello taken from + let client_hello = [ + 0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00, 0xf4, 0x03, 0x03, 0x00, 0x01, 0x02, + 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, + 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, + 0x1f, 0x20, 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb, + 0xec, 0xed, 0xee, 0xef, 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, + 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, 0x00, 0x08, 0x13, 0x02, 0x13, 0x03, 0x13, 0x01, + 0x00, 0xff, 0x01, 0x00, 0x00, 0xa3, 0x00, 0x00, 0x00, 0x18, 0x00, 0x16, 0x00, 0x00, + 0x13, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x75, 0x6c, 0x66, 0x68, 0x65, + 0x69, 0x6d, 0x2e, 0x6e, 0x65, 0x74, 0x00, 0x0b, 0x00, 0x04, 0x03, 0x00, 0x01, 0x02, + 0x00, 0x0a, 0x00, 0x16, 0x00, 0x14, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x1e, 0x00, 0x19, + 0x00, 0x18, 0x01, 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x23, + 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x1e, + 0x00, 0x1c, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x08, 0x07, 0x08, 0x08, 0x08, 0x09, + 0x08, 0x0a, 0x08, 0x0b, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06, 0x04, 0x01, 0x05, 0x01, + 0x06, 0x01, 0x00, 0x2b, 0x00, 0x03, 0x02, 0x03, 0x04, 0x00, 0x2d, 0x00, 0x02, 0x01, + 0x01, 0x00, 0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0x35, 0x80, 0x72, + 0xd6, 0x36, 0x58, 0x80, 0xd1, 0xae, 0xea, 0x32, 0x9a, 0xdf, 0x91, 0x21, 0x38, 0x38, + 0x51, 0xed, 0x21, 0xa2, 0x8e, 0x3b, 0x75, 0xe9, 0x65, 0xd0, 0xd2, 0xcd, 0x16, 0x62, + 0x54, + ]; + + let mut cursor = Cursor::new(&client_hello); + + let startup = read_startup(&mut cursor).await.unwrap(); + let FeStartupPacket::SslRequest { + direct: Some(prefix), + } = startup + else { + panic!("unexpected startup message: {startup:?}"); + }; + + // check that no data is lost. + assert_eq!(prefix, [0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00]); + assert_eq!(cursor.position(), 8); + } + + #[tokio::test] + async fn read_message_success() { + let query = b"Q\0\0\0\x0cSELECT 1Q\0\0\0\x0cSELECT 2"; + let mut cursor = Cursor::new(&query); + + let mut buf = vec![]; + let (tag, message) = read_message(&mut cursor, &mut buf, 100).await.unwrap(); + assert_eq!(tag, b'Q'); + assert_eq!(message, b"SELECT 1"); + + let (tag, message) = read_message(&mut cursor, &mut buf, 100).await.unwrap(); + assert_eq!(tag, b'Q'); + assert_eq!(message, b"SELECT 2"); + } +} diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index e013fbbe2e..57785c9ec5 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -1,5 +1,4 @@ use async_trait::async_trait; -use pq_proto::StartupMessageParams; use tokio::time; use tracing::{debug, info, warn}; @@ -15,6 +14,7 @@ use crate::error::ReportableError; use crate::metrics::{ ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType, }; +use crate::pqproto::StartupMessageParams; use crate::proxy::retry::{CouldRetry, retry_after, should_retry}; use crate::proxy::wake_compute::wake_compute; use crate::types::Host; diff --git a/proxy/src/proxy/handshake.rs b/proxy/src/proxy/handshake.rs index 54c02f2c15..13ee8c7dd2 100644 --- a/proxy/src/proxy/handshake.rs +++ b/proxy/src/proxy/handshake.rs @@ -1,8 +1,3 @@ -use bytes::Buf; -use pq_proto::framed::Framed; -use pq_proto::{ - BeMessage as Be, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams, -}; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{debug, info, warn}; @@ -12,7 +7,10 @@ use crate::config::TlsConfig; use crate::context::RequestContext; use crate::error::ReportableError; use crate::metrics::Metrics; -use crate::proxy::ERR_INSECURE_CONNECTION; +use crate::pqproto::{ + BeMessage, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams, +}; +use crate::proxy::TlsRequired; use crate::stream::{PqStream, Stream, StreamUpgradeError}; use crate::tls::PG_ALPN_PROTOCOL; @@ -71,33 +69,25 @@ pub(crate) async fn handshake( const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0); const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0); - let mut stream = PqStream::new(Stream::from_raw(stream)); + let (mut stream, mut msg) = PqStream::parse_startup(Stream::from_raw(stream)).await?; loop { - let msg = stream.read_startup_packet().await?; match msg { FeStartupPacket::SslRequest { direct } => match stream.get_ref() { Stream::Raw { .. } if !tried_ssl => { tried_ssl = true; - // We can't perform TLS handshake without a config - let have_tls = tls.is_some(); - if !direct { - stream - .write_message(&Be::EncryptionResponse(have_tls)) - .await?; - } else if !have_tls { - return Err(HandshakeError::ProtocolViolation); - } - if let Some(tls) = tls.take() { // Upgrade raw stream into a secure TLS-backed stream. // NOTE: We've consumed `tls`; this fact will be used later. - let Framed { - stream: raw, - read_buf, - write_buf, - } = stream.framed; + let mut read_buf; + let raw = if let Some(direct) = &direct { + read_buf = &direct[..]; + stream.accept_direct_tls() + } else { + read_buf = &[]; + stream.accept_tls().await? + }; let Stream::Raw { raw } = raw else { return Err(HandshakeError::StreamUpgradeError( @@ -105,12 +95,11 @@ pub(crate) async fn handshake( )); }; - let mut read_buf = read_buf.reader(); let mut res = Ok(()); let accept = tokio_rustls::TlsAcceptor::from(tls.pg_config.clone()) .accept_with(raw, |session| { // push the early data to the tls session - while !read_buf.get_ref().is_empty() { + while !read_buf.is_empty() { match session.read_tls(&mut read_buf) { Ok(_) => {} Err(e) => { @@ -123,7 +112,6 @@ pub(crate) async fn handshake( res?; - let read_buf = read_buf.into_inner(); if !read_buf.is_empty() { return Err(HandshakeError::EarlyData); } @@ -157,16 +145,17 @@ pub(crate) async fn handshake( let (_, tls_server_end_point) = tls.cert_resolver.resolve(conn_info.server_name()); - stream = PqStream { - framed: Framed { - stream: Stream::Tls { - tls: Box::new(tls_stream), - tls_server_end_point, - }, - read_buf, - write_buf, - }, + let tls = Stream::Tls { + tls: Box::new(tls_stream), + tls_server_end_point, }; + (stream, msg) = PqStream::parse_startup(tls).await?; + } else { + if direct.is_some() { + // client sent us a ClientHello already, we can't do anything with it. + return Err(HandshakeError::ProtocolViolation); + } + msg = stream.reject_encryption().await?; } } _ => return Err(HandshakeError::ProtocolViolation), @@ -176,7 +165,7 @@ pub(crate) async fn handshake( tried_gss = true; // Currently, we don't support GSSAPI - stream.write_message(&Be::EncryptionResponse(false)).await?; + msg = stream.reject_encryption().await?; } _ => return Err(HandshakeError::ProtocolViolation), }, @@ -186,13 +175,7 @@ pub(crate) async fn handshake( // Check that the config has been consumed during upgrade // OR we didn't provide it at all (for dev purposes). if tls.is_some() { - return stream - .throw_error_str( - ERR_INSECURE_CONNECTION, - crate::error::ErrorKind::User, - None, - ) - .await?; + Err(stream.throw_error(TlsRequired, None).await)?; } // This log highlights the start of the connection. @@ -214,20 +197,21 @@ pub(crate) async fn handshake( // no protocol extensions are supported. // let mut unsupported = vec![]; - for (k, _) in params.iter() { + let mut supported = StartupMessageParams::default(); + + for (k, v) in params.iter() { if k.starts_with("_pq_.") { unsupported.push(k); + } else { + supported.insert(k, v); } } - // TODO: remove unsupported options so we don't send them to compute. - - stream - .write_message(&Be::NegotiateProtocolVersion { - version: PG_PROTOCOL_LATEST, - options: &unsupported, - }) - .await?; + stream.write_message(BeMessage::NegotiateProtocolVersion { + version: PG_PROTOCOL_LATEST, + options: &unsupported, + }); + stream.flush().await?; info!( ?version, @@ -235,7 +219,7 @@ pub(crate) async fn handshake( session_type = "normal", "successful handshake; unsupported minor version requested" ); - break Ok(HandshakeData::Startup(stream, params)); + break Ok(HandshakeData::Startup(stream, supported)); } FeStartupPacket::StartupMessage { version, params } => { warn!( diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 0a86022e78..26ac6a89e7 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -10,15 +10,14 @@ pub(crate) mod wake_compute; use std::sync::Arc; pub use copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute}; -use futures::{FutureExt, TryFutureExt}; +use futures::FutureExt; use itertools::Itertools; use once_cell::sync::OnceCell; -use pq_proto::{BeMessage as Be, CancelKeyData, StartupMessageParams}; use regex::Regex; use serde::{Deserialize, Serialize}; use smol_str::{SmolStr, ToSmolStr, format_smolstr}; use thiserror::Error; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::sync::CancellationToken; use tracing::{Instrument, debug, error, info, warn}; @@ -27,8 +26,9 @@ use self::passthrough::ProxyPassthrough; use crate::cancellation::{self, CancellationHandler}; use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig}; use crate::context::RequestContext; -use crate::error::ReportableError; +use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumClientConnectionsGuard}; +use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams}; use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol}; use crate::proxy::handshake::{HandshakeData, handshake}; use crate::rate_limiter::EndpointRateLimiter; @@ -38,6 +38,18 @@ use crate::{auth, compute}; const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; +#[derive(Error, Debug)] +#[error("{ERR_INSECURE_CONNECTION}")] +pub struct TlsRequired; + +impl ReportableError for TlsRequired { + fn get_error_kind(&self) -> crate::error::ErrorKind { + crate::error::ErrorKind::User + } +} + +impl UserFacingError for TlsRequired {} + pub async fn run_until_cancelled( f: F, cancellation_token: &CancellationToken, @@ -329,7 +341,7 @@ pub(crate) async fn handle_client( let user_info = match result { Ok(user_info) => user_info, - Err(e) => stream.throw_error(e, Some(ctx)).await?, + Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, }; let user = user_info.get_user().to_owned(); @@ -349,10 +361,10 @@ pub(crate) async fn handle_client( let app = params.get("application_name"); let params_span = tracing::info_span!("", ?user, ?db, ?app); - return stream + return Err(stream .throw_error(e, Some(ctx)) .instrument(params_span) - .await?; + .await)?; } }; @@ -365,7 +377,7 @@ pub(crate) async fn handle_client( .get(NeonOptions::PARAMS_COMPAT) .is_some(); - let mut node = connect_to_compute( + let res = connect_to_compute( ctx, &TcpMechanism { user_info: compute_user_info.clone(), @@ -377,22 +389,19 @@ pub(crate) async fn handle_client( config.wake_compute_retry_config, &config.connect_to_compute, ) - .or_else(|e| stream.throw_error(e, Some(ctx))) - .await?; + .await; + + let node = match res { + Ok(node) => node, + Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, + }; let cancellation_handler_clone = Arc::clone(&cancellation_handler); let session = cancellation_handler_clone.get_key(); session.write_cancel_key(node.cancel_closure.clone())?; - - prepare_client_connection(&node, *session.key(), &mut stream).await?; - - // Before proxy passing, forward to compute whatever data is left in the - // PqStream input buffer. Normally there is none, but our serverless npm - // driver in pipeline mode sends startup, password and first query - // immediately after opening the connection. - let (stream, read_buf) = stream.into_inner(); - node.stream.write_all(&read_buf).await?; + prepare_client_connection(&node, *session.key(), &mut stream); + let stream = stream.flush_and_into_inner().await?; let private_link_id = match ctx.extra() { Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()), @@ -413,31 +422,28 @@ pub(crate) async fn handle_client( } /// Finish client connection initialization: confirm auth success, send params, etc. -#[tracing::instrument(skip_all)] -pub(crate) async fn prepare_client_connection( +pub(crate) fn prepare_client_connection( node: &compute::PostgresConnection, cancel_key_data: CancelKeyData, stream: &mut PqStream, -) -> Result<(), std::io::Error> { +) { // Forward all deferred notices to the client. for notice in &node.delayed_notice { - stream.write_message_noflush(&Be::Raw(b'N', notice.as_bytes()))?; + stream.write_raw(notice.as_bytes().len(), b'N', |buf| { + buf.extend_from_slice(notice.as_bytes()); + }); } // Forward all postgres connection params to the client. for (name, value) in &node.params { - stream.write_message_noflush(&Be::ParameterStatus { + stream.write_message(BeMessage::ParameterStatus { name: name.as_bytes(), value: value.as_bytes(), - })?; + }); } - stream - .write_message_noflush(&Be::BackendKeyData(cancel_key_data))? - .write_message(&Be::ReadyForQuery) - .await?; - - Ok(()) + stream.write_message(BeMessage::BackendKeyData(cancel_key_data)); + stream.write_message(BeMessage::ReadyForQuery); } #[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)] diff --git a/proxy/src/proxy/retry.rs b/proxy/src/proxy/retry.rs index 0879564ced..01e603ec14 100644 --- a/proxy/src/proxy/retry.rs +++ b/proxy/src/proxy/retry.rs @@ -125,9 +125,10 @@ pub(crate) fn retry_after(num_retries: u32, config: RetryConfig) -> time::Durati #[cfg(test)] mod tests { - use super::ShouldRetryWakeCompute; use postgres_client::error::{DbError, SqlState}; + use super::ShouldRetryWakeCompute; + #[test] fn should_retry_wake_compute_for_db_error() { // These SQLStates should NOT trigger a wake_compute retry. diff --git a/proxy/src/proxy/tests/mitm.rs b/proxy/src/proxy/tests/mitm.rs index 59c9ac27b8..c92ee49b8d 100644 --- a/proxy/src/proxy/tests/mitm.rs +++ b/proxy/src/proxy/tests/mitm.rs @@ -10,7 +10,7 @@ use bytes::{Bytes, BytesMut}; use futures::{SinkExt, StreamExt}; use postgres_client::tls::TlsConnect; use postgres_protocol::message::frontend; -use tokio::io::{AsyncReadExt, DuplexStream}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream}; use tokio_util::codec::{Decoder, Encoder}; use super::*; @@ -49,15 +49,14 @@ async fn proxy_mitm( }; let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame); - let (end_client, buf) = end_client.framed.into_inner(); - assert!(buf.is_empty()); + let end_client = end_client.flush_and_into_inner().await.unwrap(); let mut end_client = tokio_util::codec::Framed::new(end_client, PgFrame); // give the end_server the startup parameters let mut buf = BytesMut::new(); frontend::startup_message( &postgres_protocol::message::frontend::StartupMessageParams { - params: startup.params.into(), + params: startup.params.as_bytes().into(), }, &mut buf, ) diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index be6426a63c..3cc053e0ad 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -128,7 +128,7 @@ trait TestAuth: Sized { self, stream: &mut PqStream>, ) -> anyhow::Result<()> { - stream.write_message_noflush(&Be::AuthenticationOk)?; + stream.write_message(BeMessage::AuthenticationOk); Ok(()) } } @@ -157,9 +157,7 @@ impl TestAuth for Scram { self, stream: &mut PqStream>, ) -> anyhow::Result<()> { - let outcome = auth::AuthFlow::new(stream) - .begin(auth::Scram(&self.0, &RequestContext::test())) - .await? + let outcome = auth::AuthFlow::new(stream, auth::Scram(&self.0, &RequestContext::test())) .authenticate() .await?; @@ -185,10 +183,12 @@ async fn dummy_proxy( auth.authenticate(&mut stream).await?; - stream - .write_message_noflush(&Be::CLIENT_ENCODING)? - .write_message(&Be::ReadyForQuery) - .await?; + stream.write_message(BeMessage::ParameterStatus { + name: b"client_encoding", + value: b"UTF8", + }); + stream.write_message(BeMessage::ReadyForQuery); + stream.flush().await?; Ok(()) } diff --git a/proxy/src/redis/cancellation_publisher.rs b/proxy/src/redis/cancellation_publisher.rs index 186fece4b2..6f56aeea06 100644 --- a/proxy/src/redis/cancellation_publisher.rs +++ b/proxy/src/redis/cancellation_publisher.rs @@ -1,10 +1,11 @@ use core::net::IpAddr; use std::sync::Arc; -use pq_proto::CancelKeyData; use tokio::sync::Mutex; use uuid::Uuid; +use crate::pqproto::CancelKeyData; + pub trait CancellationPublisherMut: Send + Sync + 'static { #[allow(async_fn_in_trait)] async fn try_publish( diff --git a/proxy/src/redis/keys.rs b/proxy/src/redis/keys.rs index 7527bca6d0..3113bad949 100644 --- a/proxy/src/redis/keys.rs +++ b/proxy/src/redis/keys.rs @@ -1,16 +1,15 @@ use std::io::ErrorKind; use anyhow::Ok; -use pq_proto::{CancelKeyData, id_to_cancel_key}; -use serde::{Deserialize, Serialize}; + +use crate::pqproto::{CancelKeyData, id_to_cancel_key}; pub mod keyspace { pub const CANCEL_PREFIX: &str = "cancel"; } -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +#[derive(Clone, Debug, Eq, PartialEq)] pub(crate) enum KeyPrefix { - #[serde(untagged)] Cancel(CancelKeyData), } @@ -18,9 +17,7 @@ impl KeyPrefix { pub(crate) fn build_redis_key(&self) -> String { match self { KeyPrefix::Cancel(key) => { - let hi = (key.backend_pid as u64) << 32; - let lo = (key.cancel_key as u64) & 0xffff_ffff; - let id = hi | lo; + let id = key.0.get(); let keyspace = keyspace::CANCEL_PREFIX; format!("{keyspace}:{id:x}") } @@ -63,10 +60,7 @@ mod tests { #[test] fn test_build_redis_key() { - let cancel_key: KeyPrefix = KeyPrefix::Cancel(CancelKeyData { - backend_pid: 12345, - cancel_key: 54321, - }); + let cancel_key: KeyPrefix = KeyPrefix::Cancel(id_to_cancel_key(12345 << 32 | 54321)); let redis_key = cancel_key.build_redis_key(); assert_eq!(redis_key, "cancel:30390000d431"); @@ -77,10 +71,7 @@ mod tests { let redis_key = "cancel:30390000d431"; let key: KeyPrefix = parse_redis_key(redis_key).expect("Failed to parse key"); - let ref_key = CancelKeyData { - backend_pid: 12345, - cancel_key: 54321, - }; + let ref_key = id_to_cancel_key(12345 << 32 | 54321); assert_eq!(key.as_str(), KeyPrefix::Cancel(ref_key).as_str()); let KeyPrefix::Cancel(cancel_key) = key; diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 5f9f2509e2..769d519d94 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -2,11 +2,9 @@ use std::convert::Infallible; use std::sync::Arc; use futures::StreamExt; -use pq_proto::CancelKeyData; use redis::aio::PubSub; use serde::{Deserialize, Serialize}; use tokio_util::sync::CancellationToken; -use uuid::Uuid; use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::cache::project_info::ProjectInfoCache; @@ -100,14 +98,6 @@ pub(crate) struct PasswordUpdate { role_name: RoleNameInt, } -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] -pub(crate) struct CancelSession { - pub(crate) region_id: Option, - pub(crate) cancel_key_data: CancelKeyData, - pub(crate) session_id: Uuid, - pub(crate) peer_addr: Option, -} - fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result where T: for<'de2> serde::Deserialize<'de2>, diff --git a/proxy/src/sasl/messages.rs b/proxy/src/sasl/messages.rs index 7f2f3a761c..8d26a3f453 100644 --- a/proxy/src/sasl/messages.rs +++ b/proxy/src/sasl/messages.rs @@ -1,7 +1,5 @@ //! Definitions for SASL messages. -use pq_proto::{BeAuthenticationSaslMessage, BeMessage}; - use crate::parse::split_cstr; /// SASL-specific payload of [`PasswordMessage`](pq_proto::FeMessage::PasswordMessage). @@ -30,26 +28,6 @@ impl<'a> FirstMessage<'a> { } } -/// A single SASL message. -/// This struct is deliberately decoupled from lower-level -/// [`BeAuthenticationSaslMessage`]. -#[derive(Debug)] -pub(super) enum ServerMessage { - /// We expect to see more steps. - Continue(T), - /// This is the final step. - Final(T), -} - -impl<'a> ServerMessage<&'a str> { - pub(super) fn to_reply(&self) -> BeMessage<'a> { - BeMessage::AuthenticationSasl(match self { - ServerMessage::Continue(s) => BeAuthenticationSaslMessage::Continue(s.as_bytes()), - ServerMessage::Final(s) => BeAuthenticationSaslMessage::Final(s.as_bytes()), - }) - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/proxy/src/sasl/mod.rs b/proxy/src/sasl/mod.rs index f0181b404f..007b62dfd2 100644 --- a/proxy/src/sasl/mod.rs +++ b/proxy/src/sasl/mod.rs @@ -14,7 +14,7 @@ use std::io; pub(crate) use channel_binding::ChannelBinding; pub(crate) use messages::FirstMessage; -pub(crate) use stream::{Outcome, SaslStream}; +pub(crate) use stream::{Outcome, authenticate}; use thiserror::Error; use crate::error::{ReportableError, UserFacingError}; @@ -22,6 +22,9 @@ use crate::error::{ReportableError, UserFacingError}; /// Fine-grained auth errors help in writing tests. #[derive(Error, Debug)] pub(crate) enum Error { + #[error("Unsupported authentication method: {0}")] + BadAuthMethod(Box), + #[error("Channel binding failed: {0}")] ChannelBindingFailed(&'static str), @@ -54,6 +57,7 @@ impl UserFacingError for Error { impl ReportableError for Error { fn get_error_kind(&self) -> crate::error::ErrorKind { match self { + Error::BadAuthMethod(_) => crate::error::ErrorKind::User, Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User, Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User, Error::BadClientMessage(_) => crate::error::ErrorKind::User, diff --git a/proxy/src/sasl/stream.rs b/proxy/src/sasl/stream.rs index 46e6a439e5..cb15132673 100644 --- a/proxy/src/sasl/stream.rs +++ b/proxy/src/sasl/stream.rs @@ -3,61 +3,12 @@ use std::io; use tokio::io::{AsyncRead, AsyncWrite}; -use tracing::info; -use super::Mechanism; -use super::messages::ServerMessage; +use super::{Mechanism, Step}; +use crate::context::RequestContext; +use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage}; use crate::stream::PqStream; -/// Abstracts away all peculiarities of the libpq's protocol. -pub(crate) struct SaslStream<'a, S> { - /// The underlying stream. - stream: &'a mut PqStream, - /// Current password message we received from client. - current: bytes::Bytes, - /// First SASL message produced by client. - first: Option<&'a str>, -} - -impl<'a, S> SaslStream<'a, S> { - pub(crate) fn new(stream: &'a mut PqStream, first: &'a str) -> Self { - Self { - stream, - current: bytes::Bytes::new(), - first: Some(first), - } - } -} - -impl SaslStream<'_, S> { - // Receive a new SASL message from the client. - async fn recv(&mut self) -> io::Result<&str> { - if let Some(first) = self.first.take() { - return Ok(first); - } - - self.current = self.stream.read_password_message().await?; - let s = std::str::from_utf8(&self.current) - .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?; - - Ok(s) - } -} - -impl SaslStream<'_, S> { - // Send a SASL message to the client. - async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> { - self.stream.write_message(&msg.to_reply()).await?; - Ok(()) - } - - // Queue a SASL message for the client. - fn send_noflush(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> { - self.stream.write_message_noflush(&msg.to_reply())?; - Ok(()) - } -} - /// SASL authentication outcome. /// It's much easier to match on those two variants /// than to peek into a noisy protocol error type. @@ -69,33 +20,62 @@ pub(crate) enum Outcome { Failure(&'static str), } -impl SaslStream<'_, S> { - /// Perform SASL message exchange according to the underlying algorithm - /// until user is either authenticated or denied access. - pub(crate) async fn authenticate( - mut self, - mut mechanism: M, - ) -> super::Result> { - loop { - let input = self.recv().await?; - let step = mechanism.exchange(input).map_err(|error| { - info!(?error, "error during SASL exchange"); - error - })?; +pub async fn authenticate( + ctx: &RequestContext, + stream: &mut PqStream, + mechanism: F, +) -> super::Result> +where + S: AsyncRead + AsyncWrite + Unpin, + F: FnOnce(&str) -> super::Result, + M: Mechanism, +{ + let sasl = { + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - use super::Step; - return Ok(match step { - Step::Continue(moved_mechanism, reply) => { - self.send(&ServerMessage::Continue(&reply)).await?; - mechanism = moved_mechanism; - continue; - } - Step::Success(result, reply) => { - self.send_noflush(&ServerMessage::Final(&reply))?; - Outcome::Success(result) - } - Step::Failure(reason) => Outcome::Failure(reason), - }); + // Initial client message contains the chosen auth method's name. + let msg = stream.read_password_message().await?; + super::FirstMessage::parse(msg).ok_or(super::Error::BadClientMessage("bad sasl message"))? + }; + + let mut mechanism = mechanism(sasl.method)?; + let mut input = sasl.message; + loop { + let step = mechanism + .exchange(input) + .inspect_err(|error| tracing::info!(?error, "error during SASL exchange"))?; + + match step { + Step::Continue(moved_mechanism, reply) => { + mechanism = moved_mechanism; + + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + + // write reply + let sasl_msg = BeAuthenticationSaslMessage::Continue(reply.as_bytes()); + stream.write_message(BeMessage::AuthenticationSasl(sasl_msg)); + + // get next input + stream.flush().await?; + let msg = stream.read_password_message().await?; + input = std::str::from_utf8(msg) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?; + } + Step::Success(result, reply) => { + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + + // write reply + let sasl_msg = BeAuthenticationSaslMessage::Final(reply.as_bytes()); + stream.write_message(BeMessage::AuthenticationSasl(sasl_msg)); + stream.write_message(BeMessage::AuthenticationOk); + // exit with success + break Ok(Outcome::Success(result)); + } + // exit with failure + Step::Failure(reason) => break Ok(Outcome::Failure(reason)), } } } diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 1c5bb64480..eb80ac9ad0 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -17,7 +17,6 @@ use postgres_client::error::{DbError, ErrorPosition, SqlState}; use postgres_client::{ GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction, }; -use pq_proto::StartupMessageParamsBuilder; use serde::Serialize; use serde_json::Value; use serde_json::value::RawValue; @@ -41,6 +40,7 @@ use crate::context::RequestContext; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::http::{ReadBodyError, read_body_with_limit}; use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind}; +use crate::pqproto::StartupMessageParams; use crate::proxy::{NeonOptions, run_until_cancelled}; use crate::serverless::backend::HttpConnError; use crate::types::{DbName, RoleName}; @@ -219,7 +219,7 @@ fn get_conn_info( let mut options = Option::None; - let mut params = StartupMessageParamsBuilder::default(); + let mut params = StartupMessageParams::default(); params.insert("user", &username); params.insert("database", &dbname); for (key, value) in pairs { diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 360550b0ac..7126430a85 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -2,19 +2,17 @@ use std::pin::Pin; use std::sync::Arc; use std::{io, task}; -use bytes::BytesMut; -use pq_proto::framed::{ConnectionError, Framed}; -use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError}; use rustls::ServerConfig; -use serde::{Deserialize, Serialize}; use thiserror::Error; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; use tokio_rustls::server::TlsStream; -use tracing::debug; -use crate::control_plane::messages::ColdStartInfo; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::metrics::Metrics; +use crate::pqproto::{ + BeMessage, FE_PASSWORD_MESSAGE, FeStartupPacket, SQLSTATE_INTERNAL_ERROR, WriteBuf, + read_message, read_startup, +}; use crate::tls::TlsServerEndPoint; /// Stream wrapper which implements libpq's protocol. @@ -23,58 +21,77 @@ use crate::tls::TlsServerEndPoint; /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying /// to pass random malformed bytes through the connection). pub struct PqStream { - pub(crate) framed: Framed, + stream: S, + read: Vec, + write: WriteBuf, } impl PqStream { - /// Construct a new libpq protocol wrapper. - pub fn new(stream: S) -> Self { + pub fn get_ref(&self) -> &S { + &self.stream + } + + /// Construct a new libpq protocol wrapper over a stream without the first startup message. + #[cfg(test)] + pub fn new_skip_handshake(stream: S) -> Self { Self { - framed: Framed::new(stream), + stream, + read: Vec::new(), + write: WriteBuf::new(), } } - - /// Extract the underlying stream and read buffer. - pub fn into_inner(self) -> (S, BytesMut) { - self.framed.into_inner() - } - - /// Get a shared reference to the underlying stream. - pub(crate) fn get_ref(&self) -> &S { - self.framed.get_ref() - } } -fn err_connection() -> io::Error { - io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost") +impl PqStream { + /// Construct a new libpq protocol wrapper and read the first startup message. + /// + /// This is not cancel safe. + pub async fn parse_startup(mut stream: S) -> io::Result<(Self, FeStartupPacket)> { + let startup = read_startup(&mut stream).await?; + Ok(( + Self { + stream, + read: Vec::new(), + write: WriteBuf::new(), + }, + startup, + )) + } + + /// Tell the client that encryption is not supported. + /// + /// This is not cancel safe + pub async fn reject_encryption(&mut self) -> io::Result { + // N for No. + self.write.encryption(b'N'); + self.flush().await?; + read_startup(&mut self.stream).await + } } impl PqStream { - /// Receive [`FeStartupPacket`], which is a first packet sent by a client. - pub async fn read_startup_packet(&mut self) -> io::Result { - self.framed - .read_startup_message() - .await - .map_err(ConnectionError::into_io_error)? - .ok_or_else(err_connection) - } - - async fn read_message(&mut self) -> io::Result { - self.framed - .read_message() - .await - .map_err(ConnectionError::into_io_error)? - .ok_or_else(err_connection) - } - - pub(crate) async fn read_password_message(&mut self) -> io::Result { - match self.read_message().await? { - FeMessage::PasswordMessage(msg) => Ok(msg), - bad => Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("unexpected message type: {bad:?}"), - )), + /// Read a raw postgres packet, which will respect the max length requested. + /// This is not cancel safe. + async fn read_raw_expect(&mut self, tag: u8, max: usize) -> io::Result<&mut [u8]> { + let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?; + if actual_tag != tag { + return Err(io::Error::other(format!( + "incorrect message tag, expected {:?}, got {:?}", + tag as char, actual_tag as char, + ))); } + Ok(msg) + } + + /// Read a postgres password message, which will respect the max length requested. + /// This is not cancel safe. + pub async fn read_password_message(&mut self) -> io::Result<&mut [u8]> { + // passwords are usually pretty short + // and SASL SCRAM messages are no longer than 256 bytes in my testing + // (a few hashes and random bytes, encoded into base64). + const MAX_PASSWORD_LENGTH: usize = 512; + self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH) + .await } } @@ -84,6 +101,16 @@ pub struct ReportedError { error_kind: ErrorKind, } +impl ReportedError { + pub fn new(e: (impl UserFacingError + Into)) -> Self { + let error_kind = e.get_error_kind(); + Self { + source: e.into(), + error_kind, + } + } +} + impl std::fmt::Display for ReportedError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.source.fmt(f) @@ -102,109 +129,65 @@ impl ReportableError for ReportedError { } } -#[derive(Serialize, Deserialize, Debug)] -enum ErrorTag { - #[serde(rename = "proxy")] - Proxy, - #[serde(rename = "compute")] - Compute, - #[serde(rename = "client")] - Client, - #[serde(rename = "controlplane")] - ControlPlane, - #[serde(rename = "other")] - Other, -} - -impl From for ErrorTag { - fn from(error_kind: ErrorKind) -> Self { - match error_kind { - ErrorKind::User => Self::Client, - ErrorKind::ClientDisconnect => Self::Client, - ErrorKind::RateLimit => Self::Proxy, - ErrorKind::ServiceRateLimit => Self::Proxy, // considering rate limit as proxy error for SLI - ErrorKind::Quota => Self::Proxy, - ErrorKind::Service => Self::Proxy, - ErrorKind::ControlPlane => Self::ControlPlane, - ErrorKind::Postgres => Self::Other, - ErrorKind::Compute => Self::Compute, - } - } -} - -#[derive(Serialize, Deserialize, Debug)] -#[serde(rename_all = "snake_case")] -struct ProbeErrorData { - tag: ErrorTag, - msg: String, - cold_start_info: Option, -} - impl PqStream { - /// Write the message into an internal buffer, but don't flush the underlying stream. - pub(crate) fn write_message_noflush( - &mut self, - message: &BeMessage<'_>, - ) -> io::Result<&mut Self> { - self.framed - .write_message(message) - .map_err(ProtocolError::into_io_error)?; - Ok(self) + /// Tell the client that we are willing to accept SSL. + /// This is not cancel safe + pub async fn accept_tls(mut self) -> io::Result { + // S for SSL. + self.write.encryption(b'S'); + self.flush().await?; + Ok(self.stream) } - /// Write the message into an internal buffer and flush it. - pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> { - self.write_message_noflush(message)?; - self.flush().await?; - Ok(self) + /// Assert that we are using direct TLS. + pub fn accept_direct_tls(self) -> S { + self.stream + } + + /// Write a raw message to the internal buffer. + pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec)) { + self.write.write_raw(size_hint, tag, f); + } + + /// Write the message into an internal buffer + pub fn write_message(&mut self, message: BeMessage<'_>) { + message.write_message(&mut self.write); } /// Flush the output buffer into the underlying stream. - pub(crate) async fn flush(&mut self) -> io::Result<&mut Self> { - self.framed.flush().await?; - Ok(self) + /// + /// This is cancel safe. + pub async fn flush(&mut self) -> io::Result<()> { + self.stream.write_all_buf(&mut self.write).await?; + self.write.reset(); + + self.stream.flush().await?; + + Ok(()) } - /// Writes message with the given error kind to the stream. - /// Used only for probe queries - async fn write_format_message( - &mut self, - msg: &str, - error_kind: ErrorKind, - ctx: Option<&crate::context::RequestContext>, - ) -> String { - let formatted_msg = match ctx { - Some(ctx) if ctx.get_testodrome_id().is_some() => { - serde_json::to_string(&ProbeErrorData { - tag: ErrorTag::from(error_kind), - msg: msg.to_string(), - cold_start_info: Some(ctx.cold_start_info()), - }) - .unwrap_or_default() - } - _ => msg.to_string(), - }; - - // already error case, ignore client IO error - self.write_message(&BeMessage::ErrorResponse(&formatted_msg, None)) - .await - .inspect_err(|e| debug!("write_message failed: {e}")) - .ok(); - - formatted_msg + /// Flush the output buffer into the underlying stream. + /// + /// This is cancel safe. + pub async fn flush_and_into_inner(mut self) -> io::Result { + self.flush().await?; + Ok(self.stream) } - /// Write the error message using [`Self::write_format_message`], then re-throw it. - /// Allowing string literals is safe under the assumption they might not contain any runtime info. - /// This method exists due to `&str` not implementing `Into`. + /// Write the error message to the client, then re-throw it. + /// + /// Trait [`UserFacingError`] acts as an allowlist for error types. /// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind. - pub async fn throw_error_str( + pub(crate) async fn throw_error( &mut self, - msg: &'static str, - error_kind: ErrorKind, + error: E, ctx: Option<&crate::context::RequestContext>, - ) -> Result { - self.write_format_message(msg, error_kind, ctx).await; + ) -> ReportedError + where + E: UserFacingError + Into, + { + let error_kind = error.get_error_kind(); + let msg = error.to_string_client(); if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User { tracing::info!( @@ -214,39 +197,39 @@ impl PqStream { ); } - Err(ReportedError { - source: anyhow::anyhow!(msg), - error_kind, - }) - } - - /// Write the error message using [`Self::write_format_message`], then re-throw it. - /// Trait [`UserFacingError`] acts as an allowlist for error types. - /// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind. - pub(crate) async fn throw_error( - &mut self, - error: E, - ctx: Option<&crate::context::RequestContext>, - ) -> Result - where - E: UserFacingError + Into, - { - let error_kind = error.get_error_kind(); - let msg = error.to_string_client(); - self.write_format_message(&msg, error_kind, ctx).await; - if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User { - tracing::info!( - kind=error_kind.to_metric_label(), - error=%error, - msg, - "forwarding error to user", - ); + let probe_msg; + let mut msg = &*msg; + if let Some(ctx) = ctx { + if ctx.get_testodrome_id().is_some() { + let tag = match error_kind { + ErrorKind::User => "client", + ErrorKind::ClientDisconnect => "client", + ErrorKind::RateLimit => "proxy", + ErrorKind::ServiceRateLimit => "proxy", + ErrorKind::Quota => "proxy", + ErrorKind::Service => "proxy", + ErrorKind::ControlPlane => "controlplane", + ErrorKind::Postgres => "other", + ErrorKind::Compute => "compute", + }; + probe_msg = typed_json::json!({ + "tag": tag, + "msg": msg, + "cold_start_info": ctx.cold_start_info(), + }) + .to_string(); + msg = &probe_msg; + } } - Err(ReportedError { - source: anyhow::anyhow!(error), - error_kind, - }) + // TODO: either preserve the error code from postgres, or assign error codes to proxy errors. + self.write.write_error(msg, SQLSTATE_INTERNAL_ERROR); + + self.flush() + .await + .unwrap_or_else(|e| tracing::debug!("write_message failed: {e}")); + + ReportedError::new(error) } } From 589bfdfd02c575b172a138ee5d174777da17e18a Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 2 Jun 2025 09:38:35 +0100 Subject: [PATCH 39/48] proxy: Changes to rate limits and GetEndpointAccessControl caches. (#12048) Precursor to https://github.com/neondatabase/cloud/issues/28333. We want per-endpoint configuration for rate limits, which will be distributed via the `GetEndpointAccessControl` API. This lays some of the ground work. 1. Allow the endpoint rate limiter to accept a custom leaky bucket config on check. 2. Remove the unused auth rate limiter, as I don't want to think about how it fits into this. 3. Refactor the caching of `GetEndpointAccessControl`, as it adds friction for adding new cached data to the API. That third one was rather large. I couldn't find any way to split it up. The core idea is that there's now only 2 cache APIs. `get_endpoint_access_controls` and `get_role_access_controls`. I'm pretty sure the behaviour is unchanged, except I did a drive by change to fix #8989 because it felt harmless. The change in question is that when a password validation fails, we eagerly expire the role cache if the role was cached for 5 minutes. This is to allow for edge cases where a user tries to connect with a reset password, but the cache never expires the entry due to some redis related quirk (lag, or misconfiguration, or cplane error) --- libs/metrics/src/hll.rs | 2 +- libs/utils/src/leaky_bucket.rs | 1 + proxy/src/auth/backend/mod.rs | 335 +++------ proxy/src/binary/local_proxy.rs | 16 +- proxy/src/binary/proxy.rs | 35 +- proxy/src/cache/project_info.rs | 678 +++++------------- proxy/src/cancellation.rs | 71 +- proxy/src/config.rs | 4 - proxy/src/context/mod.rs | 12 + .../control_plane/client/cplane_proxy_v1.rs | 319 +++----- proxy/src/control_plane/client/mock.rs | 74 +- proxy/src/control_plane/client/mod.rs | 75 +- proxy/src/control_plane/errors.rs | 8 + proxy/src/control_plane/mod.rs | 95 ++- proxy/src/proxy/mod.rs | 2 +- proxy/src/proxy/tests/mod.rs | 19 +- proxy/src/rate_limiter/leaky_bucket.rs | 10 +- proxy/src/rate_limiter/limiter.rs | 30 +- proxy/src/rate_limiter/mod.rs | 2 +- proxy/src/redis/notifications.rs | 41 +- proxy/src/serverless/backend.rs | 70 +- 21 files changed, 551 insertions(+), 1348 deletions(-) diff --git a/libs/metrics/src/hll.rs b/libs/metrics/src/hll.rs index 93f6a2b7cc..1a7d7a7e44 100644 --- a/libs/metrics/src/hll.rs +++ b/libs/metrics/src/hll.rs @@ -107,7 +107,7 @@ impl MetricType for HyperLogLogState { } impl HyperLogLogState { - pub fn measure(&self, item: &impl Hash) { + pub fn measure(&self, item: &(impl Hash + ?Sized)) { // changing the hasher will break compatibility with previous measurements. self.record(BuildHasherDefault::::default().hash_one(item)); } diff --git a/libs/utils/src/leaky_bucket.rs b/libs/utils/src/leaky_bucket.rs index 2398f92766..17e96bd0a9 100644 --- a/libs/utils/src/leaky_bucket.rs +++ b/libs/utils/src/leaky_bucket.rs @@ -28,6 +28,7 @@ use std::time::Duration; use tokio::sync::Notify; use tokio::time::Instant; +#[derive(Clone, Copy)] pub struct LeakyBucketConfig { /// This is the "time cost" of a single request unit. /// Should loosely represent how long it takes to handle a request unit in active resource time. diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index 8c892d90a0..735cb52f47 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -4,38 +4,31 @@ mod hacks; pub mod jwt; pub mod local; -use std::net::IpAddr; use std::sync::Arc; pub use console_redirect::ConsoleRedirectBackend; pub(crate) use console_redirect::ConsoleRedirectError; -use ipnet::{Ipv4Net, Ipv6Net}; use local::LocalBackend; use postgres_client::config::AuthKeys; use serde::{Deserialize, Serialize}; use tokio::io::{AsyncRead, AsyncWrite}; -use tracing::{debug, info, warn}; +use tracing::{debug, info}; -use crate::auth::credentials::check_peer_addr_is_in_list; -use crate::auth::{ - self, AuthError, ComputeUserInfoMaybeEndpoint, IpPattern, validate_password_and_exchange, -}; +use crate::auth::{self, AuthError, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange}; use crate::cache::Cached; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::client::ControlPlaneClient; use crate::control_plane::errors::GetAuthInfoError; use crate::control_plane::{ - self, AccessBlockerFlags, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps, - CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi, + self, AccessBlockerFlags, AuthSecret, CachedNodeInfo, ControlPlaneApi, EndpointAccessControl, + RoleAccessControl, }; use crate::intern::EndpointIdInt; -use crate::metrics::Metrics; use crate::pqproto::BeMessage; -use crate::protocol2::ConnectionInfoExtra; use crate::proxy::NeonOptions; use crate::proxy::connect_compute::ComputeConnectBackend; -use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter}; +use crate::rate_limiter::EndpointRateLimiter; use crate::stream::Stream; use crate::types::{EndpointCacheKey, EndpointId, RoleName}; use crate::{scram, stream}; @@ -201,78 +194,6 @@ impl TryFrom for ComputeUserInfo { } } -#[derive(PartialEq, PartialOrd, Hash, Eq, Ord, Debug, Copy, Clone)] -pub struct MaskedIp(IpAddr); - -impl MaskedIp { - fn new(value: IpAddr, prefix: u8) -> Self { - match value { - IpAddr::V4(v4) => Self(IpAddr::V4( - Ipv4Net::new(v4, prefix).map_or(v4, |x| x.trunc().addr()), - )), - IpAddr::V6(v6) => Self(IpAddr::V6( - Ipv6Net::new(v6, prefix).map_or(v6, |x| x.trunc().addr()), - )), - } - } -} - -// This can't be just per IP because that would limit some PaaS that share IP addresses -pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, MaskedIp)>; - -impl AuthenticationConfig { - pub(crate) fn check_rate_limit( - &self, - ctx: &RequestContext, - secret: AuthSecret, - endpoint: &EndpointId, - is_cleartext: bool, - ) -> auth::Result { - // we have validated the endpoint exists, so let's intern it. - let endpoint_int = EndpointIdInt::from(endpoint.normalize()); - - // only count the full hash count if password hack or websocket flow. - // in other words, if proxy needs to run the hashing - let password_weight = if is_cleartext { - match &secret { - #[cfg(any(test, feature = "testing"))] - AuthSecret::Md5(_) => 1, - AuthSecret::Scram(s) => s.iterations + 1, - } - } else { - // validating scram takes just 1 hmac_sha_256 operation. - 1 - }; - - let limit_not_exceeded = self.rate_limiter.check( - ( - endpoint_int, - MaskedIp::new(ctx.peer_addr(), self.rate_limit_ip_subnet), - ), - password_weight, - ); - - if !limit_not_exceeded { - warn!( - enabled = self.rate_limiter_enabled, - "rate limiting authentication" - ); - Metrics::get().proxy.requests_auth_rate_limits_total.inc(); - Metrics::get() - .proxy - .endpoints_auth_rate_limits - .get_metric() - .measure(endpoint); - - if self.rate_limiter_enabled { - return Err(auth::AuthError::too_many_connections()); - } - } - - Ok(secret) - } -} - /// True to its name, this function encapsulates our current auth trade-offs. /// Here, we choose the appropriate auth flow based on circumstances. /// @@ -285,7 +206,7 @@ async fn auth_quirks( allow_cleartext: bool, config: &'static AuthenticationConfig, endpoint_rate_limiter: Arc, -) -> auth::Result<(ComputeCredentials, Option>)> { +) -> auth::Result { // If there's no project so far, that entails that client doesn't // support SNI or other means of passing the endpoint (project) name. // We now expect to see a very specific payload in the place of password. @@ -301,55 +222,27 @@ async fn auth_quirks( debug!("fetching authentication info and allowlists"); - // check allowed list - let allowed_ips = if config.ip_allowlist_check_enabled { - let allowed_ips = api.get_allowed_ips(ctx, &info).await?; - if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) { - return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr())); - } - allowed_ips - } else { - Cached::new_uncached(Arc::new(vec![])) - }; + let access_controls = api + .get_endpoint_access_control(ctx, &info.endpoint, &info.user) + .await?; - // check if a VPC endpoint ID is coming in and if yes, if it's allowed - let access_blocks = api.get_block_public_or_vpc_access(ctx, &info).await?; - if config.is_vpc_acccess_proxy { - if access_blocks.vpc_access_blocked { - return Err(AuthError::NetworkNotAllowed); - } + access_controls.check( + ctx, + config.ip_allowlist_check_enabled, + config.is_vpc_acccess_proxy, + )?; - let incoming_vpc_endpoint_id = match ctx.extra() { - None => return Err(AuthError::MissingEndpointName), - Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(), - Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), - }; - let allowed_vpc_endpoint_ids = api.get_allowed_vpc_endpoint_ids(ctx, &info).await?; - // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. - if !allowed_vpc_endpoint_ids.is_empty() - && !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id) - { - return Err(AuthError::vpc_endpoint_id_not_allowed( - incoming_vpc_endpoint_id, - )); - } - } else if access_blocks.public_access_blocked { - return Err(AuthError::NetworkNotAllowed); - } - - if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) { + let endpoint = EndpointIdInt::from(&info.endpoint); + let rate_limit_config = None; + if !endpoint_rate_limiter.check(endpoint, rate_limit_config, 1) { return Err(AuthError::too_many_connections()); } - let cached_secret = api.get_role_secret(ctx, &info).await?; - let (cached_entry, secret) = cached_secret.take_value(); + let role_access = api + .get_role_access_control(ctx, &info.endpoint, &info.user) + .await?; - let secret = if let Some(secret) = secret { - config.check_rate_limit( - ctx, - secret, - &info.endpoint, - unauthenticated_password.is_some() || allow_cleartext, - )? + let secret = if let Some(secret) = role_access.secret { + secret } else { // If we don't have an authentication secret, we mock one to // prevent malicious probing (possible due to missing protocol steps). @@ -369,14 +262,8 @@ async fn auth_quirks( ) .await { - Ok(keys) => Ok((keys, Some(allowed_ips.as_ref().clone()))), - Err(e) => { - if e.is_password_failed() { - // The password could have been changed, so we invalidate the cache. - cached_entry.invalidate(); - } - Err(e) - } + Ok(keys) => Ok(keys), + Err(e) => Err(e), } } @@ -439,7 +326,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> { allow_cleartext: bool, config: &'static AuthenticationConfig, endpoint_rate_limiter: Arc, - ) -> auth::Result<(Backend<'a, ComputeCredentials>, Option>)> { + ) -> auth::Result> { let res = match self { Self::ControlPlane(api, user_info) => { debug!( @@ -448,17 +335,35 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> { "performing authentication using the console" ); - let (credentials, ip_allowlist) = auth_quirks( + let auth_res = auth_quirks( ctx, &*api, - user_info, + user_info.clone(), client, allow_cleartext, config, endpoint_rate_limiter, ) - .await?; - Ok((Backend::ControlPlane(api, credentials), ip_allowlist)) + .await; + match auth_res { + Ok(credentials) => Ok(Backend::ControlPlane(api, credentials)), + Err(e) => { + // The password could have been changed, so we invalidate the cache. + // We should only invalidate the cache if the TTL might have expired. + if e.is_password_failed() { + #[allow(irrefutable_let_patterns)] + if let ControlPlaneClient::ProxyV1(api) = &*api { + if let Some(ep) = &user_info.endpoint_id { + api.caches + .project_info + .maybe_invalidate_role_secret(ep, &user_info.user); + } + } + } + + Err(e) + } + } } Self::Local(_) => { return Err(auth::AuthError::bad_auth_method("invalid for local proxy")); @@ -475,44 +380,30 @@ impl Backend<'_, ComputeUserInfo> { pub(crate) async fn get_role_secret( &self, ctx: &RequestContext, - ) -> Result { - match self { - Self::ControlPlane(api, user_info) => api.get_role_secret(ctx, user_info).await, - Self::Local(_) => Ok(Cached::new_uncached(None)), - } - } - - pub(crate) async fn get_allowed_ips( - &self, - ctx: &RequestContext, - ) -> Result { - match self { - Self::ControlPlane(api, user_info) => api.get_allowed_ips(ctx, user_info).await, - Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))), - } - } - - pub(crate) async fn get_allowed_vpc_endpoint_ids( - &self, - ctx: &RequestContext, - ) -> Result { + ) -> Result { match self { Self::ControlPlane(api, user_info) => { - api.get_allowed_vpc_endpoint_ids(ctx, user_info).await + api.get_role_access_control(ctx, &user_info.endpoint, &user_info.user) + .await } - Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))), + Self::Local(_) => Ok(RoleAccessControl { secret: None }), } } - pub(crate) async fn get_block_public_or_vpc_access( + pub(crate) async fn get_endpoint_access_control( &self, ctx: &RequestContext, - ) -> Result { + ) -> Result { match self { Self::ControlPlane(api, user_info) => { - api.get_block_public_or_vpc_access(ctx, user_info).await + api.get_endpoint_access_control(ctx, &user_info.endpoint, &user_info.user) + .await } - Self::Local(_) => Ok(Cached::new_uncached(AccessBlockerFlags::default())), + Self::Local(_) => Ok(EndpointAccessControl { + allowed_ips: Arc::new(vec![]), + allowed_vpce: Arc::new(vec![]), + flags: AccessBlockerFlags::default(), + }), } } } @@ -541,9 +432,7 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials> { mod tests { #![allow(clippy::unimplemented, clippy::unwrap_used)] - use std::net::IpAddr; use std::sync::Arc; - use std::time::Duration; use bytes::BytesMut; use control_plane::AuthSecret; @@ -554,18 +443,16 @@ mod tests { use postgres_protocol::message::frontend; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; + use super::auth_quirks; use super::jwt::JwkCache; - use super::{AuthRateLimiter, auth_quirks}; - use crate::auth::backend::MaskedIp; use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern}; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::{ - self, AccessBlockerFlags, CachedAccessBlockerFlags, CachedAllowedIps, - CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, + self, AccessBlockerFlags, CachedNodeInfo, EndpointAccessControl, RoleAccessControl, }; use crate::proxy::NeonOptions; - use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo}; + use crate::rate_limiter::EndpointRateLimiter; use crate::scram::ServerSecret; use crate::scram::threadpool::ThreadPool; use crate::stream::{PqStream, Stream}; @@ -578,46 +465,34 @@ mod tests { } impl control_plane::ControlPlaneApi for Auth { - async fn get_role_secret( + async fn get_role_access_control( &self, _ctx: &RequestContext, - _user_info: &super::ComputeUserInfo, - ) -> Result { - Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone()))) + _endpoint: &crate::types::EndpointId, + _role: &crate::types::RoleName, + ) -> Result { + Ok(RoleAccessControl { + secret: Some(self.secret.clone()), + }) } - async fn get_allowed_ips( + async fn get_endpoint_access_control( &self, _ctx: &RequestContext, - _user_info: &super::ComputeUserInfo, - ) -> Result { - Ok(CachedAllowedIps::new_uncached(Arc::new(self.ips.clone()))) - } - - async fn get_allowed_vpc_endpoint_ids( - &self, - _ctx: &RequestContext, - _user_info: &super::ComputeUserInfo, - ) -> Result { - Ok(CachedAllowedVpcEndpointIds::new_uncached(Arc::new( - self.vpc_endpoint_ids.clone(), - ))) - } - - async fn get_block_public_or_vpc_access( - &self, - _ctx: &RequestContext, - _user_info: &super::ComputeUserInfo, - ) -> Result { - Ok(CachedAccessBlockerFlags::new_uncached( - self.access_blocker_flags.clone(), - )) + _endpoint: &crate::types::EndpointId, + _role: &crate::types::RoleName, + ) -> Result { + Ok(EndpointAccessControl { + allowed_ips: Arc::new(self.ips.clone()), + allowed_vpce: Arc::new(self.vpc_endpoint_ids.clone()), + flags: self.access_blocker_flags, + }) } async fn get_endpoint_jwks( &self, _ctx: &RequestContext, - _endpoint: crate::types::EndpointId, + _endpoint: &crate::types::EndpointId, ) -> Result, control_plane::errors::GetEndpointJwksError> { unimplemented!() @@ -636,9 +511,6 @@ mod tests { jwks_cache: JwkCache::default(), thread_pool: ThreadPool::new(1), scram_protocol_timeout: std::time::Duration::from_secs(5), - rate_limiter_enabled: true, - rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET), - rate_limit_ip_subnet: 64, ip_allowlist_check_enabled: true, is_vpc_acccess_proxy: false, is_auth_broker: false, @@ -655,51 +527,6 @@ mod tests { } } - #[test] - fn masked_ip() { - let ip_a = IpAddr::V4([127, 0, 0, 1].into()); - let ip_b = IpAddr::V4([127, 0, 0, 2].into()); - let ip_c = IpAddr::V4([192, 168, 1, 101].into()); - let ip_d = IpAddr::V4([192, 168, 1, 102].into()); - let ip_e = IpAddr::V6("abcd:abcd:abcd:abcd:abcd:abcd:abcd:abcd".parse().unwrap()); - let ip_f = IpAddr::V6("abcd:abcd:abcd:abcd:1234:abcd:abcd:abcd".parse().unwrap()); - - assert_ne!(MaskedIp::new(ip_a, 64), MaskedIp::new(ip_b, 64)); - assert_ne!(MaskedIp::new(ip_a, 32), MaskedIp::new(ip_b, 32)); - assert_eq!(MaskedIp::new(ip_a, 30), MaskedIp::new(ip_b, 30)); - assert_eq!(MaskedIp::new(ip_c, 30), MaskedIp::new(ip_d, 30)); - - assert_ne!(MaskedIp::new(ip_e, 128), MaskedIp::new(ip_f, 128)); - assert_eq!(MaskedIp::new(ip_e, 64), MaskedIp::new(ip_f, 64)); - } - - #[test] - fn test_default_auth_rate_limit_set() { - // these values used to exceed u32::MAX - assert_eq!( - RateBucketInfo::DEFAULT_AUTH_SET, - [ - RateBucketInfo { - interval: Duration::from_secs(1), - max_rpi: 1000 * 4096, - }, - RateBucketInfo { - interval: Duration::from_secs(60), - max_rpi: 600 * 4096 * 60, - }, - RateBucketInfo { - interval: Duration::from_secs(600), - max_rpi: 300 * 4096 * 600, - } - ] - ); - - for x in RateBucketInfo::DEFAULT_AUTH_SET { - let y = x.to_string().parse().unwrap(); - assert_eq!(x, y); - } - } - #[tokio::test] async fn auth_quirks_scram() { let (mut client, server) = tokio::io::duplex(1024); @@ -888,7 +715,7 @@ mod tests { .await .unwrap(); - assert_eq!(creds.0.info.endpoint, "my-endpoint"); + assert_eq!(creds.info.endpoint, "my-endpoint"); handle.await.unwrap(); } diff --git a/proxy/src/binary/local_proxy.rs b/proxy/src/binary/local_proxy.rs index a566383390..ba10fce7b4 100644 --- a/proxy/src/binary/local_proxy.rs +++ b/proxy/src/binary/local_proxy.rs @@ -32,9 +32,7 @@ use crate::ext::TaskExt; use crate::http::health_server::AppMetrics; use crate::intern::RoleNameInt; use crate::metrics::{Metrics, ThreadPoolMetrics}; -use crate::rate_limiter::{ - BucketRateLimiter, EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo, -}; +use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo}; use crate::scram::threadpool::ThreadPool; use crate::serverless::cancel_set::CancelSet; use crate::serverless::{self, GlobalConnPoolOptions}; @@ -69,15 +67,6 @@ struct LocalProxyCliArgs { /// Can be given multiple times for different bucket sizes. #[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)] user_rps_limit: Vec, - /// Whether the auth rate limiter actually takes effect (for testing) - #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] - auth_rate_limit_enabled: bool, - /// Authentication rate limiter max number of hashes per second. - #[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)] - auth_rate_limit: Vec, - /// The IP subnet to use when considering whether two IP addresses are considered the same. - #[clap(long, default_value_t = 64)] - auth_rate_limit_ip_subnet: u8, /// Whether to retry the connection to the compute node #[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)] connect_to_compute_retry: String, @@ -282,9 +271,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig jwks_cache: JwkCache::default(), thread_pool: ThreadPool::new(0), scram_protocol_timeout: Duration::from_secs(10), - rate_limiter_enabled: false, - rate_limiter: BucketRateLimiter::new(vec![]), - rate_limit_ip_subnet: 64, ip_allowlist_check_enabled: true, is_vpc_acccess_proxy: false, is_auth_broker: false, diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 9a3903ba9a..dcae263647 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -20,7 +20,7 @@ use utils::sentry_init::init_sentry; use utils::{project_build_tag, project_git_version}; use crate::auth::backend::jwt::JwkCache; -use crate::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned}; +use crate::auth::backend::{ConsoleRedirectBackend, MaybeOwned}; use crate::cancellation::{CancellationHandler, handle_cancel_messages}; use crate::config::{ self, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig, ProjectInfoCacheOptions, @@ -29,9 +29,7 @@ use crate::config::{ use crate::context::parquet::ParquetUploadArgs; use crate::http::health_server::AppMetrics; use crate::metrics::Metrics; -use crate::rate_limiter::{ - EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo, WakeComputeRateLimiter, -}; +use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::redis::kv_ops::RedisKVClient; use crate::redis::{elasticache, notifications}; @@ -154,15 +152,6 @@ struct ProxyCliArgs { /// Wake compute rate limiter max number of requests per second. #[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)] wake_compute_limit: Vec, - /// Whether the auth rate limiter actually takes effect (for testing) - #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] - auth_rate_limit_enabled: bool, - /// Authentication rate limiter max number of hashes per second. - #[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)] - auth_rate_limit: Vec, - /// The IP subnet to use when considering whether two IP addresses are considered the same. - #[clap(long, default_value_t = 64)] - auth_rate_limit_ip_subnet: u8, /// Redis rate limiter max number of requests per second. #[clap(long, default_values_t = RateBucketInfo::DEFAULT_REDIS_SET)] redis_rps_limit: Vec, @@ -410,22 +399,9 @@ pub async fn run() -> anyhow::Result<()> { Some(tx_cancel), )); - // bit of a hack - find the min rps and max rps supported and turn it into - // leaky bucket config instead - let max = args - .endpoint_rps_limit - .iter() - .map(|x| x.rps()) - .max_by(f64::total_cmp) - .unwrap_or(EndpointRateLimiter::DEFAULT.max); - let rps = args - .endpoint_rps_limit - .iter() - .map(|x| x.rps()) - .min_by(f64::total_cmp) - .unwrap_or(EndpointRateLimiter::DEFAULT.rps); let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards( - LeakyBucketConfig { rps, max }, + RateBucketInfo::to_leaky_bucket(&args.endpoint_rps_limit) + .unwrap_or(EndpointRateLimiter::DEFAULT), 64, )); @@ -678,9 +654,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { jwks_cache: JwkCache::default(), thread_pool, scram_protocol_timeout: args.scram_protocol_timeout, - rate_limiter_enabled: args.auth_rate_limit_enabled, - rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()), - rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet, ip_allowlist_check_enabled: !args.is_private_access_proxy, is_vpc_acccess_proxy: args.is_private_access_proxy, is_auth_broker: args.is_auth_broker, diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index 60678b034d..81c88e3ddd 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -1,30 +1,25 @@ -use std::collections::HashSet; +use std::collections::{HashMap, HashSet, hash_map}; use std::convert::Infallible; -use std::sync::Arc; use std::sync::atomic::AtomicU64; use std::time::Duration; use async_trait::async_trait; use clashmap::ClashMap; +use clashmap::mapref::one::Ref; use rand::{Rng, thread_rng}; -use smol_str::SmolStr; use tokio::sync::Mutex; use tokio::time::Instant; use tracing::{debug, info}; -use super::{Cache, Cached}; -use crate::auth::IpPattern; use crate::config::ProjectInfoCacheOptions; -use crate::control_plane::{AccessBlockerFlags, AuthSecret}; +use crate::control_plane::{EndpointAccessControl, RoleAccessControl}; use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt}; use crate::types::{EndpointId, RoleName}; #[async_trait] pub(crate) trait ProjectInfoCache { - fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt); - fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec); - fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt); - fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt); + fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt); + fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt); fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt); async fn decrement_active_listeners(&self); async fn increment_active_listeners(&self); @@ -42,6 +37,10 @@ impl Entry { value, } } + + pub(crate) fn get(&self, valid_since: Instant) -> Option<&T> { + (valid_since < self.created_at).then_some(&self.value) + } } impl From for Entry { @@ -50,101 +49,32 @@ impl From for Entry { } } -#[derive(Default)] struct EndpointInfo { - secret: std::collections::HashMap>>, - allowed_ips: Option>>>, - block_public_or_vpc_access: Option>, - allowed_vpc_endpoint_ids: Option>>>, + role_controls: HashMap>, + controls: Option>, } impl EndpointInfo { - fn check_ignore_cache(ignore_cache_since: Option, created_at: Instant) -> bool { - match ignore_cache_since { - None => false, - Some(t) => t < created_at, - } - } pub(crate) fn get_role_secret( &self, role_name: RoleNameInt, valid_since: Instant, - ignore_cache_since: Option, - ) -> Option<(Option, bool)> { - if let Some(secret) = self.secret.get(&role_name) { - if valid_since < secret.created_at { - return Some(( - secret.value.clone(), - Self::check_ignore_cache(ignore_cache_since, secret.created_at), - )); - } - } - None + ) -> Option { + let controls = self.role_controls.get(&role_name)?; + controls.get(valid_since).cloned() } - pub(crate) fn get_allowed_ips( - &self, - valid_since: Instant, - ignore_cache_since: Option, - ) -> Option<(Arc>, bool)> { - if let Some(allowed_ips) = &self.allowed_ips { - if valid_since < allowed_ips.created_at { - return Some(( - allowed_ips.value.clone(), - Self::check_ignore_cache(ignore_cache_since, allowed_ips.created_at), - )); - } - } - None - } - pub(crate) fn get_allowed_vpc_endpoint_ids( - &self, - valid_since: Instant, - ignore_cache_since: Option, - ) -> Option<(Arc>, bool)> { - if let Some(allowed_vpc_endpoint_ids) = &self.allowed_vpc_endpoint_ids { - if valid_since < allowed_vpc_endpoint_ids.created_at { - return Some(( - allowed_vpc_endpoint_ids.value.clone(), - Self::check_ignore_cache( - ignore_cache_since, - allowed_vpc_endpoint_ids.created_at, - ), - )); - } - } - None - } - pub(crate) fn get_block_public_or_vpc_access( - &self, - valid_since: Instant, - ignore_cache_since: Option, - ) -> Option<(AccessBlockerFlags, bool)> { - if let Some(block_public_or_vpc_access) = &self.block_public_or_vpc_access { - if valid_since < block_public_or_vpc_access.created_at { - return Some(( - block_public_or_vpc_access.value.clone(), - Self::check_ignore_cache( - ignore_cache_since, - block_public_or_vpc_access.created_at, - ), - )); - } - } - None + pub(crate) fn get_controls(&self, valid_since: Instant) -> Option { + let controls = self.controls.as_ref()?; + controls.get(valid_since).cloned() } - pub(crate) fn invalidate_allowed_ips(&mut self) { - self.allowed_ips = None; - } - pub(crate) fn invalidate_allowed_vpc_endpoint_ids(&mut self) { - self.allowed_vpc_endpoint_ids = None; - } - pub(crate) fn invalidate_block_public_or_vpc_access(&mut self) { - self.block_public_or_vpc_access = None; + pub(crate) fn invalidate_endpoint(&mut self) { + self.controls = None; } + pub(crate) fn invalidate_role_secret(&mut self, role_name: RoleNameInt) { - self.secret.remove(&role_name); + self.role_controls.remove(&role_name); } } @@ -170,34 +100,22 @@ pub struct ProjectInfoCacheImpl { #[async_trait] impl ProjectInfoCache for ProjectInfoCacheImpl { - fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec) { - info!( - "invalidating allowed vpc endpoint ids for projects `{}`", - project_ids - .iter() - .map(|id| id.to_string()) - .collect::>() - .join(", ") - ); - for project_id in project_ids { - let endpoints = self - .project2ep - .get(&project_id) - .map(|kv| kv.value().clone()) - .unwrap_or_default(); - for endpoint_id in endpoints { - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_allowed_vpc_endpoint_ids(); - } + fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) { + info!("invalidating endpoint access for project `{project_id}`"); + let endpoints = self + .project2ep + .get(&project_id) + .map(|kv| kv.value().clone()) + .unwrap_or_default(); + for endpoint_id in endpoints { + if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { + endpoint_info.invalidate_endpoint(); } } } - fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt) { - info!( - "invalidating allowed vpc endpoint ids for org `{}`", - account_id - ); + fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) { + info!("invalidating endpoint access for org `{account_id}`"); let endpoints = self .account2ep .get(&account_id) @@ -205,41 +123,11 @@ impl ProjectInfoCache for ProjectInfoCacheImpl { .unwrap_or_default(); for endpoint_id in endpoints { if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_allowed_vpc_endpoint_ids(); + endpoint_info.invalidate_endpoint(); } } } - fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt) { - info!( - "invalidating block public or vpc access for project `{}`", - project_id - ); - let endpoints = self - .project2ep - .get(&project_id) - .map(|kv| kv.value().clone()) - .unwrap_or_default(); - for endpoint_id in endpoints { - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_block_public_or_vpc_access(); - } - } - } - - fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt) { - info!("invalidating allowed ips for project `{}`", project_id); - let endpoints = self - .project2ep - .get(&project_id) - .map(|kv| kv.value().clone()) - .unwrap_or_default(); - for endpoint_id in endpoints { - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_allowed_ips(); - } - } - } fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt) { info!( "invalidating role secret for project_id `{}` and role_name `{}`", @@ -256,6 +144,7 @@ impl ProjectInfoCache for ProjectInfoCacheImpl { } } } + async fn decrement_active_listeners(&self) { let mut listeners_guard = self.active_listeners_lock.lock().await; if *listeners_guard == 0 { @@ -293,155 +182,71 @@ impl ProjectInfoCacheImpl { } } + fn get_endpoint_cache( + &self, + endpoint_id: &EndpointId, + ) -> Option> { + let endpoint_id = EndpointIdInt::get(endpoint_id)?; + self.cache.get(&endpoint_id) + } + pub(crate) fn get_role_secret( &self, endpoint_id: &EndpointId, role_name: &RoleName, - ) -> Option>> { - let endpoint_id = EndpointIdInt::get(endpoint_id)?; + ) -> Option { + let valid_since = self.get_cache_times(); let role_name = RoleNameInt::get(role_name)?; - let (valid_since, ignore_cache_since) = self.get_cache_times(); - let endpoint_info = self.cache.get(&endpoint_id)?; - let (value, ignore_cache) = - endpoint_info.get_role_secret(role_name, valid_since, ignore_cache_since)?; - if !ignore_cache { - let cached = Cached { - token: Some(( - self, - CachedLookupInfo::new_role_secret(endpoint_id, role_name), - )), - value, - }; - return Some(cached); - } - Some(Cached::new_uncached(value)) - } - pub(crate) fn get_allowed_ips( - &self, - endpoint_id: &EndpointId, - ) -> Option>>> { - let endpoint_id = EndpointIdInt::get(endpoint_id)?; - let (valid_since, ignore_cache_since) = self.get_cache_times(); - let endpoint_info = self.cache.get(&endpoint_id)?; - let value = endpoint_info.get_allowed_ips(valid_since, ignore_cache_since); - let (value, ignore_cache) = value?; - if !ignore_cache { - let cached = Cached { - token: Some((self, CachedLookupInfo::new_allowed_ips(endpoint_id))), - value, - }; - return Some(cached); - } - Some(Cached::new_uncached(value)) - } - pub(crate) fn get_allowed_vpc_endpoint_ids( - &self, - endpoint_id: &EndpointId, - ) -> Option>>> { - let endpoint_id = EndpointIdInt::get(endpoint_id)?; - let (valid_since, ignore_cache_since) = self.get_cache_times(); - let endpoint_info = self.cache.get(&endpoint_id)?; - let value = endpoint_info.get_allowed_vpc_endpoint_ids(valid_since, ignore_cache_since); - let (value, ignore_cache) = value?; - if !ignore_cache { - let cached = Cached { - token: Some(( - self, - CachedLookupInfo::new_allowed_vpc_endpoint_ids(endpoint_id), - )), - value, - }; - return Some(cached); - } - Some(Cached::new_uncached(value)) - } - pub(crate) fn get_block_public_or_vpc_access( - &self, - endpoint_id: &EndpointId, - ) -> Option> { - let endpoint_id = EndpointIdInt::get(endpoint_id)?; - let (valid_since, ignore_cache_since) = self.get_cache_times(); - let endpoint_info = self.cache.get(&endpoint_id)?; - let value = endpoint_info.get_block_public_or_vpc_access(valid_since, ignore_cache_since); - let (value, ignore_cache) = value?; - if !ignore_cache { - let cached = Cached { - token: Some(( - self, - CachedLookupInfo::new_block_public_or_vpc_access(endpoint_id), - )), - value, - }; - return Some(cached); - } - Some(Cached::new_uncached(value)) + let endpoint_info = self.get_endpoint_cache(endpoint_id)?; + endpoint_info.get_role_secret(role_name, valid_since) } - pub(crate) fn insert_role_secret( + pub(crate) fn get_endpoint_access( &self, - project_id: ProjectIdInt, - endpoint_id: EndpointIdInt, - role_name: RoleNameInt, - secret: Option, - ) { - if self.cache.len() >= self.config.size { - // If there are too many entries, wait until the next gc cycle. - return; - } - self.insert_project2endpoint(project_id, endpoint_id); - let mut entry = self.cache.entry(endpoint_id).or_default(); - if entry.secret.len() < self.config.max_roles { - entry.secret.insert(role_name, secret.into()); - } + endpoint_id: &EndpointId, + ) -> Option { + let valid_since = self.get_cache_times(); + let endpoint_info = self.get_endpoint_cache(endpoint_id)?; + endpoint_info.get_controls(valid_since) } - pub(crate) fn insert_allowed_ips( - &self, - project_id: ProjectIdInt, - endpoint_id: EndpointIdInt, - allowed_ips: Arc>, - ) { - if self.cache.len() >= self.config.size { - // If there are too many entries, wait until the next gc cycle. - return; - } - self.insert_project2endpoint(project_id, endpoint_id); - self.cache.entry(endpoint_id).or_default().allowed_ips = Some(allowed_ips.into()); - } - pub(crate) fn insert_allowed_vpc_endpoint_ids( + + pub(crate) fn insert_endpoint_access( &self, account_id: Option, project_id: ProjectIdInt, endpoint_id: EndpointIdInt, - allowed_vpc_endpoint_ids: Arc>, + role_name: RoleNameInt, + controls: EndpointAccessControl, + role_controls: RoleAccessControl, ) { - if self.cache.len() >= self.config.size { - // If there are too many entries, wait until the next gc cycle. - return; - } if let Some(account_id) = account_id { self.insert_account2endpoint(account_id, endpoint_id); } self.insert_project2endpoint(project_id, endpoint_id); - self.cache - .entry(endpoint_id) - .or_default() - .allowed_vpc_endpoint_ids = Some(allowed_vpc_endpoint_ids.into()); - } - pub(crate) fn insert_block_public_or_vpc_access( - &self, - project_id: ProjectIdInt, - endpoint_id: EndpointIdInt, - access_blockers: AccessBlockerFlags, - ) { + if self.cache.len() >= self.config.size { // If there are too many entries, wait until the next gc cycle. return; } - self.insert_project2endpoint(project_id, endpoint_id); - self.cache - .entry(endpoint_id) - .or_default() - .block_public_or_vpc_access = Some(access_blockers.into()); + + let controls = Entry::from(controls); + let role_controls = Entry::from(role_controls); + + match self.cache.entry(endpoint_id) { + clashmap::Entry::Vacant(e) => { + e.insert(EndpointInfo { + role_controls: HashMap::from_iter([(role_name, role_controls)]), + controls: Some(controls), + }); + } + clashmap::Entry::Occupied(mut e) => { + let ep = e.get_mut(); + ep.controls = Some(controls); + if ep.role_controls.len() < self.config.max_roles { + ep.role_controls.insert(role_name, role_controls); + } + } + } } fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) { @@ -452,6 +257,7 @@ impl ProjectInfoCacheImpl { .insert(project_id, HashSet::from([endpoint_id])); } } + fn insert_account2endpoint(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt) { if let Some(mut endpoints) = self.account2ep.get_mut(&account_id) { endpoints.insert(endpoint_id); @@ -460,21 +266,57 @@ impl ProjectInfoCacheImpl { .insert(account_id, HashSet::from([endpoint_id])); } } - fn get_cache_times(&self) -> (Instant, Option) { - let mut valid_since = Instant::now() - self.config.ttl; - // Only ignore cache if ttl is disabled. + + fn ignore_ttl_since(&self) -> Option { let ttl_disabled_since_us = self .ttl_disabled_since_us .load(std::sync::atomic::Ordering::Relaxed); - let ignore_cache_since = if ttl_disabled_since_us == u64::MAX { - None - } else { - let ignore_cache_since = self.start_time + Duration::from_micros(ttl_disabled_since_us); + + if ttl_disabled_since_us == u64::MAX { + return None; + } + + Some(self.start_time + Duration::from_micros(ttl_disabled_since_us)) + } + + fn get_cache_times(&self) -> Instant { + let mut valid_since = Instant::now() - self.config.ttl; + if let Some(ignore_ttl_since) = self.ignore_ttl_since() { // We are fine if entry is not older than ttl or was added before we are getting notifications. - valid_since = valid_since.min(ignore_cache_since); - Some(ignore_cache_since) + valid_since = valid_since.min(ignore_ttl_since); + } + valid_since + } + + pub fn maybe_invalidate_role_secret(&self, endpoint_id: &EndpointId, role_name: &RoleName) { + let Some(endpoint_id) = EndpointIdInt::get(endpoint_id) else { + return; }; - (valid_since, ignore_cache_since) + let Some(role_name) = RoleNameInt::get(role_name) else { + return; + }; + + let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) else { + return; + }; + + let entry = endpoint_info.role_controls.entry(role_name); + let hash_map::Entry::Occupied(role_controls) = entry else { + return; + }; + + let created_at = role_controls.get().created_at; + let expire = match self.ignore_ttl_since() { + // if ignoring TTL, we should still try and roll the password if it's old + // and we the client gave an incorrect password. There could be some lag on the redis channel. + Some(_) => created_at + self.config.ttl < Instant::now(), + // edge case: redis is down, let's be generous and invalidate the cache immediately. + None => true, + }; + + if expire { + role_controls.remove(); + } } pub async fn gc_worker(&self) -> anyhow::Result { @@ -509,84 +351,12 @@ impl ProjectInfoCacheImpl { } } -/// Lookup info for project info cache. -/// This is used to invalidate cache entries. -pub(crate) struct CachedLookupInfo { - /// Search by this key. - endpoint_id: EndpointIdInt, - lookup_type: LookupType, -} - -impl CachedLookupInfo { - pub(self) fn new_role_secret(endpoint_id: EndpointIdInt, role_name: RoleNameInt) -> Self { - Self { - endpoint_id, - lookup_type: LookupType::RoleSecret(role_name), - } - } - pub(self) fn new_allowed_ips(endpoint_id: EndpointIdInt) -> Self { - Self { - endpoint_id, - lookup_type: LookupType::AllowedIps, - } - } - pub(self) fn new_allowed_vpc_endpoint_ids(endpoint_id: EndpointIdInt) -> Self { - Self { - endpoint_id, - lookup_type: LookupType::AllowedVpcEndpointIds, - } - } - pub(self) fn new_block_public_or_vpc_access(endpoint_id: EndpointIdInt) -> Self { - Self { - endpoint_id, - lookup_type: LookupType::BlockPublicOrVpcAccess, - } - } -} - -enum LookupType { - RoleSecret(RoleNameInt), - AllowedIps, - AllowedVpcEndpointIds, - BlockPublicOrVpcAccess, -} - -impl Cache for ProjectInfoCacheImpl { - type Key = SmolStr; - // Value is not really used here, but we need to specify it. - type Value = SmolStr; - - type LookupInfo = CachedLookupInfo; - - fn invalidate(&self, key: &Self::LookupInfo) { - match &key.lookup_type { - LookupType::RoleSecret(role_name) => { - if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { - endpoint_info.invalidate_role_secret(*role_name); - } - } - LookupType::AllowedIps => { - if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { - endpoint_info.invalidate_allowed_ips(); - } - } - LookupType::AllowedVpcEndpointIds => { - if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { - endpoint_info.invalidate_allowed_vpc_endpoint_ids(); - } - } - LookupType::BlockPublicOrVpcAccess => { - if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { - endpoint_info.invalidate_block_public_or_vpc_access(); - } - } - } - } -} - #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; + use crate::control_plane::{AccessBlockerFlags, AuthSecret}; use crate::scram::ServerSecret; use crate::types::ProjectId; @@ -601,6 +371,8 @@ mod tests { }); let project_id: ProjectId = "project".into(); let endpoint_id: EndpointId = "endpoint".into(); + let account_id: Option = None; + let user1: RoleName = "user1".into(); let user2: RoleName = "user2".into(); let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32]))); @@ -609,183 +381,73 @@ mod tests { "127.0.0.1".parse().unwrap(), "127.0.0.2".parse().unwrap(), ]); - cache.insert_role_secret( + + cache.insert_endpoint_access( + account_id, (&project_id).into(), (&endpoint_id).into(), (&user1).into(), - secret1.clone(), + EndpointAccessControl { + allowed_ips: allowed_ips.clone(), + allowed_vpce: Arc::new(vec![]), + flags: AccessBlockerFlags::default(), + }, + RoleAccessControl { + secret: secret1.clone(), + }, ); - cache.insert_role_secret( + + cache.insert_endpoint_access( + account_id, (&project_id).into(), (&endpoint_id).into(), (&user2).into(), - secret2.clone(), - ); - cache.insert_allowed_ips( - (&project_id).into(), - (&endpoint_id).into(), - allowed_ips.clone(), + EndpointAccessControl { + allowed_ips: allowed_ips.clone(), + allowed_vpce: Arc::new(vec![]), + flags: AccessBlockerFlags::default(), + }, + RoleAccessControl { + secret: secret2.clone(), + }, ); let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); - assert!(cached.cached()); - assert_eq!(cached.value, secret1); + assert_eq!(cached.secret, secret1); + let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap(); - assert!(cached.cached()); - assert_eq!(cached.value, secret2); + assert_eq!(cached.secret, secret2); // Shouldn't add more than 2 roles. let user3: RoleName = "user3".into(); let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32]))); - cache.insert_role_secret( + + cache.insert_endpoint_access( + account_id, (&project_id).into(), (&endpoint_id).into(), (&user3).into(), - secret3.clone(), + EndpointAccessControl { + allowed_ips: allowed_ips.clone(), + allowed_vpce: Arc::new(vec![]), + flags: AccessBlockerFlags::default(), + }, + RoleAccessControl { + secret: secret3.clone(), + }, ); + assert!(cache.get_role_secret(&endpoint_id, &user3).is_none()); - let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); - assert!(cached.cached()); - assert_eq!(cached.value, allowed_ips); + let cached = cache.get_endpoint_access(&endpoint_id).unwrap(); + assert_eq!(cached.allowed_ips, allowed_ips); tokio::time::advance(Duration::from_secs(2)).await; let cached = cache.get_role_secret(&endpoint_id, &user1); assert!(cached.is_none()); let cached = cache.get_role_secret(&endpoint_id, &user2); assert!(cached.is_none()); - let cached = cache.get_allowed_ips(&endpoint_id); + let cached = cache.get_endpoint_access(&endpoint_id); assert!(cached.is_none()); } - - #[tokio::test] - async fn test_project_info_cache_invalidations() { - tokio::time::pause(); - let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions { - size: 2, - max_roles: 2, - ttl: Duration::from_secs(1), - gc_interval: Duration::from_secs(600), - })); - cache.clone().increment_active_listeners().await; - tokio::time::advance(Duration::from_secs(2)).await; - - let project_id: ProjectId = "project".into(); - let endpoint_id: EndpointId = "endpoint".into(); - let user1: RoleName = "user1".into(); - let user2: RoleName = "user2".into(); - let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32]))); - let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32]))); - let allowed_ips = Arc::new(vec![ - "127.0.0.1".parse().unwrap(), - "127.0.0.2".parse().unwrap(), - ]); - cache.insert_role_secret( - (&project_id).into(), - (&endpoint_id).into(), - (&user1).into(), - secret1.clone(), - ); - cache.insert_role_secret( - (&project_id).into(), - (&endpoint_id).into(), - (&user2).into(), - secret2.clone(), - ); - cache.insert_allowed_ips( - (&project_id).into(), - (&endpoint_id).into(), - allowed_ips.clone(), - ); - - tokio::time::advance(Duration::from_secs(2)).await; - // Nothing should be invalidated. - - let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); - // TTL is disabled, so it should be impossible to invalidate this value. - assert!(!cached.cached()); - assert_eq!(cached.value, secret1); - - cached.invalidate(); // Shouldn't do anything. - let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); - assert_eq!(cached.value, secret1); - - let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap(); - assert!(!cached.cached()); - assert_eq!(cached.value, secret2); - - // The only way to invalidate this value is to invalidate via the api. - cache.invalidate_role_secret_for_project((&project_id).into(), (&user2).into()); - assert!(cache.get_role_secret(&endpoint_id, &user2).is_none()); - - let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); - assert!(!cached.cached()); - assert_eq!(cached.value, allowed_ips); - } - - #[tokio::test] - async fn test_increment_active_listeners_invalidate_added_before() { - tokio::time::pause(); - let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions { - size: 2, - max_roles: 2, - ttl: Duration::from_secs(1), - gc_interval: Duration::from_secs(600), - })); - - let project_id: ProjectId = "project".into(); - let endpoint_id: EndpointId = "endpoint".into(); - let user1: RoleName = "user1".into(); - let user2: RoleName = "user2".into(); - let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32]))); - let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32]))); - let allowed_ips = Arc::new(vec![ - "127.0.0.1".parse().unwrap(), - "127.0.0.2".parse().unwrap(), - ]); - cache.insert_role_secret( - (&project_id).into(), - (&endpoint_id).into(), - (&user1).into(), - secret1.clone(), - ); - cache.clone().increment_active_listeners().await; - tokio::time::advance(Duration::from_millis(100)).await; - cache.insert_role_secret( - (&project_id).into(), - (&endpoint_id).into(), - (&user2).into(), - secret2.clone(), - ); - - // Added before ttl was disabled + ttl should be still cached. - let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); - assert!(cached.cached()); - let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap(); - assert!(cached.cached()); - - tokio::time::advance(Duration::from_secs(1)).await; - // Added before ttl was disabled + ttl should expire. - assert!(cache.get_role_secret(&endpoint_id, &user1).is_none()); - assert!(cache.get_role_secret(&endpoint_id, &user2).is_none()); - - // Added after ttl was disabled + ttl should not be cached. - cache.insert_allowed_ips( - (&project_id).into(), - (&endpoint_id).into(), - allowed_ips.clone(), - ); - let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); - assert!(!cached.cached()); - - tokio::time::advance(Duration::from_secs(1)).await; - // Added before ttl was disabled + ttl still should expire. - assert!(cache.get_role_secret(&endpoint_id, &user1).is_none()); - assert!(cache.get_role_secret(&endpoint_id, &user2).is_none()); - // Shouldn't be invalidated. - - let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); - assert!(!cached.cached()); - assert_eq!(cached.value, allowed_ips); - } } diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index 0bff901376..d26641db46 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -12,8 +12,8 @@ use tokio::net::TcpStream; use tokio::sync::{mpsc, oneshot}; use tracing::{debug, error, info, warn}; +use crate::auth::AuthError; use crate::auth::backend::ComputeUserInfo; -use crate::auth::{AuthError, check_peer_addr_is_in_list}; use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::ControlPlaneApi; @@ -21,7 +21,6 @@ use crate::error::ReportableError; use crate::ext::LockExt; use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind}; use crate::pqproto::CancelKeyData; -use crate::protocol2::ConnectionInfoExtra; use crate::rate_limiter::LeakyBucketRateLimiter; use crate::redis::keys::KeyPrefix; use crate::redis::kv_ops::RedisKVClient; @@ -272,13 +271,7 @@ pub(crate) enum CancelError { #[error("rate limit exceeded")] RateLimit, - #[error("IP is not allowed")] - IpNotAllowed, - - #[error("VPC endpoint id is not allowed to connect")] - VpcEndpointIdNotAllowed, - - #[error("Authentication backend error")] + #[error("Authentication error")] AuthError(#[from] AuthError), #[error("key not found")] @@ -297,10 +290,7 @@ impl ReportableError for CancelError { } CancelError::Postgres(_) => crate::error::ErrorKind::Compute, CancelError::RateLimit => crate::error::ErrorKind::RateLimit, - CancelError::IpNotAllowed - | CancelError::VpcEndpointIdNotAllowed - | CancelError::NotFound => crate::error::ErrorKind::User, - CancelError::AuthError(_) => crate::error::ErrorKind::ControlPlane, + CancelError::NotFound | CancelError::AuthError(_) => crate::error::ErrorKind::User, CancelError::InternalError => crate::error::ErrorKind::Service, } } @@ -422,7 +412,13 @@ impl CancellationHandler { IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()), }; - if !self.limiter.lock_propagate_poison().check(subnet_key, 1) { + + let allowed = { + let rate_limit_config = None; + let limiter = self.limiter.lock_propagate_poison(); + limiter.check(subnet_key, rate_limit_config, 1) + }; + if !allowed { // log only the subnet part of the IP address to know which subnet is rate limited tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}"); Metrics::get() @@ -450,52 +446,13 @@ impl CancellationHandler { return Err(CancelError::NotFound); }; - if check_ip_allowed { - let ip_allowlist = auth_backend - .get_allowed_ips(&ctx, &cancel_closure.user_info) - .await - .map_err(|e| CancelError::AuthError(e.into()))?; - - if !check_peer_addr_is_in_list(&ctx.peer_addr(), &ip_allowlist) { - // log it here since cancel_session could be spawned in a task - tracing::warn!( - "IP is not allowed to cancel the query: {key}, address: {}", - ctx.peer_addr() - ); - return Err(CancelError::IpNotAllowed); - } - } - - // check if a VPC endpoint ID is coming in and if yes, if it's allowed - let access_blocks = auth_backend - .get_block_public_or_vpc_access(&ctx, &cancel_closure.user_info) + let info = &cancel_closure.user_info; + let access_controls = auth_backend + .get_endpoint_access_control(&ctx, &info.endpoint, &info.user) .await .map_err(|e| CancelError::AuthError(e.into()))?; - if check_vpc_allowed { - if access_blocks.vpc_access_blocked { - return Err(CancelError::AuthError(AuthError::NetworkNotAllowed)); - } - - let incoming_vpc_endpoint_id = match ctx.extra() { - None => return Err(CancelError::AuthError(AuthError::MissingVPCEndpointId)), - Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(), - Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), - }; - - let allowed_vpc_endpoint_ids = auth_backend - .get_allowed_vpc_endpoint_ids(&ctx, &cancel_closure.user_info) - .await - .map_err(|e| CancelError::AuthError(e.into()))?; - // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. - if !allowed_vpc_endpoint_ids.is_empty() - && !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id) - { - return Err(CancelError::VpcEndpointIdNotAllowed); - } - } else if access_blocks.public_access_blocked { - return Err(CancelError::VpcEndpointIdNotAllowed); - } + access_controls.check(&ctx, check_ip_allowed, check_vpc_allowed)?; Metrics::get() .proxy diff --git a/proxy/src/config.rs b/proxy/src/config.rs index ad398c122c..a97339df9a 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -7,7 +7,6 @@ use arc_swap::ArcSwapOption; use clap::ValueEnum; use remote_storage::RemoteStorageConfig; -use crate::auth::backend::AuthRateLimiter; use crate::auth::backend::jwt::JwkCache; use crate::control_plane::locks::ApiLocks; use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig}; @@ -65,9 +64,6 @@ pub struct HttpConfig { pub struct AuthenticationConfig { pub thread_pool: Arc, pub scram_protocol_timeout: tokio::time::Duration, - pub rate_limiter_enabled: bool, - pub rate_limiter: AuthRateLimiter, - pub rate_limit_ip_subnet: u8, pub ip_allowlist_check_enabled: bool, pub is_vpc_acccess_proxy: bool, pub jwks_cache: JwkCache, diff --git a/proxy/src/context/mod.rs b/proxy/src/context/mod.rs index de4600951e..24268997ba 100644 --- a/proxy/src/context/mod.rs +++ b/proxy/src/context/mod.rs @@ -370,6 +370,18 @@ impl RequestContext { } } + pub(crate) fn latency_timer_pause_at( + &self, + at: tokio::time::Instant, + waiting_for: Waiting, + ) -> LatencyTimerPause<'_> { + LatencyTimerPause { + ctx: self, + start: at, + waiting_for, + } + } + pub(crate) fn get_proxy_latency(&self) -> LatencyAccumulated { self.0 .try_lock() diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index 2765aaa462..93f4ea6cf7 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -15,7 +15,6 @@ use tracing::{Instrument, debug, info, info_span, warn}; use super::super::messages::{ControlPlaneErrorMessage, GetEndpointAccessControl, WakeCompute}; use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; -use crate::cache::Cached; use crate::context::RequestContext; use crate::control_plane::caches::ApiCaches; use crate::control_plane::errors::{ @@ -24,12 +23,12 @@ use crate::control_plane::errors::{ use crate::control_plane::locks::ApiLocks; use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason}; use crate::control_plane::{ - AccessBlockerFlags, AuthInfo, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps, - CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, NodeInfo, + AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo, + RoleAccessControl, }; -use crate::metrics::{CacheOutcome, Metrics}; +use crate::metrics::Metrics; use crate::rate_limiter::WakeComputeRateLimiter; -use crate::types::{EndpointCacheKey, EndpointId}; +use crate::types::{EndpointCacheKey, EndpointId, RoleName}; use crate::{compute, http, scram}; pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id"); @@ -66,65 +65,34 @@ impl NeonControlPlaneClient { self.endpoint.url().as_str() } - async fn do_get_auth_info( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - if !self - .caches - .endpoints_cache - .is_valid(ctx, &user_info.endpoint.normalize()) - { - // TODO: refactor this because it's weird - // this is a failure to authenticate but we return Ok. - info!("endpoint is not valid, skipping the request"); - return Ok(AuthInfo::default()); - } - self.do_get_auth_req(user_info, &ctx.session_id(), Some(ctx)) - .await - } - async fn do_get_auth_req( &self, - user_info: &ComputeUserInfo, - session_id: &uuid::Uuid, - ctx: Option<&RequestContext>, + ctx: &RequestContext, + endpoint: &EndpointId, + role: &RoleName, ) -> Result { - let request_id: String = session_id.to_string(); - let application_name = if let Some(ctx) = ctx { - ctx.console_application_name() - } else { - "auth_cancellation".to_string() - }; - async { let request = self .endpoint .get_path("get_endpoint_access_control") - .header(X_REQUEST_ID, &request_id) + .header(X_REQUEST_ID, ctx.session_id().to_string()) .header(AUTHORIZATION, format!("Bearer {}", &self.jwt)) - .query(&[("session_id", session_id)]) + .query(&[("session_id", ctx.session_id())]) .query(&[ - ("application_name", application_name.as_str()), - ("endpointish", user_info.endpoint.as_str()), - ("role", user_info.user.as_str()), + ("application_name", ctx.console_application_name().as_str()), + ("endpointish", endpoint.as_str()), + ("role", role.as_str()), ]) .build()?; debug!(url = request.url().as_str(), "sending http request"); let start = Instant::now(); - let response = match ctx { - Some(ctx) => { - let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane); - let rsp = self.endpoint.execute(request).await; - drop(pause); - rsp? - } - None => self.endpoint.execute(request).await?, + let response = { + let _pause = ctx.latency_timer_pause_at(start, crate::metrics::Waiting::Cplane); + self.endpoint.execute(request).await? }; - info!(duration = ?start.elapsed(), "received http response"); + let body = match parse_body::(response).await { Ok(body) => body, // Error 404 is special: it's ok not to have a secret. @@ -180,7 +148,7 @@ impl NeonControlPlaneClient { async fn do_get_endpoint_jwks( &self, ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, GetEndpointJwksError> { if !self .caches @@ -313,225 +281,104 @@ impl NeonControlPlaneClient { impl super::ControlPlaneApi for NeonControlPlaneClient { #[tracing::instrument(skip_all)] - async fn get_role_secret( + async fn get_role_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - let normalized_ep = &user_info.endpoint.normalize(); - let user = &user_info.user; - if let Some(role_secret) = self + endpoint: &EndpointId, + role: &RoleName, + ) -> Result { + let normalized_ep = &endpoint.normalize(); + if let Some(secret) = self .caches .project_info - .get_role_secret(normalized_ep, user) + .get_role_secret(normalized_ep, role) { - return Ok(role_secret); + return Ok(secret); } - let auth_info = self.do_get_auth_info(ctx, user_info).await?; - let account_id = auth_info.account_id; + + if !self.caches.endpoints_cache.is_valid(ctx, normalized_ep) { + info!("endpoint is not valid, skipping the request"); + return Err(GetAuthInfoError::UnknownEndpoint); + } + + let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?; + + let control = EndpointAccessControl { + allowed_ips: Arc::new(auth_info.allowed_ips), + allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids), + flags: auth_info.access_blocker_flags, + }; + let role_control = RoleAccessControl { + secret: auth_info.secret, + }; + if let Some(project_id) = auth_info.project_id { let normalized_ep_int = normalized_ep.into(); - self.caches.project_info.insert_role_secret( + + self.caches.project_info.insert_endpoint_access( + auth_info.account_id, project_id, normalized_ep_int, - user.into(), - auth_info.secret.clone(), - ); - self.caches.project_info.insert_allowed_ips( - project_id, - normalized_ep_int, - Arc::new(auth_info.allowed_ips), - ); - self.caches.project_info.insert_allowed_vpc_endpoint_ids( - account_id, - project_id, - normalized_ep_int, - Arc::new(auth_info.allowed_vpc_endpoint_ids), - ); - self.caches.project_info.insert_block_public_or_vpc_access( - project_id, - normalized_ep_int, - auth_info.access_blocker_flags, + role.into(), + control, + role_control.clone(), ); ctx.set_project_id(project_id); } - // When we just got a secret, we don't need to invalidate it. - Ok(Cached::new_uncached(auth_info.secret)) + + Ok(role_control) } - async fn get_allowed_ips( + #[tracing::instrument(skip_all)] + async fn get_endpoint_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - let normalized_ep = &user_info.endpoint.normalize(); - if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) { - Metrics::get() - .proxy - .allowed_ips_cache_misses // TODO SR: Should we rename this variable to something like allowed_ip_cache_stats? - .inc(CacheOutcome::Hit); - return Ok(allowed_ips); + endpoint: &EndpointId, + role: &RoleName, + ) -> Result { + let normalized_ep = &endpoint.normalize(); + if let Some(control) = self.caches.project_info.get_endpoint_access(normalized_ep) { + return Ok(control); } - Metrics::get() - .proxy - .allowed_ips_cache_misses - .inc(CacheOutcome::Miss); - let auth_info = self.do_get_auth_info(ctx, user_info).await?; - let allowed_ips = Arc::new(auth_info.allowed_ips); - let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids); - let access_blocker_flags = auth_info.access_blocker_flags; - let user = &user_info.user; - let account_id = auth_info.account_id; + + if !self.caches.endpoints_cache.is_valid(ctx, normalized_ep) { + info!("endpoint is not valid, skipping the request"); + return Err(GetAuthInfoError::UnknownEndpoint); + } + + let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?; + + let control = EndpointAccessControl { + allowed_ips: Arc::new(auth_info.allowed_ips), + allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids), + flags: auth_info.access_blocker_flags, + }; + let role_control = RoleAccessControl { + secret: auth_info.secret, + }; + if let Some(project_id) = auth_info.project_id { let normalized_ep_int = normalized_ep.into(); - self.caches.project_info.insert_role_secret( + + self.caches.project_info.insert_endpoint_access( + auth_info.account_id, project_id, normalized_ep_int, - user.into(), - auth_info.secret.clone(), - ); - self.caches.project_info.insert_allowed_ips( - project_id, - normalized_ep_int, - allowed_ips.clone(), - ); - self.caches.project_info.insert_allowed_vpc_endpoint_ids( - account_id, - project_id, - normalized_ep_int, - allowed_vpc_endpoint_ids.clone(), - ); - self.caches.project_info.insert_block_public_or_vpc_access( - project_id, - normalized_ep_int, - access_blocker_flags, + role.into(), + control.clone(), + role_control, ); ctx.set_project_id(project_id); } - Ok(Cached::new_uncached(allowed_ips)) - } - async fn get_allowed_vpc_endpoint_ids( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - let normalized_ep = &user_info.endpoint.normalize(); - if let Some(allowed_vpc_endpoint_ids) = self - .caches - .project_info - .get_allowed_vpc_endpoint_ids(normalized_ep) - { - Metrics::get() - .proxy - .vpc_endpoint_id_cache_stats - .inc(CacheOutcome::Hit); - return Ok(allowed_vpc_endpoint_ids); - } - - Metrics::get() - .proxy - .vpc_endpoint_id_cache_stats - .inc(CacheOutcome::Miss); - - let auth_info = self.do_get_auth_info(ctx, user_info).await?; - let allowed_ips = Arc::new(auth_info.allowed_ips); - let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids); - let access_blocker_flags = auth_info.access_blocker_flags; - let user = &user_info.user; - let account_id = auth_info.account_id; - if let Some(project_id) = auth_info.project_id { - let normalized_ep_int = normalized_ep.into(); - self.caches.project_info.insert_role_secret( - project_id, - normalized_ep_int, - user.into(), - auth_info.secret.clone(), - ); - self.caches.project_info.insert_allowed_ips( - project_id, - normalized_ep_int, - allowed_ips.clone(), - ); - self.caches.project_info.insert_allowed_vpc_endpoint_ids( - account_id, - project_id, - normalized_ep_int, - allowed_vpc_endpoint_ids.clone(), - ); - self.caches.project_info.insert_block_public_or_vpc_access( - project_id, - normalized_ep_int, - access_blocker_flags, - ); - ctx.set_project_id(project_id); - } - Ok(Cached::new_uncached(allowed_vpc_endpoint_ids)) - } - - async fn get_block_public_or_vpc_access( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - let normalized_ep = &user_info.endpoint.normalize(); - if let Some(access_blocker_flags) = self - .caches - .project_info - .get_block_public_or_vpc_access(normalized_ep) - { - Metrics::get() - .proxy - .access_blocker_flags_cache_stats - .inc(CacheOutcome::Hit); - return Ok(access_blocker_flags); - } - - Metrics::get() - .proxy - .access_blocker_flags_cache_stats - .inc(CacheOutcome::Miss); - - let auth_info = self.do_get_auth_info(ctx, user_info).await?; - let allowed_ips = Arc::new(auth_info.allowed_ips); - let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids); - let access_blocker_flags = auth_info.access_blocker_flags; - let user = &user_info.user; - let account_id = auth_info.account_id; - if let Some(project_id) = auth_info.project_id { - let normalized_ep_int = normalized_ep.into(); - self.caches.project_info.insert_role_secret( - project_id, - normalized_ep_int, - user.into(), - auth_info.secret.clone(), - ); - self.caches.project_info.insert_allowed_ips( - project_id, - normalized_ep_int, - allowed_ips.clone(), - ); - self.caches.project_info.insert_allowed_vpc_endpoint_ids( - account_id, - project_id, - normalized_ep_int, - allowed_vpc_endpoint_ids.clone(), - ); - self.caches.project_info.insert_block_public_or_vpc_access( - project_id, - normalized_ep_int, - access_blocker_flags.clone(), - ); - ctx.set_project_id(project_id); - } - Ok(Cached::new_uncached(access_blocker_flags)) + Ok(control) } #[tracing::instrument(skip_all)] async fn get_endpoint_jwks( &self, ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, GetEndpointJwksError> { self.do_get_endpoint_jwks(ctx, endpoint).await } diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index d3ab4abd0b..ece7153fce 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -15,14 +15,14 @@ use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; use crate::cache::Cached; use crate::context::RequestContext; -use crate::control_plane::client::{ - CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedRoleSecret, -}; use crate::control_plane::errors::{ ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError, }; use crate::control_plane::messages::MetricsAuxInfo; -use crate::control_plane::{AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo}; +use crate::control_plane::{ + AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo, + RoleAccessControl, +}; use crate::intern::RoleNameInt; use crate::types::{BranchId, EndpointId, ProjectId, RoleName}; use crate::url::ApiUrl; @@ -66,7 +66,8 @@ impl MockControlPlane { async fn do_get_auth_info( &self, - user_info: &ComputeUserInfo, + endpoint: &EndpointId, + role: &RoleName, ) -> Result { let (secret, allowed_ips) = async { // Perhaps we could persist this connection, but then we'd have to @@ -80,7 +81,7 @@ impl MockControlPlane { let secret = if let Some(entry) = get_execute_postgres_query( &client, "select rolpassword from pg_catalog.pg_authid where rolname = $1", - &[&&*user_info.user], + &[&role.as_str()], "rolpassword", ) .await? @@ -89,7 +90,7 @@ impl MockControlPlane { let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram); secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5)) } else { - warn!("user '{}' does not exist", user_info.user); + warn!("user '{role}' does not exist"); None }; @@ -97,7 +98,7 @@ impl MockControlPlane { match get_execute_postgres_query( &client, "select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1", - &[&user_info.endpoint.as_str()], + &[&endpoint.as_str()], "allowed_ips", ) .await? @@ -133,7 +134,7 @@ impl MockControlPlane { async fn do_get_endpoint_jwks( &self, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, GetEndpointJwksError> { let (client, connection) = tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?; @@ -222,53 +223,36 @@ async fn get_execute_postgres_query( } impl super::ControlPlaneApi for MockControlPlane { - #[tracing::instrument(skip_all)] - async fn get_role_secret( + async fn get_endpoint_access_control( &self, _ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - Ok(CachedRoleSecret::new_uncached( - self.do_get_auth_info(user_info).await?.secret, - )) + endpoint: &EndpointId, + role: &RoleName, + ) -> Result { + let info = self.do_get_auth_info(endpoint, role).await?; + Ok(EndpointAccessControl { + allowed_ips: Arc::new(info.allowed_ips), + allowed_vpce: Arc::new(info.allowed_vpc_endpoint_ids), + flags: info.access_blocker_flags, + }) } - async fn get_allowed_ips( + async fn get_role_access_control( &self, _ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - Ok(Cached::new_uncached(Arc::new( - self.do_get_auth_info(user_info).await?.allowed_ips, - ))) - } - - async fn get_allowed_vpc_endpoint_ids( - &self, - _ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - Ok(Cached::new_uncached(Arc::new( - self.do_get_auth_info(user_info) - .await? - .allowed_vpc_endpoint_ids, - ))) - } - - async fn get_block_public_or_vpc_access( - &self, - _ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - Ok(Cached::new_uncached( - self.do_get_auth_info(user_info).await?.access_blocker_flags, - )) + endpoint: &EndpointId, + role: &RoleName, + ) -> Result { + let info = self.do_get_auth_info(endpoint, role).await?; + Ok(RoleAccessControl { + secret: info.secret, + }) } async fn get_endpoint_jwks( &self, _ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, GetEndpointJwksError> { self.do_get_endpoint_jwks(endpoint).await } diff --git a/proxy/src/control_plane/client/mod.rs b/proxy/src/control_plane/client/mod.rs index 746595de38..9b9d1e25ea 100644 --- a/proxy/src/control_plane/client/mod.rs +++ b/proxy/src/control_plane/client/mod.rs @@ -16,15 +16,14 @@ use crate::cache::endpoints::EndpointsCache; use crate::cache::project_info::ProjectInfoCacheImpl; use crate::config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions}; use crate::context::RequestContext; -use crate::control_plane::{ - CachedAccessBlockerFlags, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, - CachedRoleSecret, ControlPlaneApi, NodeInfoCache, errors, -}; +use crate::control_plane::{CachedNodeInfo, ControlPlaneApi, NodeInfoCache, errors}; use crate::error::ReportableError; use crate::metrics::ApiLockMetrics; use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token}; use crate::types::EndpointId; +use super::{EndpointAccessControl, RoleAccessControl}; + #[non_exhaustive] #[derive(Clone)] pub enum ControlPlaneClient { @@ -40,68 +39,42 @@ pub enum ControlPlaneClient { } impl ControlPlaneApi for ControlPlaneClient { - async fn get_role_secret( + async fn get_role_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { + endpoint: &EndpointId, + role: &crate::types::RoleName, + ) -> Result { match self { - Self::ProxyV1(api) => api.get_role_secret(ctx, user_info).await, + Self::ProxyV1(api) => api.get_role_access_control(ctx, endpoint, role).await, #[cfg(any(test, feature = "testing"))] - Self::PostgresMock(api) => api.get_role_secret(ctx, user_info).await, + Self::PostgresMock(api) => api.get_role_access_control(ctx, endpoint, role).await, #[cfg(test)] - Self::Test(_) => { + Self::Test(_api) => { unreachable!("this function should never be called in the test backend") } } } - async fn get_allowed_ips( + async fn get_endpoint_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { + endpoint: &EndpointId, + role: &crate::types::RoleName, + ) -> Result { match self { - Self::ProxyV1(api) => api.get_allowed_ips(ctx, user_info).await, + Self::ProxyV1(api) => api.get_endpoint_access_control(ctx, endpoint, role).await, #[cfg(any(test, feature = "testing"))] - Self::PostgresMock(api) => api.get_allowed_ips(ctx, user_info).await, + Self::PostgresMock(api) => api.get_endpoint_access_control(ctx, endpoint, role).await, #[cfg(test)] - Self::Test(api) => api.get_allowed_ips(), - } - } - - async fn get_allowed_vpc_endpoint_ids( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - match self { - Self::ProxyV1(api) => api.get_allowed_vpc_endpoint_ids(ctx, user_info).await, - #[cfg(any(test, feature = "testing"))] - Self::PostgresMock(api) => api.get_allowed_vpc_endpoint_ids(ctx, user_info).await, - #[cfg(test)] - Self::Test(api) => api.get_allowed_vpc_endpoint_ids(), - } - } - - async fn get_block_public_or_vpc_access( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - match self { - Self::ProxyV1(api) => api.get_block_public_or_vpc_access(ctx, user_info).await, - #[cfg(any(test, feature = "testing"))] - Self::PostgresMock(api) => api.get_block_public_or_vpc_access(ctx, user_info).await, - #[cfg(test)] - Self::Test(api) => api.get_block_public_or_vpc_access(), + Self::Test(api) => api.get_access_control(), } } async fn get_endpoint_jwks( &self, ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, errors::GetEndpointJwksError> { match self { Self::ProxyV1(api) => api.get_endpoint_jwks(ctx, endpoint).await, @@ -131,15 +104,7 @@ impl ControlPlaneApi for ControlPlaneClient { pub(crate) trait TestControlPlaneClient: Send + Sync + 'static { fn wake_compute(&self) -> Result; - fn get_allowed_ips(&self) -> Result; - - fn get_allowed_vpc_endpoint_ids( - &self, - ) -> Result; - - fn get_block_public_or_vpc_access( - &self, - ) -> Result; + fn get_access_control(&self) -> Result; fn dyn_clone(&self) -> Box; } @@ -309,7 +274,7 @@ impl FetchAuthRules for ControlPlaneClient { ctx: &RequestContext, endpoint: EndpointId, ) -> Result, FetchAuthRulesError> { - self.get_endpoint_jwks(ctx, endpoint) + self.get_endpoint_jwks(ctx, &endpoint) .await .map_err(FetchAuthRulesError::GetEndpointJwks) } diff --git a/proxy/src/control_plane/errors.rs b/proxy/src/control_plane/errors.rs index 850d061333..77312c89c5 100644 --- a/proxy/src/control_plane/errors.rs +++ b/proxy/src/control_plane/errors.rs @@ -99,6 +99,10 @@ pub(crate) enum GetAuthInfoError { #[error(transparent)] ApiError(ControlPlaneError), + + /// Proxy does not know about the endpoint in advanced + #[error("endpoint not found in endpoint cache")] + UnknownEndpoint, } // This allows more useful interactions than `#[from]`. @@ -115,6 +119,8 @@ impl UserFacingError for GetAuthInfoError { Self::BadSecret => REQUEST_FAILED.to_owned(), // However, API might return a meaningful error. Self::ApiError(e) => e.to_string_client(), + // pretend like control plane returned an error. + Self::UnknownEndpoint => REQUEST_FAILED.to_owned(), } } } @@ -124,6 +130,8 @@ impl ReportableError for GetAuthInfoError { match self { Self::BadSecret => crate::error::ErrorKind::ControlPlane, Self::ApiError(_) => crate::error::ErrorKind::ControlPlane, + // we only apply endpoint filtering if control plane is under high load. + Self::UnknownEndpoint => crate::error::ErrorKind::ServiceRateLimit, } } } diff --git a/proxy/src/control_plane/mod.rs b/proxy/src/control_plane/mod.rs index d592223be1..7ff093d9dc 100644 --- a/proxy/src/control_plane/mod.rs +++ b/proxy/src/control_plane/mod.rs @@ -11,16 +11,16 @@ pub(crate) mod errors; use std::sync::Arc; -use crate::auth::IpPattern; use crate::auth::backend::jwt::AuthRule; use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo}; -use crate::cache::project_info::ProjectInfoCacheImpl; +use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list}; use crate::cache::{Cached, TimedLru}; use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo}; use crate::intern::{AccountIdInt, ProjectIdInt}; -use crate::types::{EndpointCacheKey, EndpointId}; +use crate::protocol2::ConnectionInfoExtra; +use crate::types::{EndpointCacheKey, EndpointId, RoleName}; use crate::{compute, scram}; /// Various cache-related types. @@ -101,7 +101,7 @@ impl NodeInfo { } } -#[derive(Clone, Default, Eq, PartialEq, Debug)] +#[derive(Copy, Clone, Default)] pub(crate) struct AccessBlockerFlags { pub public_access_blocked: bool, pub vpc_access_blocked: bool, @@ -110,47 +110,78 @@ pub(crate) struct AccessBlockerFlags { pub(crate) type NodeInfoCache = TimedLru>>; pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>; -pub(crate) type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option>; -pub(crate) type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc>>; -pub(crate) type CachedAllowedVpcEndpointIds = - Cached<&'static ProjectInfoCacheImpl, Arc>>; -pub(crate) type CachedAccessBlockerFlags = - Cached<&'static ProjectInfoCacheImpl, AccessBlockerFlags>; + +#[derive(Clone)] +pub struct RoleAccessControl { + pub secret: Option, +} + +#[derive(Clone)] +pub struct EndpointAccessControl { + pub allowed_ips: Arc>, + pub allowed_vpce: Arc>, + pub flags: AccessBlockerFlags, +} + +impl EndpointAccessControl { + pub fn check( + &self, + ctx: &RequestContext, + check_ip_allowed: bool, + check_vpc_allowed: bool, + ) -> Result<(), AuthError> { + if check_ip_allowed && !check_peer_addr_is_in_list(&ctx.peer_addr(), &self.allowed_ips) { + return Err(AuthError::IpAddressNotAllowed(ctx.peer_addr())); + } + + // check if a VPC endpoint ID is coming in and if yes, if it's allowed + if check_vpc_allowed { + if self.flags.vpc_access_blocked { + return Err(AuthError::NetworkNotAllowed); + } + + let incoming_vpc_endpoint_id = match ctx.extra() { + None => return Err(AuthError::MissingVPCEndpointId), + Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(), + Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), + }; + + let vpce = &self.allowed_vpce; + // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. + if !vpce.is_empty() && !vpce.contains(&incoming_vpc_endpoint_id) { + return Err(AuthError::vpc_endpoint_id_not_allowed( + incoming_vpc_endpoint_id, + )); + } + } else if self.flags.public_access_blocked { + return Err(AuthError::NetworkNotAllowed); + } + + Ok(()) + } +} /// This will allocate per each call, but the http requests alone /// already require a few allocations, so it should be fine. pub(crate) trait ControlPlaneApi { - /// Get the client's auth secret for authentication. - /// Returns option because user not found situation is special. - /// We still have to mock the scram to avoid leaking information that user doesn't exist. - async fn get_role_secret( + async fn get_role_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result; + endpoint: &EndpointId, + role: &RoleName, + ) -> Result; - async fn get_allowed_ips( + async fn get_endpoint_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result; - - async fn get_allowed_vpc_endpoint_ids( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result; - - async fn get_block_public_or_vpc_access( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result; + endpoint: &EndpointId, + role: &RoleName, + ) -> Result; async fn get_endpoint_jwks( &self, ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, errors::GetEndpointJwksError>; /// Wake up the compute node and return the corresponding connection info. diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 26ac6a89e7..ac0aca1176 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -345,7 +345,7 @@ pub(crate) async fn handle_client( }; let user = user_info.get_user().to_owned(); - let (user_info, _ip_allowlist) = match user_info + let user_info = match user_info .authenticate( ctx, &mut stream, diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 3cc053e0ad..61e8ee4a10 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -26,9 +26,7 @@ use crate::auth::backend::{ use crate::config::{ComputeConfig, RetryConfig}; use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient}; use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status}; -use crate::control_plane::{ - self, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, NodeInfo, NodeInfoCache, -}; +use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache}; use crate::error::ErrorKind; use crate::tls::client_config::compute_client_config_with_certs; use crate::tls::postgres_rustls::MakeRustlsConnect; @@ -547,20 +545,9 @@ impl TestControlPlaneClient for TestConnectMechanism { } } - fn get_allowed_ips(&self) -> Result { - unimplemented!("not used in tests") - } - - fn get_allowed_vpc_endpoint_ids( + fn get_access_control( &self, - ) -> Result { - unimplemented!("not used in tests") - } - - fn get_block_public_or_vpc_access( - &self, - ) -> Result - { + ) -> Result { unimplemented!("not used in tests") } diff --git a/proxy/src/rate_limiter/leaky_bucket.rs b/proxy/src/rate_limiter/leaky_bucket.rs index 4f27c6faef..0c79b5e92f 100644 --- a/proxy/src/rate_limiter/leaky_bucket.rs +++ b/proxy/src/rate_limiter/leaky_bucket.rs @@ -15,7 +15,7 @@ pub type EndpointRateLimiter = LeakyBucketRateLimiter; pub struct LeakyBucketRateLimiter { map: ClashMap, - config: utils::leaky_bucket::LeakyBucketConfig, + default_config: utils::leaky_bucket::LeakyBucketConfig, access_count: AtomicUsize, } @@ -28,15 +28,17 @@ impl LeakyBucketRateLimiter { pub fn new_with_shards(config: LeakyBucketConfig, shards: usize) -> Self { Self { map: ClashMap::with_hasher_and_shard_amount(RandomState::new(), shards), - config: config.into(), + default_config: config.into(), access_count: AtomicUsize::new(0), } } /// Check that number of connections to the endpoint is below `max_rps` rps. - pub(crate) fn check(&self, key: K, n: u32) -> bool { + pub(crate) fn check(&self, key: K, config: Option, n: u32) -> bool { let now = Instant::now(); + let config = config.map_or(self.default_config, Into::into); + if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 { self.do_gc(now); } @@ -46,7 +48,7 @@ impl LeakyBucketRateLimiter { .entry(key) .or_insert_with(|| LeakyBucketState { empty_at: now }); - entry.add_tokens(&self.config, now, n as f64).is_ok() + entry.add_tokens(&config, now, n as f64).is_ok() } fn do_gc(&self, now: Instant) { diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index 21eaa6739b..9d700c1b52 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -15,6 +15,8 @@ use tracing::info; use crate::ext::LockExt; use crate::intern::EndpointIdInt; +use super::LeakyBucketConfig; + pub struct GlobalRateLimiter { data: Vec, info: Vec, @@ -144,19 +146,6 @@ impl RateBucketInfo { Self::new(50_000, Duration::from_secs(10)), ]; - /// All of these are per endpoint-maskedip pair. - /// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus). - /// - /// First bucket: 1000mcpus total per endpoint-ip pair - /// * 4096000 requests per second with 1 hash rounds. - /// * 1000 requests per second with 4096 hash rounds. - /// * 6.8 requests per second with 600000 hash rounds. - pub const DEFAULT_AUTH_SET: [Self; 3] = [ - Self::new(1000 * 4096, Duration::from_secs(1)), - Self::new(600 * 4096, Duration::from_secs(60)), - Self::new(300 * 4096, Duration::from_secs(600)), - ]; - pub fn rps(&self) -> f64 { (self.max_rpi as f64) / self.interval.as_secs_f64() } @@ -184,6 +173,21 @@ impl RateBucketInfo { max_rpi: ((max_rps as u64) * (interval.as_millis() as u64) / 1000) as u32, } } + + pub fn to_leaky_bucket(this: &[Self]) -> Option { + // bit of a hack - find the min rps and max rps supported and turn it into + // leaky bucket config instead + + let mut iter = this.iter().map(|info| info.rps()); + let first = iter.next()?; + + let (min, max) = (first, first); + let (min, max) = iter.fold((min, max), |(min, max), rps| { + (f64::min(min, rps), f64::max(max, rps)) + }); + + Some(LeakyBucketConfig { rps: min, max }) + } } impl BucketRateLimiter { diff --git a/proxy/src/rate_limiter/mod.rs b/proxy/src/rate_limiter/mod.rs index 5f90102da3..112b95873a 100644 --- a/proxy/src/rate_limiter/mod.rs +++ b/proxy/src/rate_limiter/mod.rs @@ -8,4 +8,4 @@ pub(crate) use limit_algorithm::aimd::Aimd; pub(crate) use limit_algorithm::{ DynamicLimiter, Outcome, RateLimitAlgorithm, RateLimiterConfig, Token, }; -pub use limiter::{BucketRateLimiter, GlobalRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; +pub use limiter::{GlobalRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 769d519d94..a9d6b40603 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -233,29 +233,30 @@ impl MessageHandler { fn invalidate_cache(cache: Arc, msg: Notification) { match msg { - Notification::AllowedIpsUpdate { allowed_ips_update } => { - cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id); + Notification::AllowedIpsUpdate { + allowed_ips_update: AllowedIpsUpdate { project_id }, } - Notification::BlockPublicOrVpcAccessUpdated { - block_public_or_vpc_access_updated, - } => cache.invalidate_block_public_or_vpc_access_for_project( - block_public_or_vpc_access_updated.project_id, - ), + | Notification::BlockPublicOrVpcAccessUpdated { + block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated { project_id }, + } => cache.invalidate_endpoint_access_for_project(project_id), Notification::AllowedVpcEndpointsUpdatedForOrg { - allowed_vpc_endpoints_updated_for_org, - } => cache.invalidate_allowed_vpc_endpoint_ids_for_org( - allowed_vpc_endpoints_updated_for_org.account_id, - ), + allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg { account_id }, + } => cache.invalidate_endpoint_access_for_org(account_id), Notification::AllowedVpcEndpointsUpdatedForProjects { - allowed_vpc_endpoints_updated_for_projects, - } => cache.invalidate_allowed_vpc_endpoint_ids_for_projects( - allowed_vpc_endpoints_updated_for_projects.project_ids, - ), - Notification::PasswordUpdate { password_update } => cache - .invalidate_role_secret_for_project( - password_update.project_id, - password_update.role_name, - ), + allowed_vpc_endpoints_updated_for_projects: + AllowedVpcEndpointsUpdatedForProjects { project_ids }, + } => { + for project in project_ids { + cache.invalidate_endpoint_access_for_project(project); + } + } + Notification::PasswordUpdate { + password_update: + PasswordUpdate { + project_id, + role_name, + }, + } => cache.invalidate_role_secret_for_project(project_id, role_name), Notification::UnknownTopic => unreachable!(), } } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 13058f08f1..bf640c05e9 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -22,7 +22,7 @@ use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client}; use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool}; use crate::auth::backend::local::StaticAuthRules; use crate::auth::backend::{ComputeCredentials, ComputeUserInfo}; -use crate::auth::{self, AuthError, check_peer_addr_is_in_list}; +use crate::auth::{self, AuthError}; use crate::compute; use crate::compute_ctl::{ ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest, @@ -35,7 +35,6 @@ use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError}; use crate::control_plane::locks::ApiLocks; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::intern::EndpointIdInt; -use crate::protocol2::ConnectionInfoExtra; use crate::proxy::connect_compute::ConnectMechanism; use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute}; use crate::rate_limiter::EndpointRateLimiter; @@ -63,63 +62,24 @@ impl PoolingBackend { let user_info = user_info.clone(); let backend = self.auth_backend.as_ref().map(|()| user_info.clone()); - let allowed_ips = backend.get_allowed_ips(ctx).await?; + let access_control = backend.get_endpoint_access_control(ctx).await?; + access_control.check( + ctx, + self.config.authentication_config.ip_allowlist_check_enabled, + self.config.authentication_config.is_vpc_acccess_proxy, + )?; - if self.config.authentication_config.ip_allowlist_check_enabled - && !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) - { - return Err(AuthError::ip_address_not_allowed(ctx.peer_addr())); - } - - let access_blocker_flags = backend.get_block_public_or_vpc_access(ctx).await?; - if self.config.authentication_config.is_vpc_acccess_proxy { - if access_blocker_flags.vpc_access_blocked { - return Err(AuthError::NetworkNotAllowed); - } - - let extra = ctx.extra(); - let incoming_endpoint_id = match extra { - None => String::new(), - Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(), - Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), - }; - - if incoming_endpoint_id.is_empty() { - return Err(AuthError::MissingVPCEndpointId); - } - - let allowed_vpc_endpoint_ids = backend.get_allowed_vpc_endpoint_ids(ctx).await?; - // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. - if !allowed_vpc_endpoint_ids.is_empty() - && !allowed_vpc_endpoint_ids.contains(&incoming_endpoint_id) - { - return Err(AuthError::vpc_endpoint_id_not_allowed(incoming_endpoint_id)); - } - } else if access_blocker_flags.public_access_blocked { - return Err(AuthError::NetworkNotAllowed); - } - - if !self - .endpoint_rate_limiter - .check(user_info.endpoint.clone().into(), 1) - { + let ep = EndpointIdInt::from(&user_info.endpoint); + let rate_limit_config = None; + if !self.endpoint_rate_limiter.check(ep, rate_limit_config, 1) { return Err(AuthError::too_many_connections()); } - let cached_secret = backend.get_role_secret(ctx).await?; - let secret = match cached_secret.value.clone() { - Some(secret) => self.config.authentication_config.check_rate_limit( - ctx, - secret, - &user_info.endpoint, - true, - )?, - None => { - // If we don't have an authentication secret, for the http flow we can just return an error. - info!("authentication info not found"); - return Err(AuthError::password_failed(&*user_info.user)); - } + let role_access = backend.get_role_secret(ctx).await?; + let Some(secret) = role_access.secret else { + // If we don't have an authentication secret, for the http flow we can just return an error. + info!("authentication info not found"); + return Err(AuthError::password_failed(&*user_info.user)); }; - let ep = EndpointIdInt::from(&user_info.endpoint); let auth_outcome = crate::auth::validate_password_and_exchange( &self.config.authentication_config.thread_pool, ep, From af5bb67f08e92c018111055dcef26d9e3bab665a Mon Sep 17 00:00:00 2001 From: Vlad Lazar Date: Mon, 2 Jun 2025 09:59:21 +0100 Subject: [PATCH 40/48] pageserver: more reactive wal receiver cancellation (#12076) ## Problem If the wal receiver is cancelled, there's a 50% chance that it will ingest yet more WAL. ## Summary of Changes Always check cancellation first. --- pageserver/src/tenant/timeline/walreceiver.rs | 2 +- .../src/tenant/timeline/walreceiver/walreceiver_connection.rs | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pageserver/src/tenant/timeline/walreceiver.rs b/pageserver/src/tenant/timeline/walreceiver.rs index 0f73eb839b..633c94a010 100644 --- a/pageserver/src/tenant/timeline/walreceiver.rs +++ b/pageserver/src/tenant/timeline/walreceiver.rs @@ -113,7 +113,7 @@ impl WalReceiver { } connection_manager_state.shutdown().await; *loop_status.write().unwrap() = None; - debug!("task exits"); + info!("task exits"); } .instrument(info_span!(parent: None, "wal_connection_manager", tenant_id = %tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug(), timeline_id = %timeline_id)) }); diff --git a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs index 52259f205b..249849ac4b 100644 --- a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs +++ b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs @@ -297,6 +297,7 @@ pub(super) async fn handle_walreceiver_connection( let mut expected_wal_start = startpoint; while let Some(replication_message) = { select! { + biased; _ = cancellation.cancelled() => { debug!("walreceiver interrupted"); None From 5b62749c42c256db001eb3e59d8d289b4ad942cd Mon Sep 17 00:00:00 2001 From: Vlad Lazar Date: Mon, 2 Jun 2025 11:29:15 +0100 Subject: [PATCH 41/48] pageserver: reduce import memory utilization (#12086) ## Problem Imports can end up allocating too much. ## Summary of Changes Nerf them a bunch and add some logs. --- libs/pageserver_api/src/config.rs | 6 +++--- pageserver/src/tenant/timeline/import_pgdata.rs | 2 ++ pageserver/src/tenant/timeline/import_pgdata/flow.rs | 12 +++++++++++- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/libs/pageserver_api/src/config.rs b/libs/pageserver_api/src/config.rs index 012c020fb1..444983bd18 100644 --- a/libs/pageserver_api/src/config.rs +++ b/libs/pageserver_api/src/config.rs @@ -713,9 +713,9 @@ impl Default for ConfigToml { enable_tls_page_service_api: false, dev_mode: false, timeline_import_config: TimelineImportConfig { - import_job_concurrency: NonZeroUsize::new(128).unwrap(), - import_job_soft_size_limit: NonZeroUsize::new(1024 * 1024 * 1024).unwrap(), - import_job_checkpoint_threshold: NonZeroUsize::new(128).unwrap(), + import_job_concurrency: NonZeroUsize::new(32).unwrap(), + import_job_soft_size_limit: NonZeroUsize::new(256 * 1024 * 1024).unwrap(), + import_job_checkpoint_threshold: NonZeroUsize::new(32).unwrap(), }, basebackup_cache_config: None, posthog_config: None, diff --git a/pageserver/src/tenant/timeline/import_pgdata.rs b/pageserver/src/tenant/timeline/import_pgdata.rs index f19a4b3e9c..3f760d858b 100644 --- a/pageserver/src/tenant/timeline/import_pgdata.rs +++ b/pageserver/src/tenant/timeline/import_pgdata.rs @@ -106,6 +106,8 @@ pub async fn doit( ); } + tracing::info!("Import plan executed. Flushing remote changes and notifying storcon"); + timeline .remote_client .schedule_index_upload_for_file_changes()?; diff --git a/pageserver/src/tenant/timeline/import_pgdata/flow.rs b/pageserver/src/tenant/timeline/import_pgdata/flow.rs index bf3c7eeda6..760e82dd57 100644 --- a/pageserver/src/tenant/timeline/import_pgdata/flow.rs +++ b/pageserver/src/tenant/timeline/import_pgdata/flow.rs @@ -130,7 +130,15 @@ async fn run_v1( pausable_failpoint!("import-timeline-pre-execute-pausable"); + let jobs_count = import_progress.as_ref().map(|p| p.jobs); let start_from_job_idx = import_progress.map(|progress| progress.completed); + + tracing::info!( + start_from_job_idx=?start_from_job_idx, + jobs=?jobs_count, + "Executing import plan" + ); + plan.execute(timeline, start_from_job_idx, plan_hash, &import_config, ctx) .await } @@ -484,6 +492,8 @@ impl Plan { anyhow::anyhow!("Shut down while putting timeline import status") })?; } + + tracing::info!(last_completed_job_idx, jobs=%jobs_in_plan, "Checkpointing import status"); }, Some(Err(_)) => { anyhow::bail!( @@ -760,7 +770,7 @@ impl ImportTask for ImportRelBlocksTask { layer_writer: &mut ImageLayerWriter, ctx: &RequestContext, ) -> anyhow::Result { - const MAX_BYTE_RANGE_SIZE: usize = 128 * 1024 * 1024; + const MAX_BYTE_RANGE_SIZE: usize = 4 * 1024 * 1024; debug!("Importing relation file"); From 8d7ed2a4ee1e2753d8a3ac17c6bc43ccabc2ed2e Mon Sep 17 00:00:00 2001 From: Erik Grinaker Date: Mon, 2 Jun 2025 13:46:50 +0200 Subject: [PATCH 42/48] pageserver: add gRPC observability middleware (#12093) ## Problem The page service logic asserts that a tracing span is present with tenant/timeline/shard IDs. An initial gRPC page service implementation thus requires a tracing span. Touches https://github.com/neondatabase/neon/issues/11728. ## Summary of changes Adds an `ObservabilityLayer` middleware that generates a tracing span and decorates it with IDs from the gRPC metadata. This is a minimal implementation to address the tracing span assertion. It will be extended with additional observability in later PRs. --- Cargo.lock | 2 + libs/utils/src/lib.rs | 1 + libs/utils/src/span.rs | 19 +++++ pageserver/Cargo.toml | 2 + pageserver/src/page_service.rs | 137 ++++++++++++++++++++++++++------- 5 files changed, 134 insertions(+), 27 deletions(-) create mode 100644 libs/utils/src/span.rs diff --git a/Cargo.lock b/Cargo.lock index 89351432c1..4f7378e95d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4305,6 +4305,7 @@ dependencies = [ "hashlink", "hex", "hex-literal", + "http 1.1.0", "http-utils", "humantime", "humantime-serde", @@ -4367,6 +4368,7 @@ dependencies = [ "toml_edit", "tonic 0.13.1", "tonic-reflection", + "tower 0.5.2", "tracing", "tracing-utils", "twox-hash", diff --git a/libs/utils/src/lib.rs b/libs/utils/src/lib.rs index 206b8bbd8f..11f787562c 100644 --- a/libs/utils/src/lib.rs +++ b/libs/utils/src/lib.rs @@ -73,6 +73,7 @@ pub mod error; /// async timeout helper pub mod timeout; +pub mod span; pub mod sync; pub mod failpoint_support; diff --git a/libs/utils/src/span.rs b/libs/utils/src/span.rs new file mode 100644 index 0000000000..4dbc99044b --- /dev/null +++ b/libs/utils/src/span.rs @@ -0,0 +1,19 @@ +//! Tracing span helpers. + +/// Records the given fields in the current span, as a single call. The fields must already have +/// been declared for the span (typically with empty values). +#[macro_export] +macro_rules! span_record { + ($($tokens:tt)*) => {$crate::span_record_in!(::tracing::Span::current(), $($tokens)*)}; +} + +/// Records the given fields in the given span, as a single call. The fields must already have been +/// declared for the span (typically with empty values). +#[macro_export] +macro_rules! span_record_in { + ($span:expr, $($tokens:tt)*) => { + if let Some(meta) = $span.metadata() { + $span.record_all(&tracing::valueset!(meta.fields(), $($tokens)*)); + } + }; +} diff --git a/pageserver/Cargo.toml b/pageserver/Cargo.toml index c4d6d58945..9591c729e8 100644 --- a/pageserver/Cargo.toml +++ b/pageserver/Cargo.toml @@ -34,6 +34,7 @@ fail.workspace = true futures.workspace = true hashlink.workspace = true hex.workspace = true +http.workspace = true http-utils.workspace = true humantime-serde.workspace = true humantime.workspace = true @@ -93,6 +94,7 @@ tokio-util.workspace = true toml_edit = { workspace = true, features = [ "serde" ] } tonic.workspace = true tonic-reflection.workspace = true +tower.workspace = true tracing.workspace = true tracing-utils.workspace = true url.workspace = true diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index e96787e027..f011ed49d0 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -7,12 +7,14 @@ use std::os::fd::AsRawFd; use std::pin::Pin; use std::str::FromStr; use std::sync::Arc; +use std::task::{Context, Poll}; use std::time::{Duration, Instant, SystemTime}; use std::{io, str}; -use anyhow::{Context, bail}; +use anyhow::{Context as _, bail}; use async_compression::tokio::write::GzipEncoder; use bytes::Buf; +use futures::future::BoxFuture; use futures::{FutureExt, Stream}; use itertools::Itertools; use jsonwebtoken::TokenData; @@ -46,7 +48,6 @@ use tokio_util::sync::CancellationToken; use tonic::service::Interceptor as _; use tracing::*; use utils::auth::{Claims, Scope, SwappableJwtAuth}; -use utils::failpoint_support; use utils::id::{TenantId, TenantTimelineId, TimelineId}; use utils::logging::log_slow; use utils::lsn::Lsn; @@ -54,6 +55,7 @@ use utils::shard::ShardIndex; use utils::simple_rcu::RcuReadGuard; use utils::sync::gate::{Gate, GateGuard}; use utils::sync::spsc_fold; +use utils::{failpoint_support, span_record}; use crate::auth::check_permission; use crate::basebackup::{self, BasebackupError}; @@ -195,13 +197,17 @@ pub fn spawn_grpc( // Set up the gRPC server. // // TODO: consider tuning window sizes. - // TODO: wire up tracing. let mut server = tonic::transport::Server::builder() .http2_keepalive_interval(Some(GRPC_HTTP2_KEEPALIVE_INTERVAL)) .http2_keepalive_timeout(Some(GRPC_HTTP2_KEEPALIVE_TIMEOUT)) .max_concurrent_streams(Some(GRPC_MAX_CONCURRENT_STREAMS)); - // Main page service. + // Main page service stack. Uses a mix of Tonic interceptors and Tower layers: + // + // * Interceptors: can inspect and modify the gRPC request. Sync code only, runs before service. + // + // * Layers: allow async code, can run code after the service response. However, only has access + // to the raw HTTP request/response, not the gRPC types. let page_service_handler = PageServerHandler::new( tenant_manager, auth.clone(), @@ -214,16 +220,22 @@ pub fn spawn_grpc( gate.enter().expect("just created"), ); + let observability_layer = ObservabilityLayer; let mut tenant_interceptor = TenantMetadataInterceptor; let mut auth_interceptor = TenantAuthInterceptor::new(auth); - let interceptors = move |mut req: tonic::Request<()>| { - req = tenant_interceptor.call(req)?; - req = auth_interceptor.call(req)?; - Ok(req) - }; - let page_service = - proto::PageServiceServer::with_interceptor(page_service_handler, interceptors); + let page_service = tower::ServiceBuilder::new() + // Create tracing span. + .layer(observability_layer) + // Intercept gRPC requests. + .layer(tonic::service::InterceptorLayer::new(move |mut req| { + // Extract tenant metadata. + req = tenant_interceptor.call(req)?; + // Authenticate tenant JWT token. + req = auth_interceptor.call(req)?; + Ok(req) + })) + .service(proto::PageServiceServer::new(page_service_handler)); let server = server.add_service(page_service); // Reflection service for use with e.g. grpcurl. @@ -3311,6 +3323,7 @@ impl proto::PageService for PageServerHandler { type GetPagesStream = Pin> + Send>>; + #[instrument(skip_all)] async fn check_rel_exists( &self, _: tonic::Request, @@ -3318,6 +3331,7 @@ impl proto::PageService for PageServerHandler { Err(tonic::Status::unimplemented("not implemented")) } + #[instrument(skip_all)] async fn get_base_backup( &self, _: tonic::Request, @@ -3325,6 +3339,7 @@ impl proto::PageService for PageServerHandler { Err(tonic::Status::unimplemented("not implemented")) } + #[instrument(skip_all)] async fn get_db_size( &self, _: tonic::Request, @@ -3332,6 +3347,7 @@ impl proto::PageService for PageServerHandler { Err(tonic::Status::unimplemented("not implemented")) } + // NB: don't instrument this, instrument each streamed request. async fn get_pages( &self, _: tonic::Request>, @@ -3339,6 +3355,7 @@ impl proto::PageService for PageServerHandler { Err(tonic::Status::unimplemented("not implemented")) } + #[instrument(skip_all)] async fn get_rel_size( &self, _: tonic::Request, @@ -3346,6 +3363,7 @@ impl proto::PageService for PageServerHandler { Err(tonic::Status::unimplemented("not implemented")) } + #[instrument(skip_all)] async fn get_slru_segment( &self, _: tonic::Request, @@ -3354,19 +3372,65 @@ impl proto::PageService for PageServerHandler { } } -impl From for QueryError { - fn from(e: GetActiveTenantError) -> Self { - match e { - GetActiveTenantError::WaitForActiveTimeout { .. } => QueryError::Disconnected( - ConnectionError::Io(io::Error::new(io::ErrorKind::TimedOut, e.to_string())), - ), - GetActiveTenantError::Cancelled - | GetActiveTenantError::WillNotBecomeActive(TenantState::Stopping { .. }) => { - QueryError::Shutdown - } - e @ GetActiveTenantError::NotFound(_) => QueryError::NotFound(format!("{e}").into()), - e => QueryError::Other(anyhow::anyhow!(e)), - } +/// gRPC middleware layer that handles observability concerns: +/// +/// * Creates and enters a tracing span. +/// +/// TODO: add perf tracing. +/// TODO: add timing and metrics. +/// TODO: add logging. +#[derive(Clone)] +struct ObservabilityLayer; + +impl tower::Layer for ObservabilityLayer { + type Service = ObservabilityLayerService; + + fn layer(&self, inner: S) -> Self::Service { + Self::Service { inner } + } +} + +#[derive(Clone)] +struct ObservabilityLayerService { + inner: S, +} + +impl tonic::server::NamedService for ObservabilityLayerService { + const NAME: &'static str = S::NAME; // propagate inner service name +} + +impl tower::Service> for ObservabilityLayerService +where + S: tower::Service>, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = BoxFuture<'static, Result>; + + fn call(&mut self, req: http::Request) -> Self::Future { + // Create a basic tracing span. Enter the span for the current thread (to use it for inner + // sync code like interceptors), and instrument the future (to use it for inner async code + // like the page service itself). + // + // The instrument() call below is not sufficient. It only affects the returned future, and + // only takes effect when the caller polls it. Any sync code executed when we call + // self.inner.call() below (such as interceptors) runs outside of the returned future, and + // is not affected by it. We therefore have to enter the span on the current thread too. + let span = info_span!( + "grpc:pageservice", + // Set by TenantMetadataInterceptor. + tenant_id = field::Empty, + timeline_id = field::Empty, + shard_id = field::Empty, + ); + let _guard = span.enter(); + + Box::pin(self.inner.call(req).instrument(span.clone())) + } + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) } } @@ -3400,19 +3464,22 @@ impl tonic::service::Interceptor for TenantMetadataInterceptor { .map_err(|_| tonic::Status::invalid_argument("invalid neon-timeline-id"))?; // Decode the shard ID. - let shard_index = req + let shard_id = req .metadata() .get("neon-shard-id") .ok_or_else(|| tonic::Status::invalid_argument("missing neon-shard-id"))? .to_str() .map_err(|_| tonic::Status::invalid_argument("invalid neon-shard-id"))?; - let shard_index = ShardIndex::from_str(shard_index) + let shard_id = ShardIndex::from_str(shard_id) .map_err(|_| tonic::Status::invalid_argument("invalid neon-shard-id"))?; // Stash them in the request. let extensions = req.extensions_mut(); extensions.insert(TenantTimelineId::new(tenant_id, timeline_id)); - extensions.insert(shard_index); + extensions.insert(shard_id); + + // Decorate the tracing span. + span_record!(%tenant_id, %timeline_id, %shard_id); Ok(req) } @@ -3486,6 +3553,22 @@ impl From for QueryError { } } +impl From for QueryError { + fn from(e: GetActiveTenantError) -> Self { + match e { + GetActiveTenantError::WaitForActiveTimeout { .. } => QueryError::Disconnected( + ConnectionError::Io(io::Error::new(io::ErrorKind::TimedOut, e.to_string())), + ), + GetActiveTenantError::Cancelled + | GetActiveTenantError::WillNotBecomeActive(TenantState::Stopping { .. }) => { + QueryError::Shutdown + } + e @ GetActiveTenantError::NotFound(_) => QueryError::NotFound(format!("{e}").into()), + e => QueryError::Other(anyhow::anyhow!(e)), + } + } +} + impl From for QueryError { fn from(e: crate::tenant::timeline::handle::HandleUpgradeError) -> Self { match e { From a21c1174edefdfb59fbdce9ae5696c446a3cfe0a Mon Sep 17 00:00:00 2001 From: Erik Grinaker Date: Mon, 2 Jun 2025 16:50:49 +0200 Subject: [PATCH 43/48] pagebench: add gRPC support for `get-page-latest-lsn` (#12077) ## Problem We need gRPC support in Pagebench to benchmark the new gRPC Pageserver implementation. Touches #11728. ## Summary of changes Adds a `Client` trait to make the client transport swappable, and a gRPC client via a `--protocol grpc` parameter. This must also specify the connstring with the gRPC port: ``` pagebench get-page-latest-lsn --protocol grpc --page-service-connstring grpc://localhost:51051 ``` The client is implemented using the raw Tonic-generated gRPC client, to minimize client overhead. --- Cargo.lock | 4 + libs/pageserver_api/src/models.rs | 4 +- libs/pageserver_api/src/reltag.rs | 2 +- pageserver/pagebench/Cargo.toml | 6 +- .../pagebench/src/cmd/getpage_latest_lsn.rs | 146 ++++++++++++++++-- 5 files changed, 144 insertions(+), 18 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4f7378e95d..9fc233e5ec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4236,6 +4236,7 @@ name = "pagebench" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", "camino", "clap", "futures", @@ -4244,12 +4245,15 @@ dependencies = [ "humantime-serde", "pageserver_api", "pageserver_client", + "pageserver_page_api", "rand 0.8.5", "reqwest", "serde", "serde_json", "tokio", + "tokio-stream", "tokio-util", + "tonic 0.13.1", "tracing", "utils", "workspace_hack", diff --git a/libs/pageserver_api/src/models.rs b/libs/pageserver_api/src/models.rs index e7d612bb7a..01487c0f57 100644 --- a/libs/pageserver_api/src/models.rs +++ b/libs/pageserver_api/src/models.rs @@ -2045,7 +2045,7 @@ pub enum PagestreamProtocolVersion { pub type RequestId = u64; -#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] pub struct PagestreamRequest { pub reqid: RequestId, pub request_lsn: Lsn, @@ -2064,7 +2064,7 @@ pub struct PagestreamNblocksRequest { pub rel: RelTag, } -#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] pub struct PagestreamGetPageRequest { pub hdr: PagestreamRequest, pub rel: RelTag, diff --git a/libs/pageserver_api/src/reltag.rs b/libs/pageserver_api/src/reltag.rs index 473a44dbf9..4509cab2e0 100644 --- a/libs/pageserver_api/src/reltag.rs +++ b/libs/pageserver_api/src/reltag.rs @@ -24,7 +24,7 @@ use serde::{Deserialize, Serialize}; // FIXME: should move 'forknum' as last field to keep this consistent with Postgres. // Then we could replace the custom Ord and PartialOrd implementations below with // deriving them. This will require changes in walredoproc.c. -#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize)] +#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize)] pub struct RelTag { pub forknum: u8, pub spcnode: Oid, diff --git a/pageserver/pagebench/Cargo.toml b/pageserver/pagebench/Cargo.toml index 5b5ed09a2b..ceb1278eab 100644 --- a/pageserver/pagebench/Cargo.toml +++ b/pageserver/pagebench/Cargo.toml @@ -8,6 +8,7 @@ license.workspace = true [dependencies] anyhow.workspace = true +async-trait.workspace = true camino.workspace = true clap.workspace = true futures.workspace = true @@ -15,14 +16,17 @@ hdrhistogram.workspace = true humantime.workspace = true humantime-serde.workspace = true rand.workspace = true -reqwest.workspace=true +reqwest.workspace = true serde.workspace = true serde_json.workspace = true tracing.workspace = true tokio.workspace = true +tokio-stream.workspace = true tokio-util.workspace = true +tonic.workspace = true pageserver_client.workspace = true pageserver_api.workspace = true +pageserver_page_api.workspace = true utils = { path = "../../libs/utils/" } workspace_hack = { version = "0.1", path = "../../workspace_hack" } diff --git a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs index 50419ec338..395e9cac41 100644 --- a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs +++ b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs @@ -7,11 +7,15 @@ use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; use anyhow::Context; +use async_trait::async_trait; use camino::Utf8PathBuf; use pageserver_api::key::Key; use pageserver_api::keyspace::KeySpaceAccum; -use pageserver_api::models::{PagestreamGetPageRequest, PagestreamRequest}; +use pageserver_api::models::{ + PagestreamGetPageRequest, PagestreamGetPageResponse, PagestreamRequest, +}; use pageserver_api::shard::TenantShardId; +use pageserver_page_api::proto; use rand::prelude::*; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; @@ -22,6 +26,12 @@ use utils::lsn::Lsn; use crate::util::tokio_thread_local_stats::AllThreadLocalStats; use crate::util::{request_stats, tokio_thread_local_stats}; +#[derive(clap::ValueEnum, Clone, Debug)] +enum Protocol { + Libpq, + Grpc, +} + /// GetPage@LatestLSN, uniformly distributed across the compute-accessible keyspace. #[derive(clap::Parser)] pub(crate) struct Args { @@ -35,6 +45,8 @@ pub(crate) struct Args { num_clients: NonZeroUsize, #[clap(long)] runtime: Option, + #[clap(long, value_enum, default_value = "libpq")] + protocol: Protocol, /// Each client sends requests at the given rate. /// /// If a request takes too long and we should be issuing a new request already, @@ -303,7 +315,20 @@ async fn main_impl( .unwrap(); Box::pin(async move { - client_libpq(args, worker_id, ss, cancel, rps_period, ranges, weights).await + let client: Box = match args.protocol { + Protocol::Libpq => Box::new( + LibpqClient::new(args.page_service_connstring.clone(), worker_id.timeline) + .await + .unwrap(), + ), + + Protocol::Grpc => Box::new( + GrpcClient::new(args.page_service_connstring.clone(), worker_id.timeline) + .await + .unwrap(), + ), + }; + run_worker(args, client, ss, cancel, rps_period, ranges, weights).await }) }; @@ -355,23 +380,15 @@ async fn main_impl( anyhow::Ok(()) } -async fn client_libpq( +async fn run_worker( args: &Args, - worker_id: WorkerId, + mut client: Box, shared_state: Arc, cancel: CancellationToken, rps_period: Option, ranges: Vec, weights: rand::distributions::weighted::WeightedIndex, ) { - let client = pageserver_client::page_service::Client::new(args.page_service_connstring.clone()) - .await - .unwrap(); - let mut client = client - .pagestream(worker_id.timeline.tenant_id, worker_id.timeline.timeline_id) - .await - .unwrap(); - shared_state.start_work_barrier.wait().await; let client_start = Instant::now(); let mut ticks_processed = 0; @@ -415,12 +432,12 @@ async fn client_libpq( blkno: block_no, } }; - client.getpage_send(req).await.unwrap(); + client.send_get_page(req).await.unwrap(); inflight.push_back(start); } let start = inflight.pop_front().unwrap(); - client.getpage_recv().await.unwrap(); + client.recv_get_page().await.unwrap(); let end = Instant::now(); shared_state.live_stats.request_done(); ticks_processed += 1; @@ -442,3 +459,104 @@ async fn client_libpq( } } } + +/// A benchmark client, to allow switching out the transport protocol. +/// +/// For simplicity, this just uses separate asynchronous send/recv methods. The send method could +/// return a future that resolves when the response is received, but we don't really need it. +#[async_trait] +trait Client: Send { + /// Sends an asynchronous GetPage request to the pageserver. + async fn send_get_page(&mut self, req: PagestreamGetPageRequest) -> anyhow::Result<()>; + + /// Receives the next GetPage response from the pageserver. + async fn recv_get_page(&mut self) -> anyhow::Result; +} + +/// A libpq-based Pageserver client. +struct LibpqClient { + inner: pageserver_client::page_service::PagestreamClient, +} + +impl LibpqClient { + async fn new(connstring: String, ttid: TenantTimelineId) -> anyhow::Result { + let inner = pageserver_client::page_service::Client::new(connstring) + .await? + .pagestream(ttid.tenant_id, ttid.timeline_id) + .await?; + Ok(Self { inner }) + } +} + +#[async_trait] +impl Client for LibpqClient { + async fn send_get_page(&mut self, req: PagestreamGetPageRequest) -> anyhow::Result<()> { + self.inner.getpage_send(req).await + } + + async fn recv_get_page(&mut self) -> anyhow::Result { + self.inner.getpage_recv().await + } +} + +/// A gRPC client using the raw, no-frills gRPC client. +struct GrpcClient { + req_tx: tokio::sync::mpsc::Sender, + resp_rx: tonic::Streaming, +} + +impl GrpcClient { + async fn new(connstring: String, ttid: TenantTimelineId) -> anyhow::Result { + let mut client = pageserver_page_api::proto::PageServiceClient::connect(connstring).await?; + + // The channel has a buffer size of 1, since 0 is not allowed. It does not matter, since the + // benchmark will control the queue depth (i.e. in-flight requests) anyway, and requests are + // buffered by Tonic and the OS too. + let (req_tx, req_rx) = tokio::sync::mpsc::channel(1); + let req_stream = tokio_stream::wrappers::ReceiverStream::new(req_rx); + let mut req = tonic::Request::new(req_stream); + let metadata = req.metadata_mut(); + metadata.insert("neon-tenant-id", ttid.tenant_id.to_string().try_into()?); + metadata.insert("neon-timeline-id", ttid.timeline_id.to_string().try_into()?); + metadata.insert("neon-shard-id", "0000".try_into()?); + + let resp = client.get_pages(req).await?; + let resp_stream = resp.into_inner(); + + Ok(Self { + req_tx, + resp_rx: resp_stream, + }) + } +} + +#[async_trait] +impl Client for GrpcClient { + async fn send_get_page(&mut self, req: PagestreamGetPageRequest) -> anyhow::Result<()> { + let req = proto::GetPageRequest { + request_id: 0, + request_class: proto::GetPageClass::Normal as i32, + read_lsn: Some(proto::ReadLsn { + request_lsn: req.hdr.request_lsn.0, + not_modified_since_lsn: req.hdr.not_modified_since.0, + }), + rel: Some(req.rel.into()), + block_number: vec![req.blkno], + }; + self.req_tx.send(req).await?; + Ok(()) + } + + async fn recv_get_page(&mut self) -> anyhow::Result { + let resp = self.resp_rx.message().await?.unwrap(); + anyhow::ensure!( + resp.status_code == proto::GetPageStatusCode::Ok as i32, + "unexpected status code: {}", + resp.status_code + ); + Ok(PagestreamGetPageResponse { + page: resp.page_image[0].clone(), + req: PagestreamGetPageRequest::default(), // dummy + }) + } +} From 781bf4945d9cb3902de829a187ee0e7ebc71e432 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 2 Jun 2025 17:13:30 +0100 Subject: [PATCH 44/48] proxy: optimise future layout allocations (#12104) A smaller version of #12066 that is somewhat easier to review. Now that I've been using https://crates.io/crates/top-type-sizes I've found a lot more of the low hanging fruit that can be tweaks to reduce the memory usage. Some context for the optimisations: Rust's stack allocation in futures is quite naive. Stack variables, even if moved, often still end up taking space in the future. Rearranging the order in which variables are defined, and properly scoping them can go a long way. `async fn` and `async move {}` have a consequence that they always duplicate the "upvars" (aka captures). All captures are permanently allocated in the future, even if moved. We can be mindful when writing futures to only capture as little as possible. TlsStream is massive. Needs boxing so it doesn't contribute to the above issue. ## Measurements from `top-type-sizes`: ### Before ``` 10328 {async block@proxy::proxy::task_main::{closure#0}::{closure#0}} align=8 6120 {async fn body of proxy::proxy::handle_client>()} align=8 ``` ### After ``` 4040 {async block@proxy::proxy::task_main::{closure#0}::{closure#0}} 4704 {async fn body of proxy::proxy::handle_client>()} align=8 ``` --- proxy/src/auth/backend/classic.rs | 16 ++-- proxy/src/console_redirect_proxy.rs | 2 +- .../control_plane/client/cplane_proxy_v1.rs | 75 +++++++++++-------- proxy/src/http/mod.rs | 27 +++++-- proxy/src/pqproto.rs | 6 +- proxy/src/proxy/handshake.rs | 9 ++- proxy/src/proxy/mod.rs | 2 +- proxy/src/proxy/passthrough.rs | 2 + proxy/src/sasl/stream.rs | 49 ++++++------ proxy/src/stream.rs | 4 +- proxy/src/tls/postgres_rustls.rs | 6 +- 11 files changed, 115 insertions(+), 83 deletions(-) diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index dcc500f2c8..8445368740 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -25,19 +25,15 @@ pub(super) async fn authenticate( } AuthSecret::Scram(secret) => { debug!("auth endpoint chooses SCRAM"); - let scram = auth::Scram(&secret, ctx); - let auth_outcome = tokio::time::timeout(config.scram_protocol_timeout, async { - AuthFlow::new(client, scram) - .authenticate() - .await - .inspect_err(|error| { - warn!(?error, "error processing scram messages"); - }) - }) + let auth_outcome = tokio::time::timeout( + config.scram_protocol_timeout, + AuthFlow::new(client, auth::Scram(&secret, ctx)).authenticate(), + ) .await .inspect_err(|_| warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs())) - .map_err(auth::AuthError::user_timeout)??; + .map_err(auth::AuthError::user_timeout)? + .inspect_err(|error| warn!(?error, "error processing scram messages"))?; let client_key = match auth_outcome { sasl::Outcome::Success(key) => key, diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 9499aba61b..7fb84b5ee5 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -159,7 +159,7 @@ pub async fn task_main( } #[allow(clippy::too_many_arguments)] -pub(crate) async fn handle_client( +pub(crate) async fn handle_client( config: &'static ProxyConfig, backend: &'static ConsoleRedirectBackend, ctx: &RequestContext, diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index 93f4ea6cf7..da548d6b2c 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -7,7 +7,9 @@ use std::time::Duration; use ::http::HeaderName; use ::http::header::AUTHORIZATION; +use bytes::Bytes; use futures::TryFutureExt; +use hyper::StatusCode; use postgres_client::config::SslMode; use tokio::time::Instant; use tracing::{Instrument, debug, info, info_span, warn}; @@ -72,28 +74,34 @@ impl NeonControlPlaneClient { role: &RoleName, ) -> Result { async { - let request = self - .endpoint - .get_path("get_endpoint_access_control") - .header(X_REQUEST_ID, ctx.session_id().to_string()) - .header(AUTHORIZATION, format!("Bearer {}", &self.jwt)) - .query(&[("session_id", ctx.session_id())]) - .query(&[ - ("application_name", ctx.console_application_name().as_str()), - ("endpointish", endpoint.as_str()), - ("role", role.as_str()), - ]) - .build()?; - - debug!(url = request.url().as_str(), "sending http request"); - let start = Instant::now(); let response = { - let _pause = ctx.latency_timer_pause_at(start, crate::metrics::Waiting::Cplane); - self.endpoint.execute(request).await? - }; - info!(duration = ?start.elapsed(), "received http response"); + let request = self + .endpoint + .get_path("get_endpoint_access_control") + .header(X_REQUEST_ID, ctx.session_id().to_string()) + .header(AUTHORIZATION, format!("Bearer {}", &self.jwt)) + .query(&[("session_id", ctx.session_id())]) + .query(&[ + ("application_name", ctx.console_application_name().as_str()), + ("endpointish", endpoint.as_str()), + ("role", role.as_str()), + ]) + .build()?; - let body = match parse_body::(response).await { + debug!(url = request.url().as_str(), "sending http request"); + let start = Instant::now(); + let _pause = ctx.latency_timer_pause_at(start, crate::metrics::Waiting::Cplane); + let response = self.endpoint.execute(request).await?; + + info!(duration = ?start.elapsed(), "received http response"); + + response + }; + + let body = match parse_body::( + response.status(), + response.bytes().await?, + ) { Ok(body) => body, // Error 404 is special: it's ok not to have a secret. // TODO(anna): retry @@ -184,7 +192,10 @@ impl NeonControlPlaneClient { drop(pause); info!(duration = ?start.elapsed(), "received http response"); - let body = parse_body::(response).await?; + let body = parse_body::( + response.status(), + response.bytes().await.map_err(ControlPlaneError::from)?, + )?; let rules = body .jwks @@ -236,7 +247,7 @@ impl NeonControlPlaneClient { let response = self.endpoint.execute(request).await?; drop(pause); info!(duration = ?start.elapsed(), "received http response"); - let body = parse_body::(response).await?; + let body = parse_body::(response.status(), response.bytes().await?)?; // Unfortunately, ownership won't let us use `Option::ok_or` here. let (host, port) = match parse_host_port(&body.address) { @@ -487,33 +498,33 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { } /// Parse http response body, taking status code into account. -async fn parse_body serde::Deserialize<'a>>( - response: http::Response, +fn parse_body serde::Deserialize<'a>>( + status: StatusCode, + body: Bytes, ) -> Result { - let status = response.status(); if status.is_success() { // We shouldn't log raw body because it may contain secrets. info!("request succeeded, processing the body"); - return Ok(response.json().await?); + return Ok(serde_json::from_slice(&body).map_err(std::io::Error::other)?); } - let s = response.bytes().await?; + // Log plaintext to be able to detect, whether there are some cases not covered by the error struct. - info!("response_error plaintext: {:?}", s); + info!("response_error plaintext: {:?}", body); // Don't throw an error here because it's not as important // as the fact that the request itself has failed. - let mut body = serde_json::from_slice(&s).unwrap_or_else(|e| { + let mut body = serde_json::from_slice(&body).unwrap_or_else(|e| { warn!("failed to parse error body: {e}"); - ControlPlaneErrorMessage { + Box::new(ControlPlaneErrorMessage { error: "reason unclear (malformed error message)".into(), http_status_code: status, status: None, - } + }) }); body.http_status_code = status; warn!("console responded with an error ({status}): {body:?}"); - Err(ControlPlaneError::Message(Box::new(body))) + Err(ControlPlaneError::Message(body)) } fn parse_host_port(input: &str) -> Option<(&str, u16)> { diff --git a/proxy/src/http/mod.rs b/proxy/src/http/mod.rs index 96f600d836..36607e7861 100644 --- a/proxy/src/http/mod.rs +++ b/proxy/src/http/mod.rs @@ -4,9 +4,10 @@ pub mod health_server; -use std::time::Duration; +use std::time::{Duration, Instant}; use bytes::Bytes; +use futures::FutureExt; use http::Method; use http_body_util::BodyExt; use hyper::body::Body; @@ -109,15 +110,31 @@ impl Endpoint { } /// Execute a [request](reqwest::Request). - pub(crate) async fn execute(&self, request: Request) -> Result { - let _timer = Metrics::get() + pub(crate) fn execute( + &self, + request: Request, + ) -> impl Future> { + let metric = Metrics::get() .proxy .console_request_latency - .start_timer(ConsoleRequest { + .with_labels(ConsoleRequest { request: request.url().path(), }); - self.client.execute(request).await + let req = self.client.execute(request).boxed(); + + async move { + let start = Instant::now(); + scopeguard::defer!({ + Metrics::get() + .proxy + .console_request_latency + .get_metric(metric) + .observe_duration_since(start); + }); + + req.await + } } } diff --git a/proxy/src/pqproto.rs b/proxy/src/pqproto.rs index d68d9f9474..43074bf208 100644 --- a/proxy/src/pqproto.rs +++ b/proxy/src/pqproto.rs @@ -186,7 +186,7 @@ where pub async fn read_message<'a, S>( stream: &mut S, buf: &'a mut Vec, - max: usize, + max: u32, ) -> io::Result<(u8, &'a mut [u8])> where S: AsyncRead + Unpin, @@ -206,7 +206,7 @@ where let header = read!(stream => Header); // as described above, the length must be at least 4. - let Some(len) = (header.len.get() as usize).checked_sub(4) else { + let Some(len) = header.len.get().checked_sub(4) else { return Err(io::Error::other(format!( "invalid startup message length {}, must be at least 4.", header.len, @@ -222,7 +222,7 @@ where } // read in our entire message. - buf.resize(len, 0); + buf.resize(len as usize, 0); stream.read_exact(buf).await?; Ok((header.tag, buf)) diff --git a/proxy/src/proxy/handshake.rs b/proxy/src/proxy/handshake.rs index 13ee8c7dd2..6970ab8714 100644 --- a/proxy/src/proxy/handshake.rs +++ b/proxy/src/proxy/handshake.rs @@ -1,3 +1,4 @@ +use futures::{FutureExt, TryFutureExt}; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{debug, info, warn}; @@ -57,7 +58,7 @@ pub(crate) enum HandshakeData { /// It's easier to work with owned `stream` here as we need to upgrade it to TLS; /// we also take an extra care of propagating only the select handshake errors to client. #[tracing::instrument(skip_all)] -pub(crate) async fn handshake( +pub(crate) async fn handshake( ctx: &RequestContext, stream: S, mut tls: Option<&TlsConfig>, @@ -108,7 +109,9 @@ pub(crate) async fn handshake( } } } - }); + }) + .map_ok(Box::new) + .boxed(); res?; @@ -146,7 +149,7 @@ pub(crate) async fn handshake( tls.cert_resolver.resolve(conn_info.server_name()); let tls = Stream::Tls { - tls: Box::new(tls_stream), + tls: tls_stream, tls_server_end_point, }; (stream, msg) = PqStream::parse_startup(tls).await?; diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index ac0aca1176..0ffc54aa88 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -270,7 +270,7 @@ impl ReportableError for ClientRequestError { } #[allow(clippy::too_many_arguments)] -pub(crate) async fn handle_client( +pub(crate) async fn handle_client( config: &'static ProxyConfig, auth_backend: &'static auth::Backend<'static, ()>, ctx: &RequestContext, diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index 8f9bd2de2d..55ab5f4dba 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -1,3 +1,4 @@ +use futures::FutureExt; use smol_str::SmolStr; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::debug; @@ -89,6 +90,7 @@ impl ProxyPassthrough { .compute .cancel_closure .try_cancel_query(compute_config) + .boxed() .await { tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database"); diff --git a/proxy/src/sasl/stream.rs b/proxy/src/sasl/stream.rs index cb15132673..52ccca58d5 100644 --- a/proxy/src/sasl/stream.rs +++ b/proxy/src/sasl/stream.rs @@ -30,52 +30,53 @@ where F: FnOnce(&str) -> super::Result, M: Mechanism, { - let sasl = { + let (mut mechanism, mut input) = { // pause the timer while we communicate with the client let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); // Initial client message contains the chosen auth method's name. let msg = stream.read_password_message().await?; - super::FirstMessage::parse(msg).ok_or(super::Error::BadClientMessage("bad sasl message"))? + + let sasl = super::FirstMessage::parse(msg) + .ok_or(super::Error::BadClientMessage("bad sasl message"))?; + + (mechanism(sasl.method)?, sasl.message) }; - let mut mechanism = mechanism(sasl.method)?; - let mut input = sasl.message; loop { - let step = mechanism - .exchange(input) - .inspect_err(|error| tracing::info!(?error, "error during SASL exchange"))?; - - match step { - Step::Continue(moved_mechanism, reply) => { + match mechanism.exchange(input) { + Ok(Step::Continue(moved_mechanism, reply)) => { mechanism = moved_mechanism; - // pause the timer while we communicate with the client - let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - // write reply let sasl_msg = BeAuthenticationSaslMessage::Continue(reply.as_bytes()); stream.write_message(BeMessage::AuthenticationSasl(sasl_msg)); - - // get next input - stream.flush().await?; - let msg = stream.read_password_message().await?; - input = std::str::from_utf8(msg) - .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?; + drop(reply); } - Step::Success(result, reply) => { - // pause the timer while we communicate with the client - let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - + Ok(Step::Success(result, reply)) => { // write reply let sasl_msg = BeAuthenticationSaslMessage::Final(reply.as_bytes()); stream.write_message(BeMessage::AuthenticationSasl(sasl_msg)); stream.write_message(BeMessage::AuthenticationOk); + // exit with success break Ok(Outcome::Success(result)); } // exit with failure - Step::Failure(reason) => break Ok(Outcome::Failure(reason)), + Ok(Step::Failure(reason)) => break Ok(Outcome::Failure(reason)), + Err(error) => { + tracing::info!(?error, "error during SASL exchange"); + return Err(error); + } } + + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + + // get next input + stream.flush().await?; + let msg = stream.read_password_message().await?; + input = std::str::from_utf8(msg) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?; } } diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 7126430a85..c49a431c95 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -72,7 +72,7 @@ impl PqStream { impl PqStream { /// Read a raw postgres packet, which will respect the max length requested. /// This is not cancel safe. - async fn read_raw_expect(&mut self, tag: u8, max: usize) -> io::Result<&mut [u8]> { + async fn read_raw_expect(&mut self, tag: u8, max: u32) -> io::Result<&mut [u8]> { let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?; if actual_tag != tag { return Err(io::Error::other(format!( @@ -89,7 +89,7 @@ impl PqStream { // passwords are usually pretty short // and SASL SCRAM messages are no longer than 256 bytes in my testing // (a few hashes and random bytes, encoded into base64). - const MAX_PASSWORD_LENGTH: usize = 512; + const MAX_PASSWORD_LENGTH: u32 = 512; self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH) .await } diff --git a/proxy/src/tls/postgres_rustls.rs b/proxy/src/tls/postgres_rustls.rs index f09e916a1d..013b307f0b 100644 --- a/proxy/src/tls/postgres_rustls.rs +++ b/proxy/src/tls/postgres_rustls.rs @@ -31,7 +31,9 @@ mod private { type Output = io::Result>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - Pin::new(&mut self.inner).poll(cx).map_ok(RustlsStream) + Pin::new(&mut self.inner) + .poll(cx) + .map_ok(|s| RustlsStream(Box::new(s))) } } @@ -57,7 +59,7 @@ mod private { } } - pub struct RustlsStream(TlsStream); + pub struct RustlsStream(Box>); impl postgres_client::tls::TlsStream for RustlsStream where From fc3994eb71826de6fbec023b74558aa72a7c888b Mon Sep 17 00:00:00 2001 From: Erik Grinaker Date: Mon, 2 Jun 2025 19:15:18 +0200 Subject: [PATCH 45/48] pageserver: initial gRPC page service implementation (#12094) ## Problem We should expose the page service over gRPC. Requires #12093. Touches #11728. ## Summary of changes This patch adds an initial page service implementation over gRPC. It ties in with the existing `PageServerHandler` request logic, to avoid the implementations drifting apart for the core read path. This is just a bare-bones functional implementation. Several important aspects have been omitted, and will be addressed in follow-up PRs: * Limited observability: minimal tracing, no logging, limited metrics and timing, etc. * Rate limiting will currently block. * No performance optimization. * No cancellation handling. * No tests. I've only done rudimentary testing of this, but Pagebench passes at least. --- libs/pageserver_api/src/models.rs | 2 +- libs/pageserver_api/src/reltag.rs | 10 +- pageserver/page_api/src/model.rs | 17 +- pageserver/src/basebackup.rs | 26 +- pageserver/src/bin/pageserver.rs | 4 +- pageserver/src/page_service.rs | 822 ++++++++++++++++++++++++------ pageserver/src/tenant/timeline.rs | 12 + 7 files changed, 723 insertions(+), 170 deletions(-) diff --git a/libs/pageserver_api/src/models.rs b/libs/pageserver_api/src/models.rs index 01487c0f57..28ced4a368 100644 --- a/libs/pageserver_api/src/models.rs +++ b/libs/pageserver_api/src/models.rs @@ -1934,7 +1934,7 @@ pub enum PagestreamFeMessage { } // Wrapped in libpq CopyData -#[derive(strum_macros::EnumProperty)] +#[derive(Debug, strum_macros::EnumProperty)] pub enum PagestreamBeMessage { Exists(PagestreamExistsResponse), Nblocks(PagestreamNblocksResponse), diff --git a/libs/pageserver_api/src/reltag.rs b/libs/pageserver_api/src/reltag.rs index 4509cab2e0..e0dd4fdfe8 100644 --- a/libs/pageserver_api/src/reltag.rs +++ b/libs/pageserver_api/src/reltag.rs @@ -184,12 +184,12 @@ pub enum SlruKind { MultiXactOffsets, } -impl SlruKind { - pub fn to_str(&self) -> &'static str { +impl fmt::Display for SlruKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Self::Clog => "pg_xact", - Self::MultiXactMembers => "pg_multixact/members", - Self::MultiXactOffsets => "pg_multixact/offsets", + Self::Clog => write!(f, "pg_xact"), + Self::MultiXactMembers => write!(f, "pg_multixact/members"), + Self::MultiXactOffsets => write!(f, "pg_multixact/offsets"), } } } diff --git a/pageserver/page_api/src/model.rs b/pageserver/page_api/src/model.rs index 7ab97a994e..0268ab920b 100644 --- a/pageserver/page_api/src/model.rs +++ b/pageserver/page_api/src/model.rs @@ -10,6 +10,8 @@ //! //! - Validate protocol invariants, via try_from() and try_into(). +use std::fmt::Display; + use bytes::Bytes; use postgres_ffi::Oid; use smallvec::SmallVec; @@ -48,7 +50,8 @@ pub struct ReadLsn { pub request_lsn: Lsn, /// If given, the caller guarantees that the page has not been modified since this LSN. Must be /// smaller than or equal to request_lsn. This allows the Pageserver to serve an old page - /// without waiting for the request LSN to arrive. Valid for all request types. + /// without waiting for the request LSN to arrive. If not given, the request will read at the + /// request_lsn and wait for it to arrive if necessary. Valid for all request types. /// /// It is undefined behaviour to make a request such that the page was, in fact, modified /// between request_lsn and not_modified_since_lsn. The Pageserver might detect it and return an @@ -58,6 +61,17 @@ pub struct ReadLsn { pub not_modified_since_lsn: Option, } +impl Display for ReadLsn { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let req_lsn = self.request_lsn; + if let Some(mod_lsn) = self.not_modified_since_lsn { + write!(f, "{req_lsn}>={mod_lsn}") + } else { + req_lsn.fmt(f) + } + } +} + impl ReadLsn { /// Validates the ReadLsn. pub fn validate(&self) -> Result<(), ProtocolError> { @@ -584,6 +598,7 @@ impl TryFrom for proto::GetSlruSegmentResponse { type Error = ProtocolError; fn try_from(segment: GetSlruSegmentResponse) -> Result { + // TODO: can a segment legitimately be empty? if segment.is_empty() { return Err(ProtocolError::Missing("segment")); } diff --git a/pageserver/src/basebackup.rs b/pageserver/src/basebackup.rs index e89baa0bce..4dba9d267c 100644 --- a/pageserver/src/basebackup.rs +++ b/pageserver/src/basebackup.rs @@ -65,6 +65,30 @@ impl From for BasebackupError { } } +impl From for postgres_backend::QueryError { + fn from(err: BasebackupError) -> Self { + use postgres_backend::QueryError; + use pq_proto::framed::ConnectionError; + match err { + BasebackupError::Client(err, _) => QueryError::Disconnected(ConnectionError::Io(err)), + BasebackupError::Server(err) => QueryError::Other(err), + BasebackupError::Shutdown => QueryError::Shutdown, + } + } +} + +impl From for tonic::Status { + fn from(err: BasebackupError) -> Self { + use tonic::Code; + let code = match &err { + BasebackupError::Client(_, _) => Code::Cancelled, + BasebackupError::Server(_) => Code::Internal, + BasebackupError::Shutdown => Code::Unavailable, + }; + tonic::Status::new(code, err.to_string()) + } +} + /// Create basebackup with non-rel data in it. /// Only include relational data if 'full_backup' is true. /// @@ -248,7 +272,7 @@ where async fn flush(&mut self) -> Result<(), BasebackupError> { let nblocks = self.buf.len() / BLCKSZ as usize; let (kind, segno) = self.current_segment.take().unwrap(); - let segname = format!("{}/{:>04X}", kind.to_str(), segno); + let segname = format!("{kind}/{segno:>04X}"); let header = new_tar_header(&segname, self.buf.len() as u64)?; self.ar .append(&header, self.buf.as_slice()) diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index df3c045145..337aa135dc 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -804,7 +804,7 @@ fn start_pageserver( } else { None }, - basebackup_cache.clone(), + basebackup_cache, ); // Spawn a Pageserver gRPC server task. It will spawn separate tasks for @@ -816,12 +816,10 @@ fn start_pageserver( let mut page_service_grpc = None; if let Some(grpc_listener) = grpc_listener { page_service_grpc = Some(page_service::spawn_grpc( - conf, tenant_manager.clone(), grpc_auth, otel_guard.as_ref().map(|g| g.dispatch.clone()), grpc_listener, - basebackup_cache, )?); } diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index f011ed49d0..b9ba4a3555 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -1,6 +1,7 @@ //! The Page Service listens for client connections and serves their GetPage@LSN //! requests. +use std::any::Any; use std::borrow::Cow; use std::num::NonZeroUsize; use std::os::fd::AsRawFd; @@ -11,9 +12,9 @@ use std::task::{Context, Poll}; use std::time::{Duration, Instant, SystemTime}; use std::{io, str}; -use anyhow::{Context as _, bail}; +use anyhow::{Context as _, anyhow, bail}; use async_compression::tokio::write::GzipEncoder; -use bytes::Buf; +use bytes::{Buf, BytesMut}; use futures::future::BoxFuture; use futures::{FutureExt, Stream}; use itertools::Itertools; @@ -33,6 +34,7 @@ use pageserver_api::models::{ }; use pageserver_api::reltag::SlruKind; use pageserver_api::shard::TenantShardId; +use pageserver_page_api as page_api; use pageserver_page_api::proto; use postgres_backend::{ AuthType, PostgresBackend, PostgresBackendReader, QueryError, is_expected_io_error, @@ -41,8 +43,9 @@ use postgres_ffi::BLCKSZ; use postgres_ffi::pg_constants::DEFAULTTABLESPACE_OID; use pq_proto::framed::ConnectionError; use pq_proto::{BeMessage, FeMessage, FeStartupPacket, RowDescriptor}; +use smallvec::{SmallVec, smallvec}; use strum_macros::IntoStaticStr; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufWriter}; +use tokio::io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _, BufWriter}; use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use tonic::service::Interceptor as _; @@ -78,7 +81,8 @@ use crate::tenant::mgr::{ GetActiveTenantError, GetTenantError, ShardResolveResult, ShardSelector, TenantManager, }; use crate::tenant::storage_layer::IoConcurrency; -use crate::tenant::timeline::{self, WaitLsnError}; +use crate::tenant::timeline::handle::{Handle, HandleUpgradeError, WeakHandle}; +use crate::tenant::timeline::{self, WaitLsnError, WaitLsnTimeout, WaitLsnWaiter}; use crate::tenant::{GetTimelineError, PageReconstructError, Timeline}; use crate::{CancellableTask, PERF_TRACE_TARGET, timed_after_cancellation}; @@ -167,15 +171,14 @@ pub fn spawn( /// Spawns a gRPC server for the page service. /// +/// TODO: move this onto GrpcPageServiceHandler::spawn(). /// TODO: this doesn't support TLS. We need TLS reloading via ReloadingCertificateResolver, so we /// need to reimplement the TCP+TLS accept loop ourselves. pub fn spawn_grpc( - conf: &'static PageServerConf, tenant_manager: Arc, auth: Option>, perf_trace_dispatch: Option, listener: std::net::TcpListener, - basebackup_cache: Arc, ) -> anyhow::Result { let cancel = CancellationToken::new(); let ctx = RequestContextBuilder::new(TaskKind::PageRequestHandler) @@ -208,24 +211,17 @@ pub fn spawn_grpc( // // * Layers: allow async code, can run code after the service response. However, only has access // to the raw HTTP request/response, not the gRPC types. - let page_service_handler = PageServerHandler::new( + let page_service_handler = GrpcPageServiceHandler { tenant_manager, - auth.clone(), - PageServicePipeliningConfig::Serial, // TODO: unused with gRPC - conf.get_vectored_concurrent_io, - ConnectionPerfSpanFields::default(), - basebackup_cache, ctx, - cancel.clone(), - gate.enter().expect("just created"), - ); + }; let observability_layer = ObservabilityLayer; let mut tenant_interceptor = TenantMetadataInterceptor; let mut auth_interceptor = TenantAuthInterceptor::new(auth); let page_service = tower::ServiceBuilder::new() - // Create tracing span. + // Create tracing span and record request start time. .layer(observability_layer) // Intercept gRPC requests. .layer(tonic::service::InterceptorLayer::new(move |mut req| { @@ -554,7 +550,7 @@ impl TimelineHandles { tenant_id: TenantId, timeline_id: TimelineId, shard_selector: ShardSelector, - ) -> Result, GetActiveTimelineError> { + ) -> Result, GetActiveTimelineError> { if *self.wrapper.tenant_id.get_or_init(|| tenant_id) != tenant_id { return Err(GetActiveTimelineError::Tenant( GetActiveTenantError::SwitchedTenant, @@ -721,6 +717,82 @@ enum PageStreamError { BadRequest(Cow<'static, str>), } +impl PageStreamError { + /// Converts a PageStreamError into a proto::GetPageResponse with the appropriate status + /// code, or a gRPC status if it should terminate the stream (e.g. shutdown). This is a + /// convenience method for use from a get_pages gRPC stream. + #[allow(clippy::result_large_err)] + fn into_get_page_response( + self, + request_id: page_api::RequestID, + ) -> Result { + use page_api::GetPageStatusCode; + use tonic::Code; + + // We dispatch to Into first, and then map it to a GetPageResponse. + let status: tonic::Status = self.into(); + let status_code = match status.code() { + // We shouldn't see an OK status here, because we're emitting an error. + Code::Ok => { + debug_assert_ne!(status.code(), Code::Ok); + return Err(tonic::Status::internal(format!( + "unexpected OK status: {status:?}", + ))); + } + + // These are per-request errors, returned as GetPageResponses. + Code::AlreadyExists => GetPageStatusCode::InvalidRequest, + Code::DataLoss => GetPageStatusCode::InternalError, + Code::FailedPrecondition => GetPageStatusCode::InvalidRequest, + Code::InvalidArgument => GetPageStatusCode::InvalidRequest, + Code::Internal => GetPageStatusCode::InternalError, + Code::NotFound => GetPageStatusCode::NotFound, + Code::OutOfRange => GetPageStatusCode::InvalidRequest, + Code::ResourceExhausted => GetPageStatusCode::SlowDown, + + // These should terminate the stream. + Code::Aborted => return Err(status), + Code::Cancelled => return Err(status), + Code::DeadlineExceeded => return Err(status), + Code::PermissionDenied => return Err(status), + Code::Unauthenticated => return Err(status), + Code::Unavailable => return Err(status), + Code::Unimplemented => return Err(status), + Code::Unknown => return Err(status), + }; + + Ok(page_api::GetPageResponse { + request_id, + status_code, + reason: Some(status.message().to_string()), + page_images: SmallVec::new(), + } + .into()) + } +} + +impl From for tonic::Status { + fn from(err: PageStreamError) -> Self { + use tonic::Code; + let message = err.to_string(); + let code = match err { + PageStreamError::Reconnect(_) => Code::Unavailable, + PageStreamError::Shutdown => Code::Unavailable, + PageStreamError::Read(err) => match err { + PageReconstructError::Cancelled => Code::Unavailable, + PageReconstructError::MissingKey(_) => Code::NotFound, + PageReconstructError::AncestorLsnTimeout(err) => tonic::Status::from(err).code(), + PageReconstructError::Other(_) => Code::Internal, + PageReconstructError::WalRedo(_) => Code::Internal, + }, + PageStreamError::LsnTimeout(err) => tonic::Status::from(err).code(), + PageStreamError::NotFound(_) => Code::NotFound, + PageStreamError::BadRequest(_) => Code::InvalidArgument, + }; + tonic::Status::new(code, message) + } +} + impl From for PageStreamError { fn from(value: PageReconstructError) -> Self { match value { @@ -801,37 +873,37 @@ enum BatchedFeMessage { Exists { span: Span, timer: SmgrOpTimer, - shard: timeline::handle::WeakHandle, + shard: WeakHandle, req: models::PagestreamExistsRequest, }, Nblocks { span: Span, timer: SmgrOpTimer, - shard: timeline::handle::WeakHandle, + shard: WeakHandle, req: models::PagestreamNblocksRequest, }, GetPage { span: Span, - shard: timeline::handle::WeakHandle, - pages: smallvec::SmallVec<[BatchedGetPageRequest; 1]>, + shard: WeakHandle, + pages: SmallVec<[BatchedGetPageRequest; 1]>, batch_break_reason: GetPageBatchBreakReason, }, DbSize { span: Span, timer: SmgrOpTimer, - shard: timeline::handle::WeakHandle, + shard: WeakHandle, req: models::PagestreamDbSizeRequest, }, GetSlruSegment { span: Span, timer: SmgrOpTimer, - shard: timeline::handle::WeakHandle, + shard: WeakHandle, req: models::PagestreamGetSlruSegmentRequest, }, #[cfg(feature = "testing")] Test { span: Span, - shard: timeline::handle::WeakHandle, + shard: WeakHandle, requests: Vec, }, RespondError { @@ -1080,26 +1152,6 @@ impl PageServerHandler { let neon_fe_msg = PagestreamFeMessage::parse(&mut copy_data_bytes.reader(), protocol_version)?; - // TODO: turn in to async closure once available to avoid repeating received_at - async fn record_op_start_and_throttle( - shard: &timeline::handle::Handle, - op: metrics::SmgrQueryType, - received_at: Instant, - ) -> Result { - // It's important to start the smgr op metric recorder as early as possible - // so that the _started counters are incremented before we do - // any serious waiting, e.g., for throttle, batching, or actual request handling. - let mut timer = shard.query_metrics.start_smgr_op(op, received_at); - let now = Instant::now(); - timer.observe_throttle_start(now); - let throttled = tokio::select! { - res = shard.pagestream_throttle.throttle(1, now) => res, - _ = shard.cancel.cancelled() => return Err(QueryError::Shutdown), - }; - timer.observe_throttle_done(throttled); - Ok(timer) - } - let batched_msg = match neon_fe_msg { PagestreamFeMessage::Exists(req) => { let shard = timeline_handles @@ -1107,7 +1159,7 @@ impl PageServerHandler { .await?; debug_assert_current_span_has_tenant_and_timeline_id_no_shard_id(); let span = tracing::info_span!(parent: &parent_span, "handle_get_rel_exists_request", rel = %req.rel, req_lsn = %req.hdr.request_lsn, shard_id = %shard.tenant_shard_id.shard_slug()); - let timer = record_op_start_and_throttle( + let timer = Self::record_op_start_and_throttle( &shard, metrics::SmgrQueryType::GetRelExists, received_at, @@ -1125,7 +1177,7 @@ impl PageServerHandler { .get(tenant_id, timeline_id, ShardSelector::Zero) .await?; let span = tracing::info_span!(parent: &parent_span, "handle_get_nblocks_request", rel = %req.rel, req_lsn = %req.hdr.request_lsn, shard_id = %shard.tenant_shard_id.shard_slug()); - let timer = record_op_start_and_throttle( + let timer = Self::record_op_start_and_throttle( &shard, metrics::SmgrQueryType::GetRelSize, received_at, @@ -1143,7 +1195,7 @@ impl PageServerHandler { .get(tenant_id, timeline_id, ShardSelector::Zero) .await?; let span = tracing::info_span!(parent: &parent_span, "handle_db_size_request", dbnode = %req.dbnode, req_lsn = %req.hdr.request_lsn, shard_id = %shard.tenant_shard_id.shard_slug()); - let timer = record_op_start_and_throttle( + let timer = Self::record_op_start_and_throttle( &shard, metrics::SmgrQueryType::GetDbSize, received_at, @@ -1161,7 +1213,7 @@ impl PageServerHandler { .get(tenant_id, timeline_id, ShardSelector::Zero) .await?; let span = tracing::info_span!(parent: &parent_span, "handle_get_slru_segment_request", kind = %req.kind, segno = %req.segno, req_lsn = %req.hdr.request_lsn, shard_id = %shard.tenant_shard_id.shard_slug()); - let timer = record_op_start_and_throttle( + let timer = Self::record_op_start_and_throttle( &shard, metrics::SmgrQueryType::GetSlruSegment, received_at, @@ -1286,7 +1338,7 @@ impl PageServerHandler { // request handler log messages contain the request-specific fields. let span = mkspan!(shard.tenant_shard_id.shard_slug()); - let timer = record_op_start_and_throttle( + let timer = Self::record_op_start_and_throttle( &shard, metrics::SmgrQueryType::GetPageAtLsn, received_at, @@ -1333,7 +1385,7 @@ impl PageServerHandler { BatchedFeMessage::GetPage { span, shard: shard.downgrade(), - pages: smallvec::smallvec![BatchedGetPageRequest { + pages: smallvec![BatchedGetPageRequest { req, timer, lsn_range: LsnRange { @@ -1355,9 +1407,12 @@ impl PageServerHandler { .get(tenant_id, timeline_id, ShardSelector::Zero) .await?; let span = tracing::info_span!(parent: &parent_span, "handle_test_request", shard_id = %shard.tenant_shard_id.shard_slug()); - let timer = - record_op_start_and_throttle(&shard, metrics::SmgrQueryType::Test, received_at) - .await?; + let timer = Self::record_op_start_and_throttle( + &shard, + metrics::SmgrQueryType::Test, + received_at, + ) + .await?; BatchedFeMessage::Test { span, shard: shard.downgrade(), @@ -1368,6 +1423,26 @@ impl PageServerHandler { Ok(Some(batched_msg)) } + /// Starts a SmgrOpTimer at received_at and throttles the request. + async fn record_op_start_and_throttle( + shard: &Handle, + op: metrics::SmgrQueryType, + received_at: Instant, + ) -> Result { + // It's important to start the smgr op metric recorder as early as possible + // so that the _started counters are incremented before we do + // any serious waiting, e.g., for throttle, batching, or actual request handling. + let mut timer = shard.query_metrics.start_smgr_op(op, received_at); + let now = Instant::now(); + timer.observe_throttle_start(now); + let throttled = tokio::select! { + res = shard.pagestream_throttle.throttle(1, now) => res, + _ = shard.cancel.cancelled() => return Err(QueryError::Shutdown), + }; + timer.observe_throttle_done(throttled); + Ok(timer) + } + /// Post-condition: `batch` is Some() #[instrument(skip_all, level = tracing::Level::TRACE)] #[allow(clippy::boxed_local)] @@ -1465,8 +1540,11 @@ impl PageServerHandler { let (mut handler_results, span) = { // TODO: we unfortunately have to pin the future on the heap, since GetPage futures are huge and // won't fit on the stack. - let mut boxpinned = - Box::pin(self.pagestream_dispatch_batched_message(batch, io_concurrency, ctx)); + let mut boxpinned = Box::pin(Self::pagestream_dispatch_batched_message( + batch, + io_concurrency, + ctx, + )); log_slow( log_slow_name, LOG_SLOW_GETPAGE_THRESHOLD, @@ -1622,7 +1700,6 @@ impl PageServerHandler { /// Helper which dispatches a batched message to the appropriate handler. /// Returns a vec of results, along with the extracted trace span. async fn pagestream_dispatch_batched_message( - &mut self, batch: BatchedFeMessage, io_concurrency: IoConcurrency, ctx: &RequestContext, @@ -1652,10 +1729,10 @@ impl PageServerHandler { let (shard, ctx) = upgrade_handle_and_set_context!(shard); ( vec![ - self.handle_get_rel_exists_request(&shard, &req, &ctx) + Self::handle_get_rel_exists_request(&shard, &req, &ctx) .instrument(span.clone()) .await - .map(|msg| (msg, timer, ctx)) + .map(|msg| (PagestreamBeMessage::Exists(msg), timer, ctx)) .map_err(|err| BatchedPageStreamError { err, req: req.hdr }), ], span, @@ -1671,10 +1748,10 @@ impl PageServerHandler { let (shard, ctx) = upgrade_handle_and_set_context!(shard); ( vec![ - self.handle_get_nblocks_request(&shard, &req, &ctx) + Self::handle_get_nblocks_request(&shard, &req, &ctx) .instrument(span.clone()) .await - .map(|msg| (msg, timer, ctx)) + .map(|msg| (PagestreamBeMessage::Nblocks(msg), timer, ctx)) .map_err(|err| BatchedPageStreamError { err, req: req.hdr }), ], span, @@ -1692,16 +1769,15 @@ impl PageServerHandler { { let npages = pages.len(); trace!(npages, "handling getpage request"); - let res = self - .handle_get_page_at_lsn_request_batched( - &shard, - pages, - io_concurrency, - batch_break_reason, - &ctx, - ) - .instrument(span.clone()) - .await; + let res = Self::handle_get_page_at_lsn_request_batched( + &shard, + pages, + io_concurrency, + batch_break_reason, + &ctx, + ) + .instrument(span.clone()) + .await; assert_eq!(res.len(), npages); res }, @@ -1718,10 +1794,10 @@ impl PageServerHandler { let (shard, ctx) = upgrade_handle_and_set_context!(shard); ( vec![ - self.handle_db_size_request(&shard, &req, &ctx) + Self::handle_db_size_request(&shard, &req, &ctx) .instrument(span.clone()) .await - .map(|msg| (msg, timer, ctx)) + .map(|msg| (PagestreamBeMessage::DbSize(msg), timer, ctx)) .map_err(|err| BatchedPageStreamError { err, req: req.hdr }), ], span, @@ -1737,10 +1813,10 @@ impl PageServerHandler { let (shard, ctx) = upgrade_handle_and_set_context!(shard); ( vec![ - self.handle_get_slru_segment_request(&shard, &req, &ctx) + Self::handle_get_slru_segment_request(&shard, &req, &ctx) .instrument(span.clone()) .await - .map(|msg| (msg, timer, ctx)) + .map(|msg| (PagestreamBeMessage::GetSlruSegment(msg), timer, ctx)) .map_err(|err| BatchedPageStreamError { err, req: req.hdr }), ], span, @@ -1758,8 +1834,7 @@ impl PageServerHandler { { let npages = requests.len(); trace!(npages, "handling getpage request"); - let res = self - .handle_test_request_batch(&shard, requests, &ctx) + let res = Self::handle_test_request_batch(&shard, requests, &ctx) .instrument(span.clone()) .await; assert_eq!(res.len(), npages); @@ -2313,11 +2388,10 @@ impl PageServerHandler { #[instrument(skip_all, fields(shard_id))] async fn handle_get_rel_exists_request( - &mut self, timeline: &Timeline, req: &PagestreamExistsRequest, ctx: &RequestContext, - ) -> Result { + ) -> Result { let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); let lsn = Self::wait_or_get_last_lsn( timeline, @@ -2339,19 +2413,15 @@ impl PageServerHandler { ) .await?; - Ok(PagestreamBeMessage::Exists(PagestreamExistsResponse { - req: *req, - exists, - })) + Ok(PagestreamExistsResponse { req: *req, exists }) } #[instrument(skip_all, fields(shard_id))] async fn handle_get_nblocks_request( - &mut self, timeline: &Timeline, req: &PagestreamNblocksRequest, ctx: &RequestContext, - ) -> Result { + ) -> Result { let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); let lsn = Self::wait_or_get_last_lsn( timeline, @@ -2373,19 +2443,18 @@ impl PageServerHandler { ) .await?; - Ok(PagestreamBeMessage::Nblocks(PagestreamNblocksResponse { + Ok(PagestreamNblocksResponse { req: *req, n_blocks, - })) + }) } #[instrument(skip_all, fields(shard_id))] async fn handle_db_size_request( - &mut self, timeline: &Timeline, req: &PagestreamDbSizeRequest, ctx: &RequestContext, - ) -> Result { + ) -> Result { let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); let lsn = Self::wait_or_get_last_lsn( timeline, @@ -2409,17 +2478,13 @@ impl PageServerHandler { .await?; let db_size = total_blocks as i64 * BLCKSZ as i64; - Ok(PagestreamBeMessage::DbSize(PagestreamDbSizeResponse { - req: *req, - db_size, - })) + Ok(PagestreamDbSizeResponse { req: *req, db_size }) } #[instrument(skip_all)] async fn handle_get_page_at_lsn_request_batched( - &mut self, timeline: &Timeline, - requests: smallvec::SmallVec<[BatchedGetPageRequest; 1]>, + requests: SmallVec<[BatchedGetPageRequest; 1]>, io_concurrency: IoConcurrency, batch_break_reason: GetPageBatchBreakReason, ctx: &RequestContext, @@ -2544,11 +2609,10 @@ impl PageServerHandler { #[instrument(skip_all, fields(shard_id))] async fn handle_get_slru_segment_request( - &mut self, timeline: &Timeline, req: &PagestreamGetSlruSegmentRequest, ctx: &RequestContext, - ) -> Result { + ) -> Result { let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); let lsn = Self::wait_or_get_last_lsn( timeline, @@ -2563,16 +2627,13 @@ impl PageServerHandler { .ok_or(PageStreamError::BadRequest("invalid SLRU kind".into()))?; let segment = timeline.get_slru_segment(kind, req.segno, lsn, ctx).await?; - Ok(PagestreamBeMessage::GetSlruSegment( - PagestreamGetSlruSegmentResponse { req: *req, segment }, - )) + Ok(PagestreamGetSlruSegmentResponse { req: *req, segment }) } // NB: this impl mimics what we do for batched getpage requests. #[cfg(feature = "testing")] #[instrument(skip_all, fields(shard_id))] async fn handle_test_request_batch( - &mut self, timeline: &Timeline, requests: Vec, _ctx: &RequestContext, @@ -2648,15 +2709,6 @@ impl PageServerHandler { where IO: AsyncRead + AsyncWrite + Send + Sync + Unpin, { - fn map_basebackup_error(err: BasebackupError) -> QueryError { - match err { - // TODO: passthrough the error site to the final error message? - BasebackupError::Client(e, _) => QueryError::Disconnected(ConnectionError::Io(e)), - BasebackupError::Server(e) => QueryError::Other(e), - BasebackupError::Shutdown => QueryError::Shutdown, - } - } - let started = std::time::Instant::now(); let timeline = self @@ -2714,8 +2766,7 @@ impl PageServerHandler { replica, &ctx, ) - .await - .map_err(map_basebackup_error)?; + .await?; } else { let mut writer = BufWriter::new(pgb.copyout_writer()); @@ -2738,11 +2789,8 @@ impl PageServerHandler { from_cache = true; tokio::io::copy(&mut cached, &mut writer) .await - .map_err(|e| { - map_basebackup_error(BasebackupError::Client( - e, - "handle_basebackup_request,cached,copy", - )) + .map_err(|err| { + BasebackupError::Client(err, "handle_basebackup_request,cached,copy") })?; } else if gzip { let mut encoder = GzipEncoder::with_quality( @@ -2763,8 +2811,7 @@ impl PageServerHandler { replica, &ctx, ) - .await - .map_err(map_basebackup_error)?; + .await?; // shutdown the encoder to ensure the gzip footer is written encoder .shutdown() @@ -2780,15 +2827,12 @@ impl PageServerHandler { replica, &ctx, ) - .await - .map_err(map_basebackup_error)?; + .await?; } - writer.flush().await.map_err(|e| { - map_basebackup_error(BasebackupError::Client( - e, - "handle_basebackup_request,flush", - )) - })?; + writer + .flush() + .await + .map_err(|err| BasebackupError::Client(err, "handle_basebackup_request,flush"))?; } pgb.write_message_noflush(&BeMessage::CopyDone) @@ -3312,69 +3356,464 @@ where } } -/// Implements the page service over gRPC. +/// Serves the page service over gRPC. Dispatches to PageServerHandler for request processing. /// -/// TODO: not yet implemented, all methods return unimplemented. +/// TODO: rename to PageServiceHandler when libpq impl is removed. +pub struct GrpcPageServiceHandler { + tenant_manager: Arc, + ctx: RequestContext, +} + +impl GrpcPageServiceHandler { + /// Errors if the request is executed on a non-zero shard. Only shard 0 has a complete view of + /// relations and their sizes, as well as SLRU segments and similar data. + #[allow(clippy::result_large_err)] + fn ensure_shard_zero(timeline: &Handle) -> Result<(), tonic::Status> { + match timeline.get_shard_index().shard_number.0 { + 0 => Ok(()), + shard => Err(tonic::Status::invalid_argument(format!( + "request must execute on shard zero (is shard {shard})", + ))), + } + } + + /// Generates a PagestreamRequest header from a ReadLsn and request ID. + fn make_hdr(read_lsn: page_api::ReadLsn, req_id: u64) -> PagestreamRequest { + PagestreamRequest { + reqid: req_id, + request_lsn: read_lsn.request_lsn, + not_modified_since: read_lsn + .not_modified_since_lsn + .unwrap_or(read_lsn.request_lsn), + } + } + + /// Acquires a timeline handle for the given request. + /// + /// TODO: during shard splits, the compute may still be sending requests to the parent shard + /// until the entire split is committed and the compute is notified. Consider installing a + /// temporary shard router from the parent to the children while the split is in progress. + /// + /// TODO: consider moving this to a middleware layer; all requests need it. Needs to manage + /// the TimelineHandles lifecycle. + /// + /// TODO: untangle acquisition from TenantManagerWrapper::resolve() and Cache::get(), to avoid + /// the unnecessary overhead. + async fn get_request_timeline( + &self, + req: &tonic::Request, + ) -> Result, GetActiveTimelineError> { + let ttid = *extract::(req); + let shard_index = *extract::(req); + let shard_selector = ShardSelector::Known(shard_index); + + TimelineHandles::new(self.tenant_manager.clone()) + .get(ttid.tenant_id, ttid.timeline_id, shard_selector) + .await + } + + /// Starts a SmgrOpTimer at received_at, throttles the request, and records execution start. + /// Only errors if the timeline is shutting down. + /// + /// TODO: move timer construction to ObservabilityLayer (see TODO there). + /// TODO: decouple rate limiting (middleware?), and return SlowDown errors instead. + async fn record_op_start_and_throttle( + timeline: &Handle, + op: metrics::SmgrQueryType, + received_at: Instant, + ) -> Result { + let mut timer = PageServerHandler::record_op_start_and_throttle(timeline, op, received_at) + .await + .map_err(|err| match err { + // record_op_start_and_throttle() only returns Shutdown. + QueryError::Shutdown => tonic::Status::unavailable(format!("{err}")), + err => tonic::Status::internal(format!("unexpected error: {err}")), + })?; + timer.observe_execution_start(Instant::now()); + Ok(timer) + } + + /// Processes a GetPage batch request, via the GetPages bidirectional streaming RPC. + /// + /// NB: errors will terminate the stream. Per-request errors should return a GetPageResponse + /// with an appropriate status code instead. + /// + /// TODO: get_vectored() currently enforces a batch limit of 32. Postgres will typically send + /// batches up to effective_io_concurrency = 100. Either we have to accept large batches, or + /// split them up in the client or server. + #[instrument(skip_all, fields(req_id, rel, blkno, blks, req_lsn, mod_lsn))] + async fn get_page( + ctx: &RequestContext, + timeline: &WeakHandle, + req: proto::GetPageRequest, + io_concurrency: IoConcurrency, + ) -> Result { + let received_at = Instant::now(); + let timeline = timeline.upgrade()?; + let ctx = ctx.with_scope_page_service_pagestream(&timeline); + + // Validate the request, decorate the span, and convert it to a Pagestream request. + let req: page_api::GetPageRequest = req.try_into()?; + + span_record!( + req_id = %req.request_id, + rel = %req.rel, + blkno = %req.block_numbers[0], + blks = %req.block_numbers.len(), + lsn = %req.read_lsn, + ); + + let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); // hold guard + let effective_lsn = match PageServerHandler::effective_request_lsn( + &timeline, + timeline.get_last_record_lsn(), + req.read_lsn.request_lsn, + req.read_lsn + .not_modified_since_lsn + .unwrap_or(req.read_lsn.request_lsn), + &latest_gc_cutoff_lsn, + ) { + Ok(lsn) => lsn, + Err(err) => return err.into_get_page_response(req.request_id), + }; + + let mut batch = SmallVec::with_capacity(req.block_numbers.len()); + for blkno in req.block_numbers { + // TODO: this creates one timer per page and throttles it. We should have a timer for + // the entire batch, and throttle only the batch, but this is equivalent to what + // PageServerHandler does already so we keep it for now. + let timer = Self::record_op_start_and_throttle( + &timeline, + metrics::SmgrQueryType::GetPageAtLsn, + received_at, + ) + .await?; + + batch.push(BatchedGetPageRequest { + req: PagestreamGetPageRequest { + hdr: Self::make_hdr(req.read_lsn, req.request_id), + rel: req.rel, + blkno, + }, + lsn_range: LsnRange { + effective_lsn, + request_lsn: req.read_lsn.request_lsn, + }, + timer, + ctx: ctx.attached_child(), + batch_wait_ctx: None, // TODO: add tracing + }); + } + + // TODO: this does a relation size query for every page in the batch. Since this batch is + // all for one relation, we could do this only once. However, this is not the case for the + // libpq implementation. + let results = PageServerHandler::handle_get_page_at_lsn_request_batched( + &timeline, + batch, + io_concurrency, + GetPageBatchBreakReason::BatchFull, // TODO: not relevant for gRPC batches + &ctx, + ) + .await; + + let mut resp = page_api::GetPageResponse { + request_id: req.request_id, + status_code: page_api::GetPageStatusCode::Ok, + reason: None, + page_images: SmallVec::with_capacity(results.len()), + }; + + for result in results { + match result { + Ok((PagestreamBeMessage::GetPage(r), _, _)) => resp.page_images.push(r.page), + Ok((resp, _, _)) => { + return Err(tonic::Status::internal(format!( + "unexpected response: {resp:?}" + ))); + } + Err(err) => return err.err.into_get_page_response(req.request_id), + }; + } + + Ok(resp.into()) + } +} + +/// Implements the gRPC page service. +/// +/// TODO: cancellation. +/// TODO: when the libpq impl is removed, remove the Pagestream types and inline the handler code. #[tonic::async_trait] -impl proto::PageService for PageServerHandler { +impl proto::PageService for GrpcPageServiceHandler { type GetBaseBackupStream = Pin< Box> + Send>, >; + type GetPagesStream = Pin> + Send>>; - #[instrument(skip_all)] + #[instrument(skip_all, fields(rel, lsn))] async fn check_rel_exists( &self, - _: tonic::Request, + req: tonic::Request, ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("not implemented")) + let received_at = extract::(&req).0; + let timeline = self.get_request_timeline(&req).await?; + let ctx = self.ctx.with_scope_page_service_pagestream(&timeline); + + // Validate the request, decorate the span, and convert it to a Pagestream request. + Self::ensure_shard_zero(&timeline)?; + let req: page_api::CheckRelExistsRequest = req.into_inner().try_into()?; + + span_record!(rel=%req.rel, lsn=%req.read_lsn); + + let req = PagestreamExistsRequest { + hdr: Self::make_hdr(req.read_lsn, 0), + rel: req.rel, + }; + + // Execute the request and convert the response. + let _timer = Self::record_op_start_and_throttle( + &timeline, + metrics::SmgrQueryType::GetRelExists, + received_at, + ) + .await?; + + let resp = PageServerHandler::handle_get_rel_exists_request(&timeline, &req, &ctx).await?; + let resp: page_api::CheckRelExistsResponse = resp.exists; + Ok(tonic::Response::new(resp.into())) } - #[instrument(skip_all)] + // TODO: ensure clients use gzip compression for the stream. + #[instrument(skip_all, fields(lsn))] async fn get_base_backup( &self, - _: tonic::Request, + req: tonic::Request, ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("not implemented")) + // Send 64 KB chunks to avoid large memory allocations. + const CHUNK_SIZE: usize = 64 * 1024; + + let timeline = self.get_request_timeline(&req).await?; + let ctx = self.ctx.with_scope_timeline(&timeline); + + // Validate the request, decorate the span, and wait for the LSN to arrive. + // + // TODO: this requires a read LSN, is that ok? + Self::ensure_shard_zero(&timeline)?; + if timeline.is_archived() == Some(true) { + return Err(tonic::Status::failed_precondition("timeline is archived")); + } + let req: page_api::GetBaseBackupRequest = req.into_inner().try_into()?; + + span_record!(lsn=%req.read_lsn); + + let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); + timeline + .wait_lsn( + req.read_lsn.request_lsn, + WaitLsnWaiter::PageService, + WaitLsnTimeout::Default, + &ctx, + ) + .await?; + timeline + .check_lsn_is_in_scope(req.read_lsn.request_lsn, &latest_gc_cutoff_lsn) + .map_err(|err| { + tonic::Status::invalid_argument(format!("invalid basebackup LSN: {err}")) + })?; + + // Spawn a task to run the basebackup. + // + // TODO: do we need to support full base backups, for debugging? + let span = Span::current(); + let (mut simplex_read, mut simplex_write) = tokio::io::simplex(CHUNK_SIZE); + let jh = tokio::spawn(async move { + let result = basebackup::send_basebackup_tarball( + &mut simplex_write, + &timeline, + Some(req.read_lsn.request_lsn), + None, + false, + req.replica, + &ctx, + ) + .instrument(span) // propagate request span + .await; + simplex_write.shutdown().await.map_err(|err| { + BasebackupError::Server(anyhow!("simplex shutdown failed: {err}")) + })?; + result + }); + + // Emit chunks of size CHUNK_SIZE. + let chunks = async_stream::try_stream! { + let mut chunk = BytesMut::with_capacity(CHUNK_SIZE); + loop { + let n = simplex_read.read_buf(&mut chunk).await.map_err(|err| { + tonic::Status::internal(format!("failed to read basebackup chunk: {err}")) + })?; + + // If we read 0 bytes, either the chunk is full or the stream is closed. + if n == 0 { + if chunk.is_empty() { + break; + } + yield proto::GetBaseBackupResponseChunk::try_from(chunk.clone().freeze())?; + chunk.clear(); + } + } + // Wait for the basebackup task to exit and check for errors. + jh.await.map_err(|err| { + tonic::Status::internal(format!("basebackup failed: {err}")) + })??; + }; + + Ok(tonic::Response::new(Box::pin(chunks))) } - #[instrument(skip_all)] + #[instrument(skip_all, fields(db_oid, lsn))] async fn get_db_size( &self, - _: tonic::Request, + req: tonic::Request, ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("not implemented")) + let received_at = extract::(&req).0; + let timeline = self.get_request_timeline(&req).await?; + let ctx = self.ctx.with_scope_page_service_pagestream(&timeline); + + // Validate the request, decorate the span, and convert it to a Pagestream request. + Self::ensure_shard_zero(&timeline)?; + let req: page_api::GetDbSizeRequest = req.into_inner().try_into()?; + + span_record!(db_oid=%req.db_oid, lsn=%req.read_lsn); + + let req = PagestreamDbSizeRequest { + hdr: Self::make_hdr(req.read_lsn, 0), + dbnode: req.db_oid, + }; + + // Execute the request and convert the response. + let _timer = Self::record_op_start_and_throttle( + &timeline, + metrics::SmgrQueryType::GetDbSize, + received_at, + ) + .await?; + + let resp = PageServerHandler::handle_db_size_request(&timeline, &req, &ctx).await?; + let resp = resp.db_size as page_api::GetDbSizeResponse; + Ok(tonic::Response::new(resp.into())) } // NB: don't instrument this, instrument each streamed request. async fn get_pages( &self, - _: tonic::Request>, + req: tonic::Request>, ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("not implemented")) + // Extract the timeline from the request and check that it exists. + let ttid = *extract::(&req); + let shard_index = *extract::(&req); + let shard_selector = ShardSelector::Known(shard_index); + + let mut handles = TimelineHandles::new(self.tenant_manager.clone()); + handles + .get(ttid.tenant_id, ttid.timeline_id, shard_selector) + .await?; + + let span = Span::current(); + let ctx = self.ctx.attached_child(); + let mut reqs = req.into_inner(); + + let resps = async_stream::try_stream! { + let timeline = handles + .get(ttid.tenant_id, ttid.timeline_id, shard_selector) + .await? + .downgrade(); + while let Some(req) = reqs.message().await? { + // TODO: implement IoConcurrency sidecar. + yield Self::get_page(&ctx, &timeline, req, IoConcurrency::Sequential) + .instrument(span.clone()) // propagate request span + .await? + } + }; + + Ok(tonic::Response::new(Box::pin(resps))) } - #[instrument(skip_all)] + #[instrument(skip_all, fields(rel, lsn))] async fn get_rel_size( &self, - _: tonic::Request, + req: tonic::Request, ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("not implemented")) + let received_at = extract::(&req).0; + let timeline = self.get_request_timeline(&req).await?; + let ctx = self.ctx.with_scope_page_service_pagestream(&timeline); + + // Validate the request, decorate the span, and convert it to a Pagestream request. + Self::ensure_shard_zero(&timeline)?; + let req: page_api::GetRelSizeRequest = req.into_inner().try_into()?; + + span_record!(rel=%req.rel, lsn=%req.read_lsn); + + let req = PagestreamNblocksRequest { + hdr: Self::make_hdr(req.read_lsn, 0), + rel: req.rel, + }; + + // Execute the request and convert the response. + let _timer = Self::record_op_start_and_throttle( + &timeline, + metrics::SmgrQueryType::GetRelSize, + received_at, + ) + .await?; + + let resp = PageServerHandler::handle_get_nblocks_request(&timeline, &req, &ctx).await?; + let resp: page_api::GetRelSizeResponse = resp.n_blocks; + Ok(tonic::Response::new(resp.into())) } - #[instrument(skip_all)] + #[instrument(skip_all, fields(kind, segno, lsn))] async fn get_slru_segment( &self, - _: tonic::Request, + req: tonic::Request, ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("not implemented")) + let received_at = extract::(&req).0; + let timeline = self.get_request_timeline(&req).await?; + let ctx = self.ctx.with_scope_page_service_pagestream(&timeline); + + // Validate the request, decorate the span, and convert it to a Pagestream request. + Self::ensure_shard_zero(&timeline)?; + let req: page_api::GetSlruSegmentRequest = req.into_inner().try_into()?; + + span_record!(kind=%req.kind, segno=%req.segno, lsn=%req.read_lsn); + + let req = PagestreamGetSlruSegmentRequest { + hdr: Self::make_hdr(req.read_lsn, 0), + kind: req.kind as u8, + segno: req.segno, + }; + + // Execute the request and convert the response. + let _timer = Self::record_op_start_and_throttle( + &timeline, + metrics::SmgrQueryType::GetSlruSegment, + received_at, + ) + .await?; + + let resp = + PageServerHandler::handle_get_slru_segment_request(&timeline, &req, &ctx).await?; + let resp: page_api::GetSlruSegmentResponse = resp.segment; + Ok(tonic::Response::new(resp.try_into()?)) } } /// gRPC middleware layer that handles observability concerns: /// /// * Creates and enters a tracing span. +/// * Records the request start time as a ReceivedAt request extension. /// /// TODO: add perf tracing. /// TODO: add timing and metrics. @@ -3395,6 +3834,9 @@ struct ObservabilityLayerService { inner: S, } +#[derive(Clone, Copy)] +struct ReceivedAt(Instant); + impl tonic::server::NamedService for ObservabilityLayerService { const NAME: &'static str = S::NAME; // propagate inner service name } @@ -3408,7 +3850,13 @@ where type Error = S::Error; type Future = BoxFuture<'static, Result>; - fn call(&mut self, req: http::Request) -> Self::Future { + fn call(&mut self, mut req: http::Request) -> Self::Future { + // Record the request start time as a request extension. + // + // TODO: we should start a timer here instead, but it currently requires a timeline handle + // and SmgrQueryType, which we don't have yet. Refactor it to provide it later. + req.extensions_mut().insert(ReceivedAt(Instant::now())); + // Create a basic tracing span. Enter the span for the current thread (to use it for inner // sync code like interceptors), and instrument the future (to use it for inner async code // like the page service itself). @@ -3436,8 +3884,6 @@ where /// gRPC interceptor that decodes tenant metadata and stores it as request extensions of type /// TenantTimelineId and ShardIndex. -/// -/// TODO: consider looking up the timeline handle here and storing it. #[derive(Clone)] struct TenantMetadataInterceptor; @@ -3485,7 +3931,7 @@ impl tonic::service::Interceptor for TenantMetadataInterceptor { } } -/// Authenticates gRPC page service requests. Must run after TenantMetadataInterceptor. +/// Authenticates gRPC page service requests. #[derive(Clone)] struct TenantAuthInterceptor { auth: Option>, @@ -3504,11 +3950,8 @@ impl tonic::service::Interceptor for TenantAuthInterceptor { return Ok(req); }; - // Fetch the tenant ID that's been set by TenantMetadataInterceptor. - let ttid = req - .extensions() - .get::() - .expect("TenantMetadataInterceptor must run before TenantAuthInterceptor"); + // Fetch the tenant ID from the request extensions (set by TenantMetadataInterceptor). + let TenantTimelineId { tenant_id, .. } = *extract::(&req); // Fetch and decode the JWT token. let jwt = req @@ -3526,7 +3969,7 @@ impl tonic::service::Interceptor for TenantAuthInterceptor { let claims = jwtdata.claims; // Check if the token is valid for this tenant. - check_permission(&claims, Some(ttid.tenant_id)) + check_permission(&claims, Some(tenant_id)) .map_err(|err| tonic::Status::permission_denied(err.to_string()))?; // TODO: consider stashing the claims in the request extensions, if needed. @@ -3535,6 +3978,21 @@ impl tonic::service::Interceptor for TenantAuthInterceptor { } } +/// Extracts the given type from the request extensions, or panics if it is missing. +fn extract(req: &tonic::Request) -> &T { + extract_from(req.extensions()) +} + +/// Extract the given type from the request extensions, or panics if it is missing. This variant +/// can extract both from a tonic::Request and http::Request. +fn extract_from(ext: &http::Extensions) -> &T { + let Some(value) = ext.get::() else { + let name = std::any::type_name::(); + panic!("extension {name} should be set by middleware"); + }; + value +} + #[derive(Debug, thiserror::Error)] pub(crate) enum GetActiveTimelineError { #[error(transparent)] @@ -3553,6 +4011,29 @@ impl From for QueryError { } } +impl From for tonic::Status { + fn from(err: GetActiveTimelineError) -> Self { + let message = err.to_string(); + let code = match err { + GetActiveTimelineError::Tenant(err) => tonic::Status::from(err).code(), + GetActiveTimelineError::Timeline(err) => tonic::Status::from(err).code(), + }; + tonic::Status::new(code, message) + } +} + +impl From for tonic::Status { + fn from(err: GetTimelineError) -> Self { + use tonic::Code; + let code = match &err { + GetTimelineError::NotFound { .. } => Code::NotFound, + GetTimelineError::NotActive { .. } => Code::Unavailable, + GetTimelineError::ShuttingDown => Code::Unavailable, + }; + tonic::Status::new(code, err.to_string()) + } +} + impl From for QueryError { fn from(e: GetActiveTenantError) -> Self { match e { @@ -3569,10 +4050,33 @@ impl From for QueryError { } } -impl From for QueryError { - fn from(e: crate::tenant::timeline::handle::HandleUpgradeError) -> Self { +impl From for tonic::Status { + fn from(err: GetActiveTenantError) -> Self { + use tonic::Code; + let code = match &err { + GetActiveTenantError::Broken(_) => Code::Internal, + GetActiveTenantError::Cancelled => Code::Unavailable, + GetActiveTenantError::NotFound(_) => Code::NotFound, + GetActiveTenantError::SwitchedTenant => Code::Unavailable, + GetActiveTenantError::WaitForActiveTimeout { .. } => Code::Unavailable, + GetActiveTenantError::WillNotBecomeActive(_) => Code::Unavailable, + }; + tonic::Status::new(code, err.to_string()) + } +} + +impl From for QueryError { + fn from(e: HandleUpgradeError) -> Self { match e { - crate::tenant::timeline::handle::HandleUpgradeError::ShutDown => QueryError::Shutdown, + HandleUpgradeError::ShutDown => QueryError::Shutdown, + } + } +} + +impl From for tonic::Status { + fn from(err: HandleUpgradeError) -> Self { + match err { + HandleUpgradeError::ShutDown => tonic::Status::unavailable("timeline is shutting down"), } } } diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index 23c40a7629..9ddbe404d2 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -950,6 +950,18 @@ pub(crate) enum WaitLsnError { Timeout(String), } +impl From for tonic::Status { + fn from(err: WaitLsnError) -> Self { + use tonic::Code; + let code = match &err { + WaitLsnError::Timeout(_) => Code::Internal, + WaitLsnError::BadState(_) => Code::Internal, + WaitLsnError::Shutdown => Code::Unavailable, + }; + tonic::Status::new(code, err.to_string()) + } +} + // The impls below achieve cancellation mapping for errors. // Perhaps there's a way of achieving this with less cruft. From a650f7f5af4773bc6c7806a12b49e84234c7e6d6 Mon Sep 17 00:00:00 2001 From: "Alex Chi Z." <4198311+skyzh@users.noreply.github.com> Date: Tue, 3 Jun 2025 13:00:34 +0800 Subject: [PATCH 46/48] fix(pageserver): only deserialize reldir key once during get_db_size (#12102) ## Problem fix https://github.com/neondatabase/neon/issues/12101; this is a quick hack and we need better API in the future. In `get_db_size`, we call `get_reldir_size` for every relation. However, we do the same deserializing the reldir directory thing for every relation. This creates huge CPU overhead. ## Summary of changes Get and deserialize the reldir v1 key once and use it across all get_rel_size requests. --------- Signed-off-by: Alex Chi Z --- pageserver/src/pgdatadir_mapping.rs | 58 ++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 2 deletions(-) diff --git a/pageserver/src/pgdatadir_mapping.rs b/pageserver/src/pgdatadir_mapping.rs index c6f3929257..b6f11b744b 100644 --- a/pageserver/src/pgdatadir_mapping.rs +++ b/pageserver/src/pgdatadir_mapping.rs @@ -471,8 +471,19 @@ impl Timeline { let rels = self.list_rels(spcnode, dbnode, version, ctx).await?; + if rels.is_empty() { + return Ok(0); + } + + // Pre-deserialize the rel directory to avoid duplicated work in `get_relsize_cached`. + let reldir_key = rel_dir_to_key(spcnode, dbnode); + let buf = version.get(self, reldir_key, ctx).await?; + let reldir = RelDirectory::des(&buf)?; + for rel in rels { - let n_blocks = self.get_rel_size(rel, version, ctx).await?; + let n_blocks = self + .get_rel_size_in_reldir(rel, version, Some((reldir_key, &reldir)), ctx) + .await?; total_blocks += n_blocks as usize; } Ok(total_blocks) @@ -487,6 +498,19 @@ impl Timeline { tag: RelTag, version: Version<'_>, ctx: &RequestContext, + ) -> Result { + self.get_rel_size_in_reldir(tag, version, None, ctx).await + } + + /// Get size of a relation file. The relation must exist, otherwise an error is returned. + /// + /// See [`Self::get_rel_exists_in_reldir`] on why we need `deserialized_reldir_v1`. + pub(crate) async fn get_rel_size_in_reldir( + &self, + tag: RelTag, + version: Version<'_>, + deserialized_reldir_v1: Option<(Key, &RelDirectory)>, + ctx: &RequestContext, ) -> Result { if tag.relnode == 0 { return Err(PageReconstructError::Other( @@ -499,7 +523,9 @@ impl Timeline { } if (tag.forknum == FSM_FORKNUM || tag.forknum == VISIBILITYMAP_FORKNUM) - && !self.get_rel_exists(tag, version, ctx).await? + && !self + .get_rel_exists_in_reldir(tag, version, deserialized_reldir_v1, ctx) + .await? { // FIXME: Postgres sometimes calls smgrcreate() to create // FSM, and smgrnblocks() on it immediately afterwards, @@ -521,11 +547,28 @@ impl Timeline { /// /// Only shard 0 has a full view of the relations. Other shards only know about relations that /// the shard stores pages for. + /// pub(crate) async fn get_rel_exists( &self, tag: RelTag, version: Version<'_>, ctx: &RequestContext, + ) -> Result { + self.get_rel_exists_in_reldir(tag, version, None, ctx).await + } + + /// Does the relation exist? With a cached deserialized `RelDirectory`. + /// + /// There are some cases where the caller loops across all relations. In that specific case, + /// the caller should obtain the deserialized `RelDirectory` first and then call this function + /// to avoid duplicated work of deserliazation. This is a hack and should be removed by introducing + /// a new API (e.g., `get_rel_exists_batched`). + pub(crate) async fn get_rel_exists_in_reldir( + &self, + tag: RelTag, + version: Version<'_>, + deserialized_reldir_v1: Option<(Key, &RelDirectory)>, + ctx: &RequestContext, ) -> Result { if tag.relnode == 0 { return Err(PageReconstructError::Other( @@ -568,6 +611,17 @@ impl Timeline { // fetch directory listing (old) let key = rel_dir_to_key(tag.spcnode, tag.dbnode); + + if let Some((cached_key, dir)) = deserialized_reldir_v1 { + if cached_key == key { + return Ok(dir.rels.contains(&(tag.relnode, tag.forknum))); + } else if cfg!(test) || cfg!(feature = "testing") { + panic!("cached reldir key mismatch: {cached_key} != {key}"); + } else { + warn!("cached reldir key mismatch: {cached_key} != {key}"); + } + // Fallback to reading the directory from the datadir. + } let buf = version.get(self, key, ctx).await?; let dir = RelDirectory::des(&buf)?; From 3e72edede524af50220d0d103df08a1f6e12e6a9 Mon Sep 17 00:00:00 2001 From: a-masterov <72613290+a-masterov@users.noreply.github.com> Date: Tue, 3 Jun 2025 09:23:17 +0200 Subject: [PATCH 47/48] Use full hostname for ONNX URL (#12064) ## Problem We should use the full host name for computes, according to https://github.com/neondatabase/cloud/issues/26005 , but now a truncated host name is used. ## Summary of changes The URL for REMOTE_ONNX is rewritten using the FQDN. --- compute/compute-node.Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compute/compute-node.Dockerfile b/compute/compute-node.Dockerfile index 3459983a34..2afdde0cfa 100644 --- a/compute/compute-node.Dockerfile +++ b/compute/compute-node.Dockerfile @@ -1180,14 +1180,14 @@ RUN cd exts/rag && \ RUN cd exts/rag_bge_small_en_v15 && \ sed -i 's/pgrx = "0.14.1"/pgrx = { version = "0.14.1", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \ ORT_LIB_LOCATION=/ext-src/onnxruntime-src/build/Linux \ - REMOTE_ONNX_URL=http://pg-ext-s3-gateway/pgrag-data/bge_small_en_v15.onnx \ + REMOTE_ONNX_URL=http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local/pgrag-data/bge_small_en_v15.onnx \ cargo pgrx install --release --features remote_onnx && \ echo "trusted = true" >> /usr/local/pgsql/share/extension/rag_bge_small_en_v15.control RUN cd exts/rag_jina_reranker_v1_tiny_en && \ sed -i 's/pgrx = "0.14.1"/pgrx = { version = "0.14.1", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \ ORT_LIB_LOCATION=/ext-src/onnxruntime-src/build/Linux \ - REMOTE_ONNX_URL=http://pg-ext-s3-gateway/pgrag-data/jina_reranker_v1_tiny_en.onnx \ + REMOTE_ONNX_URL=http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local/pgrag-data/jina_reranker_v1_tiny_en.onnx \ cargo pgrx install --release --features remote_onnx && \ echo "trusted = true" >> /usr/local/pgsql/share/extension/rag_jina_reranker_v1_tiny_en.control From 3b8be98b67acbb3da0852ca5adf33408e2313f89 Mon Sep 17 00:00:00 2001 From: Vlad Lazar Date: Tue, 3 Jun 2025 10:07:07 +0100 Subject: [PATCH 48/48] pageserver: remove backtrace in info level log (#12108) ## Problem We print a backtrace in an info level log every 10 seconds while waiting for the import data to land in the bucket. ## Summary of changes The backtrace is not useful. Remove it. --- pageserver/src/tenant/timeline/import_pgdata.rs | 4 ++-- pageserver/src/tenant/timeline/import_pgdata/flow.rs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pageserver/src/tenant/timeline/import_pgdata.rs b/pageserver/src/tenant/timeline/import_pgdata.rs index 3f760d858b..606ad09ef1 100644 --- a/pageserver/src/tenant/timeline/import_pgdata.rs +++ b/pageserver/src/tenant/timeline/import_pgdata.rs @@ -201,8 +201,8 @@ async fn prepare_import( .await; match res { Ok(_) => break, - Err(err) => { - info!(?err, "indefinitely waiting for pgdata to finish"); + Err(_err) => { + info!("indefinitely waiting for pgdata to finish"); if tokio::time::timeout(std::time::Duration::from_secs(10), cancel.cancelled()) .await .is_ok() diff --git a/pageserver/src/tenant/timeline/import_pgdata/flow.rs b/pageserver/src/tenant/timeline/import_pgdata/flow.rs index 760e82dd57..2ec9d86720 100644 --- a/pageserver/src/tenant/timeline/import_pgdata/flow.rs +++ b/pageserver/src/tenant/timeline/import_pgdata/flow.rs @@ -471,6 +471,8 @@ impl Plan { last_completed_job_idx = job_idx; if last_completed_job_idx % checkpoint_every == 0 { + tracing::info!(last_completed_job_idx, jobs=%jobs_in_plan, "Checkpointing import status"); + let progress = ShardImportProgressV1 { jobs: jobs_in_plan, completed: last_completed_job_idx, @@ -492,8 +494,6 @@ impl Plan { anyhow::anyhow!("Shut down while putting timeline import status") })?; } - - tracing::info!(last_completed_job_idx, jobs=%jobs_in_plan, "Checkpointing import status"); }, Some(Err(_)) => { anyhow::bail!(