From 6d95a3fe2dcef3c8f83f5b473fee6cd3f70dd209 Mon Sep 17 00:00:00 2001 From: Vlad Lazar Date: Fri, 30 May 2025 13:30:11 +0100 Subject: [PATCH 01/43] pageserver: various import flow fixups (#12047) ## Problem There's a bunch of TODOs in the import code. ## Summary of changes 1. Bound max import byte range to 128MiB. This might still be too high, given the default job concurrency, but it needs to be balanced with going back and forth to S3. 2. Prevent unsigned overflow when determining key range splits for concurrent jobs 3. Use sharded ranges to estimate task size when splitting jobs 4. Bubble up errors that we might hit due to invalid data in the bucket back to the storage controller. 5. Tweak the import bucket S3 client configuration. --- .../src/tenant/timeline/import_pgdata/flow.rs | 68 +++++++++++-------- .../import_pgdata/importbucket_client.rs | 19 +++--- 2 files changed, 50 insertions(+), 37 deletions(-) diff --git a/pageserver/src/tenant/timeline/import_pgdata/flow.rs b/pageserver/src/tenant/timeline/import_pgdata/flow.rs index 9743aa3f26..bf3c7eeda6 100644 --- a/pageserver/src/tenant/timeline/import_pgdata/flow.rs +++ b/pageserver/src/tenant/timeline/import_pgdata/flow.rs @@ -11,19 +11,7 @@ //! - => S3 as the source for the PGDATA instead of local filesystem //! //! TODOs before productionization: -//! - ChunkProcessingJob size / ImportJob::total_size does not account for sharding. -//! => produced image layers likely too small. //! - ChunkProcessingJob should cut up an ImportJob to hit exactly target image layer size. -//! - asserts / unwraps need to be replaced with errors -//! - don't trust remote objects will be small (=prevent OOMs in those cases) -//! - limit all in-memory buffers in size, or download to disk and read from there -//! - limit task concurrency -//! - generally play nice with other tenants in the system -//! - importbucket is different bucket than main pageserver storage, so, should be fine wrt S3 rate limits -//! - but concerns like network bandwidth, local disk write bandwidth, local disk capacity, etc -//! - integrate with layer eviction system -//! - audit for Tenant::cancel nor Timeline::cancel responsivity -//! - audit for Tenant/Timeline gate holding (we spawn tokio tasks during this flow!) //! //! An incomplete set of TODOs from the Hackathon: //! - version-specific CheckPointData (=> pgv abstraction, already exists for regular walingest) @@ -44,7 +32,7 @@ use pageserver_api::key::{ rel_dir_to_key, rel_size_to_key, relmap_file_key, slru_block_to_key, slru_dir_to_key, slru_segment_size_to_key, }; -use pageserver_api::keyspace::{contiguous_range_len, is_contiguous_range, singleton_range}; +use pageserver_api::keyspace::{ShardedRange, singleton_range}; use pageserver_api::models::{ShardImportProgress, ShardImportProgressV1, ShardImportStatus}; use pageserver_api::reltag::{RelTag, SlruKind}; use pageserver_api::shard::ShardIdentity; @@ -167,6 +155,7 @@ impl Planner { /// This function is and must remain pure: given the same input, it will generate the same import plan. async fn plan(mut self, import_config: &TimelineImportConfig) -> anyhow::Result { let pgdata_lsn = Lsn(self.control_file.control_file_data().checkPoint).align(); + anyhow::ensure!(pgdata_lsn.is_valid()); let datadir = PgDataDir::new(&self.storage).await?; @@ -249,14 +238,22 @@ impl Planner { }); // Assigns parts of key space to later parallel jobs + // Note: The image layers produced here may have gaps, meaning, + // there is not an image for each key in the layer's key range. + // The read path stops traversal at the first image layer, regardless + // of whether a base image has been found for a key or not. + // (Concept of sparse image layers doesn't exist.) + // This behavior is exactly right for the base image layers we're producing here. + // But, since no other place in the code currently produces image layers with gaps, + // it seems noteworthy. let mut last_end_key = Key::MIN; let mut current_chunk = Vec::new(); let mut current_chunk_size: usize = 0; let mut jobs = Vec::new(); for task in std::mem::take(&mut self.tasks).into_iter() { - if current_chunk_size + task.total_size() - > import_config.import_job_soft_size_limit.into() - { + let task_size = task.total_size(&self.shard); + let projected_chunk_size = current_chunk_size.saturating_add(task_size); + if projected_chunk_size > import_config.import_job_soft_size_limit.into() { let key_range = last_end_key..task.key_range().start; jobs.push(ChunkProcessingJob::new( key_range.clone(), @@ -266,7 +263,7 @@ impl Planner { last_end_key = key_range.end; current_chunk_size = 0; } - current_chunk_size += task.total_size(); + current_chunk_size = current_chunk_size.saturating_add(task_size); current_chunk.push(task); } jobs.push(ChunkProcessingJob::new( @@ -604,18 +601,18 @@ impl PgDataDirDb { }; let path = datadir_path.join(rel_tag.to_segfile_name(segno)); - assert!(filesize % BLCKSZ as usize == 0); // TODO: this should result in an error + anyhow::ensure!(filesize % BLCKSZ as usize == 0); let nblocks = filesize / BLCKSZ as usize; - PgDataDirDbFile { + Ok(PgDataDirDbFile { path, filesize, rel_tag, segno, nblocks: Some(nblocks), // first non-cummulative sizes - } + }) }) - .collect(); + .collect::>()?; // Set cummulative sizes. Do all of that math here, so that later we could easier // parallelize over segments and know with which segments we need to write relsize @@ -650,12 +647,22 @@ impl PgDataDirDb { trait ImportTask { fn key_range(&self) -> Range; - fn total_size(&self) -> usize { - // TODO: revisit this - if is_contiguous_range(&self.key_range()) { - contiguous_range_len(&self.key_range()) as usize * 8192 + fn total_size(&self, shard_identity: &ShardIdentity) -> usize { + let range = ShardedRange::new(self.key_range(), shard_identity); + let page_count = range.page_count(); + if page_count == u32::MAX { + tracing::warn!( + "Import task has non contiguous key range: {}..{}", + self.key_range().start, + self.key_range().end + ); + + // Tasks should operate on contiguous ranges. It is unexpected for + // ranges to violate this assumption. Calling code handles this by mapping + // any task on a non contiguous range to its own image layer. + usize::MAX } else { - u32::MAX as usize + page_count as usize * 8192 } } @@ -753,6 +760,8 @@ impl ImportTask for ImportRelBlocksTask { layer_writer: &mut ImageLayerWriter, ctx: &RequestContext, ) -> anyhow::Result { + const MAX_BYTE_RANGE_SIZE: usize = 128 * 1024 * 1024; + debug!("Importing relation file"); let (rel_tag, start_blk) = self.key_range.start.to_rel_block()?; @@ -777,7 +786,7 @@ impl ImportTask for ImportRelBlocksTask { assert_eq!(key.len(), 1); assert!(!acc.is_empty()); assert!(acc_end > acc_start); - if acc_end == start /* TODO additional max range check here, to limit memory consumption per task to X */ { + if acc_end == start && end - acc_start <= MAX_BYTE_RANGE_SIZE { acc.push(key.pop().unwrap()); Ok((acc, acc_start, end)) } else { @@ -792,8 +801,8 @@ impl ImportTask for ImportRelBlocksTask { .get_range(&self.path, range_start.into_u64(), range_end.into_u64()) .await?; let mut buf = Bytes::from(range_buf); - // TODO: batched writes for key in keys { + // The writer buffers writes internally let image = buf.split_to(8192); layer_writer.put_image(key, image, ctx).await?; nimages += 1; @@ -846,6 +855,9 @@ impl ImportTask for ImportSlruBlocksTask { debug!("Importing SLRU segment file {}", self.path); let buf = self.storage.get(&self.path).await?; + // TODO(vlad): Does timestamp to LSN work for imported timelines? + // Probably not since we don't append the `xact_time` to it as in + // [`WalIngest::ingest_xact_record`]. let (kind, segno, start_blk) = self.key_range.start.to_slru_block()?; let (_kind, _segno, end_blk) = self.key_range.end.to_slru_block()?; let mut blknum = start_blk; diff --git a/pageserver/src/tenant/timeline/import_pgdata/importbucket_client.rs b/pageserver/src/tenant/timeline/import_pgdata/importbucket_client.rs index 34313748b7..bf2d9875c1 100644 --- a/pageserver/src/tenant/timeline/import_pgdata/importbucket_client.rs +++ b/pageserver/src/tenant/timeline/import_pgdata/importbucket_client.rs @@ -6,7 +6,7 @@ use bytes::Bytes; use postgres_ffi::ControlFileData; use remote_storage::{ Download, DownloadError, DownloadKind, DownloadOpts, GenericRemoteStorage, Listing, - ListingObject, RemotePath, + ListingObject, RemotePath, RemoteStorageConfig, }; use serde::de::DeserializeOwned; use tokio_util::sync::CancellationToken; @@ -22,11 +22,9 @@ pub async fn new( location: &index_part_format::Location, cancel: CancellationToken, ) -> Result { - // FIXME: we probably want some timeout, and we might be able to assume the max file - // size on S3 is 1GiB (postgres segment size). But the problem is that the individual - // downloaders don't know enough about concurrent downloads to make a guess on the - // expected bandwidth and resulting best timeout. - let timeout = std::time::Duration::from_secs(24 * 60 * 60); + // Downloads should be reasonably sized. We do ranged reads for relblock raw data + // and full reads for SLRU segments which are bounded by Postgres. + let timeout = RemoteStorageConfig::DEFAULT_TIMEOUT; let location_storage = match location { #[cfg(feature = "testing")] index_part_format::Location::LocalFs { path } => { @@ -50,9 +48,12 @@ pub async fn new( .import_pgdata_aws_endpoint_url .clone() .map(|url| url.to_string()), // by specifying None here, remote_storage/aws-sdk-rust will infer from env - concurrency_limit: 100.try_into().unwrap(), // TODO: think about this - max_keys_per_list_response: Some(1000), // TODO: think about this - upload_storage_class: None, // irrelevant + // This matches the default import job concurrency. This is managed + // separately from the usual S3 client, but the concern here is bandwidth + // usage. + concurrency_limit: 128.try_into().unwrap(), + max_keys_per_list_response: Some(1000), + upload_storage_class: None, // irrelevant }, timeout, ) From 35372a8f12ae143da475cc0b5de1529d5c05804e Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Fri, 30 May 2025 15:22:53 +0200 Subject: [PATCH 02/43] adjust VirtualFile operation latency histogram buckets (#12075) The expected operating range for the production NVMe drives is in the range of 50 to 250us. The bucket boundaries before this PR were not well suited to reason about the utilization / queuing / latency variability of those devices. # Performance There was some concern about perf impact of having so many buckets, considering the impl does a linear search on each observe(). I added a benchmark and measured on relevant machines. In any way, the PR is 40 buckets, so, won't make a meaningful difference on production machines (im4gn.2xlarge), going from 30ns -> 35ns. --- libs/metrics/src/lib.rs | 1 + pageserver/benches/bench_metrics.rs | 72 ++++++++++++++++++++++++++++- pageserver/src/metrics.rs | 43 +++++++++++++++-- 3 files changed, 110 insertions(+), 6 deletions(-) diff --git a/libs/metrics/src/lib.rs b/libs/metrics/src/lib.rs index 4df8d7bc51..5d028ee041 100644 --- a/libs/metrics/src/lib.rs +++ b/libs/metrics/src/lib.rs @@ -27,6 +27,7 @@ pub use prometheus::{ pub mod launch_timestamp; mod wrappers; +pub use prometheus; pub use wrappers::{CountedReader, CountedWriter}; mod hll; pub use hll::{HyperLogLog, HyperLogLogState, HyperLogLogVec}; diff --git a/pageserver/benches/bench_metrics.rs b/pageserver/benches/bench_metrics.rs index 38025124e1..e0428f6372 100644 --- a/pageserver/benches/bench_metrics.rs +++ b/pageserver/benches/bench_metrics.rs @@ -264,10 +264,56 @@ mod propagation_of_cached_label_value { } } +criterion_group!(histograms, histograms::bench_bucket_scalability); +mod histograms { + use std::time::Instant; + + use criterion::{BenchmarkId, Criterion}; + use metrics::core::Collector; + + pub fn bench_bucket_scalability(c: &mut Criterion) { + let mut g = c.benchmark_group("bucket_scalability"); + + for n in [1, 4, 8, 16, 32, 64, 128, 256] { + g.bench_with_input(BenchmarkId::new("nbuckets", n), &n, |b, n| { + b.iter_custom(|iters| { + let buckets: Vec = (0..*n).map(|i| i as f64 * 100.0).collect(); + let histo = metrics::Histogram::with_opts( + metrics::prometheus::HistogramOpts::new("name", "help") + .buckets(buckets.clone()), + ) + .unwrap(); + let start = Instant::now(); + for i in 0..usize::try_from(iters).unwrap() { + histo.observe(buckets[i % buckets.len()]); + } + let elapsed = start.elapsed(); + // self-test + let mfs = histo.collect(); + assert_eq!(mfs.len(), 1); + let metrics = mfs[0].get_metric(); + assert_eq!(metrics.len(), 1); + let histo = metrics[0].get_histogram(); + let buckets = histo.get_bucket(); + assert!( + buckets + .iter() + .enumerate() + .all(|(i, b)| b.get_cumulative_count() + >= i as u64 * (iters / buckets.len() as u64)) + ); + elapsed + }) + }); + } + } +} + criterion_main!( label_values, single_metric_multicore_scalability, - propagation_of_cached_label_value + propagation_of_cached_label_value, + histograms, ); /* @@ -290,6 +336,14 @@ propagation_of_cached_label_value__naive/nthreads/8 time: [211.50 ns 214.44 ns propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/1 time: [14.135 ns 14.147 ns 14.160 ns] propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/4 time: [14.243 ns 14.255 ns 14.268 ns] propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/8 time: [14.470 ns 14.682 ns 14.895 ns] +bucket_scalability/nbuckets/1 time: [30.352 ns 30.353 ns 30.354 ns] +bucket_scalability/nbuckets/4 time: [30.464 ns 30.465 ns 30.467 ns] +bucket_scalability/nbuckets/8 time: [30.569 ns 30.575 ns 30.584 ns] +bucket_scalability/nbuckets/16 time: [30.961 ns 30.965 ns 30.969 ns] +bucket_scalability/nbuckets/32 time: [35.691 ns 35.707 ns 35.722 ns] +bucket_scalability/nbuckets/64 time: [47.829 ns 47.898 ns 47.974 ns] +bucket_scalability/nbuckets/128 time: [73.479 ns 73.512 ns 73.545 ns] +bucket_scalability/nbuckets/256 time: [127.92 ns 127.94 ns 127.96 ns] Results on an i3en.3xlarge instance @@ -344,6 +398,14 @@ propagation_of_cached_label_value__naive/nthreads/8 time: [434.87 ns 456.4 propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/1 time: [3.3767 ns 3.3974 ns 3.4220 ns] propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/4 time: [3.6105 ns 4.2355 ns 5.1463 ns] propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/8 time: [4.0889 ns 4.9714 ns 6.0779 ns] +bucket_scalability/nbuckets/1 time: [4.8455 ns 4.8542 ns 4.8646 ns] +bucket_scalability/nbuckets/4 time: [4.5663 ns 4.5722 ns 4.5787 ns] +bucket_scalability/nbuckets/8 time: [4.5531 ns 4.5670 ns 4.5842 ns] +bucket_scalability/nbuckets/16 time: [4.6392 ns 4.6524 ns 4.6685 ns] +bucket_scalability/nbuckets/32 time: [6.0302 ns 6.0439 ns 6.0589 ns] +bucket_scalability/nbuckets/64 time: [10.608 ns 10.644 ns 10.691 ns] +bucket_scalability/nbuckets/128 time: [22.178 ns 22.316 ns 22.483 ns] +bucket_scalability/nbuckets/256 time: [42.190 ns 42.328 ns 42.492 ns] Results on a Hetzner AX102 AMD Ryzen 9 7950X3D 16-Core Processor @@ -362,5 +424,13 @@ propagation_of_cached_label_value__naive/nthreads/8 time: [164.24 ns 170.1 propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/1 time: [2.2915 ns 2.2960 ns 2.3012 ns] propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/4 time: [2.5726 ns 2.6158 ns 2.6624 ns] propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/8 time: [2.7068 ns 2.8243 ns 2.9824 ns] +bucket_scalability/nbuckets/1 time: [6.3998 ns 6.4288 ns 6.4684 ns] +bucket_scalability/nbuckets/4 time: [6.3603 ns 6.3620 ns 6.3637 ns] +bucket_scalability/nbuckets/8 time: [6.1646 ns 6.1654 ns 6.1667 ns] +bucket_scalability/nbuckets/16 time: [6.1341 ns 6.1391 ns 6.1454 ns] +bucket_scalability/nbuckets/32 time: [8.2206 ns 8.2254 ns 8.2301 ns] +bucket_scalability/nbuckets/64 time: [13.988 ns 13.994 ns 14.000 ns] +bucket_scalability/nbuckets/128 time: [28.180 ns 28.216 ns 28.251 ns] +bucket_scalability/nbuckets/256 time: [54.914 ns 54.931 ns 54.951 ns] */ diff --git a/pageserver/src/metrics.rs b/pageserver/src/metrics.rs index 0ff31dcb8a..a9b2f1b7e0 100644 --- a/pageserver/src/metrics.rs +++ b/pageserver/src/metrics.rs @@ -1312,11 +1312,44 @@ impl EvictionsWithLowResidenceDuration { // // Roughly logarithmic scale. const STORAGE_IO_TIME_BUCKETS: &[f64] = &[ - 0.000030, // 30 usec - 0.001000, // 1000 usec - 0.030, // 30 ms - 1.000, // 1000 ms - 30.000, // 30000 ms + 0.00005, // 50us + 0.00006, // 60us + 0.00007, // 70us + 0.00008, // 80us + 0.00009, // 90us + 0.0001, // 100us + 0.000110, // 110us + 0.000120, // 120us + 0.000130, // 130us + 0.000140, // 140us + 0.000150, // 150us + 0.000160, // 160us + 0.000170, // 170us + 0.000180, // 180us + 0.000190, // 190us + 0.000200, // 200us + 0.000210, // 210us + 0.000220, // 220us + 0.000230, // 230us + 0.000240, // 240us + 0.000250, // 250us + 0.000300, // 300us + 0.000350, // 350us + 0.000400, // 400us + 0.000450, // 450us + 0.000500, // 500us + 0.000600, // 600us + 0.000700, // 700us + 0.000800, // 800us + 0.000900, // 900us + 0.001000, // 1ms + 0.002000, // 2ms + 0.003000, // 3ms + 0.004000, // 4ms + 0.005000, // 5ms + 0.01000, // 10ms + 0.02000, // 20ms + 0.05000, // 50ms ]; /// VirtualFile fs operation variants. From 8d26978ed9ede0273ece491fca21104c0ba63835 Mon Sep 17 00:00:00 2001 From: Alexander Lakhin Date: Fri, 30 May 2025 18:20:46 +0300 Subject: [PATCH 03/43] Allow known pageserver errors in test_location_conf_churn (#12082) ## Problem While a pageserver in the unreadable state could not be accessed by postgres thanks to https://github.com/neondatabase/neon/pull/12059, it may still receive WAL records and bump into the "layer file download failed: No file found" error when trying to ingest them. Closes: https://github.com/neondatabase/neon/issues/11348 ## Summary of changes Allow errors from wal_connection_manager, which are considered expected. See https://github.com/neondatabase/neon/issues/11348. --- test_runner/regress/test_pageserver_secondary.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test_runner/regress/test_pageserver_secondary.py b/test_runner/regress/test_pageserver_secondary.py index e5908de363..8d18311f3d 100644 --- a/test_runner/regress/test_pageserver_secondary.py +++ b/test_runner/regress/test_pageserver_secondary.py @@ -124,6 +124,9 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver, ".*downloading failed, possibly for shutdown", # {tenant_id=... timeline_id=...}:handle_pagerequests:handle_get_page_at_lsn_request{rel=1664/0/1260 blkno=0 req_lsn=0/149F0D8}: error reading relation or page version: Not found: will not become active. Current state: Stopping\n' ".*page_service.*will not become active.*", + # the following errors are possible when pageserver tries to ingest wal records despite being in unreadable state + ".*wal_connection_manager.*layer file download failed: No file found.*", + ".*wal_connection_manager.*could not ingest record.*", ] ) From 62cd3b8d3d60da5c72224952da20e2ad3d494a11 Mon Sep 17 00:00:00 2001 From: Shockingly Good Date: Fri, 30 May 2025 17:26:22 +0200 Subject: [PATCH 04/43] fix(compute) Remove the hardcoded default value for PGXN HTTP URL. (#12030) Removes the hardcoded value for the Postgres Extensions HTTP gateway URL as it is always provided by the calling code. --- compute_tools/src/bin/compute_ctl.rs | 31 +--------------------------- 1 file changed, 1 insertion(+), 30 deletions(-) diff --git a/compute_tools/src/bin/compute_ctl.rs b/compute_tools/src/bin/compute_ctl.rs index 02339f752c..f9d9c03422 100644 --- a/compute_tools/src/bin/compute_ctl.rs +++ b/compute_tools/src/bin/compute_ctl.rs @@ -57,21 +57,6 @@ use tracing::{error, info}; use url::Url; use utils::failpoint_support; -// Compatibility hack: if the control plane specified any remote-ext-config -// use the default value for extension storage proxy gateway. -// Remove this once the control plane is updated to pass the gateway URL -fn parse_remote_ext_base_url(arg: &str) -> Result { - const FALLBACK_PG_EXT_GATEWAY_BASE_URL: &str = - "http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local"; - - Ok(if arg.starts_with("http") { - arg - } else { - FALLBACK_PG_EXT_GATEWAY_BASE_URL - } - .to_owned()) -} - #[derive(Parser)] #[command(rename_all = "kebab-case")] struct Cli { @@ -80,7 +65,7 @@ struct Cli { /// The base URL for the remote extension storage proxy gateway. /// Should be in the form of `http(s)://[:]`. - #[arg(short = 'r', long, value_parser = parse_remote_ext_base_url, alias = "remote-ext-config")] + #[arg(short = 'r', long, alias = "remote-ext-config")] pub remote_ext_base_url: Option, /// The port to bind the external listening HTTP server to. Clients running @@ -276,18 +261,4 @@ mod test { fn verify_cli() { Cli::command().debug_assert() } - - #[test] - fn parse_pg_ext_gateway_base_url() { - let arg = "http://pg-ext-s3-gateway2"; - let result = super::parse_remote_ext_base_url(arg).unwrap(); - assert_eq!(result, arg); - - let arg = "pg-ext-s3-gateway"; - let result = super::parse_remote_ext_base_url(arg).unwrap(); - assert_eq!( - result, - "http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local" - ); - } } From f6c0f6c4ecbc2bc5beba7d57d8b44ed71d8300c7 Mon Sep 17 00:00:00 2001 From: "Alex Chi Z." <4198311+skyzh@users.noreply.github.com> Date: Sat, 31 May 2025 01:00:41 +0800 Subject: [PATCH 05/43] fix(ci): install build tools with --locked (#12083) ## Problem Release pipeline failing due to some tools cannot be installed. ## Summary of changes Install with `--locked`. Signed-off-by: Alex Chi Z --- build-tools.Dockerfile | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/build-tools.Dockerfile b/build-tools.Dockerfile index 9d4c93e1cd..f97f04968e 100644 --- a/build-tools.Dockerfile +++ b/build-tools.Dockerfile @@ -310,13 +310,13 @@ RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux . "$HOME/.cargo/env" && \ cargo --version && rustup --version && \ rustup component add llvm-tools rustfmt clippy && \ - cargo install rustfilt --version ${RUSTFILT_VERSION} && \ - cargo install cargo-hakari --version ${CARGO_HAKARI_VERSION} && \ - cargo install cargo-deny --locked --version ${CARGO_DENY_VERSION} && \ - cargo install cargo-hack --version ${CARGO_HACK_VERSION} && \ - cargo install cargo-nextest --version ${CARGO_NEXTEST_VERSION} && \ - cargo install cargo-chef --locked --version ${CARGO_CHEF_VERSION} && \ - cargo install diesel_cli --version ${CARGO_DIESEL_CLI_VERSION} \ + cargo install rustfilt --version ${RUSTFILT_VERSION} --locked && \ + cargo install cargo-hakari --version ${CARGO_HAKARI_VERSION} --locked && \ + cargo install cargo-deny --version ${CARGO_DENY_VERSION} --locked && \ + cargo install cargo-hack --version ${CARGO_HACK_VERSION} --locked && \ + cargo install cargo-nextest --version ${CARGO_NEXTEST_VERSION} --locked && \ + cargo install cargo-chef --version ${CARGO_CHEF_VERSION} --locked && \ + cargo install diesel_cli --version ${CARGO_DIESEL_CLI_VERSION} --locked \ --features postgres-bundled --no-default-features && \ rm -rf /home/nonroot/.cargo/registry && \ rm -rf /home/nonroot/.cargo/git From f05df409bd63424f0c5fb0efc14547fcad854c8b Mon Sep 17 00:00:00 2001 From: Shockingly Good Date: Fri, 30 May 2025 19:45:24 +0200 Subject: [PATCH 06/43] impr(compute): Remove the deprecated CLI arg alias for remote-ext-config. (#12087) Also moves it from `String` to `Url`. --- compute_tools/src/bin/compute_ctl.rs | 5 ++--- compute_tools/src/compute.rs | 3 ++- compute_tools/src/extension_server.rs | 5 +++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/compute_tools/src/bin/compute_ctl.rs b/compute_tools/src/bin/compute_ctl.rs index f9d9c03422..db6835da61 100644 --- a/compute_tools/src/bin/compute_ctl.rs +++ b/compute_tools/src/bin/compute_ctl.rs @@ -64,9 +64,8 @@ struct Cli { pub pgbin: String, /// The base URL for the remote extension storage proxy gateway. - /// Should be in the form of `http(s)://[:]`. - #[arg(short = 'r', long, alias = "remote-ext-config")] - pub remote_ext_base_url: Option, + #[arg(short = 'r', long)] + pub remote_ext_base_url: Option, /// The port to bind the external listening HTTP server to. Clients running /// outside the compute will talk to the compute through this port. Keep diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index ff49c737f0..d678b7d670 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -31,6 +31,7 @@ use std::time::{Duration, Instant}; use std::{env, fs}; use tokio::spawn; use tracing::{Instrument, debug, error, info, instrument, warn}; +use url::Url; use utils::id::{TenantId, TimelineId}; use utils::lsn::Lsn; use utils::measured_stream::MeasuredReader; @@ -96,7 +97,7 @@ pub struct ComputeNodeParams { pub internal_http_port: u16, /// the address of extension storage proxy gateway - pub remote_ext_base_url: Option, + pub remote_ext_base_url: Option, /// Interval for installed extensions collection pub installed_extensions_collection_interval: u64, diff --git a/compute_tools/src/extension_server.rs b/compute_tools/src/extension_server.rs index 3439383699..1857afa08c 100644 --- a/compute_tools/src/extension_server.rs +++ b/compute_tools/src/extension_server.rs @@ -83,6 +83,7 @@ use reqwest::StatusCode; use tar::Archive; use tracing::info; use tracing::log::warn; +use url::Url; use zstd::stream::read::Decoder; use crate::metrics::{REMOTE_EXT_REQUESTS_TOTAL, UNKNOWN_HTTP_STATUS}; @@ -158,14 +159,14 @@ fn parse_pg_version(human_version: &str) -> PostgresMajorVersion { pub async fn download_extension( ext_name: &str, ext_path: &RemotePath, - remote_ext_base_url: &str, + remote_ext_base_url: &Url, pgbin: &str, ) -> Result { info!("Download extension {:?} from {:?}", ext_name, ext_path); // TODO add retry logic let download_buffer = - match download_extension_tar(remote_ext_base_url, &ext_path.to_string()).await { + match download_extension_tar(remote_ext_base_url.as_str(), &ext_path.to_string()).await { Ok(buffer) => buffer, Err(error_message) => { return Err(anyhow::anyhow!( From 87179e26b3c18d9cd09b9eecbfed1db742b391ab Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Sun, 1 Jun 2025 19:41:45 +0100 Subject: [PATCH 07/43] completely rewrite pq_proto (#12085) libs/pqproto is designed for safekeeper/pageserver with maximum throughput. proxy only needs it for handshakes/authentication where throughput is not a concern but memory efficiency is. For this reason, we switch to using read_exact and only allocating as much memory as we need to. All reads return a `&'a [u8]` instead of a `Bytes` because accidental sharing of bytes can cause fragmentation. Returning the reference enforces all callers only hold onto the bytes they absolutely need. For example, before this change, `pqproto` was allocating 8KiB for the initial read `BytesMut`, and proxy was holding the `Bytes` in the `StartupMessageParams` for the entire connection through to passthrough. --- proxy/src/auth/backend/classic.rs | 26 +- proxy/src/auth/backend/console_redirect.rs | 16 +- proxy/src/auth/backend/hacks.rs | 30 +- proxy/src/auth/backend/mod.rs | 9 +- proxy/src/auth/credentials.rs | 2 +- proxy/src/auth/flow.rs | 118 ++-- proxy/src/binary/pg_sni_router.rs | 87 +-- proxy/src/binary/proxy.rs | 5 +- proxy/src/cancellation.rs | 2 +- proxy/src/compute.rs | 2 +- proxy/src/console_redirect_proxy.rs | 20 +- proxy/src/context/mod.rs | 2 +- proxy/src/context/parquet.rs | 2 +- proxy/src/lib.rs | 1 + proxy/src/pqproto.rs | 693 +++++++++++++++++++++ proxy/src/proxy/connect_compute.rs | 2 +- proxy/src/proxy/handshake.rs | 90 ++- proxy/src/proxy/mod.rs | 68 +- proxy/src/proxy/retry.rs | 3 +- proxy/src/proxy/tests/mitm.rs | 7 +- proxy/src/proxy/tests/mod.rs | 16 +- proxy/src/redis/cancellation_publisher.rs | 3 +- proxy/src/redis/keys.rs | 21 +- proxy/src/redis/notifications.rs | 10 - proxy/src/sasl/messages.rs | 22 - proxy/src/sasl/mod.rs | 6 +- proxy/src/sasl/stream.rs | 136 ++-- proxy/src/serverless/sql_over_http.rs | 4 +- proxy/src/stream.rs | 319 +++++----- 29 files changed, 1122 insertions(+), 600 deletions(-) create mode 100644 proxy/src/pqproto.rs diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index 5e494dfdd6..dcc500f2c8 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -17,35 +17,27 @@ pub(super) async fn authenticate( config: &'static AuthenticationConfig, secret: AuthSecret, ) -> auth::Result { - let flow = AuthFlow::new(client); let scram_keys = match secret { #[cfg(any(test, feature = "testing"))] AuthSecret::Md5(_) => { debug!("auth endpoint chooses MD5"); - return Err(auth::AuthError::bad_auth_method("MD5")); + return Err(auth::AuthError::MalformedPassword("MD5 not supported")); } AuthSecret::Scram(secret) => { debug!("auth endpoint chooses SCRAM"); let scram = auth::Scram(&secret, ctx); - let auth_outcome = tokio::time::timeout( - config.scram_protocol_timeout, - async { - - flow.begin(scram).await.map_err(|error| { - warn!(?error, "error sending scram acknowledgement"); - error - })?.authenticate().await.map_err(|error| { + let auth_outcome = tokio::time::timeout(config.scram_protocol_timeout, async { + AuthFlow::new(client, scram) + .authenticate() + .await + .inspect_err(|error| { warn!(?error, "error processing scram messages"); - error }) - } - ) + }) .await - .map_err(|e| { - warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs()); - auth::AuthError::user_timeout(e) - })??; + .inspect_err(|_| warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs())) + .map_err(auth::AuthError::user_timeout)??; let client_key = match auth_outcome { sasl::Outcome::Success(key) => key, diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index dd48384c03..a50c30257f 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -2,7 +2,6 @@ use std::fmt; use async_trait::async_trait; use postgres_client::config::SslMode; -use pq_proto::BeMessage as Be; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{info, info_span}; @@ -16,6 +15,7 @@ use crate::context::RequestContext; use crate::control_plane::client::cplane_proxy_v1; use crate::control_plane::{self, CachedNodeInfo, NodeInfo}; use crate::error::{ReportableError, UserFacingError}; +use crate::pqproto::BeMessage; use crate::proxy::NeonOptions; use crate::proxy::connect_compute::ComputeConnectBackend; use crate::stream::PqStream; @@ -154,11 +154,13 @@ async fn authenticate( // Give user a URL to spawn a new database. info!(parent: &span, "sending the auth URL to the user"); - client - .write_message_noflush(&Be::AuthenticationOk)? - .write_message_noflush(&Be::CLIENT_ENCODING)? - .write_message(&Be::NoticeResponse(&greeting)) - .await?; + client.write_message(BeMessage::AuthenticationOk); + client.write_message(BeMessage::ParameterStatus { + name: b"client_encoding", + value: b"UTF8", + }); + client.write_message(BeMessage::NoticeResponse(&greeting)); + client.flush().await?; // Wait for console response via control plane (see `mgmt`). info!(parent: &span, "waiting for console's reply..."); @@ -188,7 +190,7 @@ async fn authenticate( } } - client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?; + client.write_message(BeMessage::NoticeResponse("Connecting to database.")); // This config should be self-contained, because we won't // take username or dbname from client's startup message. diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index 3316543022..1e5c076fb9 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -24,23 +24,25 @@ pub(crate) async fn authenticate_cleartext( debug!("cleartext auth flow override is enabled, proceeding"); ctx.set_auth_method(crate::context::AuthMethod::Cleartext); - // pause the timer while we communicate with the client - let paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - let ep = EndpointIdInt::from(&info.endpoint); - let auth_flow = AuthFlow::new(client) - .begin(auth::CleartextPassword { + let auth_flow = AuthFlow::new( + client, + auth::CleartextPassword { secret, endpoint: ep, pool: config.thread_pool.clone(), - }) - .await?; - drop(paused); - // cleartext auth is only allowed to the ws/http protocol. - // If we're here, we already received the password in the first message. - // Scram protocol will be executed on the proxy side. - let auth_outcome = auth_flow.authenticate().await?; + }, + ); + let auth_outcome = { + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + + // cleartext auth is only allowed to the ws/http protocol. + // If we're here, we already received the password in the first message. + // Scram protocol will be executed on the proxy side. + auth_flow.authenticate().await? + }; let keys = match auth_outcome { sasl::Outcome::Success(key) => key, @@ -67,9 +69,7 @@ pub(crate) async fn password_hack_no_authentication( // pause the timer while we communicate with the client let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - let payload = AuthFlow::new(client) - .begin(auth::PasswordHack) - .await? + let payload = AuthFlow::new(client, auth::PasswordHack) .get_password() .await?; diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index 6e5c0a3954..8c892d90a0 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -31,6 +31,7 @@ use crate::control_plane::{ }; use crate::intern::EndpointIdInt; use crate::metrics::Metrics; +use crate::pqproto::BeMessage; use crate::protocol2::ConnectionInfoExtra; use crate::proxy::NeonOptions; use crate::proxy::connect_compute::ComputeConnectBackend; @@ -402,7 +403,7 @@ async fn authenticate_with_secret( }; // we have authenticated the password - client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?; + client.write_message(BeMessage::AuthenticationOk); return Ok(ComputeCredentials { info, keys }); } @@ -702,7 +703,7 @@ mod tests { #[tokio::test] async fn auth_quirks_scram() { let (mut client, server) = tokio::io::duplex(1024); - let mut stream = PqStream::new(Stream::from_raw(server)); + let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server)); let ctx = RequestContext::test(); let api = Auth { @@ -784,7 +785,7 @@ mod tests { #[tokio::test] async fn auth_quirks_cleartext() { let (mut client, server) = tokio::io::duplex(1024); - let mut stream = PqStream::new(Stream::from_raw(server)); + let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server)); let ctx = RequestContext::test(); let api = Auth { @@ -838,7 +839,7 @@ mod tests { #[tokio::test] async fn auth_quirks_password_hack() { let (mut client, server) = tokio::io::duplex(1024); - let mut stream = PqStream::new(Stream::from_raw(server)); + let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server)); let ctx = RequestContext::test(); let api = Auth { diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 526d0df7f2..b51da48862 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -5,7 +5,6 @@ use std::net::IpAddr; use std::str::FromStr; use itertools::Itertools; -use pq_proto::StartupMessageParams; use thiserror::Error; use tracing::{debug, warn}; @@ -13,6 +12,7 @@ use crate::auth::password_hack::parse_endpoint_param; use crate::context::RequestContext; use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, SniGroup, SniKind}; +use crate::pqproto::StartupMessageParams; use crate::proxy::NeonOptions; use crate::serverless::{AUTH_BROKER_SNI, SERVERLESS_DRIVER_SNI}; use crate::types::{EndpointId, RoleName}; diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index 0992c6d875..8fbc4577e9 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -1,10 +1,8 @@ //! Main authentication flow. -use std::io; use std::sync::Arc; use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS}; -use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be}; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::info; @@ -13,35 +11,26 @@ use super::{AuthError, PasswordHackPayload}; use crate::context::RequestContext; use crate::control_plane::AuthSecret; use crate::intern::EndpointIdInt; +use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage}; use crate::sasl; use crate::scram::threadpool::ThreadPool; use crate::scram::{self}; use crate::stream::{PqStream, Stream}; use crate::tls::TlsServerEndPoint; -/// Every authentication selector is supposed to implement this trait. -pub(crate) trait AuthMethod { - /// Any authentication selector should provide initial backend message - /// containing auth method name and parameters, e.g. md5 salt. - fn first_message(&self, channel_binding: bool) -> BeMessage<'_>; -} - -/// Initial state of [`AuthFlow`]. -pub(crate) struct Begin; - /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`]. pub(crate) struct Scram<'a>( pub(crate) &'a scram::ServerSecret, pub(crate) &'a RequestContext, ); -impl AuthMethod for Scram<'_> { +impl Scram<'_> { #[inline(always)] fn first_message(&self, channel_binding: bool) -> BeMessage<'_> { if channel_binding { - Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS)) + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS)) } else { - Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods( + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods( scram::METHODS_WITHOUT_PLUS, )) } @@ -52,13 +41,6 @@ impl AuthMethod for Scram<'_> { /// . pub(crate) struct PasswordHack; -impl AuthMethod for PasswordHack { - #[inline(always)] - fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> { - Be::AuthenticationCleartextPassword - } -} - /// Use clear-text password auth called `password` in docs /// pub(crate) struct CleartextPassword { @@ -67,53 +49,37 @@ pub(crate) struct CleartextPassword { pub(crate) secret: AuthSecret, } -impl AuthMethod for CleartextPassword { - #[inline(always)] - fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> { - Be::AuthenticationCleartextPassword - } -} - /// This wrapper for [`PqStream`] performs client authentication. #[must_use] pub(crate) struct AuthFlow<'a, S, State> { /// The underlying stream which implements libpq's protocol. stream: &'a mut PqStream>, - /// State might contain ancillary data (see [`Self::begin`]). + /// State might contain ancillary data. state: State, tls_server_end_point: TlsServerEndPoint, } /// Initial state of the stream wrapper. -impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> { +impl<'a, S: AsyncRead + AsyncWrite + Unpin, M> AuthFlow<'a, S, M> { /// Create a new wrapper for client authentication. - pub(crate) fn new(stream: &'a mut PqStream>) -> Self { + pub(crate) fn new(stream: &'a mut PqStream>, method: M) -> Self { let tls_server_end_point = stream.get_ref().tls_server_end_point(); Self { stream, - state: Begin, + state: method, tls_server_end_point, } } - - /// Move to the next step by sending auth method's name & params to client. - pub(crate) async fn begin(self, method: M) -> io::Result> { - self.stream - .write_message(&method.first_message(self.tls_server_end_point.supported())) - .await?; - - Ok(AuthFlow { - stream: self.stream, - state: method, - tls_server_end_point: self.tls_server_end_point, - }) - } } impl AuthFlow<'_, S, PasswordHack> { /// Perform user authentication. Raise an error in case authentication failed. pub(crate) async fn get_password(self) -> super::Result { + self.stream + .write_message(BeMessage::AuthenticationCleartextPassword); + self.stream.flush().await?; + let msg = self.stream.read_password_message().await?; let password = msg .strip_suffix(&[0]) @@ -133,6 +99,10 @@ impl AuthFlow<'_, S, PasswordHack> { impl AuthFlow<'_, S, CleartextPassword> { /// Perform user authentication. Raise an error in case authentication failed. pub(crate) async fn authenticate(self) -> super::Result> { + self.stream + .write_message(BeMessage::AuthenticationCleartextPassword); + self.stream.flush().await?; + let msg = self.stream.read_password_message().await?; let password = msg .strip_suffix(&[0]) @@ -147,7 +117,7 @@ impl AuthFlow<'_, S, CleartextPassword> { .await?; if let sasl::Outcome::Success(_) = &outcome { - self.stream.write_message_noflush(&Be::AuthenticationOk)?; + self.stream.write_message(BeMessage::AuthenticationOk); } Ok(outcome) @@ -159,42 +129,36 @@ impl AuthFlow<'_, S, Scram<'_>> { /// Perform user authentication. Raise an error in case authentication failed. pub(crate) async fn authenticate(self) -> super::Result> { let Scram(secret, ctx) = self.state; + let channel_binding = self.tls_server_end_point; - // pause the timer while we communicate with the client - let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + // send sasl message. + { + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - // Initial client message contains the chosen auth method's name. - let msg = self.stream.read_password_message().await?; - let sasl = sasl::FirstMessage::parse(&msg) - .ok_or(AuthError::MalformedPassword("bad sasl message"))?; - - // Currently, the only supported SASL method is SCRAM. - if !scram::METHODS.contains(&sasl.method) { - return Err(super::AuthError::bad_auth_method(sasl.method)); + let sasl = self.state.first_message(channel_binding.supported()); + self.stream.write_message(sasl); + self.stream.flush().await?; } - match sasl.method { - SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256), - SCRAM_SHA_256_PLUS => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus), - _ => {} - } + // complete sasl handshake. + sasl::authenticate(ctx, self.stream, |method| { + // Currently, the only supported SASL method is SCRAM. + match method { + SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256), + SCRAM_SHA_256_PLUS => { + ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus); + } + method => return Err(sasl::Error::BadAuthMethod(method.into())), + } - // TODO: make this a metric instead - info!("client chooses {}", sasl.method); + // TODO: make this a metric instead + info!("client chooses {}", method); - let outcome = sasl::SaslStream::new(self.stream, sasl.message) - .authenticate(scram::Exchange::new( - secret, - rand::random, - self.tls_server_end_point, - )) - .await?; - - if let sasl::Outcome::Success(_) = &outcome { - self.stream.write_message_noflush(&Be::AuthenticationOk)?; - } - - Ok(outcome) + Ok(scram::Exchange::new(secret, rand::random, channel_binding)) + }) + .await + .map_err(AuthError::Sasl) } } diff --git a/proxy/src/binary/pg_sni_router.rs b/proxy/src/binary/pg_sni_router.rs index 3e87538ae7..a4f517fead 100644 --- a/proxy/src/binary/pg_sni_router.rs +++ b/proxy/src/binary/pg_sni_router.rs @@ -4,8 +4,9 @@ //! This allows connecting to pods/services running in the same Kubernetes cluster from //! the outside. Similar to an ingress controller for HTTPS. +use std::net::SocketAddr; use std::path::Path; -use std::{net::SocketAddr, sync::Arc}; +use std::sync::Arc; use anyhow::{Context, anyhow, bail, ensure}; use clap::Arg; @@ -17,6 +18,7 @@ use rustls::pki_types::{DnsName, PrivateKeyDer}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpListener; use tokio_rustls::TlsConnector; +use tokio_rustls::server::TlsStream; use tokio_util::sync::CancellationToken; use tracing::{Instrument, error, info}; use utils::project_git_version; @@ -24,10 +26,12 @@ use utils::sentry_init::init_sentry; use crate::context::RequestContext; use crate::metrics::{Metrics, ThreadPoolMetrics}; +use crate::pqproto::FeStartupPacket; use crate::protocol2::ConnectionInfo; -use crate::proxy::{ErrorSource, copy_bidirectional_client_compute, run_until_cancelled}; +use crate::proxy::{ + ErrorSource, TlsRequired, copy_bidirectional_client_compute, run_until_cancelled, +}; use crate::stream::{PqStream, Stream}; -use crate::tls::TlsServerEndPoint; project_git_version!(GIT_VERSION); @@ -84,7 +88,7 @@ pub async fn run() -> anyhow::Result<()> { .parse()?; // Configure TLS - let (tls_config, tls_server_end_point): (Arc, TlsServerEndPoint) = match ( + let tls_config = match ( args.get_one::("tls-key"), args.get_one::("tls-cert"), ) { @@ -117,7 +121,6 @@ pub async fn run() -> anyhow::Result<()> { dest.clone(), tls_config.clone(), None, - tls_server_end_point, proxy_listener, cancellation_token.clone(), )) @@ -127,7 +130,6 @@ pub async fn run() -> anyhow::Result<()> { dest, tls_config, Some(compute_tls_config), - tls_server_end_point, proxy_listener_compute_tls, cancellation_token.clone(), )) @@ -154,7 +156,7 @@ pub async fn run() -> anyhow::Result<()> { pub(super) fn parse_tls( key_path: &Path, cert_path: &Path, -) -> anyhow::Result<(Arc, TlsServerEndPoint)> { +) -> anyhow::Result> { let key = { let key_bytes = std::fs::read(key_path).context("TLS key file")?; @@ -187,10 +189,6 @@ pub(super) fn parse_tls( })? }; - // needed for channel bindings - let first_cert = cert_chain.first().context("missing certificate")?; - let tls_server_end_point = TlsServerEndPoint::new(first_cert)?; - let tls_config = rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider())) .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12]) @@ -199,14 +197,13 @@ pub(super) fn parse_tls( .with_single_cert(cert_chain, key)? .into(); - Ok((tls_config, tls_server_end_point)) + Ok(tls_config) } pub(super) async fn task_main( dest_suffix: Arc, tls_config: Arc, compute_tls_config: Option>, - tls_server_end_point: TlsServerEndPoint, listener: tokio::net::TcpListener, cancellation_token: CancellationToken, ) -> anyhow::Result<()> { @@ -242,15 +239,7 @@ pub(super) async fn task_main( crate::metrics::Protocol::SniRouter, "sni", ); - handle_client( - ctx, - dest_suffix, - tls_config, - compute_tls_config, - tls_server_end_point, - socket, - ) - .await + handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await } .unwrap_or_else(|e| { // Acknowledge that the task has finished with an error. @@ -269,55 +258,26 @@ pub(super) async fn task_main( Ok(()) } -const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; - async fn ssl_handshake( ctx: &RequestContext, raw_stream: S, tls_config: Arc, - tls_server_end_point: TlsServerEndPoint, -) -> anyhow::Result> { - let mut stream = PqStream::new(Stream::from_raw(raw_stream)); - - let msg = stream.read_startup_packet().await?; - use pq_proto::FeStartupPacket::SslRequest; - +) -> anyhow::Result> { + let (mut stream, msg) = PqStream::parse_startup(Stream::from_raw(raw_stream)).await?; match msg { - SslRequest { direct: false } => { - stream - .write_message(&pq_proto::BeMessage::EncryptionResponse(true)) - .await?; + FeStartupPacket::SslRequest { direct: None } => { + let raw = stream.accept_tls().await?; - // Upgrade raw stream into a secure TLS-backed stream. - // NOTE: We've consumed `tls`; this fact will be used later. - - let (raw, read_buf) = stream.into_inner(); - // TODO: Normally, client doesn't send any data before - // server says TLS handshake is ok and read_buf is empty. - // However, you could imagine pipelining of postgres - // SSLRequest + TLS ClientHello in one hunk similar to - // pipelining in our node js driver. We should probably - // support that by chaining read_buf with the stream. - if !read_buf.is_empty() { - bail!("data is sent before server replied with EncryptionResponse"); - } - - Ok(Stream::Tls { - tls: Box::new( - raw.upgrade(tls_config, !ctx.has_private_peer_addr()) - .await?, - ), - tls_server_end_point, - }) + Ok(raw + .upgrade(tls_config, !ctx.has_private_peer_addr()) + .await?) } unexpected => { info!( ?unexpected, "unexpected startup packet, rejecting connection" ); - stream - .throw_error_str(ERR_INSECURE_CONNECTION, crate::error::ErrorKind::User, None) - .await? + Err(stream.throw_error(TlsRequired, None).await)? } } } @@ -327,15 +287,18 @@ async fn handle_client( dest_suffix: Arc, tls_config: Arc, compute_tls_config: Option>, - tls_server_end_point: TlsServerEndPoint, stream: impl AsyncRead + AsyncWrite + Unpin, ) -> anyhow::Result<()> { - let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?; + let mut tls_stream = ssl_handshake(&ctx, stream, tls_config).await?; // Cut off first part of the SNI domain // We receive required destination details in the format of // `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain` - let sni = tls_stream.sni_hostname().ok_or(anyhow!("SNI missing"))?; + let sni = tls_stream + .get_ref() + .1 + .server_name() + .ok_or(anyhow!("SNI missing"))?; let dest: Vec<&str> = sni .split_once('.') .context("invalid SNI")? diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 5f24940985..9a3903ba9a 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -476,8 +476,7 @@ pub async fn run() -> anyhow::Result<()> { let key_path = args.tls_key.expect("already asserted it is set"); let cert_path = args.tls_cert.expect("already asserted it is set"); - let (tls_config, tls_server_end_point) = - super::pg_sni_router::parse_tls(&key_path, &cert_path)?; + let tls_config = super::pg_sni_router::parse_tls(&key_path, &cert_path)?; let dest = Arc::new(dest); @@ -485,7 +484,6 @@ pub async fn run() -> anyhow::Result<()> { dest.clone(), tls_config.clone(), None, - tls_server_end_point, listen, cancellation_token.clone(), )); @@ -494,7 +492,6 @@ pub async fn run() -> anyhow::Result<()> { dest, tls_config, Some(config.connect_to_compute.tls.clone()), - tls_server_end_point, listen_tls, cancellation_token.clone(), )); diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index a6e7bf85a0..0bff901376 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -5,7 +5,6 @@ use anyhow::{Context, anyhow}; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use postgres_client::CancelToken; use postgres_client::tls::MakeTlsConnect; -use pq_proto::CancelKeyData; use redis::{Cmd, FromRedisValue, Value}; use serde::{Deserialize, Serialize}; use thiserror::Error; @@ -21,6 +20,7 @@ use crate::control_plane::ControlPlaneApi; use crate::error::ReportableError; use crate::ext::LockExt; use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind}; +use crate::pqproto::CancelKeyData; use crate::protocol2::ConnectionInfoExtra; use crate::rate_limiter::LeakyBucketRateLimiter; use crate::redis::keys::KeyPrefix; diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 26254beecf..2899f25129 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -8,7 +8,6 @@ use itertools::Itertools; use postgres_client::tls::MakeTlsConnect; use postgres_client::{CancelToken, RawConnection}; use postgres_protocol::message::backend::NoticeResponseBody; -use pq_proto::StartupMessageParams; use rustls::pki_types::InvalidDnsNameError; use thiserror::Error; use tokio::net::{TcpStream, lookup_host}; @@ -24,6 +23,7 @@ use crate::control_plane::errors::WakeComputeError; use crate::control_plane::messages::MetricsAuxInfo; use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumDbConnectionsGuard}; +use crate::pqproto::StartupMessageParams; use crate::proxy::neon_option; use crate::tls::postgres_rustls::MakeRustlsConnect; use crate::types::Host; diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index e3184e20d1..9499aba61b 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use futures::{FutureExt, TryFutureExt}; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::sync::CancellationToken; use tracing::{Instrument, debug, error, info}; @@ -221,12 +221,10 @@ pub(crate) async fn handle_client( .await { Ok(auth_result) => auth_result, - Err(e) => { - return stream.throw_error(e, Some(ctx)).await?; - } + Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, }; - let mut node = connect_to_compute( + let node = connect_to_compute( ctx, &TcpMechanism { user_info, @@ -238,7 +236,7 @@ pub(crate) async fn handle_client( config.wake_compute_retry_config, &config.connect_to_compute, ) - .or_else(|e| stream.throw_error(e, Some(ctx))) + .or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) }) .await?; let cancellation_handler_clone = Arc::clone(&cancellation_handler); @@ -246,14 +244,8 @@ pub(crate) async fn handle_client( session.write_cancel_key(node.cancel_closure.clone())?; - prepare_client_connection(&node, *session.key(), &mut stream).await?; - - // Before proxy passing, forward to compute whatever data is left in the - // PqStream input buffer. Normally there is none, but our serverless npm - // driver in pipeline mode sends startup, password and first query - // immediately after opening the connection. - let (stream, read_buf) = stream.into_inner(); - node.stream.write_all(&read_buf).await?; + prepare_client_connection(&node, *session.key(), &mut stream); + let stream = stream.flush_and_into_inner().await?; Ok(Some(ProxyPassthrough { client: stream, diff --git a/proxy/src/context/mod.rs b/proxy/src/context/mod.rs index 79aaf22990..de4600951e 100644 --- a/proxy/src/context/mod.rs +++ b/proxy/src/context/mod.rs @@ -4,7 +4,6 @@ use std::net::IpAddr; use chrono::Utc; use once_cell::sync::OnceCell; -use pq_proto::StartupMessageParams; use smol_str::SmolStr; use tokio::sync::mpsc; use tracing::field::display; @@ -20,6 +19,7 @@ use crate::metrics::{ ConnectOutcome, InvalidEndpointsGroup, LatencyAccumulated, LatencyTimer, Metrics, Protocol, Waiting, }; +use crate::pqproto::StartupMessageParams; use crate::protocol2::{ConnectionInfo, ConnectionInfoExtra}; use crate::types::{DbName, EndpointId, RoleName}; diff --git a/proxy/src/context/parquet.rs b/proxy/src/context/parquet.rs index f6250bcd17..c9d3905abd 100644 --- a/proxy/src/context/parquet.rs +++ b/proxy/src/context/parquet.rs @@ -11,7 +11,6 @@ use parquet::file::metadata::RowGroupMetaDataPtr; use parquet::file::properties::{DEFAULT_PAGE_SIZE, WriterProperties, WriterPropertiesPtr}; use parquet::file::writer::SerializedFileWriter; use parquet::record::RecordWriter; -use pq_proto::StartupMessageParams; use remote_storage::{GenericRemoteStorage, RemotePath, RemoteStorageConfig, TimeoutOrCancel}; use serde::ser::SerializeMap; use tokio::sync::mpsc; @@ -24,6 +23,7 @@ use super::{LOG_CHAN, RequestContextInner}; use crate::config::remote_storage_from_toml; use crate::context::LOG_CHAN_DISCONNECT; use crate::ext::TaskExt; +use crate::pqproto::StartupMessageParams; #[derive(clap::Args, Clone, Debug)] pub struct ParquetUploadArgs { diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index d1f8430b8a..d65d056585 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -92,6 +92,7 @@ mod logging; mod metrics; mod parse; mod pglb; +mod pqproto; mod protocol2; mod proxy; mod rate_limiter; diff --git a/proxy/src/pqproto.rs b/proxy/src/pqproto.rs new file mode 100644 index 0000000000..d68d9f9474 --- /dev/null +++ b/proxy/src/pqproto.rs @@ -0,0 +1,693 @@ +//! Postgres protocol codec +//! +//! + +use std::fmt; +use std::io::{self, Cursor}; + +use bytes::{Buf, BufMut}; +use itertools::Itertools; +use rand::distributions::{Distribution, Standard}; +use tokio::io::{AsyncRead, AsyncReadExt}; +use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian}; + +pub type ErrorCode = [u8; 5]; + +pub const FE_PASSWORD_MESSAGE: u8 = b'p'; + +pub const SQLSTATE_INTERNAL_ERROR: [u8; 5] = *b"XX000"; + +/// The protocol version number. +/// +/// The most significant 16 bits are the major version number (3 for the protocol described here). +/// The least significant 16 bits are the minor version number (0 for the protocol described here). +/// +#[derive(Clone, Copy, PartialEq, PartialOrd, FromBytes, IntoBytes, Immutable)] +#[repr(C)] +pub struct ProtocolVersion { + major: big_endian::U16, + minor: big_endian::U16, +} + +impl ProtocolVersion { + pub const fn new(major: u16, minor: u16) -> Self { + Self { + major: big_endian::U16::new(major), + minor: big_endian::U16::new(minor), + } + } + pub const fn minor(self) -> u16 { + self.minor.get() + } + pub const fn major(self) -> u16 { + self.major.get() + } +} + +impl fmt::Debug for ProtocolVersion { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list() + .entry(&self.major()) + .entry(&self.minor()) + .finish() + } +} + +/// read the type from the stream using zerocopy. +/// +/// not cancel safe. +macro_rules! read { + ($s:expr => $t:ty) => {{ + // cannot be implemented as a function due to lack of const-generic-expr + let mut buf = [0; size_of::<$t>()]; + $s.read_exact(&mut buf).await?; + let res: $t = zerocopy::transmute!(buf); + res + }}; +} + +pub async fn read_startup(stream: &mut S) -> io::Result +where + S: AsyncRead + Unpin, +{ + /// + const MAX_STARTUP_PACKET_LENGTH: usize = 10000; + const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234; + /// + const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678); + /// + const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679); + /// + const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680); + + /// This first reads the startup message header, is 8 bytes. + /// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number. + /// + /// The length value is inclusive of the header. For example, + /// an empty message will always have length 8. + #[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)] + #[repr(C)] + struct StartupHeader { + len: big_endian::U32, + version: ProtocolVersion, + } + + let header = read!(stream => StartupHeader); + + // + // First byte indicates standard SSL handshake message + // (It can't be a Postgres startup length because in network byte order + // that would be a startup packet hundreds of megabytes long) + if header.as_bytes()[0] == 0x16 { + return Ok(FeStartupPacket::SslRequest { + // The bytes we read for the header are actually part of a TLS ClientHello. + // In theory, if the ClientHello was < 8 bytes we would fail with EOF before we get here. + // In practice though, I see no world where a ClientHello is less than 8 bytes + // since it includes ephemeral keys etc. + direct: Some(zerocopy::transmute!(header)), + }); + } + + let Some(len) = (header.len.get() as usize).checked_sub(8) else { + return Err(io::Error::other(format!( + "invalid startup message length {}, must be at least 8.", + header.len, + ))); + }; + + // TODO: add a histogram for startup packet lengths + if len > MAX_STARTUP_PACKET_LENGTH { + tracing::warn!("large startup message detected: {len} bytes"); + return Err(io::Error::other(format!( + "invalid startup message length {len}" + ))); + } + + match header.version { + // + CANCEL_REQUEST_CODE => { + if len != 8 { + return Err(io::Error::other( + "CancelRequest message is malformed, backend PID / secret key missing", + )); + } + + Ok(FeStartupPacket::CancelRequest( + read!(stream => CancelKeyData), + )) + } + // + NEGOTIATE_SSL_CODE => { + // Requested upgrade to SSL (aka TLS) + Ok(FeStartupPacket::SslRequest { direct: None }) + } + NEGOTIATE_GSS_CODE => { + // Requested upgrade to GSSAPI + Ok(FeStartupPacket::GssEncRequest) + } + version if version.major() == RESERVED_INVALID_MAJOR_VERSION => Err(io::Error::other( + format!("Unrecognized request code {version:?}"), + )), + // StartupMessage + version => { + // The protocol version number is followed by one or more pairs of parameter name and value strings. + // A zero byte is required as a terminator after the last name/value pair. + // Parameters can appear in any order. user is required, others are optional. + + let mut buf = vec![0; len]; + stream.read_exact(&mut buf).await?; + + if buf.pop() != Some(b'\0') { + return Err(io::Error::other( + "StartupMessage params: missing null terminator", + )); + } + + // TODO: Don't do this. + // There's no guarantee that these messages are utf8, + // but they usually happen to be simple ascii. + let params = String::from_utf8(buf) + .map_err(|_| io::Error::other("StartupMessage params: invalid utf-8"))?; + + Ok(FeStartupPacket::StartupMessage { + version, + params: StartupMessageParams { params }, + }) + } + } +} + +/// Read a raw postgres packet, which will respect the max length requested. +/// +/// This returns the message tag, as well as the message body. The message +/// body is written into `buf`, and it is otherwise completely overwritten. +/// +/// This is not cancel safe. +pub async fn read_message<'a, S>( + stream: &mut S, + buf: &'a mut Vec, + max: usize, +) -> io::Result<(u8, &'a mut [u8])> +where + S: AsyncRead + Unpin, +{ + /// This first reads the header, which for regular messages in the 3.0 protocol is 5 bytes. + /// The first byte is a message tag, and the next 4 bytes is a big-endian length. + /// + /// Awkwardly, the length value is inclusive of itself, but not of the tag. For example, + /// an empty message will always have length 4. + #[derive(Clone, Copy, FromBytes)] + #[repr(C)] + struct Header { + tag: u8, + len: big_endian::U32, + } + + let header = read!(stream => Header); + + // as described above, the length must be at least 4. + let Some(len) = (header.len.get() as usize).checked_sub(4) else { + return Err(io::Error::other(format!( + "invalid startup message length {}, must be at least 4.", + header.len, + ))); + }; + + // TODO: add a histogram for message lengths + + // check if the message exceeds our desired max. + if len > max { + tracing::warn!("large postgres message detected: {len} bytes"); + return Err(io::Error::other(format!("invalid message length {len}"))); + } + + // read in our entire message. + buf.resize(len, 0); + stream.read_exact(buf).await?; + + Ok((header.tag, buf)) +} + +pub struct WriteBuf(Cursor>); + +impl Buf for WriteBuf { + #[inline] + fn remaining(&self) -> usize { + self.0.remaining() + } + + #[inline] + fn chunk(&self) -> &[u8] { + self.0.chunk() + } + + #[inline] + fn advance(&mut self, cnt: usize) { + self.0.advance(cnt); + } +} + +impl WriteBuf { + pub const fn new() -> Self { + Self(Cursor::new(Vec::new())) + } + + /// Use a heuristic to determine if we should shrink the write buffer. + #[inline] + fn should_shrink(&self) -> bool { + let n = self.0.position() as usize; + let len = self.0.get_ref().len(); + + // the unused space at the front of our buffer is 2x the size of our filled portion. + n + n > len + } + + /// Shrink the write buffer so that subsequent writes have more spare capacity. + #[cold] + fn shrink(&mut self) { + let n = self.0.position() as usize; + let buf = self.0.get_mut(); + + // buf repr: + // [----unused------|-----filled-----|-----uninit-----] + // ^ n ^ buf.len() ^ buf.capacity() + let filled = n..buf.len(); + let filled_len = filled.len(); + buf.copy_within(filled, 0); + buf.truncate(filled_len); + self.0.set_position(0); + } + + /// clear the write buffer. + pub fn reset(&mut self) { + let buf = self.0.get_mut(); + buf.clear(); + self.0.set_position(0); + } + + /// Write a raw message to the internal buffer. + /// + /// The size_hint value is only a hint for reserving space. It's ok if it's incorrect, since + /// we calculate the length after the fact. + pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec)) { + if self.should_shrink() { + self.shrink(); + } + + let buf = self.0.get_mut(); + buf.reserve(5 + size_hint); + + buf.push(tag); + let start = buf.len(); + buf.extend_from_slice(&[0, 0, 0, 0]); + + f(buf); + + let end = buf.len(); + let len = (end - start) as u32; + buf[start..start + 4].copy_from_slice(&len.to_be_bytes()); + } + + /// Write an encryption response message. + pub fn encryption(&mut self, m: u8) { + self.0.get_mut().push(m); + } + + pub fn write_error(&mut self, msg: &str, error_code: ErrorCode) { + self.shrink(); + + // + // + // "SERROR\0CXXXXX\0M\0\0".len() == 17 + self.write_raw(17 + msg.len(), b'E', |buf| { + // Severity: ERROR + buf.put_slice(b"SERROR\0"); + + // Code: error_code + buf.put_u8(b'C'); + buf.put_slice(&error_code); + buf.put_u8(0); + + // Message: msg + buf.put_u8(b'M'); + buf.put_slice(msg.as_bytes()); + buf.put_u8(0); + + // End. + buf.put_u8(0); + }); + } +} + +#[derive(Debug)] +pub enum FeStartupPacket { + CancelRequest(CancelKeyData), + SslRequest { + direct: Option<[u8; 8]>, + }, + GssEncRequest, + StartupMessage { + version: ProtocolVersion, + params: StartupMessageParams, + }, +} + +#[derive(Debug, Clone, Default)] +pub struct StartupMessageParams { + pub params: String, +} + +impl StartupMessageParams { + /// Get parameter's value by its name. + pub fn get(&self, name: &str) -> Option<&str> { + self.iter().find_map(|(k, v)| (k == name).then_some(v)) + } + + /// Split command-line options according to PostgreSQL's logic, + /// taking into account all escape sequences but leaving them as-is. + /// [`None`] means that there's no `options` in [`Self`]. + pub fn options_raw(&self) -> Option> { + self.get("options").map(Self::parse_options_raw) + } + + /// Split command-line options according to PostgreSQL's logic, + /// taking into account all escape sequences but leaving them as-is. + pub fn parse_options_raw(input: &str) -> impl Iterator { + // See `postgres: pg_split_opts`. + let mut last_was_escape = false; + input + .split(move |c: char| { + // We split by non-escaped whitespace symbols. + let should_split = c.is_ascii_whitespace() && !last_was_escape; + last_was_escape = c == '\\' && !last_was_escape; + should_split + }) + .filter(|s| !s.is_empty()) + } + + /// Iterate through key-value pairs in an arbitrary order. + pub fn iter(&self) -> impl Iterator { + self.params.split_terminator('\0').tuples() + } + + // This function is mostly useful in tests. + #[cfg(test)] + pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self { + let mut b = Self { + params: String::new(), + }; + for (k, v) in pairs { + b.insert(k, v); + } + b + } + + /// Set parameter's value by its name. + /// name and value must not contain a \0 byte + pub fn insert(&mut self, name: &str, value: &str) { + self.params.reserve(name.len() + value.len() + 2); + self.params.push_str(name); + self.params.push('\0'); + self.params.push_str(value); + self.params.push('\0'); + } +} + +/// Cancel keys usually are represented as PID+SecretKey, but to proxy they're just +/// opaque bytes. +#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, FromBytes, IntoBytes, Immutable)] +pub struct CancelKeyData(pub big_endian::U64); + +pub fn id_to_cancel_key(id: u64) -> CancelKeyData { + CancelKeyData(big_endian::U64::new(id)) +} + +impl fmt::Display for CancelKeyData { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let id = self.0; + f.debug_tuple("CancelKeyData") + .field(&format_args!("{id:x}")) + .finish() + } +} +impl Distribution for Standard { + fn sample(&self, rng: &mut R) -> CancelKeyData { + id_to_cancel_key(rng.r#gen()) + } +} + +pub enum BeMessage<'a> { + AuthenticationOk, + AuthenticationSasl(BeAuthenticationSaslMessage<'a>), + AuthenticationCleartextPassword, + BackendKeyData(CancelKeyData), + ParameterStatus { + name: &'a [u8], + value: &'a [u8], + }, + ReadyForQuery, + NoticeResponse(&'a str), + NegotiateProtocolVersion { + version: ProtocolVersion, + options: &'a [&'a str], + }, +} + +#[derive(Debug)] +pub enum BeAuthenticationSaslMessage<'a> { + Methods(&'a [&'a str]), + Continue(&'a [u8]), + Final(&'a [u8]), +} + +impl BeMessage<'_> { + /// Write the message into an internal buffer + pub fn write_message(self, buf: &mut WriteBuf) { + match self { + // + BeMessage::AuthenticationOk => { + buf.write_raw(1, b'R', |buf| buf.put_i32(0)); + } + // + BeMessage::AuthenticationCleartextPassword => { + buf.write_raw(1, b'R', |buf| buf.put_i32(3)); + } + + // + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(methods)) => { + let len: usize = methods.iter().map(|m| m.len() + 1).sum(); + buf.write_raw(len + 2, b'R', |buf| { + buf.put_i32(10); // Specifies that SASL auth method is used. + for method in methods { + buf.put_slice(method.as_bytes()); + buf.put_u8(0); + } + buf.put_u8(0); // zero terminator for the list + }); + } + // + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Continue(extra)) => { + buf.write_raw(extra.len() + 1, b'R', |buf| { + buf.put_i32(11); // Continue SASL auth. + buf.put_slice(extra); + }); + } + // + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Final(extra)) => { + buf.write_raw(extra.len() + 1, b'R', |buf| { + buf.put_i32(12); // Send final SASL message. + buf.put_slice(extra); + }); + } + + // + BeMessage::BackendKeyData(key_data) => { + buf.write_raw(8, b'K', |buf| buf.put_slice(key_data.as_bytes())); + } + + // + // + BeMessage::NoticeResponse(msg) => { + // 'N' signalizes NoticeResponse messages + buf.write_raw(18 + msg.len(), b'N', |buf| { + // Severity: NOTICE + buf.put_slice(b"SNOTICE\0"); + + // Code: XX000 (ignored for notice, but still required) + buf.put_slice(b"CXX000\0"); + + // Message: msg + buf.put_u8(b'M'); + buf.put_slice(msg.as_bytes()); + buf.put_u8(0); + + // End notice. + buf.put_u8(0); + }); + } + + // + BeMessage::ParameterStatus { name, value } => { + buf.write_raw(name.len() + value.len() + 2, b'S', |buf| { + buf.put_slice(name.as_bytes()); + buf.put_u8(0); + buf.put_slice(value.as_bytes()); + buf.put_u8(0); + }); + } + + // + BeMessage::ReadyForQuery => { + buf.write_raw(1, b'Z', |buf| buf.put_u8(b'I')); + } + + // + BeMessage::NegotiateProtocolVersion { version, options } => { + let len: usize = options.iter().map(|o| o.len() + 1).sum(); + buf.write_raw(8 + len, b'v', |buf| { + buf.put_slice(version.as_bytes()); + buf.put_u32(options.len() as u32); + for option in options { + buf.put_slice(option.as_bytes()); + buf.put_u8(0); + } + }); + } + } + } +} + +#[cfg(test)] +mod tests { + use std::io::Cursor; + + use tokio::io::{AsyncWriteExt, duplex}; + use zerocopy::IntoBytes; + + use crate::pqproto::{FeStartupPacket, read_message, read_startup}; + + use super::ProtocolVersion; + + #[tokio::test] + async fn reject_large_startup() { + // we're going to define a v3.0 startup message with far too many parameters. + let mut payload = vec![]; + // 10001 + 8 bytes. + payload.extend_from_slice(&10009_u32.to_be_bytes()); + payload.extend_from_slice(ProtocolVersion::new(3, 0).as_bytes()); + payload.resize(10009, b'a'); + + let (mut server, mut client) = duplex(128); + #[rustfmt::skip] + let (server, client) = tokio::join!( + async move { read_startup(&mut server).await.unwrap_err() }, + async move { client.write_all(&payload).await.unwrap_err() }, + ); + + assert_eq!(server.to_string(), "invalid startup message length 10001"); + assert_eq!(client.to_string(), "broken pipe"); + } + + #[tokio::test] + async fn reject_large_password() { + // we're going to define a password message that is far too long. + let mut payload = vec![]; + payload.push(b'p'); + payload.extend_from_slice(&517_u32.to_be_bytes()); + payload.resize(518, b'a'); + + let (mut server, mut client) = duplex(128); + #[rustfmt::skip] + let (server, client) = tokio::join!( + async move { read_message(&mut server, &mut vec![], 512).await.unwrap_err() }, + async move { client.write_all(&payload).await.unwrap_err() }, + ); + + assert_eq!(server.to_string(), "invalid message length 513"); + assert_eq!(client.to_string(), "broken pipe"); + } + + #[tokio::test] + async fn read_startup_message() { + let mut payload = vec![]; + payload.extend_from_slice(&17_u32.to_be_bytes()); + payload.extend_from_slice(ProtocolVersion::new(3, 0).as_bytes()); + payload.extend_from_slice(b"abc\0def\0\0"); + + let startup = read_startup(&mut Cursor::new(&payload)).await.unwrap(); + let FeStartupPacket::StartupMessage { version, params } = startup else { + panic!("unexpected startup message: {startup:?}"); + }; + + assert_eq!(version.major(), 3); + assert_eq!(version.minor(), 0); + assert_eq!(params.params, "abc\0def\0"); + } + + #[tokio::test] + async fn read_ssl_message() { + let mut payload = vec![]; + payload.extend_from_slice(&8_u32.to_be_bytes()); + payload.extend_from_slice(ProtocolVersion::new(1234, 5679).as_bytes()); + + let startup = read_startup(&mut Cursor::new(&payload)).await.unwrap(); + let FeStartupPacket::SslRequest { direct: None } = startup else { + panic!("unexpected startup message: {startup:?}"); + }; + } + + #[tokio::test] + async fn read_tls_message() { + // sample client hello taken from + let client_hello = [ + 0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00, 0xf4, 0x03, 0x03, 0x00, 0x01, 0x02, + 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, + 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, + 0x1f, 0x20, 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb, + 0xec, 0xed, 0xee, 0xef, 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, + 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, 0x00, 0x08, 0x13, 0x02, 0x13, 0x03, 0x13, 0x01, + 0x00, 0xff, 0x01, 0x00, 0x00, 0xa3, 0x00, 0x00, 0x00, 0x18, 0x00, 0x16, 0x00, 0x00, + 0x13, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x75, 0x6c, 0x66, 0x68, 0x65, + 0x69, 0x6d, 0x2e, 0x6e, 0x65, 0x74, 0x00, 0x0b, 0x00, 0x04, 0x03, 0x00, 0x01, 0x02, + 0x00, 0x0a, 0x00, 0x16, 0x00, 0x14, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x1e, 0x00, 0x19, + 0x00, 0x18, 0x01, 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x23, + 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x1e, + 0x00, 0x1c, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x08, 0x07, 0x08, 0x08, 0x08, 0x09, + 0x08, 0x0a, 0x08, 0x0b, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06, 0x04, 0x01, 0x05, 0x01, + 0x06, 0x01, 0x00, 0x2b, 0x00, 0x03, 0x02, 0x03, 0x04, 0x00, 0x2d, 0x00, 0x02, 0x01, + 0x01, 0x00, 0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0x35, 0x80, 0x72, + 0xd6, 0x36, 0x58, 0x80, 0xd1, 0xae, 0xea, 0x32, 0x9a, 0xdf, 0x91, 0x21, 0x38, 0x38, + 0x51, 0xed, 0x21, 0xa2, 0x8e, 0x3b, 0x75, 0xe9, 0x65, 0xd0, 0xd2, 0xcd, 0x16, 0x62, + 0x54, + ]; + + let mut cursor = Cursor::new(&client_hello); + + let startup = read_startup(&mut cursor).await.unwrap(); + let FeStartupPacket::SslRequest { + direct: Some(prefix), + } = startup + else { + panic!("unexpected startup message: {startup:?}"); + }; + + // check that no data is lost. + assert_eq!(prefix, [0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00]); + assert_eq!(cursor.position(), 8); + } + + #[tokio::test] + async fn read_message_success() { + let query = b"Q\0\0\0\x0cSELECT 1Q\0\0\0\x0cSELECT 2"; + let mut cursor = Cursor::new(&query); + + let mut buf = vec![]; + let (tag, message) = read_message(&mut cursor, &mut buf, 100).await.unwrap(); + assert_eq!(tag, b'Q'); + assert_eq!(message, b"SELECT 1"); + + let (tag, message) = read_message(&mut cursor, &mut buf, 100).await.unwrap(); + assert_eq!(tag, b'Q'); + assert_eq!(message, b"SELECT 2"); + } +} diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index e013fbbe2e..57785c9ec5 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -1,5 +1,4 @@ use async_trait::async_trait; -use pq_proto::StartupMessageParams; use tokio::time; use tracing::{debug, info, warn}; @@ -15,6 +14,7 @@ use crate::error::ReportableError; use crate::metrics::{ ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType, }; +use crate::pqproto::StartupMessageParams; use crate::proxy::retry::{CouldRetry, retry_after, should_retry}; use crate::proxy::wake_compute::wake_compute; use crate::types::Host; diff --git a/proxy/src/proxy/handshake.rs b/proxy/src/proxy/handshake.rs index 54c02f2c15..13ee8c7dd2 100644 --- a/proxy/src/proxy/handshake.rs +++ b/proxy/src/proxy/handshake.rs @@ -1,8 +1,3 @@ -use bytes::Buf; -use pq_proto::framed::Framed; -use pq_proto::{ - BeMessage as Be, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams, -}; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{debug, info, warn}; @@ -12,7 +7,10 @@ use crate::config::TlsConfig; use crate::context::RequestContext; use crate::error::ReportableError; use crate::metrics::Metrics; -use crate::proxy::ERR_INSECURE_CONNECTION; +use crate::pqproto::{ + BeMessage, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams, +}; +use crate::proxy::TlsRequired; use crate::stream::{PqStream, Stream, StreamUpgradeError}; use crate::tls::PG_ALPN_PROTOCOL; @@ -71,33 +69,25 @@ pub(crate) async fn handshake( const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0); const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0); - let mut stream = PqStream::new(Stream::from_raw(stream)); + let (mut stream, mut msg) = PqStream::parse_startup(Stream::from_raw(stream)).await?; loop { - let msg = stream.read_startup_packet().await?; match msg { FeStartupPacket::SslRequest { direct } => match stream.get_ref() { Stream::Raw { .. } if !tried_ssl => { tried_ssl = true; - // We can't perform TLS handshake without a config - let have_tls = tls.is_some(); - if !direct { - stream - .write_message(&Be::EncryptionResponse(have_tls)) - .await?; - } else if !have_tls { - return Err(HandshakeError::ProtocolViolation); - } - if let Some(tls) = tls.take() { // Upgrade raw stream into a secure TLS-backed stream. // NOTE: We've consumed `tls`; this fact will be used later. - let Framed { - stream: raw, - read_buf, - write_buf, - } = stream.framed; + let mut read_buf; + let raw = if let Some(direct) = &direct { + read_buf = &direct[..]; + stream.accept_direct_tls() + } else { + read_buf = &[]; + stream.accept_tls().await? + }; let Stream::Raw { raw } = raw else { return Err(HandshakeError::StreamUpgradeError( @@ -105,12 +95,11 @@ pub(crate) async fn handshake( )); }; - let mut read_buf = read_buf.reader(); let mut res = Ok(()); let accept = tokio_rustls::TlsAcceptor::from(tls.pg_config.clone()) .accept_with(raw, |session| { // push the early data to the tls session - while !read_buf.get_ref().is_empty() { + while !read_buf.is_empty() { match session.read_tls(&mut read_buf) { Ok(_) => {} Err(e) => { @@ -123,7 +112,6 @@ pub(crate) async fn handshake( res?; - let read_buf = read_buf.into_inner(); if !read_buf.is_empty() { return Err(HandshakeError::EarlyData); } @@ -157,16 +145,17 @@ pub(crate) async fn handshake( let (_, tls_server_end_point) = tls.cert_resolver.resolve(conn_info.server_name()); - stream = PqStream { - framed: Framed { - stream: Stream::Tls { - tls: Box::new(tls_stream), - tls_server_end_point, - }, - read_buf, - write_buf, - }, + let tls = Stream::Tls { + tls: Box::new(tls_stream), + tls_server_end_point, }; + (stream, msg) = PqStream::parse_startup(tls).await?; + } else { + if direct.is_some() { + // client sent us a ClientHello already, we can't do anything with it. + return Err(HandshakeError::ProtocolViolation); + } + msg = stream.reject_encryption().await?; } } _ => return Err(HandshakeError::ProtocolViolation), @@ -176,7 +165,7 @@ pub(crate) async fn handshake( tried_gss = true; // Currently, we don't support GSSAPI - stream.write_message(&Be::EncryptionResponse(false)).await?; + msg = stream.reject_encryption().await?; } _ => return Err(HandshakeError::ProtocolViolation), }, @@ -186,13 +175,7 @@ pub(crate) async fn handshake( // Check that the config has been consumed during upgrade // OR we didn't provide it at all (for dev purposes). if tls.is_some() { - return stream - .throw_error_str( - ERR_INSECURE_CONNECTION, - crate::error::ErrorKind::User, - None, - ) - .await?; + Err(stream.throw_error(TlsRequired, None).await)?; } // This log highlights the start of the connection. @@ -214,20 +197,21 @@ pub(crate) async fn handshake( // no protocol extensions are supported. // let mut unsupported = vec![]; - for (k, _) in params.iter() { + let mut supported = StartupMessageParams::default(); + + for (k, v) in params.iter() { if k.starts_with("_pq_.") { unsupported.push(k); + } else { + supported.insert(k, v); } } - // TODO: remove unsupported options so we don't send them to compute. - - stream - .write_message(&Be::NegotiateProtocolVersion { - version: PG_PROTOCOL_LATEST, - options: &unsupported, - }) - .await?; + stream.write_message(BeMessage::NegotiateProtocolVersion { + version: PG_PROTOCOL_LATEST, + options: &unsupported, + }); + stream.flush().await?; info!( ?version, @@ -235,7 +219,7 @@ pub(crate) async fn handshake( session_type = "normal", "successful handshake; unsupported minor version requested" ); - break Ok(HandshakeData::Startup(stream, params)); + break Ok(HandshakeData::Startup(stream, supported)); } FeStartupPacket::StartupMessage { version, params } => { warn!( diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 0a86022e78..26ac6a89e7 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -10,15 +10,14 @@ pub(crate) mod wake_compute; use std::sync::Arc; pub use copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute}; -use futures::{FutureExt, TryFutureExt}; +use futures::FutureExt; use itertools::Itertools; use once_cell::sync::OnceCell; -use pq_proto::{BeMessage as Be, CancelKeyData, StartupMessageParams}; use regex::Regex; use serde::{Deserialize, Serialize}; use smol_str::{SmolStr, ToSmolStr, format_smolstr}; use thiserror::Error; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::sync::CancellationToken; use tracing::{Instrument, debug, error, info, warn}; @@ -27,8 +26,9 @@ use self::passthrough::ProxyPassthrough; use crate::cancellation::{self, CancellationHandler}; use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig}; use crate::context::RequestContext; -use crate::error::ReportableError; +use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumClientConnectionsGuard}; +use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams}; use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol}; use crate::proxy::handshake::{HandshakeData, handshake}; use crate::rate_limiter::EndpointRateLimiter; @@ -38,6 +38,18 @@ use crate::{auth, compute}; const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; +#[derive(Error, Debug)] +#[error("{ERR_INSECURE_CONNECTION}")] +pub struct TlsRequired; + +impl ReportableError for TlsRequired { + fn get_error_kind(&self) -> crate::error::ErrorKind { + crate::error::ErrorKind::User + } +} + +impl UserFacingError for TlsRequired {} + pub async fn run_until_cancelled( f: F, cancellation_token: &CancellationToken, @@ -329,7 +341,7 @@ pub(crate) async fn handle_client( let user_info = match result { Ok(user_info) => user_info, - Err(e) => stream.throw_error(e, Some(ctx)).await?, + Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, }; let user = user_info.get_user().to_owned(); @@ -349,10 +361,10 @@ pub(crate) async fn handle_client( let app = params.get("application_name"); let params_span = tracing::info_span!("", ?user, ?db, ?app); - return stream + return Err(stream .throw_error(e, Some(ctx)) .instrument(params_span) - .await?; + .await)?; } }; @@ -365,7 +377,7 @@ pub(crate) async fn handle_client( .get(NeonOptions::PARAMS_COMPAT) .is_some(); - let mut node = connect_to_compute( + let res = connect_to_compute( ctx, &TcpMechanism { user_info: compute_user_info.clone(), @@ -377,22 +389,19 @@ pub(crate) async fn handle_client( config.wake_compute_retry_config, &config.connect_to_compute, ) - .or_else(|e| stream.throw_error(e, Some(ctx))) - .await?; + .await; + + let node = match res { + Ok(node) => node, + Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, + }; let cancellation_handler_clone = Arc::clone(&cancellation_handler); let session = cancellation_handler_clone.get_key(); session.write_cancel_key(node.cancel_closure.clone())?; - - prepare_client_connection(&node, *session.key(), &mut stream).await?; - - // Before proxy passing, forward to compute whatever data is left in the - // PqStream input buffer. Normally there is none, but our serverless npm - // driver in pipeline mode sends startup, password and first query - // immediately after opening the connection. - let (stream, read_buf) = stream.into_inner(); - node.stream.write_all(&read_buf).await?; + prepare_client_connection(&node, *session.key(), &mut stream); + let stream = stream.flush_and_into_inner().await?; let private_link_id = match ctx.extra() { Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()), @@ -413,31 +422,28 @@ pub(crate) async fn handle_client( } /// Finish client connection initialization: confirm auth success, send params, etc. -#[tracing::instrument(skip_all)] -pub(crate) async fn prepare_client_connection( +pub(crate) fn prepare_client_connection( node: &compute::PostgresConnection, cancel_key_data: CancelKeyData, stream: &mut PqStream, -) -> Result<(), std::io::Error> { +) { // Forward all deferred notices to the client. for notice in &node.delayed_notice { - stream.write_message_noflush(&Be::Raw(b'N', notice.as_bytes()))?; + stream.write_raw(notice.as_bytes().len(), b'N', |buf| { + buf.extend_from_slice(notice.as_bytes()); + }); } // Forward all postgres connection params to the client. for (name, value) in &node.params { - stream.write_message_noflush(&Be::ParameterStatus { + stream.write_message(BeMessage::ParameterStatus { name: name.as_bytes(), value: value.as_bytes(), - })?; + }); } - stream - .write_message_noflush(&Be::BackendKeyData(cancel_key_data))? - .write_message(&Be::ReadyForQuery) - .await?; - - Ok(()) + stream.write_message(BeMessage::BackendKeyData(cancel_key_data)); + stream.write_message(BeMessage::ReadyForQuery); } #[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)] diff --git a/proxy/src/proxy/retry.rs b/proxy/src/proxy/retry.rs index 0879564ced..01e603ec14 100644 --- a/proxy/src/proxy/retry.rs +++ b/proxy/src/proxy/retry.rs @@ -125,9 +125,10 @@ pub(crate) fn retry_after(num_retries: u32, config: RetryConfig) -> time::Durati #[cfg(test)] mod tests { - use super::ShouldRetryWakeCompute; use postgres_client::error::{DbError, SqlState}; + use super::ShouldRetryWakeCompute; + #[test] fn should_retry_wake_compute_for_db_error() { // These SQLStates should NOT trigger a wake_compute retry. diff --git a/proxy/src/proxy/tests/mitm.rs b/proxy/src/proxy/tests/mitm.rs index 59c9ac27b8..c92ee49b8d 100644 --- a/proxy/src/proxy/tests/mitm.rs +++ b/proxy/src/proxy/tests/mitm.rs @@ -10,7 +10,7 @@ use bytes::{Bytes, BytesMut}; use futures::{SinkExt, StreamExt}; use postgres_client::tls::TlsConnect; use postgres_protocol::message::frontend; -use tokio::io::{AsyncReadExt, DuplexStream}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream}; use tokio_util::codec::{Decoder, Encoder}; use super::*; @@ -49,15 +49,14 @@ async fn proxy_mitm( }; let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame); - let (end_client, buf) = end_client.framed.into_inner(); - assert!(buf.is_empty()); + let end_client = end_client.flush_and_into_inner().await.unwrap(); let mut end_client = tokio_util::codec::Framed::new(end_client, PgFrame); // give the end_server the startup parameters let mut buf = BytesMut::new(); frontend::startup_message( &postgres_protocol::message::frontend::StartupMessageParams { - params: startup.params.into(), + params: startup.params.as_bytes().into(), }, &mut buf, ) diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index be6426a63c..3cc053e0ad 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -128,7 +128,7 @@ trait TestAuth: Sized { self, stream: &mut PqStream>, ) -> anyhow::Result<()> { - stream.write_message_noflush(&Be::AuthenticationOk)?; + stream.write_message(BeMessage::AuthenticationOk); Ok(()) } } @@ -157,9 +157,7 @@ impl TestAuth for Scram { self, stream: &mut PqStream>, ) -> anyhow::Result<()> { - let outcome = auth::AuthFlow::new(stream) - .begin(auth::Scram(&self.0, &RequestContext::test())) - .await? + let outcome = auth::AuthFlow::new(stream, auth::Scram(&self.0, &RequestContext::test())) .authenticate() .await?; @@ -185,10 +183,12 @@ async fn dummy_proxy( auth.authenticate(&mut stream).await?; - stream - .write_message_noflush(&Be::CLIENT_ENCODING)? - .write_message(&Be::ReadyForQuery) - .await?; + stream.write_message(BeMessage::ParameterStatus { + name: b"client_encoding", + value: b"UTF8", + }); + stream.write_message(BeMessage::ReadyForQuery); + stream.flush().await?; Ok(()) } diff --git a/proxy/src/redis/cancellation_publisher.rs b/proxy/src/redis/cancellation_publisher.rs index 186fece4b2..6f56aeea06 100644 --- a/proxy/src/redis/cancellation_publisher.rs +++ b/proxy/src/redis/cancellation_publisher.rs @@ -1,10 +1,11 @@ use core::net::IpAddr; use std::sync::Arc; -use pq_proto::CancelKeyData; use tokio::sync::Mutex; use uuid::Uuid; +use crate::pqproto::CancelKeyData; + pub trait CancellationPublisherMut: Send + Sync + 'static { #[allow(async_fn_in_trait)] async fn try_publish( diff --git a/proxy/src/redis/keys.rs b/proxy/src/redis/keys.rs index 7527bca6d0..3113bad949 100644 --- a/proxy/src/redis/keys.rs +++ b/proxy/src/redis/keys.rs @@ -1,16 +1,15 @@ use std::io::ErrorKind; use anyhow::Ok; -use pq_proto::{CancelKeyData, id_to_cancel_key}; -use serde::{Deserialize, Serialize}; + +use crate::pqproto::{CancelKeyData, id_to_cancel_key}; pub mod keyspace { pub const CANCEL_PREFIX: &str = "cancel"; } -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +#[derive(Clone, Debug, Eq, PartialEq)] pub(crate) enum KeyPrefix { - #[serde(untagged)] Cancel(CancelKeyData), } @@ -18,9 +17,7 @@ impl KeyPrefix { pub(crate) fn build_redis_key(&self) -> String { match self { KeyPrefix::Cancel(key) => { - let hi = (key.backend_pid as u64) << 32; - let lo = (key.cancel_key as u64) & 0xffff_ffff; - let id = hi | lo; + let id = key.0.get(); let keyspace = keyspace::CANCEL_PREFIX; format!("{keyspace}:{id:x}") } @@ -63,10 +60,7 @@ mod tests { #[test] fn test_build_redis_key() { - let cancel_key: KeyPrefix = KeyPrefix::Cancel(CancelKeyData { - backend_pid: 12345, - cancel_key: 54321, - }); + let cancel_key: KeyPrefix = KeyPrefix::Cancel(id_to_cancel_key(12345 << 32 | 54321)); let redis_key = cancel_key.build_redis_key(); assert_eq!(redis_key, "cancel:30390000d431"); @@ -77,10 +71,7 @@ mod tests { let redis_key = "cancel:30390000d431"; let key: KeyPrefix = parse_redis_key(redis_key).expect("Failed to parse key"); - let ref_key = CancelKeyData { - backend_pid: 12345, - cancel_key: 54321, - }; + let ref_key = id_to_cancel_key(12345 << 32 | 54321); assert_eq!(key.as_str(), KeyPrefix::Cancel(ref_key).as_str()); let KeyPrefix::Cancel(cancel_key) = key; diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 5f9f2509e2..769d519d94 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -2,11 +2,9 @@ use std::convert::Infallible; use std::sync::Arc; use futures::StreamExt; -use pq_proto::CancelKeyData; use redis::aio::PubSub; use serde::{Deserialize, Serialize}; use tokio_util::sync::CancellationToken; -use uuid::Uuid; use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::cache::project_info::ProjectInfoCache; @@ -100,14 +98,6 @@ pub(crate) struct PasswordUpdate { role_name: RoleNameInt, } -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] -pub(crate) struct CancelSession { - pub(crate) region_id: Option, - pub(crate) cancel_key_data: CancelKeyData, - pub(crate) session_id: Uuid, - pub(crate) peer_addr: Option, -} - fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result where T: for<'de2> serde::Deserialize<'de2>, diff --git a/proxy/src/sasl/messages.rs b/proxy/src/sasl/messages.rs index 7f2f3a761c..8d26a3f453 100644 --- a/proxy/src/sasl/messages.rs +++ b/proxy/src/sasl/messages.rs @@ -1,7 +1,5 @@ //! Definitions for SASL messages. -use pq_proto::{BeAuthenticationSaslMessage, BeMessage}; - use crate::parse::split_cstr; /// SASL-specific payload of [`PasswordMessage`](pq_proto::FeMessage::PasswordMessage). @@ -30,26 +28,6 @@ impl<'a> FirstMessage<'a> { } } -/// A single SASL message. -/// This struct is deliberately decoupled from lower-level -/// [`BeAuthenticationSaslMessage`]. -#[derive(Debug)] -pub(super) enum ServerMessage { - /// We expect to see more steps. - Continue(T), - /// This is the final step. - Final(T), -} - -impl<'a> ServerMessage<&'a str> { - pub(super) fn to_reply(&self) -> BeMessage<'a> { - BeMessage::AuthenticationSasl(match self { - ServerMessage::Continue(s) => BeAuthenticationSaslMessage::Continue(s.as_bytes()), - ServerMessage::Final(s) => BeAuthenticationSaslMessage::Final(s.as_bytes()), - }) - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/proxy/src/sasl/mod.rs b/proxy/src/sasl/mod.rs index f0181b404f..007b62dfd2 100644 --- a/proxy/src/sasl/mod.rs +++ b/proxy/src/sasl/mod.rs @@ -14,7 +14,7 @@ use std::io; pub(crate) use channel_binding::ChannelBinding; pub(crate) use messages::FirstMessage; -pub(crate) use stream::{Outcome, SaslStream}; +pub(crate) use stream::{Outcome, authenticate}; use thiserror::Error; use crate::error::{ReportableError, UserFacingError}; @@ -22,6 +22,9 @@ use crate::error::{ReportableError, UserFacingError}; /// Fine-grained auth errors help in writing tests. #[derive(Error, Debug)] pub(crate) enum Error { + #[error("Unsupported authentication method: {0}")] + BadAuthMethod(Box), + #[error("Channel binding failed: {0}")] ChannelBindingFailed(&'static str), @@ -54,6 +57,7 @@ impl UserFacingError for Error { impl ReportableError for Error { fn get_error_kind(&self) -> crate::error::ErrorKind { match self { + Error::BadAuthMethod(_) => crate::error::ErrorKind::User, Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User, Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User, Error::BadClientMessage(_) => crate::error::ErrorKind::User, diff --git a/proxy/src/sasl/stream.rs b/proxy/src/sasl/stream.rs index 46e6a439e5..cb15132673 100644 --- a/proxy/src/sasl/stream.rs +++ b/proxy/src/sasl/stream.rs @@ -3,61 +3,12 @@ use std::io; use tokio::io::{AsyncRead, AsyncWrite}; -use tracing::info; -use super::Mechanism; -use super::messages::ServerMessage; +use super::{Mechanism, Step}; +use crate::context::RequestContext; +use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage}; use crate::stream::PqStream; -/// Abstracts away all peculiarities of the libpq's protocol. -pub(crate) struct SaslStream<'a, S> { - /// The underlying stream. - stream: &'a mut PqStream, - /// Current password message we received from client. - current: bytes::Bytes, - /// First SASL message produced by client. - first: Option<&'a str>, -} - -impl<'a, S> SaslStream<'a, S> { - pub(crate) fn new(stream: &'a mut PqStream, first: &'a str) -> Self { - Self { - stream, - current: bytes::Bytes::new(), - first: Some(first), - } - } -} - -impl SaslStream<'_, S> { - // Receive a new SASL message from the client. - async fn recv(&mut self) -> io::Result<&str> { - if let Some(first) = self.first.take() { - return Ok(first); - } - - self.current = self.stream.read_password_message().await?; - let s = std::str::from_utf8(&self.current) - .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?; - - Ok(s) - } -} - -impl SaslStream<'_, S> { - // Send a SASL message to the client. - async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> { - self.stream.write_message(&msg.to_reply()).await?; - Ok(()) - } - - // Queue a SASL message for the client. - fn send_noflush(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> { - self.stream.write_message_noflush(&msg.to_reply())?; - Ok(()) - } -} - /// SASL authentication outcome. /// It's much easier to match on those two variants /// than to peek into a noisy protocol error type. @@ -69,33 +20,62 @@ pub(crate) enum Outcome { Failure(&'static str), } -impl SaslStream<'_, S> { - /// Perform SASL message exchange according to the underlying algorithm - /// until user is either authenticated or denied access. - pub(crate) async fn authenticate( - mut self, - mut mechanism: M, - ) -> super::Result> { - loop { - let input = self.recv().await?; - let step = mechanism.exchange(input).map_err(|error| { - info!(?error, "error during SASL exchange"); - error - })?; +pub async fn authenticate( + ctx: &RequestContext, + stream: &mut PqStream, + mechanism: F, +) -> super::Result> +where + S: AsyncRead + AsyncWrite + Unpin, + F: FnOnce(&str) -> super::Result, + M: Mechanism, +{ + let sasl = { + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - use super::Step; - return Ok(match step { - Step::Continue(moved_mechanism, reply) => { - self.send(&ServerMessage::Continue(&reply)).await?; - mechanism = moved_mechanism; - continue; - } - Step::Success(result, reply) => { - self.send_noflush(&ServerMessage::Final(&reply))?; - Outcome::Success(result) - } - Step::Failure(reason) => Outcome::Failure(reason), - }); + // Initial client message contains the chosen auth method's name. + let msg = stream.read_password_message().await?; + super::FirstMessage::parse(msg).ok_or(super::Error::BadClientMessage("bad sasl message"))? + }; + + let mut mechanism = mechanism(sasl.method)?; + let mut input = sasl.message; + loop { + let step = mechanism + .exchange(input) + .inspect_err(|error| tracing::info!(?error, "error during SASL exchange"))?; + + match step { + Step::Continue(moved_mechanism, reply) => { + mechanism = moved_mechanism; + + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + + // write reply + let sasl_msg = BeAuthenticationSaslMessage::Continue(reply.as_bytes()); + stream.write_message(BeMessage::AuthenticationSasl(sasl_msg)); + + // get next input + stream.flush().await?; + let msg = stream.read_password_message().await?; + input = std::str::from_utf8(msg) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?; + } + Step::Success(result, reply) => { + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + + // write reply + let sasl_msg = BeAuthenticationSaslMessage::Final(reply.as_bytes()); + stream.write_message(BeMessage::AuthenticationSasl(sasl_msg)); + stream.write_message(BeMessage::AuthenticationOk); + // exit with success + break Ok(Outcome::Success(result)); + } + // exit with failure + Step::Failure(reason) => break Ok(Outcome::Failure(reason)), } } } diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 1c5bb64480..eb80ac9ad0 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -17,7 +17,6 @@ use postgres_client::error::{DbError, ErrorPosition, SqlState}; use postgres_client::{ GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction, }; -use pq_proto::StartupMessageParamsBuilder; use serde::Serialize; use serde_json::Value; use serde_json::value::RawValue; @@ -41,6 +40,7 @@ use crate::context::RequestContext; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::http::{ReadBodyError, read_body_with_limit}; use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind}; +use crate::pqproto::StartupMessageParams; use crate::proxy::{NeonOptions, run_until_cancelled}; use crate::serverless::backend::HttpConnError; use crate::types::{DbName, RoleName}; @@ -219,7 +219,7 @@ fn get_conn_info( let mut options = Option::None; - let mut params = StartupMessageParamsBuilder::default(); + let mut params = StartupMessageParams::default(); params.insert("user", &username); params.insert("database", &dbname); for (key, value) in pairs { diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 360550b0ac..7126430a85 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -2,19 +2,17 @@ use std::pin::Pin; use std::sync::Arc; use std::{io, task}; -use bytes::BytesMut; -use pq_proto::framed::{ConnectionError, Framed}; -use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError}; use rustls::ServerConfig; -use serde::{Deserialize, Serialize}; use thiserror::Error; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; use tokio_rustls::server::TlsStream; -use tracing::debug; -use crate::control_plane::messages::ColdStartInfo; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::metrics::Metrics; +use crate::pqproto::{ + BeMessage, FE_PASSWORD_MESSAGE, FeStartupPacket, SQLSTATE_INTERNAL_ERROR, WriteBuf, + read_message, read_startup, +}; use crate::tls::TlsServerEndPoint; /// Stream wrapper which implements libpq's protocol. @@ -23,58 +21,77 @@ use crate::tls::TlsServerEndPoint; /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying /// to pass random malformed bytes through the connection). pub struct PqStream { - pub(crate) framed: Framed, + stream: S, + read: Vec, + write: WriteBuf, } impl PqStream { - /// Construct a new libpq protocol wrapper. - pub fn new(stream: S) -> Self { + pub fn get_ref(&self) -> &S { + &self.stream + } + + /// Construct a new libpq protocol wrapper over a stream without the first startup message. + #[cfg(test)] + pub fn new_skip_handshake(stream: S) -> Self { Self { - framed: Framed::new(stream), + stream, + read: Vec::new(), + write: WriteBuf::new(), } } - - /// Extract the underlying stream and read buffer. - pub fn into_inner(self) -> (S, BytesMut) { - self.framed.into_inner() - } - - /// Get a shared reference to the underlying stream. - pub(crate) fn get_ref(&self) -> &S { - self.framed.get_ref() - } } -fn err_connection() -> io::Error { - io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost") +impl PqStream { + /// Construct a new libpq protocol wrapper and read the first startup message. + /// + /// This is not cancel safe. + pub async fn parse_startup(mut stream: S) -> io::Result<(Self, FeStartupPacket)> { + let startup = read_startup(&mut stream).await?; + Ok(( + Self { + stream, + read: Vec::new(), + write: WriteBuf::new(), + }, + startup, + )) + } + + /// Tell the client that encryption is not supported. + /// + /// This is not cancel safe + pub async fn reject_encryption(&mut self) -> io::Result { + // N for No. + self.write.encryption(b'N'); + self.flush().await?; + read_startup(&mut self.stream).await + } } impl PqStream { - /// Receive [`FeStartupPacket`], which is a first packet sent by a client. - pub async fn read_startup_packet(&mut self) -> io::Result { - self.framed - .read_startup_message() - .await - .map_err(ConnectionError::into_io_error)? - .ok_or_else(err_connection) - } - - async fn read_message(&mut self) -> io::Result { - self.framed - .read_message() - .await - .map_err(ConnectionError::into_io_error)? - .ok_or_else(err_connection) - } - - pub(crate) async fn read_password_message(&mut self) -> io::Result { - match self.read_message().await? { - FeMessage::PasswordMessage(msg) => Ok(msg), - bad => Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("unexpected message type: {bad:?}"), - )), + /// Read a raw postgres packet, which will respect the max length requested. + /// This is not cancel safe. + async fn read_raw_expect(&mut self, tag: u8, max: usize) -> io::Result<&mut [u8]> { + let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?; + if actual_tag != tag { + return Err(io::Error::other(format!( + "incorrect message tag, expected {:?}, got {:?}", + tag as char, actual_tag as char, + ))); } + Ok(msg) + } + + /// Read a postgres password message, which will respect the max length requested. + /// This is not cancel safe. + pub async fn read_password_message(&mut self) -> io::Result<&mut [u8]> { + // passwords are usually pretty short + // and SASL SCRAM messages are no longer than 256 bytes in my testing + // (a few hashes and random bytes, encoded into base64). + const MAX_PASSWORD_LENGTH: usize = 512; + self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH) + .await } } @@ -84,6 +101,16 @@ pub struct ReportedError { error_kind: ErrorKind, } +impl ReportedError { + pub fn new(e: (impl UserFacingError + Into)) -> Self { + let error_kind = e.get_error_kind(); + Self { + source: e.into(), + error_kind, + } + } +} + impl std::fmt::Display for ReportedError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.source.fmt(f) @@ -102,109 +129,65 @@ impl ReportableError for ReportedError { } } -#[derive(Serialize, Deserialize, Debug)] -enum ErrorTag { - #[serde(rename = "proxy")] - Proxy, - #[serde(rename = "compute")] - Compute, - #[serde(rename = "client")] - Client, - #[serde(rename = "controlplane")] - ControlPlane, - #[serde(rename = "other")] - Other, -} - -impl From for ErrorTag { - fn from(error_kind: ErrorKind) -> Self { - match error_kind { - ErrorKind::User => Self::Client, - ErrorKind::ClientDisconnect => Self::Client, - ErrorKind::RateLimit => Self::Proxy, - ErrorKind::ServiceRateLimit => Self::Proxy, // considering rate limit as proxy error for SLI - ErrorKind::Quota => Self::Proxy, - ErrorKind::Service => Self::Proxy, - ErrorKind::ControlPlane => Self::ControlPlane, - ErrorKind::Postgres => Self::Other, - ErrorKind::Compute => Self::Compute, - } - } -} - -#[derive(Serialize, Deserialize, Debug)] -#[serde(rename_all = "snake_case")] -struct ProbeErrorData { - tag: ErrorTag, - msg: String, - cold_start_info: Option, -} - impl PqStream { - /// Write the message into an internal buffer, but don't flush the underlying stream. - pub(crate) fn write_message_noflush( - &mut self, - message: &BeMessage<'_>, - ) -> io::Result<&mut Self> { - self.framed - .write_message(message) - .map_err(ProtocolError::into_io_error)?; - Ok(self) + /// Tell the client that we are willing to accept SSL. + /// This is not cancel safe + pub async fn accept_tls(mut self) -> io::Result { + // S for SSL. + self.write.encryption(b'S'); + self.flush().await?; + Ok(self.stream) } - /// Write the message into an internal buffer and flush it. - pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> { - self.write_message_noflush(message)?; - self.flush().await?; - Ok(self) + /// Assert that we are using direct TLS. + pub fn accept_direct_tls(self) -> S { + self.stream + } + + /// Write a raw message to the internal buffer. + pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec)) { + self.write.write_raw(size_hint, tag, f); + } + + /// Write the message into an internal buffer + pub fn write_message(&mut self, message: BeMessage<'_>) { + message.write_message(&mut self.write); } /// Flush the output buffer into the underlying stream. - pub(crate) async fn flush(&mut self) -> io::Result<&mut Self> { - self.framed.flush().await?; - Ok(self) + /// + /// This is cancel safe. + pub async fn flush(&mut self) -> io::Result<()> { + self.stream.write_all_buf(&mut self.write).await?; + self.write.reset(); + + self.stream.flush().await?; + + Ok(()) } - /// Writes message with the given error kind to the stream. - /// Used only for probe queries - async fn write_format_message( - &mut self, - msg: &str, - error_kind: ErrorKind, - ctx: Option<&crate::context::RequestContext>, - ) -> String { - let formatted_msg = match ctx { - Some(ctx) if ctx.get_testodrome_id().is_some() => { - serde_json::to_string(&ProbeErrorData { - tag: ErrorTag::from(error_kind), - msg: msg.to_string(), - cold_start_info: Some(ctx.cold_start_info()), - }) - .unwrap_or_default() - } - _ => msg.to_string(), - }; - - // already error case, ignore client IO error - self.write_message(&BeMessage::ErrorResponse(&formatted_msg, None)) - .await - .inspect_err(|e| debug!("write_message failed: {e}")) - .ok(); - - formatted_msg + /// Flush the output buffer into the underlying stream. + /// + /// This is cancel safe. + pub async fn flush_and_into_inner(mut self) -> io::Result { + self.flush().await?; + Ok(self.stream) } - /// Write the error message using [`Self::write_format_message`], then re-throw it. - /// Allowing string literals is safe under the assumption they might not contain any runtime info. - /// This method exists due to `&str` not implementing `Into`. + /// Write the error message to the client, then re-throw it. + /// + /// Trait [`UserFacingError`] acts as an allowlist for error types. /// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind. - pub async fn throw_error_str( + pub(crate) async fn throw_error( &mut self, - msg: &'static str, - error_kind: ErrorKind, + error: E, ctx: Option<&crate::context::RequestContext>, - ) -> Result { - self.write_format_message(msg, error_kind, ctx).await; + ) -> ReportedError + where + E: UserFacingError + Into, + { + let error_kind = error.get_error_kind(); + let msg = error.to_string_client(); if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User { tracing::info!( @@ -214,39 +197,39 @@ impl PqStream { ); } - Err(ReportedError { - source: anyhow::anyhow!(msg), - error_kind, - }) - } - - /// Write the error message using [`Self::write_format_message`], then re-throw it. - /// Trait [`UserFacingError`] acts as an allowlist for error types. - /// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind. - pub(crate) async fn throw_error( - &mut self, - error: E, - ctx: Option<&crate::context::RequestContext>, - ) -> Result - where - E: UserFacingError + Into, - { - let error_kind = error.get_error_kind(); - let msg = error.to_string_client(); - self.write_format_message(&msg, error_kind, ctx).await; - if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User { - tracing::info!( - kind=error_kind.to_metric_label(), - error=%error, - msg, - "forwarding error to user", - ); + let probe_msg; + let mut msg = &*msg; + if let Some(ctx) = ctx { + if ctx.get_testodrome_id().is_some() { + let tag = match error_kind { + ErrorKind::User => "client", + ErrorKind::ClientDisconnect => "client", + ErrorKind::RateLimit => "proxy", + ErrorKind::ServiceRateLimit => "proxy", + ErrorKind::Quota => "proxy", + ErrorKind::Service => "proxy", + ErrorKind::ControlPlane => "controlplane", + ErrorKind::Postgres => "other", + ErrorKind::Compute => "compute", + }; + probe_msg = typed_json::json!({ + "tag": tag, + "msg": msg, + "cold_start_info": ctx.cold_start_info(), + }) + .to_string(); + msg = &probe_msg; + } } - Err(ReportedError { - source: anyhow::anyhow!(error), - error_kind, - }) + // TODO: either preserve the error code from postgres, or assign error codes to proxy errors. + self.write.write_error(msg, SQLSTATE_INTERNAL_ERROR); + + self.flush() + .await + .unwrap_or_else(|e| tracing::debug!("write_message failed: {e}")); + + ReportedError::new(error) } } From 589bfdfd02c575b172a138ee5d174777da17e18a Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 2 Jun 2025 09:38:35 +0100 Subject: [PATCH 08/43] proxy: Changes to rate limits and GetEndpointAccessControl caches. (#12048) Precursor to https://github.com/neondatabase/cloud/issues/28333. We want per-endpoint configuration for rate limits, which will be distributed via the `GetEndpointAccessControl` API. This lays some of the ground work. 1. Allow the endpoint rate limiter to accept a custom leaky bucket config on check. 2. Remove the unused auth rate limiter, as I don't want to think about how it fits into this. 3. Refactor the caching of `GetEndpointAccessControl`, as it adds friction for adding new cached data to the API. That third one was rather large. I couldn't find any way to split it up. The core idea is that there's now only 2 cache APIs. `get_endpoint_access_controls` and `get_role_access_controls`. I'm pretty sure the behaviour is unchanged, except I did a drive by change to fix #8989 because it felt harmless. The change in question is that when a password validation fails, we eagerly expire the role cache if the role was cached for 5 minutes. This is to allow for edge cases where a user tries to connect with a reset password, but the cache never expires the entry due to some redis related quirk (lag, or misconfiguration, or cplane error) --- libs/metrics/src/hll.rs | 2 +- libs/utils/src/leaky_bucket.rs | 1 + proxy/src/auth/backend/mod.rs | 335 +++------ proxy/src/binary/local_proxy.rs | 16 +- proxy/src/binary/proxy.rs | 35 +- proxy/src/cache/project_info.rs | 678 +++++------------- proxy/src/cancellation.rs | 71 +- proxy/src/config.rs | 4 - proxy/src/context/mod.rs | 12 + .../control_plane/client/cplane_proxy_v1.rs | 319 +++----- proxy/src/control_plane/client/mock.rs | 74 +- proxy/src/control_plane/client/mod.rs | 75 +- proxy/src/control_plane/errors.rs | 8 + proxy/src/control_plane/mod.rs | 95 ++- proxy/src/proxy/mod.rs | 2 +- proxy/src/proxy/tests/mod.rs | 19 +- proxy/src/rate_limiter/leaky_bucket.rs | 10 +- proxy/src/rate_limiter/limiter.rs | 30 +- proxy/src/rate_limiter/mod.rs | 2 +- proxy/src/redis/notifications.rs | 41 +- proxy/src/serverless/backend.rs | 70 +- 21 files changed, 551 insertions(+), 1348 deletions(-) diff --git a/libs/metrics/src/hll.rs b/libs/metrics/src/hll.rs index 93f6a2b7cc..1a7d7a7e44 100644 --- a/libs/metrics/src/hll.rs +++ b/libs/metrics/src/hll.rs @@ -107,7 +107,7 @@ impl MetricType for HyperLogLogState { } impl HyperLogLogState { - pub fn measure(&self, item: &impl Hash) { + pub fn measure(&self, item: &(impl Hash + ?Sized)) { // changing the hasher will break compatibility with previous measurements. self.record(BuildHasherDefault::::default().hash_one(item)); } diff --git a/libs/utils/src/leaky_bucket.rs b/libs/utils/src/leaky_bucket.rs index 2398f92766..17e96bd0a9 100644 --- a/libs/utils/src/leaky_bucket.rs +++ b/libs/utils/src/leaky_bucket.rs @@ -28,6 +28,7 @@ use std::time::Duration; use tokio::sync::Notify; use tokio::time::Instant; +#[derive(Clone, Copy)] pub struct LeakyBucketConfig { /// This is the "time cost" of a single request unit. /// Should loosely represent how long it takes to handle a request unit in active resource time. diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index 8c892d90a0..735cb52f47 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -4,38 +4,31 @@ mod hacks; pub mod jwt; pub mod local; -use std::net::IpAddr; use std::sync::Arc; pub use console_redirect::ConsoleRedirectBackend; pub(crate) use console_redirect::ConsoleRedirectError; -use ipnet::{Ipv4Net, Ipv6Net}; use local::LocalBackend; use postgres_client::config::AuthKeys; use serde::{Deserialize, Serialize}; use tokio::io::{AsyncRead, AsyncWrite}; -use tracing::{debug, info, warn}; +use tracing::{debug, info}; -use crate::auth::credentials::check_peer_addr_is_in_list; -use crate::auth::{ - self, AuthError, ComputeUserInfoMaybeEndpoint, IpPattern, validate_password_and_exchange, -}; +use crate::auth::{self, AuthError, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange}; use crate::cache::Cached; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::client::ControlPlaneClient; use crate::control_plane::errors::GetAuthInfoError; use crate::control_plane::{ - self, AccessBlockerFlags, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps, - CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi, + self, AccessBlockerFlags, AuthSecret, CachedNodeInfo, ControlPlaneApi, EndpointAccessControl, + RoleAccessControl, }; use crate::intern::EndpointIdInt; -use crate::metrics::Metrics; use crate::pqproto::BeMessage; -use crate::protocol2::ConnectionInfoExtra; use crate::proxy::NeonOptions; use crate::proxy::connect_compute::ComputeConnectBackend; -use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter}; +use crate::rate_limiter::EndpointRateLimiter; use crate::stream::Stream; use crate::types::{EndpointCacheKey, EndpointId, RoleName}; use crate::{scram, stream}; @@ -201,78 +194,6 @@ impl TryFrom for ComputeUserInfo { } } -#[derive(PartialEq, PartialOrd, Hash, Eq, Ord, Debug, Copy, Clone)] -pub struct MaskedIp(IpAddr); - -impl MaskedIp { - fn new(value: IpAddr, prefix: u8) -> Self { - match value { - IpAddr::V4(v4) => Self(IpAddr::V4( - Ipv4Net::new(v4, prefix).map_or(v4, |x| x.trunc().addr()), - )), - IpAddr::V6(v6) => Self(IpAddr::V6( - Ipv6Net::new(v6, prefix).map_or(v6, |x| x.trunc().addr()), - )), - } - } -} - -// This can't be just per IP because that would limit some PaaS that share IP addresses -pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, MaskedIp)>; - -impl AuthenticationConfig { - pub(crate) fn check_rate_limit( - &self, - ctx: &RequestContext, - secret: AuthSecret, - endpoint: &EndpointId, - is_cleartext: bool, - ) -> auth::Result { - // we have validated the endpoint exists, so let's intern it. - let endpoint_int = EndpointIdInt::from(endpoint.normalize()); - - // only count the full hash count if password hack or websocket flow. - // in other words, if proxy needs to run the hashing - let password_weight = if is_cleartext { - match &secret { - #[cfg(any(test, feature = "testing"))] - AuthSecret::Md5(_) => 1, - AuthSecret::Scram(s) => s.iterations + 1, - } - } else { - // validating scram takes just 1 hmac_sha_256 operation. - 1 - }; - - let limit_not_exceeded = self.rate_limiter.check( - ( - endpoint_int, - MaskedIp::new(ctx.peer_addr(), self.rate_limit_ip_subnet), - ), - password_weight, - ); - - if !limit_not_exceeded { - warn!( - enabled = self.rate_limiter_enabled, - "rate limiting authentication" - ); - Metrics::get().proxy.requests_auth_rate_limits_total.inc(); - Metrics::get() - .proxy - .endpoints_auth_rate_limits - .get_metric() - .measure(endpoint); - - if self.rate_limiter_enabled { - return Err(auth::AuthError::too_many_connections()); - } - } - - Ok(secret) - } -} - /// True to its name, this function encapsulates our current auth trade-offs. /// Here, we choose the appropriate auth flow based on circumstances. /// @@ -285,7 +206,7 @@ async fn auth_quirks( allow_cleartext: bool, config: &'static AuthenticationConfig, endpoint_rate_limiter: Arc, -) -> auth::Result<(ComputeCredentials, Option>)> { +) -> auth::Result { // If there's no project so far, that entails that client doesn't // support SNI or other means of passing the endpoint (project) name. // We now expect to see a very specific payload in the place of password. @@ -301,55 +222,27 @@ async fn auth_quirks( debug!("fetching authentication info and allowlists"); - // check allowed list - let allowed_ips = if config.ip_allowlist_check_enabled { - let allowed_ips = api.get_allowed_ips(ctx, &info).await?; - if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) { - return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr())); - } - allowed_ips - } else { - Cached::new_uncached(Arc::new(vec![])) - }; + let access_controls = api + .get_endpoint_access_control(ctx, &info.endpoint, &info.user) + .await?; - // check if a VPC endpoint ID is coming in and if yes, if it's allowed - let access_blocks = api.get_block_public_or_vpc_access(ctx, &info).await?; - if config.is_vpc_acccess_proxy { - if access_blocks.vpc_access_blocked { - return Err(AuthError::NetworkNotAllowed); - } + access_controls.check( + ctx, + config.ip_allowlist_check_enabled, + config.is_vpc_acccess_proxy, + )?; - let incoming_vpc_endpoint_id = match ctx.extra() { - None => return Err(AuthError::MissingEndpointName), - Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(), - Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), - }; - let allowed_vpc_endpoint_ids = api.get_allowed_vpc_endpoint_ids(ctx, &info).await?; - // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. - if !allowed_vpc_endpoint_ids.is_empty() - && !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id) - { - return Err(AuthError::vpc_endpoint_id_not_allowed( - incoming_vpc_endpoint_id, - )); - } - } else if access_blocks.public_access_blocked { - return Err(AuthError::NetworkNotAllowed); - } - - if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) { + let endpoint = EndpointIdInt::from(&info.endpoint); + let rate_limit_config = None; + if !endpoint_rate_limiter.check(endpoint, rate_limit_config, 1) { return Err(AuthError::too_many_connections()); } - let cached_secret = api.get_role_secret(ctx, &info).await?; - let (cached_entry, secret) = cached_secret.take_value(); + let role_access = api + .get_role_access_control(ctx, &info.endpoint, &info.user) + .await?; - let secret = if let Some(secret) = secret { - config.check_rate_limit( - ctx, - secret, - &info.endpoint, - unauthenticated_password.is_some() || allow_cleartext, - )? + let secret = if let Some(secret) = role_access.secret { + secret } else { // If we don't have an authentication secret, we mock one to // prevent malicious probing (possible due to missing protocol steps). @@ -369,14 +262,8 @@ async fn auth_quirks( ) .await { - Ok(keys) => Ok((keys, Some(allowed_ips.as_ref().clone()))), - Err(e) => { - if e.is_password_failed() { - // The password could have been changed, so we invalidate the cache. - cached_entry.invalidate(); - } - Err(e) - } + Ok(keys) => Ok(keys), + Err(e) => Err(e), } } @@ -439,7 +326,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> { allow_cleartext: bool, config: &'static AuthenticationConfig, endpoint_rate_limiter: Arc, - ) -> auth::Result<(Backend<'a, ComputeCredentials>, Option>)> { + ) -> auth::Result> { let res = match self { Self::ControlPlane(api, user_info) => { debug!( @@ -448,17 +335,35 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> { "performing authentication using the console" ); - let (credentials, ip_allowlist) = auth_quirks( + let auth_res = auth_quirks( ctx, &*api, - user_info, + user_info.clone(), client, allow_cleartext, config, endpoint_rate_limiter, ) - .await?; - Ok((Backend::ControlPlane(api, credentials), ip_allowlist)) + .await; + match auth_res { + Ok(credentials) => Ok(Backend::ControlPlane(api, credentials)), + Err(e) => { + // The password could have been changed, so we invalidate the cache. + // We should only invalidate the cache if the TTL might have expired. + if e.is_password_failed() { + #[allow(irrefutable_let_patterns)] + if let ControlPlaneClient::ProxyV1(api) = &*api { + if let Some(ep) = &user_info.endpoint_id { + api.caches + .project_info + .maybe_invalidate_role_secret(ep, &user_info.user); + } + } + } + + Err(e) + } + } } Self::Local(_) => { return Err(auth::AuthError::bad_auth_method("invalid for local proxy")); @@ -475,44 +380,30 @@ impl Backend<'_, ComputeUserInfo> { pub(crate) async fn get_role_secret( &self, ctx: &RequestContext, - ) -> Result { - match self { - Self::ControlPlane(api, user_info) => api.get_role_secret(ctx, user_info).await, - Self::Local(_) => Ok(Cached::new_uncached(None)), - } - } - - pub(crate) async fn get_allowed_ips( - &self, - ctx: &RequestContext, - ) -> Result { - match self { - Self::ControlPlane(api, user_info) => api.get_allowed_ips(ctx, user_info).await, - Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))), - } - } - - pub(crate) async fn get_allowed_vpc_endpoint_ids( - &self, - ctx: &RequestContext, - ) -> Result { + ) -> Result { match self { Self::ControlPlane(api, user_info) => { - api.get_allowed_vpc_endpoint_ids(ctx, user_info).await + api.get_role_access_control(ctx, &user_info.endpoint, &user_info.user) + .await } - Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))), + Self::Local(_) => Ok(RoleAccessControl { secret: None }), } } - pub(crate) async fn get_block_public_or_vpc_access( + pub(crate) async fn get_endpoint_access_control( &self, ctx: &RequestContext, - ) -> Result { + ) -> Result { match self { Self::ControlPlane(api, user_info) => { - api.get_block_public_or_vpc_access(ctx, user_info).await + api.get_endpoint_access_control(ctx, &user_info.endpoint, &user_info.user) + .await } - Self::Local(_) => Ok(Cached::new_uncached(AccessBlockerFlags::default())), + Self::Local(_) => Ok(EndpointAccessControl { + allowed_ips: Arc::new(vec![]), + allowed_vpce: Arc::new(vec![]), + flags: AccessBlockerFlags::default(), + }), } } } @@ -541,9 +432,7 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials> { mod tests { #![allow(clippy::unimplemented, clippy::unwrap_used)] - use std::net::IpAddr; use std::sync::Arc; - use std::time::Duration; use bytes::BytesMut; use control_plane::AuthSecret; @@ -554,18 +443,16 @@ mod tests { use postgres_protocol::message::frontend; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; + use super::auth_quirks; use super::jwt::JwkCache; - use super::{AuthRateLimiter, auth_quirks}; - use crate::auth::backend::MaskedIp; use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern}; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::{ - self, AccessBlockerFlags, CachedAccessBlockerFlags, CachedAllowedIps, - CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, + self, AccessBlockerFlags, CachedNodeInfo, EndpointAccessControl, RoleAccessControl, }; use crate::proxy::NeonOptions; - use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo}; + use crate::rate_limiter::EndpointRateLimiter; use crate::scram::ServerSecret; use crate::scram::threadpool::ThreadPool; use crate::stream::{PqStream, Stream}; @@ -578,46 +465,34 @@ mod tests { } impl control_plane::ControlPlaneApi for Auth { - async fn get_role_secret( + async fn get_role_access_control( &self, _ctx: &RequestContext, - _user_info: &super::ComputeUserInfo, - ) -> Result { - Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone()))) + _endpoint: &crate::types::EndpointId, + _role: &crate::types::RoleName, + ) -> Result { + Ok(RoleAccessControl { + secret: Some(self.secret.clone()), + }) } - async fn get_allowed_ips( + async fn get_endpoint_access_control( &self, _ctx: &RequestContext, - _user_info: &super::ComputeUserInfo, - ) -> Result { - Ok(CachedAllowedIps::new_uncached(Arc::new(self.ips.clone()))) - } - - async fn get_allowed_vpc_endpoint_ids( - &self, - _ctx: &RequestContext, - _user_info: &super::ComputeUserInfo, - ) -> Result { - Ok(CachedAllowedVpcEndpointIds::new_uncached(Arc::new( - self.vpc_endpoint_ids.clone(), - ))) - } - - async fn get_block_public_or_vpc_access( - &self, - _ctx: &RequestContext, - _user_info: &super::ComputeUserInfo, - ) -> Result { - Ok(CachedAccessBlockerFlags::new_uncached( - self.access_blocker_flags.clone(), - )) + _endpoint: &crate::types::EndpointId, + _role: &crate::types::RoleName, + ) -> Result { + Ok(EndpointAccessControl { + allowed_ips: Arc::new(self.ips.clone()), + allowed_vpce: Arc::new(self.vpc_endpoint_ids.clone()), + flags: self.access_blocker_flags, + }) } async fn get_endpoint_jwks( &self, _ctx: &RequestContext, - _endpoint: crate::types::EndpointId, + _endpoint: &crate::types::EndpointId, ) -> Result, control_plane::errors::GetEndpointJwksError> { unimplemented!() @@ -636,9 +511,6 @@ mod tests { jwks_cache: JwkCache::default(), thread_pool: ThreadPool::new(1), scram_protocol_timeout: std::time::Duration::from_secs(5), - rate_limiter_enabled: true, - rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET), - rate_limit_ip_subnet: 64, ip_allowlist_check_enabled: true, is_vpc_acccess_proxy: false, is_auth_broker: false, @@ -655,51 +527,6 @@ mod tests { } } - #[test] - fn masked_ip() { - let ip_a = IpAddr::V4([127, 0, 0, 1].into()); - let ip_b = IpAddr::V4([127, 0, 0, 2].into()); - let ip_c = IpAddr::V4([192, 168, 1, 101].into()); - let ip_d = IpAddr::V4([192, 168, 1, 102].into()); - let ip_e = IpAddr::V6("abcd:abcd:abcd:abcd:abcd:abcd:abcd:abcd".parse().unwrap()); - let ip_f = IpAddr::V6("abcd:abcd:abcd:abcd:1234:abcd:abcd:abcd".parse().unwrap()); - - assert_ne!(MaskedIp::new(ip_a, 64), MaskedIp::new(ip_b, 64)); - assert_ne!(MaskedIp::new(ip_a, 32), MaskedIp::new(ip_b, 32)); - assert_eq!(MaskedIp::new(ip_a, 30), MaskedIp::new(ip_b, 30)); - assert_eq!(MaskedIp::new(ip_c, 30), MaskedIp::new(ip_d, 30)); - - assert_ne!(MaskedIp::new(ip_e, 128), MaskedIp::new(ip_f, 128)); - assert_eq!(MaskedIp::new(ip_e, 64), MaskedIp::new(ip_f, 64)); - } - - #[test] - fn test_default_auth_rate_limit_set() { - // these values used to exceed u32::MAX - assert_eq!( - RateBucketInfo::DEFAULT_AUTH_SET, - [ - RateBucketInfo { - interval: Duration::from_secs(1), - max_rpi: 1000 * 4096, - }, - RateBucketInfo { - interval: Duration::from_secs(60), - max_rpi: 600 * 4096 * 60, - }, - RateBucketInfo { - interval: Duration::from_secs(600), - max_rpi: 300 * 4096 * 600, - } - ] - ); - - for x in RateBucketInfo::DEFAULT_AUTH_SET { - let y = x.to_string().parse().unwrap(); - assert_eq!(x, y); - } - } - #[tokio::test] async fn auth_quirks_scram() { let (mut client, server) = tokio::io::duplex(1024); @@ -888,7 +715,7 @@ mod tests { .await .unwrap(); - assert_eq!(creds.0.info.endpoint, "my-endpoint"); + assert_eq!(creds.info.endpoint, "my-endpoint"); handle.await.unwrap(); } diff --git a/proxy/src/binary/local_proxy.rs b/proxy/src/binary/local_proxy.rs index a566383390..ba10fce7b4 100644 --- a/proxy/src/binary/local_proxy.rs +++ b/proxy/src/binary/local_proxy.rs @@ -32,9 +32,7 @@ use crate::ext::TaskExt; use crate::http::health_server::AppMetrics; use crate::intern::RoleNameInt; use crate::metrics::{Metrics, ThreadPoolMetrics}; -use crate::rate_limiter::{ - BucketRateLimiter, EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo, -}; +use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo}; use crate::scram::threadpool::ThreadPool; use crate::serverless::cancel_set::CancelSet; use crate::serverless::{self, GlobalConnPoolOptions}; @@ -69,15 +67,6 @@ struct LocalProxyCliArgs { /// Can be given multiple times for different bucket sizes. #[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)] user_rps_limit: Vec, - /// Whether the auth rate limiter actually takes effect (for testing) - #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] - auth_rate_limit_enabled: bool, - /// Authentication rate limiter max number of hashes per second. - #[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)] - auth_rate_limit: Vec, - /// The IP subnet to use when considering whether two IP addresses are considered the same. - #[clap(long, default_value_t = 64)] - auth_rate_limit_ip_subnet: u8, /// Whether to retry the connection to the compute node #[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)] connect_to_compute_retry: String, @@ -282,9 +271,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig jwks_cache: JwkCache::default(), thread_pool: ThreadPool::new(0), scram_protocol_timeout: Duration::from_secs(10), - rate_limiter_enabled: false, - rate_limiter: BucketRateLimiter::new(vec![]), - rate_limit_ip_subnet: 64, ip_allowlist_check_enabled: true, is_vpc_acccess_proxy: false, is_auth_broker: false, diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 9a3903ba9a..dcae263647 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -20,7 +20,7 @@ use utils::sentry_init::init_sentry; use utils::{project_build_tag, project_git_version}; use crate::auth::backend::jwt::JwkCache; -use crate::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned}; +use crate::auth::backend::{ConsoleRedirectBackend, MaybeOwned}; use crate::cancellation::{CancellationHandler, handle_cancel_messages}; use crate::config::{ self, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig, ProjectInfoCacheOptions, @@ -29,9 +29,7 @@ use crate::config::{ use crate::context::parquet::ParquetUploadArgs; use crate::http::health_server::AppMetrics; use crate::metrics::Metrics; -use crate::rate_limiter::{ - EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo, WakeComputeRateLimiter, -}; +use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::redis::kv_ops::RedisKVClient; use crate::redis::{elasticache, notifications}; @@ -154,15 +152,6 @@ struct ProxyCliArgs { /// Wake compute rate limiter max number of requests per second. #[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)] wake_compute_limit: Vec, - /// Whether the auth rate limiter actually takes effect (for testing) - #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] - auth_rate_limit_enabled: bool, - /// Authentication rate limiter max number of hashes per second. - #[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)] - auth_rate_limit: Vec, - /// The IP subnet to use when considering whether two IP addresses are considered the same. - #[clap(long, default_value_t = 64)] - auth_rate_limit_ip_subnet: u8, /// Redis rate limiter max number of requests per second. #[clap(long, default_values_t = RateBucketInfo::DEFAULT_REDIS_SET)] redis_rps_limit: Vec, @@ -410,22 +399,9 @@ pub async fn run() -> anyhow::Result<()> { Some(tx_cancel), )); - // bit of a hack - find the min rps and max rps supported and turn it into - // leaky bucket config instead - let max = args - .endpoint_rps_limit - .iter() - .map(|x| x.rps()) - .max_by(f64::total_cmp) - .unwrap_or(EndpointRateLimiter::DEFAULT.max); - let rps = args - .endpoint_rps_limit - .iter() - .map(|x| x.rps()) - .min_by(f64::total_cmp) - .unwrap_or(EndpointRateLimiter::DEFAULT.rps); let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards( - LeakyBucketConfig { rps, max }, + RateBucketInfo::to_leaky_bucket(&args.endpoint_rps_limit) + .unwrap_or(EndpointRateLimiter::DEFAULT), 64, )); @@ -678,9 +654,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { jwks_cache: JwkCache::default(), thread_pool, scram_protocol_timeout: args.scram_protocol_timeout, - rate_limiter_enabled: args.auth_rate_limit_enabled, - rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()), - rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet, ip_allowlist_check_enabled: !args.is_private_access_proxy, is_vpc_acccess_proxy: args.is_private_access_proxy, is_auth_broker: args.is_auth_broker, diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index 60678b034d..81c88e3ddd 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -1,30 +1,25 @@ -use std::collections::HashSet; +use std::collections::{HashMap, HashSet, hash_map}; use std::convert::Infallible; -use std::sync::Arc; use std::sync::atomic::AtomicU64; use std::time::Duration; use async_trait::async_trait; use clashmap::ClashMap; +use clashmap::mapref::one::Ref; use rand::{Rng, thread_rng}; -use smol_str::SmolStr; use tokio::sync::Mutex; use tokio::time::Instant; use tracing::{debug, info}; -use super::{Cache, Cached}; -use crate::auth::IpPattern; use crate::config::ProjectInfoCacheOptions; -use crate::control_plane::{AccessBlockerFlags, AuthSecret}; +use crate::control_plane::{EndpointAccessControl, RoleAccessControl}; use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt}; use crate::types::{EndpointId, RoleName}; #[async_trait] pub(crate) trait ProjectInfoCache { - fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt); - fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec); - fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt); - fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt); + fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt); + fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt); fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt); async fn decrement_active_listeners(&self); async fn increment_active_listeners(&self); @@ -42,6 +37,10 @@ impl Entry { value, } } + + pub(crate) fn get(&self, valid_since: Instant) -> Option<&T> { + (valid_since < self.created_at).then_some(&self.value) + } } impl From for Entry { @@ -50,101 +49,32 @@ impl From for Entry { } } -#[derive(Default)] struct EndpointInfo { - secret: std::collections::HashMap>>, - allowed_ips: Option>>>, - block_public_or_vpc_access: Option>, - allowed_vpc_endpoint_ids: Option>>>, + role_controls: HashMap>, + controls: Option>, } impl EndpointInfo { - fn check_ignore_cache(ignore_cache_since: Option, created_at: Instant) -> bool { - match ignore_cache_since { - None => false, - Some(t) => t < created_at, - } - } pub(crate) fn get_role_secret( &self, role_name: RoleNameInt, valid_since: Instant, - ignore_cache_since: Option, - ) -> Option<(Option, bool)> { - if let Some(secret) = self.secret.get(&role_name) { - if valid_since < secret.created_at { - return Some(( - secret.value.clone(), - Self::check_ignore_cache(ignore_cache_since, secret.created_at), - )); - } - } - None + ) -> Option { + let controls = self.role_controls.get(&role_name)?; + controls.get(valid_since).cloned() } - pub(crate) fn get_allowed_ips( - &self, - valid_since: Instant, - ignore_cache_since: Option, - ) -> Option<(Arc>, bool)> { - if let Some(allowed_ips) = &self.allowed_ips { - if valid_since < allowed_ips.created_at { - return Some(( - allowed_ips.value.clone(), - Self::check_ignore_cache(ignore_cache_since, allowed_ips.created_at), - )); - } - } - None - } - pub(crate) fn get_allowed_vpc_endpoint_ids( - &self, - valid_since: Instant, - ignore_cache_since: Option, - ) -> Option<(Arc>, bool)> { - if let Some(allowed_vpc_endpoint_ids) = &self.allowed_vpc_endpoint_ids { - if valid_since < allowed_vpc_endpoint_ids.created_at { - return Some(( - allowed_vpc_endpoint_ids.value.clone(), - Self::check_ignore_cache( - ignore_cache_since, - allowed_vpc_endpoint_ids.created_at, - ), - )); - } - } - None - } - pub(crate) fn get_block_public_or_vpc_access( - &self, - valid_since: Instant, - ignore_cache_since: Option, - ) -> Option<(AccessBlockerFlags, bool)> { - if let Some(block_public_or_vpc_access) = &self.block_public_or_vpc_access { - if valid_since < block_public_or_vpc_access.created_at { - return Some(( - block_public_or_vpc_access.value.clone(), - Self::check_ignore_cache( - ignore_cache_since, - block_public_or_vpc_access.created_at, - ), - )); - } - } - None + pub(crate) fn get_controls(&self, valid_since: Instant) -> Option { + let controls = self.controls.as_ref()?; + controls.get(valid_since).cloned() } - pub(crate) fn invalidate_allowed_ips(&mut self) { - self.allowed_ips = None; - } - pub(crate) fn invalidate_allowed_vpc_endpoint_ids(&mut self) { - self.allowed_vpc_endpoint_ids = None; - } - pub(crate) fn invalidate_block_public_or_vpc_access(&mut self) { - self.block_public_or_vpc_access = None; + pub(crate) fn invalidate_endpoint(&mut self) { + self.controls = None; } + pub(crate) fn invalidate_role_secret(&mut self, role_name: RoleNameInt) { - self.secret.remove(&role_name); + self.role_controls.remove(&role_name); } } @@ -170,34 +100,22 @@ pub struct ProjectInfoCacheImpl { #[async_trait] impl ProjectInfoCache for ProjectInfoCacheImpl { - fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec) { - info!( - "invalidating allowed vpc endpoint ids for projects `{}`", - project_ids - .iter() - .map(|id| id.to_string()) - .collect::>() - .join(", ") - ); - for project_id in project_ids { - let endpoints = self - .project2ep - .get(&project_id) - .map(|kv| kv.value().clone()) - .unwrap_or_default(); - for endpoint_id in endpoints { - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_allowed_vpc_endpoint_ids(); - } + fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) { + info!("invalidating endpoint access for project `{project_id}`"); + let endpoints = self + .project2ep + .get(&project_id) + .map(|kv| kv.value().clone()) + .unwrap_or_default(); + for endpoint_id in endpoints { + if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { + endpoint_info.invalidate_endpoint(); } } } - fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt) { - info!( - "invalidating allowed vpc endpoint ids for org `{}`", - account_id - ); + fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) { + info!("invalidating endpoint access for org `{account_id}`"); let endpoints = self .account2ep .get(&account_id) @@ -205,41 +123,11 @@ impl ProjectInfoCache for ProjectInfoCacheImpl { .unwrap_or_default(); for endpoint_id in endpoints { if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_allowed_vpc_endpoint_ids(); + endpoint_info.invalidate_endpoint(); } } } - fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt) { - info!( - "invalidating block public or vpc access for project `{}`", - project_id - ); - let endpoints = self - .project2ep - .get(&project_id) - .map(|kv| kv.value().clone()) - .unwrap_or_default(); - for endpoint_id in endpoints { - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_block_public_or_vpc_access(); - } - } - } - - fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt) { - info!("invalidating allowed ips for project `{}`", project_id); - let endpoints = self - .project2ep - .get(&project_id) - .map(|kv| kv.value().clone()) - .unwrap_or_default(); - for endpoint_id in endpoints { - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_allowed_ips(); - } - } - } fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt) { info!( "invalidating role secret for project_id `{}` and role_name `{}`", @@ -256,6 +144,7 @@ impl ProjectInfoCache for ProjectInfoCacheImpl { } } } + async fn decrement_active_listeners(&self) { let mut listeners_guard = self.active_listeners_lock.lock().await; if *listeners_guard == 0 { @@ -293,155 +182,71 @@ impl ProjectInfoCacheImpl { } } + fn get_endpoint_cache( + &self, + endpoint_id: &EndpointId, + ) -> Option> { + let endpoint_id = EndpointIdInt::get(endpoint_id)?; + self.cache.get(&endpoint_id) + } + pub(crate) fn get_role_secret( &self, endpoint_id: &EndpointId, role_name: &RoleName, - ) -> Option>> { - let endpoint_id = EndpointIdInt::get(endpoint_id)?; + ) -> Option { + let valid_since = self.get_cache_times(); let role_name = RoleNameInt::get(role_name)?; - let (valid_since, ignore_cache_since) = self.get_cache_times(); - let endpoint_info = self.cache.get(&endpoint_id)?; - let (value, ignore_cache) = - endpoint_info.get_role_secret(role_name, valid_since, ignore_cache_since)?; - if !ignore_cache { - let cached = Cached { - token: Some(( - self, - CachedLookupInfo::new_role_secret(endpoint_id, role_name), - )), - value, - }; - return Some(cached); - } - Some(Cached::new_uncached(value)) - } - pub(crate) fn get_allowed_ips( - &self, - endpoint_id: &EndpointId, - ) -> Option>>> { - let endpoint_id = EndpointIdInt::get(endpoint_id)?; - let (valid_since, ignore_cache_since) = self.get_cache_times(); - let endpoint_info = self.cache.get(&endpoint_id)?; - let value = endpoint_info.get_allowed_ips(valid_since, ignore_cache_since); - let (value, ignore_cache) = value?; - if !ignore_cache { - let cached = Cached { - token: Some((self, CachedLookupInfo::new_allowed_ips(endpoint_id))), - value, - }; - return Some(cached); - } - Some(Cached::new_uncached(value)) - } - pub(crate) fn get_allowed_vpc_endpoint_ids( - &self, - endpoint_id: &EndpointId, - ) -> Option>>> { - let endpoint_id = EndpointIdInt::get(endpoint_id)?; - let (valid_since, ignore_cache_since) = self.get_cache_times(); - let endpoint_info = self.cache.get(&endpoint_id)?; - let value = endpoint_info.get_allowed_vpc_endpoint_ids(valid_since, ignore_cache_since); - let (value, ignore_cache) = value?; - if !ignore_cache { - let cached = Cached { - token: Some(( - self, - CachedLookupInfo::new_allowed_vpc_endpoint_ids(endpoint_id), - )), - value, - }; - return Some(cached); - } - Some(Cached::new_uncached(value)) - } - pub(crate) fn get_block_public_or_vpc_access( - &self, - endpoint_id: &EndpointId, - ) -> Option> { - let endpoint_id = EndpointIdInt::get(endpoint_id)?; - let (valid_since, ignore_cache_since) = self.get_cache_times(); - let endpoint_info = self.cache.get(&endpoint_id)?; - let value = endpoint_info.get_block_public_or_vpc_access(valid_since, ignore_cache_since); - let (value, ignore_cache) = value?; - if !ignore_cache { - let cached = Cached { - token: Some(( - self, - CachedLookupInfo::new_block_public_or_vpc_access(endpoint_id), - )), - value, - }; - return Some(cached); - } - Some(Cached::new_uncached(value)) + let endpoint_info = self.get_endpoint_cache(endpoint_id)?; + endpoint_info.get_role_secret(role_name, valid_since) } - pub(crate) fn insert_role_secret( + pub(crate) fn get_endpoint_access( &self, - project_id: ProjectIdInt, - endpoint_id: EndpointIdInt, - role_name: RoleNameInt, - secret: Option, - ) { - if self.cache.len() >= self.config.size { - // If there are too many entries, wait until the next gc cycle. - return; - } - self.insert_project2endpoint(project_id, endpoint_id); - let mut entry = self.cache.entry(endpoint_id).or_default(); - if entry.secret.len() < self.config.max_roles { - entry.secret.insert(role_name, secret.into()); - } + endpoint_id: &EndpointId, + ) -> Option { + let valid_since = self.get_cache_times(); + let endpoint_info = self.get_endpoint_cache(endpoint_id)?; + endpoint_info.get_controls(valid_since) } - pub(crate) fn insert_allowed_ips( - &self, - project_id: ProjectIdInt, - endpoint_id: EndpointIdInt, - allowed_ips: Arc>, - ) { - if self.cache.len() >= self.config.size { - // If there are too many entries, wait until the next gc cycle. - return; - } - self.insert_project2endpoint(project_id, endpoint_id); - self.cache.entry(endpoint_id).or_default().allowed_ips = Some(allowed_ips.into()); - } - pub(crate) fn insert_allowed_vpc_endpoint_ids( + + pub(crate) fn insert_endpoint_access( &self, account_id: Option, project_id: ProjectIdInt, endpoint_id: EndpointIdInt, - allowed_vpc_endpoint_ids: Arc>, + role_name: RoleNameInt, + controls: EndpointAccessControl, + role_controls: RoleAccessControl, ) { - if self.cache.len() >= self.config.size { - // If there are too many entries, wait until the next gc cycle. - return; - } if let Some(account_id) = account_id { self.insert_account2endpoint(account_id, endpoint_id); } self.insert_project2endpoint(project_id, endpoint_id); - self.cache - .entry(endpoint_id) - .or_default() - .allowed_vpc_endpoint_ids = Some(allowed_vpc_endpoint_ids.into()); - } - pub(crate) fn insert_block_public_or_vpc_access( - &self, - project_id: ProjectIdInt, - endpoint_id: EndpointIdInt, - access_blockers: AccessBlockerFlags, - ) { + if self.cache.len() >= self.config.size { // If there are too many entries, wait until the next gc cycle. return; } - self.insert_project2endpoint(project_id, endpoint_id); - self.cache - .entry(endpoint_id) - .or_default() - .block_public_or_vpc_access = Some(access_blockers.into()); + + let controls = Entry::from(controls); + let role_controls = Entry::from(role_controls); + + match self.cache.entry(endpoint_id) { + clashmap::Entry::Vacant(e) => { + e.insert(EndpointInfo { + role_controls: HashMap::from_iter([(role_name, role_controls)]), + controls: Some(controls), + }); + } + clashmap::Entry::Occupied(mut e) => { + let ep = e.get_mut(); + ep.controls = Some(controls); + if ep.role_controls.len() < self.config.max_roles { + ep.role_controls.insert(role_name, role_controls); + } + } + } } fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) { @@ -452,6 +257,7 @@ impl ProjectInfoCacheImpl { .insert(project_id, HashSet::from([endpoint_id])); } } + fn insert_account2endpoint(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt) { if let Some(mut endpoints) = self.account2ep.get_mut(&account_id) { endpoints.insert(endpoint_id); @@ -460,21 +266,57 @@ impl ProjectInfoCacheImpl { .insert(account_id, HashSet::from([endpoint_id])); } } - fn get_cache_times(&self) -> (Instant, Option) { - let mut valid_since = Instant::now() - self.config.ttl; - // Only ignore cache if ttl is disabled. + + fn ignore_ttl_since(&self) -> Option { let ttl_disabled_since_us = self .ttl_disabled_since_us .load(std::sync::atomic::Ordering::Relaxed); - let ignore_cache_since = if ttl_disabled_since_us == u64::MAX { - None - } else { - let ignore_cache_since = self.start_time + Duration::from_micros(ttl_disabled_since_us); + + if ttl_disabled_since_us == u64::MAX { + return None; + } + + Some(self.start_time + Duration::from_micros(ttl_disabled_since_us)) + } + + fn get_cache_times(&self) -> Instant { + let mut valid_since = Instant::now() - self.config.ttl; + if let Some(ignore_ttl_since) = self.ignore_ttl_since() { // We are fine if entry is not older than ttl or was added before we are getting notifications. - valid_since = valid_since.min(ignore_cache_since); - Some(ignore_cache_since) + valid_since = valid_since.min(ignore_ttl_since); + } + valid_since + } + + pub fn maybe_invalidate_role_secret(&self, endpoint_id: &EndpointId, role_name: &RoleName) { + let Some(endpoint_id) = EndpointIdInt::get(endpoint_id) else { + return; }; - (valid_since, ignore_cache_since) + let Some(role_name) = RoleNameInt::get(role_name) else { + return; + }; + + let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) else { + return; + }; + + let entry = endpoint_info.role_controls.entry(role_name); + let hash_map::Entry::Occupied(role_controls) = entry else { + return; + }; + + let created_at = role_controls.get().created_at; + let expire = match self.ignore_ttl_since() { + // if ignoring TTL, we should still try and roll the password if it's old + // and we the client gave an incorrect password. There could be some lag on the redis channel. + Some(_) => created_at + self.config.ttl < Instant::now(), + // edge case: redis is down, let's be generous and invalidate the cache immediately. + None => true, + }; + + if expire { + role_controls.remove(); + } } pub async fn gc_worker(&self) -> anyhow::Result { @@ -509,84 +351,12 @@ impl ProjectInfoCacheImpl { } } -/// Lookup info for project info cache. -/// This is used to invalidate cache entries. -pub(crate) struct CachedLookupInfo { - /// Search by this key. - endpoint_id: EndpointIdInt, - lookup_type: LookupType, -} - -impl CachedLookupInfo { - pub(self) fn new_role_secret(endpoint_id: EndpointIdInt, role_name: RoleNameInt) -> Self { - Self { - endpoint_id, - lookup_type: LookupType::RoleSecret(role_name), - } - } - pub(self) fn new_allowed_ips(endpoint_id: EndpointIdInt) -> Self { - Self { - endpoint_id, - lookup_type: LookupType::AllowedIps, - } - } - pub(self) fn new_allowed_vpc_endpoint_ids(endpoint_id: EndpointIdInt) -> Self { - Self { - endpoint_id, - lookup_type: LookupType::AllowedVpcEndpointIds, - } - } - pub(self) fn new_block_public_or_vpc_access(endpoint_id: EndpointIdInt) -> Self { - Self { - endpoint_id, - lookup_type: LookupType::BlockPublicOrVpcAccess, - } - } -} - -enum LookupType { - RoleSecret(RoleNameInt), - AllowedIps, - AllowedVpcEndpointIds, - BlockPublicOrVpcAccess, -} - -impl Cache for ProjectInfoCacheImpl { - type Key = SmolStr; - // Value is not really used here, but we need to specify it. - type Value = SmolStr; - - type LookupInfo = CachedLookupInfo; - - fn invalidate(&self, key: &Self::LookupInfo) { - match &key.lookup_type { - LookupType::RoleSecret(role_name) => { - if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { - endpoint_info.invalidate_role_secret(*role_name); - } - } - LookupType::AllowedIps => { - if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { - endpoint_info.invalidate_allowed_ips(); - } - } - LookupType::AllowedVpcEndpointIds => { - if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { - endpoint_info.invalidate_allowed_vpc_endpoint_ids(); - } - } - LookupType::BlockPublicOrVpcAccess => { - if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { - endpoint_info.invalidate_block_public_or_vpc_access(); - } - } - } - } -} - #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; + use crate::control_plane::{AccessBlockerFlags, AuthSecret}; use crate::scram::ServerSecret; use crate::types::ProjectId; @@ -601,6 +371,8 @@ mod tests { }); let project_id: ProjectId = "project".into(); let endpoint_id: EndpointId = "endpoint".into(); + let account_id: Option = None; + let user1: RoleName = "user1".into(); let user2: RoleName = "user2".into(); let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32]))); @@ -609,183 +381,73 @@ mod tests { "127.0.0.1".parse().unwrap(), "127.0.0.2".parse().unwrap(), ]); - cache.insert_role_secret( + + cache.insert_endpoint_access( + account_id, (&project_id).into(), (&endpoint_id).into(), (&user1).into(), - secret1.clone(), + EndpointAccessControl { + allowed_ips: allowed_ips.clone(), + allowed_vpce: Arc::new(vec![]), + flags: AccessBlockerFlags::default(), + }, + RoleAccessControl { + secret: secret1.clone(), + }, ); - cache.insert_role_secret( + + cache.insert_endpoint_access( + account_id, (&project_id).into(), (&endpoint_id).into(), (&user2).into(), - secret2.clone(), - ); - cache.insert_allowed_ips( - (&project_id).into(), - (&endpoint_id).into(), - allowed_ips.clone(), + EndpointAccessControl { + allowed_ips: allowed_ips.clone(), + allowed_vpce: Arc::new(vec![]), + flags: AccessBlockerFlags::default(), + }, + RoleAccessControl { + secret: secret2.clone(), + }, ); let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); - assert!(cached.cached()); - assert_eq!(cached.value, secret1); + assert_eq!(cached.secret, secret1); + let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap(); - assert!(cached.cached()); - assert_eq!(cached.value, secret2); + assert_eq!(cached.secret, secret2); // Shouldn't add more than 2 roles. let user3: RoleName = "user3".into(); let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32]))); - cache.insert_role_secret( + + cache.insert_endpoint_access( + account_id, (&project_id).into(), (&endpoint_id).into(), (&user3).into(), - secret3.clone(), + EndpointAccessControl { + allowed_ips: allowed_ips.clone(), + allowed_vpce: Arc::new(vec![]), + flags: AccessBlockerFlags::default(), + }, + RoleAccessControl { + secret: secret3.clone(), + }, ); + assert!(cache.get_role_secret(&endpoint_id, &user3).is_none()); - let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); - assert!(cached.cached()); - assert_eq!(cached.value, allowed_ips); + let cached = cache.get_endpoint_access(&endpoint_id).unwrap(); + assert_eq!(cached.allowed_ips, allowed_ips); tokio::time::advance(Duration::from_secs(2)).await; let cached = cache.get_role_secret(&endpoint_id, &user1); assert!(cached.is_none()); let cached = cache.get_role_secret(&endpoint_id, &user2); assert!(cached.is_none()); - let cached = cache.get_allowed_ips(&endpoint_id); + let cached = cache.get_endpoint_access(&endpoint_id); assert!(cached.is_none()); } - - #[tokio::test] - async fn test_project_info_cache_invalidations() { - tokio::time::pause(); - let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions { - size: 2, - max_roles: 2, - ttl: Duration::from_secs(1), - gc_interval: Duration::from_secs(600), - })); - cache.clone().increment_active_listeners().await; - tokio::time::advance(Duration::from_secs(2)).await; - - let project_id: ProjectId = "project".into(); - let endpoint_id: EndpointId = "endpoint".into(); - let user1: RoleName = "user1".into(); - let user2: RoleName = "user2".into(); - let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32]))); - let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32]))); - let allowed_ips = Arc::new(vec![ - "127.0.0.1".parse().unwrap(), - "127.0.0.2".parse().unwrap(), - ]); - cache.insert_role_secret( - (&project_id).into(), - (&endpoint_id).into(), - (&user1).into(), - secret1.clone(), - ); - cache.insert_role_secret( - (&project_id).into(), - (&endpoint_id).into(), - (&user2).into(), - secret2.clone(), - ); - cache.insert_allowed_ips( - (&project_id).into(), - (&endpoint_id).into(), - allowed_ips.clone(), - ); - - tokio::time::advance(Duration::from_secs(2)).await; - // Nothing should be invalidated. - - let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); - // TTL is disabled, so it should be impossible to invalidate this value. - assert!(!cached.cached()); - assert_eq!(cached.value, secret1); - - cached.invalidate(); // Shouldn't do anything. - let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); - assert_eq!(cached.value, secret1); - - let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap(); - assert!(!cached.cached()); - assert_eq!(cached.value, secret2); - - // The only way to invalidate this value is to invalidate via the api. - cache.invalidate_role_secret_for_project((&project_id).into(), (&user2).into()); - assert!(cache.get_role_secret(&endpoint_id, &user2).is_none()); - - let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); - assert!(!cached.cached()); - assert_eq!(cached.value, allowed_ips); - } - - #[tokio::test] - async fn test_increment_active_listeners_invalidate_added_before() { - tokio::time::pause(); - let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions { - size: 2, - max_roles: 2, - ttl: Duration::from_secs(1), - gc_interval: Duration::from_secs(600), - })); - - let project_id: ProjectId = "project".into(); - let endpoint_id: EndpointId = "endpoint".into(); - let user1: RoleName = "user1".into(); - let user2: RoleName = "user2".into(); - let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32]))); - let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32]))); - let allowed_ips = Arc::new(vec![ - "127.0.0.1".parse().unwrap(), - "127.0.0.2".parse().unwrap(), - ]); - cache.insert_role_secret( - (&project_id).into(), - (&endpoint_id).into(), - (&user1).into(), - secret1.clone(), - ); - cache.clone().increment_active_listeners().await; - tokio::time::advance(Duration::from_millis(100)).await; - cache.insert_role_secret( - (&project_id).into(), - (&endpoint_id).into(), - (&user2).into(), - secret2.clone(), - ); - - // Added before ttl was disabled + ttl should be still cached. - let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); - assert!(cached.cached()); - let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap(); - assert!(cached.cached()); - - tokio::time::advance(Duration::from_secs(1)).await; - // Added before ttl was disabled + ttl should expire. - assert!(cache.get_role_secret(&endpoint_id, &user1).is_none()); - assert!(cache.get_role_secret(&endpoint_id, &user2).is_none()); - - // Added after ttl was disabled + ttl should not be cached. - cache.insert_allowed_ips( - (&project_id).into(), - (&endpoint_id).into(), - allowed_ips.clone(), - ); - let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); - assert!(!cached.cached()); - - tokio::time::advance(Duration::from_secs(1)).await; - // Added before ttl was disabled + ttl still should expire. - assert!(cache.get_role_secret(&endpoint_id, &user1).is_none()); - assert!(cache.get_role_secret(&endpoint_id, &user2).is_none()); - // Shouldn't be invalidated. - - let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); - assert!(!cached.cached()); - assert_eq!(cached.value, allowed_ips); - } } diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index 0bff901376..d26641db46 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -12,8 +12,8 @@ use tokio::net::TcpStream; use tokio::sync::{mpsc, oneshot}; use tracing::{debug, error, info, warn}; +use crate::auth::AuthError; use crate::auth::backend::ComputeUserInfo; -use crate::auth::{AuthError, check_peer_addr_is_in_list}; use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::ControlPlaneApi; @@ -21,7 +21,6 @@ use crate::error::ReportableError; use crate::ext::LockExt; use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind}; use crate::pqproto::CancelKeyData; -use crate::protocol2::ConnectionInfoExtra; use crate::rate_limiter::LeakyBucketRateLimiter; use crate::redis::keys::KeyPrefix; use crate::redis::kv_ops::RedisKVClient; @@ -272,13 +271,7 @@ pub(crate) enum CancelError { #[error("rate limit exceeded")] RateLimit, - #[error("IP is not allowed")] - IpNotAllowed, - - #[error("VPC endpoint id is not allowed to connect")] - VpcEndpointIdNotAllowed, - - #[error("Authentication backend error")] + #[error("Authentication error")] AuthError(#[from] AuthError), #[error("key not found")] @@ -297,10 +290,7 @@ impl ReportableError for CancelError { } CancelError::Postgres(_) => crate::error::ErrorKind::Compute, CancelError::RateLimit => crate::error::ErrorKind::RateLimit, - CancelError::IpNotAllowed - | CancelError::VpcEndpointIdNotAllowed - | CancelError::NotFound => crate::error::ErrorKind::User, - CancelError::AuthError(_) => crate::error::ErrorKind::ControlPlane, + CancelError::NotFound | CancelError::AuthError(_) => crate::error::ErrorKind::User, CancelError::InternalError => crate::error::ErrorKind::Service, } } @@ -422,7 +412,13 @@ impl CancellationHandler { IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()), }; - if !self.limiter.lock_propagate_poison().check(subnet_key, 1) { + + let allowed = { + let rate_limit_config = None; + let limiter = self.limiter.lock_propagate_poison(); + limiter.check(subnet_key, rate_limit_config, 1) + }; + if !allowed { // log only the subnet part of the IP address to know which subnet is rate limited tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}"); Metrics::get() @@ -450,52 +446,13 @@ impl CancellationHandler { return Err(CancelError::NotFound); }; - if check_ip_allowed { - let ip_allowlist = auth_backend - .get_allowed_ips(&ctx, &cancel_closure.user_info) - .await - .map_err(|e| CancelError::AuthError(e.into()))?; - - if !check_peer_addr_is_in_list(&ctx.peer_addr(), &ip_allowlist) { - // log it here since cancel_session could be spawned in a task - tracing::warn!( - "IP is not allowed to cancel the query: {key}, address: {}", - ctx.peer_addr() - ); - return Err(CancelError::IpNotAllowed); - } - } - - // check if a VPC endpoint ID is coming in and if yes, if it's allowed - let access_blocks = auth_backend - .get_block_public_or_vpc_access(&ctx, &cancel_closure.user_info) + let info = &cancel_closure.user_info; + let access_controls = auth_backend + .get_endpoint_access_control(&ctx, &info.endpoint, &info.user) .await .map_err(|e| CancelError::AuthError(e.into()))?; - if check_vpc_allowed { - if access_blocks.vpc_access_blocked { - return Err(CancelError::AuthError(AuthError::NetworkNotAllowed)); - } - - let incoming_vpc_endpoint_id = match ctx.extra() { - None => return Err(CancelError::AuthError(AuthError::MissingVPCEndpointId)), - Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(), - Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), - }; - - let allowed_vpc_endpoint_ids = auth_backend - .get_allowed_vpc_endpoint_ids(&ctx, &cancel_closure.user_info) - .await - .map_err(|e| CancelError::AuthError(e.into()))?; - // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. - if !allowed_vpc_endpoint_ids.is_empty() - && !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id) - { - return Err(CancelError::VpcEndpointIdNotAllowed); - } - } else if access_blocks.public_access_blocked { - return Err(CancelError::VpcEndpointIdNotAllowed); - } + access_controls.check(&ctx, check_ip_allowed, check_vpc_allowed)?; Metrics::get() .proxy diff --git a/proxy/src/config.rs b/proxy/src/config.rs index ad398c122c..a97339df9a 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -7,7 +7,6 @@ use arc_swap::ArcSwapOption; use clap::ValueEnum; use remote_storage::RemoteStorageConfig; -use crate::auth::backend::AuthRateLimiter; use crate::auth::backend::jwt::JwkCache; use crate::control_plane::locks::ApiLocks; use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig}; @@ -65,9 +64,6 @@ pub struct HttpConfig { pub struct AuthenticationConfig { pub thread_pool: Arc, pub scram_protocol_timeout: tokio::time::Duration, - pub rate_limiter_enabled: bool, - pub rate_limiter: AuthRateLimiter, - pub rate_limit_ip_subnet: u8, pub ip_allowlist_check_enabled: bool, pub is_vpc_acccess_proxy: bool, pub jwks_cache: JwkCache, diff --git a/proxy/src/context/mod.rs b/proxy/src/context/mod.rs index de4600951e..24268997ba 100644 --- a/proxy/src/context/mod.rs +++ b/proxy/src/context/mod.rs @@ -370,6 +370,18 @@ impl RequestContext { } } + pub(crate) fn latency_timer_pause_at( + &self, + at: tokio::time::Instant, + waiting_for: Waiting, + ) -> LatencyTimerPause<'_> { + LatencyTimerPause { + ctx: self, + start: at, + waiting_for, + } + } + pub(crate) fn get_proxy_latency(&self) -> LatencyAccumulated { self.0 .try_lock() diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index 2765aaa462..93f4ea6cf7 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -15,7 +15,6 @@ use tracing::{Instrument, debug, info, info_span, warn}; use super::super::messages::{ControlPlaneErrorMessage, GetEndpointAccessControl, WakeCompute}; use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; -use crate::cache::Cached; use crate::context::RequestContext; use crate::control_plane::caches::ApiCaches; use crate::control_plane::errors::{ @@ -24,12 +23,12 @@ use crate::control_plane::errors::{ use crate::control_plane::locks::ApiLocks; use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason}; use crate::control_plane::{ - AccessBlockerFlags, AuthInfo, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps, - CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, NodeInfo, + AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo, + RoleAccessControl, }; -use crate::metrics::{CacheOutcome, Metrics}; +use crate::metrics::Metrics; use crate::rate_limiter::WakeComputeRateLimiter; -use crate::types::{EndpointCacheKey, EndpointId}; +use crate::types::{EndpointCacheKey, EndpointId, RoleName}; use crate::{compute, http, scram}; pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id"); @@ -66,65 +65,34 @@ impl NeonControlPlaneClient { self.endpoint.url().as_str() } - async fn do_get_auth_info( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - if !self - .caches - .endpoints_cache - .is_valid(ctx, &user_info.endpoint.normalize()) - { - // TODO: refactor this because it's weird - // this is a failure to authenticate but we return Ok. - info!("endpoint is not valid, skipping the request"); - return Ok(AuthInfo::default()); - } - self.do_get_auth_req(user_info, &ctx.session_id(), Some(ctx)) - .await - } - async fn do_get_auth_req( &self, - user_info: &ComputeUserInfo, - session_id: &uuid::Uuid, - ctx: Option<&RequestContext>, + ctx: &RequestContext, + endpoint: &EndpointId, + role: &RoleName, ) -> Result { - let request_id: String = session_id.to_string(); - let application_name = if let Some(ctx) = ctx { - ctx.console_application_name() - } else { - "auth_cancellation".to_string() - }; - async { let request = self .endpoint .get_path("get_endpoint_access_control") - .header(X_REQUEST_ID, &request_id) + .header(X_REQUEST_ID, ctx.session_id().to_string()) .header(AUTHORIZATION, format!("Bearer {}", &self.jwt)) - .query(&[("session_id", session_id)]) + .query(&[("session_id", ctx.session_id())]) .query(&[ - ("application_name", application_name.as_str()), - ("endpointish", user_info.endpoint.as_str()), - ("role", user_info.user.as_str()), + ("application_name", ctx.console_application_name().as_str()), + ("endpointish", endpoint.as_str()), + ("role", role.as_str()), ]) .build()?; debug!(url = request.url().as_str(), "sending http request"); let start = Instant::now(); - let response = match ctx { - Some(ctx) => { - let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane); - let rsp = self.endpoint.execute(request).await; - drop(pause); - rsp? - } - None => self.endpoint.execute(request).await?, + let response = { + let _pause = ctx.latency_timer_pause_at(start, crate::metrics::Waiting::Cplane); + self.endpoint.execute(request).await? }; - info!(duration = ?start.elapsed(), "received http response"); + let body = match parse_body::(response).await { Ok(body) => body, // Error 404 is special: it's ok not to have a secret. @@ -180,7 +148,7 @@ impl NeonControlPlaneClient { async fn do_get_endpoint_jwks( &self, ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, GetEndpointJwksError> { if !self .caches @@ -313,225 +281,104 @@ impl NeonControlPlaneClient { impl super::ControlPlaneApi for NeonControlPlaneClient { #[tracing::instrument(skip_all)] - async fn get_role_secret( + async fn get_role_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - let normalized_ep = &user_info.endpoint.normalize(); - let user = &user_info.user; - if let Some(role_secret) = self + endpoint: &EndpointId, + role: &RoleName, + ) -> Result { + let normalized_ep = &endpoint.normalize(); + if let Some(secret) = self .caches .project_info - .get_role_secret(normalized_ep, user) + .get_role_secret(normalized_ep, role) { - return Ok(role_secret); + return Ok(secret); } - let auth_info = self.do_get_auth_info(ctx, user_info).await?; - let account_id = auth_info.account_id; + + if !self.caches.endpoints_cache.is_valid(ctx, normalized_ep) { + info!("endpoint is not valid, skipping the request"); + return Err(GetAuthInfoError::UnknownEndpoint); + } + + let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?; + + let control = EndpointAccessControl { + allowed_ips: Arc::new(auth_info.allowed_ips), + allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids), + flags: auth_info.access_blocker_flags, + }; + let role_control = RoleAccessControl { + secret: auth_info.secret, + }; + if let Some(project_id) = auth_info.project_id { let normalized_ep_int = normalized_ep.into(); - self.caches.project_info.insert_role_secret( + + self.caches.project_info.insert_endpoint_access( + auth_info.account_id, project_id, normalized_ep_int, - user.into(), - auth_info.secret.clone(), - ); - self.caches.project_info.insert_allowed_ips( - project_id, - normalized_ep_int, - Arc::new(auth_info.allowed_ips), - ); - self.caches.project_info.insert_allowed_vpc_endpoint_ids( - account_id, - project_id, - normalized_ep_int, - Arc::new(auth_info.allowed_vpc_endpoint_ids), - ); - self.caches.project_info.insert_block_public_or_vpc_access( - project_id, - normalized_ep_int, - auth_info.access_blocker_flags, + role.into(), + control, + role_control.clone(), ); ctx.set_project_id(project_id); } - // When we just got a secret, we don't need to invalidate it. - Ok(Cached::new_uncached(auth_info.secret)) + + Ok(role_control) } - async fn get_allowed_ips( + #[tracing::instrument(skip_all)] + async fn get_endpoint_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - let normalized_ep = &user_info.endpoint.normalize(); - if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) { - Metrics::get() - .proxy - .allowed_ips_cache_misses // TODO SR: Should we rename this variable to something like allowed_ip_cache_stats? - .inc(CacheOutcome::Hit); - return Ok(allowed_ips); + endpoint: &EndpointId, + role: &RoleName, + ) -> Result { + let normalized_ep = &endpoint.normalize(); + if let Some(control) = self.caches.project_info.get_endpoint_access(normalized_ep) { + return Ok(control); } - Metrics::get() - .proxy - .allowed_ips_cache_misses - .inc(CacheOutcome::Miss); - let auth_info = self.do_get_auth_info(ctx, user_info).await?; - let allowed_ips = Arc::new(auth_info.allowed_ips); - let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids); - let access_blocker_flags = auth_info.access_blocker_flags; - let user = &user_info.user; - let account_id = auth_info.account_id; + + if !self.caches.endpoints_cache.is_valid(ctx, normalized_ep) { + info!("endpoint is not valid, skipping the request"); + return Err(GetAuthInfoError::UnknownEndpoint); + } + + let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?; + + let control = EndpointAccessControl { + allowed_ips: Arc::new(auth_info.allowed_ips), + allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids), + flags: auth_info.access_blocker_flags, + }; + let role_control = RoleAccessControl { + secret: auth_info.secret, + }; + if let Some(project_id) = auth_info.project_id { let normalized_ep_int = normalized_ep.into(); - self.caches.project_info.insert_role_secret( + + self.caches.project_info.insert_endpoint_access( + auth_info.account_id, project_id, normalized_ep_int, - user.into(), - auth_info.secret.clone(), - ); - self.caches.project_info.insert_allowed_ips( - project_id, - normalized_ep_int, - allowed_ips.clone(), - ); - self.caches.project_info.insert_allowed_vpc_endpoint_ids( - account_id, - project_id, - normalized_ep_int, - allowed_vpc_endpoint_ids.clone(), - ); - self.caches.project_info.insert_block_public_or_vpc_access( - project_id, - normalized_ep_int, - access_blocker_flags, + role.into(), + control.clone(), + role_control, ); ctx.set_project_id(project_id); } - Ok(Cached::new_uncached(allowed_ips)) - } - async fn get_allowed_vpc_endpoint_ids( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - let normalized_ep = &user_info.endpoint.normalize(); - if let Some(allowed_vpc_endpoint_ids) = self - .caches - .project_info - .get_allowed_vpc_endpoint_ids(normalized_ep) - { - Metrics::get() - .proxy - .vpc_endpoint_id_cache_stats - .inc(CacheOutcome::Hit); - return Ok(allowed_vpc_endpoint_ids); - } - - Metrics::get() - .proxy - .vpc_endpoint_id_cache_stats - .inc(CacheOutcome::Miss); - - let auth_info = self.do_get_auth_info(ctx, user_info).await?; - let allowed_ips = Arc::new(auth_info.allowed_ips); - let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids); - let access_blocker_flags = auth_info.access_blocker_flags; - let user = &user_info.user; - let account_id = auth_info.account_id; - if let Some(project_id) = auth_info.project_id { - let normalized_ep_int = normalized_ep.into(); - self.caches.project_info.insert_role_secret( - project_id, - normalized_ep_int, - user.into(), - auth_info.secret.clone(), - ); - self.caches.project_info.insert_allowed_ips( - project_id, - normalized_ep_int, - allowed_ips.clone(), - ); - self.caches.project_info.insert_allowed_vpc_endpoint_ids( - account_id, - project_id, - normalized_ep_int, - allowed_vpc_endpoint_ids.clone(), - ); - self.caches.project_info.insert_block_public_or_vpc_access( - project_id, - normalized_ep_int, - access_blocker_flags, - ); - ctx.set_project_id(project_id); - } - Ok(Cached::new_uncached(allowed_vpc_endpoint_ids)) - } - - async fn get_block_public_or_vpc_access( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - let normalized_ep = &user_info.endpoint.normalize(); - if let Some(access_blocker_flags) = self - .caches - .project_info - .get_block_public_or_vpc_access(normalized_ep) - { - Metrics::get() - .proxy - .access_blocker_flags_cache_stats - .inc(CacheOutcome::Hit); - return Ok(access_blocker_flags); - } - - Metrics::get() - .proxy - .access_blocker_flags_cache_stats - .inc(CacheOutcome::Miss); - - let auth_info = self.do_get_auth_info(ctx, user_info).await?; - let allowed_ips = Arc::new(auth_info.allowed_ips); - let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids); - let access_blocker_flags = auth_info.access_blocker_flags; - let user = &user_info.user; - let account_id = auth_info.account_id; - if let Some(project_id) = auth_info.project_id { - let normalized_ep_int = normalized_ep.into(); - self.caches.project_info.insert_role_secret( - project_id, - normalized_ep_int, - user.into(), - auth_info.secret.clone(), - ); - self.caches.project_info.insert_allowed_ips( - project_id, - normalized_ep_int, - allowed_ips.clone(), - ); - self.caches.project_info.insert_allowed_vpc_endpoint_ids( - account_id, - project_id, - normalized_ep_int, - allowed_vpc_endpoint_ids.clone(), - ); - self.caches.project_info.insert_block_public_or_vpc_access( - project_id, - normalized_ep_int, - access_blocker_flags.clone(), - ); - ctx.set_project_id(project_id); - } - Ok(Cached::new_uncached(access_blocker_flags)) + Ok(control) } #[tracing::instrument(skip_all)] async fn get_endpoint_jwks( &self, ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, GetEndpointJwksError> { self.do_get_endpoint_jwks(ctx, endpoint).await } diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index d3ab4abd0b..ece7153fce 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -15,14 +15,14 @@ use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; use crate::cache::Cached; use crate::context::RequestContext; -use crate::control_plane::client::{ - CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedRoleSecret, -}; use crate::control_plane::errors::{ ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError, }; use crate::control_plane::messages::MetricsAuxInfo; -use crate::control_plane::{AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo}; +use crate::control_plane::{ + AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo, + RoleAccessControl, +}; use crate::intern::RoleNameInt; use crate::types::{BranchId, EndpointId, ProjectId, RoleName}; use crate::url::ApiUrl; @@ -66,7 +66,8 @@ impl MockControlPlane { async fn do_get_auth_info( &self, - user_info: &ComputeUserInfo, + endpoint: &EndpointId, + role: &RoleName, ) -> Result { let (secret, allowed_ips) = async { // Perhaps we could persist this connection, but then we'd have to @@ -80,7 +81,7 @@ impl MockControlPlane { let secret = if let Some(entry) = get_execute_postgres_query( &client, "select rolpassword from pg_catalog.pg_authid where rolname = $1", - &[&&*user_info.user], + &[&role.as_str()], "rolpassword", ) .await? @@ -89,7 +90,7 @@ impl MockControlPlane { let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram); secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5)) } else { - warn!("user '{}' does not exist", user_info.user); + warn!("user '{role}' does not exist"); None }; @@ -97,7 +98,7 @@ impl MockControlPlane { match get_execute_postgres_query( &client, "select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1", - &[&user_info.endpoint.as_str()], + &[&endpoint.as_str()], "allowed_ips", ) .await? @@ -133,7 +134,7 @@ impl MockControlPlane { async fn do_get_endpoint_jwks( &self, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, GetEndpointJwksError> { let (client, connection) = tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?; @@ -222,53 +223,36 @@ async fn get_execute_postgres_query( } impl super::ControlPlaneApi for MockControlPlane { - #[tracing::instrument(skip_all)] - async fn get_role_secret( + async fn get_endpoint_access_control( &self, _ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - Ok(CachedRoleSecret::new_uncached( - self.do_get_auth_info(user_info).await?.secret, - )) + endpoint: &EndpointId, + role: &RoleName, + ) -> Result { + let info = self.do_get_auth_info(endpoint, role).await?; + Ok(EndpointAccessControl { + allowed_ips: Arc::new(info.allowed_ips), + allowed_vpce: Arc::new(info.allowed_vpc_endpoint_ids), + flags: info.access_blocker_flags, + }) } - async fn get_allowed_ips( + async fn get_role_access_control( &self, _ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - Ok(Cached::new_uncached(Arc::new( - self.do_get_auth_info(user_info).await?.allowed_ips, - ))) - } - - async fn get_allowed_vpc_endpoint_ids( - &self, - _ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - Ok(Cached::new_uncached(Arc::new( - self.do_get_auth_info(user_info) - .await? - .allowed_vpc_endpoint_ids, - ))) - } - - async fn get_block_public_or_vpc_access( - &self, - _ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - Ok(Cached::new_uncached( - self.do_get_auth_info(user_info).await?.access_blocker_flags, - )) + endpoint: &EndpointId, + role: &RoleName, + ) -> Result { + let info = self.do_get_auth_info(endpoint, role).await?; + Ok(RoleAccessControl { + secret: info.secret, + }) } async fn get_endpoint_jwks( &self, _ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, GetEndpointJwksError> { self.do_get_endpoint_jwks(endpoint).await } diff --git a/proxy/src/control_plane/client/mod.rs b/proxy/src/control_plane/client/mod.rs index 746595de38..9b9d1e25ea 100644 --- a/proxy/src/control_plane/client/mod.rs +++ b/proxy/src/control_plane/client/mod.rs @@ -16,15 +16,14 @@ use crate::cache::endpoints::EndpointsCache; use crate::cache::project_info::ProjectInfoCacheImpl; use crate::config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions}; use crate::context::RequestContext; -use crate::control_plane::{ - CachedAccessBlockerFlags, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, - CachedRoleSecret, ControlPlaneApi, NodeInfoCache, errors, -}; +use crate::control_plane::{CachedNodeInfo, ControlPlaneApi, NodeInfoCache, errors}; use crate::error::ReportableError; use crate::metrics::ApiLockMetrics; use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token}; use crate::types::EndpointId; +use super::{EndpointAccessControl, RoleAccessControl}; + #[non_exhaustive] #[derive(Clone)] pub enum ControlPlaneClient { @@ -40,68 +39,42 @@ pub enum ControlPlaneClient { } impl ControlPlaneApi for ControlPlaneClient { - async fn get_role_secret( + async fn get_role_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { + endpoint: &EndpointId, + role: &crate::types::RoleName, + ) -> Result { match self { - Self::ProxyV1(api) => api.get_role_secret(ctx, user_info).await, + Self::ProxyV1(api) => api.get_role_access_control(ctx, endpoint, role).await, #[cfg(any(test, feature = "testing"))] - Self::PostgresMock(api) => api.get_role_secret(ctx, user_info).await, + Self::PostgresMock(api) => api.get_role_access_control(ctx, endpoint, role).await, #[cfg(test)] - Self::Test(_) => { + Self::Test(_api) => { unreachable!("this function should never be called in the test backend") } } } - async fn get_allowed_ips( + async fn get_endpoint_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { + endpoint: &EndpointId, + role: &crate::types::RoleName, + ) -> Result { match self { - Self::ProxyV1(api) => api.get_allowed_ips(ctx, user_info).await, + Self::ProxyV1(api) => api.get_endpoint_access_control(ctx, endpoint, role).await, #[cfg(any(test, feature = "testing"))] - Self::PostgresMock(api) => api.get_allowed_ips(ctx, user_info).await, + Self::PostgresMock(api) => api.get_endpoint_access_control(ctx, endpoint, role).await, #[cfg(test)] - Self::Test(api) => api.get_allowed_ips(), - } - } - - async fn get_allowed_vpc_endpoint_ids( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - match self { - Self::ProxyV1(api) => api.get_allowed_vpc_endpoint_ids(ctx, user_info).await, - #[cfg(any(test, feature = "testing"))] - Self::PostgresMock(api) => api.get_allowed_vpc_endpoint_ids(ctx, user_info).await, - #[cfg(test)] - Self::Test(api) => api.get_allowed_vpc_endpoint_ids(), - } - } - - async fn get_block_public_or_vpc_access( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - match self { - Self::ProxyV1(api) => api.get_block_public_or_vpc_access(ctx, user_info).await, - #[cfg(any(test, feature = "testing"))] - Self::PostgresMock(api) => api.get_block_public_or_vpc_access(ctx, user_info).await, - #[cfg(test)] - Self::Test(api) => api.get_block_public_or_vpc_access(), + Self::Test(api) => api.get_access_control(), } } async fn get_endpoint_jwks( &self, ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, errors::GetEndpointJwksError> { match self { Self::ProxyV1(api) => api.get_endpoint_jwks(ctx, endpoint).await, @@ -131,15 +104,7 @@ impl ControlPlaneApi for ControlPlaneClient { pub(crate) trait TestControlPlaneClient: Send + Sync + 'static { fn wake_compute(&self) -> Result; - fn get_allowed_ips(&self) -> Result; - - fn get_allowed_vpc_endpoint_ids( - &self, - ) -> Result; - - fn get_block_public_or_vpc_access( - &self, - ) -> Result; + fn get_access_control(&self) -> Result; fn dyn_clone(&self) -> Box; } @@ -309,7 +274,7 @@ impl FetchAuthRules for ControlPlaneClient { ctx: &RequestContext, endpoint: EndpointId, ) -> Result, FetchAuthRulesError> { - self.get_endpoint_jwks(ctx, endpoint) + self.get_endpoint_jwks(ctx, &endpoint) .await .map_err(FetchAuthRulesError::GetEndpointJwks) } diff --git a/proxy/src/control_plane/errors.rs b/proxy/src/control_plane/errors.rs index 850d061333..77312c89c5 100644 --- a/proxy/src/control_plane/errors.rs +++ b/proxy/src/control_plane/errors.rs @@ -99,6 +99,10 @@ pub(crate) enum GetAuthInfoError { #[error(transparent)] ApiError(ControlPlaneError), + + /// Proxy does not know about the endpoint in advanced + #[error("endpoint not found in endpoint cache")] + UnknownEndpoint, } // This allows more useful interactions than `#[from]`. @@ -115,6 +119,8 @@ impl UserFacingError for GetAuthInfoError { Self::BadSecret => REQUEST_FAILED.to_owned(), // However, API might return a meaningful error. Self::ApiError(e) => e.to_string_client(), + // pretend like control plane returned an error. + Self::UnknownEndpoint => REQUEST_FAILED.to_owned(), } } } @@ -124,6 +130,8 @@ impl ReportableError for GetAuthInfoError { match self { Self::BadSecret => crate::error::ErrorKind::ControlPlane, Self::ApiError(_) => crate::error::ErrorKind::ControlPlane, + // we only apply endpoint filtering if control plane is under high load. + Self::UnknownEndpoint => crate::error::ErrorKind::ServiceRateLimit, } } } diff --git a/proxy/src/control_plane/mod.rs b/proxy/src/control_plane/mod.rs index d592223be1..7ff093d9dc 100644 --- a/proxy/src/control_plane/mod.rs +++ b/proxy/src/control_plane/mod.rs @@ -11,16 +11,16 @@ pub(crate) mod errors; use std::sync::Arc; -use crate::auth::IpPattern; use crate::auth::backend::jwt::AuthRule; use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo}; -use crate::cache::project_info::ProjectInfoCacheImpl; +use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list}; use crate::cache::{Cached, TimedLru}; use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo}; use crate::intern::{AccountIdInt, ProjectIdInt}; -use crate::types::{EndpointCacheKey, EndpointId}; +use crate::protocol2::ConnectionInfoExtra; +use crate::types::{EndpointCacheKey, EndpointId, RoleName}; use crate::{compute, scram}; /// Various cache-related types. @@ -101,7 +101,7 @@ impl NodeInfo { } } -#[derive(Clone, Default, Eq, PartialEq, Debug)] +#[derive(Copy, Clone, Default)] pub(crate) struct AccessBlockerFlags { pub public_access_blocked: bool, pub vpc_access_blocked: bool, @@ -110,47 +110,78 @@ pub(crate) struct AccessBlockerFlags { pub(crate) type NodeInfoCache = TimedLru>>; pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>; -pub(crate) type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option>; -pub(crate) type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc>>; -pub(crate) type CachedAllowedVpcEndpointIds = - Cached<&'static ProjectInfoCacheImpl, Arc>>; -pub(crate) type CachedAccessBlockerFlags = - Cached<&'static ProjectInfoCacheImpl, AccessBlockerFlags>; + +#[derive(Clone)] +pub struct RoleAccessControl { + pub secret: Option, +} + +#[derive(Clone)] +pub struct EndpointAccessControl { + pub allowed_ips: Arc>, + pub allowed_vpce: Arc>, + pub flags: AccessBlockerFlags, +} + +impl EndpointAccessControl { + pub fn check( + &self, + ctx: &RequestContext, + check_ip_allowed: bool, + check_vpc_allowed: bool, + ) -> Result<(), AuthError> { + if check_ip_allowed && !check_peer_addr_is_in_list(&ctx.peer_addr(), &self.allowed_ips) { + return Err(AuthError::IpAddressNotAllowed(ctx.peer_addr())); + } + + // check if a VPC endpoint ID is coming in and if yes, if it's allowed + if check_vpc_allowed { + if self.flags.vpc_access_blocked { + return Err(AuthError::NetworkNotAllowed); + } + + let incoming_vpc_endpoint_id = match ctx.extra() { + None => return Err(AuthError::MissingVPCEndpointId), + Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(), + Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), + }; + + let vpce = &self.allowed_vpce; + // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. + if !vpce.is_empty() && !vpce.contains(&incoming_vpc_endpoint_id) { + return Err(AuthError::vpc_endpoint_id_not_allowed( + incoming_vpc_endpoint_id, + )); + } + } else if self.flags.public_access_blocked { + return Err(AuthError::NetworkNotAllowed); + } + + Ok(()) + } +} /// This will allocate per each call, but the http requests alone /// already require a few allocations, so it should be fine. pub(crate) trait ControlPlaneApi { - /// Get the client's auth secret for authentication. - /// Returns option because user not found situation is special. - /// We still have to mock the scram to avoid leaking information that user doesn't exist. - async fn get_role_secret( + async fn get_role_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result; + endpoint: &EndpointId, + role: &RoleName, + ) -> Result; - async fn get_allowed_ips( + async fn get_endpoint_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result; - - async fn get_allowed_vpc_endpoint_ids( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result; - - async fn get_block_public_or_vpc_access( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result; + endpoint: &EndpointId, + role: &RoleName, + ) -> Result; async fn get_endpoint_jwks( &self, ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, errors::GetEndpointJwksError>; /// Wake up the compute node and return the corresponding connection info. diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 26ac6a89e7..ac0aca1176 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -345,7 +345,7 @@ pub(crate) async fn handle_client( }; let user = user_info.get_user().to_owned(); - let (user_info, _ip_allowlist) = match user_info + let user_info = match user_info .authenticate( ctx, &mut stream, diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 3cc053e0ad..61e8ee4a10 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -26,9 +26,7 @@ use crate::auth::backend::{ use crate::config::{ComputeConfig, RetryConfig}; use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient}; use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status}; -use crate::control_plane::{ - self, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, NodeInfo, NodeInfoCache, -}; +use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache}; use crate::error::ErrorKind; use crate::tls::client_config::compute_client_config_with_certs; use crate::tls::postgres_rustls::MakeRustlsConnect; @@ -547,20 +545,9 @@ impl TestControlPlaneClient for TestConnectMechanism { } } - fn get_allowed_ips(&self) -> Result { - unimplemented!("not used in tests") - } - - fn get_allowed_vpc_endpoint_ids( + fn get_access_control( &self, - ) -> Result { - unimplemented!("not used in tests") - } - - fn get_block_public_or_vpc_access( - &self, - ) -> Result - { + ) -> Result { unimplemented!("not used in tests") } diff --git a/proxy/src/rate_limiter/leaky_bucket.rs b/proxy/src/rate_limiter/leaky_bucket.rs index 4f27c6faef..0c79b5e92f 100644 --- a/proxy/src/rate_limiter/leaky_bucket.rs +++ b/proxy/src/rate_limiter/leaky_bucket.rs @@ -15,7 +15,7 @@ pub type EndpointRateLimiter = LeakyBucketRateLimiter; pub struct LeakyBucketRateLimiter { map: ClashMap, - config: utils::leaky_bucket::LeakyBucketConfig, + default_config: utils::leaky_bucket::LeakyBucketConfig, access_count: AtomicUsize, } @@ -28,15 +28,17 @@ impl LeakyBucketRateLimiter { pub fn new_with_shards(config: LeakyBucketConfig, shards: usize) -> Self { Self { map: ClashMap::with_hasher_and_shard_amount(RandomState::new(), shards), - config: config.into(), + default_config: config.into(), access_count: AtomicUsize::new(0), } } /// Check that number of connections to the endpoint is below `max_rps` rps. - pub(crate) fn check(&self, key: K, n: u32) -> bool { + pub(crate) fn check(&self, key: K, config: Option, n: u32) -> bool { let now = Instant::now(); + let config = config.map_or(self.default_config, Into::into); + if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 { self.do_gc(now); } @@ -46,7 +48,7 @@ impl LeakyBucketRateLimiter { .entry(key) .or_insert_with(|| LeakyBucketState { empty_at: now }); - entry.add_tokens(&self.config, now, n as f64).is_ok() + entry.add_tokens(&config, now, n as f64).is_ok() } fn do_gc(&self, now: Instant) { diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index 21eaa6739b..9d700c1b52 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -15,6 +15,8 @@ use tracing::info; use crate::ext::LockExt; use crate::intern::EndpointIdInt; +use super::LeakyBucketConfig; + pub struct GlobalRateLimiter { data: Vec, info: Vec, @@ -144,19 +146,6 @@ impl RateBucketInfo { Self::new(50_000, Duration::from_secs(10)), ]; - /// All of these are per endpoint-maskedip pair. - /// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus). - /// - /// First bucket: 1000mcpus total per endpoint-ip pair - /// * 4096000 requests per second with 1 hash rounds. - /// * 1000 requests per second with 4096 hash rounds. - /// * 6.8 requests per second with 600000 hash rounds. - pub const DEFAULT_AUTH_SET: [Self; 3] = [ - Self::new(1000 * 4096, Duration::from_secs(1)), - Self::new(600 * 4096, Duration::from_secs(60)), - Self::new(300 * 4096, Duration::from_secs(600)), - ]; - pub fn rps(&self) -> f64 { (self.max_rpi as f64) / self.interval.as_secs_f64() } @@ -184,6 +173,21 @@ impl RateBucketInfo { max_rpi: ((max_rps as u64) * (interval.as_millis() as u64) / 1000) as u32, } } + + pub fn to_leaky_bucket(this: &[Self]) -> Option { + // bit of a hack - find the min rps and max rps supported and turn it into + // leaky bucket config instead + + let mut iter = this.iter().map(|info| info.rps()); + let first = iter.next()?; + + let (min, max) = (first, first); + let (min, max) = iter.fold((min, max), |(min, max), rps| { + (f64::min(min, rps), f64::max(max, rps)) + }); + + Some(LeakyBucketConfig { rps: min, max }) + } } impl BucketRateLimiter { diff --git a/proxy/src/rate_limiter/mod.rs b/proxy/src/rate_limiter/mod.rs index 5f90102da3..112b95873a 100644 --- a/proxy/src/rate_limiter/mod.rs +++ b/proxy/src/rate_limiter/mod.rs @@ -8,4 +8,4 @@ pub(crate) use limit_algorithm::aimd::Aimd; pub(crate) use limit_algorithm::{ DynamicLimiter, Outcome, RateLimitAlgorithm, RateLimiterConfig, Token, }; -pub use limiter::{BucketRateLimiter, GlobalRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; +pub use limiter::{GlobalRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 769d519d94..a9d6b40603 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -233,29 +233,30 @@ impl MessageHandler { fn invalidate_cache(cache: Arc, msg: Notification) { match msg { - Notification::AllowedIpsUpdate { allowed_ips_update } => { - cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id); + Notification::AllowedIpsUpdate { + allowed_ips_update: AllowedIpsUpdate { project_id }, } - Notification::BlockPublicOrVpcAccessUpdated { - block_public_or_vpc_access_updated, - } => cache.invalidate_block_public_or_vpc_access_for_project( - block_public_or_vpc_access_updated.project_id, - ), + | Notification::BlockPublicOrVpcAccessUpdated { + block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated { project_id }, + } => cache.invalidate_endpoint_access_for_project(project_id), Notification::AllowedVpcEndpointsUpdatedForOrg { - allowed_vpc_endpoints_updated_for_org, - } => cache.invalidate_allowed_vpc_endpoint_ids_for_org( - allowed_vpc_endpoints_updated_for_org.account_id, - ), + allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg { account_id }, + } => cache.invalidate_endpoint_access_for_org(account_id), Notification::AllowedVpcEndpointsUpdatedForProjects { - allowed_vpc_endpoints_updated_for_projects, - } => cache.invalidate_allowed_vpc_endpoint_ids_for_projects( - allowed_vpc_endpoints_updated_for_projects.project_ids, - ), - Notification::PasswordUpdate { password_update } => cache - .invalidate_role_secret_for_project( - password_update.project_id, - password_update.role_name, - ), + allowed_vpc_endpoints_updated_for_projects: + AllowedVpcEndpointsUpdatedForProjects { project_ids }, + } => { + for project in project_ids { + cache.invalidate_endpoint_access_for_project(project); + } + } + Notification::PasswordUpdate { + password_update: + PasswordUpdate { + project_id, + role_name, + }, + } => cache.invalidate_role_secret_for_project(project_id, role_name), Notification::UnknownTopic => unreachable!(), } } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 13058f08f1..bf640c05e9 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -22,7 +22,7 @@ use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client}; use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool}; use crate::auth::backend::local::StaticAuthRules; use crate::auth::backend::{ComputeCredentials, ComputeUserInfo}; -use crate::auth::{self, AuthError, check_peer_addr_is_in_list}; +use crate::auth::{self, AuthError}; use crate::compute; use crate::compute_ctl::{ ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest, @@ -35,7 +35,6 @@ use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError}; use crate::control_plane::locks::ApiLocks; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::intern::EndpointIdInt; -use crate::protocol2::ConnectionInfoExtra; use crate::proxy::connect_compute::ConnectMechanism; use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute}; use crate::rate_limiter::EndpointRateLimiter; @@ -63,63 +62,24 @@ impl PoolingBackend { let user_info = user_info.clone(); let backend = self.auth_backend.as_ref().map(|()| user_info.clone()); - let allowed_ips = backend.get_allowed_ips(ctx).await?; + let access_control = backend.get_endpoint_access_control(ctx).await?; + access_control.check( + ctx, + self.config.authentication_config.ip_allowlist_check_enabled, + self.config.authentication_config.is_vpc_acccess_proxy, + )?; - if self.config.authentication_config.ip_allowlist_check_enabled - && !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) - { - return Err(AuthError::ip_address_not_allowed(ctx.peer_addr())); - } - - let access_blocker_flags = backend.get_block_public_or_vpc_access(ctx).await?; - if self.config.authentication_config.is_vpc_acccess_proxy { - if access_blocker_flags.vpc_access_blocked { - return Err(AuthError::NetworkNotAllowed); - } - - let extra = ctx.extra(); - let incoming_endpoint_id = match extra { - None => String::new(), - Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(), - Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), - }; - - if incoming_endpoint_id.is_empty() { - return Err(AuthError::MissingVPCEndpointId); - } - - let allowed_vpc_endpoint_ids = backend.get_allowed_vpc_endpoint_ids(ctx).await?; - // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. - if !allowed_vpc_endpoint_ids.is_empty() - && !allowed_vpc_endpoint_ids.contains(&incoming_endpoint_id) - { - return Err(AuthError::vpc_endpoint_id_not_allowed(incoming_endpoint_id)); - } - } else if access_blocker_flags.public_access_blocked { - return Err(AuthError::NetworkNotAllowed); - } - - if !self - .endpoint_rate_limiter - .check(user_info.endpoint.clone().into(), 1) - { + let ep = EndpointIdInt::from(&user_info.endpoint); + let rate_limit_config = None; + if !self.endpoint_rate_limiter.check(ep, rate_limit_config, 1) { return Err(AuthError::too_many_connections()); } - let cached_secret = backend.get_role_secret(ctx).await?; - let secret = match cached_secret.value.clone() { - Some(secret) => self.config.authentication_config.check_rate_limit( - ctx, - secret, - &user_info.endpoint, - true, - )?, - None => { - // If we don't have an authentication secret, for the http flow we can just return an error. - info!("authentication info not found"); - return Err(AuthError::password_failed(&*user_info.user)); - } + let role_access = backend.get_role_secret(ctx).await?; + let Some(secret) = role_access.secret else { + // If we don't have an authentication secret, for the http flow we can just return an error. + info!("authentication info not found"); + return Err(AuthError::password_failed(&*user_info.user)); }; - let ep = EndpointIdInt::from(&user_info.endpoint); let auth_outcome = crate::auth::validate_password_and_exchange( &self.config.authentication_config.thread_pool, ep, From af5bb67f08e92c018111055dcef26d9e3bab665a Mon Sep 17 00:00:00 2001 From: Vlad Lazar Date: Mon, 2 Jun 2025 09:59:21 +0100 Subject: [PATCH 09/43] pageserver: more reactive wal receiver cancellation (#12076) ## Problem If the wal receiver is cancelled, there's a 50% chance that it will ingest yet more WAL. ## Summary of Changes Always check cancellation first. --- pageserver/src/tenant/timeline/walreceiver.rs | 2 +- .../src/tenant/timeline/walreceiver/walreceiver_connection.rs | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pageserver/src/tenant/timeline/walreceiver.rs b/pageserver/src/tenant/timeline/walreceiver.rs index 0f73eb839b..633c94a010 100644 --- a/pageserver/src/tenant/timeline/walreceiver.rs +++ b/pageserver/src/tenant/timeline/walreceiver.rs @@ -113,7 +113,7 @@ impl WalReceiver { } connection_manager_state.shutdown().await; *loop_status.write().unwrap() = None; - debug!("task exits"); + info!("task exits"); } .instrument(info_span!(parent: None, "wal_connection_manager", tenant_id = %tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug(), timeline_id = %timeline_id)) }); diff --git a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs index 52259f205b..249849ac4b 100644 --- a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs +++ b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs @@ -297,6 +297,7 @@ pub(super) async fn handle_walreceiver_connection( let mut expected_wal_start = startpoint; while let Some(replication_message) = { select! { + biased; _ = cancellation.cancelled() => { debug!("walreceiver interrupted"); None From 5b62749c42c256db001eb3e59d8d289b4ad942cd Mon Sep 17 00:00:00 2001 From: Vlad Lazar Date: Mon, 2 Jun 2025 11:29:15 +0100 Subject: [PATCH 10/43] pageserver: reduce import memory utilization (#12086) ## Problem Imports can end up allocating too much. ## Summary of Changes Nerf them a bunch and add some logs. --- libs/pageserver_api/src/config.rs | 6 +++--- pageserver/src/tenant/timeline/import_pgdata.rs | 2 ++ pageserver/src/tenant/timeline/import_pgdata/flow.rs | 12 +++++++++++- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/libs/pageserver_api/src/config.rs b/libs/pageserver_api/src/config.rs index 012c020fb1..444983bd18 100644 --- a/libs/pageserver_api/src/config.rs +++ b/libs/pageserver_api/src/config.rs @@ -713,9 +713,9 @@ impl Default for ConfigToml { enable_tls_page_service_api: false, dev_mode: false, timeline_import_config: TimelineImportConfig { - import_job_concurrency: NonZeroUsize::new(128).unwrap(), - import_job_soft_size_limit: NonZeroUsize::new(1024 * 1024 * 1024).unwrap(), - import_job_checkpoint_threshold: NonZeroUsize::new(128).unwrap(), + import_job_concurrency: NonZeroUsize::new(32).unwrap(), + import_job_soft_size_limit: NonZeroUsize::new(256 * 1024 * 1024).unwrap(), + import_job_checkpoint_threshold: NonZeroUsize::new(32).unwrap(), }, basebackup_cache_config: None, posthog_config: None, diff --git a/pageserver/src/tenant/timeline/import_pgdata.rs b/pageserver/src/tenant/timeline/import_pgdata.rs index f19a4b3e9c..3f760d858b 100644 --- a/pageserver/src/tenant/timeline/import_pgdata.rs +++ b/pageserver/src/tenant/timeline/import_pgdata.rs @@ -106,6 +106,8 @@ pub async fn doit( ); } + tracing::info!("Import plan executed. Flushing remote changes and notifying storcon"); + timeline .remote_client .schedule_index_upload_for_file_changes()?; diff --git a/pageserver/src/tenant/timeline/import_pgdata/flow.rs b/pageserver/src/tenant/timeline/import_pgdata/flow.rs index bf3c7eeda6..760e82dd57 100644 --- a/pageserver/src/tenant/timeline/import_pgdata/flow.rs +++ b/pageserver/src/tenant/timeline/import_pgdata/flow.rs @@ -130,7 +130,15 @@ async fn run_v1( pausable_failpoint!("import-timeline-pre-execute-pausable"); + let jobs_count = import_progress.as_ref().map(|p| p.jobs); let start_from_job_idx = import_progress.map(|progress| progress.completed); + + tracing::info!( + start_from_job_idx=?start_from_job_idx, + jobs=?jobs_count, + "Executing import plan" + ); + plan.execute(timeline, start_from_job_idx, plan_hash, &import_config, ctx) .await } @@ -484,6 +492,8 @@ impl Plan { anyhow::anyhow!("Shut down while putting timeline import status") })?; } + + tracing::info!(last_completed_job_idx, jobs=%jobs_in_plan, "Checkpointing import status"); }, Some(Err(_)) => { anyhow::bail!( @@ -760,7 +770,7 @@ impl ImportTask for ImportRelBlocksTask { layer_writer: &mut ImageLayerWriter, ctx: &RequestContext, ) -> anyhow::Result { - const MAX_BYTE_RANGE_SIZE: usize = 128 * 1024 * 1024; + const MAX_BYTE_RANGE_SIZE: usize = 4 * 1024 * 1024; debug!("Importing relation file"); From 8d7ed2a4ee1e2753d8a3ac17c6bc43ccabc2ed2e Mon Sep 17 00:00:00 2001 From: Erik Grinaker Date: Mon, 2 Jun 2025 13:46:50 +0200 Subject: [PATCH 11/43] pageserver: add gRPC observability middleware (#12093) ## Problem The page service logic asserts that a tracing span is present with tenant/timeline/shard IDs. An initial gRPC page service implementation thus requires a tracing span. Touches https://github.com/neondatabase/neon/issues/11728. ## Summary of changes Adds an `ObservabilityLayer` middleware that generates a tracing span and decorates it with IDs from the gRPC metadata. This is a minimal implementation to address the tracing span assertion. It will be extended with additional observability in later PRs. --- Cargo.lock | 2 + libs/utils/src/lib.rs | 1 + libs/utils/src/span.rs | 19 +++++ pageserver/Cargo.toml | 2 + pageserver/src/page_service.rs | 137 ++++++++++++++++++++++++++------- 5 files changed, 134 insertions(+), 27 deletions(-) create mode 100644 libs/utils/src/span.rs diff --git a/Cargo.lock b/Cargo.lock index 89351432c1..4f7378e95d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4305,6 +4305,7 @@ dependencies = [ "hashlink", "hex", "hex-literal", + "http 1.1.0", "http-utils", "humantime", "humantime-serde", @@ -4367,6 +4368,7 @@ dependencies = [ "toml_edit", "tonic 0.13.1", "tonic-reflection", + "tower 0.5.2", "tracing", "tracing-utils", "twox-hash", diff --git a/libs/utils/src/lib.rs b/libs/utils/src/lib.rs index 206b8bbd8f..11f787562c 100644 --- a/libs/utils/src/lib.rs +++ b/libs/utils/src/lib.rs @@ -73,6 +73,7 @@ pub mod error; /// async timeout helper pub mod timeout; +pub mod span; pub mod sync; pub mod failpoint_support; diff --git a/libs/utils/src/span.rs b/libs/utils/src/span.rs new file mode 100644 index 0000000000..4dbc99044b --- /dev/null +++ b/libs/utils/src/span.rs @@ -0,0 +1,19 @@ +//! Tracing span helpers. + +/// Records the given fields in the current span, as a single call. The fields must already have +/// been declared for the span (typically with empty values). +#[macro_export] +macro_rules! span_record { + ($($tokens:tt)*) => {$crate::span_record_in!(::tracing::Span::current(), $($tokens)*)}; +} + +/// Records the given fields in the given span, as a single call. The fields must already have been +/// declared for the span (typically with empty values). +#[macro_export] +macro_rules! span_record_in { + ($span:expr, $($tokens:tt)*) => { + if let Some(meta) = $span.metadata() { + $span.record_all(&tracing::valueset!(meta.fields(), $($tokens)*)); + } + }; +} diff --git a/pageserver/Cargo.toml b/pageserver/Cargo.toml index c4d6d58945..9591c729e8 100644 --- a/pageserver/Cargo.toml +++ b/pageserver/Cargo.toml @@ -34,6 +34,7 @@ fail.workspace = true futures.workspace = true hashlink.workspace = true hex.workspace = true +http.workspace = true http-utils.workspace = true humantime-serde.workspace = true humantime.workspace = true @@ -93,6 +94,7 @@ tokio-util.workspace = true toml_edit = { workspace = true, features = [ "serde" ] } tonic.workspace = true tonic-reflection.workspace = true +tower.workspace = true tracing.workspace = true tracing-utils.workspace = true url.workspace = true diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index e96787e027..f011ed49d0 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -7,12 +7,14 @@ use std::os::fd::AsRawFd; use std::pin::Pin; use std::str::FromStr; use std::sync::Arc; +use std::task::{Context, Poll}; use std::time::{Duration, Instant, SystemTime}; use std::{io, str}; -use anyhow::{Context, bail}; +use anyhow::{Context as _, bail}; use async_compression::tokio::write::GzipEncoder; use bytes::Buf; +use futures::future::BoxFuture; use futures::{FutureExt, Stream}; use itertools::Itertools; use jsonwebtoken::TokenData; @@ -46,7 +48,6 @@ use tokio_util::sync::CancellationToken; use tonic::service::Interceptor as _; use tracing::*; use utils::auth::{Claims, Scope, SwappableJwtAuth}; -use utils::failpoint_support; use utils::id::{TenantId, TenantTimelineId, TimelineId}; use utils::logging::log_slow; use utils::lsn::Lsn; @@ -54,6 +55,7 @@ use utils::shard::ShardIndex; use utils::simple_rcu::RcuReadGuard; use utils::sync::gate::{Gate, GateGuard}; use utils::sync::spsc_fold; +use utils::{failpoint_support, span_record}; use crate::auth::check_permission; use crate::basebackup::{self, BasebackupError}; @@ -195,13 +197,17 @@ pub fn spawn_grpc( // Set up the gRPC server. // // TODO: consider tuning window sizes. - // TODO: wire up tracing. let mut server = tonic::transport::Server::builder() .http2_keepalive_interval(Some(GRPC_HTTP2_KEEPALIVE_INTERVAL)) .http2_keepalive_timeout(Some(GRPC_HTTP2_KEEPALIVE_TIMEOUT)) .max_concurrent_streams(Some(GRPC_MAX_CONCURRENT_STREAMS)); - // Main page service. + // Main page service stack. Uses a mix of Tonic interceptors and Tower layers: + // + // * Interceptors: can inspect and modify the gRPC request. Sync code only, runs before service. + // + // * Layers: allow async code, can run code after the service response. However, only has access + // to the raw HTTP request/response, not the gRPC types. let page_service_handler = PageServerHandler::new( tenant_manager, auth.clone(), @@ -214,16 +220,22 @@ pub fn spawn_grpc( gate.enter().expect("just created"), ); + let observability_layer = ObservabilityLayer; let mut tenant_interceptor = TenantMetadataInterceptor; let mut auth_interceptor = TenantAuthInterceptor::new(auth); - let interceptors = move |mut req: tonic::Request<()>| { - req = tenant_interceptor.call(req)?; - req = auth_interceptor.call(req)?; - Ok(req) - }; - let page_service = - proto::PageServiceServer::with_interceptor(page_service_handler, interceptors); + let page_service = tower::ServiceBuilder::new() + // Create tracing span. + .layer(observability_layer) + // Intercept gRPC requests. + .layer(tonic::service::InterceptorLayer::new(move |mut req| { + // Extract tenant metadata. + req = tenant_interceptor.call(req)?; + // Authenticate tenant JWT token. + req = auth_interceptor.call(req)?; + Ok(req) + })) + .service(proto::PageServiceServer::new(page_service_handler)); let server = server.add_service(page_service); // Reflection service for use with e.g. grpcurl. @@ -3311,6 +3323,7 @@ impl proto::PageService for PageServerHandler { type GetPagesStream = Pin> + Send>>; + #[instrument(skip_all)] async fn check_rel_exists( &self, _: tonic::Request, @@ -3318,6 +3331,7 @@ impl proto::PageService for PageServerHandler { Err(tonic::Status::unimplemented("not implemented")) } + #[instrument(skip_all)] async fn get_base_backup( &self, _: tonic::Request, @@ -3325,6 +3339,7 @@ impl proto::PageService for PageServerHandler { Err(tonic::Status::unimplemented("not implemented")) } + #[instrument(skip_all)] async fn get_db_size( &self, _: tonic::Request, @@ -3332,6 +3347,7 @@ impl proto::PageService for PageServerHandler { Err(tonic::Status::unimplemented("not implemented")) } + // NB: don't instrument this, instrument each streamed request. async fn get_pages( &self, _: tonic::Request>, @@ -3339,6 +3355,7 @@ impl proto::PageService for PageServerHandler { Err(tonic::Status::unimplemented("not implemented")) } + #[instrument(skip_all)] async fn get_rel_size( &self, _: tonic::Request, @@ -3346,6 +3363,7 @@ impl proto::PageService for PageServerHandler { Err(tonic::Status::unimplemented("not implemented")) } + #[instrument(skip_all)] async fn get_slru_segment( &self, _: tonic::Request, @@ -3354,19 +3372,65 @@ impl proto::PageService for PageServerHandler { } } -impl From for QueryError { - fn from(e: GetActiveTenantError) -> Self { - match e { - GetActiveTenantError::WaitForActiveTimeout { .. } => QueryError::Disconnected( - ConnectionError::Io(io::Error::new(io::ErrorKind::TimedOut, e.to_string())), - ), - GetActiveTenantError::Cancelled - | GetActiveTenantError::WillNotBecomeActive(TenantState::Stopping { .. }) => { - QueryError::Shutdown - } - e @ GetActiveTenantError::NotFound(_) => QueryError::NotFound(format!("{e}").into()), - e => QueryError::Other(anyhow::anyhow!(e)), - } +/// gRPC middleware layer that handles observability concerns: +/// +/// * Creates and enters a tracing span. +/// +/// TODO: add perf tracing. +/// TODO: add timing and metrics. +/// TODO: add logging. +#[derive(Clone)] +struct ObservabilityLayer; + +impl tower::Layer for ObservabilityLayer { + type Service = ObservabilityLayerService; + + fn layer(&self, inner: S) -> Self::Service { + Self::Service { inner } + } +} + +#[derive(Clone)] +struct ObservabilityLayerService { + inner: S, +} + +impl tonic::server::NamedService for ObservabilityLayerService { + const NAME: &'static str = S::NAME; // propagate inner service name +} + +impl tower::Service> for ObservabilityLayerService +where + S: tower::Service>, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = BoxFuture<'static, Result>; + + fn call(&mut self, req: http::Request) -> Self::Future { + // Create a basic tracing span. Enter the span for the current thread (to use it for inner + // sync code like interceptors), and instrument the future (to use it for inner async code + // like the page service itself). + // + // The instrument() call below is not sufficient. It only affects the returned future, and + // only takes effect when the caller polls it. Any sync code executed when we call + // self.inner.call() below (such as interceptors) runs outside of the returned future, and + // is not affected by it. We therefore have to enter the span on the current thread too. + let span = info_span!( + "grpc:pageservice", + // Set by TenantMetadataInterceptor. + tenant_id = field::Empty, + timeline_id = field::Empty, + shard_id = field::Empty, + ); + let _guard = span.enter(); + + Box::pin(self.inner.call(req).instrument(span.clone())) + } + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) } } @@ -3400,19 +3464,22 @@ impl tonic::service::Interceptor for TenantMetadataInterceptor { .map_err(|_| tonic::Status::invalid_argument("invalid neon-timeline-id"))?; // Decode the shard ID. - let shard_index = req + let shard_id = req .metadata() .get("neon-shard-id") .ok_or_else(|| tonic::Status::invalid_argument("missing neon-shard-id"))? .to_str() .map_err(|_| tonic::Status::invalid_argument("invalid neon-shard-id"))?; - let shard_index = ShardIndex::from_str(shard_index) + let shard_id = ShardIndex::from_str(shard_id) .map_err(|_| tonic::Status::invalid_argument("invalid neon-shard-id"))?; // Stash them in the request. let extensions = req.extensions_mut(); extensions.insert(TenantTimelineId::new(tenant_id, timeline_id)); - extensions.insert(shard_index); + extensions.insert(shard_id); + + // Decorate the tracing span. + span_record!(%tenant_id, %timeline_id, %shard_id); Ok(req) } @@ -3486,6 +3553,22 @@ impl From for QueryError { } } +impl From for QueryError { + fn from(e: GetActiveTenantError) -> Self { + match e { + GetActiveTenantError::WaitForActiveTimeout { .. } => QueryError::Disconnected( + ConnectionError::Io(io::Error::new(io::ErrorKind::TimedOut, e.to_string())), + ), + GetActiveTenantError::Cancelled + | GetActiveTenantError::WillNotBecomeActive(TenantState::Stopping { .. }) => { + QueryError::Shutdown + } + e @ GetActiveTenantError::NotFound(_) => QueryError::NotFound(format!("{e}").into()), + e => QueryError::Other(anyhow::anyhow!(e)), + } + } +} + impl From for QueryError { fn from(e: crate::tenant::timeline::handle::HandleUpgradeError) -> Self { match e { From a21c1174edefdfb59fbdce9ae5696c446a3cfe0a Mon Sep 17 00:00:00 2001 From: Erik Grinaker Date: Mon, 2 Jun 2025 16:50:49 +0200 Subject: [PATCH 12/43] pagebench: add gRPC support for `get-page-latest-lsn` (#12077) ## Problem We need gRPC support in Pagebench to benchmark the new gRPC Pageserver implementation. Touches #11728. ## Summary of changes Adds a `Client` trait to make the client transport swappable, and a gRPC client via a `--protocol grpc` parameter. This must also specify the connstring with the gRPC port: ``` pagebench get-page-latest-lsn --protocol grpc --page-service-connstring grpc://localhost:51051 ``` The client is implemented using the raw Tonic-generated gRPC client, to minimize client overhead. --- Cargo.lock | 4 + libs/pageserver_api/src/models.rs | 4 +- libs/pageserver_api/src/reltag.rs | 2 +- pageserver/pagebench/Cargo.toml | 6 +- .../pagebench/src/cmd/getpage_latest_lsn.rs | 146 ++++++++++++++++-- 5 files changed, 144 insertions(+), 18 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4f7378e95d..9fc233e5ec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4236,6 +4236,7 @@ name = "pagebench" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", "camino", "clap", "futures", @@ -4244,12 +4245,15 @@ dependencies = [ "humantime-serde", "pageserver_api", "pageserver_client", + "pageserver_page_api", "rand 0.8.5", "reqwest", "serde", "serde_json", "tokio", + "tokio-stream", "tokio-util", + "tonic 0.13.1", "tracing", "utils", "workspace_hack", diff --git a/libs/pageserver_api/src/models.rs b/libs/pageserver_api/src/models.rs index e7d612bb7a..01487c0f57 100644 --- a/libs/pageserver_api/src/models.rs +++ b/libs/pageserver_api/src/models.rs @@ -2045,7 +2045,7 @@ pub enum PagestreamProtocolVersion { pub type RequestId = u64; -#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] pub struct PagestreamRequest { pub reqid: RequestId, pub request_lsn: Lsn, @@ -2064,7 +2064,7 @@ pub struct PagestreamNblocksRequest { pub rel: RelTag, } -#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] pub struct PagestreamGetPageRequest { pub hdr: PagestreamRequest, pub rel: RelTag, diff --git a/libs/pageserver_api/src/reltag.rs b/libs/pageserver_api/src/reltag.rs index 473a44dbf9..4509cab2e0 100644 --- a/libs/pageserver_api/src/reltag.rs +++ b/libs/pageserver_api/src/reltag.rs @@ -24,7 +24,7 @@ use serde::{Deserialize, Serialize}; // FIXME: should move 'forknum' as last field to keep this consistent with Postgres. // Then we could replace the custom Ord and PartialOrd implementations below with // deriving them. This will require changes in walredoproc.c. -#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize)] +#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize)] pub struct RelTag { pub forknum: u8, pub spcnode: Oid, diff --git a/pageserver/pagebench/Cargo.toml b/pageserver/pagebench/Cargo.toml index 5b5ed09a2b..ceb1278eab 100644 --- a/pageserver/pagebench/Cargo.toml +++ b/pageserver/pagebench/Cargo.toml @@ -8,6 +8,7 @@ license.workspace = true [dependencies] anyhow.workspace = true +async-trait.workspace = true camino.workspace = true clap.workspace = true futures.workspace = true @@ -15,14 +16,17 @@ hdrhistogram.workspace = true humantime.workspace = true humantime-serde.workspace = true rand.workspace = true -reqwest.workspace=true +reqwest.workspace = true serde.workspace = true serde_json.workspace = true tracing.workspace = true tokio.workspace = true +tokio-stream.workspace = true tokio-util.workspace = true +tonic.workspace = true pageserver_client.workspace = true pageserver_api.workspace = true +pageserver_page_api.workspace = true utils = { path = "../../libs/utils/" } workspace_hack = { version = "0.1", path = "../../workspace_hack" } diff --git a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs index 50419ec338..395e9cac41 100644 --- a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs +++ b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs @@ -7,11 +7,15 @@ use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; use anyhow::Context; +use async_trait::async_trait; use camino::Utf8PathBuf; use pageserver_api::key::Key; use pageserver_api::keyspace::KeySpaceAccum; -use pageserver_api::models::{PagestreamGetPageRequest, PagestreamRequest}; +use pageserver_api::models::{ + PagestreamGetPageRequest, PagestreamGetPageResponse, PagestreamRequest, +}; use pageserver_api::shard::TenantShardId; +use pageserver_page_api::proto; use rand::prelude::*; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; @@ -22,6 +26,12 @@ use utils::lsn::Lsn; use crate::util::tokio_thread_local_stats::AllThreadLocalStats; use crate::util::{request_stats, tokio_thread_local_stats}; +#[derive(clap::ValueEnum, Clone, Debug)] +enum Protocol { + Libpq, + Grpc, +} + /// GetPage@LatestLSN, uniformly distributed across the compute-accessible keyspace. #[derive(clap::Parser)] pub(crate) struct Args { @@ -35,6 +45,8 @@ pub(crate) struct Args { num_clients: NonZeroUsize, #[clap(long)] runtime: Option, + #[clap(long, value_enum, default_value = "libpq")] + protocol: Protocol, /// Each client sends requests at the given rate. /// /// If a request takes too long and we should be issuing a new request already, @@ -303,7 +315,20 @@ async fn main_impl( .unwrap(); Box::pin(async move { - client_libpq(args, worker_id, ss, cancel, rps_period, ranges, weights).await + let client: Box = match args.protocol { + Protocol::Libpq => Box::new( + LibpqClient::new(args.page_service_connstring.clone(), worker_id.timeline) + .await + .unwrap(), + ), + + Protocol::Grpc => Box::new( + GrpcClient::new(args.page_service_connstring.clone(), worker_id.timeline) + .await + .unwrap(), + ), + }; + run_worker(args, client, ss, cancel, rps_period, ranges, weights).await }) }; @@ -355,23 +380,15 @@ async fn main_impl( anyhow::Ok(()) } -async fn client_libpq( +async fn run_worker( args: &Args, - worker_id: WorkerId, + mut client: Box, shared_state: Arc, cancel: CancellationToken, rps_period: Option, ranges: Vec, weights: rand::distributions::weighted::WeightedIndex, ) { - let client = pageserver_client::page_service::Client::new(args.page_service_connstring.clone()) - .await - .unwrap(); - let mut client = client - .pagestream(worker_id.timeline.tenant_id, worker_id.timeline.timeline_id) - .await - .unwrap(); - shared_state.start_work_barrier.wait().await; let client_start = Instant::now(); let mut ticks_processed = 0; @@ -415,12 +432,12 @@ async fn client_libpq( blkno: block_no, } }; - client.getpage_send(req).await.unwrap(); + client.send_get_page(req).await.unwrap(); inflight.push_back(start); } let start = inflight.pop_front().unwrap(); - client.getpage_recv().await.unwrap(); + client.recv_get_page().await.unwrap(); let end = Instant::now(); shared_state.live_stats.request_done(); ticks_processed += 1; @@ -442,3 +459,104 @@ async fn client_libpq( } } } + +/// A benchmark client, to allow switching out the transport protocol. +/// +/// For simplicity, this just uses separate asynchronous send/recv methods. The send method could +/// return a future that resolves when the response is received, but we don't really need it. +#[async_trait] +trait Client: Send { + /// Sends an asynchronous GetPage request to the pageserver. + async fn send_get_page(&mut self, req: PagestreamGetPageRequest) -> anyhow::Result<()>; + + /// Receives the next GetPage response from the pageserver. + async fn recv_get_page(&mut self) -> anyhow::Result; +} + +/// A libpq-based Pageserver client. +struct LibpqClient { + inner: pageserver_client::page_service::PagestreamClient, +} + +impl LibpqClient { + async fn new(connstring: String, ttid: TenantTimelineId) -> anyhow::Result { + let inner = pageserver_client::page_service::Client::new(connstring) + .await? + .pagestream(ttid.tenant_id, ttid.timeline_id) + .await?; + Ok(Self { inner }) + } +} + +#[async_trait] +impl Client for LibpqClient { + async fn send_get_page(&mut self, req: PagestreamGetPageRequest) -> anyhow::Result<()> { + self.inner.getpage_send(req).await + } + + async fn recv_get_page(&mut self) -> anyhow::Result { + self.inner.getpage_recv().await + } +} + +/// A gRPC client using the raw, no-frills gRPC client. +struct GrpcClient { + req_tx: tokio::sync::mpsc::Sender, + resp_rx: tonic::Streaming, +} + +impl GrpcClient { + async fn new(connstring: String, ttid: TenantTimelineId) -> anyhow::Result { + let mut client = pageserver_page_api::proto::PageServiceClient::connect(connstring).await?; + + // The channel has a buffer size of 1, since 0 is not allowed. It does not matter, since the + // benchmark will control the queue depth (i.e. in-flight requests) anyway, and requests are + // buffered by Tonic and the OS too. + let (req_tx, req_rx) = tokio::sync::mpsc::channel(1); + let req_stream = tokio_stream::wrappers::ReceiverStream::new(req_rx); + let mut req = tonic::Request::new(req_stream); + let metadata = req.metadata_mut(); + metadata.insert("neon-tenant-id", ttid.tenant_id.to_string().try_into()?); + metadata.insert("neon-timeline-id", ttid.timeline_id.to_string().try_into()?); + metadata.insert("neon-shard-id", "0000".try_into()?); + + let resp = client.get_pages(req).await?; + let resp_stream = resp.into_inner(); + + Ok(Self { + req_tx, + resp_rx: resp_stream, + }) + } +} + +#[async_trait] +impl Client for GrpcClient { + async fn send_get_page(&mut self, req: PagestreamGetPageRequest) -> anyhow::Result<()> { + let req = proto::GetPageRequest { + request_id: 0, + request_class: proto::GetPageClass::Normal as i32, + read_lsn: Some(proto::ReadLsn { + request_lsn: req.hdr.request_lsn.0, + not_modified_since_lsn: req.hdr.not_modified_since.0, + }), + rel: Some(req.rel.into()), + block_number: vec![req.blkno], + }; + self.req_tx.send(req).await?; + Ok(()) + } + + async fn recv_get_page(&mut self) -> anyhow::Result { + let resp = self.resp_rx.message().await?.unwrap(); + anyhow::ensure!( + resp.status_code == proto::GetPageStatusCode::Ok as i32, + "unexpected status code: {}", + resp.status_code + ); + Ok(PagestreamGetPageResponse { + page: resp.page_image[0].clone(), + req: PagestreamGetPageRequest::default(), // dummy + }) + } +} From 781bf4945d9cb3902de829a187ee0e7ebc71e432 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 2 Jun 2025 17:13:30 +0100 Subject: [PATCH 13/43] proxy: optimise future layout allocations (#12104) A smaller version of #12066 that is somewhat easier to review. Now that I've been using https://crates.io/crates/top-type-sizes I've found a lot more of the low hanging fruit that can be tweaks to reduce the memory usage. Some context for the optimisations: Rust's stack allocation in futures is quite naive. Stack variables, even if moved, often still end up taking space in the future. Rearranging the order in which variables are defined, and properly scoping them can go a long way. `async fn` and `async move {}` have a consequence that they always duplicate the "upvars" (aka captures). All captures are permanently allocated in the future, even if moved. We can be mindful when writing futures to only capture as little as possible. TlsStream is massive. Needs boxing so it doesn't contribute to the above issue. ## Measurements from `top-type-sizes`: ### Before ``` 10328 {async block@proxy::proxy::task_main::{closure#0}::{closure#0}} align=8 6120 {async fn body of proxy::proxy::handle_client>()} align=8 ``` ### After ``` 4040 {async block@proxy::proxy::task_main::{closure#0}::{closure#0}} 4704 {async fn body of proxy::proxy::handle_client>()} align=8 ``` --- proxy/src/auth/backend/classic.rs | 16 ++-- proxy/src/console_redirect_proxy.rs | 2 +- .../control_plane/client/cplane_proxy_v1.rs | 75 +++++++++++-------- proxy/src/http/mod.rs | 27 +++++-- proxy/src/pqproto.rs | 6 +- proxy/src/proxy/handshake.rs | 9 ++- proxy/src/proxy/mod.rs | 2 +- proxy/src/proxy/passthrough.rs | 2 + proxy/src/sasl/stream.rs | 49 ++++++------ proxy/src/stream.rs | 4 +- proxy/src/tls/postgres_rustls.rs | 6 +- 11 files changed, 115 insertions(+), 83 deletions(-) diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index dcc500f2c8..8445368740 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -25,19 +25,15 @@ pub(super) async fn authenticate( } AuthSecret::Scram(secret) => { debug!("auth endpoint chooses SCRAM"); - let scram = auth::Scram(&secret, ctx); - let auth_outcome = tokio::time::timeout(config.scram_protocol_timeout, async { - AuthFlow::new(client, scram) - .authenticate() - .await - .inspect_err(|error| { - warn!(?error, "error processing scram messages"); - }) - }) + let auth_outcome = tokio::time::timeout( + config.scram_protocol_timeout, + AuthFlow::new(client, auth::Scram(&secret, ctx)).authenticate(), + ) .await .inspect_err(|_| warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs())) - .map_err(auth::AuthError::user_timeout)??; + .map_err(auth::AuthError::user_timeout)? + .inspect_err(|error| warn!(?error, "error processing scram messages"))?; let client_key = match auth_outcome { sasl::Outcome::Success(key) => key, diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 9499aba61b..7fb84b5ee5 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -159,7 +159,7 @@ pub async fn task_main( } #[allow(clippy::too_many_arguments)] -pub(crate) async fn handle_client( +pub(crate) async fn handle_client( config: &'static ProxyConfig, backend: &'static ConsoleRedirectBackend, ctx: &RequestContext, diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index 93f4ea6cf7..da548d6b2c 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -7,7 +7,9 @@ use std::time::Duration; use ::http::HeaderName; use ::http::header::AUTHORIZATION; +use bytes::Bytes; use futures::TryFutureExt; +use hyper::StatusCode; use postgres_client::config::SslMode; use tokio::time::Instant; use tracing::{Instrument, debug, info, info_span, warn}; @@ -72,28 +74,34 @@ impl NeonControlPlaneClient { role: &RoleName, ) -> Result { async { - let request = self - .endpoint - .get_path("get_endpoint_access_control") - .header(X_REQUEST_ID, ctx.session_id().to_string()) - .header(AUTHORIZATION, format!("Bearer {}", &self.jwt)) - .query(&[("session_id", ctx.session_id())]) - .query(&[ - ("application_name", ctx.console_application_name().as_str()), - ("endpointish", endpoint.as_str()), - ("role", role.as_str()), - ]) - .build()?; - - debug!(url = request.url().as_str(), "sending http request"); - let start = Instant::now(); let response = { - let _pause = ctx.latency_timer_pause_at(start, crate::metrics::Waiting::Cplane); - self.endpoint.execute(request).await? - }; - info!(duration = ?start.elapsed(), "received http response"); + let request = self + .endpoint + .get_path("get_endpoint_access_control") + .header(X_REQUEST_ID, ctx.session_id().to_string()) + .header(AUTHORIZATION, format!("Bearer {}", &self.jwt)) + .query(&[("session_id", ctx.session_id())]) + .query(&[ + ("application_name", ctx.console_application_name().as_str()), + ("endpointish", endpoint.as_str()), + ("role", role.as_str()), + ]) + .build()?; - let body = match parse_body::(response).await { + debug!(url = request.url().as_str(), "sending http request"); + let start = Instant::now(); + let _pause = ctx.latency_timer_pause_at(start, crate::metrics::Waiting::Cplane); + let response = self.endpoint.execute(request).await?; + + info!(duration = ?start.elapsed(), "received http response"); + + response + }; + + let body = match parse_body::( + response.status(), + response.bytes().await?, + ) { Ok(body) => body, // Error 404 is special: it's ok not to have a secret. // TODO(anna): retry @@ -184,7 +192,10 @@ impl NeonControlPlaneClient { drop(pause); info!(duration = ?start.elapsed(), "received http response"); - let body = parse_body::(response).await?; + let body = parse_body::( + response.status(), + response.bytes().await.map_err(ControlPlaneError::from)?, + )?; let rules = body .jwks @@ -236,7 +247,7 @@ impl NeonControlPlaneClient { let response = self.endpoint.execute(request).await?; drop(pause); info!(duration = ?start.elapsed(), "received http response"); - let body = parse_body::(response).await?; + let body = parse_body::(response.status(), response.bytes().await?)?; // Unfortunately, ownership won't let us use `Option::ok_or` here. let (host, port) = match parse_host_port(&body.address) { @@ -487,33 +498,33 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { } /// Parse http response body, taking status code into account. -async fn parse_body serde::Deserialize<'a>>( - response: http::Response, +fn parse_body serde::Deserialize<'a>>( + status: StatusCode, + body: Bytes, ) -> Result { - let status = response.status(); if status.is_success() { // We shouldn't log raw body because it may contain secrets. info!("request succeeded, processing the body"); - return Ok(response.json().await?); + return Ok(serde_json::from_slice(&body).map_err(std::io::Error::other)?); } - let s = response.bytes().await?; + // Log plaintext to be able to detect, whether there are some cases not covered by the error struct. - info!("response_error plaintext: {:?}", s); + info!("response_error plaintext: {:?}", body); // Don't throw an error here because it's not as important // as the fact that the request itself has failed. - let mut body = serde_json::from_slice(&s).unwrap_or_else(|e| { + let mut body = serde_json::from_slice(&body).unwrap_or_else(|e| { warn!("failed to parse error body: {e}"); - ControlPlaneErrorMessage { + Box::new(ControlPlaneErrorMessage { error: "reason unclear (malformed error message)".into(), http_status_code: status, status: None, - } + }) }); body.http_status_code = status; warn!("console responded with an error ({status}): {body:?}"); - Err(ControlPlaneError::Message(Box::new(body))) + Err(ControlPlaneError::Message(body)) } fn parse_host_port(input: &str) -> Option<(&str, u16)> { diff --git a/proxy/src/http/mod.rs b/proxy/src/http/mod.rs index 96f600d836..36607e7861 100644 --- a/proxy/src/http/mod.rs +++ b/proxy/src/http/mod.rs @@ -4,9 +4,10 @@ pub mod health_server; -use std::time::Duration; +use std::time::{Duration, Instant}; use bytes::Bytes; +use futures::FutureExt; use http::Method; use http_body_util::BodyExt; use hyper::body::Body; @@ -109,15 +110,31 @@ impl Endpoint { } /// Execute a [request](reqwest::Request). - pub(crate) async fn execute(&self, request: Request) -> Result { - let _timer = Metrics::get() + pub(crate) fn execute( + &self, + request: Request, + ) -> impl Future> { + let metric = Metrics::get() .proxy .console_request_latency - .start_timer(ConsoleRequest { + .with_labels(ConsoleRequest { request: request.url().path(), }); - self.client.execute(request).await + let req = self.client.execute(request).boxed(); + + async move { + let start = Instant::now(); + scopeguard::defer!({ + Metrics::get() + .proxy + .console_request_latency + .get_metric(metric) + .observe_duration_since(start); + }); + + req.await + } } } diff --git a/proxy/src/pqproto.rs b/proxy/src/pqproto.rs index d68d9f9474..43074bf208 100644 --- a/proxy/src/pqproto.rs +++ b/proxy/src/pqproto.rs @@ -186,7 +186,7 @@ where pub async fn read_message<'a, S>( stream: &mut S, buf: &'a mut Vec, - max: usize, + max: u32, ) -> io::Result<(u8, &'a mut [u8])> where S: AsyncRead + Unpin, @@ -206,7 +206,7 @@ where let header = read!(stream => Header); // as described above, the length must be at least 4. - let Some(len) = (header.len.get() as usize).checked_sub(4) else { + let Some(len) = header.len.get().checked_sub(4) else { return Err(io::Error::other(format!( "invalid startup message length {}, must be at least 4.", header.len, @@ -222,7 +222,7 @@ where } // read in our entire message. - buf.resize(len, 0); + buf.resize(len as usize, 0); stream.read_exact(buf).await?; Ok((header.tag, buf)) diff --git a/proxy/src/proxy/handshake.rs b/proxy/src/proxy/handshake.rs index 13ee8c7dd2..6970ab8714 100644 --- a/proxy/src/proxy/handshake.rs +++ b/proxy/src/proxy/handshake.rs @@ -1,3 +1,4 @@ +use futures::{FutureExt, TryFutureExt}; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{debug, info, warn}; @@ -57,7 +58,7 @@ pub(crate) enum HandshakeData { /// It's easier to work with owned `stream` here as we need to upgrade it to TLS; /// we also take an extra care of propagating only the select handshake errors to client. #[tracing::instrument(skip_all)] -pub(crate) async fn handshake( +pub(crate) async fn handshake( ctx: &RequestContext, stream: S, mut tls: Option<&TlsConfig>, @@ -108,7 +109,9 @@ pub(crate) async fn handshake( } } } - }); + }) + .map_ok(Box::new) + .boxed(); res?; @@ -146,7 +149,7 @@ pub(crate) async fn handshake( tls.cert_resolver.resolve(conn_info.server_name()); let tls = Stream::Tls { - tls: Box::new(tls_stream), + tls: tls_stream, tls_server_end_point, }; (stream, msg) = PqStream::parse_startup(tls).await?; diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index ac0aca1176..0ffc54aa88 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -270,7 +270,7 @@ impl ReportableError for ClientRequestError { } #[allow(clippy::too_many_arguments)] -pub(crate) async fn handle_client( +pub(crate) async fn handle_client( config: &'static ProxyConfig, auth_backend: &'static auth::Backend<'static, ()>, ctx: &RequestContext, diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index 8f9bd2de2d..55ab5f4dba 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -1,3 +1,4 @@ +use futures::FutureExt; use smol_str::SmolStr; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::debug; @@ -89,6 +90,7 @@ impl ProxyPassthrough { .compute .cancel_closure .try_cancel_query(compute_config) + .boxed() .await { tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database"); diff --git a/proxy/src/sasl/stream.rs b/proxy/src/sasl/stream.rs index cb15132673..52ccca58d5 100644 --- a/proxy/src/sasl/stream.rs +++ b/proxy/src/sasl/stream.rs @@ -30,52 +30,53 @@ where F: FnOnce(&str) -> super::Result, M: Mechanism, { - let sasl = { + let (mut mechanism, mut input) = { // pause the timer while we communicate with the client let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); // Initial client message contains the chosen auth method's name. let msg = stream.read_password_message().await?; - super::FirstMessage::parse(msg).ok_or(super::Error::BadClientMessage("bad sasl message"))? + + let sasl = super::FirstMessage::parse(msg) + .ok_or(super::Error::BadClientMessage("bad sasl message"))?; + + (mechanism(sasl.method)?, sasl.message) }; - let mut mechanism = mechanism(sasl.method)?; - let mut input = sasl.message; loop { - let step = mechanism - .exchange(input) - .inspect_err(|error| tracing::info!(?error, "error during SASL exchange"))?; - - match step { - Step::Continue(moved_mechanism, reply) => { + match mechanism.exchange(input) { + Ok(Step::Continue(moved_mechanism, reply)) => { mechanism = moved_mechanism; - // pause the timer while we communicate with the client - let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - // write reply let sasl_msg = BeAuthenticationSaslMessage::Continue(reply.as_bytes()); stream.write_message(BeMessage::AuthenticationSasl(sasl_msg)); - - // get next input - stream.flush().await?; - let msg = stream.read_password_message().await?; - input = std::str::from_utf8(msg) - .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?; + drop(reply); } - Step::Success(result, reply) => { - // pause the timer while we communicate with the client - let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - + Ok(Step::Success(result, reply)) => { // write reply let sasl_msg = BeAuthenticationSaslMessage::Final(reply.as_bytes()); stream.write_message(BeMessage::AuthenticationSasl(sasl_msg)); stream.write_message(BeMessage::AuthenticationOk); + // exit with success break Ok(Outcome::Success(result)); } // exit with failure - Step::Failure(reason) => break Ok(Outcome::Failure(reason)), + Ok(Step::Failure(reason)) => break Ok(Outcome::Failure(reason)), + Err(error) => { + tracing::info!(?error, "error during SASL exchange"); + return Err(error); + } } + + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + + // get next input + stream.flush().await?; + let msg = stream.read_password_message().await?; + input = std::str::from_utf8(msg) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?; } } diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 7126430a85..c49a431c95 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -72,7 +72,7 @@ impl PqStream { impl PqStream { /// Read a raw postgres packet, which will respect the max length requested. /// This is not cancel safe. - async fn read_raw_expect(&mut self, tag: u8, max: usize) -> io::Result<&mut [u8]> { + async fn read_raw_expect(&mut self, tag: u8, max: u32) -> io::Result<&mut [u8]> { let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?; if actual_tag != tag { return Err(io::Error::other(format!( @@ -89,7 +89,7 @@ impl PqStream { // passwords are usually pretty short // and SASL SCRAM messages are no longer than 256 bytes in my testing // (a few hashes and random bytes, encoded into base64). - const MAX_PASSWORD_LENGTH: usize = 512; + const MAX_PASSWORD_LENGTH: u32 = 512; self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH) .await } diff --git a/proxy/src/tls/postgres_rustls.rs b/proxy/src/tls/postgres_rustls.rs index f09e916a1d..013b307f0b 100644 --- a/proxy/src/tls/postgres_rustls.rs +++ b/proxy/src/tls/postgres_rustls.rs @@ -31,7 +31,9 @@ mod private { type Output = io::Result>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - Pin::new(&mut self.inner).poll(cx).map_ok(RustlsStream) + Pin::new(&mut self.inner) + .poll(cx) + .map_ok(|s| RustlsStream(Box::new(s))) } } @@ -57,7 +59,7 @@ mod private { } } - pub struct RustlsStream(TlsStream); + pub struct RustlsStream(Box>); impl postgres_client::tls::TlsStream for RustlsStream where From fc3994eb71826de6fbec023b74558aa72a7c888b Mon Sep 17 00:00:00 2001 From: Erik Grinaker Date: Mon, 2 Jun 2025 19:15:18 +0200 Subject: [PATCH 14/43] pageserver: initial gRPC page service implementation (#12094) ## Problem We should expose the page service over gRPC. Requires #12093. Touches #11728. ## Summary of changes This patch adds an initial page service implementation over gRPC. It ties in with the existing `PageServerHandler` request logic, to avoid the implementations drifting apart for the core read path. This is just a bare-bones functional implementation. Several important aspects have been omitted, and will be addressed in follow-up PRs: * Limited observability: minimal tracing, no logging, limited metrics and timing, etc. * Rate limiting will currently block. * No performance optimization. * No cancellation handling. * No tests. I've only done rudimentary testing of this, but Pagebench passes at least. --- libs/pageserver_api/src/models.rs | 2 +- libs/pageserver_api/src/reltag.rs | 10 +- pageserver/page_api/src/model.rs | 17 +- pageserver/src/basebackup.rs | 26 +- pageserver/src/bin/pageserver.rs | 4 +- pageserver/src/page_service.rs | 822 ++++++++++++++++++++++++------ pageserver/src/tenant/timeline.rs | 12 + 7 files changed, 723 insertions(+), 170 deletions(-) diff --git a/libs/pageserver_api/src/models.rs b/libs/pageserver_api/src/models.rs index 01487c0f57..28ced4a368 100644 --- a/libs/pageserver_api/src/models.rs +++ b/libs/pageserver_api/src/models.rs @@ -1934,7 +1934,7 @@ pub enum PagestreamFeMessage { } // Wrapped in libpq CopyData -#[derive(strum_macros::EnumProperty)] +#[derive(Debug, strum_macros::EnumProperty)] pub enum PagestreamBeMessage { Exists(PagestreamExistsResponse), Nblocks(PagestreamNblocksResponse), diff --git a/libs/pageserver_api/src/reltag.rs b/libs/pageserver_api/src/reltag.rs index 4509cab2e0..e0dd4fdfe8 100644 --- a/libs/pageserver_api/src/reltag.rs +++ b/libs/pageserver_api/src/reltag.rs @@ -184,12 +184,12 @@ pub enum SlruKind { MultiXactOffsets, } -impl SlruKind { - pub fn to_str(&self) -> &'static str { +impl fmt::Display for SlruKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Self::Clog => "pg_xact", - Self::MultiXactMembers => "pg_multixact/members", - Self::MultiXactOffsets => "pg_multixact/offsets", + Self::Clog => write!(f, "pg_xact"), + Self::MultiXactMembers => write!(f, "pg_multixact/members"), + Self::MultiXactOffsets => write!(f, "pg_multixact/offsets"), } } } diff --git a/pageserver/page_api/src/model.rs b/pageserver/page_api/src/model.rs index 7ab97a994e..0268ab920b 100644 --- a/pageserver/page_api/src/model.rs +++ b/pageserver/page_api/src/model.rs @@ -10,6 +10,8 @@ //! //! - Validate protocol invariants, via try_from() and try_into(). +use std::fmt::Display; + use bytes::Bytes; use postgres_ffi::Oid; use smallvec::SmallVec; @@ -48,7 +50,8 @@ pub struct ReadLsn { pub request_lsn: Lsn, /// If given, the caller guarantees that the page has not been modified since this LSN. Must be /// smaller than or equal to request_lsn. This allows the Pageserver to serve an old page - /// without waiting for the request LSN to arrive. Valid for all request types. + /// without waiting for the request LSN to arrive. If not given, the request will read at the + /// request_lsn and wait for it to arrive if necessary. Valid for all request types. /// /// It is undefined behaviour to make a request such that the page was, in fact, modified /// between request_lsn and not_modified_since_lsn. The Pageserver might detect it and return an @@ -58,6 +61,17 @@ pub struct ReadLsn { pub not_modified_since_lsn: Option, } +impl Display for ReadLsn { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let req_lsn = self.request_lsn; + if let Some(mod_lsn) = self.not_modified_since_lsn { + write!(f, "{req_lsn}>={mod_lsn}") + } else { + req_lsn.fmt(f) + } + } +} + impl ReadLsn { /// Validates the ReadLsn. pub fn validate(&self) -> Result<(), ProtocolError> { @@ -584,6 +598,7 @@ impl TryFrom for proto::GetSlruSegmentResponse { type Error = ProtocolError; fn try_from(segment: GetSlruSegmentResponse) -> Result { + // TODO: can a segment legitimately be empty? if segment.is_empty() { return Err(ProtocolError::Missing("segment")); } diff --git a/pageserver/src/basebackup.rs b/pageserver/src/basebackup.rs index e89baa0bce..4dba9d267c 100644 --- a/pageserver/src/basebackup.rs +++ b/pageserver/src/basebackup.rs @@ -65,6 +65,30 @@ impl From for BasebackupError { } } +impl From for postgres_backend::QueryError { + fn from(err: BasebackupError) -> Self { + use postgres_backend::QueryError; + use pq_proto::framed::ConnectionError; + match err { + BasebackupError::Client(err, _) => QueryError::Disconnected(ConnectionError::Io(err)), + BasebackupError::Server(err) => QueryError::Other(err), + BasebackupError::Shutdown => QueryError::Shutdown, + } + } +} + +impl From for tonic::Status { + fn from(err: BasebackupError) -> Self { + use tonic::Code; + let code = match &err { + BasebackupError::Client(_, _) => Code::Cancelled, + BasebackupError::Server(_) => Code::Internal, + BasebackupError::Shutdown => Code::Unavailable, + }; + tonic::Status::new(code, err.to_string()) + } +} + /// Create basebackup with non-rel data in it. /// Only include relational data if 'full_backup' is true. /// @@ -248,7 +272,7 @@ where async fn flush(&mut self) -> Result<(), BasebackupError> { let nblocks = self.buf.len() / BLCKSZ as usize; let (kind, segno) = self.current_segment.take().unwrap(); - let segname = format!("{}/{:>04X}", kind.to_str(), segno); + let segname = format!("{kind}/{segno:>04X}"); let header = new_tar_header(&segname, self.buf.len() as u64)?; self.ar .append(&header, self.buf.as_slice()) diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index df3c045145..337aa135dc 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -804,7 +804,7 @@ fn start_pageserver( } else { None }, - basebackup_cache.clone(), + basebackup_cache, ); // Spawn a Pageserver gRPC server task. It will spawn separate tasks for @@ -816,12 +816,10 @@ fn start_pageserver( let mut page_service_grpc = None; if let Some(grpc_listener) = grpc_listener { page_service_grpc = Some(page_service::spawn_grpc( - conf, tenant_manager.clone(), grpc_auth, otel_guard.as_ref().map(|g| g.dispatch.clone()), grpc_listener, - basebackup_cache, )?); } diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index f011ed49d0..b9ba4a3555 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -1,6 +1,7 @@ //! The Page Service listens for client connections and serves their GetPage@LSN //! requests. +use std::any::Any; use std::borrow::Cow; use std::num::NonZeroUsize; use std::os::fd::AsRawFd; @@ -11,9 +12,9 @@ use std::task::{Context, Poll}; use std::time::{Duration, Instant, SystemTime}; use std::{io, str}; -use anyhow::{Context as _, bail}; +use anyhow::{Context as _, anyhow, bail}; use async_compression::tokio::write::GzipEncoder; -use bytes::Buf; +use bytes::{Buf, BytesMut}; use futures::future::BoxFuture; use futures::{FutureExt, Stream}; use itertools::Itertools; @@ -33,6 +34,7 @@ use pageserver_api::models::{ }; use pageserver_api::reltag::SlruKind; use pageserver_api::shard::TenantShardId; +use pageserver_page_api as page_api; use pageserver_page_api::proto; use postgres_backend::{ AuthType, PostgresBackend, PostgresBackendReader, QueryError, is_expected_io_error, @@ -41,8 +43,9 @@ use postgres_ffi::BLCKSZ; use postgres_ffi::pg_constants::DEFAULTTABLESPACE_OID; use pq_proto::framed::ConnectionError; use pq_proto::{BeMessage, FeMessage, FeStartupPacket, RowDescriptor}; +use smallvec::{SmallVec, smallvec}; use strum_macros::IntoStaticStr; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufWriter}; +use tokio::io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _, BufWriter}; use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use tonic::service::Interceptor as _; @@ -78,7 +81,8 @@ use crate::tenant::mgr::{ GetActiveTenantError, GetTenantError, ShardResolveResult, ShardSelector, TenantManager, }; use crate::tenant::storage_layer::IoConcurrency; -use crate::tenant::timeline::{self, WaitLsnError}; +use crate::tenant::timeline::handle::{Handle, HandleUpgradeError, WeakHandle}; +use crate::tenant::timeline::{self, WaitLsnError, WaitLsnTimeout, WaitLsnWaiter}; use crate::tenant::{GetTimelineError, PageReconstructError, Timeline}; use crate::{CancellableTask, PERF_TRACE_TARGET, timed_after_cancellation}; @@ -167,15 +171,14 @@ pub fn spawn( /// Spawns a gRPC server for the page service. /// +/// TODO: move this onto GrpcPageServiceHandler::spawn(). /// TODO: this doesn't support TLS. We need TLS reloading via ReloadingCertificateResolver, so we /// need to reimplement the TCP+TLS accept loop ourselves. pub fn spawn_grpc( - conf: &'static PageServerConf, tenant_manager: Arc, auth: Option>, perf_trace_dispatch: Option, listener: std::net::TcpListener, - basebackup_cache: Arc, ) -> anyhow::Result { let cancel = CancellationToken::new(); let ctx = RequestContextBuilder::new(TaskKind::PageRequestHandler) @@ -208,24 +211,17 @@ pub fn spawn_grpc( // // * Layers: allow async code, can run code after the service response. However, only has access // to the raw HTTP request/response, not the gRPC types. - let page_service_handler = PageServerHandler::new( + let page_service_handler = GrpcPageServiceHandler { tenant_manager, - auth.clone(), - PageServicePipeliningConfig::Serial, // TODO: unused with gRPC - conf.get_vectored_concurrent_io, - ConnectionPerfSpanFields::default(), - basebackup_cache, ctx, - cancel.clone(), - gate.enter().expect("just created"), - ); + }; let observability_layer = ObservabilityLayer; let mut tenant_interceptor = TenantMetadataInterceptor; let mut auth_interceptor = TenantAuthInterceptor::new(auth); let page_service = tower::ServiceBuilder::new() - // Create tracing span. + // Create tracing span and record request start time. .layer(observability_layer) // Intercept gRPC requests. .layer(tonic::service::InterceptorLayer::new(move |mut req| { @@ -554,7 +550,7 @@ impl TimelineHandles { tenant_id: TenantId, timeline_id: TimelineId, shard_selector: ShardSelector, - ) -> Result, GetActiveTimelineError> { + ) -> Result, GetActiveTimelineError> { if *self.wrapper.tenant_id.get_or_init(|| tenant_id) != tenant_id { return Err(GetActiveTimelineError::Tenant( GetActiveTenantError::SwitchedTenant, @@ -721,6 +717,82 @@ enum PageStreamError { BadRequest(Cow<'static, str>), } +impl PageStreamError { + /// Converts a PageStreamError into a proto::GetPageResponse with the appropriate status + /// code, or a gRPC status if it should terminate the stream (e.g. shutdown). This is a + /// convenience method for use from a get_pages gRPC stream. + #[allow(clippy::result_large_err)] + fn into_get_page_response( + self, + request_id: page_api::RequestID, + ) -> Result { + use page_api::GetPageStatusCode; + use tonic::Code; + + // We dispatch to Into first, and then map it to a GetPageResponse. + let status: tonic::Status = self.into(); + let status_code = match status.code() { + // We shouldn't see an OK status here, because we're emitting an error. + Code::Ok => { + debug_assert_ne!(status.code(), Code::Ok); + return Err(tonic::Status::internal(format!( + "unexpected OK status: {status:?}", + ))); + } + + // These are per-request errors, returned as GetPageResponses. + Code::AlreadyExists => GetPageStatusCode::InvalidRequest, + Code::DataLoss => GetPageStatusCode::InternalError, + Code::FailedPrecondition => GetPageStatusCode::InvalidRequest, + Code::InvalidArgument => GetPageStatusCode::InvalidRequest, + Code::Internal => GetPageStatusCode::InternalError, + Code::NotFound => GetPageStatusCode::NotFound, + Code::OutOfRange => GetPageStatusCode::InvalidRequest, + Code::ResourceExhausted => GetPageStatusCode::SlowDown, + + // These should terminate the stream. + Code::Aborted => return Err(status), + Code::Cancelled => return Err(status), + Code::DeadlineExceeded => return Err(status), + Code::PermissionDenied => return Err(status), + Code::Unauthenticated => return Err(status), + Code::Unavailable => return Err(status), + Code::Unimplemented => return Err(status), + Code::Unknown => return Err(status), + }; + + Ok(page_api::GetPageResponse { + request_id, + status_code, + reason: Some(status.message().to_string()), + page_images: SmallVec::new(), + } + .into()) + } +} + +impl From for tonic::Status { + fn from(err: PageStreamError) -> Self { + use tonic::Code; + let message = err.to_string(); + let code = match err { + PageStreamError::Reconnect(_) => Code::Unavailable, + PageStreamError::Shutdown => Code::Unavailable, + PageStreamError::Read(err) => match err { + PageReconstructError::Cancelled => Code::Unavailable, + PageReconstructError::MissingKey(_) => Code::NotFound, + PageReconstructError::AncestorLsnTimeout(err) => tonic::Status::from(err).code(), + PageReconstructError::Other(_) => Code::Internal, + PageReconstructError::WalRedo(_) => Code::Internal, + }, + PageStreamError::LsnTimeout(err) => tonic::Status::from(err).code(), + PageStreamError::NotFound(_) => Code::NotFound, + PageStreamError::BadRequest(_) => Code::InvalidArgument, + }; + tonic::Status::new(code, message) + } +} + impl From for PageStreamError { fn from(value: PageReconstructError) -> Self { match value { @@ -801,37 +873,37 @@ enum BatchedFeMessage { Exists { span: Span, timer: SmgrOpTimer, - shard: timeline::handle::WeakHandle, + shard: WeakHandle, req: models::PagestreamExistsRequest, }, Nblocks { span: Span, timer: SmgrOpTimer, - shard: timeline::handle::WeakHandle, + shard: WeakHandle, req: models::PagestreamNblocksRequest, }, GetPage { span: Span, - shard: timeline::handle::WeakHandle, - pages: smallvec::SmallVec<[BatchedGetPageRequest; 1]>, + shard: WeakHandle, + pages: SmallVec<[BatchedGetPageRequest; 1]>, batch_break_reason: GetPageBatchBreakReason, }, DbSize { span: Span, timer: SmgrOpTimer, - shard: timeline::handle::WeakHandle, + shard: WeakHandle, req: models::PagestreamDbSizeRequest, }, GetSlruSegment { span: Span, timer: SmgrOpTimer, - shard: timeline::handle::WeakHandle, + shard: WeakHandle, req: models::PagestreamGetSlruSegmentRequest, }, #[cfg(feature = "testing")] Test { span: Span, - shard: timeline::handle::WeakHandle, + shard: WeakHandle, requests: Vec, }, RespondError { @@ -1080,26 +1152,6 @@ impl PageServerHandler { let neon_fe_msg = PagestreamFeMessage::parse(&mut copy_data_bytes.reader(), protocol_version)?; - // TODO: turn in to async closure once available to avoid repeating received_at - async fn record_op_start_and_throttle( - shard: &timeline::handle::Handle, - op: metrics::SmgrQueryType, - received_at: Instant, - ) -> Result { - // It's important to start the smgr op metric recorder as early as possible - // so that the _started counters are incremented before we do - // any serious waiting, e.g., for throttle, batching, or actual request handling. - let mut timer = shard.query_metrics.start_smgr_op(op, received_at); - let now = Instant::now(); - timer.observe_throttle_start(now); - let throttled = tokio::select! { - res = shard.pagestream_throttle.throttle(1, now) => res, - _ = shard.cancel.cancelled() => return Err(QueryError::Shutdown), - }; - timer.observe_throttle_done(throttled); - Ok(timer) - } - let batched_msg = match neon_fe_msg { PagestreamFeMessage::Exists(req) => { let shard = timeline_handles @@ -1107,7 +1159,7 @@ impl PageServerHandler { .await?; debug_assert_current_span_has_tenant_and_timeline_id_no_shard_id(); let span = tracing::info_span!(parent: &parent_span, "handle_get_rel_exists_request", rel = %req.rel, req_lsn = %req.hdr.request_lsn, shard_id = %shard.tenant_shard_id.shard_slug()); - let timer = record_op_start_and_throttle( + let timer = Self::record_op_start_and_throttle( &shard, metrics::SmgrQueryType::GetRelExists, received_at, @@ -1125,7 +1177,7 @@ impl PageServerHandler { .get(tenant_id, timeline_id, ShardSelector::Zero) .await?; let span = tracing::info_span!(parent: &parent_span, "handle_get_nblocks_request", rel = %req.rel, req_lsn = %req.hdr.request_lsn, shard_id = %shard.tenant_shard_id.shard_slug()); - let timer = record_op_start_and_throttle( + let timer = Self::record_op_start_and_throttle( &shard, metrics::SmgrQueryType::GetRelSize, received_at, @@ -1143,7 +1195,7 @@ impl PageServerHandler { .get(tenant_id, timeline_id, ShardSelector::Zero) .await?; let span = tracing::info_span!(parent: &parent_span, "handle_db_size_request", dbnode = %req.dbnode, req_lsn = %req.hdr.request_lsn, shard_id = %shard.tenant_shard_id.shard_slug()); - let timer = record_op_start_and_throttle( + let timer = Self::record_op_start_and_throttle( &shard, metrics::SmgrQueryType::GetDbSize, received_at, @@ -1161,7 +1213,7 @@ impl PageServerHandler { .get(tenant_id, timeline_id, ShardSelector::Zero) .await?; let span = tracing::info_span!(parent: &parent_span, "handle_get_slru_segment_request", kind = %req.kind, segno = %req.segno, req_lsn = %req.hdr.request_lsn, shard_id = %shard.tenant_shard_id.shard_slug()); - let timer = record_op_start_and_throttle( + let timer = Self::record_op_start_and_throttle( &shard, metrics::SmgrQueryType::GetSlruSegment, received_at, @@ -1286,7 +1338,7 @@ impl PageServerHandler { // request handler log messages contain the request-specific fields. let span = mkspan!(shard.tenant_shard_id.shard_slug()); - let timer = record_op_start_and_throttle( + let timer = Self::record_op_start_and_throttle( &shard, metrics::SmgrQueryType::GetPageAtLsn, received_at, @@ -1333,7 +1385,7 @@ impl PageServerHandler { BatchedFeMessage::GetPage { span, shard: shard.downgrade(), - pages: smallvec::smallvec![BatchedGetPageRequest { + pages: smallvec![BatchedGetPageRequest { req, timer, lsn_range: LsnRange { @@ -1355,9 +1407,12 @@ impl PageServerHandler { .get(tenant_id, timeline_id, ShardSelector::Zero) .await?; let span = tracing::info_span!(parent: &parent_span, "handle_test_request", shard_id = %shard.tenant_shard_id.shard_slug()); - let timer = - record_op_start_and_throttle(&shard, metrics::SmgrQueryType::Test, received_at) - .await?; + let timer = Self::record_op_start_and_throttle( + &shard, + metrics::SmgrQueryType::Test, + received_at, + ) + .await?; BatchedFeMessage::Test { span, shard: shard.downgrade(), @@ -1368,6 +1423,26 @@ impl PageServerHandler { Ok(Some(batched_msg)) } + /// Starts a SmgrOpTimer at received_at and throttles the request. + async fn record_op_start_and_throttle( + shard: &Handle, + op: metrics::SmgrQueryType, + received_at: Instant, + ) -> Result { + // It's important to start the smgr op metric recorder as early as possible + // so that the _started counters are incremented before we do + // any serious waiting, e.g., for throttle, batching, or actual request handling. + let mut timer = shard.query_metrics.start_smgr_op(op, received_at); + let now = Instant::now(); + timer.observe_throttle_start(now); + let throttled = tokio::select! { + res = shard.pagestream_throttle.throttle(1, now) => res, + _ = shard.cancel.cancelled() => return Err(QueryError::Shutdown), + }; + timer.observe_throttle_done(throttled); + Ok(timer) + } + /// Post-condition: `batch` is Some() #[instrument(skip_all, level = tracing::Level::TRACE)] #[allow(clippy::boxed_local)] @@ -1465,8 +1540,11 @@ impl PageServerHandler { let (mut handler_results, span) = { // TODO: we unfortunately have to pin the future on the heap, since GetPage futures are huge and // won't fit on the stack. - let mut boxpinned = - Box::pin(self.pagestream_dispatch_batched_message(batch, io_concurrency, ctx)); + let mut boxpinned = Box::pin(Self::pagestream_dispatch_batched_message( + batch, + io_concurrency, + ctx, + )); log_slow( log_slow_name, LOG_SLOW_GETPAGE_THRESHOLD, @@ -1622,7 +1700,6 @@ impl PageServerHandler { /// Helper which dispatches a batched message to the appropriate handler. /// Returns a vec of results, along with the extracted trace span. async fn pagestream_dispatch_batched_message( - &mut self, batch: BatchedFeMessage, io_concurrency: IoConcurrency, ctx: &RequestContext, @@ -1652,10 +1729,10 @@ impl PageServerHandler { let (shard, ctx) = upgrade_handle_and_set_context!(shard); ( vec![ - self.handle_get_rel_exists_request(&shard, &req, &ctx) + Self::handle_get_rel_exists_request(&shard, &req, &ctx) .instrument(span.clone()) .await - .map(|msg| (msg, timer, ctx)) + .map(|msg| (PagestreamBeMessage::Exists(msg), timer, ctx)) .map_err(|err| BatchedPageStreamError { err, req: req.hdr }), ], span, @@ -1671,10 +1748,10 @@ impl PageServerHandler { let (shard, ctx) = upgrade_handle_and_set_context!(shard); ( vec![ - self.handle_get_nblocks_request(&shard, &req, &ctx) + Self::handle_get_nblocks_request(&shard, &req, &ctx) .instrument(span.clone()) .await - .map(|msg| (msg, timer, ctx)) + .map(|msg| (PagestreamBeMessage::Nblocks(msg), timer, ctx)) .map_err(|err| BatchedPageStreamError { err, req: req.hdr }), ], span, @@ -1692,16 +1769,15 @@ impl PageServerHandler { { let npages = pages.len(); trace!(npages, "handling getpage request"); - let res = self - .handle_get_page_at_lsn_request_batched( - &shard, - pages, - io_concurrency, - batch_break_reason, - &ctx, - ) - .instrument(span.clone()) - .await; + let res = Self::handle_get_page_at_lsn_request_batched( + &shard, + pages, + io_concurrency, + batch_break_reason, + &ctx, + ) + .instrument(span.clone()) + .await; assert_eq!(res.len(), npages); res }, @@ -1718,10 +1794,10 @@ impl PageServerHandler { let (shard, ctx) = upgrade_handle_and_set_context!(shard); ( vec![ - self.handle_db_size_request(&shard, &req, &ctx) + Self::handle_db_size_request(&shard, &req, &ctx) .instrument(span.clone()) .await - .map(|msg| (msg, timer, ctx)) + .map(|msg| (PagestreamBeMessage::DbSize(msg), timer, ctx)) .map_err(|err| BatchedPageStreamError { err, req: req.hdr }), ], span, @@ -1737,10 +1813,10 @@ impl PageServerHandler { let (shard, ctx) = upgrade_handle_and_set_context!(shard); ( vec![ - self.handle_get_slru_segment_request(&shard, &req, &ctx) + Self::handle_get_slru_segment_request(&shard, &req, &ctx) .instrument(span.clone()) .await - .map(|msg| (msg, timer, ctx)) + .map(|msg| (PagestreamBeMessage::GetSlruSegment(msg), timer, ctx)) .map_err(|err| BatchedPageStreamError { err, req: req.hdr }), ], span, @@ -1758,8 +1834,7 @@ impl PageServerHandler { { let npages = requests.len(); trace!(npages, "handling getpage request"); - let res = self - .handle_test_request_batch(&shard, requests, &ctx) + let res = Self::handle_test_request_batch(&shard, requests, &ctx) .instrument(span.clone()) .await; assert_eq!(res.len(), npages); @@ -2313,11 +2388,10 @@ impl PageServerHandler { #[instrument(skip_all, fields(shard_id))] async fn handle_get_rel_exists_request( - &mut self, timeline: &Timeline, req: &PagestreamExistsRequest, ctx: &RequestContext, - ) -> Result { + ) -> Result { let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); let lsn = Self::wait_or_get_last_lsn( timeline, @@ -2339,19 +2413,15 @@ impl PageServerHandler { ) .await?; - Ok(PagestreamBeMessage::Exists(PagestreamExistsResponse { - req: *req, - exists, - })) + Ok(PagestreamExistsResponse { req: *req, exists }) } #[instrument(skip_all, fields(shard_id))] async fn handle_get_nblocks_request( - &mut self, timeline: &Timeline, req: &PagestreamNblocksRequest, ctx: &RequestContext, - ) -> Result { + ) -> Result { let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); let lsn = Self::wait_or_get_last_lsn( timeline, @@ -2373,19 +2443,18 @@ impl PageServerHandler { ) .await?; - Ok(PagestreamBeMessage::Nblocks(PagestreamNblocksResponse { + Ok(PagestreamNblocksResponse { req: *req, n_blocks, - })) + }) } #[instrument(skip_all, fields(shard_id))] async fn handle_db_size_request( - &mut self, timeline: &Timeline, req: &PagestreamDbSizeRequest, ctx: &RequestContext, - ) -> Result { + ) -> Result { let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); let lsn = Self::wait_or_get_last_lsn( timeline, @@ -2409,17 +2478,13 @@ impl PageServerHandler { .await?; let db_size = total_blocks as i64 * BLCKSZ as i64; - Ok(PagestreamBeMessage::DbSize(PagestreamDbSizeResponse { - req: *req, - db_size, - })) + Ok(PagestreamDbSizeResponse { req: *req, db_size }) } #[instrument(skip_all)] async fn handle_get_page_at_lsn_request_batched( - &mut self, timeline: &Timeline, - requests: smallvec::SmallVec<[BatchedGetPageRequest; 1]>, + requests: SmallVec<[BatchedGetPageRequest; 1]>, io_concurrency: IoConcurrency, batch_break_reason: GetPageBatchBreakReason, ctx: &RequestContext, @@ -2544,11 +2609,10 @@ impl PageServerHandler { #[instrument(skip_all, fields(shard_id))] async fn handle_get_slru_segment_request( - &mut self, timeline: &Timeline, req: &PagestreamGetSlruSegmentRequest, ctx: &RequestContext, - ) -> Result { + ) -> Result { let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); let lsn = Self::wait_or_get_last_lsn( timeline, @@ -2563,16 +2627,13 @@ impl PageServerHandler { .ok_or(PageStreamError::BadRequest("invalid SLRU kind".into()))?; let segment = timeline.get_slru_segment(kind, req.segno, lsn, ctx).await?; - Ok(PagestreamBeMessage::GetSlruSegment( - PagestreamGetSlruSegmentResponse { req: *req, segment }, - )) + Ok(PagestreamGetSlruSegmentResponse { req: *req, segment }) } // NB: this impl mimics what we do for batched getpage requests. #[cfg(feature = "testing")] #[instrument(skip_all, fields(shard_id))] async fn handle_test_request_batch( - &mut self, timeline: &Timeline, requests: Vec, _ctx: &RequestContext, @@ -2648,15 +2709,6 @@ impl PageServerHandler { where IO: AsyncRead + AsyncWrite + Send + Sync + Unpin, { - fn map_basebackup_error(err: BasebackupError) -> QueryError { - match err { - // TODO: passthrough the error site to the final error message? - BasebackupError::Client(e, _) => QueryError::Disconnected(ConnectionError::Io(e)), - BasebackupError::Server(e) => QueryError::Other(e), - BasebackupError::Shutdown => QueryError::Shutdown, - } - } - let started = std::time::Instant::now(); let timeline = self @@ -2714,8 +2766,7 @@ impl PageServerHandler { replica, &ctx, ) - .await - .map_err(map_basebackup_error)?; + .await?; } else { let mut writer = BufWriter::new(pgb.copyout_writer()); @@ -2738,11 +2789,8 @@ impl PageServerHandler { from_cache = true; tokio::io::copy(&mut cached, &mut writer) .await - .map_err(|e| { - map_basebackup_error(BasebackupError::Client( - e, - "handle_basebackup_request,cached,copy", - )) + .map_err(|err| { + BasebackupError::Client(err, "handle_basebackup_request,cached,copy") })?; } else if gzip { let mut encoder = GzipEncoder::with_quality( @@ -2763,8 +2811,7 @@ impl PageServerHandler { replica, &ctx, ) - .await - .map_err(map_basebackup_error)?; + .await?; // shutdown the encoder to ensure the gzip footer is written encoder .shutdown() @@ -2780,15 +2827,12 @@ impl PageServerHandler { replica, &ctx, ) - .await - .map_err(map_basebackup_error)?; + .await?; } - writer.flush().await.map_err(|e| { - map_basebackup_error(BasebackupError::Client( - e, - "handle_basebackup_request,flush", - )) - })?; + writer + .flush() + .await + .map_err(|err| BasebackupError::Client(err, "handle_basebackup_request,flush"))?; } pgb.write_message_noflush(&BeMessage::CopyDone) @@ -3312,69 +3356,464 @@ where } } -/// Implements the page service over gRPC. +/// Serves the page service over gRPC. Dispatches to PageServerHandler for request processing. /// -/// TODO: not yet implemented, all methods return unimplemented. +/// TODO: rename to PageServiceHandler when libpq impl is removed. +pub struct GrpcPageServiceHandler { + tenant_manager: Arc, + ctx: RequestContext, +} + +impl GrpcPageServiceHandler { + /// Errors if the request is executed on a non-zero shard. Only shard 0 has a complete view of + /// relations and their sizes, as well as SLRU segments and similar data. + #[allow(clippy::result_large_err)] + fn ensure_shard_zero(timeline: &Handle) -> Result<(), tonic::Status> { + match timeline.get_shard_index().shard_number.0 { + 0 => Ok(()), + shard => Err(tonic::Status::invalid_argument(format!( + "request must execute on shard zero (is shard {shard})", + ))), + } + } + + /// Generates a PagestreamRequest header from a ReadLsn and request ID. + fn make_hdr(read_lsn: page_api::ReadLsn, req_id: u64) -> PagestreamRequest { + PagestreamRequest { + reqid: req_id, + request_lsn: read_lsn.request_lsn, + not_modified_since: read_lsn + .not_modified_since_lsn + .unwrap_or(read_lsn.request_lsn), + } + } + + /// Acquires a timeline handle for the given request. + /// + /// TODO: during shard splits, the compute may still be sending requests to the parent shard + /// until the entire split is committed and the compute is notified. Consider installing a + /// temporary shard router from the parent to the children while the split is in progress. + /// + /// TODO: consider moving this to a middleware layer; all requests need it. Needs to manage + /// the TimelineHandles lifecycle. + /// + /// TODO: untangle acquisition from TenantManagerWrapper::resolve() and Cache::get(), to avoid + /// the unnecessary overhead. + async fn get_request_timeline( + &self, + req: &tonic::Request, + ) -> Result, GetActiveTimelineError> { + let ttid = *extract::(req); + let shard_index = *extract::(req); + let shard_selector = ShardSelector::Known(shard_index); + + TimelineHandles::new(self.tenant_manager.clone()) + .get(ttid.tenant_id, ttid.timeline_id, shard_selector) + .await + } + + /// Starts a SmgrOpTimer at received_at, throttles the request, and records execution start. + /// Only errors if the timeline is shutting down. + /// + /// TODO: move timer construction to ObservabilityLayer (see TODO there). + /// TODO: decouple rate limiting (middleware?), and return SlowDown errors instead. + async fn record_op_start_and_throttle( + timeline: &Handle, + op: metrics::SmgrQueryType, + received_at: Instant, + ) -> Result { + let mut timer = PageServerHandler::record_op_start_and_throttle(timeline, op, received_at) + .await + .map_err(|err| match err { + // record_op_start_and_throttle() only returns Shutdown. + QueryError::Shutdown => tonic::Status::unavailable(format!("{err}")), + err => tonic::Status::internal(format!("unexpected error: {err}")), + })?; + timer.observe_execution_start(Instant::now()); + Ok(timer) + } + + /// Processes a GetPage batch request, via the GetPages bidirectional streaming RPC. + /// + /// NB: errors will terminate the stream. Per-request errors should return a GetPageResponse + /// with an appropriate status code instead. + /// + /// TODO: get_vectored() currently enforces a batch limit of 32. Postgres will typically send + /// batches up to effective_io_concurrency = 100. Either we have to accept large batches, or + /// split them up in the client or server. + #[instrument(skip_all, fields(req_id, rel, blkno, blks, req_lsn, mod_lsn))] + async fn get_page( + ctx: &RequestContext, + timeline: &WeakHandle, + req: proto::GetPageRequest, + io_concurrency: IoConcurrency, + ) -> Result { + let received_at = Instant::now(); + let timeline = timeline.upgrade()?; + let ctx = ctx.with_scope_page_service_pagestream(&timeline); + + // Validate the request, decorate the span, and convert it to a Pagestream request. + let req: page_api::GetPageRequest = req.try_into()?; + + span_record!( + req_id = %req.request_id, + rel = %req.rel, + blkno = %req.block_numbers[0], + blks = %req.block_numbers.len(), + lsn = %req.read_lsn, + ); + + let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); // hold guard + let effective_lsn = match PageServerHandler::effective_request_lsn( + &timeline, + timeline.get_last_record_lsn(), + req.read_lsn.request_lsn, + req.read_lsn + .not_modified_since_lsn + .unwrap_or(req.read_lsn.request_lsn), + &latest_gc_cutoff_lsn, + ) { + Ok(lsn) => lsn, + Err(err) => return err.into_get_page_response(req.request_id), + }; + + let mut batch = SmallVec::with_capacity(req.block_numbers.len()); + for blkno in req.block_numbers { + // TODO: this creates one timer per page and throttles it. We should have a timer for + // the entire batch, and throttle only the batch, but this is equivalent to what + // PageServerHandler does already so we keep it for now. + let timer = Self::record_op_start_and_throttle( + &timeline, + metrics::SmgrQueryType::GetPageAtLsn, + received_at, + ) + .await?; + + batch.push(BatchedGetPageRequest { + req: PagestreamGetPageRequest { + hdr: Self::make_hdr(req.read_lsn, req.request_id), + rel: req.rel, + blkno, + }, + lsn_range: LsnRange { + effective_lsn, + request_lsn: req.read_lsn.request_lsn, + }, + timer, + ctx: ctx.attached_child(), + batch_wait_ctx: None, // TODO: add tracing + }); + } + + // TODO: this does a relation size query for every page in the batch. Since this batch is + // all for one relation, we could do this only once. However, this is not the case for the + // libpq implementation. + let results = PageServerHandler::handle_get_page_at_lsn_request_batched( + &timeline, + batch, + io_concurrency, + GetPageBatchBreakReason::BatchFull, // TODO: not relevant for gRPC batches + &ctx, + ) + .await; + + let mut resp = page_api::GetPageResponse { + request_id: req.request_id, + status_code: page_api::GetPageStatusCode::Ok, + reason: None, + page_images: SmallVec::with_capacity(results.len()), + }; + + for result in results { + match result { + Ok((PagestreamBeMessage::GetPage(r), _, _)) => resp.page_images.push(r.page), + Ok((resp, _, _)) => { + return Err(tonic::Status::internal(format!( + "unexpected response: {resp:?}" + ))); + } + Err(err) => return err.err.into_get_page_response(req.request_id), + }; + } + + Ok(resp.into()) + } +} + +/// Implements the gRPC page service. +/// +/// TODO: cancellation. +/// TODO: when the libpq impl is removed, remove the Pagestream types and inline the handler code. #[tonic::async_trait] -impl proto::PageService for PageServerHandler { +impl proto::PageService for GrpcPageServiceHandler { type GetBaseBackupStream = Pin< Box> + Send>, >; + type GetPagesStream = Pin> + Send>>; - #[instrument(skip_all)] + #[instrument(skip_all, fields(rel, lsn))] async fn check_rel_exists( &self, - _: tonic::Request, + req: tonic::Request, ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("not implemented")) + let received_at = extract::(&req).0; + let timeline = self.get_request_timeline(&req).await?; + let ctx = self.ctx.with_scope_page_service_pagestream(&timeline); + + // Validate the request, decorate the span, and convert it to a Pagestream request. + Self::ensure_shard_zero(&timeline)?; + let req: page_api::CheckRelExistsRequest = req.into_inner().try_into()?; + + span_record!(rel=%req.rel, lsn=%req.read_lsn); + + let req = PagestreamExistsRequest { + hdr: Self::make_hdr(req.read_lsn, 0), + rel: req.rel, + }; + + // Execute the request and convert the response. + let _timer = Self::record_op_start_and_throttle( + &timeline, + metrics::SmgrQueryType::GetRelExists, + received_at, + ) + .await?; + + let resp = PageServerHandler::handle_get_rel_exists_request(&timeline, &req, &ctx).await?; + let resp: page_api::CheckRelExistsResponse = resp.exists; + Ok(tonic::Response::new(resp.into())) } - #[instrument(skip_all)] + // TODO: ensure clients use gzip compression for the stream. + #[instrument(skip_all, fields(lsn))] async fn get_base_backup( &self, - _: tonic::Request, + req: tonic::Request, ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("not implemented")) + // Send 64 KB chunks to avoid large memory allocations. + const CHUNK_SIZE: usize = 64 * 1024; + + let timeline = self.get_request_timeline(&req).await?; + let ctx = self.ctx.with_scope_timeline(&timeline); + + // Validate the request, decorate the span, and wait for the LSN to arrive. + // + // TODO: this requires a read LSN, is that ok? + Self::ensure_shard_zero(&timeline)?; + if timeline.is_archived() == Some(true) { + return Err(tonic::Status::failed_precondition("timeline is archived")); + } + let req: page_api::GetBaseBackupRequest = req.into_inner().try_into()?; + + span_record!(lsn=%req.read_lsn); + + let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); + timeline + .wait_lsn( + req.read_lsn.request_lsn, + WaitLsnWaiter::PageService, + WaitLsnTimeout::Default, + &ctx, + ) + .await?; + timeline + .check_lsn_is_in_scope(req.read_lsn.request_lsn, &latest_gc_cutoff_lsn) + .map_err(|err| { + tonic::Status::invalid_argument(format!("invalid basebackup LSN: {err}")) + })?; + + // Spawn a task to run the basebackup. + // + // TODO: do we need to support full base backups, for debugging? + let span = Span::current(); + let (mut simplex_read, mut simplex_write) = tokio::io::simplex(CHUNK_SIZE); + let jh = tokio::spawn(async move { + let result = basebackup::send_basebackup_tarball( + &mut simplex_write, + &timeline, + Some(req.read_lsn.request_lsn), + None, + false, + req.replica, + &ctx, + ) + .instrument(span) // propagate request span + .await; + simplex_write.shutdown().await.map_err(|err| { + BasebackupError::Server(anyhow!("simplex shutdown failed: {err}")) + })?; + result + }); + + // Emit chunks of size CHUNK_SIZE. + let chunks = async_stream::try_stream! { + let mut chunk = BytesMut::with_capacity(CHUNK_SIZE); + loop { + let n = simplex_read.read_buf(&mut chunk).await.map_err(|err| { + tonic::Status::internal(format!("failed to read basebackup chunk: {err}")) + })?; + + // If we read 0 bytes, either the chunk is full or the stream is closed. + if n == 0 { + if chunk.is_empty() { + break; + } + yield proto::GetBaseBackupResponseChunk::try_from(chunk.clone().freeze())?; + chunk.clear(); + } + } + // Wait for the basebackup task to exit and check for errors. + jh.await.map_err(|err| { + tonic::Status::internal(format!("basebackup failed: {err}")) + })??; + }; + + Ok(tonic::Response::new(Box::pin(chunks))) } - #[instrument(skip_all)] + #[instrument(skip_all, fields(db_oid, lsn))] async fn get_db_size( &self, - _: tonic::Request, + req: tonic::Request, ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("not implemented")) + let received_at = extract::(&req).0; + let timeline = self.get_request_timeline(&req).await?; + let ctx = self.ctx.with_scope_page_service_pagestream(&timeline); + + // Validate the request, decorate the span, and convert it to a Pagestream request. + Self::ensure_shard_zero(&timeline)?; + let req: page_api::GetDbSizeRequest = req.into_inner().try_into()?; + + span_record!(db_oid=%req.db_oid, lsn=%req.read_lsn); + + let req = PagestreamDbSizeRequest { + hdr: Self::make_hdr(req.read_lsn, 0), + dbnode: req.db_oid, + }; + + // Execute the request and convert the response. + let _timer = Self::record_op_start_and_throttle( + &timeline, + metrics::SmgrQueryType::GetDbSize, + received_at, + ) + .await?; + + let resp = PageServerHandler::handle_db_size_request(&timeline, &req, &ctx).await?; + let resp = resp.db_size as page_api::GetDbSizeResponse; + Ok(tonic::Response::new(resp.into())) } // NB: don't instrument this, instrument each streamed request. async fn get_pages( &self, - _: tonic::Request>, + req: tonic::Request>, ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("not implemented")) + // Extract the timeline from the request and check that it exists. + let ttid = *extract::(&req); + let shard_index = *extract::(&req); + let shard_selector = ShardSelector::Known(shard_index); + + let mut handles = TimelineHandles::new(self.tenant_manager.clone()); + handles + .get(ttid.tenant_id, ttid.timeline_id, shard_selector) + .await?; + + let span = Span::current(); + let ctx = self.ctx.attached_child(); + let mut reqs = req.into_inner(); + + let resps = async_stream::try_stream! { + let timeline = handles + .get(ttid.tenant_id, ttid.timeline_id, shard_selector) + .await? + .downgrade(); + while let Some(req) = reqs.message().await? { + // TODO: implement IoConcurrency sidecar. + yield Self::get_page(&ctx, &timeline, req, IoConcurrency::Sequential) + .instrument(span.clone()) // propagate request span + .await? + } + }; + + Ok(tonic::Response::new(Box::pin(resps))) } - #[instrument(skip_all)] + #[instrument(skip_all, fields(rel, lsn))] async fn get_rel_size( &self, - _: tonic::Request, + req: tonic::Request, ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("not implemented")) + let received_at = extract::(&req).0; + let timeline = self.get_request_timeline(&req).await?; + let ctx = self.ctx.with_scope_page_service_pagestream(&timeline); + + // Validate the request, decorate the span, and convert it to a Pagestream request. + Self::ensure_shard_zero(&timeline)?; + let req: page_api::GetRelSizeRequest = req.into_inner().try_into()?; + + span_record!(rel=%req.rel, lsn=%req.read_lsn); + + let req = PagestreamNblocksRequest { + hdr: Self::make_hdr(req.read_lsn, 0), + rel: req.rel, + }; + + // Execute the request and convert the response. + let _timer = Self::record_op_start_and_throttle( + &timeline, + metrics::SmgrQueryType::GetRelSize, + received_at, + ) + .await?; + + let resp = PageServerHandler::handle_get_nblocks_request(&timeline, &req, &ctx).await?; + let resp: page_api::GetRelSizeResponse = resp.n_blocks; + Ok(tonic::Response::new(resp.into())) } - #[instrument(skip_all)] + #[instrument(skip_all, fields(kind, segno, lsn))] async fn get_slru_segment( &self, - _: tonic::Request, + req: tonic::Request, ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("not implemented")) + let received_at = extract::(&req).0; + let timeline = self.get_request_timeline(&req).await?; + let ctx = self.ctx.with_scope_page_service_pagestream(&timeline); + + // Validate the request, decorate the span, and convert it to a Pagestream request. + Self::ensure_shard_zero(&timeline)?; + let req: page_api::GetSlruSegmentRequest = req.into_inner().try_into()?; + + span_record!(kind=%req.kind, segno=%req.segno, lsn=%req.read_lsn); + + let req = PagestreamGetSlruSegmentRequest { + hdr: Self::make_hdr(req.read_lsn, 0), + kind: req.kind as u8, + segno: req.segno, + }; + + // Execute the request and convert the response. + let _timer = Self::record_op_start_and_throttle( + &timeline, + metrics::SmgrQueryType::GetSlruSegment, + received_at, + ) + .await?; + + let resp = + PageServerHandler::handle_get_slru_segment_request(&timeline, &req, &ctx).await?; + let resp: page_api::GetSlruSegmentResponse = resp.segment; + Ok(tonic::Response::new(resp.try_into()?)) } } /// gRPC middleware layer that handles observability concerns: /// /// * Creates and enters a tracing span. +/// * Records the request start time as a ReceivedAt request extension. /// /// TODO: add perf tracing. /// TODO: add timing and metrics. @@ -3395,6 +3834,9 @@ struct ObservabilityLayerService { inner: S, } +#[derive(Clone, Copy)] +struct ReceivedAt(Instant); + impl tonic::server::NamedService for ObservabilityLayerService { const NAME: &'static str = S::NAME; // propagate inner service name } @@ -3408,7 +3850,13 @@ where type Error = S::Error; type Future = BoxFuture<'static, Result>; - fn call(&mut self, req: http::Request) -> Self::Future { + fn call(&mut self, mut req: http::Request) -> Self::Future { + // Record the request start time as a request extension. + // + // TODO: we should start a timer here instead, but it currently requires a timeline handle + // and SmgrQueryType, which we don't have yet. Refactor it to provide it later. + req.extensions_mut().insert(ReceivedAt(Instant::now())); + // Create a basic tracing span. Enter the span for the current thread (to use it for inner // sync code like interceptors), and instrument the future (to use it for inner async code // like the page service itself). @@ -3436,8 +3884,6 @@ where /// gRPC interceptor that decodes tenant metadata and stores it as request extensions of type /// TenantTimelineId and ShardIndex. -/// -/// TODO: consider looking up the timeline handle here and storing it. #[derive(Clone)] struct TenantMetadataInterceptor; @@ -3485,7 +3931,7 @@ impl tonic::service::Interceptor for TenantMetadataInterceptor { } } -/// Authenticates gRPC page service requests. Must run after TenantMetadataInterceptor. +/// Authenticates gRPC page service requests. #[derive(Clone)] struct TenantAuthInterceptor { auth: Option>, @@ -3504,11 +3950,8 @@ impl tonic::service::Interceptor for TenantAuthInterceptor { return Ok(req); }; - // Fetch the tenant ID that's been set by TenantMetadataInterceptor. - let ttid = req - .extensions() - .get::() - .expect("TenantMetadataInterceptor must run before TenantAuthInterceptor"); + // Fetch the tenant ID from the request extensions (set by TenantMetadataInterceptor). + let TenantTimelineId { tenant_id, .. } = *extract::(&req); // Fetch and decode the JWT token. let jwt = req @@ -3526,7 +3969,7 @@ impl tonic::service::Interceptor for TenantAuthInterceptor { let claims = jwtdata.claims; // Check if the token is valid for this tenant. - check_permission(&claims, Some(ttid.tenant_id)) + check_permission(&claims, Some(tenant_id)) .map_err(|err| tonic::Status::permission_denied(err.to_string()))?; // TODO: consider stashing the claims in the request extensions, if needed. @@ -3535,6 +3978,21 @@ impl tonic::service::Interceptor for TenantAuthInterceptor { } } +/// Extracts the given type from the request extensions, or panics if it is missing. +fn extract(req: &tonic::Request) -> &T { + extract_from(req.extensions()) +} + +/// Extract the given type from the request extensions, or panics if it is missing. This variant +/// can extract both from a tonic::Request and http::Request. +fn extract_from(ext: &http::Extensions) -> &T { + let Some(value) = ext.get::() else { + let name = std::any::type_name::(); + panic!("extension {name} should be set by middleware"); + }; + value +} + #[derive(Debug, thiserror::Error)] pub(crate) enum GetActiveTimelineError { #[error(transparent)] @@ -3553,6 +4011,29 @@ impl From for QueryError { } } +impl From for tonic::Status { + fn from(err: GetActiveTimelineError) -> Self { + let message = err.to_string(); + let code = match err { + GetActiveTimelineError::Tenant(err) => tonic::Status::from(err).code(), + GetActiveTimelineError::Timeline(err) => tonic::Status::from(err).code(), + }; + tonic::Status::new(code, message) + } +} + +impl From for tonic::Status { + fn from(err: GetTimelineError) -> Self { + use tonic::Code; + let code = match &err { + GetTimelineError::NotFound { .. } => Code::NotFound, + GetTimelineError::NotActive { .. } => Code::Unavailable, + GetTimelineError::ShuttingDown => Code::Unavailable, + }; + tonic::Status::new(code, err.to_string()) + } +} + impl From for QueryError { fn from(e: GetActiveTenantError) -> Self { match e { @@ -3569,10 +4050,33 @@ impl From for QueryError { } } -impl From for QueryError { - fn from(e: crate::tenant::timeline::handle::HandleUpgradeError) -> Self { +impl From for tonic::Status { + fn from(err: GetActiveTenantError) -> Self { + use tonic::Code; + let code = match &err { + GetActiveTenantError::Broken(_) => Code::Internal, + GetActiveTenantError::Cancelled => Code::Unavailable, + GetActiveTenantError::NotFound(_) => Code::NotFound, + GetActiveTenantError::SwitchedTenant => Code::Unavailable, + GetActiveTenantError::WaitForActiveTimeout { .. } => Code::Unavailable, + GetActiveTenantError::WillNotBecomeActive(_) => Code::Unavailable, + }; + tonic::Status::new(code, err.to_string()) + } +} + +impl From for QueryError { + fn from(e: HandleUpgradeError) -> Self { match e { - crate::tenant::timeline::handle::HandleUpgradeError::ShutDown => QueryError::Shutdown, + HandleUpgradeError::ShutDown => QueryError::Shutdown, + } + } +} + +impl From for tonic::Status { + fn from(err: HandleUpgradeError) -> Self { + match err { + HandleUpgradeError::ShutDown => tonic::Status::unavailable("timeline is shutting down"), } } } diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index 23c40a7629..9ddbe404d2 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -950,6 +950,18 @@ pub(crate) enum WaitLsnError { Timeout(String), } +impl From for tonic::Status { + fn from(err: WaitLsnError) -> Self { + use tonic::Code; + let code = match &err { + WaitLsnError::Timeout(_) => Code::Internal, + WaitLsnError::BadState(_) => Code::Internal, + WaitLsnError::Shutdown => Code::Unavailable, + }; + tonic::Status::new(code, err.to_string()) + } +} + // The impls below achieve cancellation mapping for errors. // Perhaps there's a way of achieving this with less cruft. From a650f7f5af4773bc6c7806a12b49e84234c7e6d6 Mon Sep 17 00:00:00 2001 From: "Alex Chi Z." <4198311+skyzh@users.noreply.github.com> Date: Tue, 3 Jun 2025 13:00:34 +0800 Subject: [PATCH 15/43] fix(pageserver): only deserialize reldir key once during get_db_size (#12102) ## Problem fix https://github.com/neondatabase/neon/issues/12101; this is a quick hack and we need better API in the future. In `get_db_size`, we call `get_reldir_size` for every relation. However, we do the same deserializing the reldir directory thing for every relation. This creates huge CPU overhead. ## Summary of changes Get and deserialize the reldir v1 key once and use it across all get_rel_size requests. --------- Signed-off-by: Alex Chi Z --- pageserver/src/pgdatadir_mapping.rs | 58 ++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 2 deletions(-) diff --git a/pageserver/src/pgdatadir_mapping.rs b/pageserver/src/pgdatadir_mapping.rs index c6f3929257..b6f11b744b 100644 --- a/pageserver/src/pgdatadir_mapping.rs +++ b/pageserver/src/pgdatadir_mapping.rs @@ -471,8 +471,19 @@ impl Timeline { let rels = self.list_rels(spcnode, dbnode, version, ctx).await?; + if rels.is_empty() { + return Ok(0); + } + + // Pre-deserialize the rel directory to avoid duplicated work in `get_relsize_cached`. + let reldir_key = rel_dir_to_key(spcnode, dbnode); + let buf = version.get(self, reldir_key, ctx).await?; + let reldir = RelDirectory::des(&buf)?; + for rel in rels { - let n_blocks = self.get_rel_size(rel, version, ctx).await?; + let n_blocks = self + .get_rel_size_in_reldir(rel, version, Some((reldir_key, &reldir)), ctx) + .await?; total_blocks += n_blocks as usize; } Ok(total_blocks) @@ -487,6 +498,19 @@ impl Timeline { tag: RelTag, version: Version<'_>, ctx: &RequestContext, + ) -> Result { + self.get_rel_size_in_reldir(tag, version, None, ctx).await + } + + /// Get size of a relation file. The relation must exist, otherwise an error is returned. + /// + /// See [`Self::get_rel_exists_in_reldir`] on why we need `deserialized_reldir_v1`. + pub(crate) async fn get_rel_size_in_reldir( + &self, + tag: RelTag, + version: Version<'_>, + deserialized_reldir_v1: Option<(Key, &RelDirectory)>, + ctx: &RequestContext, ) -> Result { if tag.relnode == 0 { return Err(PageReconstructError::Other( @@ -499,7 +523,9 @@ impl Timeline { } if (tag.forknum == FSM_FORKNUM || tag.forknum == VISIBILITYMAP_FORKNUM) - && !self.get_rel_exists(tag, version, ctx).await? + && !self + .get_rel_exists_in_reldir(tag, version, deserialized_reldir_v1, ctx) + .await? { // FIXME: Postgres sometimes calls smgrcreate() to create // FSM, and smgrnblocks() on it immediately afterwards, @@ -521,11 +547,28 @@ impl Timeline { /// /// Only shard 0 has a full view of the relations. Other shards only know about relations that /// the shard stores pages for. + /// pub(crate) async fn get_rel_exists( &self, tag: RelTag, version: Version<'_>, ctx: &RequestContext, + ) -> Result { + self.get_rel_exists_in_reldir(tag, version, None, ctx).await + } + + /// Does the relation exist? With a cached deserialized `RelDirectory`. + /// + /// There are some cases where the caller loops across all relations. In that specific case, + /// the caller should obtain the deserialized `RelDirectory` first and then call this function + /// to avoid duplicated work of deserliazation. This is a hack and should be removed by introducing + /// a new API (e.g., `get_rel_exists_batched`). + pub(crate) async fn get_rel_exists_in_reldir( + &self, + tag: RelTag, + version: Version<'_>, + deserialized_reldir_v1: Option<(Key, &RelDirectory)>, + ctx: &RequestContext, ) -> Result { if tag.relnode == 0 { return Err(PageReconstructError::Other( @@ -568,6 +611,17 @@ impl Timeline { // fetch directory listing (old) let key = rel_dir_to_key(tag.spcnode, tag.dbnode); + + if let Some((cached_key, dir)) = deserialized_reldir_v1 { + if cached_key == key { + return Ok(dir.rels.contains(&(tag.relnode, tag.forknum))); + } else if cfg!(test) || cfg!(feature = "testing") { + panic!("cached reldir key mismatch: {cached_key} != {key}"); + } else { + warn!("cached reldir key mismatch: {cached_key} != {key}"); + } + // Fallback to reading the directory from the datadir. + } let buf = version.get(self, key, ctx).await?; let dir = RelDirectory::des(&buf)?; From 3e72edede524af50220d0d103df08a1f6e12e6a9 Mon Sep 17 00:00:00 2001 From: a-masterov <72613290+a-masterov@users.noreply.github.com> Date: Tue, 3 Jun 2025 09:23:17 +0200 Subject: [PATCH 16/43] Use full hostname for ONNX URL (#12064) ## Problem We should use the full host name for computes, according to https://github.com/neondatabase/cloud/issues/26005 , but now a truncated host name is used. ## Summary of changes The URL for REMOTE_ONNX is rewritten using the FQDN. --- compute/compute-node.Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compute/compute-node.Dockerfile b/compute/compute-node.Dockerfile index 3459983a34..2afdde0cfa 100644 --- a/compute/compute-node.Dockerfile +++ b/compute/compute-node.Dockerfile @@ -1180,14 +1180,14 @@ RUN cd exts/rag && \ RUN cd exts/rag_bge_small_en_v15 && \ sed -i 's/pgrx = "0.14.1"/pgrx = { version = "0.14.1", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \ ORT_LIB_LOCATION=/ext-src/onnxruntime-src/build/Linux \ - REMOTE_ONNX_URL=http://pg-ext-s3-gateway/pgrag-data/bge_small_en_v15.onnx \ + REMOTE_ONNX_URL=http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local/pgrag-data/bge_small_en_v15.onnx \ cargo pgrx install --release --features remote_onnx && \ echo "trusted = true" >> /usr/local/pgsql/share/extension/rag_bge_small_en_v15.control RUN cd exts/rag_jina_reranker_v1_tiny_en && \ sed -i 's/pgrx = "0.14.1"/pgrx = { version = "0.14.1", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \ ORT_LIB_LOCATION=/ext-src/onnxruntime-src/build/Linux \ - REMOTE_ONNX_URL=http://pg-ext-s3-gateway/pgrag-data/jina_reranker_v1_tiny_en.onnx \ + REMOTE_ONNX_URL=http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local/pgrag-data/jina_reranker_v1_tiny_en.onnx \ cargo pgrx install --release --features remote_onnx && \ echo "trusted = true" >> /usr/local/pgsql/share/extension/rag_jina_reranker_v1_tiny_en.control From 3b8be98b67acbb3da0852ca5adf33408e2313f89 Mon Sep 17 00:00:00 2001 From: Vlad Lazar Date: Tue, 3 Jun 2025 10:07:07 +0100 Subject: [PATCH 17/43] pageserver: remove backtrace in info level log (#12108) ## Problem We print a backtrace in an info level log every 10 seconds while waiting for the import data to land in the bucket. ## Summary of changes The backtrace is not useful. Remove it. --- pageserver/src/tenant/timeline/import_pgdata.rs | 4 ++-- pageserver/src/tenant/timeline/import_pgdata/flow.rs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pageserver/src/tenant/timeline/import_pgdata.rs b/pageserver/src/tenant/timeline/import_pgdata.rs index 3f760d858b..606ad09ef1 100644 --- a/pageserver/src/tenant/timeline/import_pgdata.rs +++ b/pageserver/src/tenant/timeline/import_pgdata.rs @@ -201,8 +201,8 @@ async fn prepare_import( .await; match res { Ok(_) => break, - Err(err) => { - info!(?err, "indefinitely waiting for pgdata to finish"); + Err(_err) => { + info!("indefinitely waiting for pgdata to finish"); if tokio::time::timeout(std::time::Duration::from_secs(10), cancel.cancelled()) .await .is_ok() diff --git a/pageserver/src/tenant/timeline/import_pgdata/flow.rs b/pageserver/src/tenant/timeline/import_pgdata/flow.rs index 760e82dd57..2ec9d86720 100644 --- a/pageserver/src/tenant/timeline/import_pgdata/flow.rs +++ b/pageserver/src/tenant/timeline/import_pgdata/flow.rs @@ -471,6 +471,8 @@ impl Plan { last_completed_job_idx = job_idx; if last_completed_job_idx % checkpoint_every == 0 { + tracing::info!(last_completed_job_idx, jobs=%jobs_in_plan, "Checkpointing import status"); + let progress = ShardImportProgressV1 { jobs: jobs_in_plan, completed: last_completed_job_idx, @@ -492,8 +494,6 @@ impl Plan { anyhow::anyhow!("Shut down while putting timeline import status") })?; } - - tracing::info!(last_completed_job_idx, jobs=%jobs_in_plan, "Checkpointing import status"); }, Some(Err(_)) => { anyhow::bail!( From e00fd45bba98dcc22b14e3555684245a43bee160 Mon Sep 17 00:00:00 2001 From: Erik Grinaker Date: Tue, 3 Jun 2025 14:20:34 +0200 Subject: [PATCH 18/43] page_api: remove smallvec (#12095) ## Problem The gRPC `page_api` domain types used smallvecs to avoid heap allocations in the common case where a single page is requested. However, this is pointless: the Protobuf types use a normal vec, and converting a smallvec into a vec always causes a heap allocation anyway. ## Summary of changes Use a normal `Vec` instead of a `SmallVec` in `page_api` domain types. --- Cargo.lock | 1 - pageserver/page_api/Cargo.toml | 1 - pageserver/page_api/src/model.rs | 13 ++++++------- pageserver/src/page_service.rs | 4 ++-- 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9fc233e5ec..542a4528d3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4469,7 +4469,6 @@ dependencies = [ "pageserver_api", "postgres_ffi", "prost 0.13.5", - "smallvec", "thiserror 1.0.69", "tonic 0.13.1", "tonic-build", diff --git a/pageserver/page_api/Cargo.toml b/pageserver/page_api/Cargo.toml index 4f62c77eb2..e643b5749b 100644 --- a/pageserver/page_api/Cargo.toml +++ b/pageserver/page_api/Cargo.toml @@ -9,7 +9,6 @@ bytes.workspace = true pageserver_api.workspace = true postgres_ffi.workspace = true prost.workspace = true -smallvec.workspace = true thiserror.workspace = true tonic.workspace = true utils.workspace = true diff --git a/pageserver/page_api/src/model.rs b/pageserver/page_api/src/model.rs index 0268ab920b..a6853895d9 100644 --- a/pageserver/page_api/src/model.rs +++ b/pageserver/page_api/src/model.rs @@ -14,7 +14,6 @@ use std::fmt::Display; use bytes::Bytes; use postgres_ffi::Oid; -use smallvec::SmallVec; // TODO: split out Lsn, RelTag, SlruKind, Oid and other basic types to a separate crate, to avoid // pulling in all of their other crate dependencies when building the client. use utils::lsn::Lsn; @@ -302,7 +301,7 @@ pub struct GetPageRequest { /// Multiple pages will be executed as a single batch by the Pageserver, amortizing layer access /// costs and parallelizing them. This may increase the latency of any individual request, but /// improves the overall latency and throughput of the batch as a whole. - pub block_numbers: SmallVec<[u32; 1]>, + pub block_numbers: Vec, } impl TryFrom for GetPageRequest { @@ -320,7 +319,7 @@ impl TryFrom for GetPageRequest { .ok_or(ProtocolError::Missing("read_lsn"))? .try_into()?, rel: pb.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?, - block_numbers: pb.block_number.into(), + block_numbers: pb.block_number, }) } } @@ -337,7 +336,7 @@ impl TryFrom for proto::GetPageRequest { request_class: request.request_class.into(), read_lsn: Some(request.read_lsn.try_into()?), rel: Some(request.rel.into()), - block_number: request.block_numbers.into_vec(), + block_number: request.block_numbers, }) } } @@ -410,7 +409,7 @@ pub struct GetPageResponse { /// A string describing the status, if any. pub reason: Option, /// The 8KB page images, in the same order as the request. Empty if status != OK. - pub page_images: SmallVec<[Bytes; 1]>, + pub page_images: Vec, } impl From for GetPageResponse { @@ -419,7 +418,7 @@ impl From for GetPageResponse { request_id: pb.request_id, status_code: pb.status_code.into(), reason: Some(pb.reason).filter(|r| !r.is_empty()), - page_images: pb.page_image.into(), + page_images: pb.page_image, } } } @@ -430,7 +429,7 @@ impl From for proto::GetPageResponse { request_id: response.request_id, status_code: response.status_code.into(), reason: response.reason.unwrap_or_default(), - page_image: response.page_images.into_vec(), + page_image: response.page_images, } } } diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index b9ba4a3555..735c944970 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -765,7 +765,7 @@ impl PageStreamError { request_id, status_code, reason: Some(status.message().to_string()), - page_images: SmallVec::new(), + page_images: Vec::new(), } .into()) } @@ -3521,7 +3521,7 @@ impl GrpcPageServiceHandler { request_id: req.request_id, status_code: page_api::GetPageStatusCode::Ok, reason: None, - page_images: SmallVec::with_capacity(results.len()), + page_images: Vec::with_capacity(results.len()), }; for result in results { From 25fffd3a55b8c3dfba01631e94c25335cdcaa190 Mon Sep 17 00:00:00 2001 From: Trung Dinh Date: Tue, 3 Jun 2025 06:37:11 -0700 Subject: [PATCH 19/43] Validate max_batch_size against max_get_vectored_keys (#12052) ## Problem Setting `max_batch_size` to anything higher than `Timeline::MAX_GET_VECTORED_KEYS` will cause runtime error. We should rather fail fast at startup if this is the case. ## Summary of changes * Create `max_get_vectored_keys` as a new configuration (default to 32); * Validate `max_batch_size` against `max_get_vectored_keys` right at config parsing and validation. Closes https://github.com/neondatabase/neon/issues/11994 --- libs/pageserver_api/src/config.rs | 18 ++++++++++- pageserver/src/basebackup.rs | 2 +- pageserver/src/config.rs | 48 ++++++++++++++++++++++++++++- pageserver/src/metrics.rs | 6 ++-- pageserver/src/pgdatadir_mapping.rs | 10 +++--- pageserver/src/tenant.rs | 6 ++-- pageserver/src/tenant/timeline.rs | 18 ++++++----- 7 files changed, 86 insertions(+), 22 deletions(-) diff --git a/libs/pageserver_api/src/config.rs b/libs/pageserver_api/src/config.rs index 444983bd18..7ca623f8e5 100644 --- a/libs/pageserver_api/src/config.rs +++ b/libs/pageserver_api/src/config.rs @@ -181,6 +181,7 @@ pub struct ConfigToml { pub virtual_file_io_engine: Option, pub ingest_batch_size: u64, pub max_vectored_read_bytes: MaxVectoredReadBytes, + pub max_get_vectored_keys: MaxGetVectoredKeys, pub image_compression: ImageCompressionAlgorithm, pub timeline_offloading: bool, pub ephemeral_bytes_per_memory_kb: usize, @@ -229,7 +230,7 @@ pub enum PageServicePipeliningConfig { } #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub struct PageServicePipeliningConfigPipelined { - /// Causes runtime errors if larger than max get_vectored batch size. + /// Failed config parsing and validation if larger than `max_get_vectored_keys`. pub max_batch_size: NonZeroUsize, pub execution: PageServiceProtocolPipelinedExecutionStrategy, // The default below is such that new versions of the software can start @@ -403,6 +404,16 @@ impl Default for EvictionOrder { #[serde(transparent)] pub struct MaxVectoredReadBytes(pub NonZeroUsize); +#[derive(Copy, Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[serde(transparent)] +pub struct MaxGetVectoredKeys(NonZeroUsize); + +impl MaxGetVectoredKeys { + pub fn get(&self) -> usize { + self.0.get() + } +} + /// Tenant-level configuration values, used for various purposes. #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] #[serde(default)] @@ -587,6 +598,8 @@ pub mod defaults { /// That is, slightly above 128 kB. pub const DEFAULT_MAX_VECTORED_READ_BYTES: usize = 130 * 1024; // 130 KiB + pub const DEFAULT_MAX_GET_VECTORED_KEYS: usize = 32; + pub const DEFAULT_IMAGE_COMPRESSION: ImageCompressionAlgorithm = ImageCompressionAlgorithm::Zstd { level: Some(1) }; @@ -685,6 +698,9 @@ impl Default for ConfigToml { max_vectored_read_bytes: (MaxVectoredReadBytes( NonZeroUsize::new(DEFAULT_MAX_VECTORED_READ_BYTES).unwrap(), )), + max_get_vectored_keys: (MaxGetVectoredKeys( + NonZeroUsize::new(DEFAULT_MAX_GET_VECTORED_KEYS).unwrap(), + )), image_compression: (DEFAULT_IMAGE_COMPRESSION), timeline_offloading: true, ephemeral_bytes_per_memory_kb: (DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB), diff --git a/pageserver/src/basebackup.rs b/pageserver/src/basebackup.rs index 4dba9d267c..2a0548b811 100644 --- a/pageserver/src/basebackup.rs +++ b/pageserver/src/basebackup.rs @@ -371,7 +371,7 @@ where .await? .partition( self.timeline.get_shard_identity(), - Timeline::MAX_GET_VECTORED_KEYS * BLCKSZ as u64, + self.timeline.conf.max_get_vectored_keys.get() as u64 * BLCKSZ as u64, ); let mut slru_builder = SlruSegmentsBuilder::new(&mut self.ar); diff --git a/pageserver/src/config.rs b/pageserver/src/config.rs index 89f7539722..628d4f6021 100644 --- a/pageserver/src/config.rs +++ b/pageserver/src/config.rs @@ -14,7 +14,10 @@ use std::time::Duration; use anyhow::{Context, bail, ensure}; use camino::{Utf8Path, Utf8PathBuf}; use once_cell::sync::OnceCell; -use pageserver_api::config::{DiskUsageEvictionTaskConfig, MaxVectoredReadBytes, PostHogConfig}; +use pageserver_api::config::{ + DiskUsageEvictionTaskConfig, MaxGetVectoredKeys, MaxVectoredReadBytes, + PageServicePipeliningConfig, PageServicePipeliningConfigPipelined, PostHogConfig, +}; use pageserver_api::models::ImageCompressionAlgorithm; use pageserver_api::shard::TenantShardId; use pem::Pem; @@ -185,6 +188,9 @@ pub struct PageServerConf { pub max_vectored_read_bytes: MaxVectoredReadBytes, + /// Maximum number of keys to be read in a single get_vectored call. + pub max_get_vectored_keys: MaxGetVectoredKeys, + pub image_compression: ImageCompressionAlgorithm, /// Whether to offload archived timelines automatically @@ -404,6 +410,7 @@ impl PageServerConf { secondary_download_concurrency, ingest_batch_size, max_vectored_read_bytes, + max_get_vectored_keys, image_compression, timeline_offloading, ephemeral_bytes_per_memory_kb, @@ -470,6 +477,7 @@ impl PageServerConf { secondary_download_concurrency, ingest_batch_size, max_vectored_read_bytes, + max_get_vectored_keys, image_compression, timeline_offloading, ephemeral_bytes_per_memory_kb, @@ -598,6 +606,19 @@ impl PageServerConf { ) })?; + if let PageServicePipeliningConfig::Pipelined(PageServicePipeliningConfigPipelined { + max_batch_size, + .. + }) = conf.page_service_pipelining + { + if max_batch_size.get() > conf.max_get_vectored_keys.get() { + return Err(anyhow::anyhow!( + "`max_batch_size` ({max_batch_size}) must be less than or equal to `max_get_vectored_keys` ({})", + conf.max_get_vectored_keys.get() + )); + } + }; + Ok(conf) } @@ -685,6 +706,7 @@ impl ConfigurableSemaphore { mod tests { use camino::Utf8PathBuf; + use rstest::rstest; use utils::id::NodeId; use super::PageServerConf; @@ -724,4 +746,28 @@ mod tests { PageServerConf::parse_and_validate(NodeId(0), config_toml, &workdir) .expect_err("parse_and_validate should fail for endpoint without scheme"); } + + #[rstest] + #[case(32, 32, true)] + #[case(64, 32, false)] + #[case(64, 64, true)] + #[case(128, 128, true)] + fn test_config_max_batch_size_is_valid( + #[case] max_batch_size: usize, + #[case] max_get_vectored_keys: usize, + #[case] is_valid: bool, + ) { + let input = format!( + r#" + control_plane_api = "http://localhost:6666" + max_get_vectored_keys = {max_get_vectored_keys} + page_service_pipelining = {{ mode="pipelined", execution="concurrent-futures", max_batch_size={max_batch_size}, batching="uniform-lsn" }} + "#, + ); + let config_toml = toml_edit::de::from_str::(&input) + .expect("config has valid fields"); + let workdir = Utf8PathBuf::from("/nonexistent"); + let result = PageServerConf::parse_and_validate(NodeId(0), config_toml, &workdir); + assert_eq!(result.is_ok(), is_valid); + } } diff --git a/pageserver/src/metrics.rs b/pageserver/src/metrics.rs index a9b2f1b7e0..4bef9b44a7 100644 --- a/pageserver/src/metrics.rs +++ b/pageserver/src/metrics.rs @@ -15,6 +15,7 @@ use metrics::{ register_int_gauge, register_int_gauge_vec, register_uint_gauge, register_uint_gauge_vec, }; use once_cell::sync::Lazy; +use pageserver_api::config::defaults::DEFAULT_MAX_GET_VECTORED_KEYS; use pageserver_api::config::{ PageServicePipeliningConfig, PageServicePipeliningConfigPipelined, PageServiceProtocolPipelinedBatchingStrategy, PageServiceProtocolPipelinedExecutionStrategy, @@ -32,7 +33,6 @@ use crate::config::PageServerConf; use crate::context::{PageContentKind, RequestContext}; use crate::pgdatadir_mapping::DatadirModificationStats; use crate::task_mgr::TaskKind; -use crate::tenant::Timeline; use crate::tenant::layer_map::LayerMap; use crate::tenant::mgr::TenantSlot; use crate::tenant::storage_layer::{InMemoryLayer, PersistentLayerDesc}; @@ -1939,7 +1939,7 @@ static SMGR_QUERY_TIME_GLOBAL: Lazy = Lazy::new(|| { }); static PAGE_SERVICE_BATCH_SIZE_BUCKETS_GLOBAL: Lazy> = Lazy::new(|| { - (1..=u32::try_from(Timeline::MAX_GET_VECTORED_KEYS).unwrap()) + (1..=u32::try_from(DEFAULT_MAX_GET_VECTORED_KEYS).unwrap()) .map(|v| v.into()) .collect() }); @@ -1957,7 +1957,7 @@ static PAGE_SERVICE_BATCH_SIZE_BUCKETS_PER_TIMELINE: Lazy> = Lazy::new( let mut buckets = Vec::new(); for i in 0.. { let bucket = 1 << i; - if bucket > u32::try_from(Timeline::MAX_GET_VECTORED_KEYS).unwrap() { + if bucket > u32::try_from(DEFAULT_MAX_GET_VECTORED_KEYS).unwrap() { break; } buckets.push(bucket.into()); diff --git a/pageserver/src/pgdatadir_mapping.rs b/pageserver/src/pgdatadir_mapping.rs index b6f11b744b..633d62210d 100644 --- a/pageserver/src/pgdatadir_mapping.rs +++ b/pageserver/src/pgdatadir_mapping.rs @@ -431,10 +431,10 @@ impl Timeline { GetVectoredError::InvalidLsn(e) => { Err(anyhow::anyhow!("invalid LSN: {e:?}").into()) } - // NB: this should never happen in practice because we limit MAX_GET_VECTORED_KEYS + // NB: this should never happen in practice because we limit batch size to be smaller than max_get_vectored_keys // TODO: we can prevent this error class by moving this check into the type system - GetVectoredError::Oversized(err) => { - Err(anyhow::anyhow!("batching oversized: {err:?}").into()) + GetVectoredError::Oversized(err, max) => { + Err(anyhow::anyhow!("batching oversized: {err} > {max}").into()) } }; @@ -719,7 +719,7 @@ impl Timeline { let batches = keyspace.partition( self.get_shard_identity(), - Timeline::MAX_GET_VECTORED_KEYS * BLCKSZ as u64, + self.conf.max_get_vectored_keys.get() as u64 * BLCKSZ as u64, ); let io_concurrency = IoConcurrency::spawn_from_conf( @@ -959,7 +959,7 @@ impl Timeline { let batches = keyspace.partition( self.get_shard_identity(), - Timeline::MAX_GET_VECTORED_KEYS * BLCKSZ as u64, + self.conf.max_get_vectored_keys.get() as u64 * BLCKSZ as u64, ); let io_concurrency = IoConcurrency::spawn_from_conf( diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index 308ada3fa1..f9fdc143b4 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -7197,7 +7197,7 @@ mod tests { let end = desc .key_range .start - .add(Timeline::MAX_GET_VECTORED_KEYS.try_into().unwrap()); + .add(tenant.conf.max_get_vectored_keys.get() as u32); reads.push(KeySpace { ranges: vec![start..end], }); @@ -11260,11 +11260,11 @@ mod tests { let mut keyspaces_at_lsn: HashMap = HashMap::default(); let mut used_keys: HashSet = HashSet::default(); - while used_keys.len() < Timeline::MAX_GET_VECTORED_KEYS as usize { + while used_keys.len() < tenant.conf.max_get_vectored_keys.get() { let selected_lsn = interesting_lsns.choose(&mut random).expect("not empty"); let mut selected_key = start_key.add(random.gen_range(0..KEY_DIMENSION_SIZE)); - while used_keys.len() < Timeline::MAX_GET_VECTORED_KEYS as usize { + while used_keys.len() < tenant.conf.max_get_vectored_keys.get() { if used_keys.contains(&selected_key) || selected_key >= start_key.add(KEY_DIMENSION_SIZE) { diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index 9ddbe404d2..8fdf84b6d3 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -817,8 +817,8 @@ pub(crate) enum GetVectoredError { #[error("timeline shutting down")] Cancelled, - #[error("requested too many keys: {0} > {}", Timeline::MAX_GET_VECTORED_KEYS)] - Oversized(u64), + #[error("requested too many keys: {0} > {1}")] + Oversized(u64, u64), #[error("requested at invalid LSN: {0}")] InvalidLsn(Lsn), @@ -1019,7 +1019,7 @@ impl From for PageReconstructError { match e { GetVectoredError::Cancelled => PageReconstructError::Cancelled, GetVectoredError::InvalidLsn(_) => PageReconstructError::Other(anyhow!("Invalid LSN")), - err @ GetVectoredError::Oversized(_) => PageReconstructError::Other(err.into()), + err @ GetVectoredError::Oversized(_, _) => PageReconstructError::Other(err.into()), GetVectoredError::MissingKey(err) => PageReconstructError::MissingKey(err), GetVectoredError::GetReadyAncestorError(err) => PageReconstructError::from(err), GetVectoredError::Other(err) => PageReconstructError::Other(err), @@ -1199,7 +1199,6 @@ impl Timeline { } } - pub(crate) const MAX_GET_VECTORED_KEYS: u64 = 32; pub(crate) const LAYERS_VISITED_WARN_THRESHOLD: u32 = 100; /// Look up multiple page versions at a given LSN @@ -1214,9 +1213,12 @@ impl Timeline { ) -> Result>, GetVectoredError> { let total_keyspace = query.total_keyspace(); - let key_count = total_keyspace.total_raw_size().try_into().unwrap(); - if key_count > Timeline::MAX_GET_VECTORED_KEYS { - return Err(GetVectoredError::Oversized(key_count)); + let key_count = total_keyspace.total_raw_size(); + if key_count > self.conf.max_get_vectored_keys.get() { + return Err(GetVectoredError::Oversized( + key_count as u64, + self.conf.max_get_vectored_keys.get() as u64, + )); } for range in &total_keyspace.ranges { @@ -5270,7 +5272,7 @@ impl Timeline { key = key.next(); // Maybe flush `key_rest_accum` - if key_request_accum.raw_size() >= Timeline::MAX_GET_VECTORED_KEYS + if key_request_accum.raw_size() >= self.conf.max_get_vectored_keys.get() as u64 || (last_key_in_range && key_request_accum.raw_size() > 0) { let query = From 5bdba70f7dc1c7a489769029d762f1bbea9c727e Mon Sep 17 00:00:00 2001 From: Erik Grinaker Date: Tue, 3 Jun 2025 15:50:41 +0200 Subject: [PATCH 20/43] =?UTF-8?q?page=5Fapi:=20only=20validate=20Protobuf?= =?UTF-8?q?=20=E2=86=92=20domain=20type=20conversion=20(#12115)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Problem Currently, `page_api` domain types validate message invariants both when converting Protobuf → domain and domain → Protobuf. This is annoying for clients, because they can't use stream combinators to convert streamed requests (needed for hot path performance), and also performs the validation twice in the common case. Blocks #12099. ## Summary of changes Only validate the Protobuf → domain type conversion, i.e. on the receiver side, and make domain → Protobuf infallible. This is where it matters -- the Protobuf types are less strict than the domain types, and receivers should expect all sorts of junk from senders (they're not required to validate anyway, and can just construct an invalid message manually). Also adds a missing `impl From for proto::CheckRelExistsRequest`. --- pageserver/page_api/src/model.rs | 142 +++++++++++++------------------ pageserver/src/page_service.rs | 4 +- 2 files changed, 62 insertions(+), 84 deletions(-) diff --git a/pageserver/page_api/src/model.rs b/pageserver/page_api/src/model.rs index a6853895d9..1a08d04cc1 100644 --- a/pageserver/page_api/src/model.rs +++ b/pageserver/page_api/src/model.rs @@ -9,6 +9,11 @@ //! - Use more precise datatypes, e.g. Lsn and uints shorter than 32 bits. //! //! - Validate protocol invariants, via try_from() and try_into(). +//! +//! Validation only happens on the receiver side, i.e. when converting from Protobuf to domain +//! types. This is where it matters -- the Protobuf types are less strict than the domain types, and +//! receivers should expect all sorts of junk from senders. This also allows the sender to use e.g. +//! stream combinators without dealing with errors, and avoids validating the same message twice. use std::fmt::Display; @@ -71,47 +76,35 @@ impl Display for ReadLsn { } } -impl ReadLsn { - /// Validates the ReadLsn. - pub fn validate(&self) -> Result<(), ProtocolError> { - if self.request_lsn == Lsn::INVALID { - return Err(ProtocolError::invalid("request_lsn", self.request_lsn)); - } - if self.not_modified_since_lsn > Some(self.request_lsn) { - return Err(ProtocolError::invalid( - "not_modified_since_lsn", - self.not_modified_since_lsn, - )); - } - Ok(()) - } -} - impl TryFrom for ReadLsn { type Error = ProtocolError; fn try_from(pb: proto::ReadLsn) -> Result { - let read_lsn = Self { + if pb.request_lsn == 0 { + return Err(ProtocolError::invalid("request_lsn", pb.request_lsn)); + } + if pb.not_modified_since_lsn > pb.request_lsn { + return Err(ProtocolError::invalid( + "not_modified_since_lsn", + pb.not_modified_since_lsn, + )); + } + Ok(Self { request_lsn: Lsn(pb.request_lsn), not_modified_since_lsn: match pb.not_modified_since_lsn { 0 => None, lsn => Some(Lsn(lsn)), }, - }; - read_lsn.validate()?; - Ok(read_lsn) + }) } } -impl TryFrom for proto::ReadLsn { - type Error = ProtocolError; - - fn try_from(read_lsn: ReadLsn) -> Result { - read_lsn.validate()?; - Ok(Self { +impl From for proto::ReadLsn { + fn from(read_lsn: ReadLsn) -> Self { + Self { request_lsn: read_lsn.request_lsn.0, not_modified_since_lsn: read_lsn.not_modified_since_lsn.unwrap_or_default().0, - }) + } } } @@ -166,6 +159,15 @@ impl TryFrom for CheckRelExistsRequest { } } +impl From for proto::CheckRelExistsRequest { + fn from(request: CheckRelExistsRequest) -> Self { + Self { + read_lsn: Some(request.read_lsn.into()), + rel: Some(request.rel.into()), + } + } +} + pub type CheckRelExistsResponse = bool; impl From for CheckRelExistsResponse { @@ -203,14 +205,12 @@ impl TryFrom for GetBaseBackupRequest { } } -impl TryFrom for proto::GetBaseBackupRequest { - type Error = ProtocolError; - - fn try_from(request: GetBaseBackupRequest) -> Result { - Ok(Self { - read_lsn: Some(request.read_lsn.try_into()?), +impl From for proto::GetBaseBackupRequest { + fn from(request: GetBaseBackupRequest) -> Self { + Self { + read_lsn: Some(request.read_lsn.into()), replica: request.replica, - }) + } } } @@ -227,14 +227,9 @@ impl TryFrom for GetBaseBackupResponseChunk { } } -impl TryFrom for proto::GetBaseBackupResponseChunk { - type Error = ProtocolError; - - fn try_from(chunk: GetBaseBackupResponseChunk) -> Result { - if chunk.is_empty() { - return Err(ProtocolError::Missing("chunk")); - } - Ok(Self { chunk }) +impl From for proto::GetBaseBackupResponseChunk { + fn from(chunk: GetBaseBackupResponseChunk) -> Self { + Self { chunk } } } @@ -259,14 +254,12 @@ impl TryFrom for GetDbSizeRequest { } } -impl TryFrom for proto::GetDbSizeRequest { - type Error = ProtocolError; - - fn try_from(request: GetDbSizeRequest) -> Result { - Ok(Self { - read_lsn: Some(request.read_lsn.try_into()?), +impl From for proto::GetDbSizeRequest { + fn from(request: GetDbSizeRequest) -> Self { + Self { + read_lsn: Some(request.read_lsn.into()), db_oid: request.db_oid, - }) + } } } @@ -324,20 +317,15 @@ impl TryFrom for GetPageRequest { } } -impl TryFrom for proto::GetPageRequest { - type Error = ProtocolError; - - fn try_from(request: GetPageRequest) -> Result { - if request.block_numbers.is_empty() { - return Err(ProtocolError::Missing("block_number")); - } - Ok(Self { +impl From for proto::GetPageRequest { + fn from(request: GetPageRequest) -> Self { + Self { request_id: request.request_id, request_class: request.request_class.into(), - read_lsn: Some(request.read_lsn.try_into()?), + read_lsn: Some(request.read_lsn.into()), rel: Some(request.rel.into()), block_number: request.block_numbers, - }) + } } } @@ -518,14 +506,12 @@ impl TryFrom for GetRelSizeRequest { } } -impl TryFrom for proto::GetRelSizeRequest { - type Error = ProtocolError; - - fn try_from(request: GetRelSizeRequest) -> Result { - Ok(Self { - read_lsn: Some(request.read_lsn.try_into()?), +impl From for proto::GetRelSizeRequest { + fn from(request: GetRelSizeRequest) -> Self { + Self { + read_lsn: Some(request.read_lsn.into()), rel: Some(request.rel.into()), - }) + } } } @@ -568,15 +554,13 @@ impl TryFrom for GetSlruSegmentRequest { } } -impl TryFrom for proto::GetSlruSegmentRequest { - type Error = ProtocolError; - - fn try_from(request: GetSlruSegmentRequest) -> Result { - Ok(Self { - read_lsn: Some(request.read_lsn.try_into()?), +impl From for proto::GetSlruSegmentRequest { + fn from(request: GetSlruSegmentRequest) -> Self { + Self { + read_lsn: Some(request.read_lsn.into()), kind: request.kind as u32, segno: request.segno, - }) + } } } @@ -593,15 +577,9 @@ impl TryFrom for GetSlruSegmentResponse { } } -impl TryFrom for proto::GetSlruSegmentResponse { - type Error = ProtocolError; - - fn try_from(segment: GetSlruSegmentResponse) -> Result { - // TODO: can a segment legitimately be empty? - if segment.is_empty() { - return Err(ProtocolError::Missing("segment")); - } - Ok(Self { segment }) +impl From for proto::GetSlruSegmentResponse { + fn from(segment: GetSlruSegmentResponse) -> Self { + Self { segment } } } diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index 735c944970..77e5f0a92b 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -3660,7 +3660,7 @@ impl proto::PageService for GrpcPageServiceHandler { if chunk.is_empty() { break; } - yield proto::GetBaseBackupResponseChunk::try_from(chunk.clone().freeze())?; + yield proto::GetBaseBackupResponseChunk::from(chunk.clone().freeze()); chunk.clear(); } } @@ -3806,7 +3806,7 @@ impl proto::PageService for GrpcPageServiceHandler { let resp = PageServerHandler::handle_get_slru_segment_request(&timeline, &req, &ctx).await?; let resp: page_api::GetSlruSegmentResponse = resp.segment; - Ok(tonic::Response::new(resp.try_into()?)) + Ok(tonic::Response::new(resp.into())) } } From a963aab14b80062acdfac54af29f876b0f044552 Mon Sep 17 00:00:00 2001 From: Vlad Lazar Date: Tue, 3 Jun 2025 15:57:36 +0100 Subject: [PATCH 21/43] pagserver: set default wal receiver proto to interpreted (#12100) ## Problem This is already the default in production and in our test suite. ## Summary of changes Set the default proto to interpreted to reduce friction when spinning up new regions or cells. --- libs/pageserver_api/src/config.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/libs/pageserver_api/src/config.rs b/libs/pageserver_api/src/config.rs index 7ca623f8e5..237833b9de 100644 --- a/libs/pageserver_api/src/config.rs +++ b/libs/pageserver_api/src/config.rs @@ -608,7 +608,10 @@ pub mod defaults { pub const DEFAULT_IO_BUFFER_ALIGNMENT: usize = 512; pub const DEFAULT_WAL_RECEIVER_PROTOCOL: utils::postgres_client::PostgresClientProtocol = - utils::postgres_client::PostgresClientProtocol::Vanilla; + utils::postgres_client::PostgresClientProtocol::Interpreted { + format: utils::postgres_client::InterpretedFormat::Protobuf, + compression: Some(utils::postgres_client::Compression::Zstd { level: 1 }), + }; pub const DEFAULT_SSL_KEY_FILE: &str = "server.key"; pub const DEFAULT_SSL_CERT_FILE: &str = "server.crt"; From 4a3f32bf4a2865db28aa8d99e80e4766169e0edf Mon Sep 17 00:00:00 2001 From: Tristan Partin Date: Tue, 3 Jun 2025 11:00:56 -0500 Subject: [PATCH 22/43] Clean up compute_tools::http::JsonResponse::invalid_status() (#12110) JsonResponse::error() properly logs an error message which can be read in the compute logs. invalid_status() was not going through that helper function, thus not logging anything. Signed-off-by: Tristan Partin --- compute_tools/src/http/mod.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/compute_tools/src/http/mod.rs b/compute_tools/src/http/mod.rs index 9ecc1b0093..9b01def966 100644 --- a/compute_tools/src/http/mod.rs +++ b/compute_tools/src/http/mod.rs @@ -48,11 +48,9 @@ impl JsonResponse { /// Create an error response related to the compute being in an invalid state pub(self) fn invalid_status(status: ComputeStatus) -> Response { - Self::create_response( + Self::error( StatusCode::PRECONDITION_FAILED, - &GenericAPIError { - error: format!("invalid compute status: {status}"), - }, + format!("invalid compute status: {status}"), ) } } From c698cee19ac8378fe536d65140dd24009e65ab35 Mon Sep 17 00:00:00 2001 From: Mikhail Date: Wed, 4 Jun 2025 06:38:03 +0100 Subject: [PATCH 23/43] ComputeSpec: prewarm_lfc_on_startup -> autoprewarm (#12120) https://github.com/neondatabase/cloud/pull/29472 https://github.com/neondatabase/cloud/issues/26346 --- compute_tools/src/compute.rs | 10 +++++----- compute_tools/tests/pg_helpers_tests.rs | 2 +- control_plane/src/endpoint.rs | 2 +- libs/compute_api/src/spec.rs | 4 ++-- libs/compute_api/tests/cluster_spec.json | 2 +- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index d678b7d670..1ed71180d0 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -396,7 +396,7 @@ impl ComputeNode { // because QEMU will already have its memory allocated from the host, and // the necessary binaries will already be cached. if cli_spec.is_none() { - this.prewarm_postgres()?; + this.prewarm_postgres_vm_memory()?; } // Set the up metric with Empty status before starting the HTTP server. @@ -779,7 +779,7 @@ impl ComputeNode { // Spawn the extension stats background task self.spawn_extension_stats_task(); - if pspec.spec.prewarm_lfc_on_startup { + if pspec.spec.autoprewarm { self.prewarm_lfc(); } Ok(()) @@ -1307,8 +1307,8 @@ impl ComputeNode { } /// Start and stop a postgres process to warm up the VM for startup. - pub fn prewarm_postgres(&self) -> Result<()> { - info!("prewarming"); + pub fn prewarm_postgres_vm_memory(&self) -> Result<()> { + info!("prewarming VM memory"); // Create pgdata let pgdata = &format!("{}.warmup", self.params.pgdata); @@ -1350,7 +1350,7 @@ impl ComputeNode { kill(pm_pid, Signal::SIGQUIT)?; info!("sent SIGQUIT signal"); pg.wait()?; - info!("done prewarming"); + info!("done prewarming vm memory"); // clean up let _ok = fs::remove_dir_all(pgdata); diff --git a/compute_tools/tests/pg_helpers_tests.rs b/compute_tools/tests/pg_helpers_tests.rs index 04b6ed2256..2b865c75d0 100644 --- a/compute_tools/tests/pg_helpers_tests.rs +++ b/compute_tools/tests/pg_helpers_tests.rs @@ -30,7 +30,7 @@ mod pg_helpers_tests { r#"fsync = off wal_level = logical hot_standby = on -prewarm_lfc_on_startup = off +autoprewarm = off neon.safekeepers = '127.0.0.1:6502,127.0.0.1:6503,127.0.0.1:6501' wal_log_hints = on log_connections = on diff --git a/control_plane/src/endpoint.rs b/control_plane/src/endpoint.rs index 708745446d..774a0053f8 100644 --- a/control_plane/src/endpoint.rs +++ b/control_plane/src/endpoint.rs @@ -747,7 +747,7 @@ impl Endpoint { logs_export_host: None::, endpoint_storage_addr: Some(endpoint_storage_addr), endpoint_storage_token: Some(endpoint_storage_token), - prewarm_lfc_on_startup: false, + autoprewarm: false, }; // this strange code is needed to support respec() in tests diff --git a/libs/compute_api/src/spec.rs b/libs/compute_api/src/spec.rs index 09b550b96c..9b7cf43bb9 100644 --- a/libs/compute_api/src/spec.rs +++ b/libs/compute_api/src/spec.rs @@ -178,9 +178,9 @@ pub struct ComputeSpec { /// JWT for authorizing requests to endpoint storage service pub endpoint_storage_token: Option, - /// If true, download LFC state from endpoint_storage and pass it to Postgres on startup + /// Download LFC state from endpoint_storage and pass it to Postgres on startup #[serde(default)] - pub prewarm_lfc_on_startup: bool, + pub autoprewarm: bool, } /// Feature flag to signal `compute_ctl` to enable certain experimental functionality. diff --git a/libs/compute_api/tests/cluster_spec.json b/libs/compute_api/tests/cluster_spec.json index 30e788a601..2dd2aae015 100644 --- a/libs/compute_api/tests/cluster_spec.json +++ b/libs/compute_api/tests/cluster_spec.json @@ -85,7 +85,7 @@ "vartype": "bool" }, { - "name": "prewarm_lfc_on_startup", + "name": "autoprewarm", "value": "off", "vartype": "bool" }, From c567ed0de0837bc3923d39ca8706ac353a2ba50b Mon Sep 17 00:00:00 2001 From: "Alex Chi Z." <4198311+skyzh@users.noreply.github.com> Date: Wed, 4 Jun 2025 14:41:42 +0800 Subject: [PATCH 24/43] feat(pageserver): feature flag counter metrics (#12112) ## Problem Part of https://github.com/neondatabase/neon/issues/11813 ## Summary of changes Add a counter on the feature evaluation outcome and we will set up alerts for too many failed evaluations in the future. Signed-off-by: Alex Chi Z --- libs/posthog_client_lite/src/lib.rs | 10 ++++++++ pageserver/src/feature_resolver.rs | 36 +++++++++++++++++++++++++---- pageserver/src/metrics.rs | 9 ++++++++ 3 files changed, 50 insertions(+), 5 deletions(-) diff --git a/libs/posthog_client_lite/src/lib.rs b/libs/posthog_client_lite/src/lib.rs index ff12051196..281a8f8011 100644 --- a/libs/posthog_client_lite/src/lib.rs +++ b/libs/posthog_client_lite/src/lib.rs @@ -22,6 +22,16 @@ pub enum PostHogEvaluationError { Internal(String), } +impl PostHogEvaluationError { + pub fn as_variant_str(&self) -> &'static str { + match self { + PostHogEvaluationError::NotAvailable(_) => "not_available", + PostHogEvaluationError::NoConditionGroupMatched => "no_condition_group_matched", + PostHogEvaluationError::Internal(_) => "internal", + } + } +} + #[derive(Deserialize)] pub struct LocalEvaluationResponse { pub flags: Vec, diff --git a/pageserver/src/feature_resolver.rs b/pageserver/src/feature_resolver.rs index 7e31b930d0..7463186c06 100644 --- a/pageserver/src/feature_resolver.rs +++ b/pageserver/src/feature_resolver.rs @@ -6,7 +6,7 @@ use posthog_client_lite::{ use tokio_util::sync::CancellationToken; use utils::id::TenantId; -use crate::config::PageServerConf; +use crate::{config::PageServerConf, metrics::FEATURE_FLAG_EVALUATION}; #[derive(Clone)] pub struct FeatureResolver { @@ -55,11 +55,24 @@ impl FeatureResolver { tenant_id: TenantId, ) -> Result { if let Some(inner) = &self.inner { - inner.feature_store().evaluate_multivariate( + let res = inner.feature_store().evaluate_multivariate( flag_key, &tenant_id.to_string(), &HashMap::new(), - ) + ); + match &res { + Ok(value) => { + FEATURE_FLAG_EVALUATION + .with_label_values(&[flag_key, "ok", value]) + .inc(); + } + Err(e) => { + FEATURE_FLAG_EVALUATION + .with_label_values(&[flag_key, "error", e.as_variant_str()]) + .inc(); + } + } + res } else { Err(PostHogEvaluationError::NotAvailable( "PostHog integration is not enabled".to_string(), @@ -80,11 +93,24 @@ impl FeatureResolver { tenant_id: TenantId, ) -> Result<(), PostHogEvaluationError> { if let Some(inner) = &self.inner { - inner.feature_store().evaluate_boolean( + let res = inner.feature_store().evaluate_boolean( flag_key, &tenant_id.to_string(), &HashMap::new(), - ) + ); + match &res { + Ok(()) => { + FEATURE_FLAG_EVALUATION + .with_label_values(&[flag_key, "ok", "true"]) + .inc(); + } + Err(e) => { + FEATURE_FLAG_EVALUATION + .with_label_values(&[flag_key, "error", e.as_variant_str()]) + .inc(); + } + } + res } else { Err(PostHogEvaluationError::NotAvailable( "PostHog integration is not enabled".to_string(), diff --git a/pageserver/src/metrics.rs b/pageserver/src/metrics.rs index 4bef9b44a7..3173ab221f 100644 --- a/pageserver/src/metrics.rs +++ b/pageserver/src/metrics.rs @@ -446,6 +446,15 @@ static PAGE_CACHE_ERRORS: Lazy = Lazy::new(|| { .expect("failed to define a metric") }); +pub(crate) static FEATURE_FLAG_EVALUATION: Lazy = Lazy::new(|| { + register_counter_vec!( + "pageserver_feature_flag_evaluation", + "Number of times a feature flag is evaluated", + &["flag_key", "status", "value"], + ) + .unwrap() +}); + #[derive(IntoStaticStr)] #[strum(serialize_all = "kebab_case")] pub(crate) enum PageCacheErrorKind { From 208cbd52d44fb5dd322e6190cdd1ab25e8160d24 Mon Sep 17 00:00:00 2001 From: a-masterov <72613290+a-masterov@users.noreply.github.com> Date: Wed, 4 Jun 2025 11:57:31 +0200 Subject: [PATCH 25/43] Add postgis to the test image (#11672) ## Problem We don't currently run tests for PostGIS in our test environment. ## Summary of Changes - Added PostGIS test support for PostgreSQL v16 and v17 - Configured different PostGIS versions based on PostgreSQL version: - PostgreSQL v17: PostGIS 3.5.0 - PostgreSQL v14/v15/v16: PostGIS 3.3.3 - Added necessary test scripts and configurations This ensures our PostgreSQL implementation remains compatible with this widely-used extension. --------- Co-authored-by: Alexander Bayandin Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> --- compute/compute-node.Dockerfile | 18 +- docker-compose/compute_wrapper/Dockerfile | 2 +- docker-compose/docker-compose.yml | 7 +- docker-compose/docker_compose_test.sh | 11 +- .../ext-src/postgis-src/README-Neon.md | 70 ++++++ .../ext-src/postgis-src/neon-test.sh | 9 + .../postgis-src/postgis-no-upgrade-test.patch | 21 ++ .../postgis-src/postgis-regular-v16.patch | 198 ++++++++++++++++ .../postgis-src/postgis-regular-v17.patch | 218 ++++++++++++++++++ .../postgis-src/raster_outdb_template.sql | 46 ++++ .../ext-src/postgis-src/regular-test.sh | 17 ++ docker-compose/run-tests.sh | 4 + 12 files changed, 615 insertions(+), 6 deletions(-) create mode 100644 docker-compose/ext-src/postgis-src/README-Neon.md create mode 100755 docker-compose/ext-src/postgis-src/neon-test.sh create mode 100644 docker-compose/ext-src/postgis-src/postgis-no-upgrade-test.patch create mode 100644 docker-compose/ext-src/postgis-src/postgis-regular-v16.patch create mode 100644 docker-compose/ext-src/postgis-src/postgis-regular-v17.patch create mode 100644 docker-compose/ext-src/postgis-src/raster_outdb_template.sql create mode 100755 docker-compose/ext-src/postgis-src/regular-test.sh diff --git a/compute/compute-node.Dockerfile b/compute/compute-node.Dockerfile index 2afdde0cfa..6ece9c6566 100644 --- a/compute/compute-node.Dockerfile +++ b/compute/compute-node.Dockerfile @@ -297,6 +297,7 @@ RUN ./autogen.sh && \ ./configure --with-sfcgal=/usr/local/bin/sfcgal-config && \ make -j $(getconf _NPROCESSORS_ONLN) && \ make -j $(getconf _NPROCESSORS_ONLN) install && \ + make staged-install && \ cd extensions/postgis && \ make clean && \ make -j $(getconf _NPROCESSORS_ONLN) install && \ @@ -1842,10 +1843,25 @@ RUN make PG_VERSION="${PG_VERSION:?}" -C compute FROM pg-build AS extension-tests ARG PG_VERSION +# This is required for the PostGIS test +RUN apt-get update && case $DEBIAN_VERSION in \ + bullseye) \ + apt-get install -y libproj19 libgdal28 time; \ + ;; \ + bookworm) \ + apt-get install -y libgdal32 libproj25 time; \ + ;; \ + *) \ + echo "Unknown Debian version ${DEBIAN_VERSION}" && exit 1 \ + ;; \ + esac + COPY docker-compose/ext-src/ /ext-src/ COPY --from=pg-build /postgres /postgres -#COPY --from=postgis-src /ext-src/ /ext-src/ +COPY --from=postgis-build /usr/local/pgsql/ /usr/local/pgsql/ +COPY --from=postgis-build /ext-src/postgis-src /ext-src/postgis-src +COPY --from=postgis-build /sfcgal/* /usr COPY --from=plv8-src /ext-src/ /ext-src/ COPY --from=h3-pg-src /ext-src/h3-pg-src /ext-src/h3-pg-src COPY --from=postgresql-unit-src /ext-src/ /ext-src/ diff --git a/docker-compose/compute_wrapper/Dockerfile b/docker-compose/compute_wrapper/Dockerfile index 9ef831a9cd..b89e69c650 100644 --- a/docker-compose/compute_wrapper/Dockerfile +++ b/docker-compose/compute_wrapper/Dockerfile @@ -13,6 +13,6 @@ RUN echo 'Acquire::Retries "5";' > /etc/apt/apt.conf.d/80-retries && \ jq \ netcat-openbsd #This is required for the pg_hintplan test -RUN mkdir -p /ext-src/pg_hint_plan-src /postgres/contrib/file_fdw && chown postgres /ext-src/pg_hint_plan-src /postgres/contrib/file_fdw +RUN mkdir -p /ext-src/pg_hint_plan-src /postgres/contrib/file_fdw /ext-src/postgis-src/ && chown postgres /ext-src/pg_hint_plan-src /postgres/contrib/file_fdw /ext-src/postgis-src USER postgres diff --git a/docker-compose/docker-compose.yml b/docker-compose/docker-compose.yml index fd3ad1fffc..2519b75c7f 100644 --- a/docker-compose/docker-compose.yml +++ b/docker-compose/docker-compose.yml @@ -186,13 +186,14 @@ services: neon-test-extensions: profiles: ["test-extensions"] - image: ${REPOSITORY:-ghcr.io/neondatabase}/neon-test-extensions-v${PG_TEST_VERSION:-16}:${TEST_EXTENSIONS_TAG:-${TAG:-latest}} + image: ${REPOSITORY:-ghcr.io/neondatabase}/neon-test-extensions-v${PG_TEST_VERSION:-${PG_VERSION:-16}}:${TEST_EXTENSIONS_TAG:-${TAG:-latest}} environment: - - PGPASSWORD=cloud_admin + - PGUSER=${PGUSER:-cloud_admin} + - PGPASSWORD=${PGPASSWORD:-cloud_admin} entrypoint: - "/bin/bash" - "-c" command: - - sleep 1800 + - sleep 3600 depends_on: - compute diff --git a/docker-compose/docker_compose_test.sh b/docker-compose/docker_compose_test.sh index 2645a49883..6edf90ca8d 100755 --- a/docker-compose/docker_compose_test.sh +++ b/docker-compose/docker_compose_test.sh @@ -54,6 +54,15 @@ for pg_version in ${TEST_VERSION_ONLY-14 15 16 17}; do # It cannot be moved to Dockerfile now because the database directory is created after the start of the container echo Adding dummy config docker compose exec compute touch /var/db/postgres/compute/compute_ctl_temp_override.conf + # Prepare for the PostGIS test + docker compose exec compute mkdir -p /tmp/pgis_reg/pgis_reg_tmp + TMPDIR=$(mktemp -d) + docker compose cp neon-test-extensions:/ext-src/postgis-src/raster/test "${TMPDIR}" + docker compose cp neon-test-extensions:/ext-src/postgis-src/regress/00-regress-install "${TMPDIR}" + docker compose exec compute mkdir -p /ext-src/postgis-src/raster /ext-src/postgis-src/regress /ext-src/postgis-src/regress/00-regress-install + docker compose cp "${TMPDIR}/test" compute:/ext-src/postgis-src/raster/test + docker compose cp "${TMPDIR}/00-regress-install" compute:/ext-src/postgis-src/regress + rm -rf "${TMPDIR}" # The following block copies the files for the pg_hintplan test to the compute node for the extension test in an isolated docker-compose environment TMPDIR=$(mktemp -d) docker compose cp neon-test-extensions:/ext-src/pg_hint_plan-src/data "${TMPDIR}/data" @@ -68,7 +77,7 @@ for pg_version in ${TEST_VERSION_ONLY-14 15 16 17}; do docker compose exec -T neon-test-extensions bash -c "(cd /postgres && patch -p1)" <"../compute/patches/contrib_pg${pg_version}.patch" # We are running tests now rm -f testout.txt testout_contrib.txt - docker compose exec -e USE_PGXS=1 -e SKIP=timescaledb-src,rdkit-src,postgis-src,pg_jsonschema-src,kq_imcx-src,wal2json_2_5-src,rag_jina_reranker_v1_tiny_en-src,rag_bge_small_en_v15-src \ + docker compose exec -e USE_PGXS=1 -e SKIP=timescaledb-src,rdkit-src,pg_jsonschema-src,kq_imcx-src,wal2json_2_5-src,rag_jina_reranker_v1_tiny_en-src,rag_bge_small_en_v15-src \ neon-test-extensions /run-tests.sh /ext-src | tee testout.txt && EXT_SUCCESS=1 || EXT_SUCCESS=0 docker compose exec -e SKIP=start-scripts,postgres_fdw,ltree_plpython,jsonb_plpython,jsonb_plperl,hstore_plpython,hstore_plperl,dblink,bool_plperl \ neon-test-extensions /run-tests.sh /postgres/contrib | tee testout_contrib.txt && CONTRIB_SUCCESS=1 || CONTRIB_SUCCESS=0 diff --git a/docker-compose/ext-src/postgis-src/README-Neon.md b/docker-compose/ext-src/postgis-src/README-Neon.md new file mode 100644 index 0000000000..5937fc782b --- /dev/null +++ b/docker-compose/ext-src/postgis-src/README-Neon.md @@ -0,0 +1,70 @@ +# PostGIS Testing in Neon + +This directory contains configuration files and patches for running PostGIS tests in the Neon database environment. + +## Overview + +PostGIS is a spatial database extension for PostgreSQL that adds support for geographic objects. Testing PostGIS compatibility ensures that Neon's modifications to PostgreSQL don't break compatibility with this critical extension. + +## PostGIS Versions + +- PostgreSQL v17: PostGIS 3.5.0 +- PostgreSQL v14/v15/v16: PostGIS 3.3.3 + +## Test Configuration + +The test setup includes: + +- `postgis-no-upgrade-test.patch`: Disables upgrade tests by removing the upgrade test section from regress/runtest.mk +- `postgis-regular-v16.patch`: Version-specific patch for PostgreSQL v16 +- `postgis-regular-v17.patch`: Version-specific patch for PostgreSQL v17 +- `regular-test.sh`: Script to run PostGIS tests as a regular user +- `neon-test.sh`: Script to handle version-specific test configurations +- `raster_outdb_template.sql`: Template for raster tests with explicit file paths + +## Excluded Tests + +**Important Note:** The test exclusions listed below are specifically for regular-user tests against staging instances. These exclusions are necessary because staging instances run with limited privileges and cannot perform operations requiring superuser access. Docker-compose based tests are not affected by these exclusions. + +### Tests Requiring Superuser Permissions + +These tests cannot be run as a regular user: +- `estimatedextent` +- `regress/core/legacy` +- `regress/core/typmod` +- `regress/loader/TestSkipANALYZE` +- `regress/loader/TestANALYZE` + +### Tests Requiring Filesystem Access + +These tests need direct filesystem access that is only possible for superusers: +- `loader/load_outdb` + +### Tests with Flaky Results + +These tests have assumptions that don't always hold true: +- `regress/core/computed_columns` - Assumes computed columns always outperform alternatives, which is not consistently true + +### Tests Requiring Tunable Parameter Modifications + +These tests attempt to modify the `postgis.gdal_enabled_drivers` parameter, which is only accessible to superusers: +- `raster/test/regress/rt_wkb` +- `raster/test/regress/rt_addband` +- `raster/test/regress/rt_setbandpath` +- `raster/test/regress/rt_fromgdalraster` +- `raster/test/regress/rt_asgdalraster` +- `raster/test/regress/rt_astiff` +- `raster/test/regress/rt_asjpeg` +- `raster/test/regress/rt_aspng` +- `raster/test/regress/permitted_gdal_drivers` +- Loader tests: `BasicOutDB`, `Tiled10x10`, `Tiled10x10Copy`, `Tiled8x8`, `TiledAuto`, `TiledAutoSkipNoData`, `TiledAutoCopyn` + +### Topology Tests (v17 only) +- `populate_topology_layer` +- `renametopogeometrycolumn` + +## Other Modifications + +- Binary.sql tests are modified to use explicit file paths +- Server-side SQL COPY commands (which require superuser privileges) are converted to client-side `\copy` commands +- Upgrade tests are disabled diff --git a/docker-compose/ext-src/postgis-src/neon-test.sh b/docker-compose/ext-src/postgis-src/neon-test.sh new file mode 100755 index 0000000000..2866649a1b --- /dev/null +++ b/docker-compose/ext-src/postgis-src/neon-test.sh @@ -0,0 +1,9 @@ +#!/bin/bash +set -ex +cd "$(dirname "$0")" +if [[ ${PG_VERSION} = v17 ]]; then + sed -i '/computed_columns/d' regress/core/tests.mk +fi +patch -p1 =" 120),1) +- TESTS += \ +- $(top_srcdir)/regress/core/computed_columns +-endif +- + ifeq ($(shell expr "$(POSTGIS_GEOS_VERSION)" ">=" 30700),1) + # GEOS-3.7 adds: + # ST_FrechetDistance +diff --git a/regress/loader/tests.mk b/regress/loader/tests.mk +index 1fc77ac..c3cb9de 100644 +--- a/regress/loader/tests.mk ++++ b/regress/loader/tests.mk +@@ -38,7 +38,5 @@ TESTS += \ + $(top_srcdir)/regress/loader/Latin1 \ + $(top_srcdir)/regress/loader/Latin1-implicit \ + $(top_srcdir)/regress/loader/mfile \ +- $(top_srcdir)/regress/loader/TestSkipANALYZE \ +- $(top_srcdir)/regress/loader/TestANALYZE \ + $(top_srcdir)/regress/loader/CharNoWidth + +diff --git a/regress/run_test.pl b/regress/run_test.pl +index 0ec5b2d..1c331f4 100755 +--- a/regress/run_test.pl ++++ b/regress/run_test.pl +@@ -147,7 +147,6 @@ $ENV{"LANG"} = "C"; + # Add locale info to the psql options + # Add pg12 precision suppression + my $PGOPTIONS = $ENV{"PGOPTIONS"}; +-$PGOPTIONS .= " -c lc_messages=C"; + $PGOPTIONS .= " -c client_min_messages=NOTICE"; + $PGOPTIONS .= " -c extra_float_digits=0"; + $ENV{"PGOPTIONS"} = $PGOPTIONS; diff --git a/docker-compose/ext-src/postgis-src/postgis-regular-v17.patch b/docker-compose/ext-src/postgis-src/postgis-regular-v17.patch new file mode 100644 index 0000000000..f4a9d83478 --- /dev/null +++ b/docker-compose/ext-src/postgis-src/postgis-regular-v17.patch @@ -0,0 +1,218 @@ +diff --git a/raster/test/regress/tests.mk b/raster/test/regress/tests.mk +index 00918e1..7e2b6cd 100644 +--- a/raster/test/regress/tests.mk ++++ b/raster/test/regress/tests.mk +@@ -17,9 +17,7 @@ override RUNTESTFLAGS_INTERNAL := \ + $(RUNTESTFLAGS_INTERNAL) \ + --after-upgrade-script $(top_srcdir)/raster/test/regress/hooks/hook-after-upgrade-raster.sql + +-RASTER_TEST_FIRST = \ +- $(top_srcdir)/raster/test/regress/check_gdal \ +- $(top_srcdir)/raster/test/regress/loader/load_outdb ++RASTER_TEST_FIRST = + + RASTER_TEST_LAST = \ + $(top_srcdir)/raster/test/regress/clean +@@ -33,9 +31,7 @@ RASTER_TEST_IO = \ + + RASTER_TEST_BASIC_FUNC = \ + $(top_srcdir)/raster/test/regress/rt_bytea \ +- $(top_srcdir)/raster/test/regress/rt_wkb \ + $(top_srcdir)/raster/test/regress/box3d \ +- $(top_srcdir)/raster/test/regress/rt_addband \ + $(top_srcdir)/raster/test/regress/rt_band \ + $(top_srcdir)/raster/test/regress/rt_tile + +@@ -73,16 +69,10 @@ RASTER_TEST_BANDPROPS = \ + $(top_srcdir)/raster/test/regress/rt_neighborhood \ + $(top_srcdir)/raster/test/regress/rt_nearestvalue \ + $(top_srcdir)/raster/test/regress/rt_pixelofvalue \ +- $(top_srcdir)/raster/test/regress/rt_polygon \ +- $(top_srcdir)/raster/test/regress/rt_setbandpath ++ $(top_srcdir)/raster/test/regress/rt_polygon + + RASTER_TEST_UTILITY = \ + $(top_srcdir)/raster/test/regress/rt_utility \ +- $(top_srcdir)/raster/test/regress/rt_fromgdalraster \ +- $(top_srcdir)/raster/test/regress/rt_asgdalraster \ +- $(top_srcdir)/raster/test/regress/rt_astiff \ +- $(top_srcdir)/raster/test/regress/rt_asjpeg \ +- $(top_srcdir)/raster/test/regress/rt_aspng \ + $(top_srcdir)/raster/test/regress/rt_reclass \ + $(top_srcdir)/raster/test/regress/rt_gdalwarp \ + $(top_srcdir)/raster/test/regress/rt_gdalcontour \ +@@ -120,21 +110,13 @@ RASTER_TEST_SREL = \ + + RASTER_TEST_BUGS = \ + $(top_srcdir)/raster/test/regress/bug_test_car5 \ +- $(top_srcdir)/raster/test/regress/permitted_gdal_drivers \ + $(top_srcdir)/raster/test/regress/tickets + + RASTER_TEST_LOADER = \ + $(top_srcdir)/raster/test/regress/loader/Basic \ + $(top_srcdir)/raster/test/regress/loader/Projected \ + $(top_srcdir)/raster/test/regress/loader/BasicCopy \ +- $(top_srcdir)/raster/test/regress/loader/BasicFilename \ +- $(top_srcdir)/raster/test/regress/loader/BasicOutDB \ +- $(top_srcdir)/raster/test/regress/loader/Tiled10x10 \ +- $(top_srcdir)/raster/test/regress/loader/Tiled10x10Copy \ +- $(top_srcdir)/raster/test/regress/loader/Tiled8x8 \ +- $(top_srcdir)/raster/test/regress/loader/TiledAuto \ +- $(top_srcdir)/raster/test/regress/loader/TiledAutoSkipNoData \ +- $(top_srcdir)/raster/test/regress/loader/TiledAutoCopyn ++ $(top_srcdir)/raster/test/regress/loader/BasicFilename + + RASTER_TESTS := $(RASTER_TEST_FIRST) \ + $(RASTER_TEST_METADATA) $(RASTER_TEST_IO) $(RASTER_TEST_BASIC_FUNC) \ +diff --git a/regress/core/binary.sql b/regress/core/binary.sql +index 7a36b65..ad78fc7 100644 +--- a/regress/core/binary.sql ++++ b/regress/core/binary.sql +@@ -1,4 +1,5 @@ + SET client_min_messages TO warning; ++ + CREATE SCHEMA tm; + + CREATE TABLE tm.geoms (id serial, g geometry); +@@ -31,24 +32,39 @@ SELECT st_force4d(g) FROM tm.geoms WHERE id < 15 ORDER BY id; + INSERT INTO tm.geoms(g) + SELECT st_setsrid(g,4326) FROM tm.geoms ORDER BY id; + +-COPY tm.geoms TO :tmpfile WITH BINARY; ++-- define temp file path ++\set tmpfile '/tmp/postgis_binary_test.dat' ++ ++-- export ++\set command '\\copy tm.geoms TO ':tmpfile' WITH (FORMAT BINARY)' ++:command ++ ++-- import + CREATE TABLE tm.geoms_in AS SELECT * FROM tm.geoms LIMIT 0; +-COPY tm.geoms_in FROM :tmpfile WITH BINARY; +-SELECT 'geometry', count(*) FROM tm.geoms_in i, tm.geoms o WHERE i.id = o.id +- AND ST_OrderingEquals(i.g, o.g); ++\set command '\\copy tm.geoms_in FROM ':tmpfile' WITH (FORMAT BINARY)' ++:command ++ ++SELECT 'geometry', count(*) FROM tm.geoms_in i, tm.geoms o ++WHERE i.id = o.id AND ST_OrderingEquals(i.g, o.g); + + CREATE TABLE tm.geogs AS SELECT id,g::geography FROM tm.geoms + WHERE geometrytype(g) NOT LIKE '%CURVE%' + AND geometrytype(g) NOT LIKE '%CIRCULAR%' + AND geometrytype(g) NOT LIKE '%SURFACE%' + AND geometrytype(g) NOT LIKE 'TRIANGLE%' +- AND geometrytype(g) NOT LIKE 'TIN%' +-; ++ AND geometrytype(g) NOT LIKE 'TIN%'; + +-COPY tm.geogs TO :tmpfile WITH BINARY; ++-- export ++\set command '\\copy tm.geogs TO ':tmpfile' WITH (FORMAT BINARY)' ++:command ++ ++-- import + CREATE TABLE tm.geogs_in AS SELECT * FROM tm.geogs LIMIT 0; +-COPY tm.geogs_in FROM :tmpfile WITH BINARY; +-SELECT 'geometry', count(*) FROM tm.geogs_in i, tm.geogs o WHERE i.id = o.id +- AND ST_OrderingEquals(i.g::geometry, o.g::geometry); ++\set command '\\copy tm.geogs_in FROM ':tmpfile' WITH (FORMAT BINARY)' ++:command ++ ++SELECT 'geometry', count(*) FROM tm.geogs_in i, tm.geogs o ++WHERE i.id = o.id AND ST_OrderingEquals(i.g::geometry, o.g::geometry); + + DROP SCHEMA tm CASCADE; ++ +diff --git a/regress/core/tests.mk b/regress/core/tests.mk +index 9e05244..a63a3e1 100644 +--- a/regress/core/tests.mk ++++ b/regress/core/tests.mk +@@ -16,14 +16,13 @@ POSTGIS_PGSQL_VERSION=170 + POSTGIS_GEOS_VERSION=31101 + HAVE_JSON=yes + HAVE_SPGIST=yes +-INTERRUPTTESTS=yes ++INTERRUPTTESTS=no + + current_dir := $(dir $(abspath $(lastword $(MAKEFILE_LIST)))) + + RUNTESTFLAGS_INTERNAL += \ + --before-upgrade-script $(top_srcdir)/regress/hooks/hook-before-upgrade.sql \ + --after-upgrade-script $(top_srcdir)/regress/hooks/hook-after-upgrade.sql \ +- --after-create-script $(top_srcdir)/regress/hooks/hook-after-create.sql \ + --before-uninstall-script $(top_srcdir)/regress/hooks/hook-before-uninstall.sql + + TESTS += \ +@@ -40,7 +39,6 @@ TESTS += \ + $(top_srcdir)/regress/core/dumppoints \ + $(top_srcdir)/regress/core/dumpsegments \ + $(top_srcdir)/regress/core/empty \ +- $(top_srcdir)/regress/core/estimatedextent \ + $(top_srcdir)/regress/core/forcecurve \ + $(top_srcdir)/regress/core/flatgeobuf \ + $(top_srcdir)/regress/core/frechet \ +@@ -60,7 +58,6 @@ TESTS += \ + $(top_srcdir)/regress/core/out_marc21 \ + $(top_srcdir)/regress/core/in_encodedpolyline \ + $(top_srcdir)/regress/core/iscollection \ +- $(top_srcdir)/regress/core/legacy \ + $(top_srcdir)/regress/core/letters \ + $(top_srcdir)/regress/core/lwgeom_regress \ + $(top_srcdir)/regress/core/measures \ +@@ -119,7 +116,6 @@ TESTS += \ + $(top_srcdir)/regress/core/temporal_knn \ + $(top_srcdir)/regress/core/tickets \ + $(top_srcdir)/regress/core/twkb \ +- $(top_srcdir)/regress/core/typmod \ + $(top_srcdir)/regress/core/wkb \ + $(top_srcdir)/regress/core/wkt \ + $(top_srcdir)/regress/core/wmsservers \ +@@ -143,8 +139,7 @@ TESTS += \ + $(top_srcdir)/regress/core/oriented_envelope \ + $(top_srcdir)/regress/core/point_coordinates \ + $(top_srcdir)/regress/core/out_geojson \ +- $(top_srcdir)/regress/core/wrapx \ +- $(top_srcdir)/regress/core/computed_columns ++ $(top_srcdir)/regress/core/wrapx + + # Slow slow tests + TESTS_SLOW = \ +diff --git a/regress/loader/tests.mk b/regress/loader/tests.mk +index ac4f8ad..4bad4fc 100644 +--- a/regress/loader/tests.mk ++++ b/regress/loader/tests.mk +@@ -38,7 +38,5 @@ TESTS += \ + $(top_srcdir)/regress/loader/Latin1 \ + $(top_srcdir)/regress/loader/Latin1-implicit \ + $(top_srcdir)/regress/loader/mfile \ +- $(top_srcdir)/regress/loader/TestSkipANALYZE \ +- $(top_srcdir)/regress/loader/TestANALYZE \ + $(top_srcdir)/regress/loader/CharNoWidth \ + +diff --git a/regress/run_test.pl b/regress/run_test.pl +index cac4b2e..4c7c82b 100755 +--- a/regress/run_test.pl ++++ b/regress/run_test.pl +@@ -238,7 +238,6 @@ $ENV{"LANG"} = "C"; + # Add locale info to the psql options + # Add pg12 precision suppression + my $PGOPTIONS = $ENV{"PGOPTIONS"}; +-$PGOPTIONS .= " -c lc_messages=C"; + $PGOPTIONS .= " -c client_min_messages=NOTICE"; + $PGOPTIONS .= " -c extra_float_digits=0"; + $ENV{"PGOPTIONS"} = $PGOPTIONS; +diff --git a/topology/test/tests.mk b/topology/test/tests.mk +index cbe2633..2c7c18f 100644 +--- a/topology/test/tests.mk ++++ b/topology/test/tests.mk +@@ -46,9 +46,7 @@ TESTS += \ + $(top_srcdir)/topology/test/regress/legacy_query.sql \ + $(top_srcdir)/topology/test/regress/legacy_validate.sql \ + $(top_srcdir)/topology/test/regress/polygonize.sql \ +- $(top_srcdir)/topology/test/regress/populate_topology_layer.sql \ + $(top_srcdir)/topology/test/regress/removeunusedprimitives.sql \ +- $(top_srcdir)/topology/test/regress/renametopogeometrycolumn.sql \ + $(top_srcdir)/topology/test/regress/renametopology.sql \ + $(top_srcdir)/topology/test/regress/share_sequences.sql \ + $(top_srcdir)/topology/test/regress/sqlmm.sql \ diff --git a/docker-compose/ext-src/postgis-src/raster_outdb_template.sql b/docker-compose/ext-src/postgis-src/raster_outdb_template.sql new file mode 100644 index 0000000000..16232f28dd --- /dev/null +++ b/docker-compose/ext-src/postgis-src/raster_outdb_template.sql @@ -0,0 +1,46 @@ +-- +-- PostgreSQL database dump +-- + +-- Dumped from database version 17.4 +-- Dumped by pg_dump version 17.4 + +SET statement_timeout = 0; +SET lock_timeout = 0; +SET idle_in_transaction_session_timeout = 0; +SET transaction_timeout = 0; +SET client_encoding = 'UTF8'; +SET standard_conforming_strings = on; +SELECT pg_catalog.set_config('search_path', '', false); +SET check_function_bodies = false; +SET xmloption = content; +SET client_min_messages = warning; + +-- +-- Name: raster_outdb_template; Type: TABLE; Schema: public; Owner: cloud_admin +-- + +CREATE TABLE public.raster_outdb_template ( + rid integer, + rast public.raster +); + + +ALTER TABLE public.raster_outdb_template OWNER TO cloud_admin; + +-- +-- Data for Name: raster_outdb_template; Type: TABLE DATA; Schema: public; Owner: cloud_admin +-- + +COPY public.raster_outdb_template (rid, rast) FROM stdin; +1 0100000300000000000000F03F000000000000F0BF0000000000000000000000000000000000000000000000000000000000000000000000005A0032008400002F6578742D7372632F706F73746769732D7372632F726567726573732F2E2E2F7261737465722F746573742F726567726573732F6C6F616465722F746573747261737465722E746966008400012F6578742D7372632F706F73746769732D7372632F726567726573732F2E2E2F7261737465722F746573742F726567726573732F6C6F616465722F746573747261737465722E746966008400022F6578742D7372632F706F73746769732D7372632F726567726573732F2E2E2F7261737465722F746573742F726567726573732F6C6F616465722F746573747261737465722E74696600 +2 0100000300000000000000F03F000000000000F0BF0000000000000000000000000000000000000000000000000000000000000000000000005A0032008400002F6578742D7372632F706F73746769732D7372632F726567726573732F2E2E2F7261737465722F746573742F726567726573732F6C6F616465722F746573747261737465722E746966008400012F6578742D7372632F706F73746769732D7372632F726567726573732F2E2E2F7261737465722F746573742F726567726573732F6C6F616465722F746573747261737465722E746966008400022F6578742D7372632F706F73746769732D7372632F726567726573732F2E2E2F7261737465722F746573742F726567726573732F6C6F616465722F746573747261737465722E74696600 +3 0100000200000000000000F03F000000000000F0BF0000000000000000000000000000000000000000000000000000000000000000000000005A00320044000101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101018400012F6578742D7372632F706F73746769732D7372632F726567726573732F2E2E2F7261737465722F746573742F726567726573732F6C6F616465722F746573747261737465722E74696600 +4 0100000200000000000000F03F000000000000F0BF0000000000000000000000000000000000000000000000000000000000000000000000005A003200C4FF012F6578742D7372632F706F73746769732D7372632F726567726573732F2E2E2F7261737465722F746573742F726567726573732F6C6F616465722F746573747261737465722E746966004400010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101 +\. + + +-- +-- PostgreSQL database dump complete +-- + diff --git a/docker-compose/ext-src/postgis-src/regular-test.sh b/docker-compose/ext-src/postgis-src/regular-test.sh new file mode 100755 index 0000000000..4b0b929946 --- /dev/null +++ b/docker-compose/ext-src/postgis-src/regular-test.sh @@ -0,0 +1,17 @@ +#!/bin/bash +set -ex +cd "$(dirname "${0}")" +dropdb --if-exist contrib_regression +createdb contrib_regression +psql -d contrib_regression -c "ALTER DATABASE contrib_regression SET TimeZone='UTC'" \ + -c "ALTER DATABASE contrib_regression SET DateStyle='ISO, MDY'" \ + -c "CREATE EXTENSION postgis SCHEMA public" \ + -c "CREATE EXTENSION postgis_topology" \ + -c "CREATE EXTENSION postgis_tiger_geocoder CASCADE" \ + -c "CREATE EXTENSION postgis_raster SCHEMA public" \ + -c "CREATE EXTENSION postgis_sfcgal SCHEMA public" +patch -p1 Date: Wed, 4 Jun 2025 11:44:23 +0100 Subject: [PATCH 26/43] pageserver: make import job max byte range size configurable (#12117) ## Problem We want to repro an OOM situation, but large partial reads are required. ## Summary of Changes Make the max partial read size configurable for import jobs. --- libs/pageserver_api/src/config.rs | 3 ++ .../src/tenant/timeline/import_pgdata/flow.rs | 28 +++++++++++++------ test_runner/fixtures/neon_fixtures.py | 1 + 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/libs/pageserver_api/src/config.rs b/libs/pageserver_api/src/config.rs index 237833b9de..46903965b1 100644 --- a/libs/pageserver_api/src/config.rs +++ b/libs/pageserver_api/src/config.rs @@ -330,6 +330,8 @@ pub struct TimelineImportConfig { pub import_job_concurrency: NonZeroUsize, pub import_job_soft_size_limit: NonZeroUsize, pub import_job_checkpoint_threshold: NonZeroUsize, + /// Max size of the remote storage partial read done by any job + pub import_job_max_byte_range_size: NonZeroUsize, } #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] @@ -735,6 +737,7 @@ impl Default for ConfigToml { import_job_concurrency: NonZeroUsize::new(32).unwrap(), import_job_soft_size_limit: NonZeroUsize::new(256 * 1024 * 1024).unwrap(), import_job_checkpoint_threshold: NonZeroUsize::new(32).unwrap(), + import_job_max_byte_range_size: NonZeroUsize::new(4 * 1024 * 1024).unwrap(), }, basebackup_cache_config: None, posthog_config: None, diff --git a/pageserver/src/tenant/timeline/import_pgdata/flow.rs b/pageserver/src/tenant/timeline/import_pgdata/flow.rs index 2ec9d86720..e003bb6810 100644 --- a/pageserver/src/tenant/timeline/import_pgdata/flow.rs +++ b/pageserver/src/tenant/timeline/import_pgdata/flow.rs @@ -100,6 +100,7 @@ async fn run_v1( .unwrap(), import_job_concurrency: base.import_job_concurrency, import_job_checkpoint_threshold: base.import_job_checkpoint_threshold, + import_job_max_byte_range_size: base.import_job_max_byte_range_size, } } None => timeline.conf.timeline_import_config.clone(), @@ -441,6 +442,7 @@ impl Plan { let mut last_completed_job_idx = start_after_job_idx.unwrap_or(0); let checkpoint_every: usize = import_config.import_job_checkpoint_threshold.into(); + let max_byte_range_size: usize = import_config.import_job_max_byte_range_size.into(); // Run import jobs concurrently up to the limit specified by the pageserver configuration. // Note that we process completed futures in the oreder of insertion. This will be the @@ -456,7 +458,7 @@ impl Plan { work.push_back(tokio::task::spawn(async move { let _permit = permit; - let res = job.run(job_timeline, &ctx).await; + let res = job.run(job_timeline, max_byte_range_size, &ctx).await; (job_idx, res) })); }, @@ -679,6 +681,7 @@ trait ImportTask { async fn doit( self, layer_writer: &mut ImageLayerWriter, + max_byte_range_size: usize, ctx: &RequestContext, ) -> anyhow::Result; } @@ -715,6 +718,7 @@ impl ImportTask for ImportSingleKeyTask { async fn doit( self, layer_writer: &mut ImageLayerWriter, + _max_byte_range_size: usize, ctx: &RequestContext, ) -> anyhow::Result { layer_writer.put_image(self.key, self.buf, ctx).await?; @@ -768,10 +772,9 @@ impl ImportTask for ImportRelBlocksTask { async fn doit( self, layer_writer: &mut ImageLayerWriter, + max_byte_range_size: usize, ctx: &RequestContext, ) -> anyhow::Result { - const MAX_BYTE_RANGE_SIZE: usize = 4 * 1024 * 1024; - debug!("Importing relation file"); let (rel_tag, start_blk) = self.key_range.start.to_rel_block()?; @@ -796,7 +799,7 @@ impl ImportTask for ImportRelBlocksTask { assert_eq!(key.len(), 1); assert!(!acc.is_empty()); assert!(acc_end > acc_start); - if acc_end == start && end - acc_start <= MAX_BYTE_RANGE_SIZE { + if acc_end == start && end - acc_start <= max_byte_range_size { acc.push(key.pop().unwrap()); Ok((acc, acc_start, end)) } else { @@ -860,6 +863,7 @@ impl ImportTask for ImportSlruBlocksTask { async fn doit( self, layer_writer: &mut ImageLayerWriter, + _max_byte_range_size: usize, ctx: &RequestContext, ) -> anyhow::Result { debug!("Importing SLRU segment file {}", self.path); @@ -906,12 +910,13 @@ impl ImportTask for AnyImportTask { async fn doit( self, layer_writer: &mut ImageLayerWriter, + max_byte_range_size: usize, ctx: &RequestContext, ) -> anyhow::Result { match self { - Self::SingleKey(t) => t.doit(layer_writer, ctx).await, - Self::RelBlocks(t) => t.doit(layer_writer, ctx).await, - Self::SlruBlocks(t) => t.doit(layer_writer, ctx).await, + Self::SingleKey(t) => t.doit(layer_writer, max_byte_range_size, ctx).await, + Self::RelBlocks(t) => t.doit(layer_writer, max_byte_range_size, ctx).await, + Self::SlruBlocks(t) => t.doit(layer_writer, max_byte_range_size, ctx).await, } } } @@ -952,7 +957,12 @@ impl ChunkProcessingJob { } } - async fn run(self, timeline: Arc, ctx: &RequestContext) -> anyhow::Result<()> { + async fn run( + self, + timeline: Arc, + max_byte_range_size: usize, + ctx: &RequestContext, + ) -> anyhow::Result<()> { let mut writer = ImageLayerWriter::new( timeline.conf, timeline.timeline_id, @@ -967,7 +977,7 @@ impl ChunkProcessingJob { let mut nimages = 0; for task in self.tasks { - nimages += task.doit(&mut writer, ctx).await?; + nimages += task.doit(&mut writer, max_byte_range_size, ctx).await?; } let resident_layer = if nimages > 0 { diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index ab4885ce6b..19d12da5e3 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -423,6 +423,7 @@ class PageserverImportConfig: "import_job_concurrency": self.import_job_concurrency, "import_job_soft_size_limit": self.import_job_soft_size_limit, "import_job_checkpoint_threshold": self.import_job_checkpoint_threshold, + "import_job_max_byte_range_size": 4 * 1024 * 1024, # Pageserver default } return ("timeline_import_config", value) From e4ca3ac7458ca22dd97e89da859adee966b11580 Mon Sep 17 00:00:00 2001 From: a-masterov <72613290+a-masterov@users.noreply.github.com> Date: Wed, 4 Jun 2025 17:07:48 +0200 Subject: [PATCH 27/43] Fix codestyle for compute.sh for docker-compose (#12128) ## Problem The script `compute.sh` had a non-consistent coding style and didn't follow best practices for modern bash scripts ## Summary of changes The coding style was fixed to follow best practices. --- .../compute_wrapper/shell/compute.sh | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/docker-compose/compute_wrapper/shell/compute.sh b/docker-compose/compute_wrapper/shell/compute.sh index ab8d74d355..c8ca812bf9 100755 --- a/docker-compose/compute_wrapper/shell/compute.sh +++ b/docker-compose/compute_wrapper/shell/compute.sh @@ -1,18 +1,18 @@ -#!/bin/bash +#!/usr/bin/env bash set -eux # Generate a random tenant or timeline ID # # Takes a variable name as argument. The result is stored in that variable. generate_id() { - local -n resvar=$1 - printf -v resvar '%08x%08x%08x%08x' $SRANDOM $SRANDOM $SRANDOM $SRANDOM + local -n resvar=${1} + printf -v resvar '%08x%08x%08x%08x' ${SRANDOM} ${SRANDOM} ${SRANDOM} ${SRANDOM} } PG_VERSION=${PG_VERSION:-14} -CONFIG_FILE_ORG=/var/db/postgres/configs/config.json -CONFIG_FILE=/tmp/config.json +readonly CONFIG_FILE_ORG=/var/db/postgres/configs/config.json +readonly CONFIG_FILE=/tmp/config.json # Test that the first library path that the dynamic loader looks in is the path # that we use for custom compiled software @@ -20,17 +20,17 @@ first_path="$(ldconfig --verbose 2>/dev/null \ | grep --invert-match ^$'\t' \ | cut --delimiter=: --fields=1 \ | head --lines=1)" -test "$first_path" == '/usr/local/lib' +test "${first_path}" = '/usr/local/lib' echo "Waiting pageserver become ready." while ! nc -z pageserver 6400; do - sleep 1; + sleep 1 done echo "Page server is ready." -cp ${CONFIG_FILE_ORG} ${CONFIG_FILE} +cp "${CONFIG_FILE_ORG}" "${CONFIG_FILE}" - if [ -n "${TENANT_ID:-}" ] && [ -n "${TIMELINE_ID:-}" ]; then + if [[ -n "${TENANT_ID:-}" && -n "${TIMELINE_ID:-}" ]]; then tenant_id=${TENANT_ID} timeline_id=${TIMELINE_ID} else @@ -41,7 +41,7 @@ else "http://pageserver:9898/v1/tenant" ) tenant_id=$(curl "${PARAMS[@]}" | jq -r .[0].id) - if [ -z "${tenant_id}" ] || [ "${tenant_id}" = null ]; then + if [[ -z "${tenant_id}" || "${tenant_id}" = null ]]; then echo "Create a tenant" generate_id tenant_id PARAMS=( @@ -51,7 +51,7 @@ else "http://pageserver:9898/v1/tenant/${tenant_id}/location_config" ) result=$(curl "${PARAMS[@]}") - echo $result | jq . + printf '%s\n' "${result}" | jq . fi echo "Check if a timeline present" @@ -61,7 +61,7 @@ else "http://pageserver:9898/v1/tenant/${tenant_id}/timeline" ) timeline_id=$(curl "${PARAMS[@]}" | jq -r .[0].timeline_id) - if [ -z "${timeline_id}" ] || [ "${timeline_id}" = null ]; then + if [[ -z "${timeline_id}" || "${timeline_id}" = null ]]; then generate_id timeline_id PARAMS=( -sbf @@ -71,7 +71,7 @@ else "http://pageserver:9898/v1/tenant/${tenant_id}/timeline/" ) result=$(curl "${PARAMS[@]}") - echo $result | jq . + printf '%s\n' "${result}" | jq . fi fi @@ -82,10 +82,10 @@ else fi echo "Adding pgx_ulid" shared_libraries=$(jq -r '.spec.cluster.settings[] | select(.name=="shared_preload_libraries").value' ${CONFIG_FILE}) -sed -i "s/${shared_libraries}/${shared_libraries},${ulid_extension}/" ${CONFIG_FILE} +sed -i "s|${shared_libraries}|${shared_libraries},${ulid_extension}|" ${CONFIG_FILE} echo "Overwrite tenant id and timeline id in spec file" -sed -i "s/TENANT_ID/${tenant_id}/" ${CONFIG_FILE} -sed -i "s/TIMELINE_ID/${timeline_id}/" ${CONFIG_FILE} +sed -i "s|TENANT_ID|${tenant_id}|" ${CONFIG_FILE} +sed -i "s|TIMELINE_ID|${timeline_id}|" ${CONFIG_FILE} cat ${CONFIG_FILE} @@ -93,5 +93,5 @@ echo "Start compute node" /usr/local/bin/compute_ctl --pgdata /var/db/postgres/compute \ -C "postgresql://cloud_admin@localhost:55433/postgres" \ -b /usr/local/bin/postgres \ - --compute-id "compute-$RANDOM" \ - --config "$CONFIG_FILE" + --compute-id "compute-${RANDOM}" \ + --config "${CONFIG_FILE}" From e7d6f525b3fa4f3fb849b9d515f39b0d2e09bf7e Mon Sep 17 00:00:00 2001 From: Erik Grinaker Date: Wed, 4 Jun 2025 17:14:17 +0200 Subject: [PATCH 28/43] pageserver: support `get_vectored_concurrent_io` with gRPC (#12131) ## Problem The gRPC page service doesn't respect `get_vectored_concurrent_io` and always uses sequential IO. ## Summary of changes Spawn a sidecar task for concurrent IO when enabled. Cancellation will be addressed separately. --- pageserver/src/bin/pageserver.rs | 1 + pageserver/src/page_service.rs | 20 ++++++++++++++------ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index 337aa135dc..be7e634d4c 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -819,6 +819,7 @@ fn start_pageserver( tenant_manager.clone(), grpc_auth, otel_guard.as_ref().map(|g| g.dispatch.clone()), + conf.get_vectored_concurrent_io, grpc_listener, )?); } diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index 77e5f0a92b..4a1ddf09b5 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -178,6 +178,7 @@ pub fn spawn_grpc( tenant_manager: Arc, auth: Option>, perf_trace_dispatch: Option, + get_vectored_concurrent_io: GetVectoredConcurrentIo, listener: std::net::TcpListener, ) -> anyhow::Result { let cancel = CancellationToken::new(); @@ -214,6 +215,8 @@ pub fn spawn_grpc( let page_service_handler = GrpcPageServiceHandler { tenant_manager, ctx, + gate_guard: gate.enter().expect("gate was just created"), + get_vectored_concurrent_io, }; let observability_layer = ObservabilityLayer; @@ -497,10 +500,6 @@ async fn page_service_conn_main( } /// Page service connection handler. -/// -/// TODO: for gRPC, this will be shared by all requests from all connections. -/// Decompose it into global state and per-connection/request state, and make -/// libpq-specific options (e.g. pipelining) separate. struct PageServerHandler { auth: Option>, claims: Option, @@ -3362,6 +3361,8 @@ where pub struct GrpcPageServiceHandler { tenant_manager: Arc, ctx: RequestContext, + gate_guard: GateGuard, + get_vectored_concurrent_io: GetVectoredConcurrentIo, } impl GrpcPageServiceHandler { @@ -3721,6 +3722,14 @@ impl proto::PageService for GrpcPageServiceHandler { .get(ttid.tenant_id, ttid.timeline_id, shard_selector) .await?; + // Spawn an IoConcurrency sidecar, if enabled. + let Ok(gate_guard) = self.gate_guard.try_clone() else { + return Err(tonic::Status::unavailable("shutting down")); + }; + let io_concurrency = + IoConcurrency::spawn_from_conf(self.get_vectored_concurrent_io, gate_guard); + + // Spawn a task to handle the GetPageRequest stream. let span = Span::current(); let ctx = self.ctx.attached_child(); let mut reqs = req.into_inner(); @@ -3731,8 +3740,7 @@ impl proto::PageService for GrpcPageServiceHandler { .await? .downgrade(); while let Some(req) = reqs.message().await? { - // TODO: implement IoConcurrency sidecar. - yield Self::get_page(&ctx, &timeline, req, IoConcurrency::Sequential) + yield Self::get_page(&ctx, &timeline, req, io_concurrency.clone()) .instrument(span.clone()) // propagate request span .await? } From 3fd5a94a853604d26ff69acaf88c35f765e51698 Mon Sep 17 00:00:00 2001 From: Tristan Partin Date: Wed, 4 Jun 2025 10:56:12 -0500 Subject: [PATCH 29/43] Use Url::join() when creating the final remote extension URL (#12121) Url::to_string() adds a trailing slash on the base URL, so when we did the format!(), we were adding a double forward slash. Signed-off-by: Tristan Partin --- compute_tools/src/bin/compute_ctl.rs | 67 ++++++++++++++- compute_tools/src/extension_server.rs | 16 ++-- libs/compute_api/src/spec.rs | 85 ++++++++++++++----- .../regress/test_download_extensions.py | 7 +- 4 files changed, 140 insertions(+), 35 deletions(-) diff --git a/compute_tools/src/bin/compute_ctl.rs b/compute_tools/src/bin/compute_ctl.rs index db6835da61..8b502a058e 100644 --- a/compute_tools/src/bin/compute_ctl.rs +++ b/compute_tools/src/bin/compute_ctl.rs @@ -40,7 +40,7 @@ use std::sync::mpsc; use std::thread; use std::time::Duration; -use anyhow::{Context, Result}; +use anyhow::{Context, Result, bail}; use clap::Parser; use compute_api::responses::ComputeConfig; use compute_tools::compute::{ @@ -57,14 +57,14 @@ use tracing::{error, info}; use url::Url; use utils::failpoint_support; -#[derive(Parser)] +#[derive(Debug, Parser)] #[command(rename_all = "kebab-case")] struct Cli { #[arg(short = 'b', long, default_value = "postgres", env = "POSTGRES_PATH")] pub pgbin: String, /// The base URL for the remote extension storage proxy gateway. - #[arg(short = 'r', long)] + #[arg(short = 'r', long, value_parser = Self::parse_remote_ext_base_url)] pub remote_ext_base_url: Option, /// The port to bind the external listening HTTP server to. Clients running @@ -126,6 +126,25 @@ struct Cli { pub installed_extensions_collection_interval: u64, } +impl Cli { + /// Parse a URL from an argument. By default, this isn't necessary, but we + /// want to do some sanity checking. + fn parse_remote_ext_base_url(value: &str) -> Result { + // Remove extra trailing slashes, and add one. We use Url::join() later + // when downloading remote extensions. If the base URL is something like + // http://example.com/pg-ext-s3-gateway, and join() is called with + // something like "xyz", the resulting URL is http://example.com/xyz. + let value = value.trim_end_matches('/').to_owned() + "/"; + let url = Url::parse(&value)?; + + if url.query_pairs().count() != 0 { + bail!("parameters detected in remote extensions base URL") + } + + Ok(url) + } +} + fn main() -> Result<()> { let cli = Cli::parse(); @@ -252,7 +271,8 @@ fn handle_exit_signal(sig: i32) { #[cfg(test)] mod test { - use clap::CommandFactory; + use clap::{CommandFactory, Parser}; + use url::Url; use super::Cli; @@ -260,4 +280,43 @@ mod test { fn verify_cli() { Cli::command().debug_assert() } + + #[test] + fn verify_remote_ext_base_url() { + let cli = Cli::parse_from([ + "compute_ctl", + "--pgdata=test", + "--connstr=test", + "--compute-id=test", + "--remote-ext-base-url", + "https://example.com/subpath", + ]); + assert_eq!( + cli.remote_ext_base_url.unwrap(), + Url::parse("https://example.com/subpath/").unwrap() + ); + + let cli = Cli::parse_from([ + "compute_ctl", + "--pgdata=test", + "--connstr=test", + "--compute-id=test", + "--remote-ext-base-url", + "https://example.com//", + ]); + assert_eq!( + cli.remote_ext_base_url.unwrap(), + Url::parse("https://example.com").unwrap() + ); + + Cli::try_parse_from([ + "compute_ctl", + "--pgdata=test", + "--connstr=test", + "--compute-id=test", + "--remote-ext-base-url", + "https://example.com?hello=world", + ]) + .expect_err("URL parameters are not allowed"); + } } diff --git a/compute_tools/src/extension_server.rs b/compute_tools/src/extension_server.rs index 1857afa08c..3764bc1525 100644 --- a/compute_tools/src/extension_server.rs +++ b/compute_tools/src/extension_server.rs @@ -166,7 +166,7 @@ pub async fn download_extension( // TODO add retry logic let download_buffer = - match download_extension_tar(remote_ext_base_url.as_str(), &ext_path.to_string()).await { + match download_extension_tar(remote_ext_base_url, &ext_path.to_string()).await { Ok(buffer) => buffer, Err(error_message) => { return Err(anyhow::anyhow!( @@ -271,10 +271,14 @@ pub fn create_control_files(remote_extensions: &RemoteExtSpec, pgbin: &str) { } // Do request to extension storage proxy, e.g., -// curl http://pg-ext-s3-gateway/latest/v15/extensions/anon.tar.zst +// curl http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local/latest/v15/extensions/anon.tar.zst // using HTTP GET and return the response body as bytes. -async fn download_extension_tar(remote_ext_base_url: &str, ext_path: &str) -> Result { - let uri = format!("{}/{}", remote_ext_base_url, ext_path); +async fn download_extension_tar(remote_ext_base_url: &Url, ext_path: &str) -> Result { + let uri = remote_ext_base_url.join(ext_path).with_context(|| { + format!( + "failed to create the remote extension URI for {ext_path} using {remote_ext_base_url}" + ) + })?; let filename = Path::new(ext_path) .file_name() .unwrap_or_else(|| std::ffi::OsStr::new("unknown")) @@ -284,7 +288,7 @@ async fn download_extension_tar(remote_ext_base_url: &str, ext_path: &str) -> Re info!("Downloading extension file '{}' from uri {}", filename, uri); - match do_extension_server_request(&uri).await { + match do_extension_server_request(uri).await { Ok(resp) => { info!("Successfully downloaded remote extension data {}", ext_path); REMOTE_EXT_REQUESTS_TOTAL @@ -303,7 +307,7 @@ async fn download_extension_tar(remote_ext_base_url: &str, ext_path: &str) -> Re // Do a single remote extensions server request. // Return result or (error message + stringified status code) in case of any failures. -async fn do_extension_server_request(uri: &str) -> Result { +async fn do_extension_server_request(uri: Url) -> Result { let resp = reqwest::get(uri).await.map_err(|e| { ( format!( diff --git a/libs/compute_api/src/spec.rs b/libs/compute_api/src/spec.rs index 9b7cf43bb9..343923d446 100644 --- a/libs/compute_api/src/spec.rs +++ b/libs/compute_api/src/spec.rs @@ -250,34 +250,44 @@ impl RemoteExtSpec { } match self.extension_data.get(real_ext_name) { - Some(_ext_data) => { - // We have decided to use the Go naming convention due to Kubernetes. - - let arch = match std::env::consts::ARCH { - "x86_64" => "amd64", - "aarch64" => "arm64", - arch => arch, - }; - - // Construct the path to the extension archive - // BUILD_TAG/PG_MAJOR_VERSION/extensions/EXTENSION_NAME.tar.zst - // - // Keep it in sync with path generation in - // https://github.com/neondatabase/build-custom-extensions/tree/main - let archive_path_str = format!( - "{build_tag}/{arch}/{pg_major_version}/extensions/{real_ext_name}.tar.zst" - ); - Ok(( - real_ext_name.to_string(), - RemotePath::from_string(&archive_path_str)?, - )) - } + Some(_ext_data) => Ok(( + real_ext_name.to_string(), + Self::build_remote_path(build_tag, pg_major_version, real_ext_name)?, + )), None => Err(anyhow::anyhow!( "real_ext_name {} is not found", real_ext_name )), } } + + /// Get the architecture-specific portion of the remote extension path. We + /// use the Go naming convention due to Kubernetes. + fn get_arch() -> &'static str { + match std::env::consts::ARCH { + "x86_64" => "amd64", + "aarch64" => "arm64", + arch => arch, + } + } + + /// Build a [`RemotePath`] for an extension. + fn build_remote_path( + build_tag: &str, + pg_major_version: &str, + ext_name: &str, + ) -> anyhow::Result { + let arch = Self::get_arch(); + + // Construct the path to the extension archive + // BUILD_TAG/PG_MAJOR_VERSION/extensions/EXTENSION_NAME.tar.zst + // + // Keep it in sync with path generation in + // https://github.com/neondatabase/build-custom-extensions/tree/main + RemotePath::from_string(&format!( + "{build_tag}/{arch}/{pg_major_version}/extensions/{ext_name}.tar.zst" + )) + } } #[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Deserialize, Serialize)] @@ -518,6 +528,37 @@ mod tests { .expect("Library should be found"); } + #[test] + fn remote_extension_path() { + let rspec: RemoteExtSpec = serde_json::from_value(serde_json::json!({ + "public_extensions": ["ext"], + "custom_extensions": [], + "library_index": { + "extlib": "ext", + }, + "extension_data": { + "ext": { + "control_data": { + "ext.control": "" + }, + "archive_path": "" + } + }, + })) + .unwrap(); + + let (_ext_name, ext_path) = rspec + .get_ext("ext", false, "latest", "v17") + .expect("Extension should be found"); + // Starting with a forward slash would have consequences for the + // Url::join() that occurs when downloading a remote extension. + assert!(!ext_path.to_string().starts_with("/")); + assert_eq!( + ext_path, + RemoteExtSpec::build_remote_path("latest", "v17", "ext").unwrap() + ); + } + #[test] fn parse_spec_file() { let file = File::open("tests/cluster_spec.json").unwrap(); diff --git a/test_runner/regress/test_download_extensions.py b/test_runner/regress/test_download_extensions.py index 24ba0713d2..fe3b220c67 100644 --- a/test_runner/regress/test_download_extensions.py +++ b/test_runner/regress/test_download_extensions.py @@ -159,7 +159,8 @@ def test_remote_extensions( # Setup a mock nginx S3 gateway which will return our test extension. (host, port) = httpserver_listen_address - extensions_endpoint = f"http://{host}:{port}/pg-ext-s3-gateway" + remote_ext_base_url = f"http://{host}:{port}/pg-ext-s3-gateway" + log.info(f"remote extensions base URL: {remote_ext_base_url}") extension.build(pg_config, test_output_dir) tarball = extension.package(test_output_dir) @@ -221,7 +222,7 @@ def test_remote_extensions( endpoint.create_remote_extension_spec(spec) - endpoint.start(remote_ext_base_url=extensions_endpoint) + endpoint.start(remote_ext_base_url=remote_ext_base_url) with endpoint.connect() as conn: with conn.cursor() as cur: @@ -249,7 +250,7 @@ def test_remote_extensions( # Remove the extension files to force a redownload of the extension. extension.remove(test_output_dir, pg_version) - endpoint.start(remote_ext_base_url=extensions_endpoint) + endpoint.start(remote_ext_base_url=remote_ext_base_url) # Test that ALTER EXTENSION UPDATE statements also fetch remote extensions. with endpoint.connect() as conn: From 838622c5944d22dbb645a93fec915a465e9242f8 Mon Sep 17 00:00:00 2001 From: Suhas Thalanki <54014218+thesuhas@users.noreply.github.com> Date: Wed, 4 Jun 2025 14:03:59 -0400 Subject: [PATCH 30/43] compute: Add manifest.yml for default Postgres configuration settings (#11820) Adds a `manifest.yml` file that contains the default settings for compute. Currently, it comes from cplane code [here](https://github.com/neondatabase/cloud/blob/0cda3d4b014765c4e2e551a1d56efc095b004e77/goapp/controlplane/internal/pkg/compute/computespec/pg_settings.go#L110). Related RFC: https://github.com/neondatabase/neon/blob/main/docs/rfcs/038-independent-compute-release.md Related Issue: https://github.com/neondatabase/cloud/issues/11698 --- compute/manifest.yaml | 121 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 compute/manifest.yaml diff --git a/compute/manifest.yaml b/compute/manifest.yaml new file mode 100644 index 0000000000..f1cd20c497 --- /dev/null +++ b/compute/manifest.yaml @@ -0,0 +1,121 @@ +pg_settings: + # Common settings for primaries and replicas of all versions. + common: + # Check for client disconnection every 1 minute. By default, Postgres will detect the + # loss of the connection only at the next interaction with the socket, when it waits + # for, receives or sends data, so it will likely waste resources till the end of the + # query execution. There should be no drawbacks in setting this for everyone, so enable + # it by default. If anyone will complain, we can allow editing it. + # https://www.postgresql.org/docs/16/runtime-config-connection.html#GUC-CLIENT-CONNECTION-CHECK-INTERVAL + client_connection_check_interval: "60000" # 1 minute + # ---- IO ---- + effective_io_concurrency: "20" + maintenance_io_concurrency: "100" + fsync: "off" + hot_standby: "off" + # We allow users to change this if needed, but by default we + # just don't want to see long-lasting idle transactions, as they + # prevent activity monitor from suspending projects. + idle_in_transaction_session_timeout: "300000" # 5 minutes + listen_addresses: "*" + # --- LOGGING ---- helps investigations + log_connections: "on" + log_disconnections: "on" + # 1GB, unit is KB + log_temp_files: "1048576" + # Disable dumping customer data to logs, both to increase data privacy + # and to reduce the amount the logs. + log_error_verbosity: "terse" + log_min_error_statement: "panic" + max_connections: "100" + # --- WAL --- + # - flush lag is the max amount of WAL that has been generated but not yet stored + # to disk in the page server. A smaller value means less delay after a pageserver + # restart, but if you set it too small you might again need to slow down writes if the + # pageserver cannot flush incoming WAL to disk fast enough. This must be larger + # than the pageserver's checkpoint interval, currently 1 GB! Otherwise you get a + # a deadlock where the compute node refuses to generate more WAL before the + # old WAL has been uploaded to S3, but the pageserver is waiting for more WAL + # to be generated before it is uploaded to S3. + max_replication_flush_lag: "10GB" + max_replication_slots: "10" + # Backpressure configuration: + # - write lag is the max amount of WAL that has been generated by Postgres but not yet + # processed by the page server. Making this smaller reduces the worst case latency + # of a GetPage request, if you request a page that was recently modified. On the other + # hand, if this is too small, the compute node might need to wait on a write if there is a + # hiccup in the network or page server so that the page server has temporarily fallen + # behind. + # + # Previously it was set to 500 MB, but it caused compute being unresponsive under load + # https://github.com/neondatabase/neon/issues/2028 + max_replication_write_lag: "500MB" + max_wal_senders: "10" + # A Postgres checkpoint is cheap in storage, as doesn't involve any significant amount + # of real I/O. Only the SLRU buffers and some other small files are flushed to disk. + # However, as long as we have full_page_writes=on, page updates after a checkpoint + # include full-page images which bloats the WAL. So may want to bump max_wal_size to + # reduce the WAL bloating, but at the same it will increase pg_wal directory size on + # compute and can lead to out of disk error on k8s nodes. + max_wal_size: "1024" + wal_keep_size: "0" + wal_level: "replica" + # Reduce amount of WAL generated by default. + wal_log_hints: "off" + # - without wal_sender_timeout set we don't get feedback messages, + # required for backpressure. + wal_sender_timeout: "10000" + # We have some experimental extensions, which we don't want users to install unconsciously. + # To install them, users would need to set the `neon.allow_unstable_extensions` setting. + # There are two of them currently: + # - `pgrag` - https://github.com/neondatabase-labs/pgrag - extension is actually called just `rag`, + # and two dependencies: + # - `rag_bge_small_en_v15` + # - `rag_jina_reranker_v1_tiny_en` + # - `pg_mooncake` - https://github.com/Mooncake-Labs/pg_mooncake/ + neon.unstable_extensions: "rag,rag_bge_small_en_v15,rag_jina_reranker_v1_tiny_en,pg_mooncake,anon" + neon.protocol_version: "3" + password_encryption: "scram-sha-256" + # This is important to prevent Postgres from trying to perform + # a local WAL redo after backend crash. It should exit and let + # the systemd or k8s to do a fresh startup with compute_ctl. + restart_after_crash: "off" + # By default 3. We have the following persistent connections in the VM: + # * compute_activity_monitor (from compute_ctl) + # * postgres-exporter (metrics collector; it has 2 connections) + # * sql_exporter (metrics collector; we have 2 instances [1 for us & users; 1 for autoscaling]) + # * vm-monitor (to query & change file cache size) + # i.e. total of 6. Let's reserve 7, so there's still at least one left over. + superuser_reserved_connections: "7" + synchronous_standby_names: "walproposer" + + replica: + hot_standby: "on" + + per_version: + 17: + common: + # PostgreSQL 17 has a new IO system called "read stream", which can combine IOs up to some + # size. It still has some issues with readahead, though, so we default to disabled/ + # "no combining of IOs" to make sure we get the maximum prefetch depth. + # See also: https://github.com/neondatabase/neon/pull/9860 + io_combine_limit: "1" + replica: + # prefetching of blocks referenced in WAL doesn't make sense for us + # Neon hot standby ignores pages that are not in the shared_buffers + recovery_prefetch: "off" + 16: + common: + replica: + # prefetching of blocks referenced in WAL doesn't make sense for us + # Neon hot standby ignores pages that are not in the shared_buffers + recovery_prefetch: "off" + 15: + common: + replica: + # prefetching of blocks referenced in WAL doesn't make sense for us + # Neon hot standby ignores pages that are not in the shared_buffers + recovery_prefetch: "off" + 14: + common: + replica: From 1fb1315aedc8816aa038520b2275b1434a14418f Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Wed, 4 Jun 2025 21:07:47 +0100 Subject: [PATCH 31/43] compute-ctl: add spec for enable_tls, separate from compute-ctl config (#12109) ## Problem Inbetween adding the TLS config for compute-ctl, and adding the TLS config in controlplane, we switched from using a provision flag to a bind flag. This happened to work in all of my testing in preview regions as they have no VM pool, so each bind was also a provision. However, in staging I found that the TLS config is still only processed during provision, even though it's only sent on bind. ## Summary of changes * Add a new feature flag value, `tls_experimental`, which tells postgres/pgbouncer/local_proxy to use the TLS certificates on bind. * compute_ctl on provision will be told where the certificates are, instead of being told on bind. --- compute_tools/src/compute.rs | 41 +++++++++++++++++----- compute_tools/src/http/routes/configure.rs | 2 +- libs/compute_api/src/spec.rs | 3 ++ 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index 1ed71180d0..bd6ed910be 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -3,7 +3,7 @@ use chrono::{DateTime, Utc}; use compute_api::privilege::Privilege; use compute_api::responses::{ ComputeConfig, ComputeCtlConfig, ComputeMetrics, ComputeStatus, LfcOffloadState, - LfcPrewarmState, + LfcPrewarmState, TlsConfig, }; use compute_api::spec::{ ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, PgIdent, @@ -603,6 +603,8 @@ impl ComputeNode { }); } + let tls_config = self.tls_config(&pspec.spec); + // If there are any remote extensions in shared_preload_libraries, start downloading them if pspec.spec.remote_extensions.is_some() { let (this, spec) = (self.clone(), pspec.spec.clone()); @@ -659,7 +661,7 @@ impl ComputeNode { info!("tuning pgbouncer"); let pgbouncer_settings = pgbouncer_settings.clone(); - let tls_config = self.compute_ctl_config.tls.clone(); + let tls_config = tls_config.clone(); // Spawn a background task to do the tuning, // so that we don't block the main thread that starts Postgres. @@ -678,7 +680,10 @@ impl ComputeNode { // Spawn a background task to do the configuration, // so that we don't block the main thread that starts Postgres. - let local_proxy = local_proxy.clone(); + + let mut local_proxy = local_proxy.clone(); + local_proxy.tls = tls_config.clone(); + let _handle = tokio::spawn(async move { if let Err(err) = local_proxy::configure(&local_proxy) { error!("error while configuring local_proxy: {err:?}"); @@ -1205,13 +1210,15 @@ impl ComputeNode { let spec = &pspec.spec; let pgdata_path = Path::new(&self.params.pgdata); + let tls_config = self.tls_config(&pspec.spec); + // Remove/create an empty pgdata directory and put configuration there. self.create_pgdata()?; config::write_postgres_conf( pgdata_path, &pspec.spec, self.params.internal_http_port, - &self.compute_ctl_config.tls, + tls_config, )?; // Syncing safekeepers is only safe with primary nodes: if a primary @@ -1536,14 +1543,22 @@ impl ComputeNode { .clone(), ); + let mut tls_config = None::; + if spec.features.contains(&ComputeFeature::TlsExperimental) { + tls_config = self.compute_ctl_config.tls.clone(); + } + let max_concurrent_connections = self.max_service_connections(compute_state, &spec); // Merge-apply spec & changes to PostgreSQL state. self.apply_spec_sql(spec.clone(), conf.clone(), max_concurrent_connections)?; if let Some(local_proxy) = &spec.clone().local_proxy_config { + let mut local_proxy = local_proxy.clone(); + local_proxy.tls = tls_config.clone(); + info!("configuring local_proxy"); - local_proxy::configure(local_proxy).context("apply_config local_proxy")?; + local_proxy::configure(&local_proxy).context("apply_config local_proxy")?; } // Run migrations separately to not hold up cold starts @@ -1595,11 +1610,13 @@ impl ComputeNode { pub fn reconfigure(&self) -> Result<()> { let spec = self.state.lock().unwrap().pspec.clone().unwrap().spec; + let tls_config = self.tls_config(&spec); + if let Some(ref pgbouncer_settings) = spec.pgbouncer_settings { info!("tuning pgbouncer"); let pgbouncer_settings = pgbouncer_settings.clone(); - let tls_config = self.compute_ctl_config.tls.clone(); + let tls_config = tls_config.clone(); // Spawn a background task to do the tuning, // so that we don't block the main thread that starts Postgres. @@ -1617,7 +1634,7 @@ impl ComputeNode { // Spawn a background task to do the configuration, // so that we don't block the main thread that starts Postgres. let mut local_proxy = local_proxy.clone(); - local_proxy.tls = self.compute_ctl_config.tls.clone(); + local_proxy.tls = tls_config.clone(); tokio::spawn(async move { if let Err(err) = local_proxy::configure(&local_proxy) { error!("error while configuring local_proxy: {err:?}"); @@ -1635,7 +1652,7 @@ impl ComputeNode { pgdata_path, &spec, self.params.internal_http_port, - &self.compute_ctl_config.tls, + tls_config, )?; if !spec.skip_pg_catalog_updates { @@ -1755,6 +1772,14 @@ impl ComputeNode { } } + pub fn tls_config(&self, spec: &ComputeSpec) -> &Option { + if spec.features.contains(&ComputeFeature::TlsExperimental) { + &self.compute_ctl_config.tls + } else { + &None:: + } + } + /// Update the `last_active` in the shared state, but ensure that it's a more recent one. pub fn update_last_active(&self, last_active: Option>) { let mut state = self.state.lock().unwrap(); diff --git a/compute_tools/src/http/routes/configure.rs b/compute_tools/src/http/routes/configure.rs index f7a19da611..c29e3a97da 100644 --- a/compute_tools/src/http/routes/configure.rs +++ b/compute_tools/src/http/routes/configure.rs @@ -22,7 +22,7 @@ pub(in crate::http) async fn configure( State(compute): State>, request: Json, ) -> Response { - let pspec = match ParsedSpec::try_from(request.spec.clone()) { + let pspec = match ParsedSpec::try_from(request.0.spec) { Ok(p) => p, Err(e) => return JsonResponse::error(StatusCode::BAD_REQUEST, e), }; diff --git a/libs/compute_api/src/spec.rs b/libs/compute_api/src/spec.rs index 343923d446..0e23b70265 100644 --- a/libs/compute_api/src/spec.rs +++ b/libs/compute_api/src/spec.rs @@ -192,6 +192,9 @@ pub enum ComputeFeature { /// track short-lived connections as user activity. ActivityMonitorExperimental, + /// Enable TLS functionality. + TlsExperimental, + /// This is a special feature flag that is used to represent unknown feature flags. /// Basically all unknown to enum flags are represented as this one. See unit test /// `parse_unknown_features()` for more details. From dae203ef692415428977d49bdf116c3a5fed344e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arpad=20M=C3=BCller?= Date: Thu, 5 Jun 2025 01:02:31 +0200 Subject: [PATCH 32/43] pgxn: support generations in safekeepers_cmp (#12129) `safekeepers_cmp` was added by #8840 to make changes of the safekeeper set order independent: a `sk1,sk2,sk3` specifier changed to `sk3,sk1,sk2` should not cause a walproposer restart. However, this check didn't support generations, in the sense that it would see the `g#123:` as part of the first safekeeper in the list, and if the first safekeeper changes, it would also restart the walproposer. Therefore, parse the generation properly and make it not be a part of the generation. This PR doesn't add a specific test, but I have confirmed locally that `test_safekeepers_reconfigure_reorder` is fixed with the changes of PR #11712 applied thanks to this PR. Part of https://github.com/neondatabase/neon/issues/11670 --- pgxn/neon/walproposer.c | 4 ++++ pgxn/neon/walproposer_pg.c | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/pgxn/neon/walproposer.c b/pgxn/neon/walproposer.c index f42103c7cd..c63edb1398 100644 --- a/pgxn/neon/walproposer.c +++ b/pgxn/neon/walproposer.c @@ -119,6 +119,10 @@ WalProposerCreate(WalProposerConfig *config, walproposer_api api) { wp_log(FATAL, "failed to parse neon.safekeepers generation number: %m"); } + if (*endptr != ':') + { + wp_log(FATAL, "failed to parse neon.safekeepers: no colon after generation"); + } /* Skip past : to the first hostname. */ host = endptr + 1; } diff --git a/pgxn/neon/walproposer_pg.c b/pgxn/neon/walproposer_pg.c index d15bf91d24..c90702a282 100644 --- a/pgxn/neon/walproposer_pg.c +++ b/pgxn/neon/walproposer_pg.c @@ -272,6 +272,30 @@ split_safekeepers_list(char *safekeepers_list, char *safekeepers[]) return n_safekeepers; } +static char *split_off_safekeepers_generation(char *safekeepers_list, uint32 *generation) +{ + char *endptr; + + if (strncmp(safekeepers_list, "g#", 2) != 0) + { + return safekeepers_list; + } + else + { + errno = 0; + *generation = strtoul(safekeepers_list + 2, &endptr, 10); + if (errno != 0) + { + wp_log(FATAL, "failed to parse neon.safekeepers generation number: %m"); + } + if (*endptr != ':') + { + wp_log(FATAL, "failed to parse neon.safekeepers: no colon after generation"); + } + return endptr + 1; + } +} + /* * Accept two coma-separated strings with list of safekeeper host:port addresses. * Split them into arrays and return false if two sets do not match, ignoring the order. @@ -283,6 +307,16 @@ safekeepers_cmp(char *old, char *new) char *safekeepers_new[MAX_SAFEKEEPERS]; int len_old = 0; int len_new = 0; + uint32 gen_old = INVALID_GENERATION; + uint32 gen_new = INVALID_GENERATION; + + old = split_off_safekeepers_generation(old, &gen_old); + new = split_off_safekeepers_generation(new, &gen_new); + + if (gen_old != gen_new) + { + return false; + } len_old = split_safekeepers_list(old, safekeepers_old); len_new = split_safekeepers_list(new, safekeepers_new); From 56d505bce6088f72e4291afee8ae8ac4fe44dbfb Mon Sep 17 00:00:00 2001 From: Konstantin Knizhnik Date: Thu, 5 Jun 2025 08:48:25 +0300 Subject: [PATCH 33/43] Update online_advisor (#12045) ## Problem Investigate crash of online_advisor in image check ## Summary of changes --------- Co-authored-by: Konstantin Knizhnik --- compute/compute-node.Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compute/compute-node.Dockerfile b/compute/compute-node.Dockerfile index 6ece9c6566..d3008c8085 100644 --- a/compute/compute-node.Dockerfile +++ b/compute/compute-node.Dockerfile @@ -603,7 +603,7 @@ RUN case "${PG_VERSION:?}" in \ ;; \ esac && \ wget https://github.com/knizhnik/online_advisor/archive/refs/tags/1.0.tar.gz -O online_advisor.tar.gz && \ - echo "059b7d9e5a90013a58bdd22e9505b88406ce05790675eb2d8434e5b215652d54 online_advisor.tar.gz" | sha256sum --check && \ + echo "37dcadf8f7cc8d6cc1f8831276ee245b44f1b0274f09e511e47a67738ba9ed0f online_advisor.tar.gz" | sha256sum --check && \ mkdir online_advisor-src && cd online_advisor-src && tar xzf ../online_advisor.tar.gz --strip-components=1 -C . FROM pg-build AS online_advisor-build From c8a96cf7223de04f208766a6d162205269a61030 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 5 Jun 2025 08:12:00 +0100 Subject: [PATCH 34/43] update proxy protocol parsing to not a rw wrapper (#12035) ## Problem I believe in all environments we now specify either required/rejected for proxy-protocol V2 as required. We no longer rely on the supported flow. This means we no longer need to keep around read bytes incase they're not in a header. While I designed ChainRW to be fast (the hot path with an empty buffer is very easy to branch predict), it's still unnecessary. ## Summary of changes * Remove the ChainRW wrapper * Refactor how we read the proxy-protocol header using read_exact. Slightly worse perf but it's hardly significant. * Don't try and parse the header if it's rejected. --- proxy/src/binary/proxy.rs | 3 +- proxy/src/config.rs | 2 - proxy/src/console_redirect_proxy.rs | 44 +++---- proxy/src/protocol2.rs | 177 +++++----------------------- proxy/src/proxy/mod.rs | 42 +++---- proxy/src/proxy/tests/mod.rs | 1 - proxy/src/serverless/mod.rs | 55 ++++----- 7 files changed, 94 insertions(+), 230 deletions(-) diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index dcae263647..757c1e988b 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -221,8 +221,7 @@ struct ProxyCliArgs { is_private_access_proxy: bool, /// Configure whether all incoming requests have a Proxy Protocol V2 packet. - // TODO(conradludgate): switch default to rejected or required once we've updated all deployments - #[clap(value_enum, long, default_value_t = ProxyProtocolV2::Supported)] + #[clap(value_enum, long, default_value_t = ProxyProtocolV2::Rejected)] proxy_protocol_v2: ProxyProtocolV2, /// Time the proxy waits for the webauth session to be confirmed by the control plane. diff --git a/proxy/src/config.rs b/proxy/src/config.rs index a97339df9a..248584a19a 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -39,8 +39,6 @@ pub struct ComputeConfig { pub enum ProxyProtocolV2 { /// Connection will error if PROXY protocol v2 header is missing Required, - /// Connection will parse PROXY protocol v2 header, but accept the connection if it's missing. - Supported, /// Connection will error if PROXY protocol v2 header is provided Rejected, } diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 7fb84b5ee5..6755499b45 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -54,30 +54,24 @@ pub async fn task_main( debug!(protocol = "tcp", %session_id, "accepted new TCP connection"); connections.spawn(async move { - let (socket, peer_addr) = match read_proxy_protocol(socket).await { - Err(e) => { - error!("per-client task finished with an error: {e:#}"); - return; + let (socket, conn_info) = match config.proxy_protocol_v2 { + ProxyProtocolV2::Required => { + match read_proxy_protocol(socket).await { + Err(e) => { + error!("per-client task finished with an error: {e:#}"); + return; + } + // our load balancers will not send any more data. let's just exit immediately + Ok((_socket, ConnectHeader::Local)) => { + debug!("healthcheck received"); + return; + } + Ok((socket, ConnectHeader::Proxy(info))) => (socket, info), + } } - // our load balancers will not send any more data. let's just exit immediately - Ok((_socket, ConnectHeader::Local)) => { - debug!("healthcheck received"); - return; - } - Ok((_socket, ConnectHeader::Missing)) - if config.proxy_protocol_v2 == ProxyProtocolV2::Required => - { - error!("missing required proxy protocol header"); - return; - } - Ok((_socket, ConnectHeader::Proxy(_))) - if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => - { - error!("proxy protocol header not supported"); - return; - } - Ok((socket, ConnectHeader::Proxy(info))) => (socket, info), - Ok((socket, ConnectHeader::Missing)) => ( + // ignore the header - it cannot be confused for a postgres or http connection so will + // error later. + ProxyProtocolV2::Rejected => ( socket, ConnectionInfo { addr: peer_addr, @@ -86,7 +80,7 @@ pub async fn task_main( ), }; - match socket.inner.set_nodelay(true) { + match socket.set_nodelay(true) { Ok(()) => {} Err(e) => { error!( @@ -98,7 +92,7 @@ pub async fn task_main( let ctx = RequestContext::new( session_id, - peer_addr, + conn_info, crate::metrics::Protocol::Tcp, &config.region, ); diff --git a/proxy/src/protocol2.rs b/proxy/src/protocol2.rs index 0793998639..5bec6d6ca3 100644 --- a/proxy/src/protocol2.rs +++ b/proxy/src/protocol2.rs @@ -4,60 +4,13 @@ use core::fmt; use std::io; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; -use std::pin::Pin; -use std::task::{Context, Poll}; -use bytes::{Buf, Bytes, BytesMut}; -use pin_project_lite::pin_project; +use bytes::Buf; use smol_str::SmolStr; use strum_macros::FromRepr; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf}; +use tokio::io::{AsyncRead, AsyncReadExt}; use zerocopy::{FromBytes, Immutable, KnownLayout, Unaligned, network_endian}; -pin_project! { - /// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough - pub(crate) struct ChainRW { - #[pin] - pub(crate) inner: T, - buf: BytesMut, - } -} - -impl AsyncWrite for ChainRW { - #[inline] - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self.project().inner.poll_write(cx, buf) - } - - #[inline] - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_flush(cx) - } - - #[inline] - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_shutdown(cx) - } - - #[inline] - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[io::IoSlice<'_>], - ) -> Poll> { - self.project().inner.poll_write_vectored(cx, bufs) - } - - #[inline] - fn is_write_vectored(&self) -> bool { - self.inner.is_write_vectored() - } -} - /// Proxy Protocol Version 2 Header const SIGNATURE: [u8; 12] = [ 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, @@ -79,7 +32,6 @@ pub struct ConnectionInfo { #[derive(PartialEq, Eq, Clone, Debug)] pub enum ConnectHeader { - Missing, Local, Proxy(ConnectionInfo), } @@ -106,47 +58,24 @@ pub enum ConnectionInfoExtra { pub(crate) async fn read_proxy_protocol( mut read: T, -) -> std::io::Result<(ChainRW, ConnectHeader)> { - let mut buf = BytesMut::with_capacity(128); - let header = loop { - let bytes_read = read.read_buf(&mut buf).await?; - - // exit for bad header signature - let len = usize::min(buf.len(), SIGNATURE.len()); - if buf[..len] != SIGNATURE[..len] { - return Ok((ChainRW { inner: read, buf }, ConnectHeader::Missing)); - } - - // if no more bytes available then exit - if bytes_read == 0 { - return Ok((ChainRW { inner: read, buf }, ConnectHeader::Missing)); - } - - // check if we have enough bytes to continue - if let Some(header) = buf.try_get::() { - break header; - } - }; - - let remaining_length = usize::from(header.len.get()); - - while buf.len() < remaining_length { - if read.read_buf(&mut buf).await? == 0 { - return Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - "stream closed while waiting for proxy protocol addresses", - )); - } +) -> std::io::Result<(T, ConnectHeader)> { + let mut header = [0; size_of::()]; + read.read_exact(&mut header).await?; + let header: ProxyProtocolV2Header = zerocopy::transmute!(header); + if header.signature != SIGNATURE { + return Err(std::io::Error::other("invalid proxy protocol header")); } - let payload = buf.split_to(remaining_length); - let res = process_proxy_payload(header, payload)?; - Ok((ChainRW { inner: read, buf }, res)) + let mut payload = vec![0; usize::from(header.len.get())]; + read.read_exact(&mut payload).await?; + + let res = process_proxy_payload(header, &payload)?; + Ok((read, res)) } fn process_proxy_payload( header: ProxyProtocolV2Header, - mut payload: BytesMut, + mut payload: &[u8], ) -> std::io::Result { match header.version_and_command { // the connection was established on purpose by the proxy @@ -162,13 +91,12 @@ fn process_proxy_payload( PROXY_V2 => {} // other values are unassigned and must not be emitted by senders. Receivers // must drop connections presenting unexpected values here. - #[rustfmt::skip] // https://github.com/rust-lang/rustfmt/issues/6384 - _ => return Err(io::Error::other( - format!( + _ => { + return Err(io::Error::other(format!( "invalid proxy protocol command 0x{:02X}. expected local (0x20) or proxy (0x21)", header.version_and_command - ), - )), + ))); + } } let size_err = @@ -206,7 +134,7 @@ fn process_proxy_payload( } let subtype = tlv.value.get_u8(); match Pp2AwsType::from_repr(subtype) { - Some(Pp2AwsType::VpceId) => match std::str::from_utf8(&tlv.value) { + Some(Pp2AwsType::VpceId) => match std::str::from_utf8(tlv.value) { Ok(s) => { extra = Some(ConnectionInfoExtra::Aws { vpce_id: s.into() }); } @@ -282,65 +210,28 @@ enum Pp2AzureType { PrivateEndpointLinkId = 0x01, } -impl AsyncRead for ChainRW { - #[inline] - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - if self.buf.is_empty() { - self.project().inner.poll_read(cx, buf) - } else { - self.read_from_buf(buf) - } - } -} - -impl ChainRW { - #[cold] - fn read_from_buf(self: Pin<&mut Self>, buf: &mut ReadBuf<'_>) -> Poll> { - debug_assert!(!self.buf.is_empty()); - let this = self.project(); - - let write = usize::min(this.buf.len(), buf.remaining()); - let slice = this.buf.split_to(write).freeze(); - buf.put_slice(&slice); - - // reset the allocation so it can be freed - if this.buf.is_empty() { - *this.buf = BytesMut::new(); - } - - Poll::Ready(Ok(())) - } -} - #[derive(Debug)] -struct Tlv { +struct Tlv<'a> { kind: u8, - value: Bytes, + value: &'a [u8], } -fn read_tlv(b: &mut BytesMut) -> Option { +fn read_tlv<'a>(b: &mut &'a [u8]) -> Option> { let tlv_header = b.try_get::()?; let len = usize::from(tlv_header.len.get()); - if b.len() < len { - return None; - } Some(Tlv { kind: tlv_header.kind, - value: b.split_to(len).freeze(), + value: b.split_off(..len)?, }) } trait BufExt: Sized { fn try_get(&mut self) -> Option; } -impl BufExt for BytesMut { +impl BufExt for &[u8] { fn try_get(&mut self) -> Option { - let (res, _) = T::read_from_prefix(self).ok()?; - self.advance(size_of::()); + let (res, rest) = T::read_from_prefix(self).ok()?; + *self = rest; Some(res) } } @@ -481,27 +372,19 @@ mod tests { } #[tokio::test] + #[should_panic = "invalid proxy protocol header"] async fn test_invalid() { let data = [0x55; 256]; - let (mut read, info) = read_proxy_protocol(data.as_slice()).await.unwrap(); - - let mut bytes = vec![]; - read.read_to_end(&mut bytes).await.unwrap(); - assert_eq!(bytes, data); - assert_eq!(info, ConnectHeader::Missing); + read_proxy_protocol(data.as_slice()).await.unwrap(); } #[tokio::test] + #[should_panic = "early eof"] async fn test_short() { let data = [0x55; 10]; - let (mut read, info) = read_proxy_protocol(data.as_slice()).await.unwrap(); - - let mut bytes = vec![]; - read.read_to_end(&mut bytes).await.unwrap(); - assert_eq!(bytes, data); - assert_eq!(info, ConnectHeader::Missing); + read_proxy_protocol(data.as_slice()).await.unwrap(); } #[tokio::test] diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 0ffc54aa88..477baff1c9 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -102,30 +102,24 @@ pub async fn task_main( let endpoint_rate_limiter2 = endpoint_rate_limiter.clone(); connections.spawn(async move { - let (socket, conn_info) = match read_proxy_protocol(socket).await { - Err(e) => { - warn!("per-client task finished with an error: {e:#}"); - return; + let (socket, conn_info) = match config.proxy_protocol_v2 { + ProxyProtocolV2::Required => { + match read_proxy_protocol(socket).await { + Err(e) => { + warn!("per-client task finished with an error: {e:#}"); + return; + } + // our load balancers will not send any more data. let's just exit immediately + Ok((_socket, ConnectHeader::Local)) => { + debug!("healthcheck received"); + return; + } + Ok((socket, ConnectHeader::Proxy(info))) => (socket, info), + } } - // our load balancers will not send any more data. let's just exit immediately - Ok((_socket, ConnectHeader::Local)) => { - debug!("healthcheck received"); - return; - } - Ok((_socket, ConnectHeader::Missing)) - if config.proxy_protocol_v2 == ProxyProtocolV2::Required => - { - warn!("missing required proxy protocol header"); - return; - } - Ok((_socket, ConnectHeader::Proxy(_))) - if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => - { - warn!("proxy protocol header not supported"); - return; - } - Ok((socket, ConnectHeader::Proxy(info))) => (socket, info), - Ok((socket, ConnectHeader::Missing)) => ( + // ignore the header - it cannot be confused for a postgres or http connection so will + // error later. + ProxyProtocolV2::Rejected => ( socket, ConnectionInfo { addr: peer_addr, @@ -134,7 +128,7 @@ pub async fn task_main( ), }; - match socket.inner.set_nodelay(true) { + match socket.set_nodelay(true) { Ok(()) => {} Err(e) => { error!( diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 61e8ee4a10..117c42e19c 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -173,7 +173,6 @@ async fn dummy_proxy( tls: Option, auth: impl TestAuth + Send, ) -> anyhow::Result<()> { - let (client, _) = read_proxy_protocol(client).await?; let mut stream = match handshake(&RequestContext::test(), client, tls.as_ref(), false).await? { HandshakeData::Startup(stream, _) => stream, HandshakeData::Cancel(_) => bail!("cancellation not supported"), diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index 2a7069b1c2..f6f681ac45 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -49,7 +49,7 @@ use crate::config::{ProxyConfig, ProxyProtocolV2}; use crate::context::RequestContext; use crate::ext::TaskExt; use crate::metrics::Metrics; -use crate::protocol2::{ChainRW, ConnectHeader, ConnectionInfo, read_proxy_protocol}; +use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol}; use crate::proxy::run_until_cancelled; use crate::rate_limiter::EndpointRateLimiter; use crate::serverless::backend::PoolingBackend; @@ -207,12 +207,12 @@ pub(crate) type AsyncRW = Pin>; #[async_trait] trait MaybeTlsAcceptor: Send + Sync + 'static { - async fn accept(&self, conn: ChainRW) -> std::io::Result; + async fn accept(&self, conn: TcpStream) -> std::io::Result; } #[async_trait] impl MaybeTlsAcceptor for &'static ArcSwapOption { - async fn accept(&self, conn: ChainRW) -> std::io::Result { + async fn accept(&self, conn: TcpStream) -> std::io::Result { match &*self.load() { Some(config) => Ok(Box::pin( TlsAcceptor::from(config.http_config.clone()) @@ -235,33 +235,30 @@ async fn connection_startup( peer_addr: SocketAddr, ) -> Option<(AsyncRW, ConnectionInfo)> { // handle PROXY protocol - let (conn, peer) = match read_proxy_protocol(conn).await { - Ok(c) => c, - Err(e) => { - tracing::warn!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}"); - return None; + let (conn, conn_info) = match config.proxy_protocol_v2 { + ProxyProtocolV2::Required => { + match read_proxy_protocol(conn).await { + Err(e) => { + warn!("per-client task finished with an error: {e:#}"); + return None; + } + // our load balancers will not send any more data. let's just exit immediately + Ok((_conn, ConnectHeader::Local)) => { + tracing::debug!("healthcheck received"); + return None; + } + Ok((conn, ConnectHeader::Proxy(info))) => (conn, info), + } } - }; - - let conn_info = match peer { - // our load balancers will not send any more data. let's just exit immediately - ConnectHeader::Local => { - tracing::debug!("healthcheck received"); - return None; - } - ConnectHeader::Missing if config.proxy_protocol_v2 == ProxyProtocolV2::Required => { - tracing::warn!("missing required proxy protocol header"); - return None; - } - ConnectHeader::Proxy(_) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => { - tracing::warn!("proxy protocol header not supported"); - return None; - } - ConnectHeader::Proxy(info) => info, - ConnectHeader::Missing => ConnectionInfo { - addr: peer_addr, - extra: None, - }, + // ignore the header - it cannot be confused for a postgres or http connection so will + // error later. + ProxyProtocolV2::Rejected => ( + conn, + ConnectionInfo { + addr: peer_addr, + extra: None, + }, + ), }; let has_private_peer_addr = match conn_info.addr.ip() { From d8ebd1d7717a4871cca5dda1b60ae4167e4bf514 Mon Sep 17 00:00:00 2001 From: "Alex Chi Z." <4198311+skyzh@users.noreply.github.com> Date: Thu, 5 Jun 2025 15:48:36 +0800 Subject: [PATCH 35/43] feat(pageserver): report tenant properties to posthog (#12113) ## Problem Part of https://github.com/neondatabase/neon/issues/11813 In PostHog UI, we need to create the properties before using them as a filter. We report all variants automatically when we start the pageserver. In the future, we can report all real tenants instead of fake tenants (we do that now to save money + we don't need real tenants in the UI). ## Summary of changes * Collect `region`, `availability_zone`, `pageserver_id` properties and use them in the feature evaluation. * Report 10 fake tenants on each pageserver startup. --------- Signed-off-by: Alex Chi Z --- .../src/background_loop.rs | 27 +++- libs/posthog_client_lite/src/lib.rs | 48 ++++++- pageserver/src/feature_resolver.rs | 123 ++++++++++++++++-- pageserver/src/http/routes.rs | 10 +- 4 files changed, 190 insertions(+), 18 deletions(-) diff --git a/libs/posthog_client_lite/src/background_loop.rs b/libs/posthog_client_lite/src/background_loop.rs index a05f6096b1..a404c76da9 100644 --- a/libs/posthog_client_lite/src/background_loop.rs +++ b/libs/posthog_client_lite/src/background_loop.rs @@ -6,7 +6,7 @@ use arc_swap::ArcSwap; use tokio_util::sync::CancellationToken; use tracing::{Instrument, info_span}; -use crate::{FeatureStore, PostHogClient, PostHogClientConfig}; +use crate::{CaptureEvent, FeatureStore, PostHogClient, PostHogClientConfig}; /// A background loop that fetches feature flags from PostHog and updates the feature store. pub struct FeatureResolverBackgroundLoop { @@ -24,9 +24,16 @@ impl FeatureResolverBackgroundLoop { } } - pub fn spawn(self: Arc, handle: &tokio::runtime::Handle, refresh_period: Duration) { + pub fn spawn( + self: Arc, + handle: &tokio::runtime::Handle, + refresh_period: Duration, + fake_tenants: Vec, + ) { let this = self.clone(); let cancel = self.cancel.clone(); + + // Main loop of updating the feature flags. handle.spawn( async move { tracing::info!("Starting PostHog feature resolver"); @@ -56,6 +63,22 @@ impl FeatureResolverBackgroundLoop { } .instrument(info_span!("posthog_feature_resolver")), ); + + // Report fake tenants to PostHog so that we have the combination of all the properties in the UI. + // Do one report per pageserver restart. + let this = self.clone(); + handle.spawn( + async move { + tracing::info!("Starting PostHog feature reporter"); + for tenant in &fake_tenants { + tracing::info!("Reporting fake tenant: {:?}", tenant); + } + if let Err(e) = this.posthog_client.capture_event_batch(&fake_tenants).await { + tracing::warn!("Cannot report fake tenants: {}", e); + } + } + .instrument(info_span!("posthog_feature_reporter")), + ); } pub fn feature_store(&self) -> Arc { diff --git a/libs/posthog_client_lite/src/lib.rs b/libs/posthog_client_lite/src/lib.rs index 281a8f8011..f607b1be0a 100644 --- a/libs/posthog_client_lite/src/lib.rs +++ b/libs/posthog_client_lite/src/lib.rs @@ -64,7 +64,7 @@ pub struct LocalEvaluationFlagFilterProperty { operator: String, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] #[serde(untagged)] pub enum PostHogFlagFilterPropertyValue { String(String), @@ -507,6 +507,13 @@ pub struct PostHogClient { client: reqwest::Client, } +#[derive(Serialize, Debug)] +pub struct CaptureEvent { + pub event: String, + pub distinct_id: String, + pub properties: serde_json::Value, +} + impl PostHogClient { pub fn new(config: PostHogClientConfig) -> Self { let client = reqwest::Client::new(); @@ -570,12 +577,12 @@ impl PostHogClient { &self, event: &str, distinct_id: &str, - properties: &HashMap, + properties: &serde_json::Value, ) -> anyhow::Result<()> { // PUBLIC_URL/capture/ - // with bearer token of self.client_api_key let url = format!("{}/capture/", self.config.public_api_url); - self.client + let response = self + .client .post(url) .body(serde_json::to_string(&json!({ "api_key": self.config.client_api_key, @@ -585,6 +592,39 @@ impl PostHogClient { }))?) .send() .await?; + let status = response.status(); + let body = response.text().await?; + if !status.is_success() { + return Err(anyhow::anyhow!( + "Failed to capture events: {}, {}", + status, + body + )); + } + Ok(()) + } + + pub async fn capture_event_batch(&self, events: &[CaptureEvent]) -> anyhow::Result<()> { + // PUBLIC_URL/batch/ + let url = format!("{}/batch/", self.config.public_api_url); + let response = self + .client + .post(url) + .body(serde_json::to_string(&json!({ + "api_key": self.config.client_api_key, + "batch": events, + }))?) + .send() + .await?; + let status = response.status(); + let body = response.text().await?; + if !status.is_success() { + return Err(anyhow::anyhow!( + "Failed to capture events: {}, {}", + status, + body + )); + } Ok(()) } } diff --git a/pageserver/src/feature_resolver.rs b/pageserver/src/feature_resolver.rs index 7463186c06..50de3b691c 100644 --- a/pageserver/src/feature_resolver.rs +++ b/pageserver/src/feature_resolver.rs @@ -1,8 +1,11 @@ use std::{collections::HashMap, sync::Arc, time::Duration}; use posthog_client_lite::{ - FeatureResolverBackgroundLoop, PostHogClientConfig, PostHogEvaluationError, + CaptureEvent, FeatureResolverBackgroundLoop, PostHogClientConfig, PostHogEvaluationError, + PostHogFlagFilterPropertyValue, }; +use remote_storage::RemoteStorageKind; +use serde_json::json; use tokio_util::sync::CancellationToken; use utils::id::TenantId; @@ -11,11 +14,15 @@ use crate::{config::PageServerConf, metrics::FEATURE_FLAG_EVALUATION}; #[derive(Clone)] pub struct FeatureResolver { inner: Option>, + internal_properties: Option>>, } impl FeatureResolver { pub fn new_disabled() -> Self { - Self { inner: None } + Self { + inner: None, + internal_properties: None, + } } pub fn spawn( @@ -36,14 +43,114 @@ impl FeatureResolver { shutdown_pageserver, ); let inner = Arc::new(inner); - // TODO: make this configurable - inner.clone().spawn(handle, Duration::from_secs(60)); - Ok(FeatureResolver { inner: Some(inner) }) + + // The properties shared by all tenants on this pageserver. + let internal_properties = { + let mut properties = HashMap::new(); + properties.insert( + "pageserver_id".to_string(), + PostHogFlagFilterPropertyValue::String(conf.id.to_string()), + ); + if let Some(availability_zone) = &conf.availability_zone { + properties.insert( + "availability_zone".to_string(), + PostHogFlagFilterPropertyValue::String(availability_zone.clone()), + ); + } + // Infer region based on the remote storage config. + if let Some(remote_storage) = &conf.remote_storage_config { + match &remote_storage.storage { + RemoteStorageKind::AwsS3(config) => { + properties.insert( + "region".to_string(), + PostHogFlagFilterPropertyValue::String(format!( + "aws-{}", + config.bucket_region + )), + ); + } + RemoteStorageKind::AzureContainer(config) => { + properties.insert( + "region".to_string(), + PostHogFlagFilterPropertyValue::String(format!( + "azure-{}", + config.container_region + )), + ); + } + RemoteStorageKind::LocalFs { .. } => { + properties.insert( + "region".to_string(), + PostHogFlagFilterPropertyValue::String("local".to_string()), + ); + } + } + } + // TODO: add pageserver URL. + Arc::new(properties) + }; + let fake_tenants = { + let mut tenants = Vec::new(); + for i in 0..10 { + let distinct_id = format!( + "fake_tenant_{}_{}_{}", + conf.availability_zone.as_deref().unwrap_or_default(), + conf.id, + i + ); + let properties = Self::collect_properties_inner( + distinct_id.clone(), + Some(&internal_properties), + ); + tenants.push(CaptureEvent { + event: "initial_tenant_report".to_string(), + distinct_id, + properties: json!({ "$set": properties }), // use `$set` to set the person properties instead of the event properties + }); + } + tenants + }; + // TODO: make refresh period configurable + inner + .clone() + .spawn(handle, Duration::from_secs(60), fake_tenants); + Ok(FeatureResolver { + inner: Some(inner), + internal_properties: Some(internal_properties), + }) } else { - Ok(FeatureResolver { inner: None }) + Ok(FeatureResolver { + inner: None, + internal_properties: None, + }) } } + fn collect_properties_inner( + tenant_id: String, + internal_properties: Option<&HashMap>, + ) -> HashMap { + let mut properties = HashMap::new(); + if let Some(internal_properties) = internal_properties { + for (key, value) in internal_properties.iter() { + properties.insert(key.clone(), value.clone()); + } + } + properties.insert( + "tenant_id".to_string(), + PostHogFlagFilterPropertyValue::String(tenant_id), + ); + properties + } + + /// Collect all properties availble for the feature flag evaluation. + pub(crate) fn collect_properties( + &self, + tenant_id: TenantId, + ) -> HashMap { + Self::collect_properties_inner(tenant_id.to_string(), self.internal_properties.as_deref()) + } + /// Evaluate a multivariate feature flag. Currently, we do not support any properties. /// /// Error handling: the caller should inspect the error and decide the behavior when a feature flag @@ -58,7 +165,7 @@ impl FeatureResolver { let res = inner.feature_store().evaluate_multivariate( flag_key, &tenant_id.to_string(), - &HashMap::new(), + &self.collect_properties(tenant_id), ); match &res { Ok(value) => { @@ -96,7 +203,7 @@ impl FeatureResolver { let res = inner.feature_store().evaluate_boolean( flag_key, &tenant_id.to_string(), - &HashMap::new(), + &self.collect_properties(tenant_id), ); match &res { Ok(()) => { diff --git a/pageserver/src/http/routes.rs b/pageserver/src/http/routes.rs index 1effa10404..c8a2a0209f 100644 --- a/pageserver/src/http/routes.rs +++ b/pageserver/src/http/routes.rs @@ -43,6 +43,7 @@ use pageserver_api::models::{ use pageserver_api::shard::{ShardCount, TenantShardId}; use remote_storage::{DownloadError, GenericRemoteStorage, TimeTravelError}; use scopeguard::defer; +use serde_json::json; use tenant_size_model::svg::SvgBranchKind; use tenant_size_model::{SizeResult, StorageModel}; use tokio::time::Instant; @@ -3679,23 +3680,24 @@ async fn tenant_evaluate_feature_flag( let tenant = state .tenant_manager .get_attached_tenant_shard(tenant_shard_id)?; + let properties = tenant.feature_resolver.collect_properties(tenant_shard_id.tenant_id); if as_type == "boolean" { let result = tenant.feature_resolver.evaluate_boolean(&flag, tenant_shard_id.tenant_id); let result = result.map(|_| true).map_err(|e| e.to_string()); - json_response(StatusCode::OK, result) + json_response(StatusCode::OK, json!({ "result": result, "properties": properties })) } else if as_type == "multivariate" { let result = tenant.feature_resolver.evaluate_multivariate(&flag, tenant_shard_id.tenant_id).map_err(|e| e.to_string()); - json_response(StatusCode::OK, result) + json_response(StatusCode::OK, json!({ "result": result, "properties": properties })) } else { // Auto infer the type of the feature flag. let is_boolean = tenant.feature_resolver.is_feature_flag_boolean(&flag).map_err(|e| ApiError::InternalServerError(anyhow::anyhow!("{e}")))?; if is_boolean { let result = tenant.feature_resolver.evaluate_boolean(&flag, tenant_shard_id.tenant_id); let result = result.map(|_| true).map_err(|e| e.to_string()); - json_response(StatusCode::OK, result) + json_response(StatusCode::OK, json!({ "result": result, "properties": properties })) } else { let result = tenant.feature_resolver.evaluate_multivariate(&flag, tenant_shard_id.tenant_id).map_err(|e| e.to_string()); - json_response(StatusCode::OK, result) + json_response(StatusCode::OK, json!({ "result": result, "properties": properties })) } } } From 1577665c20c98ab82c9aa5edcd2e4972887b5e5f Mon Sep 17 00:00:00 2001 From: Folke Behrens Date: Thu, 5 Jun 2025 11:00:23 +0000 Subject: [PATCH 36/43] proxy: Move PGLB-related modules into pglb root module. (#12144) Split the modules responsible for passing data and connecting to compute from auth and waking for PGLB. This PR just moves files. The waking is going to get removed from pglb after this. --- proxy/src/auth/backend/console_redirect.rs | 2 +- proxy/src/auth/backend/mod.rs | 2 +- proxy/src/console_redirect_proxy.rs | 6 +++--- proxy/src/{proxy => pglb}/connect_compute.rs | 3 +-- proxy/src/{proxy => pglb}/copy_bidirectional.rs | 0 proxy/src/{proxy => pglb}/handshake.rs | 0 proxy/src/pglb/mod.rs | 4 ++++ proxy/src/{proxy => pglb}/passthrough.rs | 2 +- proxy/src/proxy/mod.rs | 14 +++++--------- proxy/src/proxy/tests/mod.rs | 2 +- proxy/src/proxy/wake_compute.rs | 2 +- proxy/src/serverless/backend.rs | 6 +++--- 12 files changed, 21 insertions(+), 22 deletions(-) rename proxy/src/{proxy => pglb}/connect_compute.rs (98%) rename proxy/src/{proxy => pglb}/copy_bidirectional.rs (100%) rename proxy/src/{proxy => pglb}/handshake.rs (100%) rename proxy/src/{proxy => pglb}/passthrough.rs (97%) diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index a50c30257f..c388848926 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -15,9 +15,9 @@ use crate::context::RequestContext; use crate::control_plane::client::cplane_proxy_v1; use crate::control_plane::{self, CachedNodeInfo, NodeInfo}; use crate::error::{ReportableError, UserFacingError}; +use crate::pglb::connect_compute::ComputeConnectBackend; use crate::pqproto::BeMessage; use crate::proxy::NeonOptions; -use crate::proxy::connect_compute::ComputeConnectBackend; use crate::stream::PqStream; use crate::types::RoleName; use crate::{auth, compute, waiters}; diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index 735cb52f47..f978f655c4 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -25,9 +25,9 @@ use crate::control_plane::{ RoleAccessControl, }; use crate::intern::EndpointIdInt; +use crate::pglb::connect_compute::ComputeConnectBackend; use crate::pqproto::BeMessage; use crate::proxy::NeonOptions; -use crate::proxy::connect_compute::ComputeConnectBackend; use crate::rate_limiter::EndpointRateLimiter; use crate::stream::Stream; use crate::types::{EndpointCacheKey, EndpointId, RoleName}; diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 6755499b45..f2484b54b8 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -11,10 +11,10 @@ use crate::config::{ProxyConfig, ProxyProtocolV2}; use crate::context::RequestContext; use crate::error::ReportableError; use crate::metrics::{Metrics, NumClientConnectionsGuard}; +use crate::pglb::connect_compute::{TcpMechanism, connect_to_compute}; +use crate::pglb::handshake::{HandshakeData, handshake}; +use crate::pglb::passthrough::ProxyPassthrough; use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol}; -use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute}; -use crate::proxy::handshake::{HandshakeData, handshake}; -use crate::proxy::passthrough::ProxyPassthrough; use crate::proxy::{ ClientRequestError, ErrorSource, prepare_client_connection, run_until_cancelled, }; diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/pglb/connect_compute.rs similarity index 98% rename from proxy/src/proxy/connect_compute.rs rename to proxy/src/pglb/connect_compute.rs index 57785c9ec5..1d6ca5fbb3 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/pglb/connect_compute.rs @@ -2,7 +2,6 @@ use async_trait::async_trait; use tokio::time; use tracing::{debug, info, warn}; -use super::retry::ShouldRetryWakeCompute; use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo}; use crate::compute::{self, COULD_NOT_CONNECT, PostgresConnection}; use crate::config::{ComputeConfig, RetryConfig}; @@ -15,7 +14,7 @@ use crate::metrics::{ ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType, }; use crate::pqproto::StartupMessageParams; -use crate::proxy::retry::{CouldRetry, retry_after, should_retry}; +use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry}; use crate::proxy::wake_compute::wake_compute; use crate::types::Host; diff --git a/proxy/src/proxy/copy_bidirectional.rs b/proxy/src/pglb/copy_bidirectional.rs similarity index 100% rename from proxy/src/proxy/copy_bidirectional.rs rename to proxy/src/pglb/copy_bidirectional.rs diff --git a/proxy/src/proxy/handshake.rs b/proxy/src/pglb/handshake.rs similarity index 100% rename from proxy/src/proxy/handshake.rs rename to proxy/src/pglb/handshake.rs diff --git a/proxy/src/pglb/mod.rs b/proxy/src/pglb/mod.rs index 1088859fb9..4b107142a7 100644 --- a/proxy/src/pglb/mod.rs +++ b/proxy/src/pglb/mod.rs @@ -1 +1,5 @@ +pub mod connect_compute; +pub mod copy_bidirectional; +pub mod handshake; pub mod inprocess; +pub mod passthrough; diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/pglb/passthrough.rs similarity index 97% rename from proxy/src/proxy/passthrough.rs rename to proxy/src/pglb/passthrough.rs index 55ab5f4dba..6f651d383d 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/pglb/passthrough.rs @@ -53,7 +53,7 @@ pub(crate) async fn proxy_pass( // Starting from here we only proxy the client's traffic. debug!("performing the proxy pass..."); - let _ = crate::proxy::copy_bidirectional::copy_bidirectional_client_compute( + let _ = crate::pglb::copy_bidirectional::copy_bidirectional_client_compute( &mut client, &mut compute, ) diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 477baff1c9..0e138cc0c7 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -1,15 +1,10 @@ #[cfg(test)] mod tests; -pub(crate) mod connect_compute; -mod copy_bidirectional; -pub(crate) mod handshake; -pub(crate) mod passthrough; pub(crate) mod retry; pub(crate) mod wake_compute; use std::sync::Arc; -pub use copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute}; use futures::FutureExt; use itertools::Itertools; use once_cell::sync::OnceCell; @@ -21,16 +16,17 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::sync::CancellationToken; use tracing::{Instrument, debug, error, info, warn}; -use self::connect_compute::{TcpMechanism, connect_to_compute}; -use self::passthrough::ProxyPassthrough; use crate::cancellation::{self, CancellationHandler}; use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig}; use crate::context::RequestContext; use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumClientConnectionsGuard}; +use crate::pglb::connect_compute::{TcpMechanism, connect_to_compute}; +pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute}; +use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake}; +use crate::pglb::passthrough::ProxyPassthrough; use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams}; use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol}; -use crate::proxy::handshake::{HandshakeData, handshake}; use crate::rate_limiter::EndpointRateLimiter; use crate::stream::{PqStream, Stream}; use crate::types::EndpointCacheKey; @@ -242,7 +238,7 @@ pub(crate) enum ClientRequestError { #[error("{0}")] Cancellation(#[from] cancellation::CancelError), #[error("{0}")] - Handshake(#[from] handshake::HandshakeError), + Handshake(#[from] HandshakeError), #[error("{0}")] HandshakeTimeout(#[from] tokio::time::error::Elapsed), #[error("{0}")] diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 117c42e19c..e5db0013a7 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -17,7 +17,6 @@ use rustls::pki_types; use tokio::io::DuplexStream; use tracing_test::traced_test; -use super::connect_compute::ConnectMechanism; use super::retry::CouldRetry; use super::*; use crate::auth::backend::{ @@ -28,6 +27,7 @@ use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient}; use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status}; use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache}; use crate::error::ErrorKind; +use crate::pglb::connect_compute::ConnectMechanism; use crate::tls::client_config::compute_client_config_with_certs; use crate::tls::postgres_rustls::MakeRustlsConnect; use crate::tls::server_config::CertResolver; diff --git a/proxy/src/proxy/wake_compute.rs b/proxy/src/proxy/wake_compute.rs index 9d8915e24a..06c2da58db 100644 --- a/proxy/src/proxy/wake_compute.rs +++ b/proxy/src/proxy/wake_compute.rs @@ -1,6 +1,5 @@ use tracing::{error, info}; -use super::connect_compute::ComputeConnectBackend; use crate::config::RetryConfig; use crate::context::RequestContext; use crate::control_plane::CachedNodeInfo; @@ -9,6 +8,7 @@ use crate::error::ReportableError; use crate::metrics::{ ConnectOutcome, ConnectionFailuresBreakdownGroup, Metrics, RetriesMetricGroup, RetryType, }; +use crate::pglb::connect_compute::ComputeConnectBackend; use crate::proxy::retry::{retry_after, should_retry}; // Use macro to retain original callsite. diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index bf640c05e9..748e0ce6f2 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -35,7 +35,7 @@ use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError}; use crate::control_plane::locks::ApiLocks; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::intern::EndpointIdInt; -use crate::proxy::connect_compute::ConnectMechanism; +use crate::pglb::connect_compute::ConnectMechanism; use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute}; use crate::rate_limiter::EndpointRateLimiter; use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX}; @@ -182,7 +182,7 @@ impl PoolingBackend { tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "pool: opening a new connection '{conn_info}'"); let backend = self.auth_backend.as_ref().map(|()| keys); - crate::proxy::connect_compute::connect_to_compute( + crate::pglb::connect_compute::connect_to_compute( ctx, &TokioMechanism { conn_id, @@ -226,7 +226,7 @@ impl PoolingBackend { }, keys: crate::auth::backend::ComputeCredentialKeys::None, }); - crate::proxy::connect_compute::connect_to_compute( + crate::pglb::connect_compute::connect_to_compute( ctx, &HyperMechanism { conn_id, From 6123fe2d5e0e860fc3da74422d094927ecd6191e Mon Sep 17 00:00:00 2001 From: Konstantin Knizhnik Date: Thu, 5 Jun 2025 14:23:39 +0300 Subject: [PATCH 37/43] Add query execution time histogram (#10050) ## Problem It will be useful to understand what kind of queries our clients are executed. And one of the most important characteristic of query is query execution time - at least it allows to distinguish OLAP and OLTP queries. Also monitoring query execution time can help to detect problem with performance (assuming that workload is more or less stable). ## Summary of changes Add query execution time histogram. --------- Co-authored-by: Konstantin Knizhnik --- pgxn/neon/neon.c | 76 ++++++++++++ pgxn/neon/neon_perf_counters.c | 122 ++++++++++++++++---- pgxn/neon/neon_perf_counters.h | 28 +++++ test_runner/regress/test_compute_metrics.py | 5 + 4 files changed, 208 insertions(+), 23 deletions(-) diff --git a/pgxn/neon/neon.c b/pgxn/neon/neon.c index a6a7021756..5b4ced7cf0 100644 --- a/pgxn/neon/neon.c +++ b/pgxn/neon/neon.c @@ -16,6 +16,7 @@ #if PG_MAJORVERSION_NUM >= 15 #include "access/xlogrecovery.h" #endif +#include "executor/instrument.h" #include "replication/logical.h" #include "replication/logicallauncher.h" #include "replication/slot.h" @@ -33,6 +34,7 @@ #include "file_cache.h" #include "neon.h" #include "neon_lwlsncache.h" +#include "neon_perf_counters.h" #include "control_plane_connector.h" #include "logical_replication_monitor.h" #include "unstable_extensions.h" @@ -46,6 +48,13 @@ void _PG_init(void); static int running_xacts_overflow_policy; +static bool monitor_query_exec_time = false; + +static ExecutorStart_hook_type prev_ExecutorStart = NULL; +static ExecutorEnd_hook_type prev_ExecutorEnd = NULL; + +static void neon_ExecutorStart(QueryDesc *queryDesc, int eflags); +static void neon_ExecutorEnd(QueryDesc *queryDesc); #if PG_MAJORVERSION_NUM >= 16 static shmem_startup_hook_type prev_shmem_startup_hook; @@ -470,6 +479,16 @@ _PG_init(void) 0, NULL, NULL, NULL); + DefineCustomBoolVariable( + "neon.monitor_query_exec_time", + "Collect infortmation about query execution time", + NULL, + &monitor_query_exec_time, + false, + PGC_USERSET, + 0, + NULL, NULL, NULL); + DefineCustomBoolVariable( "neon.allow_replica_misconfig", "Allow replica startup when some critical GUCs have smaller value than on primary node", @@ -508,6 +527,11 @@ _PG_init(void) EmitWarningsOnPlaceholders("neon"); ReportSearchPath(); + + prev_ExecutorStart = ExecutorStart_hook; + ExecutorStart_hook = neon_ExecutorStart; + prev_ExecutorEnd = ExecutorEnd_hook; + ExecutorEnd_hook = neon_ExecutorEnd; } PG_FUNCTION_INFO_V1(pg_cluster_size); @@ -581,3 +605,55 @@ neon_shmem_startup_hook(void) #endif } #endif + +/* + * ExecutorStart hook: start up tracking if needed + */ +static void +neon_ExecutorStart(QueryDesc *queryDesc, int eflags) +{ + if (prev_ExecutorStart) + prev_ExecutorStart(queryDesc, eflags); + else + standard_ExecutorStart(queryDesc, eflags); + + if (monitor_query_exec_time) + { + /* + * Set up to track total elapsed time in ExecutorRun. Make sure the + * space is allocated in the per-query context so it will go away at + * ExecutorEnd. + */ + if (queryDesc->totaltime == NULL) + { + MemoryContext oldcxt; + + oldcxt = MemoryContextSwitchTo(queryDesc->estate->es_query_cxt); + queryDesc->totaltime = InstrAlloc(1, INSTRUMENT_TIMER, false); + MemoryContextSwitchTo(oldcxt); + } + } +} + +/* + * ExecutorEnd hook: store results if needed + */ +static void +neon_ExecutorEnd(QueryDesc *queryDesc) +{ + if (monitor_query_exec_time && queryDesc->totaltime) + { + /* + * Make sure stats accumulation is done. (Note: it's okay if several + * levels of hook all do this.) + */ + InstrEndLoop(queryDesc->totaltime); + + inc_query_time(queryDesc->totaltime->total*1000000); /* convert to usec */ + } + + if (prev_ExecutorEnd) + prev_ExecutorEnd(queryDesc); + else + standard_ExecutorEnd(queryDesc); +} diff --git a/pgxn/neon/neon_perf_counters.c b/pgxn/neon/neon_perf_counters.c index c77d99d636..d0a3d15108 100644 --- a/pgxn/neon/neon_perf_counters.c +++ b/pgxn/neon/neon_perf_counters.c @@ -71,6 +71,27 @@ inc_iohist(IOHistogram hist, uint64 latency_us) hist->wait_us_count++; } +static inline void +inc_qthist(QTHistogram hist, uint64 elapsed_us) +{ + int lo = 0; + int hi = NUM_QT_BUCKETS - 1; + + /* Find the right bucket with binary search */ + while (lo < hi) + { + int mid = (lo + hi) / 2; + + if (elapsed_us < qt_bucket_thresholds[mid]) + hi = mid; + else + lo = mid + 1; + } + hist->elapsed_us_bucket[lo]++; + hist->elapsed_us_sum += elapsed_us; + hist->elapsed_us_count++; +} + /* * Count a GetPage wait operation. */ @@ -98,6 +119,13 @@ inc_page_cache_write_wait(uint64 latency) inc_iohist(&MyNeonCounters->file_cache_write_hist, latency); } + +void +inc_query_time(uint64 elapsed) +{ + inc_qthist(&MyNeonCounters->query_time_hist, elapsed); +} + /* * Support functions for the views, neon_backend_perf_counters and * neon_perf_counters. @@ -112,11 +140,11 @@ typedef struct } metric_t; static int -histogram_to_metrics(IOHistogram histogram, - metric_t *metrics, - const char *count, - const char *sum, - const char *bucket) +io_histogram_to_metrics(IOHistogram histogram, + metric_t *metrics, + const char *count, + const char *sum, + const char *bucket) { int i = 0; uint64 bucket_accum = 0; @@ -145,10 +173,44 @@ histogram_to_metrics(IOHistogram histogram, return i; } +static int +qt_histogram_to_metrics(QTHistogram histogram, + metric_t *metrics, + const char *count, + const char *sum, + const char *bucket) +{ + int i = 0; + uint64 bucket_accum = 0; + + metrics[i].name = count; + metrics[i].is_bucket = false; + metrics[i].value = (double) histogram->elapsed_us_count; + i++; + metrics[i].name = sum; + metrics[i].is_bucket = false; + metrics[i].value = (double) histogram->elapsed_us_sum / 1000000.0; + i++; + for (int bucketno = 0; bucketno < NUM_QT_BUCKETS; bucketno++) + { + uint64 threshold = qt_bucket_thresholds[bucketno]; + + bucket_accum += histogram->elapsed_us_bucket[bucketno]; + + metrics[i].name = bucket; + metrics[i].is_bucket = true; + metrics[i].bucket_le = (threshold == UINT64_MAX) ? INFINITY : ((double) threshold) / 1000000.0; + metrics[i].value = (double) bucket_accum; + i++; + } + + return i; +} + static metric_t * neon_perf_counters_to_metrics(neon_per_backend_counters *counters) { -#define NUM_METRICS ((2 + NUM_IO_WAIT_BUCKETS) * 3 + 12) +#define NUM_METRICS ((2 + NUM_IO_WAIT_BUCKETS) * 3 + (2 + NUM_QT_BUCKETS) + 12) metric_t *metrics = palloc((NUM_METRICS + 1) * sizeof(metric_t)); int i = 0; @@ -159,10 +221,10 @@ neon_perf_counters_to_metrics(neon_per_backend_counters *counters) i++; \ } while (false) - i += histogram_to_metrics(&counters->getpage_hist, &metrics[i], - "getpage_wait_seconds_count", - "getpage_wait_seconds_sum", - "getpage_wait_seconds_bucket"); + i += io_histogram_to_metrics(&counters->getpage_hist, &metrics[i], + "getpage_wait_seconds_count", + "getpage_wait_seconds_sum", + "getpage_wait_seconds_bucket"); APPEND_METRIC(getpage_prefetch_requests_total); APPEND_METRIC(getpage_sync_requests_total); @@ -178,14 +240,19 @@ neon_perf_counters_to_metrics(neon_per_backend_counters *counters) APPEND_METRIC(file_cache_hits_total); - i += histogram_to_metrics(&counters->file_cache_read_hist, &metrics[i], - "file_cache_read_wait_seconds_count", - "file_cache_read_wait_seconds_sum", - "file_cache_read_wait_seconds_bucket"); - i += histogram_to_metrics(&counters->file_cache_write_hist, &metrics[i], - "file_cache_write_wait_seconds_count", - "file_cache_write_wait_seconds_sum", - "file_cache_write_wait_seconds_bucket"); + i += io_histogram_to_metrics(&counters->file_cache_read_hist, &metrics[i], + "file_cache_read_wait_seconds_count", + "file_cache_read_wait_seconds_sum", + "file_cache_read_wait_seconds_bucket"); + i += io_histogram_to_metrics(&counters->file_cache_write_hist, &metrics[i], + "file_cache_write_wait_seconds_count", + "file_cache_write_wait_seconds_sum", + "file_cache_write_wait_seconds_bucket"); + + i += qt_histogram_to_metrics(&counters->query_time_hist, &metrics[i], + "query_time_seconds_count", + "query_time_seconds_sum", + "query_time_seconds_bucket"); Assert(i == NUM_METRICS); @@ -257,7 +324,7 @@ neon_get_backend_perf_counters(PG_FUNCTION_ARGS) } static inline void -histogram_merge_into(IOHistogram into, IOHistogram from) +io_histogram_merge_into(IOHistogram into, IOHistogram from) { into->wait_us_count += from->wait_us_count; into->wait_us_sum += from->wait_us_sum; @@ -265,6 +332,15 @@ histogram_merge_into(IOHistogram into, IOHistogram from) into->wait_us_bucket[bucketno] += from->wait_us_bucket[bucketno]; } +static inline void +qt_histogram_merge_into(QTHistogram into, QTHistogram from) +{ + into->elapsed_us_count += from->elapsed_us_count; + into->elapsed_us_sum += from->elapsed_us_sum; + for (int bucketno = 0; bucketno < NUM_QT_BUCKETS; bucketno++) + into->elapsed_us_bucket[bucketno] += from->elapsed_us_bucket[bucketno]; +} + PG_FUNCTION_INFO_V1(neon_get_perf_counters); Datum neon_get_perf_counters(PG_FUNCTION_ARGS) @@ -283,7 +359,7 @@ neon_get_perf_counters(PG_FUNCTION_ARGS) { neon_per_backend_counters *counters = &neon_per_backend_counters_shared[procno]; - histogram_merge_into(&totals.getpage_hist, &counters->getpage_hist); + io_histogram_merge_into(&totals.getpage_hist, &counters->getpage_hist); totals.getpage_prefetch_requests_total += counters->getpage_prefetch_requests_total; totals.getpage_sync_requests_total += counters->getpage_sync_requests_total; totals.getpage_prefetch_misses_total += counters->getpage_prefetch_misses_total; @@ -294,13 +370,13 @@ neon_get_perf_counters(PG_FUNCTION_ARGS) totals.pageserver_open_requests += counters->pageserver_open_requests; totals.getpage_prefetches_buffered += counters->getpage_prefetches_buffered; totals.file_cache_hits_total += counters->file_cache_hits_total; - histogram_merge_into(&totals.file_cache_read_hist, &counters->file_cache_read_hist); - histogram_merge_into(&totals.file_cache_write_hist, &counters->file_cache_write_hist); - totals.compute_getpage_stuck_requests_total += counters->compute_getpage_stuck_requests_total; totals.compute_getpage_max_inflight_stuck_time_ms = Max( totals.compute_getpage_max_inflight_stuck_time_ms, counters->compute_getpage_max_inflight_stuck_time_ms); + io_histogram_merge_into(&totals.file_cache_read_hist, &counters->file_cache_read_hist); + io_histogram_merge_into(&totals.file_cache_write_hist, &counters->file_cache_write_hist); + qt_histogram_merge_into(&totals.query_time_hist, &counters->query_time_hist); } metrics = neon_perf_counters_to_metrics(&totals); diff --git a/pgxn/neon/neon_perf_counters.h b/pgxn/neon/neon_perf_counters.h index 10cf094d4a..4b611b0636 100644 --- a/pgxn/neon/neon_perf_counters.h +++ b/pgxn/neon/neon_perf_counters.h @@ -36,6 +36,28 @@ typedef struct IOHistogramData typedef IOHistogramData *IOHistogram; +static const uint64 qt_bucket_thresholds[] = { + 2, 3, 6, 10, /* 0 us - 10 us */ + 20, 30, 60, 100, /* 10 us - 100 us */ + 200, 300, 600, 1000, /* 100 us - 1 ms */ + 2000, 3000, 6000, 10000, /* 1 ms - 10 ms */ + 20000, 30000, 60000, 100000, /* 10 ms - 100 ms */ + 200000, 300000, 600000, 1000000, /* 100 ms - 1 s */ + 2000000, 3000000, 6000000, 10000000, /* 1 s - 10 s */ + 20000000, 30000000, 60000000, 100000000, /* 10 s - 100 s */ + UINT64_MAX, +}; +#define NUM_QT_BUCKETS (lengthof(qt_bucket_thresholds)) + +typedef struct QTHistogramData +{ + uint64 elapsed_us_count; + uint64 elapsed_us_sum; + uint64 elapsed_us_bucket[NUM_QT_BUCKETS]; +} QTHistogramData; + +typedef QTHistogramData *QTHistogram; + typedef struct { /* @@ -127,6 +149,11 @@ typedef struct /* LFC I/O time buckets */ IOHistogramData file_cache_read_hist; IOHistogramData file_cache_write_hist; + + /* + * Histogram of query execution time. + */ + QTHistogramData query_time_hist; } neon_per_backend_counters; /* Pointer to the shared memory array of neon_per_backend_counters structs */ @@ -149,6 +176,7 @@ extern neon_per_backend_counters *neon_per_backend_counters_shared; extern void inc_getpage_wait(uint64 latency); extern void inc_page_cache_read_wait(uint64 latency); extern void inc_page_cache_write_wait(uint64 latency); +extern void inc_query_time(uint64 elapsed); extern Size NeonPerfCountersShmemSize(void); extern void NeonPerfCountersShmemInit(void); diff --git a/test_runner/regress/test_compute_metrics.py b/test_runner/regress/test_compute_metrics.py index 2cb2ee7b58..c751a3e7cc 100644 --- a/test_runner/regress/test_compute_metrics.py +++ b/test_runner/regress/test_compute_metrics.py @@ -466,8 +466,13 @@ def test_perf_counters(neon_simple_env: NeonEnv): # # 1.5 is the minimum version to contain these views. cur.execute("CREATE EXTENSION neon VERSION '1.5'") + cur.execute("set neon.monitor_query_exec_time = on") cur.execute("SELECT * FROM neon_perf_counters") cur.execute("SELECT * FROM neon_backend_perf_counters") + cur.execute( + "select value from neon_backend_perf_counters where metric='query_time_seconds_count' and pid=pg_backend_pid()" + ) + assert cur.fetchall()[0][0] == 2 def collect_metric( From 9c6c7802012cd4f5d1ed24aeaeb5e60f28075922 Mon Sep 17 00:00:00 2001 From: Konstantin Knizhnik Date: Thu, 5 Jun 2025 14:27:14 +0300 Subject: [PATCH 38/43] Replica promote (#12090) ## Problem This PR is part of larger computes support activity: https://www.notion.so/neondatabase/Larger-computes-114f189e00478080ba01e8651ab7da90 Epic: https://github.com/neondatabase/cloud/issues/19010 In case of planned node restart, we are going to 1. create new read-only replica 2. capture LFC state at primary 3. use this state to prewarm replica 4. stop old primary 5. promote replica to primary Steps 1-3 are currently implemented and support from compute side. This PR provides compute level implementation of replica promotion. Support replica promotion ## Summary of changes Right now replica promotion is done in three steps: 1. Set safekeepers list (now it is empty for replica) 2. Call `pg_promote()` top promote replica 3. Update endpoint setting to that it ids not more treated as replica. May be all this three steps should be done by some function in compute_ctl. But right now this logic is only implement5ed in test. Postgres submodules PRs: https://github.com/neondatabase/postgres/pull/648 https://github.com/neondatabase/postgres/pull/649 https://github.com/neondatabase/postgres/pull/650 https://github.com/neondatabase/postgres/pull/651 --------- Co-authored-by: Matthias van de Meent Co-authored-by: Konstantin Knizhnik --- libs/walproposer/src/api_bindings.rs | 1 + pgxn/neon/neon_pgversioncompat.c | 8 ++ pgxn/neon/neon_pgversioncompat.h | 1 + pgxn/neon/neon_walreader.c | 15 ++- pgxn/neon/neon_walreader.h | 3 +- pgxn/neon/walproposer.c | 3 +- pgxn/neon/walproposer.h | 4 + pgxn/neon/walproposer_pg.c | 26 ++-- pgxn/neon/walsender_hooks.c | 4 +- test_runner/fixtures/neon_fixtures.py | 4 +- test_runner/regress/test_hot_standby.py | 9 +- test_runner/regress/test_replica_promotes.py | 133 +++++++++++++++++++ vendor/postgres-v14 | 2 +- vendor/postgres-v15 | 2 +- vendor/postgres-v16 | 2 +- vendor/postgres-v17 | 2 +- vendor/revisions.json | 8 +- 17 files changed, 199 insertions(+), 28 deletions(-) create mode 100644 test_runner/regress/test_replica_promotes.py diff --git a/libs/walproposer/src/api_bindings.rs b/libs/walproposer/src/api_bindings.rs index d660602149..4d6cbae9a9 100644 --- a/libs/walproposer/src/api_bindings.rs +++ b/libs/walproposer/src/api_bindings.rs @@ -439,6 +439,7 @@ pub fn empty_shmem() -> crate::bindings::WalproposerShmemState { currentClusterSize: crate::bindings::pg_atomic_uint64 { value: 0 }, shard_ps_feedback: [empty_feedback; 128], num_shards: 0, + replica_promote: false, min_ps_feedback: empty_feedback, } } diff --git a/pgxn/neon/neon_pgversioncompat.c b/pgxn/neon/neon_pgversioncompat.c index 7c404fb5a9..6f57b618da 100644 --- a/pgxn/neon/neon_pgversioncompat.c +++ b/pgxn/neon/neon_pgversioncompat.c @@ -5,6 +5,7 @@ #include "funcapi.h" #include "miscadmin.h" +#include "access/xlog.h" #include "utils/tuplestore.h" #include "neon_pgversioncompat.h" @@ -41,5 +42,12 @@ InitMaterializedSRF(FunctionCallInfo fcinfo, bits32 flags) rsinfo->setDesc = stored_tupdesc; MemoryContextSwitchTo(old_context); } + +TimeLineID GetWALInsertionTimeLine(void) +{ + return ThisTimeLineID + 1; +} + + #endif diff --git a/pgxn/neon/neon_pgversioncompat.h b/pgxn/neon/neon_pgversioncompat.h index bf91a02b45..787bd552f8 100644 --- a/pgxn/neon/neon_pgversioncompat.h +++ b/pgxn/neon/neon_pgversioncompat.h @@ -162,6 +162,7 @@ InitBufferTag(BufferTag *tag, const RelFileNode *rnode, #if PG_MAJORVERSION_NUM < 15 extern void InitMaterializedSRF(FunctionCallInfo fcinfo, bits32 flags); +extern TimeLineID GetWALInsertionTimeLine(void); #endif #endif /* NEON_PGVERSIONCOMPAT_H */ diff --git a/pgxn/neon/neon_walreader.c b/pgxn/neon/neon_walreader.c index d5e3a38dbb..0a1f6d9c72 100644 --- a/pgxn/neon/neon_walreader.c +++ b/pgxn/neon/neon_walreader.c @@ -69,6 +69,7 @@ struct NeonWALReader WALSegmentContext segcxt; WALOpenSegment seg; int wre_errno; + TimeLineID local_active_tlid; /* Explains failure to read, static for simplicity. */ char err_msg[NEON_WALREADER_ERR_MSG_LEN]; @@ -106,7 +107,7 @@ struct NeonWALReader /* palloc and initialize NeonWALReader */ NeonWALReader * -NeonWALReaderAllocate(int wal_segment_size, XLogRecPtr available_lsn, char *log_prefix) +NeonWALReaderAllocate(int wal_segment_size, XLogRecPtr available_lsn, char *log_prefix, TimeLineID tlid) { NeonWALReader *reader; @@ -118,6 +119,7 @@ NeonWALReaderAllocate(int wal_segment_size, XLogRecPtr available_lsn, char *log_ MemoryContextAllocZero(TopMemoryContext, sizeof(NeonWALReader)); reader->available_lsn = available_lsn; + reader->local_active_tlid = tlid; reader->seg.ws_file = -1; reader->seg.ws_segno = 0; reader->seg.ws_tli = 0; @@ -577,6 +579,17 @@ NeonWALReaderIsRemConnEstablished(NeonWALReader *state) return state->rem_state == RS_ESTABLISHED; } +/* + * Whether remote connection is established. Once this is done, until successful + * local read or error socket is stable and user can update socket events + * instead of readding it each time. + */ +TimeLineID +NeonWALReaderLocalActiveTimeLineID(NeonWALReader *state) +{ + return state->local_active_tlid; +} + /* * Returns events user should wait on connection socket or 0 if remote * connection is not active. diff --git a/pgxn/neon/neon_walreader.h b/pgxn/neon/neon_walreader.h index 3e41825069..722bc10537 100644 --- a/pgxn/neon/neon_walreader.h +++ b/pgxn/neon/neon_walreader.h @@ -19,9 +19,10 @@ typedef enum NEON_WALREAD_ERROR, } NeonWALReadResult; -extern NeonWALReader *NeonWALReaderAllocate(int wal_segment_size, XLogRecPtr available_lsn, char *log_prefix); +extern NeonWALReader *NeonWALReaderAllocate(int wal_segment_size, XLogRecPtr available_lsn, char *log_prefix, TimeLineID tlid); extern void NeonWALReaderFree(NeonWALReader *state); extern void NeonWALReaderResetRemote(NeonWALReader *state); +extern TimeLineID NeonWALReaderLocalActiveTimeLineID(NeonWALReader *state); extern NeonWALReadResult NeonWALRead(NeonWALReader *state, char *buf, XLogRecPtr startptr, Size count, TimeLineID tli); extern pgsocket NeonWALReaderSocket(NeonWALReader *state); extern uint32 NeonWALReaderEvents(NeonWALReader *state); diff --git a/pgxn/neon/walproposer.c b/pgxn/neon/walproposer.c index c63edb1398..91d39345e2 100644 --- a/pgxn/neon/walproposer.c +++ b/pgxn/neon/walproposer.c @@ -98,6 +98,7 @@ WalProposerCreate(WalProposerConfig *config, walproposer_api api) wp = palloc0(sizeof(WalProposer)); wp->config = config; wp->api = api; + wp->localTimeLineID = config->pgTimeline; wp->state = WPS_COLLECTING_TERMS; wp->mconf.generation = INVALID_GENERATION; wp->mconf.members.len = 0; @@ -1384,7 +1385,7 @@ ProcessPropStartPos(WalProposer *wp) * we must bail out, as clog and other non rel data is inconsistent. */ walprop_shared = wp->api.get_shmem_state(wp); - if (!wp->config->syncSafekeepers) + if (!wp->config->syncSafekeepers && !walprop_shared->replica_promote) { /* * Basebackup LSN always points to the beginning of the record (not diff --git a/pgxn/neon/walproposer.h b/pgxn/neon/walproposer.h index cca20e746b..08087e5a55 100644 --- a/pgxn/neon/walproposer.h +++ b/pgxn/neon/walproposer.h @@ -391,6 +391,7 @@ typedef struct WalproposerShmemState /* last feedback from each shard */ PageserverFeedback shard_ps_feedback[MAX_SHARDS]; int num_shards; + bool replica_promote; /* aggregated feedback with min LSNs across shards */ PageserverFeedback min_ps_feedback; @@ -806,6 +807,9 @@ typedef struct WalProposer /* Safekeepers walproposer is connecting to. */ Safekeeper safekeeper[MAX_SAFEKEEPERS]; + /* Current local TimeLineId in use */ + TimeLineID localTimeLineID; + /* WAL has been generated up to this point */ XLogRecPtr availableLsn; diff --git a/pgxn/neon/walproposer_pg.c b/pgxn/neon/walproposer_pg.c index c90702a282..3d6a92ad79 100644 --- a/pgxn/neon/walproposer_pg.c +++ b/pgxn/neon/walproposer_pg.c @@ -35,6 +35,7 @@ #include "storage/proc.h" #include "storage/ipc.h" #include "storage/lwlock.h" +#include "storage/pg_shmem.h" #include "storage/shmem.h" #include "storage/spin.h" #include "tcop/tcopprot.h" @@ -159,12 +160,19 @@ WalProposerMain(Datum main_arg) { WalProposer *wp; + if (*wal_acceptors_list == '\0') + { + wpg_log(WARNING, "Safekeepers list is empty"); + return; + } + init_walprop_config(false); walprop_pg_init_bgworker(); am_walproposer = true; walprop_pg_load_libpqwalreceiver(); wp = WalProposerCreate(&walprop_config, walprop_pg); + wp->localTimeLineID = GetWALInsertionTimeLine(); wp->last_reconnect_attempt = walprop_pg_get_current_timestamp(wp); walprop_pg_init_walsender(); @@ -350,6 +358,9 @@ assign_neon_safekeepers(const char *newval, void *extra) char *newval_copy; char *oldval; + if (newval && *newval != '\0' && UsedShmemSegAddr && walprop_shared && RecoveryInProgress()) + walprop_shared->replica_promote = true; + if (!am_walproposer) return; @@ -540,16 +551,15 @@ BackpressureThrottlingTime(void) /* * Register a background worker proposing WAL to wal acceptors. + * We start walproposer bgworker even for replicas in order to support possible replica promotion. + * When pg_promote() function is called, then walproposer bgworker registered with BgWorkerStart_RecoveryFinished + * is automatically launched when promotion is completed. */ static void walprop_register_bgworker(void) { BackgroundWorker bgw; - /* If no wal acceptors are specified, don't start the background worker. */ - if (*wal_acceptors_list == '\0') - return; - memset(&bgw, 0, sizeof(bgw)); bgw.bgw_flags = BGWORKER_SHMEM_ACCESS; bgw.bgw_start_time = BgWorkerStart_RecoveryFinished; @@ -1326,9 +1336,7 @@ StartProposerReplication(WalProposer *wp, StartReplicationCmd *cmd) #if PG_VERSION_NUM < 150000 if (ThisTimeLineID == 0) - ereport(ERROR, - (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), - errmsg("IDENTIFY_SYSTEM has not been run before START_REPLICATION"))); + ThisTimeLineID = 1; #endif /* @@ -1542,7 +1550,7 @@ walprop_pg_wal_reader_allocate(Safekeeper *sk) snprintf(log_prefix, sizeof(log_prefix), WP_LOG_PREFIX "sk %s:%s nwr: ", sk->host, sk->port); Assert(!sk->xlogreader); - sk->xlogreader = NeonWALReaderAllocate(wal_segment_size, sk->wp->propTermStartLsn, log_prefix); + sk->xlogreader = NeonWALReaderAllocate(wal_segment_size, sk->wp->propTermStartLsn, log_prefix, sk->wp->localTimeLineID); if (sk->xlogreader == NULL) wpg_log(FATAL, "failed to allocate xlog reader"); } @@ -1556,7 +1564,7 @@ walprop_pg_wal_read(Safekeeper *sk, char *buf, XLogRecPtr startptr, Size count, buf, startptr, count, - walprop_pg_get_timeline_id()); + sk->wp->localTimeLineID); if (res == NEON_WALREAD_SUCCESS) { diff --git a/pgxn/neon/walsender_hooks.c b/pgxn/neon/walsender_hooks.c index 81198d6c8d..534bf1c19b 100644 --- a/pgxn/neon/walsender_hooks.c +++ b/pgxn/neon/walsender_hooks.c @@ -111,7 +111,7 @@ NeonWALPageRead( readBuf, targetPagePtr, count, - walprop_pg_get_timeline_id()); + NeonWALReaderLocalActiveTimeLineID(wal_reader)); if (res == NEON_WALREAD_SUCCESS) { @@ -202,7 +202,7 @@ NeonOnDemandXLogReaderRoutines(XLogReaderRoutine *xlr) { elog(ERROR, "unable to start walsender when basebackupLsn is 0"); } - wal_reader = NeonWALReaderAllocate(wal_segment_size, basebackupLsn, "[walsender] "); + wal_reader = NeonWALReaderAllocate(wal_segment_size, basebackupLsn, "[walsender] ", 1); } xlr->page_read = NeonWALPageRead; xlr->segment_open = NeonWALReadSegmentOpen; diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 19d12da5e3..f9337bed89 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -4711,7 +4711,7 @@ class EndpointFactory: origin: Endpoint, endpoint_id: str | None = None, config_lines: list[str] | None = None, - ): + ) -> Endpoint: branch_name = origin.branch_name assert origin in self.endpoints assert branch_name is not None @@ -4730,7 +4730,7 @@ class EndpointFactory: origin: Endpoint, endpoint_id: str | None = None, config_lines: list[str] | None = None, - ): + ) -> Endpoint: branch_name = origin.branch_name assert origin in self.endpoints assert branch_name is not None diff --git a/test_runner/regress/test_hot_standby.py b/test_runner/regress/test_hot_standby.py index 4044f25b37..1ff61ce8dc 100644 --- a/test_runner/regress/test_hot_standby.py +++ b/test_runner/regress/test_hot_standby.py @@ -74,8 +74,9 @@ def test_hot_standby(neon_simple_env: NeonEnv): for query in queries: with s_con.cursor() as secondary_cursor: secondary_cursor.execute(query) - response = secondary_cursor.fetchone() - assert response is not None + res = secondary_cursor.fetchone() + assert res is not None + response = res assert response == responses[query] # Check for corrupted WAL messages which might otherwise go unnoticed if @@ -164,7 +165,7 @@ def test_hot_standby_gc(neon_env_builder: NeonEnvBuilder, pause_apply: bool): s_cur.execute("SELECT COUNT(*) FROM test") res = s_cur.fetchone() - assert res[0] == 10000 + assert res == (10000,) # Clear the cache in the standby, so that when we # re-execute the query, it will make GetPage @@ -195,7 +196,7 @@ def test_hot_standby_gc(neon_env_builder: NeonEnvBuilder, pause_apply: bool): s_cur.execute("SELECT COUNT(*) FROM test") log_replica_lag(primary, secondary) res = s_cur.fetchone() - assert res[0] == 10000 + assert res == (10000,) def run_pgbench(connstr: str, pg_bin: PgBin): diff --git a/test_runner/regress/test_replica_promotes.py b/test_runner/regress/test_replica_promotes.py new file mode 100644 index 0000000000..e378d37635 --- /dev/null +++ b/test_runner/regress/test_replica_promotes.py @@ -0,0 +1,133 @@ +""" +File with secondary->primary promotion testing. + +This far, only contains a test that we don't break and that the data is persisted. +""" + +import psycopg2 +from fixtures.log_helper import log +from fixtures.neon_fixtures import Endpoint, NeonEnv, wait_replica_caughtup +from fixtures.pg_version import PgVersion +from pytest import raises + + +def test_replica_promotes(neon_simple_env: NeonEnv, pg_version: PgVersion): + """ + Test that a replica safely promotes, and can commit data updates which + show up when the primary boots up after the promoted secondary endpoint + shut down. + """ + + # Initialize the primary, a test table, and a helper function to create lots + # of subtransactions. + env: NeonEnv = neon_simple_env + primary: Endpoint = env.endpoints.create_start(branch_name="main", endpoint_id="primary") + secondary: Endpoint = env.endpoints.new_replica_start(origin=primary, endpoint_id="secondary") + + with primary.connect() as primary_conn: + primary_cur = primary_conn.cursor() + primary_cur.execute( + "create table t(pk bigint GENERATED ALWAYS AS IDENTITY, payload integer)" + ) + primary_cur.execute("INSERT INTO t(payload) SELECT generate_series(1, 100)") + primary_cur.execute( + """ + SELECT pg_current_wal_insert_lsn(), + pg_current_wal_lsn(), + pg_current_wal_flush_lsn() + """ + ) + log.info(f"Primary: Current LSN after workload is {primary_cur.fetchone()}") + primary_cur.execute("show neon.safekeepers") + safekeepers = primary_cur.fetchall()[0][0] + + wait_replica_caughtup(primary, secondary) + + with secondary.connect() as secondary_conn: + secondary_cur = secondary_conn.cursor() + secondary_cur.execute("select count(*) from t") + + assert secondary_cur.fetchone() == (100,) + + with raises(psycopg2.Error): + secondary_cur.execute("INSERT INTO t (payload) SELECT generate_series(101, 200)") + secondary_conn.commit() + + secondary_conn.rollback() + secondary_cur.execute("select count(*) from t") + assert secondary_cur.fetchone() == (100,) + + primary.stop_and_destroy(mode="immediate") + + # Reconnect to the secondary to make sure we get a read-write connection + promo_conn = secondary.connect() + promo_cur = promo_conn.cursor() + promo_cur.execute(f"alter system set neon.safekeepers='{safekeepers}'") + promo_cur.execute("select pg_reload_conf()") + + promo_cur.execute("SELECT * FROM pg_promote()") + assert promo_cur.fetchone() == (True,) + promo_cur.execute( + """ + SELECT pg_current_wal_insert_lsn(), + pg_current_wal_lsn(), + pg_current_wal_flush_lsn() + """ + ) + log.info(f"Secondary: LSN after promotion is {promo_cur.fetchone()}") + + # Reconnect to the secondary to make sure we get a read-write connection + with secondary.connect() as new_primary_conn: + new_primary_cur = new_primary_conn.cursor() + new_primary_cur.execute("select count(*) from t") + assert new_primary_cur.fetchone() == (100,) + + new_primary_cur.execute( + "INSERT INTO t (payload) SELECT generate_series(101, 200) RETURNING payload" + ) + assert new_primary_cur.fetchall() == [(it,) for it in range(101, 201)] + + new_primary_cur = new_primary_conn.cursor() + new_primary_cur.execute("select payload from t") + assert new_primary_cur.fetchall() == [(it,) for it in range(1, 201)] + + new_primary_cur.execute("select count(*) from t") + assert new_primary_cur.fetchone() == (200,) + new_primary_cur.execute( + """ + SELECT pg_current_wal_insert_lsn(), + pg_current_wal_lsn(), + pg_current_wal_flush_lsn() + """ + ) + log.info(f"Secondary: LSN after workload is {new_primary_cur.fetchone()}") + + with secondary.connect() as second_viewpoint_conn: + new_primary_cur = second_viewpoint_conn.cursor() + new_primary_cur.execute("select payload from t") + assert new_primary_cur.fetchall() == [(it,) for it in range(1, 201)] + + # wait_for_last_flush_lsn(env, secondary, env.initial_tenant, env.initial_timeline) + + secondary.stop_and_destroy() + + primary = env.endpoints.create_start(branch_name="main", endpoint_id="primary") + + with primary.connect() as new_primary: + new_primary_cur = new_primary.cursor() + new_primary_cur.execute( + """ + SELECT pg_current_wal_insert_lsn(), + pg_current_wal_lsn(), + pg_current_wal_flush_lsn() + """ + ) + log.info(f"New primary: Boot LSN is {new_primary_cur.fetchone()}") + + new_primary_cur.execute("select count(*) from t") + assert new_primary_cur.fetchone() == (200,) + new_primary_cur.execute("INSERT INTO t (payload) SELECT generate_series(201, 300)") + new_primary_cur.execute("select count(*) from t") + assert new_primary_cur.fetchone() == (300,) + + primary.stop(mode="immediate") diff --git a/vendor/postgres-v14 b/vendor/postgres-v14 index 55c0d45abe..6770bc2513 160000 --- a/vendor/postgres-v14 +++ b/vendor/postgres-v14 @@ -1 +1 @@ -Subproject commit 55c0d45abe6467c02084c2192bca117eda6ce1e7 +Subproject commit 6770bc251301ef40c66f7ecb731741dc435b5051 diff --git a/vendor/postgres-v15 b/vendor/postgres-v15 index de7640f55d..8c3249f36c 160000 --- a/vendor/postgres-v15 +++ b/vendor/postgres-v15 @@ -1 +1 @@ -Subproject commit de7640f55da07512834d5cc40c4b3fb376b5f04f +Subproject commit 8c3249f36c7df6ac0efb8ee9f1baf4aa1b83e5c9 diff --git a/vendor/postgres-v16 b/vendor/postgres-v16 index 0bf96bd6d7..7a4c0eacae 160000 --- a/vendor/postgres-v16 +++ b/vendor/postgres-v16 @@ -1 +1 @@ -Subproject commit 0bf96bd6d70301a0b43b0b3457bb3cf8fb43c198 +Subproject commit 7a4c0eacaeb9b97416542fa19103061c166460b1 diff --git a/vendor/postgres-v17 b/vendor/postgres-v17 index 8be779fd3a..db424d42d7 160000 --- a/vendor/postgres-v17 +++ b/vendor/postgres-v17 @@ -1 +1 @@ -Subproject commit 8be779fd3ab9e87206da96a7e4842ef1abf04f44 +Subproject commit db424d42d748f8ad91ac00e28db2c7f2efa42f7f diff --git a/vendor/revisions.json b/vendor/revisions.json index 3e999760f4..12d5499ddb 100644 --- a/vendor/revisions.json +++ b/vendor/revisions.json @@ -1,18 +1,18 @@ { "v17": [ "17.5", - "8be779fd3ab9e87206da96a7e4842ef1abf04f44" + "db424d42d748f8ad91ac00e28db2c7f2efa42f7f" ], "v16": [ "16.9", - "0bf96bd6d70301a0b43b0b3457bb3cf8fb43c198" + "7a4c0eacaeb9b97416542fa19103061c166460b1" ], "v15": [ "15.13", - "de7640f55da07512834d5cc40c4b3fb376b5f04f" + "8c3249f36c7df6ac0efb8ee9f1baf4aa1b83e5c9" ], "v14": [ "14.18", - "55c0d45abe6467c02084c2192bca117eda6ce1e7" + "6770bc251301ef40c66f7ecb731741dc435b5051" ] } From 868f194a3b7cf2463c86e1fbd0e83360a40d31ef Mon Sep 17 00:00:00 2001 From: Vlad Lazar Date: Thu, 5 Jun 2025 12:43:04 +0100 Subject: [PATCH 39/43] pageserver: remove handling of vanilla protocol (#12126) ## Problem We support two ingest protocols on the pageserver: vanilla and interpreted. Interpreted has been the only protocol in use for a long time. ## Summary of changes * Remove the ingest handling of the vanilla protocol * Remove tenant and pageserver configuration for it * Update all tests that tweaked the ingest protocol ## Compatibility Backward compatibility: * The new pageserver version can read the existing pageserver configuration and it will ignore the unknown field. * When the tenant config is read from the storcon db or from the pageserver disk, the extra field will be ignored. Forward compatiblity: * Both the pageserver config and the tenant config map missing fields to their default value. I'm not aware of any tenant level override that was made for this knob. --- control_plane/src/pageserver.rs | 5 - libs/pageserver_api/src/config.rs | 12 -- libs/pageserver_api/src/models.rs | 14 -- pageserver/src/bin/pageserver.rs | 1 - pageserver/src/config.rs | 5 - pageserver/src/metrics.rs | 6 - pageserver/src/tenant/timeline.rs | 23 +-- .../walreceiver/connection_manager.rs | 32 ++-- .../walreceiver/walreceiver_connection.rs | 161 +----------------- test_runner/fixtures/neon_fixtures.py | 39 ----- .../performance/test_sharded_ingest.py | 39 ----- .../regress/test_attach_tenant_config.py | 4 - test_runner/regress/test_compaction.py | 8 - test_runner/regress/test_crafted_wal_end.py | 13 +- test_runner/regress/test_subxacts.py | 10 +- test_runner/regress/test_tenant_conf.py | 8 +- .../regress/test_wal_acceptor_async.py | 10 +- 17 files changed, 39 insertions(+), 351 deletions(-) diff --git a/control_plane/src/pageserver.rs b/control_plane/src/pageserver.rs index 29314dab9e..0cf7ca184d 100644 --- a/control_plane/src/pageserver.rs +++ b/control_plane/src/pageserver.rs @@ -513,11 +513,6 @@ impl PageServerNode { .map(|x| x.parse::()) .transpose() .context("Failed to parse 'timeline_offloading' as bool")?, - wal_receiver_protocol_override: settings - .remove("wal_receiver_protocol_override") - .map(serde_json::from_str) - .transpose() - .context("parse `wal_receiver_protocol_override` from json")?, rel_size_v2_enabled: settings .remove("rel_size_v2_enabled") .map(|x| x.parse::()) diff --git a/libs/pageserver_api/src/config.rs b/libs/pageserver_api/src/config.rs index 46903965b1..30b0612082 100644 --- a/libs/pageserver_api/src/config.rs +++ b/libs/pageserver_api/src/config.rs @@ -20,7 +20,6 @@ use postgres_backend::AuthType; use remote_storage::RemoteStorageConfig; use serde_with::serde_as; use utils::logging::LogFormat; -use utils::postgres_client::PostgresClientProtocol; use crate::models::{ImageCompressionAlgorithm, LsnLease}; @@ -189,7 +188,6 @@ pub struct ConfigToml { pub virtual_file_io_mode: Option, #[serde(skip_serializing_if = "Option::is_none")] pub no_sync: Option, - pub wal_receiver_protocol: PostgresClientProtocol, pub page_service_pipelining: PageServicePipeliningConfig, pub get_vectored_concurrent_io: GetVectoredConcurrentIo, pub enable_read_path_debugging: Option, @@ -527,8 +525,6 @@ pub struct TenantConfigToml { /// (either this flag or the pageserver-global one need to be set) pub timeline_offloading: bool, - pub wal_receiver_protocol_override: Option, - /// Enable rel_size_v2 for this tenant. Once enabled, the tenant will persist this information into /// `index_part.json`, and it cannot be reversed. pub rel_size_v2_enabled: bool, @@ -609,12 +605,6 @@ pub mod defaults { pub const DEFAULT_IO_BUFFER_ALIGNMENT: usize = 512; - pub const DEFAULT_WAL_RECEIVER_PROTOCOL: utils::postgres_client::PostgresClientProtocol = - utils::postgres_client::PostgresClientProtocol::Interpreted { - format: utils::postgres_client::InterpretedFormat::Protobuf, - compression: Some(utils::postgres_client::Compression::Zstd { level: 1 }), - }; - pub const DEFAULT_SSL_KEY_FILE: &str = "server.key"; pub const DEFAULT_SSL_CERT_FILE: &str = "server.crt"; } @@ -713,7 +703,6 @@ impl Default for ConfigToml { virtual_file_io_mode: None, tenant_config: TenantConfigToml::default(), no_sync: None, - wal_receiver_protocol: DEFAULT_WAL_RECEIVER_PROTOCOL, page_service_pipelining: PageServicePipeliningConfig::Pipelined( PageServicePipeliningConfigPipelined { max_batch_size: NonZeroUsize::new(32).unwrap(), @@ -858,7 +847,6 @@ impl Default for TenantConfigToml { lsn_lease_length: LsnLease::DEFAULT_LENGTH, lsn_lease_length_for_ts: LsnLease::DEFAULT_LENGTH_FOR_TS, timeline_offloading: true, - wal_receiver_protocol_override: None, rel_size_v2_enabled: false, gc_compaction_enabled: DEFAULT_GC_COMPACTION_ENABLED, gc_compaction_verification: DEFAULT_GC_COMPACTION_VERIFICATION, diff --git a/libs/pageserver_api/src/models.rs b/libs/pageserver_api/src/models.rs index 28ced4a368..881f24b86c 100644 --- a/libs/pageserver_api/src/models.rs +++ b/libs/pageserver_api/src/models.rs @@ -20,7 +20,6 @@ use serde_with::serde_as; pub use utilization::PageserverUtilization; use utils::id::{NodeId, TenantId, TimelineId}; use utils::lsn::Lsn; -use utils::postgres_client::PostgresClientProtocol; use utils::{completion, serde_system_time}; use crate::config::Ratio; @@ -622,8 +621,6 @@ pub struct TenantConfigPatch { #[serde(skip_serializing_if = "FieldPatch::is_noop")] pub timeline_offloading: FieldPatch, #[serde(skip_serializing_if = "FieldPatch::is_noop")] - pub wal_receiver_protocol_override: FieldPatch, - #[serde(skip_serializing_if = "FieldPatch::is_noop")] pub rel_size_v2_enabled: FieldPatch, #[serde(skip_serializing_if = "FieldPatch::is_noop")] pub gc_compaction_enabled: FieldPatch, @@ -748,9 +745,6 @@ pub struct TenantConfig { #[serde(skip_serializing_if = "Option::is_none")] pub timeline_offloading: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub wal_receiver_protocol_override: Option, - #[serde(skip_serializing_if = "Option::is_none")] pub rel_size_v2_enabled: Option, @@ -812,7 +806,6 @@ impl TenantConfig { mut lsn_lease_length, mut lsn_lease_length_for_ts, mut timeline_offloading, - mut wal_receiver_protocol_override, mut rel_size_v2_enabled, mut gc_compaction_enabled, mut gc_compaction_verification, @@ -905,9 +898,6 @@ impl TenantConfig { .map(|v| humantime::parse_duration(&v))? .apply(&mut lsn_lease_length_for_ts); patch.timeline_offloading.apply(&mut timeline_offloading); - patch - .wal_receiver_protocol_override - .apply(&mut wal_receiver_protocol_override); patch.rel_size_v2_enabled.apply(&mut rel_size_v2_enabled); patch .gc_compaction_enabled @@ -960,7 +950,6 @@ impl TenantConfig { lsn_lease_length, lsn_lease_length_for_ts, timeline_offloading, - wal_receiver_protocol_override, rel_size_v2_enabled, gc_compaction_enabled, gc_compaction_verification, @@ -1058,9 +1047,6 @@ impl TenantConfig { timeline_offloading: self .timeline_offloading .unwrap_or(global_conf.timeline_offloading), - wal_receiver_protocol_override: self - .wal_receiver_protocol_override - .or(global_conf.wal_receiver_protocol_override), rel_size_v2_enabled: self .rel_size_v2_enabled .unwrap_or(global_conf.rel_size_v2_enabled), diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index be7e634d4c..a1a95ad2d1 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -158,7 +158,6 @@ fn main() -> anyhow::Result<()> { // (maybe we should automate this with a visitor?). info!(?conf.virtual_file_io_engine, "starting with virtual_file IO engine"); info!(?conf.virtual_file_io_mode, "starting with virtual_file IO mode"); - info!(?conf.wal_receiver_protocol, "starting with WAL receiver protocol"); info!(?conf.validate_wal_contiguity, "starting with WAL contiguity validation"); info!(?conf.page_service_pipelining, "starting with page service pipelining config"); info!(?conf.get_vectored_concurrent_io, "starting with get_vectored IO concurrency config"); diff --git a/pageserver/src/config.rs b/pageserver/src/config.rs index 628d4f6021..3492a8d966 100644 --- a/pageserver/src/config.rs +++ b/pageserver/src/config.rs @@ -27,7 +27,6 @@ use reqwest::Url; use storage_broker::Uri; use utils::id::{NodeId, TimelineId}; use utils::logging::{LogFormat, SecretString}; -use utils::postgres_client::PostgresClientProtocol; use crate::tenant::storage_layer::inmemory_layer::IndexEntry; use crate::tenant::{TENANTS_SEGMENT_NAME, TIMELINES_SEGMENT_NAME}; @@ -211,8 +210,6 @@ pub struct PageServerConf { /// Optionally disable disk syncs (unsafe!) pub no_sync: bool, - pub wal_receiver_protocol: PostgresClientProtocol, - pub page_service_pipelining: pageserver_api::config::PageServicePipeliningConfig, pub get_vectored_concurrent_io: pageserver_api::config::GetVectoredConcurrentIo, @@ -421,7 +418,6 @@ impl PageServerConf { virtual_file_io_engine, tenant_config, no_sync, - wal_receiver_protocol, page_service_pipelining, get_vectored_concurrent_io, enable_read_path_debugging, @@ -484,7 +480,6 @@ impl PageServerConf { import_pgdata_upcall_api, import_pgdata_upcall_api_token: import_pgdata_upcall_api_token.map(SecretString::from), import_pgdata_aws_endpoint_url, - wal_receiver_protocol, page_service_pipelining, get_vectored_concurrent_io, tracing, diff --git a/pageserver/src/metrics.rs b/pageserver/src/metrics.rs index 3173ab221f..3eb70ffac2 100644 --- a/pageserver/src/metrics.rs +++ b/pageserver/src/metrics.rs @@ -2855,7 +2855,6 @@ pub(crate) struct WalIngestMetrics { pub(crate) records_received: IntCounter, pub(crate) records_observed: IntCounter, pub(crate) records_committed: IntCounter, - pub(crate) records_filtered: IntCounter, pub(crate) values_committed_metadata_images: IntCounter, pub(crate) values_committed_metadata_deltas: IntCounter, pub(crate) values_committed_data_images: IntCounter, @@ -2911,11 +2910,6 @@ pub(crate) static WAL_INGEST: Lazy = Lazy::new(|| { "Number of WAL records which resulted in writes to pageserver storage" ) .expect("failed to define a metric"), - records_filtered: register_int_counter!( - "pageserver_wal_ingest_records_filtered", - "Number of WAL records filtered out due to sharding" - ) - .expect("failed to define a metric"), values_committed_metadata_images: values_committed.with_label_values(&["metadata", "image"]), values_committed_metadata_deltas: values_committed.with_label_values(&["metadata", "delta"]), values_committed_data_images: values_committed.with_label_values(&["data", "image"]), diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index 8fdf84b6d3..6798606141 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -2845,21 +2845,6 @@ impl Timeline { ) } - /// Resolve the effective WAL receiver protocol to use for this tenant. - /// - /// Priority order is: - /// 1. Tenant config override - /// 2. Default value for tenant config override - /// 3. Pageserver config override - /// 4. Pageserver config default - pub fn resolve_wal_receiver_protocol(&self) -> PostgresClientProtocol { - let tenant_conf = self.tenant_conf.load().tenant_conf.clone(); - tenant_conf - .wal_receiver_protocol_override - .or(self.conf.default_tenant_conf.wal_receiver_protocol_override) - .unwrap_or(self.conf.wal_receiver_protocol) - } - pub(super) fn tenant_conf_updated(&self, new_conf: &AttachedTenantConf) { // NB: Most tenant conf options are read by background loops, so, // changes will automatically be picked up. @@ -3215,10 +3200,16 @@ impl Timeline { guard.is_none(), "multiple launches / re-launches of WAL receiver are not supported" ); + + let protocol = PostgresClientProtocol::Interpreted { + format: utils::postgres_client::InterpretedFormat::Protobuf, + compression: Some(utils::postgres_client::Compression::Zstd { level: 1 }), + }; + *guard = Some(WalReceiver::start( Arc::clone(self), WalReceiverConf { - protocol: self.resolve_wal_receiver_protocol(), + protocol, wal_connect_timeout, lagging_wal_timeout, max_lsn_wal_lag, diff --git a/pageserver/src/tenant/timeline/walreceiver/connection_manager.rs b/pageserver/src/tenant/timeline/walreceiver/connection_manager.rs index 3c3608d1bd..7e0b0e9b25 100644 --- a/pageserver/src/tenant/timeline/walreceiver/connection_manager.rs +++ b/pageserver/src/tenant/timeline/walreceiver/connection_manager.rs @@ -32,9 +32,7 @@ use utils::backoff::{ }; use utils::id::{NodeId, TenantTimelineId}; use utils::lsn::Lsn; -use utils::postgres_client::{ - ConnectionConfigArgs, PostgresClientProtocol, wal_stream_connection_config, -}; +use utils::postgres_client::{ConnectionConfigArgs, wal_stream_connection_config}; use super::walreceiver_connection::{WalConnectionStatus, WalReceiverError}; use super::{TaskEvent, TaskHandle, TaskStateUpdate, WalReceiverConf}; @@ -991,19 +989,12 @@ impl ConnectionManagerState { return None; // no connection string, ignore sk } - let (shard_number, shard_count, shard_stripe_size) = match self.conf.protocol { - PostgresClientProtocol::Vanilla => { - (None, None, None) - }, - PostgresClientProtocol::Interpreted { .. } => { - let shard_identity = self.timeline.get_shard_identity(); - ( - Some(shard_identity.number.0), - Some(shard_identity.count.0), - Some(shard_identity.stripe_size.0), - ) - } - }; + let shard_identity = self.timeline.get_shard_identity(); + let (shard_number, shard_count, shard_stripe_size) = ( + Some(shard_identity.number.0), + Some(shard_identity.count.0), + Some(shard_identity.stripe_size.0), + ); let connection_conf_args = ConnectionConfigArgs { protocol: self.conf.protocol, @@ -1120,8 +1111,8 @@ impl ReconnectReason { #[cfg(test)] mod tests { - use pageserver_api::config::defaults::DEFAULT_WAL_RECEIVER_PROTOCOL; use url::Host; + use utils::postgres_client::PostgresClientProtocol; use super::*; use crate::tenant::harness::{TIMELINE_ID, TenantHarness}; @@ -1552,6 +1543,11 @@ mod tests { .await .expect("Failed to create an empty timeline for dummy wal connection manager"); + let protocol = PostgresClientProtocol::Interpreted { + format: utils::postgres_client::InterpretedFormat::Protobuf, + compression: Some(utils::postgres_client::Compression::Zstd { level: 1 }), + }; + ConnectionManagerState { id: TenantTimelineId { tenant_id: harness.tenant_shard_id.tenant_id, @@ -1560,7 +1556,7 @@ mod tests { timeline, cancel: CancellationToken::new(), conf: WalReceiverConf { - protocol: DEFAULT_WAL_RECEIVER_PROTOCOL, + protocol, wal_connect_timeout: Duration::from_secs(1), lagging_wal_timeout: Duration::from_secs(1), max_lsn_wal_lag: NonZeroU64::new(1024 * 1024).unwrap(), diff --git a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs index 249849ac4b..343e04f5f0 100644 --- a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs +++ b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs @@ -15,7 +15,7 @@ use postgres_backend::is_expected_io_error; use postgres_connection::PgConnectionConfig; use postgres_ffi::WAL_SEGMENT_SIZE; use postgres_ffi::v14::xlog_utils::normalize_lsn; -use postgres_ffi::waldecoder::{WalDecodeError, WalStreamDecoder}; +use postgres_ffi::waldecoder::WalDecodeError; use postgres_protocol::message::backend::ReplicationMessage; use postgres_types::PgLsn; use tokio::sync::watch; @@ -31,7 +31,7 @@ use utils::lsn::Lsn; use utils::pageserver_feedback::PageserverFeedback; use utils::postgres_client::PostgresClientProtocol; use utils::sync::gate::GateError; -use wal_decoder::models::{FlushUncommittedRecords, InterpretedWalRecord, InterpretedWalRecords}; +use wal_decoder::models::{FlushUncommittedRecords, InterpretedWalRecords}; use wal_decoder::wire_format::FromWireFormat; use super::TaskStateUpdate; @@ -275,8 +275,6 @@ pub(super) async fn handle_walreceiver_connection( let copy_stream = replication_client.copy_both_simple(&query).await?; let mut physical_stream = pin!(ReplicationStream::new(copy_stream)); - let mut waldecoder = WalStreamDecoder::new(startpoint, timeline.pg_version); - let mut walingest = WalIngest::new(timeline.as_ref(), startpoint, &ctx) .await .map_err(|e| match e.kind { @@ -284,14 +282,16 @@ pub(super) async fn handle_walreceiver_connection( _ => WalReceiverError::Other(e.into()), })?; - let shard = vec![*timeline.get_shard_identity()]; - - let interpreted_proto_config = match protocol { - PostgresClientProtocol::Vanilla => None, + let (format, compression) = match protocol { PostgresClientProtocol::Interpreted { format, compression, - } => Some((format, compression)), + } => (format, compression), + PostgresClientProtocol::Vanilla => { + return Err(WalReceiverError::Other(anyhow!( + "Vanilla WAL receiver protocol is no longer supported for ingest" + ))); + } }; let mut expected_wal_start = startpoint; @@ -313,16 +313,6 @@ pub(super) async fn handle_walreceiver_connection( // Update the connection status before processing the message. If the message processing // fails (e.g. in walingest), we still want to know latests LSNs from the safekeeper. match &replication_message { - ReplicationMessage::XLogData(xlog_data) => { - connection_status.latest_connection_update = now; - connection_status.commit_lsn = Some(Lsn::from(xlog_data.wal_end())); - connection_status.streaming_lsn = Some(Lsn::from( - xlog_data.wal_start() + xlog_data.data().len() as u64, - )); - if !xlog_data.data().is_empty() { - connection_status.latest_wal_update = now; - } - } ReplicationMessage::PrimaryKeepAlive(keepalive) => { connection_status.latest_connection_update = now; connection_status.commit_lsn = Some(Lsn::from(keepalive.wal_end())); @@ -353,7 +343,6 @@ pub(super) async fn handle_walreceiver_connection( // were interpreted. let streaming_lsn = Lsn::from(raw.streaming_lsn()); - let (format, compression) = interpreted_proto_config.unwrap(); let batch = InterpretedWalRecords::from_wire(raw.data(), format, compression) .await .with_context(|| { @@ -509,138 +498,6 @@ pub(super) async fn handle_walreceiver_connection( Some(streaming_lsn) } - ReplicationMessage::XLogData(xlog_data) => { - async fn commit( - modification: &mut DatadirModification<'_>, - uncommitted: &mut u64, - filtered: &mut u64, - ctx: &RequestContext, - ) -> anyhow::Result<()> { - let stats = modification.stats(); - modification.commit(ctx).await?; - WAL_INGEST - .records_committed - .inc_by(*uncommitted - *filtered); - WAL_INGEST.inc_values_committed(&stats); - *uncommitted = 0; - *filtered = 0; - Ok(()) - } - - // Pass the WAL data to the decoder, and see if we can decode - // more records as a result. - let data = xlog_data.data(); - let startlsn = Lsn::from(xlog_data.wal_start()); - let endlsn = startlsn + data.len() as u64; - - trace!("received XLogData between {startlsn} and {endlsn}"); - - WAL_INGEST.bytes_received.inc_by(data.len() as u64); - waldecoder.feed_bytes(data); - - { - let mut modification = timeline.begin_modification(startlsn); - let mut uncommitted_records = 0; - let mut filtered_records = 0; - - while let Some((next_record_lsn, recdata)) = waldecoder.poll_decode()? { - // It is important to deal with the aligned records as lsn in getPage@LSN is - // aligned and can be several bytes bigger. Without this alignment we are - // at risk of hitting a deadlock. - if !next_record_lsn.is_aligned() { - return Err(WalReceiverError::Other(anyhow!("LSN not aligned"))); - } - - // Deserialize and interpret WAL record - let interpreted = InterpretedWalRecord::from_bytes_filtered( - recdata, - &shard, - next_record_lsn, - modification.tline.pg_version, - )? - .remove(timeline.get_shard_identity()) - .unwrap(); - - if matches!(interpreted.flush_uncommitted, FlushUncommittedRecords::Yes) - && uncommitted_records > 0 - { - // Special case: legacy PG database creations operate by reading pages from a 'template' database: - // these are the only kinds of WAL record that require reading data blocks while ingesting. Ensure - // all earlier writes of data blocks are visible by committing any modification in flight. - commit( - &mut modification, - &mut uncommitted_records, - &mut filtered_records, - &ctx, - ) - .await?; - } - - // Ingest the records without immediately committing them. - timeline.metrics.wal_records_received.inc(); - let ingested = walingest - .ingest_record(interpreted, &mut modification, &ctx) - .await - .with_context(|| { - format!("could not ingest record at {next_record_lsn}") - }) - .inspect_err(|err| { - // TODO: we can't differentiate cancellation errors with - // anyhow::Error, so just ignore it if we're cancelled. - if !cancellation.is_cancelled() && !timeline.is_stopping() { - critical!("{err:?}") - } - })?; - if !ingested { - tracing::debug!("ingest: filtered out record @ LSN {next_record_lsn}"); - WAL_INGEST.records_filtered.inc(); - filtered_records += 1; - } - - // FIXME: this cannot be made pausable_failpoint without fixing the - // failpoint library; in tests, the added amount of debugging will cause us - // to timeout the tests. - fail_point!("walreceiver-after-ingest"); - - last_rec_lsn = next_record_lsn; - - // Commit every ingest_batch_size records. Even if we filtered out - // all records, we still need to call commit to advance the LSN. - uncommitted_records += 1; - if uncommitted_records >= ingest_batch_size - || modification.approx_pending_bytes() - > DatadirModification::MAX_PENDING_BYTES - { - commit( - &mut modification, - &mut uncommitted_records, - &mut filtered_records, - &ctx, - ) - .await?; - } - } - - // Commit the remaining records. - if uncommitted_records > 0 { - commit( - &mut modification, - &mut uncommitted_records, - &mut filtered_records, - &ctx, - ) - .await?; - } - } - - if !caught_up && endlsn >= end_of_wal { - info!("caught up at LSN {endlsn}"); - caught_up = true; - } - - Some(endlsn) - } - ReplicationMessage::PrimaryKeepAlive(keepalive) => { let wal_end = keepalive.wal_end(); let timestamp = keepalive.timestamp(); diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index f9337bed89..db3f080261 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -357,31 +357,6 @@ class PgProtocol: return TimelineId(cast("str", self.safe_psql("show neon.timeline_id")[0][0])) -class PageserverWalReceiverProtocol(StrEnum): - VANILLA = "vanilla" - INTERPRETED = "interpreted" - - @staticmethod - def to_config_key_value(proto) -> tuple[str, dict[str, Any]]: - if proto == PageserverWalReceiverProtocol.VANILLA: - return ( - "wal_receiver_protocol", - { - "type": "vanilla", - }, - ) - elif proto == PageserverWalReceiverProtocol.INTERPRETED: - return ( - "wal_receiver_protocol", - { - "type": "interpreted", - "args": {"format": "protobuf", "compression": {"zstd": {"level": 1}}}, - }, - ) - else: - raise ValueError(f"Unknown protocol type: {proto}") - - @dataclass class PageserverTracingConfig: sampling_ratio: tuple[int, int] @@ -475,7 +450,6 @@ class NeonEnvBuilder: safekeeper_extra_opts: list[str] | None = None, storage_controller_port_override: int | None = None, pageserver_virtual_file_io_mode: str | None = None, - pageserver_wal_receiver_protocol: PageserverWalReceiverProtocol | None = None, pageserver_get_vectored_concurrent_io: str | None = None, pageserver_tracing_config: PageserverTracingConfig | None = None, pageserver_import_config: PageserverImportConfig | None = None, @@ -552,11 +526,6 @@ class NeonEnvBuilder: self.pageserver_virtual_file_io_mode = pageserver_virtual_file_io_mode - if pageserver_wal_receiver_protocol is not None: - self.pageserver_wal_receiver_protocol = pageserver_wal_receiver_protocol - else: - self.pageserver_wal_receiver_protocol = PageserverWalReceiverProtocol.INTERPRETED - assert test_name.startswith("test_"), ( "Unexpectedly instantiated from outside a test function" ) @@ -1202,7 +1171,6 @@ class NeonEnv: self.pageserver_virtual_file_io_engine = config.pageserver_virtual_file_io_engine self.pageserver_virtual_file_io_mode = config.pageserver_virtual_file_io_mode - self.pageserver_wal_receiver_protocol = config.pageserver_wal_receiver_protocol self.pageserver_get_vectored_concurrent_io = config.pageserver_get_vectored_concurrent_io self.pageserver_tracing_config = config.pageserver_tracing_config if config.pageserver_import_config is None: @@ -1334,13 +1302,6 @@ class NeonEnv: for key, value in override.items(): ps_cfg[key] = value - if self.pageserver_wal_receiver_protocol is not None: - key, value = PageserverWalReceiverProtocol.to_config_key_value( - self.pageserver_wal_receiver_protocol - ) - if key not in ps_cfg: - ps_cfg[key] = value - if self.pageserver_tracing_config is not None: key, value = self.pageserver_tracing_config.to_config_key_value() diff --git a/test_runner/performance/test_sharded_ingest.py b/test_runner/performance/test_sharded_ingest.py index 293026d40a..364fcf3737 100644 --- a/test_runner/performance/test_sharded_ingest.py +++ b/test_runner/performance/test_sharded_ingest.py @@ -15,19 +15,10 @@ from fixtures.neon_fixtures import ( @pytest.mark.timeout(1200) @pytest.mark.parametrize("shard_count", [1, 8, 32]) -@pytest.mark.parametrize( - "wal_receiver_protocol", - [ - "vanilla", - "interpreted-bincode-compressed", - "interpreted-protobuf-compressed", - ], -) def test_sharded_ingest( neon_env_builder: NeonEnvBuilder, zenbenchmark: NeonBenchmarker, shard_count: int, - wal_receiver_protocol: str, ): """ Benchmarks sharded ingestion throughput, by ingesting a large amount of WAL into a Safekeeper @@ -39,36 +30,6 @@ def test_sharded_ingest( neon_env_builder.num_pageservers = shard_count env = neon_env_builder.init_configs() - for ps in env.pageservers: - if wal_receiver_protocol == "vanilla": - ps.patch_config_toml_nonrecursive( - { - "wal_receiver_protocol": { - "type": "vanilla", - } - } - ) - elif wal_receiver_protocol == "interpreted-bincode-compressed": - ps.patch_config_toml_nonrecursive( - { - "wal_receiver_protocol": { - "type": "interpreted", - "args": {"format": "bincode", "compression": {"zstd": {"level": 1}}}, - } - } - ) - elif wal_receiver_protocol == "interpreted-protobuf-compressed": - ps.patch_config_toml_nonrecursive( - { - "wal_receiver_protocol": { - "type": "interpreted", - "args": {"format": "protobuf", "compression": {"zstd": {"level": 1}}}, - } - } - ) - else: - raise AssertionError("Test must use explicit wal receiver protocol config") - env.start() # Create a sharded tenant and timeline, and migrate it to the respective pageservers. Ensure diff --git a/test_runner/regress/test_attach_tenant_config.py b/test_runner/regress/test_attach_tenant_config.py index 3eb6b7193c..dc44fc77db 100644 --- a/test_runner/regress/test_attach_tenant_config.py +++ b/test_runner/regress/test_attach_tenant_config.py @@ -182,10 +182,6 @@ def test_fully_custom_config(positive_env: NeonEnv): "lsn_lease_length": "1m", "lsn_lease_length_for_ts": "5s", "timeline_offloading": False, - "wal_receiver_protocol_override": { - "type": "interpreted", - "args": {"format": "bincode", "compression": {"zstd": {"level": 1}}}, - }, "rel_size_v2_enabled": True, "relsize_snapshot_cache_capacity": 10000, "gc_compaction_enabled": True, diff --git a/test_runner/regress/test_compaction.py b/test_runner/regress/test_compaction.py index 370f57b19d..1570d40ae9 100644 --- a/test_runner/regress/test_compaction.py +++ b/test_runner/regress/test_compaction.py @@ -10,7 +10,6 @@ import pytest from fixtures.log_helper import log from fixtures.neon_fixtures import ( NeonEnvBuilder, - PageserverWalReceiverProtocol, generate_uploads_and_deletions, ) from fixtures.pageserver.http import PageserverApiException @@ -68,14 +67,9 @@ PREEMPT_GC_COMPACTION_TENANT_CONF = { @skip_in_debug_build("only run with release build") -@pytest.mark.parametrize( - "wal_receiver_protocol", - [PageserverWalReceiverProtocol.VANILLA, PageserverWalReceiverProtocol.INTERPRETED], -) @pytest.mark.timeout(900) def test_pageserver_compaction_smoke( neon_env_builder: NeonEnvBuilder, - wal_receiver_protocol: PageserverWalReceiverProtocol, ): """ This is a smoke test that compaction kicks in. The workload repeatedly churns @@ -85,8 +79,6 @@ def test_pageserver_compaction_smoke( observed bounds. """ - neon_env_builder.pageserver_wal_receiver_protocol = wal_receiver_protocol - # Effectively disable the page cache to rely only on image layers # to shorten reads. neon_env_builder.pageserver_config_override = """ diff --git a/test_runner/regress/test_crafted_wal_end.py b/test_runner/regress/test_crafted_wal_end.py index 6b9dcbba07..89ff873ca3 100644 --- a/test_runner/regress/test_crafted_wal_end.py +++ b/test_runner/regress/test_crafted_wal_end.py @@ -1,9 +1,13 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import pytest from fixtures.log_helper import log from fixtures.neon_cli import WalCraft -from fixtures.neon_fixtures import NeonEnvBuilder, PageserverWalReceiverProtocol + +if TYPE_CHECKING: + from fixtures.neon_fixtures import NeonEnvBuilder # Restart nodes with WAL end having specially crafted shape, like last record # crossing segment boundary, to test decoding issues. @@ -19,17 +23,10 @@ from fixtures.neon_fixtures import NeonEnvBuilder, PageserverWalReceiverProtocol "wal_record_crossing_segment_followed_by_small_one", ], ) -@pytest.mark.parametrize( - "wal_receiver_protocol", - [PageserverWalReceiverProtocol.VANILLA, PageserverWalReceiverProtocol.INTERPRETED], -) def test_crafted_wal_end( neon_env_builder: NeonEnvBuilder, wal_type: str, - wal_receiver_protocol: PageserverWalReceiverProtocol, ): - neon_env_builder.pageserver_wal_receiver_protocol = wal_receiver_protocol - env = neon_env_builder.init_start() env.create_branch("test_crafted_wal_end") env.pageserver.allowed_errors.extend( diff --git a/test_runner/regress/test_subxacts.py b/test_runner/regress/test_subxacts.py index b235da0bc7..92a21007c8 100644 --- a/test_runner/regress/test_subxacts.py +++ b/test_runner/regress/test_subxacts.py @@ -1,9 +1,7 @@ from __future__ import annotations -import pytest from fixtures.neon_fixtures import ( NeonEnvBuilder, - PageserverWalReceiverProtocol, check_restored_datadir_content, ) @@ -14,13 +12,7 @@ from fixtures.neon_fixtures import ( # maintained in the pageserver, so subtransactions are not very exciting for # Neon. They are included in the commit record though and updated in the # CLOG. -@pytest.mark.parametrize( - "wal_receiver_protocol", - [PageserverWalReceiverProtocol.VANILLA, PageserverWalReceiverProtocol.INTERPRETED], -) -def test_subxacts(neon_env_builder: NeonEnvBuilder, test_output_dir, wal_receiver_protocol): - neon_env_builder.pageserver_wal_receiver_protocol = wal_receiver_protocol - +def test_subxacts(neon_env_builder: NeonEnvBuilder, test_output_dir): env = neon_env_builder.init_start() endpoint = env.endpoints.create_start("main") diff --git a/test_runner/regress/test_tenant_conf.py b/test_runner/regress/test_tenant_conf.py index de6bdc0aec..d78b9d8817 100644 --- a/test_runner/regress/test_tenant_conf.py +++ b/test_runner/regress/test_tenant_conf.py @@ -348,7 +348,6 @@ def test_tenant_config_patch(neon_env_builder: NeonEnvBuilder, ps_managed_by: st def assert_tenant_conf_semantically_equal(lhs, rhs): """ - Storcon returns None for fields that are not set while the pageserver does not. Compare two tenant's config overrides semantically, by dropping the None values. """ lhs = {k: v for k, v in lhs.items() if v is not None} @@ -375,10 +374,7 @@ def test_tenant_config_patch(neon_env_builder: NeonEnvBuilder, ps_managed_by: st patch: dict[str, Any | None] = { "gc_period": "3h", - "wal_receiver_protocol_override": { - "type": "interpreted", - "args": {"format": "bincode", "compression": {"zstd": {"level": 1}}}, - }, + "gc_compaction_ratio_percent": 10, } api.patch_tenant_config(env.initial_tenant, patch) tenant_conf_after_patch = api.tenant_config(env.initial_tenant).tenant_specific_overrides @@ -391,7 +387,7 @@ def test_tenant_config_patch(neon_env_builder: NeonEnvBuilder, ps_managed_by: st assert_tenant_conf_semantically_equal(tenant_conf_after_patch, crnt_tenant_conf | patch) crnt_tenant_conf = tenant_conf_after_patch - patch = {"gc_period": "5h", "wal_receiver_protocol_override": None} + patch = {"gc_period": "5h", "gc_compaction_ratio_percent": None} api.patch_tenant_config(env.initial_tenant, patch) tenant_conf_after_patch = api.tenant_config(env.initial_tenant).tenant_specific_overrides if ps_managed_by == "storcon": diff --git a/test_runner/regress/test_wal_acceptor_async.py b/test_runner/regress/test_wal_acceptor_async.py index 4070f99568..d8a7dc2a2b 100644 --- a/test_runner/regress/test_wal_acceptor_async.py +++ b/test_runner/regress/test_wal_acceptor_async.py @@ -14,7 +14,6 @@ from fixtures.neon_fixtures import ( Endpoint, NeonEnv, NeonEnvBuilder, - PageserverWalReceiverProtocol, Safekeeper, ) from fixtures.remote_storage import RemoteStorageKind @@ -751,15 +750,8 @@ async def run_segment_init_failure(env: NeonEnv): # Test (injected) failure during WAL segment init. # https://github.com/neondatabase/neon/issues/6401 # https://github.com/neondatabase/neon/issues/6402 -@pytest.mark.parametrize( - "wal_receiver_protocol", - [PageserverWalReceiverProtocol.VANILLA, PageserverWalReceiverProtocol.INTERPRETED], -) -def test_segment_init_failure( - neon_env_builder: NeonEnvBuilder, wal_receiver_protocol: PageserverWalReceiverProtocol -): +def test_segment_init_failure(neon_env_builder: NeonEnvBuilder): neon_env_builder.num_safekeepers = 1 - neon_env_builder.pageserver_wal_receiver_protocol = wal_receiver_protocol env = neon_env_builder.init_start() asyncio.run(run_segment_init_failure(env)) From 6a43f23ecaeb12385dc7b7bc5b1b4a5f52047b37 Mon Sep 17 00:00:00 2001 From: Erik Grinaker Date: Thu, 5 Jun 2025 13:52:52 +0200 Subject: [PATCH 40/43] pagebench: add batch support (#12133) ## Problem The new gRPC page service protocol supports client-side batches. The current libpq protocol only does best-effort server-side batching. To compare these approaches, Pagebench should support submitting contiguous page batches, similar to how Postgres will submit them (e.g. with prefetches or vectored reads). ## Summary of changes Add a `--batch-size` parameter specifying the size of contiguous page batches. One batch counts as 1 RPS and 1 queue depth. For the libpq protocol, a batch is submitted as individual requests and we rely on the server to batch them for us. This will give a realistic comparison of how these would be processed in the wild (e.g. when Postgres sends 100 prefetch requests). This patch also adds some basic validation of responses. --- Cargo.lock | 1 + pageserver/pagebench/Cargo.toml | 1 + .../pagebench/src/cmd/getpage_latest_lsn.rs | 195 ++++++++++++++---- 3 files changed, 152 insertions(+), 45 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 542a4528d3..588a63b6a3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4237,6 +4237,7 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "bytes", "camino", "clap", "futures", diff --git a/pageserver/pagebench/Cargo.toml b/pageserver/pagebench/Cargo.toml index ceb1278eab..5e4af88e69 100644 --- a/pageserver/pagebench/Cargo.toml +++ b/pageserver/pagebench/Cargo.toml @@ -9,6 +9,7 @@ license.workspace = true [dependencies] anyhow.workspace = true async-trait.workspace = true +bytes.workspace = true camino.workspace = true clap.workspace = true futures.workspace = true diff --git a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs index 395e9cac41..3f3b6e396e 100644 --- a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs +++ b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs @@ -1,4 +1,4 @@ -use std::collections::{HashSet, VecDeque}; +use std::collections::{HashMap, HashSet, VecDeque}; use std::future::Future; use std::num::NonZeroUsize; use std::pin::Pin; @@ -8,12 +8,12 @@ use std::time::{Duration, Instant}; use anyhow::Context; use async_trait::async_trait; +use bytes::Bytes; use camino::Utf8PathBuf; use pageserver_api::key::Key; use pageserver_api::keyspace::KeySpaceAccum; -use pageserver_api::models::{ - PagestreamGetPageRequest, PagestreamGetPageResponse, PagestreamRequest, -}; +use pageserver_api::models::{PagestreamGetPageRequest, PagestreamRequest}; +use pageserver_api::reltag::RelTag; use pageserver_api::shard::TenantShardId; use pageserver_page_api::proto; use rand::prelude::*; @@ -77,6 +77,16 @@ pub(crate) struct Args { #[clap(long, default_value = "1")] queue_depth: NonZeroUsize, + /// Batch size of contiguous pages generated by each client. This is equivalent to how Postgres + /// will request page batches (e.g. prefetches or vectored reads). A batch counts as 1 RPS and + /// 1 queue depth. + /// + /// The libpq protocol does not support client-side batching, and will submit batches as many + /// individual requests, in the hope that the server will batch them. Each batch still counts as + /// 1 RPS and 1 queue depth. + #[clap(long, default_value = "1")] + batch_size: NonZeroUsize, + #[clap(long)] only_relnode: Option, @@ -392,7 +402,16 @@ async fn run_worker( shared_state.start_work_barrier.wait().await; let client_start = Instant::now(); let mut ticks_processed = 0; - let mut inflight = VecDeque::new(); + let mut req_id = 0; + let batch_size: usize = args.batch_size.into(); + + // Track inflight requests by request ID and start time. This times the request duration, and + // ensures responses match requests. We don't expect responses back in any particular order. + // + // NB: this does not check that all requests received a response, because we don't wait for the + // inflight requests to complete when the duration elapses. + let mut inflight: HashMap = HashMap::new(); + while !cancel.is_cancelled() { // Detect if a request took longer than the RPS rate if let Some(period) = &rps_period { @@ -408,36 +427,72 @@ async fn run_worker( } while inflight.len() < args.queue_depth.get() { + req_id += 1; let start = Instant::now(); - let req = { + let (req_lsn, mod_lsn, rel, blks) = { + /// Converts a compact i128 key to a relation tag and block number. + fn key_to_block(key: i128) -> (RelTag, u32) { + let key = Key::from_i128(key); + assert!(key.is_rel_block_key()); + key.to_rel_block() + .expect("we filter non-rel-block keys out above") + } + + // Pick a random page from a random relation. let mut rng = rand::thread_rng(); let r = &ranges[weights.sample(&mut rng)]; let key: i128 = rng.gen_range(r.start..r.end); - let key = Key::from_i128(key); - assert!(key.is_rel_block_key()); - let (rel_tag, block_no) = key - .to_rel_block() - .expect("we filter non-rel-block keys out above"); - PagestreamGetPageRequest { - hdr: PagestreamRequest { - reqid: 0, - request_lsn: if rng.gen_bool(args.req_latest_probability) { - Lsn::MAX - } else { - r.timeline_lsn - }, - not_modified_since: r.timeline_lsn, - }, - rel: rel_tag, - blkno: block_no, + let (rel_tag, block_no) = key_to_block(key); + + let mut blks = VecDeque::with_capacity(batch_size); + blks.push_back(block_no); + + // If requested, populate a batch of sequential pages. This is how Postgres will + // request page batches (e.g. prefetches). If we hit the end of the relation, we + // grow the batch towards the start too. + for i in 1..batch_size { + let (r, b) = key_to_block(key + i as i128); + if r != rel_tag { + break; // went outside relation + } + blks.push_back(b) } + + if blks.len() < batch_size { + // Grow batch backwards if needed. + for i in 1..batch_size { + let (r, b) = key_to_block(key - i as i128); + if r != rel_tag { + break; // went outside relation + } + blks.push_front(b) + } + } + + // We assume that the entire batch can fit within the relation. + assert_eq!(blks.len(), batch_size, "incomplete batch"); + + let req_lsn = if rng.gen_bool(args.req_latest_probability) { + Lsn::MAX + } else { + r.timeline_lsn + }; + (req_lsn, r.timeline_lsn, rel_tag, blks.into()) }; - client.send_get_page(req).await.unwrap(); - inflight.push_back(start); + client + .send_get_page(req_id, req_lsn, mod_lsn, rel, blks) + .await + .unwrap(); + let old = inflight.insert(req_id, start); + assert!(old.is_none(), "duplicate request ID {req_id}"); } - let start = inflight.pop_front().unwrap(); - client.recv_get_page().await.unwrap(); + let (req_id, pages) = client.recv_get_page().await.unwrap(); + assert_eq!(pages.len(), batch_size, "unexpected page count"); + assert!(pages.iter().all(|p| !p.is_empty()), "empty page"); + let start = inflight + .remove(&req_id) + .expect("response for unknown request ID"); let end = Instant::now(); shared_state.live_stats.request_done(); ticks_processed += 1; @@ -467,15 +522,24 @@ async fn run_worker( #[async_trait] trait Client: Send { /// Sends an asynchronous GetPage request to the pageserver. - async fn send_get_page(&mut self, req: PagestreamGetPageRequest) -> anyhow::Result<()>; + async fn send_get_page( + &mut self, + req_id: u64, + req_lsn: Lsn, + mod_lsn: Lsn, + rel: RelTag, + blks: Vec, + ) -> anyhow::Result<()>; /// Receives the next GetPage response from the pageserver. - async fn recv_get_page(&mut self) -> anyhow::Result; + async fn recv_get_page(&mut self) -> anyhow::Result<(u64, Vec)>; } /// A libpq-based Pageserver client. struct LibpqClient { inner: pageserver_client::page_service::PagestreamClient, + // Track sent batches, so we know how many responses to expect. + batch_sizes: VecDeque, } impl LibpqClient { @@ -484,18 +548,55 @@ impl LibpqClient { .await? .pagestream(ttid.tenant_id, ttid.timeline_id) .await?; - Ok(Self { inner }) + Ok(Self { + inner, + batch_sizes: VecDeque::new(), + }) } } #[async_trait] impl Client for LibpqClient { - async fn send_get_page(&mut self, req: PagestreamGetPageRequest) -> anyhow::Result<()> { - self.inner.getpage_send(req).await + async fn send_get_page( + &mut self, + req_id: u64, + req_lsn: Lsn, + mod_lsn: Lsn, + rel: RelTag, + blks: Vec, + ) -> anyhow::Result<()> { + // libpq doesn't support client-side batches, so we send a bunch of individual requests + // instead in the hope that the server will batch them for us. We use the same request ID + // for all, because we'll return a single batch response. + self.batch_sizes.push_back(blks.len()); + for blkno in blks { + let req = PagestreamGetPageRequest { + hdr: PagestreamRequest { + reqid: req_id, + request_lsn: req_lsn, + not_modified_since: mod_lsn, + }, + rel, + blkno, + }; + self.inner.getpage_send(req).await?; + } + Ok(()) } - async fn recv_get_page(&mut self) -> anyhow::Result { - self.inner.getpage_recv().await + async fn recv_get_page(&mut self) -> anyhow::Result<(u64, Vec)> { + let batch_size = self.batch_sizes.pop_front().unwrap(); + let mut batch = Vec::with_capacity(batch_size); + let mut req_id = None; + for _ in 0..batch_size { + let resp = self.inner.getpage_recv().await?; + if req_id.is_none() { + req_id = Some(resp.req.hdr.reqid); + } + assert_eq!(req_id, Some(resp.req.hdr.reqid), "request ID mismatch"); + batch.push(resp.page); + } + Ok((req_id.unwrap(), batch)) } } @@ -532,31 +633,35 @@ impl GrpcClient { #[async_trait] impl Client for GrpcClient { - async fn send_get_page(&mut self, req: PagestreamGetPageRequest) -> anyhow::Result<()> { + async fn send_get_page( + &mut self, + req_id: u64, + req_lsn: Lsn, + mod_lsn: Lsn, + rel: RelTag, + blks: Vec, + ) -> anyhow::Result<()> { let req = proto::GetPageRequest { - request_id: 0, + request_id: req_id, request_class: proto::GetPageClass::Normal as i32, read_lsn: Some(proto::ReadLsn { - request_lsn: req.hdr.request_lsn.0, - not_modified_since_lsn: req.hdr.not_modified_since.0, + request_lsn: req_lsn.0, + not_modified_since_lsn: mod_lsn.0, }), - rel: Some(req.rel.into()), - block_number: vec![req.blkno], + rel: Some(rel.into()), + block_number: blks, }; self.req_tx.send(req).await?; Ok(()) } - async fn recv_get_page(&mut self) -> anyhow::Result { + async fn recv_get_page(&mut self) -> anyhow::Result<(u64, Vec)> { let resp = self.resp_rx.message().await?.unwrap(); anyhow::ensure!( resp.status_code == proto::GetPageStatusCode::Ok as i32, "unexpected status code: {}", resp.status_code ); - Ok(PagestreamGetPageResponse { - page: resp.page_image[0].clone(), - req: PagestreamGetPageRequest::default(), // dummy - }) + Ok((resp.request_id, resp.page_image)) } } From 038e967daf145af48290eb8560b8a3a43f3579fc Mon Sep 17 00:00:00 2001 From: a-masterov <72613290+a-masterov@users.noreply.github.com> Date: Thu, 5 Jun 2025 14:03:51 +0200 Subject: [PATCH 41/43] Configure the dynamic loader for the extension-tests image (#12142) ## Problem The same problem, fixed in https://github.com/neondatabase/neon/issues/11857, but for the image `neon-extesions-test` ## Summary of changes The config file was added to use our library --- compute/compute-node.Dockerfile | 1 + 1 file changed, 1 insertion(+) diff --git a/compute/compute-node.Dockerfile b/compute/compute-node.Dockerfile index d3008c8085..248f52088b 100644 --- a/compute/compute-node.Dockerfile +++ b/compute/compute-node.Dockerfile @@ -1902,6 +1902,7 @@ COPY compute/patches/pg_repack.patch /ext-src RUN cd /ext-src/pg_repack-src && patch -p1 /etc/ld.so.conf.d/00-neon.conf && /sbin/ldconfig RUN apt-get update && apt-get install -y libtap-parser-sourcehandler-pgtap-perl jq \ && apt clean && rm -rf /ext-src/*.tar.gz /ext-src/*.patch /var/lib/apt/lists/* ENV PATH=/usr/local/pgsql/bin:$PATH From f7ec7668a2a452f23f197ee2670a4d799be20ed4 Mon Sep 17 00:00:00 2001 From: Dmitrii Kovalkov <34828390+DimasKovas@users.noreply.github.com> Date: Thu, 5 Jun 2025 16:04:37 +0400 Subject: [PATCH 42/43] pageserver, tests: prepare test_basebackup_cache for --timelines-onto-safekeepers (#12143) ## Problem - `test_basebackup_cache` fails in https://github.com/neondatabase/neon/pull/11712 because after the timelines on safekeepers are managed by storage controller, they do contain proper start_lsn and the compute_ctl tool sends the first basebackup request with this LSN. - `Failed to prepare basebackup` log messages during timeline initialization, because the timeline is not yet in the global timeline map. - Relates to https://github.com/neondatabase/cloud/issues/29353 ## Summary of changes - Account for `timeline_onto_safekeepers` storcon's option in the test. - Do not trigger basebackup prepare during the timeline initialization. --- pageserver/src/tenant/timeline.rs | 7 ++++++ test_runner/regress/test_basebackup.py | 33 +++++++++++++++++++------- 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index 6798606141..3522af2de0 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -2506,6 +2506,13 @@ impl Timeline { // Preparing basebackup doesn't make sense for shards other than shard zero. return; } + if !self.is_active() { + // May happen during initial timeline creation. + // Such timeline is not in the global timeline map yet, + // so basebackup cache will not be able to find it. + // TODO(diko): We can prepare such timelines in finish_creation(). + return; + } let res = self .basebackup_prepare_sender diff --git a/test_runner/regress/test_basebackup.py b/test_runner/regress/test_basebackup.py index b083c394c7..2d42be4051 100644 --- a/test_runner/regress/test_basebackup.py +++ b/test_runner/regress/test_basebackup.py @@ -26,6 +26,10 @@ def test_basebackup_cache(neon_env_builder: NeonEnvBuilder): ps = env.pageserver ps_http = ps.http_client() + storcon_managed_timelines = (env.storage_controller_config or {}).get( + "timelines_onto_safekeepers", False + ) + # 1. Check that we always hit the cache after compute restart. for i in range(3): ep.start() @@ -33,15 +37,26 @@ def test_basebackup_cache(neon_env_builder: NeonEnvBuilder): def check_metrics(i=i): metrics = ps_http.get_metrics() - # Never miss. - # The first time compute_ctl sends `get_basebackup` with lsn=None, we do not cache such requests. - # All other requests should be a hit - assert ( - metrics.query_one( - "pageserver_basebackup_cache_read_total", {"result": "miss"} - ).value - == 0 - ) + if storcon_managed_timelines: + # We do not cache the initial basebackup yet, + # so the first compute startup should be a miss. + assert ( + metrics.query_one( + "pageserver_basebackup_cache_read_total", {"result": "miss"} + ).value + == 1 + ) + else: + # If the timeline is not initialized on safekeeprs, + # the compute_ctl sends `get_basebackup` with lsn=None for the first startup. + # We do not use cache for such requests, so it's niether a hit nor a miss. + assert ( + metrics.query_one( + "pageserver_basebackup_cache_read_total", {"result": "miss"} + ).value + == 0 + ) + # All but the first requests are hits. assert ( metrics.query_one("pageserver_basebackup_cache_read_total", {"result": "hit"}).value From 6ae4b89000b7dd80e5f950023e4f29c384dc926c Mon Sep 17 00:00:00 2001 From: Alexey Kondratov Date: Thu, 5 Jun 2025 14:17:28 +0200 Subject: [PATCH 43/43] feat(compute_ctl): Implement graceful compute monitor exit (#11911) ## Problem After introducing a naive downtime calculation for the Postgres process inside compute in https://github.com/neondatabase/neon/pull/11346, I noticed that some amount of computes regularly report short downtime. After checking some particular cases, it looks like all of them report downtime close to the end of the life of the compute, i.e., when the control plane calls a `/terminate` and we are waiting for Postgres to exit. Compute monitor also produces a lot of error logs because Postgres stops accepting connections, but it's expected during the termination process. ## Summary of changes Regularly check the compute status inside the main compute monitor loop and exit gracefully when the compute is in some terminal or soon-to-be-terminal state. --------- Co-authored-by: Tristan Partin --- compute_tools/src/monitor.rs | 66 ++++++++++++++++++++++++++++++++---- 1 file changed, 59 insertions(+), 7 deletions(-) diff --git a/compute_tools/src/monitor.rs b/compute_tools/src/monitor.rs index 3311ee47b3..bacaf05cd5 100644 --- a/compute_tools/src/monitor.rs +++ b/compute_tools/src/monitor.rs @@ -13,6 +13,12 @@ use crate::metrics::{PG_CURR_DOWNTIME_MS, PG_TOTAL_DOWNTIME_MS}; const MONITOR_CHECK_INTERVAL: Duration = Duration::from_millis(500); +/// Struct to store runtime state of the compute monitor thread. +/// In theory, this could be a part of `Compute`, but i) +/// this state is expected to be accessed only by single thread, +/// so we don't need to care about locking; ii) `Compute` is +/// already quite big. Thus, it seems to be a good idea to keep +/// all the activity/health monitoring parts here. struct ComputeMonitor { compute: Arc, @@ -70,12 +76,36 @@ impl ComputeMonitor { ) } + /// Check if compute is in some terminal or soon-to-be-terminal + /// state, then return `true`, signalling the caller that it + /// should exit gracefully. Otherwise, return `false`. + fn check_interrupts(&mut self) -> bool { + let compute_status = self.compute.get_status(); + if matches!( + compute_status, + ComputeStatus::Terminated | ComputeStatus::TerminationPending | ComputeStatus::Failed + ) { + info!( + "compute is in {} status, stopping compute monitor", + compute_status + ); + return true; + } + + false + } + /// Spin in a loop and figure out the last activity time in the Postgres. - /// Then update it in the shared state. This function never errors out. + /// Then update it in the shared state. This function currently never + /// errors out explicitly, but there is a graceful termination path. + /// Every time we receive an error trying to check Postgres, we use + /// [`ComputeMonitor::check_interrupts()`] because it could be that + /// compute is being terminated already, then we can exit gracefully + /// to not produce errors' noise in the log. /// NB: the only expected panic is at `Mutex` unwrap(), all other errors /// should be handled gracefully. #[instrument(skip_all)] - pub fn run(&mut self) { + pub fn run(&mut self) -> anyhow::Result<()> { // Suppose that `connstr` doesn't change let connstr = self.compute.params.connstr.clone(); let conf = self @@ -93,6 +123,10 @@ impl ComputeMonitor { info!("starting compute monitor for {}", connstr); loop { + if self.check_interrupts() { + break; + } + match &mut client { Ok(cli) => { if cli.is_closed() { @@ -100,6 +134,10 @@ impl ComputeMonitor { downtime_info = self.downtime_info(), "connection to Postgres is closed, trying to reconnect" ); + if self.check_interrupts() { + break; + } + self.report_down(); // Connection is closed, reconnect and try again. @@ -111,15 +149,19 @@ impl ComputeMonitor { self.compute.update_last_active(self.last_active); } Err(e) => { + error!( + downtime_info = self.downtime_info(), + "could not check Postgres: {}", e + ); + if self.check_interrupts() { + break; + } + // Although we have many places where we can return errors in `check()`, // normally it shouldn't happen. I.e., we will likely return error if // connection got broken, query timed out, Postgres returned invalid data, etc. // In all such cases it's suspicious, so let's report this as downtime. self.report_down(); - error!( - downtime_info = self.downtime_info(), - "could not check Postgres: {}", e - ); // Reconnect to Postgres just in case. During tests, I noticed // that queries in `check()` can fail with `connection closed`, @@ -136,6 +178,10 @@ impl ComputeMonitor { downtime_info = self.downtime_info(), "could not connect to Postgres: {}, retrying", e ); + if self.check_interrupts() { + break; + } + self.report_down(); // Establish a new connection and try again. @@ -147,6 +193,9 @@ impl ComputeMonitor { self.last_checked = Utc::now(); thread::sleep(MONITOR_CHECK_INTERVAL); } + + // Graceful termination path + Ok(()) } #[instrument(skip_all)] @@ -429,7 +478,10 @@ pub fn launch_monitor(compute: &Arc) -> thread::JoinHandle<()> { .spawn(move || { let span = span!(Level::INFO, "compute_monitor"); let _enter = span.enter(); - monitor.run(); + match monitor.run() { + Ok(_) => info!("compute monitor thread terminated gracefully"), + Err(err) => error!("compute monitor thread terminated abnormally {:?}", err), + } }) .expect("cannot launch compute monitor thread") }