diff --git a/.github/actions/allure-report-generate/action.yml b/.github/actions/allure-report-generate/action.yml index 1ecb5ecc7e..f84beff20c 100644 --- a/.github/actions/allure-report-generate/action.yml +++ b/.github/actions/allure-report-generate/action.yml @@ -150,7 +150,7 @@ runs: # Use aws s3 cp (instead of aws s3 sync) to keep files from previous runs to make old URLs work, # and to keep files on the host to upload them to the database - time aws s3 cp --recursive --only-show-errors "${WORKDIR}/report" "s3://${BUCKET}/${REPORT_PREFIX}/${GITHUB_RUN_ID}" + time s5cmd --log error cp "${WORKDIR}/report/*" "s3://${BUCKET}/${REPORT_PREFIX}/${GITHUB_RUN_ID}/" # Generate redirect cat < ${WORKDIR}/index.html diff --git a/.github/workflows/approved-for-ci-run.yml b/.github/workflows/approved-for-ci-run.yml index 69c48d86b9..ab616d17e2 100644 --- a/.github/workflows/approved-for-ci-run.yml +++ b/.github/workflows/approved-for-ci-run.yml @@ -18,6 +18,7 @@ on: concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number }} + cancel-in-progress: false env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/build-build-tools-image.yml b/.github/workflows/build-build-tools-image.yml index 251423e701..c527cef1ac 100644 --- a/.github/workflows/build-build-tools-image.yml +++ b/.github/workflows/build-build-tools-image.yml @@ -21,6 +21,7 @@ defaults: concurrency: group: build-build-tools-image-${{ inputs.image-tag }} + cancel-in-progress: false # No permission for GITHUB_TOKEN by default; the **minimal required** set of permissions should be granted in each job. permissions: {} diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 36922d5294..1d35fa9223 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -1133,8 +1133,6 @@ jobs: -f deployPreprodRegion=true gh workflow --repo neondatabase/aws run deploy-prod.yml --ref main \ - -f deployPgSniRouter=false \ - -f deployProxy=false \ -f deployStorage=true \ -f deployStorageBroker=true \ -f deployStorageController=true \ diff --git a/.github/workflows/check-build-tools-image.yml b/.github/workflows/check-build-tools-image.yml index 28646dfc19..a1e22cf93f 100644 --- a/.github/workflows/check-build-tools-image.yml +++ b/.github/workflows/check-build-tools-image.yml @@ -28,7 +28,9 @@ jobs: - name: Get build-tools image tag for the current commit id: get-build-tools-tag env: - COMMIT_SHA: ${{ github.event.pull_request.head.sha || github.sha }} + # Usually, for COMMIT_SHA, we use `github.event.pull_request.head.sha || github.sha`, but here, even for PRs, + # we want to use `github.sha` i.e. point to a phantom merge commit to determine the image tag correctly. + COMMIT_SHA: ${{ github.sha }} GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | LAST_BUILD_TOOLS_SHA=$( diff --git a/.github/workflows/pin-build-tools-image.yml b/.github/workflows/pin-build-tools-image.yml index c941692066..d495a158e8 100644 --- a/.github/workflows/pin-build-tools-image.yml +++ b/.github/workflows/pin-build-tools-image.yml @@ -20,6 +20,7 @@ defaults: concurrency: group: pin-build-tools-image-${{ inputs.from-tag }} + cancel-in-progress: false permissions: {} diff --git a/Cargo.lock b/Cargo.lock index bdf2b08c5c..6faf4b72f0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2932,9 +2932,9 @@ checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" [[package]] name = "measured" -version = "0.0.20" +version = "0.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3cbf033874bea03565f2449572c8640ca37ec26300455faf36001f24755da452" +checksum = "652bc741286361c06de8cb4d89b21a6437f120c508c51713663589eeb9928ac5" dependencies = [ "bytes", "crossbeam-utils", @@ -2950,9 +2950,9 @@ dependencies = [ [[package]] name = "measured-derive" -version = "0.0.20" +version = "0.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be9e29b682b38f8af2a89f960455054ab1a9f5a06822f6f3500637ad9fa57def" +checksum = "6ea497f33e1e856a376c32ad916f69a0bd3c597db1f912a399f842b01a4a685d" dependencies = [ "heck 0.5.0", "proc-macro2", @@ -2962,9 +2962,9 @@ dependencies = [ [[package]] name = "measured-process" -version = "0.0.20" +version = "0.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a20849acdd04c5d6a88f565559044546904648a1842a2937cfff0b48b4ca7ef2" +checksum = "b364ccb66937a814b6b2ad751d1a2f7a9d5a78c761144036825fb36bb0771000" dependencies = [ "libc", "measured", @@ -4322,6 +4322,7 @@ dependencies = [ "itertools", "lasso", "md5", + "measured", "metrics", "native-tls", "once_cell", diff --git a/Cargo.toml b/Cargo.toml index feea17ab05..8310d2d522 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -107,8 +107,8 @@ lasso = "0.7" leaky-bucket = "1.0.1" libc = "0.2" md5 = "0.7.0" -measured = { version = "0.0.20", features=["lasso"] } -measured-process = { version = "0.0.20" } +measured = { version = "0.0.21", features=["lasso"] } +measured-process = { version = "0.0.21" } memoffset = "0.8" native-tls = "0.2" nix = { version = "0.27", features = ["fs", "process", "socket", "signal", "poll"] } diff --git a/Dockerfile.build-tools b/Dockerfile.build-tools index 1ed6f87473..a082f15c34 100644 --- a/Dockerfile.build-tools +++ b/Dockerfile.build-tools @@ -58,6 +58,12 @@ RUN curl -fsSL "https://github.com/protocolbuffers/protobuf/releases/download/v$ && mv protoc/include/google /usr/local/include/google \ && rm -rf protoc.zip protoc +# s5cmd +ENV S5CMD_VERSION=2.2.2 +RUN curl -sL "https://github.com/peak/s5cmd/releases/download/v${S5CMD_VERSION}/s5cmd_${S5CMD_VERSION}_Linux-$(uname -m | sed 's/x86_64/64bit/g' | sed 's/aarch64/arm64/g').tar.gz" | tar zxvf - s5cmd \ + && chmod +x s5cmd \ + && mv s5cmd /usr/local/bin/s5cmd + # LLVM ENV LLVM_VERSION=17 RUN curl -fsSL 'https://apt.llvm.org/llvm-snapshot.gpg.key' | apt-key add - \ diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index 88dc4aca2b..40060f4117 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -818,9 +818,15 @@ impl ComputeNode { Client::connect(zenith_admin_connstr.as_str(), NoTls) .context("broken cloud_admin credential: tried connecting with cloud_admin but could not authenticate, and zenith_admin does not work either")?; // Disable forwarding so that users don't get a cloud_admin role - client.simple_query("SET neon.forward_ddl = false")?; - client.simple_query("CREATE USER cloud_admin WITH SUPERUSER")?; - client.simple_query("GRANT zenith_admin TO cloud_admin")?; + + let mut func = || { + client.simple_query("SET neon.forward_ddl = false")?; + client.simple_query("CREATE USER cloud_admin WITH SUPERUSER")?; + client.simple_query("GRANT zenith_admin TO cloud_admin")?; + Ok::<_, anyhow::Error>(()) + }; + func().context("apply_config setup cloud_admin")?; + drop(client); // reconnect with connstring with expected name @@ -832,24 +838,29 @@ impl ComputeNode { }; // Disable DDL forwarding because control plane already knows about these roles/databases. - client.simple_query("SET neon.forward_ddl = false")?; + client + .simple_query("SET neon.forward_ddl = false") + .context("apply_config SET neon.forward_ddl = false")?; // Proceed with post-startup configuration. Note, that order of operations is important. let spec = &compute_state.pspec.as_ref().expect("spec must be set").spec; - create_neon_superuser(spec, &mut client)?; - cleanup_instance(&mut client)?; - handle_roles(spec, &mut client)?; - handle_databases(spec, &mut client)?; - handle_role_deletions(spec, connstr.as_str(), &mut client)?; + create_neon_superuser(spec, &mut client).context("apply_config create_neon_superuser")?; + cleanup_instance(&mut client).context("apply_config cleanup_instance")?; + handle_roles(spec, &mut client).context("apply_config handle_roles")?; + handle_databases(spec, &mut client).context("apply_config handle_databases")?; + handle_role_deletions(spec, connstr.as_str(), &mut client) + .context("apply_config handle_role_deletions")?; handle_grants( spec, &mut client, connstr.as_str(), self.has_feature(ComputeFeature::AnonExtension), - )?; - handle_extensions(spec, &mut client)?; - handle_extension_neon(&mut client)?; - create_availability_check_data(&mut client)?; + ) + .context("apply_config handle_grants")?; + handle_extensions(spec, &mut client).context("apply_config handle_extensions")?; + handle_extension_neon(&mut client).context("apply_config handle_extension_neon")?; + create_availability_check_data(&mut client) + .context("apply_config create_availability_check_data")?; // 'Close' connection drop(client); @@ -857,7 +868,7 @@ impl ComputeNode { // Run migrations separately to not hold up cold starts thread::spawn(move || { let mut client = Client::connect(connstr.as_str(), NoTls)?; - handle_migrations(&mut client) + handle_migrations(&mut client).context("apply_config handle_migrations") }); Ok(()) } diff --git a/compute_tools/src/spec.rs b/compute_tools/src/spec.rs index 5643634633..269177ee16 100644 --- a/compute_tools/src/spec.rs +++ b/compute_tools/src/spec.rs @@ -2,7 +2,7 @@ use std::fs::File; use std::path::Path; use std::str::FromStr; -use anyhow::{anyhow, bail, Result}; +use anyhow::{anyhow, bail, Context, Result}; use postgres::config::Config; use postgres::{Client, NoTls}; use reqwest::StatusCode; @@ -698,7 +698,8 @@ pub fn handle_grants( // it is important to run this after all grants if enable_anon_extension { - handle_extension_anon(spec, &db.owner, &mut db_client, false)?; + handle_extension_anon(spec, &db.owner, &mut db_client, false) + .context("handle_grants handle_extension_anon")?; } } @@ -813,28 +814,36 @@ $$;"#, // Add new migrations below. ]; - let mut query = "CREATE SCHEMA IF NOT EXISTS neon_migration"; - client.simple_query(query)?; + let mut func = || { + let query = "CREATE SCHEMA IF NOT EXISTS neon_migration"; + client.simple_query(query)?; - query = "CREATE TABLE IF NOT EXISTS neon_migration.migration_id (key INT NOT NULL PRIMARY KEY, id bigint NOT NULL DEFAULT 0)"; - client.simple_query(query)?; + let query = "CREATE TABLE IF NOT EXISTS neon_migration.migration_id (key INT NOT NULL PRIMARY KEY, id bigint NOT NULL DEFAULT 0)"; + client.simple_query(query)?; - query = "INSERT INTO neon_migration.migration_id VALUES (0, 0) ON CONFLICT DO NOTHING"; - client.simple_query(query)?; + let query = "INSERT INTO neon_migration.migration_id VALUES (0, 0) ON CONFLICT DO NOTHING"; + client.simple_query(query)?; - query = "ALTER SCHEMA neon_migration OWNER TO cloud_admin"; - client.simple_query(query)?; + let query = "ALTER SCHEMA neon_migration OWNER TO cloud_admin"; + client.simple_query(query)?; - query = "REVOKE ALL ON SCHEMA neon_migration FROM PUBLIC"; - client.simple_query(query)?; + let query = "REVOKE ALL ON SCHEMA neon_migration FROM PUBLIC"; + client.simple_query(query)?; + Ok::<_, anyhow::Error>(()) + }; + func().context("handle_migrations prepare")?; - query = "SELECT id FROM neon_migration.migration_id"; - let row = client.query_one(query, &[])?; + let query = "SELECT id FROM neon_migration.migration_id"; + let row = client + .query_one(query, &[]) + .context("handle_migrations get migration_id")?; let mut current_migration: usize = row.get::<&str, i64>("id") as usize; let starting_migration_id = current_migration; - query = "BEGIN"; - client.simple_query(query)?; + let query = "BEGIN"; + client + .simple_query(query) + .context("handle_migrations begin")?; while current_migration < migrations.len() { let migration = &migrations[current_migration]; @@ -842,7 +851,9 @@ $$;"#, info!("Skip migration id={}", current_migration); } else { info!("Running migration:\n{}\n", migration); - client.simple_query(migration)?; + client.simple_query(migration).with_context(|| { + format!("handle_migrations current_migration={}", current_migration) + })?; } current_migration += 1; } @@ -850,10 +861,14 @@ $$;"#, "UPDATE neon_migration.migration_id SET id={}", migrations.len() ); - client.simple_query(&setval)?; + client + .simple_query(&setval) + .context("handle_migrations update id")?; - query = "COMMIT"; - client.simple_query(query)?; + let query = "COMMIT"; + client + .simple_query(query) + .context("handle_migrations commit")?; info!( "Ran {} migrations", diff --git a/control_plane/src/bin/neon_local.rs b/control_plane/src/bin/neon_local.rs index 56495dd2da..7f8f6d21e0 100644 --- a/control_plane/src/bin/neon_local.rs +++ b/control_plane/src/bin/neon_local.rs @@ -1231,7 +1231,7 @@ async fn try_stop_all(env: &local_env::LocalEnv, immediate: bool) { match ComputeControlPlane::load(env.clone()) { Ok(cplane) => { for (_k, node) in cplane.endpoints { - if let Err(e) = node.stop(if immediate { "immediate" } else { "fast " }, false) { + if let Err(e) = node.stop(if immediate { "immediate" } else { "fast" }, false) { eprintln!("postgres stop failed: {e:#}"); } } @@ -1417,6 +1417,7 @@ fn cli() -> Command { .subcommand( Command::new("timeline") .about("Manage timelines") + .arg_required_else_help(true) .subcommand(Command::new("list") .about("List all timelines, available to this pageserver") .arg(tenant_id_arg.clone())) diff --git a/control_plane/src/local_env.rs b/control_plane/src/local_env.rs index bd3dbef453..38b7fffd09 100644 --- a/control_plane/src/local_env.rs +++ b/control_plane/src/local_env.rs @@ -156,6 +156,7 @@ pub struct SafekeeperConf { pub remote_storage: Option, pub backup_threads: Option, pub auth_enabled: bool, + pub listen_addr: Option, } impl Default for SafekeeperConf { @@ -169,6 +170,7 @@ impl Default for SafekeeperConf { remote_storage: None, backup_threads: None, auth_enabled: false, + listen_addr: None, } } } diff --git a/control_plane/src/safekeeper.rs b/control_plane/src/safekeeper.rs index 6ac71dfe51..d62a2e80b5 100644 --- a/control_plane/src/safekeeper.rs +++ b/control_plane/src/safekeeper.rs @@ -70,24 +70,31 @@ pub struct SafekeeperNode { pub pg_connection_config: PgConnectionConfig, pub env: LocalEnv, pub http_client: reqwest::Client, + pub listen_addr: String, pub http_base_url: String, } impl SafekeeperNode { pub fn from_env(env: &LocalEnv, conf: &SafekeeperConf) -> SafekeeperNode { + let listen_addr = if let Some(ref listen_addr) = conf.listen_addr { + listen_addr.clone() + } else { + "127.0.0.1".to_string() + }; SafekeeperNode { id: conf.id, conf: conf.clone(), - pg_connection_config: Self::safekeeper_connection_config(conf.pg_port), + pg_connection_config: Self::safekeeper_connection_config(&listen_addr, conf.pg_port), env: env.clone(), http_client: reqwest::Client::new(), - http_base_url: format!("http://127.0.0.1:{}/v1", conf.http_port), + http_base_url: format!("http://{}:{}/v1", listen_addr, conf.http_port), + listen_addr, } } /// Construct libpq connection string for connecting to this safekeeper. - fn safekeeper_connection_config(port: u16) -> PgConnectionConfig { - PgConnectionConfig::new_host_port(url::Host::parse("127.0.0.1").unwrap(), port) + fn safekeeper_connection_config(addr: &str, port: u16) -> PgConnectionConfig { + PgConnectionConfig::new_host_port(url::Host::parse(addr).unwrap(), port) } pub fn datadir_path_by_id(env: &LocalEnv, sk_id: NodeId) -> PathBuf { @@ -111,8 +118,8 @@ impl SafekeeperNode { ); io::stdout().flush().unwrap(); - let listen_pg = format!("127.0.0.1:{}", self.conf.pg_port); - let listen_http = format!("127.0.0.1:{}", self.conf.http_port); + let listen_pg = format!("{}:{}", self.listen_addr, self.conf.pg_port); + let listen_http = format!("{}:{}", self.listen_addr, self.conf.http_port); let id = self.id; let datadir = self.datadir_path(); @@ -139,7 +146,7 @@ impl SafekeeperNode { availability_zone, ]; if let Some(pg_tenant_only_port) = self.conf.pg_tenant_only_port { - let listen_pg_tenant_only = format!("127.0.0.1:{}", pg_tenant_only_port); + let listen_pg_tenant_only = format!("{}:{}", self.listen_addr, pg_tenant_only_port); args.extend(["--listen-pg-tenant-only".to_owned(), listen_pg_tenant_only]); } if !self.conf.sync { diff --git a/libs/metrics/src/hll.rs b/libs/metrics/src/hll.rs index dfb4461ce9..f53511ab5c 100644 --- a/libs/metrics/src/hll.rs +++ b/libs/metrics/src/hll.rs @@ -7,14 +7,19 @@ //! use significantly less memory than this, but can only approximate the cardinality. use std::{ - collections::HashMap, - hash::{BuildHasher, BuildHasherDefault, Hash, Hasher}, - sync::{atomic::AtomicU8, Arc, RwLock}, + hash::{BuildHasher, BuildHasherDefault, Hash}, + sync::atomic::AtomicU8, }; -use prometheus::{ - core::{self, Describer}, - proto, Opts, +use measured::{ + label::{LabelGroupVisitor, LabelName, LabelValue, LabelVisitor}, + metric::{ + group::{Encoding, MetricValue}, + name::MetricNameEncoder, + Metric, MetricType, MetricVec, + }, + text::TextEncoder, + LabelGroup, }; use twox_hash::xxh3; @@ -93,203 +98,25 @@ macro_rules! register_hll { /// ``` /// /// See for estimates on alpha -#[derive(Clone)] -pub struct HyperLogLogVec { - core: Arc>, +pub type HyperLogLogVec = MetricVec, L>; +pub type HyperLogLog = Metric>; + +pub struct HyperLogLogState { + shards: [AtomicU8; N], } - -struct HyperLogLogVecCore { - pub children: RwLock, BuildHasherDefault>>, - pub desc: core::Desc, - pub opts: Opts, -} - -impl core::Collector for HyperLogLogVec { - fn desc(&self) -> Vec<&core::Desc> { - vec![&self.core.desc] - } - - fn collect(&self) -> Vec { - let mut m = proto::MetricFamily::default(); - m.set_name(self.core.desc.fq_name.clone()); - m.set_help(self.core.desc.help.clone()); - m.set_field_type(proto::MetricType::GAUGE); - - let mut metrics = Vec::new(); - for child in self.core.children.read().unwrap().values() { - child.core.collect_into(&mut metrics); - } - m.set_metric(metrics); - - vec![m] +impl Default for HyperLogLogState { + fn default() -> Self { + #[allow(clippy::declare_interior_mutable_const)] + const ZERO: AtomicU8 = AtomicU8::new(0); + Self { shards: [ZERO; N] } } } -impl HyperLogLogVec { - /// Create a new [`HyperLogLogVec`] based on the provided - /// [`Opts`] and partitioned by the given label names. At least one label name must be - /// provided. - pub fn new(opts: Opts, label_names: &[&str]) -> prometheus::Result { - assert!(N.is_power_of_two()); - let variable_names = label_names.iter().map(|s| (*s).to_owned()).collect(); - let opts = opts.variable_labels(variable_names); - - let desc = opts.describe()?; - let v = HyperLogLogVecCore { - children: RwLock::new(HashMap::default()), - desc, - opts, - }; - - Ok(Self { core: Arc::new(v) }) - } - - /// `get_metric_with_label_values` returns the [`HyperLogLog

`] for the given slice - /// of label values (same order as the VariableLabels in Desc). If that combination of - /// label values is accessed for the first time, a new [`HyperLogLog

`] is created. - /// - /// An error is returned if the number of label values is not the same as the - /// number of VariableLabels in Desc. - pub fn get_metric_with_label_values( - &self, - vals: &[&str], - ) -> prometheus::Result> { - self.core.get_metric_with_label_values(vals) - } - - /// `with_label_values` works as `get_metric_with_label_values`, but panics if an error - /// occurs. - pub fn with_label_values(&self, vals: &[&str]) -> HyperLogLog { - self.get_metric_with_label_values(vals).unwrap() - } +impl MetricType for HyperLogLogState { + type Metadata = (); } -impl HyperLogLogVecCore { - pub fn get_metric_with_label_values( - &self, - vals: &[&str], - ) -> prometheus::Result> { - let h = self.hash_label_values(vals)?; - - if let Some(metric) = self.children.read().unwrap().get(&h).cloned() { - return Ok(metric); - } - - self.get_or_create_metric(h, vals) - } - - pub(crate) fn hash_label_values(&self, vals: &[&str]) -> prometheus::Result { - if vals.len() != self.desc.variable_labels.len() { - return Err(prometheus::Error::InconsistentCardinality { - expect: self.desc.variable_labels.len(), - got: vals.len(), - }); - } - - let mut h = xxh3::Hash64::default(); - for val in vals { - h.write(val.as_bytes()); - } - - Ok(h.finish()) - } - - fn get_or_create_metric( - &self, - hash: u64, - label_values: &[&str], - ) -> prometheus::Result> { - let mut children = self.children.write().unwrap(); - // Check exist first. - if let Some(metric) = children.get(&hash).cloned() { - return Ok(metric); - } - - let metric = HyperLogLog::with_opts_and_label_values(&self.opts, label_values)?; - children.insert(hash, metric.clone()); - Ok(metric) - } -} - -/// HLL is a probabilistic cardinality measure. -/// -/// How to use this time-series for a metric name `my_metrics_total_hll`: -/// -/// ```promql -/// # harmonic mean -/// 1 / ( -/// sum ( -/// 2 ^ -( -/// # HLL merge operation -/// max (my_metrics_total_hll{}) by (hll_shard, other_labels...) -/// ) -/// ) without (hll_shard) -/// ) -/// * alpha -/// * shards_count -/// * shards_count -/// ``` -/// -/// If you want an estimate over time, you can use the following query: -/// -/// ```promql -/// # harmonic mean -/// 1 / ( -/// sum ( -/// 2 ^ -( -/// # HLL merge operation -/// max ( -/// max_over_time(my_metrics_total_hll{}[$__rate_interval]) -/// ) by (hll_shard, other_labels...) -/// ) -/// ) without (hll_shard) -/// ) -/// * alpha -/// * shards_count -/// * shards_count -/// ``` -/// -/// In the case of low cardinality, you might want to use the linear counting approximation: -/// -/// ```promql -/// # LinearCounting(m, V) = m log (m / V) -/// shards_count * ln(shards_count / -/// # calculate V = how many shards contain a 0 -/// count(max (proxy_connecting_endpoints{}) by (hll_shard, protocol) == 0) without (hll_shard) -/// ) -/// ``` -/// -/// See for estimates on alpha -#[derive(Clone)] -pub struct HyperLogLog { - core: Arc>, -} - -impl HyperLogLog { - /// Create a [`HyperLogLog`] with the `name` and `help` arguments. - pub fn new, S2: Into>(name: S1, help: S2) -> prometheus::Result { - assert!(N.is_power_of_two()); - let opts = Opts::new(name, help); - Self::with_opts(opts) - } - - /// Create a [`HyperLogLog`] with the `opts` options. - pub fn with_opts(opts: Opts) -> prometheus::Result { - Self::with_opts_and_label_values(&opts, &[]) - } - - fn with_opts_and_label_values(opts: &Opts, label_values: &[&str]) -> prometheus::Result { - let desc = opts.describe()?; - let labels = make_label_pairs(&desc, label_values)?; - - let v = HyperLogLogCore { - shards: [0; N].map(AtomicU8::new), - desc, - labels, - }; - Ok(Self { core: Arc::new(v) }) - } - +impl HyperLogLogState { pub fn measure(&self, item: &impl Hash) { // changing the hasher will break compatibility with previous measurements. self.record(BuildHasherDefault::::default().hash_one(item)); @@ -299,42 +126,11 @@ impl HyperLogLog { let p = N.ilog2() as u8; let j = hash & (N as u64 - 1); let rho = (hash >> p).leading_zeros() as u8 + 1 - p; - self.core.shards[j as usize].fetch_max(rho, std::sync::atomic::Ordering::Relaxed); - } -} - -struct HyperLogLogCore { - shards: [AtomicU8; N], - desc: core::Desc, - labels: Vec, -} - -impl core::Collector for HyperLogLog { - fn desc(&self) -> Vec<&core::Desc> { - vec![&self.core.desc] + self.shards[j as usize].fetch_max(rho, std::sync::atomic::Ordering::Relaxed); } - fn collect(&self) -> Vec { - let mut m = proto::MetricFamily::default(); - m.set_name(self.core.desc.fq_name.clone()); - m.set_help(self.core.desc.help.clone()); - m.set_field_type(proto::MetricType::GAUGE); - - let mut metrics = Vec::new(); - self.core.collect_into(&mut metrics); - m.set_metric(metrics); - - vec![m] - } -} - -impl HyperLogLogCore { - fn collect_into(&self, metrics: &mut Vec) { - self.shards.iter().enumerate().for_each(|(i, x)| { - let mut shard_label = proto::LabelPair::default(); - shard_label.set_name("hll_shard".to_owned()); - shard_label.set_value(format!("{i}")); - + fn take_sample(&self) -> [u8; N] { + self.shards.each_ref().map(|x| { // We reset the counter to 0 so we can perform a cardinality measure over any time slice in prometheus. // This seems like it would be a race condition, @@ -344,85 +140,90 @@ impl HyperLogLogCore { // TODO: maybe we shouldn't reset this on every collect, instead, only after a time window. // this would mean that a dev port-forwarding the metrics url won't break the sampling. - let v = x.swap(0, std::sync::atomic::Ordering::Relaxed); - - let mut m = proto::Metric::default(); - let mut c = proto::Gauge::default(); - c.set_value(v as f64); - m.set_gauge(c); - - let mut labels = Vec::with_capacity(self.labels.len() + 1); - labels.extend_from_slice(&self.labels); - labels.push(shard_label); - - m.set_label(labels); - metrics.push(m); + x.swap(0, std::sync::atomic::Ordering::Relaxed) }) } } - -fn make_label_pairs( - desc: &core::Desc, - label_values: &[&str], -) -> prometheus::Result> { - if desc.variable_labels.len() != label_values.len() { - return Err(prometheus::Error::InconsistentCardinality { - expect: desc.variable_labels.len(), - got: label_values.len(), - }); +impl measured::metric::MetricEncoding> + for HyperLogLogState +{ + fn write_type( + name: impl MetricNameEncoder, + enc: &mut TextEncoder, + ) -> Result<(), std::io::Error> { + enc.write_type(&name, measured::text::MetricType::Gauge) } + fn collect_into( + &self, + _: &(), + labels: impl LabelGroup, + name: impl MetricNameEncoder, + enc: &mut TextEncoder, + ) -> Result<(), std::io::Error> { + struct I64(i64); + impl LabelValue for I64 { + fn visit(&self, v: V) -> V::Output { + v.write_int(self.0) + } + } - let total_len = desc.variable_labels.len() + desc.const_label_pairs.len(); - if total_len == 0 { - return Ok(vec![]); - } + struct HllShardLabel { + hll_shard: i64, + } - if desc.variable_labels.is_empty() { - return Ok(desc.const_label_pairs.clone()); - } + impl LabelGroup for HllShardLabel { + fn visit_values(&self, v: &mut impl LabelGroupVisitor) { + const LE: &LabelName = LabelName::from_str("hll_shard"); + v.write_value(LE, &I64(self.hll_shard)); + } + } - let mut label_pairs = Vec::with_capacity(total_len); - for (i, n) in desc.variable_labels.iter().enumerate() { - let mut label_pair = proto::LabelPair::default(); - label_pair.set_name(n.clone()); - label_pair.set_value(label_values[i].to_owned()); - label_pairs.push(label_pair); + self.take_sample() + .into_iter() + .enumerate() + .try_for_each(|(hll_shard, val)| { + enc.write_metric_value( + name.by_ref(), + labels.by_ref().compose_with(HllShardLabel { + hll_shard: hll_shard as i64, + }), + MetricValue::Int(val as i64), + ) + }) } - - for label_pair in &desc.const_label_pairs { - label_pairs.push(label_pair.clone()); - } - label_pairs.sort(); - Ok(label_pairs) } #[cfg(test)] mod tests { use std::collections::HashSet; - use prometheus::{proto, Opts}; + use measured::{label::StaticLabelSet, FixedCardinalityLabel}; use rand::{rngs::StdRng, Rng, SeedableRng}; use rand_distr::{Distribution, Zipf}; use crate::HyperLogLogVec; - fn collect(hll: &HyperLogLogVec<32>) -> Vec { - let mut metrics = vec![]; - hll.core - .children - .read() - .unwrap() - .values() - .for_each(|c| c.core.collect_into(&mut metrics)); - metrics + #[derive(FixedCardinalityLabel, Clone, Copy)] + #[label(singleton = "x")] + enum Label { + A, + B, } - fn get_cardinality(metrics: &[proto::Metric], filter: impl Fn(&proto::Metric) -> bool) -> f64 { + + fn collect(hll: &HyperLogLogVec, 32>) -> ([u8; 32], [u8; 32]) { + // cannot go through the `hll.collect_family_into` interface yet... + // need to see if I can fix the conflicting impls problem in measured. + ( + hll.get_metric(hll.with_labels(Label::A)).take_sample(), + hll.get_metric(hll.with_labels(Label::B)).take_sample(), + ) + } + + fn get_cardinality(samples: &[[u8; 32]]) -> f64 { let mut buckets = [0.0; 32]; - for metric in metrics.chunks_exact(32) { - if filter(&metric[0]) { - for (i, m) in metric.iter().enumerate() { - buckets[i] = f64::max(buckets[i], m.get_gauge().get_value()); - } + for &sample in samples { + for (i, m) in sample.into_iter().enumerate() { + buckets[i] = f64::max(buckets[i], m as f64); } } @@ -437,7 +238,7 @@ mod tests { } fn test_cardinality(n: usize, dist: impl Distribution) -> ([usize; 3], [f64; 3]) { - let hll = HyperLogLogVec::<32>::new(Opts::new("foo", "bar"), &["x"]).unwrap(); + let hll = HyperLogLogVec::, 32>::new(); let mut iter = StdRng::seed_from_u64(0x2024_0112).sample_iter(dist); let mut set_a = HashSet::new(); @@ -445,18 +246,20 @@ mod tests { for x in iter.by_ref().take(n) { set_a.insert(x.to_bits()); - hll.with_label_values(&["a"]).measure(&x.to_bits()); + hll.get_metric(hll.with_labels(Label::A)) + .measure(&x.to_bits()); } for x in iter.by_ref().take(n) { set_b.insert(x.to_bits()); - hll.with_label_values(&["b"]).measure(&x.to_bits()); + hll.get_metric(hll.with_labels(Label::B)) + .measure(&x.to_bits()); } let merge = &set_a | &set_b; - let metrics = collect(&hll); - let len = get_cardinality(&metrics, |_| true); - let len_a = get_cardinality(&metrics, |l| l.get_label()[0].get_value() == "a"); - let len_b = get_cardinality(&metrics, |l| l.get_label()[0].get_value() == "b"); + let (a, b) = collect(&hll); + let len = get_cardinality(&[a, b]); + let len_a = get_cardinality(&[a]); + let len_b = get_cardinality(&[b]); ([merge.len(), set_a.len(), set_b.len()], [len, len_a, len_b]) } diff --git a/libs/metrics/src/lib.rs b/libs/metrics/src/lib.rs index 6cff28c0ca..2cf3cdeaa7 100644 --- a/libs/metrics/src/lib.rs +++ b/libs/metrics/src/lib.rs @@ -5,7 +5,7 @@ #![deny(clippy::undocumented_unsafe_blocks)] use measured::{ - label::{LabelGroupVisitor, LabelName, NoLabels}, + label::{LabelGroupSet, LabelGroupVisitor, LabelName, NoLabels}, metric::{ counter::CounterState, gauge::GaugeState, @@ -40,7 +40,7 @@ pub mod launch_timestamp; mod wrappers; pub use wrappers::{CountedReader, CountedWriter}; mod hll; -pub use hll::{HyperLogLog, HyperLogLogVec}; +pub use hll::{HyperLogLog, HyperLogLogState, HyperLogLogVec}; #[cfg(target_os = "linux")] pub mod more_process_metrics; @@ -421,3 +421,171 @@ pub type IntCounterPair = GenericCounterPair; /// A guard for [`IntCounterPair`] that will decrement the gauge on drop pub type IntCounterPairGuard = GenericCounterPairGuard; + +pub trait CounterPairAssoc { + const INC_NAME: &'static MetricName; + const DEC_NAME: &'static MetricName; + + const INC_HELP: &'static str; + const DEC_HELP: &'static str; + + type LabelGroupSet: LabelGroupSet; +} + +pub struct CounterPairVec { + vec: measured::metric::MetricVec, +} + +impl Default for CounterPairVec +where + A::LabelGroupSet: Default, +{ + fn default() -> Self { + Self { + vec: Default::default(), + } + } +} + +impl CounterPairVec { + pub fn guard( + &self, + labels: ::Group<'_>, + ) -> MeasuredCounterPairGuard<'_, A> { + let id = self.vec.with_labels(labels); + self.vec.get_metric(id).inc.inc(); + MeasuredCounterPairGuard { vec: &self.vec, id } + } + pub fn inc(&self, labels: ::Group<'_>) { + let id = self.vec.with_labels(labels); + self.vec.get_metric(id).inc.inc(); + } + pub fn dec(&self, labels: ::Group<'_>) { + let id = self.vec.with_labels(labels); + self.vec.get_metric(id).dec.inc(); + } + pub fn remove_metric( + &self, + labels: ::Group<'_>, + ) -> Option { + let id = self.vec.with_labels(labels); + self.vec.remove_metric(id) + } +} + +impl ::measured::metric::group::MetricGroup for CounterPairVec +where + T: ::measured::metric::group::Encoding, + A: CounterPairAssoc, + ::measured::metric::counter::CounterState: ::measured::metric::MetricEncoding, +{ + fn collect_group_into(&self, enc: &mut T) -> Result<(), T::Err> { + // write decrement first to avoid a race condition where inc - dec < 0 + T::write_help(enc, A::DEC_NAME, A::DEC_HELP)?; + self.vec + .collect_family_into(A::DEC_NAME, &mut Dec(&mut *enc))?; + + T::write_help(enc, A::INC_NAME, A::INC_HELP)?; + self.vec + .collect_family_into(A::INC_NAME, &mut Inc(&mut *enc))?; + + Ok(()) + } +} + +#[derive(MetricGroup, Default)] +pub struct MeasuredCounterPairState { + pub inc: CounterState, + pub dec: CounterState, +} + +impl measured::metric::MetricType for MeasuredCounterPairState { + type Metadata = (); +} + +pub struct MeasuredCounterPairGuard<'a, A: CounterPairAssoc> { + vec: &'a measured::metric::MetricVec, + id: measured::metric::LabelId, +} + +impl Drop for MeasuredCounterPairGuard<'_, A> { + fn drop(&mut self) { + self.vec.get_metric(self.id).dec.inc(); + } +} + +/// [`MetricEncoding`] for [`MeasuredCounterPairState`] that only writes the inc counter to the inner encoder. +struct Inc(T); +/// [`MetricEncoding`] for [`MeasuredCounterPairState`] that only writes the dec counter to the inner encoder. +struct Dec(T); + +impl Encoding for Inc { + type Err = T::Err; + + fn write_help(&mut self, name: impl MetricNameEncoder, help: &str) -> Result<(), Self::Err> { + self.0.write_help(name, help) + } + + fn write_metric_value( + &mut self, + name: impl MetricNameEncoder, + labels: impl LabelGroup, + value: MetricValue, + ) -> Result<(), Self::Err> { + self.0.write_metric_value(name, labels, value) + } +} + +impl MetricEncoding> for MeasuredCounterPairState +where + CounterState: MetricEncoding, +{ + fn write_type(name: impl MetricNameEncoder, enc: &mut Inc) -> Result<(), T::Err> { + CounterState::write_type(name, &mut enc.0) + } + fn collect_into( + &self, + metadata: &(), + labels: impl LabelGroup, + name: impl MetricNameEncoder, + enc: &mut Inc, + ) -> Result<(), T::Err> { + self.inc.collect_into(metadata, labels, name, &mut enc.0) + } +} + +impl Encoding for Dec { + type Err = T::Err; + + fn write_help(&mut self, name: impl MetricNameEncoder, help: &str) -> Result<(), Self::Err> { + self.0.write_help(name, help) + } + + fn write_metric_value( + &mut self, + name: impl MetricNameEncoder, + labels: impl LabelGroup, + value: MetricValue, + ) -> Result<(), Self::Err> { + self.0.write_metric_value(name, labels, value) + } +} + +/// Write the dec counter to the encoder +impl MetricEncoding> for MeasuredCounterPairState +where + CounterState: MetricEncoding, +{ + fn write_type(name: impl MetricNameEncoder, enc: &mut Dec) -> Result<(), T::Err> { + CounterState::write_type(name, &mut enc.0) + } + fn collect_into( + &self, + metadata: &(), + labels: impl LabelGroup, + name: impl MetricNameEncoder, + enc: &mut Dec, + ) -> Result<(), T::Err> { + self.dec.collect_into(metadata, labels, name, &mut enc.0) + } +} diff --git a/libs/pageserver_api/src/models.rs b/libs/pageserver_api/src/models.rs index b4909f247f..f441d1ff1a 100644 --- a/libs/pageserver_api/src/models.rs +++ b/libs/pageserver_api/src/models.rs @@ -747,10 +747,18 @@ pub struct TimelineGcRequest { pub gc_horizon: Option, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WalRedoManagerProcessStatus { + pub pid: u32, + /// The strum-generated `into::<&'static str>()` for `pageserver::walredo::ProcessKind`. + /// `ProcessKind` are a transitory thing, so, they have no enum representation in `pageserver_api`. + pub kind: Cow<'static, str>, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct WalRedoManagerStatus { pub last_redo_at: Option>, - pub pid: Option, + pub process: Option, } /// The progress of a secondary tenant is mostly useful when doing a long running download: e.g. initiating diff --git a/libs/pageserver_api/src/shard.rs b/libs/pageserver_api/src/shard.rs index a2a9165184..c293ad705b 100644 --- a/libs/pageserver_api/src/shard.rs +++ b/libs/pageserver_api/src/shard.rs @@ -8,12 +8,89 @@ use hex::FromHex; use serde::{Deserialize, Serialize}; use utils::id::TenantId; +/// See docs/rfcs/031-sharding-static.md for an overview of sharding. +/// +/// This module contains a variety of types used to represent the concept of sharding +/// a Neon tenant across multiple physical shards. Since there are quite a few of these, +/// we provide an summary here. +/// +/// Types used to describe shards: +/// - [`ShardCount`] describes how many shards make up a tenant, plus the magic `unsharded` value +/// which identifies a tenant which is not shard-aware. This means its storage paths do not include +/// a shard suffix. +/// - [`ShardNumber`] is simply the zero-based index of a shard within a tenant. +/// - [`ShardIndex`] is the 2-tuple of `ShardCount` and `ShardNumber`, it's just like a `TenantShardId` +/// without the tenant ID. This is useful for things that are implicitly scoped to a particular +/// tenant, such as layer files. +/// - [`ShardIdentity`]` is the full description of a particular shard's parameters, in sufficient +/// detail to convert a [`Key`] to a [`ShardNumber`] when deciding where to write/read. +/// - The [`ShardSlug`] is a terse formatter for ShardCount and ShardNumber, written as +/// four hex digits. An unsharded tenant is `0000`. +/// - [`TenantShardId`] is the unique ID of a particular shard within a particular tenant +/// +/// Types used to describe the parameters for data distribution in a sharded tenant: +/// - [`ShardStripeSize`] controls how long contiguous runs of [`Key`]s (stripes) are when distributed across +/// multiple shards. Its value is given in 8kiB pages. +/// - [`ShardLayout`] describes the data distribution scheme, and at time of writing is +/// always zero: this is provided for future upgrades that might introduce different +/// data distribution schemes. +/// +/// Examples: +/// - A legacy unsharded tenant has one shard with ShardCount(0), ShardNumber(0), and its slug is 0000 +/// - A single sharded tenant has one shard with ShardCount(1), ShardNumber(0), and its slug is 0001 +/// - In a tenant with 4 shards, each shard has ShardCount(N), ShardNumber(i) where i in 0..N-1 (inclusive), +/// and their slugs are 0004, 0104, 0204, and 0304. + #[derive(Ord, PartialOrd, Eq, PartialEq, Clone, Copy, Serialize, Deserialize, Debug, Hash)] pub struct ShardNumber(pub u8); #[derive(Ord, PartialOrd, Eq, PartialEq, Clone, Copy, Serialize, Deserialize, Debug, Hash)] pub struct ShardCount(u8); +/// Combination of ShardNumber and ShardCount. For use within the context of a particular tenant, +/// when we need to know which shard we're dealing with, but do not need to know the full +/// ShardIdentity (because we won't be doing any page->shard mapping), and do not need to know +/// the fully qualified TenantShardId. +#[derive(Eq, PartialEq, PartialOrd, Ord, Clone, Copy, Hash)] +pub struct ShardIndex { + pub shard_number: ShardNumber, + pub shard_count: ShardCount, +} + +/// The ShardIdentity contains enough information to map a [`Key`] to a [`ShardNumber`], +/// and to check whether that [`ShardNumber`] is the same as the current shard. +#[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, Debug)] +pub struct ShardIdentity { + pub number: ShardNumber, + pub count: ShardCount, + pub stripe_size: ShardStripeSize, + layout: ShardLayout, +} + +/// Formatting helper, for generating the `shard_id` label in traces. +struct ShardSlug<'a>(&'a TenantShardId); + +/// TenantShardId globally identifies a particular shard in a particular tenant. +/// +/// These are written as `-`, for example: +/// # The second shard in a two-shard tenant +/// 072f1291a5310026820b2fe4b2968934-0102 +/// +/// If the `ShardCount` is _unsharded_, the `TenantShardId` is written without +/// a shard suffix and is equivalent to the encoding of a `TenantId`: this enables +/// an unsharded [`TenantShardId`] to be used interchangably with a [`TenantId`]. +/// +/// The human-readable encoding of an unsharded TenantShardId, such as used in API URLs, +/// is both forward and backward compatible with TenantId: a legacy TenantId can be +/// decoded as a TenantShardId, and when re-encoded it will be parseable +/// as a TenantId. +#[derive(Eq, PartialEq, PartialOrd, Ord, Clone, Copy, Hash)] +pub struct TenantShardId { + pub tenant_id: TenantId, + pub shard_number: ShardNumber, + pub shard_count: ShardCount, +} + impl ShardCount { pub const MAX: Self = Self(u8::MAX); @@ -38,6 +115,7 @@ impl ShardCount { self.0 } + /// pub fn is_unsharded(&self) -> bool { self.0 == 0 } @@ -53,33 +131,6 @@ impl ShardNumber { pub const MAX: Self = Self(u8::MAX); } -/// TenantShardId identify the units of work for the Pageserver. -/// -/// These are written as `-`, for example: -/// -/// # The second shard in a two-shard tenant -/// 072f1291a5310026820b2fe4b2968934-0102 -/// -/// Historically, tenants could not have multiple shards, and were identified -/// by TenantId. To support this, TenantShardId has a special legacy -/// mode where `shard_count` is equal to zero: this represents a single-sharded -/// tenant which should be written as a TenantId with no suffix. -/// -/// The human-readable encoding of TenantShardId, such as used in API URLs, -/// is both forward and backward compatible: a legacy TenantId can be -/// decoded as a TenantShardId, and when re-encoded it will be parseable -/// as a TenantId. -/// -/// Note that the binary encoding is _not_ backward compatible, because -/// at the time sharding is introduced, there are no existing binary structures -/// containing TenantId that we need to handle. -#[derive(Eq, PartialEq, PartialOrd, Ord, Clone, Copy, Hash)] -pub struct TenantShardId { - pub tenant_id: TenantId, - pub shard_number: ShardNumber, - pub shard_count: ShardCount, -} - impl TenantShardId { pub fn unsharded(tenant_id: TenantId) -> Self { Self { @@ -111,10 +162,13 @@ impl TenantShardId { } /// Convenience for code that has special behavior on the 0th shard. - pub fn is_zero(&self) -> bool { + pub fn is_shard_zero(&self) -> bool { self.shard_number == ShardNumber(0) } + /// The "unsharded" value is distinct from simply having a single shard: it represents + /// a tenant which is not shard-aware at all, and whose storage paths will not include + /// a shard suffix. pub fn is_unsharded(&self) -> bool { self.shard_number == ShardNumber(0) && self.shard_count.is_unsharded() } @@ -150,9 +204,6 @@ impl TenantShardId { } } -/// Formatting helper -struct ShardSlug<'a>(&'a TenantShardId); - impl<'a> std::fmt::Display for ShardSlug<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( @@ -222,16 +273,6 @@ impl From<[u8; 18]> for TenantShardId { } } -/// For use within the context of a particular tenant, when we need to know which -/// shard we're dealing with, but do not need to know the full ShardIdentity (because -/// we won't be doing any page->shard mapping), and do not need to know the fully qualified -/// TenantShardId. -#[derive(Eq, PartialEq, PartialOrd, Ord, Clone, Copy, Hash)] -pub struct ShardIndex { - pub shard_number: ShardNumber, - pub shard_count: ShardCount, -} - impl ShardIndex { pub fn new(number: ShardNumber, count: ShardCount) -> Self { Self { @@ -246,6 +287,9 @@ impl ShardIndex { } } + /// The "unsharded" value is distinct from simply having a single shard: it represents + /// a tenant which is not shard-aware at all, and whose storage paths will not include + /// a shard suffix. pub fn is_unsharded(&self) -> bool { self.shard_number == ShardNumber(0) && self.shard_count == ShardCount(0) } @@ -313,6 +357,8 @@ impl Serialize for TenantShardId { if serializer.is_human_readable() { serializer.collect_str(self) } else { + // Note: while human encoding of [`TenantShardId`] is backward and forward + // compatible, this binary encoding is not. let mut packed: [u8; 18] = [0; 18]; packed[0..16].clone_from_slice(&self.tenant_id.as_arr()); packed[16] = self.shard_number.0; @@ -390,16 +436,6 @@ const LAYOUT_BROKEN: ShardLayout = ShardLayout(255); /// Default stripe size in pages: 256MiB divided by 8kiB page size. const DEFAULT_STRIPE_SIZE: ShardStripeSize = ShardStripeSize(256 * 1024 / 8); -/// The ShardIdentity contains the information needed for one member of map -/// to resolve a key to a shard, and then check whether that shard is ==self. -#[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, Debug)] -pub struct ShardIdentity { - pub number: ShardNumber, - pub count: ShardCount, - pub stripe_size: ShardStripeSize, - layout: ShardLayout, -} - #[derive(thiserror::Error, Debug, PartialEq, Eq)] pub enum ShardConfigError { #[error("Invalid shard count")] @@ -439,6 +475,9 @@ impl ShardIdentity { } } + /// The "unsharded" value is distinct from simply having a single shard: it represents + /// a tenant which is not shard-aware at all, and whose storage paths will not include + /// a shard suffix. pub fn is_unsharded(&self) -> bool { self.number == ShardNumber(0) && self.count == ShardCount(0) } @@ -487,6 +526,8 @@ impl ShardIdentity { } /// Return true if the key should be ingested by this shard + /// + /// Shards must ingest _at least_ keys which return true from this check. pub fn is_key_local(&self, key: &Key) -> bool { assert!(!self.is_broken()); if self.count < ShardCount(2) || (key_is_shard0(key) && self.number == ShardNumber(0)) { @@ -497,7 +538,9 @@ impl ShardIdentity { } /// Return true if the key should be discarded if found in this shard's - /// data store, e.g. during compaction after a split + /// data store, e.g. during compaction after a split. + /// + /// Shards _may_ drop keys which return false here, but are not obliged to. pub fn is_key_disposable(&self, key: &Key) -> bool { if key_is_shard0(key) { // Q: Why can't we dispose of shard0 content if we're not shard 0? @@ -523,7 +566,7 @@ impl ShardIdentity { /// Convenience for checking if this identity is the 0th shard in a tenant, /// for special cases on shard 0 such as ingesting relation sizes. - pub fn is_zero(&self) -> bool { + pub fn is_shard_zero(&self) -> bool { self.number == ShardNumber(0) } } diff --git a/libs/utils/src/lib.rs b/libs/utils/src/lib.rs index b09350d11e..2953f0aad4 100644 --- a/libs/utils/src/lib.rs +++ b/libs/utils/src/lib.rs @@ -92,6 +92,8 @@ pub mod zstd; pub mod env; +pub mod poison; + /// This is a shortcut to embed git sha into binaries and avoid copying the same build script to all packages /// /// we have several cases: diff --git a/libs/utils/src/poison.rs b/libs/utils/src/poison.rs new file mode 100644 index 0000000000..0bf5664f47 --- /dev/null +++ b/libs/utils/src/poison.rs @@ -0,0 +1,121 @@ +//! Protect a piece of state from reuse after it is left in an inconsistent state. +//! +//! # Example +//! +//! ``` +//! # tokio_test::block_on(async { +//! use utils::poison::Poison; +//! use std::time::Duration; +//! +//! struct State { +//! clean: bool, +//! } +//! let state = tokio::sync::Mutex::new(Poison::new("mystate", State { clean: true })); +//! +//! let mut mutex_guard = state.lock().await; +//! let mut poison_guard = mutex_guard.check_and_arm()?; +//! let state = poison_guard.data_mut(); +//! state.clean = false; +//! // If we get cancelled at this await point, subsequent check_and_arm() calls will fail. +//! tokio::time::sleep(Duration::from_secs(10)).await; +//! state.clean = true; +//! poison_guard.disarm(); +//! # Ok::<(), utils::poison::Error>(()) +//! # }); +//! ``` + +use tracing::warn; + +pub struct Poison { + what: &'static str, + state: State, + data: T, +} + +#[derive(Clone, Copy)] +enum State { + Clean, + Armed, + Poisoned { at: chrono::DateTime }, +} + +impl Poison { + /// We log `what` `warning!` level if the [`Guard`] gets dropped without being [`Guard::disarm`]ed. + pub fn new(what: &'static str, data: T) -> Self { + Self { + what, + state: State::Clean, + data, + } + } + + /// Check for poisoning and return a [`Guard`] that provides access to the wrapped state. + pub fn check_and_arm(&mut self) -> Result, Error> { + match self.state { + State::Clean => { + self.state = State::Armed; + Ok(Guard(self)) + } + State::Armed => unreachable!("transient state"), + State::Poisoned { at } => Err(Error::Poisoned { + what: self.what, + at, + }), + } + } +} + +/// Use [`Self::data`] and [`Self::data_mut`] to access the wrapped state. +/// Once modifications are done, use [`Self::disarm`]. +/// If [`Guard`] gets dropped instead of calling [`Self::disarm`], the state is poisoned +/// and subsequent calls to [`Poison::check_and_arm`] will fail with an error. +pub struct Guard<'a, T>(&'a mut Poison); + +impl<'a, T> Guard<'a, T> { + pub fn data(&self) -> &T { + &self.0.data + } + pub fn data_mut(&mut self) -> &mut T { + &mut self.0.data + } + + pub fn disarm(self) { + match self.0.state { + State::Clean => unreachable!("we set it to Armed in check_and_arm()"), + State::Armed => { + self.0.state = State::Clean; + } + State::Poisoned { at } => { + unreachable!("we fail check_and_arm() if it's in that state: {at}") + } + } + } +} + +impl<'a, T> Drop for Guard<'a, T> { + fn drop(&mut self) { + match self.0.state { + State::Clean => { + // set by disarm() + } + State::Armed => { + // still armed => poison it + let at = chrono::Utc::now(); + self.0.state = State::Poisoned { at }; + warn!(at=?at, "poisoning {}", self.0.what); + } + State::Poisoned { at } => { + unreachable!("we fail check_and_arm() if it's in that state: {at}") + } + } + } +} + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("poisoned at {at}: {what}")] + Poisoned { + what: &'static str, + at: chrono::DateTime, + }, +} diff --git a/pageserver/benches/bench_walredo.rs b/pageserver/benches/bench_walredo.rs index ffe607be4b..5b871c5d5e 100644 --- a/pageserver/benches/bench_walredo.rs +++ b/pageserver/benches/bench_walredo.rs @@ -27,30 +27,50 @@ //! //! # Reference Numbers //! -//! 2024-04-04 on i3en.3xlarge +//! 2024-04-15 on i3en.3xlarge //! //! ```text -//! short/1 time: [25.925 µs 26.060 µs 26.209 µs] -//! short/2 time: [31.277 µs 31.483 µs 31.722 µs] -//! short/4 time: [45.496 µs 45.831 µs 46.182 µs] -//! short/8 time: [84.298 µs 84.920 µs 85.566 µs] -//! short/16 time: [185.04 µs 186.41 µs 187.88 µs] -//! short/32 time: [385.01 µs 386.77 µs 388.70 µs] -//! short/64 time: [770.24 µs 773.04 µs 776.04 µs] -//! short/128 time: [1.5017 ms 1.5064 ms 1.5113 ms] -//! medium/1 time: [106.65 µs 107.20 µs 107.85 µs] -//! medium/2 time: [153.28 µs 154.24 µs 155.56 µs] -//! medium/4 time: [325.67 µs 327.01 µs 328.71 µs] -//! medium/8 time: [646.82 µs 650.17 µs 653.91 µs] -//! medium/16 time: [1.2645 ms 1.2701 ms 1.2762 ms] -//! medium/32 time: [2.4409 ms 2.4550 ms 2.4692 ms] -//! medium/64 time: [4.6814 ms 4.7114 ms 4.7408 ms] -//! medium/128 time: [8.7790 ms 8.9037 ms 9.0282 ms] +//! async-short/1 time: [24.584 µs 24.737 µs 24.922 µs] +//! async-short/2 time: [33.479 µs 33.660 µs 33.888 µs] +//! async-short/4 time: [42.713 µs 43.046 µs 43.440 µs] +//! async-short/8 time: [71.814 µs 72.478 µs 73.240 µs] +//! async-short/16 time: [132.73 µs 134.45 µs 136.22 µs] +//! async-short/32 time: [258.31 µs 260.73 µs 263.27 µs] +//! async-short/64 time: [511.61 µs 514.44 µs 517.51 µs] +//! async-short/128 time: [992.64 µs 998.23 µs 1.0042 ms] +//! async-medium/1 time: [110.11 µs 110.50 µs 110.96 µs] +//! async-medium/2 time: [153.06 µs 153.85 µs 154.99 µs] +//! async-medium/4 time: [317.51 µs 319.92 µs 322.85 µs] +//! async-medium/8 time: [638.30 µs 644.68 µs 652.12 µs] +//! async-medium/16 time: [1.2651 ms 1.2773 ms 1.2914 ms] +//! async-medium/32 time: [2.5117 ms 2.5410 ms 2.5720 ms] +//! async-medium/64 time: [4.8088 ms 4.8555 ms 4.9047 ms] +//! async-medium/128 time: [8.8311 ms 8.9849 ms 9.1263 ms] +//! sync-short/1 time: [25.503 µs 25.626 µs 25.771 µs] +//! sync-short/2 time: [30.850 µs 31.013 µs 31.208 µs] +//! sync-short/4 time: [45.543 µs 45.856 µs 46.193 µs] +//! sync-short/8 time: [84.114 µs 84.639 µs 85.220 µs] +//! sync-short/16 time: [185.22 µs 186.15 µs 187.13 µs] +//! sync-short/32 time: [377.43 µs 378.87 µs 380.46 µs] +//! sync-short/64 time: [756.49 µs 759.04 µs 761.70 µs] +//! sync-short/128 time: [1.4825 ms 1.4874 ms 1.4923 ms] +//! sync-medium/1 time: [105.66 µs 106.01 µs 106.43 µs] +//! sync-medium/2 time: [153.10 µs 153.84 µs 154.72 µs] +//! sync-medium/4 time: [327.13 µs 329.44 µs 332.27 µs] +//! sync-medium/8 time: [654.26 µs 658.73 µs 663.63 µs] +//! sync-medium/16 time: [1.2682 ms 1.2748 ms 1.2816 ms] +//! sync-medium/32 time: [2.4456 ms 2.4595 ms 2.4731 ms] +//! sync-medium/64 time: [4.6523 ms 4.6890 ms 4.7256 ms] +//! sync-medium/128 time: [8.7215 ms 8.8323 ms 8.9344 ms] //! ``` use bytes::{Buf, Bytes}; use criterion::{BenchmarkId, Criterion}; -use pageserver::{config::PageServerConf, walrecord::NeonWalRecord, walredo::PostgresRedoManager}; +use pageserver::{ + config::PageServerConf, + walrecord::NeonWalRecord, + walredo::{PostgresRedoManager, ProcessKind}, +}; use pageserver_api::{key::Key, shard::TenantShardId}; use std::{ sync::Arc, @@ -60,33 +80,39 @@ use tokio::{sync::Barrier, task::JoinSet}; use utils::{id::TenantId, lsn::Lsn}; fn bench(c: &mut Criterion) { - { - let nclients = [1, 2, 4, 8, 16, 32, 64, 128]; - for nclients in nclients { - let mut group = c.benchmark_group("short"); - group.bench_with_input( - BenchmarkId::from_parameter(nclients), - &nclients, - |b, nclients| { - let redo_work = Arc::new(Request::short_input()); - b.iter_custom(|iters| bench_impl(Arc::clone(&redo_work), iters, *nclients)); - }, - ); + for process_kind in &[ProcessKind::Async, ProcessKind::Sync] { + { + let nclients = [1, 2, 4, 8, 16, 32, 64, 128]; + for nclients in nclients { + let mut group = c.benchmark_group(format!("{process_kind}-short")); + group.bench_with_input( + BenchmarkId::from_parameter(nclients), + &nclients, + |b, nclients| { + let redo_work = Arc::new(Request::short_input()); + b.iter_custom(|iters| { + bench_impl(*process_kind, Arc::clone(&redo_work), iters, *nclients) + }); + }, + ); + } } - } - { - let nclients = [1, 2, 4, 8, 16, 32, 64, 128]; - for nclients in nclients { - let mut group = c.benchmark_group("medium"); - group.bench_with_input( - BenchmarkId::from_parameter(nclients), - &nclients, - |b, nclients| { - let redo_work = Arc::new(Request::medium_input()); - b.iter_custom(|iters| bench_impl(Arc::clone(&redo_work), iters, *nclients)); - }, - ); + { + let nclients = [1, 2, 4, 8, 16, 32, 64, 128]; + for nclients in nclients { + let mut group = c.benchmark_group(format!("{process_kind}-medium")); + group.bench_with_input( + BenchmarkId::from_parameter(nclients), + &nclients, + |b, nclients| { + let redo_work = Arc::new(Request::medium_input()); + b.iter_custom(|iters| { + bench_impl(*process_kind, Arc::clone(&redo_work), iters, *nclients) + }); + }, + ); + } } } } @@ -94,10 +120,16 @@ criterion::criterion_group!(benches, bench); criterion::criterion_main!(benches); // Returns the sum of each client's wall-clock time spent executing their share of the n_redos. -fn bench_impl(redo_work: Arc, n_redos: u64, nclients: u64) -> Duration { +fn bench_impl( + process_kind: ProcessKind, + redo_work: Arc, + n_redos: u64, + nclients: u64, +) -> Duration { let repo_dir = camino_tempfile::tempdir_in(env!("CARGO_TARGET_TMPDIR")).unwrap(); - let conf = PageServerConf::dummy_conf(repo_dir.path().to_path_buf()); + let mut conf = PageServerConf::dummy_conf(repo_dir.path().to_path_buf()); + conf.walredo_process_kind = process_kind; let conf = Box::leak(Box::new(conf)); let tenant_shard_id = TenantShardId::unsharded(TenantId::generate()); @@ -113,25 +145,40 @@ fn bench_impl(redo_work: Arc, n_redos: u64, nclients: u64) -> Duration let manager = PostgresRedoManager::new(conf, tenant_shard_id); let manager = Arc::new(manager); + // divide the amount of work equally among the clients. + let nredos_per_client = n_redos / nclients; for _ in 0..nclients { rt.block_on(async { tasks.spawn(client( Arc::clone(&manager), Arc::clone(&start), Arc::clone(&redo_work), - // divide the amount of work equally among the clients - n_redos / nclients, + nredos_per_client, )) }); } - rt.block_on(async move { - let mut total_wallclock_time = std::time::Duration::from_millis(0); + let elapsed = rt.block_on(async move { + let mut total_wallclock_time = Duration::ZERO; while let Some(res) = tasks.join_next().await { total_wallclock_time += res.unwrap(); } total_wallclock_time - }) + }); + + // consistency check to ensure process kind setting worked + if nredos_per_client > 0 { + assert_eq!( + manager + .status() + .process + .map(|p| p.kind) + .expect("the benchmark work causes a walredo process to be spawned"), + std::borrow::Cow::Borrowed(process_kind.into()) + ); + } + + elapsed } async fn client( diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index 0903b206ff..41835f9843 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -285,6 +285,7 @@ fn start_pageserver( )) .unwrap(); pageserver::preinitialize_metrics(); + pageserver::metrics::wal_redo::set_process_kind_metric(conf.walredo_process_kind); // If any failpoints were set from FAILPOINTS environment variable, // print them to the log for debugging purposes diff --git a/pageserver/src/config.rs b/pageserver/src/config.rs index 1837da34ce..e10db2b853 100644 --- a/pageserver/src/config.rs +++ b/pageserver/src/config.rs @@ -97,6 +97,8 @@ pub mod defaults { pub const DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB: usize = 0; + pub const DEFAULT_WALREDO_PROCESS_KIND: &str = "sync"; + /// /// Default built-in configuration file. /// @@ -140,6 +142,8 @@ pub mod defaults { #validate_vectored_get = '{DEFAULT_VALIDATE_VECTORED_GET}' +#walredo_process_kind = '{DEFAULT_WALREDO_PROCESS_KIND}' + [tenant_config] #checkpoint_distance = {DEFAULT_CHECKPOINT_DISTANCE} # in bytes #checkpoint_timeout = {DEFAULT_CHECKPOINT_TIMEOUT} @@ -290,6 +294,8 @@ pub struct PageServerConf { /// /// Setting this to zero disables limits on total ephemeral layer size. pub ephemeral_bytes_per_memory_kb: usize, + + pub walredo_process_kind: crate::walredo::ProcessKind, } /// We do not want to store this in a PageServerConf because the latter may be logged @@ -413,6 +419,8 @@ struct PageServerConfigBuilder { validate_vectored_get: BuilderValue, ephemeral_bytes_per_memory_kb: BuilderValue, + + walredo_process_kind: BuilderValue, } impl PageServerConfigBuilder { @@ -500,6 +508,8 @@ impl PageServerConfigBuilder { )), validate_vectored_get: Set(DEFAULT_VALIDATE_VECTORED_GET), ephemeral_bytes_per_memory_kb: Set(DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB), + + walredo_process_kind: Set(DEFAULT_WALREDO_PROCESS_KIND.parse().unwrap()), } } } @@ -683,6 +693,10 @@ impl PageServerConfigBuilder { self.ephemeral_bytes_per_memory_kb = BuilderValue::Set(value); } + pub fn get_walredo_process_kind(&mut self, value: crate::walredo::ProcessKind) { + self.walredo_process_kind = BuilderValue::Set(value); + } + pub fn build(self) -> anyhow::Result { let default = Self::default_values(); @@ -739,6 +753,7 @@ impl PageServerConfigBuilder { max_vectored_read_bytes, validate_vectored_get, ephemeral_bytes_per_memory_kb, + walredo_process_kind, } CUSTOM LOGIC { @@ -1032,6 +1047,9 @@ impl PageServerConf { "ephemeral_bytes_per_memory_kb" => { builder.get_ephemeral_bytes_per_memory_kb(parse_toml_u64("ephemeral_bytes_per_memory_kb", item)? as usize) } + "walredo_process_kind" => { + builder.get_walredo_process_kind(parse_toml_from_str("walredo_process_kind", item)?) + } _ => bail!("unrecognized pageserver option '{key}'"), } } @@ -1114,6 +1132,7 @@ impl PageServerConf { ), validate_vectored_get: defaults::DEFAULT_VALIDATE_VECTORED_GET, ephemeral_bytes_per_memory_kb: defaults::DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB, + walredo_process_kind: defaults::DEFAULT_WALREDO_PROCESS_KIND.parse().unwrap(), } } } @@ -1351,7 +1370,8 @@ background_task_maximum_delay = '334 s' .expect("Invalid default constant") ), validate_vectored_get: defaults::DEFAULT_VALIDATE_VECTORED_GET, - ephemeral_bytes_per_memory_kb: defaults::DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB + ephemeral_bytes_per_memory_kb: defaults::DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB, + walredo_process_kind: defaults::DEFAULT_WALREDO_PROCESS_KIND.parse().unwrap(), }, "Correct defaults should be used when no config values are provided" ); @@ -1423,7 +1443,8 @@ background_task_maximum_delay = '334 s' .expect("Invalid default constant") ), validate_vectored_get: defaults::DEFAULT_VALIDATE_VECTORED_GET, - ephemeral_bytes_per_memory_kb: defaults::DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB + ephemeral_bytes_per_memory_kb: defaults::DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB, + walredo_process_kind: defaults::DEFAULT_WALREDO_PROCESS_KIND.parse().unwrap(), }, "Should be able to parse all basic config values correctly" ); diff --git a/pageserver/src/consumption_metrics.rs b/pageserver/src/consumption_metrics.rs index f5540e896f..62bbde42f4 100644 --- a/pageserver/src/consumption_metrics.rs +++ b/pageserver/src/consumption_metrics.rs @@ -304,7 +304,7 @@ async fn calculate_synthetic_size_worker( continue; } - if !tenant_shard_id.is_zero() { + if !tenant_shard_id.is_shard_zero() { // We only send consumption metrics from shard 0, so don't waste time calculating // synthetic size on other shards. continue; diff --git a/pageserver/src/consumption_metrics/metrics.rs b/pageserver/src/consumption_metrics/metrics.rs index 6740c1360b..7ba2d04c4f 100644 --- a/pageserver/src/consumption_metrics/metrics.rs +++ b/pageserver/src/consumption_metrics/metrics.rs @@ -199,7 +199,7 @@ pub(super) async fn collect_all_metrics( }; let tenants = futures::stream::iter(tenants).filter_map(|(id, state, _)| async move { - if state != TenantState::Active || !id.is_zero() { + if state != TenantState::Active || !id.is_shard_zero() { None } else { tenant_manager diff --git a/pageserver/src/http/openapi_spec.yml b/pageserver/src/http/openapi_spec.yml index 2713309824..d89f949688 100644 --- a/pageserver/src/http/openapi_spec.yml +++ b/pageserver/src/http/openapi_spec.yml @@ -58,24 +58,6 @@ paths: responses: "200": description: The reload completed successfully. - "401": - description: Unauthorized Error - content: - application/json: - schema: - $ref: "#/components/schemas/UnauthorizedError" - "403": - description: Forbidden Error - content: - application/json: - schema: - $ref: "#/components/schemas/ForbiddenError" - "500": - description: Generic operation error (also hits if no keys were found) - content: - application/json: - schema: - $ref: "#/components/schemas/Error" /v1/tenant/{tenant_id}: parameters: @@ -93,62 +75,14 @@ paths: application/json: schema: $ref: "#/components/schemas/TenantInfo" - "400": - description: Error when no tenant id found in path or no timeline id - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "401": - description: Unauthorized Error - content: - application/json: - schema: - $ref: "#/components/schemas/UnauthorizedError" - "403": - description: Forbidden Error - content: - application/json: - schema: - $ref: "#/components/schemas/ForbiddenError" - "500": - description: Generic operation error - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "503": - description: Temporarily unavailable, please retry. - content: - application/json: - schema: - $ref: "#/components/schemas/ServiceUnavailableError" delete: description: | Attempts to delete specified tenant. 500, 503 and 409 errors should be retried until 404 is retrieved. 404 means that deletion successfully finished" responses: - "400": - description: Error when no tenant id found in path - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "401": - description: Unauthorized Error - content: - application/json: - schema: - $ref: "#/components/schemas/UnauthorizedError" - "403": - description: Forbidden Error - content: - application/json: - schema: - $ref: "#/components/schemas/ForbiddenError" "404": - description: Tenant not found + description: Tenant not found. This is the success path. content: application/json: schema: @@ -165,18 +99,6 @@ paths: application/json: schema: $ref: "#/components/schemas/PreconditionFailedError" - "500": - description: Generic operation error - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "503": - description: Temporarily unavailable, please retry. - content: - application/json: - schema: - $ref: "#/components/schemas/ServiceUnavailableError" /v1/tenant/{tenant_id}/time_travel_remote_storage: parameters: @@ -206,36 +128,6 @@ paths: application/json: schema: type: string - "400": - description: Error when no tenant id found in path or invalid timestamp - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "401": - description: Unauthorized Error - content: - application/json: - schema: - $ref: "#/components/schemas/UnauthorizedError" - "403": - description: Forbidden Error - content: - application/json: - schema: - $ref: "#/components/schemas/ForbiddenError" - "500": - description: Generic operation error - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "503": - description: Temporarily unavailable, please retry. - content: - application/json: - schema: - $ref: "#/components/schemas/ServiceUnavailableError" /v1/tenant/{tenant_id}/timeline: parameters: @@ -255,36 +147,6 @@ paths: type: array items: $ref: "#/components/schemas/TimelineInfo" - "400": - description: Error when no tenant id found in path - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "401": - description: Unauthorized Error - content: - application/json: - schema: - $ref: "#/components/schemas/UnauthorizedError" - "403": - description: Forbidden Error - content: - application/json: - schema: - $ref: "#/components/schemas/ForbiddenError" - "500": - description: Generic operation error - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "503": - description: Temporarily unavailable, please retry. - content: - application/json: - schema: - $ref: "#/components/schemas/ServiceUnavailableError" /v1/tenant/{tenant_id}/timeline/{timeline_id}: @@ -309,60 +171,12 @@ paths: application/json: schema: $ref: "#/components/schemas/TimelineInfo" - "400": - description: Error when no tenant id found in path or no timeline id - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "401": - description: Unauthorized Error - content: - application/json: - schema: - $ref: "#/components/schemas/UnauthorizedError" - "403": - description: Forbidden Error - content: - application/json: - schema: - $ref: "#/components/schemas/ForbiddenError" - "500": - description: Generic operation error - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "503": - description: Temporarily unavailable, please retry. - content: - application/json: - schema: - $ref: "#/components/schemas/ServiceUnavailableError" delete: description: "Attempts to delete specified timeline. 500 and 409 errors should be retried" responses: - "400": - description: Error when no tenant id found in path or no timeline id - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "401": - description: Unauthorized Error - content: - application/json: - schema: - $ref: "#/components/schemas/UnauthorizedError" - "403": - description: Forbidden Error - content: - application/json: - schema: - $ref: "#/components/schemas/ForbiddenError" "404": - description: Timeline not found + description: Timeline not found. This is the success path. content: application/json: schema: @@ -379,18 +193,6 @@ paths: application/json: schema: $ref: "#/components/schemas/PreconditionFailedError" - "500": - description: Generic operation error - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "503": - description: Temporarily unavailable, please retry. - content: - application/json: - schema: - $ref: "#/components/schemas/ServiceUnavailableError" /v1/tenant/{tenant_id}/timeline/{timeline_id}/get_timestamp_of_lsn: parameters: @@ -423,36 +225,6 @@ paths: schema: type: string format: date-time - "400": - description: Error when no tenant id found in path, no timeline id or invalid timestamp - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "401": - description: Unauthorized Error - content: - application/json: - schema: - $ref: "#/components/schemas/UnauthorizedError" - "403": - description: Forbidden Error - content: - application/json: - schema: - $ref: "#/components/schemas/ForbiddenError" - "404": - description: Timeline not found, or there is no timestamp information for the given lsn - content: - application/json: - schema: - $ref: "#/components/schemas/NotFoundError" - "500": - description: Generic operation error - content: - application/json: - schema: - $ref: "#/components/schemas/Error" /v1/tenant/{tenant_id}/timeline/{timeline_id}/get_lsn_by_timestamp: parameters: @@ -484,36 +256,6 @@ paths: application/json: schema: $ref: "#/components/schemas/LsnByTimestampResponse" - "400": - description: Error when no tenant id found in path, no timeline id or invalid timestamp - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "401": - description: Unauthorized Error - content: - application/json: - schema: - $ref: "#/components/schemas/UnauthorizedError" - "403": - description: Forbidden Error - content: - application/json: - schema: - $ref: "#/components/schemas/ForbiddenError" - "500": - description: Generic operation error - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "503": - description: Temporarily unavailable, please retry. - content: - application/json: - schema: - $ref: "#/components/schemas/ServiceUnavailableError" /v1/tenant/{tenant_id}/timeline/{timeline_id}/do_gc: parameters: @@ -537,36 +279,6 @@ paths: application/json: schema: type: string - "400": - description: Error when no tenant id found in path, no timeline id or invalid timestamp - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "401": - description: Unauthorized Error - content: - application/json: - schema: - $ref: "#/components/schemas/UnauthorizedError" - "403": - description: Forbidden Error - content: - application/json: - schema: - $ref: "#/components/schemas/ForbiddenError" - "500": - description: Generic operation error - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "503": - description: Temporarily unavailable, please retry. - content: - application/json: - schema: - $ref: "#/components/schemas/ServiceUnavailableError" /v1/tenant/{tenant_shard_id}/location_config: parameters: - name: tenant_shard_id @@ -628,24 +340,6 @@ paths: application/json: schema: $ref: "#/components/schemas/TenantLocationConfigResponse" - "503": - description: Tenant's state cannot be changed right now. Wait a few seconds and retry. - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "401": - description: Unauthorized Error - content: - application/json: - schema: - $ref: "#/components/schemas/UnauthorizedError" - "403": - description: Forbidden Error - content: - application/json: - schema: - $ref: "#/components/schemas/ForbiddenError" "409": description: | The tenant is already known to Pageserver in some way, @@ -662,12 +356,6 @@ paths: application/json: schema: $ref: "#/components/schemas/ConflictError" - "500": - description: Generic operation error - content: - application/json: - schema: - $ref: "#/components/schemas/Error" /v1/tenant/{tenant_id}/ignore: parameters: - name: tenant_id @@ -684,36 +372,6 @@ paths: responses: "200": description: Tenant ignored - "400": - description: Error when no tenant id found in path parameters - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "401": - description: Unauthorized Error - content: - application/json: - schema: - $ref: "#/components/schemas/UnauthorizedError" - "403": - description: Forbidden Error - content: - application/json: - schema: - $ref: "#/components/schemas/ForbiddenError" - "500": - description: Generic operation error - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "503": - description: Temporarily unavailable, please retry. - content: - application/json: - schema: - $ref: "#/components/schemas/ServiceUnavailableError" /v1/tenant/{tenant_id}/load: @@ -740,36 +398,6 @@ paths: responses: "202": description: Tenant scheduled to load successfully - "400": - description: Error when no tenant id found in path parameters - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "401": - description: Unauthorized Error - content: - application/json: - schema: - $ref: "#/components/schemas/UnauthorizedError" - "403": - description: Forbidden Error - content: - application/json: - schema: - $ref: "#/components/schemas/ForbiddenError" - "500": - description: Generic operation error - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "503": - description: Temporarily unavailable, please retry. - content: - application/json: - schema: - $ref: "#/components/schemas/ServiceUnavailableError" /v1/tenant/{tenant_id}/{timeline_id}/preserve_initdb_archive: parameters: @@ -790,37 +418,6 @@ paths: responses: "202": description: Tenant scheduled to load successfully - "404": - description: No tenant or timeline found for the specified ids - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "401": - description: Unauthorized Error - content: - application/json: - schema: - $ref: "#/components/schemas/UnauthorizedError" - "403": - description: Forbidden Error - content: - application/json: - schema: - $ref: "#/components/schemas/ForbiddenError" - "500": - description: Generic operation error - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "503": - description: Temporarily unavailable, please retry. - content: - application/json: - schema: - $ref: "#/components/schemas/ServiceUnavailableError" - /v1/tenant/{tenant_id}/synthetic_size: parameters: @@ -839,31 +436,8 @@ paths: application/json: schema: $ref: "#/components/schemas/SyntheticSizeResponse" - "401": - description: Unauthorized Error - content: - application/json: - schema: - $ref: "#/components/schemas/UnauthorizedError" - "403": - description: Forbidden Error - content: - application/json: - schema: - $ref: "#/components/schemas/ForbiddenError" - "500": - description: Generic operation error - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "503": - description: Temporarily unavailable, please retry. - content: - application/json: - schema: - $ref: "#/components/schemas/ServiceUnavailableError" + # This route has no handler. TODO: remove? /v1/tenant/{tenant_id}/size: parameters: - name: tenant_id @@ -945,18 +519,6 @@ paths: responses: "200": description: Success - "500": - description: Generic operation error - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "503": - description: Temporarily unavailable, please retry. - content: - application/json: - schema: - $ref: "#/components/schemas/ServiceUnavailableError" /v1/tenant/{tenant_shard_id}/secondary/download: parameters: @@ -987,20 +549,6 @@ paths: application/json: schema: $ref: "#/components/schemas/SecondaryProgress" - "500": - description: Generic operation error - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "503": - description: Temporarily unavailable, please retry. - content: - application/json: - schema: - $ref: "#/components/schemas/ServiceUnavailableError" - - /v1/tenant/{tenant_id}/timeline/: parameters: @@ -1043,24 +591,6 @@ paths: application/json: schema: $ref: "#/components/schemas/TimelineInfo" - "400": - description: Malformed timeline create request - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "401": - description: Unauthorized Error - content: - application/json: - schema: - $ref: "#/components/schemas/UnauthorizedError" - "403": - description: Forbidden Error - content: - application/json: - schema: - $ref: "#/components/schemas/ForbiddenError" "406": description: Permanently unsatisfiable request, don't retry. content: @@ -1079,18 +609,6 @@ paths: application/json: schema: $ref: "#/components/schemas/Error" - "500": - description: Generic operation error - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "503": - description: Temporarily unavailable, please retry. - content: - application/json: - schema: - $ref: "#/components/schemas/ServiceUnavailableError" /v1/tenant/: get: @@ -1104,30 +622,6 @@ paths: type: array items: $ref: "#/components/schemas/TenantInfo" - "401": - description: Unauthorized Error - content: - application/json: - schema: - $ref: "#/components/schemas/UnauthorizedError" - "403": - description: Forbidden Error - content: - application/json: - schema: - $ref: "#/components/schemas/ForbiddenError" - "500": - description: Generic operation error - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "503": - description: Temporarily unavailable, please retry. - content: - application/json: - schema: - $ref: "#/components/schemas/ServiceUnavailableError" post: description: | @@ -1148,43 +642,12 @@ paths: application/json: schema: type: string - "400": - description: Malformed tenant create request - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "401": - description: Unauthorized Error - content: - application/json: - schema: - $ref: "#/components/schemas/UnauthorizedError" - "403": - description: Forbidden Error - content: - application/json: - schema: - $ref: "#/components/schemas/ForbiddenError" "409": description: Tenant already exists, creation skipped content: application/json: schema: $ref: "#/components/schemas/ConflictError" - "500": - description: Generic operation error - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "503": - description: Temporarily unavailable, please retry. - content: - application/json: - schema: - $ref: "#/components/schemas/ServiceUnavailableError" - /v1/tenant/config: put: @@ -1206,36 +669,6 @@ paths: type: array items: $ref: "#/components/schemas/TenantInfo" - "400": - description: Malformed tenant config request - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "401": - description: Unauthorized Error - content: - application/json: - schema: - $ref: "#/components/schemas/UnauthorizedError" - "403": - description: Forbidden Error - content: - application/json: - schema: - $ref: "#/components/schemas/ForbiddenError" - "500": - description: Generic operation error - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "503": - description: Temporarily unavailable, please retry. - content: - application/json: - schema: - $ref: "#/components/schemas/ServiceUnavailableError" /v1/tenant/{tenant_id}/config/: parameters: @@ -1255,42 +688,6 @@ paths: application/json: schema: $ref: "#/components/schemas/TenantConfigResponse" - "400": - description: Malformed get tenanant config request - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "401": - description: Unauthorized Error - content: - application/json: - schema: - $ref: "#/components/schemas/UnauthorizedError" - "403": - description: Forbidden Error - content: - application/json: - schema: - $ref: "#/components/schemas/ForbiddenError" - "404": - description: Tenand or timeline were not found - content: - application/json: - schema: - $ref: "#/components/schemas/NotFoundError" - "500": - description: Generic operation error - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - "503": - description: Temporarily unavailable, please retry. - content: - application/json: - schema: - $ref: "#/components/schemas/ServiceUnavailableError" /v1/utilization: get: @@ -1304,12 +701,6 @@ paths: application/json: schema: $ref: "#/components/schemas/PageserverUtilization" - "500": - description: Generic operation error - content: - application/json: - schema: - $ref: "#/components/schemas/Error" components: securitySchemes: diff --git a/pageserver/src/http/routes.rs b/pageserver/src/http/routes.rs index 47d8ae1148..20258dd950 100644 --- a/pageserver/src/http/routes.rs +++ b/pageserver/src/http/routes.rs @@ -457,8 +457,12 @@ async fn reload_auth_validation_keys_handler( json_response(StatusCode::OK, ()) } Err(e) => { + let err_msg = "Error reloading public keys"; warn!("Error reloading public keys from {key_path:?}: {e:}"); - json_response(StatusCode::INTERNAL_SERVER_ERROR, ()) + json_response( + StatusCode::INTERNAL_SERVER_ERROR, + HttpErrorBody::from_msg(err_msg.to_string()), + ) } } } @@ -696,7 +700,7 @@ async fn get_lsn_by_timestamp_handler( check_permission(&request, Some(tenant_shard_id.tenant_id))?; let state = get_state(&request); - if !tenant_shard_id.is_zero() { + if !tenant_shard_id.is_shard_zero() { // Requires SLRU contents, which are only stored on shard zero return Err(ApiError::BadRequest(anyhow!( "Size calculations are only available on shard zero" @@ -747,7 +751,7 @@ async fn get_timestamp_of_lsn_handler( check_permission(&request, Some(tenant_shard_id.tenant_id))?; let state = get_state(&request); - if !tenant_shard_id.is_zero() { + if !tenant_shard_id.is_shard_zero() { // Requires SLRU contents, which are only stored on shard zero return Err(ApiError::BadRequest(anyhow!( "Size calculations are only available on shard zero" @@ -772,7 +776,9 @@ async fn get_timestamp_of_lsn_handler( let time = format_rfc3339(postgres_ffi::from_pg_timestamp(time)).to_string(); json_response(StatusCode::OK, time) } - None => json_response(StatusCode::NOT_FOUND, ()), + None => Err(ApiError::NotFound( + anyhow::anyhow!("Timestamp for lsn {} not found", lsn).into(), + )), } } @@ -1086,7 +1092,7 @@ async fn tenant_size_handler( let headers = request.headers(); let state = get_state(&request); - if !tenant_shard_id.is_zero() { + if !tenant_shard_id.is_shard_zero() { return Err(ApiError::BadRequest(anyhow!( "Size calculations are only available on shard zero" ))); diff --git a/pageserver/src/metrics.rs b/pageserver/src/metrics.rs index 3160f204e2..be61a755ff 100644 --- a/pageserver/src/metrics.rs +++ b/pageserver/src/metrics.rs @@ -1819,6 +1819,29 @@ impl Default for WalRedoProcessCounters { pub(crate) static WAL_REDO_PROCESS_COUNTERS: Lazy = Lazy::new(WalRedoProcessCounters::default); +#[cfg(not(test))] +pub mod wal_redo { + use super::*; + + static PROCESS_KIND: Lazy> = Lazy::new(|| { + std::sync::Mutex::new( + register_uint_gauge_vec!( + "pageserver_wal_redo_process_kind", + "The configured process kind for walredo", + &["kind"], + ) + .unwrap(), + ) + }); + + pub fn set_process_kind_metric(kind: crate::walredo::ProcessKind) { + // use guard to avoid races around the next two steps + let guard = PROCESS_KIND.lock().unwrap(); + guard.reset(); + guard.with_label_values(&[&format!("{kind}")]).set(1); + } +} + /// Similar to `prometheus::HistogramTimer` but does not record on drop. pub(crate) struct StorageTimeMetricsTimer { metrics: StorageTimeMetrics, @@ -2089,7 +2112,7 @@ impl TimelineMetrics { pub(crate) fn remove_tenant_metrics(tenant_shard_id: &TenantShardId) { // Only shard zero deals in synthetic sizes - if tenant_shard_id.is_zero() { + if tenant_shard_id.is_shard_zero() { let tid = tenant_shard_id.tenant_id.to_string(); let _ = TENANT_SYNTHETIC_SIZE_METRIC.remove_label_values(&[&tid]); } diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index 17ff033e00..35ea037a55 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -386,7 +386,7 @@ impl WalRedoManager { pub(crate) fn status(&self) -> Option { match self { - WalRedoManager::Prod(m) => m.status(), + WalRedoManager::Prod(m) => Some(m.status()), #[cfg(test)] WalRedoManager::Test(_) => None, } @@ -3190,7 +3190,7 @@ impl Tenant { run_initdb(self.conf, &pgdata_path, pg_version, &self.cancel).await?; // Upload the created data dir to S3 - if self.tenant_shard_id().is_zero() { + if self.tenant_shard_id().is_shard_zero() { self.upload_initdb(&timelines_path, &pgdata_path, &timeline_id) .await?; } @@ -3437,7 +3437,7 @@ impl Tenant { .store(size, Ordering::Relaxed); // Only shard zero should be calculating synthetic sizes - debug_assert!(self.shard_identity.is_zero()); + debug_assert!(self.shard_identity.is_shard_zero()); TENANT_SYNTHETIC_SIZE_METRIC .get_metric_with_label_values(&[&self.tenant_shard_id.tenant_id.to_string()]) diff --git a/pageserver/src/tenant/delete.rs b/pageserver/src/tenant/delete.rs index d1881f3897..33d0f677e5 100644 --- a/pageserver/src/tenant/delete.rs +++ b/pageserver/src/tenant/delete.rs @@ -436,6 +436,11 @@ impl DeleteTenantFlow { .await } + /// Check whether background deletion of this tenant is currently in progress + pub(crate) fn is_in_progress(tenant: &Tenant) -> bool { + tenant.delete_progress.try_lock().is_err() + } + async fn prepare( tenant: &Arc, ) -> Result, DeleteTenantError> { diff --git a/pageserver/src/tenant/mgr.rs b/pageserver/src/tenant/mgr.rs index b1b46d487b..73967f2949 100644 --- a/pageserver/src/tenant/mgr.rs +++ b/pageserver/src/tenant/mgr.rs @@ -1410,9 +1410,15 @@ impl TenantManager { match tenant.current_state() { TenantState::Broken { .. } | TenantState::Stopping { .. } => { - // If a tenant is broken or stopping, DeleteTenantFlow can - // handle it: broken tenants proceed to delete, stopping tenants - // are checked for deletion already in progress. + // If deletion is already in progress, return success (the semantics of this + // function are to rerturn success afterr deletion is spawned in background). + // Otherwise fall through and let [`DeleteTenantFlow`] handle this state. + if DeleteTenantFlow::is_in_progress(&tenant) { + // The `delete_progress` lock is held: deletion is already happening + // in the bacckground + slot_guard.revert(); + return Ok(()); + } } _ => { tenant diff --git a/pageserver/src/tenant/remote_timeline_client/upload.rs b/pageserver/src/tenant/remote_timeline_client/upload.rs index 137fe48b73..0227331953 100644 --- a/pageserver/src/tenant/remote_timeline_client/upload.rs +++ b/pageserver/src/tenant/remote_timeline_client/upload.rs @@ -167,7 +167,7 @@ pub(crate) async fn time_travel_recover_tenant( let warn_after = 3; let max_attempts = 10; let mut prefixes = Vec::with_capacity(2); - if tenant_shard_id.is_zero() { + if tenant_shard_id.is_shard_zero() { // Also recover the unsharded prefix for a shard of zero: // - if the tenant is totally unsharded, the unsharded prefix contains all the data // - if the tenant is sharded, we still want to recover the initdb data, but we only diff --git a/pageserver/src/tenant/storage_layer/delta_layer.rs b/pageserver/src/tenant/storage_layer/delta_layer.rs index 466d95f46d..255855a246 100644 --- a/pageserver/src/tenant/storage_layer/delta_layer.rs +++ b/pageserver/src/tenant/storage_layer/delta_layer.rs @@ -939,7 +939,7 @@ impl DeltaLayerInner { } if !range_end_handled { - tracing::info!("Handling range end fallback at {}", data_end_offset); + tracing::debug!("Handling range end fallback at {}", data_end_offset); planner.handle_range_end(data_end_offset); } } diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index d046a60af4..46b3d41e2b 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -1344,7 +1344,7 @@ impl Timeline { background_jobs_can_start: Option<&completion::Barrier>, ctx: &RequestContext, ) { - if self.tenant_shard_id.is_zero() { + if self.tenant_shard_id.is_shard_zero() { // Logical size is only maintained accurately on shard zero. self.spawn_initial_logical_size_computation_task(ctx); } @@ -2237,7 +2237,7 @@ impl Timeline { priority: GetLogicalSizePriority, ctx: &RequestContext, ) -> logical_size::CurrentLogicalSize { - if !self.tenant_shard_id.is_zero() { + if !self.tenant_shard_id.is_shard_zero() { // Logical size is only accurately maintained on shard zero: when called elsewhere, for example // when HTTP API is serving a GET for timeline zero, return zero return logical_size::CurrentLogicalSize::Approximate(logical_size::Approximate::zero()); @@ -2533,7 +2533,7 @@ impl Timeline { crate::span::debug_assert_current_span_has_tenant_and_timeline_id(); // We should never be calculating logical sizes on shard !=0, because these shards do not have // accurate relation sizes, and they do not emit consumption metrics. - debug_assert!(self.tenant_shard_id.is_zero()); + debug_assert!(self.tenant_shard_id.is_shard_zero()); let guard = self .gate diff --git a/pageserver/src/tenant/timeline/eviction_task.rs b/pageserver/src/tenant/timeline/eviction_task.rs index 522c5b57de..304d0d60ee 100644 --- a/pageserver/src/tenant/timeline/eviction_task.rs +++ b/pageserver/src/tenant/timeline/eviction_task.rs @@ -378,7 +378,7 @@ impl Timeline { gate: &GateGuard, ctx: &RequestContext, ) -> ControlFlow<()> { - if !self.tenant_shard_id.is_zero() { + if !self.tenant_shard_id.is_shard_zero() { // Shards !=0 do not maintain accurate relation sizes, and do not need to calculate logical size // for consumption metrics (consumption metrics are only sent from shard 0). We may therefore // skip imitating logical size accesses for eviction purposes. diff --git a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs index 3f3419e886..c6ee6b90c4 100644 --- a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs +++ b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs @@ -427,7 +427,7 @@ pub(super) async fn handle_walreceiver_connection( // Send the replication feedback message. // Regular standby_status_update fields are put into this message. - let current_timeline_size = if timeline.tenant_shard_id.is_zero() { + let current_timeline_size = if timeline.tenant_shard_id.is_shard_zero() { timeline .get_current_logical_size( crate::tenant::timeline::GetLogicalSizePriority::User, diff --git a/pageserver/src/walingest.rs b/pageserver/src/walingest.rs index 9c7e8748d5..4f83b118ae 100644 --- a/pageserver/src/walingest.rs +++ b/pageserver/src/walingest.rs @@ -403,7 +403,7 @@ impl WalIngest { ); if !key_is_local { - if self.shard.is_zero() { + if self.shard.is_shard_zero() { // Shard 0 tracks relation sizes. Although we will not store this block, we will observe // its blkno in case it implicitly extends a relation. self.observe_decoded_block(modification, blk, ctx).await?; diff --git a/pageserver/src/walredo.rs b/pageserver/src/walredo.rs index ca41a576fd..9776d4ce88 100644 --- a/pageserver/src/walredo.rs +++ b/pageserver/src/walredo.rs @@ -20,6 +20,7 @@ /// Process lifecycle and abstracction for the IPC protocol. mod process; +pub use process::Kind as ProcessKind; /// Code to apply [`NeonWalRecord`]s. pub(crate) mod apply_neon; @@ -34,7 +35,7 @@ use crate::walrecord::NeonWalRecord; use anyhow::Context; use bytes::{Bytes, BytesMut}; use pageserver_api::key::key_to_rel_block; -use pageserver_api::models::WalRedoManagerStatus; +use pageserver_api::models::{WalRedoManagerProcessStatus, WalRedoManagerStatus}; use pageserver_api::shard::TenantShardId; use std::sync::Arc; use std::time::Duration; @@ -54,7 +55,7 @@ pub struct PostgresRedoManager { tenant_shard_id: TenantShardId, conf: &'static PageServerConf, last_redo_at: std::sync::Mutex>, - /// The current [`process::WalRedoProcess`] that is used by new redo requests. + /// The current [`process::Process`] that is used by new redo requests. /// We use [`heavier_once_cell`] for coalescing the spawning, but the redo /// requests don't use the [`heavier_once_cell::Guard`] to keep ahold of the /// their process object; we use [`Arc::clone`] for that. @@ -66,7 +67,7 @@ pub struct PostgresRedoManager { /// still be using the old redo process. But, those other tasks will most likely /// encounter an error as well, and errors are an unexpected condition anyway. /// So, probably we could get rid of the `Arc` in the future. - redo_process: heavier_once_cell::OnceCell>, + redo_process: heavier_once_cell::OnceCell>, } /// @@ -139,8 +140,8 @@ impl PostgresRedoManager { } } - pub(crate) fn status(&self) -> Option { - Some(WalRedoManagerStatus { + pub fn status(&self) -> WalRedoManagerStatus { + WalRedoManagerStatus { last_redo_at: { let at = *self.last_redo_at.lock().unwrap(); at.and_then(|at| { @@ -149,8 +150,14 @@ impl PostgresRedoManager { chrono::Utc::now().checked_sub_signed(chrono::Duration::from_std(age).ok()?) }) }, - pid: self.redo_process.get().map(|p| p.id()), - }) + process: self + .redo_process + .get() + .map(|p| WalRedoManagerProcessStatus { + pid: p.id(), + kind: std::borrow::Cow::Borrowed(p.kind().into()), + }), + } } } @@ -208,37 +215,33 @@ impl PostgresRedoManager { const MAX_RETRY_ATTEMPTS: u32 = 1; let mut n_attempts = 0u32; loop { - let proc: Arc = - match self.redo_process.get_or_init_detached().await { - Ok(guard) => Arc::clone(&guard), - Err(permit) => { - // don't hold poison_guard, the launch code can bail - let start = Instant::now(); - let proc = Arc::new( - process::WalRedoProcess::launch( - self.conf, - self.tenant_shard_id, - pg_version, - ) + let proc: Arc = match self.redo_process.get_or_init_detached().await { + Ok(guard) => Arc::clone(&guard), + Err(permit) => { + // don't hold poison_guard, the launch code can bail + let start = Instant::now(); + let proc = Arc::new( + process::Process::launch(self.conf, self.tenant_shard_id, pg_version) .context("launch walredo process")?, - ); - let duration = start.elapsed(); - WAL_REDO_PROCESS_LAUNCH_DURATION_HISTOGRAM.observe(duration.as_secs_f64()); - info!( - duration_ms = duration.as_millis(), - pid = proc.id(), - "launched walredo process" - ); - self.redo_process.set(Arc::clone(&proc), permit); - proc - } - }; + ); + let duration = start.elapsed(); + WAL_REDO_PROCESS_LAUNCH_DURATION_HISTOGRAM.observe(duration.as_secs_f64()); + info!( + duration_ms = duration.as_millis(), + pid = proc.id(), + "launched walredo process" + ); + self.redo_process.set(Arc::clone(&proc), permit); + proc + } + }; let started_at = std::time::Instant::now(); // Relational WAL records are applied using wal-redo-postgres let result = proc .apply_wal_records(rel, blknum, &base_img, records, wal_redo_timeout) + .await .context("apply_wal_records"); let duration = started_at.elapsed(); diff --git a/pageserver/src/walredo/process.rs b/pageserver/src/walredo/process.rs index bcbb263663..ad6b4e5fe9 100644 --- a/pageserver/src/walredo/process.rs +++ b/pageserver/src/walredo/process.rs @@ -1,186 +1,67 @@ -use self::no_leak_child::NoLeakChild; -use crate::{ - config::PageServerConf, - metrics::{WalRedoKillCause, WAL_REDO_PROCESS_COUNTERS, WAL_REDO_RECORD_COUNTER}, - walrecord::NeonWalRecord, -}; -use anyhow::Context; +use std::time::Duration; + use bytes::Bytes; -use nix::poll::{PollFd, PollFlags}; use pageserver_api::{reltag::RelTag, shard::TenantShardId}; -use postgres_ffi::BLCKSZ; -use std::os::fd::AsRawFd; -#[cfg(feature = "testing")] -use std::sync::atomic::AtomicUsize; -use std::{ - collections::VecDeque, - io::{Read, Write}, - process::{ChildStdin, ChildStdout, Command, Stdio}, - sync::{Mutex, MutexGuard}, - time::Duration, -}; -use tracing::{debug, error, instrument, Instrument}; -use utils::{lsn::Lsn, nonblock::set_nonblock}; +use utils::lsn::Lsn; + +use crate::{config::PageServerConf, walrecord::NeonWalRecord}; mod no_leak_child; /// The IPC protocol that pageserver and walredo process speak over their shared pipe. mod protocol; -pub struct WalRedoProcess { - #[allow(dead_code)] - conf: &'static PageServerConf, - tenant_shard_id: TenantShardId, - // Some() on construction, only becomes None on Drop. - child: Option, - stdout: Mutex, - stdin: Mutex, - /// Counter to separate same sized walredo inputs failing at the same millisecond. - #[cfg(feature = "testing")] - dump_sequence: AtomicUsize, +mod process_impl { + pub(super) mod process_async; + pub(super) mod process_std; } -struct ProcessInput { - stdin: ChildStdin, - n_requests: usize, +#[derive( + Clone, + Copy, + Debug, + PartialEq, + Eq, + strum_macros::EnumString, + strum_macros::Display, + strum_macros::IntoStaticStr, + serde_with::DeserializeFromStr, + serde_with::SerializeDisplay, +)] +#[strum(serialize_all = "kebab-case")] +#[repr(u8)] +pub enum Kind { + Sync, + Async, } -struct ProcessOutput { - stdout: ChildStdout, - pending_responses: VecDeque>, - n_processed_responses: usize, +pub(crate) enum Process { + Sync(process_impl::process_std::WalRedoProcess), + Async(process_impl::process_async::WalRedoProcess), } -impl WalRedoProcess { - // - // Start postgres binary in special WAL redo mode. - // - #[instrument(skip_all,fields(pg_version=pg_version))] - pub(crate) fn launch( +impl Process { + #[inline(always)] + pub fn launch( conf: &'static PageServerConf, tenant_shard_id: TenantShardId, pg_version: u32, ) -> anyhow::Result { - crate::span::debug_assert_current_span_has_tenant_id(); - - let pg_bin_dir_path = conf.pg_bin_dir(pg_version).context("pg_bin_dir")?; // TODO these should be infallible. - let pg_lib_dir_path = conf.pg_lib_dir(pg_version).context("pg_lib_dir")?; - - use no_leak_child::NoLeakChildCommandExt; - // Start postgres itself - let child = Command::new(pg_bin_dir_path.join("postgres")) - // the first arg must be --wal-redo so the child process enters into walredo mode - .arg("--wal-redo") - // the child doesn't process this arg, but, having it in the argv helps indentify the - // walredo process for a particular tenant when debugging a pagserver - .args(["--tenant-shard-id", &format!("{tenant_shard_id}")]) - .stdin(Stdio::piped()) - .stderr(Stdio::piped()) - .stdout(Stdio::piped()) - .env_clear() - .env("LD_LIBRARY_PATH", &pg_lib_dir_path) - .env("DYLD_LIBRARY_PATH", &pg_lib_dir_path) - // NB: The redo process is not trusted after we sent it the first - // walredo work. Before that, it is trusted. Specifically, we trust - // it to - // 1. close all file descriptors except stdin, stdout, stderr because - // pageserver might not be 100% diligent in setting FD_CLOEXEC on all - // the files it opens, and - // 2. to use seccomp to sandbox itself before processing the first - // walredo request. - .spawn_no_leak_child(tenant_shard_id) - .context("spawn process")?; - WAL_REDO_PROCESS_COUNTERS.started.inc(); - let mut child = scopeguard::guard(child, |child| { - error!("killing wal-redo-postgres process due to a problem during launch"); - child.kill_and_wait(WalRedoKillCause::Startup); - }); - - let stdin = child.stdin.take().unwrap(); - let stdout = child.stdout.take().unwrap(); - let stderr = child.stderr.take().unwrap(); - let stderr = tokio::process::ChildStderr::from_std(stderr) - .context("convert to tokio::ChildStderr")?; - macro_rules! set_nonblock_or_log_err { - ($file:ident) => {{ - let res = set_nonblock($file.as_raw_fd()); - if let Err(e) = &res { - error!(error = %e, file = stringify!($file), pid = child.id(), "set_nonblock failed"); - } - res - }}; - } - set_nonblock_or_log_err!(stdin)?; - set_nonblock_or_log_err!(stdout)?; - - // all fallible operations post-spawn are complete, so get rid of the guard - let child = scopeguard::ScopeGuard::into_inner(child); - - tokio::spawn( - async move { - scopeguard::defer! { - debug!("wal-redo-postgres stderr_logger_task finished"); - crate::metrics::WAL_REDO_PROCESS_COUNTERS.active_stderr_logger_tasks_finished.inc(); - } - debug!("wal-redo-postgres stderr_logger_task started"); - crate::metrics::WAL_REDO_PROCESS_COUNTERS.active_stderr_logger_tasks_started.inc(); - - use tokio::io::AsyncBufReadExt; - let mut stderr_lines = tokio::io::BufReader::new(stderr); - let mut buf = Vec::new(); - let res = loop { - buf.clear(); - // TODO we don't trust the process to cap its stderr length. - // Currently it can do unbounded Vec allocation. - match stderr_lines.read_until(b'\n', &mut buf).await { - Ok(0) => break Ok(()), // eof - Ok(num_bytes) => { - let output = String::from_utf8_lossy(&buf[..num_bytes]); - error!(%output, "received output"); - } - Err(e) => { - break Err(e); - } - } - }; - match res { - Ok(()) => (), - Err(e) => { - error!(error=?e, "failed to read from walredo stderr"); - } - } - }.instrument(tracing::info_span!(parent: None, "wal-redo-postgres-stderr", pid = child.id(), tenant_id = %tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug(), %pg_version)) - ); - - Ok(Self { - conf, - tenant_shard_id, - child: Some(child), - stdin: Mutex::new(ProcessInput { - stdin, - n_requests: 0, - }), - stdout: Mutex::new(ProcessOutput { - stdout, - pending_responses: VecDeque::new(), - n_processed_responses: 0, - }), - #[cfg(feature = "testing")] - dump_sequence: AtomicUsize::default(), + Ok(match conf.walredo_process_kind { + Kind::Sync => Self::Sync(process_impl::process_std::WalRedoProcess::launch( + conf, + tenant_shard_id, + pg_version, + )?), + Kind::Async => Self::Async(process_impl::process_async::WalRedoProcess::launch( + conf, + tenant_shard_id, + pg_version, + )?), }) } - pub(crate) fn id(&self) -> u32 { - self.child - .as_ref() - .expect("must not call this during Drop") - .id() - } - - // Apply given WAL records ('records') over an old page image. Returns - // new page image. - // - #[instrument(skip_all, fields(tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug(), pid=%self.id()))] - pub(crate) fn apply_wal_records( + #[inline(always)] + pub(crate) async fn apply_wal_records( &self, rel: RelTag, blknum: u32, @@ -188,221 +69,29 @@ impl WalRedoProcess { records: &[(Lsn, NeonWalRecord)], wal_redo_timeout: Duration, ) -> anyhow::Result { - let tag = protocol::BufferTag { rel, blknum }; - let input = self.stdin.lock().unwrap(); - - // Serialize all the messages to send the WAL redo process first. - // - // This could be problematic if there are millions of records to replay, - // but in practice the number of records is usually so small that it doesn't - // matter, and it's better to keep this code simple. - // - // Most requests start with a before-image with BLCKSZ bytes, followed by - // by some other WAL records. Start with a buffer that can hold that - // comfortably. - let mut writebuf: Vec = Vec::with_capacity((BLCKSZ as usize) * 3); - protocol::build_begin_redo_for_block_msg(tag, &mut writebuf); - if let Some(img) = base_img { - protocol::build_push_page_msg(tag, img, &mut writebuf); - } - for (lsn, rec) in records.iter() { - if let NeonWalRecord::Postgres { - will_init: _, - rec: postgres_rec, - } = rec - { - protocol::build_apply_record_msg(*lsn, postgres_rec, &mut writebuf); - } else { - anyhow::bail!("tried to pass neon wal record to postgres WAL redo"); + match self { + Process::Sync(p) => { + p.apply_wal_records(rel, blknum, base_img, records, wal_redo_timeout) + .await } - } - protocol::build_get_page_msg(tag, &mut writebuf); - WAL_REDO_RECORD_COUNTER.inc_by(records.len() as u64); - - let res = self.apply_wal_records0(&writebuf, input, wal_redo_timeout); - - if res.is_err() { - // not all of these can be caused by this particular input, however these are so rare - // in tests so capture all. - self.record_and_log(&writebuf); - } - - res - } - - fn apply_wal_records0( - &self, - writebuf: &[u8], - input: MutexGuard, - wal_redo_timeout: Duration, - ) -> anyhow::Result { - let mut proc = { input }; // TODO: remove this legacy rename, but this keep the patch small. - let mut nwrite = 0usize; - - while nwrite < writebuf.len() { - let mut stdin_pollfds = [PollFd::new(&proc.stdin, PollFlags::POLLOUT)]; - let n = loop { - match nix::poll::poll(&mut stdin_pollfds[..], wal_redo_timeout.as_millis() as i32) { - Err(nix::errno::Errno::EINTR) => continue, - res => break res, - } - }?; - - if n == 0 { - anyhow::bail!("WAL redo timed out"); + Process::Async(p) => { + p.apply_wal_records(rel, blknum, base_img, records, wal_redo_timeout) + .await } - - // If 'stdin' is writeable, do write. - let in_revents = stdin_pollfds[0].revents().unwrap(); - if in_revents & (PollFlags::POLLERR | PollFlags::POLLOUT) != PollFlags::empty() { - nwrite += proc.stdin.write(&writebuf[nwrite..])?; - } - if in_revents.contains(PollFlags::POLLHUP) { - // We still have more data to write, but the process closed the pipe. - anyhow::bail!("WAL redo process closed its stdin unexpectedly"); - } - } - let request_no = proc.n_requests; - proc.n_requests += 1; - drop(proc); - - // To improve walredo performance we separate sending requests and receiving - // responses. Them are protected by different mutexes (output and input). - // If thread T1, T2, T3 send requests D1, D2, D3 to walredo process - // then there is not warranty that T1 will first granted output mutex lock. - // To address this issue we maintain number of sent requests, number of processed - // responses and ring buffer with pending responses. After sending response - // (under input mutex), threads remembers request number. Then it releases - // input mutex, locks output mutex and fetch in ring buffer all responses until - // its stored request number. The it takes correspondent element from - // pending responses ring buffer and truncate all empty elements from the front, - // advancing processed responses number. - - let mut output = self.stdout.lock().unwrap(); - let n_processed_responses = output.n_processed_responses; - while n_processed_responses + output.pending_responses.len() <= request_no { - // We expect the WAL redo process to respond with an 8k page image. We read it - // into this buffer. - let mut resultbuf = vec![0; BLCKSZ.into()]; - let mut nresult: usize = 0; // # of bytes read into 'resultbuf' so far - while nresult < BLCKSZ.into() { - let mut stdout_pollfds = [PollFd::new(&output.stdout, PollFlags::POLLIN)]; - // We do two things simultaneously: reading response from stdout - // and forward any logging information that the child writes to its stderr to the page server's log. - let n = loop { - match nix::poll::poll( - &mut stdout_pollfds[..], - wal_redo_timeout.as_millis() as i32, - ) { - Err(nix::errno::Errno::EINTR) => continue, - res => break res, - } - }?; - - if n == 0 { - anyhow::bail!("WAL redo timed out"); - } - - // If we have some data in stdout, read it to the result buffer. - let out_revents = stdout_pollfds[0].revents().unwrap(); - if out_revents & (PollFlags::POLLERR | PollFlags::POLLIN) != PollFlags::empty() { - nresult += output.stdout.read(&mut resultbuf[nresult..])?; - } - if out_revents.contains(PollFlags::POLLHUP) { - anyhow::bail!("WAL redo process closed its stdout unexpectedly"); - } - } - output - .pending_responses - .push_back(Some(Bytes::from(resultbuf))); - } - // Replace our request's response with None in `pending_responses`. - // Then make space in the ring buffer by clearing out any seqence of contiguous - // `None`'s from the front of `pending_responses`. - // NB: We can't pop_front() because other requests' responses because another - // requester might have grabbed the output mutex before us: - // T1: grab input mutex - // T1: send request_no 23 - // T1: release input mutex - // T2: grab input mutex - // T2: send request_no 24 - // T2: release input mutex - // T2: grab output mutex - // T2: n_processed_responses + output.pending_responses.len() <= request_no - // 23 0 24 - // T2: enters poll loop that reads stdout - // T2: put response for 23 into pending_responses - // T2: put response for 24 into pending_resposnes - // pending_responses now looks like this: Front Some(response_23) Some(response_24) Back - // T2: takes its response_24 - // pending_responses now looks like this: Front Some(response_23) None Back - // T2: does the while loop below - // pending_responses now looks like this: Front Some(response_23) None Back - // T2: releases output mutex - // T1: grabs output mutex - // T1: n_processed_responses + output.pending_responses.len() > request_no - // 23 2 23 - // T1: skips poll loop that reads stdout - // T1: takes its response_23 - // pending_responses now looks like this: Front None None Back - // T2: does the while loop below - // pending_responses now looks like this: Front Back - // n_processed_responses now has value 25 - let res = output.pending_responses[request_no - n_processed_responses] - .take() - .expect("we own this request_no, nobody else is supposed to take it"); - while let Some(front) = output.pending_responses.front() { - if front.is_none() { - output.pending_responses.pop_front(); - output.n_processed_responses += 1; - } else { - break; - } - } - Ok(res) - } - - #[cfg(feature = "testing")] - fn record_and_log(&self, writebuf: &[u8]) { - use std::sync::atomic::Ordering; - - let millis = std::time::SystemTime::now() - .duration_since(std::time::SystemTime::UNIX_EPOCH) - .unwrap() - .as_millis(); - - let seq = self.dump_sequence.fetch_add(1, Ordering::Relaxed); - - // these files will be collected to an allure report - let filename = format!("walredo-{millis}-{}-{seq}.walredo", writebuf.len()); - - let path = self.conf.tenant_path(&self.tenant_shard_id).join(&filename); - - let res = std::fs::OpenOptions::new() - .write(true) - .create_new(true) - .read(true) - .open(path) - .and_then(|mut f| f.write_all(writebuf)); - - // trip up allowed_errors - if let Err(e) = res { - tracing::error!(target=%filename, length=writebuf.len(), "failed to write out the walredo errored input: {e}"); - } else { - tracing::error!(filename, "erroring walredo input saved"); } } - #[cfg(not(feature = "testing"))] - fn record_and_log(&self, _: &[u8]) {} -} + pub(crate) fn id(&self) -> u32 { + match self { + Process::Sync(p) => p.id(), + Process::Async(p) => p.id(), + } + } -impl Drop for WalRedoProcess { - fn drop(&mut self) { - self.child - .take() - .expect("we only do this once") - .kill_and_wait(WalRedoKillCause::WalRedoProcessDrop); - // no way to wait for stderr_logger_task from Drop because that is async only + pub(crate) fn kind(&self) -> Kind { + match self { + Process::Sync(_) => Kind::Sync, + Process::Async(_) => Kind::Async, + } } } diff --git a/pageserver/src/walredo/process/process_impl/process_async.rs b/pageserver/src/walredo/process/process_impl/process_async.rs new file mode 100644 index 0000000000..262858b033 --- /dev/null +++ b/pageserver/src/walredo/process/process_impl/process_async.rs @@ -0,0 +1,374 @@ +use self::no_leak_child::NoLeakChild; +use crate::{ + config::PageServerConf, + metrics::{WalRedoKillCause, WAL_REDO_PROCESS_COUNTERS, WAL_REDO_RECORD_COUNTER}, + walrecord::NeonWalRecord, + walredo::process::{no_leak_child, protocol}, +}; +use anyhow::Context; +use bytes::Bytes; +use pageserver_api::{reltag::RelTag, shard::TenantShardId}; +use postgres_ffi::BLCKSZ; +#[cfg(feature = "testing")] +use std::sync::atomic::AtomicUsize; +use std::{ + collections::VecDeque, + process::{Command, Stdio}, + time::Duration, +}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tracing::{debug, error, instrument, Instrument}; +use utils::{lsn::Lsn, poison::Poison}; + +pub struct WalRedoProcess { + #[allow(dead_code)] + conf: &'static PageServerConf, + tenant_shard_id: TenantShardId, + // Some() on construction, only becomes None on Drop. + child: Option, + stdout: tokio::sync::Mutex>, + stdin: tokio::sync::Mutex>, + /// Counter to separate same sized walredo inputs failing at the same millisecond. + #[cfg(feature = "testing")] + dump_sequence: AtomicUsize, +} + +struct ProcessInput { + stdin: tokio::process::ChildStdin, + n_requests: usize, +} + +struct ProcessOutput { + stdout: tokio::process::ChildStdout, + pending_responses: VecDeque>, + n_processed_responses: usize, +} + +impl WalRedoProcess { + // + // Start postgres binary in special WAL redo mode. + // + #[instrument(skip_all,fields(pg_version=pg_version))] + pub(crate) fn launch( + conf: &'static PageServerConf, + tenant_shard_id: TenantShardId, + pg_version: u32, + ) -> anyhow::Result { + crate::span::debug_assert_current_span_has_tenant_id(); + + let pg_bin_dir_path = conf.pg_bin_dir(pg_version).context("pg_bin_dir")?; // TODO these should be infallible. + let pg_lib_dir_path = conf.pg_lib_dir(pg_version).context("pg_lib_dir")?; + + use no_leak_child::NoLeakChildCommandExt; + // Start postgres itself + let child = Command::new(pg_bin_dir_path.join("postgres")) + // the first arg must be --wal-redo so the child process enters into walredo mode + .arg("--wal-redo") + // the child doesn't process this arg, but, having it in the argv helps indentify the + // walredo process for a particular tenant when debugging a pagserver + .args(["--tenant-shard-id", &format!("{tenant_shard_id}")]) + .stdin(Stdio::piped()) + .stderr(Stdio::piped()) + .stdout(Stdio::piped()) + .env_clear() + .env("LD_LIBRARY_PATH", &pg_lib_dir_path) + .env("DYLD_LIBRARY_PATH", &pg_lib_dir_path) + // NB: The redo process is not trusted after we sent it the first + // walredo work. Before that, it is trusted. Specifically, we trust + // it to + // 1. close all file descriptors except stdin, stdout, stderr because + // pageserver might not be 100% diligent in setting FD_CLOEXEC on all + // the files it opens, and + // 2. to use seccomp to sandbox itself before processing the first + // walredo request. + .spawn_no_leak_child(tenant_shard_id) + .context("spawn process")?; + WAL_REDO_PROCESS_COUNTERS.started.inc(); + let mut child = scopeguard::guard(child, |child| { + error!("killing wal-redo-postgres process due to a problem during launch"); + child.kill_and_wait(WalRedoKillCause::Startup); + }); + + let stdin = child.stdin.take().unwrap(); + let stdout = child.stdout.take().unwrap(); + let stderr = child.stderr.take().unwrap(); + let stderr = tokio::process::ChildStderr::from_std(stderr) + .context("convert to tokio::ChildStderr")?; + let stdin = + tokio::process::ChildStdin::from_std(stdin).context("convert to tokio::ChildStdin")?; + let stdout = tokio::process::ChildStdout::from_std(stdout) + .context("convert to tokio::ChildStdout")?; + + // all fallible operations post-spawn are complete, so get rid of the guard + let child = scopeguard::ScopeGuard::into_inner(child); + + tokio::spawn( + async move { + scopeguard::defer! { + debug!("wal-redo-postgres stderr_logger_task finished"); + crate::metrics::WAL_REDO_PROCESS_COUNTERS.active_stderr_logger_tasks_finished.inc(); + } + debug!("wal-redo-postgres stderr_logger_task started"); + crate::metrics::WAL_REDO_PROCESS_COUNTERS.active_stderr_logger_tasks_started.inc(); + + use tokio::io::AsyncBufReadExt; + let mut stderr_lines = tokio::io::BufReader::new(stderr); + let mut buf = Vec::new(); + let res = loop { + buf.clear(); + // TODO we don't trust the process to cap its stderr length. + // Currently it can do unbounded Vec allocation. + match stderr_lines.read_until(b'\n', &mut buf).await { + Ok(0) => break Ok(()), // eof + Ok(num_bytes) => { + let output = String::from_utf8_lossy(&buf[..num_bytes]); + error!(%output, "received output"); + } + Err(e) => { + break Err(e); + } + } + }; + match res { + Ok(()) => (), + Err(e) => { + error!(error=?e, "failed to read from walredo stderr"); + } + } + }.instrument(tracing::info_span!(parent: None, "wal-redo-postgres-stderr", pid = child.id(), tenant_id = %tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug(), %pg_version)) + ); + + Ok(Self { + conf, + tenant_shard_id, + child: Some(child), + stdin: tokio::sync::Mutex::new(Poison::new( + "stdin", + ProcessInput { + stdin, + n_requests: 0, + }, + )), + stdout: tokio::sync::Mutex::new(Poison::new( + "stdout", + ProcessOutput { + stdout, + pending_responses: VecDeque::new(), + n_processed_responses: 0, + }, + )), + #[cfg(feature = "testing")] + dump_sequence: AtomicUsize::default(), + }) + } + + pub(crate) fn id(&self) -> u32 { + self.child + .as_ref() + .expect("must not call this during Drop") + .id() + } + + /// Apply given WAL records ('records') over an old page image. Returns + /// new page image. + /// + /// # Cancel-Safety + /// + /// Cancellation safe. + #[instrument(skip_all, fields(tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug(), pid=%self.id()))] + pub(crate) async fn apply_wal_records( + &self, + rel: RelTag, + blknum: u32, + base_img: &Option, + records: &[(Lsn, NeonWalRecord)], + wal_redo_timeout: Duration, + ) -> anyhow::Result { + let tag = protocol::BufferTag { rel, blknum }; + + // Serialize all the messages to send the WAL redo process first. + // + // This could be problematic if there are millions of records to replay, + // but in practice the number of records is usually so small that it doesn't + // matter, and it's better to keep this code simple. + // + // Most requests start with a before-image with BLCKSZ bytes, followed by + // by some other WAL records. Start with a buffer that can hold that + // comfortably. + let mut writebuf: Vec = Vec::with_capacity((BLCKSZ as usize) * 3); + protocol::build_begin_redo_for_block_msg(tag, &mut writebuf); + if let Some(img) = base_img { + protocol::build_push_page_msg(tag, img, &mut writebuf); + } + for (lsn, rec) in records.iter() { + if let NeonWalRecord::Postgres { + will_init: _, + rec: postgres_rec, + } = rec + { + protocol::build_apply_record_msg(*lsn, postgres_rec, &mut writebuf); + } else { + anyhow::bail!("tried to pass neon wal record to postgres WAL redo"); + } + } + protocol::build_get_page_msg(tag, &mut writebuf); + WAL_REDO_RECORD_COUNTER.inc_by(records.len() as u64); + + let Ok(res) = + tokio::time::timeout(wal_redo_timeout, self.apply_wal_records0(&writebuf)).await + else { + anyhow::bail!("WAL redo timed out"); + }; + + if res.is_err() { + // not all of these can be caused by this particular input, however these are so rare + // in tests so capture all. + self.record_and_log(&writebuf); + } + + res + } + + /// # Cancel-Safety + /// + /// When not polled to completion (e.g. because in `tokio::select!` another + /// branch becomes ready before this future), concurrent and subsequent + /// calls may fail due to [`utils::poison::Poison::check_and_arm`] calls. + /// Dispose of this process instance and create a new one. + async fn apply_wal_records0(&self, writebuf: &[u8]) -> anyhow::Result { + let request_no = { + let mut lock_guard = self.stdin.lock().await; + let mut poison_guard = lock_guard.check_and_arm()?; + let input = poison_guard.data_mut(); + input + .stdin + .write_all(writebuf) + .await + .context("write to walredo stdin")?; + let request_no = input.n_requests; + input.n_requests += 1; + poison_guard.disarm(); + request_no + }; + + // To improve walredo performance we separate sending requests and receiving + // responses. Them are protected by different mutexes (output and input). + // If thread T1, T2, T3 send requests D1, D2, D3 to walredo process + // then there is not warranty that T1 will first granted output mutex lock. + // To address this issue we maintain number of sent requests, number of processed + // responses and ring buffer with pending responses. After sending response + // (under input mutex), threads remembers request number. Then it releases + // input mutex, locks output mutex and fetch in ring buffer all responses until + // its stored request number. The it takes correspondent element from + // pending responses ring buffer and truncate all empty elements from the front, + // advancing processed responses number. + + let mut lock_guard = self.stdout.lock().await; + let mut poison_guard = lock_guard.check_and_arm()?; + let output = poison_guard.data_mut(); + let n_processed_responses = output.n_processed_responses; + while n_processed_responses + output.pending_responses.len() <= request_no { + // We expect the WAL redo process to respond with an 8k page image. We read it + // into this buffer. + let mut resultbuf = vec![0; BLCKSZ.into()]; + output + .stdout + .read_exact(&mut resultbuf) + .await + .context("read walredo stdout")?; + output + .pending_responses + .push_back(Some(Bytes::from(resultbuf))); + } + // Replace our request's response with None in `pending_responses`. + // Then make space in the ring buffer by clearing out any seqence of contiguous + // `None`'s from the front of `pending_responses`. + // NB: We can't pop_front() because other requests' responses because another + // requester might have grabbed the output mutex before us: + // T1: grab input mutex + // T1: send request_no 23 + // T1: release input mutex + // T2: grab input mutex + // T2: send request_no 24 + // T2: release input mutex + // T2: grab output mutex + // T2: n_processed_responses + output.pending_responses.len() <= request_no + // 23 0 24 + // T2: enters poll loop that reads stdout + // T2: put response for 23 into pending_responses + // T2: put response for 24 into pending_resposnes + // pending_responses now looks like this: Front Some(response_23) Some(response_24) Back + // T2: takes its response_24 + // pending_responses now looks like this: Front Some(response_23) None Back + // T2: does the while loop below + // pending_responses now looks like this: Front Some(response_23) None Back + // T2: releases output mutex + // T1: grabs output mutex + // T1: n_processed_responses + output.pending_responses.len() > request_no + // 23 2 23 + // T1: skips poll loop that reads stdout + // T1: takes its response_23 + // pending_responses now looks like this: Front None None Back + // T2: does the while loop below + // pending_responses now looks like this: Front Back + // n_processed_responses now has value 25 + let res = output.pending_responses[request_no - n_processed_responses] + .take() + .expect("we own this request_no, nobody else is supposed to take it"); + while let Some(front) = output.pending_responses.front() { + if front.is_none() { + output.pending_responses.pop_front(); + output.n_processed_responses += 1; + } else { + break; + } + } + poison_guard.disarm(); + Ok(res) + } + + #[cfg(feature = "testing")] + fn record_and_log(&self, writebuf: &[u8]) { + use std::sync::atomic::Ordering; + + let millis = std::time::SystemTime::now() + .duration_since(std::time::SystemTime::UNIX_EPOCH) + .unwrap() + .as_millis(); + + let seq = self.dump_sequence.fetch_add(1, Ordering::Relaxed); + + // these files will be collected to an allure report + let filename = format!("walredo-{millis}-{}-{seq}.walredo", writebuf.len()); + + let path = self.conf.tenant_path(&self.tenant_shard_id).join(&filename); + + use std::io::Write; + let res = std::fs::OpenOptions::new() + .write(true) + .create_new(true) + .read(true) + .open(path) + .and_then(|mut f| f.write_all(writebuf)); + + // trip up allowed_errors + if let Err(e) = res { + tracing::error!(target=%filename, length=writebuf.len(), "failed to write out the walredo errored input: {e}"); + } else { + tracing::error!(filename, "erroring walredo input saved"); + } + } + + #[cfg(not(feature = "testing"))] + fn record_and_log(&self, _: &[u8]) {} +} + +impl Drop for WalRedoProcess { + fn drop(&mut self) { + self.child + .take() + .expect("we only do this once") + .kill_and_wait(WalRedoKillCause::WalRedoProcessDrop); + // no way to wait for stderr_logger_task from Drop because that is async only + } +} diff --git a/pageserver/src/walredo/process/process_impl/process_std.rs b/pageserver/src/walredo/process/process_impl/process_std.rs new file mode 100644 index 0000000000..e7a6c263c9 --- /dev/null +++ b/pageserver/src/walredo/process/process_impl/process_std.rs @@ -0,0 +1,405 @@ +use self::no_leak_child::NoLeakChild; +use crate::{ + config::PageServerConf, + metrics::{WalRedoKillCause, WAL_REDO_PROCESS_COUNTERS, WAL_REDO_RECORD_COUNTER}, + walrecord::NeonWalRecord, + walredo::process::{no_leak_child, protocol}, +}; +use anyhow::Context; +use bytes::Bytes; +use nix::poll::{PollFd, PollFlags}; +use pageserver_api::{reltag::RelTag, shard::TenantShardId}; +use postgres_ffi::BLCKSZ; +use std::os::fd::AsRawFd; +#[cfg(feature = "testing")] +use std::sync::atomic::AtomicUsize; +use std::{ + collections::VecDeque, + io::{Read, Write}, + process::{ChildStdin, ChildStdout, Command, Stdio}, + sync::{Mutex, MutexGuard}, + time::Duration, +}; +use tracing::{debug, error, instrument, Instrument}; +use utils::{lsn::Lsn, nonblock::set_nonblock}; + +pub struct WalRedoProcess { + #[allow(dead_code)] + conf: &'static PageServerConf, + tenant_shard_id: TenantShardId, + // Some() on construction, only becomes None on Drop. + child: Option, + stdout: Mutex, + stdin: Mutex, + /// Counter to separate same sized walredo inputs failing at the same millisecond. + #[cfg(feature = "testing")] + dump_sequence: AtomicUsize, +} + +struct ProcessInput { + stdin: ChildStdin, + n_requests: usize, +} + +struct ProcessOutput { + stdout: ChildStdout, + pending_responses: VecDeque>, + n_processed_responses: usize, +} + +impl WalRedoProcess { + // + // Start postgres binary in special WAL redo mode. + // + #[instrument(skip_all,fields(pg_version=pg_version))] + pub(crate) fn launch( + conf: &'static PageServerConf, + tenant_shard_id: TenantShardId, + pg_version: u32, + ) -> anyhow::Result { + crate::span::debug_assert_current_span_has_tenant_id(); + + let pg_bin_dir_path = conf.pg_bin_dir(pg_version).context("pg_bin_dir")?; // TODO these should be infallible. + let pg_lib_dir_path = conf.pg_lib_dir(pg_version).context("pg_lib_dir")?; + + use no_leak_child::NoLeakChildCommandExt; + // Start postgres itself + let child = Command::new(pg_bin_dir_path.join("postgres")) + // the first arg must be --wal-redo so the child process enters into walredo mode + .arg("--wal-redo") + // the child doesn't process this arg, but, having it in the argv helps indentify the + // walredo process for a particular tenant when debugging a pagserver + .args(["--tenant-shard-id", &format!("{tenant_shard_id}")]) + .stdin(Stdio::piped()) + .stderr(Stdio::piped()) + .stdout(Stdio::piped()) + .env_clear() + .env("LD_LIBRARY_PATH", &pg_lib_dir_path) + .env("DYLD_LIBRARY_PATH", &pg_lib_dir_path) + // NB: The redo process is not trusted after we sent it the first + // walredo work. Before that, it is trusted. Specifically, we trust + // it to + // 1. close all file descriptors except stdin, stdout, stderr because + // pageserver might not be 100% diligent in setting FD_CLOEXEC on all + // the files it opens, and + // 2. to use seccomp to sandbox itself before processing the first + // walredo request. + .spawn_no_leak_child(tenant_shard_id) + .context("spawn process")?; + WAL_REDO_PROCESS_COUNTERS.started.inc(); + let mut child = scopeguard::guard(child, |child| { + error!("killing wal-redo-postgres process due to a problem during launch"); + child.kill_and_wait(WalRedoKillCause::Startup); + }); + + let stdin = child.stdin.take().unwrap(); + let stdout = child.stdout.take().unwrap(); + let stderr = child.stderr.take().unwrap(); + let stderr = tokio::process::ChildStderr::from_std(stderr) + .context("convert to tokio::ChildStderr")?; + macro_rules! set_nonblock_or_log_err { + ($file:ident) => {{ + let res = set_nonblock($file.as_raw_fd()); + if let Err(e) = &res { + error!(error = %e, file = stringify!($file), pid = child.id(), "set_nonblock failed"); + } + res + }}; + } + set_nonblock_or_log_err!(stdin)?; + set_nonblock_or_log_err!(stdout)?; + + // all fallible operations post-spawn are complete, so get rid of the guard + let child = scopeguard::ScopeGuard::into_inner(child); + + tokio::spawn( + async move { + scopeguard::defer! { + debug!("wal-redo-postgres stderr_logger_task finished"); + crate::metrics::WAL_REDO_PROCESS_COUNTERS.active_stderr_logger_tasks_finished.inc(); + } + debug!("wal-redo-postgres stderr_logger_task started"); + crate::metrics::WAL_REDO_PROCESS_COUNTERS.active_stderr_logger_tasks_started.inc(); + + use tokio::io::AsyncBufReadExt; + let mut stderr_lines = tokio::io::BufReader::new(stderr); + let mut buf = Vec::new(); + let res = loop { + buf.clear(); + // TODO we don't trust the process to cap its stderr length. + // Currently it can do unbounded Vec allocation. + match stderr_lines.read_until(b'\n', &mut buf).await { + Ok(0) => break Ok(()), // eof + Ok(num_bytes) => { + let output = String::from_utf8_lossy(&buf[..num_bytes]); + error!(%output, "received output"); + } + Err(e) => { + break Err(e); + } + } + }; + match res { + Ok(()) => (), + Err(e) => { + error!(error=?e, "failed to read from walredo stderr"); + } + } + }.instrument(tracing::info_span!(parent: None, "wal-redo-postgres-stderr", pid = child.id(), tenant_id = %tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug(), %pg_version)) + ); + + Ok(Self { + conf, + tenant_shard_id, + child: Some(child), + stdin: Mutex::new(ProcessInput { + stdin, + n_requests: 0, + }), + stdout: Mutex::new(ProcessOutput { + stdout, + pending_responses: VecDeque::new(), + n_processed_responses: 0, + }), + #[cfg(feature = "testing")] + dump_sequence: AtomicUsize::default(), + }) + } + + pub(crate) fn id(&self) -> u32 { + self.child + .as_ref() + .expect("must not call this during Drop") + .id() + } + + // Apply given WAL records ('records') over an old page image. Returns + // new page image. + // + #[instrument(skip_all, fields(tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug(), pid=%self.id()))] + pub(crate) async fn apply_wal_records( + &self, + rel: RelTag, + blknum: u32, + base_img: &Option, + records: &[(Lsn, NeonWalRecord)], + wal_redo_timeout: Duration, + ) -> anyhow::Result { + let tag = protocol::BufferTag { rel, blknum }; + let input = self.stdin.lock().unwrap(); + + // Serialize all the messages to send the WAL redo process first. + // + // This could be problematic if there are millions of records to replay, + // but in practice the number of records is usually so small that it doesn't + // matter, and it's better to keep this code simple. + // + // Most requests start with a before-image with BLCKSZ bytes, followed by + // by some other WAL records. Start with a buffer that can hold that + // comfortably. + let mut writebuf: Vec = Vec::with_capacity((BLCKSZ as usize) * 3); + protocol::build_begin_redo_for_block_msg(tag, &mut writebuf); + if let Some(img) = base_img { + protocol::build_push_page_msg(tag, img, &mut writebuf); + } + for (lsn, rec) in records.iter() { + if let NeonWalRecord::Postgres { + will_init: _, + rec: postgres_rec, + } = rec + { + protocol::build_apply_record_msg(*lsn, postgres_rec, &mut writebuf); + } else { + anyhow::bail!("tried to pass neon wal record to postgres WAL redo"); + } + } + protocol::build_get_page_msg(tag, &mut writebuf); + WAL_REDO_RECORD_COUNTER.inc_by(records.len() as u64); + + let res = self.apply_wal_records0(&writebuf, input, wal_redo_timeout); + + if res.is_err() { + // not all of these can be caused by this particular input, however these are so rare + // in tests so capture all. + self.record_and_log(&writebuf); + } + + res + } + + fn apply_wal_records0( + &self, + writebuf: &[u8], + input: MutexGuard, + wal_redo_timeout: Duration, + ) -> anyhow::Result { + let mut proc = { input }; // TODO: remove this legacy rename, but this keep the patch small. + let mut nwrite = 0usize; + + while nwrite < writebuf.len() { + let mut stdin_pollfds = [PollFd::new(&proc.stdin, PollFlags::POLLOUT)]; + let n = loop { + match nix::poll::poll(&mut stdin_pollfds[..], wal_redo_timeout.as_millis() as i32) { + Err(nix::errno::Errno::EINTR) => continue, + res => break res, + } + }?; + + if n == 0 { + anyhow::bail!("WAL redo timed out"); + } + + // If 'stdin' is writeable, do write. + let in_revents = stdin_pollfds[0].revents().unwrap(); + if in_revents & (PollFlags::POLLERR | PollFlags::POLLOUT) != PollFlags::empty() { + nwrite += proc.stdin.write(&writebuf[nwrite..])?; + } + if in_revents.contains(PollFlags::POLLHUP) { + // We still have more data to write, but the process closed the pipe. + anyhow::bail!("WAL redo process closed its stdin unexpectedly"); + } + } + let request_no = proc.n_requests; + proc.n_requests += 1; + drop(proc); + + // To improve walredo performance we separate sending requests and receiving + // responses. Them are protected by different mutexes (output and input). + // If thread T1, T2, T3 send requests D1, D2, D3 to walredo process + // then there is not warranty that T1 will first granted output mutex lock. + // To address this issue we maintain number of sent requests, number of processed + // responses and ring buffer with pending responses. After sending response + // (under input mutex), threads remembers request number. Then it releases + // input mutex, locks output mutex and fetch in ring buffer all responses until + // its stored request number. The it takes correspondent element from + // pending responses ring buffer and truncate all empty elements from the front, + // advancing processed responses number. + + let mut output = self.stdout.lock().unwrap(); + let n_processed_responses = output.n_processed_responses; + while n_processed_responses + output.pending_responses.len() <= request_no { + // We expect the WAL redo process to respond with an 8k page image. We read it + // into this buffer. + let mut resultbuf = vec![0; BLCKSZ.into()]; + let mut nresult: usize = 0; // # of bytes read into 'resultbuf' so far + while nresult < BLCKSZ.into() { + let mut stdout_pollfds = [PollFd::new(&output.stdout, PollFlags::POLLIN)]; + // We do two things simultaneously: reading response from stdout + // and forward any logging information that the child writes to its stderr to the page server's log. + let n = loop { + match nix::poll::poll( + &mut stdout_pollfds[..], + wal_redo_timeout.as_millis() as i32, + ) { + Err(nix::errno::Errno::EINTR) => continue, + res => break res, + } + }?; + + if n == 0 { + anyhow::bail!("WAL redo timed out"); + } + + // If we have some data in stdout, read it to the result buffer. + let out_revents = stdout_pollfds[0].revents().unwrap(); + if out_revents & (PollFlags::POLLERR | PollFlags::POLLIN) != PollFlags::empty() { + nresult += output.stdout.read(&mut resultbuf[nresult..])?; + } + if out_revents.contains(PollFlags::POLLHUP) { + anyhow::bail!("WAL redo process closed its stdout unexpectedly"); + } + } + output + .pending_responses + .push_back(Some(Bytes::from(resultbuf))); + } + // Replace our request's response with None in `pending_responses`. + // Then make space in the ring buffer by clearing out any seqence of contiguous + // `None`'s from the front of `pending_responses`. + // NB: We can't pop_front() because other requests' responses because another + // requester might have grabbed the output mutex before us: + // T1: grab input mutex + // T1: send request_no 23 + // T1: release input mutex + // T2: grab input mutex + // T2: send request_no 24 + // T2: release input mutex + // T2: grab output mutex + // T2: n_processed_responses + output.pending_responses.len() <= request_no + // 23 0 24 + // T2: enters poll loop that reads stdout + // T2: put response for 23 into pending_responses + // T2: put response for 24 into pending_resposnes + // pending_responses now looks like this: Front Some(response_23) Some(response_24) Back + // T2: takes its response_24 + // pending_responses now looks like this: Front Some(response_23) None Back + // T2: does the while loop below + // pending_responses now looks like this: Front Some(response_23) None Back + // T2: releases output mutex + // T1: grabs output mutex + // T1: n_processed_responses + output.pending_responses.len() > request_no + // 23 2 23 + // T1: skips poll loop that reads stdout + // T1: takes its response_23 + // pending_responses now looks like this: Front None None Back + // T2: does the while loop below + // pending_responses now looks like this: Front Back + // n_processed_responses now has value 25 + let res = output.pending_responses[request_no - n_processed_responses] + .take() + .expect("we own this request_no, nobody else is supposed to take it"); + while let Some(front) = output.pending_responses.front() { + if front.is_none() { + output.pending_responses.pop_front(); + output.n_processed_responses += 1; + } else { + break; + } + } + Ok(res) + } + + #[cfg(feature = "testing")] + fn record_and_log(&self, writebuf: &[u8]) { + use std::sync::atomic::Ordering; + + let millis = std::time::SystemTime::now() + .duration_since(std::time::SystemTime::UNIX_EPOCH) + .unwrap() + .as_millis(); + + let seq = self.dump_sequence.fetch_add(1, Ordering::Relaxed); + + // these files will be collected to an allure report + let filename = format!("walredo-{millis}-{}-{seq}.walredo", writebuf.len()); + + let path = self.conf.tenant_path(&self.tenant_shard_id).join(&filename); + + let res = std::fs::OpenOptions::new() + .write(true) + .create_new(true) + .read(true) + .open(path) + .and_then(|mut f| f.write_all(writebuf)); + + // trip up allowed_errors + if let Err(e) = res { + tracing::error!(target=%filename, length=writebuf.len(), "failed to write out the walredo errored input: {e}"); + } else { + tracing::error!(filename, "erroring walredo input saved"); + } + } + + #[cfg(not(feature = "testing"))] + fn record_and_log(&self, _: &[u8]) {} +} + +impl Drop for WalRedoProcess { + fn drop(&mut self) { + self.child + .take() + .expect("we only do this once") + .kill_and_wait(WalRedoKillCause::WalRedoProcessDrop); + // no way to wait for stderr_logger_task from Drop because that is async only + } +} diff --git a/poetry.lock b/poetry.lock index 7b49daf42a..aca88073a8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aiohttp" @@ -1191,13 +1191,13 @@ files = [ [[package]] name = "idna" -version = "3.3" +version = "3.7" description = "Internationalized Domain Names in Applications (IDNA)" optional = false python-versions = ">=3.5" files = [ - {file = "idna-3.3-py3-none-any.whl", hash = "sha256:84d9dd047ffa80596e0f246e2eab0b391788b0503584e8945f2368256d2735ff"}, - {file = "idna-3.3.tar.gz", hash = "sha256:9d643ff0a55b762d5cdb124b8eaa99c66322e2157b69160bc32796e824360e6d"}, + {file = "idna-3.7-py3-none-any.whl", hash = "sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0"}, + {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"}, ] [[package]] @@ -2182,6 +2182,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -2652,6 +2653,16 @@ files = [ {file = "wrapt-1.14.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8ad85f7f4e20964db4daadcab70b47ab05c7c1cf2a7c1e51087bfaa83831854c"}, {file = "wrapt-1.14.1-cp310-cp310-win32.whl", hash = "sha256:a9a52172be0b5aae932bef82a79ec0a0ce87288c7d132946d645eba03f0ad8a8"}, {file = "wrapt-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:6d323e1554b3d22cfc03cd3243b5bb815a51f5249fdcbb86fda4bf62bab9e164"}, + {file = "wrapt-1.14.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ecee4132c6cd2ce5308e21672015ddfed1ff975ad0ac8d27168ea82e71413f55"}, + {file = "wrapt-1.14.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2020f391008ef874c6d9e208b24f28e31bcb85ccff4f335f15a3251d222b92d9"}, + {file = "wrapt-1.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2feecf86e1f7a86517cab34ae6c2f081fd2d0dac860cb0c0ded96d799d20b335"}, + {file = "wrapt-1.14.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:240b1686f38ae665d1b15475966fe0472f78e71b1b4903c143a842659c8e4cb9"}, + {file = "wrapt-1.14.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9008dad07d71f68487c91e96579c8567c98ca4c3881b9b113bc7b33e9fd78b8"}, + {file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6447e9f3ba72f8e2b985a1da758767698efa72723d5b59accefd716e9e8272bf"}, + {file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:acae32e13a4153809db37405f5eba5bac5fbe2e2ba61ab227926a22901051c0a"}, + {file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49ef582b7a1152ae2766557f0550a9fcbf7bbd76f43fbdc94dd3bf07cc7168be"}, + {file = "wrapt-1.14.1-cp311-cp311-win32.whl", hash = "sha256:358fe87cc899c6bb0ddc185bf3dbfa4ba646f05b1b0b9b5a27c2cb92c2cea204"}, + {file = "wrapt-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:26046cd03936ae745a502abf44dac702a5e6880b2b01c29aea8ddf3353b68224"}, {file = "wrapt-1.14.1-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:43ca3bbbe97af00f49efb06e352eae40434ca9d915906f77def219b88e85d907"}, {file = "wrapt-1.14.1-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:6b1a564e6cb69922c7fe3a678b9f9a3c54e72b469875aa8018f18b4d1dd1adf3"}, {file = "wrapt-1.14.1-cp35-cp35m-manylinux2010_i686.whl", hash = "sha256:00b6d4ea20a906c0ca56d84f93065b398ab74b927a7a3dbd470f6fc503f95dc3"}, diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 12bd67ea36..6b8f2ecbf4 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -44,6 +44,7 @@ ipnet.workspace = true itertools.workspace = true lasso = { workspace = true, features = ["multi-threaded"] } md5.workspace = true +measured = { workspace = true, features = ["lasso"] } metrics.workspace = true once_cell.workspace = true opentelemetry.workspace = true diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index e421798067..3795e3b608 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -2,8 +2,15 @@ mod classic; mod hacks; mod link; +use std::net::IpAddr; +use std::sync::Arc; +use std::time::Duration; + +use ipnet::{Ipv4Net, Ipv6Net}; pub use link::LinkAuthError; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio_postgres::config::AuthKeys; +use tracing::{info, warn}; use crate::auth::credentials::check_peer_addr_is_in_list; use crate::auth::validate_password_and_exchange; @@ -13,9 +20,10 @@ use crate::console::provider::{CachedRoleSecret, ConsoleBackend}; use crate::console::{AuthSecret, NodeInfo}; use crate::context::RequestMonitoring; use crate::intern::EndpointIdInt; -use crate::metrics::{AUTH_RATE_LIMIT_HITS, ENDPOINTS_AUTH_RATE_LIMITED}; +use crate::metrics::Metrics; use crate::proxy::connect_compute::ComputeConnectBackend; use crate::proxy::NeonOptions; +use crate::rate_limiter::{BucketRateLimiter, RateBucketInfo}; use crate::stream::Stream; use crate::{ auth::{self, ComputeUserInfoMaybeEndpoint}, @@ -27,10 +35,7 @@ use crate::{ }, stream, url, }; -use crate::{scram, EndpointCacheKey, EndpointId, RoleName}; -use std::sync::Arc; -use tokio::io::{AsyncRead, AsyncWrite}; -use tracing::{info, warn}; +use crate::{scram, EndpointCacheKey, EndpointId, Normalize, RoleName}; /// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality pub enum MaybeOwned<'a, T> { @@ -176,17 +181,51 @@ impl TryFrom for ComputeUserInfo { } } +#[derive(PartialEq, PartialOrd, Hash, Eq, Ord, Debug, Copy, Clone)] +pub struct MaskedIp(IpAddr); + +impl MaskedIp { + fn new(value: IpAddr, prefix: u8) -> Self { + match value { + IpAddr::V4(v4) => Self(IpAddr::V4( + Ipv4Net::new(v4, prefix).map_or(v4, |x| x.trunc().addr()), + )), + IpAddr::V6(v6) => Self(IpAddr::V6( + Ipv6Net::new(v6, prefix).map_or(v6, |x| x.trunc().addr()), + )), + } + } +} + +// This can't be just per IP because that would limit some PaaS that share IP addresses +pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, MaskedIp)>; + +impl RateBucketInfo { + /// All of these are per endpoint-maskedip pair. + /// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus). + /// + /// First bucket: 1000mcpus total per endpoint-ip pair + /// * 4096000 requests per second with 1 hash rounds. + /// * 1000 requests per second with 4096 hash rounds. + /// * 6.8 requests per second with 600000 hash rounds. + pub const DEFAULT_AUTH_SET: [Self; 3] = [ + Self::new(1000 * 4096, Duration::from_secs(1)), + Self::new(600 * 4096, Duration::from_secs(60)), + Self::new(300 * 4096, Duration::from_secs(600)), + ]; +} + impl AuthenticationConfig { pub fn check_rate_limit( &self, - ctx: &mut RequestMonitoring, + config: &AuthenticationConfig, secret: AuthSecret, endpoint: &EndpointId, is_cleartext: bool, ) -> auth::Result { // we have validated the endpoint exists, so let's intern it. - let endpoint_int = EndpointIdInt::from(endpoint); + let endpoint_int = EndpointIdInt::from(endpoint.normalize()); // only count the full hash count if password hack or websocket flow. // in other words, if proxy needs to run the hashing @@ -201,17 +240,25 @@ impl AuthenticationConfig { 1 }; - let limit_not_exceeded = self - .rate_limiter - .check((endpoint_int, ctx.peer_addr), password_weight); + let limit_not_exceeded = self.rate_limiter.check( + ( + endpoint_int, + MaskedIp::new(ctx.peer_addr, config.rate_limit_ip_subnet), + ), + password_weight, + ); if !limit_not_exceeded { warn!( enabled = self.rate_limiter_enabled, "rate limiting authentication" ); - AUTH_RATE_LIMIT_HITS.inc(); - ENDPOINTS_AUTH_RATE_LIMITED.measure(endpoint); + Metrics::get().proxy.requests_auth_rate_limits_total.inc(); + Metrics::get() + .proxy + .endpoints_auth_rate_limits + .get_metric() + .measure(endpoint); if self.rate_limiter_enabled { return Err(auth::AuthError::too_many_connections()); @@ -267,6 +314,7 @@ async fn auth_quirks( let secret = match secret { Some(secret) => config.check_rate_limit( ctx, + config, secret, &info.endpoint, unauthenticated_password.is_some() || allow_cleartext, @@ -469,7 +517,7 @@ impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> { #[cfg(test)] mod tests { - use std::sync::Arc; + use std::{net::IpAddr, sync::Arc, time::Duration}; use bytes::BytesMut; use fallible_iterator::FallibleIterator; @@ -482,7 +530,7 @@ mod tests { use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; use crate::{ - auth::{ComputeUserInfoMaybeEndpoint, IpPattern}, + auth::{backend::MaskedIp, ComputeUserInfoMaybeEndpoint, IpPattern}, config::AuthenticationConfig, console::{ self, @@ -491,12 +539,12 @@ mod tests { }, context::RequestMonitoring, proxy::NeonOptions, - rate_limiter::{AuthRateLimiter, RateBucketInfo}, + rate_limiter::RateBucketInfo, scram::ServerSecret, stream::{PqStream, Stream}, }; - use super::auth_quirks; + use super::{auth_quirks, AuthRateLimiter}; struct Auth { ips: Vec, @@ -537,6 +585,7 @@ mod tests { scram_protocol_timeout: std::time::Duration::from_secs(5), rate_limiter_enabled: true, rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET), + rate_limit_ip_subnet: 64, }); async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage { @@ -548,6 +597,51 @@ mod tests { } } + #[test] + fn masked_ip() { + let ip_a = IpAddr::V4([127, 0, 0, 1].into()); + let ip_b = IpAddr::V4([127, 0, 0, 2].into()); + let ip_c = IpAddr::V4([192, 168, 1, 101].into()); + let ip_d = IpAddr::V4([192, 168, 1, 102].into()); + let ip_e = IpAddr::V6("abcd:abcd:abcd:abcd:abcd:abcd:abcd:abcd".parse().unwrap()); + let ip_f = IpAddr::V6("abcd:abcd:abcd:abcd:1234:abcd:abcd:abcd".parse().unwrap()); + + assert_ne!(MaskedIp::new(ip_a, 64), MaskedIp::new(ip_b, 64)); + assert_ne!(MaskedIp::new(ip_a, 32), MaskedIp::new(ip_b, 32)); + assert_eq!(MaskedIp::new(ip_a, 30), MaskedIp::new(ip_b, 30)); + assert_eq!(MaskedIp::new(ip_c, 30), MaskedIp::new(ip_d, 30)); + + assert_ne!(MaskedIp::new(ip_e, 128), MaskedIp::new(ip_f, 128)); + assert_eq!(MaskedIp::new(ip_e, 64), MaskedIp::new(ip_f, 64)); + } + + #[test] + fn test_default_auth_rate_limit_set() { + // these values used to exceed u32::MAX + assert_eq!( + RateBucketInfo::DEFAULT_AUTH_SET, + [ + RateBucketInfo { + interval: Duration::from_secs(1), + max_rpi: 1000 * 4096, + }, + RateBucketInfo { + interval: Duration::from_secs(60), + max_rpi: 600 * 4096 * 60, + }, + RateBucketInfo { + interval: Duration::from_secs(600), + max_rpi: 300 * 4096 * 600, + } + ] + ); + + for x in RateBucketInfo::DEFAULT_AUTH_SET { + let y = x.to_string().parse().unwrap(); + assert_eq!(x, y); + } + } + #[tokio::test] async fn auth_quirks_scram() { let (mut client, server) = tokio::io::duplex(1024); diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 89773aa1ff..783a1a5a21 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -4,7 +4,7 @@ use crate::{ auth::password_hack::parse_endpoint_param, context::RequestMonitoring, error::{ReportableError, UserFacingError}, - metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, + metrics::{Metrics, SniKind}, proxy::NeonOptions, serverless::SERVERLESS_DRIVER_SNI, EndpointId, RoleName, @@ -144,21 +144,22 @@ impl ComputeUserInfoMaybeEndpoint { ctx.set_endpoint_id(ep.clone()); } + let metrics = Metrics::get(); info!(%user, "credentials"); if sni.is_some() { info!("Connection with sni"); - NUM_CONNECTION_ACCEPTED_BY_SNI - .with_label_values(&["sni"]) - .inc(); + metrics.proxy.accepted_connections_by_sni.inc(SniKind::Sni); } else if endpoint.is_some() { - NUM_CONNECTION_ACCEPTED_BY_SNI - .with_label_values(&["no_sni"]) - .inc(); + metrics + .proxy + .accepted_connections_by_sni + .inc(SniKind::NoSni); info!("Connection without sni"); } else { - NUM_CONNECTION_ACCEPTED_BY_SNI - .with_label_values(&["password_hack"]) - .inc(); + metrics + .proxy + .accepted_connections_by_sni + .inc(SniKind::PasswordHack); info!("Connection with password hack"); } diff --git a/proxy/src/bin/pg_sni_router.rs b/proxy/src/bin/pg_sni_router.rs index c28814b1c8..7a693002a8 100644 --- a/proxy/src/bin/pg_sni_router.rs +++ b/proxy/src/bin/pg_sni_router.rs @@ -9,15 +9,13 @@ use futures::future::Either; use itertools::Itertools; use proxy::config::TlsServerEndPoint; use proxy::context::RequestMonitoring; -use proxy::proxy::run_until_cancelled; -use proxy::{BranchId, EndpointId, ProjectId}; +use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled}; use rustls::pki_types::PrivateKeyDer; use tokio::net::TcpListener; use anyhow::{anyhow, bail, ensure, Context}; use clap::Arg; use futures::TryFutureExt; -use proxy::console::messages::MetricsAuxInfo; use proxy::stream::{PqStream, Stream}; use tokio::io::{AsyncRead, AsyncWrite}; @@ -176,7 +174,12 @@ async fn task_main( .context("failed to set socket option")?; info!(%peer_addr, "serving"); - let ctx = RequestMonitoring::new(session_id, peer_addr.ip(), "sni_router", "sni"); + let ctx = RequestMonitoring::new( + session_id, + peer_addr.ip(), + proxy::metrics::Protocol::SniRouter, + "sni", + ); handle_client(ctx, dest_suffix, tls_config, tls_server_end_point, socket).await } .unwrap_or_else(|e| { @@ -199,6 +202,7 @@ async fn task_main( const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; async fn ssl_handshake( + ctx: &mut RequestMonitoring, raw_stream: S, tls_config: Arc, tls_server_end_point: TlsServerEndPoint, @@ -228,7 +232,10 @@ async fn ssl_handshake( } Ok(Stream::Tls { - tls: Box::new(raw.upgrade(tls_config).await?), + tls: Box::new( + raw.upgrade(tls_config, !ctx.has_private_peer_addr()) + .await?, + ), tls_server_end_point, }) } @@ -251,7 +258,7 @@ async fn handle_client( tls_server_end_point: TlsServerEndPoint, stream: impl AsyncRead + AsyncWrite + Unpin, ) -> anyhow::Result<()> { - let tls_stream = ssl_handshake(stream, tls_config, tls_server_end_point).await?; + let mut tls_stream = ssl_handshake(&mut ctx, stream, tls_config, tls_server_end_point).await?; // Cut off first part of the SNI domain // We receive required destination details in the format of @@ -268,18 +275,15 @@ async fn handle_client( info!("destination: {}", destination); - let client = tokio::net::TcpStream::connect(destination).await?; - - let metrics_aux: MetricsAuxInfo = MetricsAuxInfo { - endpoint_id: (&EndpointId::from("")).into(), - project_id: (&ProjectId::from("")).into(), - branch_id: (&BranchId::from("")).into(), - cold_start_info: proxy::console::messages::ColdStartInfo::Unknown, - }; + let mut client = tokio::net::TcpStream::connect(destination).await?; // doesn't yet matter as pg-sni-router doesn't report analytics logs ctx.set_success(); ctx.log(); - proxy::proxy::passthrough::proxy_pass(tls_stream, client, metrics_aux).await + // Starting from here we only proxy the client's traffic. + info!("performing the proxy pass..."); + let _ = copy_bidirectional_client_compute(&mut tls_stream, &mut client).await?; + + Ok(()) } diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 56a3ef79cd..b54f8c131c 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -7,6 +7,7 @@ use aws_config::provider_config::ProviderConfig; use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider; use futures::future::Either; use proxy::auth; +use proxy::auth::backend::AuthRateLimiter; use proxy::auth::backend::MaybeOwned; use proxy::cancellation::CancelMap; use proxy::cancellation::CancellationHandler; @@ -18,11 +19,10 @@ use proxy::config::ProjectInfoCacheOptions; use proxy::console; use proxy::context::parquet::ParquetUploadArgs; use proxy::http; -use proxy::metrics::NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT; -use proxy::rate_limiter::AuthRateLimiter; +use proxy::http::health_server::AppMetrics; +use proxy::metrics::Metrics; use proxy::rate_limiter::EndpointRateLimiter; use proxy::rate_limiter::RateBucketInfo; -use proxy::rate_limiter::RateLimiterConfig; use proxy::redis::cancellation_publisher::RedisPublisherClient; use proxy::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use proxy::redis::elasticache; @@ -42,6 +42,7 @@ use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; use tracing::info; use tracing::warn; +use tracing::Instrument; use utils::{project_build_tag, project_git_version, sentry_init::init_sentry}; project_git_version!(GIT_VERSION); @@ -131,14 +132,8 @@ struct ProxyCliArgs { #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] require_client_ip: bool, /// Disable dynamic rate limiter and store the metrics to ensure its production behaviour. - #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] + #[clap(long, default_value_t = true, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] disable_dynamic_rate_limiter: bool, - /// Rate limit algorithm. Makes sense only if `disable_rate_limiter` is `false`. - #[clap(value_enum, long, default_value_t = proxy::rate_limiter::RateLimitAlgorithm::Aimd)] - rate_limit_algorithm: proxy::rate_limiter::RateLimitAlgorithm, - /// Timeout for rate limiter. If it didn't manage to aquire a permit in this time, it will return an error. - #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)] - rate_limiter_timeout: tokio::time::Duration, /// Endpoint rate limiter max number of requests per second. /// /// Provided in the form '@'. @@ -151,14 +146,12 @@ struct ProxyCliArgs { /// Authentication rate limiter max number of hashes per second. #[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)] auth_rate_limit: Vec, + /// The IP subnet to use when considering whether two IP addresses are considered the same. + #[clap(long, default_value_t = 64)] + auth_rate_limit_ip_subnet: u8, /// Redis rate limiter max number of requests per second. #[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)] redis_rps_limit: Vec, - /// Initial limit for dynamic rate limiter. Makes sense only if `rate_limit_algorithm` is *not* `None`. - #[clap(long, default_value_t = 100)] - initial_limit: usize, - #[clap(flatten)] - aimd_config: proxy::rate_limiter::AimdConfig, /// cache for `allowed_ips` (use `size=0` to disable) #[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)] allowed_ips_cache: String, @@ -189,7 +182,9 @@ struct ProxyCliArgs { /// cache for `project_info` (use `size=0` to disable) #[clap(long, default_value = config::ProjectInfoCacheOptions::CACHE_DEFAULT_OPTIONS)] project_info_cache: String, - + /// cache for all valid endpoints + #[clap(long, default_value = config::EndpointCacheConfig::CACHE_DEFAULT_OPTIONS)] + endpoint_cache_config: String, #[clap(flatten)] parquet_upload: ParquetUploadArgs, @@ -249,14 +244,18 @@ async fn main() -> anyhow::Result<()> { info!("Version: {GIT_VERSION}"); info!("Build_tag: {BUILD_TAG}"); - ::metrics::set_build_info_metric(GIT_VERSION, BUILD_TAG); + let neon_metrics = ::metrics::NeonMetrics::new(::metrics::BuildInfo { + revision: GIT_VERSION, + build_tag: BUILD_TAG, + }); - match proxy::jemalloc::MetricRecorder::new(prometheus::default_registry()) { - Ok(t) => { - t.start(); + let jemalloc = match proxy::jemalloc::MetricRecorder::new() { + Ok(t) => Some(t), + Err(e) => { + tracing::error!(error = ?e, "could not start jemalloc metrics loop"); + None } - Err(e) => tracing::error!(error = ?e, "could not start jemalloc metrics loop"), - } + }; let args = ProxyCliArgs::parse(); let config = build_config(&args)?; @@ -296,27 +295,27 @@ async fn main() -> anyhow::Result<()> { ), aws_credentials_provider, )); - let redis_notifications_client = - match (args.redis_notifications, (args.redis_host, args.redis_port)) { - (Some(url), _) => { - info!("Starting redis notifications listener ({url})"); - Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url)) - } - (None, (Some(host), Some(port))) => Some( - ConnectionWithCredentialsProvider::new_with_credentials_provider( - host, - port, - elasticache_credentials_provider.clone(), - ), + let regional_redis_client = match (args.redis_host, args.redis_port) { + (Some(host), Some(port)) => Some( + ConnectionWithCredentialsProvider::new_with_credentials_provider( + host, + port, + elasticache_credentials_provider.clone(), ), - (None, (None, None)) => { - warn!("Redis is disabled"); - None - } - _ => { - bail!("redis-host and redis-port must be specified together"); - } - }; + ), + (None, None) => { + warn!("Redis events from console are disabled"); + None + } + _ => { + bail!("redis-host and redis-port must be specified together"); + } + }; + let redis_notifications_client = if let Some(url) = args.redis_notifications { + Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url)) + } else { + regional_redis_client.clone() + }; // Check that we can bind to address before further initialization let http_address: SocketAddr = args.http.parse()?; @@ -332,11 +331,9 @@ async fn main() -> anyhow::Result<()> { let proxy_listener = TcpListener::bind(proxy_address).await?; let cancellation_token = CancellationToken::new(); - let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(&config.endpoint_rps_limit)); let cancel_map = CancelMap::default(); - // let redis_notifications_client = redis_notifications_client.map(|x| Box::leak(Box::new(x))); - let redis_publisher = match &redis_notifications_client { + let redis_publisher = match ®ional_redis_client { Some(redis_publisher) => Some(Arc::new(Mutex::new(RedisPublisherClient::new( redis_publisher.clone(), args.region.clone(), @@ -349,7 +346,7 @@ async fn main() -> anyhow::Result<()> { >::new( cancel_map.clone(), redis_publisher, - NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT, + proxy::metrics::CancellationSource::FromClient, )); // client facing tasks. these will exit on error or on cancellation @@ -359,7 +356,6 @@ async fn main() -> anyhow::Result<()> { config, proxy_listener, cancellation_token.clone(), - endpoint_rate_limiter.clone(), cancellation_handler.clone(), )); @@ -374,7 +370,6 @@ async fn main() -> anyhow::Result<()> { config, serverless_listener, cancellation_token.clone(), - endpoint_rate_limiter.clone(), cancellation_handler.clone(), )); } @@ -387,7 +382,14 @@ async fn main() -> anyhow::Result<()> { // maintenance tasks. these never return unless there's an error let mut maintenance_tasks = JoinSet::new(); maintenance_tasks.spawn(proxy::handle_signals(cancellation_token.clone())); - maintenance_tasks.spawn(http::health_server::task_main(http_listener)); + maintenance_tasks.spawn(http::health_server::task_main( + http_listener, + AppMetrics { + jemalloc, + neon_metrics, + proxy: proxy::metrics::Metrics::get(), + }, + )); maintenance_tasks.spawn(console::mgmt::task_main(mgmt_listener)); if let Some(metrics_config) = &config.metric_collection { @@ -404,13 +406,19 @@ async fn main() -> anyhow::Result<()> { if let Some(redis_notifications_client) = redis_notifications_client { let cache = api.caches.project_info.clone(); maintenance_tasks.spawn(notifications::task_main( - redis_notifications_client.clone(), + redis_notifications_client, cache.clone(), cancel_map.clone(), args.region.clone(), )); maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); } + if let Some(regional_redis_client) = regional_redis_client { + let cache = api.caches.endpoints_cache.clone(); + let con = regional_redis_client; + let span = tracing::info_span!("endpoints_cache"); + maintenance_tasks.spawn(async move { cache.do_read(con).await }.instrument(span)); + } } } @@ -476,27 +484,27 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { and metric-collection-interval must be specified" ), }; - let rate_limiter_config = RateLimiterConfig { - disable: args.disable_dynamic_rate_limiter, - algorithm: args.rate_limit_algorithm, - timeout: args.rate_limiter_timeout, - initial_limit: args.initial_limit, - aimd_config: Some(args.aimd_config), - }; + if !args.disable_dynamic_rate_limiter { + bail!("dynamic rate limiter should be disabled"); + } let auth_backend = match &args.auth_backend { AuthBackend::Console => { let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?; let project_info_cache_config: ProjectInfoCacheOptions = args.project_info_cache.parse()?; + let endpoint_cache_config: config::EndpointCacheConfig = + args.endpoint_cache_config.parse()?; info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}"); info!( "Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}" ); + info!("Using EndpointCacheConfig with options={endpoint_cache_config:?}"); let caches = Box::leak(Box::new(console::caches::ApiCaches::new( wake_compute_cache_config, project_info_cache_config, + endpoint_cache_config, ))); let config::WakeComputeLockOptions { @@ -507,15 +515,26 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { } = args.wake_compute_lock.parse()?; info!(permits, shards, ?epoch, "Using NodeLocks (wake_compute)"); let locks = Box::leak(Box::new( - console::locks::ApiLocks::new("wake_compute_lock", permits, shards, timeout) - .unwrap(), + console::locks::ApiLocks::new( + "wake_compute_lock", + permits, + shards, + timeout, + epoch, + &Metrics::get().wake_compute_lock, + ) + .unwrap(), )); - tokio::spawn(locks.garbage_collect_worker(epoch)); + tokio::spawn(locks.garbage_collect_worker()); let url = args.auth_endpoint.parse()?; - let endpoint = http::Endpoint::new(url, http::new_client(rate_limiter_config)); + let endpoint = http::Endpoint::new(url, http::new_client()); - let api = console::provider::neon::Api::new(endpoint, caches, locks); + let mut endpoint_rps_limit = args.endpoint_rps_limit.clone(); + RateBucketInfo::validate(&mut endpoint_rps_limit)?; + let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(endpoint_rps_limit)); + let api = + console::provider::neon::Api::new(endpoint, caches, locks, endpoint_rate_limiter); let api = console::provider::ConsoleBackend::Console(api); auth::BackendType::Console(MaybeOwned::Owned(api), ()) } @@ -546,10 +565,9 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { scram_protocol_timeout: args.scram_protocol_timeout, rate_limiter_enabled: args.auth_rate_limit_enabled, rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()), + rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet, }; - let mut endpoint_rps_limit = args.endpoint_rps_limit.clone(); - RateBucketInfo::validate(&mut endpoint_rps_limit)?; let mut redis_rps_limit = args.redis_rps_limit.clone(); RateBucketInfo::validate(&mut redis_rps_limit)?; @@ -562,7 +580,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { authentication_config, require_client_ip: args.require_client_ip, disable_ip_check_for_http: args.disable_ip_check_for_http, - endpoint_rps_limit, redis_rps_limit, handshake_timeout: args.handshake_timeout, region: args.region.clone(), diff --git a/proxy/src/cache.rs b/proxy/src/cache.rs index fc5f416395..d1d4087241 100644 --- a/proxy/src/cache.rs +++ b/proxy/src/cache.rs @@ -1,4 +1,5 @@ pub mod common; +pub mod endpoints; pub mod project_info; mod timed_lru; diff --git a/proxy/src/cache/endpoints.rs b/proxy/src/cache/endpoints.rs new file mode 100644 index 0000000000..2aa1986d5e --- /dev/null +++ b/proxy/src/cache/endpoints.rs @@ -0,0 +1,227 @@ +use std::{ + convert::Infallible, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; + +use dashmap::DashSet; +use redis::{ + streams::{StreamReadOptions, StreamReadReply}, + AsyncCommands, FromRedisValue, Value, +}; +use serde::Deserialize; +use tokio::sync::Mutex; +use tracing::info; + +use crate::{ + config::EndpointCacheConfig, + context::RequestMonitoring, + intern::{BranchIdInt, EndpointIdInt, ProjectIdInt}, + metrics::{Metrics, RedisErrors}, + rate_limiter::GlobalRateLimiter, + redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider, + EndpointId, +}; + +#[derive(Deserialize, Debug, Clone)] +pub struct ControlPlaneEventKey { + endpoint_created: Option, + branch_created: Option, + project_created: Option, +} +#[derive(Deserialize, Debug, Clone)] +struct EndpointCreated { + endpoint_id: String, +} +#[derive(Deserialize, Debug, Clone)] +struct BranchCreated { + branch_id: String, +} +#[derive(Deserialize, Debug, Clone)] +struct ProjectCreated { + project_id: String, +} + +pub struct EndpointsCache { + config: EndpointCacheConfig, + endpoints: DashSet, + branches: DashSet, + projects: DashSet, + ready: AtomicBool, + limiter: Arc>, +} + +impl EndpointsCache { + pub fn new(config: EndpointCacheConfig) -> Self { + Self { + limiter: Arc::new(Mutex::new(GlobalRateLimiter::new( + config.limiter_info.clone(), + ))), + config, + endpoints: DashSet::new(), + branches: DashSet::new(), + projects: DashSet::new(), + ready: AtomicBool::new(false), + } + } + pub async fn is_valid(&self, ctx: &mut RequestMonitoring, endpoint: &EndpointId) -> bool { + if !self.ready.load(Ordering::Acquire) { + return true; + } + let rejected = self.should_reject(endpoint); + ctx.set_rejected(rejected); + info!(?rejected, "check endpoint is valid, disabled cache"); + // If cache is disabled, just collect the metrics and return or + // If the limiter allows, we don't need to check the cache. + if self.config.disable_cache || self.limiter.lock().await.check() { + return true; + } + !rejected + } + fn should_reject(&self, endpoint: &EndpointId) -> bool { + if endpoint.is_endpoint() { + !self.endpoints.contains(&EndpointIdInt::from(endpoint)) + } else if endpoint.is_branch() { + !self + .branches + .contains(&BranchIdInt::from(&endpoint.as_branch())) + } else { + !self + .projects + .contains(&ProjectIdInt::from(&endpoint.as_project())) + } + } + fn insert_event(&self, key: ControlPlaneEventKey) { + // Do not do normalization here, we expect the events to be normalized. + if let Some(endpoint_created) = key.endpoint_created { + self.endpoints + .insert(EndpointIdInt::from(&endpoint_created.endpoint_id.into())); + } + if let Some(branch_created) = key.branch_created { + self.branches + .insert(BranchIdInt::from(&branch_created.branch_id.into())); + } + if let Some(project_created) = key.project_created { + self.projects + .insert(ProjectIdInt::from(&project_created.project_id.into())); + } + } + pub async fn do_read( + &self, + mut con: ConnectionWithCredentialsProvider, + ) -> anyhow::Result { + let mut last_id = "0-0".to_string(); + loop { + self.ready.store(false, Ordering::Release); + if let Err(e) = con.connect().await { + tracing::error!("error connecting to redis: {:?}", e); + continue; + } + if let Err(e) = self.read_from_stream(&mut con, &mut last_id).await { + tracing::error!("error reading from redis: {:?}", e); + } + tokio::time::sleep(self.config.retry_interval).await; + } + } + async fn read_from_stream( + &self, + con: &mut ConnectionWithCredentialsProvider, + last_id: &mut String, + ) -> anyhow::Result<()> { + tracing::info!("reading endpoints/branches/projects from redis"); + self.batch_read( + con, + StreamReadOptions::default().count(self.config.initial_batch_size), + last_id, + true, + ) + .await?; + tracing::info!("ready to filter user requests"); + self.ready.store(true, Ordering::Release); + self.batch_read( + con, + StreamReadOptions::default() + .count(self.config.default_batch_size) + .block(self.config.xread_timeout.as_millis() as usize), + last_id, + false, + ) + .await + } + fn parse_key_value(value: &Value) -> anyhow::Result { + let s: String = FromRedisValue::from_redis_value(value)?; + Ok(serde_json::from_str(&s)?) + } + async fn batch_read( + &self, + conn: &mut ConnectionWithCredentialsProvider, + opts: StreamReadOptions, + last_id: &mut String, + return_when_finish: bool, + ) -> anyhow::Result<()> { + let mut total: usize = 0; + loop { + let mut res: StreamReadReply = conn + .xread_options(&[&self.config.stream_name], &[last_id.as_str()], &opts) + .await?; + + if res.keys.is_empty() { + if return_when_finish { + if total != 0 { + break; + } + anyhow::bail!( + "Redis stream {} is empty, cannot be used to filter endpoints", + self.config.stream_name + ); + } + // If we are not returning when finish, we should wait for more data. + continue; + } + if res.keys.len() != 1 { + anyhow::bail!("Cannot read from redis stream {}", self.config.stream_name); + } + + let res = res.keys.pop().expect("Checked length above"); + let len = res.ids.len(); + for x in res.ids { + total += 1; + for (_, v) in x.map { + let key = match Self::parse_key_value(&v) { + Ok(x) => x, + Err(e) => { + Metrics::get().proxy.redis_errors_total.inc(RedisErrors { + channel: &self.config.stream_name, + }); + tracing::error!("error parsing value {v:?}: {e:?}"); + continue; + } + }; + self.insert_event(key); + } + if total.is_power_of_two() { + tracing::debug!("endpoints read {}", total); + } + *last_id = x.id; + } + if return_when_finish && len <= self.config.default_batch_size { + break; + } + } + tracing::info!("read {} endpoints/branches/projects from redis", total); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::ControlPlaneEventKey; + + #[test] + fn test() { + let s = "{\"branch_created\":null,\"endpoint_created\":{\"endpoint_id\":\"ep-rapid-thunder-w0qqw2q9\"},\"project_created\":null,\"type\":\"endpoint_created\"}"; + let _: ControlPlaneEventKey = serde_json::from_str(s).unwrap(); + } +} diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index 6151513614..34512e9f5b 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -10,7 +10,7 @@ use uuid::Uuid; use crate::{ error::ReportableError, - metrics::NUM_CANCELLATION_REQUESTS, + metrics::{CancellationRequest, CancellationSource, Metrics}, redis::cancellation_publisher::{ CancellationPublisher, CancellationPublisherMut, RedisPublisherClient, }, @@ -28,7 +28,7 @@ pub struct CancellationHandler

{ client: P, /// This field used for the monitoring purposes. /// Represents the source of the cancellation request. - from: &'static str, + from: CancellationSource, } #[derive(Debug, Error)] @@ -89,9 +89,13 @@ impl CancellationHandler

{ // NB: we should immediately release the lock after cloning the token. let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else { tracing::warn!("query cancellation key not found: {key}"); - NUM_CANCELLATION_REQUESTS - .with_label_values(&[self.from, "not_found"]) - .inc(); + Metrics::get() + .proxy + .cancellation_requests_total + .inc(CancellationRequest { + source: self.from, + kind: crate::metrics::CancellationOutcome::NotFound, + }); match self.client.try_publish(key, session_id).await { Ok(()) => {} // do nothing Err(e) => { @@ -103,9 +107,13 @@ impl CancellationHandler

{ } return Ok(()); }; - NUM_CANCELLATION_REQUESTS - .with_label_values(&[self.from, "found"]) - .inc(); + Metrics::get() + .proxy + .cancellation_requests_total + .inc(CancellationRequest { + source: self.from, + kind: crate::metrics::CancellationOutcome::Found, + }); info!("cancelling query per user's request using key {key}"); cancel_closure.try_cancel_query().await } @@ -122,7 +130,7 @@ impl CancellationHandler

{ } impl CancellationHandler<()> { - pub fn new(map: CancelMap, from: &'static str) -> Self { + pub fn new(map: CancelMap, from: CancellationSource) -> Self { Self { map, client: (), @@ -132,7 +140,7 @@ impl CancellationHandler<()> { } impl CancellationHandler>>> { - pub fn new(map: CancelMap, client: Option>>, from: &'static str) -> Self { + pub fn new(map: CancelMap, client: Option>>, from: CancellationSource) -> Self { Self { map, client, from } } } @@ -192,15 +200,13 @@ impl

Drop for Session

{ #[cfg(test)] mod tests { - use crate::metrics::NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS; - use super::*; #[tokio::test] async fn check_session_drop() -> anyhow::Result<()> { let cancellation_handler = Arc::new(CancellationHandler::<()>::new( CancelMap::default(), - NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS, + CancellationSource::FromRedis, )); let session = cancellation_handler.clone().get_session(); @@ -214,7 +220,7 @@ mod tests { #[tokio::test] async fn cancel_session_noop_regression() { - let handler = CancellationHandler::<()>::new(Default::default(), "local"); + let handler = CancellationHandler::<()>::new(Default::default(), CancellationSource::Local); handler .cancel_session( CancelKeyData { diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index ee33b97fbd..149a619316 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -4,12 +4,11 @@ use crate::{ console::{errors::WakeComputeError, messages::MetricsAuxInfo}, context::RequestMonitoring, error::{ReportableError, UserFacingError}, - metrics::NUM_DB_CONNECTIONS_GAUGE, + metrics::{Metrics, NumDbConnectionsGuard}, proxy::neon_option, }; use futures::{FutureExt, TryFutureExt}; use itertools::Itertools; -use metrics::IntCounterPairGuard; use pq_proto::StartupMessageParams; use std::{io, net::SocketAddr, time::Duration}; use thiserror::Error; @@ -249,7 +248,7 @@ pub struct PostgresConnection { /// Labels for proxy's metrics. pub aux: MetricsAuxInfo, - _guage: IntCounterPairGuard, + _guage: NumDbConnectionsGuard<'static>, } impl ConnCfg { @@ -295,9 +294,7 @@ impl ConnCfg { params, cancel_closure, aux, - _guage: NUM_DB_CONNECTIONS_GAUGE - .with_label_values(&[ctx.protocol]) - .guard(), + _guage: Metrics::get().proxy.db_connections.guard(ctx.protocol), }; Ok(connection) diff --git a/proxy/src/config.rs b/proxy/src/config.rs index fc490c7348..f9519c7645 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -1,6 +1,6 @@ use crate::{ - auth, - rate_limiter::{AuthRateLimiter, RateBucketInfo}, + auth::{self, backend::AuthRateLimiter}, + rate_limiter::RateBucketInfo, serverless::GlobalConnPoolOptions, }; use anyhow::{bail, ensure, Context, Ok}; @@ -29,7 +29,6 @@ pub struct ProxyConfig { pub authentication_config: AuthenticationConfig, pub require_client_ip: bool, pub disable_ip_check_for_http: bool, - pub endpoint_rps_limit: Vec, pub redis_rps_limit: Vec, pub region: String, pub handshake_timeout: Duration, @@ -58,6 +57,7 @@ pub struct AuthenticationConfig { pub scram_protocol_timeout: tokio::time::Duration, pub rate_limiter_enabled: bool, pub rate_limiter: AuthRateLimiter, + pub rate_limit_ip_subnet: u8, } impl TlsConfig { @@ -313,6 +313,80 @@ impl CertResolver { } } +#[derive(Debug)] +pub struct EndpointCacheConfig { + /// Batch size to receive all endpoints on the startup. + pub initial_batch_size: usize, + /// Batch size to receive endpoints. + pub default_batch_size: usize, + /// Timeouts for the stream read operation. + pub xread_timeout: Duration, + /// Stream name to read from. + pub stream_name: String, + /// Limiter info (to distinguish when to enable cache). + pub limiter_info: Vec, + /// Disable cache. + /// If true, cache is ignored, but reports all statistics. + pub disable_cache: bool, + /// Retry interval for the stream read operation. + pub retry_interval: Duration, +} + +impl EndpointCacheConfig { + /// Default options for [`crate::console::provider::NodeInfoCache`]. + /// Notice that by default the limiter is empty, which means that cache is disabled. + pub const CACHE_DEFAULT_OPTIONS: &'static str = + "initial_batch_size=1000,default_batch_size=10,xread_timeout=5m,stream_name=controlPlane,disable_cache=true,limiter_info=1000@1s,retry_interval=1s"; + + /// Parse cache options passed via cmdline. + /// Example: [`Self::CACHE_DEFAULT_OPTIONS`]. + fn parse(options: &str) -> anyhow::Result { + let mut initial_batch_size = None; + let mut default_batch_size = None; + let mut xread_timeout = None; + let mut stream_name = None; + let mut limiter_info = vec![]; + let mut disable_cache = false; + let mut retry_interval = None; + + for option in options.split(',') { + let (key, value) = option + .split_once('=') + .with_context(|| format!("bad key-value pair: {option}"))?; + + match key { + "initial_batch_size" => initial_batch_size = Some(value.parse()?), + "default_batch_size" => default_batch_size = Some(value.parse()?), + "xread_timeout" => xread_timeout = Some(humantime::parse_duration(value)?), + "stream_name" => stream_name = Some(value.to_string()), + "limiter_info" => limiter_info.push(RateBucketInfo::from_str(value)?), + "disable_cache" => disable_cache = value.parse()?, + "retry_interval" => retry_interval = Some(humantime::parse_duration(value)?), + unknown => bail!("unknown key: {unknown}"), + } + } + RateBucketInfo::validate(&mut limiter_info)?; + + Ok(Self { + initial_batch_size: initial_batch_size.context("missing `initial_batch_size`")?, + default_batch_size: default_batch_size.context("missing `default_batch_size`")?, + xread_timeout: xread_timeout.context("missing `xread_timeout`")?, + stream_name: stream_name.context("missing `stream_name`")?, + disable_cache, + limiter_info, + retry_interval: retry_interval.context("missing `retry_interval`")?, + }) + } +} + +impl FromStr for EndpointCacheConfig { + type Err = anyhow::Error; + + fn from_str(options: &str) -> Result { + let error = || format!("failed to parse endpoint cache options '{options}'"); + Self::parse(options).with_context(error) + } +} #[derive(Debug)] pub struct MetricBackupCollectionConfig { pub interval: Duration, diff --git a/proxy/src/console/messages.rs b/proxy/src/console/messages.rs index 45161f5ac8..9869b95768 100644 --- a/proxy/src/console/messages.rs +++ b/proxy/src/console/messages.rs @@ -1,3 +1,4 @@ +use measured::FixedCardinalityLabel; use serde::{Deserialize, Serialize}; use std::fmt; @@ -102,7 +103,7 @@ pub struct MetricsAuxInfo { pub cold_start_info: ColdStartInfo, } -#[derive(Debug, Default, Serialize, Deserialize, Clone, Copy)] +#[derive(Debug, Default, Serialize, Deserialize, Clone, Copy, FixedCardinalityLabel)] #[serde(rename_all = "snake_case")] pub enum ColdStartInfo { #[default] @@ -110,9 +111,11 @@ pub enum ColdStartInfo { /// Compute was already running Warm, #[serde(rename = "pool_hit")] + #[label(rename = "pool_hit")] /// Compute was not running but there was an available VM VmPoolHit, #[serde(rename = "pool_miss")] + #[label(rename = "pool_miss")] /// Compute was not running and there were no VMs available VmPoolMiss, diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index f7d621fb12..aa1800a9da 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -8,11 +8,12 @@ use crate::{ backend::{ComputeCredentialKeys, ComputeUserInfo}, IpPattern, }, - cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru}, + cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru}, compute, - config::{CacheOptions, ProjectInfoCacheOptions}, + config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions}, context::RequestMonitoring, intern::ProjectIdInt, + metrics::ApiLockMetrics, scram, EndpointCacheKey, }; use dashmap::DashMap; @@ -207,6 +208,9 @@ pub mod errors { #[error(transparent)] ApiError(ApiError), + #[error("Too many connections attempts")] + TooManyConnections, + #[error("Timeout waiting to acquire wake compute lock")] TimeoutError, } @@ -239,6 +243,8 @@ pub mod errors { // However, API might return a meaningful error. ApiError(e) => e.to_string_client(), + TooManyConnections => self.to_string(), + TimeoutError => "timeout while acquiring the compute resource lock".to_owned(), } } @@ -249,6 +255,7 @@ pub mod errors { match self { WakeComputeError::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane, WakeComputeError::ApiError(e) => e.get_error_kind(), + WakeComputeError::TooManyConnections => crate::error::ErrorKind::RateLimit, WakeComputeError::TimeoutError => crate::error::ErrorKind::ServiceRateLimit, } } @@ -416,12 +423,15 @@ pub struct ApiCaches { pub node_info: NodeInfoCache, /// Cache which stores project_id -> endpoint_ids mapping. pub project_info: Arc, + /// List of all valid endpoints. + pub endpoints_cache: Arc, } impl ApiCaches { pub fn new( wake_compute_cache_config: CacheOptions, project_info_cache_config: ProjectInfoCacheOptions, + endpoint_cache_config: EndpointCacheConfig, ) -> Self { Self { node_info: NodeInfoCache::new( @@ -431,6 +441,7 @@ impl ApiCaches { true, ), project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)), + endpoints_cache: Arc::new(EndpointsCache::new(endpoint_cache_config)), } } } @@ -441,10 +452,8 @@ pub struct ApiLocks { node_locks: DashMap>, permits: usize, timeout: Duration, - registered: prometheus::IntCounter, - unregistered: prometheus::IntCounter, - reclamation_lag: prometheus::Histogram, - lock_acquire_lag: prometheus::Histogram, + epoch: std::time::Duration, + metrics: &'static ApiLockMetrics, } impl ApiLocks { @@ -453,54 +462,16 @@ impl ApiLocks { permits: usize, shards: usize, timeout: Duration, + epoch: std::time::Duration, + metrics: &'static ApiLockMetrics, ) -> prometheus::Result { - let registered = prometheus::IntCounter::with_opts( - prometheus::Opts::new( - "semaphores_registered", - "Number of semaphores registered in this api lock", - ) - .namespace(name), - )?; - prometheus::register(Box::new(registered.clone()))?; - let unregistered = prometheus::IntCounter::with_opts( - prometheus::Opts::new( - "semaphores_unregistered", - "Number of semaphores unregistered in this api lock", - ) - .namespace(name), - )?; - prometheus::register(Box::new(unregistered.clone()))?; - let reclamation_lag = prometheus::Histogram::with_opts( - prometheus::HistogramOpts::new( - "reclamation_lag_seconds", - "Time it takes to reclaim unused semaphores in the api lock", - ) - .namespace(name) - // 1us -> 65ms - // benchmarks on my mac indicate it's usually in the range of 256us and 512us - .buckets(prometheus::exponential_buckets(1e-6, 2.0, 16)?), - )?; - prometheus::register(Box::new(reclamation_lag.clone()))?; - let lock_acquire_lag = prometheus::Histogram::with_opts( - prometheus::HistogramOpts::new( - "semaphore_acquire_seconds", - "Time it takes to reclaim unused semaphores in the api lock", - ) - .namespace(name) - // 0.1ms -> 6s - .buckets(prometheus::exponential_buckets(1e-4, 2.0, 16)?), - )?; - prometheus::register(Box::new(lock_acquire_lag.clone()))?; - Ok(Self { name, node_locks: DashMap::with_shard_amount(shards), permits, timeout, - lock_acquire_lag, - registered, - unregistered, - reclamation_lag, + epoch, + metrics, }) } @@ -520,7 +491,7 @@ impl ApiLocks { self.node_locks .entry(key.clone()) .or_insert_with(|| { - self.registered.inc(); + self.metrics.semaphores_registered.inc(); Arc::new(Semaphore::new(self.permits)) }) .clone() @@ -528,20 +499,21 @@ impl ApiLocks { }; let permit = tokio::time::timeout_at(now + self.timeout, semaphore.acquire_owned()).await; - self.lock_acquire_lag - .observe((Instant::now() - now).as_secs_f64()); + self.metrics + .semaphore_acquire_seconds + .observe(now.elapsed().as_secs_f64()); Ok(WakeComputePermit { permit: Some(permit??), }) } - pub async fn garbage_collect_worker(&self, epoch: std::time::Duration) { + pub async fn garbage_collect_worker(&self) { if self.permits == 0 { return; } - - let mut interval = tokio::time::interval(epoch / (self.node_locks.shards().len()) as u32); + let mut interval = + tokio::time::interval(self.epoch / (self.node_locks.shards().len()) as u32); loop { for (i, shard) in self.node_locks.shards().iter().enumerate() { interval.tick().await; @@ -554,13 +526,13 @@ impl ApiLocks { "performing epoch reclamation on api lock" ); let mut lock = shard.write(); - let timer = self.reclamation_lag.start_timer(); + let timer = self.metrics.reclamation_lag_seconds.start_timer(); let count = lock .extract_if(|_, semaphore| Arc::strong_count(semaphore.get_mut()) == 1) .count(); drop(lock); - self.unregistered.inc_by(count as u64); - timer.observe_duration() + self.metrics.semaphores_unregistered.inc_by(count as u64); + timer.observe(); } } } diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index 1a3e2ca795..58b2a1570c 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -7,13 +7,15 @@ use super::{ NodeInfo, }; use crate::{ - auth::backend::ComputeUserInfo, compute, console::messages::ColdStartInfo, http, scram, -}; -use crate::{ - cache::Cached, - context::RequestMonitoring, - metrics::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER}, + auth::backend::ComputeUserInfo, + compute, + console::messages::ColdStartInfo, + http, + metrics::{CacheOutcome, Metrics}, + rate_limiter::EndpointRateLimiter, + scram, Normalize, }; +use crate::{cache::Cached, context::RequestMonitoring}; use futures::TryFutureExt; use std::sync::Arc; use tokio::time::Instant; @@ -23,7 +25,8 @@ use tracing::{error, info, info_span, warn, Instrument}; pub struct Api { endpoint: http::Endpoint, pub caches: &'static ApiCaches, - locks: &'static ApiLocks, + pub locks: &'static ApiLocks, + pub endpoint_rate_limiter: Arc, jwt: String, } @@ -33,6 +36,7 @@ impl Api { endpoint: http::Endpoint, caches: &'static ApiCaches, locks: &'static ApiLocks, + endpoint_rate_limiter: Arc, ) -> Self { let jwt: String = match std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN") { Ok(v) => v, @@ -42,6 +46,7 @@ impl Api { endpoint, caches, locks, + endpoint_rate_limiter, jwt, } } @@ -55,6 +60,15 @@ impl Api { ctx: &mut RequestMonitoring, user_info: &ComputeUserInfo, ) -> Result { + if !self + .caches + .endpoints_cache + .is_valid(ctx, &user_info.endpoint.normalize()) + .await + { + info!("endpoint is not valid, skipping the request"); + return Ok(AuthInfo::default()); + } let request_id = ctx.session_id.to_string(); let application_name = ctx.console_application_name(); async { @@ -81,7 +95,9 @@ impl Api { Ok(body) => body, // Error 404 is special: it's ok not to have a secret. Err(e) => match e.http_status_code() { - Some(http::StatusCode::NOT_FOUND) => return Ok(AuthInfo::default()), + Some(http::StatusCode::NOT_FOUND) => { + return Ok(AuthInfo::default()); + } _otherwise => return Err(e.into()), }, }; @@ -95,7 +111,10 @@ impl Api { Some(secret) }; let allowed_ips = body.allowed_ips.unwrap_or_default(); - ALLOWED_IPS_NUMBER.observe(allowed_ips.len() as f64); + Metrics::get() + .proxy + .allowed_ips_number + .observe(allowed_ips.len() as f64); Ok(AuthInfo { secret, allowed_ips, @@ -174,23 +193,27 @@ impl super::Api for Api { ctx: &mut RequestMonitoring, user_info: &ComputeUserInfo, ) -> Result { - let ep = &user_info.endpoint; + let normalized_ep = &user_info.endpoint.normalize(); let user = &user_info.user; - if let Some(role_secret) = self.caches.project_info.get_role_secret(ep, user) { + if let Some(role_secret) = self + .caches + .project_info + .get_role_secret(normalized_ep, user) + { return Ok(role_secret); } let auth_info = self.do_get_auth_info(ctx, user_info).await?; if let Some(project_id) = auth_info.project_id { - let ep_int = ep.into(); + let normalized_ep_int = normalized_ep.into(); self.caches.project_info.insert_role_secret( project_id, - ep_int, + normalized_ep_int, user.into(), auth_info.secret.clone(), ); self.caches.project_info.insert_allowed_ips( project_id, - ep_int, + normalized_ep_int, Arc::new(auth_info.allowed_ips), ); ctx.set_project_id(project_id); @@ -204,30 +227,34 @@ impl super::Api for Api { ctx: &mut RequestMonitoring, user_info: &ComputeUserInfo, ) -> Result<(CachedAllowedIps, Option), GetAuthInfoError> { - let ep = &user_info.endpoint; - if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(ep) { - ALLOWED_IPS_BY_CACHE_OUTCOME - .with_label_values(&["hit"]) - .inc(); + let normalized_ep = &user_info.endpoint.normalize(); + if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) { + Metrics::get() + .proxy + .allowed_ips_cache_misses + .inc(CacheOutcome::Hit); return Ok((allowed_ips, None)); } - ALLOWED_IPS_BY_CACHE_OUTCOME - .with_label_values(&["miss"]) - .inc(); + Metrics::get() + .proxy + .allowed_ips_cache_misses + .inc(CacheOutcome::Miss); let auth_info = self.do_get_auth_info(ctx, user_info).await?; let allowed_ips = Arc::new(auth_info.allowed_ips); let user = &user_info.user; if let Some(project_id) = auth_info.project_id { - let ep_int = ep.into(); + let normalized_ep_int = normalized_ep.into(); self.caches.project_info.insert_role_secret( project_id, - ep_int, + normalized_ep_int, user.into(), auth_info.secret.clone(), ); - self.caches - .project_info - .insert_allowed_ips(project_id, ep_int, allowed_ips.clone()); + self.caches.project_info.insert_allowed_ips( + project_id, + normalized_ep_int, + allowed_ips.clone(), + ); ctx.set_project_id(project_id); } Ok(( @@ -254,6 +281,14 @@ impl super::Api for Api { return Ok(cached); } + // check rate limit + if !self + .endpoint_rate_limiter + .check(user_info.endpoint.normalize().into(), 1) + { + return Err(WakeComputeError::TooManyConnections); + } + let permit = self.locks.get_wake_compute_permit(&key).await?; // after getting back a permit - it's possible the cache was filled diff --git a/proxy/src/context.rs b/proxy/src/context.rs index fec95f4722..17b82c08aa 100644 --- a/proxy/src/context.rs +++ b/proxy/src/context.rs @@ -5,14 +5,14 @@ use once_cell::sync::OnceCell; use smol_str::SmolStr; use std::net::IpAddr; use tokio::sync::mpsc; -use tracing::{field::display, info_span, Span}; +use tracing::{field::display, info, info_span, Span}; use uuid::Uuid; use crate::{ console::messages::{ColdStartInfo, MetricsAuxInfo}, error::ErrorKind, intern::{BranchIdInt, ProjectIdInt}, - metrics::{LatencyTimer, ENDPOINT_ERRORS_BY_KIND, ERROR_BY_KIND}, + metrics::{ConnectOutcome, InvalidEndpointsGroup, LatencyTimer, Metrics, Protocol}, DbName, EndpointId, RoleName, }; @@ -29,7 +29,7 @@ static LOG_CHAN: OnceCell> = OnceCell::ne pub struct RequestMonitoring { pub peer_addr: IpAddr, pub session_id: Uuid, - pub protocol: &'static str, + pub protocol: Protocol, first_packet: chrono::DateTime, region: &'static str, pub span: Span, @@ -50,6 +50,8 @@ pub struct RequestMonitoring { // This sender is here to keep the request monitoring channel open while requests are taking place. sender: Option>, pub latency_timer: LatencyTimer, + // Whether proxy decided that it's not a valid endpoint end rejected it before going to cplane. + rejected: Option, } #[derive(Clone, Debug)] @@ -65,7 +67,7 @@ impl RequestMonitoring { pub fn new( session_id: Uuid, peer_addr: IpAddr, - protocol: &'static str, + protocol: Protocol, region: &'static str, ) -> Self { let span = info_span!( @@ -74,6 +76,7 @@ impl RequestMonitoring { ?session_id, %peer_addr, ep = tracing::field::Empty, + role = tracing::field::Empty, ); Self { @@ -93,6 +96,7 @@ impl RequestMonitoring { error_kind: None, auth_method: None, success: false, + rejected: None, cold_start_info: ColdStartInfo::Unknown, sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()), @@ -102,7 +106,7 @@ impl RequestMonitoring { #[cfg(test)] pub fn test() -> Self { - RequestMonitoring::new(Uuid::now_v7(), [127, 0, 0, 1].into(), "test", "test") + RequestMonitoring::new(Uuid::now_v7(), [127, 0, 0, 1].into(), Protocol::Tcp, "test") } pub fn console_application_name(&self) -> String { @@ -113,6 +117,10 @@ impl RequestMonitoring { ) } + pub fn set_rejected(&mut self, rejected: bool) { + self.rejected = Some(rejected); + } + pub fn set_cold_start_info(&mut self, info: ColdStartInfo) { self.cold_start_info = info; self.latency_timer.cold_start_info(info); @@ -134,9 +142,9 @@ impl RequestMonitoring { pub fn set_endpoint_id(&mut self, endpoint_id: EndpointId) { if self.endpoint_id.is_none() { self.span.record("ep", display(&endpoint_id)); - crate::metrics::CONNECTING_ENDPOINTS - .with_label_values(&[self.protocol]) - .measure(&endpoint_id); + let metric = &Metrics::get().proxy.connecting_endpoints; + let label = metric.with_labels(self.protocol); + metric.get_metric(label).measure(&endpoint_id); self.endpoint_id = Some(endpoint_id); } } @@ -150,6 +158,7 @@ impl RequestMonitoring { } pub fn set_user(&mut self, user: RoleName) { + self.span.record("role", display(&user)); self.user = Some(user); } @@ -157,14 +166,22 @@ impl RequestMonitoring { self.auth_method = Some(auth_method); } + pub fn has_private_peer_addr(&self) -> bool { + match self.peer_addr { + IpAddr::V4(ip) => ip.is_private(), + _ => false, + } + } + pub fn set_error_kind(&mut self, kind: ErrorKind) { - ERROR_BY_KIND - .with_label_values(&[kind.to_metric_label()]) - .inc(); + // Do not record errors from the private address to metrics. + if !self.has_private_peer_addr() { + Metrics::get().proxy.errors_total.inc(kind); + } if let Some(ep) = &self.endpoint_id { - ENDPOINT_ERRORS_BY_KIND - .with_label_values(&[kind.to_metric_label()]) - .measure(ep); + let metric = &Metrics::get().proxy.endpoints_affected_by_errors; + let label = metric.with_labels(kind); + metric.get_metric(label).measure(ep); } self.error_kind = Some(kind); } @@ -178,6 +195,33 @@ impl RequestMonitoring { impl Drop for RequestMonitoring { fn drop(&mut self) { + let outcome = if self.success { + ConnectOutcome::Success + } else { + ConnectOutcome::Failed + }; + if let Some(rejected) = self.rejected { + let ep = self + .endpoint_id + .as_ref() + .map(|x| x.as_str()) + .unwrap_or_default(); + // This makes sense only if cache is disabled + info!( + ?outcome, + ?rejected, + ?ep, + "check endpoint is valid with outcome" + ); + Metrics::get() + .proxy + .invalid_endpoints_total + .inc(InvalidEndpointsGroup { + protocol: self.protocol, + rejected: rejected.into(), + outcome, + }); + } if let Some(tx) = self.sender.take() { let _: Result<(), _> = tx.send(RequestData::from(&*self)); } diff --git a/proxy/src/context/parquet.rs b/proxy/src/context/parquet.rs index eb77409429..e061216d15 100644 --- a/proxy/src/context/parquet.rs +++ b/proxy/src/context/parquet.rs @@ -111,7 +111,7 @@ impl From<&RequestMonitoring> for RequestData { super::AuthMethod::ScramSha256Plus => "scram_sha_256_plus", super::AuthMethod::Cleartext => "cleartext", }), - protocol: value.protocol, + protocol: value.protocol.as_str(), region: value.region, error: value.error_kind.as_ref().map(|e| e.to_metric_label()), success: value.success, diff --git a/proxy/src/error.rs b/proxy/src/error.rs index 4614f3913d..fdfe50a494 100644 --- a/proxy/src/error.rs +++ b/proxy/src/error.rs @@ -1,5 +1,7 @@ use std::{error::Error as StdError, fmt, io}; +use measured::FixedCardinalityLabel; + /// Upcast (almost) any error into an opaque [`io::Error`]. pub fn io_error(e: impl Into>) -> io::Error { io::Error::new(io::ErrorKind::Other, e) @@ -29,24 +31,29 @@ pub trait UserFacingError: ReportableError { } } -#[derive(Copy, Clone, Debug, Eq, PartialEq)] +#[derive(Copy, Clone, Debug, Eq, PartialEq, FixedCardinalityLabel)] +#[label(singleton = "type")] pub enum ErrorKind { /// Wrong password, unknown endpoint, protocol violation, etc... User, /// Network error between user and proxy. Not necessarily user error + #[label(rename = "clientdisconnect")] ClientDisconnect, /// Proxy self-imposed user rate limits + #[label(rename = "ratelimit")] RateLimit, /// Proxy self-imposed service-wise rate limits + #[label(rename = "serviceratelimit")] ServiceRateLimit, /// internal errors Service, /// Error communicating with control plane + #[label(rename = "controlplane")] ControlPlane, /// Postgres error diff --git a/proxy/src/http.rs b/proxy/src/http.rs index 59e1492ed4..e20488e23c 100644 --- a/proxy/src/http.rs +++ b/proxy/src/http.rs @@ -13,13 +13,16 @@ pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use tokio::time::Instant; use tracing::trace; -use crate::{metrics::CONSOLE_REQUEST_LATENCY, rate_limiter, url::ApiUrl}; +use crate::{ + metrics::{ConsoleRequest, Metrics}, + url::ApiUrl, +}; use reqwest_middleware::RequestBuilder; /// This is the preferred way to create new http clients, /// because it takes care of observability (OpenTelemetry). /// We deliberately don't want to replace this with a public static. -pub fn new_client(rate_limiter_config: rate_limiter::RateLimiterConfig) -> ClientWithMiddleware { +pub fn new_client() -> ClientWithMiddleware { let client = reqwest::ClientBuilder::new() .dns_resolver(Arc::new(GaiResolver::default())) .connection_verbose(true) @@ -28,7 +31,6 @@ pub fn new_client(rate_limiter_config: rate_limiter::RateLimiterConfig) -> Clien reqwest_middleware::ClientBuilder::new(client) .with(reqwest_tracing::TracingMiddleware::default()) - .with(rate_limiter::Limiter::new(rate_limiter_config)) .build() } @@ -90,13 +92,14 @@ impl Endpoint { /// Execute a [request](reqwest::Request). pub async fn execute(&self, request: Request) -> Result { - let path = request.url().path().to_string(); - let start = Instant::now(); - let res = self.client.execute(request).await; - CONSOLE_REQUEST_LATENCY - .with_label_values(&[&path]) - .observe(start.elapsed().as_secs_f64()); - res + let _timer = Metrics::get() + .proxy + .console_request_latency + .start_timer(ConsoleRequest { + request: request.url().path(), + }); + + self.client.execute(request).await } } diff --git a/proxy/src/http/health_server.rs b/proxy/src/http/health_server.rs index cbb17ebcb7..cae9eb5b97 100644 --- a/proxy/src/http/health_server.rs +++ b/proxy/src/http/health_server.rs @@ -1,30 +1,49 @@ use anyhow::{anyhow, bail}; -use hyper::{Body, Request, Response, StatusCode}; -use std::{convert::Infallible, net::TcpListener}; -use tracing::info; +use hyper::{header::CONTENT_TYPE, Body, Request, Response, StatusCode}; +use measured::{text::BufferedTextEncoder, MetricGroup}; +use metrics::NeonMetrics; +use std::{ + convert::Infallible, + net::TcpListener, + sync::{Arc, Mutex}, +}; +use tracing::{info, info_span}; use utils::http::{ - endpoint::{self, prometheus_metrics_handler, request_span}, + endpoint::{self, request_span}, error::ApiError, json::json_response, RouterBuilder, RouterService, }; +use crate::jemalloc; + async fn status_handler(_: Request) -> Result, ApiError> { json_response(StatusCode::OK, "") } -fn make_router() -> RouterBuilder { +fn make_router(metrics: AppMetrics) -> RouterBuilder { + let state = Arc::new(Mutex::new(PrometheusHandler { + encoder: BufferedTextEncoder::new(), + metrics, + })); + endpoint::make_router() - .get("/metrics", |r| request_span(r, prometheus_metrics_handler)) + .get("/metrics", move |r| { + let state = state.clone(); + request_span(r, move |b| prometheus_metrics_handler(b, state)) + }) .get("/v1/status", status_handler) } -pub async fn task_main(http_listener: TcpListener) -> anyhow::Result { +pub async fn task_main( + http_listener: TcpListener, + metrics: AppMetrics, +) -> anyhow::Result { scopeguard::defer! { info!("http has shut down"); } - let service = || RouterService::new(make_router().build()?); + let service = || RouterService::new(make_router(metrics).build()?); hyper::Server::from_tcp(http_listener)? .serve(service().map_err(|e| anyhow!(e))?) @@ -32,3 +51,57 @@ pub async fn task_main(http_listener: TcpListener) -> anyhow::Result bail!("hyper server without shutdown handling cannot shutdown successfully"); } + +struct PrometheusHandler { + encoder: BufferedTextEncoder, + metrics: AppMetrics, +} + +#[derive(MetricGroup)] +pub struct AppMetrics { + #[metric(namespace = "jemalloc")] + pub jemalloc: Option, + #[metric(flatten)] + pub neon_metrics: NeonMetrics, + #[metric(flatten)] + pub proxy: &'static crate::metrics::Metrics, +} + +async fn prometheus_metrics_handler( + _req: Request, + state: Arc>, +) -> Result, ApiError> { + let started_at = std::time::Instant::now(); + + let span = info_span!("blocking"); + let body = tokio::task::spawn_blocking(move || { + let _span = span.entered(); + + let mut state = state.lock().unwrap(); + let PrometheusHandler { encoder, metrics } = &mut *state; + + metrics + .collect_group_into(&mut *encoder) + .unwrap_or_else(|infallible| match infallible {}); + + let body = encoder.finish(); + + tracing::info!( + bytes = body.len(), + elapsed_ms = started_at.elapsed().as_millis(), + "responded /metrics" + ); + + body + }) + .await + .unwrap(); + + let response = Response::builder() + .status(200) + .header(CONTENT_TYPE, "text/plain; version=0.0.4") + .body(Body::from(body)) + .unwrap(); + + Ok(response) +} diff --git a/proxy/src/intern.rs b/proxy/src/intern.rs index a6519bdff9..e38135dd22 100644 --- a/proxy/src/intern.rs +++ b/proxy/src/intern.rs @@ -160,6 +160,11 @@ impl From<&EndpointId> for EndpointIdInt { EndpointIdTag::get_interner().get_or_intern(value) } } +impl From for EndpointIdInt { + fn from(value: EndpointId) -> Self { + EndpointIdTag::get_interner().get_or_intern(&value) + } +} #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] pub struct BranchIdTag; @@ -175,6 +180,11 @@ impl From<&BranchId> for BranchIdInt { BranchIdTag::get_interner().get_or_intern(value) } } +impl From for BranchIdInt { + fn from(value: BranchId) -> Self { + BranchIdTag::get_interner().get_or_intern(&value) + } +} #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] pub struct ProjectIdTag; @@ -190,6 +200,11 @@ impl From<&ProjectId> for ProjectIdInt { ProjectIdTag::get_interner().get_or_intern(value) } } +impl From for ProjectIdInt { + fn from(value: ProjectId) -> Self { + ProjectIdTag::get_interner().get_or_intern(&value) + } +} #[cfg(test)] mod tests { diff --git a/proxy/src/jemalloc.rs b/proxy/src/jemalloc.rs index ed20798d56..3243e6a140 100644 --- a/proxy/src/jemalloc.rs +++ b/proxy/src/jemalloc.rs @@ -1,27 +1,45 @@ -use std::time::Duration; +use std::marker::PhantomData; -use metrics::IntGauge; -use prometheus::{register_int_gauge_with_registry, Registry}; +use measured::{ + label::NoLabels, + metric::{ + gauge::GaugeState, group::Encoding, group::MetricValue, name::MetricNameEncoder, + MetricEncoding, MetricFamilyEncoding, MetricType, + }, + text::TextEncoder, + LabelGroup, MetricGroup, +}; use tikv_jemalloc_ctl::{config, epoch, epoch_mib, stats, version}; pub struct MetricRecorder { epoch: epoch_mib, - active: stats::active_mib, - active_gauge: IntGauge, - allocated: stats::allocated_mib, - allocated_gauge: IntGauge, - mapped: stats::mapped_mib, - mapped_gauge: IntGauge, - metadata: stats::metadata_mib, - metadata_gauge: IntGauge, - resident: stats::resident_mib, - resident_gauge: IntGauge, - retained: stats::retained_mib, - retained_gauge: IntGauge, + inner: Metrics, +} + +#[derive(MetricGroup)] +struct Metrics { + active_bytes: JemallocGaugeFamily, + allocated_bytes: JemallocGaugeFamily, + mapped_bytes: JemallocGaugeFamily, + metadata_bytes: JemallocGaugeFamily, + resident_bytes: JemallocGaugeFamily, + retained_bytes: JemallocGaugeFamily, +} + +impl MetricGroup for MetricRecorder +where + Metrics: MetricGroup, +{ + fn collect_group_into(&self, enc: &mut Enc) -> Result<(), Enc::Err> { + if self.epoch.advance().is_ok() { + self.inner.collect_group_into(enc)?; + } + Ok(()) + } } impl MetricRecorder { - pub fn new(registry: &Registry) -> Result { + pub fn new() -> Result { tracing::info!( config = config::malloc_conf::read()?, version = version::read()?, @@ -30,71 +48,69 @@ impl MetricRecorder { Ok(Self { epoch: epoch::mib()?, - active: stats::active::mib()?, - active_gauge: register_int_gauge_with_registry!( - "jemalloc_active_bytes", - "Total number of bytes in active pages allocated by the process", - registry - )?, - allocated: stats::allocated::mib()?, - allocated_gauge: register_int_gauge_with_registry!( - "jemalloc_allocated_bytes", - "Total number of bytes allocated by the process", - registry - )?, - mapped: stats::mapped::mib()?, - mapped_gauge: register_int_gauge_with_registry!( - "jemalloc_mapped_bytes", - "Total number of bytes in active extents mapped by the allocator", - registry - )?, - metadata: stats::metadata::mib()?, - metadata_gauge: register_int_gauge_with_registry!( - "jemalloc_metadata_bytes", - "Total number of bytes dedicated to jemalloc metadata", - registry - )?, - resident: stats::resident::mib()?, - resident_gauge: register_int_gauge_with_registry!( - "jemalloc_resident_bytes", - "Total number of bytes in physically resident data pages mapped by the allocator", - registry - )?, - retained: stats::retained::mib()?, - retained_gauge: register_int_gauge_with_registry!( - "jemalloc_retained_bytes", - "Total number of bytes in virtual memory mappings that were retained rather than being returned to the operating system", - registry - )?, - }) - } - - fn _poll(&self) -> Result<(), anyhow::Error> { - self.epoch.advance()?; - self.active_gauge.set(self.active.read()? as i64); - self.allocated_gauge.set(self.allocated.read()? as i64); - self.mapped_gauge.set(self.mapped.read()? as i64); - self.metadata_gauge.set(self.metadata.read()? as i64); - self.resident_gauge.set(self.resident.read()? as i64); - self.retained_gauge.set(self.retained.read()? as i64); - Ok(()) - } - - #[inline] - pub fn poll(&self) { - if let Err(error) = self._poll() { - tracing::warn!(%error, "Failed to poll jemalloc stats"); - } - } - - pub fn start(self) -> tokio::task::JoinHandle<()> { - tokio::task::spawn(async move { - let mut interval = tokio::time::interval(Duration::from_secs(15)); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - loop { - self.poll(); - interval.tick().await; - } + inner: Metrics { + active_bytes: JemallocGaugeFamily(stats::active::mib()?), + allocated_bytes: JemallocGaugeFamily(stats::allocated::mib()?), + mapped_bytes: JemallocGaugeFamily(stats::mapped::mib()?), + metadata_bytes: JemallocGaugeFamily(stats::metadata::mib()?), + resident_bytes: JemallocGaugeFamily(stats::resident::mib()?), + retained_bytes: JemallocGaugeFamily(stats::retained::mib()?), + }, }) } } + +struct JemallocGauge(PhantomData); + +impl Default for JemallocGauge { + fn default() -> Self { + JemallocGauge(PhantomData) + } +} +impl MetricType for JemallocGauge { + type Metadata = T; +} + +struct JemallocGaugeFamily(T); +impl MetricFamilyEncoding for JemallocGaugeFamily +where + JemallocGauge: MetricEncoding, +{ + fn collect_family_into(&self, name: impl MetricNameEncoder, enc: &mut T) -> Result<(), T::Err> { + JemallocGauge::write_type(&name, enc)?; + JemallocGauge(PhantomData).collect_into(&self.0, NoLabels, name, enc) + } +} + +macro_rules! jemalloc_gauge { + ($stat:ident, $mib:ident) => { + impl MetricEncoding> for JemallocGauge { + fn write_type( + name: impl MetricNameEncoder, + enc: &mut TextEncoder, + ) -> Result<(), std::io::Error> { + GaugeState::write_type(name, enc) + } + + fn collect_into( + &self, + mib: &stats::$mib, + labels: impl LabelGroup, + name: impl MetricNameEncoder, + enc: &mut TextEncoder, + ) -> Result<(), std::io::Error> { + if let Ok(v) = mib.read() { + enc.write_metric_value(name, labels, MetricValue::Int(v as i64))?; + } + Ok(()) + } + } + }; +} + +jemalloc_gauge!(active, active_mib); +jemalloc_gauge!(allocated, allocated_mib); +jemalloc_gauge!(mapped, mapped_mib); +jemalloc_gauge!(metadata, metadata_mib); +jemalloc_gauge!(resident, resident_mib); +jemalloc_gauge!(retained, retained_mib); diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index da7c7f3ed2..3f6d985fe8 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -127,6 +127,24 @@ macro_rules! smol_str_wrapper { }; } +const POOLER_SUFFIX: &str = "-pooler"; + +pub trait Normalize { + fn normalize(&self) -> Self; +} + +impl + From> Normalize for S { + fn normalize(&self) -> Self { + if self.as_ref().ends_with(POOLER_SUFFIX) { + let mut s = self.as_ref().to_string(); + s.truncate(s.len() - POOLER_SUFFIX.len()); + s.into() + } else { + self.clone() + } + } +} + // 90% of role name strings are 20 characters or less. smol_str_wrapper!(RoleName); // 50% of endpoint strings are 23 characters or less. @@ -140,3 +158,22 @@ smol_str_wrapper!(ProjectId); smol_str_wrapper!(EndpointCacheKey); smol_str_wrapper!(DbName); + +// Endpoints are a bit tricky. Rare they might be branches or projects. +impl EndpointId { + pub fn is_endpoint(&self) -> bool { + self.0.starts_with("ep-") + } + pub fn is_branch(&self) -> bool { + self.0.starts_with("br-") + } + pub fn is_project(&self) -> bool { + !self.is_endpoint() && !self.is_branch() + } + pub fn as_branch(&self) -> BranchId { + BranchId(self.0.clone()) + } + pub fn as_project(&self) -> ProjectId { + ProjectId(self.0.clone()) + } +} diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index 59ee899c08..3a4e54aea0 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -1,176 +1,348 @@ -use ::metrics::{ - exponential_buckets, register_histogram, register_histogram_vec, register_hll_vec, - register_int_counter_pair_vec, register_int_counter_vec, register_int_gauge, - register_int_gauge_vec, Histogram, HistogramVec, HyperLogLogVec, IntCounterPairVec, - IntCounterVec, IntGauge, IntGaugeVec, -}; -use metrics::{ - register_hll, register_int_counter, register_int_counter_pair, HyperLogLog, IntCounter, - IntCounterPair, -}; +use std::sync::OnceLock; + +use lasso::ThreadedRodeo; +use measured::{ + label::StaticLabelSet, + metric::{histogram::Thresholds, name::MetricName}, + Counter, CounterVec, FixedCardinalityLabel, Gauge, Histogram, HistogramVec, LabelGroup, + MetricGroup, +}; +use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLog, HyperLogLogVec}; -use once_cell::sync::Lazy; use tokio::time::{self, Instant}; use crate::console::messages::ColdStartInfo; -pub static NUM_DB_CONNECTIONS_GAUGE: Lazy = Lazy::new(|| { - register_int_counter_pair_vec!( - "proxy_opened_db_connections_total", - "Number of opened connections to a database.", - "proxy_closed_db_connections_total", - "Number of closed connections to a database.", - &["protocol"], - ) - .unwrap() -}); +#[derive(MetricGroup)] +pub struct Metrics { + #[metric(namespace = "proxy")] + pub proxy: ProxyMetrics, -pub static NUM_CLIENT_CONNECTION_GAUGE: Lazy = Lazy::new(|| { - register_int_counter_pair_vec!( - "proxy_opened_client_connections_total", - "Number of opened connections from a client.", - "proxy_closed_client_connections_total", - "Number of closed connections from a client.", - &["protocol"], - ) - .unwrap() -}); + #[metric(namespace = "wake_compute_lock")] + pub wake_compute_lock: ApiLockMetrics, +} -pub static NUM_CONNECTION_REQUESTS_GAUGE: Lazy = Lazy::new(|| { - register_int_counter_pair_vec!( - "proxy_accepted_connections_total", - "Number of client connections accepted.", - "proxy_closed_connections_total", - "Number of client connections closed.", - &["protocol"], - ) - .unwrap() -}); +impl Metrics { + pub fn get() -> &'static Self { + static SELF: OnceLock = OnceLock::new(); + SELF.get_or_init(|| Metrics { + proxy: ProxyMetrics::default(), + wake_compute_lock: ApiLockMetrics::new(), + }) + } +} -pub static COMPUTE_CONNECTION_LATENCY: Lazy = Lazy::new(|| { - register_histogram_vec!( - "proxy_compute_connection_latency_seconds", - "Time it took for proxy to establish a connection to the compute endpoint", - // http/ws/tcp, true/false, true/false, success/failure, client/client_and_cplane - // 3 * 6 * 2 * 2 = 72 counters - &["protocol", "cold_start_info", "outcome", "excluded"], - // largest bucket = 2^16 * 0.5ms = 32s - exponential_buckets(0.0005, 2.0, 16).unwrap(), - ) - .unwrap() -}); +#[derive(MetricGroup)] +#[metric(new())] +pub struct ProxyMetrics { + #[metric(flatten)] + pub db_connections: CounterPairVec, + #[metric(flatten)] + pub client_connections: CounterPairVec, + #[metric(flatten)] + pub connection_requests: CounterPairVec, + #[metric(flatten)] + pub http_endpoint_pools: HttpEndpointPools, -pub static CONSOLE_REQUEST_LATENCY: Lazy = Lazy::new(|| { - register_histogram_vec!( - "proxy_console_request_latency", - "Time it took for proxy to establish a connection to the compute endpoint", - // proxy_wake_compute/proxy_get_role_info - &["request"], + /// Time it took for proxy to establish a connection to the compute endpoint. + // largest bucket = 2^16 * 0.5ms = 32s + #[metric(metadata = Thresholds::exponential_buckets(0.0005, 2.0))] + pub compute_connection_latency_seconds: HistogramVec, + + /// Time it took for proxy to receive a response from control plane. + #[metric( // largest bucket = 2^16 * 0.2ms = 13s - exponential_buckets(0.0002, 2.0, 16).unwrap(), - ) - .unwrap() -}); + metadata = Thresholds::exponential_buckets(0.0002, 2.0), + )] + pub console_request_latency: HistogramVec, -pub static ALLOWED_IPS_BY_CACHE_OUTCOME: Lazy = Lazy::new(|| { - register_int_counter_vec!( - "proxy_allowed_ips_cache_misses", - "Number of cache hits/misses for allowed ips", - // hit/miss - &["outcome"], - ) - .unwrap() -}); + /// Time it takes to acquire a token to call console plane. + // largest bucket = 3^16 * 0.05ms = 2.15s + #[metric(metadata = Thresholds::exponential_buckets(0.00005, 3.0))] + pub control_plane_token_acquire_seconds: Histogram<16>, -pub static RATE_LIMITER_ACQUIRE_LATENCY: Lazy = Lazy::new(|| { - register_histogram!( - "proxy_control_plane_token_acquire_seconds", - "Time it took for proxy to establish a connection to the compute endpoint", - // largest bucket = 3^16 * 0.05ms = 2.15s - exponential_buckets(0.00005, 3.0, 16).unwrap(), - ) - .unwrap() -}); + /// Size of the HTTP request body lengths. + // smallest bucket = 16 bytes + // largest bucket = 4^12 * 16 bytes = 256MB + #[metric(metadata = Thresholds::exponential_buckets(16.0, 4.0))] + pub http_conn_content_length_bytes: HistogramVec, 12>, -pub static RATE_LIMITER_LIMIT: Lazy = Lazy::new(|| { - register_int_gauge_vec!( - "semaphore_control_plane_limit", - "Current limit of the semaphore control plane", - &["limit"], // 2 counters - ) - .unwrap() -}); + /// Time it takes to reclaim unused connection pools. + #[metric(metadata = Thresholds::exponential_buckets(1e-6, 2.0))] + pub http_pool_reclaimation_lag_seconds: Histogram<16>, -pub static NUM_CONNECTION_ACCEPTED_BY_SNI: Lazy = Lazy::new(|| { - register_int_counter_vec!( - "proxy_accepted_connections_by_sni", - "Number of connections (per sni).", - &["kind"], - ) - .unwrap() -}); + /// Number of opened connections to a database. + pub http_pool_opened_connections: Gauge, -pub static ALLOWED_IPS_NUMBER: Lazy = Lazy::new(|| { - register_histogram!( - "proxy_allowed_ips_number", - "Number of allowed ips", - vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 50.0, 100.0], - ) - .unwrap() -}); + /// Number of cache hits/misses for allowed ips. + pub allowed_ips_cache_misses: CounterVec>, -pub static HTTP_CONTENT_LENGTH: Lazy = Lazy::new(|| { - register_histogram_vec!( - "proxy_http_conn_content_length_bytes", - "Number of bytes the HTTP response content consumes", - // request/response - &["direction"], - // smallest bucket = 16 bytes - // largest bucket = 4^12 * 16 bytes = 256MB - exponential_buckets(16.0, 4.0, 12).unwrap() - ) - .unwrap() -}); + /// Number of allowed ips + #[metric(metadata = Thresholds::with_buckets([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 50.0, 100.0]))] + pub allowed_ips_number: Histogram<10>, -pub static GC_LATENCY: Lazy = Lazy::new(|| { - register_histogram!( - "proxy_http_pool_reclaimation_lag_seconds", - "Time it takes to reclaim unused connection pools", - // 1us -> 65ms - exponential_buckets(1e-6, 2.0, 16).unwrap(), - ) - .unwrap() -}); + /// Number of connections (per sni). + pub accepted_connections_by_sni: CounterVec>, -pub static ENDPOINT_POOLS: Lazy = Lazy::new(|| { - register_int_counter_pair!( - "proxy_http_pool_endpoints_registered_total", - "Number of endpoints we have registered pools for", - "proxy_http_pool_endpoints_unregistered_total", - "Number of endpoints we have unregistered pools for", - ) - .unwrap() -}); + /// Number of connection failures (per kind). + pub connection_failures_total: CounterVec>, -pub static NUM_OPEN_CLIENTS_IN_HTTP_POOL: Lazy = Lazy::new(|| { - register_int_gauge!( - "proxy_http_pool_opened_connections", - "Number of opened connections to a database.", - ) - .unwrap() -}); + /// Number of wake-up failures (per kind). + pub connection_failures_breakdown: CounterVec, -pub static NUM_CANCELLATION_REQUESTS: Lazy = Lazy::new(|| { - register_int_counter_vec!( - "proxy_cancellation_requests_total", - "Number of cancellation requests (per found/not_found).", - &["source", "kind"], - ) - .unwrap() -}); + /// Number of bytes sent/received between all clients and backends. + pub io_bytes: CounterVec>, -pub const NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT: &str = "from_client"; -pub const NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS: &str = "from_redis"; + /// Number of errors by a given classification. + pub errors_total: CounterVec>, + + /// Number of cancellation requests (per found/not_found). + pub cancellation_requests_total: CounterVec, + + /// Number of errors by a given classification + pub redis_errors_total: CounterVec, + + /// Number of TLS handshake failures + pub tls_handshake_failures: Counter, + + /// Number of connection requests affected by authentication rate limits + pub requests_auth_rate_limits_total: Counter, + + /// HLL approximate cardinality of endpoints that are connecting + pub connecting_endpoints: HyperLogLogVec, 32>, + + /// Number of endpoints affected by errors of a given classification + pub endpoints_affected_by_errors: HyperLogLogVec, 32>, + + /// Number of endpoints affected by authentication rate limits + pub endpoints_auth_rate_limits: HyperLogLog<32>, + + /// Number of invalid endpoints (per protocol, per rejected). + pub invalid_endpoints_total: CounterVec, +} + +#[derive(MetricGroup)] +#[metric(new())] +pub struct ApiLockMetrics { + /// Number of semaphores registered in this api lock + pub semaphores_registered: Counter, + /// Number of semaphores unregistered in this api lock + pub semaphores_unregistered: Counter, + /// Time it takes to reclaim unused semaphores in the api lock + #[metric(metadata = Thresholds::exponential_buckets(1e-6, 2.0))] + pub reclamation_lag_seconds: Histogram<16>, + /// Time it takes to acquire a semaphore lock + #[metric(metadata = Thresholds::exponential_buckets(1e-4, 2.0))] + pub semaphore_acquire_seconds: Histogram<16>, +} + +impl Default for ProxyMetrics { + fn default() -> Self { + Self::new() + } +} + +#[derive(FixedCardinalityLabel, Copy, Clone)] +#[label(singleton = "direction")] +pub enum HttpDirection { + Request, + Response, +} + +#[derive(FixedCardinalityLabel, Copy, Clone)] +#[label(singleton = "direction")] +pub enum Direction { + Tx, + Rx, +} + +#[derive(FixedCardinalityLabel, Clone, Copy, Debug)] +#[label(singleton = "protocol")] +pub enum Protocol { + Http, + Ws, + Tcp, + SniRouter, +} + +impl Protocol { + pub fn as_str(&self) -> &'static str { + match self { + Protocol::Http => "http", + Protocol::Ws => "ws", + Protocol::Tcp => "tcp", + Protocol::SniRouter => "sni_router", + } + } +} + +impl std::fmt::Display for Protocol { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + +#[derive(FixedCardinalityLabel, Copy, Clone)] +pub enum Bool { + True, + False, +} + +#[derive(FixedCardinalityLabel, Copy, Clone)] +#[label(singleton = "outcome")] +pub enum Outcome { + Success, + Failed, +} + +#[derive(FixedCardinalityLabel, Copy, Clone)] +#[label(singleton = "outcome")] +pub enum CacheOutcome { + Hit, + Miss, +} + +#[derive(LabelGroup)] +#[label(set = ConsoleRequestSet)] +pub struct ConsoleRequest<'a> { + #[label(dynamic_with = ThreadedRodeo, default)] + pub request: &'a str, +} + +#[derive(MetricGroup, Default)] +pub struct HttpEndpointPools { + /// Number of endpoints we have registered pools for + pub http_pool_endpoints_registered_total: Counter, + /// Number of endpoints we have unregistered pools for + pub http_pool_endpoints_unregistered_total: Counter, +} + +pub struct HttpEndpointPoolsGuard<'a> { + dec: &'a Counter, +} + +impl Drop for HttpEndpointPoolsGuard<'_> { + fn drop(&mut self) { + self.dec.inc(); + } +} + +impl HttpEndpointPools { + pub fn guard(&self) -> HttpEndpointPoolsGuard { + self.http_pool_endpoints_registered_total.inc(); + HttpEndpointPoolsGuard { + dec: &self.http_pool_endpoints_unregistered_total, + } + } +} +pub struct NumDbConnectionsGauge; +impl CounterPairAssoc for NumDbConnectionsGauge { + const INC_NAME: &'static MetricName = MetricName::from_str("opened_db_connections_total"); + const DEC_NAME: &'static MetricName = MetricName::from_str("closed_db_connections_total"); + const INC_HELP: &'static str = "Number of opened connections to a database."; + const DEC_HELP: &'static str = "Number of closed connections to a database."; + type LabelGroupSet = StaticLabelSet; +} +pub type NumDbConnectionsGuard<'a> = metrics::MeasuredCounterPairGuard<'a, NumDbConnectionsGauge>; + +pub struct NumClientConnectionsGauge; +impl CounterPairAssoc for NumClientConnectionsGauge { + const INC_NAME: &'static MetricName = MetricName::from_str("opened_client_connections_total"); + const DEC_NAME: &'static MetricName = MetricName::from_str("closed_client_connections_total"); + const INC_HELP: &'static str = "Number of opened connections from a client."; + const DEC_HELP: &'static str = "Number of closed connections from a client."; + type LabelGroupSet = StaticLabelSet; +} +pub type NumClientConnectionsGuard<'a> = + metrics::MeasuredCounterPairGuard<'a, NumClientConnectionsGauge>; + +pub struct NumConnectionRequestsGauge; +impl CounterPairAssoc for NumConnectionRequestsGauge { + const INC_NAME: &'static MetricName = MetricName::from_str("accepted_connections_total"); + const DEC_NAME: &'static MetricName = MetricName::from_str("closed_connections_total"); + const INC_HELP: &'static str = "Number of client connections accepted."; + const DEC_HELP: &'static str = "Number of client connections closed."; + type LabelGroupSet = StaticLabelSet; +} +pub type NumConnectionRequestsGuard<'a> = + metrics::MeasuredCounterPairGuard<'a, NumConnectionRequestsGauge>; + +#[derive(LabelGroup)] +#[label(set = ComputeConnectionLatencySet)] +pub struct ComputeConnectionLatencyGroup { + protocol: Protocol, + cold_start_info: ColdStartInfo, + outcome: ConnectOutcome, + excluded: LatencyExclusions, +} + +#[derive(FixedCardinalityLabel, Copy, Clone)] +pub enum LatencyExclusions { + Client, + ClientAndCplane, +} + +#[derive(FixedCardinalityLabel, Copy, Clone)] +#[label(singleton = "kind")] +pub enum SniKind { + Sni, + NoSni, + PasswordHack, +} + +#[derive(FixedCardinalityLabel, Copy, Clone)] +#[label(singleton = "kind")] +pub enum ConnectionFailureKind { + ComputeCached, + ComputeUncached, +} + +#[derive(FixedCardinalityLabel, Copy, Clone)] +#[label(singleton = "kind")] +pub enum WakeupFailureKind { + BadComputeAddress, + ApiTransportError, + QuotaExceeded, + ApiConsoleLocked, + ApiConsoleBadRequest, + ApiConsoleOtherServerError, + ApiConsoleOtherError, + TimeoutError, +} + +#[derive(LabelGroup)] +#[label(set = ConnectionFailuresBreakdownSet)] +pub struct ConnectionFailuresBreakdownGroup { + pub kind: WakeupFailureKind, + pub retry: Bool, +} + +#[derive(LabelGroup, Copy, Clone)] +#[label(set = RedisErrorsSet)] +pub struct RedisErrors<'a> { + #[label(dynamic_with = ThreadedRodeo, default)] + pub channel: &'a str, +} + +#[derive(FixedCardinalityLabel, Copy, Clone)] +pub enum CancellationSource { + FromClient, + FromRedis, + Local, +} + +#[derive(FixedCardinalityLabel, Copy, Clone)] +pub enum CancellationOutcome { + NotFound, + Found, +} + +#[derive(LabelGroup)] +#[label(set = CancellationRequestSet)] +pub struct CancellationRequest { + pub source: CancellationSource, + pub kind: CancellationOutcome, +} pub enum Waiting { Cplane, @@ -185,20 +357,6 @@ struct Accumulated { compute: time::Duration, } -enum Outcome { - Success, - Failed, -} - -impl Outcome { - fn as_str(&self) -> &'static str { - match self { - Outcome::Success => "success", - Outcome::Failed => "failed", - } - } -} - pub struct LatencyTimer { // time since the stopwatch was started start: time::Instant, @@ -207,9 +365,9 @@ pub struct LatencyTimer { // accumulated time on the stopwatch accumulated: Accumulated, // label data - protocol: &'static str, + protocol: Protocol, cold_start_info: ColdStartInfo, - outcome: Outcome, + outcome: ConnectOutcome, } pub struct LatencyTimerPause<'a> { @@ -219,7 +377,7 @@ pub struct LatencyTimerPause<'a> { } impl LatencyTimer { - pub fn new(protocol: &'static str) -> Self { + pub fn new(protocol: Protocol) -> Self { Self { start: time::Instant::now(), stop: None, @@ -227,7 +385,7 @@ impl LatencyTimer { protocol, cold_start_info: ColdStartInfo::Unknown, // assume failed unless otherwise specified - outcome: Outcome::Failed, + outcome: ConnectOutcome::Failed, } } @@ -248,7 +406,7 @@ impl LatencyTimer { self.stop = Some(time::Instant::now()); // success - self.outcome = Outcome::Success; + self.outcome = ConnectOutcome::Success; } } @@ -263,128 +421,62 @@ impl Drop for LatencyTimerPause<'_> { } } +#[derive(FixedCardinalityLabel, Clone, Copy, Debug)] +pub enum ConnectOutcome { + Success, + Failed, +} + impl Drop for LatencyTimer { fn drop(&mut self) { let duration = self .stop .unwrap_or_else(time::Instant::now) .duration_since(self.start); - // Excluding cplane communication from the accumulated time. - COMPUTE_CONNECTION_LATENCY - .with_label_values(&[ - self.protocol, - self.cold_start_info.as_str(), - self.outcome.as_str(), - "client", - ]) - .observe((duration.saturating_sub(self.accumulated.client)).as_secs_f64()); + + let metric = &Metrics::get().proxy.compute_connection_latency_seconds; + + // Excluding client communication from the accumulated time. + metric.observe( + ComputeConnectionLatencyGroup { + protocol: self.protocol, + cold_start_info: self.cold_start_info, + outcome: self.outcome, + excluded: LatencyExclusions::Client, + }, + duration + .saturating_sub(self.accumulated.client) + .as_secs_f64(), + ); + // Exclude client and cplane communication from the accumulated time. let accumulated_total = self.accumulated.client + self.accumulated.cplane; - COMPUTE_CONNECTION_LATENCY - .with_label_values(&[ - self.protocol, - self.cold_start_info.as_str(), - self.outcome.as_str(), - "client_and_cplane", - ]) - .observe((duration.saturating_sub(accumulated_total)).as_secs_f64()); + metric.observe( + ComputeConnectionLatencyGroup { + protocol: self.protocol, + cold_start_info: self.cold_start_info, + outcome: self.outcome, + excluded: LatencyExclusions::ClientAndCplane, + }, + duration.saturating_sub(accumulated_total).as_secs_f64(), + ); } } -pub static NUM_CONNECTION_FAILURES: Lazy = Lazy::new(|| { - register_int_counter_vec!( - "proxy_connection_failures_total", - "Number of connection failures (per kind).", - &["kind"], - ) - .unwrap() -}); - -pub static NUM_WAKEUP_FAILURES: Lazy = Lazy::new(|| { - register_int_counter_vec!( - "proxy_connection_failures_breakdown", - "Number of wake-up failures (per kind).", - &["retry", "kind"], - ) - .unwrap() -}); - -pub static NUM_BYTES_PROXIED_COUNTER: Lazy = Lazy::new(|| { - register_int_counter_vec!( - "proxy_io_bytes", - "Number of bytes sent/received between all clients and backends.", - &["direction"], - ) - .unwrap() -}); - -pub const fn bool_to_str(x: bool) -> &'static str { - if x { - "true" - } else { - "false" +impl From for Bool { + fn from(value: bool) -> Self { + if value { + Bool::True + } else { + Bool::False + } } } -pub static CONNECTING_ENDPOINTS: Lazy> = Lazy::new(|| { - register_hll_vec!( - 32, - "proxy_connecting_endpoints", - "HLL approximate cardinality of endpoints that are connecting", - &["protocol"], - ) - .unwrap() -}); - -pub static ERROR_BY_KIND: Lazy = Lazy::new(|| { - register_int_counter_vec!( - "proxy_errors_total", - "Number of errors by a given classification", - &["type"], - ) - .unwrap() -}); - -pub static ENDPOINT_ERRORS_BY_KIND: Lazy> = Lazy::new(|| { - register_hll_vec!( - 32, - "proxy_endpoints_affected_by_errors", - "Number of endpoints affected by errors of a given classification", - &["type"], - ) - .unwrap() -}); - -pub static REDIS_BROKEN_MESSAGES: Lazy = Lazy::new(|| { - register_int_counter_vec!( - "proxy_redis_errors_total", - "Number of errors by a given classification", - &["channel"], - ) - .unwrap() -}); - -pub static TLS_HANDSHAKE_FAILURES: Lazy = Lazy::new(|| { - register_int_counter!( - "proxy_tls_handshake_failures", - "Number of TLS handshake failures", - ) - .unwrap() -}); - -pub static ENDPOINTS_AUTH_RATE_LIMITED: Lazy> = Lazy::new(|| { - register_hll!( - 32, - "proxy_endpoints_auth_rate_limits", - "Number of endpoints affected by authentication rate limits", - ) - .unwrap() -}); - -pub static AUTH_RATE_LIMIT_HITS: Lazy = Lazy::new(|| { - register_int_counter!( - "proxy_requests_auth_rate_limits_total", - "Number of connection requests affected by authentication rate limits", - ) - .unwrap() -}); +#[derive(LabelGroup)] +#[label(set = InvalidEndpointsSet)] +pub struct InvalidEndpointsGroup { + pub protocol: Protocol, + pub rejected: Bool, + pub outcome: ConnectOutcome, +} diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 6051c0a812..4321bad968 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -7,6 +7,7 @@ pub mod handshake; pub mod passthrough; pub mod retry; pub mod wake_compute; +pub use copy_bidirectional::copy_bidirectional_client_compute; use crate::{ auth, @@ -15,16 +16,14 @@ use crate::{ config::{ProxyConfig, TlsConfig}, context::RequestMonitoring, error::ReportableError, - metrics::{NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE}, + metrics::{Metrics, NumClientConnectionsGuard}, protocol2::WithClientIp, proxy::handshake::{handshake, HandshakeData}, - rate_limiter::EndpointRateLimiter, stream::{PqStream, Stream}, EndpointCacheKey, }; use futures::TryFutureExt; use itertools::Itertools; -use metrics::IntCounterPairGuard; use once_cell::sync::OnceCell; use pq_proto::{BeMessage as Be, StartupMessageParams}; use regex::Regex; @@ -61,7 +60,6 @@ pub async fn task_main( config: &'static ProxyConfig, listener: tokio::net::TcpListener, cancellation_token: CancellationToken, - endpoint_rate_limiter: Arc, cancellation_handler: Arc, ) -> anyhow::Result<()> { scopeguard::defer! { @@ -79,13 +77,13 @@ pub async fn task_main( { let (socket, peer_addr) = accept_result?; - let conn_gauge = NUM_CLIENT_CONNECTION_GAUGE - .with_label_values(&["tcp"]) - .guard(); + let conn_gauge = Metrics::get() + .proxy + .client_connections + .guard(crate::metrics::Protocol::Tcp); let session_id = uuid::Uuid::new_v4(); let cancellation_handler = Arc::clone(&cancellation_handler); - let endpoint_rate_limiter = endpoint_rate_limiter.clone(); tracing::info!(protocol = "tcp", %session_id, "accepted new TCP connection"); @@ -113,7 +111,12 @@ pub async fn task_main( }, }; - let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region); + let mut ctx = RequestMonitoring::new( + session_id, + peer_addr, + crate::metrics::Protocol::Tcp, + &config.region, + ); let span = ctx.span.clone(); let res = handle_client( @@ -122,7 +125,6 @@ pub async fn task_main( cancellation_handler, socket, ClientMode::Tcp, - endpoint_rate_limiter, conn_gauge, ) .instrument(span.clone()) @@ -236,20 +238,23 @@ pub async fn handle_client( cancellation_handler: Arc, stream: S, mode: ClientMode, - endpoint_rate_limiter: Arc, - conn_gauge: IntCounterPairGuard, + conn_gauge: NumClientConnectionsGuard<'static>, ) -> Result>, ClientRequestError> { - info!("handling interactive connection from client"); + info!( + protocol = %ctx.protocol, + "handling interactive connection from client" + ); + let metrics = &Metrics::get().proxy; let proto = ctx.protocol; - let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE - .with_label_values(&[proto]) - .guard(); + // let _client_gauge = metrics.client_connections.guard(proto); + let _request_gauge = metrics.connection_requests.guard(proto); let tls = config.tls_config.as_ref(); + let record_handshake_error = !ctx.has_private_peer_addr(); let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Client); - let do_handshake = handshake(stream, mode.handshake_tls(tls)); + let do_handshake = handshake(stream, mode.handshake_tls(tls), record_handshake_error); let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? { HandshakeData::Startup(stream, params) => (stream, params), @@ -278,15 +283,6 @@ pub async fn handle_client( Err(e) => stream.throw_error(e).await?, }; - // check rate limit - if let Some(ep) = user_info.get_endpoint() { - if !endpoint_rate_limiter.check(ep, 1) { - return stream - .throw_error(auth::AuthError::too_many_connections()) - .await?; - } - } - let user = user_info.get_user().to_owned(); let user_info = match user_info .authenticate( diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index 4c0d68ce0b..33f394c550 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -4,7 +4,7 @@ use crate::{ console::{self, errors::WakeComputeError, CachedNodeInfo, NodeInfo}, context::RequestMonitoring, error::ReportableError, - metrics::NUM_CONNECTION_FAILURES, + metrics::{ConnectionFailureKind, Metrics}, proxy::{ retry::{retry_after, ShouldRetry}, wake_compute::wake_compute, @@ -27,10 +27,10 @@ pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo { warn!("invalidating stalled compute node info cache entry"); } let label = match is_cached { - true => "compute_cached", - false => "compute_uncached", + true => ConnectionFailureKind::ComputeCached, + false => ConnectionFailureKind::ComputeUncached, }; - NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc(); + Metrics::get().proxy.connection_failures_total.inc(label); node_info.invalidate() } diff --git a/proxy/src/proxy/copy_bidirectional.rs b/proxy/src/proxy/copy_bidirectional.rs index 684be74f9a..4b09ebd8dc 100644 --- a/proxy/src/proxy/copy_bidirectional.rs +++ b/proxy/src/proxy/copy_bidirectional.rs @@ -41,7 +41,7 @@ where } #[tracing::instrument(skip_all)] -pub(super) async fn copy_bidirectional_client_compute( +pub async fn copy_bidirectional_client_compute( client: &mut Client, compute: &mut Compute, ) -> Result<(u64, u64), std::io::Error> diff --git a/proxy/src/proxy/handshake.rs b/proxy/src/proxy/handshake.rs index 4665e07d23..dd935cc245 100644 --- a/proxy/src/proxy/handshake.rs +++ b/proxy/src/proxy/handshake.rs @@ -63,6 +63,7 @@ pub enum HandshakeData { pub async fn handshake( stream: S, mut tls: Option<&TlsConfig>, + record_handshake_error: bool, ) -> Result, HandshakeError> { // Client may try upgrading to each protocol only once let (mut tried_ssl, mut tried_gss) = (false, false); @@ -95,7 +96,9 @@ pub async fn handshake( if !read_buf.is_empty() { return Err(HandshakeError::EarlyData); } - let tls_stream = raw.upgrade(tls.to_server_config()).await?; + let tls_stream = raw + .upgrade(tls.to_server_config(), record_handshake_error) + .await?; let (_, tls_server_end_point) = tls .cert_resolver diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index c81a1a8292..62de79946f 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -2,11 +2,10 @@ use crate::{ cancellation, compute::PostgresConnection, console::messages::MetricsAuxInfo, - metrics::NUM_BYTES_PROXIED_COUNTER, + metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard}, stream::Stream, usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS}, }; -use metrics::IntCounterPairGuard; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::info; use utils::measured_stream::MeasuredStream; @@ -23,24 +22,25 @@ pub async fn proxy_pass( branch_id: aux.branch_id, }); - let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["tx"]); + let metrics = &Metrics::get().proxy.io_bytes; + let m_sent = metrics.with_labels(Direction::Tx); let mut client = MeasuredStream::new( client, |_| {}, |cnt| { // Number of bytes we sent to the client (outbound). - m_sent.inc_by(cnt as u64); + metrics.get_metric(m_sent).inc_by(cnt as u64); usage.record_egress(cnt as u64); }, ); - let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["rx"]); + let m_recv = metrics.with_labels(Direction::Rx); let mut compute = MeasuredStream::new( compute, |_| {}, |cnt| { // Number of bytes the client sent to the compute node (inbound). - m_recv.inc_by(cnt as u64); + metrics.get_metric(m_recv).inc_by(cnt as u64); }, ); @@ -60,8 +60,8 @@ pub struct ProxyPassthrough { pub compute: PostgresConnection, pub aux: MetricsAuxInfo, - pub req: IntCounterPairGuard, - pub conn: IntCounterPairGuard, + pub req: NumConnectionRequestsGuard<'static>, + pub conn: NumClientConnectionsGuard<'static>, pub cancel: cancellation::Session

, } diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 71d85e106d..849e9bd33c 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -175,7 +175,7 @@ async fn dummy_proxy( auth: impl TestAuth + Send, ) -> anyhow::Result<()> { let client = WithClientIp::new(client); - let mut stream = match handshake(client, tls.as_ref()).await? { + let mut stream = match handshake(client, tls.as_ref(), false).await? { HandshakeData::Startup(stream, _) => stream, HandshakeData::Cancel(_) => bail!("cancellation not supported"), }; diff --git a/proxy/src/proxy/tests/mitm.rs b/proxy/src/proxy/tests/mitm.rs index 3b760e5dab..cbfc9f1358 100644 --- a/proxy/src/proxy/tests/mitm.rs +++ b/proxy/src/proxy/tests/mitm.rs @@ -34,7 +34,10 @@ async fn proxy_mitm( tokio::spawn(async move { // begin handshake with end_server let end_server = connect_tls(server2, client_config2.make_tls_connect().unwrap()).await; - let (end_client, startup) = match handshake(client1, Some(&server_config1)).await.unwrap() { + let (end_client, startup) = match handshake(client1, Some(&server_config1), false) + .await + .unwrap() + { HandshakeData::Startup(stream, params) => (stream, params), HandshakeData::Cancel(_) => panic!("cancellation not supported"), }; diff --git a/proxy/src/proxy/wake_compute.rs b/proxy/src/proxy/wake_compute.rs index bfe4b7ec3a..fe228ab33d 100644 --- a/proxy/src/proxy/wake_compute.rs +++ b/proxy/src/proxy/wake_compute.rs @@ -1,6 +1,6 @@ use crate::console::{errors::WakeComputeError, provider::CachedNodeInfo}; use crate::context::RequestMonitoring; -use crate::metrics::{bool_to_str, NUM_WAKEUP_FAILURES}; +use crate::metrics::{ConnectionFailuresBreakdownGroup, Metrics, WakeupFailureKind}; use crate::proxy::retry::retry_after; use hyper::StatusCode; use std::ops::ControlFlow; @@ -57,39 +57,47 @@ pub fn handle_try_wake( fn report_error(e: &WakeComputeError, retry: bool) { use crate::console::errors::ApiError; - let retry = bool_to_str(retry); let kind = match e { - WakeComputeError::BadComputeAddress(_) => "bad_compute_address", - WakeComputeError::ApiError(ApiError::Transport(_)) => "api_transport_error", + WakeComputeError::BadComputeAddress(_) => WakeupFailureKind::BadComputeAddress, + WakeComputeError::ApiError(ApiError::Transport(_)) => WakeupFailureKind::ApiTransportError, WakeComputeError::ApiError(ApiError::Console { status: StatusCode::LOCKED, ref text, }) if text.contains("written data quota exceeded") || text.contains("the limit for current plan reached") => { - "quota_exceeded" + WakeupFailureKind::QuotaExceeded } WakeComputeError::ApiError(ApiError::Console { status: StatusCode::UNPROCESSABLE_ENTITY, ref text, }) if text.contains("compute time quota of non-primary branches is exceeded") => { - "quota_exceeded" + WakeupFailureKind::QuotaExceeded } WakeComputeError::ApiError(ApiError::Console { status: StatusCode::LOCKED, .. - }) => "api_console_locked", + }) => WakeupFailureKind::ApiConsoleLocked, WakeComputeError::ApiError(ApiError::Console { status: StatusCode::BAD_REQUEST, .. - }) => "api_console_bad_request", + }) => WakeupFailureKind::ApiConsoleBadRequest, WakeComputeError::ApiError(ApiError::Console { status, .. }) if status.is_server_error() => { - "api_console_other_server_error" + WakeupFailureKind::ApiConsoleOtherServerError } - WakeComputeError::ApiError(ApiError::Console { .. }) => "api_console_other_error", - WakeComputeError::TimeoutError => "timeout_error", + WakeComputeError::ApiError(ApiError::Console { .. }) => { + WakeupFailureKind::ApiConsoleOtherError + } + WakeComputeError::TooManyConnections => WakeupFailureKind::ApiConsoleLocked, + WakeComputeError::TimeoutError => WakeupFailureKind::TimeoutError, }; - NUM_WAKEUP_FAILURES.with_label_values(&[retry, kind]).inc(); + Metrics::get() + .proxy + .connection_failures_breakdown + .inc(ConnectionFailuresBreakdownGroup { + kind, + retry: retry.into(), + }); } diff --git a/proxy/src/rate_limiter.rs b/proxy/src/rate_limiter.rs index 13dffffca0..c542267547 100644 --- a/proxy/src/rate_limiter.rs +++ b/proxy/src/rate_limiter.rs @@ -1,7 +1,2 @@ -mod aimd; -mod limit_algorithm; mod limiter; -pub use aimd::Aimd; -pub use limit_algorithm::{AimdConfig, Fixed, RateLimitAlgorithm, RateLimiterConfig}; -pub use limiter::Limiter; -pub use limiter::{AuthRateLimiter, EndpointRateLimiter, RateBucketInfo, RedisRateLimiter}; +pub use limiter::{BucketRateLimiter, EndpointRateLimiter, GlobalRateLimiter, RateBucketInfo}; diff --git a/proxy/src/rate_limiter/aimd.rs b/proxy/src/rate_limiter/aimd.rs deleted file mode 100644 index 2c14a54a6c..0000000000 --- a/proxy/src/rate_limiter/aimd.rs +++ /dev/null @@ -1,166 +0,0 @@ -use std::usize; - -use async_trait::async_trait; - -use super::limit_algorithm::{AimdConfig, LimitAlgorithm, Sample}; - -use super::limiter::Outcome; - -/// Loss-based congestion avoidance. -/// -/// Additive-increase, multiplicative decrease. -/// -/// Adds available currency when: -/// 1. no load-based errors are observed, and -/// 2. the utilisation of the current limit is high. -/// -/// Reduces available concurrency by a factor when load-based errors are detected. -pub struct Aimd { - min_limit: usize, - max_limit: usize, - decrease_factor: f32, - increase_by: usize, - min_utilisation_threshold: f32, -} - -impl Aimd { - pub fn new(config: AimdConfig) -> Self { - Self { - min_limit: config.aimd_min_limit, - max_limit: config.aimd_max_limit, - decrease_factor: config.aimd_decrease_factor, - increase_by: config.aimd_increase_by, - min_utilisation_threshold: config.aimd_min_utilisation_threshold, - } - } -} - -#[async_trait] -impl LimitAlgorithm for Aimd { - async fn update(&mut self, old_limit: usize, sample: Sample) -> usize { - use Outcome::*; - match sample.outcome { - Success => { - let utilisation = sample.in_flight as f32 / old_limit as f32; - - if utilisation > self.min_utilisation_threshold { - let limit = old_limit + self.increase_by; - limit.clamp(self.min_limit, self.max_limit) - } else { - old_limit - } - } - Overload => { - let limit = old_limit as f32 * self.decrease_factor; - - // Floor instead of round, so the limit reduces even with small numbers. - // E.g. round(2 * 0.9) = 2, but floor(2 * 0.9) = 1 - let limit = limit.floor() as usize; - - limit.clamp(self.min_limit, self.max_limit) - } - } - } -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use tokio::sync::Notify; - - use super::*; - - use crate::rate_limiter::{Limiter, RateLimiterConfig}; - - #[tokio::test] - async fn should_decrease_limit_on_overload() { - let config = RateLimiterConfig { - initial_limit: 10, - aimd_config: Some(AimdConfig { - aimd_decrease_factor: 0.5, - ..Default::default() - }), - disable: false, - ..Default::default() - }; - - let release_notifier = Arc::new(Notify::new()); - - let limiter = Limiter::new(config).with_release_notifier(release_notifier.clone()); - - let token = limiter.try_acquire().unwrap(); - limiter.release(token, Some(Outcome::Overload)).await; - release_notifier.notified().await; - assert_eq!(limiter.state().limit(), 5, "overload: decrease"); - } - - #[tokio::test] - async fn should_increase_limit_on_success_when_using_gt_util_threshold() { - let config = RateLimiterConfig { - initial_limit: 4, - aimd_config: Some(AimdConfig { - aimd_decrease_factor: 0.5, - aimd_min_utilisation_threshold: 0.5, - aimd_increase_by: 1, - ..Default::default() - }), - disable: false, - ..Default::default() - }; - - let limiter = Limiter::new(config); - - let token = limiter.try_acquire().unwrap(); - let _token = limiter.try_acquire().unwrap(); - let _token = limiter.try_acquire().unwrap(); - - limiter.release(token, Some(Outcome::Success)).await; - assert_eq!(limiter.state().limit(), 5, "success: increase"); - } - - #[tokio::test] - async fn should_not_change_limit_on_success_when_using_lt_util_threshold() { - let config = RateLimiterConfig { - initial_limit: 4, - aimd_config: Some(AimdConfig { - aimd_decrease_factor: 0.5, - aimd_min_utilisation_threshold: 0.5, - ..Default::default() - }), - disable: false, - ..Default::default() - }; - - let limiter = Limiter::new(config); - - let token = limiter.try_acquire().unwrap(); - - limiter.release(token, Some(Outcome::Success)).await; - assert_eq!( - limiter.state().limit(), - 4, - "success: ignore when < half limit" - ); - } - - #[tokio::test] - async fn should_not_change_limit_when_no_outcome() { - let config = RateLimiterConfig { - initial_limit: 10, - aimd_config: Some(AimdConfig { - aimd_decrease_factor: 0.5, - aimd_min_utilisation_threshold: 0.5, - ..Default::default() - }), - disable: false, - ..Default::default() - }; - - let limiter = Limiter::new(config); - - let token = limiter.try_acquire().unwrap(); - limiter.release(token, None).await; - assert_eq!(limiter.state().limit(), 10, "ignore"); - } -} diff --git a/proxy/src/rate_limiter/limit_algorithm.rs b/proxy/src/rate_limiter/limit_algorithm.rs deleted file mode 100644 index 5cd2d5ebb7..0000000000 --- a/proxy/src/rate_limiter/limit_algorithm.rs +++ /dev/null @@ -1,98 +0,0 @@ -//! Algorithms for controlling concurrency limits. -use async_trait::async_trait; -use std::time::Duration; - -use super::{limiter::Outcome, Aimd}; - -/// An algorithm for controlling a concurrency limit. -#[async_trait] -pub trait LimitAlgorithm: Send + Sync + 'static { - /// Update the concurrency limit in response to a new job completion. - async fn update(&mut self, old_limit: usize, sample: Sample) -> usize; -} - -/// The result of a job (or jobs), including the [Outcome] (loss) and latency (delay). -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Sample { - pub(crate) latency: Duration, - /// Jobs in flight when the sample was taken. - pub(crate) in_flight: usize, - pub(crate) outcome: Outcome, -} - -#[derive(Clone, Copy, Debug, Default, clap::ValueEnum)] -pub enum RateLimitAlgorithm { - Fixed, - #[default] - Aimd, -} - -pub struct Fixed; - -#[async_trait] -impl LimitAlgorithm for Fixed { - async fn update(&mut self, old_limit: usize, _sample: Sample) -> usize { - old_limit - } -} - -#[derive(Clone, Copy, Debug)] -pub struct RateLimiterConfig { - pub disable: bool, - pub algorithm: RateLimitAlgorithm, - pub timeout: Duration, - pub initial_limit: usize, - pub aimd_config: Option, -} - -impl RateLimiterConfig { - pub fn create_rate_limit_algorithm(self) -> Box { - match self.algorithm { - RateLimitAlgorithm::Fixed => Box::new(Fixed), - RateLimitAlgorithm::Aimd => Box::new(Aimd::new(self.aimd_config.unwrap())), // For aimd algorithm config is mandatory. - } - } -} - -impl Default for RateLimiterConfig { - fn default() -> Self { - Self { - disable: true, - algorithm: RateLimitAlgorithm::Aimd, - timeout: Duration::from_secs(1), - initial_limit: 100, - aimd_config: Some(AimdConfig::default()), - } - } -} - -#[derive(clap::Parser, Clone, Copy, Debug)] -pub struct AimdConfig { - /// Minimum limit for AIMD algorithm. Makes sense only if `rate_limit_algorithm` is `Aimd`. - #[clap(long, default_value_t = 1)] - pub aimd_min_limit: usize, - /// Maximum limit for AIMD algorithm. Makes sense only if `rate_limit_algorithm` is `Aimd`. - #[clap(long, default_value_t = 1500)] - pub aimd_max_limit: usize, - /// Increase AIMD increase by value in case of success. Makes sense only if `rate_limit_algorithm` is `Aimd`. - #[clap(long, default_value_t = 10)] - pub aimd_increase_by: usize, - /// Decrease AIMD decrease by value in case of timout/429. Makes sense only if `rate_limit_algorithm` is `Aimd`. - #[clap(long, default_value_t = 0.9)] - pub aimd_decrease_factor: f32, - /// A threshold below which the limit won't be increased. Makes sense only if `rate_limit_algorithm` is `Aimd`. - #[clap(long, default_value_t = 0.8)] - pub aimd_min_utilisation_threshold: f32, -} - -impl Default for AimdConfig { - fn default() -> Self { - Self { - aimd_min_limit: 1, - aimd_max_limit: 1500, - aimd_increase_by: 10, - aimd_decrease_factor: 0.9, - aimd_min_utilisation_threshold: 0.8, - } - } -} diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index f590896dd9..5ba2c36436 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -2,10 +2,9 @@ use std::{ borrow::Cow, collections::hash_map::RandomState, hash::{BuildHasher, Hash}, - net::IpAddr, sync::{ atomic::{AtomicUsize, Ordering}, - Arc, Mutex, + Mutex, }, }; @@ -13,24 +12,18 @@ use anyhow::bail; use dashmap::DashMap; use itertools::Itertools; use rand::{rngs::StdRng, Rng, SeedableRng}; -use tokio::sync::{Mutex as AsyncMutex, Semaphore, SemaphorePermit}; -use tokio::time::{timeout, Duration, Instant}; +use tokio::time::{Duration, Instant}; use tracing::info; -use crate::{intern::EndpointIdInt, EndpointId}; +use crate::intern::EndpointIdInt; -use super::{ - limit_algorithm::{LimitAlgorithm, Sample}, - RateLimiterConfig, -}; - -pub struct RedisRateLimiter { +pub struct GlobalRateLimiter { data: Vec, - info: &'static [RateBucketInfo], + info: Vec, } -impl RedisRateLimiter { - pub fn new(info: &'static [RateBucketInfo]) -> Self { +impl GlobalRateLimiter { + pub fn new(info: Vec) -> Self { Self { data: vec![ RateBucket { @@ -50,7 +43,7 @@ impl RedisRateLimiter { let should_allow_request = self .data .iter_mut() - .zip(self.info) + .zip(&self.info) .all(|(bucket, info)| bucket.should_allow_request(info, now, 1)); if should_allow_request { @@ -68,15 +61,7 @@ impl RedisRateLimiter { // Purposefully ignore user name and database name as clients can reconnect // with different names, so we'll end up sending some http requests to // the control plane. -// -// We also may save quite a lot of CPU (I think) by bailing out right after we -// saw SNI, before doing TLS handshake. User-side error messages in that case -// does not look very nice (`SSL SYSCALL error: Undefined error: 0`), so for now -// I went with a more expensive way that yields user-friendlier error messages. -pub type EndpointRateLimiter = BucketRateLimiter; - -// This can't be just per IP because that would limit some PaaS that share IP addresses -pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, IpAddr), StdRng, RandomState>; +pub type EndpointRateLimiter = BucketRateLimiter; pub struct BucketRateLimiter { map: DashMap, Hasher>, @@ -149,19 +134,6 @@ impl RateBucketInfo { Self::new(100, Duration::from_secs(600)), ]; - /// All of these are per endpoint-ip pair. - /// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus). - /// - /// First bucket: 300mcpus total per endpoint-ip pair - /// * 1228800 requests per second with 1 hash rounds. (endpoint rate limiter will catch this first) - /// * 300 requests per second with 4096 hash rounds. - /// * 2 requests per second with 600000 hash rounds. - pub const DEFAULT_AUTH_SET: [Self; 3] = [ - Self::new(300 * 4096, Duration::from_secs(1)), - Self::new(200 * 4096, Duration::from_secs(60)), - Self::new(100 * 4096, Duration::from_secs(600)), - ]; - pub fn validate(info: &mut [Self]) -> anyhow::Result<()> { info.sort_unstable_by_key(|info| info.interval); let invalid = info @@ -259,419 +231,16 @@ impl BucketRateLimiter { } } -/// Limits the number of concurrent jobs. -/// -/// Concurrency is limited through the use of [Token]s. Acquire a token to run a job, and release the -/// token once the job is finished. -/// -/// The limit will be automatically adjusted based on observed latency (delay) and/or failures -/// caused by overload (loss). -pub struct Limiter { - limit_algo: AsyncMutex>, - semaphore: std::sync::Arc, - config: RateLimiterConfig, - - // ONLY WRITE WHEN LIMIT_ALGO IS LOCKED - limits: AtomicUsize, - - // ONLY USE ATOMIC ADD/SUB - in_flight: Arc, - - #[cfg(test)] - notifier: Option>, -} - -/// A concurrency token, required to run a job. -/// -/// Release the token back to the [Limiter] after the job is complete. -#[derive(Debug)] -pub struct Token<'t> { - permit: Option>, - start: Instant, - in_flight: Arc, -} - -/// A snapshot of the state of the [Limiter]. -/// -/// Not guaranteed to be consistent under high concurrency. -#[derive(Debug, Clone, Copy)] -pub struct LimiterState { - limit: usize, - in_flight: usize, -} - -/// Whether a job succeeded or failed as a result of congestion/overload. -/// -/// Errors not considered to be caused by overload should be ignored. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum Outcome { - /// The job succeeded, or failed in a way unrelated to overload. - Success, - /// The job failed because of overload, e.g. it timed out or an explicit backpressure signal - /// was observed. - Overload, -} - -impl Outcome { - fn from_reqwest_error(error: &reqwest_middleware::Error) -> Self { - match error { - reqwest_middleware::Error::Middleware(_) => Outcome::Success, - reqwest_middleware::Error::Reqwest(e) => { - if let Some(status) = e.status() { - if status.is_server_error() - || reqwest::StatusCode::TOO_MANY_REQUESTS.as_u16() == status - { - Outcome::Overload - } else { - Outcome::Success - } - } else { - Outcome::Success - } - } - } - } - fn from_reqwest_response(response: &reqwest::Response) -> Self { - if response.status().is_server_error() - || response.status() == reqwest::StatusCode::TOO_MANY_REQUESTS - { - Outcome::Overload - } else { - Outcome::Success - } - } -} - -impl Limiter { - /// Create a limiter with a given limit control algorithm. - pub fn new(config: RateLimiterConfig) -> Self { - assert!(config.initial_limit > 0); - Self { - limit_algo: AsyncMutex::new(config.create_rate_limit_algorithm()), - semaphore: Arc::new(Semaphore::new(config.initial_limit)), - config, - limits: AtomicUsize::new(config.initial_limit), - in_flight: Arc::new(AtomicUsize::new(0)), - #[cfg(test)] - notifier: None, - } - } - // pub fn new(limit_algorithm: T, timeout: Duration, initial_limit: usize) -> Self { - // assert!(initial_limit > 0); - - // Self { - // limit_algo: AsyncMutex::new(limit_algorithm), - // semaphore: Arc::new(Semaphore::new(initial_limit)), - // timeout, - // limits: AtomicUsize::new(initial_limit), - // in_flight: Arc::new(AtomicUsize::new(0)), - // #[cfg(test)] - // notifier: None, - // } - // } - - /// In some cases [Token]s are acquired asynchronously when updating the limit. - #[cfg(test)] - pub fn with_release_notifier(mut self, n: std::sync::Arc) -> Self { - self.notifier = Some(n); - self - } - - /// Try to immediately acquire a concurrency [Token]. - /// - /// Returns `None` if there are none available. - pub fn try_acquire(&self) -> Option { - let result = if self.config.disable { - // If the rate limiter is disabled, we can always acquire a token. - Some(Token::new(None, self.in_flight.clone())) - } else { - self.semaphore - .try_acquire() - .map(|permit| Token::new(Some(permit), self.in_flight.clone())) - .ok() - }; - if result.is_some() { - self.in_flight.fetch_add(1, Ordering::AcqRel); - } - result - } - - /// Try to acquire a concurrency [Token], waiting for `duration` if there are none available. - /// - /// Returns `None` if there are none available after `duration`. - pub async fn acquire_timeout(&self, duration: Duration) -> Option> { - info!("acquiring token: {:?}", self.semaphore.available_permits()); - let result = if self.config.disable { - // If the rate limiter is disabled, we can always acquire a token. - Some(Token::new(None, self.in_flight.clone())) - } else { - match timeout(duration, self.semaphore.acquire()).await { - Ok(maybe_permit) => maybe_permit - .map(|permit| Token::new(Some(permit), self.in_flight.clone())) - .ok(), - Err(_) => None, - } - }; - if result.is_some() { - self.in_flight.fetch_add(1, Ordering::AcqRel); - } - result - } - - /// Return the concurrency [Token], along with the outcome of the job. - /// - /// The [Outcome] of the job, and the time taken to perform it, may be used - /// to update the concurrency limit. - /// - /// Set the outcome to `None` to ignore the job. - pub async fn release(&self, mut token: Token<'_>, outcome: Option) { - tracing::info!("outcome is {:?}", outcome); - let in_flight = self.in_flight.load(Ordering::Acquire); - let old_limit = self.limits.load(Ordering::Acquire); - let available = if self.config.disable { - 0 // This is not used in the algorithm and can be anything. If the config disable it makes sense to set it to 0. - } else { - self.semaphore.available_permits() - }; - let total = in_flight + available; - - let mut algo = self.limit_algo.lock().await; - - let new_limit = if let Some(outcome) = outcome { - let sample = Sample { - latency: token.start.elapsed(), - in_flight, - outcome, - }; - algo.update(old_limit, sample).await - } else { - old_limit - }; - tracing::info!("new limit is {}", new_limit); - let actual_limit = if new_limit < total { - token.forget(); - total.saturating_sub(1) - } else { - if !self.config.disable { - self.semaphore.add_permits(new_limit.saturating_sub(total)); - } - new_limit - }; - crate::metrics::RATE_LIMITER_LIMIT - .with_label_values(&["expected"]) - .set(new_limit as i64); - crate::metrics::RATE_LIMITER_LIMIT - .with_label_values(&["actual"]) - .set(actual_limit as i64); - self.limits.store(new_limit, Ordering::Release); - #[cfg(test)] - if let Some(n) = &self.notifier { - n.notify_one(); - } - } - - /// The current state of the limiter. - pub fn state(&self) -> LimiterState { - let limit = self.limits.load(Ordering::Relaxed); - let in_flight = self.in_flight.load(Ordering::Relaxed); - LimiterState { limit, in_flight } - } -} - -impl<'t> Token<'t> { - fn new(permit: Option>, in_flight: Arc) -> Self { - Self { - permit, - start: Instant::now(), - in_flight, - } - } - - pub fn forget(&mut self) { - if let Some(permit) = self.permit.take() { - permit.forget(); - } - } -} - -impl Drop for Token<'_> { - fn drop(&mut self) { - self.in_flight.fetch_sub(1, Ordering::AcqRel); - } -} - -impl LimiterState { - /// The current concurrency limit. - pub fn limit(&self) -> usize { - self.limit - } - /// The number of jobs in flight. - pub fn in_flight(&self) -> usize { - self.in_flight - } -} - -#[async_trait::async_trait] -impl reqwest_middleware::Middleware for Limiter { - async fn handle( - &self, - req: reqwest::Request, - extensions: &mut task_local_extensions::Extensions, - next: reqwest_middleware::Next<'_>, - ) -> reqwest_middleware::Result { - let start = Instant::now(); - let token = self - .acquire_timeout(self.config.timeout) - .await - .ok_or_else(|| { - reqwest_middleware::Error::Middleware( - // TODO: Should we map it into user facing errors? - crate::console::errors::ApiError::Console { - status: crate::http::StatusCode::TOO_MANY_REQUESTS, - text: "Too many requests".into(), - } - .into(), - ) - })?; - info!(duration = ?start.elapsed(), "waiting for token to connect to the control plane"); - crate::metrics::RATE_LIMITER_ACQUIRE_LATENCY.observe(start.elapsed().as_secs_f64()); - match next.run(req, extensions).await { - Ok(response) => { - self.release(token, Some(Outcome::from_reqwest_response(&response))) - .await; - Ok(response) - } - Err(e) => { - self.release(token, Some(Outcome::from_reqwest_error(&e))) - .await; - Err(e) - } - } - } -} - #[cfg(test)] mod tests { - use std::{hash::BuildHasherDefault, pin::pin, task::Context, time::Duration}; + use std::{hash::BuildHasherDefault, time::Duration}; - use futures::{task::noop_waker_ref, Future}; use rand::SeedableRng; use rustc_hash::FxHasher; use tokio::time; - use super::{BucketRateLimiter, EndpointRateLimiter, Limiter, Outcome}; - use crate::{ - rate_limiter::{RateBucketInfo, RateLimitAlgorithm}, - EndpointId, - }; - - #[tokio::test] - async fn it_works() { - let config = super::RateLimiterConfig { - algorithm: RateLimitAlgorithm::Fixed, - timeout: Duration::from_secs(1), - initial_limit: 10, - disable: false, - ..Default::default() - }; - let limiter = Limiter::new(config); - - let token = limiter.try_acquire().unwrap(); - - limiter.release(token, Some(Outcome::Success)).await; - - assert_eq!(limiter.state().limit(), 10); - } - - #[tokio::test] - async fn is_fair() { - let config = super::RateLimiterConfig { - algorithm: RateLimitAlgorithm::Fixed, - timeout: Duration::from_secs(1), - initial_limit: 1, - disable: false, - ..Default::default() - }; - let limiter = Limiter::new(config); - - // === TOKEN 1 === - let token1 = limiter.try_acquire().unwrap(); - - let mut token2_fut = pin!(limiter.acquire_timeout(Duration::from_secs(1))); - assert!( - token2_fut - .as_mut() - .poll(&mut Context::from_waker(noop_waker_ref())) - .is_pending(), - "token is acquired by token1" - ); - - let mut token3_fut = pin!(limiter.acquire_timeout(Duration::from_secs(1))); - assert!( - token3_fut - .as_mut() - .poll(&mut Context::from_waker(noop_waker_ref())) - .is_pending(), - "token is acquired by token1" - ); - - limiter.release(token1, Some(Outcome::Success)).await; - // === END TOKEN 1 === - - // === TOKEN 2 === - assert!( - limiter.try_acquire().is_none(), - "token is acquired by token2" - ); - - assert!( - token3_fut - .as_mut() - .poll(&mut Context::from_waker(noop_waker_ref())) - .is_pending(), - "token is acquired by token2" - ); - - let token2 = token2_fut.await.unwrap(); - - limiter.release(token2, Some(Outcome::Success)).await; - // === END TOKEN 2 === - - // === TOKEN 3 === - assert!( - limiter.try_acquire().is_none(), - "token is acquired by token3" - ); - - let token3 = token3_fut.await.unwrap(); - limiter.release(token3, Some(Outcome::Success)).await; - // === END TOKEN 3 === - - // === TOKEN 4 === - let token4 = limiter.try_acquire().unwrap(); - limiter.release(token4, Some(Outcome::Success)).await; - } - - #[tokio::test] - async fn disable() { - let config = super::RateLimiterConfig { - algorithm: RateLimitAlgorithm::Fixed, - timeout: Duration::from_secs(1), - initial_limit: 1, - disable: true, - ..Default::default() - }; - let limiter = Limiter::new(config); - - // === TOKEN 1 === - let token1 = limiter.try_acquire().unwrap(); - let token2 = limiter.try_acquire().unwrap(); - let state = limiter.state(); - assert_eq!(state.limit(), 1); - assert_eq!(state.in_flight(), 2); // For disabled limiter, it's expected. - limiter.release(token1, None).await; - limiter.release(token2, None).await; - } + use super::{BucketRateLimiter, EndpointRateLimiter}; + use crate::{intern::EndpointIdInt, rate_limiter::RateBucketInfo, EndpointId}; #[test] fn rate_bucket_rpi() { @@ -721,39 +290,40 @@ mod tests { let limiter = EndpointRateLimiter::new(rates); let endpoint = EndpointId::from("ep-my-endpoint-1234"); + let endpoint = EndpointIdInt::from(endpoint); time::pause(); for _ in 0..100 { - assert!(limiter.check(endpoint.clone(), 1)); + assert!(limiter.check(endpoint, 1)); } // more connections fail - assert!(!limiter.check(endpoint.clone(), 1)); + assert!(!limiter.check(endpoint, 1)); // fail even after 500ms as it's in the same bucket time::advance(time::Duration::from_millis(500)).await; - assert!(!limiter.check(endpoint.clone(), 1)); + assert!(!limiter.check(endpoint, 1)); // after a full 1s, 100 requests are allowed again time::advance(time::Duration::from_millis(500)).await; for _ in 1..6 { for _ in 0..50 { - assert!(limiter.check(endpoint.clone(), 2)); + assert!(limiter.check(endpoint, 2)); } time::advance(time::Duration::from_millis(1000)).await; } // more connections after 600 will exceed the 20rps@30s limit - assert!(!limiter.check(endpoint.clone(), 1)); + assert!(!limiter.check(endpoint, 1)); // will still fail before the 30 second limit time::advance(time::Duration::from_millis(30_000 - 6_000 - 1)).await; - assert!(!limiter.check(endpoint.clone(), 1)); + assert!(!limiter.check(endpoint, 1)); // after the full 30 seconds, 100 requests are allowed again time::advance(time::Duration::from_millis(1)).await; for _ in 0..100 { - assert!(limiter.check(endpoint.clone(), 1)); + assert!(limiter.check(endpoint, 1)); } } @@ -773,31 +343,4 @@ mod tests { } assert!(limiter.map.len() < 150_000); } - - #[test] - fn test_default_auth_set() { - // these values used to exceed u32::MAX - assert_eq!( - RateBucketInfo::DEFAULT_AUTH_SET, - [ - RateBucketInfo { - interval: Duration::from_secs(1), - max_rpi: 300 * 4096, - }, - RateBucketInfo { - interval: Duration::from_secs(60), - max_rpi: 200 * 4096 * 60, - }, - RateBucketInfo { - interval: Duration::from_secs(600), - max_rpi: 100 * 4096 * 600, - } - ] - ); - - for x in RateBucketInfo::DEFAULT_AUTH_SET { - let y = x.to_string().parse().unwrap(); - assert_eq!(x, y); - } - } } diff --git a/proxy/src/redis/cancellation_publisher.rs b/proxy/src/redis/cancellation_publisher.rs index 422789813c..7baf104374 100644 --- a/proxy/src/redis/cancellation_publisher.rs +++ b/proxy/src/redis/cancellation_publisher.rs @@ -5,7 +5,7 @@ use redis::AsyncCommands; use tokio::sync::Mutex; use uuid::Uuid; -use crate::rate_limiter::{RateBucketInfo, RedisRateLimiter}; +use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo}; use super::{ connection_with_credentials_provider::ConnectionWithCredentialsProvider, @@ -80,7 +80,7 @@ impl CancellationPublisher for Arc> { pub struct RedisPublisherClient { client: ConnectionWithCredentialsProvider, region_id: String, - limiter: RedisRateLimiter, + limiter: GlobalRateLimiter, } impl RedisPublisherClient { @@ -92,7 +92,7 @@ impl RedisPublisherClient { Ok(Self { client, region_id, - limiter: RedisRateLimiter::new(info), + limiter: GlobalRateLimiter::new(info.into()), }) } diff --git a/proxy/src/redis/connection_with_credentials_provider.rs b/proxy/src/redis/connection_with_credentials_provider.rs index d183abb53a..3a90d911c2 100644 --- a/proxy/src/redis/connection_with_credentials_provider.rs +++ b/proxy/src/redis/connection_with_credentials_provider.rs @@ -77,10 +77,14 @@ impl ConnectionWithCredentialsProvider { } } + async fn ping(con: &mut MultiplexedConnection) -> RedisResult<()> { + redis::cmd("PING").query_async(con).await + } + pub async fn connect(&mut self) -> anyhow::Result<()> { let _guard = self.mutex.lock().await; if let Some(con) = self.con.as_mut() { - match redis::cmd("PING").query_async(con).await { + match Self::ping(con).await { Ok(()) => { return Ok(()); } @@ -96,7 +100,7 @@ impl ConnectionWithCredentialsProvider { if let Some(f) = self.refresh_token_task.take() { f.abort() } - let con = self + let mut con = self .get_client() .await? .get_multiplexed_tokio_connection() @@ -109,6 +113,14 @@ impl ConnectionWithCredentialsProvider { }); self.refresh_token_task = Some(f); } + match Self::ping(&mut con).await { + Ok(()) => { + info!("Connection succesfully established"); + } + Err(e) => { + error!("Connection is broken. Error during PING: {e:?}"); + } + } self.con = Some(con); Ok(()) } diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 8b7e3e3419..5a38530faf 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -11,7 +11,7 @@ use crate::{ cache::project_info::ProjectInfoCache, cancellation::{CancelMap, CancellationHandler}, intern::{ProjectIdInt, RoleNameInt}, - metrics::{NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS, REDIS_BROKEN_MESSAGES}, + metrics::{Metrics, RedisErrors}, }; const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates"; @@ -104,9 +104,9 @@ impl MessageHandler { let msg: Notification = match serde_json::from_str(&payload) { Ok(msg) => msg, Err(e) => { - REDIS_BROKEN_MESSAGES - .with_label_values(&[msg.get_channel_name()]) - .inc(); + Metrics::get().proxy.redis_errors_total.inc(RedisErrors { + channel: msg.get_channel_name(), + }); tracing::error!("broken message: {e}"); return Ok(()); } @@ -183,7 +183,7 @@ where cache, Arc::new(CancellationHandler::<()>::new( cancel_map, - NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS, + crate::metrics::CancellationSource::FromRedis, )), region_id, ); diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index f275caa7eb..b0f4026c76 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -32,10 +32,9 @@ use tokio_util::task::TaskTracker; use crate::cancellation::CancellationHandlerMain; use crate::config::ProxyConfig; use crate::context::RequestMonitoring; -use crate::metrics::{NUM_CLIENT_CONNECTION_GAUGE, TLS_HANDSHAKE_FAILURES}; +use crate::metrics::Metrics; use crate::protocol2::WithClientIp; use crate::proxy::run_until_cancelled; -use crate::rate_limiter::EndpointRateLimiter; use crate::serverless::backend::PoolingBackend; use crate::serverless::http_util::{api_error_into_response, json_response}; @@ -53,7 +52,6 @@ pub async fn task_main( config: &'static ProxyConfig, ws_listener: TcpListener, cancellation_token: CancellationToken, - endpoint_rate_limiter: Arc, cancellation_handler: Arc, ) -> anyhow::Result<()> { scopeguard::defer! { @@ -117,7 +115,6 @@ pub async fn task_main( backend.clone(), connections.clone(), cancellation_handler.clone(), - endpoint_rate_limiter.clone(), cancellation_token.clone(), server.clone(), tls_acceptor.clone(), @@ -147,7 +144,6 @@ async fn connection_handler( backend: Arc, connections: TaskTracker, cancellation_handler: Arc, - endpoint_rate_limiter: Arc, cancellation_token: CancellationToken, server: Builder, tls_acceptor: TlsAcceptor, @@ -156,9 +152,10 @@ async fn connection_handler( ) { let session_id = uuid::Uuid::new_v4(); - let _gauge = NUM_CLIENT_CONNECTION_GAUGE - .with_label_values(&["http"]) - .guard(); + let _gauge = Metrics::get() + .proxy + .client_connections + .guard(crate::metrics::Protocol::Http); // handle PROXY protocol let mut conn = WithClientIp::new(conn); @@ -171,6 +168,10 @@ async fn connection_handler( }; let peer_addr = peer.unwrap_or(peer_addr).ip(); + let has_private_peer_addr = match peer_addr { + IpAddr::V4(ip) => ip.is_private(), + _ => false, + }; info!(?session_id, %peer_addr, "accepted new TCP connection"); // try upgrade to TLS, but with a timeout. @@ -181,13 +182,17 @@ async fn connection_handler( } // The handshake failed Ok(Err(e)) => { - TLS_HANDSHAKE_FAILURES.inc(); + if !has_private_peer_addr { + Metrics::get().proxy.tls_handshake_failures.inc(); + } warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}"); return; } // The handshake timed out Err(e) => { - TLS_HANDSHAKE_FAILURES.inc(); + if !has_private_peer_addr { + Metrics::get().proxy.tls_handshake_failures.inc(); + } warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}"); return; } @@ -222,7 +227,6 @@ async fn connection_handler( cancellation_handler.clone(), session_id, peer_addr, - endpoint_rate_limiter.clone(), http_request_token, ) .in_current_span() @@ -261,7 +265,6 @@ async fn request_handler( cancellation_handler: Arc, session_id: uuid::Uuid, peer_addr: IpAddr, - endpoint_rate_limiter: Arc, // used to cancel in-flight HTTP requests. not used to cancel websockets http_cancellation_token: CancellationToken, ) -> Result>, ApiError> { @@ -274,7 +277,13 @@ async fn request_handler( // Check if the request is a websocket upgrade request. if hyper_tungstenite::is_upgrade_request(&request) { - let ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region); + let ctx = RequestMonitoring::new( + session_id, + peer_addr, + crate::metrics::Protocol::Ws, + &config.region, + ); + let span = ctx.span.clone(); info!(parent: &span, "performing websocket upgrade"); @@ -283,15 +292,9 @@ async fn request_handler( ws_connections.spawn( async move { - if let Err(e) = websocket::serve_websocket( - config, - ctx, - websocket, - cancellation_handler, - host, - endpoint_rate_limiter, - ) - .await + if let Err(e) = + websocket::serve_websocket(config, ctx, websocket, cancellation_handler, host) + .await { error!("error in websocket connection: {e:#}"); } @@ -302,7 +305,12 @@ async fn request_handler( // Return the response so the spawned future can continue. Ok(response) } else if request.uri().path() == "/sql" && *request.method() == Method::POST { - let ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region); + let ctx = RequestMonitoring::new( + session_id, + peer_addr, + crate::metrics::Protocol::Http, + &config.region, + ); let span = ctx.span.clone(); sql_over_http::handle(config, ctx, request, backend, http_cancellation_token) diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 8aa5ad4e8a..e74c63599a 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -6,7 +6,7 @@ use tracing::{field::display, info}; use crate::{ auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError}, compute, - config::ProxyConfig, + config::{AuthenticationConfig, ProxyConfig}, console::{ errors::{GetAuthInfoError, WakeComputeError}, CachedNodeInfo, @@ -27,6 +27,7 @@ impl PoolingBackend { pub async fn authenticate( &self, ctx: &mut RequestMonitoring, + config: &AuthenticationConfig, conn_info: &ConnInfo, ) -> Result { let user_info = conn_info.user_info.clone(); @@ -43,6 +44,7 @@ impl PoolingBackend { let secret = match cached_secret.value.clone() { Some(secret) => self.config.authentication_config.check_rate_limit( ctx, + config, secret, &user_info.endpoint, true, diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 0a89c56b6f..798e488509 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -1,6 +1,5 @@ use dashmap::DashMap; use futures::{future::poll_fn, Future}; -use metrics::IntCounterPairGuard; use parking_lot::RwLock; use rand::Rng; use smallvec::SmallVec; @@ -19,11 +18,10 @@ use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket}; use tokio_util::sync::CancellationToken; use crate::console::messages::{ColdStartInfo, MetricsAuxInfo}; -use crate::metrics::{ENDPOINT_POOLS, GC_LATENCY, NUM_OPEN_CLIENTS_IN_HTTP_POOL}; +use crate::metrics::{HttpEndpointPoolsGuard, Metrics}; use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; use crate::{ - auth::backend::ComputeUserInfo, context::RequestMonitoring, metrics::NUM_DB_CONNECTIONS_GAUGE, - DbName, EndpointCacheKey, RoleName, + auth::backend::ComputeUserInfo, context::RequestMonitoring, DbName, EndpointCacheKey, RoleName, }; use tracing::{debug, error, warn, Span}; @@ -79,7 +77,7 @@ pub struct EndpointConnPool { pools: HashMap<(DbName, RoleName), DbUserConnPool>, total_conns: usize, max_conns: usize, - _guard: IntCounterPairGuard, + _guard: HttpEndpointPoolsGuard<'static>, global_connections_count: Arc, global_pool_size_max_conns: usize, } @@ -111,7 +109,11 @@ impl EndpointConnPool { let removed = old_len - new_len; if removed > 0 { global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed); - NUM_OPEN_CLIENTS_IN_HTTP_POOL.sub(removed as i64); + Metrics::get() + .proxy + .http_pool_opened_connections + .get_metric() + .dec_by(removed as i64); } *total_conns -= removed; removed > 0 @@ -157,7 +159,11 @@ impl EndpointConnPool { pool.total_conns += 1; pool.global_connections_count .fetch_add(1, atomic::Ordering::Relaxed); - NUM_OPEN_CLIENTS_IN_HTTP_POOL.inc(); + Metrics::get() + .proxy + .http_pool_opened_connections + .get_metric() + .inc(); } pool.total_conns @@ -177,7 +183,11 @@ impl Drop for EndpointConnPool { if self.total_conns > 0 { self.global_connections_count .fetch_sub(self.total_conns, atomic::Ordering::Relaxed); - NUM_OPEN_CLIENTS_IN_HTTP_POOL.sub(self.total_conns as i64); + Metrics::get() + .proxy + .http_pool_opened_connections + .get_metric() + .dec_by(self.total_conns as i64); } } } @@ -216,7 +226,11 @@ impl DbUserConnPool { removed += 1; } global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed); - NUM_OPEN_CLIENTS_IN_HTTP_POOL.sub(removed as i64); + Metrics::get() + .proxy + .http_pool_opened_connections + .get_metric() + .dec_by(removed as i64); conn } } @@ -304,7 +318,10 @@ impl GlobalConnPool { // acquire a random shard lock let mut shard = self.global_pool.shards()[shard].write(); - let timer = GC_LATENCY.start_timer(); + let timer = Metrics::get() + .proxy + .http_pool_reclaimation_lag_seconds + .start_timer(); let current_len = shard.len(); let mut clients_removed = 0; shard.retain(|endpoint, x| { @@ -332,7 +349,7 @@ impl GlobalConnPool { let new_len = shard.len(); drop(shard); - timer.observe_duration(); + timer.observe(); // Do logging outside of the lock. if clients_removed > 0 { @@ -340,7 +357,11 @@ impl GlobalConnPool { .global_connections_count .fetch_sub(clients_removed, atomic::Ordering::Relaxed) - clients_removed; - NUM_OPEN_CLIENTS_IN_HTTP_POOL.sub(clients_removed as i64); + Metrics::get() + .proxy + .http_pool_opened_connections + .get_metric() + .dec_by(clients_removed as i64); info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}"); } let removed = current_len - new_len; @@ -411,7 +432,7 @@ impl GlobalConnPool { pools: HashMap::new(), total_conns: 0, max_conns: self.config.pool_options.max_conns_per_endpoint, - _guard: ENDPOINT_POOLS.guard(), + _guard: Metrics::get().proxy.http_endpoint_pools.guard(), global_connections_count: self.global_connections_count.clone(), global_pool_size_max_conns: self.config.pool_options.max_total_conns, })); @@ -451,9 +472,7 @@ pub fn poll_client( conn_id: uuid::Uuid, aux: MetricsAuxInfo, ) -> Client { - let conn_gauge = NUM_DB_CONNECTIONS_GAUGE - .with_label_values(&[ctx.protocol]) - .guard(); + let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol); let mut session_id = ctx.session_id; let (tx, mut rx) = tokio::sync::watch::channel(session_id); diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 7f7f93988c..e856053a7e 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -43,8 +43,8 @@ use crate::context::RequestMonitoring; use crate::error::ErrorKind; use crate::error::ReportableError; use crate::error::UserFacingError; -use crate::metrics::HTTP_CONTENT_LENGTH; -use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE; +use crate::metrics::HttpDirection; +use crate::metrics::Metrics; use crate::proxy::run_until_cancelled; use crate::proxy::NeonOptions; use crate::serverless::backend::HttpConnError; @@ -494,10 +494,11 @@ async fn handle_inner( request: Request, backend: Arc, ) -> Result>, SqlOverHttpError> { - let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE - .with_label_values(&[ctx.protocol]) - .guard(); - info!("handling interactive connection from client"); + let _requeset_gauge = Metrics::get().proxy.connection_requests.guard(ctx.protocol); + info!( + protocol = %ctx.protocol, + "handling interactive connection from client" + ); // // Determine the destination and connection params @@ -520,9 +521,10 @@ async fn handle_inner( None => MAX_REQUEST_SIZE + 1, }; info!(request_content_length, "request size in bytes"); - HTTP_CONTENT_LENGTH - .with_label_values(&["request"]) - .observe(request_content_length as f64); + Metrics::get() + .proxy + .http_conn_content_length_bytes + .observe(HttpDirection::Request, request_content_length as f64); // we don't have a streaming request support yet so this is to prevent OOM // from a malicious user sending an extremely large request body @@ -539,7 +541,9 @@ async fn handle_inner( .map_err(SqlOverHttpError::from); let authenticate_and_connect = async { - let keys = backend.authenticate(ctx, &conn_info).await?; + let keys = backend + .authenticate(ctx, &config.authentication_config, &conn_info) + .await?; let client = backend .connect_to_compute(ctx, conn_info, keys, !allow_pool) .await?; @@ -607,9 +611,10 @@ async fn handle_inner( // count the egress bytes - we miss the TLS and header overhead but oh well... // moving this later in the stack is going to be a lot of effort and ehhhh metrics.record_egress(len as u64); - HTTP_CONTENT_LENGTH - .with_label_values(&["response"]) - .observe(len as f64); + Metrics::get() + .proxy + .http_conn_content_length_bytes + .observe(HttpDirection::Response, len as f64); Ok(response) } diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index ada6c974f4..eddd278b7d 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -3,9 +3,8 @@ use crate::{ config::ProxyConfig, context::RequestMonitoring, error::{io_error, ReportableError}, - metrics::NUM_CLIENT_CONNECTION_GAUGE, + metrics::Metrics, proxy::{handle_client, ClientMode}, - rate_limiter::EndpointRateLimiter, }; use bytes::{Buf, Bytes}; use futures::{Sink, Stream}; @@ -136,12 +135,12 @@ pub async fn serve_websocket( websocket: HyperWebsocket, cancellation_handler: Arc, hostname: Option, - endpoint_rate_limiter: Arc, ) -> anyhow::Result<()> { let websocket = websocket.await?; - let conn_gauge = NUM_CLIENT_CONNECTION_GAUGE - .with_label_values(&["ws"]) - .guard(); + let conn_gauge = Metrics::get() + .proxy + .client_connections + .guard(crate::metrics::Protocol::Ws); let res = handle_client( config, @@ -149,7 +148,6 @@ pub async fn serve_websocket( cancellation_handler, WebSocketRw::new(websocket), ClientMode::Websockets { hostname }, - endpoint_rate_limiter, conn_gauge, ) .await; diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index b6b7a85659..690e92ffb1 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -1,6 +1,6 @@ use crate::config::TlsServerEndPoint; use crate::error::{ErrorKind, ReportableError, UserFacingError}; -use crate::metrics::TLS_HANDSHAKE_FAILURES; +use crate::metrics::Metrics; use bytes::BytesMut; use pq_proto::framed::{ConnectionError, Framed}; @@ -223,12 +223,20 @@ pub enum StreamUpgradeError { impl Stream { /// If possible, upgrade raw stream into a secure TLS-based stream. - pub async fn upgrade(self, cfg: Arc) -> Result, StreamUpgradeError> { + pub async fn upgrade( + self, + cfg: Arc, + record_handshake_error: bool, + ) -> Result, StreamUpgradeError> { match self { Stream::Raw { raw } => Ok(tokio_rustls::TlsAcceptor::from(cfg) .accept(raw) .await - .inspect_err(|_| TLS_HANDSHAKE_FAILURES.inc())?), + .inspect_err(|_| { + if record_handshake_error { + Metrics::get().proxy.tls_handshake_failures.inc() + } + })?), Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls), } } diff --git a/proxy/src/usage_metrics.rs b/proxy/src/usage_metrics.rs index 5ffbf95c07..56ed2145dc 100644 --- a/proxy/src/usage_metrics.rs +++ b/proxy/src/usage_metrics.rs @@ -495,7 +495,7 @@ mod tests { use url::Url; use super::*; - use crate::{http, rate_limiter::RateLimiterConfig, BranchId, EndpointId}; + use crate::{http, BranchId, EndpointId}; #[tokio::test] async fn metrics() { @@ -525,7 +525,7 @@ mod tests { tokio::spawn(server); let metrics = Metrics::default(); - let client = http::new_client(RateLimiterConfig::default()); + let client = http::new_client(); let endpoint = Url::parse(&format!("http://{addr}")).unwrap(); let now = Utc::now(); diff --git a/storage_controller/src/compute_hook.rs b/storage_controller/src/compute_hook.rs index eb0c4472e4..1ed8998713 100644 --- a/storage_controller/src/compute_hook.rs +++ b/storage_controller/src/compute_hook.rs @@ -17,6 +17,8 @@ use crate::service::Config; const SLOWDOWN_DELAY: Duration = Duration::from_secs(5); +const NOTIFY_REQUEST_TIMEOUT: Duration = Duration::from_secs(10); + pub(crate) const API_CONCURRENCY: usize = 32; struct UnshardedComputeHookTenant { @@ -242,6 +244,10 @@ pub(super) struct ComputeHook { // This lock is only used in testing enviroments, to serialize calls into neon_lock neon_local_lock: tokio::sync::Mutex<()>, + + // We share a client across all notifications to enable connection re-use etc when + // sending large numbers of notifications + client: reqwest::Client, } impl ComputeHook { @@ -251,12 +257,18 @@ impl ComputeHook { .clone() .map(|jwt| format!("Bearer {}", jwt)); + let client = reqwest::ClientBuilder::new() + .timeout(NOTIFY_REQUEST_TIMEOUT) + .build() + .expect("Failed to construct HTTP client"); + Self { state: Default::default(), config, authorization_header, neon_local_lock: Default::default(), api_concurrency: tokio::sync::Semaphore::new(API_CONCURRENCY), + client, } } @@ -310,12 +322,11 @@ impl ComputeHook { async fn do_notify_iteration( &self, - client: &reqwest::Client, url: &String, reconfigure_request: &ComputeHookNotifyRequest, cancel: &CancellationToken, ) -> Result<(), NotifyError> { - let req = client.request(Method::PUT, url); + let req = self.client.request(Method::PUT, url); let req = if let Some(value) = &self.authorization_header { req.header(reqwest::header::AUTHORIZATION, value) } else { @@ -381,8 +392,6 @@ impl ComputeHook { reconfigure_request: &ComputeHookNotifyRequest, cancel: &CancellationToken, ) -> Result<(), NotifyError> { - let client = reqwest::Client::new(); - // We hold these semaphore units across all retries, rather than only across each // HTTP request: this is to preserve fairness and avoid a situation where a retry might // time out waiting for a semaphore. @@ -394,7 +403,7 @@ impl ComputeHook { .map_err(|_| NotifyError::ShuttingDown)?; backoff::retry( - || self.do_notify_iteration(&client, url, reconfigure_request, cancel), + || self.do_notify_iteration(url, reconfigure_request, cancel), |e| { matches!( e, diff --git a/storage_controller/src/scheduler.rs b/storage_controller/src/scheduler.rs index 862ac0cbfe..3ff0d87988 100644 --- a/storage_controller/src/scheduler.rs +++ b/storage_controller/src/scheduler.rs @@ -84,6 +84,20 @@ impl std::ops::Add for AffinityScore { } } +/// Hint for whether this is a sincere attempt to schedule, or a speculative +/// check for where we _would_ schedule (done during optimization) +#[derive(Debug)] +pub(crate) enum ScheduleMode { + Normal, + Speculative, +} + +impl Default for ScheduleMode { + fn default() -> Self { + Self::Normal + } +} + // For carrying state between multiple calls to [`TenantShard::schedule`], e.g. when calling // it for many shards in the same tenant. #[derive(Debug, Default)] @@ -93,6 +107,8 @@ pub(crate) struct ScheduleContext { /// Specifically how many _attached_ locations are on each node pub(crate) attached_nodes: HashMap, + + pub(crate) mode: ScheduleMode, } impl ScheduleContext { @@ -329,27 +345,34 @@ impl Scheduler { scores.sort_by_key(|i| (i.1, i.2, i.0)); if scores.is_empty() { - // After applying constraints, no pageservers were left. We log some detail about - // the state of nodes to help understand why this happened. This is not logged as an error because - // it is legitimately possible for enough nodes to be Offline to prevent scheduling a shard. - tracing::info!("Scheduling failure, while excluding {hard_exclude:?}, node states:"); - for (node_id, node) in &self.nodes { + // After applying constraints, no pageservers were left. + if !matches!(context.mode, ScheduleMode::Speculative) { + // If this was not a speculative attempt, log details to understand why we couldn't + // schedule: this may help an engineer understand if some nodes are marked offline + // in a way that's preventing progress. tracing::info!( - "Node {node_id}: may_schedule={} shards={}", - node.may_schedule != MaySchedule::No, - node.shard_count + "Scheduling failure, while excluding {hard_exclude:?}, node states:" ); + for (node_id, node) in &self.nodes { + tracing::info!( + "Node {node_id}: may_schedule={} shards={}", + node.may_schedule != MaySchedule::No, + node.shard_count + ); + } } - return Err(ScheduleError::ImpossibleConstraint); } // Lowest score wins let node_id = scores.first().unwrap().0; - tracing::info!( + + if !matches!(context.mode, ScheduleMode::Speculative) { + tracing::info!( "scheduler selected node {node_id} (elegible nodes {:?}, hard exclude: {hard_exclude:?}, soft exclude: {context:?})", scores.iter().map(|i| i.0 .0).collect::>() ); + } // Note that we do not update shard count here to reflect the scheduling: that // is IntentState's job when the scheduled location is used. diff --git a/storage_controller/src/service.rs b/storage_controller/src/service.rs index 010558b797..0565f8e7b4 100644 --- a/storage_controller/src/service.rs +++ b/storage_controller/src/service.rs @@ -11,7 +11,7 @@ use crate::{ id_lock_map::IdLockMap, persistence::{AbortShardSplitStatus, TenantFilter}, reconciler::ReconcileError, - scheduler::ScheduleContext, + scheduler::{ScheduleContext, ScheduleMode}, }; use anyhow::Context; use control_plane::storage_controller::{ @@ -2744,7 +2744,7 @@ impl Service { let mut describe_shards = Vec::new(); for shard in shards { - if shard.tenant_shard_id.is_zero() { + if shard.tenant_shard_id.is_shard_zero() { shard_zero = Some(shard); } @@ -4084,7 +4084,7 @@ impl Service { let mut reconciles_spawned = 0; for (tenant_shard_id, shard) in tenants.iter_mut() { - if tenant_shard_id.is_zero() { + if tenant_shard_id.is_shard_zero() { schedule_context = ScheduleContext::default(); } @@ -4134,9 +4134,10 @@ impl Service { let mut work = Vec::new(); for (tenant_shard_id, shard) in tenants.iter() { - if tenant_shard_id.is_zero() { + if tenant_shard_id.is_shard_zero() { // Reset accumulators on the first shard in a tenant schedule_context = ScheduleContext::default(); + schedule_context.mode = ScheduleMode::Speculative; tenant_shards.clear(); } diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 0e4a58c099..c2c661088b 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -2449,10 +2449,12 @@ class NeonPageserver(PgProtocol): if cur_line_no < skip_until_line_no: cur_line_no += 1 continue - if contains_re.search(line): + elif contains_re.search(line): # found it! cur_line_no += 1 return (line, LogCursor(cur_line_no)) + else: + cur_line_no += 1 return None def tenant_attach( diff --git a/test_runner/regress/test_compatibility.py b/test_runner/regress/test_compatibility.py index 208263a22a..ddad98a5fa 100644 --- a/test_runner/regress/test_compatibility.py +++ b/test_runner/regress/test_compatibility.py @@ -192,9 +192,6 @@ def test_backward_compatibility( assert not breaking_changes_allowed, "Breaking changes are allowed by ALLOW_BACKWARD_COMPATIBILITY_BREAKAGE, but the test has passed without any breakage" -# Forward compatibility is broken due to https://github.com/neondatabase/neon/pull/6530 -# The test is disabled until the next release deployment -@pytest.mark.xfail @check_ondisk_data_compatibility_if_enabled @pytest.mark.xdist_group("compatibility") @pytest.mark.order(after="test_create_snapshot") diff --git a/test_runner/regress/test_pageserver_config.py b/test_runner/regress/test_pageserver_config.py new file mode 100644 index 0000000000..c04348b488 --- /dev/null +++ b/test_runner/regress/test_pageserver_config.py @@ -0,0 +1,35 @@ +import pytest +from fixtures.neon_fixtures import ( + NeonEnvBuilder, + last_flush_lsn_upload, +) + + +@pytest.mark.parametrize("kind", ["sync", "async"]) +def test_walredo_process_kind_config(neon_env_builder: NeonEnvBuilder, kind: str): + neon_env_builder.pageserver_config_override = f"walredo_process_kind = '{kind}'" + # ensure it starts + env = neon_env_builder.init_start() + # ensure the metric is set + ps_http = env.pageserver.http_client() + metrics = ps_http.get_metrics() + samples = metrics.query_all("pageserver_wal_redo_process_kind") + assert [(s.labels, s.value) for s in samples] == [({"kind": kind}, 1)] + # ensure default tenant's config kind matches + # => write some data to force-spawn walredo + ep = env.endpoints.create_start("main") + with ep.connect() as conn: + with conn.cursor() as cur: + cur.execute("create table foo(bar text)") + cur.execute("insert into foo select from generate_series(1, 100)") + last_flush_lsn_upload(env, ep, env.initial_tenant, env.initial_timeline) + ep.stop() + ep.start() + with ep.connect() as conn: + with conn.cursor() as cur: + cur.execute("select count(*) from foo") + [(count,)] = cur.fetchall() + assert count == 100 + + status = ps_http.tenant_status(env.initial_tenant) + assert status["walredo"]["process"]["kind"] == kind diff --git a/test_runner/regress/test_proxy_rate_limiter.py b/test_runner/regress/test_proxy_rate_limiter.py deleted file mode 100644 index f39f0cad07..0000000000 --- a/test_runner/regress/test_proxy_rate_limiter.py +++ /dev/null @@ -1,84 +0,0 @@ -import asyncio -import time -from pathlib import Path -from typing import Iterator - -import pytest -from fixtures.neon_fixtures import ( - PSQL, - NeonProxy, -) -from fixtures.port_distributor import PortDistributor -from pytest_httpserver import HTTPServer -from werkzeug.wrappers.response import Response - - -def waiting_handler(status_code: int) -> Response: - # wait more than timeout to make sure that both (two) connections are open. - # It would be better to use a barrier here, but I don't know how to do that together with pytest-httpserver. - time.sleep(2) - return Response(status=status_code) - - -@pytest.fixture(scope="function") -def proxy_with_rate_limit( - port_distributor: PortDistributor, - neon_binpath: Path, - httpserver_listen_address, - test_output_dir: Path, -) -> Iterator[NeonProxy]: - """Neon proxy that routes directly to vanilla postgres.""" - - proxy_port = port_distributor.get_port() - mgmt_port = port_distributor.get_port() - http_port = port_distributor.get_port() - external_http_port = port_distributor.get_port() - (host, port) = httpserver_listen_address - endpoint = f"http://{host}:{port}/billing/api/v1/usage_events" - - with NeonProxy( - neon_binpath=neon_binpath, - test_output_dir=test_output_dir, - proxy_port=proxy_port, - http_port=http_port, - mgmt_port=mgmt_port, - external_http_port=external_http_port, - auth_backend=NeonProxy.Console(endpoint, fixed_rate_limit=5), - ) as proxy: - proxy.start() - yield proxy - - -@pytest.mark.asyncio -async def test_proxy_rate_limit( - httpserver: HTTPServer, - proxy_with_rate_limit: NeonProxy, -): - uri = "/billing/api/v1/usage_events/proxy_get_role_secret" - # mock control plane service - httpserver.expect_ordered_request(uri, method="GET").respond_with_handler( - lambda _: Response(status=200) - ) - httpserver.expect_ordered_request(uri, method="GET").respond_with_handler( - lambda _: waiting_handler(429) - ) - httpserver.expect_ordered_request(uri, method="GET").respond_with_handler( - lambda _: waiting_handler(500) - ) - - psql = PSQL(host=proxy_with_rate_limit.host, port=proxy_with_rate_limit.proxy_port) - f = await psql.run("select 42;") - await proxy_with_rate_limit.find_auth_link(uri, f) - # Limit should be 2. - - # Run two queries in parallel. - f1, f2 = await asyncio.gather(psql.run("select 42;"), psql.run("select 42;")) - await proxy_with_rate_limit.find_auth_link(uri, f1) - await proxy_with_rate_limit.find_auth_link(uri, f2) - - # Now limit should be 0. - f = await psql.run("select 42;") - await proxy_with_rate_limit.find_auth_link(uri, f) - - # There last query shouldn't reach the http-server. - assert httpserver.assertions == [] diff --git a/test_runner/regress/test_tenant_delete.py b/test_runner/regress/test_tenant_delete.py index a164c7f60a..c115c0375b 100644 --- a/test_runner/regress/test_tenant_delete.py +++ b/test_runner/regress/test_tenant_delete.py @@ -469,7 +469,8 @@ def test_tenant_delete_concurrent( ): """ Validate that concurrent delete requests to the same tenant behave correctly: - exactly one should succeed. + exactly one should execute: the rest should give 202 responses but not start + another deletion. This is a reproducer for https://github.com/neondatabase/neon/issues/5936 """ @@ -484,14 +485,10 @@ def test_tenant_delete_concurrent( run_pg_bench_small(pg_bin, endpoint.connstr()) last_flush_lsn_upload(env, endpoint, tenant_id, timeline_id) - CONFLICT_MESSAGE = "Precondition failed: Invalid state Stopping. Expected Active or Broken" - env.pageserver.allowed_errors.extend( [ # lucky race with stopping from flushing a layer we fail to schedule any uploads ".*layer flush task.+: could not flush frozen layer: update_metadata_file", - # Errors logged from our 4xx requests - f".*{CONFLICT_MESSAGE}.*", ] ) @@ -507,7 +504,7 @@ def test_tenant_delete_concurrent( return ps_http.tenant_delete(tenant_id) def hit_remove_failpoint(): - env.pageserver.assert_log_contains(f"at failpoint {BEFORE_REMOVE_FAILPOINT}") + return env.pageserver.assert_log_contains(f"at failpoint {BEFORE_REMOVE_FAILPOINT}")[1] def hit_run_failpoint(): env.pageserver.assert_log_contains(f"at failpoint {BEFORE_RUN_FAILPOINT}") @@ -518,11 +515,14 @@ def test_tenant_delete_concurrent( # Wait until the first request completes its work and is blocked on removing # the TenantSlot from tenant manager. - wait_until(100, 0.1, hit_remove_failpoint) + log_cursor = wait_until(100, 0.1, hit_remove_failpoint) + assert log_cursor is not None - # Start another request: this should fail when it sees a tenant in Stopping state - with pytest.raises(PageserverApiException, match=CONFLICT_MESSAGE): - ps_http.tenant_delete(tenant_id) + # Start another request: this should succeed without actually entering the deletion code + ps_http.tenant_delete(tenant_id) + assert not env.pageserver.log_contains( + f"at failpoint {BEFORE_RUN_FAILPOINT}", offset=log_cursor + ) # Start another background request, which will pause after acquiring a TenantSlotGuard # but before completing. @@ -539,8 +539,10 @@ def test_tenant_delete_concurrent( # Permit the duplicate background request to run to completion and fail. ps_http.configure_failpoints((BEFORE_RUN_FAILPOINT, "off")) - with pytest.raises(PageserverApiException, match=CONFLICT_MESSAGE): - background_4xx_req.result(timeout=10) + background_4xx_req.result(timeout=10) + assert not env.pageserver.log_contains( + f"at failpoint {BEFORE_RUN_FAILPOINT}", offset=log_cursor + ) # Physical deletion should have happened assert_prefix_empty(